mirror of
https://github.com/python-kasa/python-kasa.git
synced 2024-12-22 19:23:34 +00:00
Defer module updates for less volatile modules (#1052)
Addresses stability issues on older hw device versions - Handles module timeout errors better by querying modules individually on errors and disabling problematic modules like Firmware that go out to the internet to get updates. - Addresses an issue with the Led module on P100 hardware version 1.0 which appears to have a memory leak and will cause the device to crash after approximately 500 calls. - Delays updates of modules that do not have regular changes like LightPreset and LightEffect and enables them to be updated on the next update cycle only if required values have changed.
This commit is contained in:
parent
a044063526
commit
7fd5c213e6
@ -146,7 +146,9 @@ class AesTransport(BaseTransport):
|
||||
try:
|
||||
error_code = SmartErrorCode.from_int(error_code_raw)
|
||||
except ValueError:
|
||||
_LOGGER.warning("Received unknown error code: %s", error_code_raw)
|
||||
_LOGGER.warning(
|
||||
"Device %s received unknown error code: %s", self._host, error_code_raw
|
||||
)
|
||||
error_code = SmartErrorCode.INTERNAL_UNKNOWN_ERROR
|
||||
if error_code is SmartErrorCode.SUCCESS:
|
||||
return
|
||||
@ -216,10 +218,18 @@ class AesTransport(BaseTransport):
|
||||
"""Login to the device."""
|
||||
try:
|
||||
await self.try_login(self._login_params)
|
||||
_LOGGER.debug(
|
||||
"%s: logged in with provided credentials",
|
||||
self._host,
|
||||
)
|
||||
except AuthenticationError as aex:
|
||||
try:
|
||||
if aex.error_code is not SmartErrorCode.LOGIN_ERROR:
|
||||
raise aex
|
||||
_LOGGER.debug(
|
||||
"%s: trying login with default TAPO credentials",
|
||||
self._host,
|
||||
)
|
||||
if self._default_credentials is None:
|
||||
self._default_credentials = get_default_credentials(
|
||||
DEFAULT_CREDENTIALS["TAPO"]
|
||||
@ -227,7 +237,7 @@ class AesTransport(BaseTransport):
|
||||
await self.perform_handshake()
|
||||
await self.try_login(self._get_login_params(self._default_credentials))
|
||||
_LOGGER.debug(
|
||||
"%s: logged in with default credentials",
|
||||
"%s: logged in with default TAPO credentials",
|
||||
self._host,
|
||||
)
|
||||
except (AuthenticationError, _ConnectionError, TimeoutError):
|
||||
|
@ -128,6 +128,8 @@ class SmartErrorCode(IntEnum):
|
||||
|
||||
# Library internal for unknown error codes
|
||||
INTERNAL_UNKNOWN_ERROR = -100_000
|
||||
# Library internal for query errors
|
||||
INTERNAL_QUERY_ERROR = -100_001
|
||||
|
||||
|
||||
SMART_RETRYABLE_ERRORS = [
|
||||
|
@ -75,13 +75,21 @@ class HttpClient:
|
||||
now = time.time()
|
||||
gap = now - self._last_request_time
|
||||
if gap < self._wait_between_requests:
|
||||
await asyncio.sleep(self._wait_between_requests - gap)
|
||||
sleep = self._wait_between_requests - gap
|
||||
_LOGGER.debug(
|
||||
"Device %s waiting %s seconds to send request",
|
||||
self._config.host,
|
||||
sleep,
|
||||
)
|
||||
await asyncio.sleep(sleep)
|
||||
|
||||
_LOGGER.debug("Posting to %s", url)
|
||||
response_data = None
|
||||
self._last_url = url
|
||||
self.client.cookie_jar.clear()
|
||||
return_json = bool(json)
|
||||
client_timeout = aiohttp.ClientTimeout(total=self._config.timeout)
|
||||
|
||||
# If json is not a dict send as data.
|
||||
# This allows the json parameter to be used to pass other
|
||||
# types of data such as async_generator and still have json
|
||||
@ -95,9 +103,10 @@ class HttpClient:
|
||||
params=params,
|
||||
data=data,
|
||||
json=json,
|
||||
timeout=self._config.timeout,
|
||||
timeout=client_timeout,
|
||||
cookies=cookies_dict,
|
||||
headers=headers,
|
||||
ssl=False,
|
||||
)
|
||||
async with resp:
|
||||
if resp.status == 200:
|
||||
@ -106,7 +115,13 @@ class HttpClient:
|
||||
response_data = json_loads(response_data.decode())
|
||||
|
||||
except (aiohttp.ServerDisconnectedError, aiohttp.ClientOSError) as ex:
|
||||
if isinstance(ex, aiohttp.ClientOSError):
|
||||
if not self._wait_between_requests:
|
||||
_LOGGER.debug(
|
||||
"Device %s received an os error, "
|
||||
"enabling sequential request delay: %s",
|
||||
self._config.host,
|
||||
ex,
|
||||
)
|
||||
self._wait_between_requests = self.WAIT_BETWEEN_REQUESTS_ON_OSERROR
|
||||
self._last_request_time = time.time()
|
||||
raise _ConnectionError(
|
||||
|
@ -16,6 +16,7 @@ class Cloud(SmartModule):
|
||||
|
||||
QUERY_GETTER_NAME = "get_connect_cloud_state"
|
||||
REQUIRED_COMPONENT = "cloud_connect"
|
||||
MINIMUM_UPDATE_INTERVAL_SECS = 60
|
||||
|
||||
def _post_update_hook(self):
|
||||
"""Perform actions after a device update.
|
||||
|
@ -14,7 +14,7 @@ from async_timeout import timeout as asyncio_timeout
|
||||
from pydantic.v1 import BaseModel, Field, validator
|
||||
|
||||
from ...feature import Feature
|
||||
from ..smartmodule import SmartModule
|
||||
from ..smartmodule import SmartModule, allow_update_after
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..smartdevice import SmartDevice
|
||||
@ -66,6 +66,7 @@ class Firmware(SmartModule):
|
||||
"""Implementation of firmware module."""
|
||||
|
||||
REQUIRED_COMPONENT = "firmware"
|
||||
MINIMUM_UPDATE_INTERVAL_SECS = 60 * 60 * 24
|
||||
|
||||
def __init__(self, device: SmartDevice, module: str):
|
||||
super().__init__(device, module)
|
||||
@ -122,13 +123,6 @@ class Firmware(SmartModule):
|
||||
req["get_auto_update_info"] = None
|
||||
return req
|
||||
|
||||
def _post_update_hook(self):
|
||||
"""Perform actions after a device update.
|
||||
|
||||
Overrides the default behaviour to disable a module if the query returns
|
||||
an error because some of the module still functions.
|
||||
"""
|
||||
|
||||
@property
|
||||
def current_firmware(self) -> str:
|
||||
"""Return the current firmware version."""
|
||||
@ -162,6 +156,7 @@ class Firmware(SmartModule):
|
||||
state = resp["get_fw_download_state"]
|
||||
return DownloadState(**state)
|
||||
|
||||
@allow_update_after
|
||||
async def update(
|
||||
self, progress_cb: Callable[[DownloadState], Coroutine] | None = None
|
||||
):
|
||||
@ -219,6 +214,7 @@ class Firmware(SmartModule):
|
||||
and self.data["get_auto_update_info"]["enable"]
|
||||
)
|
||||
|
||||
@allow_update_after
|
||||
async def set_auto_update_enabled(self, enabled: bool):
|
||||
"""Change autoupdate setting."""
|
||||
data = {**self.data["get_auto_update_info"], "enable": enabled}
|
||||
|
@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ...interfaces.led import Led as LedInterface
|
||||
from ..smartmodule import SmartModule
|
||||
from ..smartmodule import SmartModule, allow_update_after
|
||||
|
||||
|
||||
class Led(SmartModule, LedInterface):
|
||||
@ -11,6 +11,8 @@ class Led(SmartModule, LedInterface):
|
||||
|
||||
REQUIRED_COMPONENT = "led"
|
||||
QUERY_GETTER_NAME = "get_led_info"
|
||||
# Led queries can cause device to crash on P100
|
||||
MINIMUM_UPDATE_INTERVAL_SECS = 60 * 60
|
||||
|
||||
def query(self) -> dict:
|
||||
"""Query to execute during the update cycle."""
|
||||
@ -29,6 +31,7 @@ class Led(SmartModule, LedInterface):
|
||||
"""Return current led status."""
|
||||
return self.data["led_rule"] != "never"
|
||||
|
||||
@allow_update_after
|
||||
async def set_led(self, enable: bool):
|
||||
"""Set led.
|
||||
|
||||
|
@ -9,7 +9,7 @@ import copy
|
||||
from typing import Any
|
||||
|
||||
from ..effects import SmartLightEffect
|
||||
from ..smartmodule import Module, SmartModule
|
||||
from ..smartmodule import Module, SmartModule, allow_update_after
|
||||
|
||||
|
||||
class LightEffect(SmartModule, SmartLightEffect):
|
||||
@ -17,6 +17,7 @@ class LightEffect(SmartModule, SmartLightEffect):
|
||||
|
||||
REQUIRED_COMPONENT = "light_effect"
|
||||
QUERY_GETTER_NAME = "get_dynamic_light_effect_rules"
|
||||
MINIMUM_UPDATE_INTERVAL_SECS = 60
|
||||
AVAILABLE_BULB_EFFECTS = {
|
||||
"L1": "Party",
|
||||
"L2": "Relax",
|
||||
@ -130,6 +131,7 @@ class LightEffect(SmartModule, SmartLightEffect):
|
||||
|
||||
return brightness
|
||||
|
||||
@allow_update_after
|
||||
async def set_brightness(
|
||||
self,
|
||||
brightness: int,
|
||||
@ -156,6 +158,7 @@ class LightEffect(SmartModule, SmartLightEffect):
|
||||
|
||||
return await self.call("edit_dynamic_light_effect_rule", new_effect)
|
||||
|
||||
@allow_update_after
|
||||
async def set_custom_effect(
|
||||
self,
|
||||
effect_dict: dict,
|
||||
|
@ -9,7 +9,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
from ...interfaces import LightPreset as LightPresetInterface
|
||||
from ...interfaces import LightState
|
||||
from ..smartmodule import SmartModule
|
||||
from ..smartmodule import SmartModule, allow_update_after
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..smartdevice import SmartDevice
|
||||
@ -22,6 +22,7 @@ class LightPreset(SmartModule, LightPresetInterface):
|
||||
|
||||
REQUIRED_COMPONENT = "preset"
|
||||
QUERY_GETTER_NAME = "get_preset_rules"
|
||||
MINIMUM_UPDATE_INTERVAL_SECS = 60
|
||||
|
||||
SYS_INFO_STATE_KEY = "preset_state"
|
||||
|
||||
@ -124,6 +125,7 @@ class LightPreset(SmartModule, LightPresetInterface):
|
||||
raise ValueError(f"{preset_name} is not a valid preset: {self.preset_list}")
|
||||
await self._device.modules[SmartModule.Light].set_state(preset)
|
||||
|
||||
@allow_update_after
|
||||
async def save_preset(
|
||||
self,
|
||||
preset_name: str,
|
||||
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..effects import EFFECT_MAPPING, EFFECT_NAMES, SmartLightEffect
|
||||
from ..smartmodule import Module, SmartModule
|
||||
from ..smartmodule import Module, SmartModule, allow_update_after
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..smartdevice import SmartDevice
|
||||
@ -84,6 +84,7 @@ class LightStripEffect(SmartModule, SmartLightEffect):
|
||||
"""
|
||||
return self._effect_list
|
||||
|
||||
@allow_update_after
|
||||
async def set_effect(
|
||||
self,
|
||||
effect: str,
|
||||
@ -126,6 +127,7 @@ class LightStripEffect(SmartModule, SmartLightEffect):
|
||||
|
||||
await self.set_custom_effect(effect_dict)
|
||||
|
||||
@allow_update_after
|
||||
async def set_custom_effect(
|
||||
self,
|
||||
effect_dict: dict,
|
||||
|
@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, TypedDict
|
||||
|
||||
from ...exceptions import KasaException
|
||||
from ...feature import Feature
|
||||
from ..smartmodule import SmartModule
|
||||
from ..smartmodule import SmartModule, allow_update_after
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..smartdevice import SmartDevice
|
||||
@ -23,6 +23,7 @@ class LightTransition(SmartModule):
|
||||
|
||||
REQUIRED_COMPONENT = "on_off_gradually"
|
||||
QUERY_GETTER_NAME = "get_on_off_gradually_info"
|
||||
MINIMUM_UPDATE_INTERVAL_SECS = 60
|
||||
MAXIMUM_DURATION = 60
|
||||
|
||||
# Key in sysinfo that indicates state can be retrieved from there.
|
||||
@ -136,6 +137,7 @@ class LightTransition(SmartModule):
|
||||
"max_duration": off_max,
|
||||
}
|
||||
|
||||
@allow_update_after
|
||||
async def set_enabled(self, enable: bool):
|
||||
"""Enable gradual on/off."""
|
||||
if not self._supports_on_and_off:
|
||||
@ -168,6 +170,7 @@ class LightTransition(SmartModule):
|
||||
# v3 added max_duration, we default to 60 when it's not available
|
||||
return self._on_state["max_duration"]
|
||||
|
||||
@allow_update_after
|
||||
async def set_turn_on_transition(self, seconds: int):
|
||||
"""Set turn on transition in seconds.
|
||||
|
||||
@ -203,6 +206,7 @@ class LightTransition(SmartModule):
|
||||
# v3 added max_duration, we default to 60 when it's not available
|
||||
return self._off_state["max_duration"]
|
||||
|
||||
@allow_update_after
|
||||
async def set_turn_off_transition(self, seconds: int):
|
||||
"""Set turn on transition in seconds.
|
||||
|
||||
|
@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from ..device_type import DeviceType
|
||||
@ -54,6 +55,7 @@ class SmartChildDevice(SmartDevice):
|
||||
req.update(mod_query)
|
||||
if req:
|
||||
self._last_update = await self.protocol.query(req)
|
||||
self._last_update_time = time.time()
|
||||
|
||||
@classmethod
|
||||
async def create(cls, parent: SmartDevice, child_info, child_components):
|
||||
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, cast
|
||||
@ -18,6 +19,7 @@ from ..module import Module
|
||||
from ..modulemapping import ModuleMapping, ModuleName
|
||||
from ..smartprotocol import SmartProtocol
|
||||
from .modules import (
|
||||
ChildDevice,
|
||||
Cloud,
|
||||
DeviceModule,
|
||||
Firmware,
|
||||
@ -35,6 +37,9 @@ _LOGGER = logging.getLogger(__name__)
|
||||
# same issue, homekit perhaps?
|
||||
NON_HUB_PARENT_ONLY_MODULES = [DeviceModule, Time, Firmware, Cloud]
|
||||
|
||||
# Modules that are called as part of the init procedure on first update
|
||||
FIRST_UPDATE_MODULES = {DeviceModule, ChildDevice, Cloud}
|
||||
|
||||
|
||||
# Device must go last as the other interfaces also inherit Device
|
||||
# and python needs a consistent method resolution order.
|
||||
@ -60,6 +65,7 @@ class SmartDevice(Device):
|
||||
self._parent: SmartDevice | None = None
|
||||
self._children: Mapping[str, SmartDevice] = {}
|
||||
self._last_update = {}
|
||||
self._last_update_time: float | None = None
|
||||
|
||||
async def _initialize_children(self):
|
||||
"""Initialize children for power strips."""
|
||||
@ -152,19 +158,15 @@ class SmartDevice(Device):
|
||||
if self.credentials is None and self.credentials_hash is None:
|
||||
raise AuthenticationError("Tapo plug requires authentication.")
|
||||
|
||||
if self._components_raw is None:
|
||||
first_update = self._last_update_time is None
|
||||
now = time.time()
|
||||
self._last_update_time = now
|
||||
|
||||
if first_update:
|
||||
await self._negotiate()
|
||||
await self._initialize_modules()
|
||||
|
||||
req: dict[str, Any] = {}
|
||||
|
||||
# TODO: this could be optimized by constructing the query only once
|
||||
for module in self._modules.values():
|
||||
req.update(module.query())
|
||||
|
||||
self._last_update = resp = await self.protocol.query(req)
|
||||
|
||||
self._info = self._try_get_response(resp, "get_device_info")
|
||||
resp = await self._modular_update(first_update, now)
|
||||
|
||||
# Call child update which will only update module calls, info is updated
|
||||
# from get_child_device_list. update_children only affects hub devices, other
|
||||
@ -172,18 +174,12 @@ class SmartDevice(Device):
|
||||
if update_children or self.device_type != DeviceType.Hub:
|
||||
for child in self._children.values():
|
||||
await child._update()
|
||||
if child_info := self._try_get_response(resp, "get_child_device_list", {}):
|
||||
if child_info := self._try_get_response(
|
||||
self._last_update, "get_child_device_list", {}
|
||||
):
|
||||
for info in child_info["child_device_list"]:
|
||||
self._children[info["device_id"]]._update_internal_state(info)
|
||||
|
||||
# Call handle update for modules that want to update internal data
|
||||
errors = []
|
||||
for module_name, module in self._modules.items():
|
||||
if not self._handle_module_post_update_hook(module):
|
||||
errors.append(module_name)
|
||||
for error in errors:
|
||||
self._modules.pop(error)
|
||||
|
||||
for child in self._children.values():
|
||||
errors = []
|
||||
for child_module_name, child_module in child._modules.items():
|
||||
@ -197,14 +193,18 @@ class SmartDevice(Device):
|
||||
if not self._features:
|
||||
await self._initialize_features()
|
||||
|
||||
_LOGGER.debug("Got an update: %s", self._last_update)
|
||||
_LOGGER.debug(
|
||||
"Update completed %s: %s",
|
||||
self.host,
|
||||
self._last_update if first_update else resp,
|
||||
)
|
||||
|
||||
def _handle_module_post_update_hook(self, module: SmartModule) -> bool:
|
||||
try:
|
||||
module._post_update_hook()
|
||||
return True
|
||||
except Exception as ex:
|
||||
_LOGGER.error(
|
||||
_LOGGER.warning(
|
||||
"Error processing %s for device %s, module will be unavailable: %s",
|
||||
module.name,
|
||||
self.host,
|
||||
@ -212,6 +212,100 @@ class SmartDevice(Device):
|
||||
)
|
||||
return False
|
||||
|
||||
async def _modular_update(
|
||||
self, first_update: bool, update_time: float
|
||||
) -> dict[str, Any]:
|
||||
"""Update the device with via the module queries."""
|
||||
req: dict[str, Any] = {}
|
||||
# Keep a track of actual module queries so we can track the time for
|
||||
# modules that do not need to be updated frequently
|
||||
module_queries: list[SmartModule] = []
|
||||
mq = {
|
||||
module: query
|
||||
for module in self._modules.values()
|
||||
if (query := module.query())
|
||||
}
|
||||
for module, query in mq.items():
|
||||
if first_update and module.__class__ in FIRST_UPDATE_MODULES:
|
||||
module._last_update_time = update_time
|
||||
continue
|
||||
if (
|
||||
not module.MINIMUM_UPDATE_INTERVAL_SECS
|
||||
or not module._last_update_time
|
||||
or (update_time - module._last_update_time)
|
||||
>= module.MINIMUM_UPDATE_INTERVAL_SECS
|
||||
):
|
||||
module_queries.append(module)
|
||||
req.update(query)
|
||||
|
||||
_LOGGER.debug(
|
||||
"Querying %s for modules: %s",
|
||||
self.host,
|
||||
", ".join(mod.name for mod in module_queries),
|
||||
)
|
||||
|
||||
try:
|
||||
resp = await self.protocol.query(req)
|
||||
except Exception as ex:
|
||||
resp = await self._handle_modular_update_error(
|
||||
ex, first_update, ", ".join(mod.name for mod in module_queries), req
|
||||
)
|
||||
|
||||
info_resp = self._last_update if first_update else resp
|
||||
self._last_update.update(**resp)
|
||||
self._info = self._try_get_response(info_resp, "get_device_info")
|
||||
|
||||
# Call handle update for modules that want to update internal data
|
||||
errors = []
|
||||
for module_name, module in self._modules.items():
|
||||
if not self._handle_module_post_update_hook(module):
|
||||
errors.append(module_name)
|
||||
for error in errors:
|
||||
self._modules.pop(error)
|
||||
|
||||
# Set the last update time for modules that had queries made.
|
||||
for module in module_queries:
|
||||
module._last_update_time = update_time
|
||||
|
||||
return resp
|
||||
|
||||
async def _handle_modular_update_error(
|
||||
self,
|
||||
ex: Exception,
|
||||
first_update: bool,
|
||||
module_names: str,
|
||||
requests: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Handle an error on calling module update.
|
||||
|
||||
Will try to call all modules individually
|
||||
and any errors such as timeouts will be set as a SmartErrorCode.
|
||||
"""
|
||||
msg_part = "on first update" if first_update else "after first update"
|
||||
|
||||
_LOGGER.error(
|
||||
"Error querying %s for modules '%s' %s: %s",
|
||||
self.host,
|
||||
module_names,
|
||||
msg_part,
|
||||
ex,
|
||||
)
|
||||
responses = {}
|
||||
for meth, params in requests.items():
|
||||
try:
|
||||
resp = await self.protocol.query({meth: params})
|
||||
responses[meth] = resp[meth]
|
||||
except Exception as iex:
|
||||
_LOGGER.error(
|
||||
"Error querying %s individually for module query '%s' %s: %s",
|
||||
self.host,
|
||||
meth,
|
||||
msg_part,
|
||||
iex,
|
||||
)
|
||||
responses[meth] = SmartErrorCode.INTERNAL_QUERY_ERROR
|
||||
return responses
|
||||
|
||||
async def _initialize_modules(self):
|
||||
"""Initialize modules based on component negotiation response."""
|
||||
from .smartmodule import SmartModule
|
||||
@ -229,8 +323,6 @@ class SmartDevice(Device):
|
||||
skip_parent_only_modules = True
|
||||
|
||||
for mod in SmartModule.REGISTERED_MODULES.values():
|
||||
_LOGGER.debug("%s requires %s", mod, mod.REQUIRED_COMPONENT)
|
||||
|
||||
if (
|
||||
skip_parent_only_modules and mod in NON_HUB_PARENT_ONLY_MODULES
|
||||
) or mod.__name__ in child_modules_to_skip:
|
||||
@ -240,7 +332,8 @@ class SmartDevice(Device):
|
||||
or self.sys_info.get(mod.REQUIRED_KEY_ON_PARENT) is not None
|
||||
):
|
||||
_LOGGER.debug(
|
||||
"Found required %s, adding %s to modules.",
|
||||
"Device %s, found required %s, adding %s to modules.",
|
||||
self.host,
|
||||
mod.REQUIRED_COMPONENT,
|
||||
mod.__name__,
|
||||
)
|
||||
|
@ -3,7 +3,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
from collections.abc import Awaitable, Callable, Coroutine
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from typing_extensions import Concatenate, ParamSpec, TypeVar
|
||||
|
||||
from ..exceptions import DeviceError, KasaException, SmartErrorCode
|
||||
from ..module import Module
|
||||
@ -13,6 +16,27 @@ if TYPE_CHECKING:
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_T = TypeVar("_T", bound="SmartModule")
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
|
||||
def allow_update_after(
|
||||
func: Callable[Concatenate[_T, _P], Awaitable[None]],
|
||||
) -> Callable[Concatenate[_T, _P], Coroutine[Any, Any, None]]:
|
||||
"""Define a wrapper to set _last_update_time to None.
|
||||
|
||||
This will ensure that a module is updated in the next update cycle after
|
||||
a value has been changed.
|
||||
"""
|
||||
|
||||
async def _async_wrap(self: _T, *args: _P.args, **kwargs: _P.kwargs) -> None:
|
||||
try:
|
||||
await func(self, *args, **kwargs)
|
||||
finally:
|
||||
self._last_update_time = None
|
||||
|
||||
return _async_wrap
|
||||
|
||||
|
||||
class SmartModule(Module):
|
||||
"""Base class for SMART modules."""
|
||||
@ -27,9 +51,12 @@ class SmartModule(Module):
|
||||
|
||||
REGISTERED_MODULES: dict[str, type[SmartModule]] = {}
|
||||
|
||||
MINIMUM_UPDATE_INTERVAL_SECS = 0
|
||||
|
||||
def __init__(self, device: SmartDevice, module: str):
|
||||
self._device: SmartDevice
|
||||
super().__init__(device, module)
|
||||
self._last_update_time: float | None = None
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
name = getattr(cls, "NAME", cls.__name__)
|
||||
|
@ -73,18 +73,32 @@ class SmartProtocol(BaseProtocol):
|
||||
return await self._execute_query(
|
||||
request, retry_count=retry, iterate_list_pages=True
|
||||
)
|
||||
except _ConnectionError as sdex:
|
||||
except _ConnectionError as ex:
|
||||
if retry == 0:
|
||||
_LOGGER.debug(
|
||||
"Device %s got a connection error, will retry %s times: %s",
|
||||
self._host,
|
||||
retry_count,
|
||||
ex,
|
||||
)
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise sdex
|
||||
raise ex
|
||||
continue
|
||||
except AuthenticationError as auex:
|
||||
except AuthenticationError as ex:
|
||||
await self._transport.reset()
|
||||
_LOGGER.debug(
|
||||
"Unable to authenticate with %s, not retrying", self._host
|
||||
"Unable to authenticate with %s, not retrying: %s", self._host, ex
|
||||
)
|
||||
raise auex
|
||||
raise ex
|
||||
except _RetryableError as ex:
|
||||
if retry == 0:
|
||||
_LOGGER.debug(
|
||||
"Device %s got a retryable error, will retry %s times: %s",
|
||||
self._host,
|
||||
retry_count,
|
||||
ex,
|
||||
)
|
||||
await self._transport.reset()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
@ -92,6 +106,13 @@ class SmartProtocol(BaseProtocol):
|
||||
await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT)
|
||||
continue
|
||||
except TimeoutError as ex:
|
||||
if retry == 0:
|
||||
_LOGGER.debug(
|
||||
"Device %s got a timeout error, will retry %s times: %s",
|
||||
self._host,
|
||||
retry_count,
|
||||
ex,
|
||||
)
|
||||
await self._transport.reset()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
@ -130,20 +151,21 @@ class SmartProtocol(BaseProtocol):
|
||||
self._handle_response_error_code(resp, method, raise_on_error=False)
|
||||
multi_result[method] = resp["result"]
|
||||
return multi_result
|
||||
for i in range(0, end, step):
|
||||
|
||||
for batch_num, i in enumerate(range(0, end, step)):
|
||||
requests_step = multi_requests[i : i + step]
|
||||
|
||||
smart_params = {"requests": requests_step}
|
||||
smart_request = self.get_smart_request(smart_method, smart_params)
|
||||
batch_name = f"multi-request-batch-{batch_num+1}-of-{int(end/step)+1}"
|
||||
if debug_enabled:
|
||||
_LOGGER.debug(
|
||||
"%s multi-request-batch-%s >> %s",
|
||||
"%s %s >> %s",
|
||||
self._host,
|
||||
i + 1,
|
||||
batch_name,
|
||||
pf(smart_request),
|
||||
)
|
||||
response_step = await self._transport.send(smart_request)
|
||||
batch_name = f"multi-request-batch-{i+1}"
|
||||
if debug_enabled:
|
||||
_LOGGER.debug(
|
||||
"%s %s << %s",
|
||||
@ -271,7 +293,9 @@ class SmartProtocol(BaseProtocol):
|
||||
try:
|
||||
error_code = SmartErrorCode.from_int(error_code_raw)
|
||||
except ValueError:
|
||||
_LOGGER.warning("Received unknown error code: %s", error_code_raw)
|
||||
_LOGGER.warning(
|
||||
"Device %s received unknown error code: %s", self._host, error_code_raw
|
||||
)
|
||||
error_code = SmartErrorCode.INTERNAL_UNKNOWN_ERROR
|
||||
|
||||
if error_code is SmartErrorCode.SUCCESS:
|
||||
|
@ -3,10 +3,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, cast
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from freezegun.api import FrozenDateTimeFactory
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from kasa import Device, KasaException, Module
|
||||
@ -54,6 +56,8 @@ async def test_initial_update(dev: SmartDevice, mocker: MockerFixture):
|
||||
dev._modules = {}
|
||||
dev._features = {}
|
||||
dev._children = {}
|
||||
dev._last_update = {}
|
||||
dev._last_update_time = None
|
||||
|
||||
negotiate = mocker.spy(dev, "_negotiate")
|
||||
initialize_modules = mocker.spy(dev, "_initialize_modules")
|
||||
@ -109,6 +113,9 @@ async def test_update_module_queries(dev: SmartDevice, mocker: MockerFixture):
|
||||
"""Test that the regular update uses queries from all supported modules."""
|
||||
# We need to have some modules initialized by now
|
||||
assert dev._modules
|
||||
# Reset last update so all modules will query
|
||||
for mod in dev._modules.values():
|
||||
mod._last_update_time = None
|
||||
|
||||
device_queries: dict[SmartDevice, dict[str, Any]] = {}
|
||||
for mod in dev._modules.values():
|
||||
@ -139,7 +146,7 @@ async def test_update_module_errors(dev: SmartDevice, mocker: MockerFixture):
|
||||
assert dev._modules
|
||||
|
||||
critical_modules = {Module.DeviceModule, Module.ChildDevice}
|
||||
not_disabling_modules = {Module.Firmware, Module.Cloud}
|
||||
not_disabling_modules = {Module.Cloud}
|
||||
|
||||
new_dev = SmartDevice("127.0.0.1", protocol=dev.protocol)
|
||||
|
||||
@ -204,6 +211,123 @@ async def test_update_module_errors(dev: SmartDevice, mocker: MockerFixture):
|
||||
), f"{modname} present {mod_present} when no_disable {no_disable}"
|
||||
|
||||
|
||||
@device_smart
|
||||
async def test_update_module_update_delays(
|
||||
dev: SmartDevice,
|
||||
mocker: MockerFixture,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
freezer: FrozenDateTimeFactory,
|
||||
):
|
||||
"""Test that modules that disabled / removed on query failures."""
|
||||
# We need to have some modules initialized by now
|
||||
assert dev._modules
|
||||
|
||||
new_dev = SmartDevice("127.0.0.1", protocol=dev.protocol)
|
||||
await new_dev.update()
|
||||
first_update_time = time.time()
|
||||
assert new_dev._last_update_time == first_update_time
|
||||
for module in new_dev.modules.values():
|
||||
if module.query():
|
||||
assert module._last_update_time == first_update_time
|
||||
|
||||
seconds = 0
|
||||
tick = 30
|
||||
while seconds <= 180:
|
||||
seconds += tick
|
||||
freezer.tick(tick)
|
||||
|
||||
now = time.time()
|
||||
await new_dev.update()
|
||||
for module in new_dev.modules.values():
|
||||
mod_delay = module.MINIMUM_UPDATE_INTERVAL_SECS
|
||||
if module.query():
|
||||
expected_update_time = (
|
||||
now if mod_delay == 0 else now - (seconds % mod_delay)
|
||||
)
|
||||
|
||||
assert (
|
||||
module._last_update_time == expected_update_time
|
||||
), f"Expected update time {expected_update_time} after {seconds} seconds for {module.name} with delay {mod_delay} got {module._last_update_time}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("first_update"),
|
||||
[
|
||||
pytest.param(True, id="First update true"),
|
||||
pytest.param(False, id="First update false"),
|
||||
],
|
||||
)
|
||||
@device_smart
|
||||
async def test_update_module_query_errors(
|
||||
dev: SmartDevice,
|
||||
mocker: MockerFixture,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
freezer: FrozenDateTimeFactory,
|
||||
first_update,
|
||||
):
|
||||
"""Test that modules that disabled / removed on query failures."""
|
||||
# We need to have some modules initialized by now
|
||||
assert dev._modules
|
||||
|
||||
first_update_queries = {"get_device_info", "get_connect_cloud_state"}
|
||||
|
||||
critical_modules = {Module.DeviceModule, Module.ChildDevice}
|
||||
not_disabling_modules = {Module.Cloud}
|
||||
|
||||
new_dev = SmartDevice("127.0.0.1", protocol=dev.protocol)
|
||||
if not first_update:
|
||||
await new_dev.update()
|
||||
freezer.tick(
|
||||
max(module.MINIMUM_UPDATE_INTERVAL_SECS for module in dev._modules.values())
|
||||
)
|
||||
|
||||
module_queries = {
|
||||
modname: q
|
||||
for modname, module in dev._modules.items()
|
||||
if (q := module.query()) and modname not in critical_modules
|
||||
}
|
||||
|
||||
async def _query(request, *args, **kwargs):
|
||||
if (
|
||||
"component_nego" in request
|
||||
or "get_child_device_component_list" in request
|
||||
or "control_child" in request
|
||||
):
|
||||
return await dev.protocol._query(request, *args, **kwargs)
|
||||
if len(request) == 1 and "get_device_info" in request:
|
||||
return await dev.protocol._query(request, *args, **kwargs)
|
||||
|
||||
raise TimeoutError("Dummy timeout")
|
||||
|
||||
from kasa.smartprotocol import _ChildProtocolWrapper
|
||||
|
||||
child_protocols = {
|
||||
cast(_ChildProtocolWrapper, child.protocol)._device_id: child.protocol
|
||||
for child in dev.children
|
||||
}
|
||||
|
||||
async def _child_query(self, request, *args, **kwargs):
|
||||
return await child_protocols[self._device_id]._query(request, *args, **kwargs)
|
||||
|
||||
mocker.patch.object(new_dev.protocol, "query", side_effect=_query)
|
||||
# children not created yet so cannot patch.object
|
||||
mocker.patch("kasa.smartprotocol._ChildProtocolWrapper.query", new=_child_query)
|
||||
|
||||
await new_dev.update()
|
||||
msg = f"Error querying {new_dev.host} for modules"
|
||||
assert msg in caplog.text
|
||||
for modname in module_queries:
|
||||
no_disable = modname in not_disabling_modules
|
||||
mod_present = modname in new_dev._modules
|
||||
assert (
|
||||
mod_present is no_disable
|
||||
), f"{modname} present {mod_present} when no_disable {no_disable}"
|
||||
for mod_query in module_queries[modname]:
|
||||
if not first_update or mod_query not in first_update_queries:
|
||||
msg = f"Error querying {new_dev.host} individually for module query '{mod_query}"
|
||||
assert msg in caplog.text
|
||||
|
||||
|
||||
async def test_get_modules():
|
||||
"""Test getting modules for child and parent modules."""
|
||||
dummy_device = await get_device_for_fixture_protocol(
|
||||
|
@ -66,7 +66,7 @@ async def test_smart_device_unknown_errors(
|
||||
assert res is SmartErrorCode.INTERNAL_UNKNOWN_ERROR
|
||||
|
||||
send_mock.assert_called_once()
|
||||
assert f"Received unknown error code: {error_code}" in caplog.text
|
||||
assert f"received unknown error code: {error_code}" in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.parametrize("error_code", ERRORS, ids=lambda e: e.name)
|
||||
|
Loading…
Reference in New Issue
Block a user