Enable batching of multiple requests (#662)

* Enable batching of multiple requests

* Test for debug enabled outside of loop

* tweaks

* tweaks

* tweaks

* Update kasa/smartprotocol.py

Co-authored-by: Teemu R. <tpr@iki.fi>

* revert

* Update pyproject.toml

* Add batch test and make batch_size configurable

---------

Co-authored-by: J. Nick Koston <nick@koston.org>
Co-authored-by: Teemu R. <tpr@iki.fi>
This commit is contained in:
Steven B 2024-01-29 10:55:54 +00:00 committed by GitHub
parent cedffc5c9f
commit 9c0a831027
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 100 additions and 31 deletions

View File

@ -146,6 +146,8 @@ class DeviceConfig:
#: Credentials hash can be retrieved from :attr:`SmartDevice.credentials_hash` #: Credentials hash can be retrieved from :attr:`SmartDevice.credentials_hash`
credentials_hash: Optional[str] = None credentials_hash: Optional[str] = None
#: The protocol specific type of connection. Defaults to the legacy type. #: The protocol specific type of connection. Defaults to the legacy type.
batch_size: Optional[int] = None
#: The batch size for protoools supporting multiple request batches.
connection_type: ConnectionType = field( connection_type: ConnectionType = field(
default_factory=lambda: ConnectionType( default_factory=lambda: ConnectionType(
DeviceFamilyType.IotSmartPlugSwitch, EncryptType.Xor, 1 DeviceFamilyType.IotSmartPlugSwitch, EncryptType.Xor, 1

View File

@ -10,7 +10,7 @@ import logging
import time import time
import uuid import uuid
from pprint import pformat as pf from pprint import pformat as pf
from typing import Dict, Union from typing import Any, Dict, Union
from .exceptions import ( from .exceptions import (
SMART_AUTHENTICATION_ERRORS, SMART_AUTHENTICATION_ERRORS,
@ -33,6 +33,7 @@ class SmartProtocol(BaseProtocol):
"""Class for the new TPLink SMART protocol.""" """Class for the new TPLink SMART protocol."""
BACKOFF_SECONDS_AFTER_TIMEOUT = 1 BACKOFF_SECONDS_AFTER_TIMEOUT = 1
DEFAULT_MULTI_REQUEST_BATCH_SIZE = 5
def __init__( def __init__(
self, self,
@ -101,51 +102,81 @@ class SmartProtocol(BaseProtocol):
# make mypy happy, this should never be reached.. # make mypy happy, this should never be reached..
raise SmartDeviceException("Query reached somehow to unreachable") raise SmartDeviceException("Query reached somehow to unreachable")
async def _execute_multiple_query(self, request: Dict, retry_count: int) -> Dict:
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
multi_result: Dict[str, Any] = {}
smart_method = "multipleRequest"
requests = [
{"method": method, "params": params} for method, params in request.items()
]
end = len(requests)
# Break the requests down as there can be a size limit
step = (
self._transport._config.batch_size or self.DEFAULT_MULTI_REQUEST_BATCH_SIZE
)
for i in range(0, end, step):
requests_step = requests[i : i + step]
smart_params = {"requests": requests_step}
smart_request = self.get_smart_request(smart_method, smart_params)
if debug_enabled:
_LOGGER.debug(
"%s multi-request-batch-%s >> %s",
self._host,
i + 1,
pf(smart_request),
)
response_step = await self._transport.send(smart_request)
if debug_enabled:
_LOGGER.debug(
"%s multi-request-batch-%s << %s",
self._host,
i + 1,
pf(response_step),
)
self._handle_response_error_code(response_step)
responses = response_step["result"]["responses"]
for response in responses:
self._handle_response_error_code(response)
result = response.get("result", None)
multi_result[response["method"]] = result
return multi_result
async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> Dict: async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> Dict:
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
if isinstance(request, dict): if isinstance(request, dict):
if len(request) == 1: if len(request) == 1:
smart_method = next(iter(request)) smart_method = next(iter(request))
smart_params = request[smart_method] smart_params = request[smart_method]
else: else:
requests = [] return await self._execute_multiple_query(request, retry_count)
for method, params in request.items():
requests.append({"method": method, "params": params})
smart_method = "multipleRequest"
smart_params = {"requests": requests}
else: else:
smart_method = request smart_method = request
smart_params = None smart_params = None
smart_request = self.get_smart_request(smart_method, smart_params) smart_request = self.get_smart_request(smart_method, smart_params)
_LOGGER.debug( if debug_enabled:
"%s >> %s", _LOGGER.debug(
self._host, "%s >> %s",
_LOGGER.isEnabledFor(logging.DEBUG) and pf(smart_request), self._host,
) pf(smart_request),
)
response_data = await self._transport.send(smart_request) response_data = await self._transport.send(smart_request)
_LOGGER.debug( if debug_enabled:
"%s << %s", _LOGGER.debug(
self._host, "%s << %s",
_LOGGER.isEnabledFor(logging.DEBUG) and pf(response_data), self._host,
) pf(response_data),
)
self._handle_response_error_code(response_data) self._handle_response_error_code(response_data)
if (result := response_data.get("result")) is None: # Single set_ requests do not return a result
# Single set_ requests do not return a result result = response_data.get("result")
return {smart_method: None} return {smart_method: result}
if (responses := result.get("responses")) is None:
return {smart_method: result}
# responses is returned for multipleRequest
multi_result = {}
for response in responses:
self._handle_response_error_code(response)
result = response.get("result", None)
multi_result[response["method"]] = result
return multi_result
def _handle_response_error_code(self, resp_dict: dict): def _handle_response_error_code(self, resp_dict: dict):
error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type] error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]

View File

@ -24,6 +24,10 @@ from ..klaptransport import KlapEncryptionSession, KlapTransport, _sha256
from ..smartprotocol import SmartProtocol from ..smartprotocol import SmartProtocol
DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}} DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}}
DUMMY_MULTIPLE_QUERY = {
"foobar": {"foo": "bar", "bar": "foo"},
"barfoo": {"foo": "bar", "bar": "foo"},
}
ERRORS = [e for e in SmartErrorCode if e != 0] ERRORS = [e for e in SmartErrorCode if e != 0]
@ -74,9 +78,39 @@ async def test_smart_device_errors_in_multiple_request(mocker, error_code):
config = DeviceConfig(host, credentials=Credentials("foo", "bar")) config = DeviceConfig(host, credentials=Credentials("foo", "bar"))
protocol = SmartProtocol(transport=AesTransport(config=config)) protocol = SmartProtocol(transport=AesTransport(config=config))
with pytest.raises(SmartDeviceException): with pytest.raises(SmartDeviceException):
await protocol.query(DUMMY_QUERY, retry_count=2) await protocol.query(DUMMY_MULTIPLE_QUERY, retry_count=2)
if error_code in chain(SMART_TIMEOUT_ERRORS, SMART_RETRYABLE_ERRORS): if error_code in chain(SMART_TIMEOUT_ERRORS, SMART_RETRYABLE_ERRORS):
expected_calls = 3 expected_calls = 3
else: else:
expected_calls = 1 expected_calls = 1
assert send_mock.call_count == expected_calls assert send_mock.call_count == expected_calls
@pytest.mark.parametrize("request_size", [1, 3, 5, 10])
@pytest.mark.parametrize("batch_size", [1, 2, 3, 4, 5])
async def test_smart_device_multiple_request(mocker, request_size, batch_size):
host = "127.0.0.1"
requests = {}
mock_response = {
"result": {"responses": []},
"error_code": 0,
}
for i in range(request_size):
method = f"get_method_{i}"
requests[method] = {"foo": "bar", "bar": "foo"}
mock_response["result"]["responses"].append(
{"method": method, "result": {"great": "success"}, "error_code": 0}
)
mocker.patch.object(AesTransport, "perform_handshake")
mocker.patch.object(AesTransport, "perform_login")
send_mock = mocker.patch.object(AesTransport, "send", return_value=mock_response)
config = DeviceConfig(
host, credentials=Credentials("foo", "bar"), batch_size=batch_size
)
protocol = SmartProtocol(transport=AesTransport(config=config))
await protocol.query(requests, retry_count=0)
expected_count = int(request_size / batch_size) + (request_size % batch_size > 0)
assert send_mock.call_count == expected_count

View File

@ -65,6 +65,8 @@ omit = ["kasa/tests/*"]
[tool.coverage.report] [tool.coverage.report]
exclude_lines = [ exclude_lines = [
# ignore debug logging
"if debug_enabled:",
# Don't complain if tests don't hit defensive assertion code: # Don't complain if tests don't hit defensive assertion code:
"raise AssertionError", "raise AssertionError",
"raise NotImplementedError", "raise NotImplementedError",