mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-04-27 09:06:24 +00:00
Handle paging of partial responses of lists like child_device_info (#862)
When devices have lists greater than 10 for child devices only the first 10 are returned. This retrieves the rest of the items (currently with single requests rather than multiple requests)
This commit is contained in:
parent
eff8db450d
commit
53b84b7683
@ -67,7 +67,9 @@ class SmartProtocol(BaseProtocol):
|
|||||||
async def _query(self, request: str | dict, retry_count: int = 3) -> dict:
|
async def _query(self, request: str | dict, retry_count: int = 3) -> dict:
|
||||||
for retry in range(retry_count + 1):
|
for retry in range(retry_count + 1):
|
||||||
try:
|
try:
|
||||||
return await self._execute_query(request, retry)
|
return await self._execute_query(
|
||||||
|
request, retry_count=retry, iterate_list_pages=True
|
||||||
|
)
|
||||||
except _ConnectionError as sdex:
|
except _ConnectionError as sdex:
|
||||||
if retry >= retry_count:
|
if retry >= retry_count:
|
||||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||||
@ -145,6 +147,9 @@ class SmartProtocol(BaseProtocol):
|
|||||||
method = response["method"]
|
method = response["method"]
|
||||||
self._handle_response_error_code(response, method, raise_on_error=False)
|
self._handle_response_error_code(response, method, raise_on_error=False)
|
||||||
result = response.get("result", None)
|
result = response.get("result", None)
|
||||||
|
await self._handle_response_lists(
|
||||||
|
result, method, retry_count=retry_count
|
||||||
|
)
|
||||||
multi_result[method] = result
|
multi_result[method] = result
|
||||||
# Multi requests don't continue after errors so requery any missing
|
# Multi requests don't continue after errors so requery any missing
|
||||||
for method, params in requests.items():
|
for method, params in requests.items():
|
||||||
@ -156,7 +161,9 @@ class SmartProtocol(BaseProtocol):
|
|||||||
multi_result[method] = resp.get("result")
|
multi_result[method] = resp.get("result")
|
||||||
return multi_result
|
return multi_result
|
||||||
|
|
||||||
async def _execute_query(self, request: str | dict, retry_count: int) -> dict:
|
async def _execute_query(
|
||||||
|
self, request: str | dict, *, retry_count: int, iterate_list_pages: bool = True
|
||||||
|
) -> dict:
|
||||||
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
|
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
|
||||||
|
|
||||||
if isinstance(request, dict):
|
if isinstance(request, dict):
|
||||||
@ -189,8 +196,40 @@ class SmartProtocol(BaseProtocol):
|
|||||||
|
|
||||||
# Single set_ requests do not return a result
|
# Single set_ requests do not return a result
|
||||||
result = response_data.get("result")
|
result = response_data.get("result")
|
||||||
|
if iterate_list_pages and result:
|
||||||
|
await self._handle_response_lists(
|
||||||
|
result, smart_method, retry_count=retry_count
|
||||||
|
)
|
||||||
return {smart_method: result}
|
return {smart_method: result}
|
||||||
|
|
||||||
|
async def _handle_response_lists(
|
||||||
|
self, response_result: dict[str, Any], method, retry_count
|
||||||
|
):
|
||||||
|
if (
|
||||||
|
isinstance(response_result, SmartErrorCode)
|
||||||
|
or "start_index" not in response_result
|
||||||
|
or (list_sum := response_result.get("sum")) is None
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
response_list_name = next(
|
||||||
|
iter(
|
||||||
|
[
|
||||||
|
key
|
||||||
|
for key in response_result
|
||||||
|
if isinstance(response_result[key], list)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
while (list_length := len(response_result[response_list_name])) < list_sum:
|
||||||
|
response = await self._execute_query(
|
||||||
|
{method: {"start_index": list_length}},
|
||||||
|
retry_count=retry_count,
|
||||||
|
iterate_list_pages=False,
|
||||||
|
)
|
||||||
|
next_batch = response[method]
|
||||||
|
response_result[response_list_name].extend(next_batch[response_list_name])
|
||||||
|
|
||||||
def _handle_response_error_code(self, resp_dict: dict, method, raise_on_error=True):
|
def _handle_response_error_code(self, resp_dict: dict, method, raise_on_error=True):
|
||||||
error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
|
error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
|
||||||
if error_code == SmartErrorCode.SUCCESS:
|
if error_code == SmartErrorCode.SUCCESS:
|
||||||
|
@ -21,7 +21,14 @@ class FakeSmartProtocol(SmartProtocol):
|
|||||||
|
|
||||||
|
|
||||||
class FakeSmartTransport(BaseTransport):
|
class FakeSmartTransport(BaseTransport):
|
||||||
def __init__(self, info, fixture_name):
|
def __init__(
|
||||||
|
self,
|
||||||
|
info,
|
||||||
|
fixture_name,
|
||||||
|
*,
|
||||||
|
list_return_size=10,
|
||||||
|
component_nego_not_included=False,
|
||||||
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
config=DeviceConfig(
|
config=DeviceConfig(
|
||||||
"127.0.0.123",
|
"127.0.0.123",
|
||||||
@ -33,10 +40,12 @@ class FakeSmartTransport(BaseTransport):
|
|||||||
)
|
)
|
||||||
self.fixture_name = fixture_name
|
self.fixture_name = fixture_name
|
||||||
self.info = copy.deepcopy(info)
|
self.info = copy.deepcopy(info)
|
||||||
|
if not component_nego_not_included:
|
||||||
self.components = {
|
self.components = {
|
||||||
comp["id"]: comp["ver_code"]
|
comp["id"]: comp["ver_code"]
|
||||||
for comp in self.info["component_nego"]["component_list"]
|
for comp in self.info["component_nego"]["component_list"]
|
||||||
}
|
}
|
||||||
|
self.list_return_size = list_return_size
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def default_port(self):
|
def default_port(self):
|
||||||
@ -177,7 +186,20 @@ class FakeSmartTransport(BaseTransport):
|
|||||||
elif method == "component_nego" or method[:4] == "get_":
|
elif method == "component_nego" or method[:4] == "get_":
|
||||||
if method in info:
|
if method in info:
|
||||||
result = copy.deepcopy(info[method])
|
result = copy.deepcopy(info[method])
|
||||||
|
if "start_index" in result and "sum" in result:
|
||||||
|
list_key = next(
|
||||||
|
iter([key for key in result if isinstance(result[key], list)])
|
||||||
|
)
|
||||||
|
start_index = (
|
||||||
|
start_index
|
||||||
|
if (params and (start_index := params.get("start_index")))
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
result[list_key] = result[list_key][
|
||||||
|
start_index : start_index + self.list_return_size
|
||||||
|
]
|
||||||
return {"result": result, "error_code": 0}
|
return {"result": result, "error_code": 0}
|
||||||
|
|
||||||
if (
|
if (
|
||||||
# FIXTURE_MISSING is for service calls not in place when
|
# FIXTURE_MISSING is for service calls not in place when
|
||||||
# SMART fixtures started to be generated
|
# SMART fixtures started to be generated
|
||||||
|
@ -7,7 +7,8 @@ from ..exceptions import (
|
|||||||
KasaException,
|
KasaException,
|
||||||
SmartErrorCode,
|
SmartErrorCode,
|
||||||
)
|
)
|
||||||
from ..smartprotocol import _ChildProtocolWrapper
|
from ..smartprotocol import SmartProtocol, _ChildProtocolWrapper
|
||||||
|
from .fakeprotocol_smart import FakeSmartTransport
|
||||||
|
|
||||||
DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}}
|
DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}}
|
||||||
DUMMY_MULTIPLE_QUERY = {
|
DUMMY_MULTIPLE_QUERY = {
|
||||||
@ -180,3 +181,64 @@ async def test_childdevicewrapper_multiplerequest_error(dummy_protocol, mocker):
|
|||||||
mocker.patch.object(wrapped_protocol._transport, "send", return_value=mock_response)
|
mocker.patch.object(wrapped_protocol._transport, "send", return_value=mock_response)
|
||||||
with pytest.raises(KasaException):
|
with pytest.raises(KasaException):
|
||||||
await wrapped_protocol.query(DUMMY_QUERY)
|
await wrapped_protocol.query(DUMMY_QUERY)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("list_sum", [5, 10, 30])
|
||||||
|
@pytest.mark.parametrize("batch_size", [1, 2, 3, 50])
|
||||||
|
async def test_smart_protocol_lists_single_request(mocker, list_sum, batch_size):
|
||||||
|
child_device_list = [{"foo": i} for i in range(list_sum)]
|
||||||
|
response = {
|
||||||
|
"get_child_device_list": {
|
||||||
|
"child_device_list": child_device_list,
|
||||||
|
"start_index": 0,
|
||||||
|
"sum": list_sum,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
request = {"get_child_device_list": None}
|
||||||
|
|
||||||
|
ft = FakeSmartTransport(
|
||||||
|
response,
|
||||||
|
"foobar",
|
||||||
|
list_return_size=batch_size,
|
||||||
|
component_nego_not_included=True,
|
||||||
|
)
|
||||||
|
protocol = SmartProtocol(transport=ft)
|
||||||
|
query_spy = mocker.spy(protocol, "_execute_query")
|
||||||
|
resp = await protocol.query(request)
|
||||||
|
expected_count = int(list_sum / batch_size) + (1 if list_sum % batch_size else 0)
|
||||||
|
assert query_spy.call_count == expected_count
|
||||||
|
assert resp == response
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("list_sum", [5, 10, 30])
|
||||||
|
@pytest.mark.parametrize("batch_size", [1, 2, 3, 50])
|
||||||
|
async def test_smart_protocol_lists_multiple_request(mocker, list_sum, batch_size):
|
||||||
|
child_list = [{"foo": i} for i in range(list_sum)]
|
||||||
|
response = {
|
||||||
|
"get_child_device_list": {
|
||||||
|
"child_device_list": child_list,
|
||||||
|
"start_index": 0,
|
||||||
|
"sum": list_sum,
|
||||||
|
},
|
||||||
|
"get_child_device_component_list": {
|
||||||
|
"child_component_list": child_list,
|
||||||
|
"start_index": 0,
|
||||||
|
"sum": list_sum,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
request = {"get_child_device_list": None, "get_child_device_component_list": None}
|
||||||
|
|
||||||
|
ft = FakeSmartTransport(
|
||||||
|
response,
|
||||||
|
"foobar",
|
||||||
|
list_return_size=batch_size,
|
||||||
|
component_nego_not_included=True,
|
||||||
|
)
|
||||||
|
protocol = SmartProtocol(transport=ft)
|
||||||
|
query_spy = mocker.spy(protocol, "_execute_query")
|
||||||
|
resp = await protocol.query(request)
|
||||||
|
expected_count = 1 + 2 * (
|
||||||
|
int(list_sum / batch_size) + (0 if list_sum % batch_size else -1)
|
||||||
|
)
|
||||||
|
assert query_spy.call_count == expected_count
|
||||||
|
assert resp == response
|
||||||
|
Loading…
x
Reference in New Issue
Block a user