Use dict as store for child devices

This allows accessing child devices directly by their device_id,
which will be necessary to improve the child device support.
This commit is contained in:
Teemu Rytilahti 2024-02-18 18:47:39 +01:00
parent 9ab9420ad6
commit e27d5a3dec
8 changed files with 57 additions and 51 deletions

View File

@ -559,7 +559,7 @@ async def state(ctx, dev: Device):
echo(f"\tDevice state: {dev.is_on}")
if dev.is_strip:
echo("\t[bold]== Plugs ==[/bold]")
for plug in dev.children: # type: ignore
for plug in dev.children.values(): # type: ignore
echo(f"\t* Socket '{plug.alias}' state: {plug.is_on} since {plug.on_since}")
echo()

View File

@ -3,7 +3,7 @@ import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, List, Optional, Sequence, Union
from typing import Any, Dict, List, Mapping, Optional, Union
from .credentials import Credentials
from .device_type import DeviceType
@ -70,6 +70,7 @@ class Device(ABC):
self._discovery_info: Optional[Dict[str, Any]] = None
self.modules: Dict[str, Any] = {}
self._children: Dict[str, "Device"] = {}
self._features: Dict[str, Feature] = {}
@staticmethod
@ -183,7 +184,7 @@ class Device(ABC):
@property
@abstractmethod
def children(self) -> Sequence["Device"]:
def children(self) -> Mapping[str, "Device"]:
"""Returns the child devices."""
@property
@ -238,7 +239,7 @@ class Device(ABC):
def get_plug_by_name(self, name: str) -> "Device":
"""Return child device for the given name."""
for p in self.children:
for p in self.children.values():
if p.alias == name:
return p
@ -250,7 +251,7 @@ class Device(ABC):
raise SmartDeviceException(
f"Invalid index {index}, device has {len(self.children)} plugs"
)
return self.children[index]
return list(self.children.values())[index]
@property
@abstractmethod

View File

@ -16,7 +16,7 @@ import functools
import inspect
import logging
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Sequence, Set
from typing import Any, Dict, List, Optional, Set, cast
from ..device import Device, WifiNetwork
from ..deviceconfig import DeviceConfig
@ -185,19 +185,13 @@ class IotDevice(Device):
super().__init__(host=host, config=config, protocol=protocol)
self._sys_info: Any = None # TODO: this is here to avoid changing tests
self._children: Sequence["IotDevice"] = []
self._supported_modules: Optional[Dict[str, IotModule]] = None
self._legacy_features: Set[str] = set()
@property
def children(self) -> Sequence["IotDevice"]:
def children(self) -> Dict[str, "IotDevice"]:
"""Return list of children."""
return self._children
@children.setter
def children(self, children):
"""Initialize from a list of children."""
self._children = children
return cast(Dict[str, "IotDevice"], self._children)
def add_module(self, name: str, module: IotModule):
"""Register a module."""

View File

@ -55,7 +55,7 @@ class IotStrip(IotDevice):
All methods act on the whole strip:
>>> for plug in strip.children:
>>> for plug in strip.children.values():
>>> print(f"{plug.alias}: {plug.is_on}")
Plug 1: True
Plug 2: False
@ -68,12 +68,12 @@ class IotStrip(IotDevice):
>>> len(strip.children)
3
>>> for plug in strip.children:
>>> for plug in strip.children.values():
>>> print(f"{plug.alias}: {plug.is_on}")
Plug 1: False
Plug 2: False
Plug 3: False
>>> asyncio.run(strip.children[1].turn_on())
>>> asyncio.run(list(strip.children.values())[1].turn_on())
>>> asyncio.run(strip.update())
>>> strip.is_on
True
@ -102,7 +102,7 @@ class IotStrip(IotDevice):
@requires_update
def is_on(self) -> bool:
"""Return if any of the outlets are on."""
return any(plug.is_on for plug in self.children)
return any(plug.is_on for plug in self.children.values())
async def update(self, update_children: bool = True):
"""Update some of the attributes.
@ -115,13 +115,13 @@ class IotStrip(IotDevice):
if not self.children:
children = self.sys_info["children"]
_LOGGER.debug("Initializing %s child sockets", len(children))
self.children = [
IotStripPlug(self.host, parent=self, child_id=child["id"])
self._children = {
child["id"]: IotStripPlug(self.host, parent=self, child_id=child["id"])
for child in children
]
}
if update_children and self.has_emeter:
for plug in self.children:
for plug in self.children.values():
await plug.update()
async def turn_on(self, **kwargs):
@ -139,7 +139,11 @@ class IotStrip(IotDevice):
if self.is_off:
return None
return max(plug.on_since for plug in self.children if plug.on_since is not None)
return max(
plug.on_since
for plug in self.children.values()
if plug.on_since is not None
)
@property # type: ignore
@requires_update
@ -167,7 +171,9 @@ class IotStrip(IotDevice):
async def current_consumption(self) -> float:
"""Get the current power consumption in watts."""
return sum([await plug.current_consumption() for plug in self.children])
return sum(
[await plug.current_consumption() for plug in self.children.values()]
)
@requires_update
async def get_emeter_realtime(self) -> EmeterStatus:
@ -211,32 +217,32 @@ class IotStrip(IotDevice):
"""Retreive emeter stats for a time period from children."""
self._verify_emeter()
return merge_sums(
[await getattr(plug, func)(**kwargs) for plug in self.children]
[await getattr(plug, func)(**kwargs) for plug in self.children.values()]
)
@requires_update
async def erase_emeter_stats(self):
"""Erase energy meter statistics for all plugs."""
for plug in self.children:
for plug in self.children.values():
await plug.erase_emeter_stats()
@property # type: ignore
@requires_update
def emeter_this_month(self) -> Optional[float]:
"""Return this month's energy consumption in kWh."""
return sum(plug.emeter_this_month for plug in self.children)
return sum(plug.emeter_this_month for plug in self.children.values())
@property # type: ignore
@requires_update
def emeter_today(self) -> Optional[float]:
"""Return this month's energy consumption in kWh."""
return sum(plug.emeter_today for plug in self.children)
return sum(plug.emeter_today for plug in self.children.values())
@property # type: ignore
@requires_update
def emeter_realtime(self) -> EmeterStatus:
"""Return current energy readings."""
emeter = merge_sums([plug.emeter_realtime for plug in self.children])
emeter = merge_sums([plug.emeter_realtime for plug in self.children.values()])
# Voltage is averaged since each read will result
# in a slightly different voltage since they are not atomic
emeter["voltage_mv"] = int(emeter["voltage_mv"] / len(self.children))

View File

@ -2,7 +2,7 @@
import base64
import logging
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, cast
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, cast
from ..aestransport import AesTransport
from ..device import Device, WifiNetwork
@ -16,7 +16,7 @@ from ..smartprotocol import SmartProtocol
_LOGGER = logging.getLogger(__name__)
if TYPE_CHECKING:
from .smartchilddevice import SmartChildDevice
pass
class SmartDevice(Device):
@ -36,7 +36,6 @@ class SmartDevice(Device):
self.protocol: SmartProtocol
self._components_raw: Optional[Dict[str, Any]] = None
self._components: Dict[str, int] = {}
self._children: Dict[str, "SmartChildDevice"] = {}
self._energy: Dict[str, Any] = {}
self._state_information: Dict[str, Any] = {}
self._time: Dict[str, Any] = {}
@ -57,9 +56,9 @@ class SmartDevice(Device):
self._device_type = DeviceType.Strip
@property
def children(self) -> Sequence["SmartDevice"]:
def children(self) -> Mapping[str, "SmartDevice"]:
"""Return list of children."""
return list(self._children.values())
return cast(Mapping[str, "SmartDevice"], self._children)
def _try_get_response(self, responses: dict, request: str, default=None) -> dict:
response = responses.get(request)
@ -141,7 +140,7 @@ class SmartDevice(Device):
if not self.children:
await self._initialize_children()
for info in child_info["child_device_list"]:
self._children[info["device_id"]].update_internal_state(info)
self.children[info["device_id"]].update_internal_state(info) # type: ignore[attr-defined]
# We can first initialize the features after the first update.
# We make here an assumption that every device has at least a single feature.

View File

@ -15,7 +15,7 @@ def test_childdevice_init(dev, dummy_protocol, mocker):
assert len(dev.children) > 0
assert dev.is_strip
first = dev.children[0]
first = list(dev.children.values())[0]
assert isinstance(first.protocol, _ChildProtocolWrapper)
assert first._info["category"] == "plug.powerstrip.sub-plug"
@ -29,7 +29,7 @@ async def test_childdevice_update(dev, dummy_protocol, mocker):
child_list = child_info["child_device_list"]
assert len(dev.children) == child_info["sum"]
first = dev.children[0]
first = list(dev.children.values())[0]
await dev.update()
@ -46,8 +46,7 @@ async def test_childdevice_properties(dev: SmartChildDevice):
"""Check that accessing childdevice properties do not raise exceptions."""
assert len(dev.children) > 0
first = dev.children[0]
assert first.is_strip_socket
first = list(dev.children.values())[0]
# children do not have children
assert not first.children
@ -60,6 +59,11 @@ async def test_childdevice_properties(dev: SmartChildDevice):
)
for prop in properties:
name, _ = prop
if (
name.startswith("emeter_")
or name.startswith("time")
or name.startswith("on_since")
):
try:
_ = getattr(first, name)
except Exception as ex:

View File

@ -241,7 +241,8 @@ async def test_emeter(dev: Device, mocker):
assert "Index and name are only for power strips!" in res.output
if dev.is_strip and len(dev.children) > 0:
realtime_emeter = mocker.patch.object(dev.children[0], "get_emeter_realtime")
first_child = list(dev.children.values())[0]
realtime_emeter = mocker.patch.object(first_child, "get_emeter_realtime")
realtime_emeter.return_value = EmeterStatus({"voltage_mv": 122066})
res = await runner.invoke(emeter, ["--index", "0"], obj=dev)
@ -249,7 +250,7 @@ async def test_emeter(dev: Device, mocker):
realtime_emeter.assert_called()
assert realtime_emeter.call_count == 1
res = await runner.invoke(emeter, ["--name", dev.children[0].alias], obj=dev)
res = await runner.invoke(emeter, ["--name", first_child.alias], obj=dev)
assert "Voltage: 122.066 V" in res.output
assert realtime_emeter.call_count == 2

View File

@ -12,7 +12,7 @@ from .conftest import handle_turn_on, strip, turn_on
@turn_on
async def test_children_change_state(dev, turn_on):
await handle_turn_on(dev, turn_on)
for plug in dev.children:
for plug in dev.children.values():
orig_state = plug.is_on
if orig_state:
await plug.turn_off()
@ -39,7 +39,7 @@ async def test_children_change_state(dev, turn_on):
@strip
async def test_children_alias(dev):
test_alias = "TEST1234"
for plug in dev.children:
for plug in dev.children.values():
original = plug.alias
await plug.set_alias(alias=test_alias)
await dev.update() # TODO: set_alias does not call parent's update()..
@ -53,7 +53,7 @@ async def test_children_alias(dev):
@strip
async def test_children_on_since(dev):
on_sinces = []
for plug in dev.children:
for plug in dev.children.values():
if plug.is_on:
on_sinces.append(plug.on_since)
assert isinstance(plug.on_since, datetime)
@ -70,8 +70,9 @@ async def test_children_on_since(dev):
@strip
async def test_get_plug_by_name(dev: IotStrip):
name = dev.children[0].alias
assert dev.get_plug_by_name(name) == dev.children[0] # type: ignore[arg-type]
children = list(dev.children.values())
name = children[0].alias
assert dev.get_plug_by_name(name) == children[0] # type: ignore[arg-type]
with pytest.raises(SmartDeviceException):
dev.get_plug_by_name("NONEXISTING NAME")
@ -79,7 +80,7 @@ async def test_get_plug_by_name(dev: IotStrip):
@strip
async def test_get_plug_by_index(dev: IotStrip):
assert dev.get_plug_by_index(0) == dev.children[0]
assert dev.get_plug_by_index(0) == list(dev.children.values())[0]
with pytest.raises(SmartDeviceException):
dev.get_plug_by_index(-1)