Generalize smartdevice child support (#775)

* Initialize children's modules (and features) using the child component negotiation results
* Set device_type based on the device response
* Print out child features in cli 'state'
* Add --child option to cli 'command' to allow targeting child devices
* Guard "generic" features like rssi, ssid, etc. only to devices which have this information

Note, we do not currently perform queries on child modules so some data may not be available. At the moment, a stop-gap solution to use parent's data is used but this is not always correct; even if the device shares the same clock and cloud connectivity, it may have its own firmware updates.
This commit is contained in:
Teemu R 2024-02-22 20:46:19 +01:00 committed by GitHub
parent f965b14021
commit 2b0721aea9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 198 additions and 99 deletions

View File

@ -582,9 +582,14 @@ async def state(ctx, dev: Device):
echo(f"\tPort: {dev.port}") echo(f"\tPort: {dev.port}")
echo(f"\tDevice state: {dev.is_on}") echo(f"\tDevice state: {dev.is_on}")
if dev.is_strip: if dev.is_strip:
echo("\t[bold]== Plugs ==[/bold]") echo("\t[bold]== Children ==[/bold]")
for plug in dev.children: # type: ignore for child in dev.children:
echo(f"\t* Socket '{plug.alias}' state: {plug.is_on} since {plug.on_since}") echo(f"\t* {child.alias} ({child.model}, {child.device_type})")
for feat in child.features.values():
try:
echo(f"\t\t{feat.name}: {feat.value}")
except Exception as ex:
echo(f"\t\t{feat.name}: got exception (%s)" % ex)
echo() echo()
echo("\t[bold]== Generic information ==[/bold]") echo("\t[bold]== Generic information ==[/bold]")
@ -665,13 +670,22 @@ async def raw_command(ctx, dev: Device, module, command, parameters):
@cli.command(name="command") @cli.command(name="command")
@pass_dev @pass_dev
@click.option("--module", required=False, help="Module for IOT protocol.") @click.option("--module", required=False, help="Module for IOT protocol.")
@click.option("--child", required=False, help="Child ID for controlling sub-devices")
@click.argument("command") @click.argument("command")
@click.argument("parameters", default=None, required=False) @click.argument("parameters", default=None, required=False)
async def cmd_command(dev: Device, module, command, parameters): async def cmd_command(dev: Device, module, child, command, parameters):
"""Run a raw command on the device.""" """Run a raw command on the device."""
if parameters is not None: if parameters is not None:
parameters = ast.literal_eval(parameters) parameters = ast.literal_eval(parameters)
if child:
# The way child devices are accessed requires a ChildDevice to
# wrap the communications. Doing this properly would require creating
# a common interfaces for both IOT and SMART child devices.
# As a stop-gap solution, we perform an update instead.
await dev.update()
dev = dev.get_child_device(child)
if isinstance(dev, IotDevice): if isinstance(dev, IotDevice):
res = await dev._query_helper(module, command, parameters) res = await dev._query_helper(module, command, parameters)
elif isinstance(dev, SmartDevice): elif isinstance(dev, SmartDevice):

View File

@ -3,7 +3,7 @@ import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Optional, Sequence, Union from typing import Any, Dict, List, Mapping, Optional, Sequence, Union
from .credentials import Credentials from .credentials import Credentials
from .device_type import DeviceType from .device_type import DeviceType
@ -71,6 +71,8 @@ class Device(ABC):
self.modules: Dict[str, Any] = {} self.modules: Dict[str, Any] = {}
self._features: Dict[str, Feature] = {} self._features: Dict[str, Feature] = {}
self._parent: Optional["Device"] = None
self._children: Mapping[str, "Device"] = {}
@staticmethod @staticmethod
async def connect( async def connect(
@ -182,9 +184,13 @@ class Device(ABC):
return await self.protocol.query(request=request) return await self.protocol.query(request=request)
@property @property
@abstractmethod
def children(self) -> Sequence["Device"]: def children(self) -> Sequence["Device"]:
"""Returns the child devices.""" """Returns the child devices."""
return list(self._children.values())
def get_child_device(self, id_: str) -> "Device":
"""Return child device by its ID."""
return self._children[id_]
@property @property
@abstractmethod @abstractmethod

View File

@ -14,6 +14,7 @@ class DeviceType(Enum):
StripSocket = "stripsocket" StripSocket = "stripsocket"
Dimmer = "dimmer" Dimmer = "dimmer"
LightStrip = "lightstrip" LightStrip = "lightstrip"
Sensor = "sensor"
Unknown = "unknown" Unknown = "unknown"
@staticmethod @staticmethod

View File

@ -16,7 +16,7 @@ import functools
import inspect import inspect
import logging import logging
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Sequence, Set from typing import Any, Dict, List, Mapping, Optional, Sequence, Set
from ..device import Device, WifiNetwork from ..device import Device, WifiNetwork
from ..deviceconfig import DeviceConfig from ..deviceconfig import DeviceConfig
@ -183,19 +183,14 @@ class IotDevice(Device):
super().__init__(host=host, config=config, protocol=protocol) super().__init__(host=host, config=config, protocol=protocol)
self._sys_info: Any = None # TODO: this is here to avoid changing tests self._sys_info: Any = None # TODO: this is here to avoid changing tests
self._children: Sequence["IotDevice"] = []
self._supported_modules: Optional[Dict[str, IotModule]] = None self._supported_modules: Optional[Dict[str, IotModule]] = None
self._legacy_features: Set[str] = set() self._legacy_features: Set[str] = set()
self._children: Mapping[str, "IotDevice"] = {}
@property @property
def children(self) -> Sequence["IotDevice"]: def children(self) -> Sequence["IotDevice"]:
"""Return list of children.""" """Return list of children."""
return self._children return list(self._children.values())
@children.setter
def children(self, children):
"""Initialize from a list of children."""
self._children = children
def add_module(self, name: str, module: IotModule): def add_module(self, name: str, module: IotModule):
"""Register a module.""" """Register a module."""
@ -408,15 +403,6 @@ class IotDevice(Device):
sys_info = self._sys_info sys_info = self._sys_info
return str(sys_info["model"]) return str(sys_info["model"])
@property
def has_children(self) -> bool:
"""Return true if the device has children devices."""
# Ideally we would check for the 'child_num' key in sys_info,
# but devices that speak klap do not populate this key via
# update_from_discover_info so we check for the devices
# we know have children instead.
return self.is_strip
@property # type: ignore @property # type: ignore
def alias(self) -> Optional[str]: def alias(self) -> Optional[str]:
"""Return device name (alias).""" """Return device name (alias)."""

View File

@ -115,10 +115,12 @@ class IotStrip(IotDevice):
if not self.children: if not self.children:
children = self.sys_info["children"] children = self.sys_info["children"]
_LOGGER.debug("Initializing %s child sockets", len(children)) _LOGGER.debug("Initializing %s child sockets", len(children))
self.children = [ self._children = {
IotStripPlug(self.host, parent=self, child_id=child["id"]) f"{self.mac}_{child['id']}": IotStripPlug(
self.host, parent=self, child_id=child["id"]
)
for child in children for child in children
] }
if update_children and self.has_emeter: if update_children and self.has_emeter:
for plug in self.children: for plug in self.children:

View File

@ -1,4 +1,6 @@
"""Implementation for child devices.""" """Implementation for child devices."""
from typing import Dict
from ..smartmodule import SmartModule from ..smartmodule import SmartModule
@ -6,4 +8,12 @@ class ChildDeviceModule(SmartModule):
"""Implementation for child devices.""" """Implementation for child devices."""
REQUIRED_COMPONENT = "child_device" REQUIRED_COMPONENT = "child_device"
QUERY_GETTER_NAME = "get_child_device_list"
def query(self) -> Dict:
"""Query to execute during the update cycle."""
# TODO: There is no need to fetch the component list every time,
# so this should be optimized only for the init.
return {
"get_child_device_list": None,
"get_child_device_component_list": None,
}

View File

@ -1,4 +1,5 @@
"""Child device implementation.""" """Child device implementation."""
import logging
from typing import Optional from typing import Optional
from ..device_type import DeviceType from ..device_type import DeviceType
@ -6,6 +7,8 @@ from ..deviceconfig import DeviceConfig
from ..smartprotocol import SmartProtocol, _ChildProtocolWrapper from ..smartprotocol import SmartProtocol, _ChildProtocolWrapper
from .smartdevice import SmartDevice from .smartdevice import SmartDevice
_LOGGER = logging.getLogger(__name__)
class SmartChildDevice(SmartDevice): class SmartChildDevice(SmartDevice):
"""Presentation of a child device. """Presentation of a child device.
@ -16,23 +19,41 @@ class SmartChildDevice(SmartDevice):
def __init__( def __init__(
self, self,
parent: SmartDevice, parent: SmartDevice,
child_id: str, info,
component_info,
config: Optional[DeviceConfig] = None, config: Optional[DeviceConfig] = None,
protocol: Optional[SmartProtocol] = None, protocol: Optional[SmartProtocol] = None,
) -> None: ) -> None:
super().__init__(parent.host, config=parent.config, protocol=parent.protocol) super().__init__(parent.host, config=parent.config, protocol=parent.protocol)
self._parent = parent self._parent = parent
self._id = child_id self._update_internal_state(info)
self.protocol = _ChildProtocolWrapper(child_id, parent.protocol) self._components = component_info
self._device_type = DeviceType.StripSocket self._id = info["device_id"]
self.protocol = _ChildProtocolWrapper(self._id, parent.protocol)
async def update(self, update_children: bool = True): async def update(self, update_children: bool = True):
"""Noop update. The parent updates our internals.""" """Noop update. The parent updates our internals."""
def update_internal_state(self, info): @classmethod
"""Set internal state for the child.""" async def create(cls, parent: SmartDevice, child_info, child_components):
# TODO: cleanup the _last_update, _sys_info, _info, _data mess. """Create a child device based on device info and component listing."""
self._last_update = self._sys_info = self._info = info child: "SmartChildDevice" = cls(parent, child_info, child_components)
await child._initialize_modules()
await child._initialize_features()
return child
@property
def device_type(self) -> DeviceType:
"""Return child device type."""
child_device_map = {
"plug.powerstrip.sub-plug": DeviceType.Plug,
"subg.trigger.temp-hmdt-sensor": DeviceType.Sensor,
}
dev_type = child_device_map.get(self.sys_info["category"])
if dev_type is None:
_LOGGER.warning("Unknown child device type, please open issue ")
dev_type = DeviceType.Unknown
return dev_type
def __repr__(self): def __repr__(self):
return f"<ChildDevice {self.alias} of {self._parent}>" return f"<ChildDevice {self.alias} of {self._parent}>"

View File

@ -2,7 +2,7 @@
import base64 import base64
import logging import logging
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, cast from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence, cast
from ..aestransport import AesTransport from ..aestransport import AesTransport
from ..device import Device, WifiNetwork from ..device import Device, WifiNetwork
@ -12,22 +12,12 @@ from ..emeterstatus import EmeterStatus
from ..exceptions import AuthenticationError, DeviceError, KasaException, SmartErrorCode from ..exceptions import AuthenticationError, DeviceError, KasaException, SmartErrorCode
from ..feature import Feature, FeatureType from ..feature import Feature, FeatureType
from ..smartprotocol import SmartProtocol from ..smartprotocol import SmartProtocol
from .modules import ( # noqa: F401 from .modules import * # noqa: F403
AutoOffModule,
ChildDeviceModule,
CloudModule,
DeviceModule,
EnergyModule,
LedModule,
LightTransitionModule,
TimeModule,
)
from .smartmodule import SmartModule
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from .smartchilddevice import SmartChildDevice from .smartmodule import SmartModule
class SmartDevice(Device): class SmartDevice(Device):
@ -47,23 +37,34 @@ class SmartDevice(Device):
self.protocol: SmartProtocol self.protocol: SmartProtocol
self._components_raw: Optional[Dict[str, Any]] = None self._components_raw: Optional[Dict[str, Any]] = None
self._components: Dict[str, int] = {} self._components: Dict[str, int] = {}
self._children: Dict[str, "SmartChildDevice"] = {}
self._state_information: Dict[str, Any] = {} self._state_information: Dict[str, Any] = {}
self.modules: Dict[str, SmartModule] = {} self.modules: Dict[str, "SmartModule"] = {}
self._parent: Optional["SmartDevice"] = None
self._children: Mapping[str, "SmartDevice"] = {}
async def _initialize_children(self): async def _initialize_children(self):
"""Initialize children for power strips.""" """Initialize children for power strips."""
children = self._last_update["child_info"]["child_device_list"] children = self.internal_state["child_info"]["child_device_list"]
# TODO: Use the type information to construct children, children_components = {
# as hubs can also have them. child["device_id"]: {
comp["id"]: int(comp["ver_code"]) for comp in child["component_list"]
}
for child in self.internal_state["get_child_device_component_list"][
"child_component_list"
]
}
from .smartchilddevice import SmartChildDevice from .smartchilddevice import SmartChildDevice
self._children = { self._children = {
child["device_id"]: SmartChildDevice( child_info["device_id"]: await SmartChildDevice.create(
parent=self, child_id=child["device_id"] parent=self,
child_info=child_info,
child_components=children_components[child_info["device_id"]],
) )
for child in children for child_info in children
} }
# TODO: if all are sockets, then we are a strip, and otherwise a hub?
# doesn't work for the walldimmer with fancontrol...
self._device_type = DeviceType.Strip self._device_type = DeviceType.Strip
@property @property
@ -126,8 +127,10 @@ class SmartDevice(Device):
if not self.children: if not self.children:
await self._initialize_children() await self._initialize_children()
# TODO: we don't currently perform queries on children based on modules,
# but just update the information that is returned in the main query.
for info in child_info["child_device_list"]: for info in child_info["child_device_list"]:
self._children[info["device_id"]].update_internal_state(info) self._children[info["device_id"]]._update_internal_state(info)
# We can first initialize the features after the first update. # We can first initialize the features after the first update.
# We make here an assumption that every device has at least a single feature. # We make here an assumption that every device has at least a single feature.
@ -153,6 +156,7 @@ class SmartDevice(Device):
async def _initialize_features(self): async def _initialize_features(self):
"""Initialize device features.""" """Initialize device features."""
self._add_feature(Feature(self, "Device ID", attribute_getter="device_id"))
if "device_on" in self._info: if "device_on" in self._info:
self._add_feature( self._add_feature(
Feature( Feature(
@ -164,25 +168,32 @@ class SmartDevice(Device):
) )
) )
self._add_feature( if "signal_level" in self._info:
Feature( self._add_feature(
self, Feature(
"Signal Level", self,
attribute_getter=lambda x: x._info["signal_level"], "Signal Level",
icon="mdi:signal", attribute_getter=lambda x: x._info["signal_level"],
icon="mdi:signal",
)
) )
)
self._add_feature( if "rssi" in self._info:
Feature( self._add_feature(
self, Feature(
"RSSI", self,
attribute_getter=lambda x: x._info["rssi"], "RSSI",
icon="mdi:signal", attribute_getter=lambda x: x._info["rssi"],
icon="mdi:signal",
)
)
if "ssid" in self._info:
self._add_feature(
Feature(
device=self, name="SSID", attribute_getter="ssid", icon="mdi:wifi"
)
) )
)
self._add_feature(
Feature(device=self, name="SSID", attribute_getter="ssid", icon="mdi:wifi")
)
if "overheated" in self._info: if "overheated" in self._info:
self._add_feature( self._add_feature(
@ -232,7 +243,12 @@ class SmartDevice(Device):
@property @property
def time(self) -> datetime: def time(self) -> datetime:
"""Return the time.""" """Return the time."""
_timemod = cast(TimeModule, self.modules["TimeModule"]) # TODO: Default to parent's time module for child devices
if self._parent and "TimeModule" in self.modules:
_timemod = cast(TimeModule, self._parent.modules["TimeModule"]) # noqa: F405
else:
_timemod = cast(TimeModule, self.modules["TimeModule"]) # noqa: F405
return _timemod.time return _timemod.time
@property @property
@ -284,6 +300,14 @@ class SmartDevice(Device):
"""Return all the internal state data.""" """Return all the internal state data."""
return self._last_update return self._last_update
def _update_internal_state(self, info):
"""Update internal state.
This is used by the parent to push updates to its children
"""
# TODO: cleanup the _last_update, _info mess.
self._last_update = self._info = info
async def _query_helper( async def _query_helper(
self, method: str, params: Optional[Dict] = None, child_ids=None self, method: str, params: Optional[Dict] = None, child_ids=None
) -> Any: ) -> Any:
@ -347,19 +371,19 @@ class SmartDevice(Device):
@property @property
def emeter_realtime(self) -> EmeterStatus: def emeter_realtime(self) -> EmeterStatus:
"""Get the emeter status.""" """Get the emeter status."""
energy = cast(EnergyModule, self.modules["EnergyModule"]) energy = cast(EnergyModule, self.modules["EnergyModule"]) # noqa: F405
return energy.emeter_realtime return energy.emeter_realtime
@property @property
def emeter_this_month(self) -> Optional[float]: def emeter_this_month(self) -> Optional[float]:
"""Get the emeter value for this month.""" """Get the emeter value for this month."""
energy = cast(EnergyModule, self.modules["EnergyModule"]) energy = cast(EnergyModule, self.modules["EnergyModule"]) # noqa: F405
return energy.emeter_this_month return energy.emeter_this_month
@property @property
def emeter_today(self) -> Optional[float]: def emeter_today(self) -> Optional[float]:
"""Get the emeter value for today.""" """Get the emeter value for today."""
energy = cast(EnergyModule, self.modules["EnergyModule"]) energy = cast(EnergyModule, self.modules["EnergyModule"]) # noqa: F405
return energy.emeter_today return energy.emeter_today
@property @property
@ -372,7 +396,7 @@ class SmartDevice(Device):
return None return None
on_time = cast(float, on_time) on_time = cast(float, on_time)
if (timemod := self.modules.get("TimeModule")) is not None: if (timemod := self.modules.get("TimeModule")) is not None:
timemod = cast(TimeModule, timemod) timemod = cast(TimeModule, timemod) # noqa: F405
return timemod.time - timedelta(seconds=on_time) return timemod.time - timedelta(seconds=on_time)
else: # We have no device time, use current local time. else: # We have no device time, use current local time.
return datetime.now().replace(microsecond=0) - timedelta(seconds=on_time) return datetime.now().replace(microsecond=0) - timedelta(seconds=on_time)

View File

@ -57,16 +57,25 @@ class SmartModule(Module):
""" """
q = self.query() q = self.query()
q_keys = list(q.keys()) q_keys = list(q.keys())
# TODO: hacky way to check if update has been called. query_key = q_keys[0]
if q_keys[0] not in self._device._last_update:
raise KasaException( dev = self._device
f"You need to call update() prior accessing module data"
f" for '{self._module}'" # TODO: hacky way to check if update has been called.
) # The way this falls back to parent may not always be wanted.
# Especially, devices can have their own firmware updates.
if query_key not in dev._last_update:
if dev._parent and query_key in dev._parent._last_update:
_LOGGER.debug("%s not found child, but found on parent", query_key)
dev = dev._parent
else:
raise KasaException(
f"You need to call update() prior accessing module data"
f" for '{self._module}'"
)
filtered_data = {k: v for k, v in dev._last_update.items() if k in q_keys}
filtered_data = {
k: v for k, v in self._device._last_update.items() if k in q_keys
}
if len(filtered_data) == 1: if len(filtered_data) == 1:
return next(iter(filtered_data.values())) return next(iter(filtered_data.values()))

View File

@ -47,7 +47,6 @@ async def test_childdevice_properties(dev: SmartChildDevice):
assert len(dev.children) > 0 assert len(dev.children) > 0
first = dev.children[0] first = dev.children[0]
assert first.is_strip_socket
# children do not have children # children do not have children
assert not first.children assert not first.children

View File

@ -19,6 +19,7 @@ from kasa.cli import (
alias, alias,
brightness, brightness,
cli, cli,
cmd_command,
emeter, emeter,
raw_command, raw_command,
reboot, reboot,
@ -136,6 +137,32 @@ async def test_raw_command(dev, mocker):
assert "Usage" in res.output assert "Usage" in res.output
async def test_command_with_child(dev, mocker):
"""Test 'command' command with --child."""
runner = CliRunner()
update_mock = mocker.patch.object(dev, "update")
dummy_child = mocker.create_autospec(IotDevice)
query_mock = mocker.patch.object(
dummy_child, "_query_helper", return_value={"dummy": "response"}
)
mocker.patch.object(dev, "_children", {"XYZ": dummy_child})
mocker.patch.object(dev, "get_child_device", return_value=dummy_child)
res = await runner.invoke(
cmd_command,
["--child", "XYZ", "command", "'params'"],
obj=dev,
catch_exceptions=False,
)
update_mock.assert_called()
query_mock.assert_called()
assert '{"dummy": "response"}' in res.output
assert res.exit_code == 0
@device_smart @device_smart
async def test_reboot(dev, mocker): async def test_reboot(dev, mocker):
"""Test that reboot works on SMART devices.""" """Test that reboot works on SMART devices."""

View File

@ -37,6 +37,7 @@ from .conftest import (
lightstrip, lightstrip,
no_emeter_iot, no_emeter_iot,
plug, plug,
strip,
turn_on, turn_on,
) )
from .fakeprotocol_iot import FakeIotProtocol from .fakeprotocol_iot import FakeIotProtocol
@ -201,13 +202,12 @@ async def test_representation(dev):
assert pattern.match(str(dev)) assert pattern.match(str(dev))
@device_iot @strip
async def test_childrens(dev): def test_children_api(dev):
"""Make sure that children property is exposed by every device.""" """Test the child device API."""
if dev.is_strip: first = dev.children[0]
assert len(dev.children) > 0 first_by_get_child_device = dev.get_child_device(first.device_id)
else: assert first == first_by_get_child_device
assert len(dev.children) == 0
@device_iot @device_iot
@ -215,10 +215,8 @@ async def test_children(dev):
"""Make sure that children property is exposed by every device.""" """Make sure that children property is exposed by every device."""
if dev.is_strip: if dev.is_strip:
assert len(dev.children) > 0 assert len(dev.children) > 0
assert dev.has_children is True
else: else:
assert len(dev.children) == 0 assert len(dev.children) == 0
assert dev.has_children is False
@device_iot @device_iot
@ -260,7 +258,9 @@ async def test_device_class_ctors(device_class_name_obj):
klass = device_class_name_obj[1] klass = device_class_name_obj[1]
if issubclass(klass, SmartChildDevice): if issubclass(klass, SmartChildDevice):
parent = SmartDevice(host, config=config) parent = SmartDevice(host, config=config)
dev = klass(parent, 1) dev = klass(
parent, {"dummy": "info", "device_id": "dummy"}, {"dummy": "components"}
)
else: else:
dev = klass(host, config=config) dev = klass(host, config=config)
assert dev.host == host assert dev.host == host