Initial support for tapos with child devices (#720)

* Add ChildDevice and ChildProtocolWrapper

* Initialize & update children

* Fix circular imports

* Add dummy_protocol fixture and tests for unwrapping responseData

* Use dummy_protocol for existing smartprotocol tests

* Move _ChildProtocolWrapper to smartprotocol.py

* Use dummy_protocol for test multiple requests

* Use device_id instead of position for selecting the child

* Fix wrapping for regular requests

* Remove unused imports

* tweak

* rename child_device to childdevice

* Fix import
This commit is contained in:
Teemu R 2024-01-29 17:11:29 +01:00 committed by GitHub
parent b479b6d84d
commit 1ad2a05b65
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 280 additions and 43 deletions

View File

@ -279,3 +279,70 @@ class SnowflakeId:
while timestamp <= last_timestamp:
timestamp = self._current_millis()
return timestamp
class _ChildProtocolWrapper(SmartProtocol):
"""Protocol wrapper for controlling child devices.
This is an internal class used to communicate with child devices,
and should not be used directly.
This class overrides query() method of the protocol to modify all
outgoing queries to use ``control_child`` command, and unwraps the
device responses before returning to the caller.
"""
def __init__(self, device_id: str, base_protocol: SmartProtocol):
self._device_id = device_id
self._protocol = base_protocol
self._transport = base_protocol._transport
def _get_method_and_params_for_request(self, request):
"""Return payload for wrapping.
TODO: this does not support batches and requires refactoring in the future.
"""
if isinstance(request, dict):
if len(request) == 1:
smart_method = next(iter(request))
smart_params = request[smart_method]
else:
smart_method = "multipleRequest"
requests = [
{"method": method, "params": params}
for method, params in request.items()
]
smart_params = {"requests": requests}
else:
smart_method = request
smart_params = None
return smart_method, smart_params
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
"""Wrap request inside control_child envelope."""
method, params = self._get_method_and_params_for_request(request)
request_data = {
"method": method,
"params": params,
}
wrapped_payload = {
"control_child": {
"device_id": self._device_id,
"requestData": request_data,
}
}
response = await self._protocol.query(wrapped_payload, retry_count)
result = response.get("control_child")
# Unwrap responseData for control_child
if result and (response_data := result.get("responseData")):
self._handle_response_error_code(response_data)
result = response_data.get("result")
# TODO: handle multipleRequest unwrapping
return {method: result}
async def close(self) -> None:
"""Do nothing as the parent owns the protocol."""

44
kasa/tapo/childdevice.py Normal file
View File

@ -0,0 +1,44 @@
"""Child device implementation."""
from typing import Dict, Optional
from ..deviceconfig import DeviceConfig
from ..exceptions import SmartDeviceException
from ..smartprotocol import SmartProtocol, _ChildProtocolWrapper
from .tapodevice import TapoDevice
class ChildDevice(TapoDevice):
"""Presentation of a child device.
This wraps the protocol communications and sets internal data for the child.
"""
def __init__(
self,
parent: TapoDevice,
child_id: str,
config: Optional[DeviceConfig] = None,
protocol: Optional[SmartProtocol] = None,
) -> None:
super().__init__(parent.host, config=parent.config, protocol=parent.protocol)
self._parent = parent
self._id = child_id
self.protocol = _ChildProtocolWrapper(child_id, parent.protocol)
async def update(self, update_children: bool = True):
"""We just set the info here accordingly."""
def _get_child_info() -> Dict:
"""Return the subdevice information for this device."""
for child in self._parent._last_update["child_info"]["child_device_list"]:
if child["device_id"] == self._id:
return child
raise SmartDeviceException(
f"Unable to find child device with position {self._id}"
)
self._last_update = self._sys_info = self._info = _get_child_info()
def __repr__(self):
return f"<ChildDevice {self.alias} of {self._parent}>"

View File

@ -5,11 +5,11 @@ from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional, Set, cast
from ..aestransport import AesTransport
from ..device_type import DeviceType
from ..deviceconfig import DeviceConfig
from ..emeterstatus import EmeterStatus
from ..exceptions import AuthenticationException, SmartDeviceException
from ..modules import Emeter
from ..protocol import BaseProtocol
from ..smartdevice import SmartDevice, WifiNetwork
from ..smartprotocol import SmartProtocol
@ -24,17 +24,27 @@ class TapoDevice(SmartDevice):
host: str,
*,
config: Optional[DeviceConfig] = None,
protocol: Optional[BaseProtocol] = None,
protocol: Optional[SmartProtocol] = None,
) -> None:
_protocol = protocol or SmartProtocol(
transport=AesTransport(config=config or DeviceConfig(host=host)),
)
super().__init__(host=host, config=config, protocol=_protocol)
self.protocol: SmartProtocol
self._components_raw: Optional[Dict[str, Any]] = None
self._components: Dict[str, int]
self._state_information: Dict[str, Any] = {}
self._discovery_info: Optional[Dict[str, Any]] = None
self.modules: Dict[str, Any] = {}
async def _initialize_children(self):
children = self._last_update["child_info"]["child_device_list"]
# TODO: Use the type information to construct children,
# as hubs can also have them.
from .childdevice import ChildDevice
self.children = [
ChildDevice(parent=self, child_id=child["device_id"]) for child in children
]
self._device_type = DeviceType.Strip
async def update(self, update_children: bool = True):
"""Update the device."""
@ -51,6 +61,10 @@ class TapoDevice(SmartDevice):
await self._initialize_modules()
extra_reqs: Dict[str, Any] = {}
if "child_device" in self._components:
extra_reqs = {**extra_reqs, "get_child_device_list": None}
if "energy_monitoring" in self._components:
extra_reqs = {
**extra_reqs,
@ -81,8 +95,15 @@ class TapoDevice(SmartDevice):
"time": self._time,
"energy": self._energy,
"emeter": self._emeter,
"child_info": resp.get("get_child_device_list", {}),
}
if self._last_update["child_info"]:
if not self.children:
await self._initialize_children()
for child in self.children:
await child.update()
_LOGGER.debug("Got an update: %s", self._data)
async def _initialize_modules(self):

View File

@ -4,8 +4,8 @@ from datetime import datetime, timedelta
from typing import Any, Dict, Optional, cast
from ..deviceconfig import DeviceConfig
from ..protocol import BaseProtocol
from ..smartdevice import DeviceType
from ..smartprotocol import SmartProtocol
from .tapodevice import TapoDevice
_LOGGER = logging.getLogger(__name__)
@ -19,7 +19,7 @@ class TapoPlug(TapoDevice):
host: str,
*,
config: Optional[DeviceConfig] = None,
protocol: Optional[BaseProtocol] = None,
protocol: Optional[SmartProtocol] = None,
) -> None:
super().__init__(host=host, config=config, protocol=protocol)
self._device_type = DeviceType.Plug

View File

@ -482,6 +482,9 @@ def _get_subclasses(of_class):
"class_name_obj", _get_subclasses(BaseProtocol), ids=lambda t: t[0]
)
def test_protocol_init_signature(class_name_obj):
if class_name_obj[0].startswith("_"):
pytest.skip("Skipping internal protocols")
return
params = list(inspect.signature(class_name_obj[1].__init__).parameters.values())
assert len(params) == 2

View File

@ -1,16 +1,8 @@
import errno
import json
import logging
import secrets
import struct
import sys
import time
from contextlib import nullcontext as does_not_raise
from itertools import chain
from typing import Dict
import pytest
from ..aestransport import AesTransport
from ..credentials import Credentials
from ..deviceconfig import DeviceConfig
from ..exceptions import (
@ -19,9 +11,8 @@ from ..exceptions import (
SmartDeviceException,
SmartErrorCode,
)
from ..iotprotocol import IotProtocol
from ..klaptransport import KlapEncryptionSession, KlapTransport, _sha256
from ..smartprotocol import SmartProtocol
from ..protocol import BaseTransport
from ..smartprotocol import SmartProtocol, _ChildProtocolWrapper
DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}}
DUMMY_MULTIPLE_QUERY = {
@ -31,20 +22,45 @@ DUMMY_MULTIPLE_QUERY = {
ERRORS = [e for e in SmartErrorCode if e != 0]
# TODO: this could be moved to conftest to make it available for other tests?
@pytest.fixture()
def dummy_protocol():
"""Return a smart protocol instance with a mocking-ready dummy transport."""
class DummyTransport(BaseTransport):
@property
def default_port(self) -> int:
return -1
@property
def credentials_hash(self) -> str:
return "dummy hash"
async def send(self, request: str) -> Dict:
return {}
async def close(self) -> None:
pass
async def reset(self) -> None:
pass
transport = DummyTransport(config=DeviceConfig(host="127.0.0.123"))
protocol = SmartProtocol(transport=transport)
return protocol
@pytest.mark.parametrize("error_code", ERRORS, ids=lambda e: e.name)
async def test_smart_device_errors(mocker, error_code):
host = "127.0.0.1"
async def test_smart_device_errors(dummy_protocol, mocker, error_code):
mock_response = {"result": {"great": "success"}, "error_code": error_code.value}
mocker.patch.object(AesTransport, "perform_handshake")
mocker.patch.object(AesTransport, "perform_login")
send_mock = mocker.patch.object(
dummy_protocol._transport, "send", return_value=mock_response
)
send_mock = mocker.patch.object(AesTransport, "send", return_value=mock_response)
config = DeviceConfig(host, credentials=Credentials("foo", "bar"))
protocol = SmartProtocol(transport=AesTransport(config=config))
with pytest.raises(SmartDeviceException):
await protocol.query(DUMMY_QUERY, retry_count=2)
await dummy_protocol.query(DUMMY_QUERY, retry_count=2)
if error_code in chain(SMART_TIMEOUT_ERRORS, SMART_RETRYABLE_ERRORS):
expected_calls = 3
@ -54,8 +70,9 @@ async def test_smart_device_errors(mocker, error_code):
@pytest.mark.parametrize("error_code", ERRORS, ids=lambda e: e.name)
async def test_smart_device_errors_in_multiple_request(mocker, error_code):
host = "127.0.0.1"
async def test_smart_device_errors_in_multiple_request(
dummy_protocol, mocker, error_code
):
mock_response = {
"result": {
"responses": [
@ -71,14 +88,11 @@ async def test_smart_device_errors_in_multiple_request(mocker, error_code):
"error_code": 0,
}
mocker.patch.object(AesTransport, "perform_handshake")
mocker.patch.object(AesTransport, "perform_login")
send_mock = mocker.patch.object(AesTransport, "send", return_value=mock_response)
config = DeviceConfig(host, credentials=Credentials("foo", "bar"))
protocol = SmartProtocol(transport=AesTransport(config=config))
send_mock = mocker.patch.object(
dummy_protocol._transport, "send", return_value=mock_response
)
with pytest.raises(SmartDeviceException):
await protocol.query(DUMMY_MULTIPLE_QUERY, retry_count=2)
await dummy_protocol.query(DUMMY_MULTIPLE_QUERY, retry_count=2)
if error_code in chain(SMART_TIMEOUT_ERRORS, SMART_RETRYABLE_ERRORS):
expected_calls = 3
else:
@ -88,7 +102,9 @@ async def test_smart_device_errors_in_multiple_request(mocker, error_code):
@pytest.mark.parametrize("request_size", [1, 3, 5, 10])
@pytest.mark.parametrize("batch_size", [1, 2, 3, 4, 5])
async def test_smart_device_multiple_request(mocker, request_size, batch_size):
async def test_smart_device_multiple_request(
dummy_protocol, mocker, request_size, batch_size
):
host = "127.0.0.1"
requests = {}
mock_response = {
@ -102,15 +118,101 @@ async def test_smart_device_multiple_request(mocker, request_size, batch_size):
{"method": method, "result": {"great": "success"}, "error_code": 0}
)
mocker.patch.object(AesTransport, "perform_handshake")
mocker.patch.object(AesTransport, "perform_login")
send_mock = mocker.patch.object(AesTransport, "send", return_value=mock_response)
send_mock = mocker.patch.object(
dummy_protocol._transport, "send", return_value=mock_response
)
config = DeviceConfig(
host, credentials=Credentials("foo", "bar"), batch_size=batch_size
)
protocol = SmartProtocol(transport=AesTransport(config=config))
dummy_protocol._transport._config = config
await protocol.query(requests, retry_count=0)
await dummy_protocol.query(requests, retry_count=0)
expected_count = int(request_size / batch_size) + (request_size % batch_size > 0)
assert send_mock.call_count == expected_count
async def test_childdevicewrapper_unwrapping(dummy_protocol, mocker):
"""Test that responseData gets unwrapped correctly."""
wrapped_protocol = _ChildProtocolWrapper("dummyid", dummy_protocol)
mock_response = {"error_code": 0, "result": {"responseData": {"error_code": 0}}}
mocker.patch.object(wrapped_protocol._transport, "send", return_value=mock_response)
res = await wrapped_protocol.query(DUMMY_QUERY)
assert res == {"foobar": None}
async def test_childdevicewrapper_unwrapping_with_payload(dummy_protocol, mocker):
wrapped_protocol = _ChildProtocolWrapper("dummyid", dummy_protocol)
mock_response = {
"error_code": 0,
"result": {"responseData": {"error_code": 0, "result": {"bar": "bar"}}},
}
mocker.patch.object(wrapped_protocol._transport, "send", return_value=mock_response)
res = await wrapped_protocol.query(DUMMY_QUERY)
assert res == {"foobar": {"bar": "bar"}}
async def test_childdevicewrapper_error(dummy_protocol, mocker):
"""Test that errors inside the responseData payload cause an exception."""
wrapped_protocol = _ChildProtocolWrapper("dummyid", dummy_protocol)
mock_response = {"error_code": 0, "result": {"responseData": {"error_code": -1001}}}
mocker.patch.object(wrapped_protocol._transport, "send", return_value=mock_response)
with pytest.raises(SmartDeviceException):
await wrapped_protocol.query(DUMMY_QUERY)
@pytest.mark.skip("childprotocolwrapper does not yet support multirequests")
async def test_childdevicewrapper_unwrapping_multiplerequest(dummy_protocol, mocker):
"""Test that unwrapping multiplerequest works correctly."""
mock_response = {
"error_code": 0,
"result": {
"responseData": {
"result": {
"responses": [
{
"error_code": 0,
"method": "get_device_info",
"result": {"foo": "bar"},
},
{
"error_code": 0,
"method": "second_command",
"result": {"bar": "foo"},
},
]
}
}
},
}
mocker.patch.object(dummy_protocol._transport, "send", return_value=mock_response)
resp = await dummy_protocol.query(DUMMY_QUERY)
assert resp == {"get_device_info": {"foo": "bar"}, "second_command": {"bar": "foo"}}
@pytest.mark.skip("childprotocolwrapper does not yet support multirequests")
async def test_childdevicewrapper_multiplerequest_error(dummy_protocol, mocker):
"""Test that errors inside multipleRequest response of responseData raise an exception."""
mock_response = {
"error_code": 0,
"result": {
"responseData": {
"result": {
"responses": [
{
"error_code": 0,
"method": "get_device_info",
"result": {"foo": "bar"},
},
{"error_code": -1001, "method": "invalid_command"},
]
}
}
},
}
mocker.patch.object(dummy_protocol._transport, "send", return_value=mock_response)
with pytest.raises(SmartDeviceException):
await dummy_protocol.query(DUMMY_QUERY)