mirror of
https://github.com/python-kasa/python-kasa.git
synced 2024-12-22 11:13:34 +00:00
Enable batching of multiple requests (#662)
* Enable batching of multiple requests * Test for debug enabled outside of loop * tweaks * tweaks * tweaks * Update kasa/smartprotocol.py Co-authored-by: Teemu R. <tpr@iki.fi> * revert * Update pyproject.toml * Add batch test and make batch_size configurable --------- Co-authored-by: J. Nick Koston <nick@koston.org> Co-authored-by: Teemu R. <tpr@iki.fi>
This commit is contained in:
parent
cedffc5c9f
commit
9c0a831027
@ -146,6 +146,8 @@ class DeviceConfig:
|
|||||||
#: Credentials hash can be retrieved from :attr:`SmartDevice.credentials_hash`
|
#: Credentials hash can be retrieved from :attr:`SmartDevice.credentials_hash`
|
||||||
credentials_hash: Optional[str] = None
|
credentials_hash: Optional[str] = None
|
||||||
#: The protocol specific type of connection. Defaults to the legacy type.
|
#: The protocol specific type of connection. Defaults to the legacy type.
|
||||||
|
batch_size: Optional[int] = None
|
||||||
|
#: The batch size for protoools supporting multiple request batches.
|
||||||
connection_type: ConnectionType = field(
|
connection_type: ConnectionType = field(
|
||||||
default_factory=lambda: ConnectionType(
|
default_factory=lambda: ConnectionType(
|
||||||
DeviceFamilyType.IotSmartPlugSwitch, EncryptType.Xor, 1
|
DeviceFamilyType.IotSmartPlugSwitch, EncryptType.Xor, 1
|
||||||
|
@ -10,7 +10,7 @@ import logging
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from pprint import pformat as pf
|
from pprint import pformat as pf
|
||||||
from typing import Dict, Union
|
from typing import Any, Dict, Union
|
||||||
|
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
SMART_AUTHENTICATION_ERRORS,
|
SMART_AUTHENTICATION_ERRORS,
|
||||||
@ -33,6 +33,7 @@ class SmartProtocol(BaseProtocol):
|
|||||||
"""Class for the new TPLink SMART protocol."""
|
"""Class for the new TPLink SMART protocol."""
|
||||||
|
|
||||||
BACKOFF_SECONDS_AFTER_TIMEOUT = 1
|
BACKOFF_SECONDS_AFTER_TIMEOUT = 1
|
||||||
|
DEFAULT_MULTI_REQUEST_BATCH_SIZE = 5
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -101,51 +102,81 @@ class SmartProtocol(BaseProtocol):
|
|||||||
# make mypy happy, this should never be reached..
|
# make mypy happy, this should never be reached..
|
||||||
raise SmartDeviceException("Query reached somehow to unreachable")
|
raise SmartDeviceException("Query reached somehow to unreachable")
|
||||||
|
|
||||||
|
async def _execute_multiple_query(self, request: Dict, retry_count: int) -> Dict:
|
||||||
|
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
|
||||||
|
multi_result: Dict[str, Any] = {}
|
||||||
|
smart_method = "multipleRequest"
|
||||||
|
requests = [
|
||||||
|
{"method": method, "params": params} for method, params in request.items()
|
||||||
|
]
|
||||||
|
|
||||||
|
end = len(requests)
|
||||||
|
# Break the requests down as there can be a size limit
|
||||||
|
step = (
|
||||||
|
self._transport._config.batch_size or self.DEFAULT_MULTI_REQUEST_BATCH_SIZE
|
||||||
|
)
|
||||||
|
for i in range(0, end, step):
|
||||||
|
requests_step = requests[i : i + step]
|
||||||
|
|
||||||
|
smart_params = {"requests": requests_step}
|
||||||
|
smart_request = self.get_smart_request(smart_method, smart_params)
|
||||||
|
if debug_enabled:
|
||||||
|
_LOGGER.debug(
|
||||||
|
"%s multi-request-batch-%s >> %s",
|
||||||
|
self._host,
|
||||||
|
i + 1,
|
||||||
|
pf(smart_request),
|
||||||
|
)
|
||||||
|
response_step = await self._transport.send(smart_request)
|
||||||
|
if debug_enabled:
|
||||||
|
_LOGGER.debug(
|
||||||
|
"%s multi-request-batch-%s << %s",
|
||||||
|
self._host,
|
||||||
|
i + 1,
|
||||||
|
pf(response_step),
|
||||||
|
)
|
||||||
|
self._handle_response_error_code(response_step)
|
||||||
|
responses = response_step["result"]["responses"]
|
||||||
|
for response in responses:
|
||||||
|
self._handle_response_error_code(response)
|
||||||
|
result = response.get("result", None)
|
||||||
|
multi_result[response["method"]] = result
|
||||||
|
return multi_result
|
||||||
|
|
||||||
async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> Dict:
|
async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> Dict:
|
||||||
|
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
|
||||||
|
|
||||||
if isinstance(request, dict):
|
if isinstance(request, dict):
|
||||||
if len(request) == 1:
|
if len(request) == 1:
|
||||||
smart_method = next(iter(request))
|
smart_method = next(iter(request))
|
||||||
smart_params = request[smart_method]
|
smart_params = request[smart_method]
|
||||||
else:
|
else:
|
||||||
requests = []
|
return await self._execute_multiple_query(request, retry_count)
|
||||||
for method, params in request.items():
|
|
||||||
requests.append({"method": method, "params": params})
|
|
||||||
smart_method = "multipleRequest"
|
|
||||||
smart_params = {"requests": requests}
|
|
||||||
else:
|
else:
|
||||||
smart_method = request
|
smart_method = request
|
||||||
smart_params = None
|
smart_params = None
|
||||||
|
|
||||||
smart_request = self.get_smart_request(smart_method, smart_params)
|
smart_request = self.get_smart_request(smart_method, smart_params)
|
||||||
_LOGGER.debug(
|
if debug_enabled:
|
||||||
"%s >> %s",
|
_LOGGER.debug(
|
||||||
self._host,
|
"%s >> %s",
|
||||||
_LOGGER.isEnabledFor(logging.DEBUG) and pf(smart_request),
|
self._host,
|
||||||
)
|
pf(smart_request),
|
||||||
|
)
|
||||||
response_data = await self._transport.send(smart_request)
|
response_data = await self._transport.send(smart_request)
|
||||||
|
|
||||||
_LOGGER.debug(
|
if debug_enabled:
|
||||||
"%s << %s",
|
_LOGGER.debug(
|
||||||
self._host,
|
"%s << %s",
|
||||||
_LOGGER.isEnabledFor(logging.DEBUG) and pf(response_data),
|
self._host,
|
||||||
)
|
pf(response_data),
|
||||||
|
)
|
||||||
|
|
||||||
self._handle_response_error_code(response_data)
|
self._handle_response_error_code(response_data)
|
||||||
|
|
||||||
if (result := response_data.get("result")) is None:
|
# Single set_ requests do not return a result
|
||||||
# Single set_ requests do not return a result
|
result = response_data.get("result")
|
||||||
return {smart_method: None}
|
return {smart_method: result}
|
||||||
|
|
||||||
if (responses := result.get("responses")) is None:
|
|
||||||
return {smart_method: result}
|
|
||||||
|
|
||||||
# responses is returned for multipleRequest
|
|
||||||
multi_result = {}
|
|
||||||
for response in responses:
|
|
||||||
self._handle_response_error_code(response)
|
|
||||||
result = response.get("result", None)
|
|
||||||
multi_result[response["method"]] = result
|
|
||||||
return multi_result
|
|
||||||
|
|
||||||
def _handle_response_error_code(self, resp_dict: dict):
|
def _handle_response_error_code(self, resp_dict: dict):
|
||||||
error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
|
error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
|
||||||
|
@ -24,6 +24,10 @@ from ..klaptransport import KlapEncryptionSession, KlapTransport, _sha256
|
|||||||
from ..smartprotocol import SmartProtocol
|
from ..smartprotocol import SmartProtocol
|
||||||
|
|
||||||
DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}}
|
DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}}
|
||||||
|
DUMMY_MULTIPLE_QUERY = {
|
||||||
|
"foobar": {"foo": "bar", "bar": "foo"},
|
||||||
|
"barfoo": {"foo": "bar", "bar": "foo"},
|
||||||
|
}
|
||||||
ERRORS = [e for e in SmartErrorCode if e != 0]
|
ERRORS = [e for e in SmartErrorCode if e != 0]
|
||||||
|
|
||||||
|
|
||||||
@ -74,9 +78,39 @@ async def test_smart_device_errors_in_multiple_request(mocker, error_code):
|
|||||||
config = DeviceConfig(host, credentials=Credentials("foo", "bar"))
|
config = DeviceConfig(host, credentials=Credentials("foo", "bar"))
|
||||||
protocol = SmartProtocol(transport=AesTransport(config=config))
|
protocol = SmartProtocol(transport=AesTransport(config=config))
|
||||||
with pytest.raises(SmartDeviceException):
|
with pytest.raises(SmartDeviceException):
|
||||||
await protocol.query(DUMMY_QUERY, retry_count=2)
|
await protocol.query(DUMMY_MULTIPLE_QUERY, retry_count=2)
|
||||||
if error_code in chain(SMART_TIMEOUT_ERRORS, SMART_RETRYABLE_ERRORS):
|
if error_code in chain(SMART_TIMEOUT_ERRORS, SMART_RETRYABLE_ERRORS):
|
||||||
expected_calls = 3
|
expected_calls = 3
|
||||||
else:
|
else:
|
||||||
expected_calls = 1
|
expected_calls = 1
|
||||||
assert send_mock.call_count == expected_calls
|
assert send_mock.call_count == expected_calls
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("request_size", [1, 3, 5, 10])
|
||||||
|
@pytest.mark.parametrize("batch_size", [1, 2, 3, 4, 5])
|
||||||
|
async def test_smart_device_multiple_request(mocker, request_size, batch_size):
|
||||||
|
host = "127.0.0.1"
|
||||||
|
requests = {}
|
||||||
|
mock_response = {
|
||||||
|
"result": {"responses": []},
|
||||||
|
"error_code": 0,
|
||||||
|
}
|
||||||
|
for i in range(request_size):
|
||||||
|
method = f"get_method_{i}"
|
||||||
|
requests[method] = {"foo": "bar", "bar": "foo"}
|
||||||
|
mock_response["result"]["responses"].append(
|
||||||
|
{"method": method, "result": {"great": "success"}, "error_code": 0}
|
||||||
|
)
|
||||||
|
|
||||||
|
mocker.patch.object(AesTransport, "perform_handshake")
|
||||||
|
mocker.patch.object(AesTransport, "perform_login")
|
||||||
|
|
||||||
|
send_mock = mocker.patch.object(AesTransport, "send", return_value=mock_response)
|
||||||
|
config = DeviceConfig(
|
||||||
|
host, credentials=Credentials("foo", "bar"), batch_size=batch_size
|
||||||
|
)
|
||||||
|
protocol = SmartProtocol(transport=AesTransport(config=config))
|
||||||
|
|
||||||
|
await protocol.query(requests, retry_count=0)
|
||||||
|
expected_count = int(request_size / batch_size) + (request_size % batch_size > 0)
|
||||||
|
assert send_mock.call_count == expected_count
|
||||||
|
@ -65,6 +65,8 @@ omit = ["kasa/tests/*"]
|
|||||||
|
|
||||||
[tool.coverage.report]
|
[tool.coverage.report]
|
||||||
exclude_lines = [
|
exclude_lines = [
|
||||||
|
# ignore debug logging
|
||||||
|
"if debug_enabled:",
|
||||||
# Don't complain if tests don't hit defensive assertion code:
|
# Don't complain if tests don't hit defensive assertion code:
|
||||||
"raise AssertionError",
|
"raise AssertionError",
|
||||||
"raise NotImplementedError",
|
"raise NotImplementedError",
|
||||||
|
Loading…
Reference in New Issue
Block a user