mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-01-08 22:07:06 +00:00
Let caller handle SMART errors on multi-requests (#754)
* Fix for missing get_device_usage * Fix coverage and add methods to exceptions * Remove unused caplog fixture
This commit is contained in:
parent
64da736717
commit
9ab9420ad6
@ -45,6 +45,9 @@ class ConnectionException(SmartDeviceException):
|
||||
class SmartErrorCode(IntEnum):
|
||||
"""Enum for SMART Error Codes."""
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name}({self.value})"
|
||||
|
||||
SUCCESS = 0
|
||||
|
||||
# Transport Errors
|
||||
|
@ -9,7 +9,7 @@ from ..device import Device, WifiNetwork
|
||||
from ..device_type import DeviceType
|
||||
from ..deviceconfig import DeviceConfig
|
||||
from ..emeterstatus import EmeterStatus
|
||||
from ..exceptions import AuthenticationException, SmartDeviceException
|
||||
from ..exceptions import AuthenticationException, SmartDeviceException, SmartErrorCode
|
||||
from ..feature import Feature, FeatureType
|
||||
from ..smartprotocol import SmartProtocol
|
||||
|
||||
@ -61,6 +61,24 @@ class SmartDevice(Device):
|
||||
"""Return list of children."""
|
||||
return list(self._children.values())
|
||||
|
||||
def _try_get_response(self, responses: dict, request: str, default=None) -> dict:
|
||||
response = responses.get(request)
|
||||
if isinstance(response, SmartErrorCode):
|
||||
_LOGGER.debug(
|
||||
"Error %s getting request %s for device %s",
|
||||
response,
|
||||
request,
|
||||
self.host,
|
||||
)
|
||||
response = None
|
||||
if response is not None:
|
||||
return response
|
||||
if default is not None:
|
||||
return default
|
||||
raise SmartDeviceException(
|
||||
f"{request} not found in {responses} for device {self.host}"
|
||||
)
|
||||
|
||||
async def update(self, update_children: bool = True):
|
||||
"""Update the device."""
|
||||
if self.credentials is None and self.credentials_hash is None:
|
||||
@ -87,7 +105,7 @@ class SmartDevice(Device):
|
||||
"get_current_power": None,
|
||||
}
|
||||
|
||||
if self._components["device"] >= 2:
|
||||
if self._components.get("device", 0) >= 2:
|
||||
extra_reqs = {
|
||||
**extra_reqs,
|
||||
"get_device_usage": None,
|
||||
@ -101,13 +119,13 @@ class SmartDevice(Device):
|
||||
|
||||
resp = await self.protocol.query(req)
|
||||
|
||||
self._info = resp["get_device_info"]
|
||||
self._time = resp["get_device_time"]
|
||||
self._info = self._try_get_response(resp, "get_device_info")
|
||||
self._time = self._try_get_response(resp, "get_device_time", {})
|
||||
# Device usage is not available on older firmware versions
|
||||
self._usage = resp.get("get_device_usage", {})
|
||||
self._usage = self._try_get_response(resp, "get_device_usage", {})
|
||||
# Emeter is not always available, but we set them still for now.
|
||||
self._energy = resp.get("get_energy_usage", {})
|
||||
self._emeter = resp.get("get_current_power", {})
|
||||
self._energy = self._try_get_response(resp, "get_energy_usage", {})
|
||||
self._emeter = self._try_get_response(resp, "get_current_power", {})
|
||||
|
||||
self._last_update = {
|
||||
"components": self._components_raw,
|
||||
@ -116,7 +134,7 @@ class SmartDevice(Device):
|
||||
"time": self._time,
|
||||
"energy": self._energy,
|
||||
"emeter": self._emeter,
|
||||
"child_info": resp.get("get_child_device_list", {}),
|
||||
"child_info": self._try_get_response(resp, "get_child_device_list", {}),
|
||||
}
|
||||
|
||||
if child_info := self._last_update.get("child_info"):
|
||||
|
@ -129,19 +129,21 @@ class SmartProtocol(BaseProtocol):
|
||||
pf(smart_request),
|
||||
)
|
||||
response_step = await self._transport.send(smart_request)
|
||||
batch_name = f"multi-request-batch-{i+1}"
|
||||
if debug_enabled:
|
||||
_LOGGER.debug(
|
||||
"%s multi-request-batch-%s << %s",
|
||||
"%s %s << %s",
|
||||
self._host,
|
||||
i + 1,
|
||||
batch_name,
|
||||
pf(response_step),
|
||||
)
|
||||
self._handle_response_error_code(response_step)
|
||||
self._handle_response_error_code(response_step, batch_name)
|
||||
responses = response_step["result"]["responses"]
|
||||
for response in responses:
|
||||
self._handle_response_error_code(response)
|
||||
method = response["method"]
|
||||
self._handle_response_error_code(response, method, raise_on_error=False)
|
||||
result = response.get("result", None)
|
||||
multi_result[response["method"]] = result
|
||||
multi_result[method] = result
|
||||
return multi_result
|
||||
|
||||
async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> Dict:
|
||||
@ -173,22 +175,24 @@ class SmartProtocol(BaseProtocol):
|
||||
pf(response_data),
|
||||
)
|
||||
|
||||
self._handle_response_error_code(response_data)
|
||||
self._handle_response_error_code(response_data, smart_method)
|
||||
|
||||
# Single set_ requests do not return a result
|
||||
result = response_data.get("result")
|
||||
return {smart_method: result}
|
||||
|
||||
def _handle_response_error_code(self, resp_dict: dict):
|
||||
def _handle_response_error_code(self, resp_dict: dict, method, raise_on_error=True):
|
||||
error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
|
||||
if error_code == SmartErrorCode.SUCCESS:
|
||||
return
|
||||
if not raise_on_error:
|
||||
resp_dict["result"] = error_code
|
||||
return
|
||||
msg = (
|
||||
f"Error querying device: {self._host}: "
|
||||
+ f"{error_code.name}({error_code.value})"
|
||||
+ f" for method: {method}"
|
||||
)
|
||||
if method := resp_dict.get("method"):
|
||||
msg += f" for method: {method}"
|
||||
if error_code in SMART_TIMEOUT_ERRORS:
|
||||
raise TimeoutException(msg, error_code=error_code)
|
||||
if error_code in SMART_RETRYABLE_ERRORS:
|
||||
@ -338,7 +342,7 @@ class _ChildProtocolWrapper(SmartProtocol):
|
||||
result = response.get("control_child")
|
||||
# Unwrap responseData for control_child
|
||||
if result and (response_data := result.get("responseData")):
|
||||
self._handle_response_error_code(response_data)
|
||||
self._handle_response_error_code(response_data, "control_child")
|
||||
result = response_data.get("result")
|
||||
|
||||
# TODO: handle multipleRequest unwrapping
|
||||
|
@ -1,5 +1,6 @@
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
import pkgutil
|
||||
import re
|
||||
import sys
|
||||
@ -21,10 +22,18 @@ from voluptuous import (
|
||||
|
||||
import kasa
|
||||
from kasa import Credentials, Device, DeviceConfig, SmartDeviceException
|
||||
from kasa.exceptions import SmartErrorCode
|
||||
from kasa.iot import IotDevice
|
||||
from kasa.smart import SmartChildDevice, SmartDevice
|
||||
|
||||
from .conftest import device_iot, handle_turn_on, has_emeter_iot, no_emeter_iot, turn_on
|
||||
from .conftest import (
|
||||
device_iot,
|
||||
device_smart,
|
||||
handle_turn_on,
|
||||
has_emeter_iot,
|
||||
no_emeter_iot,
|
||||
turn_on,
|
||||
)
|
||||
from .fakeprotocol_iot import FakeIotProtocol
|
||||
|
||||
|
||||
@ -300,6 +309,33 @@ async def test_modules_not_supported(dev: IotDevice):
|
||||
assert module.is_supported is not None
|
||||
|
||||
|
||||
@device_smart
|
||||
async def test_update_sub_errors(dev: SmartDevice, caplog):
|
||||
mock_response: dict = {
|
||||
"get_device_info": {},
|
||||
"get_device_usage": SmartErrorCode.PARAMS_ERROR,
|
||||
"get_device_time": {},
|
||||
}
|
||||
caplog.set_level(logging.DEBUG)
|
||||
with patch.object(dev.protocol, "query", return_value=mock_response):
|
||||
await dev.update()
|
||||
msg = "Error PARAMS_ERROR(-1008) getting request get_device_usage for device 127.0.0.123"
|
||||
assert msg in caplog.text
|
||||
|
||||
|
||||
@device_smart
|
||||
async def test_update_no_device_info(dev: SmartDevice):
|
||||
mock_response: dict = {
|
||||
"get_device_usage": {},
|
||||
"get_device_time": {},
|
||||
}
|
||||
msg = f"get_device_info not found in {mock_response} for device 127.0.0.123"
|
||||
with patch.object(dev.protocol, "query", return_value=mock_response), pytest.raises(
|
||||
SmartDeviceException, match=msg
|
||||
):
|
||||
await dev.update()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"device_class, use_class", kasa.deprecated_smart_devices.items()
|
||||
)
|
||||
|
@ -60,13 +60,10 @@ async def test_smart_device_errors_in_multiple_request(
|
||||
send_mock = mocker.patch.object(
|
||||
dummy_protocol._transport, "send", return_value=mock_response
|
||||
)
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await dummy_protocol.query(DUMMY_MULTIPLE_QUERY, retry_count=2)
|
||||
if error_code in chain(SMART_TIMEOUT_ERRORS, SMART_RETRYABLE_ERRORS):
|
||||
expected_calls = 3
|
||||
else:
|
||||
expected_calls = 1
|
||||
assert send_mock.call_count == expected_calls
|
||||
|
||||
resp_dict = await dummy_protocol.query(DUMMY_MULTIPLE_QUERY, retry_count=2)
|
||||
assert resp_dict["foobar2"] == error_code
|
||||
assert send_mock.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("request_size", [1, 3, 5, 10])
|
||||
|
Loading…
Reference in New Issue
Block a user