From 214b26a1ea99d890d2b5620c26054d1d3a776cee Mon Sep 17 00:00:00 2001 From: Steven B <51370195+sdb9696@users.noreply.github.com> Date: Sat, 20 Apr 2024 16:24:49 +0100 Subject: [PATCH] Re-query missing responses after multi request errors (#850) When smart devices encounter an error during a multipleRequest they return the previous successes and the current error and stop processing subsequent requests. This checks the responses returned and re-queries individually for any missing responses so that individual errors do not break other components. --- kasa/smartprotocol.py | 18 +++++++++++++----- kasa/tests/fakeprotocol_smart.py | 3 +++ kasa/tests/test_smartprotocol.py | 8 +++++++- 3 files changed, 23 insertions(+), 6 deletions(-) diff --git a/kasa/smartprotocol.py b/kasa/smartprotocol.py index 3020a575..9a1482b1 100644 --- a/kasa/smartprotocol.py +++ b/kasa/smartprotocol.py @@ -105,21 +105,21 @@ class SmartProtocol(BaseProtocol): # make mypy happy, this should never be reached.. raise KasaException("Query reached somehow to unreachable") - async def _execute_multiple_query(self, request: dict, retry_count: int) -> dict: + async def _execute_multiple_query(self, requests: dict, retry_count: int) -> dict: debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG) multi_result: dict[str, Any] = {} smart_method = "multipleRequest" - requests = [ - {"method": method, "params": params} for method, params in request.items() + multi_requests = [ + {"method": method, "params": params} for method, params in requests.items() ] - end = len(requests) + end = len(multi_requests) # Break the requests down as there can be a size limit step = ( self._transport._config.batch_size or self.DEFAULT_MULTI_REQUEST_BATCH_SIZE ) for i in range(0, end, step): - requests_step = requests[i : i + step] + requests_step = multi_requests[i : i + step] smart_params = {"requests": requests_step} smart_request = self.get_smart_request(smart_method, smart_params) @@ -146,6 +146,14 @@ class SmartProtocol(BaseProtocol): self._handle_response_error_code(response, method, raise_on_error=False) result = response.get("result", None) multi_result[method] = result + # Multi requests don't continue after errors so requery any missing + for method, params in requests.items(): + if method not in multi_result: + resp = await self._transport.send( + self.get_smart_request(method, params) + ) + self._handle_response_error_code(resp, method, raise_on_error=False) + multi_result[method] = resp.get("result") return multi_result async def _execute_query(self, request: str | dict, retry_count: int) -> dict: diff --git a/kasa/tests/fakeprotocol_smart.py b/kasa/tests/fakeprotocol_smart.py index 024e7636..d03d04c4 100644 --- a/kasa/tests/fakeprotocol_smart.py +++ b/kasa/tests/fakeprotocol_smart.py @@ -113,6 +113,9 @@ class FakeSmartTransport(BaseTransport): responses = [] for request in params["requests"]: response = self._send_request(request) # type: ignore[arg-type] + # Devices do not continue after error + if response["error_code"] != 0: + break response["method"] = request["method"] # type: ignore[index] responses.append(response) return {"result": {"responses": responses}, "error_code": 0} diff --git a/kasa/tests/test_smartprotocol.py b/kasa/tests/test_smartprotocol.py index 541d17c9..b970eaa5 100644 --- a/kasa/tests/test_smartprotocol.py +++ b/kasa/tests/test_smartprotocol.py @@ -36,6 +36,11 @@ async def test_smart_device_errors(dummy_protocol, mocker, error_code): async def test_smart_device_errors_in_multiple_request( dummy_protocol, mocker, error_code ): + mock_request = { + "foobar1": {"foo": "bar", "bar": "foo"}, + "foobar2": {"foo": "bar", "bar": "foo"}, + "foobar3": {"foo": "bar", "bar": "foo"}, + } mock_response = { "result": { "responses": [ @@ -55,9 +60,10 @@ async def test_smart_device_errors_in_multiple_request( dummy_protocol._transport, "send", return_value=mock_response ) - resp_dict = await dummy_protocol.query(DUMMY_MULTIPLE_QUERY, retry_count=2) + resp_dict = await dummy_protocol.query(mock_request, retry_count=2) assert resp_dict["foobar2"] == error_code assert send_mock.call_count == 1 + assert len(resp_dict) == len(mock_request) @pytest.mark.parametrize("request_size", [1, 3, 5, 10])