From 905a14895d6111a7a0b50b773bbbd89ca955e5d4 Mon Sep 17 00:00:00 2001 From: Steven B <51370195+sdb9696@users.noreply.github.com> Date: Thu, 4 Jul 2024 08:02:50 +0100 Subject: [PATCH] Handle module errors more robustly and add query params to light preset and transition (#1036) Ensures that all modules try to access their data in `_post_update_hook` in a safe manner and disable themselves if there's an error. Also adds parameters to get_preset_rules and get_on_off_gradually_info to fix issues with recent firmware updates. [#1033](https://github.com/python-kasa/python-kasa/issues/1033) --- devtools/helpers/smartrequests.py | 11 ++- kasa/smart/modules/autooff.py | 6 -- kasa/smart/modules/batterysensor.py | 4 ++ kasa/smart/modules/cloud.py | 10 ++- kasa/smart/modules/devicemodule.py | 7 ++ kasa/smart/modules/firmware.py | 12 +++- kasa/smart/modules/frostprotection.py | 4 ++ kasa/smart/modules/humiditysensor.py | 4 ++ kasa/smart/modules/lightpreset.py | 2 +- kasa/smart/modules/lighttransition.py | 2 +- kasa/smart/modules/reportmode.py | 4 ++ kasa/smart/modules/temperaturesensor.py | 4 ++ kasa/smart/smartdevice.py | 30 ++++++-- kasa/smart/smartmodule.py | 22 +++++- kasa/smartprotocol.py | 4 ++ kasa/tests/test_smartdevice.py | 94 ++++++++++++++++++++++--- kasa/tests/test_smartprotocol.py | 16 +++++ 17 files changed, 206 insertions(+), 30 deletions(-) diff --git a/devtools/helpers/smartrequests.py b/devtools/helpers/smartrequests.py index 881488b5..4db1f7a1 100644 --- a/devtools/helpers/smartrequests.py +++ b/devtools/helpers/smartrequests.py @@ -284,6 +284,15 @@ class SmartRequest: """Get preset rules.""" return SmartRequest("get_preset_rules", params or SmartRequest.GetRulesParams()) + @staticmethod + def get_on_off_gradually_info( + params: SmartRequestParams | None = None, + ) -> SmartRequest: + """Get preset rules.""" + return SmartRequest( + "get_on_off_gradually_info", params or SmartRequest.SmartRequestParams() + ) + @staticmethod def get_auto_light_info() -> SmartRequest: """Get auto light info.""" @@ -382,7 +391,7 @@ COMPONENT_REQUESTS = { "auto_light": [SmartRequest.get_auto_light_info()], "light_effect": [SmartRequest.get_dynamic_light_effect_rules()], "bulb_quick_control": [], - "on_off_gradually": [SmartRequest.get_raw_request("get_on_off_gradually_info")], + "on_off_gradually": [SmartRequest.get_on_off_gradually_info()], "light_strip": [], "light_strip_lighting_effect": [ SmartRequest.get_raw_request("get_lighting_effect") diff --git a/kasa/smart/modules/autooff.py b/kasa/smart/modules/autooff.py index 0004aec4..5e4b100f 100644 --- a/kasa/smart/modules/autooff.py +++ b/kasa/smart/modules/autooff.py @@ -19,12 +19,6 @@ class AutoOff(SmartModule): def _initialize_features(self): """Initialize features after the initial update.""" - if not isinstance(self.data, dict): - _LOGGER.warning( - "No data available for module, skipping %s: %s", self, self.data - ) - return - self._add_feature( Feature( self._device, diff --git a/kasa/smart/modules/batterysensor.py b/kasa/smart/modules/batterysensor.py index 415e47d1..7ff7df2d 100644 --- a/kasa/smart/modules/batterysensor.py +++ b/kasa/smart/modules/batterysensor.py @@ -43,6 +43,10 @@ class BatterySensor(SmartModule): ) ) + def query(self) -> dict: + """Query to execute during the update cycle.""" + return {} + @property def battery(self): """Return battery level.""" diff --git a/kasa/smart/modules/cloud.py b/kasa/smart/modules/cloud.py index 1b64f090..8346af57 100644 --- a/kasa/smart/modules/cloud.py +++ b/kasa/smart/modules/cloud.py @@ -4,7 +4,6 @@ from __future__ import annotations from typing import TYPE_CHECKING -from ...exceptions import SmartErrorCode from ...feature import Feature from ..smartmodule import SmartModule @@ -18,6 +17,13 @@ class Cloud(SmartModule): QUERY_GETTER_NAME = "get_connect_cloud_state" REQUIRED_COMPONENT = "cloud_connect" + def _post_update_hook(self): + """Perform actions after a device update. + + Overrides the default behaviour to disable a module if the query returns + an error because the logic here is to treat that as not connected. + """ + def __init__(self, device: SmartDevice, module: str): super().__init__(device, module) @@ -37,6 +43,6 @@ class Cloud(SmartModule): @property def is_connected(self): """Return True if device is connected to the cloud.""" - if isinstance(self.data, SmartErrorCode): + if self._has_data_error(): return False return self.data["status"] == 0 diff --git a/kasa/smart/modules/devicemodule.py b/kasa/smart/modules/devicemodule.py index 6a846d54..3203e82f 100644 --- a/kasa/smart/modules/devicemodule.py +++ b/kasa/smart/modules/devicemodule.py @@ -10,6 +10,13 @@ class DeviceModule(SmartModule): REQUIRED_COMPONENT = "device" + def _post_update_hook(self): + """Perform actions after a device update. + + Overrides the default behaviour to disable a module if the query returns + an error because this module is critical. + """ + def query(self) -> dict: """Query to execute during the update cycle.""" query = { diff --git a/kasa/smart/modules/firmware.py b/kasa/smart/modules/firmware.py index 3dcaddd6..10a6b824 100644 --- a/kasa/smart/modules/firmware.py +++ b/kasa/smart/modules/firmware.py @@ -13,7 +13,6 @@ from typing import TYPE_CHECKING, Any, Callable, Optional from async_timeout import timeout as asyncio_timeout from pydantic.v1 import BaseModel, Field, validator -from ...exceptions import SmartErrorCode from ...feature import Feature from ..smartmodule import SmartModule @@ -123,6 +122,13 @@ class Firmware(SmartModule): req["get_auto_update_info"] = None return req + def _post_update_hook(self): + """Perform actions after a device update. + + Overrides the default behaviour to disable a module if the query returns + an error because some of the module still functions. + """ + @property def current_firmware(self) -> str: """Return the current firmware version.""" @@ -136,11 +142,11 @@ class Firmware(SmartModule): @property def firmware_update_info(self): """Return latest firmware information.""" - fw = self.data.get("get_latest_fw") or self.data - if not self._device.is_cloud_connected or isinstance(fw, SmartErrorCode): + if not self._device.is_cloud_connected or self._has_data_error(): # Error in response, probably disconnected from the cloud. return UpdateInfo(type=0, need_to_upgrade=False) + fw = self.data.get("get_latest_fw") or self.data return UpdateInfo.parse_obj(fw) @property diff --git a/kasa/smart/modules/frostprotection.py b/kasa/smart/modules/frostprotection.py index f1811012..440e1ed1 100644 --- a/kasa/smart/modules/frostprotection.py +++ b/kasa/smart/modules/frostprotection.py @@ -14,6 +14,10 @@ class FrostProtection(SmartModule): REQUIRED_COMPONENT = "frost_protection" QUERY_GETTER_NAME = "get_frost_protection" + def query(self) -> dict: + """Query to execute during the update cycle.""" + return {} + @property def enabled(self) -> bool: """Return True if frost protection is on.""" diff --git a/kasa/smart/modules/humiditysensor.py b/kasa/smart/modules/humiditysensor.py index f0dcc18a..b137736f 100644 --- a/kasa/smart/modules/humiditysensor.py +++ b/kasa/smart/modules/humiditysensor.py @@ -45,6 +45,10 @@ class HumiditySensor(SmartModule): ) ) + def query(self) -> dict: + """Query to execute during the update cycle.""" + return {} + @property def humidity(self): """Return current humidity in percentage.""" diff --git a/kasa/smart/modules/lightpreset.py b/kasa/smart/modules/lightpreset.py index 8e5cae20..7635a5f8 100644 --- a/kasa/smart/modules/lightpreset.py +++ b/kasa/smart/modules/lightpreset.py @@ -140,7 +140,7 @@ class LightPreset(SmartModule, LightPresetInterface): """Query to execute during the update cycle.""" if self._state_in_sysinfo: # Child lights can have states in the child info return {} - return {self.QUERY_GETTER_NAME: None} + return {self.QUERY_GETTER_NAME: {"start_index": 0}} async def _check_supported(self): """Additional check to see if the module is supported by the device. diff --git a/kasa/smart/modules/lighttransition.py b/kasa/smart/modules/lighttransition.py index 29a4bb05..ca0eca86 100644 --- a/kasa/smart/modules/lighttransition.py +++ b/kasa/smart/modules/lighttransition.py @@ -230,7 +230,7 @@ class LightTransition(SmartModule): if self._state_in_sysinfo: return {} else: - return {self.QUERY_GETTER_NAME: None} + return {self.QUERY_GETTER_NAME: {}} async def _check_supported(self): """Additional check to see if the module is supported by the device.""" diff --git a/kasa/smart/modules/reportmode.py b/kasa/smart/modules/reportmode.py index 79c8ae62..8d210a5b 100644 --- a/kasa/smart/modules/reportmode.py +++ b/kasa/smart/modules/reportmode.py @@ -32,6 +32,10 @@ class ReportMode(SmartModule): ) ) + def query(self) -> dict: + """Query to execute during the update cycle.""" + return {} + @property def report_interval(self): """Reporting interval of a sensor device.""" diff --git a/kasa/smart/modules/temperaturesensor.py b/kasa/smart/modules/temperaturesensor.py index d9850150..a61859cd 100644 --- a/kasa/smart/modules/temperaturesensor.py +++ b/kasa/smart/modules/temperaturesensor.py @@ -58,6 +58,10 @@ class TemperatureSensor(SmartModule): ) ) + def query(self) -> dict: + """Query to execute during the update cycle.""" + return {} + @property def temperature(self): """Return current humidity in percentage.""" diff --git a/kasa/smart/smartdevice.py b/kasa/smart/smartdevice.py index 408ba027..5bf2400b 100644 --- a/kasa/smart/smartdevice.py +++ b/kasa/smart/smartdevice.py @@ -177,11 +177,20 @@ class SmartDevice(Device): self._children[info["device_id"]]._update_internal_state(info) # Call handle update for modules that want to update internal data - for module in self._modules.values(): - module._post_update_hook() + errors = [] + for module_name, module in self._modules.items(): + if not self._handle_module_post_update_hook(module): + errors.append(module_name) + for error in errors: + self._modules.pop(error) + for child in self._children.values(): - for child_module in child._modules.values(): - child_module._post_update_hook() + errors = [] + for child_module_name, child_module in child._modules.items(): + if not self._handle_module_post_update_hook(child_module): + errors.append(child_module_name) + for error in errors: + child._modules.pop(error) # We can first initialize the features after the first update. # We make here an assumption that every device has at least a single feature. @@ -190,6 +199,19 @@ class SmartDevice(Device): _LOGGER.debug("Got an update: %s", self._last_update) + def _handle_module_post_update_hook(self, module: SmartModule) -> bool: + try: + module._post_update_hook() + return True + except Exception as ex: + _LOGGER.error( + "Error processing %s for device %s, module will be unavailable: %s", + module.name, + self.host, + ex, + ) + return False + async def _initialize_modules(self): """Initialize modules based on component negotiation response.""" from .smartmodule import SmartModule diff --git a/kasa/smart/smartmodule.py b/kasa/smart/smartmodule.py index e78f4393..fb946a8b 100644 --- a/kasa/smart/smartmodule.py +++ b/kasa/smart/smartmodule.py @@ -5,7 +5,7 @@ from __future__ import annotations import logging from typing import TYPE_CHECKING -from ..exceptions import KasaException +from ..exceptions import DeviceError, KasaException, SmartErrorCode from ..module import Module if TYPE_CHECKING: @@ -41,6 +41,14 @@ class SmartModule(Module): """Name of the module.""" return getattr(self, "NAME", self.__class__.__name__) + def _post_update_hook(self): # noqa: B027 + """Perform actions after a device update. + + Any modules overriding this should ensure that self.data is + accessed unless the module should remain active despite errors. + """ + assert self.data # noqa: S101 + def query(self) -> dict: """Query to execute during the update cycle. @@ -87,6 +95,11 @@ class SmartModule(Module): filtered_data = {k: v for k, v in dev._last_update.items() if k in q_keys} + for data_item in filtered_data: + if isinstance(filtered_data[data_item], SmartErrorCode): + raise DeviceError( + f"{data_item} for {self.name}", error_code=filtered_data[data_item] + ) if len(filtered_data) == 1: return next(iter(filtered_data.values())) @@ -110,3 +123,10 @@ class SmartModule(Module): color_temp_range but only supports one value. """ return True + + def _has_data_error(self) -> bool: + try: + assert self.data # noqa: S101 + return False + except DeviceError: + return True diff --git a/kasa/smartprotocol.py b/kasa/smartprotocol.py index e6741bc4..3085714c 100644 --- a/kasa/smartprotocol.py +++ b/kasa/smartprotocol.py @@ -416,6 +416,10 @@ class _ChildProtocolWrapper(SmartProtocol): return smart_method, smart_params async def query(self, request: str | dict, retry_count: int = 3) -> dict: + """Wrap request inside control_child envelope.""" + return await self._query(request, retry_count) + + async def _query(self, request: str | dict, retry_count: int = 3) -> dict: """Wrap request inside control_child envelope.""" method, params = self._get_method_and_params_for_request(request) request_data = { diff --git a/kasa/tests/test_smartdevice.py b/kasa/tests/test_smartdevice.py index 48475a90..44fabc71 100644 --- a/kasa/tests/test_smartdevice.py +++ b/kasa/tests/test_smartdevice.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging -from typing import Any +from typing import Any, cast from unittest.mock import patch import pytest @@ -132,6 +132,78 @@ async def test_update_module_queries(dev: SmartDevice, mocker: MockerFixture): spies[device].assert_not_called() +@device_smart +async def test_update_module_errors(dev: SmartDevice, mocker: MockerFixture): + """Test that modules that error are disabled / removed.""" + # We need to have some modules initialized by now + assert dev._modules + + critical_modules = {Module.DeviceModule, Module.ChildDevice} + not_disabling_modules = {Module.Firmware, Module.Cloud} + + new_dev = SmartDevice("127.0.0.1", protocol=dev.protocol) + + module_queries = { + modname: q + for modname, module in dev._modules.items() + if (q := module.query()) and modname not in critical_modules + } + child_module_queries = { + modname: q + for child in dev.children + for modname, module in child._modules.items() + if (q := module.query()) and modname not in critical_modules + } + all_queries_names = { + key for mod_query in module_queries.values() for key in mod_query + } + all_child_queries_names = { + key for mod_query in child_module_queries.values() for key in mod_query + } + + async def _query(request, *args, **kwargs): + responses = await dev.protocol._query(request, *args, **kwargs) + for k in responses: + if k in all_queries_names: + responses[k] = SmartErrorCode.PARAMS_ERROR + return responses + + async def _child_query(self, request, *args, **kwargs): + responses = await child_protocols[self._device_id]._query( + request, *args, **kwargs + ) + for k in responses: + if k in all_child_queries_names: + responses[k] = SmartErrorCode.PARAMS_ERROR + return responses + + mocker.patch.object(new_dev.protocol, "query", side_effect=_query) + + from kasa.smartprotocol import _ChildProtocolWrapper + + child_protocols = { + cast(_ChildProtocolWrapper, child.protocol)._device_id: child.protocol + for child in dev.children + } + # children not created yet so cannot patch.object + mocker.patch("kasa.smartprotocol._ChildProtocolWrapper.query", new=_child_query) + + await new_dev.update() + for modname in module_queries: + no_disable = modname in not_disabling_modules + mod_present = modname in new_dev._modules + assert ( + mod_present is no_disable + ), f"{modname} present {mod_present} when no_disable {no_disable}" + + for modname in child_module_queries: + no_disable = modname in not_disabling_modules + mod_present = any(modname in child._modules for child in new_dev.children) + assert ( + mod_present is no_disable + ), f"{modname} present {mod_present} when no_disable {no_disable}" + + async def test_get_modules(): """Test getting modules for child and parent modules.""" dummy_device = await get_device_for_fixture_protocol( @@ -181,6 +253,9 @@ async def test_smartdevice_cloud_connection(dev: SmartDevice, mocker: MockerFixt assert dev.is_cloud_connected == is_connected last_update = dev._last_update + for child in dev.children: + mocker.patch.object(child.protocol, "query", return_value=child._last_update) + last_update["get_connect_cloud_state"] = {"status": 0} with patch.object(dev.protocol, "query", return_value=last_update): await dev.update() @@ -207,21 +282,18 @@ async def test_smartdevice_cloud_connection(dev: SmartDevice, mocker: MockerFixt "get_connect_cloud_state": last_update["get_connect_cloud_state"], "get_device_info": last_update["get_device_info"], } - # Child component list is not stored on the device - if "get_child_device_list" in last_update: - child_component_list = await dev.protocol.query( - "get_child_device_component_list" - ) - last_update["get_child_device_component_list"] = child_component_list[ - "get_child_device_component_list" - ] + new_dev = SmartDevice("127.0.0.1", protocol=dev.protocol) first_call = True - def side_effect_func(*_, **__): + async def side_effect_func(*args, **kwargs): nonlocal first_call - resp = initial_response if first_call else last_update + resp = ( + initial_response + if first_call + else await new_dev.protocol._query(*args, **kwargs) + ) first_call = False return resp diff --git a/kasa/tests/test_smartprotocol.py b/kasa/tests/test_smartprotocol.py index d362fd00..71125ca8 100644 --- a/kasa/tests/test_smartprotocol.py +++ b/kasa/tests/test_smartprotocol.py @@ -1,6 +1,7 @@ import logging import pytest +import pytest_mock from ..exceptions import ( SMART_RETRYABLE_ERRORS, @@ -19,6 +20,21 @@ DUMMY_MULTIPLE_QUERY = { ERRORS = [e for e in SmartErrorCode if e != 0] +async def test_smart_queries(dummy_protocol, mocker: pytest_mock.MockerFixture): + mock_response = {"result": {"great": "success"}, "error_code": 0} + + mocker.patch.object(dummy_protocol._transport, "send", return_value=mock_response) + # test sending a method name as a string + resp = await dummy_protocol.query("foobar") + assert "foobar" in resp + assert resp["foobar"] == mock_response["result"] + + # test sending a method name as a dict + resp = await dummy_protocol.query(DUMMY_QUERY) + assert "foobar" in resp + assert resp["foobar"] == mock_response["result"] + + @pytest.mark.parametrize("error_code", ERRORS, ids=lambda e: e.name) async def test_smart_device_errors(dummy_protocol, mocker, error_code): mock_response = {"result": {"great": "success"}, "error_code": error_code.value}