Put modules back on children for wall switches (#881)

Puts modules back on the children for `WallSwitches` (i.e. ks240) and
makes them accessible from the `modules` property on the parent.
This commit is contained in:
Steven B 2024-04-29 17:34:20 +01:00 committed by GitHub
parent 6724506fab
commit cb11b36511
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 80 additions and 47 deletions

View File

@ -39,6 +39,7 @@ from kasa.iot import (
IotStrip, IotStrip,
IotWallSwitch, IotWallSwitch,
) )
from kasa.iot.modules import Usage
from kasa.smart import SmartBulb, SmartDevice from kasa.smart import SmartBulb, SmartDevice
try: try:
@ -829,7 +830,7 @@ async def usage(dev: Device, year, month, erase):
Daily and monthly data provided in CSV format. Daily and monthly data provided in CSV format.
""" """
echo("[bold]== Usage ==[/bold]") echo("[bold]== Usage ==[/bold]")
usage = dev.modules["usage"] usage = cast(Usage, dev.modules["usage"])
if erase: if erase:
echo("Erasing usage statistics..") echo("Erasing usage statistics..")

View File

@ -15,6 +15,7 @@ from .emeterstatus import EmeterStatus
from .exceptions import KasaException from .exceptions import KasaException
from .feature import Feature from .feature import Feature
from .iotprotocol import IotProtocol from .iotprotocol import IotProtocol
from .module import Module
from .protocol import BaseProtocol from .protocol import BaseProtocol
from .xortransport import XorTransport from .xortransport import XorTransport
@ -72,7 +73,6 @@ class Device(ABC):
self._last_update: Any = None self._last_update: Any = None
self._discovery_info: dict[str, Any] | None = None self._discovery_info: dict[str, Any] | None = None
self.modules: dict[str, Any] = {}
self._features: dict[str, Feature] = {} self._features: dict[str, Feature] = {}
self._parent: Device | None = None self._parent: Device | None = None
self._children: Mapping[str, Device] = {} self._children: Mapping[str, Device] = {}
@ -111,6 +111,11 @@ class Device(ABC):
"""Disconnect and close any underlying connection resources.""" """Disconnect and close any underlying connection resources."""
await self.protocol.close() await self.protocol.close()
@property
@abstractmethod
def modules(self) -> Mapping[str, Module]:
"""Return the device modules."""
@property @property
@abstractmethod @abstractmethod
def is_on(self) -> bool: def is_on(self) -> bool:

View File

@ -19,7 +19,7 @@ import functools
import inspect import inspect
import logging import logging
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Mapping, Sequence from typing import Any, Mapping, Sequence, cast
from ..device import Device, WifiNetwork from ..device import Device, WifiNetwork
from ..deviceconfig import DeviceConfig from ..deviceconfig import DeviceConfig
@ -28,7 +28,7 @@ from ..exceptions import KasaException
from ..feature import Feature from ..feature import Feature
from ..protocol import BaseProtocol from ..protocol import BaseProtocol
from .iotmodule import IotModule from .iotmodule import IotModule
from .modules import Emeter from .modules import Emeter, Time
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -189,12 +189,18 @@ class IotDevice(Device):
self._supported_modules: dict[str, IotModule] | None = None self._supported_modules: dict[str, IotModule] | None = None
self._legacy_features: set[str] = set() self._legacy_features: set[str] = set()
self._children: Mapping[str, IotDevice] = {} self._children: Mapping[str, IotDevice] = {}
self._modules: dict[str, IotModule] = {}
@property @property
def children(self) -> Sequence[IotDevice]: def children(self) -> Sequence[IotDevice]:
"""Return list of children.""" """Return list of children."""
return list(self._children.values()) return list(self._children.values())
@property
def modules(self) -> dict[str, IotModule]:
"""Return the device modules."""
return self._modules
def add_module(self, name: str, module: IotModule): def add_module(self, name: str, module: IotModule):
"""Register a module.""" """Register a module."""
if name in self.modules: if name in self.modules:
@ -420,31 +426,31 @@ class IotDevice(Device):
"""Set the device name (alias).""" """Set the device name (alias)."""
return await self._query_helper("system", "set_dev_alias", {"alias": alias}) return await self._query_helper("system", "set_dev_alias", {"alias": alias})
@property # type: ignore @property
@requires_update @requires_update
def time(self) -> datetime: def time(self) -> datetime:
"""Return current time from the device.""" """Return current time from the device."""
return self.modules["time"].time return cast(Time, self.modules["time"]).time
@property # type: ignore @property
@requires_update @requires_update
def timezone(self) -> dict: def timezone(self) -> dict:
"""Return the current timezone.""" """Return the current timezone."""
return self.modules["time"].timezone return cast(Time, self.modules["time"]).timezone
async def get_time(self) -> datetime | None: async def get_time(self) -> datetime | None:
"""Return current time from the device, if available.""" """Return current time from the device, if available."""
_LOGGER.warning( _LOGGER.warning(
"Use `time` property instead, this call will be removed in the future." "Use `time` property instead, this call will be removed in the future."
) )
return await self.modules["time"].get_time() return await cast(Time, self.modules["time"]).get_time()
async def get_timezone(self) -> dict: async def get_timezone(self) -> dict:
"""Return timezone information.""" """Return timezone information."""
_LOGGER.warning( _LOGGER.warning(
"Use `timezone` property instead, this call will be removed in the future." "Use `timezone` property instead, this call will be removed in the future."
) )
return await self.modules["time"].get_timezone() return await cast(Time, self.modules["time"]).get_timezone()
@property # type: ignore @property # type: ignore
@requires_update @requires_update
@ -520,31 +526,31 @@ class IotDevice(Device):
""" """
return await self._query_helper("system", "set_mac_addr", {"mac": mac}) return await self._query_helper("system", "set_mac_addr", {"mac": mac})
@property # type: ignore @property
@requires_update @requires_update
def emeter_realtime(self) -> EmeterStatus: def emeter_realtime(self) -> EmeterStatus:
"""Return current energy readings.""" """Return current energy readings."""
self._verify_emeter() self._verify_emeter()
return EmeterStatus(self.modules["emeter"].realtime) return EmeterStatus(cast(Emeter, self.modules["emeter"]).realtime)
async def get_emeter_realtime(self) -> EmeterStatus: async def get_emeter_realtime(self) -> EmeterStatus:
"""Retrieve current energy readings.""" """Retrieve current energy readings."""
self._verify_emeter() self._verify_emeter()
return EmeterStatus(await self.modules["emeter"].get_realtime()) return EmeterStatus(await cast(Emeter, self.modules["emeter"]).get_realtime())
@property # type: ignore @property
@requires_update @requires_update
def emeter_today(self) -> float | None: def emeter_today(self) -> float | None:
"""Return today's energy consumption in kWh.""" """Return today's energy consumption in kWh."""
self._verify_emeter() self._verify_emeter()
return self.modules["emeter"].emeter_today return cast(Emeter, self.modules["emeter"]).emeter_today
@property # type: ignore @property
@requires_update @requires_update
def emeter_this_month(self) -> float | None: def emeter_this_month(self) -> float | None:
"""Return this month's energy consumption in kWh.""" """Return this month's energy consumption in kWh."""
self._verify_emeter() self._verify_emeter()
return self.modules["emeter"].emeter_this_month return cast(Emeter, self.modules["emeter"]).emeter_this_month
async def get_emeter_daily( async def get_emeter_daily(
self, year: int | None = None, month: int | None = None, kwh: bool = True self, year: int | None = None, month: int | None = None, kwh: bool = True
@ -558,7 +564,9 @@ class IotDevice(Device):
:return: mapping of day of month to value :return: mapping of day of month to value
""" """
self._verify_emeter() self._verify_emeter()
return await self.modules["emeter"].get_daystat(year=year, month=month, kwh=kwh) return await cast(Emeter, self.modules["emeter"]).get_daystat(
year=year, month=month, kwh=kwh
)
@requires_update @requires_update
async def get_emeter_monthly( async def get_emeter_monthly(
@ -571,13 +579,15 @@ class IotDevice(Device):
:return: dict: mapping of month to value :return: dict: mapping of month to value
""" """
self._verify_emeter() self._verify_emeter()
return await self.modules["emeter"].get_monthstat(year=year, kwh=kwh) return await cast(Emeter, self.modules["emeter"]).get_monthstat(
year=year, kwh=kwh
)
@requires_update @requires_update
async def erase_emeter_stats(self) -> dict: async def erase_emeter_stats(self) -> dict:
"""Erase energy meter statistics.""" """Erase energy meter statistics."""
self._verify_emeter() self._verify_emeter()
return await self.modules["emeter"].erase_stats() return await cast(Emeter, self.modules["emeter"]).erase_stats()
@requires_update @requires_update
async def current_consumption(self) -> float: async def current_consumption(self) -> float:

View File

@ -253,7 +253,7 @@ class IotStripPlug(IotPlug):
self._last_update = parent._last_update self._last_update = parent._last_update
self._set_sys_info(parent.sys_info) self._set_sys_info(parent.sys_info)
self._device_type = DeviceType.StripSocket self._device_type = DeviceType.StripSocket
self.modules = {} self._modules = {}
self.protocol = parent.protocol # Must use the same connection as the parent self.protocol = parent.protocol # Must use the same connection as the parent
self.add_module("time", Time(self, "time")) self.add_module("time", Time(self, "time"))

View File

@ -4,11 +4,14 @@ from __future__ import annotations
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from .device import Device
from .exceptions import KasaException from .exceptions import KasaException
from .feature import Feature from .feature import Feature
if TYPE_CHECKING:
from .device import Device
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)

View File

@ -47,7 +47,8 @@ class SmartDevice(Device):
self._components_raw: dict[str, Any] | None = None self._components_raw: dict[str, Any] | None = None
self._components: dict[str, int] = {} self._components: dict[str, int] = {}
self._state_information: dict[str, Any] = {} self._state_information: dict[str, Any] = {}
self.modules: dict[str, SmartModule] = {} self._modules: dict[str, SmartModule] = {}
self._exposes_child_modules = False
self._parent: SmartDevice | None = None self._parent: SmartDevice | None = None
self._children: Mapping[str, SmartDevice] = {} self._children: Mapping[str, SmartDevice] = {}
self._last_update = {} self._last_update = {}
@ -84,11 +85,13 @@ class SmartDevice(Device):
@property @property
def children(self) -> Sequence[SmartDevice]: def children(self) -> Sequence[SmartDevice]:
"""Return list of children.""" """Return list of children."""
# Wall switches with children report all modules on the parent only
if self.device_type == DeviceType.WallSwitch:
return []
return list(self._children.values()) return list(self._children.values())
@property
def modules(self) -> dict[str, SmartModule]:
"""Return the device modules."""
return self._modules
def _try_get_response(self, responses: dict, request: str, default=None) -> dict: def _try_get_response(self, responses: dict, request: str, default=None) -> dict:
response = responses.get(request) response = responses.get(request)
if isinstance(response, SmartErrorCode): if isinstance(response, SmartErrorCode):
@ -148,7 +151,7 @@ class SmartDevice(Device):
req: dict[str, Any] = {} req: dict[str, Any] = {}
# TODO: this could be optimized by constructing the query only once # TODO: this could be optimized by constructing the query only once
for module in self.modules.values(): for module in self._modules.values():
req.update(module.query()) req.update(module.query())
self._last_update = resp = await self.protocol.query(req) self._last_update = resp = await self.protocol.query(req)
@ -174,19 +177,24 @@ class SmartDevice(Device):
# Some wall switches (like ks240) are internally presented as having child # Some wall switches (like ks240) are internally presented as having child
# devices which report the child's components on the parent's sysinfo, even # devices which report the child's components on the parent's sysinfo, even
# when they need to be accessed through the children. # when they need to be accessed through the children.
# The logic below ensures that such devices report all but whitelisted, the # The logic below ensures that such devices add all but whitelisted, only on
# child modules at the parent level to create an illusion of a single device. # the child device.
skip_parent_only_modules = False
child_modules_to_skip = {}
if self._parent and self._parent.device_type == DeviceType.WallSwitch: if self._parent and self._parent.device_type == DeviceType.WallSwitch:
modules = self._parent.modules
skip_parent_only_modules = True skip_parent_only_modules = True
else: elif self._children and self.device_type == DeviceType.WallSwitch:
modules = self.modules # _initialize_modules is called on the parent after the children
skip_parent_only_modules = False self._exposes_child_modules = True
for child in self._children.values():
child_modules_to_skip.update(**child.modules)
for mod in SmartModule.REGISTERED_MODULES.values(): for mod in SmartModule.REGISTERED_MODULES.values():
_LOGGER.debug("%s requires %s", mod, mod.REQUIRED_COMPONENT) _LOGGER.debug("%s requires %s", mod, mod.REQUIRED_COMPONENT)
if skip_parent_only_modules and mod in WALL_SWITCH_PARENT_ONLY_MODULES: if (
skip_parent_only_modules and mod in WALL_SWITCH_PARENT_ONLY_MODULES
) or mod.__name__ in child_modules_to_skip:
continue continue
if mod.REQUIRED_COMPONENT in self._components: if mod.REQUIRED_COMPONENT in self._components:
_LOGGER.debug( _LOGGER.debug(
@ -195,8 +203,11 @@ class SmartDevice(Device):
mod.__name__, mod.__name__,
) )
module = mod(self, mod.REQUIRED_COMPONENT) module = mod(self, mod.REQUIRED_COMPONENT)
if module.name not in modules and await module._check_supported(): if await module._check_supported():
modules[module.name] = module self._modules[module.name] = module
if self._exposes_child_modules:
self._modules.update(**child_modules_to_skip)
async def _initialize_features(self): async def _initialize_features(self):
"""Initialize device features.""" """Initialize device features."""
@ -278,7 +289,7 @@ class SmartDevice(Device):
) )
) )
for module in self.modules.values(): for module in self._modules.values():
for feat in module._module_features.values(): for feat in module._module_features.values():
self._add_feature(feat) self._add_feature(feat)

View File

@ -1,6 +1,9 @@
from typing import cast
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
from kasa import SmartDevice from kasa import SmartDevice
from kasa.smart.modules import FanModule
from kasa.tests.device_fixtures import parametrize from kasa.tests.device_fixtures import parametrize
fan = parametrize("has fan", component_filter="fan_control", protocol_filter={"SMART"}) fan = parametrize("has fan", component_filter="fan_control", protocol_filter={"SMART"})
@ -9,7 +12,7 @@ fan = parametrize("has fan", component_filter="fan_control", protocol_filter={"S
@fan @fan
async def test_fan_speed(dev: SmartDevice, mocker: MockerFixture): async def test_fan_speed(dev: SmartDevice, mocker: MockerFixture):
"""Test fan speed feature.""" """Test fan speed feature."""
fan = dev.modules.get("FanModule") fan = cast(FanModule, dev.modules.get("FanModule"))
assert fan assert fan
level_feature = fan._module_features["fan_speed_level"] level_feature = fan._module_features["fan_speed_level"]
@ -32,7 +35,7 @@ async def test_fan_speed(dev: SmartDevice, mocker: MockerFixture):
@fan @fan
async def test_sleep_mode(dev: SmartDevice, mocker: MockerFixture): async def test_sleep_mode(dev: SmartDevice, mocker: MockerFixture):
"""Test sleep mode feature.""" """Test sleep mode feature."""
fan = dev.modules.get("FanModule") fan = cast(FanModule, dev.modules.get("FanModule"))
assert fan assert fan
sleep_feature = fan._module_features["fan_sleep_mode"] sleep_feature = fan._module_features["fan_sleep_mode"]
assert isinstance(sleep_feature.value, bool) assert isinstance(sleep_feature.value, bool)

View File

@ -103,22 +103,22 @@ async def test_negotiate(dev: SmartDevice, mocker: MockerFixture):
async def test_update_module_queries(dev: SmartDevice, mocker: MockerFixture): async def test_update_module_queries(dev: SmartDevice, mocker: MockerFixture):
"""Test that the regular update uses queries from all supported modules.""" """Test that the regular update uses queries from all supported modules."""
# We need to have some modules initialized by now # We need to have some modules initialized by now
assert dev.modules assert dev._modules
device_queries: dict[SmartDevice, dict[str, Any]] = {} device_queries: dict[SmartDevice, dict[str, Any]] = {}
for mod in dev.modules.values(): for mod in dev._modules.values():
device_queries.setdefault(mod._device, {}).update(mod.query()) device_queries.setdefault(mod._device, {}).update(mod.query())
spies = {} spies = {}
for dev in device_queries: for device in device_queries:
spies[dev] = mocker.spy(dev.protocol, "query") spies[device] = mocker.spy(device.protocol, "query")
await dev.update() await dev.update()
for dev in device_queries: for device in device_queries:
if device_queries[dev]: if device_queries[device]:
spies[dev].assert_called_with(device_queries[dev]) spies[device].assert_called_with(device_queries[device])
else: else:
spies[dev].assert_not_called() spies[device].assert_not_called()
@bulb_smart @bulb_smart