Disable multi requests on json decode error during multi-request (#1025)

Issue affecting some P100 devices
This commit is contained in:
Steven B 2024-07-01 13:57:13 +01:00 committed by GitHub
parent b31a2ede7f
commit 8d1a4a4229
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 118 additions and 11 deletions

View File

@ -47,6 +47,9 @@ class SmartProtocol(BaseProtocol):
self._terminal_uuid: str = base64.b64encode(md5(uuid.uuid4().bytes)).decode() self._terminal_uuid: str = base64.b64encode(md5(uuid.uuid4().bytes)).decode()
self._request_id_generator = SnowflakeId(1, 1) self._request_id_generator = SnowflakeId(1, 1)
self._query_lock = asyncio.Lock() self._query_lock = asyncio.Lock()
self._multi_request_batch_size = (
self._transport._config.batch_size or self.DEFAULT_MULTI_REQUEST_BATCH_SIZE
)
def get_smart_request(self, method, params=None) -> str: def get_smart_request(self, method, params=None) -> str:
"""Get a request message as a string.""" """Get a request message as a string."""
@ -117,9 +120,16 @@ class SmartProtocol(BaseProtocol):
end = len(multi_requests) end = len(multi_requests)
# Break the requests down as there can be a size limit # Break the requests down as there can be a size limit
step = ( step = self._multi_request_batch_size
self._transport._config.batch_size or self.DEFAULT_MULTI_REQUEST_BATCH_SIZE if step == 1:
) # If step is 1 do not send request batches
for request in multi_requests:
method = request["method"]
req = self.get_smart_request(method, request["params"])
resp = await self._transport.send(req)
self._handle_response_error_code(resp, method, raise_on_error=False)
multi_result[method] = resp["result"]
return multi_result
for i in range(0, end, step): for i in range(0, end, step):
requests_step = multi_requests[i : i + step] requests_step = multi_requests[i : i + step]
@ -141,7 +151,21 @@ class SmartProtocol(BaseProtocol):
batch_name, batch_name,
pf(response_step), pf(response_step),
) )
self._handle_response_error_code(response_step, batch_name) try:
self._handle_response_error_code(response_step, batch_name)
except DeviceError as ex:
# P100 sometimes raises JSON_DECODE_FAIL_ERROR on batched request so
# disable batching
if (
ex.error_code is SmartErrorCode.JSON_DECODE_FAIL_ERROR
and self._multi_request_batch_size != 1
):
self._multi_request_batch_size = 1
raise _RetryableError(
"JSON Decode failure, multi requests disabled"
) from ex
raise ex
responses = response_step["result"]["responses"] responses = response_step["result"]["responses"]
for response in responses: for response in responses:
method = response["method"] method = response["method"]

View File

@ -2,10 +2,9 @@ import logging
import pytest import pytest
from ..credentials import Credentials
from ..deviceconfig import DeviceConfig
from ..exceptions import ( from ..exceptions import (
SMART_RETRYABLE_ERRORS, SMART_RETRYABLE_ERRORS,
DeviceError,
KasaException, KasaException,
SmartErrorCode, SmartErrorCode,
) )
@ -93,7 +92,6 @@ async def test_smart_device_errors_in_multiple_request(
async def test_smart_device_multiple_request( async def test_smart_device_multiple_request(
dummy_protocol, mocker, request_size, batch_size dummy_protocol, mocker, request_size, batch_size
): ):
host = "127.0.0.1"
requests = {} requests = {}
mock_response = { mock_response = {
"result": {"responses": []}, "result": {"responses": []},
@ -109,16 +107,101 @@ async def test_smart_device_multiple_request(
send_mock = mocker.patch.object( send_mock = mocker.patch.object(
dummy_protocol._transport, "send", return_value=mock_response dummy_protocol._transport, "send", return_value=mock_response
) )
config = DeviceConfig( dummy_protocol._multi_request_batch_size = batch_size
host, credentials=Credentials("foo", "bar"), batch_size=batch_size
)
dummy_protocol._transport._config = config
await dummy_protocol.query(requests, retry_count=0) await dummy_protocol.query(requests, retry_count=0)
expected_count = int(request_size / batch_size) + (request_size % batch_size > 0) expected_count = int(request_size / batch_size) + (request_size % batch_size > 0)
assert send_mock.call_count == expected_count assert send_mock.call_count == expected_count
async def test_smart_device_multiple_request_json_decode_failure(
dummy_protocol, mocker
):
"""Test the logic to disable multiple requests on JSON_DECODE_FAIL_ERROR."""
requests = {}
mock_responses = []
mock_json_error = {
"result": {"responses": []},
"error_code": SmartErrorCode.JSON_DECODE_FAIL_ERROR.value,
}
for i in range(10):
method = f"get_method_{i}"
requests[method] = {"foo": "bar", "bar": "foo"}
mock_responses.append(
{"method": method, "result": {"great": "success"}, "error_code": 0}
)
send_mock = mocker.patch.object(
dummy_protocol._transport,
"send",
side_effect=[mock_json_error, *mock_responses],
)
dummy_protocol._multi_request_batch_size = 5
assert dummy_protocol._multi_request_batch_size == 5
await dummy_protocol.query(requests, retry_count=1)
assert dummy_protocol._multi_request_batch_size == 1
# Call count should be the first error + number of requests
assert send_mock.call_count == len(requests) + 1
async def test_smart_device_multiple_request_json_decode_failure_twice(
dummy_protocol, mocker
):
"""Test the logic to disable multiple requests on JSON_DECODE_FAIL_ERROR."""
requests = {}
mock_json_error = {
"result": {"responses": []},
"error_code": SmartErrorCode.JSON_DECODE_FAIL_ERROR.value,
}
for i in range(10):
method = f"get_method_{i}"
requests[method] = {"foo": "bar", "bar": "foo"}
send_mock = mocker.patch.object(
dummy_protocol._transport,
"send",
side_effect=[mock_json_error, KasaException],
)
dummy_protocol._multi_request_batch_size = 5
with pytest.raises(KasaException):
await dummy_protocol.query(requests, retry_count=1)
assert dummy_protocol._multi_request_batch_size == 1
assert send_mock.call_count == 2
async def test_smart_device_multiple_request_non_json_decode_failure(
dummy_protocol, mocker
):
"""Test the logic to disable multiple requests on JSON_DECODE_FAIL_ERROR.
Ensure other exception types behave as expected.
"""
requests = {}
mock_json_error = {
"result": {"responses": []},
"error_code": SmartErrorCode.UNKNOWN_METHOD_ERROR.value,
}
for i in range(10):
method = f"get_method_{i}"
requests[method] = {"foo": "bar", "bar": "foo"}
send_mock = mocker.patch.object(
dummy_protocol._transport,
"send",
side_effect=[mock_json_error, KasaException],
)
dummy_protocol._multi_request_batch_size = 5
with pytest.raises(DeviceError):
await dummy_protocol.query(requests, retry_count=1)
assert dummy_protocol._multi_request_batch_size == 5
assert send_mock.call_count == 1
async def test_childdevicewrapper_unwrapping(dummy_protocol, mocker): async def test_childdevicewrapper_unwrapping(dummy_protocol, mocker):
"""Test that responseData gets unwrapped correctly.""" """Test that responseData gets unwrapped correctly."""
wrapped_protocol = _ChildProtocolWrapper("dummyid", dummy_protocol) wrapped_protocol = _ChildProtocolWrapper("dummyid", dummy_protocol)