mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-04-26 16:46:23 +00:00
Initial support for tapos with child devices (#720)
* Add ChildDevice and ChildProtocolWrapper * Initialize & update children * Fix circular imports * Add dummy_protocol fixture and tests for unwrapping responseData * Use dummy_protocol for existing smartprotocol tests * Move _ChildProtocolWrapper to smartprotocol.py * Use dummy_protocol for test multiple requests * Use device_id instead of position for selecting the child * Fix wrapping for regular requests * Remove unused imports * tweak * rename child_device to childdevice * Fix import
This commit is contained in:
parent
b479b6d84d
commit
1ad2a05b65
@ -279,3 +279,70 @@ class SnowflakeId:
|
||||
while timestamp <= last_timestamp:
|
||||
timestamp = self._current_millis()
|
||||
return timestamp
|
||||
|
||||
|
||||
class _ChildProtocolWrapper(SmartProtocol):
|
||||
"""Protocol wrapper for controlling child devices.
|
||||
|
||||
This is an internal class used to communicate with child devices,
|
||||
and should not be used directly.
|
||||
|
||||
This class overrides query() method of the protocol to modify all
|
||||
outgoing queries to use ``control_child`` command, and unwraps the
|
||||
device responses before returning to the caller.
|
||||
"""
|
||||
|
||||
def __init__(self, device_id: str, base_protocol: SmartProtocol):
|
||||
self._device_id = device_id
|
||||
self._protocol = base_protocol
|
||||
self._transport = base_protocol._transport
|
||||
|
||||
def _get_method_and_params_for_request(self, request):
|
||||
"""Return payload for wrapping.
|
||||
|
||||
TODO: this does not support batches and requires refactoring in the future.
|
||||
"""
|
||||
if isinstance(request, dict):
|
||||
if len(request) == 1:
|
||||
smart_method = next(iter(request))
|
||||
smart_params = request[smart_method]
|
||||
else:
|
||||
smart_method = "multipleRequest"
|
||||
requests = [
|
||||
{"method": method, "params": params}
|
||||
for method, params in request.items()
|
||||
]
|
||||
smart_params = {"requests": requests}
|
||||
else:
|
||||
smart_method = request
|
||||
smart_params = None
|
||||
|
||||
return smart_method, smart_params
|
||||
|
||||
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
||||
"""Wrap request inside control_child envelope."""
|
||||
method, params = self._get_method_and_params_for_request(request)
|
||||
request_data = {
|
||||
"method": method,
|
||||
"params": params,
|
||||
}
|
||||
wrapped_payload = {
|
||||
"control_child": {
|
||||
"device_id": self._device_id,
|
||||
"requestData": request_data,
|
||||
}
|
||||
}
|
||||
|
||||
response = await self._protocol.query(wrapped_payload, retry_count)
|
||||
result = response.get("control_child")
|
||||
# Unwrap responseData for control_child
|
||||
if result and (response_data := result.get("responseData")):
|
||||
self._handle_response_error_code(response_data)
|
||||
result = response_data.get("result")
|
||||
|
||||
# TODO: handle multipleRequest unwrapping
|
||||
|
||||
return {method: result}
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Do nothing as the parent owns the protocol."""
|
||||
|
44
kasa/tapo/childdevice.py
Normal file
44
kasa/tapo/childdevice.py
Normal file
@ -0,0 +1,44 @@
|
||||
"""Child device implementation."""
|
||||
from typing import Dict, Optional
|
||||
|
||||
from ..deviceconfig import DeviceConfig
|
||||
from ..exceptions import SmartDeviceException
|
||||
from ..smartprotocol import SmartProtocol, _ChildProtocolWrapper
|
||||
from .tapodevice import TapoDevice
|
||||
|
||||
|
||||
class ChildDevice(TapoDevice):
|
||||
"""Presentation of a child device.
|
||||
|
||||
This wraps the protocol communications and sets internal data for the child.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent: TapoDevice,
|
||||
child_id: str,
|
||||
config: Optional[DeviceConfig] = None,
|
||||
protocol: Optional[SmartProtocol] = None,
|
||||
) -> None:
|
||||
super().__init__(parent.host, config=parent.config, protocol=parent.protocol)
|
||||
self._parent = parent
|
||||
self._id = child_id
|
||||
self.protocol = _ChildProtocolWrapper(child_id, parent.protocol)
|
||||
|
||||
async def update(self, update_children: bool = True):
|
||||
"""We just set the info here accordingly."""
|
||||
|
||||
def _get_child_info() -> Dict:
|
||||
"""Return the subdevice information for this device."""
|
||||
for child in self._parent._last_update["child_info"]["child_device_list"]:
|
||||
if child["device_id"] == self._id:
|
||||
return child
|
||||
|
||||
raise SmartDeviceException(
|
||||
f"Unable to find child device with position {self._id}"
|
||||
)
|
||||
|
||||
self._last_update = self._sys_info = self._info = _get_child_info()
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ChildDevice {self.alias} of {self._parent}>"
|
@ -5,11 +5,11 @@ from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, List, Optional, Set, cast
|
||||
|
||||
from ..aestransport import AesTransport
|
||||
from ..device_type import DeviceType
|
||||
from ..deviceconfig import DeviceConfig
|
||||
from ..emeterstatus import EmeterStatus
|
||||
from ..exceptions import AuthenticationException, SmartDeviceException
|
||||
from ..modules import Emeter
|
||||
from ..protocol import BaseProtocol
|
||||
from ..smartdevice import SmartDevice, WifiNetwork
|
||||
from ..smartprotocol import SmartProtocol
|
||||
|
||||
@ -24,17 +24,27 @@ class TapoDevice(SmartDevice):
|
||||
host: str,
|
||||
*,
|
||||
config: Optional[DeviceConfig] = None,
|
||||
protocol: Optional[BaseProtocol] = None,
|
||||
protocol: Optional[SmartProtocol] = None,
|
||||
) -> None:
|
||||
_protocol = protocol or SmartProtocol(
|
||||
transport=AesTransport(config=config or DeviceConfig(host=host)),
|
||||
)
|
||||
super().__init__(host=host, config=config, protocol=_protocol)
|
||||
self.protocol: SmartProtocol
|
||||
self._components_raw: Optional[Dict[str, Any]] = None
|
||||
self._components: Dict[str, int]
|
||||
self._state_information: Dict[str, Any] = {}
|
||||
self._discovery_info: Optional[Dict[str, Any]] = None
|
||||
self.modules: Dict[str, Any] = {}
|
||||
|
||||
async def _initialize_children(self):
|
||||
children = self._last_update["child_info"]["child_device_list"]
|
||||
# TODO: Use the type information to construct children,
|
||||
# as hubs can also have them.
|
||||
from .childdevice import ChildDevice
|
||||
|
||||
self.children = [
|
||||
ChildDevice(parent=self, child_id=child["device_id"]) for child in children
|
||||
]
|
||||
self._device_type = DeviceType.Strip
|
||||
|
||||
async def update(self, update_children: bool = True):
|
||||
"""Update the device."""
|
||||
@ -51,6 +61,10 @@ class TapoDevice(SmartDevice):
|
||||
await self._initialize_modules()
|
||||
|
||||
extra_reqs: Dict[str, Any] = {}
|
||||
|
||||
if "child_device" in self._components:
|
||||
extra_reqs = {**extra_reqs, "get_child_device_list": None}
|
||||
|
||||
if "energy_monitoring" in self._components:
|
||||
extra_reqs = {
|
||||
**extra_reqs,
|
||||
@ -81,8 +95,15 @@ class TapoDevice(SmartDevice):
|
||||
"time": self._time,
|
||||
"energy": self._energy,
|
||||
"emeter": self._emeter,
|
||||
"child_info": resp.get("get_child_device_list", {}),
|
||||
}
|
||||
|
||||
if self._last_update["child_info"]:
|
||||
if not self.children:
|
||||
await self._initialize_children()
|
||||
for child in self.children:
|
||||
await child.update()
|
||||
|
||||
_LOGGER.debug("Got an update: %s", self._data)
|
||||
|
||||
async def _initialize_modules(self):
|
||||
|
@ -4,8 +4,8 @@ from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, Optional, cast
|
||||
|
||||
from ..deviceconfig import DeviceConfig
|
||||
from ..protocol import BaseProtocol
|
||||
from ..smartdevice import DeviceType
|
||||
from ..smartprotocol import SmartProtocol
|
||||
from .tapodevice import TapoDevice
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@ -19,7 +19,7 @@ class TapoPlug(TapoDevice):
|
||||
host: str,
|
||||
*,
|
||||
config: Optional[DeviceConfig] = None,
|
||||
protocol: Optional[BaseProtocol] = None,
|
||||
protocol: Optional[SmartProtocol] = None,
|
||||
) -> None:
|
||||
super().__init__(host=host, config=config, protocol=protocol)
|
||||
self._device_type = DeviceType.Plug
|
||||
|
@ -482,6 +482,9 @@ def _get_subclasses(of_class):
|
||||
"class_name_obj", _get_subclasses(BaseProtocol), ids=lambda t: t[0]
|
||||
)
|
||||
def test_protocol_init_signature(class_name_obj):
|
||||
if class_name_obj[0].startswith("_"):
|
||||
pytest.skip("Skipping internal protocols")
|
||||
return
|
||||
params = list(inspect.signature(class_name_obj[1].__init__).parameters.values())
|
||||
|
||||
assert len(params) == 2
|
||||
|
@ -1,16 +1,8 @@
|
||||
import errno
|
||||
import json
|
||||
import logging
|
||||
import secrets
|
||||
import struct
|
||||
import sys
|
||||
import time
|
||||
from contextlib import nullcontext as does_not_raise
|
||||
from itertools import chain
|
||||
from typing import Dict
|
||||
|
||||
import pytest
|
||||
|
||||
from ..aestransport import AesTransport
|
||||
from ..credentials import Credentials
|
||||
from ..deviceconfig import DeviceConfig
|
||||
from ..exceptions import (
|
||||
@ -19,9 +11,8 @@ from ..exceptions import (
|
||||
SmartDeviceException,
|
||||
SmartErrorCode,
|
||||
)
|
||||
from ..iotprotocol import IotProtocol
|
||||
from ..klaptransport import KlapEncryptionSession, KlapTransport, _sha256
|
||||
from ..smartprotocol import SmartProtocol
|
||||
from ..protocol import BaseTransport
|
||||
from ..smartprotocol import SmartProtocol, _ChildProtocolWrapper
|
||||
|
||||
DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}}
|
||||
DUMMY_MULTIPLE_QUERY = {
|
||||
@ -31,20 +22,45 @@ DUMMY_MULTIPLE_QUERY = {
|
||||
ERRORS = [e for e in SmartErrorCode if e != 0]
|
||||
|
||||
|
||||
# TODO: this could be moved to conftest to make it available for other tests?
|
||||
@pytest.fixture()
|
||||
def dummy_protocol():
|
||||
"""Return a smart protocol instance with a mocking-ready dummy transport."""
|
||||
|
||||
class DummyTransport(BaseTransport):
|
||||
@property
|
||||
def default_port(self) -> int:
|
||||
return -1
|
||||
|
||||
@property
|
||||
def credentials_hash(self) -> str:
|
||||
return "dummy hash"
|
||||
|
||||
async def send(self, request: str) -> Dict:
|
||||
return {}
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
async def reset(self) -> None:
|
||||
pass
|
||||
|
||||
transport = DummyTransport(config=DeviceConfig(host="127.0.0.123"))
|
||||
protocol = SmartProtocol(transport=transport)
|
||||
|
||||
return protocol
|
||||
|
||||
|
||||
@pytest.mark.parametrize("error_code", ERRORS, ids=lambda e: e.name)
|
||||
async def test_smart_device_errors(mocker, error_code):
|
||||
host = "127.0.0.1"
|
||||
async def test_smart_device_errors(dummy_protocol, mocker, error_code):
|
||||
mock_response = {"result": {"great": "success"}, "error_code": error_code.value}
|
||||
|
||||
mocker.patch.object(AesTransport, "perform_handshake")
|
||||
mocker.patch.object(AesTransport, "perform_login")
|
||||
send_mock = mocker.patch.object(
|
||||
dummy_protocol._transport, "send", return_value=mock_response
|
||||
)
|
||||
|
||||
send_mock = mocker.patch.object(AesTransport, "send", return_value=mock_response)
|
||||
|
||||
config = DeviceConfig(host, credentials=Credentials("foo", "bar"))
|
||||
protocol = SmartProtocol(transport=AesTransport(config=config))
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await protocol.query(DUMMY_QUERY, retry_count=2)
|
||||
await dummy_protocol.query(DUMMY_QUERY, retry_count=2)
|
||||
|
||||
if error_code in chain(SMART_TIMEOUT_ERRORS, SMART_RETRYABLE_ERRORS):
|
||||
expected_calls = 3
|
||||
@ -54,8 +70,9 @@ async def test_smart_device_errors(mocker, error_code):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("error_code", ERRORS, ids=lambda e: e.name)
|
||||
async def test_smart_device_errors_in_multiple_request(mocker, error_code):
|
||||
host = "127.0.0.1"
|
||||
async def test_smart_device_errors_in_multiple_request(
|
||||
dummy_protocol, mocker, error_code
|
||||
):
|
||||
mock_response = {
|
||||
"result": {
|
||||
"responses": [
|
||||
@ -71,14 +88,11 @@ async def test_smart_device_errors_in_multiple_request(mocker, error_code):
|
||||
"error_code": 0,
|
||||
}
|
||||
|
||||
mocker.patch.object(AesTransport, "perform_handshake")
|
||||
mocker.patch.object(AesTransport, "perform_login")
|
||||
|
||||
send_mock = mocker.patch.object(AesTransport, "send", return_value=mock_response)
|
||||
config = DeviceConfig(host, credentials=Credentials("foo", "bar"))
|
||||
protocol = SmartProtocol(transport=AesTransport(config=config))
|
||||
send_mock = mocker.patch.object(
|
||||
dummy_protocol._transport, "send", return_value=mock_response
|
||||
)
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await protocol.query(DUMMY_MULTIPLE_QUERY, retry_count=2)
|
||||
await dummy_protocol.query(DUMMY_MULTIPLE_QUERY, retry_count=2)
|
||||
if error_code in chain(SMART_TIMEOUT_ERRORS, SMART_RETRYABLE_ERRORS):
|
||||
expected_calls = 3
|
||||
else:
|
||||
@ -88,7 +102,9 @@ async def test_smart_device_errors_in_multiple_request(mocker, error_code):
|
||||
|
||||
@pytest.mark.parametrize("request_size", [1, 3, 5, 10])
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 3, 4, 5])
|
||||
async def test_smart_device_multiple_request(mocker, request_size, batch_size):
|
||||
async def test_smart_device_multiple_request(
|
||||
dummy_protocol, mocker, request_size, batch_size
|
||||
):
|
||||
host = "127.0.0.1"
|
||||
requests = {}
|
||||
mock_response = {
|
||||
@ -102,15 +118,101 @@ async def test_smart_device_multiple_request(mocker, request_size, batch_size):
|
||||
{"method": method, "result": {"great": "success"}, "error_code": 0}
|
||||
)
|
||||
|
||||
mocker.patch.object(AesTransport, "perform_handshake")
|
||||
mocker.patch.object(AesTransport, "perform_login")
|
||||
|
||||
send_mock = mocker.patch.object(AesTransport, "send", return_value=mock_response)
|
||||
send_mock = mocker.patch.object(
|
||||
dummy_protocol._transport, "send", return_value=mock_response
|
||||
)
|
||||
config = DeviceConfig(
|
||||
host, credentials=Credentials("foo", "bar"), batch_size=batch_size
|
||||
)
|
||||
protocol = SmartProtocol(transport=AesTransport(config=config))
|
||||
dummy_protocol._transport._config = config
|
||||
|
||||
await protocol.query(requests, retry_count=0)
|
||||
await dummy_protocol.query(requests, retry_count=0)
|
||||
expected_count = int(request_size / batch_size) + (request_size % batch_size > 0)
|
||||
assert send_mock.call_count == expected_count
|
||||
|
||||
|
||||
async def test_childdevicewrapper_unwrapping(dummy_protocol, mocker):
|
||||
"""Test that responseData gets unwrapped correctly."""
|
||||
wrapped_protocol = _ChildProtocolWrapper("dummyid", dummy_protocol)
|
||||
mock_response = {"error_code": 0, "result": {"responseData": {"error_code": 0}}}
|
||||
|
||||
mocker.patch.object(wrapped_protocol._transport, "send", return_value=mock_response)
|
||||
res = await wrapped_protocol.query(DUMMY_QUERY)
|
||||
assert res == {"foobar": None}
|
||||
|
||||
|
||||
async def test_childdevicewrapper_unwrapping_with_payload(dummy_protocol, mocker):
|
||||
wrapped_protocol = _ChildProtocolWrapper("dummyid", dummy_protocol)
|
||||
mock_response = {
|
||||
"error_code": 0,
|
||||
"result": {"responseData": {"error_code": 0, "result": {"bar": "bar"}}},
|
||||
}
|
||||
mocker.patch.object(wrapped_protocol._transport, "send", return_value=mock_response)
|
||||
res = await wrapped_protocol.query(DUMMY_QUERY)
|
||||
assert res == {"foobar": {"bar": "bar"}}
|
||||
|
||||
|
||||
async def test_childdevicewrapper_error(dummy_protocol, mocker):
|
||||
"""Test that errors inside the responseData payload cause an exception."""
|
||||
wrapped_protocol = _ChildProtocolWrapper("dummyid", dummy_protocol)
|
||||
mock_response = {"error_code": 0, "result": {"responseData": {"error_code": -1001}}}
|
||||
|
||||
mocker.patch.object(wrapped_protocol._transport, "send", return_value=mock_response)
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await wrapped_protocol.query(DUMMY_QUERY)
|
||||
|
||||
|
||||
@pytest.mark.skip("childprotocolwrapper does not yet support multirequests")
|
||||
async def test_childdevicewrapper_unwrapping_multiplerequest(dummy_protocol, mocker):
|
||||
"""Test that unwrapping multiplerequest works correctly."""
|
||||
mock_response = {
|
||||
"error_code": 0,
|
||||
"result": {
|
||||
"responseData": {
|
||||
"result": {
|
||||
"responses": [
|
||||
{
|
||||
"error_code": 0,
|
||||
"method": "get_device_info",
|
||||
"result": {"foo": "bar"},
|
||||
},
|
||||
{
|
||||
"error_code": 0,
|
||||
"method": "second_command",
|
||||
"result": {"bar": "foo"},
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
mocker.patch.object(dummy_protocol._transport, "send", return_value=mock_response)
|
||||
resp = await dummy_protocol.query(DUMMY_QUERY)
|
||||
assert resp == {"get_device_info": {"foo": "bar"}, "second_command": {"bar": "foo"}}
|
||||
|
||||
|
||||
@pytest.mark.skip("childprotocolwrapper does not yet support multirequests")
|
||||
async def test_childdevicewrapper_multiplerequest_error(dummy_protocol, mocker):
|
||||
"""Test that errors inside multipleRequest response of responseData raise an exception."""
|
||||
mock_response = {
|
||||
"error_code": 0,
|
||||
"result": {
|
||||
"responseData": {
|
||||
"result": {
|
||||
"responses": [
|
||||
{
|
||||
"error_code": 0,
|
||||
"method": "get_device_info",
|
||||
"result": {"foo": "bar"},
|
||||
},
|
||||
{"error_code": -1001, "method": "invalid_command"},
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
mocker.patch.object(dummy_protocol._transport, "send", return_value=mock_response)
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await dummy_protocol.query(DUMMY_QUERY)
|
||||
|
Loading…
x
Reference in New Issue
Block a user