diff --git a/kasa/smartprotocol.py b/kasa/smartprotocol.py index 22fd49dc..f7551e33 100644 --- a/kasa/smartprotocol.py +++ b/kasa/smartprotocol.py @@ -47,6 +47,9 @@ class SmartProtocol(BaseProtocol): self._terminal_uuid: str = base64.b64encode(md5(uuid.uuid4().bytes)).decode() self._request_id_generator = SnowflakeId(1, 1) self._query_lock = asyncio.Lock() + self._multi_request_batch_size = ( + self._transport._config.batch_size or self.DEFAULT_MULTI_REQUEST_BATCH_SIZE + ) def get_smart_request(self, method, params=None) -> str: """Get a request message as a string.""" @@ -117,9 +120,16 @@ class SmartProtocol(BaseProtocol): end = len(multi_requests) # Break the requests down as there can be a size limit - step = ( - self._transport._config.batch_size or self.DEFAULT_MULTI_REQUEST_BATCH_SIZE - ) + step = self._multi_request_batch_size + if step == 1: + # If step is 1 do not send request batches + for request in multi_requests: + method = request["method"] + req = self.get_smart_request(method, request["params"]) + resp = await self._transport.send(req) + self._handle_response_error_code(resp, method, raise_on_error=False) + multi_result[method] = resp["result"] + return multi_result for i in range(0, end, step): requests_step = multi_requests[i : i + step] @@ -141,7 +151,21 @@ class SmartProtocol(BaseProtocol): batch_name, pf(response_step), ) - self._handle_response_error_code(response_step, batch_name) + try: + self._handle_response_error_code(response_step, batch_name) + except DeviceError as ex: + # P100 sometimes raises JSON_DECODE_FAIL_ERROR on batched request so + # disable batching + if ( + ex.error_code is SmartErrorCode.JSON_DECODE_FAIL_ERROR + and self._multi_request_batch_size != 1 + ): + self._multi_request_batch_size = 1 + raise _RetryableError( + "JSON Decode failure, multi requests disabled" + ) from ex + raise ex + responses = response_step["result"]["responses"] for response in responses: method = response["method"] diff --git a/kasa/tests/test_smartprotocol.py b/kasa/tests/test_smartprotocol.py index 5ead00d6..d362fd00 100644 --- a/kasa/tests/test_smartprotocol.py +++ b/kasa/tests/test_smartprotocol.py @@ -2,10 +2,9 @@ import logging import pytest -from ..credentials import Credentials -from ..deviceconfig import DeviceConfig from ..exceptions import ( SMART_RETRYABLE_ERRORS, + DeviceError, KasaException, SmartErrorCode, ) @@ -93,7 +92,6 @@ async def test_smart_device_errors_in_multiple_request( async def test_smart_device_multiple_request( dummy_protocol, mocker, request_size, batch_size ): - host = "127.0.0.1" requests = {} mock_response = { "result": {"responses": []}, @@ -109,16 +107,101 @@ async def test_smart_device_multiple_request( send_mock = mocker.patch.object( dummy_protocol._transport, "send", return_value=mock_response ) - config = DeviceConfig( - host, credentials=Credentials("foo", "bar"), batch_size=batch_size - ) - dummy_protocol._transport._config = config + dummy_protocol._multi_request_batch_size = batch_size await dummy_protocol.query(requests, retry_count=0) expected_count = int(request_size / batch_size) + (request_size % batch_size > 0) assert send_mock.call_count == expected_count +async def test_smart_device_multiple_request_json_decode_failure( + dummy_protocol, mocker +): + """Test the logic to disable multiple requests on JSON_DECODE_FAIL_ERROR.""" + requests = {} + mock_responses = [] + + mock_json_error = { + "result": {"responses": []}, + "error_code": SmartErrorCode.JSON_DECODE_FAIL_ERROR.value, + } + for i in range(10): + method = f"get_method_{i}" + requests[method] = {"foo": "bar", "bar": "foo"} + mock_responses.append( + {"method": method, "result": {"great": "success"}, "error_code": 0} + ) + + send_mock = mocker.patch.object( + dummy_protocol._transport, + "send", + side_effect=[mock_json_error, *mock_responses], + ) + dummy_protocol._multi_request_batch_size = 5 + assert dummy_protocol._multi_request_batch_size == 5 + await dummy_protocol.query(requests, retry_count=1) + assert dummy_protocol._multi_request_batch_size == 1 + # Call count should be the first error + number of requests + assert send_mock.call_count == len(requests) + 1 + + +async def test_smart_device_multiple_request_json_decode_failure_twice( + dummy_protocol, mocker +): + """Test the logic to disable multiple requests on JSON_DECODE_FAIL_ERROR.""" + requests = {} + + mock_json_error = { + "result": {"responses": []}, + "error_code": SmartErrorCode.JSON_DECODE_FAIL_ERROR.value, + } + for i in range(10): + method = f"get_method_{i}" + requests[method] = {"foo": "bar", "bar": "foo"} + + send_mock = mocker.patch.object( + dummy_protocol._transport, + "send", + side_effect=[mock_json_error, KasaException], + ) + dummy_protocol._multi_request_batch_size = 5 + with pytest.raises(KasaException): + await dummy_protocol.query(requests, retry_count=1) + assert dummy_protocol._multi_request_batch_size == 1 + + assert send_mock.call_count == 2 + + +async def test_smart_device_multiple_request_non_json_decode_failure( + dummy_protocol, mocker +): + """Test the logic to disable multiple requests on JSON_DECODE_FAIL_ERROR. + + Ensure other exception types behave as expected. + """ + requests = {} + + mock_json_error = { + "result": {"responses": []}, + "error_code": SmartErrorCode.UNKNOWN_METHOD_ERROR.value, + } + for i in range(10): + method = f"get_method_{i}" + requests[method] = {"foo": "bar", "bar": "foo"} + + send_mock = mocker.patch.object( + dummy_protocol._transport, + "send", + side_effect=[mock_json_error, KasaException], + ) + dummy_protocol._multi_request_batch_size = 5 + with pytest.raises(DeviceError): + await dummy_protocol.query(requests, retry_count=1) + assert dummy_protocol._multi_request_batch_size == 5 + + assert send_mock.call_count == 1 + + async def test_childdevicewrapper_unwrapping(dummy_protocol, mocker): """Test that responseData gets unwrapped correctly.""" wrapped_protocol = _ChildProtocolWrapper("dummyid", dummy_protocol)