From 7bba9926ed89b50cba503e1d50571bf880cc1433 Mon Sep 17 00:00:00 2001 From: "Steven B." <51370195+sdb9696@users.noreply.github.com> Date: Tue, 30 Jul 2024 19:23:07 +0100 Subject: [PATCH] Allow erroring modules to recover (#1080) Re-query failed modules after some delay instead of immediately disabling them. Changes to features so they can still be created when modules are erroring. --- kasa/feature.py | 73 ++++++---- kasa/interfaces/energy.py | 12 +- kasa/iot/iotdevice.py | 2 +- kasa/iot/modules/ambientlight.py | 2 +- kasa/iot/modules/light.py | 3 +- kasa/smart/modules/alarm.py | 2 +- kasa/smart/modules/autooff.py | 2 +- kasa/smart/modules/batterysensor.py | 2 +- kasa/smart/modules/brightness.py | 3 +- kasa/smart/modules/cloud.py | 7 - kasa/smart/modules/energy.py | 10 +- kasa/smart/modules/fan.py | 3 +- kasa/smart/modules/humiditysensor.py | 2 +- kasa/smart/modules/lighttransition.py | 4 +- kasa/smart/modules/reportmode.py | 2 +- kasa/smart/modules/temperaturecontrol.py | 3 +- kasa/smart/modules/temperaturesensor.py | 2 +- kasa/smart/smartchilddevice.py | 13 +- kasa/smart/smartdevice.py | 74 +++++----- kasa/smart/smartmodule.py | 53 +++++++ kasa/tests/fakeprotocol_smart.py | 1 + kasa/tests/test_feature.py | 4 +- kasa/tests/test_smartdevice.py | 172 ++++++++++++----------- 23 files changed, 264 insertions(+), 187 deletions(-) diff --git a/kasa/feature.py b/kasa/feature.py index ab73f991..18bed554 100644 --- a/kasa/feature.py +++ b/kasa/feature.py @@ -69,6 +69,7 @@ from __future__ import annotations import logging from dataclasses import dataclass from enum import Enum, auto +from functools import cached_property from typing import TYPE_CHECKING, Any, Callable if TYPE_CHECKING: @@ -142,11 +143,9 @@ class Feature: container: Any = None #: Icon suggestion icon: str | None = None - #: Unit, if applicable - unit: str | None = None #: Attribute containing the name of the unit getter property. - #: If set, this property will be used to set *unit*. - unit_getter: str | None = None + #: If set, this property will be used to get the *unit*. + unit_getter: str | Callable[[], str] | None = None #: Category hint for downstreams category: Feature.Category = Category.Unset @@ -154,38 +153,18 @@ class Feature: #: Hint to help rounding the sensor values to given after-comma digits precision_hint: int | None = None - # Number-specific attributes - #: Minimum value - minimum_value: int = 0 - #: Maximum value - maximum_value: int = DEFAULT_MAX #: Attribute containing the name of the range getter property. #: If set, this property will be used to set *minimum_value* and *maximum_value*. - range_getter: str | None = None + range_getter: str | Callable[[], tuple[int, int]] | None = None - # Choice-specific attributes - #: List of choices as enum - choices: list[str] | None = None #: Attribute name of the choices getter property. - #: If set, this property will be used to set *choices*. - choices_getter: str | None = None + #: If set, this property will be used to get *choices*. + choices_getter: str | Callable[[], list[str]] | None = None def __post_init__(self): """Handle late-binding of members.""" # Populate minimum & maximum values, if range_getter is given - container = self.container if self.container is not None else self.device - if self.range_getter is not None: - self.minimum_value, self.maximum_value = getattr( - container, self.range_getter - ) - - # Populate choices, if choices_getter is given - if self.choices_getter is not None: - self.choices = getattr(container, self.choices_getter) - - # Populate unit, if unit_getter is given - if self.unit_getter is not None: - self.unit = getattr(container, self.unit_getter) + self._container = self.container if self.container is not None else self.device # Set the category, if unset if self.category is Feature.Category.Unset: @@ -208,6 +187,44 @@ class Feature: f"Read-only feat defines attribute_setter: {self.name} ({self.id}):" ) + def _get_property_value(self, getter): + if getter is None: + return None + if isinstance(getter, str): + return getattr(self._container, getter) + if callable(getter): + return getter() + raise ValueError("Invalid getter: %s", getter) # pragma: no cover + + @property + def choices(self) -> list[str] | None: + """List of choices.""" + return self._get_property_value(self.choices_getter) + + @property + def unit(self) -> str | None: + """Unit if applicable.""" + return self._get_property_value(self.unit_getter) + + @cached_property + def range(self) -> tuple[int, int] | None: + """Range of values if applicable.""" + return self._get_property_value(self.range_getter) + + @cached_property + def maximum_value(self) -> int: + """Maximum value.""" + if range := self.range: + return range[1] + return self.DEFAULT_MAX + + @cached_property + def minimum_value(self) -> int: + """Minimum value.""" + if range := self.range: + return range[0] + return 0 + @property def value(self): """Return the current value.""" diff --git a/kasa/interfaces/energy.py b/kasa/interfaces/energy.py index 76859647..51579322 100644 --- a/kasa/interfaces/energy.py +++ b/kasa/interfaces/energy.py @@ -40,7 +40,7 @@ class Energy(Module, ABC): name="Current consumption", attribute_getter="current_consumption", container=self, - unit="W", + unit_getter=lambda: "W", id="current_consumption", precision_hint=1, category=Feature.Category.Primary, @@ -53,7 +53,7 @@ class Energy(Module, ABC): name="Today's consumption", attribute_getter="consumption_today", container=self, - unit="kWh", + unit_getter=lambda: "kWh", id="consumption_today", precision_hint=3, category=Feature.Category.Info, @@ -67,7 +67,7 @@ class Energy(Module, ABC): name="This month's consumption", attribute_getter="consumption_this_month", container=self, - unit="kWh", + unit_getter=lambda: "kWh", precision_hint=3, category=Feature.Category.Info, type=Feature.Type.Sensor, @@ -80,7 +80,7 @@ class Energy(Module, ABC): name="Total consumption since reboot", attribute_getter="consumption_total", container=self, - unit="kWh", + unit_getter=lambda: "kWh", id="consumption_total", precision_hint=3, category=Feature.Category.Info, @@ -94,7 +94,7 @@ class Energy(Module, ABC): name="Voltage", attribute_getter="voltage", container=self, - unit="V", + unit_getter=lambda: "V", id="voltage", precision_hint=1, category=Feature.Category.Primary, @@ -107,7 +107,7 @@ class Energy(Module, ABC): name="Current", attribute_getter="current", container=self, - unit="A", + unit_getter=lambda: "A", id="current", precision_hint=2, category=Feature.Category.Primary, diff --git a/kasa/iot/iotdevice.py b/kasa/iot/iotdevice.py index 28ae1228..234ea9fe 100755 --- a/kasa/iot/iotdevice.py +++ b/kasa/iot/iotdevice.py @@ -340,7 +340,7 @@ class IotDevice(Device): name="RSSI", attribute_getter="rssi", icon="mdi:signal", - unit="dBm", + unit_getter=lambda: "dBm", category=Feature.Category.Debug, type=Feature.Type.Sensor, ) diff --git a/kasa/iot/modules/ambientlight.py b/kasa/iot/modules/ambientlight.py index d49768ef..fd693ed5 100644 --- a/kasa/iot/modules/ambientlight.py +++ b/kasa/iot/modules/ambientlight.py @@ -28,7 +28,7 @@ class AmbientLight(IotModule): attribute_getter="ambientlight_brightness", type=Feature.Type.Sensor, category=Feature.Category.Primary, - unit="%", + unit_getter=lambda: "%", ) ) diff --git a/kasa/iot/modules/light.py b/kasa/iot/modules/light.py index 8c4e22c9..358771a6 100644 --- a/kasa/iot/modules/light.py +++ b/kasa/iot/modules/light.py @@ -41,8 +41,7 @@ class Light(IotModule, LightInterface): container=self, attribute_getter="brightness", attribute_setter="set_brightness", - minimum_value=BRIGHTNESS_MIN, - maximum_value=BRIGHTNESS_MAX, + range_getter=lambda: (BRIGHTNESS_MIN, BRIGHTNESS_MAX), type=Feature.Type.Number, category=Feature.Category.Primary, ) diff --git a/kasa/smart/modules/alarm.py b/kasa/smart/modules/alarm.py index 89f133f5..439bc571 100644 --- a/kasa/smart/modules/alarm.py +++ b/kasa/smart/modules/alarm.py @@ -69,7 +69,7 @@ class Alarm(SmartModule): attribute_setter="set_alarm_volume", category=Feature.Category.Config, type=Feature.Type.Choice, - choices=["low", "normal", "high"], + choices_getter=lambda: ["low", "normal", "high"], ) ) self._add_feature( diff --git a/kasa/smart/modules/autooff.py b/kasa/smart/modules/autooff.py index 5e4b100f..ae1bb082 100644 --- a/kasa/smart/modules/autooff.py +++ b/kasa/smart/modules/autooff.py @@ -39,7 +39,7 @@ class AutoOff(SmartModule): attribute_getter="delay", attribute_setter="set_delay", type=Feature.Type.Number, - unit="min", # ha-friendly unit, see UnitOfTime.MINUTES + unit_getter=lambda: "min", # ha-friendly unit, see UnitOfTime.MINUTES ) ) self._add_feature( diff --git a/kasa/smart/modules/batterysensor.py b/kasa/smart/modules/batterysensor.py index 7ff7df2d..7ecfad20 100644 --- a/kasa/smart/modules/batterysensor.py +++ b/kasa/smart/modules/batterysensor.py @@ -37,7 +37,7 @@ class BatterySensor(SmartModule): container=self, attribute_getter="battery", icon="mdi:battery", - unit="%", + unit_getter=lambda: "%", category=Feature.Category.Info, type=Feature.Type.Sensor, ) diff --git a/kasa/smart/modules/brightness.py b/kasa/smart/modules/brightness.py index f5e6d6d6..f6e5c322 100644 --- a/kasa/smart/modules/brightness.py +++ b/kasa/smart/modules/brightness.py @@ -27,8 +27,7 @@ class Brightness(SmartModule): container=self, attribute_getter="brightness", attribute_setter="set_brightness", - minimum_value=BRIGHTNESS_MIN, - maximum_value=BRIGHTNESS_MAX, + range_getter=lambda: (BRIGHTNESS_MIN, BRIGHTNESS_MAX), type=Feature.Type.Number, category=Feature.Category.Primary, ) diff --git a/kasa/smart/modules/cloud.py b/kasa/smart/modules/cloud.py index e7513a56..e66f1858 100644 --- a/kasa/smart/modules/cloud.py +++ b/kasa/smart/modules/cloud.py @@ -18,13 +18,6 @@ class Cloud(SmartModule): REQUIRED_COMPONENT = "cloud_connect" MINIMUM_UPDATE_INTERVAL_SECS = 60 - def _post_update_hook(self): - """Perform actions after a device update. - - Overrides the default behaviour to disable a module if the query returns - an error because the logic here is to treat that as not connected. - """ - def __init__(self, device: SmartDevice, module: str): super().__init__(device, module) diff --git a/kasa/smart/modules/energy.py b/kasa/smart/modules/energy.py index 3edbddb4..166f688e 100644 --- a/kasa/smart/modules/energy.py +++ b/kasa/smart/modules/energy.py @@ -5,7 +5,7 @@ from __future__ import annotations from ...emeterstatus import EmeterStatus from ...exceptions import KasaException from ...interfaces.energy import Energy as EnergyInterface -from ..smartmodule import SmartModule +from ..smartmodule import SmartModule, raise_if_update_error class Energy(SmartModule, EnergyInterface): @@ -23,6 +23,7 @@ class Energy(SmartModule, EnergyInterface): return req @property + @raise_if_update_error def current_consumption(self) -> float | None: """Current power in watts.""" if (power := self.energy.get("current_power")) is not None: @@ -30,6 +31,7 @@ class Energy(SmartModule, EnergyInterface): return None @property + @raise_if_update_error def energy(self): """Return get_energy_usage results.""" if en := self.data.get("get_energy_usage"): @@ -45,6 +47,7 @@ class Energy(SmartModule, EnergyInterface): ) @property + @raise_if_update_error def status(self): """Get the emeter status.""" return self._get_status_from_energy(self.energy) @@ -55,26 +58,31 @@ class Energy(SmartModule, EnergyInterface): return self._get_status_from_energy(res["get_energy_usage"]) @property + @raise_if_update_error def consumption_this_month(self) -> float | None: """Get the emeter value for this month in kWh.""" return self.energy.get("month_energy") / 1_000 @property + @raise_if_update_error def consumption_today(self) -> float | None: """Get the emeter value for today in kWh.""" return self.energy.get("today_energy") / 1_000 @property + @raise_if_update_error def consumption_total(self) -> float | None: """Return total consumption since last reboot in kWh.""" return None @property + @raise_if_update_error def current(self) -> float | None: """Return the current in A.""" return None @property + @raise_if_update_error def voltage(self) -> float | None: """Get the current voltage in V.""" return None diff --git a/kasa/smart/modules/fan.py b/kasa/smart/modules/fan.py index 153f9c8f..245bef2c 100644 --- a/kasa/smart/modules/fan.py +++ b/kasa/smart/modules/fan.py @@ -30,8 +30,7 @@ class Fan(SmartModule, FanInterface): attribute_setter="set_fan_speed_level", icon="mdi:fan", type=Feature.Type.Number, - minimum_value=0, - maximum_value=4, + range_getter=lambda: (0, 4), category=Feature.Category.Primary, ) ) diff --git a/kasa/smart/modules/humiditysensor.py b/kasa/smart/modules/humiditysensor.py index b137736f..606b1d54 100644 --- a/kasa/smart/modules/humiditysensor.py +++ b/kasa/smart/modules/humiditysensor.py @@ -27,7 +27,7 @@ class HumiditySensor(SmartModule): container=self, attribute_getter="humidity", icon="mdi:water-percent", - unit="%", + unit_getter=lambda: "%", category=Feature.Category.Primary, type=Feature.Type.Sensor, ) diff --git a/kasa/smart/modules/lighttransition.py b/kasa/smart/modules/lighttransition.py index e0aeb4d7..da05995d 100644 --- a/kasa/smart/modules/lighttransition.py +++ b/kasa/smart/modules/lighttransition.py @@ -73,7 +73,7 @@ class LightTransition(SmartModule): attribute_setter="set_turn_on_transition", icon=icon, type=Feature.Type.Number, - maximum_value=self._turn_on_transition_max, + range_getter=lambda: (0, self._turn_on_transition_max), ) ) self._add_feature( @@ -86,7 +86,7 @@ class LightTransition(SmartModule): attribute_setter="set_turn_off_transition", icon=icon, type=Feature.Type.Number, - maximum_value=self._turn_off_transition_max, + range_getter=lambda: (0, self._turn_off_transition_max), ) ) diff --git a/kasa/smart/modules/reportmode.py b/kasa/smart/modules/reportmode.py index 8d210a5b..d2c9d929 100644 --- a/kasa/smart/modules/reportmode.py +++ b/kasa/smart/modules/reportmode.py @@ -26,7 +26,7 @@ class ReportMode(SmartModule): name="Report interval", container=self, attribute_getter="report_interval", - unit="s", + unit_getter=lambda: "s", category=Feature.Category.Debug, type=Feature.Type.Sensor, ) diff --git a/kasa/smart/modules/temperaturecontrol.py b/kasa/smart/modules/temperaturecontrol.py index 00afe5b5..96630ce5 100644 --- a/kasa/smart/modules/temperaturecontrol.py +++ b/kasa/smart/modules/temperaturecontrol.py @@ -51,8 +51,7 @@ class TemperatureControl(SmartModule): container=self, attribute_getter="temperature_offset", attribute_setter="set_temperature_offset", - minimum_value=-10, - maximum_value=10, + range_getter=lambda: (-10, 10), type=Feature.Type.Number, category=Feature.Category.Config, ) diff --git a/kasa/smart/modules/temperaturesensor.py b/kasa/smart/modules/temperaturesensor.py index a61859cd..1741b26b 100644 --- a/kasa/smart/modules/temperaturesensor.py +++ b/kasa/smart/modules/temperaturesensor.py @@ -54,7 +54,7 @@ class TemperatureSensor(SmartModule): attribute_getter="temperature_unit", attribute_setter="set_temperature_unit", type=Feature.Type.Choice, - choices=["celsius", "fahrenheit"], + choices_getter=lambda: ["celsius", "fahrenheit"], ) ) diff --git a/kasa/smart/smartchilddevice.py b/kasa/smart/smartchilddevice.py index 679692ba..8fe3b969 100644 --- a/kasa/smart/smartchilddevice.py +++ b/kasa/smart/smartchilddevice.py @@ -10,6 +10,7 @@ from ..device_type import DeviceType from ..deviceconfig import DeviceConfig from ..smartprotocol import SmartProtocol, _ChildProtocolWrapper from .smartdevice import SmartDevice +from .smartmodule import SmartModule _LOGGER = logging.getLogger(__name__) @@ -49,13 +50,21 @@ class SmartChildDevice(SmartDevice): Internal implementation to allow patching of public update in the cli or test framework. """ + now = time.monotonic() + module_queries: list[SmartModule] = [] req: dict[str, Any] = {} for module in self.modules.values(): - if mod_query := module.query(): + if module.disabled is False and (mod_query := module.query()): + module_queries.append(module) req.update(mod_query) if req: self._last_update = await self.protocol.query(req) - self._last_update_time = time.time() + + for module in self.modules.values(): + self._handle_module_post_update( + module, now, had_query=module in module_queries + ) + self._last_update_time = now @classmethod async def create(cls, parent: SmartDevice, child_info, child_components): diff --git a/kasa/smart/smartdevice.py b/kasa/smart/smartdevice.py index fcdbef97..04a9608a 100644 --- a/kasa/smart/smartdevice.py +++ b/kasa/smart/smartdevice.py @@ -165,28 +165,25 @@ class SmartDevice(Device): if first_update: await self._negotiate() await self._initialize_modules() + # Run post update for the cloud module + if cloud_mod := self.modules.get(Module.Cloud): + self._handle_module_post_update(cloud_mod, now, had_query=True) resp = await self._modular_update(first_update, now) - # Call child update which will only update module calls, info is updated - # from get_child_device_list. update_children only affects hub devices, other - # devices will always update children to prevent errors on module access. - if update_children or self.device_type != DeviceType.Hub: - for child in self._children.values(): - await child._update() if child_info := self._try_get_response( self._last_update, "get_child_device_list", {} ): for info in child_info["child_device_list"]: self._children[info["device_id"]]._update_internal_state(info) - - for child in self._children.values(): - errors = [] - for child_module_name, child_module in child._modules.items(): - if not self._handle_module_post_update_hook(child_module): - errors.append(child_module_name) - for error in errors: - child._modules.pop(error) + # Call child update which will only update module calls, info is updated + # from get_child_device_list. update_children only affects hub devices, other + # devices will always update children to prevent errors on module access. + # This needs to go after updating the internal state of the children so that + # child modules have access to their sysinfo. + if update_children or self.device_type != DeviceType.Hub: + for child in self._children.values(): + await child._update() # We can first initialize the features after the first update. # We make here an assumption that every device has at least a single feature. @@ -197,18 +194,26 @@ class SmartDevice(Device): updated = self._last_update if first_update else resp _LOGGER.debug("Update completed %s: %s", self.host, list(updated.keys())) - def _handle_module_post_update_hook(self, module: SmartModule) -> bool: + def _handle_module_post_update( + self, module: SmartModule, update_time: float, had_query: bool + ): + if module.disabled: + return # pragma: no cover + if had_query: + module._last_update_time = update_time try: module._post_update_hook() - return True + module._set_error(None) except Exception as ex: - _LOGGER.warning( - "Error processing %s for device %s, module will be unavailable: %s", - module.name, - self.host, - ex, - ) - return False + # Only set the error if a query happened. + if had_query: + module._set_error(ex) + _LOGGER.warning( + "Error processing %s for device %s, module will be unavailable: %s", + module.name, + self.host, + ex, + ) async def _modular_update( self, first_update: bool, update_time: float @@ -221,17 +226,16 @@ class SmartDevice(Device): mq = { module: query for module in self._modules.values() - if (query := module.query()) + if module.disabled is False and (query := module.query()) } for module, query in mq.items(): if first_update and module.__class__ in FIRST_UPDATE_MODULES: module._last_update_time = update_time continue if ( - not module.MINIMUM_UPDATE_INTERVAL_SECS + not module.update_interval or not module._last_update_time - or (update_time - module._last_update_time) - >= module.MINIMUM_UPDATE_INTERVAL_SECS + or (update_time - module._last_update_time) >= module.update_interval ): module_queries.append(module) req.update(query) @@ -254,16 +258,10 @@ class SmartDevice(Device): self._info = self._try_get_response(info_resp, "get_device_info") # Call handle update for modules that want to update internal data - errors = [] - for module_name, module in self._modules.items(): - if not self._handle_module_post_update_hook(module): - errors.append(module_name) - for error in errors: - self._modules.pop(error) - - # Set the last update time for modules that had queries made. - for module in module_queries: - module._last_update_time = update_time + for module in self._modules.values(): + self._handle_module_post_update( + module, update_time, had_query=module in module_queries + ) return resp @@ -392,7 +390,7 @@ class SmartDevice(Device): name="RSSI", attribute_getter=lambda x: x._info["rssi"], icon="mdi:signal", - unit="dBm", + unit_getter=lambda: "dBm", category=Feature.Category.Debug, type=Feature.Type.Sensor, ) diff --git a/kasa/smart/smartmodule.py b/kasa/smart/smartmodule.py index f5f2c212..0e6256a0 100644 --- a/kasa/smart/smartmodule.py +++ b/kasa/smart/smartmodule.py @@ -18,6 +18,7 @@ _LOGGER = logging.getLogger(__name__) _T = TypeVar("_T", bound="SmartModule") _P = ParamSpec("_P") +_R = TypeVar("_R") def allow_update_after( @@ -38,6 +39,17 @@ def allow_update_after( return _async_wrap +def raise_if_update_error(func: Callable[[_T], _R]) -> Callable[[_T], _R]: + """Define a wrapper to raise an error if the last module update was an error.""" + + def _wrap(self: _T) -> _R: + if err := self._last_update_error: + raise err + return func(self) + + return _wrap + + class SmartModule(Module): """Base class for SMART modules.""" @@ -52,17 +64,58 @@ class SmartModule(Module): REGISTERED_MODULES: dict[str, type[SmartModule]] = {} MINIMUM_UPDATE_INTERVAL_SECS = 0 + UPDATE_INTERVAL_AFTER_ERROR_SECS = 30 + + DISABLE_AFTER_ERROR_COUNT = 10 def __init__(self, device: SmartDevice, module: str): self._device: SmartDevice super().__init__(device, module) self._last_update_time: float | None = None + self._last_update_error: KasaException | None = None + self._error_count = 0 def __init_subclass__(cls, **kwargs): name = getattr(cls, "NAME", cls.__name__) _LOGGER.debug("Registering %s" % cls) cls.REGISTERED_MODULES[name] = cls + def _set_error(self, err: Exception | None): + if err is None: + self._error_count = 0 + self._last_update_error = None + else: + self._last_update_error = KasaException("Module update error", err) + self._error_count += 1 + if self._error_count == self.DISABLE_AFTER_ERROR_COUNT: + _LOGGER.error( + "Error processing %s for device %s, module will be disabled: %s", + self.name, + self._device.host, + err, + ) + if self._error_count > self.DISABLE_AFTER_ERROR_COUNT: + _LOGGER.error( # pragma: no cover + "Unexpected error processing %s for device %s, " + "module should be disabled: %s", + self.name, + self._device.host, + err, + ) + + @property + def update_interval(self) -> int: + """Time to wait between updates.""" + if self._last_update_error is None: + return self.MINIMUM_UPDATE_INTERVAL_SECS + + return self.UPDATE_INTERVAL_AFTER_ERROR_SECS * self._error_count + + @property + def disabled(self) -> bool: + """Return true if the module is disabled due to errors.""" + return self._error_count >= self.DISABLE_AFTER_ERROR_COUNT + @property def name(self) -> str: """Name of the module.""" diff --git a/kasa/tests/fakeprotocol_smart.py b/kasa/tests/fakeprotocol_smart.py index 7a54be17..40465b6f 100644 --- a/kasa/tests/fakeprotocol_smart.py +++ b/kasa/tests/fakeprotocol_smart.py @@ -114,6 +114,7 @@ class FakeSmartTransport(BaseTransport): }, ), "get_device_usage": ("device", {}), + "get_connect_cloud_state": ("cloud_connect", {"status": 0}), } async def send(self, request: str): diff --git a/kasa/tests/test_feature.py b/kasa/tests/test_feature.py index 440c9c1b..fd400856 100644 --- a/kasa/tests/test_feature.py +++ b/kasa/tests/test_feature.py @@ -27,7 +27,7 @@ def dummy_feature() -> Feature: container=None, icon="mdi:dummy", type=Feature.Type.Switch, - unit="dummyunit", + unit_getter=lambda: "dummyunit", ) return feat @@ -127,7 +127,7 @@ async def test_feature_action(mocker): async def test_feature_choice_list(dummy_feature, caplog, mocker: MockerFixture): """Test the choice feature type.""" dummy_feature.type = Feature.Type.Choice - dummy_feature.choices = ["first", "second"] + dummy_feature.choices_getter = lambda: ["first", "second"] mock_setter = mocker.patch.object(dummy_feature.device, "dummysetter", create=True) await dummy_feature.set_value("first") diff --git a/kasa/tests/test_smartdevice.py b/kasa/tests/test_smartdevice.py index 4e670644..d96542e5 100644 --- a/kasa/tests/test_smartdevice.py +++ b/kasa/tests/test_smartdevice.py @@ -12,8 +12,11 @@ from freezegun.api import FrozenDateTimeFactory from pytest_mock import MockerFixture from kasa import Device, KasaException, Module -from kasa.exceptions import SmartErrorCode +from kasa.exceptions import DeviceError, SmartErrorCode from kasa.smart import SmartDevice +from kasa.smart.modules.energy import Energy +from kasa.smart.smartmodule import SmartModule +from kasa.smartprotocol import _ChildProtocolWrapper from .conftest import ( device_smart, @@ -139,78 +142,6 @@ async def test_update_module_queries(dev: SmartDevice, mocker: MockerFixture): spies[device].assert_not_called() -@device_smart -async def test_update_module_errors(dev: SmartDevice, mocker: MockerFixture): - """Test that modules that error are disabled / removed.""" - # We need to have some modules initialized by now - assert dev._modules - - critical_modules = {Module.DeviceModule, Module.ChildDevice} - not_disabling_modules = {Module.Cloud} - - new_dev = SmartDevice("127.0.0.1", protocol=dev.protocol) - - module_queries = { - modname: q - for modname, module in dev._modules.items() - if (q := module.query()) and modname not in critical_modules - } - child_module_queries = { - modname: q - for child in dev.children - for modname, module in child._modules.items() - if (q := module.query()) and modname not in critical_modules - } - all_queries_names = { - key for mod_query in module_queries.values() for key in mod_query - } - all_child_queries_names = { - key for mod_query in child_module_queries.values() for key in mod_query - } - - async def _query(request, *args, **kwargs): - responses = await dev.protocol._query(request, *args, **kwargs) - for k in responses: - if k in all_queries_names: - responses[k] = SmartErrorCode.PARAMS_ERROR - return responses - - async def _child_query(self, request, *args, **kwargs): - responses = await child_protocols[self._device_id]._query( - request, *args, **kwargs - ) - for k in responses: - if k in all_child_queries_names: - responses[k] = SmartErrorCode.PARAMS_ERROR - return responses - - mocker.patch.object(new_dev.protocol, "query", side_effect=_query) - - from kasa.smartprotocol import _ChildProtocolWrapper - - child_protocols = { - cast(_ChildProtocolWrapper, child.protocol)._device_id: child.protocol - for child in dev.children - } - # children not created yet so cannot patch.object - mocker.patch("kasa.smartprotocol._ChildProtocolWrapper.query", new=_child_query) - - await new_dev.update() - for modname in module_queries: - no_disable = modname in not_disabling_modules - mod_present = modname in new_dev._modules - assert ( - mod_present is no_disable - ), f"{modname} present {mod_present} when no_disable {no_disable}" - - for modname in child_module_queries: - no_disable = modname in not_disabling_modules - mod_present = any(modname in child._modules for child in new_dev.children) - assert ( - mod_present is no_disable - ), f"{modname} present {mod_present} when no_disable {no_disable}" - - @device_smart async def test_update_module_update_delays( dev: SmartDevice, @@ -218,7 +149,7 @@ async def test_update_module_update_delays( caplog: pytest.LogCaptureFixture, freezer: FrozenDateTimeFactory, ): - """Test that modules that disabled / removed on query failures.""" + """Test that modules with minimum delays delay.""" # We need to have some modules initialized by now assert dev._modules @@ -257,6 +188,20 @@ async def test_update_module_update_delays( pytest.param(False, id="First update false"), ], ) +@pytest.mark.parametrize( + ("error_type"), + [ + pytest.param(SmartErrorCode.PARAMS_ERROR, id="Device error"), + pytest.param(TimeoutError("Dummy timeout"), id="Query error"), + ], +) +@pytest.mark.parametrize( + ("recover"), + [ + pytest.param(True, id="recover"), + pytest.param(False, id="no recover"), + ], +) @device_smart async def test_update_module_query_errors( dev: SmartDevice, @@ -264,15 +209,20 @@ async def test_update_module_query_errors( caplog: pytest.LogCaptureFixture, freezer: FrozenDateTimeFactory, first_update, + error_type, + recover, ): - """Test that modules that disabled / removed on query failures.""" + """Test that modules that disabled / removed on query failures. + + i.e. the whole query times out rather than device returns an error. + """ # We need to have some modules initialized by now assert dev._modules + SmartModule.DISABLE_AFTER_ERROR_COUNT = 2 first_update_queries = {"get_device_info", "get_connect_cloud_state"} critical_modules = {Module.DeviceModule, Module.ChildDevice} - not_disabling_modules = {Module.Cloud} new_dev = SmartDevice("127.0.0.1", protocol=dev.protocol) if not first_update: @@ -293,13 +243,18 @@ async def test_update_module_query_errors( or "get_child_device_component_list" in request or "control_child" in request ): - return await dev.protocol._query(request, *args, **kwargs) + resp = await dev.protocol._query(request, *args, **kwargs) + resp["get_connect_cloud_state"] = SmartErrorCode.CLOUD_FAILED_ERROR + return resp + # Don't test for errors on get_device_info as that is likely terminal if len(request) == 1 and "get_device_info" in request: return await dev.protocol._query(request, *args, **kwargs) - raise TimeoutError("Dummy timeout") - - from kasa.smartprotocol import _ChildProtocolWrapper + if isinstance(error_type, SmartErrorCode): + if len(request) == 1: + raise DeviceError("Dummy device error", error_code=error_type) + raise TimeoutError("Dummy timeout") + raise error_type child_protocols = { cast(_ChildProtocolWrapper, child.protocol)._device_id: child.protocol @@ -314,19 +269,66 @@ async def test_update_module_query_errors( mocker.patch("kasa.smartprotocol._ChildProtocolWrapper.query", new=_child_query) await new_dev.update() + msg = f"Error querying {new_dev.host} for modules" assert msg in caplog.text for modname in module_queries: - no_disable = modname in not_disabling_modules - mod_present = modname in new_dev._modules - assert ( - mod_present is no_disable - ), f"{modname} present {mod_present} when no_disable {no_disable}" + mod = cast(SmartModule, new_dev.modules[modname]) + assert mod.disabled is False, f"{modname} disabled" + assert mod.update_interval == mod.UPDATE_INTERVAL_AFTER_ERROR_SECS for mod_query in module_queries[modname]: if not first_update or mod_query not in first_update_queries: msg = f"Error querying {new_dev.host} individually for module query '{mod_query}" assert msg in caplog.text + # Query again should not run for the modules + caplog.clear() + await new_dev.update() + for modname in module_queries: + mod = cast(SmartModule, new_dev.modules[modname]) + assert mod.disabled is False, f"{modname} disabled" + + freezer.tick(SmartModule.UPDATE_INTERVAL_AFTER_ERROR_SECS) + + caplog.clear() + + if recover: + mocker.patch.object( + new_dev.protocol, "query", side_effect=new_dev.protocol._query + ) + mocker.patch( + "kasa.smartprotocol._ChildProtocolWrapper.query", + new=_ChildProtocolWrapper._query, + ) + + await new_dev.update() + msg = f"Error querying {new_dev.host} for modules" + if not recover: + assert msg in caplog.text + for modname in module_queries: + mod = cast(SmartModule, new_dev.modules[modname]) + if not recover: + assert mod.disabled is True, f"{modname} not disabled" + assert mod._error_count == 2 + assert mod._last_update_error + for mod_query in module_queries[modname]: + if not first_update or mod_query not in first_update_queries: + msg = f"Error querying {new_dev.host} individually for module query '{mod_query}" + assert msg in caplog.text + # Test one of the raise_if_update_error + if mod.name == "Energy": + emod = cast(Energy, mod) + with pytest.raises(KasaException, match="Module update error"): + assert emod.current_consumption is not None + else: + assert mod.disabled is False + assert mod._error_count == 0 + assert mod._last_update_error is None + # Test one of the raise_if_update_error doesn't raise + if mod.name == "Energy": + emod = cast(Energy, mod) + assert emod.current_consumption is not None + async def test_get_modules(): """Test getting modules for child and parent modules."""