mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-04-27 09:06:24 +00:00
Re-query missing responses after multi request errors (#850)
When smart devices encounter an error during a multipleRequest they return the previous successes and the current error and stop processing subsequent requests. This checks the responses returned and re-queries individually for any missing responses so that individual errors do not break other components.
This commit is contained in:
parent
aeb2c923c6
commit
214b26a1ea
@ -105,21 +105,21 @@ class SmartProtocol(BaseProtocol):
|
|||||||
# make mypy happy, this should never be reached..
|
# make mypy happy, this should never be reached..
|
||||||
raise KasaException("Query reached somehow to unreachable")
|
raise KasaException("Query reached somehow to unreachable")
|
||||||
|
|
||||||
async def _execute_multiple_query(self, request: dict, retry_count: int) -> dict:
|
async def _execute_multiple_query(self, requests: dict, retry_count: int) -> dict:
|
||||||
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
|
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
|
||||||
multi_result: dict[str, Any] = {}
|
multi_result: dict[str, Any] = {}
|
||||||
smart_method = "multipleRequest"
|
smart_method = "multipleRequest"
|
||||||
requests = [
|
multi_requests = [
|
||||||
{"method": method, "params": params} for method, params in request.items()
|
{"method": method, "params": params} for method, params in requests.items()
|
||||||
]
|
]
|
||||||
|
|
||||||
end = len(requests)
|
end = len(multi_requests)
|
||||||
# Break the requests down as there can be a size limit
|
# Break the requests down as there can be a size limit
|
||||||
step = (
|
step = (
|
||||||
self._transport._config.batch_size or self.DEFAULT_MULTI_REQUEST_BATCH_SIZE
|
self._transport._config.batch_size or self.DEFAULT_MULTI_REQUEST_BATCH_SIZE
|
||||||
)
|
)
|
||||||
for i in range(0, end, step):
|
for i in range(0, end, step):
|
||||||
requests_step = requests[i : i + step]
|
requests_step = multi_requests[i : i + step]
|
||||||
|
|
||||||
smart_params = {"requests": requests_step}
|
smart_params = {"requests": requests_step}
|
||||||
smart_request = self.get_smart_request(smart_method, smart_params)
|
smart_request = self.get_smart_request(smart_method, smart_params)
|
||||||
@ -146,6 +146,14 @@ class SmartProtocol(BaseProtocol):
|
|||||||
self._handle_response_error_code(response, method, raise_on_error=False)
|
self._handle_response_error_code(response, method, raise_on_error=False)
|
||||||
result = response.get("result", None)
|
result = response.get("result", None)
|
||||||
multi_result[method] = result
|
multi_result[method] = result
|
||||||
|
# Multi requests don't continue after errors so requery any missing
|
||||||
|
for method, params in requests.items():
|
||||||
|
if method not in multi_result:
|
||||||
|
resp = await self._transport.send(
|
||||||
|
self.get_smart_request(method, params)
|
||||||
|
)
|
||||||
|
self._handle_response_error_code(resp, method, raise_on_error=False)
|
||||||
|
multi_result[method] = resp.get("result")
|
||||||
return multi_result
|
return multi_result
|
||||||
|
|
||||||
async def _execute_query(self, request: str | dict, retry_count: int) -> dict:
|
async def _execute_query(self, request: str | dict, retry_count: int) -> dict:
|
||||||
|
@ -113,6 +113,9 @@ class FakeSmartTransport(BaseTransport):
|
|||||||
responses = []
|
responses = []
|
||||||
for request in params["requests"]:
|
for request in params["requests"]:
|
||||||
response = self._send_request(request) # type: ignore[arg-type]
|
response = self._send_request(request) # type: ignore[arg-type]
|
||||||
|
# Devices do not continue after error
|
||||||
|
if response["error_code"] != 0:
|
||||||
|
break
|
||||||
response["method"] = request["method"] # type: ignore[index]
|
response["method"] = request["method"] # type: ignore[index]
|
||||||
responses.append(response)
|
responses.append(response)
|
||||||
return {"result": {"responses": responses}, "error_code": 0}
|
return {"result": {"responses": responses}, "error_code": 0}
|
||||||
|
@ -36,6 +36,11 @@ async def test_smart_device_errors(dummy_protocol, mocker, error_code):
|
|||||||
async def test_smart_device_errors_in_multiple_request(
|
async def test_smart_device_errors_in_multiple_request(
|
||||||
dummy_protocol, mocker, error_code
|
dummy_protocol, mocker, error_code
|
||||||
):
|
):
|
||||||
|
mock_request = {
|
||||||
|
"foobar1": {"foo": "bar", "bar": "foo"},
|
||||||
|
"foobar2": {"foo": "bar", "bar": "foo"},
|
||||||
|
"foobar3": {"foo": "bar", "bar": "foo"},
|
||||||
|
}
|
||||||
mock_response = {
|
mock_response = {
|
||||||
"result": {
|
"result": {
|
||||||
"responses": [
|
"responses": [
|
||||||
@ -55,9 +60,10 @@ async def test_smart_device_errors_in_multiple_request(
|
|||||||
dummy_protocol._transport, "send", return_value=mock_response
|
dummy_protocol._transport, "send", return_value=mock_response
|
||||||
)
|
)
|
||||||
|
|
||||||
resp_dict = await dummy_protocol.query(DUMMY_MULTIPLE_QUERY, retry_count=2)
|
resp_dict = await dummy_protocol.query(mock_request, retry_count=2)
|
||||||
assert resp_dict["foobar2"] == error_code
|
assert resp_dict["foobar2"] == error_code
|
||||||
assert send_mock.call_count == 1
|
assert send_mock.call_count == 1
|
||||||
|
assert len(resp_dict) == len(mock_request)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("request_size", [1, 3, 5, 10])
|
@pytest.mark.parametrize("request_size", [1, 3, 5, 10])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user