mirror of
https://github.com/python-kasa/python-kasa.git
synced 2024-12-22 19:23:34 +00:00
Enable multiple requests in smartprotocol (#584)
* Enable multiple requests in smartprotocol * Update following review * Remove error_code parameter in exceptions
This commit is contained in:
parent
20ea6700a5
commit
6819c746d7
@ -125,19 +125,19 @@ class AesTransport(BaseTransport):
|
|||||||
return resp.status_code, response_data
|
return resp.status_code, response_data
|
||||||
|
|
||||||
def _handle_response_error_code(self, resp_dict: dict, msg: str):
|
def _handle_response_error_code(self, resp_dict: dict, msg: str):
|
||||||
if (
|
error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
|
||||||
error_code := SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
|
if error_code == SmartErrorCode.SUCCESS:
|
||||||
) != SmartErrorCode.SUCCESS:
|
return
|
||||||
msg = f"{msg}: {self._host}: {error_code.name}({error_code.value})"
|
msg = f"{msg}: {self._host}: {error_code.name}({error_code.value})"
|
||||||
if error_code in SMART_TIMEOUT_ERRORS:
|
if error_code in SMART_TIMEOUT_ERRORS:
|
||||||
raise TimeoutException(msg)
|
raise TimeoutException(msg)
|
||||||
if error_code in SMART_RETRYABLE_ERRORS:
|
if error_code in SMART_RETRYABLE_ERRORS:
|
||||||
raise RetryableException(msg)
|
raise RetryableException(msg)
|
||||||
if error_code in SMART_AUTHENTICATION_ERRORS:
|
if error_code in SMART_AUTHENTICATION_ERRORS:
|
||||||
self._handshake_done = False
|
self._handshake_done = False
|
||||||
self._login_token = None
|
self._login_token = None
|
||||||
raise AuthenticationException(msg)
|
raise AuthenticationException(msg)
|
||||||
raise SmartDeviceException(msg)
|
raise SmartDeviceException(msg)
|
||||||
|
|
||||||
async def send_secure_passthrough(self, request: str):
|
async def send_secure_passthrough(self, request: str):
|
||||||
"""Send encrypted message as passthrough."""
|
"""Send encrypted message as passthrough."""
|
||||||
|
@ -62,6 +62,13 @@ class IotProtocol(TPLinkProtocol):
|
|||||||
"Unable to authenticate with %s, not retrying", self._host
|
"Unable to authenticate with %s, not retrying", self._host
|
||||||
)
|
)
|
||||||
raise auex
|
raise auex
|
||||||
|
except SmartDeviceException as ex:
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Unable to connect to the device: %s, not retrying: %s",
|
||||||
|
self._host,
|
||||||
|
ex,
|
||||||
|
)
|
||||||
|
raise ex
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
await self.close()
|
await self.close()
|
||||||
if retry >= retry_count:
|
if retry >= retry_count:
|
||||||
|
@ -62,26 +62,7 @@ class SmartProtocol(TPLinkProtocol):
|
|||||||
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
||||||
"""Query the device retrying for retry_count on failure."""
|
"""Query the device retrying for retry_count on failure."""
|
||||||
async with self._query_lock:
|
async with self._query_lock:
|
||||||
resp_dict = await self._query(request, retry_count)
|
return await self._query(request, retry_count)
|
||||||
|
|
||||||
if (
|
|
||||||
error_code := SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
|
|
||||||
) != SmartErrorCode.SUCCESS:
|
|
||||||
msg = (
|
|
||||||
f"Error querying device: {self._host}: "
|
|
||||||
+ f"{error_code.name}({error_code.value})"
|
|
||||||
)
|
|
||||||
if error_code in SMART_TIMEOUT_ERRORS:
|
|
||||||
raise TimeoutException(msg)
|
|
||||||
if error_code in SMART_RETRYABLE_ERRORS:
|
|
||||||
raise RetryableException(msg)
|
|
||||||
if error_code in SMART_AUTHENTICATION_ERRORS:
|
|
||||||
raise AuthenticationException(msg)
|
|
||||||
raise SmartDeviceException(msg)
|
|
||||||
|
|
||||||
if "result" in resp_dict:
|
|
||||||
return resp_dict["result"]
|
|
||||||
return {}
|
|
||||||
|
|
||||||
async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
||||||
for retry in range(retry_count + 1):
|
for retry in range(retry_count + 1):
|
||||||
@ -128,6 +109,11 @@ class SmartProtocol(TPLinkProtocol):
|
|||||||
raise ex
|
raise ex
|
||||||
await asyncio.sleep(self.SLEEP_SECONDS_AFTER_TIMEOUT)
|
await asyncio.sleep(self.SLEEP_SECONDS_AFTER_TIMEOUT)
|
||||||
continue
|
continue
|
||||||
|
except SmartDeviceException as ex:
|
||||||
|
# Transport would have raised RetryableException if retry makes sense.
|
||||||
|
await self.close()
|
||||||
|
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||||
|
raise ex
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
if retry >= retry_count:
|
if retry >= retry_count:
|
||||||
await self.close()
|
await self.close()
|
||||||
@ -145,8 +131,15 @@ class SmartProtocol(TPLinkProtocol):
|
|||||||
|
|
||||||
async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> Dict:
|
async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> Dict:
|
||||||
if isinstance(request, dict):
|
if isinstance(request, dict):
|
||||||
smart_method = next(iter(request))
|
if len(request) == 1:
|
||||||
smart_params = request[smart_method]
|
smart_method = next(iter(request))
|
||||||
|
smart_params = request[smart_method]
|
||||||
|
else:
|
||||||
|
requests = []
|
||||||
|
for method, params in request.items():
|
||||||
|
requests.append({"method": method, "params": params})
|
||||||
|
smart_method = "multipleRequest"
|
||||||
|
smart_params = {"requests": requests}
|
||||||
else:
|
else:
|
||||||
smart_method = request
|
smart_method = request
|
||||||
smart_params = None
|
smart_params = None
|
||||||
@ -165,7 +158,40 @@ class SmartProtocol(TPLinkProtocol):
|
|||||||
_LOGGER.isEnabledFor(logging.DEBUG) and pf(response_data),
|
_LOGGER.isEnabledFor(logging.DEBUG) and pf(response_data),
|
||||||
)
|
)
|
||||||
|
|
||||||
return response_data
|
self._handle_response_error_code(response_data)
|
||||||
|
|
||||||
|
if (result := response_data.get("result")) is None:
|
||||||
|
# Single set_ requests do not return a result
|
||||||
|
return {smart_method: None}
|
||||||
|
|
||||||
|
if (responses := result.get("responses")) is None:
|
||||||
|
return {smart_method: result}
|
||||||
|
|
||||||
|
# responses is returned for multipleRequest
|
||||||
|
multi_result = {}
|
||||||
|
for response in responses:
|
||||||
|
self._handle_response_error_code(response)
|
||||||
|
result = response.get("result", None)
|
||||||
|
multi_result[response["method"]] = result
|
||||||
|
return multi_result
|
||||||
|
|
||||||
|
def _handle_response_error_code(self, resp_dict: dict):
|
||||||
|
error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
|
||||||
|
if error_code == SmartErrorCode.SUCCESS:
|
||||||
|
return
|
||||||
|
msg = (
|
||||||
|
f"Error querying device: {self._host}: "
|
||||||
|
+ f"{error_code.name}({error_code.value})"
|
||||||
|
)
|
||||||
|
if method := resp_dict.get("method"):
|
||||||
|
msg += f" for method: {method}"
|
||||||
|
if error_code in SMART_TIMEOUT_ERRORS:
|
||||||
|
raise TimeoutException(msg)
|
||||||
|
if error_code in SMART_RETRYABLE_ERRORS:
|
||||||
|
raise RetryableException(msg)
|
||||||
|
if error_code in SMART_AUTHENTICATION_ERRORS:
|
||||||
|
raise AuthenticationException(msg)
|
||||||
|
raise SmartDeviceException(msg)
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
"""Close the protocol."""
|
"""Close the protocol."""
|
||||||
|
@ -41,11 +41,18 @@ class TapoDevice(SmartDevice):
|
|||||||
raise AuthenticationException("Tapo plug requires authentication.")
|
raise AuthenticationException("Tapo plug requires authentication.")
|
||||||
|
|
||||||
if self._components is None:
|
if self._components is None:
|
||||||
self._components = await self.protocol.query("component_nego")
|
resp = await self.protocol.query("component_nego")
|
||||||
|
self._components = resp["component_nego"]
|
||||||
|
|
||||||
self._info = await self.protocol.query("get_device_info")
|
req = {
|
||||||
self._usage = await self.protocol.query("get_device_usage")
|
"get_device_info": None,
|
||||||
self._time = await self.protocol.query("get_device_time")
|
"get_device_usage": None,
|
||||||
|
"get_device_time": None,
|
||||||
|
}
|
||||||
|
resp = await self.protocol.query(req)
|
||||||
|
self._info = resp["get_device_info"]
|
||||||
|
self._usage = resp["get_device_usage"]
|
||||||
|
self._time = resp["get_device_time"]
|
||||||
|
|
||||||
self._last_update = self._data = {
|
self._last_update = self._data = {
|
||||||
"components": self._components,
|
"components": self._components,
|
||||||
|
@ -39,8 +39,13 @@ class TapoPlug(TapoDevice):
|
|||||||
"""Call the device endpoint and update the device data."""
|
"""Call the device endpoint and update the device data."""
|
||||||
await super().update(update_children)
|
await super().update(update_children)
|
||||||
|
|
||||||
self._energy = await self.protocol.query("get_energy_usage")
|
req = {
|
||||||
self._emeter = await self.protocol.query("get_current_power")
|
"get_energy_usage": None,
|
||||||
|
"get_current_power": None,
|
||||||
|
}
|
||||||
|
resp = await self.protocol.query(req)
|
||||||
|
self._energy = resp["get_energy_usage"]
|
||||||
|
self._emeter = resp["get_current_power"]
|
||||||
|
|
||||||
self._data["energy"] = self._energy
|
self._data["energy"] = self._energy
|
||||||
self._data["emeter"] = self._emeter
|
self._data["emeter"] = self._emeter
|
||||||
@ -71,6 +76,13 @@ class TapoPlug(TapoDevice):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def get_emeter_realtime(self) -> EmeterStatus:
|
||||||
|
"""Retrieve current energy readings."""
|
||||||
|
self._verify_emeter()
|
||||||
|
resp = await self.protocol.query("get_energy_usage")
|
||||||
|
self._energy = resp["get_energy_usage"]
|
||||||
|
return self.emeter_realtime
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def emeter_today(self) -> Optional[float]:
|
def emeter_today(self) -> Optional[float]:
|
||||||
"""Get the emeter value for today."""
|
"""Get the emeter value for today."""
|
||||||
|
@ -196,9 +196,13 @@ def parametrize(desc, devices, protocol_filter=None, ids=None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
has_emeter = parametrize("has emeter", WITH_EMETER_IOT, protocol_filter={"IOT"})
|
has_emeter = parametrize("has emeter", WITH_EMETER, protocol_filter={"SMART", "IOT"})
|
||||||
no_emeter = parametrize(
|
no_emeter = parametrize(
|
||||||
"no emeter", ALL_DEVICES_IOT - WITH_EMETER_IOT, protocol_filter={"SMART", "IOT"}
|
"no emeter", ALL_DEVICES - WITH_EMETER, protocol_filter={"SMART", "IOT"}
|
||||||
|
)
|
||||||
|
has_emeter_iot = parametrize("has emeter iot", WITH_EMETER_IOT, protocol_filter={"IOT"})
|
||||||
|
no_emeter_iot = parametrize(
|
||||||
|
"no emeter iot", ALL_DEVICES_IOT - WITH_EMETER_IOT, protocol_filter={"IOT"}
|
||||||
)
|
)
|
||||||
|
|
||||||
bulb = parametrize("bulbs", BULBS, protocol_filter={"SMART", "IOT"})
|
bulb = parametrize("bulbs", BULBS, protocol_filter={"SMART", "IOT"})
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
import warnings
|
||||||
from json import loads as json_loads
|
from json import loads as json_loads
|
||||||
|
|
||||||
from voluptuous import (
|
from voluptuous import (
|
||||||
@ -294,9 +295,7 @@ class FakeSmartProtocol(SmartProtocol):
|
|||||||
async def query(self, request, retry_count: int = 3):
|
async def query(self, request, retry_count: int = 3):
|
||||||
"""Implement query here so can still patch SmartProtocol.query."""
|
"""Implement query here so can still patch SmartProtocol.query."""
|
||||||
resp_dict = await self._query(request, retry_count)
|
resp_dict = await self._query(request, retry_count)
|
||||||
if "result" in resp_dict:
|
return resp_dict
|
||||||
return resp_dict["result"]
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
class FakeSmartTransport(BaseTransport):
|
class FakeSmartTransport(BaseTransport):
|
||||||
@ -306,26 +305,34 @@ class FakeSmartTransport(BaseTransport):
|
|||||||
)
|
)
|
||||||
self.info = info
|
self.info = info
|
||||||
|
|
||||||
@property
|
|
||||||
def needs_handshake(self) -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def needs_login(self) -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def login(self, request: str) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def handshake(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def send(self, request: str):
|
async def send(self, request: str):
|
||||||
request_dict = json_loads(request)
|
request_dict = json_loads(request)
|
||||||
|
method = request_dict["method"]
|
||||||
|
params = request_dict["params"]
|
||||||
|
if method == "multipleRequest":
|
||||||
|
responses = []
|
||||||
|
for request in params["requests"]:
|
||||||
|
response = self._send_request(request) # type: ignore[arg-type]
|
||||||
|
response["method"] = request["method"] # type: ignore[index]
|
||||||
|
responses.append(response)
|
||||||
|
return {"result": {"responses": responses}, "error_code": 0}
|
||||||
|
else:
|
||||||
|
return self._send_request(request_dict)
|
||||||
|
|
||||||
|
def _send_request(self, request_dict: dict):
|
||||||
method = request_dict["method"]
|
method = request_dict["method"]
|
||||||
params = request_dict["params"]
|
params = request_dict["params"]
|
||||||
if method == "component_nego" or method[:4] == "get_":
|
if method == "component_nego" or method[:4] == "get_":
|
||||||
return {"result": self.info[method], "error_code": 0}
|
if method in self.info:
|
||||||
|
return {"result": self.info[method], "error_code": 0}
|
||||||
|
else:
|
||||||
|
warnings.warn(
|
||||||
|
UserWarning(
|
||||||
|
f"Fixture missing expected method {method}, try to regenerate"
|
||||||
|
),
|
||||||
|
stacklevel=1,
|
||||||
|
)
|
||||||
|
return {"result": {}, "error_code": 0}
|
||||||
elif method[:4] == "set_":
|
elif method[:4] == "set_":
|
||||||
target_method = f"get_{method[4:]}"
|
target_method = f"get_{method[4:]}"
|
||||||
self.info[target_method].update(params)
|
self.info[target_method].update(params)
|
||||||
|
@ -12,7 +12,12 @@ from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padd
|
|||||||
|
|
||||||
from ..aestransport import AesEncyptionSession, AesTransport
|
from ..aestransport import AesEncyptionSession, AesTransport
|
||||||
from ..credentials import Credentials
|
from ..credentials import Credentials
|
||||||
from ..exceptions import SmartDeviceException
|
from ..exceptions import (
|
||||||
|
SMART_RETRYABLE_ERRORS,
|
||||||
|
SMART_TIMEOUT_ERRORS,
|
||||||
|
SmartDeviceException,
|
||||||
|
SmartErrorCode,
|
||||||
|
)
|
||||||
|
|
||||||
DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}}
|
DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}}
|
||||||
|
|
||||||
@ -105,6 +110,32 @@ async def test_send(mocker, status_code, error_code, inner_error_code, expectati
|
|||||||
assert "result" in res
|
assert "result" in res
|
||||||
|
|
||||||
|
|
||||||
|
ERRORS = [e for e in SmartErrorCode if e != 0]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("error_code", ERRORS, ids=lambda e: e.name)
|
||||||
|
async def test_passthrough_errors(mocker, error_code):
|
||||||
|
host = "127.0.0.1"
|
||||||
|
mock_aes_device = MockAesDevice(host, 200, error_code, 0)
|
||||||
|
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
|
||||||
|
|
||||||
|
transport = AesTransport(host=host, credentials=Credentials("foo", "bar"))
|
||||||
|
transport._handshake_done = True
|
||||||
|
transport._session_expire_at = time.time() + 86400
|
||||||
|
transport._encryption_session = mock_aes_device.encryption_session
|
||||||
|
transport._login_token = mock_aes_device.token
|
||||||
|
|
||||||
|
request = {
|
||||||
|
"method": "get_device_info",
|
||||||
|
"params": None,
|
||||||
|
"request_time_milis": round(time.time() * 1000),
|
||||||
|
"requestID": 1,
|
||||||
|
"terminal_uuid": "foobar",
|
||||||
|
}
|
||||||
|
with pytest.raises(SmartDeviceException):
|
||||||
|
await transport.send(json_dumps(request))
|
||||||
|
|
||||||
|
|
||||||
class MockAesDevice:
|
class MockAesDevice:
|
||||||
class _mock_response:
|
class _mock_response:
|
||||||
def __init__(self, status_code, json: dict):
|
def __init__(self, status_code, json: dict):
|
||||||
|
@ -2,7 +2,7 @@ import pytest
|
|||||||
|
|
||||||
from kasa import EmeterStatus, SmartDeviceException
|
from kasa import EmeterStatus, SmartDeviceException
|
||||||
|
|
||||||
from .conftest import has_emeter, no_emeter
|
from .conftest import has_emeter, has_emeter_iot, no_emeter
|
||||||
from .newfakes import CURRENT_CONSUMPTION_SCHEMA
|
from .newfakes import CURRENT_CONSUMPTION_SCHEMA
|
||||||
|
|
||||||
|
|
||||||
@ -20,7 +20,7 @@ async def test_no_emeter(dev):
|
|||||||
await dev.erase_emeter_stats()
|
await dev.erase_emeter_stats()
|
||||||
|
|
||||||
|
|
||||||
@has_emeter
|
@has_emeter_iot
|
||||||
async def test_get_emeter_realtime(dev):
|
async def test_get_emeter_realtime(dev):
|
||||||
assert dev.has_emeter
|
assert dev.has_emeter
|
||||||
|
|
||||||
@ -28,7 +28,7 @@ async def test_get_emeter_realtime(dev):
|
|||||||
CURRENT_CONSUMPTION_SCHEMA(current_emeter)
|
CURRENT_CONSUMPTION_SCHEMA(current_emeter)
|
||||||
|
|
||||||
|
|
||||||
@has_emeter
|
@has_emeter_iot
|
||||||
@pytest.mark.requires_dummy
|
@pytest.mark.requires_dummy
|
||||||
async def test_get_emeter_daily(dev):
|
async def test_get_emeter_daily(dev):
|
||||||
assert dev.has_emeter
|
assert dev.has_emeter
|
||||||
@ -48,7 +48,7 @@ async def test_get_emeter_daily(dev):
|
|||||||
assert v * 1000 == v2
|
assert v * 1000 == v2
|
||||||
|
|
||||||
|
|
||||||
@has_emeter
|
@has_emeter_iot
|
||||||
@pytest.mark.requires_dummy
|
@pytest.mark.requires_dummy
|
||||||
async def test_get_emeter_monthly(dev):
|
async def test_get_emeter_monthly(dev):
|
||||||
assert dev.has_emeter
|
assert dev.has_emeter
|
||||||
@ -68,7 +68,7 @@ async def test_get_emeter_monthly(dev):
|
|||||||
assert v * 1000 == v2
|
assert v * 1000 == v2
|
||||||
|
|
||||||
|
|
||||||
@has_emeter
|
@has_emeter_iot
|
||||||
async def test_emeter_status(dev):
|
async def test_emeter_status(dev):
|
||||||
assert dev.has_emeter
|
assert dev.has_emeter
|
||||||
|
|
||||||
|
@ -26,20 +26,29 @@ class _mock_response:
|
|||||||
self.content = content
|
self.content = content
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"error, retry_expectation",
|
||||||
|
[
|
||||||
|
(Exception("dummy exception"), True),
|
||||||
|
(SmartDeviceException("dummy exception"), False),
|
||||||
|
],
|
||||||
|
ids=("Exception", "SmartDeviceException"),
|
||||||
|
)
|
||||||
@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
|
@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
|
||||||
@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol])
|
@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol])
|
||||||
@pytest.mark.parametrize("retry_count", [1, 3, 5])
|
@pytest.mark.parametrize("retry_count", [1, 3, 5])
|
||||||
async def test_protocol_retries(mocker, retry_count, protocol_class, transport_class):
|
async def test_protocol_retries(
|
||||||
|
mocker, retry_count, protocol_class, transport_class, error, retry_expectation
|
||||||
|
):
|
||||||
host = "127.0.0.1"
|
host = "127.0.0.1"
|
||||||
conn = mocker.patch.object(
|
conn = mocker.patch.object(httpx.AsyncClient, "post", side_effect=error)
|
||||||
httpx.AsyncClient, "post", side_effect=Exception("dummy exception")
|
|
||||||
)
|
|
||||||
with pytest.raises(SmartDeviceException):
|
with pytest.raises(SmartDeviceException):
|
||||||
await protocol_class(host, transport=transport_class(host)).query(
|
await protocol_class(host, transport=transport_class(host)).query(
|
||||||
DUMMY_QUERY, retry_count=retry_count
|
DUMMY_QUERY, retry_count=retry_count
|
||||||
)
|
)
|
||||||
|
|
||||||
assert conn.call_count == retry_count + 1
|
expected_count = retry_count + 1 if retry_expectation else 1
|
||||||
|
assert conn.call_count == expected_count
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
|
@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
|
||||||
@ -109,7 +118,7 @@ async def test_protocol_reconnect(mocker, retry_count, protocol_class, transport
|
|||||||
response = await protocol_class(host, transport=transport_class(host)).query(
|
response = await protocol_class(host, transport=transport_class(host)).query(
|
||||||
DUMMY_QUERY, retry_count=retry_count
|
DUMMY_QUERY, retry_count=retry_count
|
||||||
)
|
)
|
||||||
assert "result" in response or "great" in response
|
assert "result" in response or "foobar" in response
|
||||||
assert send_mock.call_count == retry_count
|
assert send_mock.call_count == retry_count
|
||||||
|
|
||||||
|
|
||||||
|
@ -8,7 +8,7 @@ import kasa
|
|||||||
from kasa import Credentials, SmartDevice, SmartDeviceException
|
from kasa import Credentials, SmartDevice, SmartDeviceException
|
||||||
from kasa.smartdevice import DeviceType
|
from kasa.smartdevice import DeviceType
|
||||||
|
|
||||||
from .conftest import device_iot, handle_turn_on, has_emeter, no_emeter, turn_on
|
from .conftest import device_iot, handle_turn_on, has_emeter, no_emeter_iot, turn_on
|
||||||
from .newfakes import PLUG_SCHEMA, TZ_SCHEMA, FakeTransportProtocol
|
from .newfakes import PLUG_SCHEMA, TZ_SCHEMA, FakeTransportProtocol
|
||||||
|
|
||||||
# List of all SmartXXX classes including the SmartDevice base class
|
# List of all SmartXXX classes including the SmartDevice base class
|
||||||
@ -48,7 +48,7 @@ async def test_initial_update_emeter(dev, mocker):
|
|||||||
assert spy.call_count == expected_queries + len(dev.children)
|
assert spy.call_count == expected_queries + len(dev.children)
|
||||||
|
|
||||||
|
|
||||||
@no_emeter
|
@no_emeter_iot
|
||||||
async def test_initial_update_no_emeter(dev, mocker):
|
async def test_initial_update_no_emeter(dev, mocker):
|
||||||
"""Test that the initial update performs second query if emeter is available."""
|
"""Test that the initial update performs second query if emeter is available."""
|
||||||
dev._last_update = None
|
dev._last_update = None
|
||||||
|
81
kasa/tests/test_smartprotocol.py
Normal file
81
kasa/tests/test_smartprotocol.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
import errno
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import secrets
|
||||||
|
import struct
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from contextlib import nullcontext as does_not_raise
|
||||||
|
from itertools import chain
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from ..aestransport import AesTransport
|
||||||
|
from ..credentials import Credentials
|
||||||
|
from ..exceptions import (
|
||||||
|
SMART_RETRYABLE_ERRORS,
|
||||||
|
SMART_TIMEOUT_ERRORS,
|
||||||
|
SmartDeviceException,
|
||||||
|
SmartErrorCode,
|
||||||
|
)
|
||||||
|
from ..iotprotocol import IotProtocol
|
||||||
|
from ..klaptransport import KlapEncryptionSession, KlapTransport, _sha256
|
||||||
|
from ..smartprotocol import SmartProtocol
|
||||||
|
|
||||||
|
DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}}
|
||||||
|
ERRORS = [e for e in SmartErrorCode if e != 0]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("error_code", ERRORS, ids=lambda e: e.name)
|
||||||
|
async def test_smart_device_errors(mocker, error_code):
|
||||||
|
host = "127.0.0.1"
|
||||||
|
mock_response = {"result": {"great": "success"}, "error_code": error_code.value}
|
||||||
|
|
||||||
|
mocker.patch.object(AesTransport, "perform_handshake")
|
||||||
|
mocker.patch.object(AesTransport, "perform_login")
|
||||||
|
|
||||||
|
send_mock = mocker.patch.object(AesTransport, "send", return_value=mock_response)
|
||||||
|
|
||||||
|
protocol = SmartProtocol(host, transport=AesTransport(host))
|
||||||
|
with pytest.raises(SmartDeviceException):
|
||||||
|
await protocol.query(DUMMY_QUERY, retry_count=2)
|
||||||
|
|
||||||
|
if error_code in chain(SMART_TIMEOUT_ERRORS, SMART_RETRYABLE_ERRORS):
|
||||||
|
expected_calls = 3
|
||||||
|
else:
|
||||||
|
expected_calls = 1
|
||||||
|
assert send_mock.call_count == expected_calls
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("error_code", ERRORS, ids=lambda e: e.name)
|
||||||
|
async def test_smart_device_errors_in_multiple_request(mocker, error_code):
|
||||||
|
host = "127.0.0.1"
|
||||||
|
mock_response = {
|
||||||
|
"result": {
|
||||||
|
"responses": [
|
||||||
|
{"method": "foobar1", "result": {"great": "success"}, "error_code": 0},
|
||||||
|
{
|
||||||
|
"method": "foobar2",
|
||||||
|
"result": {"great": "success"},
|
||||||
|
"error_code": error_code.value,
|
||||||
|
},
|
||||||
|
{"method": "foobar3", "result": {"great": "success"}, "error_code": 0},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"error_code": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
mocker.patch.object(AesTransport, "perform_handshake")
|
||||||
|
mocker.patch.object(AesTransport, "perform_login")
|
||||||
|
|
||||||
|
send_mock = mocker.patch.object(AesTransport, "send", return_value=mock_response)
|
||||||
|
|
||||||
|
protocol = SmartProtocol(host, transport=AesTransport(host))
|
||||||
|
with pytest.raises(SmartDeviceException):
|
||||||
|
await protocol.query(DUMMY_QUERY, retry_count=2)
|
||||||
|
if error_code in chain(SMART_TIMEOUT_ERRORS, SMART_RETRYABLE_ERRORS):
|
||||||
|
expected_calls = 3
|
||||||
|
else:
|
||||||
|
expected_calls = 1
|
||||||
|
assert send_mock.call_count == expected_calls
|
Loading…
Reference in New Issue
Block a user