Enable multiple requests in smartprotocol (#584)

* Enable multiple requests in smartprotocol

* Update following review

* Remove error_code parameter in exceptions
This commit is contained in:
sdb9696
2023-12-20 17:08:04 +00:00
committed by GitHub
parent 20ea6700a5
commit 6819c746d7
12 changed files with 260 additions and 76 deletions

View File

@@ -196,9 +196,13 @@ def parametrize(desc, devices, protocol_filter=None, ids=None):
)
has_emeter = parametrize("has emeter", WITH_EMETER_IOT, protocol_filter={"IOT"})
has_emeter = parametrize("has emeter", WITH_EMETER, protocol_filter={"SMART", "IOT"})
no_emeter = parametrize(
"no emeter", ALL_DEVICES_IOT - WITH_EMETER_IOT, protocol_filter={"SMART", "IOT"}
"no emeter", ALL_DEVICES - WITH_EMETER, protocol_filter={"SMART", "IOT"}
)
has_emeter_iot = parametrize("has emeter iot", WITH_EMETER_IOT, protocol_filter={"IOT"})
no_emeter_iot = parametrize(
"no emeter iot", ALL_DEVICES_IOT - WITH_EMETER_IOT, protocol_filter={"IOT"}
)
bulb = parametrize("bulbs", BULBS, protocol_filter={"SMART", "IOT"})

View File

@@ -1,6 +1,7 @@
import copy
import logging
import re
import warnings
from json import loads as json_loads
from voluptuous import (
@@ -294,9 +295,7 @@ class FakeSmartProtocol(SmartProtocol):
async def query(self, request, retry_count: int = 3):
"""Implement query here so can still patch SmartProtocol.query."""
resp_dict = await self._query(request, retry_count)
if "result" in resp_dict:
return resp_dict["result"]
return {}
return resp_dict
class FakeSmartTransport(BaseTransport):
@@ -306,26 +305,34 @@ class FakeSmartTransport(BaseTransport):
)
self.info = info
@property
def needs_handshake(self) -> bool:
return False
@property
def needs_login(self) -> bool:
return False
async def login(self, request: str) -> None:
pass
async def handshake(self) -> None:
pass
async def send(self, request: str):
request_dict = json_loads(request)
method = request_dict["method"]
params = request_dict["params"]
if method == "multipleRequest":
responses = []
for request in params["requests"]:
response = self._send_request(request) # type: ignore[arg-type]
response["method"] = request["method"] # type: ignore[index]
responses.append(response)
return {"result": {"responses": responses}, "error_code": 0}
else:
return self._send_request(request_dict)
def _send_request(self, request_dict: dict):
method = request_dict["method"]
params = request_dict["params"]
if method == "component_nego" or method[:4] == "get_":
return {"result": self.info[method], "error_code": 0}
if method in self.info:
return {"result": self.info[method], "error_code": 0}
else:
warnings.warn(
UserWarning(
f"Fixture missing expected method {method}, try to regenerate"
),
stacklevel=1,
)
return {"result": {}, "error_code": 0}
elif method[:4] == "set_":
target_method = f"get_{method[4:]}"
self.info[target_method].update(params)

View File

@@ -12,7 +12,12 @@ from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padd
from ..aestransport import AesEncyptionSession, AesTransport
from ..credentials import Credentials
from ..exceptions import SmartDeviceException
from ..exceptions import (
SMART_RETRYABLE_ERRORS,
SMART_TIMEOUT_ERRORS,
SmartDeviceException,
SmartErrorCode,
)
DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}}
@@ -105,6 +110,32 @@ async def test_send(mocker, status_code, error_code, inner_error_code, expectati
assert "result" in res
ERRORS = [e for e in SmartErrorCode if e != 0]
@pytest.mark.parametrize("error_code", ERRORS, ids=lambda e: e.name)
async def test_passthrough_errors(mocker, error_code):
host = "127.0.0.1"
mock_aes_device = MockAesDevice(host, 200, error_code, 0)
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
transport = AesTransport(host=host, credentials=Credentials("foo", "bar"))
transport._handshake_done = True
transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session
transport._login_token = mock_aes_device.token
request = {
"method": "get_device_info",
"params": None,
"request_time_milis": round(time.time() * 1000),
"requestID": 1,
"terminal_uuid": "foobar",
}
with pytest.raises(SmartDeviceException):
await transport.send(json_dumps(request))
class MockAesDevice:
class _mock_response:
def __init__(self, status_code, json: dict):

View File

@@ -2,7 +2,7 @@ import pytest
from kasa import EmeterStatus, SmartDeviceException
from .conftest import has_emeter, no_emeter
from .conftest import has_emeter, has_emeter_iot, no_emeter
from .newfakes import CURRENT_CONSUMPTION_SCHEMA
@@ -20,7 +20,7 @@ async def test_no_emeter(dev):
await dev.erase_emeter_stats()
@has_emeter
@has_emeter_iot
async def test_get_emeter_realtime(dev):
assert dev.has_emeter
@@ -28,7 +28,7 @@ async def test_get_emeter_realtime(dev):
CURRENT_CONSUMPTION_SCHEMA(current_emeter)
@has_emeter
@has_emeter_iot
@pytest.mark.requires_dummy
async def test_get_emeter_daily(dev):
assert dev.has_emeter
@@ -48,7 +48,7 @@ async def test_get_emeter_daily(dev):
assert v * 1000 == v2
@has_emeter
@has_emeter_iot
@pytest.mark.requires_dummy
async def test_get_emeter_monthly(dev):
assert dev.has_emeter
@@ -68,7 +68,7 @@ async def test_get_emeter_monthly(dev):
assert v * 1000 == v2
@has_emeter
@has_emeter_iot
async def test_emeter_status(dev):
assert dev.has_emeter

View File

@@ -26,20 +26,29 @@ class _mock_response:
self.content = content
@pytest.mark.parametrize(
"error, retry_expectation",
[
(Exception("dummy exception"), True),
(SmartDeviceException("dummy exception"), False),
],
ids=("Exception", "SmartDeviceException"),
)
@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol])
@pytest.mark.parametrize("retry_count", [1, 3, 5])
async def test_protocol_retries(mocker, retry_count, protocol_class, transport_class):
async def test_protocol_retries(
mocker, retry_count, protocol_class, transport_class, error, retry_expectation
):
host = "127.0.0.1"
conn = mocker.patch.object(
httpx.AsyncClient, "post", side_effect=Exception("dummy exception")
)
conn = mocker.patch.object(httpx.AsyncClient, "post", side_effect=error)
with pytest.raises(SmartDeviceException):
await protocol_class(host, transport=transport_class(host)).query(
DUMMY_QUERY, retry_count=retry_count
)
assert conn.call_count == retry_count + 1
expected_count = retry_count + 1 if retry_expectation else 1
assert conn.call_count == expected_count
@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
@@ -109,7 +118,7 @@ async def test_protocol_reconnect(mocker, retry_count, protocol_class, transport
response = await protocol_class(host, transport=transport_class(host)).query(
DUMMY_QUERY, retry_count=retry_count
)
assert "result" in response or "great" in response
assert "result" in response or "foobar" in response
assert send_mock.call_count == retry_count

View File

@@ -8,7 +8,7 @@ import kasa
from kasa import Credentials, SmartDevice, SmartDeviceException
from kasa.smartdevice import DeviceType
from .conftest import device_iot, handle_turn_on, has_emeter, no_emeter, turn_on
from .conftest import device_iot, handle_turn_on, has_emeter, no_emeter_iot, turn_on
from .newfakes import PLUG_SCHEMA, TZ_SCHEMA, FakeTransportProtocol
# List of all SmartXXX classes including the SmartDevice base class
@@ -48,7 +48,7 @@ async def test_initial_update_emeter(dev, mocker):
assert spy.call_count == expected_queries + len(dev.children)
@no_emeter
@no_emeter_iot
async def test_initial_update_no_emeter(dev, mocker):
"""Test that the initial update performs second query if emeter is available."""
dev._last_update = None

View File

@@ -0,0 +1,81 @@
import errno
import json
import logging
import secrets
import struct
import sys
import time
from contextlib import nullcontext as does_not_raise
from itertools import chain
import httpx
import pytest
from ..aestransport import AesTransport
from ..credentials import Credentials
from ..exceptions import (
SMART_RETRYABLE_ERRORS,
SMART_TIMEOUT_ERRORS,
SmartDeviceException,
SmartErrorCode,
)
from ..iotprotocol import IotProtocol
from ..klaptransport import KlapEncryptionSession, KlapTransport, _sha256
from ..smartprotocol import SmartProtocol
DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}}
ERRORS = [e for e in SmartErrorCode if e != 0]
@pytest.mark.parametrize("error_code", ERRORS, ids=lambda e: e.name)
async def test_smart_device_errors(mocker, error_code):
host = "127.0.0.1"
mock_response = {"result": {"great": "success"}, "error_code": error_code.value}
mocker.patch.object(AesTransport, "perform_handshake")
mocker.patch.object(AesTransport, "perform_login")
send_mock = mocker.patch.object(AesTransport, "send", return_value=mock_response)
protocol = SmartProtocol(host, transport=AesTransport(host))
with pytest.raises(SmartDeviceException):
await protocol.query(DUMMY_QUERY, retry_count=2)
if error_code in chain(SMART_TIMEOUT_ERRORS, SMART_RETRYABLE_ERRORS):
expected_calls = 3
else:
expected_calls = 1
assert send_mock.call_count == expected_calls
@pytest.mark.parametrize("error_code", ERRORS, ids=lambda e: e.name)
async def test_smart_device_errors_in_multiple_request(mocker, error_code):
host = "127.0.0.1"
mock_response = {
"result": {
"responses": [
{"method": "foobar1", "result": {"great": "success"}, "error_code": 0},
{
"method": "foobar2",
"result": {"great": "success"},
"error_code": error_code.value,
},
{"method": "foobar3", "result": {"great": "success"}, "error_code": 0},
]
},
"error_code": 0,
}
mocker.patch.object(AesTransport, "perform_handshake")
mocker.patch.object(AesTransport, "perform_login")
send_mock = mocker.patch.object(AesTransport, "send", return_value=mock_response)
protocol = SmartProtocol(host, transport=AesTransport(host))
with pytest.raises(SmartDeviceException):
await protocol.query(DUMMY_QUERY, retry_count=2)
if error_code in chain(SMART_TIMEOUT_ERRORS, SMART_RETRYABLE_ERRORS):
expected_calls = 3
else:
expected_calls = 1
assert send_mock.call_count == expected_calls