From 1ad2a05b6578747f0d375372fc90599328b9d6e3 Mon Sep 17 00:00:00 2001 From: Teemu R Date: Mon, 29 Jan 2024 17:11:29 +0100 Subject: [PATCH] Initial support for tapos with child devices (#720) * Add ChildDevice and ChildProtocolWrapper * Initialize & update children * Fix circular imports * Add dummy_protocol fixture and tests for unwrapping responseData * Use dummy_protocol for existing smartprotocol tests * Move _ChildProtocolWrapper to smartprotocol.py * Use dummy_protocol for test multiple requests * Use device_id instead of position for selecting the child * Fix wrapping for regular requests * Remove unused imports * tweak * rename child_device to childdevice * Fix import --- kasa/smartprotocol.py | 67 ++++++++++++ kasa/tapo/childdevice.py | 44 ++++++++ kasa/tapo/tapodevice.py | 29 ++++- kasa/tapo/tapoplug.py | 4 +- kasa/tests/test_protocol.py | 3 + kasa/tests/test_smartprotocol.py | 176 ++++++++++++++++++++++++------- 6 files changed, 280 insertions(+), 43 deletions(-) create mode 100644 kasa/tapo/childdevice.py diff --git a/kasa/smartprotocol.py b/kasa/smartprotocol.py index 9ec2547d..74f2275d 100644 --- a/kasa/smartprotocol.py +++ b/kasa/smartprotocol.py @@ -279,3 +279,70 @@ class SnowflakeId: while timestamp <= last_timestamp: timestamp = self._current_millis() return timestamp + + +class _ChildProtocolWrapper(SmartProtocol): + """Protocol wrapper for controlling child devices. + + This is an internal class used to communicate with child devices, + and should not be used directly. + + This class overrides query() method of the protocol to modify all + outgoing queries to use ``control_child`` command, and unwraps the + device responses before returning to the caller. + """ + + def __init__(self, device_id: str, base_protocol: SmartProtocol): + self._device_id = device_id + self._protocol = base_protocol + self._transport = base_protocol._transport + + def _get_method_and_params_for_request(self, request): + """Return payload for wrapping. + + TODO: this does not support batches and requires refactoring in the future. + """ + if isinstance(request, dict): + if len(request) == 1: + smart_method = next(iter(request)) + smart_params = request[smart_method] + else: + smart_method = "multipleRequest" + requests = [ + {"method": method, "params": params} + for method, params in request.items() + ] + smart_params = {"requests": requests} + else: + smart_method = request + smart_params = None + + return smart_method, smart_params + + async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: + """Wrap request inside control_child envelope.""" + method, params = self._get_method_and_params_for_request(request) + request_data = { + "method": method, + "params": params, + } + wrapped_payload = { + "control_child": { + "device_id": self._device_id, + "requestData": request_data, + } + } + + response = await self._protocol.query(wrapped_payload, retry_count) + result = response.get("control_child") + # Unwrap responseData for control_child + if result and (response_data := result.get("responseData")): + self._handle_response_error_code(response_data) + result = response_data.get("result") + + # TODO: handle multipleRequest unwrapping + + return {method: result} + + async def close(self) -> None: + """Do nothing as the parent owns the protocol.""" diff --git a/kasa/tapo/childdevice.py b/kasa/tapo/childdevice.py new file mode 100644 index 00000000..c1b108a3 --- /dev/null +++ b/kasa/tapo/childdevice.py @@ -0,0 +1,44 @@ +"""Child device implementation.""" +from typing import Dict, Optional + +from ..deviceconfig import DeviceConfig +from ..exceptions import SmartDeviceException +from ..smartprotocol import SmartProtocol, _ChildProtocolWrapper +from .tapodevice import TapoDevice + + +class ChildDevice(TapoDevice): + """Presentation of a child device. + + This wraps the protocol communications and sets internal data for the child. + """ + + def __init__( + self, + parent: TapoDevice, + child_id: str, + config: Optional[DeviceConfig] = None, + protocol: Optional[SmartProtocol] = None, + ) -> None: + super().__init__(parent.host, config=parent.config, protocol=parent.protocol) + self._parent = parent + self._id = child_id + self.protocol = _ChildProtocolWrapper(child_id, parent.protocol) + + async def update(self, update_children: bool = True): + """We just set the info here accordingly.""" + + def _get_child_info() -> Dict: + """Return the subdevice information for this device.""" + for child in self._parent._last_update["child_info"]["child_device_list"]: + if child["device_id"] == self._id: + return child + + raise SmartDeviceException( + f"Unable to find child device with position {self._id}" + ) + + self._last_update = self._sys_info = self._info = _get_child_info() + + def __repr__(self): + return f"" diff --git a/kasa/tapo/tapodevice.py b/kasa/tapo/tapodevice.py index 9edcca86..a7e57a6d 100644 --- a/kasa/tapo/tapodevice.py +++ b/kasa/tapo/tapodevice.py @@ -5,11 +5,11 @@ from datetime import datetime, timedelta, timezone from typing import Any, Dict, List, Optional, Set, cast from ..aestransport import AesTransport +from ..device_type import DeviceType from ..deviceconfig import DeviceConfig from ..emeterstatus import EmeterStatus from ..exceptions import AuthenticationException, SmartDeviceException from ..modules import Emeter -from ..protocol import BaseProtocol from ..smartdevice import SmartDevice, WifiNetwork from ..smartprotocol import SmartProtocol @@ -24,17 +24,27 @@ class TapoDevice(SmartDevice): host: str, *, config: Optional[DeviceConfig] = None, - protocol: Optional[BaseProtocol] = None, + protocol: Optional[SmartProtocol] = None, ) -> None: _protocol = protocol or SmartProtocol( transport=AesTransport(config=config or DeviceConfig(host=host)), ) super().__init__(host=host, config=config, protocol=_protocol) + self.protocol: SmartProtocol self._components_raw: Optional[Dict[str, Any]] = None self._components: Dict[str, int] self._state_information: Dict[str, Any] = {} - self._discovery_info: Optional[Dict[str, Any]] = None - self.modules: Dict[str, Any] = {} + + async def _initialize_children(self): + children = self._last_update["child_info"]["child_device_list"] + # TODO: Use the type information to construct children, + # as hubs can also have them. + from .childdevice import ChildDevice + + self.children = [ + ChildDevice(parent=self, child_id=child["device_id"]) for child in children + ] + self._device_type = DeviceType.Strip async def update(self, update_children: bool = True): """Update the device.""" @@ -51,6 +61,10 @@ class TapoDevice(SmartDevice): await self._initialize_modules() extra_reqs: Dict[str, Any] = {} + + if "child_device" in self._components: + extra_reqs = {**extra_reqs, "get_child_device_list": None} + if "energy_monitoring" in self._components: extra_reqs = { **extra_reqs, @@ -81,8 +95,15 @@ class TapoDevice(SmartDevice): "time": self._time, "energy": self._energy, "emeter": self._emeter, + "child_info": resp.get("get_child_device_list", {}), } + if self._last_update["child_info"]: + if not self.children: + await self._initialize_children() + for child in self.children: + await child.update() + _LOGGER.debug("Got an update: %s", self._data) async def _initialize_modules(self): diff --git a/kasa/tapo/tapoplug.py b/kasa/tapo/tapoplug.py index 1bd90fd3..e4355e4b 100644 --- a/kasa/tapo/tapoplug.py +++ b/kasa/tapo/tapoplug.py @@ -4,8 +4,8 @@ from datetime import datetime, timedelta from typing import Any, Dict, Optional, cast from ..deviceconfig import DeviceConfig -from ..protocol import BaseProtocol from ..smartdevice import DeviceType +from ..smartprotocol import SmartProtocol from .tapodevice import TapoDevice _LOGGER = logging.getLogger(__name__) @@ -19,7 +19,7 @@ class TapoPlug(TapoDevice): host: str, *, config: Optional[DeviceConfig] = None, - protocol: Optional[BaseProtocol] = None, + protocol: Optional[SmartProtocol] = None, ) -> None: super().__init__(host=host, config=config, protocol=protocol) self._device_type = DeviceType.Plug diff --git a/kasa/tests/test_protocol.py b/kasa/tests/test_protocol.py index e71f4296..46359742 100644 --- a/kasa/tests/test_protocol.py +++ b/kasa/tests/test_protocol.py @@ -482,6 +482,9 @@ def _get_subclasses(of_class): "class_name_obj", _get_subclasses(BaseProtocol), ids=lambda t: t[0] ) def test_protocol_init_signature(class_name_obj): + if class_name_obj[0].startswith("_"): + pytest.skip("Skipping internal protocols") + return params = list(inspect.signature(class_name_obj[1].__init__).parameters.values()) assert len(params) == 2 diff --git a/kasa/tests/test_smartprotocol.py b/kasa/tests/test_smartprotocol.py index 9b597b51..619caef0 100644 --- a/kasa/tests/test_smartprotocol.py +++ b/kasa/tests/test_smartprotocol.py @@ -1,16 +1,8 @@ -import errno -import json -import logging -import secrets -import struct -import sys -import time -from contextlib import nullcontext as does_not_raise from itertools import chain +from typing import Dict import pytest -from ..aestransport import AesTransport from ..credentials import Credentials from ..deviceconfig import DeviceConfig from ..exceptions import ( @@ -19,9 +11,8 @@ from ..exceptions import ( SmartDeviceException, SmartErrorCode, ) -from ..iotprotocol import IotProtocol -from ..klaptransport import KlapEncryptionSession, KlapTransport, _sha256 -from ..smartprotocol import SmartProtocol +from ..protocol import BaseTransport +from ..smartprotocol import SmartProtocol, _ChildProtocolWrapper DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}} DUMMY_MULTIPLE_QUERY = { @@ -31,20 +22,45 @@ DUMMY_MULTIPLE_QUERY = { ERRORS = [e for e in SmartErrorCode if e != 0] +# TODO: this could be moved to conftest to make it available for other tests? +@pytest.fixture() +def dummy_protocol(): + """Return a smart protocol instance with a mocking-ready dummy transport.""" + + class DummyTransport(BaseTransport): + @property + def default_port(self) -> int: + return -1 + + @property + def credentials_hash(self) -> str: + return "dummy hash" + + async def send(self, request: str) -> Dict: + return {} + + async def close(self) -> None: + pass + + async def reset(self) -> None: + pass + + transport = DummyTransport(config=DeviceConfig(host="127.0.0.123")) + protocol = SmartProtocol(transport=transport) + + return protocol + + @pytest.mark.parametrize("error_code", ERRORS, ids=lambda e: e.name) -async def test_smart_device_errors(mocker, error_code): - host = "127.0.0.1" +async def test_smart_device_errors(dummy_protocol, mocker, error_code): mock_response = {"result": {"great": "success"}, "error_code": error_code.value} - mocker.patch.object(AesTransport, "perform_handshake") - mocker.patch.object(AesTransport, "perform_login") + send_mock = mocker.patch.object( + dummy_protocol._transport, "send", return_value=mock_response + ) - send_mock = mocker.patch.object(AesTransport, "send", return_value=mock_response) - - config = DeviceConfig(host, credentials=Credentials("foo", "bar")) - protocol = SmartProtocol(transport=AesTransport(config=config)) with pytest.raises(SmartDeviceException): - await protocol.query(DUMMY_QUERY, retry_count=2) + await dummy_protocol.query(DUMMY_QUERY, retry_count=2) if error_code in chain(SMART_TIMEOUT_ERRORS, SMART_RETRYABLE_ERRORS): expected_calls = 3 @@ -54,8 +70,9 @@ async def test_smart_device_errors(mocker, error_code): @pytest.mark.parametrize("error_code", ERRORS, ids=lambda e: e.name) -async def test_smart_device_errors_in_multiple_request(mocker, error_code): - host = "127.0.0.1" +async def test_smart_device_errors_in_multiple_request( + dummy_protocol, mocker, error_code +): mock_response = { "result": { "responses": [ @@ -71,14 +88,11 @@ async def test_smart_device_errors_in_multiple_request(mocker, error_code): "error_code": 0, } - mocker.patch.object(AesTransport, "perform_handshake") - mocker.patch.object(AesTransport, "perform_login") - - send_mock = mocker.patch.object(AesTransport, "send", return_value=mock_response) - config = DeviceConfig(host, credentials=Credentials("foo", "bar")) - protocol = SmartProtocol(transport=AesTransport(config=config)) + send_mock = mocker.patch.object( + dummy_protocol._transport, "send", return_value=mock_response + ) with pytest.raises(SmartDeviceException): - await protocol.query(DUMMY_MULTIPLE_QUERY, retry_count=2) + await dummy_protocol.query(DUMMY_MULTIPLE_QUERY, retry_count=2) if error_code in chain(SMART_TIMEOUT_ERRORS, SMART_RETRYABLE_ERRORS): expected_calls = 3 else: @@ -88,7 +102,9 @@ async def test_smart_device_errors_in_multiple_request(mocker, error_code): @pytest.mark.parametrize("request_size", [1, 3, 5, 10]) @pytest.mark.parametrize("batch_size", [1, 2, 3, 4, 5]) -async def test_smart_device_multiple_request(mocker, request_size, batch_size): +async def test_smart_device_multiple_request( + dummy_protocol, mocker, request_size, batch_size +): host = "127.0.0.1" requests = {} mock_response = { @@ -102,15 +118,101 @@ async def test_smart_device_multiple_request(mocker, request_size, batch_size): {"method": method, "result": {"great": "success"}, "error_code": 0} ) - mocker.patch.object(AesTransport, "perform_handshake") - mocker.patch.object(AesTransport, "perform_login") - - send_mock = mocker.patch.object(AesTransport, "send", return_value=mock_response) + send_mock = mocker.patch.object( + dummy_protocol._transport, "send", return_value=mock_response + ) config = DeviceConfig( host, credentials=Credentials("foo", "bar"), batch_size=batch_size ) - protocol = SmartProtocol(transport=AesTransport(config=config)) + dummy_protocol._transport._config = config - await protocol.query(requests, retry_count=0) + await dummy_protocol.query(requests, retry_count=0) expected_count = int(request_size / batch_size) + (request_size % batch_size > 0) assert send_mock.call_count == expected_count + + +async def test_childdevicewrapper_unwrapping(dummy_protocol, mocker): + """Test that responseData gets unwrapped correctly.""" + wrapped_protocol = _ChildProtocolWrapper("dummyid", dummy_protocol) + mock_response = {"error_code": 0, "result": {"responseData": {"error_code": 0}}} + + mocker.patch.object(wrapped_protocol._transport, "send", return_value=mock_response) + res = await wrapped_protocol.query(DUMMY_QUERY) + assert res == {"foobar": None} + + +async def test_childdevicewrapper_unwrapping_with_payload(dummy_protocol, mocker): + wrapped_protocol = _ChildProtocolWrapper("dummyid", dummy_protocol) + mock_response = { + "error_code": 0, + "result": {"responseData": {"error_code": 0, "result": {"bar": "bar"}}}, + } + mocker.patch.object(wrapped_protocol._transport, "send", return_value=mock_response) + res = await wrapped_protocol.query(DUMMY_QUERY) + assert res == {"foobar": {"bar": "bar"}} + + +async def test_childdevicewrapper_error(dummy_protocol, mocker): + """Test that errors inside the responseData payload cause an exception.""" + wrapped_protocol = _ChildProtocolWrapper("dummyid", dummy_protocol) + mock_response = {"error_code": 0, "result": {"responseData": {"error_code": -1001}}} + + mocker.patch.object(wrapped_protocol._transport, "send", return_value=mock_response) + with pytest.raises(SmartDeviceException): + await wrapped_protocol.query(DUMMY_QUERY) + + +@pytest.mark.skip("childprotocolwrapper does not yet support multirequests") +async def test_childdevicewrapper_unwrapping_multiplerequest(dummy_protocol, mocker): + """Test that unwrapping multiplerequest works correctly.""" + mock_response = { + "error_code": 0, + "result": { + "responseData": { + "result": { + "responses": [ + { + "error_code": 0, + "method": "get_device_info", + "result": {"foo": "bar"}, + }, + { + "error_code": 0, + "method": "second_command", + "result": {"bar": "foo"}, + }, + ] + } + } + }, + } + + mocker.patch.object(dummy_protocol._transport, "send", return_value=mock_response) + resp = await dummy_protocol.query(DUMMY_QUERY) + assert resp == {"get_device_info": {"foo": "bar"}, "second_command": {"bar": "foo"}} + + +@pytest.mark.skip("childprotocolwrapper does not yet support multirequests") +async def test_childdevicewrapper_multiplerequest_error(dummy_protocol, mocker): + """Test that errors inside multipleRequest response of responseData raise an exception.""" + mock_response = { + "error_code": 0, + "result": { + "responseData": { + "result": { + "responses": [ + { + "error_code": 0, + "method": "get_device_info", + "result": {"foo": "bar"}, + }, + {"error_code": -1001, "method": "invalid_command"}, + ] + } + } + }, + } + + mocker.patch.object(dummy_protocol._transport, "send", return_value=mock_response) + with pytest.raises(SmartDeviceException): + await dummy_protocol.query(DUMMY_QUERY)