Use dict as store for child devices

This allows accessing child devices directly by their device_id,
which will be necessary to improve the child device support.
This commit is contained in:
Teemu Rytilahti 2024-02-18 18:47:39 +01:00
parent 9ab9420ad6
commit e27d5a3dec
8 changed files with 57 additions and 51 deletions

View File

@ -559,7 +559,7 @@ async def state(ctx, dev: Device):
echo(f"\tDevice state: {dev.is_on}") echo(f"\tDevice state: {dev.is_on}")
if dev.is_strip: if dev.is_strip:
echo("\t[bold]== Plugs ==[/bold]") echo("\t[bold]== Plugs ==[/bold]")
for plug in dev.children: # type: ignore for plug in dev.children.values(): # type: ignore
echo(f"\t* Socket '{plug.alias}' state: {plug.is_on} since {plug.on_since}") echo(f"\t* Socket '{plug.alias}' state: {plug.is_on} since {plug.on_since}")
echo() echo()

View File

@ -3,7 +3,7 @@ import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Optional, Sequence, Union from typing import Any, Dict, List, Mapping, Optional, Union
from .credentials import Credentials from .credentials import Credentials
from .device_type import DeviceType from .device_type import DeviceType
@ -70,6 +70,7 @@ class Device(ABC):
self._discovery_info: Optional[Dict[str, Any]] = None self._discovery_info: Optional[Dict[str, Any]] = None
self.modules: Dict[str, Any] = {} self.modules: Dict[str, Any] = {}
self._children: Dict[str, "Device"] = {}
self._features: Dict[str, Feature] = {} self._features: Dict[str, Feature] = {}
@staticmethod @staticmethod
@ -183,7 +184,7 @@ class Device(ABC):
@property @property
@abstractmethod @abstractmethod
def children(self) -> Sequence["Device"]: def children(self) -> Mapping[str, "Device"]:
"""Returns the child devices.""" """Returns the child devices."""
@property @property
@ -238,7 +239,7 @@ class Device(ABC):
def get_plug_by_name(self, name: str) -> "Device": def get_plug_by_name(self, name: str) -> "Device":
"""Return child device for the given name.""" """Return child device for the given name."""
for p in self.children: for p in self.children.values():
if p.alias == name: if p.alias == name:
return p return p
@ -250,7 +251,7 @@ class Device(ABC):
raise SmartDeviceException( raise SmartDeviceException(
f"Invalid index {index}, device has {len(self.children)} plugs" f"Invalid index {index}, device has {len(self.children)} plugs"
) )
return self.children[index] return list(self.children.values())[index]
@property @property
@abstractmethod @abstractmethod

View File

@ -16,7 +16,7 @@ import functools
import inspect import inspect
import logging import logging
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Sequence, Set from typing import Any, Dict, List, Optional, Set, cast
from ..device import Device, WifiNetwork from ..device import Device, WifiNetwork
from ..deviceconfig import DeviceConfig from ..deviceconfig import DeviceConfig
@ -185,19 +185,13 @@ class IotDevice(Device):
super().__init__(host=host, config=config, protocol=protocol) super().__init__(host=host, config=config, protocol=protocol)
self._sys_info: Any = None # TODO: this is here to avoid changing tests self._sys_info: Any = None # TODO: this is here to avoid changing tests
self._children: Sequence["IotDevice"] = []
self._supported_modules: Optional[Dict[str, IotModule]] = None self._supported_modules: Optional[Dict[str, IotModule]] = None
self._legacy_features: Set[str] = set() self._legacy_features: Set[str] = set()
@property @property
def children(self) -> Sequence["IotDevice"]: def children(self) -> Dict[str, "IotDevice"]:
"""Return list of children.""" """Return list of children."""
return self._children return cast(Dict[str, "IotDevice"], self._children)
@children.setter
def children(self, children):
"""Initialize from a list of children."""
self._children = children
def add_module(self, name: str, module: IotModule): def add_module(self, name: str, module: IotModule):
"""Register a module.""" """Register a module."""

View File

@ -55,7 +55,7 @@ class IotStrip(IotDevice):
All methods act on the whole strip: All methods act on the whole strip:
>>> for plug in strip.children: >>> for plug in strip.children.values():
>>> print(f"{plug.alias}: {plug.is_on}") >>> print(f"{plug.alias}: {plug.is_on}")
Plug 1: True Plug 1: True
Plug 2: False Plug 2: False
@ -68,12 +68,12 @@ class IotStrip(IotDevice):
>>> len(strip.children) >>> len(strip.children)
3 3
>>> for plug in strip.children: >>> for plug in strip.children.values():
>>> print(f"{plug.alias}: {plug.is_on}") >>> print(f"{plug.alias}: {plug.is_on}")
Plug 1: False Plug 1: False
Plug 2: False Plug 2: False
Plug 3: False Plug 3: False
>>> asyncio.run(strip.children[1].turn_on()) >>> asyncio.run(list(strip.children.values())[1].turn_on())
>>> asyncio.run(strip.update()) >>> asyncio.run(strip.update())
>>> strip.is_on >>> strip.is_on
True True
@ -102,7 +102,7 @@ class IotStrip(IotDevice):
@requires_update @requires_update
def is_on(self) -> bool: def is_on(self) -> bool:
"""Return if any of the outlets are on.""" """Return if any of the outlets are on."""
return any(plug.is_on for plug in self.children) return any(plug.is_on for plug in self.children.values())
async def update(self, update_children: bool = True): async def update(self, update_children: bool = True):
"""Update some of the attributes. """Update some of the attributes.
@ -115,13 +115,13 @@ class IotStrip(IotDevice):
if not self.children: if not self.children:
children = self.sys_info["children"] children = self.sys_info["children"]
_LOGGER.debug("Initializing %s child sockets", len(children)) _LOGGER.debug("Initializing %s child sockets", len(children))
self.children = [ self._children = {
IotStripPlug(self.host, parent=self, child_id=child["id"]) child["id"]: IotStripPlug(self.host, parent=self, child_id=child["id"])
for child in children for child in children
] }
if update_children and self.has_emeter: if update_children and self.has_emeter:
for plug in self.children: for plug in self.children.values():
await plug.update() await plug.update()
async def turn_on(self, **kwargs): async def turn_on(self, **kwargs):
@ -139,7 +139,11 @@ class IotStrip(IotDevice):
if self.is_off: if self.is_off:
return None return None
return max(plug.on_since for plug in self.children if plug.on_since is not None) return max(
plug.on_since
for plug in self.children.values()
if plug.on_since is not None
)
@property # type: ignore @property # type: ignore
@requires_update @requires_update
@ -167,7 +171,9 @@ class IotStrip(IotDevice):
async def current_consumption(self) -> float: async def current_consumption(self) -> float:
"""Get the current power consumption in watts.""" """Get the current power consumption in watts."""
return sum([await plug.current_consumption() for plug in self.children]) return sum(
[await plug.current_consumption() for plug in self.children.values()]
)
@requires_update @requires_update
async def get_emeter_realtime(self) -> EmeterStatus: async def get_emeter_realtime(self) -> EmeterStatus:
@ -211,32 +217,32 @@ class IotStrip(IotDevice):
"""Retreive emeter stats for a time period from children.""" """Retreive emeter stats for a time period from children."""
self._verify_emeter() self._verify_emeter()
return merge_sums( return merge_sums(
[await getattr(plug, func)(**kwargs) for plug in self.children] [await getattr(plug, func)(**kwargs) for plug in self.children.values()]
) )
@requires_update @requires_update
async def erase_emeter_stats(self): async def erase_emeter_stats(self):
"""Erase energy meter statistics for all plugs.""" """Erase energy meter statistics for all plugs."""
for plug in self.children: for plug in self.children.values():
await plug.erase_emeter_stats() await plug.erase_emeter_stats()
@property # type: ignore @property # type: ignore
@requires_update @requires_update
def emeter_this_month(self) -> Optional[float]: def emeter_this_month(self) -> Optional[float]:
"""Return this month's energy consumption in kWh.""" """Return this month's energy consumption in kWh."""
return sum(plug.emeter_this_month for plug in self.children) return sum(plug.emeter_this_month for plug in self.children.values())
@property # type: ignore @property # type: ignore
@requires_update @requires_update
def emeter_today(self) -> Optional[float]: def emeter_today(self) -> Optional[float]:
"""Return this month's energy consumption in kWh.""" """Return this month's energy consumption in kWh."""
return sum(plug.emeter_today for plug in self.children) return sum(plug.emeter_today for plug in self.children.values())
@property # type: ignore @property # type: ignore
@requires_update @requires_update
def emeter_realtime(self) -> EmeterStatus: def emeter_realtime(self) -> EmeterStatus:
"""Return current energy readings.""" """Return current energy readings."""
emeter = merge_sums([plug.emeter_realtime for plug in self.children]) emeter = merge_sums([plug.emeter_realtime for plug in self.children.values()])
# Voltage is averaged since each read will result # Voltage is averaged since each read will result
# in a slightly different voltage since they are not atomic # in a slightly different voltage since they are not atomic
emeter["voltage_mv"] = int(emeter["voltage_mv"] / len(self.children)) emeter["voltage_mv"] = int(emeter["voltage_mv"] / len(self.children))

View File

@ -2,7 +2,7 @@
import base64 import base64
import logging import logging
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, cast from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, cast
from ..aestransport import AesTransport from ..aestransport import AesTransport
from ..device import Device, WifiNetwork from ..device import Device, WifiNetwork
@ -16,7 +16,7 @@ from ..smartprotocol import SmartProtocol
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from .smartchilddevice import SmartChildDevice pass
class SmartDevice(Device): class SmartDevice(Device):
@ -36,7 +36,6 @@ class SmartDevice(Device):
self.protocol: SmartProtocol self.protocol: SmartProtocol
self._components_raw: Optional[Dict[str, Any]] = None self._components_raw: Optional[Dict[str, Any]] = None
self._components: Dict[str, int] = {} self._components: Dict[str, int] = {}
self._children: Dict[str, "SmartChildDevice"] = {}
self._energy: Dict[str, Any] = {} self._energy: Dict[str, Any] = {}
self._state_information: Dict[str, Any] = {} self._state_information: Dict[str, Any] = {}
self._time: Dict[str, Any] = {} self._time: Dict[str, Any] = {}
@ -57,9 +56,9 @@ class SmartDevice(Device):
self._device_type = DeviceType.Strip self._device_type = DeviceType.Strip
@property @property
def children(self) -> Sequence["SmartDevice"]: def children(self) -> Mapping[str, "SmartDevice"]:
"""Return list of children.""" """Return list of children."""
return list(self._children.values()) return cast(Mapping[str, "SmartDevice"], self._children)
def _try_get_response(self, responses: dict, request: str, default=None) -> dict: def _try_get_response(self, responses: dict, request: str, default=None) -> dict:
response = responses.get(request) response = responses.get(request)
@ -141,7 +140,7 @@ class SmartDevice(Device):
if not self.children: if not self.children:
await self._initialize_children() await self._initialize_children()
for info in child_info["child_device_list"]: for info in child_info["child_device_list"]:
self._children[info["device_id"]].update_internal_state(info) self.children[info["device_id"]].update_internal_state(info) # type: ignore[attr-defined]
# We can first initialize the features after the first update. # We can first initialize the features after the first update.
# We make here an assumption that every device has at least a single feature. # We make here an assumption that every device has at least a single feature.

View File

@ -15,7 +15,7 @@ def test_childdevice_init(dev, dummy_protocol, mocker):
assert len(dev.children) > 0 assert len(dev.children) > 0
assert dev.is_strip assert dev.is_strip
first = dev.children[0] first = list(dev.children.values())[0]
assert isinstance(first.protocol, _ChildProtocolWrapper) assert isinstance(first.protocol, _ChildProtocolWrapper)
assert first._info["category"] == "plug.powerstrip.sub-plug" assert first._info["category"] == "plug.powerstrip.sub-plug"
@ -29,7 +29,7 @@ async def test_childdevice_update(dev, dummy_protocol, mocker):
child_list = child_info["child_device_list"] child_list = child_info["child_device_list"]
assert len(dev.children) == child_info["sum"] assert len(dev.children) == child_info["sum"]
first = dev.children[0] first = list(dev.children.values())[0]
await dev.update() await dev.update()
@ -46,8 +46,7 @@ async def test_childdevice_properties(dev: SmartChildDevice):
"""Check that accessing childdevice properties do not raise exceptions.""" """Check that accessing childdevice properties do not raise exceptions."""
assert len(dev.children) > 0 assert len(dev.children) > 0
first = dev.children[0] first = list(dev.children.values())[0]
assert first.is_strip_socket
# children do not have children # children do not have children
assert not first.children assert not first.children
@ -60,10 +59,15 @@ async def test_childdevice_properties(dev: SmartChildDevice):
) )
for prop in properties: for prop in properties:
name, _ = prop name, _ = prop
try: if (
_ = getattr(first, name) name.startswith("emeter_")
except Exception as ex: or name.startswith("time")
exceptions.append(ex) or name.startswith("on_since")
):
try:
_ = getattr(first, name)
except Exception as ex:
exceptions.append(ex)
return exceptions return exceptions

View File

@ -241,7 +241,8 @@ async def test_emeter(dev: Device, mocker):
assert "Index and name are only for power strips!" in res.output assert "Index and name are only for power strips!" in res.output
if dev.is_strip and len(dev.children) > 0: if dev.is_strip and len(dev.children) > 0:
realtime_emeter = mocker.patch.object(dev.children[0], "get_emeter_realtime") first_child = list(dev.children.values())[0]
realtime_emeter = mocker.patch.object(first_child, "get_emeter_realtime")
realtime_emeter.return_value = EmeterStatus({"voltage_mv": 122066}) realtime_emeter.return_value = EmeterStatus({"voltage_mv": 122066})
res = await runner.invoke(emeter, ["--index", "0"], obj=dev) res = await runner.invoke(emeter, ["--index", "0"], obj=dev)
@ -249,7 +250,7 @@ async def test_emeter(dev: Device, mocker):
realtime_emeter.assert_called() realtime_emeter.assert_called()
assert realtime_emeter.call_count == 1 assert realtime_emeter.call_count == 1
res = await runner.invoke(emeter, ["--name", dev.children[0].alias], obj=dev) res = await runner.invoke(emeter, ["--name", first_child.alias], obj=dev)
assert "Voltage: 122.066 V" in res.output assert "Voltage: 122.066 V" in res.output
assert realtime_emeter.call_count == 2 assert realtime_emeter.call_count == 2

View File

@ -12,7 +12,7 @@ from .conftest import handle_turn_on, strip, turn_on
@turn_on @turn_on
async def test_children_change_state(dev, turn_on): async def test_children_change_state(dev, turn_on):
await handle_turn_on(dev, turn_on) await handle_turn_on(dev, turn_on)
for plug in dev.children: for plug in dev.children.values():
orig_state = plug.is_on orig_state = plug.is_on
if orig_state: if orig_state:
await plug.turn_off() await plug.turn_off()
@ -39,7 +39,7 @@ async def test_children_change_state(dev, turn_on):
@strip @strip
async def test_children_alias(dev): async def test_children_alias(dev):
test_alias = "TEST1234" test_alias = "TEST1234"
for plug in dev.children: for plug in dev.children.values():
original = plug.alias original = plug.alias
await plug.set_alias(alias=test_alias) await plug.set_alias(alias=test_alias)
await dev.update() # TODO: set_alias does not call parent's update().. await dev.update() # TODO: set_alias does not call parent's update()..
@ -53,7 +53,7 @@ async def test_children_alias(dev):
@strip @strip
async def test_children_on_since(dev): async def test_children_on_since(dev):
on_sinces = [] on_sinces = []
for plug in dev.children: for plug in dev.children.values():
if plug.is_on: if plug.is_on:
on_sinces.append(plug.on_since) on_sinces.append(plug.on_since)
assert isinstance(plug.on_since, datetime) assert isinstance(plug.on_since, datetime)
@ -70,8 +70,9 @@ async def test_children_on_since(dev):
@strip @strip
async def test_get_plug_by_name(dev: IotStrip): async def test_get_plug_by_name(dev: IotStrip):
name = dev.children[0].alias children = list(dev.children.values())
assert dev.get_plug_by_name(name) == dev.children[0] # type: ignore[arg-type] name = children[0].alias
assert dev.get_plug_by_name(name) == children[0] # type: ignore[arg-type]
with pytest.raises(SmartDeviceException): with pytest.raises(SmartDeviceException):
dev.get_plug_by_name("NONEXISTING NAME") dev.get_plug_by_name("NONEXISTING NAME")
@ -79,7 +80,7 @@ async def test_get_plug_by_name(dev: IotStrip):
@strip @strip
async def test_get_plug_by_index(dev: IotStrip): async def test_get_plug_by_index(dev: IotStrip):
assert dev.get_plug_by_index(0) == dev.children[0] assert dev.get_plug_by_index(0) == list(dev.children.values())[0]
with pytest.raises(SmartDeviceException): with pytest.raises(SmartDeviceException):
dev.get_plug_by_index(-1) dev.get_plug_by_index(-1)