Handle smartcam partial list responses (#1411)

This commit is contained in:
Steven B. 2025-01-06 09:23:46 +00:00 committed by GitHub
parent 1f45f425a0
commit 6aa019280b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 91 additions and 18 deletions

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from pprint import pformat as pf from pprint import pformat as pf
from typing import Any from typing import Any, cast
from ..exceptions import ( from ..exceptions import (
AuthenticationError, AuthenticationError,
@ -49,10 +49,13 @@ class SingleRequest:
class SmartCamProtocol(SmartProtocol): class SmartCamProtocol(SmartProtocol):
"""Class for SmartCam Protocol.""" """Class for SmartCam Protocol."""
async def _handle_response_lists( def _get_list_request(
self, response_result: dict[str, Any], method: str, retry_count: int self, method: str, params: dict | None, start_index: int
) -> None: ) -> dict:
pass # All smartcam requests have params
params = cast(dict, params)
module_name = next(iter(params))
return {method: {module_name: {"start_index": start_index}}}
def _handle_response_error_code( def _handle_response_error_code(
self, resp_dict: dict, method: str, raise_on_error: bool = True self, resp_dict: dict, method: str, raise_on_error: bool = True
@ -147,7 +150,9 @@ class SmartCamProtocol(SmartProtocol):
if len(request) == 1 and method in {"get", "set", "do", "multipleRequest"}: if len(request) == 1 and method in {"get", "set", "do", "multipleRequest"}:
single_request = self._get_smart_camera_single_request(request) single_request = self._get_smart_camera_single_request(request)
else: else:
return await self._execute_multiple_query(request, retry_count) return await self._execute_multiple_query(
request, retry_count, iterate_list_pages
)
else: else:
single_request = self._make_smart_camera_single_request(request) single_request = self._make_smart_camera_single_request(request)

View File

@ -180,7 +180,9 @@ class SmartProtocol(BaseProtocol):
# make mypy happy, this should never be reached.. # make mypy happy, this should never be reached..
raise KasaException("Query reached somehow to unreachable") raise KasaException("Query reached somehow to unreachable")
async def _execute_multiple_query(self, requests: dict, retry_count: int) -> dict: async def _execute_multiple_query(
self, requests: dict, retry_count: int, iterate_list_pages: bool
) -> dict:
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG) debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
multi_result: dict[str, Any] = {} multi_result: dict[str, Any] = {}
smart_method = "multipleRequest" smart_method = "multipleRequest"
@ -275,8 +277,10 @@ class SmartProtocol(BaseProtocol):
response, method, raise_on_error=raise_on_error response, method, raise_on_error=raise_on_error
) )
result = response.get("result", None) result = response.get("result", None)
request_params = rp if (rp := requests.get(method)) else None
if iterate_list_pages and result:
await self._handle_response_lists( await self._handle_response_lists(
result, method, retry_count=retry_count result, method, request_params, retry_count=retry_count
) )
multi_result[method] = result multi_result[method] = result
@ -303,7 +307,9 @@ class SmartProtocol(BaseProtocol):
smart_method = next(iter(request)) smart_method = next(iter(request))
smart_params = request[smart_method] smart_params = request[smart_method]
else: else:
return await self._execute_multiple_query(request, retry_count) return await self._execute_multiple_query(
request, retry_count, iterate_list_pages
)
else: else:
smart_method = request smart_method = request
smart_params = None smart_params = None
@ -330,12 +336,21 @@ class SmartProtocol(BaseProtocol):
result = response_data.get("result") result = response_data.get("result")
if iterate_list_pages and result: if iterate_list_pages and result:
await self._handle_response_lists( await self._handle_response_lists(
result, smart_method, retry_count=retry_count result, smart_method, smart_params, retry_count=retry_count
) )
return {smart_method: result} return {smart_method: result}
def _get_list_request(
self, method: str, params: dict | None, start_index: int
) -> dict:
return {method: {"start_index": start_index}}
async def _handle_response_lists( async def _handle_response_lists(
self, response_result: dict[str, Any], method: str, retry_count: int self,
response_result: dict[str, Any],
method: str,
params: dict | None,
retry_count: int,
) -> None: ) -> None:
if ( if (
response_result is None response_result is None
@ -355,8 +370,9 @@ class SmartProtocol(BaseProtocol):
) )
) )
while (list_length := len(response_result[response_list_name])) < list_sum: while (list_length := len(response_result[response_list_name])) < list_sum:
request = self._get_list_request(method, params, list_length)
response = await self._execute_query( response = await self._execute_query(
{method: {"start_index": list_length}}, request,
retry_count=retry_count, retry_count=retry_count,
iterate_list_pages=False, iterate_list_pages=False,
) )

View File

@ -33,6 +33,7 @@ class FakeSmartCamTransport(BaseTransport):
*, *,
list_return_size=10, list_return_size=10,
is_child=False, is_child=False,
get_child_fixtures=True,
verbatim=False, verbatim=False,
components_not_included=False, components_not_included=False,
): ):
@ -52,6 +53,9 @@ class FakeSmartCamTransport(BaseTransport):
self.verbatim = verbatim self.verbatim = verbatim
if not is_child: if not is_child:
self.info = copy.deepcopy(info) self.info = copy.deepcopy(info)
# We don't need to get the child fixtures if testing things like
# lists
if get_child_fixtures:
self.child_protocols = FakeSmartTransport._get_child_protocols( self.child_protocols = FakeSmartTransport._get_child_protocols(
self.info, self.fixture_name, "getChildDeviceList" self.info, self.fixture_name, "getChildDeviceList"
) )
@ -229,9 +233,16 @@ class FakeSmartCamTransport(BaseTransport):
list_key = next( list_key = next(
iter([key for key in result if isinstance(result[key], list)]) iter([key for key in result if isinstance(result[key], list)])
) )
assert isinstance(params, dict)
module_name = next(iter(params))
start_index = ( start_index = (
start_index start_index
if (params and (start_index := params.get("start_index"))) if (
params
and module_name
and (start_index := params[module_name].get("start_index"))
)
else 0 else 0
) )

View File

@ -10,6 +10,7 @@ from kasa.exceptions import (
KasaException, KasaException,
SmartErrorCode, SmartErrorCode,
) )
from kasa.protocols.smartcamprotocol import SmartCamProtocol
from kasa.protocols.smartprotocol import SmartProtocol, _ChildProtocolWrapper from kasa.protocols.smartprotocol import SmartProtocol, _ChildProtocolWrapper
from kasa.smart import SmartDevice from kasa.smart import SmartDevice
@ -373,6 +374,46 @@ async def test_smart_protocol_lists_multiple_request(mocker, list_sum, batch_siz
assert resp == response assert resp == response
@pytest.mark.parametrize("list_sum", [5, 10, 30])
@pytest.mark.parametrize("batch_size", [1, 2, 3, 50])
async def test_smartcam_protocol_list_request(mocker, list_sum, batch_size):
"""Test smartcam protocol list handling for lists."""
child_list = [{"foo": i} for i in range(list_sum)]
response = {
"getChildDeviceList": {
"child_device_list": child_list,
"start_index": 0,
"sum": list_sum,
},
"getChildDeviceComponentList": {
"child_component_list": child_list,
"start_index": 0,
"sum": list_sum,
},
}
request = {
"getChildDeviceList": {"childControl": {"start_index": 0}},
"getChildDeviceComponentList": {"childControl": {"start_index": 0}},
}
ft = FakeSmartCamTransport(
response,
"foobar",
list_return_size=batch_size,
components_not_included=True,
get_child_fixtures=False,
)
protocol = SmartCamProtocol(transport=ft)
query_spy = mocker.spy(protocol, "_execute_query")
resp = await protocol.query(request)
expected_count = 1 + 2 * (
int(list_sum / batch_size) + (0 if list_sum % batch_size else -1)
)
assert query_spy.call_count == expected_count
assert resp == response
async def test_incomplete_list(mocker, caplog): async def test_incomplete_list(mocker, caplog):
"""Test for handling incomplete lists returned from queries.""" """Test for handling incomplete lists returned from queries."""
info = { info = {