Enable multiple requests in smartprotocol (#584)

* Enable multiple requests in smartprotocol

* Update following review

* Remove error_code parameter in exceptions
This commit is contained in:
sdb9696 2023-12-20 17:08:04 +00:00 committed by GitHub
parent 20ea6700a5
commit 6819c746d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 260 additions and 76 deletions

View File

@ -125,19 +125,19 @@ class AesTransport(BaseTransport):
return resp.status_code, response_data return resp.status_code, response_data
def _handle_response_error_code(self, resp_dict: dict, msg: str): def _handle_response_error_code(self, resp_dict: dict, msg: str):
if ( error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
error_code := SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type] if error_code == SmartErrorCode.SUCCESS:
) != SmartErrorCode.SUCCESS: return
msg = f"{msg}: {self._host}: {error_code.name}({error_code.value})" msg = f"{msg}: {self._host}: {error_code.name}({error_code.value})"
if error_code in SMART_TIMEOUT_ERRORS: if error_code in SMART_TIMEOUT_ERRORS:
raise TimeoutException(msg) raise TimeoutException(msg)
if error_code in SMART_RETRYABLE_ERRORS: if error_code in SMART_RETRYABLE_ERRORS:
raise RetryableException(msg) raise RetryableException(msg)
if error_code in SMART_AUTHENTICATION_ERRORS: if error_code in SMART_AUTHENTICATION_ERRORS:
self._handshake_done = False self._handshake_done = False
self._login_token = None self._login_token = None
raise AuthenticationException(msg) raise AuthenticationException(msg)
raise SmartDeviceException(msg) raise SmartDeviceException(msg)
async def send_secure_passthrough(self, request: str): async def send_secure_passthrough(self, request: str):
"""Send encrypted message as passthrough.""" """Send encrypted message as passthrough."""

View File

@ -62,6 +62,13 @@ class IotProtocol(TPLinkProtocol):
"Unable to authenticate with %s, not retrying", self._host "Unable to authenticate with %s, not retrying", self._host
) )
raise auex raise auex
except SmartDeviceException as ex:
_LOGGER.debug(
"Unable to connect to the device: %s, not retrying: %s",
self._host,
ex,
)
raise ex
except Exception as ex: except Exception as ex:
await self.close() await self.close()
if retry >= retry_count: if retry >= retry_count:

View File

@ -62,26 +62,7 @@ class SmartProtocol(TPLinkProtocol):
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
"""Query the device retrying for retry_count on failure.""" """Query the device retrying for retry_count on failure."""
async with self._query_lock: async with self._query_lock:
resp_dict = await self._query(request, retry_count) return await self._query(request, retry_count)
if (
error_code := SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
) != SmartErrorCode.SUCCESS:
msg = (
f"Error querying device: {self._host}: "
+ f"{error_code.name}({error_code.value})"
)
if error_code in SMART_TIMEOUT_ERRORS:
raise TimeoutException(msg)
if error_code in SMART_RETRYABLE_ERRORS:
raise RetryableException(msg)
if error_code in SMART_AUTHENTICATION_ERRORS:
raise AuthenticationException(msg)
raise SmartDeviceException(msg)
if "result" in resp_dict:
return resp_dict["result"]
return {}
async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
for retry in range(retry_count + 1): for retry in range(retry_count + 1):
@ -128,6 +109,11 @@ class SmartProtocol(TPLinkProtocol):
raise ex raise ex
await asyncio.sleep(self.SLEEP_SECONDS_AFTER_TIMEOUT) await asyncio.sleep(self.SLEEP_SECONDS_AFTER_TIMEOUT)
continue continue
except SmartDeviceException as ex:
# Transport would have raised RetryableException if retry makes sense.
await self.close()
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise ex
except Exception as ex: except Exception as ex:
if retry >= retry_count: if retry >= retry_count:
await self.close() await self.close()
@ -145,8 +131,15 @@ class SmartProtocol(TPLinkProtocol):
async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> Dict: async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> Dict:
if isinstance(request, dict): if isinstance(request, dict):
smart_method = next(iter(request)) if len(request) == 1:
smart_params = request[smart_method] smart_method = next(iter(request))
smart_params = request[smart_method]
else:
requests = []
for method, params in request.items():
requests.append({"method": method, "params": params})
smart_method = "multipleRequest"
smart_params = {"requests": requests}
else: else:
smart_method = request smart_method = request
smart_params = None smart_params = None
@ -165,7 +158,40 @@ class SmartProtocol(TPLinkProtocol):
_LOGGER.isEnabledFor(logging.DEBUG) and pf(response_data), _LOGGER.isEnabledFor(logging.DEBUG) and pf(response_data),
) )
return response_data self._handle_response_error_code(response_data)
if (result := response_data.get("result")) is None:
# Single set_ requests do not return a result
return {smart_method: None}
if (responses := result.get("responses")) is None:
return {smart_method: result}
# responses is returned for multipleRequest
multi_result = {}
for response in responses:
self._handle_response_error_code(response)
result = response.get("result", None)
multi_result[response["method"]] = result
return multi_result
def _handle_response_error_code(self, resp_dict: dict):
error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
if error_code == SmartErrorCode.SUCCESS:
return
msg = (
f"Error querying device: {self._host}: "
+ f"{error_code.name}({error_code.value})"
)
if method := resp_dict.get("method"):
msg += f" for method: {method}"
if error_code in SMART_TIMEOUT_ERRORS:
raise TimeoutException(msg)
if error_code in SMART_RETRYABLE_ERRORS:
raise RetryableException(msg)
if error_code in SMART_AUTHENTICATION_ERRORS:
raise AuthenticationException(msg)
raise SmartDeviceException(msg)
async def close(self) -> None: async def close(self) -> None:
"""Close the protocol.""" """Close the protocol."""

View File

@ -41,11 +41,18 @@ class TapoDevice(SmartDevice):
raise AuthenticationException("Tapo plug requires authentication.") raise AuthenticationException("Tapo plug requires authentication.")
if self._components is None: if self._components is None:
self._components = await self.protocol.query("component_nego") resp = await self.protocol.query("component_nego")
self._components = resp["component_nego"]
self._info = await self.protocol.query("get_device_info") req = {
self._usage = await self.protocol.query("get_device_usage") "get_device_info": None,
self._time = await self.protocol.query("get_device_time") "get_device_usage": None,
"get_device_time": None,
}
resp = await self.protocol.query(req)
self._info = resp["get_device_info"]
self._usage = resp["get_device_usage"]
self._time = resp["get_device_time"]
self._last_update = self._data = { self._last_update = self._data = {
"components": self._components, "components": self._components,

View File

@ -39,8 +39,13 @@ class TapoPlug(TapoDevice):
"""Call the device endpoint and update the device data.""" """Call the device endpoint and update the device data."""
await super().update(update_children) await super().update(update_children)
self._energy = await self.protocol.query("get_energy_usage") req = {
self._emeter = await self.protocol.query("get_current_power") "get_energy_usage": None,
"get_current_power": None,
}
resp = await self.protocol.query(req)
self._energy = resp["get_energy_usage"]
self._emeter = resp["get_current_power"]
self._data["energy"] = self._energy self._data["energy"] = self._energy
self._data["emeter"] = self._emeter self._data["emeter"] = self._emeter
@ -71,6 +76,13 @@ class TapoPlug(TapoDevice):
} }
) )
async def get_emeter_realtime(self) -> EmeterStatus:
"""Retrieve current energy readings."""
self._verify_emeter()
resp = await self.protocol.query("get_energy_usage")
self._energy = resp["get_energy_usage"]
return self.emeter_realtime
@property @property
def emeter_today(self) -> Optional[float]: def emeter_today(self) -> Optional[float]:
"""Get the emeter value for today.""" """Get the emeter value for today."""

View File

@ -196,9 +196,13 @@ def parametrize(desc, devices, protocol_filter=None, ids=None):
) )
has_emeter = parametrize("has emeter", WITH_EMETER_IOT, protocol_filter={"IOT"}) has_emeter = parametrize("has emeter", WITH_EMETER, protocol_filter={"SMART", "IOT"})
no_emeter = parametrize( no_emeter = parametrize(
"no emeter", ALL_DEVICES_IOT - WITH_EMETER_IOT, protocol_filter={"SMART", "IOT"} "no emeter", ALL_DEVICES - WITH_EMETER, protocol_filter={"SMART", "IOT"}
)
has_emeter_iot = parametrize("has emeter iot", WITH_EMETER_IOT, protocol_filter={"IOT"})
no_emeter_iot = parametrize(
"no emeter iot", ALL_DEVICES_IOT - WITH_EMETER_IOT, protocol_filter={"IOT"}
) )
bulb = parametrize("bulbs", BULBS, protocol_filter={"SMART", "IOT"}) bulb = parametrize("bulbs", BULBS, protocol_filter={"SMART", "IOT"})

View File

@ -1,6 +1,7 @@
import copy import copy
import logging import logging
import re import re
import warnings
from json import loads as json_loads from json import loads as json_loads
from voluptuous import ( from voluptuous import (
@ -294,9 +295,7 @@ class FakeSmartProtocol(SmartProtocol):
async def query(self, request, retry_count: int = 3): async def query(self, request, retry_count: int = 3):
"""Implement query here so can still patch SmartProtocol.query.""" """Implement query here so can still patch SmartProtocol.query."""
resp_dict = await self._query(request, retry_count) resp_dict = await self._query(request, retry_count)
if "result" in resp_dict: return resp_dict
return resp_dict["result"]
return {}
class FakeSmartTransport(BaseTransport): class FakeSmartTransport(BaseTransport):
@ -306,26 +305,34 @@ class FakeSmartTransport(BaseTransport):
) )
self.info = info self.info = info
@property
def needs_handshake(self) -> bool:
return False
@property
def needs_login(self) -> bool:
return False
async def login(self, request: str) -> None:
pass
async def handshake(self) -> None:
pass
async def send(self, request: str): async def send(self, request: str):
request_dict = json_loads(request) request_dict = json_loads(request)
method = request_dict["method"]
params = request_dict["params"]
if method == "multipleRequest":
responses = []
for request in params["requests"]:
response = self._send_request(request) # type: ignore[arg-type]
response["method"] = request["method"] # type: ignore[index]
responses.append(response)
return {"result": {"responses": responses}, "error_code": 0}
else:
return self._send_request(request_dict)
def _send_request(self, request_dict: dict):
method = request_dict["method"] method = request_dict["method"]
params = request_dict["params"] params = request_dict["params"]
if method == "component_nego" or method[:4] == "get_": if method == "component_nego" or method[:4] == "get_":
return {"result": self.info[method], "error_code": 0} if method in self.info:
return {"result": self.info[method], "error_code": 0}
else:
warnings.warn(
UserWarning(
f"Fixture missing expected method {method}, try to regenerate"
),
stacklevel=1,
)
return {"result": {}, "error_code": 0}
elif method[:4] == "set_": elif method[:4] == "set_":
target_method = f"get_{method[4:]}" target_method = f"get_{method[4:]}"
self.info[target_method].update(params) self.info[target_method].update(params)

View File

@ -12,7 +12,12 @@ from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padd
from ..aestransport import AesEncyptionSession, AesTransport from ..aestransport import AesEncyptionSession, AesTransport
from ..credentials import Credentials from ..credentials import Credentials
from ..exceptions import SmartDeviceException from ..exceptions import (
SMART_RETRYABLE_ERRORS,
SMART_TIMEOUT_ERRORS,
SmartDeviceException,
SmartErrorCode,
)
DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}} DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}}
@ -105,6 +110,32 @@ async def test_send(mocker, status_code, error_code, inner_error_code, expectati
assert "result" in res assert "result" in res
ERRORS = [e for e in SmartErrorCode if e != 0]
@pytest.mark.parametrize("error_code", ERRORS, ids=lambda e: e.name)
async def test_passthrough_errors(mocker, error_code):
host = "127.0.0.1"
mock_aes_device = MockAesDevice(host, 200, error_code, 0)
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
transport = AesTransport(host=host, credentials=Credentials("foo", "bar"))
transport._handshake_done = True
transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session
transport._login_token = mock_aes_device.token
request = {
"method": "get_device_info",
"params": None,
"request_time_milis": round(time.time() * 1000),
"requestID": 1,
"terminal_uuid": "foobar",
}
with pytest.raises(SmartDeviceException):
await transport.send(json_dumps(request))
class MockAesDevice: class MockAesDevice:
class _mock_response: class _mock_response:
def __init__(self, status_code, json: dict): def __init__(self, status_code, json: dict):

View File

@ -2,7 +2,7 @@ import pytest
from kasa import EmeterStatus, SmartDeviceException from kasa import EmeterStatus, SmartDeviceException
from .conftest import has_emeter, no_emeter from .conftest import has_emeter, has_emeter_iot, no_emeter
from .newfakes import CURRENT_CONSUMPTION_SCHEMA from .newfakes import CURRENT_CONSUMPTION_SCHEMA
@ -20,7 +20,7 @@ async def test_no_emeter(dev):
await dev.erase_emeter_stats() await dev.erase_emeter_stats()
@has_emeter @has_emeter_iot
async def test_get_emeter_realtime(dev): async def test_get_emeter_realtime(dev):
assert dev.has_emeter assert dev.has_emeter
@ -28,7 +28,7 @@ async def test_get_emeter_realtime(dev):
CURRENT_CONSUMPTION_SCHEMA(current_emeter) CURRENT_CONSUMPTION_SCHEMA(current_emeter)
@has_emeter @has_emeter_iot
@pytest.mark.requires_dummy @pytest.mark.requires_dummy
async def test_get_emeter_daily(dev): async def test_get_emeter_daily(dev):
assert dev.has_emeter assert dev.has_emeter
@ -48,7 +48,7 @@ async def test_get_emeter_daily(dev):
assert v * 1000 == v2 assert v * 1000 == v2
@has_emeter @has_emeter_iot
@pytest.mark.requires_dummy @pytest.mark.requires_dummy
async def test_get_emeter_monthly(dev): async def test_get_emeter_monthly(dev):
assert dev.has_emeter assert dev.has_emeter
@ -68,7 +68,7 @@ async def test_get_emeter_monthly(dev):
assert v * 1000 == v2 assert v * 1000 == v2
@has_emeter @has_emeter_iot
async def test_emeter_status(dev): async def test_emeter_status(dev):
assert dev.has_emeter assert dev.has_emeter

View File

@ -26,20 +26,29 @@ class _mock_response:
self.content = content self.content = content
@pytest.mark.parametrize(
"error, retry_expectation",
[
(Exception("dummy exception"), True),
(SmartDeviceException("dummy exception"), False),
],
ids=("Exception", "SmartDeviceException"),
)
@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport]) @pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol]) @pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol])
@pytest.mark.parametrize("retry_count", [1, 3, 5]) @pytest.mark.parametrize("retry_count", [1, 3, 5])
async def test_protocol_retries(mocker, retry_count, protocol_class, transport_class): async def test_protocol_retries(
mocker, retry_count, protocol_class, transport_class, error, retry_expectation
):
host = "127.0.0.1" host = "127.0.0.1"
conn = mocker.patch.object( conn = mocker.patch.object(httpx.AsyncClient, "post", side_effect=error)
httpx.AsyncClient, "post", side_effect=Exception("dummy exception")
)
with pytest.raises(SmartDeviceException): with pytest.raises(SmartDeviceException):
await protocol_class(host, transport=transport_class(host)).query( await protocol_class(host, transport=transport_class(host)).query(
DUMMY_QUERY, retry_count=retry_count DUMMY_QUERY, retry_count=retry_count
) )
assert conn.call_count == retry_count + 1 expected_count = retry_count + 1 if retry_expectation else 1
assert conn.call_count == expected_count
@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport]) @pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
@ -109,7 +118,7 @@ async def test_protocol_reconnect(mocker, retry_count, protocol_class, transport
response = await protocol_class(host, transport=transport_class(host)).query( response = await protocol_class(host, transport=transport_class(host)).query(
DUMMY_QUERY, retry_count=retry_count DUMMY_QUERY, retry_count=retry_count
) )
assert "result" in response or "great" in response assert "result" in response or "foobar" in response
assert send_mock.call_count == retry_count assert send_mock.call_count == retry_count

View File

@ -8,7 +8,7 @@ import kasa
from kasa import Credentials, SmartDevice, SmartDeviceException from kasa import Credentials, SmartDevice, SmartDeviceException
from kasa.smartdevice import DeviceType from kasa.smartdevice import DeviceType
from .conftest import device_iot, handle_turn_on, has_emeter, no_emeter, turn_on from .conftest import device_iot, handle_turn_on, has_emeter, no_emeter_iot, turn_on
from .newfakes import PLUG_SCHEMA, TZ_SCHEMA, FakeTransportProtocol from .newfakes import PLUG_SCHEMA, TZ_SCHEMA, FakeTransportProtocol
# List of all SmartXXX classes including the SmartDevice base class # List of all SmartXXX classes including the SmartDevice base class
@ -48,7 +48,7 @@ async def test_initial_update_emeter(dev, mocker):
assert spy.call_count == expected_queries + len(dev.children) assert spy.call_count == expected_queries + len(dev.children)
@no_emeter @no_emeter_iot
async def test_initial_update_no_emeter(dev, mocker): async def test_initial_update_no_emeter(dev, mocker):
"""Test that the initial update performs second query if emeter is available.""" """Test that the initial update performs second query if emeter is available."""
dev._last_update = None dev._last_update = None

View File

@ -0,0 +1,81 @@
import errno
import json
import logging
import secrets
import struct
import sys
import time
from contextlib import nullcontext as does_not_raise
from itertools import chain
import httpx
import pytest
from ..aestransport import AesTransport
from ..credentials import Credentials
from ..exceptions import (
SMART_RETRYABLE_ERRORS,
SMART_TIMEOUT_ERRORS,
SmartDeviceException,
SmartErrorCode,
)
from ..iotprotocol import IotProtocol
from ..klaptransport import KlapEncryptionSession, KlapTransport, _sha256
from ..smartprotocol import SmartProtocol
DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}}
ERRORS = [e for e in SmartErrorCode if e != 0]
@pytest.mark.parametrize("error_code", ERRORS, ids=lambda e: e.name)
async def test_smart_device_errors(mocker, error_code):
host = "127.0.0.1"
mock_response = {"result": {"great": "success"}, "error_code": error_code.value}
mocker.patch.object(AesTransport, "perform_handshake")
mocker.patch.object(AesTransport, "perform_login")
send_mock = mocker.patch.object(AesTransport, "send", return_value=mock_response)
protocol = SmartProtocol(host, transport=AesTransport(host))
with pytest.raises(SmartDeviceException):
await protocol.query(DUMMY_QUERY, retry_count=2)
if error_code in chain(SMART_TIMEOUT_ERRORS, SMART_RETRYABLE_ERRORS):
expected_calls = 3
else:
expected_calls = 1
assert send_mock.call_count == expected_calls
@pytest.mark.parametrize("error_code", ERRORS, ids=lambda e: e.name)
async def test_smart_device_errors_in_multiple_request(mocker, error_code):
host = "127.0.0.1"
mock_response = {
"result": {
"responses": [
{"method": "foobar1", "result": {"great": "success"}, "error_code": 0},
{
"method": "foobar2",
"result": {"great": "success"},
"error_code": error_code.value,
},
{"method": "foobar3", "result": {"great": "success"}, "error_code": 0},
]
},
"error_code": 0,
}
mocker.patch.object(AesTransport, "perform_handshake")
mocker.patch.object(AesTransport, "perform_login")
send_mock = mocker.patch.object(AesTransport, "send", return_value=mock_response)
protocol = SmartProtocol(host, transport=AesTransport(host))
with pytest.raises(SmartDeviceException):
await protocol.query(DUMMY_QUERY, retry_count=2)
if error_code in chain(SMART_TIMEOUT_ERRORS, SMART_RETRYABLE_ERRORS):
expected_calls = 3
else:
expected_calls = 1
assert send_mock.call_count == expected_calls