mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-01-09 22:37:08 +00:00
Handle smartcam partial list responses (#1411)
This commit is contained in:
parent
1f45f425a0
commit
6aa019280b
@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pprint import pformat as pf
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from ..exceptions import (
|
||||
AuthenticationError,
|
||||
@ -49,10 +49,13 @@ class SingleRequest:
|
||||
class SmartCamProtocol(SmartProtocol):
|
||||
"""Class for SmartCam Protocol."""
|
||||
|
||||
async def _handle_response_lists(
|
||||
self, response_result: dict[str, Any], method: str, retry_count: int
|
||||
) -> None:
|
||||
pass
|
||||
def _get_list_request(
|
||||
self, method: str, params: dict | None, start_index: int
|
||||
) -> dict:
|
||||
# All smartcam requests have params
|
||||
params = cast(dict, params)
|
||||
module_name = next(iter(params))
|
||||
return {method: {module_name: {"start_index": start_index}}}
|
||||
|
||||
def _handle_response_error_code(
|
||||
self, resp_dict: dict, method: str, raise_on_error: bool = True
|
||||
@ -147,7 +150,9 @@ class SmartCamProtocol(SmartProtocol):
|
||||
if len(request) == 1 and method in {"get", "set", "do", "multipleRequest"}:
|
||||
single_request = self._get_smart_camera_single_request(request)
|
||||
else:
|
||||
return await self._execute_multiple_query(request, retry_count)
|
||||
return await self._execute_multiple_query(
|
||||
request, retry_count, iterate_list_pages
|
||||
)
|
||||
else:
|
||||
single_request = self._make_smart_camera_single_request(request)
|
||||
|
||||
|
@ -180,7 +180,9 @@ class SmartProtocol(BaseProtocol):
|
||||
# make mypy happy, this should never be reached..
|
||||
raise KasaException("Query reached somehow to unreachable")
|
||||
|
||||
async def _execute_multiple_query(self, requests: dict, retry_count: int) -> dict:
|
||||
async def _execute_multiple_query(
|
||||
self, requests: dict, retry_count: int, iterate_list_pages: bool
|
||||
) -> dict:
|
||||
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
|
||||
multi_result: dict[str, Any] = {}
|
||||
smart_method = "multipleRequest"
|
||||
@ -275,9 +277,11 @@ class SmartProtocol(BaseProtocol):
|
||||
response, method, raise_on_error=raise_on_error
|
||||
)
|
||||
result = response.get("result", None)
|
||||
await self._handle_response_lists(
|
||||
result, method, retry_count=retry_count
|
||||
)
|
||||
request_params = rp if (rp := requests.get(method)) else None
|
||||
if iterate_list_pages and result:
|
||||
await self._handle_response_lists(
|
||||
result, method, request_params, retry_count=retry_count
|
||||
)
|
||||
multi_result[method] = result
|
||||
|
||||
# Multi requests don't continue after errors so requery any missing.
|
||||
@ -303,7 +307,9 @@ class SmartProtocol(BaseProtocol):
|
||||
smart_method = next(iter(request))
|
||||
smart_params = request[smart_method]
|
||||
else:
|
||||
return await self._execute_multiple_query(request, retry_count)
|
||||
return await self._execute_multiple_query(
|
||||
request, retry_count, iterate_list_pages
|
||||
)
|
||||
else:
|
||||
smart_method = request
|
||||
smart_params = None
|
||||
@ -330,12 +336,21 @@ class SmartProtocol(BaseProtocol):
|
||||
result = response_data.get("result")
|
||||
if iterate_list_pages and result:
|
||||
await self._handle_response_lists(
|
||||
result, smart_method, retry_count=retry_count
|
||||
result, smart_method, smart_params, retry_count=retry_count
|
||||
)
|
||||
return {smart_method: result}
|
||||
|
||||
def _get_list_request(
|
||||
self, method: str, params: dict | None, start_index: int
|
||||
) -> dict:
|
||||
return {method: {"start_index": start_index}}
|
||||
|
||||
async def _handle_response_lists(
|
||||
self, response_result: dict[str, Any], method: str, retry_count: int
|
||||
self,
|
||||
response_result: dict[str, Any],
|
||||
method: str,
|
||||
params: dict | None,
|
||||
retry_count: int,
|
||||
) -> None:
|
||||
if (
|
||||
response_result is None
|
||||
@ -355,8 +370,9 @@ class SmartProtocol(BaseProtocol):
|
||||
)
|
||||
)
|
||||
while (list_length := len(response_result[response_list_name])) < list_sum:
|
||||
request = self._get_list_request(method, params, list_length)
|
||||
response = await self._execute_query(
|
||||
{method: {"start_index": list_length}},
|
||||
request,
|
||||
retry_count=retry_count,
|
||||
iterate_list_pages=False,
|
||||
)
|
||||
|
@ -33,6 +33,7 @@ class FakeSmartCamTransport(BaseTransport):
|
||||
*,
|
||||
list_return_size=10,
|
||||
is_child=False,
|
||||
get_child_fixtures=True,
|
||||
verbatim=False,
|
||||
components_not_included=False,
|
||||
):
|
||||
@ -52,9 +53,12 @@ class FakeSmartCamTransport(BaseTransport):
|
||||
self.verbatim = verbatim
|
||||
if not is_child:
|
||||
self.info = copy.deepcopy(info)
|
||||
self.child_protocols = FakeSmartTransport._get_child_protocols(
|
||||
self.info, self.fixture_name, "getChildDeviceList"
|
||||
)
|
||||
# We don't need to get the child fixtures if testing things like
|
||||
# lists
|
||||
if get_child_fixtures:
|
||||
self.child_protocols = FakeSmartTransport._get_child_protocols(
|
||||
self.info, self.fixture_name, "getChildDeviceList"
|
||||
)
|
||||
else:
|
||||
self.info = info
|
||||
# self.child_protocols = self._get_child_protocols()
|
||||
@ -229,9 +233,16 @@ class FakeSmartCamTransport(BaseTransport):
|
||||
list_key = next(
|
||||
iter([key for key in result if isinstance(result[key], list)])
|
||||
)
|
||||
assert isinstance(params, dict)
|
||||
module_name = next(iter(params))
|
||||
|
||||
start_index = (
|
||||
start_index
|
||||
if (params and (start_index := params.get("start_index")))
|
||||
if (
|
||||
params
|
||||
and module_name
|
||||
and (start_index := params[module_name].get("start_index"))
|
||||
)
|
||||
else 0
|
||||
)
|
||||
|
||||
|
@ -10,6 +10,7 @@ from kasa.exceptions import (
|
||||
KasaException,
|
||||
SmartErrorCode,
|
||||
)
|
||||
from kasa.protocols.smartcamprotocol import SmartCamProtocol
|
||||
from kasa.protocols.smartprotocol import SmartProtocol, _ChildProtocolWrapper
|
||||
from kasa.smart import SmartDevice
|
||||
|
||||
@ -373,6 +374,46 @@ async def test_smart_protocol_lists_multiple_request(mocker, list_sum, batch_siz
|
||||
assert resp == response
|
||||
|
||||
|
||||
@pytest.mark.parametrize("list_sum", [5, 10, 30])
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 3, 50])
|
||||
async def test_smartcam_protocol_list_request(mocker, list_sum, batch_size):
|
||||
"""Test smartcam protocol list handling for lists."""
|
||||
child_list = [{"foo": i} for i in range(list_sum)]
|
||||
|
||||
response = {
|
||||
"getChildDeviceList": {
|
||||
"child_device_list": child_list,
|
||||
"start_index": 0,
|
||||
"sum": list_sum,
|
||||
},
|
||||
"getChildDeviceComponentList": {
|
||||
"child_component_list": child_list,
|
||||
"start_index": 0,
|
||||
"sum": list_sum,
|
||||
},
|
||||
}
|
||||
request = {
|
||||
"getChildDeviceList": {"childControl": {"start_index": 0}},
|
||||
"getChildDeviceComponentList": {"childControl": {"start_index": 0}},
|
||||
}
|
||||
|
||||
ft = FakeSmartCamTransport(
|
||||
response,
|
||||
"foobar",
|
||||
list_return_size=batch_size,
|
||||
components_not_included=True,
|
||||
get_child_fixtures=False,
|
||||
)
|
||||
protocol = SmartCamProtocol(transport=ft)
|
||||
query_spy = mocker.spy(protocol, "_execute_query")
|
||||
resp = await protocol.query(request)
|
||||
expected_count = 1 + 2 * (
|
||||
int(list_sum / batch_size) + (0 if list_sum % batch_size else -1)
|
||||
)
|
||||
assert query_spy.call_count == expected_count
|
||||
assert resp == response
|
||||
|
||||
|
||||
async def test_incomplete_list(mocker, caplog):
|
||||
"""Test for handling incomplete lists returned from queries."""
|
||||
info = {
|
||||
|
Loading…
Reference in New Issue
Block a user