From 6aa019280ba248f318776d65441eefaad3f3b322 Mon Sep 17 00:00:00 2001 From: "Steven B." <51370195+sdb9696@users.noreply.github.com> Date: Mon, 6 Jan 2025 09:23:46 +0000 Subject: [PATCH] Handle smartcam partial list responses (#1411) --- kasa/protocols/smartcamprotocol.py | 17 +++++++---- kasa/protocols/smartprotocol.py | 32 +++++++++++++++------ tests/fakeprotocol_smartcam.py | 19 ++++++++++--- tests/protocols/test_smartprotocol.py | 41 +++++++++++++++++++++++++++ 4 files changed, 91 insertions(+), 18 deletions(-) diff --git a/kasa/protocols/smartcamprotocol.py b/kasa/protocols/smartcamprotocol.py index 324f8056..a1d6ae9c 100644 --- a/kasa/protocols/smartcamprotocol.py +++ b/kasa/protocols/smartcamprotocol.py @@ -5,7 +5,7 @@ from __future__ import annotations import logging from dataclasses import dataclass from pprint import pformat as pf -from typing import Any +from typing import Any, cast from ..exceptions import ( AuthenticationError, @@ -49,10 +49,13 @@ class SingleRequest: class SmartCamProtocol(SmartProtocol): """Class for SmartCam Protocol.""" - async def _handle_response_lists( - self, response_result: dict[str, Any], method: str, retry_count: int - ) -> None: - pass + def _get_list_request( + self, method: str, params: dict | None, start_index: int + ) -> dict: + # All smartcam requests have params + params = cast(dict, params) + module_name = next(iter(params)) + return {method: {module_name: {"start_index": start_index}}} def _handle_response_error_code( self, resp_dict: dict, method: str, raise_on_error: bool = True @@ -147,7 +150,9 @@ class SmartCamProtocol(SmartProtocol): if len(request) == 1 and method in {"get", "set", "do", "multipleRequest"}: single_request = self._get_smart_camera_single_request(request) else: - return await self._execute_multiple_query(request, retry_count) + return await self._execute_multiple_query( + request, retry_count, iterate_list_pages + ) else: single_request = self._make_smart_camera_single_request(request) diff --git a/kasa/protocols/smartprotocol.py b/kasa/protocols/smartprotocol.py index 7f02b45e..28a20641 100644 --- a/kasa/protocols/smartprotocol.py +++ b/kasa/protocols/smartprotocol.py @@ -180,7 +180,9 @@ class SmartProtocol(BaseProtocol): # make mypy happy, this should never be reached.. raise KasaException("Query reached somehow to unreachable") - async def _execute_multiple_query(self, requests: dict, retry_count: int) -> dict: + async def _execute_multiple_query( + self, requests: dict, retry_count: int, iterate_list_pages: bool + ) -> dict: debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG) multi_result: dict[str, Any] = {} smart_method = "multipleRequest" @@ -275,9 +277,11 @@ class SmartProtocol(BaseProtocol): response, method, raise_on_error=raise_on_error ) result = response.get("result", None) - await self._handle_response_lists( - result, method, retry_count=retry_count - ) + request_params = rp if (rp := requests.get(method)) else None + if iterate_list_pages and result: + await self._handle_response_lists( + result, method, request_params, retry_count=retry_count + ) multi_result[method] = result # Multi requests don't continue after errors so requery any missing. @@ -303,7 +307,9 @@ class SmartProtocol(BaseProtocol): smart_method = next(iter(request)) smart_params = request[smart_method] else: - return await self._execute_multiple_query(request, retry_count) + return await self._execute_multiple_query( + request, retry_count, iterate_list_pages + ) else: smart_method = request smart_params = None @@ -330,12 +336,21 @@ class SmartProtocol(BaseProtocol): result = response_data.get("result") if iterate_list_pages and result: await self._handle_response_lists( - result, smart_method, retry_count=retry_count + result, smart_method, smart_params, retry_count=retry_count ) return {smart_method: result} + def _get_list_request( + self, method: str, params: dict | None, start_index: int + ) -> dict: + return {method: {"start_index": start_index}} + async def _handle_response_lists( - self, response_result: dict[str, Any], method: str, retry_count: int + self, + response_result: dict[str, Any], + method: str, + params: dict | None, + retry_count: int, ) -> None: if ( response_result is None @@ -355,8 +370,9 @@ class SmartProtocol(BaseProtocol): ) ) while (list_length := len(response_result[response_list_name])) < list_sum: + request = self._get_list_request(method, params, list_length) response = await self._execute_query( - {method: {"start_index": list_length}}, + request, retry_count=retry_count, iterate_list_pages=False, ) diff --git a/tests/fakeprotocol_smartcam.py b/tests/fakeprotocol_smartcam.py index 381a0a89..eee014e8 100644 --- a/tests/fakeprotocol_smartcam.py +++ b/tests/fakeprotocol_smartcam.py @@ -33,6 +33,7 @@ class FakeSmartCamTransport(BaseTransport): *, list_return_size=10, is_child=False, + get_child_fixtures=True, verbatim=False, components_not_included=False, ): @@ -52,9 +53,12 @@ class FakeSmartCamTransport(BaseTransport): self.verbatim = verbatim if not is_child: self.info = copy.deepcopy(info) - self.child_protocols = FakeSmartTransport._get_child_protocols( - self.info, self.fixture_name, "getChildDeviceList" - ) + # We don't need to get the child fixtures if testing things like + # lists + if get_child_fixtures: + self.child_protocols = FakeSmartTransport._get_child_protocols( + self.info, self.fixture_name, "getChildDeviceList" + ) else: self.info = info # self.child_protocols = self._get_child_protocols() @@ -229,9 +233,16 @@ class FakeSmartCamTransport(BaseTransport): list_key = next( iter([key for key in result if isinstance(result[key], list)]) ) + assert isinstance(params, dict) + module_name = next(iter(params)) + start_index = ( start_index - if (params and (start_index := params.get("start_index"))) + if ( + params + and module_name + and (start_index := params[module_name].get("start_index")) + ) else 0 ) diff --git a/tests/protocols/test_smartprotocol.py b/tests/protocols/test_smartprotocol.py index 7961df68..51492635 100644 --- a/tests/protocols/test_smartprotocol.py +++ b/tests/protocols/test_smartprotocol.py @@ -10,6 +10,7 @@ from kasa.exceptions import ( KasaException, SmartErrorCode, ) +from kasa.protocols.smartcamprotocol import SmartCamProtocol from kasa.protocols.smartprotocol import SmartProtocol, _ChildProtocolWrapper from kasa.smart import SmartDevice @@ -373,6 +374,46 @@ async def test_smart_protocol_lists_multiple_request(mocker, list_sum, batch_siz assert resp == response +@pytest.mark.parametrize("list_sum", [5, 10, 30]) +@pytest.mark.parametrize("batch_size", [1, 2, 3, 50]) +async def test_smartcam_protocol_list_request(mocker, list_sum, batch_size): + """Test smartcam protocol list handling for lists.""" + child_list = [{"foo": i} for i in range(list_sum)] + + response = { + "getChildDeviceList": { + "child_device_list": child_list, + "start_index": 0, + "sum": list_sum, + }, + "getChildDeviceComponentList": { + "child_component_list": child_list, + "start_index": 0, + "sum": list_sum, + }, + } + request = { + "getChildDeviceList": {"childControl": {"start_index": 0}}, + "getChildDeviceComponentList": {"childControl": {"start_index": 0}}, + } + + ft = FakeSmartCamTransport( + response, + "foobar", + list_return_size=batch_size, + components_not_included=True, + get_child_fixtures=False, + ) + protocol = SmartCamProtocol(transport=ft) + query_spy = mocker.spy(protocol, "_execute_query") + resp = await protocol.query(request) + expected_count = 1 + 2 * ( + int(list_sum / batch_size) + (0 if list_sum % batch_size else -1) + ) + assert query_spy.call_count == expected_count + assert resp == response + + async def test_incomplete_list(mocker, caplog): """Test for handling incomplete lists returned from queries.""" info = {