Let caller handle SMART errors on multi-requests (#754)

* Fix for missing get_device_usage

* Fix coverage and add methods to exceptions

* Remove unused caplog fixture
This commit is contained in:
Steven B 2024-02-15 18:10:34 +00:00 committed by GitHub
parent 64da736717
commit 9ab9420ad6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 84 additions and 26 deletions

View File

@ -45,6 +45,9 @@ class ConnectionException(SmartDeviceException):
class SmartErrorCode(IntEnum):
"""Enum for SMART Error Codes."""
def __str__(self):
return f"{self.name}({self.value})"
SUCCESS = 0
# Transport Errors

View File

@ -9,7 +9,7 @@ from ..device import Device, WifiNetwork
from ..device_type import DeviceType
from ..deviceconfig import DeviceConfig
from ..emeterstatus import EmeterStatus
from ..exceptions import AuthenticationException, SmartDeviceException
from ..exceptions import AuthenticationException, SmartDeviceException, SmartErrorCode
from ..feature import Feature, FeatureType
from ..smartprotocol import SmartProtocol
@ -61,6 +61,24 @@ class SmartDevice(Device):
"""Return list of children."""
return list(self._children.values())
def _try_get_response(self, responses: dict, request: str, default=None) -> dict:
response = responses.get(request)
if isinstance(response, SmartErrorCode):
_LOGGER.debug(
"Error %s getting request %s for device %s",
response,
request,
self.host,
)
response = None
if response is not None:
return response
if default is not None:
return default
raise SmartDeviceException(
f"{request} not found in {responses} for device {self.host}"
)
async def update(self, update_children: bool = True):
"""Update the device."""
if self.credentials is None and self.credentials_hash is None:
@ -87,7 +105,7 @@ class SmartDevice(Device):
"get_current_power": None,
}
if self._components["device"] >= 2:
if self._components.get("device", 0) >= 2:
extra_reqs = {
**extra_reqs,
"get_device_usage": None,
@ -101,13 +119,13 @@ class SmartDevice(Device):
resp = await self.protocol.query(req)
self._info = resp["get_device_info"]
self._time = resp["get_device_time"]
self._info = self._try_get_response(resp, "get_device_info")
self._time = self._try_get_response(resp, "get_device_time", {})
# Device usage is not available on older firmware versions
self._usage = resp.get("get_device_usage", {})
self._usage = self._try_get_response(resp, "get_device_usage", {})
# Emeter is not always available, but we set them still for now.
self._energy = resp.get("get_energy_usage", {})
self._emeter = resp.get("get_current_power", {})
self._energy = self._try_get_response(resp, "get_energy_usage", {})
self._emeter = self._try_get_response(resp, "get_current_power", {})
self._last_update = {
"components": self._components_raw,
@ -116,7 +134,7 @@ class SmartDevice(Device):
"time": self._time,
"energy": self._energy,
"emeter": self._emeter,
"child_info": resp.get("get_child_device_list", {}),
"child_info": self._try_get_response(resp, "get_child_device_list", {}),
}
if child_info := self._last_update.get("child_info"):

View File

@ -129,19 +129,21 @@ class SmartProtocol(BaseProtocol):
pf(smart_request),
)
response_step = await self._transport.send(smart_request)
batch_name = f"multi-request-batch-{i+1}"
if debug_enabled:
_LOGGER.debug(
"%s multi-request-batch-%s << %s",
"%s %s << %s",
self._host,
i + 1,
batch_name,
pf(response_step),
)
self._handle_response_error_code(response_step)
self._handle_response_error_code(response_step, batch_name)
responses = response_step["result"]["responses"]
for response in responses:
self._handle_response_error_code(response)
method = response["method"]
self._handle_response_error_code(response, method, raise_on_error=False)
result = response.get("result", None)
multi_result[response["method"]] = result
multi_result[method] = result
return multi_result
async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> Dict:
@ -173,22 +175,24 @@ class SmartProtocol(BaseProtocol):
pf(response_data),
)
self._handle_response_error_code(response_data)
self._handle_response_error_code(response_data, smart_method)
# Single set_ requests do not return a result
result = response_data.get("result")
return {smart_method: result}
def _handle_response_error_code(self, resp_dict: dict):
def _handle_response_error_code(self, resp_dict: dict, method, raise_on_error=True):
error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
if error_code == SmartErrorCode.SUCCESS:
return
if not raise_on_error:
resp_dict["result"] = error_code
return
msg = (
f"Error querying device: {self._host}: "
+ f"{error_code.name}({error_code.value})"
+ f" for method: {method}"
)
if method := resp_dict.get("method"):
msg += f" for method: {method}"
if error_code in SMART_TIMEOUT_ERRORS:
raise TimeoutException(msg, error_code=error_code)
if error_code in SMART_RETRYABLE_ERRORS:
@ -338,7 +342,7 @@ class _ChildProtocolWrapper(SmartProtocol):
result = response.get("control_child")
# Unwrap responseData for control_child
if result and (response_data := result.get("responseData")):
self._handle_response_error_code(response_data)
self._handle_response_error_code(response_data, "control_child")
result = response_data.get("result")
# TODO: handle multipleRequest unwrapping

View File

@ -1,5 +1,6 @@
import importlib
import inspect
import logging
import pkgutil
import re
import sys
@ -21,10 +22,18 @@ from voluptuous import (
import kasa
from kasa import Credentials, Device, DeviceConfig, SmartDeviceException
from kasa.exceptions import SmartErrorCode
from kasa.iot import IotDevice
from kasa.smart import SmartChildDevice, SmartDevice
from .conftest import device_iot, handle_turn_on, has_emeter_iot, no_emeter_iot, turn_on
from .conftest import (
device_iot,
device_smart,
handle_turn_on,
has_emeter_iot,
no_emeter_iot,
turn_on,
)
from .fakeprotocol_iot import FakeIotProtocol
@ -300,6 +309,33 @@ async def test_modules_not_supported(dev: IotDevice):
assert module.is_supported is not None
@device_smart
async def test_update_sub_errors(dev: SmartDevice, caplog):
mock_response: dict = {
"get_device_info": {},
"get_device_usage": SmartErrorCode.PARAMS_ERROR,
"get_device_time": {},
}
caplog.set_level(logging.DEBUG)
with patch.object(dev.protocol, "query", return_value=mock_response):
await dev.update()
msg = "Error PARAMS_ERROR(-1008) getting request get_device_usage for device 127.0.0.123"
assert msg in caplog.text
@device_smart
async def test_update_no_device_info(dev: SmartDevice):
mock_response: dict = {
"get_device_usage": {},
"get_device_time": {},
}
msg = f"get_device_info not found in {mock_response} for device 127.0.0.123"
with patch.object(dev.protocol, "query", return_value=mock_response), pytest.raises(
SmartDeviceException, match=msg
):
await dev.update()
@pytest.mark.parametrize(
"device_class, use_class", kasa.deprecated_smart_devices.items()
)

View File

@ -60,13 +60,10 @@ async def test_smart_device_errors_in_multiple_request(
send_mock = mocker.patch.object(
dummy_protocol._transport, "send", return_value=mock_response
)
with pytest.raises(SmartDeviceException):
await dummy_protocol.query(DUMMY_MULTIPLE_QUERY, retry_count=2)
if error_code in chain(SMART_TIMEOUT_ERRORS, SMART_RETRYABLE_ERRORS):
expected_calls = 3
else:
expected_calls = 1
assert send_mock.call_count == expected_calls
resp_dict = await dummy_protocol.query(DUMMY_MULTIPLE_QUERY, retry_count=2)
assert resp_dict["foobar2"] == error_code
assert send_mock.call_count == 1
@pytest.mark.parametrize("request_size", [1, 3, 5, 10])