Enable batching of multiple requests (#662)

* Enable batching of multiple requests

* Test for debug enabled outside of loop

* tweaks

* tweaks

* tweaks

* Update kasa/smartprotocol.py

Co-authored-by: Teemu R. <tpr@iki.fi>

* revert

* Update pyproject.toml

* Add batch test and make batch_size configurable

---------

Co-authored-by: J. Nick Koston <nick@koston.org>
Co-authored-by: Teemu R. <tpr@iki.fi>
This commit is contained in:
Steven B 2024-01-29 10:55:54 +00:00 committed by GitHub
parent cedffc5c9f
commit 9c0a831027
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 100 additions and 31 deletions

View File

@ -146,6 +146,8 @@ class DeviceConfig:
#: Credentials hash can be retrieved from :attr:`SmartDevice.credentials_hash`
credentials_hash: Optional[str] = None
#: The protocol specific type of connection. Defaults to the legacy type.
batch_size: Optional[int] = None
#: The batch size for protoools supporting multiple request batches.
connection_type: ConnectionType = field(
default_factory=lambda: ConnectionType(
DeviceFamilyType.IotSmartPlugSwitch, EncryptType.Xor, 1

View File

@ -10,7 +10,7 @@ import logging
import time
import uuid
from pprint import pformat as pf
from typing import Dict, Union
from typing import Any, Dict, Union
from .exceptions import (
SMART_AUTHENTICATION_ERRORS,
@ -33,6 +33,7 @@ class SmartProtocol(BaseProtocol):
"""Class for the new TPLink SMART protocol."""
BACKOFF_SECONDS_AFTER_TIMEOUT = 1
DEFAULT_MULTI_REQUEST_BATCH_SIZE = 5
def __init__(
self,
@ -101,51 +102,81 @@ class SmartProtocol(BaseProtocol):
# make mypy happy, this should never be reached..
raise SmartDeviceException("Query reached somehow to unreachable")
async def _execute_multiple_query(self, request: Dict, retry_count: int) -> Dict:
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
multi_result: Dict[str, Any] = {}
smart_method = "multipleRequest"
requests = [
{"method": method, "params": params} for method, params in request.items()
]
end = len(requests)
# Break the requests down as there can be a size limit
step = (
self._transport._config.batch_size or self.DEFAULT_MULTI_REQUEST_BATCH_SIZE
)
for i in range(0, end, step):
requests_step = requests[i : i + step]
smart_params = {"requests": requests_step}
smart_request = self.get_smart_request(smart_method, smart_params)
if debug_enabled:
_LOGGER.debug(
"%s multi-request-batch-%s >> %s",
self._host,
i + 1,
pf(smart_request),
)
response_step = await self._transport.send(smart_request)
if debug_enabled:
_LOGGER.debug(
"%s multi-request-batch-%s << %s",
self._host,
i + 1,
pf(response_step),
)
self._handle_response_error_code(response_step)
responses = response_step["result"]["responses"]
for response in responses:
self._handle_response_error_code(response)
result = response.get("result", None)
multi_result[response["method"]] = result
return multi_result
async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> Dict:
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
if isinstance(request, dict):
if len(request) == 1:
smart_method = next(iter(request))
smart_params = request[smart_method]
else:
requests = []
for method, params in request.items():
requests.append({"method": method, "params": params})
smart_method = "multipleRequest"
smart_params = {"requests": requests}
return await self._execute_multiple_query(request, retry_count)
else:
smart_method = request
smart_params = None
smart_request = self.get_smart_request(smart_method, smart_params)
_LOGGER.debug(
"%s >> %s",
self._host,
_LOGGER.isEnabledFor(logging.DEBUG) and pf(smart_request),
)
if debug_enabled:
_LOGGER.debug(
"%s >> %s",
self._host,
pf(smart_request),
)
response_data = await self._transport.send(smart_request)
_LOGGER.debug(
"%s << %s",
self._host,
_LOGGER.isEnabledFor(logging.DEBUG) and pf(response_data),
)
if debug_enabled:
_LOGGER.debug(
"%s << %s",
self._host,
pf(response_data),
)
self._handle_response_error_code(response_data)
if (result := response_data.get("result")) is None:
# Single set_ requests do not return a result
return {smart_method: None}
if (responses := result.get("responses")) is None:
return {smart_method: result}
# responses is returned for multipleRequest
multi_result = {}
for response in responses:
self._handle_response_error_code(response)
result = response.get("result", None)
multi_result[response["method"]] = result
return multi_result
# Single set_ requests do not return a result
result = response_data.get("result")
return {smart_method: result}
def _handle_response_error_code(self, resp_dict: dict):
error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]

View File

@ -24,6 +24,10 @@ from ..klaptransport import KlapEncryptionSession, KlapTransport, _sha256
from ..smartprotocol import SmartProtocol
DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}}
DUMMY_MULTIPLE_QUERY = {
"foobar": {"foo": "bar", "bar": "foo"},
"barfoo": {"foo": "bar", "bar": "foo"},
}
ERRORS = [e for e in SmartErrorCode if e != 0]
@ -74,9 +78,39 @@ async def test_smart_device_errors_in_multiple_request(mocker, error_code):
config = DeviceConfig(host, credentials=Credentials("foo", "bar"))
protocol = SmartProtocol(transport=AesTransport(config=config))
with pytest.raises(SmartDeviceException):
await protocol.query(DUMMY_QUERY, retry_count=2)
await protocol.query(DUMMY_MULTIPLE_QUERY, retry_count=2)
if error_code in chain(SMART_TIMEOUT_ERRORS, SMART_RETRYABLE_ERRORS):
expected_calls = 3
else:
expected_calls = 1
assert send_mock.call_count == expected_calls
@pytest.mark.parametrize("request_size", [1, 3, 5, 10])
@pytest.mark.parametrize("batch_size", [1, 2, 3, 4, 5])
async def test_smart_device_multiple_request(mocker, request_size, batch_size):
host = "127.0.0.1"
requests = {}
mock_response = {
"result": {"responses": []},
"error_code": 0,
}
for i in range(request_size):
method = f"get_method_{i}"
requests[method] = {"foo": "bar", "bar": "foo"}
mock_response["result"]["responses"].append(
{"method": method, "result": {"great": "success"}, "error_code": 0}
)
mocker.patch.object(AesTransport, "perform_handshake")
mocker.patch.object(AesTransport, "perform_login")
send_mock = mocker.patch.object(AesTransport, "send", return_value=mock_response)
config = DeviceConfig(
host, credentials=Credentials("foo", "bar"), batch_size=batch_size
)
protocol = SmartProtocol(transport=AesTransport(config=config))
await protocol.query(requests, retry_count=0)
expected_count = int(request_size / batch_size) + (request_size % batch_size > 0)
assert send_mock.call_count == expected_count

View File

@ -65,6 +65,8 @@ omit = ["kasa/tests/*"]
[tool.coverage.report]
exclude_lines = [
# ignore debug logging
"if debug_enabled:",
# Don't complain if tests don't hit defensive assertion code:
"raise AssertionError",
"raise NotImplementedError",