Let caller handle SMART errors on multi-requests (#754)

* Fix for missing get_device_usage

* Fix coverage and add methods to exceptions

* Remove unused caplog fixture
This commit is contained in:
Steven B 2024-02-15 18:10:34 +00:00 committed by GitHub
parent 64da736717
commit 9ab9420ad6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 84 additions and 26 deletions

View File

@ -45,6 +45,9 @@ class ConnectionException(SmartDeviceException):
class SmartErrorCode(IntEnum): class SmartErrorCode(IntEnum):
"""Enum for SMART Error Codes.""" """Enum for SMART Error Codes."""
def __str__(self):
return f"{self.name}({self.value})"
SUCCESS = 0 SUCCESS = 0
# Transport Errors # Transport Errors

View File

@ -9,7 +9,7 @@ from ..device import Device, WifiNetwork
from ..device_type import DeviceType from ..device_type import DeviceType
from ..deviceconfig import DeviceConfig from ..deviceconfig import DeviceConfig
from ..emeterstatus import EmeterStatus from ..emeterstatus import EmeterStatus
from ..exceptions import AuthenticationException, SmartDeviceException from ..exceptions import AuthenticationException, SmartDeviceException, SmartErrorCode
from ..feature import Feature, FeatureType from ..feature import Feature, FeatureType
from ..smartprotocol import SmartProtocol from ..smartprotocol import SmartProtocol
@ -61,6 +61,24 @@ class SmartDevice(Device):
"""Return list of children.""" """Return list of children."""
return list(self._children.values()) return list(self._children.values())
def _try_get_response(self, responses: dict, request: str, default=None) -> dict:
response = responses.get(request)
if isinstance(response, SmartErrorCode):
_LOGGER.debug(
"Error %s getting request %s for device %s",
response,
request,
self.host,
)
response = None
if response is not None:
return response
if default is not None:
return default
raise SmartDeviceException(
f"{request} not found in {responses} for device {self.host}"
)
async def update(self, update_children: bool = True): async def update(self, update_children: bool = True):
"""Update the device.""" """Update the device."""
if self.credentials is None and self.credentials_hash is None: if self.credentials is None and self.credentials_hash is None:
@ -87,7 +105,7 @@ class SmartDevice(Device):
"get_current_power": None, "get_current_power": None,
} }
if self._components["device"] >= 2: if self._components.get("device", 0) >= 2:
extra_reqs = { extra_reqs = {
**extra_reqs, **extra_reqs,
"get_device_usage": None, "get_device_usage": None,
@ -101,13 +119,13 @@ class SmartDevice(Device):
resp = await self.protocol.query(req) resp = await self.protocol.query(req)
self._info = resp["get_device_info"] self._info = self._try_get_response(resp, "get_device_info")
self._time = resp["get_device_time"] self._time = self._try_get_response(resp, "get_device_time", {})
# Device usage is not available on older firmware versions # Device usage is not available on older firmware versions
self._usage = resp.get("get_device_usage", {}) self._usage = self._try_get_response(resp, "get_device_usage", {})
# Emeter is not always available, but we set them still for now. # Emeter is not always available, but we set them still for now.
self._energy = resp.get("get_energy_usage", {}) self._energy = self._try_get_response(resp, "get_energy_usage", {})
self._emeter = resp.get("get_current_power", {}) self._emeter = self._try_get_response(resp, "get_current_power", {})
self._last_update = { self._last_update = {
"components": self._components_raw, "components": self._components_raw,
@ -116,7 +134,7 @@ class SmartDevice(Device):
"time": self._time, "time": self._time,
"energy": self._energy, "energy": self._energy,
"emeter": self._emeter, "emeter": self._emeter,
"child_info": resp.get("get_child_device_list", {}), "child_info": self._try_get_response(resp, "get_child_device_list", {}),
} }
if child_info := self._last_update.get("child_info"): if child_info := self._last_update.get("child_info"):

View File

@ -129,19 +129,21 @@ class SmartProtocol(BaseProtocol):
pf(smart_request), pf(smart_request),
) )
response_step = await self._transport.send(smart_request) response_step = await self._transport.send(smart_request)
batch_name = f"multi-request-batch-{i+1}"
if debug_enabled: if debug_enabled:
_LOGGER.debug( _LOGGER.debug(
"%s multi-request-batch-%s << %s", "%s %s << %s",
self._host, self._host,
i + 1, batch_name,
pf(response_step), pf(response_step),
) )
self._handle_response_error_code(response_step) self._handle_response_error_code(response_step, batch_name)
responses = response_step["result"]["responses"] responses = response_step["result"]["responses"]
for response in responses: for response in responses:
self._handle_response_error_code(response) method = response["method"]
self._handle_response_error_code(response, method, raise_on_error=False)
result = response.get("result", None) result = response.get("result", None)
multi_result[response["method"]] = result multi_result[method] = result
return multi_result return multi_result
async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> Dict: async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> Dict:
@ -173,22 +175,24 @@ class SmartProtocol(BaseProtocol):
pf(response_data), pf(response_data),
) )
self._handle_response_error_code(response_data) self._handle_response_error_code(response_data, smart_method)
# Single set_ requests do not return a result # Single set_ requests do not return a result
result = response_data.get("result") result = response_data.get("result")
return {smart_method: result} return {smart_method: result}
def _handle_response_error_code(self, resp_dict: dict): def _handle_response_error_code(self, resp_dict: dict, method, raise_on_error=True):
error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type] error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
if error_code == SmartErrorCode.SUCCESS: if error_code == SmartErrorCode.SUCCESS:
return return
if not raise_on_error:
resp_dict["result"] = error_code
return
msg = ( msg = (
f"Error querying device: {self._host}: " f"Error querying device: {self._host}: "
+ f"{error_code.name}({error_code.value})" + f"{error_code.name}({error_code.value})"
+ f" for method: {method}"
) )
if method := resp_dict.get("method"):
msg += f" for method: {method}"
if error_code in SMART_TIMEOUT_ERRORS: if error_code in SMART_TIMEOUT_ERRORS:
raise TimeoutException(msg, error_code=error_code) raise TimeoutException(msg, error_code=error_code)
if error_code in SMART_RETRYABLE_ERRORS: if error_code in SMART_RETRYABLE_ERRORS:
@ -338,7 +342,7 @@ class _ChildProtocolWrapper(SmartProtocol):
result = response.get("control_child") result = response.get("control_child")
# Unwrap responseData for control_child # Unwrap responseData for control_child
if result and (response_data := result.get("responseData")): if result and (response_data := result.get("responseData")):
self._handle_response_error_code(response_data) self._handle_response_error_code(response_data, "control_child")
result = response_data.get("result") result = response_data.get("result")
# TODO: handle multipleRequest unwrapping # TODO: handle multipleRequest unwrapping

View File

@ -1,5 +1,6 @@
import importlib import importlib
import inspect import inspect
import logging
import pkgutil import pkgutil
import re import re
import sys import sys
@ -21,10 +22,18 @@ from voluptuous import (
import kasa import kasa
from kasa import Credentials, Device, DeviceConfig, SmartDeviceException from kasa import Credentials, Device, DeviceConfig, SmartDeviceException
from kasa.exceptions import SmartErrorCode
from kasa.iot import IotDevice from kasa.iot import IotDevice
from kasa.smart import SmartChildDevice, SmartDevice from kasa.smart import SmartChildDevice, SmartDevice
from .conftest import device_iot, handle_turn_on, has_emeter_iot, no_emeter_iot, turn_on from .conftest import (
device_iot,
device_smart,
handle_turn_on,
has_emeter_iot,
no_emeter_iot,
turn_on,
)
from .fakeprotocol_iot import FakeIotProtocol from .fakeprotocol_iot import FakeIotProtocol
@ -300,6 +309,33 @@ async def test_modules_not_supported(dev: IotDevice):
assert module.is_supported is not None assert module.is_supported is not None
@device_smart
async def test_update_sub_errors(dev: SmartDevice, caplog):
mock_response: dict = {
"get_device_info": {},
"get_device_usage": SmartErrorCode.PARAMS_ERROR,
"get_device_time": {},
}
caplog.set_level(logging.DEBUG)
with patch.object(dev.protocol, "query", return_value=mock_response):
await dev.update()
msg = "Error PARAMS_ERROR(-1008) getting request get_device_usage for device 127.0.0.123"
assert msg in caplog.text
@device_smart
async def test_update_no_device_info(dev: SmartDevice):
mock_response: dict = {
"get_device_usage": {},
"get_device_time": {},
}
msg = f"get_device_info not found in {mock_response} for device 127.0.0.123"
with patch.object(dev.protocol, "query", return_value=mock_response), pytest.raises(
SmartDeviceException, match=msg
):
await dev.update()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"device_class, use_class", kasa.deprecated_smart_devices.items() "device_class, use_class", kasa.deprecated_smart_devices.items()
) )

View File

@ -60,13 +60,10 @@ async def test_smart_device_errors_in_multiple_request(
send_mock = mocker.patch.object( send_mock = mocker.patch.object(
dummy_protocol._transport, "send", return_value=mock_response dummy_protocol._transport, "send", return_value=mock_response
) )
with pytest.raises(SmartDeviceException):
await dummy_protocol.query(DUMMY_MULTIPLE_QUERY, retry_count=2) resp_dict = await dummy_protocol.query(DUMMY_MULTIPLE_QUERY, retry_count=2)
if error_code in chain(SMART_TIMEOUT_ERRORS, SMART_RETRYABLE_ERRORS): assert resp_dict["foobar2"] == error_code
expected_calls = 3 assert send_mock.call_count == 1
else:
expected_calls = 1
assert send_mock.call_count == expected_calls
@pytest.mark.parametrize("request_size", [1, 3, 5, 10]) @pytest.mark.parametrize("request_size", [1, 3, 5, 10])