mirror of
				https://github.com/python-kasa/python-kasa.git
				synced 2025-10-31 20:51:54 +00:00 
			
		
		
		
	Update hub children on first update and delay subsequent updates (#1438)
	
		
			
	
		
	
	
		
	
		
			Some checks are pending
		
		
	
	
		
			
				
	
				CI / Perform linting checks (3.13) (push) Waiting to run
				
			
		
			
				
	
				CI / Python ${{ matrix.python-version}} on ${{ matrix.os }}${{ fromJSON('[" (extras)", ""]')[matrix.extras == ''] }} (false, macos-latest, 3.11) (push) Blocked by required conditions
				
			
		
			
				
	
				CI / Python ${{ matrix.python-version}} on ${{ matrix.os }}${{ fromJSON('[" (extras)", ""]')[matrix.extras == ''] }} (false, macos-latest, 3.12) (push) Blocked by required conditions
				
			
		
			
				
	
				CI / Python ${{ matrix.python-version}} on ${{ matrix.os }}${{ fromJSON('[" (extras)", ""]')[matrix.extras == ''] }} (false, macos-latest, 3.13) (push) Blocked by required conditions
				
			
		
			
				
	
				CI / Python ${{ matrix.python-version}} on ${{ matrix.os }}${{ fromJSON('[" (extras)", ""]')[matrix.extras == ''] }} (false, ubuntu-latest, 3.11) (push) Blocked by required conditions
				
			
		
			
				
	
				CI / Python ${{ matrix.python-version}} on ${{ matrix.os }}${{ fromJSON('[" (extras)", ""]')[matrix.extras == ''] }} (false, ubuntu-latest, 3.12) (push) Blocked by required conditions
				
			
		
			
				
	
				CI / Python ${{ matrix.python-version}} on ${{ matrix.os }}${{ fromJSON('[" (extras)", ""]')[matrix.extras == ''] }} (false, ubuntu-latest, 3.13) (push) Blocked by required conditions
				
			
		
			
				
	
				CI / Python ${{ matrix.python-version}} on ${{ matrix.os }}${{ fromJSON('[" (extras)", ""]')[matrix.extras == ''] }} (false, windows-latest, 3.11) (push) Blocked by required conditions
				
			
		
			
				
	
				CI / Python ${{ matrix.python-version}} on ${{ matrix.os }}${{ fromJSON('[" (extras)", ""]')[matrix.extras == ''] }} (false, windows-latest, 3.12) (push) Blocked by required conditions
				
			
		
			
				
	
				CI / Python ${{ matrix.python-version}} on ${{ matrix.os }}${{ fromJSON('[" (extras)", ""]')[matrix.extras == ''] }} (false, windows-latest, 3.13) (push) Blocked by required conditions
				
			
		
			
				
	
				CI / Python ${{ matrix.python-version}} on ${{ matrix.os }}${{ fromJSON('[" (extras)", ""]')[matrix.extras == ''] }} (true, ubuntu-latest, 3.11) (push) Blocked by required conditions
				
			
		
			
				
	
				CI / Python ${{ matrix.python-version}} on ${{ matrix.os }}${{ fromJSON('[" (extras)", ""]')[matrix.extras == ''] }} (true, ubuntu-latest, 3.12) (push) Blocked by required conditions
				
			
		
			
				
	
				CI / Python ${{ matrix.python-version}} on ${{ matrix.os }}${{ fromJSON('[" (extras)", ""]')[matrix.extras == ''] }} (true, ubuntu-latest, 3.13) (push) Blocked by required conditions
				
			
		
			
				
	
				CodeQL checks / Analyze (python) (push) Waiting to run
				
			
		
		
	
	
				
					
				
			
		
			Some checks are pending
		
		
	
	CI / Perform linting checks (3.13) (push) Waiting to run
				
			CI / Python ${{ matrix.python-version}} on ${{ matrix.os }}${{ fromJSON('[" (extras)", ""]')[matrix.extras == ''] }} (false, macos-latest, 3.11) (push) Blocked by required conditions
				
			CI / Python ${{ matrix.python-version}} on ${{ matrix.os }}${{ fromJSON('[" (extras)", ""]')[matrix.extras == ''] }} (false, macos-latest, 3.12) (push) Blocked by required conditions
				
			CI / Python ${{ matrix.python-version}} on ${{ matrix.os }}${{ fromJSON('[" (extras)", ""]')[matrix.extras == ''] }} (false, macos-latest, 3.13) (push) Blocked by required conditions
				
			CI / Python ${{ matrix.python-version}} on ${{ matrix.os }}${{ fromJSON('[" (extras)", ""]')[matrix.extras == ''] }} (false, ubuntu-latest, 3.11) (push) Blocked by required conditions
				
			CI / Python ${{ matrix.python-version}} on ${{ matrix.os }}${{ fromJSON('[" (extras)", ""]')[matrix.extras == ''] }} (false, ubuntu-latest, 3.12) (push) Blocked by required conditions
				
			CI / Python ${{ matrix.python-version}} on ${{ matrix.os }}${{ fromJSON('[" (extras)", ""]')[matrix.extras == ''] }} (false, ubuntu-latest, 3.13) (push) Blocked by required conditions
				
			CI / Python ${{ matrix.python-version}} on ${{ matrix.os }}${{ fromJSON('[" (extras)", ""]')[matrix.extras == ''] }} (false, windows-latest, 3.11) (push) Blocked by required conditions
				
			CI / Python ${{ matrix.python-version}} on ${{ matrix.os }}${{ fromJSON('[" (extras)", ""]')[matrix.extras == ''] }} (false, windows-latest, 3.12) (push) Blocked by required conditions
				
			CI / Python ${{ matrix.python-version}} on ${{ matrix.os }}${{ fromJSON('[" (extras)", ""]')[matrix.extras == ''] }} (false, windows-latest, 3.13) (push) Blocked by required conditions
				
			CI / Python ${{ matrix.python-version}} on ${{ matrix.os }}${{ fromJSON('[" (extras)", ""]')[matrix.extras == ''] }} (true, ubuntu-latest, 3.11) (push) Blocked by required conditions
				
			CI / Python ${{ matrix.python-version}} on ${{ matrix.os }}${{ fromJSON('[" (extras)", ""]')[matrix.extras == ''] }} (true, ubuntu-latest, 3.12) (push) Blocked by required conditions
				
			CI / Python ${{ matrix.python-version}} on ${{ matrix.os }}${{ fromJSON('[" (extras)", ""]')[matrix.extras == ''] }} (true, ubuntu-latest, 3.13) (push) Blocked by required conditions
				
			CodeQL checks / Analyze (python) (push) Waiting to run
				
			This commit is contained in:
		| @@ -19,12 +19,15 @@ class DeviceModule(SmartModule): | ||||
|  | ||||
|     def query(self) -> dict: | ||||
|         """Query to execute during the update cycle.""" | ||||
|         if self._device._is_hub_child: | ||||
|             # Child devices get their device info updated by the parent device. | ||||
|             return {} | ||||
|         query = { | ||||
|             "get_device_info": None, | ||||
|         } | ||||
|         # Device usage is not available on older firmware versions | ||||
|         # or child devices of hubs | ||||
|         if self.supported_version >= 2 and not self._device._is_hub_child: | ||||
|         if self.supported_version >= 2: | ||||
|             query["get_device_usage"] = None | ||||
|  | ||||
|         return query | ||||
|   | ||||
| @@ -86,11 +86,22 @@ class SmartChildDevice(SmartDevice): | ||||
|         module_queries: list[SmartModule] = [] | ||||
|         req: dict[str, Any] = {} | ||||
|         for module in self.modules.values(): | ||||
|             if module.disabled is False and (mod_query := module.query()): | ||||
|             if ( | ||||
|                 module.disabled is False | ||||
|                 and (mod_query := module.query()) | ||||
|                 and module._should_update(now) | ||||
|             ): | ||||
|                 module_queries.append(module) | ||||
|                 req.update(mod_query) | ||||
|         if req: | ||||
|             self._last_update = await self.protocol.query(req) | ||||
|             first_update = self._last_update != {} | ||||
|             try: | ||||
|                 resp = await self.protocol.query(req) | ||||
|             except Exception as ex: | ||||
|                 resp = await self._handle_modular_update_error( | ||||
|                     ex, first_update, ", ".join(mod.name for mod in module_queries), req | ||||
|                 ) | ||||
|             self._last_update = resp | ||||
|  | ||||
|         for module in self.modules.values(): | ||||
|             await self._handle_module_post_update( | ||||
|   | ||||
| @@ -183,7 +183,7 @@ class SmartDevice(Device): | ||||
|         """Update the internal device info.""" | ||||
|         self._info = self._try_get_response(info_resp, "get_device_info") | ||||
|  | ||||
|     async def update(self, update_children: bool = False) -> None: | ||||
|     async def update(self, update_children: bool = True) -> None: | ||||
|         """Update the device.""" | ||||
|         if self.credentials is None and self.credentials_hash is None: | ||||
|             raise AuthenticationError("Tapo plug requires authentication.") | ||||
| @@ -207,7 +207,7 @@ class SmartDevice(Device): | ||||
|         # devices will always update children to prevent errors on module access. | ||||
|         # This needs to go after updating the internal state of the children so that | ||||
|         # child modules have access to their sysinfo. | ||||
|         if update_children or self.device_type != DeviceType.Hub: | ||||
|         if first_update or update_children or self.device_type != DeviceType.Hub: | ||||
|             for child in self._children.values(): | ||||
|                 if TYPE_CHECKING: | ||||
|                     assert isinstance(child, SmartChildDevice) | ||||
| @@ -260,11 +260,7 @@ class SmartDevice(Device): | ||||
|             if first_update and module.__class__ in self.FIRST_UPDATE_MODULES: | ||||
|                 module._last_update_time = update_time | ||||
|                 continue | ||||
|             if ( | ||||
|                 not module.update_interval | ||||
|                 or not module._last_update_time | ||||
|                 or (update_time - module._last_update_time) >= module.update_interval | ||||
|             ): | ||||
|             if module._should_update(update_time): | ||||
|                 module_queries.append(module) | ||||
|                 req.update(query) | ||||
|  | ||||
|   | ||||
| @@ -62,6 +62,8 @@ class SmartModule(Module): | ||||
|     REGISTERED_MODULES: dict[str, type[SmartModule]] = {} | ||||
|  | ||||
|     MINIMUM_UPDATE_INTERVAL_SECS = 0 | ||||
|     MINIMUM_HUB_CHILD_UPDATE_INTERVAL_SECS = 60 * 60 * 24 | ||||
|  | ||||
|     UPDATE_INTERVAL_AFTER_ERROR_SECS = 30 | ||||
|  | ||||
|     DISABLE_AFTER_ERROR_COUNT = 10 | ||||
| @@ -107,16 +109,27 @@ class SmartModule(Module): | ||||
|     @property | ||||
|     def update_interval(self) -> int: | ||||
|         """Time to wait between updates.""" | ||||
|         if self._last_update_error is None: | ||||
|             return self.MINIMUM_UPDATE_INTERVAL_SECS | ||||
|         if self._last_update_error: | ||||
|             return self.UPDATE_INTERVAL_AFTER_ERROR_SECS * self._error_count | ||||
|  | ||||
|         return self.UPDATE_INTERVAL_AFTER_ERROR_SECS * self._error_count | ||||
|         if self._device._is_hub_child: | ||||
|             return self.MINIMUM_HUB_CHILD_UPDATE_INTERVAL_SECS | ||||
|  | ||||
|         return self.MINIMUM_UPDATE_INTERVAL_SECS | ||||
|  | ||||
|     @property | ||||
|     def disabled(self) -> bool: | ||||
|         """Return true if the module is disabled due to errors.""" | ||||
|         return self._error_count >= self.DISABLE_AFTER_ERROR_COUNT | ||||
|  | ||||
|     def _should_update(self, update_time: float) -> bool: | ||||
|         """Return true if module should update based on delay parameters.""" | ||||
|         return ( | ||||
|             not self.update_interval | ||||
|             or not self._last_update_time | ||||
|             or (update_time - self._last_update_time) >= self.update_interval | ||||
|         ) | ||||
|  | ||||
|     @classmethod | ||||
|     def _module_name(cls) -> str: | ||||
|         return getattr(cls, "NAME", cls.__name__) | ||||
|   | ||||
| @@ -16,6 +16,11 @@ class DeviceModule(SmartCamModule): | ||||
|  | ||||
|     def query(self) -> dict: | ||||
|         """Query to execute during the update cycle.""" | ||||
|         if self._device._is_hub_child: | ||||
|             # Child devices get their device info updated by the parent device. | ||||
|             # and generally don't support connection type as they're not | ||||
|             # connected to the network | ||||
|             return {} | ||||
|         q = super().query() | ||||
|         q["getConnectionType"] = {"network": {"get_connection_type": []}} | ||||
|  | ||||
| @@ -70,14 +75,14 @@ class DeviceModule(SmartCamModule): | ||||
|     @property | ||||
|     def device_id(self) -> str: | ||||
|         """Return the device id.""" | ||||
|         return self.data[self.QUERY_GETTER_NAME]["basic_info"]["dev_id"] | ||||
|         return self._device._info["device_id"] | ||||
|  | ||||
|     @property | ||||
|     def rssi(self) -> int | None: | ||||
|         """Return the device id.""" | ||||
|         return self.data["getConnectionType"].get("rssiValue") | ||||
|         return self.data.get("getConnectionType", {}).get("rssiValue") | ||||
|  | ||||
|     @property | ||||
|     def signal_level(self) -> int | None: | ||||
|         """Return the device id.""" | ||||
|         return self.data["getConnectionType"].get("rssi") | ||||
|         return self.data.get("getConnectionType", {}).get("rssi") | ||||
|   | ||||
| @@ -5,7 +5,7 @@ from __future__ import annotations | ||||
| import copy | ||||
| import logging | ||||
| import time | ||||
| from typing import Any, cast | ||||
| from typing import TYPE_CHECKING, Any, cast | ||||
| from unittest.mock import patch | ||||
|  | ||||
| import pytest | ||||
| @@ -14,7 +14,6 @@ from pytest_mock import MockerFixture | ||||
|  | ||||
| from kasa import Device, DeviceType, KasaException, Module | ||||
| from kasa.exceptions import DeviceError, SmartErrorCode | ||||
| from kasa.protocols.smartprotocol import _ChildProtocolWrapper | ||||
| from kasa.smart import SmartDevice | ||||
| from kasa.smart.modules.energy import Energy | ||||
| from kasa.smart.smartmodule import SmartModule | ||||
| @@ -25,7 +24,16 @@ from tests.conftest import ( | ||||
|     get_parent_and_child_modules, | ||||
|     smart_discovery, | ||||
| ) | ||||
| from tests.device_fixtures import variable_temp_smart | ||||
| from tests.device_fixtures import ( | ||||
|     hub_smartcam, | ||||
|     hubs_smart, | ||||
|     parametrize_combine, | ||||
|     variable_temp_smart, | ||||
| ) | ||||
|  | ||||
| DUMMY_CHILD_REQUEST_PREFIX = "get_dummy_" | ||||
|  | ||||
| hub_all = parametrize_combine([hubs_smart, hub_smartcam]) | ||||
|  | ||||
|  | ||||
| @device_smart | ||||
| @@ -214,6 +222,166 @@ async def test_update_module_update_delays( | ||||
|                 ), f"Expected update time {expected_update_time} after {seconds} seconds for {module.name} with delay {mod_delay} got {module._last_update_time}" | ||||
|  | ||||
|  | ||||
| async def _get_child_responses(child_requests: list[dict[str, Any]], child_protocol): | ||||
|     """Get dummy responses for testing all child modules. | ||||
|  | ||||
|     Even if they don't return really return query. | ||||
|     """ | ||||
|     child_req = {item["method"]: item.get("params") for item in child_requests} | ||||
|     child_resp = {k: v for k, v in child_req.items() if k.startswith("get_dummy")} | ||||
|     child_req = { | ||||
|         k: v for k, v in child_req.items() if k.startswith("get_dummy") is False | ||||
|     } | ||||
|     resp = await child_protocol._query(child_req) | ||||
|     resp = {**child_resp, **resp} | ||||
|     return [ | ||||
|         {"method": k, "error_code": 0, "result": v or {"dummy": "dummy"}} | ||||
|         for k, v in resp.items() | ||||
|     ] | ||||
|  | ||||
|  | ||||
| @hub_all | ||||
| @pytest.mark.xdist_group(name="caplog") | ||||
| async def test_hub_children_update_delays( | ||||
|     dev: SmartDevice, | ||||
|     mocker: MockerFixture, | ||||
|     caplog: pytest.LogCaptureFixture, | ||||
|     freezer: FrozenDateTimeFactory, | ||||
| ): | ||||
|     """Test that hub children use the correct delay.""" | ||||
|     if not dev.children: | ||||
|         pytest.skip(f"Device {dev.model} does not have children.") | ||||
|     # We need to have some modules initialized by now | ||||
|     assert dev._modules | ||||
|  | ||||
|     new_dev = type(dev)("127.0.0.1", protocol=dev.protocol) | ||||
|     module_queries: dict[str, dict[str, dict]] = {} | ||||
|  | ||||
|     # children should always update on first update | ||||
|     await new_dev.update(update_children=False) | ||||
|  | ||||
|     if TYPE_CHECKING: | ||||
|         from ..fakeprotocol_smart import FakeSmartTransport | ||||
|  | ||||
|         assert isinstance(dev.protocol._transport, FakeSmartTransport) | ||||
|     if dev.protocol._transport.child_protocols: | ||||
|         for child in new_dev.children: | ||||
|             for modname, module in child._modules.items(): | ||||
|                 if ( | ||||
|                     not (q := module.query()) | ||||
|                     and modname not in {"DeviceModule", "Light"} | ||||
|                     and not module.SYSINFO_LOOKUP_KEYS | ||||
|                 ): | ||||
|                     q = {f"get_dummy_{modname}": {}} | ||||
|                     mocker.patch.object(module, "query", return_value=q) | ||||
|                 if q: | ||||
|                     queries = module_queries.setdefault(child.device_id, {}) | ||||
|                     queries[cast(str, modname)] = q | ||||
|                 module._last_update_time = None | ||||
|  | ||||
|     module_queries[""] = { | ||||
|         cast(str, modname): q | ||||
|         for modname, module in dev._modules.items() | ||||
|         if (q := module.query()) | ||||
|     } | ||||
|  | ||||
|     async def _query(request, *args, **kwargs): | ||||
|         # If this is a child multipleRequest query return the error wrapped | ||||
|         child_id = None | ||||
|         # smart hub | ||||
|         if ( | ||||
|             (cc := request.get("control_child")) | ||||
|             and (child_id := cc.get("device_id")) | ||||
|             and (requestData := cc["requestData"]) | ||||
|             and requestData["method"] == "multipleRequest" | ||||
|             and (child_requests := requestData["params"]["requests"]) | ||||
|         ): | ||||
|             child_protocol = dev.protocol._transport.child_protocols[child_id] | ||||
|             resp = await _get_child_responses(child_requests, child_protocol) | ||||
|             return {"control_child": {"responseData": {"result": {"responses": resp}}}} | ||||
|         # smartcam hub | ||||
|         if ( | ||||
|             (mr := request.get("multipleRequest")) | ||||
|             and (requests := mr.get("requests")) | ||||
|             # assumes all requests for the same child | ||||
|             and ( | ||||
|                 child_id := next(iter(requests)) | ||||
|                 .get("params", {}) | ||||
|                 .get("childControl", {}) | ||||
|                 .get("device_id") | ||||
|             ) | ||||
|             and ( | ||||
|                 child_requests := [ | ||||
|                     cc["request_data"] | ||||
|                     for req in requests | ||||
|                     if (cc := req["params"].get("childControl")) | ||||
|                 ] | ||||
|             ) | ||||
|         ): | ||||
|             child_protocol = dev.protocol._transport.child_protocols[child_id] | ||||
|             resp = await _get_child_responses(child_requests, child_protocol) | ||||
|             resp = [{"result": {"response_data": resp}} for resp in resp] | ||||
|             return {"multipleRequest": {"responses": resp}} | ||||
|  | ||||
|         if child_id:  # child single query | ||||
|             child_protocol = dev.protocol._transport.child_protocols[child_id] | ||||
|             resp_list = await _get_child_responses([requestData], child_protocol) | ||||
|             resp = {"control_child": {"responseData": resp_list[0]}} | ||||
|         else: | ||||
|             resp = await dev.protocol._query(request, *args, **kwargs) | ||||
|  | ||||
|         return resp | ||||
|  | ||||
|     mocker.patch.object(new_dev.protocol, "query", side_effect=_query) | ||||
|  | ||||
|     first_update_time = time.monotonic() | ||||
|     assert new_dev._last_update_time == first_update_time | ||||
|  | ||||
|     await new_dev.update() | ||||
|  | ||||
|     for dev_id, modqueries in module_queries.items(): | ||||
|         check_dev = new_dev._children[dev_id] if dev_id else new_dev | ||||
|         for modname in modqueries: | ||||
|             mod = cast(SmartModule, check_dev.modules[modname]) | ||||
|             assert mod._last_update_time == first_update_time | ||||
|  | ||||
|     for mod in new_dev.modules.values(): | ||||
|         mod.MINIMUM_UPDATE_INTERVAL_SECS = 5 | ||||
|     freezer.tick(180) | ||||
|  | ||||
|     now = time.monotonic() | ||||
|     await new_dev.update() | ||||
|  | ||||
|     child_tick = max( | ||||
|         module.MINIMUM_HUB_CHILD_UPDATE_INTERVAL_SECS | ||||
|         for child in new_dev.children | ||||
|         for module in child.modules.values() | ||||
|     ) | ||||
|  | ||||
|     for dev_id, modqueries in module_queries.items(): | ||||
|         check_dev = new_dev._children[dev_id] if dev_id else new_dev | ||||
|         for modname in modqueries: | ||||
|             if modname in {"Firmware"}: | ||||
|                 continue | ||||
|             mod = cast(SmartModule, check_dev.modules[modname]) | ||||
|             expected_update_time = first_update_time if dev_id else now | ||||
|             assert mod._last_update_time == expected_update_time | ||||
|  | ||||
|     freezer.tick(child_tick) | ||||
|  | ||||
|     now = time.monotonic() | ||||
|     await new_dev.update() | ||||
|  | ||||
|     for dev_id, modqueries in module_queries.items(): | ||||
|         check_dev = new_dev._children[dev_id] if dev_id else new_dev | ||||
|         for modname in modqueries: | ||||
|             if modname in {"Firmware"}: | ||||
|                 continue | ||||
|             mod = cast(SmartModule, check_dev.modules[modname]) | ||||
|  | ||||
|             assert mod._last_update_time == now | ||||
|  | ||||
|  | ||||
| @pytest.mark.parametrize( | ||||
|     ("first_update"), | ||||
|     [ | ||||
| @@ -261,25 +429,77 @@ async def test_update_module_query_errors( | ||||
|     new_dev = SmartDevice("127.0.0.1", protocol=dev.protocol) | ||||
|     if not first_update: | ||||
|         await new_dev.update() | ||||
|         freezer.tick( | ||||
|             max(module.MINIMUM_UPDATE_INTERVAL_SECS for module in dev._modules.values()) | ||||
|         ) | ||||
|         freezer.tick(max(module.update_interval for module in dev._modules.values())) | ||||
|  | ||||
|     module_queries = { | ||||
|         modname: q | ||||
|     module_queries: dict[str, dict[str, dict]] = {} | ||||
|     if TYPE_CHECKING: | ||||
|         from ..fakeprotocol_smart import FakeSmartTransport | ||||
|  | ||||
|         assert isinstance(dev.protocol._transport, FakeSmartTransport) | ||||
|     if dev.protocol._transport.child_protocols: | ||||
|         for child in new_dev.children: | ||||
|             for modname, module in child._modules.items(): | ||||
|                 if ( | ||||
|                     not (q := module.query()) | ||||
|                     and modname not in {"DeviceModule", "Light"} | ||||
|                     and not module.SYSINFO_LOOKUP_KEYS | ||||
|                 ): | ||||
|                     q = {f"get_dummy_{modname}": {}} | ||||
|                     mocker.patch.object(module, "query", return_value=q) | ||||
|                 if q: | ||||
|                     queries = module_queries.setdefault(child.device_id, {}) | ||||
|                     queries[cast(str, modname)] = q | ||||
|  | ||||
|     module_queries[""] = { | ||||
|         cast(str, modname): q | ||||
|         for modname, module in dev._modules.items() | ||||
|         if (q := module.query()) and modname not in critical_modules | ||||
|     } | ||||
|  | ||||
|     raise_error = True | ||||
|  | ||||
|     async def _query(request, *args, **kwargs): | ||||
|         pass | ||||
|         # If this is a childmultipleRequest query return the error wrapped | ||||
|         child_id = None | ||||
|         if ( | ||||
|             "component_nego" in request | ||||
|             or "get_child_device_component_list" in request | ||||
|             or "control_child" in request | ||||
|             (cc := request.get("control_child")) | ||||
|             and (child_id := cc.get("device_id")) | ||||
|             and (requestData := cc["requestData"]) | ||||
|             and requestData["method"] == "multipleRequest" | ||||
|             and (child_requests := requestData["params"]["requests"]) | ||||
|         ): | ||||
|             resp = await dev.protocol._query(request, *args, **kwargs) | ||||
|             resp["get_connect_cloud_state"] = SmartErrorCode.CLOUD_FAILED_ERROR | ||||
|             if raise_error: | ||||
|                 if not isinstance(error_type, SmartErrorCode): | ||||
|                     raise TimeoutError() | ||||
|                 if len(child_requests) > 1: | ||||
|                     raise TimeoutError() | ||||
|  | ||||
|             if raise_error: | ||||
|                 resp = { | ||||
|                     "method": child_requests[0]["method"], | ||||
|                     "error_code": error_type.value, | ||||
|                 } | ||||
|             else: | ||||
|                 child_protocol = dev.protocol._transport.child_protocols[child_id] | ||||
|                 resp = await _get_child_responses(child_requests, child_protocol) | ||||
|             return {"control_child": {"responseData": {"result": {"responses": resp}}}} | ||||
|  | ||||
|         if ( | ||||
|             not raise_error | ||||
|             or "component_nego" in request | ||||
|             or "get_child_device_component_list" in request | ||||
|         ): | ||||
|             if child_id:  # child single query | ||||
|                 child_protocol = dev.protocol._transport.child_protocols[child_id] | ||||
|                 resp_list = await _get_child_responses([requestData], child_protocol) | ||||
|                 resp = {"control_child": {"responseData": resp_list[0]}} | ||||
|             else: | ||||
|                 resp = await dev.protocol._query(request, *args, **kwargs) | ||||
|             if raise_error: | ||||
|                 resp["get_connect_cloud_state"] = SmartErrorCode.CLOUD_FAILED_ERROR | ||||
|             return resp | ||||
|  | ||||
|         # Don't test for errors on get_device_info as that is likely terminal | ||||
|         if len(request) == 1 and "get_device_info" in request: | ||||
|             return await dev.protocol._query(request, *args, **kwargs) | ||||
| @@ -290,80 +510,77 @@ async def test_update_module_query_errors( | ||||
|             raise TimeoutError("Dummy timeout") | ||||
|         raise error_type | ||||
|  | ||||
|     child_protocols = { | ||||
|         cast(_ChildProtocolWrapper, child.protocol)._device_id: child.protocol | ||||
|         for child in dev.children | ||||
|     } | ||||
|  | ||||
|     async def _child_query(self, request, *args, **kwargs): | ||||
|         return await child_protocols[self._device_id]._query(request, *args, **kwargs) | ||||
|  | ||||
|     mocker.patch.object(new_dev.protocol, "query", side_effect=_query) | ||||
|     # children not created yet so cannot patch.object | ||||
|     mocker.patch( | ||||
|         "kasa.protocols.smartprotocol._ChildProtocolWrapper.query", new=_child_query | ||||
|     ) | ||||
|  | ||||
|     await new_dev.update() | ||||
|  | ||||
|     msg = f"Error querying {new_dev.host} for modules" | ||||
|     assert msg in caplog.text | ||||
|     for modname in module_queries: | ||||
|         mod = cast(SmartModule, new_dev.modules[modname]) | ||||
|         assert mod.disabled is False, f"{modname} disabled" | ||||
|         assert mod.update_interval == mod.UPDATE_INTERVAL_AFTER_ERROR_SECS | ||||
|         for mod_query in module_queries[modname]: | ||||
|             if not first_update or mod_query not in first_update_queries: | ||||
|                 msg = f"Error querying {new_dev.host} individually for module query '{mod_query}" | ||||
|                 assert msg in caplog.text | ||||
|     for dev_id, modqueries in module_queries.items(): | ||||
|         check_dev = new_dev._children[dev_id] if dev_id else new_dev | ||||
|         for modname in modqueries: | ||||
|             mod = cast(SmartModule, check_dev.modules[modname]) | ||||
|             if modname in {"DeviceModule"} or ( | ||||
|                 hasattr(mod, "_state_in_sysinfo") and mod._state_in_sysinfo is True | ||||
|             ): | ||||
|                 continue | ||||
|             assert mod.disabled is False, f"{modname} disabled" | ||||
|             assert mod.update_interval == mod.UPDATE_INTERVAL_AFTER_ERROR_SECS | ||||
|             for mod_query in modqueries[modname]: | ||||
|                 if not first_update or mod_query not in first_update_queries: | ||||
|                     msg = f"Error querying {new_dev.host} individually for module query '{mod_query}" | ||||
|                     assert msg in caplog.text | ||||
|  | ||||
|     # Query again should not run for the modules | ||||
|     caplog.clear() | ||||
|     await new_dev.update() | ||||
|     for modname in module_queries: | ||||
|         mod = cast(SmartModule, new_dev.modules[modname]) | ||||
|         assert mod.disabled is False, f"{modname} disabled" | ||||
|     for dev_id, modqueries in module_queries.items(): | ||||
|         check_dev = new_dev._children[dev_id] if dev_id else new_dev | ||||
|         for modname in modqueries: | ||||
|             mod = cast(SmartModule, check_dev.modules[modname]) | ||||
|             assert mod.disabled is False, f"{modname} disabled" | ||||
|  | ||||
|     freezer.tick(SmartModule.UPDATE_INTERVAL_AFTER_ERROR_SECS) | ||||
|  | ||||
|     caplog.clear() | ||||
|  | ||||
|     if recover: | ||||
|         mocker.patch.object( | ||||
|             new_dev.protocol, "query", side_effect=new_dev.protocol._query | ||||
|         ) | ||||
|         mocker.patch( | ||||
|             "kasa.protocols.smartprotocol._ChildProtocolWrapper.query", | ||||
|             new=_ChildProtocolWrapper._query, | ||||
|         ) | ||||
|         raise_error = False | ||||
|  | ||||
|     await new_dev.update() | ||||
|     msg = f"Error querying {new_dev.host} for modules" | ||||
|     if not recover: | ||||
|         assert msg in caplog.text | ||||
|     for modname in module_queries: | ||||
|         mod = cast(SmartModule, new_dev.modules[modname]) | ||||
|         if not recover: | ||||
|             assert mod.disabled is True, f"{modname} not disabled" | ||||
|             assert mod._error_count == 2 | ||||
|             assert mod._last_update_error | ||||
|             for mod_query in module_queries[modname]: | ||||
|                 if not first_update or mod_query not in first_update_queries: | ||||
|                     msg = f"Error querying {new_dev.host} individually for module query '{mod_query}" | ||||
|                     assert msg in caplog.text | ||||
|             # Test one of the raise_if_update_error | ||||
|             if mod.name == "Energy": | ||||
|                 emod = cast(Energy, mod) | ||||
|                 with pytest.raises(KasaException, match="Module update error"): | ||||
|  | ||||
|     for dev_id, modqueries in module_queries.items(): | ||||
|         check_dev = new_dev._children[dev_id] if dev_id else new_dev | ||||
|         for modname in modqueries: | ||||
|             mod = cast(SmartModule, check_dev.modules[modname]) | ||||
|             if modname in {"DeviceModule"} or ( | ||||
|                 hasattr(mod, "_state_in_sysinfo") and mod._state_in_sysinfo is True | ||||
|             ): | ||||
|                 continue | ||||
|             if not recover: | ||||
|                 assert mod.disabled is True, f"{modname} not disabled" | ||||
|                 assert mod._error_count == 2 | ||||
|                 assert mod._last_update_error | ||||
|                 for mod_query in modqueries[modname]: | ||||
|                     if not first_update or mod_query not in first_update_queries: | ||||
|                         msg = f"Error querying {new_dev.host} individually for module query '{mod_query}" | ||||
|                         assert msg in caplog.text | ||||
|                 # Test one of the raise_if_update_error | ||||
|                 if mod.name == "Energy": | ||||
|                     emod = cast(Energy, mod) | ||||
|                     with pytest.raises(KasaException, match="Module update error"): | ||||
|                         assert emod.status is not None | ||||
|             else: | ||||
|                 assert mod.disabled is False | ||||
|                 assert mod._error_count == 0 | ||||
|                 assert mod._last_update_error is None | ||||
|                 # Test one of the raise_if_update_error doesn't raise | ||||
|                 if mod.name == "Energy": | ||||
|                     emod = cast(Energy, mod) | ||||
|                     assert emod.status is not None | ||||
|         else: | ||||
|             assert mod.disabled is False | ||||
|             assert mod._error_count == 0 | ||||
|             assert mod._last_update_error is None | ||||
|             # Test one of the raise_if_update_error doesn't raise | ||||
|             if mod.name == "Energy": | ||||
|                 emod = cast(Energy, mod) | ||||
|                 assert emod.status is not None | ||||
|  | ||||
|  | ||||
| async def test_get_modules(): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Steven B.
					Steven B.