diff --git a/kasa/smartprotocol.py b/kasa/smartprotocol.py index 9a1482b1..cbfd16b0 100644 --- a/kasa/smartprotocol.py +++ b/kasa/smartprotocol.py @@ -67,7 +67,9 @@ class SmartProtocol(BaseProtocol): async def _query(self, request: str | dict, retry_count: int = 3) -> dict: for retry in range(retry_count + 1): try: - return await self._execute_query(request, retry) + return await self._execute_query( + request, retry_count=retry, iterate_list_pages=True + ) except _ConnectionError as sdex: if retry >= retry_count: _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) @@ -145,6 +147,9 @@ class SmartProtocol(BaseProtocol): method = response["method"] self._handle_response_error_code(response, method, raise_on_error=False) result = response.get("result", None) + await self._handle_response_lists( + result, method, retry_count=retry_count + ) multi_result[method] = result # Multi requests don't continue after errors so requery any missing for method, params in requests.items(): @@ -156,7 +161,9 @@ class SmartProtocol(BaseProtocol): multi_result[method] = resp.get("result") return multi_result - async def _execute_query(self, request: str | dict, retry_count: int) -> dict: + async def _execute_query( + self, request: str | dict, *, retry_count: int, iterate_list_pages: bool = True + ) -> dict: debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG) if isinstance(request, dict): @@ -189,8 +196,40 @@ class SmartProtocol(BaseProtocol): # Single set_ requests do not return a result result = response_data.get("result") + if iterate_list_pages and result: + await self._handle_response_lists( + result, smart_method, retry_count=retry_count + ) return {smart_method: result} + async def _handle_response_lists( + self, response_result: dict[str, Any], method, retry_count + ): + if ( + isinstance(response_result, SmartErrorCode) + or "start_index" not in response_result + or (list_sum := response_result.get("sum")) is None + ): + return + + response_list_name = next( + iter( + [ + key + for key in response_result + if isinstance(response_result[key], list) + ] + ) + ) + while (list_length := len(response_result[response_list_name])) < list_sum: + response = await self._execute_query( + {method: {"start_index": list_length}}, + retry_count=retry_count, + iterate_list_pages=False, + ) + next_batch = response[method] + response_result[response_list_name].extend(next_batch[response_list_name]) + def _handle_response_error_code(self, resp_dict: dict, method, raise_on_error=True): error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type] if error_code == SmartErrorCode.SUCCESS: diff --git a/kasa/tests/fakeprotocol_smart.py b/kasa/tests/fakeprotocol_smart.py index b46f8f3d..7340b5b7 100644 --- a/kasa/tests/fakeprotocol_smart.py +++ b/kasa/tests/fakeprotocol_smart.py @@ -21,7 +21,14 @@ class FakeSmartProtocol(SmartProtocol): class FakeSmartTransport(BaseTransport): - def __init__(self, info, fixture_name): + def __init__( + self, + info, + fixture_name, + *, + list_return_size=10, + component_nego_not_included=False, + ): super().__init__( config=DeviceConfig( "127.0.0.123", @@ -33,10 +40,12 @@ class FakeSmartTransport(BaseTransport): ) self.fixture_name = fixture_name self.info = copy.deepcopy(info) - self.components = { - comp["id"]: comp["ver_code"] - for comp in self.info["component_nego"]["component_list"] - } + if not component_nego_not_included: + self.components = { + comp["id"]: comp["ver_code"] + for comp in self.info["component_nego"]["component_list"] + } + self.list_return_size = list_return_size @property def default_port(self): @@ -177,7 +186,20 @@ class FakeSmartTransport(BaseTransport): elif method == "component_nego" or method[:4] == "get_": if method in info: result = copy.deepcopy(info[method]) + if "start_index" in result and "sum" in result: + list_key = next( + iter([key for key in result if isinstance(result[key], list)]) + ) + start_index = ( + start_index + if (params and (start_index := params.get("start_index"))) + else 0 + ) + result[list_key] = result[list_key][ + start_index : start_index + self.list_return_size + ] return {"result": result, "error_code": 0} + if ( # FIXTURE_MISSING is for service calls not in place when # SMART fixtures started to be generated diff --git a/kasa/tests/test_smartprotocol.py b/kasa/tests/test_smartprotocol.py index b970eaa5..ca62ba02 100644 --- a/kasa/tests/test_smartprotocol.py +++ b/kasa/tests/test_smartprotocol.py @@ -7,7 +7,8 @@ from ..exceptions import ( KasaException, SmartErrorCode, ) -from ..smartprotocol import _ChildProtocolWrapper +from ..smartprotocol import SmartProtocol, _ChildProtocolWrapper +from .fakeprotocol_smart import FakeSmartTransport DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}} DUMMY_MULTIPLE_QUERY = { @@ -180,3 +181,64 @@ async def test_childdevicewrapper_multiplerequest_error(dummy_protocol, mocker): mocker.patch.object(wrapped_protocol._transport, "send", return_value=mock_response) with pytest.raises(KasaException): await wrapped_protocol.query(DUMMY_QUERY) + + +@pytest.mark.parametrize("list_sum", [5, 10, 30]) +@pytest.mark.parametrize("batch_size", [1, 2, 3, 50]) +async def test_smart_protocol_lists_single_request(mocker, list_sum, batch_size): + child_device_list = [{"foo": i} for i in range(list_sum)] + response = { + "get_child_device_list": { + "child_device_list": child_device_list, + "start_index": 0, + "sum": list_sum, + } + } + request = {"get_child_device_list": None} + + ft = FakeSmartTransport( + response, + "foobar", + list_return_size=batch_size, + component_nego_not_included=True, + ) + protocol = SmartProtocol(transport=ft) + query_spy = mocker.spy(protocol, "_execute_query") + resp = await protocol.query(request) + expected_count = int(list_sum / batch_size) + (1 if list_sum % batch_size else 0) + assert query_spy.call_count == expected_count + assert resp == response + + +@pytest.mark.parametrize("list_sum", [5, 10, 30]) +@pytest.mark.parametrize("batch_size", [1, 2, 3, 50]) +async def test_smart_protocol_lists_multiple_request(mocker, list_sum, batch_size): + child_list = [{"foo": i} for i in range(list_sum)] + response = { + "get_child_device_list": { + "child_device_list": child_list, + "start_index": 0, + "sum": list_sum, + }, + "get_child_device_component_list": { + "child_component_list": child_list, + "start_index": 0, + "sum": list_sum, + }, + } + request = {"get_child_device_list": None, "get_child_device_component_list": None} + + ft = FakeSmartTransport( + response, + "foobar", + list_return_size=batch_size, + component_nego_not_included=True, + ) + protocol = SmartProtocol(transport=ft) + query_spy = mocker.spy(protocol, "_execute_query") + resp = await protocol.query(request) + expected_count = 1 + 2 * ( + int(list_sum / batch_size) + (0 if list_sum % batch_size else -1) + ) + assert query_spy.call_count == expected_count + assert resp == response