Allow erroring modules to recover (#1080)

Re-query failed modules after some delay instead of immediately disabling them.
Changes to features so they can still be created when modules are erroring.
This commit is contained in:
Steven B. 2024-07-30 19:23:07 +01:00 committed by GitHub
parent 445f74eed7
commit 7bba9926ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 264 additions and 187 deletions

View File

@ -69,6 +69,7 @@ from __future__ import annotations
import logging
from dataclasses import dataclass
from enum import Enum, auto
from functools import cached_property
from typing import TYPE_CHECKING, Any, Callable
if TYPE_CHECKING:
@ -142,11 +143,9 @@ class Feature:
container: Any = None
#: Icon suggestion
icon: str | None = None
#: Unit, if applicable
unit: str | None = None
#: Attribute containing the name of the unit getter property.
#: If set, this property will be used to set *unit*.
unit_getter: str | None = None
#: If set, this property will be used to get the *unit*.
unit_getter: str | Callable[[], str] | None = None
#: Category hint for downstreams
category: Feature.Category = Category.Unset
@ -154,38 +153,18 @@ class Feature:
#: Hint to help rounding the sensor values to given after-comma digits
precision_hint: int | None = None
# Number-specific attributes
#: Minimum value
minimum_value: int = 0
#: Maximum value
maximum_value: int = DEFAULT_MAX
#: Attribute containing the name of the range getter property.
#: If set, this property will be used to set *minimum_value* and *maximum_value*.
range_getter: str | None = None
range_getter: str | Callable[[], tuple[int, int]] | None = None
# Choice-specific attributes
#: List of choices as enum
choices: list[str] | None = None
#: Attribute name of the choices getter property.
#: If set, this property will be used to set *choices*.
choices_getter: str | None = None
#: If set, this property will be used to get *choices*.
choices_getter: str | Callable[[], list[str]] | None = None
def __post_init__(self):
"""Handle late-binding of members."""
# Populate minimum & maximum values, if range_getter is given
container = self.container if self.container is not None else self.device
if self.range_getter is not None:
self.minimum_value, self.maximum_value = getattr(
container, self.range_getter
)
# Populate choices, if choices_getter is given
if self.choices_getter is not None:
self.choices = getattr(container, self.choices_getter)
# Populate unit, if unit_getter is given
if self.unit_getter is not None:
self.unit = getattr(container, self.unit_getter)
self._container = self.container if self.container is not None else self.device
# Set the category, if unset
if self.category is Feature.Category.Unset:
@ -208,6 +187,44 @@ class Feature:
f"Read-only feat defines attribute_setter: {self.name} ({self.id}):"
)
def _get_property_value(self, getter):
if getter is None:
return None
if isinstance(getter, str):
return getattr(self._container, getter)
if callable(getter):
return getter()
raise ValueError("Invalid getter: %s", getter) # pragma: no cover
@property
def choices(self) -> list[str] | None:
"""List of choices."""
return self._get_property_value(self.choices_getter)
@property
def unit(self) -> str | None:
"""Unit if applicable."""
return self._get_property_value(self.unit_getter)
@cached_property
def range(self) -> tuple[int, int] | None:
"""Range of values if applicable."""
return self._get_property_value(self.range_getter)
@cached_property
def maximum_value(self) -> int:
"""Maximum value."""
if range := self.range:
return range[1]
return self.DEFAULT_MAX
@cached_property
def minimum_value(self) -> int:
"""Minimum value."""
if range := self.range:
return range[0]
return 0
@property
def value(self):
"""Return the current value."""

View File

@ -40,7 +40,7 @@ class Energy(Module, ABC):
name="Current consumption",
attribute_getter="current_consumption",
container=self,
unit="W",
unit_getter=lambda: "W",
id="current_consumption",
precision_hint=1,
category=Feature.Category.Primary,
@ -53,7 +53,7 @@ class Energy(Module, ABC):
name="Today's consumption",
attribute_getter="consumption_today",
container=self,
unit="kWh",
unit_getter=lambda: "kWh",
id="consumption_today",
precision_hint=3,
category=Feature.Category.Info,
@ -67,7 +67,7 @@ class Energy(Module, ABC):
name="This month's consumption",
attribute_getter="consumption_this_month",
container=self,
unit="kWh",
unit_getter=lambda: "kWh",
precision_hint=3,
category=Feature.Category.Info,
type=Feature.Type.Sensor,
@ -80,7 +80,7 @@ class Energy(Module, ABC):
name="Total consumption since reboot",
attribute_getter="consumption_total",
container=self,
unit="kWh",
unit_getter=lambda: "kWh",
id="consumption_total",
precision_hint=3,
category=Feature.Category.Info,
@ -94,7 +94,7 @@ class Energy(Module, ABC):
name="Voltage",
attribute_getter="voltage",
container=self,
unit="V",
unit_getter=lambda: "V",
id="voltage",
precision_hint=1,
category=Feature.Category.Primary,
@ -107,7 +107,7 @@ class Energy(Module, ABC):
name="Current",
attribute_getter="current",
container=self,
unit="A",
unit_getter=lambda: "A",
id="current",
precision_hint=2,
category=Feature.Category.Primary,

View File

@ -340,7 +340,7 @@ class IotDevice(Device):
name="RSSI",
attribute_getter="rssi",
icon="mdi:signal",
unit="dBm",
unit_getter=lambda: "dBm",
category=Feature.Category.Debug,
type=Feature.Type.Sensor,
)

View File

@ -28,7 +28,7 @@ class AmbientLight(IotModule):
attribute_getter="ambientlight_brightness",
type=Feature.Type.Sensor,
category=Feature.Category.Primary,
unit="%",
unit_getter=lambda: "%",
)
)

View File

@ -41,8 +41,7 @@ class Light(IotModule, LightInterface):
container=self,
attribute_getter="brightness",
attribute_setter="set_brightness",
minimum_value=BRIGHTNESS_MIN,
maximum_value=BRIGHTNESS_MAX,
range_getter=lambda: (BRIGHTNESS_MIN, BRIGHTNESS_MAX),
type=Feature.Type.Number,
category=Feature.Category.Primary,
)

View File

@ -69,7 +69,7 @@ class Alarm(SmartModule):
attribute_setter="set_alarm_volume",
category=Feature.Category.Config,
type=Feature.Type.Choice,
choices=["low", "normal", "high"],
choices_getter=lambda: ["low", "normal", "high"],
)
)
self._add_feature(

View File

@ -39,7 +39,7 @@ class AutoOff(SmartModule):
attribute_getter="delay",
attribute_setter="set_delay",
type=Feature.Type.Number,
unit="min", # ha-friendly unit, see UnitOfTime.MINUTES
unit_getter=lambda: "min", # ha-friendly unit, see UnitOfTime.MINUTES
)
)
self._add_feature(

View File

@ -37,7 +37,7 @@ class BatterySensor(SmartModule):
container=self,
attribute_getter="battery",
icon="mdi:battery",
unit="%",
unit_getter=lambda: "%",
category=Feature.Category.Info,
type=Feature.Type.Sensor,
)

View File

@ -27,8 +27,7 @@ class Brightness(SmartModule):
container=self,
attribute_getter="brightness",
attribute_setter="set_brightness",
minimum_value=BRIGHTNESS_MIN,
maximum_value=BRIGHTNESS_MAX,
range_getter=lambda: (BRIGHTNESS_MIN, BRIGHTNESS_MAX),
type=Feature.Type.Number,
category=Feature.Category.Primary,
)

View File

@ -18,13 +18,6 @@ class Cloud(SmartModule):
REQUIRED_COMPONENT = "cloud_connect"
MINIMUM_UPDATE_INTERVAL_SECS = 60
def _post_update_hook(self):
"""Perform actions after a device update.
Overrides the default behaviour to disable a module if the query returns
an error because the logic here is to treat that as not connected.
"""
def __init__(self, device: SmartDevice, module: str):
super().__init__(device, module)

View File

@ -5,7 +5,7 @@ from __future__ import annotations
from ...emeterstatus import EmeterStatus
from ...exceptions import KasaException
from ...interfaces.energy import Energy as EnergyInterface
from ..smartmodule import SmartModule
from ..smartmodule import SmartModule, raise_if_update_error
class Energy(SmartModule, EnergyInterface):
@ -23,6 +23,7 @@ class Energy(SmartModule, EnergyInterface):
return req
@property
@raise_if_update_error
def current_consumption(self) -> float | None:
"""Current power in watts."""
if (power := self.energy.get("current_power")) is not None:
@ -30,6 +31,7 @@ class Energy(SmartModule, EnergyInterface):
return None
@property
@raise_if_update_error
def energy(self):
"""Return get_energy_usage results."""
if en := self.data.get("get_energy_usage"):
@ -45,6 +47,7 @@ class Energy(SmartModule, EnergyInterface):
)
@property
@raise_if_update_error
def status(self):
"""Get the emeter status."""
return self._get_status_from_energy(self.energy)
@ -55,26 +58,31 @@ class Energy(SmartModule, EnergyInterface):
return self._get_status_from_energy(res["get_energy_usage"])
@property
@raise_if_update_error
def consumption_this_month(self) -> float | None:
"""Get the emeter value for this month in kWh."""
return self.energy.get("month_energy") / 1_000
@property
@raise_if_update_error
def consumption_today(self) -> float | None:
"""Get the emeter value for today in kWh."""
return self.energy.get("today_energy") / 1_000
@property
@raise_if_update_error
def consumption_total(self) -> float | None:
"""Return total consumption since last reboot in kWh."""
return None
@property
@raise_if_update_error
def current(self) -> float | None:
"""Return the current in A."""
return None
@property
@raise_if_update_error
def voltage(self) -> float | None:
"""Get the current voltage in V."""
return None

View File

@ -30,8 +30,7 @@ class Fan(SmartModule, FanInterface):
attribute_setter="set_fan_speed_level",
icon="mdi:fan",
type=Feature.Type.Number,
minimum_value=0,
maximum_value=4,
range_getter=lambda: (0, 4),
category=Feature.Category.Primary,
)
)

View File

@ -27,7 +27,7 @@ class HumiditySensor(SmartModule):
container=self,
attribute_getter="humidity",
icon="mdi:water-percent",
unit="%",
unit_getter=lambda: "%",
category=Feature.Category.Primary,
type=Feature.Type.Sensor,
)

View File

@ -73,7 +73,7 @@ class LightTransition(SmartModule):
attribute_setter="set_turn_on_transition",
icon=icon,
type=Feature.Type.Number,
maximum_value=self._turn_on_transition_max,
range_getter=lambda: (0, self._turn_on_transition_max),
)
)
self._add_feature(
@ -86,7 +86,7 @@ class LightTransition(SmartModule):
attribute_setter="set_turn_off_transition",
icon=icon,
type=Feature.Type.Number,
maximum_value=self._turn_off_transition_max,
range_getter=lambda: (0, self._turn_off_transition_max),
)
)

View File

@ -26,7 +26,7 @@ class ReportMode(SmartModule):
name="Report interval",
container=self,
attribute_getter="report_interval",
unit="s",
unit_getter=lambda: "s",
category=Feature.Category.Debug,
type=Feature.Type.Sensor,
)

View File

@ -51,8 +51,7 @@ class TemperatureControl(SmartModule):
container=self,
attribute_getter="temperature_offset",
attribute_setter="set_temperature_offset",
minimum_value=-10,
maximum_value=10,
range_getter=lambda: (-10, 10),
type=Feature.Type.Number,
category=Feature.Category.Config,
)

View File

@ -54,7 +54,7 @@ class TemperatureSensor(SmartModule):
attribute_getter="temperature_unit",
attribute_setter="set_temperature_unit",
type=Feature.Type.Choice,
choices=["celsius", "fahrenheit"],
choices_getter=lambda: ["celsius", "fahrenheit"],
)
)

View File

@ -10,6 +10,7 @@ from ..device_type import DeviceType
from ..deviceconfig import DeviceConfig
from ..smartprotocol import SmartProtocol, _ChildProtocolWrapper
from .smartdevice import SmartDevice
from .smartmodule import SmartModule
_LOGGER = logging.getLogger(__name__)
@ -49,13 +50,21 @@ class SmartChildDevice(SmartDevice):
Internal implementation to allow patching of public update in the cli
or test framework.
"""
now = time.monotonic()
module_queries: list[SmartModule] = []
req: dict[str, Any] = {}
for module in self.modules.values():
if mod_query := module.query():
if module.disabled is False and (mod_query := module.query()):
module_queries.append(module)
req.update(mod_query)
if req:
self._last_update = await self.protocol.query(req)
self._last_update_time = time.time()
for module in self.modules.values():
self._handle_module_post_update(
module, now, had_query=module in module_queries
)
self._last_update_time = now
@classmethod
async def create(cls, parent: SmartDevice, child_info, child_components):

View File

@ -165,28 +165,25 @@ class SmartDevice(Device):
if first_update:
await self._negotiate()
await self._initialize_modules()
# Run post update for the cloud module
if cloud_mod := self.modules.get(Module.Cloud):
self._handle_module_post_update(cloud_mod, now, had_query=True)
resp = await self._modular_update(first_update, now)
# Call child update which will only update module calls, info is updated
# from get_child_device_list. update_children only affects hub devices, other
# devices will always update children to prevent errors on module access.
if update_children or self.device_type != DeviceType.Hub:
for child in self._children.values():
await child._update()
if child_info := self._try_get_response(
self._last_update, "get_child_device_list", {}
):
for info in child_info["child_device_list"]:
self._children[info["device_id"]]._update_internal_state(info)
for child in self._children.values():
errors = []
for child_module_name, child_module in child._modules.items():
if not self._handle_module_post_update_hook(child_module):
errors.append(child_module_name)
for error in errors:
child._modules.pop(error)
# Call child update which will only update module calls, info is updated
# from get_child_device_list. update_children only affects hub devices, other
# devices will always update children to prevent errors on module access.
# This needs to go after updating the internal state of the children so that
# child modules have access to their sysinfo.
if update_children or self.device_type != DeviceType.Hub:
for child in self._children.values():
await child._update()
# We can first initialize the features after the first update.
# We make here an assumption that every device has at least a single feature.
@ -197,18 +194,26 @@ class SmartDevice(Device):
updated = self._last_update if first_update else resp
_LOGGER.debug("Update completed %s: %s", self.host, list(updated.keys()))
def _handle_module_post_update_hook(self, module: SmartModule) -> bool:
def _handle_module_post_update(
self, module: SmartModule, update_time: float, had_query: bool
):
if module.disabled:
return # pragma: no cover
if had_query:
module._last_update_time = update_time
try:
module._post_update_hook()
return True
module._set_error(None)
except Exception as ex:
_LOGGER.warning(
"Error processing %s for device %s, module will be unavailable: %s",
module.name,
self.host,
ex,
)
return False
# Only set the error if a query happened.
if had_query:
module._set_error(ex)
_LOGGER.warning(
"Error processing %s for device %s, module will be unavailable: %s",
module.name,
self.host,
ex,
)
async def _modular_update(
self, first_update: bool, update_time: float
@ -221,17 +226,16 @@ class SmartDevice(Device):
mq = {
module: query
for module in self._modules.values()
if (query := module.query())
if module.disabled is False and (query := module.query())
}
for module, query in mq.items():
if first_update and module.__class__ in FIRST_UPDATE_MODULES:
module._last_update_time = update_time
continue
if (
not module.MINIMUM_UPDATE_INTERVAL_SECS
not module.update_interval
or not module._last_update_time
or (update_time - module._last_update_time)
>= module.MINIMUM_UPDATE_INTERVAL_SECS
or (update_time - module._last_update_time) >= module.update_interval
):
module_queries.append(module)
req.update(query)
@ -254,16 +258,10 @@ class SmartDevice(Device):
self._info = self._try_get_response(info_resp, "get_device_info")
# Call handle update for modules that want to update internal data
errors = []
for module_name, module in self._modules.items():
if not self._handle_module_post_update_hook(module):
errors.append(module_name)
for error in errors:
self._modules.pop(error)
# Set the last update time for modules that had queries made.
for module in module_queries:
module._last_update_time = update_time
for module in self._modules.values():
self._handle_module_post_update(
module, update_time, had_query=module in module_queries
)
return resp
@ -392,7 +390,7 @@ class SmartDevice(Device):
name="RSSI",
attribute_getter=lambda x: x._info["rssi"],
icon="mdi:signal",
unit="dBm",
unit_getter=lambda: "dBm",
category=Feature.Category.Debug,
type=Feature.Type.Sensor,
)

View File

@ -18,6 +18,7 @@ _LOGGER = logging.getLogger(__name__)
_T = TypeVar("_T", bound="SmartModule")
_P = ParamSpec("_P")
_R = TypeVar("_R")
def allow_update_after(
@ -38,6 +39,17 @@ def allow_update_after(
return _async_wrap
def raise_if_update_error(func: Callable[[_T], _R]) -> Callable[[_T], _R]:
"""Define a wrapper to raise an error if the last module update was an error."""
def _wrap(self: _T) -> _R:
if err := self._last_update_error:
raise err
return func(self)
return _wrap
class SmartModule(Module):
"""Base class for SMART modules."""
@ -52,17 +64,58 @@ class SmartModule(Module):
REGISTERED_MODULES: dict[str, type[SmartModule]] = {}
MINIMUM_UPDATE_INTERVAL_SECS = 0
UPDATE_INTERVAL_AFTER_ERROR_SECS = 30
DISABLE_AFTER_ERROR_COUNT = 10
def __init__(self, device: SmartDevice, module: str):
self._device: SmartDevice
super().__init__(device, module)
self._last_update_time: float | None = None
self._last_update_error: KasaException | None = None
self._error_count = 0
def __init_subclass__(cls, **kwargs):
name = getattr(cls, "NAME", cls.__name__)
_LOGGER.debug("Registering %s" % cls)
cls.REGISTERED_MODULES[name] = cls
def _set_error(self, err: Exception | None):
if err is None:
self._error_count = 0
self._last_update_error = None
else:
self._last_update_error = KasaException("Module update error", err)
self._error_count += 1
if self._error_count == self.DISABLE_AFTER_ERROR_COUNT:
_LOGGER.error(
"Error processing %s for device %s, module will be disabled: %s",
self.name,
self._device.host,
err,
)
if self._error_count > self.DISABLE_AFTER_ERROR_COUNT:
_LOGGER.error( # pragma: no cover
"Unexpected error processing %s for device %s, "
"module should be disabled: %s",
self.name,
self._device.host,
err,
)
@property
def update_interval(self) -> int:
"""Time to wait between updates."""
if self._last_update_error is None:
return self.MINIMUM_UPDATE_INTERVAL_SECS
return self.UPDATE_INTERVAL_AFTER_ERROR_SECS * self._error_count
@property
def disabled(self) -> bool:
"""Return true if the module is disabled due to errors."""
return self._error_count >= self.DISABLE_AFTER_ERROR_COUNT
@property
def name(self) -> str:
"""Name of the module."""

View File

@ -114,6 +114,7 @@ class FakeSmartTransport(BaseTransport):
},
),
"get_device_usage": ("device", {}),
"get_connect_cloud_state": ("cloud_connect", {"status": 0}),
}
async def send(self, request: str):

View File

@ -27,7 +27,7 @@ def dummy_feature() -> Feature:
container=None,
icon="mdi:dummy",
type=Feature.Type.Switch,
unit="dummyunit",
unit_getter=lambda: "dummyunit",
)
return feat
@ -127,7 +127,7 @@ async def test_feature_action(mocker):
async def test_feature_choice_list(dummy_feature, caplog, mocker: MockerFixture):
"""Test the choice feature type."""
dummy_feature.type = Feature.Type.Choice
dummy_feature.choices = ["first", "second"]
dummy_feature.choices_getter = lambda: ["first", "second"]
mock_setter = mocker.patch.object(dummy_feature.device, "dummysetter", create=True)
await dummy_feature.set_value("first")

View File

@ -12,8 +12,11 @@ from freezegun.api import FrozenDateTimeFactory
from pytest_mock import MockerFixture
from kasa import Device, KasaException, Module
from kasa.exceptions import SmartErrorCode
from kasa.exceptions import DeviceError, SmartErrorCode
from kasa.smart import SmartDevice
from kasa.smart.modules.energy import Energy
from kasa.smart.smartmodule import SmartModule
from kasa.smartprotocol import _ChildProtocolWrapper
from .conftest import (
device_smart,
@ -139,78 +142,6 @@ async def test_update_module_queries(dev: SmartDevice, mocker: MockerFixture):
spies[device].assert_not_called()
@device_smart
async def test_update_module_errors(dev: SmartDevice, mocker: MockerFixture):
"""Test that modules that error are disabled / removed."""
# We need to have some modules initialized by now
assert dev._modules
critical_modules = {Module.DeviceModule, Module.ChildDevice}
not_disabling_modules = {Module.Cloud}
new_dev = SmartDevice("127.0.0.1", protocol=dev.protocol)
module_queries = {
modname: q
for modname, module in dev._modules.items()
if (q := module.query()) and modname not in critical_modules
}
child_module_queries = {
modname: q
for child in dev.children
for modname, module in child._modules.items()
if (q := module.query()) and modname not in critical_modules
}
all_queries_names = {
key for mod_query in module_queries.values() for key in mod_query
}
all_child_queries_names = {
key for mod_query in child_module_queries.values() for key in mod_query
}
async def _query(request, *args, **kwargs):
responses = await dev.protocol._query(request, *args, **kwargs)
for k in responses:
if k in all_queries_names:
responses[k] = SmartErrorCode.PARAMS_ERROR
return responses
async def _child_query(self, request, *args, **kwargs):
responses = await child_protocols[self._device_id]._query(
request, *args, **kwargs
)
for k in responses:
if k in all_child_queries_names:
responses[k] = SmartErrorCode.PARAMS_ERROR
return responses
mocker.patch.object(new_dev.protocol, "query", side_effect=_query)
from kasa.smartprotocol import _ChildProtocolWrapper
child_protocols = {
cast(_ChildProtocolWrapper, child.protocol)._device_id: child.protocol
for child in dev.children
}
# children not created yet so cannot patch.object
mocker.patch("kasa.smartprotocol._ChildProtocolWrapper.query", new=_child_query)
await new_dev.update()
for modname in module_queries:
no_disable = modname in not_disabling_modules
mod_present = modname in new_dev._modules
assert (
mod_present is no_disable
), f"{modname} present {mod_present} when no_disable {no_disable}"
for modname in child_module_queries:
no_disable = modname in not_disabling_modules
mod_present = any(modname in child._modules for child in new_dev.children)
assert (
mod_present is no_disable
), f"{modname} present {mod_present} when no_disable {no_disable}"
@device_smart
async def test_update_module_update_delays(
dev: SmartDevice,
@ -218,7 +149,7 @@ async def test_update_module_update_delays(
caplog: pytest.LogCaptureFixture,
freezer: FrozenDateTimeFactory,
):
"""Test that modules that disabled / removed on query failures."""
"""Test that modules with minimum delays delay."""
# We need to have some modules initialized by now
assert dev._modules
@ -257,6 +188,20 @@ async def test_update_module_update_delays(
pytest.param(False, id="First update false"),
],
)
@pytest.mark.parametrize(
("error_type"),
[
pytest.param(SmartErrorCode.PARAMS_ERROR, id="Device error"),
pytest.param(TimeoutError("Dummy timeout"), id="Query error"),
],
)
@pytest.mark.parametrize(
("recover"),
[
pytest.param(True, id="recover"),
pytest.param(False, id="no recover"),
],
)
@device_smart
async def test_update_module_query_errors(
dev: SmartDevice,
@ -264,15 +209,20 @@ async def test_update_module_query_errors(
caplog: pytest.LogCaptureFixture,
freezer: FrozenDateTimeFactory,
first_update,
error_type,
recover,
):
"""Test that modules that disabled / removed on query failures."""
"""Test that modules that disabled / removed on query failures.
i.e. the whole query times out rather than device returns an error.
"""
# We need to have some modules initialized by now
assert dev._modules
SmartModule.DISABLE_AFTER_ERROR_COUNT = 2
first_update_queries = {"get_device_info", "get_connect_cloud_state"}
critical_modules = {Module.DeviceModule, Module.ChildDevice}
not_disabling_modules = {Module.Cloud}
new_dev = SmartDevice("127.0.0.1", protocol=dev.protocol)
if not first_update:
@ -293,13 +243,18 @@ async def test_update_module_query_errors(
or "get_child_device_component_list" in request
or "control_child" in request
):
return await dev.protocol._query(request, *args, **kwargs)
resp = await dev.protocol._query(request, *args, **kwargs)
resp["get_connect_cloud_state"] = SmartErrorCode.CLOUD_FAILED_ERROR
return resp
# Don't test for errors on get_device_info as that is likely terminal
if len(request) == 1 and "get_device_info" in request:
return await dev.protocol._query(request, *args, **kwargs)
raise TimeoutError("Dummy timeout")
from kasa.smartprotocol import _ChildProtocolWrapper
if isinstance(error_type, SmartErrorCode):
if len(request) == 1:
raise DeviceError("Dummy device error", error_code=error_type)
raise TimeoutError("Dummy timeout")
raise error_type
child_protocols = {
cast(_ChildProtocolWrapper, child.protocol)._device_id: child.protocol
@ -314,19 +269,66 @@ async def test_update_module_query_errors(
mocker.patch("kasa.smartprotocol._ChildProtocolWrapper.query", new=_child_query)
await new_dev.update()
msg = f"Error querying {new_dev.host} for modules"
assert msg in caplog.text
for modname in module_queries:
no_disable = modname in not_disabling_modules
mod_present = modname in new_dev._modules
assert (
mod_present is no_disable
), f"{modname} present {mod_present} when no_disable {no_disable}"
mod = cast(SmartModule, new_dev.modules[modname])
assert mod.disabled is False, f"{modname} disabled"
assert mod.update_interval == mod.UPDATE_INTERVAL_AFTER_ERROR_SECS
for mod_query in module_queries[modname]:
if not first_update or mod_query not in first_update_queries:
msg = f"Error querying {new_dev.host} individually for module query '{mod_query}"
assert msg in caplog.text
# Query again should not run for the modules
caplog.clear()
await new_dev.update()
for modname in module_queries:
mod = cast(SmartModule, new_dev.modules[modname])
assert mod.disabled is False, f"{modname} disabled"
freezer.tick(SmartModule.UPDATE_INTERVAL_AFTER_ERROR_SECS)
caplog.clear()
if recover:
mocker.patch.object(
new_dev.protocol, "query", side_effect=new_dev.protocol._query
)
mocker.patch(
"kasa.smartprotocol._ChildProtocolWrapper.query",
new=_ChildProtocolWrapper._query,
)
await new_dev.update()
msg = f"Error querying {new_dev.host} for modules"
if not recover:
assert msg in caplog.text
for modname in module_queries:
mod = cast(SmartModule, new_dev.modules[modname])
if not recover:
assert mod.disabled is True, f"{modname} not disabled"
assert mod._error_count == 2
assert mod._last_update_error
for mod_query in module_queries[modname]:
if not first_update or mod_query not in first_update_queries:
msg = f"Error querying {new_dev.host} individually for module query '{mod_query}"
assert msg in caplog.text
# Test one of the raise_if_update_error
if mod.name == "Energy":
emod = cast(Energy, mod)
with pytest.raises(KasaException, match="Module update error"):
assert emod.current_consumption is not None
else:
assert mod.disabled is False
assert mod._error_count == 0
assert mod._last_update_error is None
# Test one of the raise_if_update_error doesn't raise
if mod.name == "Energy":
emod = cast(Energy, mod)
assert emod.current_consumption is not None
async def test_get_modules():
"""Test getting modules for child and parent modules."""