diff --git a/kasa/smart/modules/devicemodule.py b/kasa/smart/modules/devicemodule.py index bf112e2d..692745bb 100644 --- a/kasa/smart/modules/devicemodule.py +++ b/kasa/smart/modules/devicemodule.py @@ -19,12 +19,15 @@ class DeviceModule(SmartModule): def query(self) -> dict: """Query to execute during the update cycle.""" + if self._device._is_hub_child: + # Child devices get their device info updated by the parent device. + return {} query = { "get_device_info": None, } # Device usage is not available on older firmware versions # or child devices of hubs - if self.supported_version >= 2 and not self._device._is_hub_child: + if self.supported_version >= 2: query["get_device_usage"] = None return query diff --git a/kasa/smart/smartchilddevice.py b/kasa/smart/smartchilddevice.py index 5ed7feb6..760a18a1 100644 --- a/kasa/smart/smartchilddevice.py +++ b/kasa/smart/smartchilddevice.py @@ -86,11 +86,22 @@ class SmartChildDevice(SmartDevice): module_queries: list[SmartModule] = [] req: dict[str, Any] = {} for module in self.modules.values(): - if module.disabled is False and (mod_query := module.query()): + if ( + module.disabled is False + and (mod_query := module.query()) + and module._should_update(now) + ): module_queries.append(module) req.update(mod_query) if req: - self._last_update = await self.protocol.query(req) + first_update = self._last_update != {} + try: + resp = await self.protocol.query(req) + except Exception as ex: + resp = await self._handle_modular_update_error( + ex, first_update, ", ".join(mod.name for mod in module_queries), req + ) + self._last_update = resp for module in self.modules.values(): await self._handle_module_post_update( diff --git a/kasa/smart/smartdevice.py b/kasa/smart/smartdevice.py index 5fd22115..89f2f950 100644 --- a/kasa/smart/smartdevice.py +++ b/kasa/smart/smartdevice.py @@ -183,7 +183,7 @@ class SmartDevice(Device): """Update the internal device info.""" self._info = self._try_get_response(info_resp, "get_device_info") - async def update(self, update_children: bool = False) -> None: + async def update(self, update_children: bool = True) -> None: """Update the device.""" if self.credentials is None and self.credentials_hash is None: raise AuthenticationError("Tapo plug requires authentication.") @@ -207,7 +207,7 @@ class SmartDevice(Device): # devices will always update children to prevent errors on module access. # This needs to go after updating the internal state of the children so that # child modules have access to their sysinfo. - if update_children or self.device_type != DeviceType.Hub: + if first_update or update_children or self.device_type != DeviceType.Hub: for child in self._children.values(): if TYPE_CHECKING: assert isinstance(child, SmartChildDevice) @@ -260,11 +260,7 @@ class SmartDevice(Device): if first_update and module.__class__ in self.FIRST_UPDATE_MODULES: module._last_update_time = update_time continue - if ( - not module.update_interval - or not module._last_update_time - or (update_time - module._last_update_time) >= module.update_interval - ): + if module._should_update(update_time): module_queries.append(module) req.update(query) diff --git a/kasa/smart/smartmodule.py b/kasa/smart/smartmodule.py index a5666f63..243852e0 100644 --- a/kasa/smart/smartmodule.py +++ b/kasa/smart/smartmodule.py @@ -62,6 +62,8 @@ class SmartModule(Module): REGISTERED_MODULES: dict[str, type[SmartModule]] = {} MINIMUM_UPDATE_INTERVAL_SECS = 0 + MINIMUM_HUB_CHILD_UPDATE_INTERVAL_SECS = 60 * 60 * 24 + UPDATE_INTERVAL_AFTER_ERROR_SECS = 30 DISABLE_AFTER_ERROR_COUNT = 10 @@ -107,16 +109,27 @@ class SmartModule(Module): @property def update_interval(self) -> int: """Time to wait between updates.""" - if self._last_update_error is None: - return self.MINIMUM_UPDATE_INTERVAL_SECS + if self._last_update_error: + return self.UPDATE_INTERVAL_AFTER_ERROR_SECS * self._error_count - return self.UPDATE_INTERVAL_AFTER_ERROR_SECS * self._error_count + if self._device._is_hub_child: + return self.MINIMUM_HUB_CHILD_UPDATE_INTERVAL_SECS + + return self.MINIMUM_UPDATE_INTERVAL_SECS @property def disabled(self) -> bool: """Return true if the module is disabled due to errors.""" return self._error_count >= self.DISABLE_AFTER_ERROR_COUNT + def _should_update(self, update_time: float) -> bool: + """Return true if module should update based on delay parameters.""" + return ( + not self.update_interval + or not self._last_update_time + or (update_time - self._last_update_time) >= self.update_interval + ) + @classmethod def _module_name(cls) -> str: return getattr(cls, "NAME", cls.__name__) diff --git a/kasa/smartcam/modules/device.py b/kasa/smartcam/modules/device.py index 655a92da..7f84de1e 100644 --- a/kasa/smartcam/modules/device.py +++ b/kasa/smartcam/modules/device.py @@ -16,6 +16,11 @@ class DeviceModule(SmartCamModule): def query(self) -> dict: """Query to execute during the update cycle.""" + if self._device._is_hub_child: + # Child devices get their device info updated by the parent device. + # and generally don't support connection type as they're not + # connected to the network + return {} q = super().query() q["getConnectionType"] = {"network": {"get_connection_type": []}} @@ -70,14 +75,14 @@ class DeviceModule(SmartCamModule): @property def device_id(self) -> str: """Return the device id.""" - return self.data[self.QUERY_GETTER_NAME]["basic_info"]["dev_id"] + return self._device._info["device_id"] @property def rssi(self) -> int | None: """Return the device id.""" - return self.data["getConnectionType"].get("rssiValue") + return self.data.get("getConnectionType", {}).get("rssiValue") @property def signal_level(self) -> int | None: """Return the device id.""" - return self.data["getConnectionType"].get("rssi") + return self.data.get("getConnectionType", {}).get("rssi") diff --git a/tests/smart/test_smartdevice.py b/tests/smart/test_smartdevice.py index 549eb8ad..1cae0abc 100644 --- a/tests/smart/test_smartdevice.py +++ b/tests/smart/test_smartdevice.py @@ -5,7 +5,7 @@ from __future__ import annotations import copy import logging import time -from typing import Any, cast +from typing import TYPE_CHECKING, Any, cast from unittest.mock import patch import pytest @@ -14,7 +14,6 @@ from pytest_mock import MockerFixture from kasa import Device, DeviceType, KasaException, Module from kasa.exceptions import DeviceError, SmartErrorCode -from kasa.protocols.smartprotocol import _ChildProtocolWrapper from kasa.smart import SmartDevice from kasa.smart.modules.energy import Energy from kasa.smart.smartmodule import SmartModule @@ -25,7 +24,16 @@ from tests.conftest import ( get_parent_and_child_modules, smart_discovery, ) -from tests.device_fixtures import variable_temp_smart +from tests.device_fixtures import ( + hub_smartcam, + hubs_smart, + parametrize_combine, + variable_temp_smart, +) + +DUMMY_CHILD_REQUEST_PREFIX = "get_dummy_" + +hub_all = parametrize_combine([hubs_smart, hub_smartcam]) @device_smart @@ -214,6 +222,166 @@ async def test_update_module_update_delays( ), f"Expected update time {expected_update_time} after {seconds} seconds for {module.name} with delay {mod_delay} got {module._last_update_time}" +async def _get_child_responses(child_requests: list[dict[str, Any]], child_protocol): + """Get dummy responses for testing all child modules. + + Even if they don't return really return query. + """ + child_req = {item["method"]: item.get("params") for item in child_requests} + child_resp = {k: v for k, v in child_req.items() if k.startswith("get_dummy")} + child_req = { + k: v for k, v in child_req.items() if k.startswith("get_dummy") is False + } + resp = await child_protocol._query(child_req) + resp = {**child_resp, **resp} + return [ + {"method": k, "error_code": 0, "result": v or {"dummy": "dummy"}} + for k, v in resp.items() + ] + + +@hub_all +@pytest.mark.xdist_group(name="caplog") +async def test_hub_children_update_delays( + dev: SmartDevice, + mocker: MockerFixture, + caplog: pytest.LogCaptureFixture, + freezer: FrozenDateTimeFactory, +): + """Test that hub children use the correct delay.""" + if not dev.children: + pytest.skip(f"Device {dev.model} does not have children.") + # We need to have some modules initialized by now + assert dev._modules + + new_dev = type(dev)("127.0.0.1", protocol=dev.protocol) + module_queries: dict[str, dict[str, dict]] = {} + + # children should always update on first update + await new_dev.update(update_children=False) + + if TYPE_CHECKING: + from ..fakeprotocol_smart import FakeSmartTransport + + assert isinstance(dev.protocol._transport, FakeSmartTransport) + if dev.protocol._transport.child_protocols: + for child in new_dev.children: + for modname, module in child._modules.items(): + if ( + not (q := module.query()) + and modname not in {"DeviceModule", "Light"} + and not module.SYSINFO_LOOKUP_KEYS + ): + q = {f"get_dummy_{modname}": {}} + mocker.patch.object(module, "query", return_value=q) + if q: + queries = module_queries.setdefault(child.device_id, {}) + queries[cast(str, modname)] = q + module._last_update_time = None + + module_queries[""] = { + cast(str, modname): q + for modname, module in dev._modules.items() + if (q := module.query()) + } + + async def _query(request, *args, **kwargs): + # If this is a child multipleRequest query return the error wrapped + child_id = None + # smart hub + if ( + (cc := request.get("control_child")) + and (child_id := cc.get("device_id")) + and (requestData := cc["requestData"]) + and requestData["method"] == "multipleRequest" + and (child_requests := requestData["params"]["requests"]) + ): + child_protocol = dev.protocol._transport.child_protocols[child_id] + resp = await _get_child_responses(child_requests, child_protocol) + return {"control_child": {"responseData": {"result": {"responses": resp}}}} + # smartcam hub + if ( + (mr := request.get("multipleRequest")) + and (requests := mr.get("requests")) + # assumes all requests for the same child + and ( + child_id := next(iter(requests)) + .get("params", {}) + .get("childControl", {}) + .get("device_id") + ) + and ( + child_requests := [ + cc["request_data"] + for req in requests + if (cc := req["params"].get("childControl")) + ] + ) + ): + child_protocol = dev.protocol._transport.child_protocols[child_id] + resp = await _get_child_responses(child_requests, child_protocol) + resp = [{"result": {"response_data": resp}} for resp in resp] + return {"multipleRequest": {"responses": resp}} + + if child_id: # child single query + child_protocol = dev.protocol._transport.child_protocols[child_id] + resp_list = await _get_child_responses([requestData], child_protocol) + resp = {"control_child": {"responseData": resp_list[0]}} + else: + resp = await dev.protocol._query(request, *args, **kwargs) + + return resp + + mocker.patch.object(new_dev.protocol, "query", side_effect=_query) + + first_update_time = time.monotonic() + assert new_dev._last_update_time == first_update_time + + await new_dev.update() + + for dev_id, modqueries in module_queries.items(): + check_dev = new_dev._children[dev_id] if dev_id else new_dev + for modname in modqueries: + mod = cast(SmartModule, check_dev.modules[modname]) + assert mod._last_update_time == first_update_time + + for mod in new_dev.modules.values(): + mod.MINIMUM_UPDATE_INTERVAL_SECS = 5 + freezer.tick(180) + + now = time.monotonic() + await new_dev.update() + + child_tick = max( + module.MINIMUM_HUB_CHILD_UPDATE_INTERVAL_SECS + for child in new_dev.children + for module in child.modules.values() + ) + + for dev_id, modqueries in module_queries.items(): + check_dev = new_dev._children[dev_id] if dev_id else new_dev + for modname in modqueries: + if modname in {"Firmware"}: + continue + mod = cast(SmartModule, check_dev.modules[modname]) + expected_update_time = first_update_time if dev_id else now + assert mod._last_update_time == expected_update_time + + freezer.tick(child_tick) + + now = time.monotonic() + await new_dev.update() + + for dev_id, modqueries in module_queries.items(): + check_dev = new_dev._children[dev_id] if dev_id else new_dev + for modname in modqueries: + if modname in {"Firmware"}: + continue + mod = cast(SmartModule, check_dev.modules[modname]) + + assert mod._last_update_time == now + + @pytest.mark.parametrize( ("first_update"), [ @@ -261,25 +429,77 @@ async def test_update_module_query_errors( new_dev = SmartDevice("127.0.0.1", protocol=dev.protocol) if not first_update: await new_dev.update() - freezer.tick( - max(module.MINIMUM_UPDATE_INTERVAL_SECS for module in dev._modules.values()) - ) + freezer.tick(max(module.update_interval for module in dev._modules.values())) - module_queries = { - modname: q + module_queries: dict[str, dict[str, dict]] = {} + if TYPE_CHECKING: + from ..fakeprotocol_smart import FakeSmartTransport + + assert isinstance(dev.protocol._transport, FakeSmartTransport) + if dev.protocol._transport.child_protocols: + for child in new_dev.children: + for modname, module in child._modules.items(): + if ( + not (q := module.query()) + and modname not in {"DeviceModule", "Light"} + and not module.SYSINFO_LOOKUP_KEYS + ): + q = {f"get_dummy_{modname}": {}} + mocker.patch.object(module, "query", return_value=q) + if q: + queries = module_queries.setdefault(child.device_id, {}) + queries[cast(str, modname)] = q + + module_queries[""] = { + cast(str, modname): q for modname, module in dev._modules.items() if (q := module.query()) and modname not in critical_modules } + raise_error = True + async def _query(request, *args, **kwargs): + pass + # If this is a childmultipleRequest query return the error wrapped + child_id = None if ( - "component_nego" in request - or "get_child_device_component_list" in request - or "control_child" in request + (cc := request.get("control_child")) + and (child_id := cc.get("device_id")) + and (requestData := cc["requestData"]) + and requestData["method"] == "multipleRequest" + and (child_requests := requestData["params"]["requests"]) ): - resp = await dev.protocol._query(request, *args, **kwargs) - resp["get_connect_cloud_state"] = SmartErrorCode.CLOUD_FAILED_ERROR + if raise_error: + if not isinstance(error_type, SmartErrorCode): + raise TimeoutError() + if len(child_requests) > 1: + raise TimeoutError() + + if raise_error: + resp = { + "method": child_requests[0]["method"], + "error_code": error_type.value, + } + else: + child_protocol = dev.protocol._transport.child_protocols[child_id] + resp = await _get_child_responses(child_requests, child_protocol) + return {"control_child": {"responseData": {"result": {"responses": resp}}}} + + if ( + not raise_error + or "component_nego" in request + or "get_child_device_component_list" in request + ): + if child_id: # child single query + child_protocol = dev.protocol._transport.child_protocols[child_id] + resp_list = await _get_child_responses([requestData], child_protocol) + resp = {"control_child": {"responseData": resp_list[0]}} + else: + resp = await dev.protocol._query(request, *args, **kwargs) + if raise_error: + resp["get_connect_cloud_state"] = SmartErrorCode.CLOUD_FAILED_ERROR return resp + # Don't test for errors on get_device_info as that is likely terminal if len(request) == 1 and "get_device_info" in request: return await dev.protocol._query(request, *args, **kwargs) @@ -290,80 +510,77 @@ async def test_update_module_query_errors( raise TimeoutError("Dummy timeout") raise error_type - child_protocols = { - cast(_ChildProtocolWrapper, child.protocol)._device_id: child.protocol - for child in dev.children - } - - async def _child_query(self, request, *args, **kwargs): - return await child_protocols[self._device_id]._query(request, *args, **kwargs) - mocker.patch.object(new_dev.protocol, "query", side_effect=_query) - # children not created yet so cannot patch.object - mocker.patch( - "kasa.protocols.smartprotocol._ChildProtocolWrapper.query", new=_child_query - ) await new_dev.update() msg = f"Error querying {new_dev.host} for modules" assert msg in caplog.text - for modname in module_queries: - mod = cast(SmartModule, new_dev.modules[modname]) - assert mod.disabled is False, f"{modname} disabled" - assert mod.update_interval == mod.UPDATE_INTERVAL_AFTER_ERROR_SECS - for mod_query in module_queries[modname]: - if not first_update or mod_query not in first_update_queries: - msg = f"Error querying {new_dev.host} individually for module query '{mod_query}" - assert msg in caplog.text + for dev_id, modqueries in module_queries.items(): + check_dev = new_dev._children[dev_id] if dev_id else new_dev + for modname in modqueries: + mod = cast(SmartModule, check_dev.modules[modname]) + if modname in {"DeviceModule"} or ( + hasattr(mod, "_state_in_sysinfo") and mod._state_in_sysinfo is True + ): + continue + assert mod.disabled is False, f"{modname} disabled" + assert mod.update_interval == mod.UPDATE_INTERVAL_AFTER_ERROR_SECS + for mod_query in modqueries[modname]: + if not first_update or mod_query not in first_update_queries: + msg = f"Error querying {new_dev.host} individually for module query '{mod_query}" + assert msg in caplog.text # Query again should not run for the modules caplog.clear() await new_dev.update() - for modname in module_queries: - mod = cast(SmartModule, new_dev.modules[modname]) - assert mod.disabled is False, f"{modname} disabled" + for dev_id, modqueries in module_queries.items(): + check_dev = new_dev._children[dev_id] if dev_id else new_dev + for modname in modqueries: + mod = cast(SmartModule, check_dev.modules[modname]) + assert mod.disabled is False, f"{modname} disabled" freezer.tick(SmartModule.UPDATE_INTERVAL_AFTER_ERROR_SECS) caplog.clear() if recover: - mocker.patch.object( - new_dev.protocol, "query", side_effect=new_dev.protocol._query - ) - mocker.patch( - "kasa.protocols.smartprotocol._ChildProtocolWrapper.query", - new=_ChildProtocolWrapper._query, - ) + raise_error = False await new_dev.update() msg = f"Error querying {new_dev.host} for modules" if not recover: assert msg in caplog.text - for modname in module_queries: - mod = cast(SmartModule, new_dev.modules[modname]) - if not recover: - assert mod.disabled is True, f"{modname} not disabled" - assert mod._error_count == 2 - assert mod._last_update_error - for mod_query in module_queries[modname]: - if not first_update or mod_query not in first_update_queries: - msg = f"Error querying {new_dev.host} individually for module query '{mod_query}" - assert msg in caplog.text - # Test one of the raise_if_update_error - if mod.name == "Energy": - emod = cast(Energy, mod) - with pytest.raises(KasaException, match="Module update error"): + + for dev_id, modqueries in module_queries.items(): + check_dev = new_dev._children[dev_id] if dev_id else new_dev + for modname in modqueries: + mod = cast(SmartModule, check_dev.modules[modname]) + if modname in {"DeviceModule"} or ( + hasattr(mod, "_state_in_sysinfo") and mod._state_in_sysinfo is True + ): + continue + if not recover: + assert mod.disabled is True, f"{modname} not disabled" + assert mod._error_count == 2 + assert mod._last_update_error + for mod_query in modqueries[modname]: + if not first_update or mod_query not in first_update_queries: + msg = f"Error querying {new_dev.host} individually for module query '{mod_query}" + assert msg in caplog.text + # Test one of the raise_if_update_error + if mod.name == "Energy": + emod = cast(Energy, mod) + with pytest.raises(KasaException, match="Module update error"): + assert emod.status is not None + else: + assert mod.disabled is False + assert mod._error_count == 0 + assert mod._last_update_error is None + # Test one of the raise_if_update_error doesn't raise + if mod.name == "Energy": + emod = cast(Energy, mod) assert emod.status is not None - else: - assert mod.disabled is False - assert mod._error_count == 0 - assert mod._last_update_error is None - # Test one of the raise_if_update_error doesn't raise - if mod.name == "Energy": - emod = cast(Energy, mod) - assert emod.status is not None async def test_get_modules():