diff --git a/kasa/aestransport.py b/kasa/aestransport.py index cc373b19..abe282c0 100644 --- a/kasa/aestransport.py +++ b/kasa/aestransport.py @@ -144,7 +144,9 @@ class AesTransport(BaseTransport): try: error_code = SmartErrorCode.from_int(error_code_raw) except ValueError: - _LOGGER.warning("Received unknown error code: %s", error_code_raw) + _LOGGER.warning( + "Device %s received unknown error code: %s", self._host, error_code_raw + ) error_code = SmartErrorCode.INTERNAL_UNKNOWN_ERROR if error_code is SmartErrorCode.SUCCESS: return @@ -214,10 +216,18 @@ class AesTransport(BaseTransport): """Login to the device.""" try: await self.try_login(self._login_params) + _LOGGER.debug( + "%s: logged in with provided credentials", + self._host, + ) except AuthenticationError as aex: try: if aex.error_code is not SmartErrorCode.LOGIN_ERROR: raise aex + _LOGGER.debug( + "%s: trying login with default TAPO credentials", + self._host, + ) if self._default_credentials is None: self._default_credentials = get_default_credentials( DEFAULT_CREDENTIALS["TAPO"] @@ -225,7 +235,7 @@ class AesTransport(BaseTransport): await self.perform_handshake() await self.try_login(self._get_login_params(self._default_credentials)) _LOGGER.debug( - "%s: logged in with default credentials", + "%s: logged in with default TAPO credentials", self._host, ) except (AuthenticationError, _ConnectionError, TimeoutError): diff --git a/kasa/exceptions.py b/kasa/exceptions.py index f5c26ff0..3f7f301b 100644 --- a/kasa/exceptions.py +++ b/kasa/exceptions.py @@ -128,6 +128,8 @@ class SmartErrorCode(IntEnum): # Library internal for unknown error codes INTERNAL_UNKNOWN_ERROR = -100_000 + # Library internal for query errors + INTERNAL_QUERY_ERROR = -100_001 SMART_RETRYABLE_ERRORS = [ diff --git a/kasa/httpclient.py b/kasa/httpclient.py index 02e69782..1c8c46e2 100644 --- a/kasa/httpclient.py +++ b/kasa/httpclient.py @@ -75,13 +75,21 @@ class HttpClient: now = time.time() gap = now - self._last_request_time if gap < self._wait_between_requests: - await asyncio.sleep(self._wait_between_requests - gap) + sleep = self._wait_between_requests - gap + _LOGGER.debug( + "Device %s waiting %s seconds to send request", + self._config.host, + sleep, + ) + await asyncio.sleep(sleep) _LOGGER.debug("Posting to %s", url) response_data = None self._last_url = url self.client.cookie_jar.clear() return_json = bool(json) + client_timeout = aiohttp.ClientTimeout(total=self._config.timeout) + # If json is not a dict send as data. # This allows the json parameter to be used to pass other # types of data such as async_generator and still have json @@ -95,9 +103,10 @@ class HttpClient: params=params, data=data, json=json, - timeout=self._config.timeout, + timeout=client_timeout, cookies=cookies_dict, headers=headers, + ssl=False, ) async with resp: if resp.status == 200: @@ -106,9 +115,15 @@ class HttpClient: response_data = json_loads(response_data.decode()) except (aiohttp.ServerDisconnectedError, aiohttp.ClientOSError) as ex: - if isinstance(ex, aiohttp.ClientOSError): + if not self._wait_between_requests: + _LOGGER.debug( + "Device %s received an os error, " + "enabling sequential request delay: %s", + self._config.host, + ex, + ) self._wait_between_requests = self.WAIT_BETWEEN_REQUESTS_ON_OSERROR - self._last_request_time = time.time() + self._last_request_time = time.time() raise _ConnectionError( f"Device connection error: {self._config.host}: {ex}", ex ) from ex diff --git a/kasa/smart/modules/cloud.py b/kasa/smart/modules/cloud.py index 8346af57..e7513a56 100644 --- a/kasa/smart/modules/cloud.py +++ b/kasa/smart/modules/cloud.py @@ -16,6 +16,7 @@ class Cloud(SmartModule): QUERY_GETTER_NAME = "get_connect_cloud_state" REQUIRED_COMPONENT = "cloud_connect" + MINIMUM_UPDATE_INTERVAL_SECS = 60 def _post_update_hook(self): """Perform actions after a device update. diff --git a/kasa/smart/modules/firmware.py b/kasa/smart/modules/firmware.py index 10a6b824..dc0483e7 100644 --- a/kasa/smart/modules/firmware.py +++ b/kasa/smart/modules/firmware.py @@ -14,7 +14,7 @@ from async_timeout import timeout as asyncio_timeout from pydantic.v1 import BaseModel, Field, validator from ...feature import Feature -from ..smartmodule import SmartModule +from ..smartmodule import SmartModule, allow_update_after if TYPE_CHECKING: from ..smartdevice import SmartDevice @@ -66,6 +66,7 @@ class Firmware(SmartModule): """Implementation of firmware module.""" REQUIRED_COMPONENT = "firmware" + MINIMUM_UPDATE_INTERVAL_SECS = 60 * 60 * 24 def __init__(self, device: SmartDevice, module: str): super().__init__(device, module) @@ -122,13 +123,6 @@ class Firmware(SmartModule): req["get_auto_update_info"] = None return req - def _post_update_hook(self): - """Perform actions after a device update. - - Overrides the default behaviour to disable a module if the query returns - an error because some of the module still functions. - """ - @property def current_firmware(self) -> str: """Return the current firmware version.""" @@ -162,6 +156,7 @@ class Firmware(SmartModule): state = resp["get_fw_download_state"] return DownloadState(**state) + @allow_update_after async def update( self, progress_cb: Callable[[DownloadState], Coroutine] | None = None ): @@ -219,6 +214,7 @@ class Firmware(SmartModule): and self.data["get_auto_update_info"]["enable"] ) + @allow_update_after async def set_auto_update_enabled(self, enabled: bool): """Change autoupdate setting.""" data = {**self.data["get_auto_update_info"], "enable": enabled} diff --git a/kasa/smart/modules/led.py b/kasa/smart/modules/led.py index 2d0a354c..bbfe3579 100644 --- a/kasa/smart/modules/led.py +++ b/kasa/smart/modules/led.py @@ -3,7 +3,7 @@ from __future__ import annotations from ...interfaces.led import Led as LedInterface -from ..smartmodule import SmartModule +from ..smartmodule import SmartModule, allow_update_after class Led(SmartModule, LedInterface): @@ -11,6 +11,8 @@ class Led(SmartModule, LedInterface): REQUIRED_COMPONENT = "led" QUERY_GETTER_NAME = "get_led_info" + # Led queries can cause device to crash on P100 + MINIMUM_UPDATE_INTERVAL_SECS = 60 * 60 def query(self) -> dict: """Query to execute during the update cycle.""" @@ -29,6 +31,7 @@ class Led(SmartModule, LedInterface): """Return current led status.""" return self.data["led_rule"] != "never" + @allow_update_after async def set_led(self, enable: bool): """Set led. diff --git a/kasa/smart/modules/lighteffect.py b/kasa/smart/modules/lighteffect.py index 07f6aece..5f589d6d 100644 --- a/kasa/smart/modules/lighteffect.py +++ b/kasa/smart/modules/lighteffect.py @@ -9,7 +9,7 @@ import copy from typing import Any from ..effects import SmartLightEffect -from ..smartmodule import Module, SmartModule +from ..smartmodule import Module, SmartModule, allow_update_after class LightEffect(SmartModule, SmartLightEffect): @@ -17,6 +17,7 @@ class LightEffect(SmartModule, SmartLightEffect): REQUIRED_COMPONENT = "light_effect" QUERY_GETTER_NAME = "get_dynamic_light_effect_rules" + MINIMUM_UPDATE_INTERVAL_SECS = 60 AVAILABLE_BULB_EFFECTS = { "L1": "Party", "L2": "Relax", @@ -130,6 +131,7 @@ class LightEffect(SmartModule, SmartLightEffect): return brightness + @allow_update_after async def set_brightness( self, brightness: int, @@ -156,6 +158,7 @@ class LightEffect(SmartModule, SmartLightEffect): return await self.call("edit_dynamic_light_effect_rule", new_effect) + @allow_update_after async def set_custom_effect( self, effect_dict: dict, diff --git a/kasa/smart/modules/lightpreset.py b/kasa/smart/modules/lightpreset.py index 7635a5f8..b9692438 100644 --- a/kasa/smart/modules/lightpreset.py +++ b/kasa/smart/modules/lightpreset.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING from ...interfaces import LightPreset as LightPresetInterface from ...interfaces import LightState -from ..smartmodule import SmartModule +from ..smartmodule import SmartModule, allow_update_after if TYPE_CHECKING: from ..smartdevice import SmartDevice @@ -19,6 +19,7 @@ class LightPreset(SmartModule, LightPresetInterface): REQUIRED_COMPONENT = "preset" QUERY_GETTER_NAME = "get_preset_rules" + MINIMUM_UPDATE_INTERVAL_SECS = 60 SYS_INFO_STATE_KEY = "preset_state" @@ -113,6 +114,7 @@ class LightPreset(SmartModule, LightPresetInterface): raise ValueError(f"{preset_name} is not a valid preset: {self.preset_list}") await self._device.modules[SmartModule.Light].set_state(preset) + @allow_update_after async def save_preset( self, preset_name: str, diff --git a/kasa/smart/modules/lightstripeffect.py b/kasa/smart/modules/lightstripeffect.py index a80c20f3..f7562068 100644 --- a/kasa/smart/modules/lightstripeffect.py +++ b/kasa/smart/modules/lightstripeffect.py @@ -5,7 +5,7 @@ from __future__ import annotations from typing import TYPE_CHECKING from ..effects import EFFECT_MAPPING, EFFECT_NAMES, SmartLightEffect -from ..smartmodule import Module, SmartModule +from ..smartmodule import Module, SmartModule, allow_update_after if TYPE_CHECKING: from ..smartdevice import SmartDevice @@ -84,6 +84,7 @@ class LightStripEffect(SmartModule, SmartLightEffect): """ return self._effect_list + @allow_update_after async def set_effect( self, effect: str, @@ -126,6 +127,7 @@ class LightStripEffect(SmartModule, SmartLightEffect): await self.set_custom_effect(effect_dict) + @allow_update_after async def set_custom_effect( self, effect_dict: dict, diff --git a/kasa/smart/modules/lighttransition.py b/kasa/smart/modules/lighttransition.py index ca0eca86..3a5897d1 100644 --- a/kasa/smart/modules/lighttransition.py +++ b/kasa/smart/modules/lighttransition.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, TypedDict from ...exceptions import KasaException from ...feature import Feature -from ..smartmodule import SmartModule +from ..smartmodule import SmartModule, allow_update_after if TYPE_CHECKING: from ..smartdevice import SmartDevice @@ -23,6 +23,7 @@ class LightTransition(SmartModule): REQUIRED_COMPONENT = "on_off_gradually" QUERY_GETTER_NAME = "get_on_off_gradually_info" + MINIMUM_UPDATE_INTERVAL_SECS = 60 MAXIMUM_DURATION = 60 # Key in sysinfo that indicates state can be retrieved from there. @@ -136,6 +137,7 @@ class LightTransition(SmartModule): "max_duration": off_max, } + @allow_update_after async def set_enabled(self, enable: bool): """Enable gradual on/off.""" if not self._supports_on_and_off: @@ -168,6 +170,7 @@ class LightTransition(SmartModule): # v3 added max_duration, we default to 60 when it's not available return self._on_state["max_duration"] + @allow_update_after async def set_turn_on_transition(self, seconds: int): """Set turn on transition in seconds. @@ -203,6 +206,7 @@ class LightTransition(SmartModule): # v3 added max_duration, we default to 60 when it's not available return self._off_state["max_duration"] + @allow_update_after async def set_turn_off_transition(self, seconds: int): """Set turn on transition in seconds. diff --git a/kasa/smart/smartchilddevice.py b/kasa/smart/smartchilddevice.py index c6596b96..3dfbd146 100644 --- a/kasa/smart/smartchilddevice.py +++ b/kasa/smart/smartchilddevice.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging +import time from typing import Any from ..device_type import DeviceType @@ -46,6 +47,7 @@ class SmartChildDevice(SmartDevice): req.update(mod_query) if req: self._last_update = await self.protocol.query(req) + self._last_update_time = time.time() @classmethod async def create(cls, parent: SmartDevice, child_info, child_components): diff --git a/kasa/smart/smartdevice.py b/kasa/smart/smartdevice.py index fcbc8a15..731789a0 100644 --- a/kasa/smart/smartdevice.py +++ b/kasa/smart/smartdevice.py @@ -4,6 +4,7 @@ from __future__ import annotations import base64 import logging +import time from collections.abc import Mapping, Sequence from datetime import datetime, timedelta, timezone from typing import Any, cast @@ -18,6 +19,7 @@ from ..module import Module from ..modulemapping import ModuleMapping, ModuleName from ..smartprotocol import SmartProtocol from .modules import ( + ChildDevice, Cloud, DeviceModule, Firmware, @@ -35,6 +37,9 @@ _LOGGER = logging.getLogger(__name__) # same issue, homekit perhaps? NON_HUB_PARENT_ONLY_MODULES = [DeviceModule, Time, Firmware, Cloud] +# Modules that are called as part of the init procedure on first update +FIRST_UPDATE_MODULES = {DeviceModule, ChildDevice, Cloud} + # Device must go last as the other interfaces also inherit Device # and python needs a consistent method resolution order. @@ -60,6 +65,7 @@ class SmartDevice(Device): self._parent: SmartDevice | None = None self._children: Mapping[str, SmartDevice] = {} self._last_update = {} + self._last_update_time: float | None = None async def _initialize_children(self): """Initialize children for power strips.""" @@ -152,19 +158,15 @@ class SmartDevice(Device): if self.credentials is None and self.credentials_hash is None: raise AuthenticationError("Tapo plug requires authentication.") - if self._components_raw is None: + first_update = self._last_update_time is None + now = time.time() + self._last_update_time = now + + if first_update: await self._negotiate() await self._initialize_modules() - req: dict[str, Any] = {} - - # TODO: this could be optimized by constructing the query only once - for module in self._modules.values(): - req.update(module.query()) - - self._last_update = resp = await self.protocol.query(req) - - self._info = self._try_get_response(resp, "get_device_info") + resp = await self._modular_update(first_update, now) # Call child update which will only update module calls, info is updated # from get_child_device_list. update_children only affects hub devices, other @@ -172,18 +174,12 @@ class SmartDevice(Device): if update_children or self.device_type != DeviceType.Hub: for child in self._children.values(): await child.update() - if child_info := self._try_get_response(resp, "get_child_device_list", {}): + if child_info := self._try_get_response( + self._last_update, "get_child_device_list", {} + ): for info in child_info["child_device_list"]: self._children[info["device_id"]]._update_internal_state(info) - # Call handle update for modules that want to update internal data - errors = [] - for module_name, module in self._modules.items(): - if not self._handle_module_post_update_hook(module): - errors.append(module_name) - for error in errors: - self._modules.pop(error) - for child in self._children.values(): errors = [] for child_module_name, child_module in child._modules.items(): @@ -197,14 +193,18 @@ class SmartDevice(Device): if not self._features: await self._initialize_features() - _LOGGER.debug("Got an update: %s", self._last_update) + _LOGGER.debug( + "Update completed %s: %s", + self.host, + self._last_update if first_update else resp, + ) def _handle_module_post_update_hook(self, module: SmartModule) -> bool: try: module._post_update_hook() return True except Exception as ex: - _LOGGER.error( + _LOGGER.warning( "Error processing %s for device %s, module will be unavailable: %s", module.name, self.host, @@ -212,6 +212,100 @@ class SmartDevice(Device): ) return False + async def _modular_update( + self, first_update: bool, update_time: float + ) -> dict[str, Any]: + """Update the device with via the module queries.""" + req: dict[str, Any] = {} + # Keep a track of actual module queries so we can track the time for + # modules that do not need to be updated frequently + module_queries: list[SmartModule] = [] + mq = { + module: query + for module in self._modules.values() + if (query := module.query()) + } + for module, query in mq.items(): + if first_update and module.__class__ in FIRST_UPDATE_MODULES: + module._last_update_time = update_time + continue + if ( + not module.MINIMUM_UPDATE_INTERVAL_SECS + or not module._last_update_time + or (update_time - module._last_update_time) + >= module.MINIMUM_UPDATE_INTERVAL_SECS + ): + module_queries.append(module) + req.update(query) + + _LOGGER.debug( + "Querying %s for modules: %s", + self.host, + ", ".join(mod.name for mod in module_queries), + ) + + try: + resp = await self.protocol.query(req) + except Exception as ex: + resp = await self._handle_modular_update_error( + ex, first_update, ", ".join(mod.name for mod in module_queries), req + ) + + info_resp = self._last_update if first_update else resp + self._last_update.update(**resp) + self._info = self._try_get_response(info_resp, "get_device_info") + + # Call handle update for modules that want to update internal data + errors = [] + for module_name, module in self._modules.items(): + if not self._handle_module_post_update_hook(module): + errors.append(module_name) + for error in errors: + self._modules.pop(error) + + # Set the last update time for modules that had queries made. + for module in module_queries: + module._last_update_time = update_time + + return resp + + async def _handle_modular_update_error( + self, + ex: Exception, + first_update: bool, + module_names: str, + requests: dict[str, Any], + ) -> dict[str, Any]: + """Handle an error on calling module update. + + Will try to call all modules individually + and any errors such as timeouts will be set as a SmartErrorCode. + """ + msg_part = "on first update" if first_update else "after first update" + + _LOGGER.error( + "Error querying %s for modules '%s' %s: %s", + self.host, + module_names, + msg_part, + ex, + ) + responses = {} + for meth, params in requests.items(): + try: + resp = await self.protocol.query({meth: params}) + responses[meth] = resp[meth] + except Exception as iex: + _LOGGER.error( + "Error querying %s individually for module query '%s' %s: %s", + self.host, + meth, + msg_part, + iex, + ) + responses[meth] = SmartErrorCode.INTERNAL_QUERY_ERROR + return responses + async def _initialize_modules(self): """Initialize modules based on component negotiation response.""" from .smartmodule import SmartModule @@ -229,8 +323,6 @@ class SmartDevice(Device): skip_parent_only_modules = True for mod in SmartModule.REGISTERED_MODULES.values(): - _LOGGER.debug("%s requires %s", mod, mod.REQUIRED_COMPONENT) - if ( skip_parent_only_modules and mod in NON_HUB_PARENT_ONLY_MODULES ) or mod.__name__ in child_modules_to_skip: @@ -240,7 +332,8 @@ class SmartDevice(Device): or self.sys_info.get(mod.REQUIRED_KEY_ON_PARENT) is not None ): _LOGGER.debug( - "Found required %s, adding %s to modules.", + "Device %s, found required %s, adding %s to modules.", + self.host, mod.REQUIRED_COMPONENT, mod.__name__, ) diff --git a/kasa/smart/smartmodule.py b/kasa/smart/smartmodule.py index fb946a8b..f5f2c212 100644 --- a/kasa/smart/smartmodule.py +++ b/kasa/smart/smartmodule.py @@ -3,7 +3,10 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING +from collections.abc import Awaitable, Callable, Coroutine +from typing import TYPE_CHECKING, Any + +from typing_extensions import Concatenate, ParamSpec, TypeVar from ..exceptions import DeviceError, KasaException, SmartErrorCode from ..module import Module @@ -13,6 +16,27 @@ if TYPE_CHECKING: _LOGGER = logging.getLogger(__name__) +_T = TypeVar("_T", bound="SmartModule") +_P = ParamSpec("_P") + + +def allow_update_after( + func: Callable[Concatenate[_T, _P], Awaitable[None]], +) -> Callable[Concatenate[_T, _P], Coroutine[Any, Any, None]]: + """Define a wrapper to set _last_update_time to None. + + This will ensure that a module is updated in the next update cycle after + a value has been changed. + """ + + async def _async_wrap(self: _T, *args: _P.args, **kwargs: _P.kwargs) -> None: + try: + await func(self, *args, **kwargs) + finally: + self._last_update_time = None + + return _async_wrap + class SmartModule(Module): """Base class for SMART modules.""" @@ -27,9 +51,12 @@ class SmartModule(Module): REGISTERED_MODULES: dict[str, type[SmartModule]] = {} + MINIMUM_UPDATE_INTERVAL_SECS = 0 + def __init__(self, device: SmartDevice, module: str): self._device: SmartDevice super().__init__(device, module) + self._last_update_time: float | None = None def __init_subclass__(cls, **kwargs): name = getattr(cls, "NAME", cls.__name__) diff --git a/kasa/smartprotocol.py b/kasa/smartprotocol.py index 3085714c..0c95325a 100644 --- a/kasa/smartprotocol.py +++ b/kasa/smartprotocol.py @@ -73,18 +73,32 @@ class SmartProtocol(BaseProtocol): return await self._execute_query( request, retry_count=retry, iterate_list_pages=True ) - except _ConnectionError as sdex: + except _ConnectionError as ex: + if retry == 0: + _LOGGER.debug( + "Device %s got a connection error, will retry %s times: %s", + self._host, + retry_count, + ex, + ) if retry >= retry_count: _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) - raise sdex + raise ex continue - except AuthenticationError as auex: + except AuthenticationError as ex: await self._transport.reset() _LOGGER.debug( - "Unable to authenticate with %s, not retrying", self._host + "Unable to authenticate with %s, not retrying: %s", self._host, ex ) - raise auex + raise ex except _RetryableError as ex: + if retry == 0: + _LOGGER.debug( + "Device %s got a retryable error, will retry %s times: %s", + self._host, + retry_count, + ex, + ) await self._transport.reset() if retry >= retry_count: _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) @@ -92,6 +106,13 @@ class SmartProtocol(BaseProtocol): await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT) continue except TimeoutError as ex: + if retry == 0: + _LOGGER.debug( + "Device %s got a timeout error, will retry %s times: %s", + self._host, + retry_count, + ex, + ) await self._transport.reset() if retry >= retry_count: _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) @@ -130,20 +151,21 @@ class SmartProtocol(BaseProtocol): self._handle_response_error_code(resp, method, raise_on_error=False) multi_result[method] = resp["result"] return multi_result - for i in range(0, end, step): + + for batch_num, i in enumerate(range(0, end, step)): requests_step = multi_requests[i : i + step] smart_params = {"requests": requests_step} smart_request = self.get_smart_request(smart_method, smart_params) + batch_name = f"multi-request-batch-{batch_num+1}-of-{int(end/step)+1}" if debug_enabled: _LOGGER.debug( - "%s multi-request-batch-%s >> %s", + "%s %s >> %s", self._host, - i + 1, + batch_name, pf(smart_request), ) response_step = await self._transport.send(smart_request) - batch_name = f"multi-request-batch-{i+1}" if debug_enabled: _LOGGER.debug( "%s %s << %s", @@ -271,7 +293,9 @@ class SmartProtocol(BaseProtocol): try: error_code = SmartErrorCode.from_int(error_code_raw) except ValueError: - _LOGGER.warning("Received unknown error code: %s", error_code_raw) + _LOGGER.warning( + "Device %s received unknown error code: %s", self._host, error_code_raw + ) error_code = SmartErrorCode.INTERNAL_UNKNOWN_ERROR if error_code is SmartErrorCode.SUCCESS: diff --git a/kasa/tests/test_smartdevice.py b/kasa/tests/test_smartdevice.py index 44fabc71..99e2ddb9 100644 --- a/kasa/tests/test_smartdevice.py +++ b/kasa/tests/test_smartdevice.py @@ -3,10 +3,12 @@ from __future__ import annotations import logging +import time from typing import Any, cast from unittest.mock import patch import pytest +from freezegun.api import FrozenDateTimeFactory from pytest_mock import MockerFixture from kasa import Device, KasaException, Module @@ -54,6 +56,8 @@ async def test_initial_update(dev: SmartDevice, mocker: MockerFixture): dev._modules = {} dev._features = {} dev._children = {} + dev._last_update = {} + dev._last_update_time = None negotiate = mocker.spy(dev, "_negotiate") initialize_modules = mocker.spy(dev, "_initialize_modules") @@ -109,6 +113,9 @@ async def test_update_module_queries(dev: SmartDevice, mocker: MockerFixture): """Test that the regular update uses queries from all supported modules.""" # We need to have some modules initialized by now assert dev._modules + # Reset last update so all modules will query + for mod in dev._modules.values(): + mod._last_update_time = None device_queries: dict[SmartDevice, dict[str, Any]] = {} for mod in dev._modules.values(): @@ -139,7 +146,7 @@ async def test_update_module_errors(dev: SmartDevice, mocker: MockerFixture): assert dev._modules critical_modules = {Module.DeviceModule, Module.ChildDevice} - not_disabling_modules = {Module.Firmware, Module.Cloud} + not_disabling_modules = {Module.Cloud} new_dev = SmartDevice("127.0.0.1", protocol=dev.protocol) @@ -204,6 +211,123 @@ async def test_update_module_errors(dev: SmartDevice, mocker: MockerFixture): ), f"{modname} present {mod_present} when no_disable {no_disable}" +@device_smart +async def test_update_module_update_delays( + dev: SmartDevice, + mocker: MockerFixture, + caplog: pytest.LogCaptureFixture, + freezer: FrozenDateTimeFactory, +): + """Test that modules that disabled / removed on query failures.""" + # We need to have some modules initialized by now + assert dev._modules + + new_dev = SmartDevice("127.0.0.1", protocol=dev.protocol) + await new_dev.update() + first_update_time = time.time() + assert new_dev._last_update_time == first_update_time + for module in new_dev.modules.values(): + if module.query(): + assert module._last_update_time == first_update_time + + seconds = 0 + tick = 30 + while seconds <= 180: + seconds += tick + freezer.tick(tick) + + now = time.time() + await new_dev.update() + for module in new_dev.modules.values(): + mod_delay = module.MINIMUM_UPDATE_INTERVAL_SECS + if module.query(): + expected_update_time = ( + now if mod_delay == 0 else now - (seconds % mod_delay) + ) + + assert ( + module._last_update_time == expected_update_time + ), f"Expected update time {expected_update_time} after {seconds} seconds for {module.name} with delay {mod_delay} got {module._last_update_time}" + + +@pytest.mark.parametrize( + ("first_update"), + [ + pytest.param(True, id="First update true"), + pytest.param(False, id="First update false"), + ], +) +@device_smart +async def test_update_module_query_errors( + dev: SmartDevice, + mocker: MockerFixture, + caplog: pytest.LogCaptureFixture, + freezer: FrozenDateTimeFactory, + first_update, +): + """Test that modules that disabled / removed on query failures.""" + # We need to have some modules initialized by now + assert dev._modules + + first_update_queries = {"get_device_info", "get_connect_cloud_state"} + + critical_modules = {Module.DeviceModule, Module.ChildDevice} + not_disabling_modules = {Module.Cloud} + + new_dev = SmartDevice("127.0.0.1", protocol=dev.protocol) + if not first_update: + await new_dev.update() + freezer.tick( + max(module.MINIMUM_UPDATE_INTERVAL_SECS for module in dev._modules.values()) + ) + + module_queries = { + modname: q + for modname, module in dev._modules.items() + if (q := module.query()) and modname not in critical_modules + } + + async def _query(request, *args, **kwargs): + if ( + "component_nego" in request + or "get_child_device_component_list" in request + or "control_child" in request + ): + return await dev.protocol._query(request, *args, **kwargs) + if len(request) == 1 and "get_device_info" in request: + return await dev.protocol._query(request, *args, **kwargs) + + raise TimeoutError("Dummy timeout") + + from kasa.smartprotocol import _ChildProtocolWrapper + + child_protocols = { + cast(_ChildProtocolWrapper, child.protocol)._device_id: child.protocol + for child in dev.children + } + + async def _child_query(self, request, *args, **kwargs): + return await child_protocols[self._device_id]._query(request, *args, **kwargs) + + mocker.patch.object(new_dev.protocol, "query", side_effect=_query) + # children not created yet so cannot patch.object + mocker.patch("kasa.smartprotocol._ChildProtocolWrapper.query", new=_child_query) + + await new_dev.update() + msg = f"Error querying {new_dev.host} for modules" + assert msg in caplog.text + for modname in module_queries: + no_disable = modname in not_disabling_modules + mod_present = modname in new_dev._modules + assert ( + mod_present is no_disable + ), f"{modname} present {mod_present} when no_disable {no_disable}" + for mod_query in module_queries[modname]: + if not first_update or mod_query not in first_update_queries: + msg = f"Error querying {new_dev.host} individually for module query '{mod_query}" + assert msg in caplog.text + + async def test_get_modules(): """Test getting modules for child and parent modules.""" dummy_device = await get_device_for_fixture_protocol( diff --git a/kasa/tests/test_smartprotocol.py b/kasa/tests/test_smartprotocol.py index 71125ca8..204d0c7f 100644 --- a/kasa/tests/test_smartprotocol.py +++ b/kasa/tests/test_smartprotocol.py @@ -66,7 +66,7 @@ async def test_smart_device_unknown_errors( assert res is SmartErrorCode.INTERNAL_UNKNOWN_ERROR send_mock.assert_called_once() - assert f"Received unknown error code: {error_code}" in caplog.text + assert f"received unknown error code: {error_code}" in caplog.text @pytest.mark.parametrize("error_code", ERRORS, ids=lambda e: e.name)