Handle paging of partial responses of lists like child_device_info (#862)

When devices have lists greater than 10 for child devices only the first
10 are returned. This retrieves the rest of the items (currently with
single requests rather than multiple requests)
This commit is contained in:
Steven B 2024-04-24 19:32:30 +01:00 committed by GitHub
parent eff8db450d
commit 53b84b7683
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 131 additions and 8 deletions

View File

@ -67,7 +67,9 @@ class SmartProtocol(BaseProtocol):
async def _query(self, request: str | dict, retry_count: int = 3) -> dict:
for retry in range(retry_count + 1):
try:
return await self._execute_query(request, retry)
return await self._execute_query(
request, retry_count=retry, iterate_list_pages=True
)
except _ConnectionError as sdex:
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
@ -145,6 +147,9 @@ class SmartProtocol(BaseProtocol):
method = response["method"]
self._handle_response_error_code(response, method, raise_on_error=False)
result = response.get("result", None)
await self._handle_response_lists(
result, method, retry_count=retry_count
)
multi_result[method] = result
# Multi requests don't continue after errors so requery any missing
for method, params in requests.items():
@ -156,7 +161,9 @@ class SmartProtocol(BaseProtocol):
multi_result[method] = resp.get("result")
return multi_result
async def _execute_query(self, request: str | dict, retry_count: int) -> dict:
async def _execute_query(
self, request: str | dict, *, retry_count: int, iterate_list_pages: bool = True
) -> dict:
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
if isinstance(request, dict):
@ -189,8 +196,40 @@ class SmartProtocol(BaseProtocol):
# Single set_ requests do not return a result
result = response_data.get("result")
if iterate_list_pages and result:
await self._handle_response_lists(
result, smart_method, retry_count=retry_count
)
return {smart_method: result}
async def _handle_response_lists(
self, response_result: dict[str, Any], method, retry_count
):
if (
isinstance(response_result, SmartErrorCode)
or "start_index" not in response_result
or (list_sum := response_result.get("sum")) is None
):
return
response_list_name = next(
iter(
[
key
for key in response_result
if isinstance(response_result[key], list)
]
)
)
while (list_length := len(response_result[response_list_name])) < list_sum:
response = await self._execute_query(
{method: {"start_index": list_length}},
retry_count=retry_count,
iterate_list_pages=False,
)
next_batch = response[method]
response_result[response_list_name].extend(next_batch[response_list_name])
def _handle_response_error_code(self, resp_dict: dict, method, raise_on_error=True):
error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
if error_code == SmartErrorCode.SUCCESS:

View File

@ -21,7 +21,14 @@ class FakeSmartProtocol(SmartProtocol):
class FakeSmartTransport(BaseTransport):
def __init__(self, info, fixture_name):
def __init__(
self,
info,
fixture_name,
*,
list_return_size=10,
component_nego_not_included=False,
):
super().__init__(
config=DeviceConfig(
"127.0.0.123",
@ -33,10 +40,12 @@ class FakeSmartTransport(BaseTransport):
)
self.fixture_name = fixture_name
self.info = copy.deepcopy(info)
self.components = {
comp["id"]: comp["ver_code"]
for comp in self.info["component_nego"]["component_list"]
}
if not component_nego_not_included:
self.components = {
comp["id"]: comp["ver_code"]
for comp in self.info["component_nego"]["component_list"]
}
self.list_return_size = list_return_size
@property
def default_port(self):
@ -177,7 +186,20 @@ class FakeSmartTransport(BaseTransport):
elif method == "component_nego" or method[:4] == "get_":
if method in info:
result = copy.deepcopy(info[method])
if "start_index" in result and "sum" in result:
list_key = next(
iter([key for key in result if isinstance(result[key], list)])
)
start_index = (
start_index
if (params and (start_index := params.get("start_index")))
else 0
)
result[list_key] = result[list_key][
start_index : start_index + self.list_return_size
]
return {"result": result, "error_code": 0}
if (
# FIXTURE_MISSING is for service calls not in place when
# SMART fixtures started to be generated

View File

@ -7,7 +7,8 @@ from ..exceptions import (
KasaException,
SmartErrorCode,
)
from ..smartprotocol import _ChildProtocolWrapper
from ..smartprotocol import SmartProtocol, _ChildProtocolWrapper
from .fakeprotocol_smart import FakeSmartTransport
DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}}
DUMMY_MULTIPLE_QUERY = {
@ -180,3 +181,64 @@ async def test_childdevicewrapper_multiplerequest_error(dummy_protocol, mocker):
mocker.patch.object(wrapped_protocol._transport, "send", return_value=mock_response)
with pytest.raises(KasaException):
await wrapped_protocol.query(DUMMY_QUERY)
@pytest.mark.parametrize("list_sum", [5, 10, 30])
@pytest.mark.parametrize("batch_size", [1, 2, 3, 50])
async def test_smart_protocol_lists_single_request(mocker, list_sum, batch_size):
child_device_list = [{"foo": i} for i in range(list_sum)]
response = {
"get_child_device_list": {
"child_device_list": child_device_list,
"start_index": 0,
"sum": list_sum,
}
}
request = {"get_child_device_list": None}
ft = FakeSmartTransport(
response,
"foobar",
list_return_size=batch_size,
component_nego_not_included=True,
)
protocol = SmartProtocol(transport=ft)
query_spy = mocker.spy(protocol, "_execute_query")
resp = await protocol.query(request)
expected_count = int(list_sum / batch_size) + (1 if list_sum % batch_size else 0)
assert query_spy.call_count == expected_count
assert resp == response
@pytest.mark.parametrize("list_sum", [5, 10, 30])
@pytest.mark.parametrize("batch_size", [1, 2, 3, 50])
async def test_smart_protocol_lists_multiple_request(mocker, list_sum, batch_size):
child_list = [{"foo": i} for i in range(list_sum)]
response = {
"get_child_device_list": {
"child_device_list": child_list,
"start_index": 0,
"sum": list_sum,
},
"get_child_device_component_list": {
"child_component_list": child_list,
"start_index": 0,
"sum": list_sum,
},
}
request = {"get_child_device_list": None, "get_child_device_component_list": None}
ft = FakeSmartTransport(
response,
"foobar",
list_return_size=batch_size,
component_nego_not_included=True,
)
protocol = SmartProtocol(transport=ft)
query_spy = mocker.spy(protocol, "_execute_query")
resp = await protocol.query(request)
expected_count = 1 + 2 * (
int(list_sum / batch_size) + (0 if list_sum % batch_size else -1)
)
assert query_spy.call_count == expected_count
assert resp == response