mirror of
https://github.com/python-kasa/python-kasa.git
synced 2024-12-22 11:13:34 +00:00
Enable batching of multiple requests (#662)
* Enable batching of multiple requests * Test for debug enabled outside of loop * tweaks * tweaks * tweaks * Update kasa/smartprotocol.py Co-authored-by: Teemu R. <tpr@iki.fi> * revert * Update pyproject.toml * Add batch test and make batch_size configurable --------- Co-authored-by: J. Nick Koston <nick@koston.org> Co-authored-by: Teemu R. <tpr@iki.fi>
This commit is contained in:
parent
cedffc5c9f
commit
9c0a831027
@ -146,6 +146,8 @@ class DeviceConfig:
|
||||
#: Credentials hash can be retrieved from :attr:`SmartDevice.credentials_hash`
|
||||
credentials_hash: Optional[str] = None
|
||||
#: The protocol specific type of connection. Defaults to the legacy type.
|
||||
batch_size: Optional[int] = None
|
||||
#: The batch size for protoools supporting multiple request batches.
|
||||
connection_type: ConnectionType = field(
|
||||
default_factory=lambda: ConnectionType(
|
||||
DeviceFamilyType.IotSmartPlugSwitch, EncryptType.Xor, 1
|
||||
|
@ -10,7 +10,7 @@ import logging
|
||||
import time
|
||||
import uuid
|
||||
from pprint import pformat as pf
|
||||
from typing import Dict, Union
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from .exceptions import (
|
||||
SMART_AUTHENTICATION_ERRORS,
|
||||
@ -33,6 +33,7 @@ class SmartProtocol(BaseProtocol):
|
||||
"""Class for the new TPLink SMART protocol."""
|
||||
|
||||
BACKOFF_SECONDS_AFTER_TIMEOUT = 1
|
||||
DEFAULT_MULTI_REQUEST_BATCH_SIZE = 5
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -101,52 +102,82 @@ class SmartProtocol(BaseProtocol):
|
||||
# make mypy happy, this should never be reached..
|
||||
raise SmartDeviceException("Query reached somehow to unreachable")
|
||||
|
||||
async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> Dict:
|
||||
if isinstance(request, dict):
|
||||
if len(request) == 1:
|
||||
smart_method = next(iter(request))
|
||||
smart_params = request[smart_method]
|
||||
else:
|
||||
requests = []
|
||||
for method, params in request.items():
|
||||
requests.append({"method": method, "params": params})
|
||||
async def _execute_multiple_query(self, request: Dict, retry_count: int) -> Dict:
|
||||
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
|
||||
multi_result: Dict[str, Any] = {}
|
||||
smart_method = "multipleRequest"
|
||||
smart_params = {"requests": requests}
|
||||
else:
|
||||
smart_method = request
|
||||
smart_params = None
|
||||
requests = [
|
||||
{"method": method, "params": params} for method, params in request.items()
|
||||
]
|
||||
|
||||
end = len(requests)
|
||||
# Break the requests down as there can be a size limit
|
||||
step = (
|
||||
self._transport._config.batch_size or self.DEFAULT_MULTI_REQUEST_BATCH_SIZE
|
||||
)
|
||||
for i in range(0, end, step):
|
||||
requests_step = requests[i : i + step]
|
||||
|
||||
smart_params = {"requests": requests_step}
|
||||
smart_request = self.get_smart_request(smart_method, smart_params)
|
||||
if debug_enabled:
|
||||
_LOGGER.debug(
|
||||
"%s >> %s",
|
||||
"%s multi-request-batch-%s >> %s",
|
||||
self._host,
|
||||
_LOGGER.isEnabledFor(logging.DEBUG) and pf(smart_request),
|
||||
i + 1,
|
||||
pf(smart_request),
|
||||
)
|
||||
response_data = await self._transport.send(smart_request)
|
||||
|
||||
response_step = await self._transport.send(smart_request)
|
||||
if debug_enabled:
|
||||
_LOGGER.debug(
|
||||
"%s << %s",
|
||||
"%s multi-request-batch-%s << %s",
|
||||
self._host,
|
||||
_LOGGER.isEnabledFor(logging.DEBUG) and pf(response_data),
|
||||
i + 1,
|
||||
pf(response_step),
|
||||
)
|
||||
|
||||
self._handle_response_error_code(response_data)
|
||||
|
||||
if (result := response_data.get("result")) is None:
|
||||
# Single set_ requests do not return a result
|
||||
return {smart_method: None}
|
||||
|
||||
if (responses := result.get("responses")) is None:
|
||||
return {smart_method: result}
|
||||
|
||||
# responses is returned for multipleRequest
|
||||
multi_result = {}
|
||||
self._handle_response_error_code(response_step)
|
||||
responses = response_step["result"]["responses"]
|
||||
for response in responses:
|
||||
self._handle_response_error_code(response)
|
||||
result = response.get("result", None)
|
||||
multi_result[response["method"]] = result
|
||||
return multi_result
|
||||
|
||||
async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> Dict:
|
||||
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
|
||||
|
||||
if isinstance(request, dict):
|
||||
if len(request) == 1:
|
||||
smart_method = next(iter(request))
|
||||
smart_params = request[smart_method]
|
||||
else:
|
||||
return await self._execute_multiple_query(request, retry_count)
|
||||
else:
|
||||
smart_method = request
|
||||
smart_params = None
|
||||
|
||||
smart_request = self.get_smart_request(smart_method, smart_params)
|
||||
if debug_enabled:
|
||||
_LOGGER.debug(
|
||||
"%s >> %s",
|
||||
self._host,
|
||||
pf(smart_request),
|
||||
)
|
||||
response_data = await self._transport.send(smart_request)
|
||||
|
||||
if debug_enabled:
|
||||
_LOGGER.debug(
|
||||
"%s << %s",
|
||||
self._host,
|
||||
pf(response_data),
|
||||
)
|
||||
|
||||
self._handle_response_error_code(response_data)
|
||||
|
||||
# Single set_ requests do not return a result
|
||||
result = response_data.get("result")
|
||||
return {smart_method: result}
|
||||
|
||||
def _handle_response_error_code(self, resp_dict: dict):
|
||||
error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
|
||||
if error_code == SmartErrorCode.SUCCESS:
|
||||
|
@ -24,6 +24,10 @@ from ..klaptransport import KlapEncryptionSession, KlapTransport, _sha256
|
||||
from ..smartprotocol import SmartProtocol
|
||||
|
||||
DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}}
|
||||
DUMMY_MULTIPLE_QUERY = {
|
||||
"foobar": {"foo": "bar", "bar": "foo"},
|
||||
"barfoo": {"foo": "bar", "bar": "foo"},
|
||||
}
|
||||
ERRORS = [e for e in SmartErrorCode if e != 0]
|
||||
|
||||
|
||||
@ -74,9 +78,39 @@ async def test_smart_device_errors_in_multiple_request(mocker, error_code):
|
||||
config = DeviceConfig(host, credentials=Credentials("foo", "bar"))
|
||||
protocol = SmartProtocol(transport=AesTransport(config=config))
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await protocol.query(DUMMY_QUERY, retry_count=2)
|
||||
await protocol.query(DUMMY_MULTIPLE_QUERY, retry_count=2)
|
||||
if error_code in chain(SMART_TIMEOUT_ERRORS, SMART_RETRYABLE_ERRORS):
|
||||
expected_calls = 3
|
||||
else:
|
||||
expected_calls = 1
|
||||
assert send_mock.call_count == expected_calls
|
||||
|
||||
|
||||
@pytest.mark.parametrize("request_size", [1, 3, 5, 10])
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 3, 4, 5])
|
||||
async def test_smart_device_multiple_request(mocker, request_size, batch_size):
|
||||
host = "127.0.0.1"
|
||||
requests = {}
|
||||
mock_response = {
|
||||
"result": {"responses": []},
|
||||
"error_code": 0,
|
||||
}
|
||||
for i in range(request_size):
|
||||
method = f"get_method_{i}"
|
||||
requests[method] = {"foo": "bar", "bar": "foo"}
|
||||
mock_response["result"]["responses"].append(
|
||||
{"method": method, "result": {"great": "success"}, "error_code": 0}
|
||||
)
|
||||
|
||||
mocker.patch.object(AesTransport, "perform_handshake")
|
||||
mocker.patch.object(AesTransport, "perform_login")
|
||||
|
||||
send_mock = mocker.patch.object(AesTransport, "send", return_value=mock_response)
|
||||
config = DeviceConfig(
|
||||
host, credentials=Credentials("foo", "bar"), batch_size=batch_size
|
||||
)
|
||||
protocol = SmartProtocol(transport=AesTransport(config=config))
|
||||
|
||||
await protocol.query(requests, retry_count=0)
|
||||
expected_count = int(request_size / batch_size) + (request_size % batch_size > 0)
|
||||
assert send_mock.call_count == expected_count
|
||||
|
@ -65,6 +65,8 @@ omit = ["kasa/tests/*"]
|
||||
|
||||
[tool.coverage.report]
|
||||
exclude_lines = [
|
||||
# ignore debug logging
|
||||
"if debug_enabled:",
|
||||
# Don't complain if tests don't hit defensive assertion code:
|
||||
"raise AssertionError",
|
||||
"raise NotImplementedError",
|
||||
|
Loading…
Reference in New Issue
Block a user