diff --git a/kasa/cli.py b/kasa/cli.py index 66eb8936..317bf038 100755 --- a/kasa/cli.py +++ b/kasa/cli.py @@ -39,6 +39,7 @@ from kasa.iot import ( IotStrip, IotWallSwitch, ) +from kasa.iot.modules import Usage from kasa.smart import SmartBulb, SmartDevice try: @@ -829,7 +830,7 @@ async def usage(dev: Device, year, month, erase): Daily and monthly data provided in CSV format. """ echo("[bold]== Usage ==[/bold]") - usage = dev.modules["usage"] + usage = cast(Usage, dev.modules["usage"]) if erase: echo("Erasing usage statistics..") diff --git a/kasa/device.py b/kasa/device.py index dda7822f..8a81030f 100644 --- a/kasa/device.py +++ b/kasa/device.py @@ -15,6 +15,7 @@ from .emeterstatus import EmeterStatus from .exceptions import KasaException from .feature import Feature from .iotprotocol import IotProtocol +from .module import Module from .protocol import BaseProtocol from .xortransport import XorTransport @@ -72,7 +73,6 @@ class Device(ABC): self._last_update: Any = None self._discovery_info: dict[str, Any] | None = None - self.modules: dict[str, Any] = {} self._features: dict[str, Feature] = {} self._parent: Device | None = None self._children: Mapping[str, Device] = {} @@ -111,6 +111,11 @@ class Device(ABC): """Disconnect and close any underlying connection resources.""" await self.protocol.close() + @property + @abstractmethod + def modules(self) -> Mapping[str, Module]: + """Return the device modules.""" + @property @abstractmethod def is_on(self) -> bool: diff --git a/kasa/iot/iotdevice.py b/kasa/iot/iotdevice.py index d4551d0d..81b5edda 100755 --- a/kasa/iot/iotdevice.py +++ b/kasa/iot/iotdevice.py @@ -19,7 +19,7 @@ import functools import inspect import logging from datetime import datetime, timedelta -from typing import Any, Mapping, Sequence +from typing import Any, Mapping, Sequence, cast from ..device import Device, WifiNetwork from ..deviceconfig import DeviceConfig @@ -28,7 +28,7 @@ from ..exceptions import KasaException from ..feature import Feature from ..protocol import BaseProtocol from .iotmodule import IotModule -from .modules import Emeter +from .modules import Emeter, Time _LOGGER = logging.getLogger(__name__) @@ -189,12 +189,18 @@ class IotDevice(Device): self._supported_modules: dict[str, IotModule] | None = None self._legacy_features: set[str] = set() self._children: Mapping[str, IotDevice] = {} + self._modules: dict[str, IotModule] = {} @property def children(self) -> Sequence[IotDevice]: """Return list of children.""" return list(self._children.values()) + @property + def modules(self) -> dict[str, IotModule]: + """Return the device modules.""" + return self._modules + def add_module(self, name: str, module: IotModule): """Register a module.""" if name in self.modules: @@ -420,31 +426,31 @@ class IotDevice(Device): """Set the device name (alias).""" return await self._query_helper("system", "set_dev_alias", {"alias": alias}) - @property # type: ignore + @property @requires_update def time(self) -> datetime: """Return current time from the device.""" - return self.modules["time"].time + return cast(Time, self.modules["time"]).time - @property # type: ignore + @property @requires_update def timezone(self) -> dict: """Return the current timezone.""" - return self.modules["time"].timezone + return cast(Time, self.modules["time"]).timezone async def get_time(self) -> datetime | None: """Return current time from the device, if available.""" _LOGGER.warning( "Use `time` property instead, this call will be removed in the future." ) - return await self.modules["time"].get_time() + return await cast(Time, self.modules["time"]).get_time() async def get_timezone(self) -> dict: """Return timezone information.""" _LOGGER.warning( "Use `timezone` property instead, this call will be removed in the future." ) - return await self.modules["time"].get_timezone() + return await cast(Time, self.modules["time"]).get_timezone() @property # type: ignore @requires_update @@ -520,31 +526,31 @@ class IotDevice(Device): """ return await self._query_helper("system", "set_mac_addr", {"mac": mac}) - @property # type: ignore + @property @requires_update def emeter_realtime(self) -> EmeterStatus: """Return current energy readings.""" self._verify_emeter() - return EmeterStatus(self.modules["emeter"].realtime) + return EmeterStatus(cast(Emeter, self.modules["emeter"]).realtime) async def get_emeter_realtime(self) -> EmeterStatus: """Retrieve current energy readings.""" self._verify_emeter() - return EmeterStatus(await self.modules["emeter"].get_realtime()) + return EmeterStatus(await cast(Emeter, self.modules["emeter"]).get_realtime()) - @property # type: ignore + @property @requires_update def emeter_today(self) -> float | None: """Return today's energy consumption in kWh.""" self._verify_emeter() - return self.modules["emeter"].emeter_today + return cast(Emeter, self.modules["emeter"]).emeter_today - @property # type: ignore + @property @requires_update def emeter_this_month(self) -> float | None: """Return this month's energy consumption in kWh.""" self._verify_emeter() - return self.modules["emeter"].emeter_this_month + return cast(Emeter, self.modules["emeter"]).emeter_this_month async def get_emeter_daily( self, year: int | None = None, month: int | None = None, kwh: bool = True @@ -558,7 +564,9 @@ class IotDevice(Device): :return: mapping of day of month to value """ self._verify_emeter() - return await self.modules["emeter"].get_daystat(year=year, month=month, kwh=kwh) + return await cast(Emeter, self.modules["emeter"]).get_daystat( + year=year, month=month, kwh=kwh + ) @requires_update async def get_emeter_monthly( @@ -571,13 +579,15 @@ class IotDevice(Device): :return: dict: mapping of month to value """ self._verify_emeter() - return await self.modules["emeter"].get_monthstat(year=year, kwh=kwh) + return await cast(Emeter, self.modules["emeter"]).get_monthstat( + year=year, kwh=kwh + ) @requires_update async def erase_emeter_stats(self) -> dict: """Erase energy meter statistics.""" self._verify_emeter() - return await self.modules["emeter"].erase_stats() + return await cast(Emeter, self.modules["emeter"]).erase_stats() @requires_update async def current_consumption(self) -> float: diff --git a/kasa/iot/iotstrip.py b/kasa/iot/iotstrip.py index 99f5913d..9e99a074 100755 --- a/kasa/iot/iotstrip.py +++ b/kasa/iot/iotstrip.py @@ -253,7 +253,7 @@ class IotStripPlug(IotPlug): self._last_update = parent._last_update self._set_sys_info(parent.sys_info) self._device_type = DeviceType.StripSocket - self.modules = {} + self._modules = {} self.protocol = parent.protocol # Must use the same connection as the parent self.add_module("time", Time(self, "time")) diff --git a/kasa/module.py b/kasa/module.py index ad0b5562..213a2e0a 100644 --- a/kasa/module.py +++ b/kasa/module.py @@ -4,11 +4,14 @@ from __future__ import annotations import logging from abc import ABC, abstractmethod +from typing import TYPE_CHECKING -from .device import Device from .exceptions import KasaException from .feature import Feature +if TYPE_CHECKING: + from .device import Device + _LOGGER = logging.getLogger(__name__) diff --git a/kasa/smart/smartdevice.py b/kasa/smart/smartdevice.py index b325614b..80528fe4 100644 --- a/kasa/smart/smartdevice.py +++ b/kasa/smart/smartdevice.py @@ -47,7 +47,8 @@ class SmartDevice(Device): self._components_raw: dict[str, Any] | None = None self._components: dict[str, int] = {} self._state_information: dict[str, Any] = {} - self.modules: dict[str, SmartModule] = {} + self._modules: dict[str, SmartModule] = {} + self._exposes_child_modules = False self._parent: SmartDevice | None = None self._children: Mapping[str, SmartDevice] = {} self._last_update = {} @@ -84,11 +85,13 @@ class SmartDevice(Device): @property def children(self) -> Sequence[SmartDevice]: """Return list of children.""" - # Wall switches with children report all modules on the parent only - if self.device_type == DeviceType.WallSwitch: - return [] return list(self._children.values()) + @property + def modules(self) -> dict[str, SmartModule]: + """Return the device modules.""" + return self._modules + def _try_get_response(self, responses: dict, request: str, default=None) -> dict: response = responses.get(request) if isinstance(response, SmartErrorCode): @@ -148,7 +151,7 @@ class SmartDevice(Device): req: dict[str, Any] = {} # TODO: this could be optimized by constructing the query only once - for module in self.modules.values(): + for module in self._modules.values(): req.update(module.query()) self._last_update = resp = await self.protocol.query(req) @@ -174,19 +177,24 @@ class SmartDevice(Device): # Some wall switches (like ks240) are internally presented as having child # devices which report the child's components on the parent's sysinfo, even # when they need to be accessed through the children. - # The logic below ensures that such devices report all but whitelisted, the - # child modules at the parent level to create an illusion of a single device. + # The logic below ensures that such devices add all but whitelisted, only on + # the child device. + skip_parent_only_modules = False + child_modules_to_skip = {} if self._parent and self._parent.device_type == DeviceType.WallSwitch: - modules = self._parent.modules skip_parent_only_modules = True - else: - modules = self.modules - skip_parent_only_modules = False + elif self._children and self.device_type == DeviceType.WallSwitch: + # _initialize_modules is called on the parent after the children + self._exposes_child_modules = True + for child in self._children.values(): + child_modules_to_skip.update(**child.modules) for mod in SmartModule.REGISTERED_MODULES.values(): _LOGGER.debug("%s requires %s", mod, mod.REQUIRED_COMPONENT) - if skip_parent_only_modules and mod in WALL_SWITCH_PARENT_ONLY_MODULES: + if ( + skip_parent_only_modules and mod in WALL_SWITCH_PARENT_ONLY_MODULES + ) or mod.__name__ in child_modules_to_skip: continue if mod.REQUIRED_COMPONENT in self._components: _LOGGER.debug( @@ -195,8 +203,11 @@ class SmartDevice(Device): mod.__name__, ) module = mod(self, mod.REQUIRED_COMPONENT) - if module.name not in modules and await module._check_supported(): - modules[module.name] = module + if await module._check_supported(): + self._modules[module.name] = module + + if self._exposes_child_modules: + self._modules.update(**child_modules_to_skip) async def _initialize_features(self): """Initialize device features.""" @@ -278,7 +289,7 @@ class SmartDevice(Device): ) ) - for module in self.modules.values(): + for module in self._modules.values(): for feat in module._module_features.values(): self._add_feature(feat) diff --git a/kasa/tests/smart/modules/test_fan.py b/kasa/tests/smart/modules/test_fan.py index 559ffefe..41d5706c 100644 --- a/kasa/tests/smart/modules/test_fan.py +++ b/kasa/tests/smart/modules/test_fan.py @@ -1,6 +1,9 @@ +from typing import cast + from pytest_mock import MockerFixture from kasa import SmartDevice +from kasa.smart.modules import FanModule from kasa.tests.device_fixtures import parametrize fan = parametrize("has fan", component_filter="fan_control", protocol_filter={"SMART"}) @@ -9,7 +12,7 @@ fan = parametrize("has fan", component_filter="fan_control", protocol_filter={"S @fan async def test_fan_speed(dev: SmartDevice, mocker: MockerFixture): """Test fan speed feature.""" - fan = dev.modules.get("FanModule") + fan = cast(FanModule, dev.modules.get("FanModule")) assert fan level_feature = fan._module_features["fan_speed_level"] @@ -32,7 +35,7 @@ async def test_fan_speed(dev: SmartDevice, mocker: MockerFixture): @fan async def test_sleep_mode(dev: SmartDevice, mocker: MockerFixture): """Test sleep mode feature.""" - fan = dev.modules.get("FanModule") + fan = cast(FanModule, dev.modules.get("FanModule")) assert fan sleep_feature = fan._module_features["fan_sleep_mode"] assert isinstance(sleep_feature.value, bool) diff --git a/kasa/tests/test_smartdevice.py b/kasa/tests/test_smartdevice.py index 037edaf9..2b39e105 100644 --- a/kasa/tests/test_smartdevice.py +++ b/kasa/tests/test_smartdevice.py @@ -103,22 +103,22 @@ async def test_negotiate(dev: SmartDevice, mocker: MockerFixture): async def test_update_module_queries(dev: SmartDevice, mocker: MockerFixture): """Test that the regular update uses queries from all supported modules.""" # We need to have some modules initialized by now - assert dev.modules + assert dev._modules device_queries: dict[SmartDevice, dict[str, Any]] = {} - for mod in dev.modules.values(): + for mod in dev._modules.values(): device_queries.setdefault(mod._device, {}).update(mod.query()) spies = {} - for dev in device_queries: - spies[dev] = mocker.spy(dev.protocol, "query") + for device in device_queries: + spies[device] = mocker.spy(device.protocol, "query") await dev.update() - for dev in device_queries: - if device_queries[dev]: - spies[dev].assert_called_with(device_queries[dev]) + for device in device_queries: + if device_queries[device]: + spies[device].assert_called_with(device_queries[device]) else: - spies[dev].assert_not_called() + spies[device].assert_not_called() @bulb_smart