Handle smartcam partial list responses (#1411)

This commit is contained in:
Steven B. 2025-01-06 09:23:46 +00:00 committed by GitHub
parent 1f45f425a0
commit 6aa019280b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 91 additions and 18 deletions

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import logging
from dataclasses import dataclass
from pprint import pformat as pf
from typing import Any
from typing import Any, cast
from ..exceptions import (
AuthenticationError,
@ -49,10 +49,13 @@ class SingleRequest:
class SmartCamProtocol(SmartProtocol):
"""Class for SmartCam Protocol."""
async def _handle_response_lists(
self, response_result: dict[str, Any], method: str, retry_count: int
) -> None:
pass
def _get_list_request(
self, method: str, params: dict | None, start_index: int
) -> dict:
# All smartcam requests have params
params = cast(dict, params)
module_name = next(iter(params))
return {method: {module_name: {"start_index": start_index}}}
def _handle_response_error_code(
self, resp_dict: dict, method: str, raise_on_error: bool = True
@ -147,7 +150,9 @@ class SmartCamProtocol(SmartProtocol):
if len(request) == 1 and method in {"get", "set", "do", "multipleRequest"}:
single_request = self._get_smart_camera_single_request(request)
else:
return await self._execute_multiple_query(request, retry_count)
return await self._execute_multiple_query(
request, retry_count, iterate_list_pages
)
else:
single_request = self._make_smart_camera_single_request(request)

View File

@ -180,7 +180,9 @@ class SmartProtocol(BaseProtocol):
# make mypy happy, this should never be reached..
raise KasaException("Query reached somehow to unreachable")
async def _execute_multiple_query(self, requests: dict, retry_count: int) -> dict:
async def _execute_multiple_query(
self, requests: dict, retry_count: int, iterate_list_pages: bool
) -> dict:
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
multi_result: dict[str, Any] = {}
smart_method = "multipleRequest"
@ -275,8 +277,10 @@ class SmartProtocol(BaseProtocol):
response, method, raise_on_error=raise_on_error
)
result = response.get("result", None)
request_params = rp if (rp := requests.get(method)) else None
if iterate_list_pages and result:
await self._handle_response_lists(
result, method, retry_count=retry_count
result, method, request_params, retry_count=retry_count
)
multi_result[method] = result
@ -303,7 +307,9 @@ class SmartProtocol(BaseProtocol):
smart_method = next(iter(request))
smart_params = request[smart_method]
else:
return await self._execute_multiple_query(request, retry_count)
return await self._execute_multiple_query(
request, retry_count, iterate_list_pages
)
else:
smart_method = request
smart_params = None
@ -330,12 +336,21 @@ class SmartProtocol(BaseProtocol):
result = response_data.get("result")
if iterate_list_pages and result:
await self._handle_response_lists(
result, smart_method, retry_count=retry_count
result, smart_method, smart_params, retry_count=retry_count
)
return {smart_method: result}
def _get_list_request(
self, method: str, params: dict | None, start_index: int
) -> dict:
return {method: {"start_index": start_index}}
async def _handle_response_lists(
self, response_result: dict[str, Any], method: str, retry_count: int
self,
response_result: dict[str, Any],
method: str,
params: dict | None,
retry_count: int,
) -> None:
if (
response_result is None
@ -355,8 +370,9 @@ class SmartProtocol(BaseProtocol):
)
)
while (list_length := len(response_result[response_list_name])) < list_sum:
request = self._get_list_request(method, params, list_length)
response = await self._execute_query(
{method: {"start_index": list_length}},
request,
retry_count=retry_count,
iterate_list_pages=False,
)

View File

@ -33,6 +33,7 @@ class FakeSmartCamTransport(BaseTransport):
*,
list_return_size=10,
is_child=False,
get_child_fixtures=True,
verbatim=False,
components_not_included=False,
):
@ -52,6 +53,9 @@ class FakeSmartCamTransport(BaseTransport):
self.verbatim = verbatim
if not is_child:
self.info = copy.deepcopy(info)
# We don't need to get the child fixtures if testing things like
# lists
if get_child_fixtures:
self.child_protocols = FakeSmartTransport._get_child_protocols(
self.info, self.fixture_name, "getChildDeviceList"
)
@ -229,9 +233,16 @@ class FakeSmartCamTransport(BaseTransport):
list_key = next(
iter([key for key in result if isinstance(result[key], list)])
)
assert isinstance(params, dict)
module_name = next(iter(params))
start_index = (
start_index
if (params and (start_index := params.get("start_index")))
if (
params
and module_name
and (start_index := params[module_name].get("start_index"))
)
else 0
)

View File

@ -10,6 +10,7 @@ from kasa.exceptions import (
KasaException,
SmartErrorCode,
)
from kasa.protocols.smartcamprotocol import SmartCamProtocol
from kasa.protocols.smartprotocol import SmartProtocol, _ChildProtocolWrapper
from kasa.smart import SmartDevice
@ -373,6 +374,46 @@ async def test_smart_protocol_lists_multiple_request(mocker, list_sum, batch_siz
assert resp == response
@pytest.mark.parametrize("list_sum", [5, 10, 30])
@pytest.mark.parametrize("batch_size", [1, 2, 3, 50])
async def test_smartcam_protocol_list_request(mocker, list_sum, batch_size):
"""Test smartcam protocol list handling for lists."""
child_list = [{"foo": i} for i in range(list_sum)]
response = {
"getChildDeviceList": {
"child_device_list": child_list,
"start_index": 0,
"sum": list_sum,
},
"getChildDeviceComponentList": {
"child_component_list": child_list,
"start_index": 0,
"sum": list_sum,
},
}
request = {
"getChildDeviceList": {"childControl": {"start_index": 0}},
"getChildDeviceComponentList": {"childControl": {"start_index": 0}},
}
ft = FakeSmartCamTransport(
response,
"foobar",
list_return_size=batch_size,
components_not_included=True,
get_child_fixtures=False,
)
protocol = SmartCamProtocol(transport=ft)
query_spy = mocker.spy(protocol, "_execute_query")
resp = await protocol.query(request)
expected_count = 1 + 2 * (
int(list_sum / batch_size) + (0 if list_sum % batch_size else -1)
)
assert query_spy.call_count == expected_count
assert resp == response
async def test_incomplete_list(mocker, caplog):
"""Test for handling incomplete lists returned from queries."""
info = {