Defer module updates for less volatile modules (#1052)

Addresses stability issues on older hw device versions

 - Handles module timeout errors better by querying modules individually on errors and disabling problematic modules like Firmware that go out to the internet to get updates.
- Addresses an issue with the Led module on P100 hardware version 1.0 which appears to have a memory leak and will cause the device to crash after approximately 500 calls.
- Delays updates of modules that do not have regular changes like LightPreset and LightEffect and enables them to be updated on the next update cycle only if required values have changed.
This commit is contained in:
Steven B 2024-07-11 16:21:59 +01:00 committed by GitHub
parent a044063526
commit 7fd5c213e6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 364 additions and 56 deletions

View File

@ -146,7 +146,9 @@ class AesTransport(BaseTransport):
try: try:
error_code = SmartErrorCode.from_int(error_code_raw) error_code = SmartErrorCode.from_int(error_code_raw)
except ValueError: except ValueError:
_LOGGER.warning("Received unknown error code: %s", error_code_raw) _LOGGER.warning(
"Device %s received unknown error code: %s", self._host, error_code_raw
)
error_code = SmartErrorCode.INTERNAL_UNKNOWN_ERROR error_code = SmartErrorCode.INTERNAL_UNKNOWN_ERROR
if error_code is SmartErrorCode.SUCCESS: if error_code is SmartErrorCode.SUCCESS:
return return
@ -216,10 +218,18 @@ class AesTransport(BaseTransport):
"""Login to the device.""" """Login to the device."""
try: try:
await self.try_login(self._login_params) await self.try_login(self._login_params)
_LOGGER.debug(
"%s: logged in with provided credentials",
self._host,
)
except AuthenticationError as aex: except AuthenticationError as aex:
try: try:
if aex.error_code is not SmartErrorCode.LOGIN_ERROR: if aex.error_code is not SmartErrorCode.LOGIN_ERROR:
raise aex raise aex
_LOGGER.debug(
"%s: trying login with default TAPO credentials",
self._host,
)
if self._default_credentials is None: if self._default_credentials is None:
self._default_credentials = get_default_credentials( self._default_credentials = get_default_credentials(
DEFAULT_CREDENTIALS["TAPO"] DEFAULT_CREDENTIALS["TAPO"]
@ -227,7 +237,7 @@ class AesTransport(BaseTransport):
await self.perform_handshake() await self.perform_handshake()
await self.try_login(self._get_login_params(self._default_credentials)) await self.try_login(self._get_login_params(self._default_credentials))
_LOGGER.debug( _LOGGER.debug(
"%s: logged in with default credentials", "%s: logged in with default TAPO credentials",
self._host, self._host,
) )
except (AuthenticationError, _ConnectionError, TimeoutError): except (AuthenticationError, _ConnectionError, TimeoutError):

View File

@ -128,6 +128,8 @@ class SmartErrorCode(IntEnum):
# Library internal for unknown error codes # Library internal for unknown error codes
INTERNAL_UNKNOWN_ERROR = -100_000 INTERNAL_UNKNOWN_ERROR = -100_000
# Library internal for query errors
INTERNAL_QUERY_ERROR = -100_001
SMART_RETRYABLE_ERRORS = [ SMART_RETRYABLE_ERRORS = [

View File

@ -75,13 +75,21 @@ class HttpClient:
now = time.time() now = time.time()
gap = now - self._last_request_time gap = now - self._last_request_time
if gap < self._wait_between_requests: if gap < self._wait_between_requests:
await asyncio.sleep(self._wait_between_requests - gap) sleep = self._wait_between_requests - gap
_LOGGER.debug(
"Device %s waiting %s seconds to send request",
self._config.host,
sleep,
)
await asyncio.sleep(sleep)
_LOGGER.debug("Posting to %s", url) _LOGGER.debug("Posting to %s", url)
response_data = None response_data = None
self._last_url = url self._last_url = url
self.client.cookie_jar.clear() self.client.cookie_jar.clear()
return_json = bool(json) return_json = bool(json)
client_timeout = aiohttp.ClientTimeout(total=self._config.timeout)
# If json is not a dict send as data. # If json is not a dict send as data.
# This allows the json parameter to be used to pass other # This allows the json parameter to be used to pass other
# types of data such as async_generator and still have json # types of data such as async_generator and still have json
@ -95,9 +103,10 @@ class HttpClient:
params=params, params=params,
data=data, data=data,
json=json, json=json,
timeout=self._config.timeout, timeout=client_timeout,
cookies=cookies_dict, cookies=cookies_dict,
headers=headers, headers=headers,
ssl=False,
) )
async with resp: async with resp:
if resp.status == 200: if resp.status == 200:
@ -106,9 +115,15 @@ class HttpClient:
response_data = json_loads(response_data.decode()) response_data = json_loads(response_data.decode())
except (aiohttp.ServerDisconnectedError, aiohttp.ClientOSError) as ex: except (aiohttp.ServerDisconnectedError, aiohttp.ClientOSError) as ex:
if isinstance(ex, aiohttp.ClientOSError): if not self._wait_between_requests:
_LOGGER.debug(
"Device %s received an os error, "
"enabling sequential request delay: %s",
self._config.host,
ex,
)
self._wait_between_requests = self.WAIT_BETWEEN_REQUESTS_ON_OSERROR self._wait_between_requests = self.WAIT_BETWEEN_REQUESTS_ON_OSERROR
self._last_request_time = time.time() self._last_request_time = time.time()
raise _ConnectionError( raise _ConnectionError(
f"Device connection error: {self._config.host}: {ex}", ex f"Device connection error: {self._config.host}: {ex}", ex
) from ex ) from ex

View File

@ -16,6 +16,7 @@ class Cloud(SmartModule):
QUERY_GETTER_NAME = "get_connect_cloud_state" QUERY_GETTER_NAME = "get_connect_cloud_state"
REQUIRED_COMPONENT = "cloud_connect" REQUIRED_COMPONENT = "cloud_connect"
MINIMUM_UPDATE_INTERVAL_SECS = 60
def _post_update_hook(self): def _post_update_hook(self):
"""Perform actions after a device update. """Perform actions after a device update.

View File

@ -14,7 +14,7 @@ from async_timeout import timeout as asyncio_timeout
from pydantic.v1 import BaseModel, Field, validator from pydantic.v1 import BaseModel, Field, validator
from ...feature import Feature from ...feature import Feature
from ..smartmodule import SmartModule from ..smartmodule import SmartModule, allow_update_after
if TYPE_CHECKING: if TYPE_CHECKING:
from ..smartdevice import SmartDevice from ..smartdevice import SmartDevice
@ -66,6 +66,7 @@ class Firmware(SmartModule):
"""Implementation of firmware module.""" """Implementation of firmware module."""
REQUIRED_COMPONENT = "firmware" REQUIRED_COMPONENT = "firmware"
MINIMUM_UPDATE_INTERVAL_SECS = 60 * 60 * 24
def __init__(self, device: SmartDevice, module: str): def __init__(self, device: SmartDevice, module: str):
super().__init__(device, module) super().__init__(device, module)
@ -122,13 +123,6 @@ class Firmware(SmartModule):
req["get_auto_update_info"] = None req["get_auto_update_info"] = None
return req return req
def _post_update_hook(self):
"""Perform actions after a device update.
Overrides the default behaviour to disable a module if the query returns
an error because some of the module still functions.
"""
@property @property
def current_firmware(self) -> str: def current_firmware(self) -> str:
"""Return the current firmware version.""" """Return the current firmware version."""
@ -162,6 +156,7 @@ class Firmware(SmartModule):
state = resp["get_fw_download_state"] state = resp["get_fw_download_state"]
return DownloadState(**state) return DownloadState(**state)
@allow_update_after
async def update( async def update(
self, progress_cb: Callable[[DownloadState], Coroutine] | None = None self, progress_cb: Callable[[DownloadState], Coroutine] | None = None
): ):
@ -219,6 +214,7 @@ class Firmware(SmartModule):
and self.data["get_auto_update_info"]["enable"] and self.data["get_auto_update_info"]["enable"]
) )
@allow_update_after
async def set_auto_update_enabled(self, enabled: bool): async def set_auto_update_enabled(self, enabled: bool):
"""Change autoupdate setting.""" """Change autoupdate setting."""
data = {**self.data["get_auto_update_info"], "enable": enabled} data = {**self.data["get_auto_update_info"], "enable": enabled}

View File

@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
from ...interfaces.led import Led as LedInterface from ...interfaces.led import Led as LedInterface
from ..smartmodule import SmartModule from ..smartmodule import SmartModule, allow_update_after
class Led(SmartModule, LedInterface): class Led(SmartModule, LedInterface):
@ -11,6 +11,8 @@ class Led(SmartModule, LedInterface):
REQUIRED_COMPONENT = "led" REQUIRED_COMPONENT = "led"
QUERY_GETTER_NAME = "get_led_info" QUERY_GETTER_NAME = "get_led_info"
# Led queries can cause device to crash on P100
MINIMUM_UPDATE_INTERVAL_SECS = 60 * 60
def query(self) -> dict: def query(self) -> dict:
"""Query to execute during the update cycle.""" """Query to execute during the update cycle."""
@ -29,6 +31,7 @@ class Led(SmartModule, LedInterface):
"""Return current led status.""" """Return current led status."""
return self.data["led_rule"] != "never" return self.data["led_rule"] != "never"
@allow_update_after
async def set_led(self, enable: bool): async def set_led(self, enable: bool):
"""Set led. """Set led.

View File

@ -9,7 +9,7 @@ import copy
from typing import Any from typing import Any
from ..effects import SmartLightEffect from ..effects import SmartLightEffect
from ..smartmodule import Module, SmartModule from ..smartmodule import Module, SmartModule, allow_update_after
class LightEffect(SmartModule, SmartLightEffect): class LightEffect(SmartModule, SmartLightEffect):
@ -17,6 +17,7 @@ class LightEffect(SmartModule, SmartLightEffect):
REQUIRED_COMPONENT = "light_effect" REQUIRED_COMPONENT = "light_effect"
QUERY_GETTER_NAME = "get_dynamic_light_effect_rules" QUERY_GETTER_NAME = "get_dynamic_light_effect_rules"
MINIMUM_UPDATE_INTERVAL_SECS = 60
AVAILABLE_BULB_EFFECTS = { AVAILABLE_BULB_EFFECTS = {
"L1": "Party", "L1": "Party",
"L2": "Relax", "L2": "Relax",
@ -130,6 +131,7 @@ class LightEffect(SmartModule, SmartLightEffect):
return brightness return brightness
@allow_update_after
async def set_brightness( async def set_brightness(
self, self,
brightness: int, brightness: int,
@ -156,6 +158,7 @@ class LightEffect(SmartModule, SmartLightEffect):
return await self.call("edit_dynamic_light_effect_rule", new_effect) return await self.call("edit_dynamic_light_effect_rule", new_effect)
@allow_update_after
async def set_custom_effect( async def set_custom_effect(
self, self,
effect_dict: dict, effect_dict: dict,

View File

@ -9,7 +9,7 @@ from typing import TYPE_CHECKING
from ...interfaces import LightPreset as LightPresetInterface from ...interfaces import LightPreset as LightPresetInterface
from ...interfaces import LightState from ...interfaces import LightState
from ..smartmodule import SmartModule from ..smartmodule import SmartModule, allow_update_after
if TYPE_CHECKING: if TYPE_CHECKING:
from ..smartdevice import SmartDevice from ..smartdevice import SmartDevice
@ -22,6 +22,7 @@ class LightPreset(SmartModule, LightPresetInterface):
REQUIRED_COMPONENT = "preset" REQUIRED_COMPONENT = "preset"
QUERY_GETTER_NAME = "get_preset_rules" QUERY_GETTER_NAME = "get_preset_rules"
MINIMUM_UPDATE_INTERVAL_SECS = 60
SYS_INFO_STATE_KEY = "preset_state" SYS_INFO_STATE_KEY = "preset_state"
@ -124,6 +125,7 @@ class LightPreset(SmartModule, LightPresetInterface):
raise ValueError(f"{preset_name} is not a valid preset: {self.preset_list}") raise ValueError(f"{preset_name} is not a valid preset: {self.preset_list}")
await self._device.modules[SmartModule.Light].set_state(preset) await self._device.modules[SmartModule.Light].set_state(preset)
@allow_update_after
async def save_preset( async def save_preset(
self, self,
preset_name: str, preset_name: str,

View File

@ -5,7 +5,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ..effects import EFFECT_MAPPING, EFFECT_NAMES, SmartLightEffect from ..effects import EFFECT_MAPPING, EFFECT_NAMES, SmartLightEffect
from ..smartmodule import Module, SmartModule from ..smartmodule import Module, SmartModule, allow_update_after
if TYPE_CHECKING: if TYPE_CHECKING:
from ..smartdevice import SmartDevice from ..smartdevice import SmartDevice
@ -84,6 +84,7 @@ class LightStripEffect(SmartModule, SmartLightEffect):
""" """
return self._effect_list return self._effect_list
@allow_update_after
async def set_effect( async def set_effect(
self, self,
effect: str, effect: str,
@ -126,6 +127,7 @@ class LightStripEffect(SmartModule, SmartLightEffect):
await self.set_custom_effect(effect_dict) await self.set_custom_effect(effect_dict)
@allow_update_after
async def set_custom_effect( async def set_custom_effect(
self, self,
effect_dict: dict, effect_dict: dict,

View File

@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, TypedDict
from ...exceptions import KasaException from ...exceptions import KasaException
from ...feature import Feature from ...feature import Feature
from ..smartmodule import SmartModule from ..smartmodule import SmartModule, allow_update_after
if TYPE_CHECKING: if TYPE_CHECKING:
from ..smartdevice import SmartDevice from ..smartdevice import SmartDevice
@ -23,6 +23,7 @@ class LightTransition(SmartModule):
REQUIRED_COMPONENT = "on_off_gradually" REQUIRED_COMPONENT = "on_off_gradually"
QUERY_GETTER_NAME = "get_on_off_gradually_info" QUERY_GETTER_NAME = "get_on_off_gradually_info"
MINIMUM_UPDATE_INTERVAL_SECS = 60
MAXIMUM_DURATION = 60 MAXIMUM_DURATION = 60
# Key in sysinfo that indicates state can be retrieved from there. # Key in sysinfo that indicates state can be retrieved from there.
@ -136,6 +137,7 @@ class LightTransition(SmartModule):
"max_duration": off_max, "max_duration": off_max,
} }
@allow_update_after
async def set_enabled(self, enable: bool): async def set_enabled(self, enable: bool):
"""Enable gradual on/off.""" """Enable gradual on/off."""
if not self._supports_on_and_off: if not self._supports_on_and_off:
@ -168,6 +170,7 @@ class LightTransition(SmartModule):
# v3 added max_duration, we default to 60 when it's not available # v3 added max_duration, we default to 60 when it's not available
return self._on_state["max_duration"] return self._on_state["max_duration"]
@allow_update_after
async def set_turn_on_transition(self, seconds: int): async def set_turn_on_transition(self, seconds: int):
"""Set turn on transition in seconds. """Set turn on transition in seconds.
@ -203,6 +206,7 @@ class LightTransition(SmartModule):
# v3 added max_duration, we default to 60 when it's not available # v3 added max_duration, we default to 60 when it's not available
return self._off_state["max_duration"] return self._off_state["max_duration"]
@allow_update_after
async def set_turn_off_transition(self, seconds: int): async def set_turn_off_transition(self, seconds: int):
"""Set turn on transition in seconds. """Set turn on transition in seconds.

View File

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
import time
from typing import Any from typing import Any
from ..device_type import DeviceType from ..device_type import DeviceType
@ -54,6 +55,7 @@ class SmartChildDevice(SmartDevice):
req.update(mod_query) req.update(mod_query)
if req: if req:
self._last_update = await self.protocol.query(req) self._last_update = await self.protocol.query(req)
self._last_update_time = time.time()
@classmethod @classmethod
async def create(cls, parent: SmartDevice, child_info, child_components): async def create(cls, parent: SmartDevice, child_info, child_components):

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import base64 import base64
import logging import logging
import time
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Any, cast from typing import Any, cast
@ -18,6 +19,7 @@ from ..module import Module
from ..modulemapping import ModuleMapping, ModuleName from ..modulemapping import ModuleMapping, ModuleName
from ..smartprotocol import SmartProtocol from ..smartprotocol import SmartProtocol
from .modules import ( from .modules import (
ChildDevice,
Cloud, Cloud,
DeviceModule, DeviceModule,
Firmware, Firmware,
@ -35,6 +37,9 @@ _LOGGER = logging.getLogger(__name__)
# same issue, homekit perhaps? # same issue, homekit perhaps?
NON_HUB_PARENT_ONLY_MODULES = [DeviceModule, Time, Firmware, Cloud] NON_HUB_PARENT_ONLY_MODULES = [DeviceModule, Time, Firmware, Cloud]
# Modules that are called as part of the init procedure on first update
FIRST_UPDATE_MODULES = {DeviceModule, ChildDevice, Cloud}
# Device must go last as the other interfaces also inherit Device # Device must go last as the other interfaces also inherit Device
# and python needs a consistent method resolution order. # and python needs a consistent method resolution order.
@ -60,6 +65,7 @@ class SmartDevice(Device):
self._parent: SmartDevice | None = None self._parent: SmartDevice | None = None
self._children: Mapping[str, SmartDevice] = {} self._children: Mapping[str, SmartDevice] = {}
self._last_update = {} self._last_update = {}
self._last_update_time: float | None = None
async def _initialize_children(self): async def _initialize_children(self):
"""Initialize children for power strips.""" """Initialize children for power strips."""
@ -152,19 +158,15 @@ class SmartDevice(Device):
if self.credentials is None and self.credentials_hash is None: if self.credentials is None and self.credentials_hash is None:
raise AuthenticationError("Tapo plug requires authentication.") raise AuthenticationError("Tapo plug requires authentication.")
if self._components_raw is None: first_update = self._last_update_time is None
now = time.time()
self._last_update_time = now
if first_update:
await self._negotiate() await self._negotiate()
await self._initialize_modules() await self._initialize_modules()
req: dict[str, Any] = {} resp = await self._modular_update(first_update, now)
# TODO: this could be optimized by constructing the query only once
for module in self._modules.values():
req.update(module.query())
self._last_update = resp = await self.protocol.query(req)
self._info = self._try_get_response(resp, "get_device_info")
# Call child update which will only update module calls, info is updated # Call child update which will only update module calls, info is updated
# from get_child_device_list. update_children only affects hub devices, other # from get_child_device_list. update_children only affects hub devices, other
@ -172,18 +174,12 @@ class SmartDevice(Device):
if update_children or self.device_type != DeviceType.Hub: if update_children or self.device_type != DeviceType.Hub:
for child in self._children.values(): for child in self._children.values():
await child._update() await child._update()
if child_info := self._try_get_response(resp, "get_child_device_list", {}): if child_info := self._try_get_response(
self._last_update, "get_child_device_list", {}
):
for info in child_info["child_device_list"]: for info in child_info["child_device_list"]:
self._children[info["device_id"]]._update_internal_state(info) self._children[info["device_id"]]._update_internal_state(info)
# Call handle update for modules that want to update internal data
errors = []
for module_name, module in self._modules.items():
if not self._handle_module_post_update_hook(module):
errors.append(module_name)
for error in errors:
self._modules.pop(error)
for child in self._children.values(): for child in self._children.values():
errors = [] errors = []
for child_module_name, child_module in child._modules.items(): for child_module_name, child_module in child._modules.items():
@ -197,14 +193,18 @@ class SmartDevice(Device):
if not self._features: if not self._features:
await self._initialize_features() await self._initialize_features()
_LOGGER.debug("Got an update: %s", self._last_update) _LOGGER.debug(
"Update completed %s: %s",
self.host,
self._last_update if first_update else resp,
)
def _handle_module_post_update_hook(self, module: SmartModule) -> bool: def _handle_module_post_update_hook(self, module: SmartModule) -> bool:
try: try:
module._post_update_hook() module._post_update_hook()
return True return True
except Exception as ex: except Exception as ex:
_LOGGER.error( _LOGGER.warning(
"Error processing %s for device %s, module will be unavailable: %s", "Error processing %s for device %s, module will be unavailable: %s",
module.name, module.name,
self.host, self.host,
@ -212,6 +212,100 @@ class SmartDevice(Device):
) )
return False return False
async def _modular_update(
self, first_update: bool, update_time: float
) -> dict[str, Any]:
"""Update the device with via the module queries."""
req: dict[str, Any] = {}
# Keep a track of actual module queries so we can track the time for
# modules that do not need to be updated frequently
module_queries: list[SmartModule] = []
mq = {
module: query
for module in self._modules.values()
if (query := module.query())
}
for module, query in mq.items():
if first_update and module.__class__ in FIRST_UPDATE_MODULES:
module._last_update_time = update_time
continue
if (
not module.MINIMUM_UPDATE_INTERVAL_SECS
or not module._last_update_time
or (update_time - module._last_update_time)
>= module.MINIMUM_UPDATE_INTERVAL_SECS
):
module_queries.append(module)
req.update(query)
_LOGGER.debug(
"Querying %s for modules: %s",
self.host,
", ".join(mod.name for mod in module_queries),
)
try:
resp = await self.protocol.query(req)
except Exception as ex:
resp = await self._handle_modular_update_error(
ex, first_update, ", ".join(mod.name for mod in module_queries), req
)
info_resp = self._last_update if first_update else resp
self._last_update.update(**resp)
self._info = self._try_get_response(info_resp, "get_device_info")
# Call handle update for modules that want to update internal data
errors = []
for module_name, module in self._modules.items():
if not self._handle_module_post_update_hook(module):
errors.append(module_name)
for error in errors:
self._modules.pop(error)
# Set the last update time for modules that had queries made.
for module in module_queries:
module._last_update_time = update_time
return resp
async def _handle_modular_update_error(
self,
ex: Exception,
first_update: bool,
module_names: str,
requests: dict[str, Any],
) -> dict[str, Any]:
"""Handle an error on calling module update.
Will try to call all modules individually
and any errors such as timeouts will be set as a SmartErrorCode.
"""
msg_part = "on first update" if first_update else "after first update"
_LOGGER.error(
"Error querying %s for modules '%s' %s: %s",
self.host,
module_names,
msg_part,
ex,
)
responses = {}
for meth, params in requests.items():
try:
resp = await self.protocol.query({meth: params})
responses[meth] = resp[meth]
except Exception as iex:
_LOGGER.error(
"Error querying %s individually for module query '%s' %s: %s",
self.host,
meth,
msg_part,
iex,
)
responses[meth] = SmartErrorCode.INTERNAL_QUERY_ERROR
return responses
async def _initialize_modules(self): async def _initialize_modules(self):
"""Initialize modules based on component negotiation response.""" """Initialize modules based on component negotiation response."""
from .smartmodule import SmartModule from .smartmodule import SmartModule
@ -229,8 +323,6 @@ class SmartDevice(Device):
skip_parent_only_modules = True skip_parent_only_modules = True
for mod in SmartModule.REGISTERED_MODULES.values(): for mod in SmartModule.REGISTERED_MODULES.values():
_LOGGER.debug("%s requires %s", mod, mod.REQUIRED_COMPONENT)
if ( if (
skip_parent_only_modules and mod in NON_HUB_PARENT_ONLY_MODULES skip_parent_only_modules and mod in NON_HUB_PARENT_ONLY_MODULES
) or mod.__name__ in child_modules_to_skip: ) or mod.__name__ in child_modules_to_skip:
@ -240,7 +332,8 @@ class SmartDevice(Device):
or self.sys_info.get(mod.REQUIRED_KEY_ON_PARENT) is not None or self.sys_info.get(mod.REQUIRED_KEY_ON_PARENT) is not None
): ):
_LOGGER.debug( _LOGGER.debug(
"Found required %s, adding %s to modules.", "Device %s, found required %s, adding %s to modules.",
self.host,
mod.REQUIRED_COMPONENT, mod.REQUIRED_COMPONENT,
mod.__name__, mod.__name__,
) )

View File

@ -3,7 +3,10 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import TYPE_CHECKING from collections.abc import Awaitable, Callable, Coroutine
from typing import TYPE_CHECKING, Any
from typing_extensions import Concatenate, ParamSpec, TypeVar
from ..exceptions import DeviceError, KasaException, SmartErrorCode from ..exceptions import DeviceError, KasaException, SmartErrorCode
from ..module import Module from ..module import Module
@ -13,6 +16,27 @@ if TYPE_CHECKING:
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
_T = TypeVar("_T", bound="SmartModule")
_P = ParamSpec("_P")
def allow_update_after(
func: Callable[Concatenate[_T, _P], Awaitable[None]],
) -> Callable[Concatenate[_T, _P], Coroutine[Any, Any, None]]:
"""Define a wrapper to set _last_update_time to None.
This will ensure that a module is updated in the next update cycle after
a value has been changed.
"""
async def _async_wrap(self: _T, *args: _P.args, **kwargs: _P.kwargs) -> None:
try:
await func(self, *args, **kwargs)
finally:
self._last_update_time = None
return _async_wrap
class SmartModule(Module): class SmartModule(Module):
"""Base class for SMART modules.""" """Base class for SMART modules."""
@ -27,9 +51,12 @@ class SmartModule(Module):
REGISTERED_MODULES: dict[str, type[SmartModule]] = {} REGISTERED_MODULES: dict[str, type[SmartModule]] = {}
MINIMUM_UPDATE_INTERVAL_SECS = 0
def __init__(self, device: SmartDevice, module: str): def __init__(self, device: SmartDevice, module: str):
self._device: SmartDevice self._device: SmartDevice
super().__init__(device, module) super().__init__(device, module)
self._last_update_time: float | None = None
def __init_subclass__(cls, **kwargs): def __init_subclass__(cls, **kwargs):
name = getattr(cls, "NAME", cls.__name__) name = getattr(cls, "NAME", cls.__name__)

View File

@ -73,18 +73,32 @@ class SmartProtocol(BaseProtocol):
return await self._execute_query( return await self._execute_query(
request, retry_count=retry, iterate_list_pages=True request, retry_count=retry, iterate_list_pages=True
) )
except _ConnectionError as sdex: except _ConnectionError as ex:
if retry == 0:
_LOGGER.debug(
"Device %s got a connection error, will retry %s times: %s",
self._host,
retry_count,
ex,
)
if retry >= retry_count: if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry) _LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise sdex raise ex
continue continue
except AuthenticationError as auex: except AuthenticationError as ex:
await self._transport.reset() await self._transport.reset()
_LOGGER.debug( _LOGGER.debug(
"Unable to authenticate with %s, not retrying", self._host "Unable to authenticate with %s, not retrying: %s", self._host, ex
) )
raise auex raise ex
except _RetryableError as ex: except _RetryableError as ex:
if retry == 0:
_LOGGER.debug(
"Device %s got a retryable error, will retry %s times: %s",
self._host,
retry_count,
ex,
)
await self._transport.reset() await self._transport.reset()
if retry >= retry_count: if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry) _LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
@ -92,6 +106,13 @@ class SmartProtocol(BaseProtocol):
await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT) await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT)
continue continue
except TimeoutError as ex: except TimeoutError as ex:
if retry == 0:
_LOGGER.debug(
"Device %s got a timeout error, will retry %s times: %s",
self._host,
retry_count,
ex,
)
await self._transport.reset() await self._transport.reset()
if retry >= retry_count: if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry) _LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
@ -130,20 +151,21 @@ class SmartProtocol(BaseProtocol):
self._handle_response_error_code(resp, method, raise_on_error=False) self._handle_response_error_code(resp, method, raise_on_error=False)
multi_result[method] = resp["result"] multi_result[method] = resp["result"]
return multi_result return multi_result
for i in range(0, end, step):
for batch_num, i in enumerate(range(0, end, step)):
requests_step = multi_requests[i : i + step] requests_step = multi_requests[i : i + step]
smart_params = {"requests": requests_step} smart_params = {"requests": requests_step}
smart_request = self.get_smart_request(smart_method, smart_params) smart_request = self.get_smart_request(smart_method, smart_params)
batch_name = f"multi-request-batch-{batch_num+1}-of-{int(end/step)+1}"
if debug_enabled: if debug_enabled:
_LOGGER.debug( _LOGGER.debug(
"%s multi-request-batch-%s >> %s", "%s %s >> %s",
self._host, self._host,
i + 1, batch_name,
pf(smart_request), pf(smart_request),
) )
response_step = await self._transport.send(smart_request) response_step = await self._transport.send(smart_request)
batch_name = f"multi-request-batch-{i+1}"
if debug_enabled: if debug_enabled:
_LOGGER.debug( _LOGGER.debug(
"%s %s << %s", "%s %s << %s",
@ -271,7 +293,9 @@ class SmartProtocol(BaseProtocol):
try: try:
error_code = SmartErrorCode.from_int(error_code_raw) error_code = SmartErrorCode.from_int(error_code_raw)
except ValueError: except ValueError:
_LOGGER.warning("Received unknown error code: %s", error_code_raw) _LOGGER.warning(
"Device %s received unknown error code: %s", self._host, error_code_raw
)
error_code = SmartErrorCode.INTERNAL_UNKNOWN_ERROR error_code = SmartErrorCode.INTERNAL_UNKNOWN_ERROR
if error_code is SmartErrorCode.SUCCESS: if error_code is SmartErrorCode.SUCCESS:

View File

@ -3,10 +3,12 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
import time
from typing import Any, cast from typing import Any, cast
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from freezegun.api import FrozenDateTimeFactory
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
from kasa import Device, KasaException, Module from kasa import Device, KasaException, Module
@ -54,6 +56,8 @@ async def test_initial_update(dev: SmartDevice, mocker: MockerFixture):
dev._modules = {} dev._modules = {}
dev._features = {} dev._features = {}
dev._children = {} dev._children = {}
dev._last_update = {}
dev._last_update_time = None
negotiate = mocker.spy(dev, "_negotiate") negotiate = mocker.spy(dev, "_negotiate")
initialize_modules = mocker.spy(dev, "_initialize_modules") initialize_modules = mocker.spy(dev, "_initialize_modules")
@ -109,6 +113,9 @@ async def test_update_module_queries(dev: SmartDevice, mocker: MockerFixture):
"""Test that the regular update uses queries from all supported modules.""" """Test that the regular update uses queries from all supported modules."""
# We need to have some modules initialized by now # We need to have some modules initialized by now
assert dev._modules assert dev._modules
# Reset last update so all modules will query
for mod in dev._modules.values():
mod._last_update_time = None
device_queries: dict[SmartDevice, dict[str, Any]] = {} device_queries: dict[SmartDevice, dict[str, Any]] = {}
for mod in dev._modules.values(): for mod in dev._modules.values():
@ -139,7 +146,7 @@ async def test_update_module_errors(dev: SmartDevice, mocker: MockerFixture):
assert dev._modules assert dev._modules
critical_modules = {Module.DeviceModule, Module.ChildDevice} critical_modules = {Module.DeviceModule, Module.ChildDevice}
not_disabling_modules = {Module.Firmware, Module.Cloud} not_disabling_modules = {Module.Cloud}
new_dev = SmartDevice("127.0.0.1", protocol=dev.protocol) new_dev = SmartDevice("127.0.0.1", protocol=dev.protocol)
@ -204,6 +211,123 @@ async def test_update_module_errors(dev: SmartDevice, mocker: MockerFixture):
), f"{modname} present {mod_present} when no_disable {no_disable}" ), f"{modname} present {mod_present} when no_disable {no_disable}"
@device_smart
async def test_update_module_update_delays(
dev: SmartDevice,
mocker: MockerFixture,
caplog: pytest.LogCaptureFixture,
freezer: FrozenDateTimeFactory,
):
"""Test that modules that disabled / removed on query failures."""
# We need to have some modules initialized by now
assert dev._modules
new_dev = SmartDevice("127.0.0.1", protocol=dev.protocol)
await new_dev.update()
first_update_time = time.time()
assert new_dev._last_update_time == first_update_time
for module in new_dev.modules.values():
if module.query():
assert module._last_update_time == first_update_time
seconds = 0
tick = 30
while seconds <= 180:
seconds += tick
freezer.tick(tick)
now = time.time()
await new_dev.update()
for module in new_dev.modules.values():
mod_delay = module.MINIMUM_UPDATE_INTERVAL_SECS
if module.query():
expected_update_time = (
now if mod_delay == 0 else now - (seconds % mod_delay)
)
assert (
module._last_update_time == expected_update_time
), f"Expected update time {expected_update_time} after {seconds} seconds for {module.name} with delay {mod_delay} got {module._last_update_time}"
@pytest.mark.parametrize(
("first_update"),
[
pytest.param(True, id="First update true"),
pytest.param(False, id="First update false"),
],
)
@device_smart
async def test_update_module_query_errors(
dev: SmartDevice,
mocker: MockerFixture,
caplog: pytest.LogCaptureFixture,
freezer: FrozenDateTimeFactory,
first_update,
):
"""Test that modules that disabled / removed on query failures."""
# We need to have some modules initialized by now
assert dev._modules
first_update_queries = {"get_device_info", "get_connect_cloud_state"}
critical_modules = {Module.DeviceModule, Module.ChildDevice}
not_disabling_modules = {Module.Cloud}
new_dev = SmartDevice("127.0.0.1", protocol=dev.protocol)
if not first_update:
await new_dev.update()
freezer.tick(
max(module.MINIMUM_UPDATE_INTERVAL_SECS for module in dev._modules.values())
)
module_queries = {
modname: q
for modname, module in dev._modules.items()
if (q := module.query()) and modname not in critical_modules
}
async def _query(request, *args, **kwargs):
if (
"component_nego" in request
or "get_child_device_component_list" in request
or "control_child" in request
):
return await dev.protocol._query(request, *args, **kwargs)
if len(request) == 1 and "get_device_info" in request:
return await dev.protocol._query(request, *args, **kwargs)
raise TimeoutError("Dummy timeout")
from kasa.smartprotocol import _ChildProtocolWrapper
child_protocols = {
cast(_ChildProtocolWrapper, child.protocol)._device_id: child.protocol
for child in dev.children
}
async def _child_query(self, request, *args, **kwargs):
return await child_protocols[self._device_id]._query(request, *args, **kwargs)
mocker.patch.object(new_dev.protocol, "query", side_effect=_query)
# children not created yet so cannot patch.object
mocker.patch("kasa.smartprotocol._ChildProtocolWrapper.query", new=_child_query)
await new_dev.update()
msg = f"Error querying {new_dev.host} for modules"
assert msg in caplog.text
for modname in module_queries:
no_disable = modname in not_disabling_modules
mod_present = modname in new_dev._modules
assert (
mod_present is no_disable
), f"{modname} present {mod_present} when no_disable {no_disable}"
for mod_query in module_queries[modname]:
if not first_update or mod_query not in first_update_queries:
msg = f"Error querying {new_dev.host} individually for module query '{mod_query}"
assert msg in caplog.text
async def test_get_modules(): async def test_get_modules():
"""Test getting modules for child and parent modules.""" """Test getting modules for child and parent modules."""
dummy_device = await get_device_for_fixture_protocol( dummy_device = await get_device_for_fixture_protocol(

View File

@ -66,7 +66,7 @@ async def test_smart_device_unknown_errors(
assert res is SmartErrorCode.INTERNAL_UNKNOWN_ERROR assert res is SmartErrorCode.INTERNAL_UNKNOWN_ERROR
send_mock.assert_called_once() send_mock.assert_called_once()
assert f"Received unknown error code: {error_code}" in caplog.text assert f"received unknown error code: {error_code}" in caplog.text
@pytest.mark.parametrize("error_code", ERRORS, ids=lambda e: e.name) @pytest.mark.parametrize("error_code", ERRORS, ids=lambda e: e.name)