diff --git a/kasa/deviceconfig.py b/kasa/deviceconfig.py index 77ce6df4..ffb2988e 100644 --- a/kasa/deviceconfig.py +++ b/kasa/deviceconfig.py @@ -146,6 +146,8 @@ class DeviceConfig: #: Credentials hash can be retrieved from :attr:`SmartDevice.credentials_hash` credentials_hash: Optional[str] = None #: The protocol specific type of connection. Defaults to the legacy type. + batch_size: Optional[int] = None + #: The batch size for protoools supporting multiple request batches. connection_type: ConnectionType = field( default_factory=lambda: ConnectionType( DeviceFamilyType.IotSmartPlugSwitch, EncryptType.Xor, 1 diff --git a/kasa/smartprotocol.py b/kasa/smartprotocol.py index 6f0648ea..9ec2547d 100644 --- a/kasa/smartprotocol.py +++ b/kasa/smartprotocol.py @@ -10,7 +10,7 @@ import logging import time import uuid from pprint import pformat as pf -from typing import Dict, Union +from typing import Any, Dict, Union from .exceptions import ( SMART_AUTHENTICATION_ERRORS, @@ -33,6 +33,7 @@ class SmartProtocol(BaseProtocol): """Class for the new TPLink SMART protocol.""" BACKOFF_SECONDS_AFTER_TIMEOUT = 1 + DEFAULT_MULTI_REQUEST_BATCH_SIZE = 5 def __init__( self, @@ -101,51 +102,81 @@ class SmartProtocol(BaseProtocol): # make mypy happy, this should never be reached.. raise SmartDeviceException("Query reached somehow to unreachable") + async def _execute_multiple_query(self, request: Dict, retry_count: int) -> Dict: + debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG) + multi_result: Dict[str, Any] = {} + smart_method = "multipleRequest" + requests = [ + {"method": method, "params": params} for method, params in request.items() + ] + + end = len(requests) + # Break the requests down as there can be a size limit + step = ( + self._transport._config.batch_size or self.DEFAULT_MULTI_REQUEST_BATCH_SIZE + ) + for i in range(0, end, step): + requests_step = requests[i : i + step] + + smart_params = {"requests": requests_step} + smart_request = self.get_smart_request(smart_method, smart_params) + if debug_enabled: + _LOGGER.debug( + "%s multi-request-batch-%s >> %s", + self._host, + i + 1, + pf(smart_request), + ) + response_step = await self._transport.send(smart_request) + if debug_enabled: + _LOGGER.debug( + "%s multi-request-batch-%s << %s", + self._host, + i + 1, + pf(response_step), + ) + self._handle_response_error_code(response_step) + responses = response_step["result"]["responses"] + for response in responses: + self._handle_response_error_code(response) + result = response.get("result", None) + multi_result[response["method"]] = result + return multi_result + async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> Dict: + debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG) + if isinstance(request, dict): if len(request) == 1: smart_method = next(iter(request)) smart_params = request[smart_method] else: - requests = [] - for method, params in request.items(): - requests.append({"method": method, "params": params}) - smart_method = "multipleRequest" - smart_params = {"requests": requests} + return await self._execute_multiple_query(request, retry_count) else: smart_method = request smart_params = None smart_request = self.get_smart_request(smart_method, smart_params) - _LOGGER.debug( - "%s >> %s", - self._host, - _LOGGER.isEnabledFor(logging.DEBUG) and pf(smart_request), - ) + if debug_enabled: + _LOGGER.debug( + "%s >> %s", + self._host, + pf(smart_request), + ) response_data = await self._transport.send(smart_request) - _LOGGER.debug( - "%s << %s", - self._host, - _LOGGER.isEnabledFor(logging.DEBUG) and pf(response_data), - ) + if debug_enabled: + _LOGGER.debug( + "%s << %s", + self._host, + pf(response_data), + ) self._handle_response_error_code(response_data) - if (result := response_data.get("result")) is None: - # Single set_ requests do not return a result - return {smart_method: None} - - if (responses := result.get("responses")) is None: - return {smart_method: result} - - # responses is returned for multipleRequest - multi_result = {} - for response in responses: - self._handle_response_error_code(response) - result = response.get("result", None) - multi_result[response["method"]] = result - return multi_result + # Single set_ requests do not return a result + result = response_data.get("result") + return {smart_method: result} def _handle_response_error_code(self, resp_dict: dict): error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type] diff --git a/kasa/tests/test_smartprotocol.py b/kasa/tests/test_smartprotocol.py index af2fce4c..9b597b51 100644 --- a/kasa/tests/test_smartprotocol.py +++ b/kasa/tests/test_smartprotocol.py @@ -24,6 +24,10 @@ from ..klaptransport import KlapEncryptionSession, KlapTransport, _sha256 from ..smartprotocol import SmartProtocol DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}} +DUMMY_MULTIPLE_QUERY = { + "foobar": {"foo": "bar", "bar": "foo"}, + "barfoo": {"foo": "bar", "bar": "foo"}, +} ERRORS = [e for e in SmartErrorCode if e != 0] @@ -74,9 +78,39 @@ async def test_smart_device_errors_in_multiple_request(mocker, error_code): config = DeviceConfig(host, credentials=Credentials("foo", "bar")) protocol = SmartProtocol(transport=AesTransport(config=config)) with pytest.raises(SmartDeviceException): - await protocol.query(DUMMY_QUERY, retry_count=2) + await protocol.query(DUMMY_MULTIPLE_QUERY, retry_count=2) if error_code in chain(SMART_TIMEOUT_ERRORS, SMART_RETRYABLE_ERRORS): expected_calls = 3 else: expected_calls = 1 assert send_mock.call_count == expected_calls + + +@pytest.mark.parametrize("request_size", [1, 3, 5, 10]) +@pytest.mark.parametrize("batch_size", [1, 2, 3, 4, 5]) +async def test_smart_device_multiple_request(mocker, request_size, batch_size): + host = "127.0.0.1" + requests = {} + mock_response = { + "result": {"responses": []}, + "error_code": 0, + } + for i in range(request_size): + method = f"get_method_{i}" + requests[method] = {"foo": "bar", "bar": "foo"} + mock_response["result"]["responses"].append( + {"method": method, "result": {"great": "success"}, "error_code": 0} + ) + + mocker.patch.object(AesTransport, "perform_handshake") + mocker.patch.object(AesTransport, "perform_login") + + send_mock = mocker.patch.object(AesTransport, "send", return_value=mock_response) + config = DeviceConfig( + host, credentials=Credentials("foo", "bar"), batch_size=batch_size + ) + protocol = SmartProtocol(transport=AesTransport(config=config)) + + await protocol.query(requests, retry_count=0) + expected_count = int(request_size / batch_size) + (request_size % batch_size > 0) + assert send_mock.call_count == expected_count diff --git a/pyproject.toml b/pyproject.toml index f6092024..9db1474a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,8 @@ omit = ["kasa/tests/*"] [tool.coverage.report] exclude_lines = [ + # ignore debug logging + "if debug_enabled:", # Don't complain if tests don't hit defensive assertion code: "raise AssertionError", "raise NotImplementedError",