Put modules back on children for wall switches (#881)

Puts modules back on the children for `WallSwitches` (i.e. ks240) and
makes them accessible from the `modules` property on the parent.
This commit is contained in:
Steven B 2024-04-29 17:34:20 +01:00 committed by GitHub
parent 6724506fab
commit cb11b36511
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 80 additions and 47 deletions

View File

@ -39,6 +39,7 @@ from kasa.iot import (
IotStrip,
IotWallSwitch,
)
from kasa.iot.modules import Usage
from kasa.smart import SmartBulb, SmartDevice
try:
@ -829,7 +830,7 @@ async def usage(dev: Device, year, month, erase):
Daily and monthly data provided in CSV format.
"""
echo("[bold]== Usage ==[/bold]")
usage = dev.modules["usage"]
usage = cast(Usage, dev.modules["usage"])
if erase:
echo("Erasing usage statistics..")

View File

@ -15,6 +15,7 @@ from .emeterstatus import EmeterStatus
from .exceptions import KasaException
from .feature import Feature
from .iotprotocol import IotProtocol
from .module import Module
from .protocol import BaseProtocol
from .xortransport import XorTransport
@ -72,7 +73,6 @@ class Device(ABC):
self._last_update: Any = None
self._discovery_info: dict[str, Any] | None = None
self.modules: dict[str, Any] = {}
self._features: dict[str, Feature] = {}
self._parent: Device | None = None
self._children: Mapping[str, Device] = {}
@ -111,6 +111,11 @@ class Device(ABC):
"""Disconnect and close any underlying connection resources."""
await self.protocol.close()
@property
@abstractmethod
def modules(self) -> Mapping[str, Module]:
"""Return the device modules."""
@property
@abstractmethod
def is_on(self) -> bool:

View File

@ -19,7 +19,7 @@ import functools
import inspect
import logging
from datetime import datetime, timedelta
from typing import Any, Mapping, Sequence
from typing import Any, Mapping, Sequence, cast
from ..device import Device, WifiNetwork
from ..deviceconfig import DeviceConfig
@ -28,7 +28,7 @@ from ..exceptions import KasaException
from ..feature import Feature
from ..protocol import BaseProtocol
from .iotmodule import IotModule
from .modules import Emeter
from .modules import Emeter, Time
_LOGGER = logging.getLogger(__name__)
@ -189,12 +189,18 @@ class IotDevice(Device):
self._supported_modules: dict[str, IotModule] | None = None
self._legacy_features: set[str] = set()
self._children: Mapping[str, IotDevice] = {}
self._modules: dict[str, IotModule] = {}
@property
def children(self) -> Sequence[IotDevice]:
"""Return list of children."""
return list(self._children.values())
@property
def modules(self) -> dict[str, IotModule]:
"""Return the device modules."""
return self._modules
def add_module(self, name: str, module: IotModule):
"""Register a module."""
if name in self.modules:
@ -420,31 +426,31 @@ class IotDevice(Device):
"""Set the device name (alias)."""
return await self._query_helper("system", "set_dev_alias", {"alias": alias})
@property # type: ignore
@property
@requires_update
def time(self) -> datetime:
"""Return current time from the device."""
return self.modules["time"].time
return cast(Time, self.modules["time"]).time
@property # type: ignore
@property
@requires_update
def timezone(self) -> dict:
"""Return the current timezone."""
return self.modules["time"].timezone
return cast(Time, self.modules["time"]).timezone
async def get_time(self) -> datetime | None:
"""Return current time from the device, if available."""
_LOGGER.warning(
"Use `time` property instead, this call will be removed in the future."
)
return await self.modules["time"].get_time()
return await cast(Time, self.modules["time"]).get_time()
async def get_timezone(self) -> dict:
"""Return timezone information."""
_LOGGER.warning(
"Use `timezone` property instead, this call will be removed in the future."
)
return await self.modules["time"].get_timezone()
return await cast(Time, self.modules["time"]).get_timezone()
@property # type: ignore
@requires_update
@ -520,31 +526,31 @@ class IotDevice(Device):
"""
return await self._query_helper("system", "set_mac_addr", {"mac": mac})
@property # type: ignore
@property
@requires_update
def emeter_realtime(self) -> EmeterStatus:
"""Return current energy readings."""
self._verify_emeter()
return EmeterStatus(self.modules["emeter"].realtime)
return EmeterStatus(cast(Emeter, self.modules["emeter"]).realtime)
async def get_emeter_realtime(self) -> EmeterStatus:
"""Retrieve current energy readings."""
self._verify_emeter()
return EmeterStatus(await self.modules["emeter"].get_realtime())
return EmeterStatus(await cast(Emeter, self.modules["emeter"]).get_realtime())
@property # type: ignore
@property
@requires_update
def emeter_today(self) -> float | None:
"""Return today's energy consumption in kWh."""
self._verify_emeter()
return self.modules["emeter"].emeter_today
return cast(Emeter, self.modules["emeter"]).emeter_today
@property # type: ignore
@property
@requires_update
def emeter_this_month(self) -> float | None:
"""Return this month's energy consumption in kWh."""
self._verify_emeter()
return self.modules["emeter"].emeter_this_month
return cast(Emeter, self.modules["emeter"]).emeter_this_month
async def get_emeter_daily(
self, year: int | None = None, month: int | None = None, kwh: bool = True
@ -558,7 +564,9 @@ class IotDevice(Device):
:return: mapping of day of month to value
"""
self._verify_emeter()
return await self.modules["emeter"].get_daystat(year=year, month=month, kwh=kwh)
return await cast(Emeter, self.modules["emeter"]).get_daystat(
year=year, month=month, kwh=kwh
)
@requires_update
async def get_emeter_monthly(
@ -571,13 +579,15 @@ class IotDevice(Device):
:return: dict: mapping of month to value
"""
self._verify_emeter()
return await self.modules["emeter"].get_monthstat(year=year, kwh=kwh)
return await cast(Emeter, self.modules["emeter"]).get_monthstat(
year=year, kwh=kwh
)
@requires_update
async def erase_emeter_stats(self) -> dict:
"""Erase energy meter statistics."""
self._verify_emeter()
return await self.modules["emeter"].erase_stats()
return await cast(Emeter, self.modules["emeter"]).erase_stats()
@requires_update
async def current_consumption(self) -> float:

View File

@ -253,7 +253,7 @@ class IotStripPlug(IotPlug):
self._last_update = parent._last_update
self._set_sys_info(parent.sys_info)
self._device_type = DeviceType.StripSocket
self.modules = {}
self._modules = {}
self.protocol = parent.protocol # Must use the same connection as the parent
self.add_module("time", Time(self, "time"))

View File

@ -4,11 +4,14 @@ from __future__ import annotations
import logging
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from .device import Device
from .exceptions import KasaException
from .feature import Feature
if TYPE_CHECKING:
from .device import Device
_LOGGER = logging.getLogger(__name__)

View File

@ -47,7 +47,8 @@ class SmartDevice(Device):
self._components_raw: dict[str, Any] | None = None
self._components: dict[str, int] = {}
self._state_information: dict[str, Any] = {}
self.modules: dict[str, SmartModule] = {}
self._modules: dict[str, SmartModule] = {}
self._exposes_child_modules = False
self._parent: SmartDevice | None = None
self._children: Mapping[str, SmartDevice] = {}
self._last_update = {}
@ -84,11 +85,13 @@ class SmartDevice(Device):
@property
def children(self) -> Sequence[SmartDevice]:
"""Return list of children."""
# Wall switches with children report all modules on the parent only
if self.device_type == DeviceType.WallSwitch:
return []
return list(self._children.values())
@property
def modules(self) -> dict[str, SmartModule]:
"""Return the device modules."""
return self._modules
def _try_get_response(self, responses: dict, request: str, default=None) -> dict:
response = responses.get(request)
if isinstance(response, SmartErrorCode):
@ -148,7 +151,7 @@ class SmartDevice(Device):
req: dict[str, Any] = {}
# TODO: this could be optimized by constructing the query only once
for module in self.modules.values():
for module in self._modules.values():
req.update(module.query())
self._last_update = resp = await self.protocol.query(req)
@ -174,19 +177,24 @@ class SmartDevice(Device):
# Some wall switches (like ks240) are internally presented as having child
# devices which report the child's components on the parent's sysinfo, even
# when they need to be accessed through the children.
# The logic below ensures that such devices report all but whitelisted, the
# child modules at the parent level to create an illusion of a single device.
# The logic below ensures that such devices add all but whitelisted, only on
# the child device.
skip_parent_only_modules = False
child_modules_to_skip = {}
if self._parent and self._parent.device_type == DeviceType.WallSwitch:
modules = self._parent.modules
skip_parent_only_modules = True
else:
modules = self.modules
skip_parent_only_modules = False
elif self._children and self.device_type == DeviceType.WallSwitch:
# _initialize_modules is called on the parent after the children
self._exposes_child_modules = True
for child in self._children.values():
child_modules_to_skip.update(**child.modules)
for mod in SmartModule.REGISTERED_MODULES.values():
_LOGGER.debug("%s requires %s", mod, mod.REQUIRED_COMPONENT)
if skip_parent_only_modules and mod in WALL_SWITCH_PARENT_ONLY_MODULES:
if (
skip_parent_only_modules and mod in WALL_SWITCH_PARENT_ONLY_MODULES
) or mod.__name__ in child_modules_to_skip:
continue
if mod.REQUIRED_COMPONENT in self._components:
_LOGGER.debug(
@ -195,8 +203,11 @@ class SmartDevice(Device):
mod.__name__,
)
module = mod(self, mod.REQUIRED_COMPONENT)
if module.name not in modules and await module._check_supported():
modules[module.name] = module
if await module._check_supported():
self._modules[module.name] = module
if self._exposes_child_modules:
self._modules.update(**child_modules_to_skip)
async def _initialize_features(self):
"""Initialize device features."""
@ -278,7 +289,7 @@ class SmartDevice(Device):
)
)
for module in self.modules.values():
for module in self._modules.values():
for feat in module._module_features.values():
self._add_feature(feat)

View File

@ -1,6 +1,9 @@
from typing import cast
from pytest_mock import MockerFixture
from kasa import SmartDevice
from kasa.smart.modules import FanModule
from kasa.tests.device_fixtures import parametrize
fan = parametrize("has fan", component_filter="fan_control", protocol_filter={"SMART"})
@ -9,7 +12,7 @@ fan = parametrize("has fan", component_filter="fan_control", protocol_filter={"S
@fan
async def test_fan_speed(dev: SmartDevice, mocker: MockerFixture):
"""Test fan speed feature."""
fan = dev.modules.get("FanModule")
fan = cast(FanModule, dev.modules.get("FanModule"))
assert fan
level_feature = fan._module_features["fan_speed_level"]
@ -32,7 +35,7 @@ async def test_fan_speed(dev: SmartDevice, mocker: MockerFixture):
@fan
async def test_sleep_mode(dev: SmartDevice, mocker: MockerFixture):
"""Test sleep mode feature."""
fan = dev.modules.get("FanModule")
fan = cast(FanModule, dev.modules.get("FanModule"))
assert fan
sleep_feature = fan._module_features["fan_sleep_mode"]
assert isinstance(sleep_feature.value, bool)

View File

@ -103,22 +103,22 @@ async def test_negotiate(dev: SmartDevice, mocker: MockerFixture):
async def test_update_module_queries(dev: SmartDevice, mocker: MockerFixture):
"""Test that the regular update uses queries from all supported modules."""
# We need to have some modules initialized by now
assert dev.modules
assert dev._modules
device_queries: dict[SmartDevice, dict[str, Any]] = {}
for mod in dev.modules.values():
for mod in dev._modules.values():
device_queries.setdefault(mod._device, {}).update(mod.query())
spies = {}
for dev in device_queries:
spies[dev] = mocker.spy(dev.protocol, "query")
for device in device_queries:
spies[device] = mocker.spy(device.protocol, "query")
await dev.update()
for dev in device_queries:
if device_queries[dev]:
spies[dev].assert_called_with(device_queries[dev])
for device in device_queries:
if device_queries[device]:
spies[device].assert_called_with(device_queries[device])
else:
spies[dev].assert_not_called()
spies[device].assert_not_called()
@bulb_smart