From e27d5a3dec9aa4f094806fcd5b22572a436f6ad4 Mon Sep 17 00:00:00 2001 From: Teemu Rytilahti Date: Sun, 18 Feb 2024 18:47:39 +0100 Subject: [PATCH] Use dict as store for child devices This allows accessing child devices directly by their device_id, which will be necessary to improve the child device support. --- kasa/cli.py | 2 +- kasa/device.py | 9 +++++---- kasa/iot/iotdevice.py | 12 +++--------- kasa/iot/iotstrip.py | 36 ++++++++++++++++++++-------------- kasa/smart/smartdevice.py | 11 +++++------ kasa/tests/test_childdevice.py | 20 +++++++++++-------- kasa/tests/test_cli.py | 5 +++-- kasa/tests/test_strip.py | 13 ++++++------ 8 files changed, 57 insertions(+), 51 deletions(-) diff --git a/kasa/cli.py b/kasa/cli.py index e922ec81..6c8fc46f 100755 --- a/kasa/cli.py +++ b/kasa/cli.py @@ -559,7 +559,7 @@ async def state(ctx, dev: Device): echo(f"\tDevice state: {dev.is_on}") if dev.is_strip: echo("\t[bold]== Plugs ==[/bold]") - for plug in dev.children: # type: ignore + for plug in dev.children.values(): # type: ignore echo(f"\t* Socket '{plug.alias}' state: {plug.is_on} since {plug.on_since}") echo() diff --git a/kasa/device.py b/kasa/device.py index 3c38b544..f0ea0c4d 100644 --- a/kasa/device.py +++ b/kasa/device.py @@ -3,7 +3,7 @@ import logging from abc import ABC, abstractmethod from dataclasses import dataclass from datetime import datetime -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Dict, List, Mapping, Optional, Union from .credentials import Credentials from .device_type import DeviceType @@ -70,6 +70,7 @@ class Device(ABC): self._discovery_info: Optional[Dict[str, Any]] = None self.modules: Dict[str, Any] = {} + self._children: Dict[str, "Device"] = {} self._features: Dict[str, Feature] = {} @staticmethod @@ -183,7 +184,7 @@ class Device(ABC): @property @abstractmethod - def children(self) -> Sequence["Device"]: + def children(self) -> Mapping[str, "Device"]: """Returns the child devices.""" @property @@ -238,7 +239,7 @@ class Device(ABC): def get_plug_by_name(self, name: str) -> "Device": """Return child device for the given name.""" - for p in self.children: + for p in self.children.values(): if p.alias == name: return p @@ -250,7 +251,7 @@ class Device(ABC): raise SmartDeviceException( f"Invalid index {index}, device has {len(self.children)} plugs" ) - return self.children[index] + return list(self.children.values())[index] @property @abstractmethod diff --git a/kasa/iot/iotdevice.py b/kasa/iot/iotdevice.py index 8ec7cd4b..14194d00 100755 --- a/kasa/iot/iotdevice.py +++ b/kasa/iot/iotdevice.py @@ -16,7 +16,7 @@ import functools import inspect import logging from datetime import datetime, timedelta -from typing import Any, Dict, List, Optional, Sequence, Set +from typing import Any, Dict, List, Optional, Set, cast from ..device import Device, WifiNetwork from ..deviceconfig import DeviceConfig @@ -185,19 +185,13 @@ class IotDevice(Device): super().__init__(host=host, config=config, protocol=protocol) self._sys_info: Any = None # TODO: this is here to avoid changing tests - self._children: Sequence["IotDevice"] = [] self._supported_modules: Optional[Dict[str, IotModule]] = None self._legacy_features: Set[str] = set() @property - def children(self) -> Sequence["IotDevice"]: + def children(self) -> Dict[str, "IotDevice"]: """Return list of children.""" - return self._children - - @children.setter - def children(self, children): - """Initialize from a list of children.""" - self._children = children + return cast(Dict[str, "IotDevice"], self._children) def add_module(self, name: str, module: IotModule): """Register a module.""" diff --git a/kasa/iot/iotstrip.py b/kasa/iot/iotstrip.py index 7cbb10b0..a621d498 100755 --- a/kasa/iot/iotstrip.py +++ b/kasa/iot/iotstrip.py @@ -55,7 +55,7 @@ class IotStrip(IotDevice): All methods act on the whole strip: - >>> for plug in strip.children: + >>> for plug in strip.children.values(): >>> print(f"{plug.alias}: {plug.is_on}") Plug 1: True Plug 2: False @@ -68,12 +68,12 @@ class IotStrip(IotDevice): >>> len(strip.children) 3 - >>> for plug in strip.children: + >>> for plug in strip.children.values(): >>> print(f"{plug.alias}: {plug.is_on}") Plug 1: False Plug 2: False Plug 3: False - >>> asyncio.run(strip.children[1].turn_on()) + >>> asyncio.run(list(strip.children.values())[1].turn_on()) >>> asyncio.run(strip.update()) >>> strip.is_on True @@ -102,7 +102,7 @@ class IotStrip(IotDevice): @requires_update def is_on(self) -> bool: """Return if any of the outlets are on.""" - return any(plug.is_on for plug in self.children) + return any(plug.is_on for plug in self.children.values()) async def update(self, update_children: bool = True): """Update some of the attributes. @@ -115,13 +115,13 @@ class IotStrip(IotDevice): if not self.children: children = self.sys_info["children"] _LOGGER.debug("Initializing %s child sockets", len(children)) - self.children = [ - IotStripPlug(self.host, parent=self, child_id=child["id"]) + self._children = { + child["id"]: IotStripPlug(self.host, parent=self, child_id=child["id"]) for child in children - ] + } if update_children and self.has_emeter: - for plug in self.children: + for plug in self.children.values(): await plug.update() async def turn_on(self, **kwargs): @@ -139,7 +139,11 @@ class IotStrip(IotDevice): if self.is_off: return None - return max(plug.on_since for plug in self.children if plug.on_since is not None) + return max( + plug.on_since + for plug in self.children.values() + if plug.on_since is not None + ) @property # type: ignore @requires_update @@ -167,7 +171,9 @@ class IotStrip(IotDevice): async def current_consumption(self) -> float: """Get the current power consumption in watts.""" - return sum([await plug.current_consumption() for plug in self.children]) + return sum( + [await plug.current_consumption() for plug in self.children.values()] + ) @requires_update async def get_emeter_realtime(self) -> EmeterStatus: @@ -211,32 +217,32 @@ class IotStrip(IotDevice): """Retreive emeter stats for a time period from children.""" self._verify_emeter() return merge_sums( - [await getattr(plug, func)(**kwargs) for plug in self.children] + [await getattr(plug, func)(**kwargs) for plug in self.children.values()] ) @requires_update async def erase_emeter_stats(self): """Erase energy meter statistics for all plugs.""" - for plug in self.children: + for plug in self.children.values(): await plug.erase_emeter_stats() @property # type: ignore @requires_update def emeter_this_month(self) -> Optional[float]: """Return this month's energy consumption in kWh.""" - return sum(plug.emeter_this_month for plug in self.children) + return sum(plug.emeter_this_month for plug in self.children.values()) @property # type: ignore @requires_update def emeter_today(self) -> Optional[float]: """Return this month's energy consumption in kWh.""" - return sum(plug.emeter_today for plug in self.children) + return sum(plug.emeter_today for plug in self.children.values()) @property # type: ignore @requires_update def emeter_realtime(self) -> EmeterStatus: """Return current energy readings.""" - emeter = merge_sums([plug.emeter_realtime for plug in self.children]) + emeter = merge_sums([plug.emeter_realtime for plug in self.children.values()]) # Voltage is averaged since each read will result # in a slightly different voltage since they are not atomic emeter["voltage_mv"] = int(emeter["voltage_mv"] / len(self.children)) diff --git a/kasa/smart/smartdevice.py b/kasa/smart/smartdevice.py index d2259434..51b104b2 100644 --- a/kasa/smart/smartdevice.py +++ b/kasa/smart/smartdevice.py @@ -2,7 +2,7 @@ import base64 import logging from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, cast +from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, cast from ..aestransport import AesTransport from ..device import Device, WifiNetwork @@ -16,7 +16,7 @@ from ..smartprotocol import SmartProtocol _LOGGER = logging.getLogger(__name__) if TYPE_CHECKING: - from .smartchilddevice import SmartChildDevice + pass class SmartDevice(Device): @@ -36,7 +36,6 @@ class SmartDevice(Device): self.protocol: SmartProtocol self._components_raw: Optional[Dict[str, Any]] = None self._components: Dict[str, int] = {} - self._children: Dict[str, "SmartChildDevice"] = {} self._energy: Dict[str, Any] = {} self._state_information: Dict[str, Any] = {} self._time: Dict[str, Any] = {} @@ -57,9 +56,9 @@ class SmartDevice(Device): self._device_type = DeviceType.Strip @property - def children(self) -> Sequence["SmartDevice"]: + def children(self) -> Mapping[str, "SmartDevice"]: """Return list of children.""" - return list(self._children.values()) + return cast(Mapping[str, "SmartDevice"], self._children) def _try_get_response(self, responses: dict, request: str, default=None) -> dict: response = responses.get(request) @@ -141,7 +140,7 @@ class SmartDevice(Device): if not self.children: await self._initialize_children() for info in child_info["child_device_list"]: - self._children[info["device_id"]].update_internal_state(info) + self.children[info["device_id"]].update_internal_state(info) # type: ignore[attr-defined] # We can first initialize the features after the first update. # We make here an assumption that every device has at least a single feature. diff --git a/kasa/tests/test_childdevice.py b/kasa/tests/test_childdevice.py index 3247c917..31b730dc 100644 --- a/kasa/tests/test_childdevice.py +++ b/kasa/tests/test_childdevice.py @@ -15,7 +15,7 @@ def test_childdevice_init(dev, dummy_protocol, mocker): assert len(dev.children) > 0 assert dev.is_strip - first = dev.children[0] + first = list(dev.children.values())[0] assert isinstance(first.protocol, _ChildProtocolWrapper) assert first._info["category"] == "plug.powerstrip.sub-plug" @@ -29,7 +29,7 @@ async def test_childdevice_update(dev, dummy_protocol, mocker): child_list = child_info["child_device_list"] assert len(dev.children) == child_info["sum"] - first = dev.children[0] + first = list(dev.children.values())[0] await dev.update() @@ -46,8 +46,7 @@ async def test_childdevice_properties(dev: SmartChildDevice): """Check that accessing childdevice properties do not raise exceptions.""" assert len(dev.children) > 0 - first = dev.children[0] - assert first.is_strip_socket + first = list(dev.children.values())[0] # children do not have children assert not first.children @@ -60,10 +59,15 @@ async def test_childdevice_properties(dev: SmartChildDevice): ) for prop in properties: name, _ = prop - try: - _ = getattr(first, name) - except Exception as ex: - exceptions.append(ex) + if ( + name.startswith("emeter_") + or name.startswith("time") + or name.startswith("on_since") + ): + try: + _ = getattr(first, name) + except Exception as ex: + exceptions.append(ex) return exceptions diff --git a/kasa/tests/test_cli.py b/kasa/tests/test_cli.py index 51155f40..b769f6f8 100644 --- a/kasa/tests/test_cli.py +++ b/kasa/tests/test_cli.py @@ -241,7 +241,8 @@ async def test_emeter(dev: Device, mocker): assert "Index and name are only for power strips!" in res.output if dev.is_strip and len(dev.children) > 0: - realtime_emeter = mocker.patch.object(dev.children[0], "get_emeter_realtime") + first_child = list(dev.children.values())[0] + realtime_emeter = mocker.patch.object(first_child, "get_emeter_realtime") realtime_emeter.return_value = EmeterStatus({"voltage_mv": 122066}) res = await runner.invoke(emeter, ["--index", "0"], obj=dev) @@ -249,7 +250,7 @@ async def test_emeter(dev: Device, mocker): realtime_emeter.assert_called() assert realtime_emeter.call_count == 1 - res = await runner.invoke(emeter, ["--name", dev.children[0].alias], obj=dev) + res = await runner.invoke(emeter, ["--name", first_child.alias], obj=dev) assert "Voltage: 122.066 V" in res.output assert realtime_emeter.call_count == 2 diff --git a/kasa/tests/test_strip.py b/kasa/tests/test_strip.py index 623adde6..2918a4b3 100644 --- a/kasa/tests/test_strip.py +++ b/kasa/tests/test_strip.py @@ -12,7 +12,7 @@ from .conftest import handle_turn_on, strip, turn_on @turn_on async def test_children_change_state(dev, turn_on): await handle_turn_on(dev, turn_on) - for plug in dev.children: + for plug in dev.children.values(): orig_state = plug.is_on if orig_state: await plug.turn_off() @@ -39,7 +39,7 @@ async def test_children_change_state(dev, turn_on): @strip async def test_children_alias(dev): test_alias = "TEST1234" - for plug in dev.children: + for plug in dev.children.values(): original = plug.alias await plug.set_alias(alias=test_alias) await dev.update() # TODO: set_alias does not call parent's update().. @@ -53,7 +53,7 @@ async def test_children_alias(dev): @strip async def test_children_on_since(dev): on_sinces = [] - for plug in dev.children: + for plug in dev.children.values(): if plug.is_on: on_sinces.append(plug.on_since) assert isinstance(plug.on_since, datetime) @@ -70,8 +70,9 @@ async def test_children_on_since(dev): @strip async def test_get_plug_by_name(dev: IotStrip): - name = dev.children[0].alias - assert dev.get_plug_by_name(name) == dev.children[0] # type: ignore[arg-type] + children = list(dev.children.values()) + name = children[0].alias + assert dev.get_plug_by_name(name) == children[0] # type: ignore[arg-type] with pytest.raises(SmartDeviceException): dev.get_plug_by_name("NONEXISTING NAME") @@ -79,7 +80,7 @@ async def test_get_plug_by_name(dev: IotStrip): @strip async def test_get_plug_by_index(dev: IotStrip): - assert dev.get_plug_by_index(0) == dev.children[0] + assert dev.get_plug_by_index(0) == list(dev.children.values())[0] with pytest.raises(SmartDeviceException): dev.get_plug_by_index(-1)