diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8c0438d9..c274bb97 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,6 +21,11 @@ repos: hooks: - id: mypy additional_dependencies: [types-click] + exclude: | + (?x)^( + kasa/modulemapping\.py| + )$ + - repo: https://github.com/PyCQA/doc8 rev: 'v1.1.1' diff --git a/devtools/create_module_fixtures.py b/devtools/create_module_fixtures.py index 8372bfff..ed881a88 100644 --- a/devtools/create_module_fixtures.py +++ b/devtools/create_module_fixtures.py @@ -19,7 +19,7 @@ app = typer.Typer() def create_fixtures(dev: IotDevice, outputdir: Path): """Iterate over supported modules and create version-specific fixture files.""" for name, module in dev.modules.items(): - module_dir = outputdir / name + module_dir = outputdir / str(name) if not module_dir.exists(): module_dir.mkdir(exist_ok=True, parents=True) diff --git a/kasa/__init__.py b/kasa/__init__.py index 62d54502..e9f64c70 100755 --- a/kasa/__init__.py +++ b/kasa/__init__.py @@ -16,7 +16,7 @@ from importlib.metadata import version from typing import TYPE_CHECKING from warnings import warn -from kasa.bulb import Bulb +from kasa.bulb import Bulb, BulbPreset from kasa.credentials import Credentials from kasa.device import Device from kasa.device_type import DeviceType @@ -36,12 +36,11 @@ from kasa.exceptions import ( UnsupportedDeviceError, ) from kasa.feature import Feature -from kasa.iot.iotbulb import BulbPreset, TurnOnBehavior, TurnOnBehaviors from kasa.iotprotocol import ( IotProtocol, _deprecated_TPLinkSmartHomeProtocol, # noqa: F401 ) -from kasa.plug import Plug +from kasa.module import Module from kasa.protocol import BaseProtocol from kasa.smartprotocol import SmartProtocol @@ -62,6 +61,7 @@ __all__ = [ "Device", "Bulb", "Plug", + "Module", "KasaException", "AuthenticationError", "DeviceError", diff --git a/kasa/bulb.py b/kasa/bulb.py index 01065dc0..52a722d9 100644 --- a/kasa/bulb.py +++ b/kasa/bulb.py @@ -54,11 +54,6 @@ class Bulb(Device, ABC): def is_color(self) -> bool: """Whether the bulb supports color changes.""" - @property - @abstractmethod - def is_dimmable(self) -> bool: - """Whether the bulb supports brightness changes.""" - @property @abstractmethod def is_variable_color_temp(self) -> bool: diff --git a/kasa/device.py b/kasa/device.py index ea358a8d..8150352d 100644 --- a/kasa/device.py +++ b/kasa/device.py @@ -6,7 +6,7 @@ import logging from abc import ABC, abstractmethod from dataclasses import dataclass from datetime import datetime -from typing import Any, Mapping, Sequence, overload +from typing import TYPE_CHECKING, Any, Mapping, Sequence from .credentials import Credentials from .device_type import DeviceType @@ -15,10 +15,13 @@ from .emeterstatus import EmeterStatus from .exceptions import KasaException from .feature import Feature from .iotprotocol import IotProtocol -from .module import Module, ModuleT +from .module import Module from .protocol import BaseProtocol from .xortransport import XorTransport +if TYPE_CHECKING: + from .modulemapping import ModuleMapping + @dataclass class WifiNetwork: @@ -113,21 +116,9 @@ class Device(ABC): @property @abstractmethod - def modules(self) -> Mapping[str, Module]: + def modules(self) -> ModuleMapping[Module]: """Return the device modules.""" - @overload - @abstractmethod - def get_module(self, module_type: type[ModuleT]) -> ModuleT | None: ... - - @overload - @abstractmethod - def get_module(self, module_type: str) -> Module | None: ... - - @abstractmethod - def get_module(self, module_type: type[ModuleT] | str) -> ModuleT | Module | None: - """Return the module from the device modules or None if not present.""" - @property @abstractmethod def is_on(self) -> bool: diff --git a/kasa/interfaces/led.py b/kasa/interfaces/led.py new file mode 100644 index 00000000..2ddba00c --- /dev/null +++ b/kasa/interfaces/led.py @@ -0,0 +1,38 @@ +"""Module for base light effect module.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod + +from ..feature import Feature +from ..module import Module + + +class Led(Module, ABC): + """Base interface to represent a LED module.""" + + def _initialize_features(self): + """Initialize features.""" + device = self._device + self._add_feature( + Feature( + device=device, + container=self, + name="LED", + id="led", + icon="mdi:led", + attribute_getter="led", + attribute_setter="set_led", + type=Feature.Type.Switch, + category=Feature.Category.Config, + ) + ) + + @property + @abstractmethod + def led(self) -> bool: + """Return current led status.""" + + @abstractmethod + async def set_led(self, enable: bool) -> None: + """Set led.""" diff --git a/kasa/interfaces/lighteffect.py b/kasa/interfaces/lighteffect.py new file mode 100644 index 00000000..0eb11b5b --- /dev/null +++ b/kasa/interfaces/lighteffect.py @@ -0,0 +1,80 @@ +"""Module for base light effect module.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod + +from ..feature import Feature +from ..module import Module + + +class LightEffect(Module, ABC): + """Interface to represent a light effect module.""" + + LIGHT_EFFECTS_OFF = "Off" + + def _initialize_features(self): + """Initialize features.""" + device = self._device + self._add_feature( + Feature( + device, + id="light_effect", + name="Light effect", + container=self, + attribute_getter="effect", + attribute_setter="set_effect", + category=Feature.Category.Primary, + type=Feature.Type.Choice, + choices_getter="effect_list", + ) + ) + + @property + @abstractmethod + def has_custom_effects(self) -> bool: + """Return True if the device supports setting custom effects.""" + + @property + @abstractmethod + def effect(self) -> str: + """Return effect state or name.""" + + @property + @abstractmethod + def effect_list(self) -> list[str]: + """Return built-in effects list. + + Example: + ['Aurora', 'Bubbling Cauldron', ...] + """ + + @abstractmethod + async def set_effect( + self, + effect: str, + *, + brightness: int | None = None, + transition: int | None = None, + ) -> None: + """Set an effect on the device. + + If brightness or transition is defined, + its value will be used instead of the effect-specific default. + + See :meth:`effect_list` for available effects, + or use :meth:`set_custom_effect` for custom effects. + + :param str effect: The effect to set + :param int brightness: The wanted brightness + :param int transition: The wanted transition time + """ + + async def set_custom_effect( + self, + effect_dict: dict, + ) -> None: + """Set a custom effect on the device. + + :param str effect_dict: The custom effect dict to set + """ diff --git a/kasa/effects.py b/kasa/iot/effects.py similarity index 100% rename from kasa/effects.py rename to kasa/iot/effects.py diff --git a/kasa/iot/iotdevice.py b/kasa/iot/iotdevice.py index 29ba3155..762fc06c 100755 --- a/kasa/iot/iotdevice.py +++ b/kasa/iot/iotdevice.py @@ -19,14 +19,15 @@ import functools import inspect import logging from datetime import datetime, timedelta -from typing import Any, Mapping, Sequence, cast, overload +from typing import TYPE_CHECKING, Any, Mapping, Sequence, cast from ..device import Device, WifiNetwork from ..deviceconfig import DeviceConfig from ..emeterstatus import EmeterStatus from ..exceptions import KasaException from ..feature import Feature -from ..module import ModuleT +from ..module import Module +from ..modulemapping import ModuleMapping, ModuleName from ..protocol import BaseProtocol from .iotmodule import IotModule from .modules import Emeter, Time @@ -190,7 +191,7 @@ class IotDevice(Device): self._supported_modules: dict[str, IotModule] | None = None self._legacy_features: set[str] = set() self._children: Mapping[str, IotDevice] = {} - self._modules: dict[str, IotModule] = {} + self._modules: dict[str | ModuleName[Module], IotModule] = {} @property def children(self) -> Sequence[IotDevice]: @@ -198,38 +199,20 @@ class IotDevice(Device): return list(self._children.values()) @property - def modules(self) -> dict[str, IotModule]: + def modules(self) -> ModuleMapping[IotModule]: """Return the device modules.""" + if TYPE_CHECKING: + return cast(ModuleMapping[IotModule], self._modules) return self._modules - @overload - def get_module(self, module_type: type[ModuleT]) -> ModuleT | None: ... - - @overload - def get_module(self, module_type: str) -> IotModule | None: ... - - def get_module( - self, module_type: type[ModuleT] | str - ) -> ModuleT | IotModule | None: - """Return the module from the device modules or None if not present.""" - if isinstance(module_type, str): - module_name = module_type.lower() - elif issubclass(module_type, IotModule): - module_name = module_type.__name__.lower() - else: - return None - if module_name in self.modules: - return self.modules[module_name] - return None - - def add_module(self, name: str, module: IotModule): + def add_module(self, name: str | ModuleName[Module], module: IotModule): """Register a module.""" if name in self.modules: _LOGGER.debug("Module %s already registered, ignoring..." % name) return _LOGGER.debug("Adding module %s", module) - self.modules[name] = module + self._modules[name] = module def _create_request( self, target: str, cmd: str, arg: dict | None = None, child_ids=None @@ -291,11 +274,11 @@ class IotDevice(Device): @property # type: ignore @requires_update - def supported_modules(self) -> list[str]: + def supported_modules(self) -> list[str | ModuleName[Module]]: """Return a set of modules supported by the device.""" # TODO: this should rather be called `features`, but we don't want to break # the API now. Maybe just deprecate it and point the users to use this? - return list(self.modules.keys()) + return list(self._modules.keys()) @property # type: ignore @requires_update @@ -324,10 +307,11 @@ class IotDevice(Device): self._last_update = response self._set_sys_info(response["system"]["get_sysinfo"]) + await self._modular_update(req) + if not self._features: await self._initialize_features() - await self._modular_update(req) self._set_sys_info(self._last_update["system"]["get_sysinfo"]) async def _initialize_features(self): @@ -352,6 +336,11 @@ class IotDevice(Device): ) ) + for module in self._modules.values(): + module._initialize_features() + for module_feat in module._module_features.values(): + self._add_feature(module_feat) + async def _modular_update(self, req: dict) -> None: """Execute an update query.""" if self.has_emeter: @@ -364,17 +353,15 @@ class IotDevice(Device): # making separate handling for this unnecessary if self._supported_modules is None: supported = {} - for module in self.modules.values(): + for module in self._modules.values(): if module.is_supported: supported[module._module] = module - for module_feat in module._module_features.values(): - self._add_feature(module_feat) self._supported_modules = supported request_list = [] est_response_size = 1024 if "system" in req else 0 - for module in self.modules.values(): + for module in self._modules.values(): if not module.is_supported: _LOGGER.debug("Module %s not supported, skipping" % module) continue diff --git a/kasa/iot/iotlightstrip.py b/kasa/iot/iotlightstrip.py index 57b3282f..a120be7a 100644 --- a/kasa/iot/iotlightstrip.py +++ b/kasa/iot/iotlightstrip.py @@ -4,10 +4,12 @@ from __future__ import annotations from ..device_type import DeviceType from ..deviceconfig import DeviceConfig -from ..effects import EFFECT_MAPPING_V1, EFFECT_NAMES_V1 +from ..module import Module from ..protocol import BaseProtocol +from .effects import EFFECT_NAMES_V1 from .iotbulb import IotBulb from .iotdevice import KasaException, requires_update +from .modules.lighteffectmodule import LightEffectModule class IotLightStrip(IotBulb): @@ -54,6 +56,10 @@ class IotLightStrip(IotBulb): ) -> None: super().__init__(host=host, config=config, protocol=protocol) self._device_type = DeviceType.LightStrip + self.add_module( + Module.LightEffect, + LightEffectModule(self, "smartlife.iot.lighting_effect"), + ) @property # type: ignore @requires_update @@ -73,6 +79,8 @@ class IotLightStrip(IotBulb): 'id': '', 'name': ''} """ + # LightEffectModule returns the current effect name + # so return the dict here for backwards compatibility return self.sys_info["lighting_effect_state"] @property # type: ignore @@ -83,6 +91,8 @@ class IotLightStrip(IotBulb): Example: ['Aurora', 'Bubbling Cauldron', ...] """ + # LightEffectModule returns effect names along with a LIGHT_EFFECTS_OFF value + # so return the original effect names here for backwards compatibility return EFFECT_NAMES_V1 if self.has_effects else None @requires_update @@ -105,15 +115,9 @@ class IotLightStrip(IotBulb): :param int brightness: The wanted brightness :param int transition: The wanted transition time """ - if effect not in EFFECT_MAPPING_V1: - raise KasaException(f"The effect {effect} is not a built in effect.") - effect_dict = EFFECT_MAPPING_V1[effect] - if brightness is not None: - effect_dict["brightness"] = brightness - if transition is not None: - effect_dict["transition"] = transition - - await self.set_custom_effect(effect_dict) + await self.modules[Module.LightEffect].set_effect( + effect, brightness=brightness, transition=transition + ) @requires_update async def set_custom_effect( @@ -126,8 +130,4 @@ class IotLightStrip(IotBulb): """ if not self.has_effects: raise KasaException("Bulb does not support effects.") - await self._query_helper( - "smartlife.iot.lighting_effect", - "set_lighting_effect", - effect_dict, - ) + await self.modules[Module.LightEffect].set_custom_effect(effect_dict) diff --git a/kasa/iot/iotmodule.py b/kasa/iot/iotmodule.py index d8fb4812..ca0c3adb 100644 --- a/kasa/iot/iotmodule.py +++ b/kasa/iot/iotmodule.py @@ -43,13 +43,19 @@ class IotModule(Module): @property def data(self): """Return the module specific raw data from the last update.""" - if self._module not in self._device._last_update: + dev = self._device + q = self.query() + + if not q: + return dev.sys_info + + if self._module not in dev._last_update: raise KasaException( f"You need to call update() prior accessing module data" f" for '{self._module}'" ) - return self._device._last_update[self._module] + return dev._last_update[self._module] @property def is_supported(self) -> bool: diff --git a/kasa/iot/iotplug.py b/kasa/iot/iotplug.py index dadb38f2..22238c7a 100644 --- a/kasa/iot/iotplug.py +++ b/kasa/iot/iotplug.py @@ -6,10 +6,10 @@ import logging from ..device_type import DeviceType from ..deviceconfig import DeviceConfig -from ..feature import Feature +from ..module import Module from ..protocol import BaseProtocol from .iotdevice import IotDevice, requires_update -from .modules import Antitheft, Cloud, Schedule, Time, Usage +from .modules import Antitheft, Cloud, LedModule, Schedule, Time, Usage _LOGGER = logging.getLogger(__name__) @@ -58,21 +58,7 @@ class IotPlug(IotDevice): self.add_module("antitheft", Antitheft(self, "anti_theft")) self.add_module("time", Time(self, "time")) self.add_module("cloud", Cloud(self, "cnCloud")) - - async def _initialize_features(self): - await super()._initialize_features() - - self._add_feature( - Feature( - device=self, - id="led", - name="LED", - icon="mdi:led-{state}", - attribute_getter="led", - attribute_setter="set_led", - type=Feature.Type.Switch, - ) - ) + self.add_module(Module.Led, LedModule(self, "system")) @property # type: ignore @requires_update @@ -93,14 +79,11 @@ class IotPlug(IotDevice): @requires_update def led(self) -> bool: """Return the state of the led.""" - sys_info = self.sys_info - return bool(1 - sys_info["led_off"]) + return self.modules[Module.Led].led async def set_led(self, state: bool): """Set the state of the led (night mode).""" - return await self._query_helper( - "system", "set_led_off", {"off": int(not state)} - ) + return await self.modules[Module.Led].set_led(state) class IotWallSwitch(IotPlug): diff --git a/kasa/iot/iotstrip.py b/kasa/iot/iotstrip.py index 9e99a074..ab14abb0 100755 --- a/kasa/iot/iotstrip.py +++ b/kasa/iot/iotstrip.py @@ -253,7 +253,6 @@ class IotStripPlug(IotPlug): self._last_update = parent._last_update self._set_sys_info(parent.sys_info) self._device_type = DeviceType.StripSocket - self._modules = {} self.protocol = parent.protocol # Must use the same connection as the parent self.add_module("time", Time(self, "time")) diff --git a/kasa/iot/modules/__init__.py b/kasa/iot/modules/__init__.py index 41e03bbd..f061e607 100644 --- a/kasa/iot/modules/__init__.py +++ b/kasa/iot/modules/__init__.py @@ -5,6 +5,7 @@ from .antitheft import Antitheft from .cloud import Cloud from .countdown import Countdown from .emeter import Emeter +from .ledmodule import LedModule from .motion import Motion from .rulemodule import Rule, RuleModule from .schedule import Schedule @@ -17,6 +18,7 @@ __all__ = [ "Cloud", "Countdown", "Emeter", + "LedModule", "Motion", "Rule", "RuleModule", diff --git a/kasa/iot/modules/ledmodule.py b/kasa/iot/modules/ledmodule.py new file mode 100644 index 00000000..6b3c6194 --- /dev/null +++ b/kasa/iot/modules/ledmodule.py @@ -0,0 +1,32 @@ +"""Module for led controls.""" + +from __future__ import annotations + +from ...interfaces.led import Led +from ..iotmodule import IotModule + + +class LedModule(IotModule, Led): + """Implementation of led controls.""" + + def query(self) -> dict: + """Query to execute during the update cycle.""" + return {} + + @property + def mode(self): + """LED mode setting. + + "always", "never" + """ + return "always" if self.led else "never" + + @property + def led(self) -> bool: + """Return the state of the led.""" + sys_info = self.data + return bool(1 - sys_info["led_off"]) + + async def set_led(self, state: bool): + """Set the state of the led (night mode).""" + return await self.call("set_led_off", {"off": int(not state)}) diff --git a/kasa/iot/modules/lighteffectmodule.py b/kasa/iot/modules/lighteffectmodule.py new file mode 100644 index 00000000..c53de192 --- /dev/null +++ b/kasa/iot/modules/lighteffectmodule.py @@ -0,0 +1,97 @@ +"""Module for light effects.""" + +from __future__ import annotations + +from ...interfaces.lighteffect import LightEffect +from ..effects import EFFECT_MAPPING_V1, EFFECT_NAMES_V1 +from ..iotmodule import IotModule + + +class LightEffectModule(IotModule, LightEffect): + """Implementation of dynamic light effects.""" + + @property + def effect(self) -> str: + """Return effect state. + + Example: + {'brightness': 50, + 'custom': 0, + 'enable': 0, + 'id': '', + 'name': ''} + """ + if ( + (state := self.data.get("lighting_effect_state")) + and state.get("enable") + and (name := state.get("name")) + and name in EFFECT_NAMES_V1 + ): + return name + return self.LIGHT_EFFECTS_OFF + + @property + def effect_list(self) -> list[str]: + """Return built-in effects list. + + Example: + ['Aurora', 'Bubbling Cauldron', ...] + """ + effect_list = [self.LIGHT_EFFECTS_OFF] + effect_list.extend(EFFECT_NAMES_V1) + return effect_list + + async def set_effect( + self, + effect: str, + *, + brightness: int | None = None, + transition: int | None = None, + ) -> None: + """Set an effect on the device. + + If brightness or transition is defined, + its value will be used instead of the effect-specific default. + + See :meth:`effect_list` for available effects, + or use :meth:`set_custom_effect` for custom effects. + + :param str effect: The effect to set + :param int brightness: The wanted brightness + :param int transition: The wanted transition time + """ + if effect == self.LIGHT_EFFECTS_OFF: + effect_dict = dict(self.data["lighting_effect_state"]) + effect_dict["enable"] = 0 + elif effect not in EFFECT_MAPPING_V1: + raise ValueError(f"The effect {effect} is not a built in effect.") + else: + effect_dict = EFFECT_MAPPING_V1[effect] + if brightness is not None: + effect_dict["brightness"] = brightness + if transition is not None: + effect_dict["transition"] = transition + + await self.set_custom_effect(effect_dict) + + async def set_custom_effect( + self, + effect_dict: dict, + ) -> None: + """Set a custom effect on the device. + + :param str effect_dict: The custom effect dict to set + """ + return await self.call( + "set_lighting_effect", + effect_dict, + ) + + @property + def has_custom_effects(self) -> bool: + """Return True if the device supports setting custom effects.""" + return True + + def query(self): + """Return the base query.""" + return {} diff --git a/kasa/module.py b/kasa/module.py index 3da0c1ad..b65f0499 100644 --- a/kasa/module.py +++ b/kasa/module.py @@ -6,14 +6,20 @@ import logging from abc import ABC, abstractmethod from typing import ( TYPE_CHECKING, + Final, TypeVar, ) from .exceptions import KasaException from .feature import Feature +from .modulemapping import ModuleName if TYPE_CHECKING: - from .device import Device + from .device import Device as DeviceType # avoid name clash with Device module + from .interfaces.led import Led + from .interfaces.lighteffect import LightEffect + from .iot import modules as iot + from .smart import modules as smart _LOGGER = logging.getLogger(__name__) @@ -27,7 +33,59 @@ class Module(ABC): executed during the regular update cycle. """ - def __init__(self, device: Device, module: str): + # Common Modules + LightEffect: Final[ModuleName[LightEffect]] = ModuleName("LightEffectModule") + Led: Final[ModuleName[Led]] = ModuleName("LedModule") + + # IOT only Modules + IotAmbientLight: Final[ModuleName[iot.AmbientLight]] = ModuleName("ambient") + IotAntitheft: Final[ModuleName[iot.Antitheft]] = ModuleName("anti_theft") + IotCountdown: Final[ModuleName[iot.Countdown]] = ModuleName("countdown") + IotEmeter: Final[ModuleName[iot.Emeter]] = ModuleName("emeter") + IotMotion: Final[ModuleName[iot.Motion]] = ModuleName("motion") + IotSchedule: Final[ModuleName[iot.Schedule]] = ModuleName("schedule") + IotUsage: Final[ModuleName[iot.Usage]] = ModuleName("usage") + IotCloud: Final[ModuleName[iot.Cloud]] = ModuleName("cloud") + IotTime: Final[ModuleName[iot.Time]] = ModuleName("time") + + # SMART only Modules + Alarm: Final[ModuleName[smart.AlarmModule]] = ModuleName("AlarmModule") + AutoOff: Final[ModuleName[smart.AutoOffModule]] = ModuleName("AutoOffModule") + BatterySensor: Final[ModuleName[smart.BatterySensor]] = ModuleName("BatterySensor") + Brightness: Final[ModuleName[smart.Brightness]] = ModuleName("Brightness") + ChildDevice: Final[ModuleName[smart.ChildDeviceModule]] = ModuleName( + "ChildDeviceModule" + ) + Cloud: Final[ModuleName[smart.CloudModule]] = ModuleName("CloudModule") + Color: Final[ModuleName[smart.ColorModule]] = ModuleName("ColorModule") + ColorTemp: Final[ModuleName[smart.ColorTemperatureModule]] = ModuleName( + "ColorTemperatureModule" + ) + ContactSensor: Final[ModuleName[smart.ContactSensor]] = ModuleName("ContactSensor") + Device: Final[ModuleName[smart.DeviceModule]] = ModuleName("DeviceModule") + Energy: Final[ModuleName[smart.EnergyModule]] = ModuleName("EnergyModule") + Fan: Final[ModuleName[smart.FanModule]] = ModuleName("FanModule") + Firmware: Final[ModuleName[smart.Firmware]] = ModuleName("Firmware") + FrostProtection: Final[ModuleName[smart.FrostProtectionModule]] = ModuleName( + "FrostProtectionModule" + ) + Humidity: Final[ModuleName[smart.HumiditySensor]] = ModuleName("HumiditySensor") + LightTransition: Final[ModuleName[smart.LightTransitionModule]] = ModuleName( + "LightTransitionModule" + ) + Report: Final[ModuleName[smart.ReportModule]] = ModuleName("ReportModule") + Temperature: Final[ModuleName[smart.TemperatureSensor]] = ModuleName( + "TemperatureSensor" + ) + TemperatureSensor: Final[ModuleName[smart.TemperatureControl]] = ModuleName( + "TemperatureControl" + ) + Time: Final[ModuleName[smart.TimeModule]] = ModuleName("TimeModule") + WaterleakSensor: Final[ModuleName[smart.WaterleakSensor]] = ModuleName( + "WaterleakSensor" + ) + + def __init__(self, device: DeviceType, module: str): self._device = device self._module = module self._module_features: dict[str, Feature] = {} diff --git a/kasa/modulemapping.py b/kasa/modulemapping.py new file mode 100644 index 00000000..06ba8619 --- /dev/null +++ b/kasa/modulemapping.py @@ -0,0 +1,25 @@ +"""Module for Implementation for ModuleMapping and ModuleName types. + +Custom dict for getting typed modules from the module dict. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Generic, TypeVar + +if TYPE_CHECKING: + from .module import Module + +_ModuleT = TypeVar("_ModuleT", bound="Module") + + +class ModuleName(str, Generic[_ModuleT]): + """Generic Module name type. + + At runtime this is a generic subclass of str. + """ + + __slots__ = () + + +ModuleMapping = dict diff --git a/kasa/modulemapping.pyi b/kasa/modulemapping.pyi new file mode 100644 index 00000000..8d110d39 --- /dev/null +++ b/kasa/modulemapping.pyi @@ -0,0 +1,96 @@ +"""Typing stub file for ModuleMapping.""" + +from abc import ABCMeta +from collections.abc import Mapping +from typing import Generic, TypeVar, overload + +from .module import Module + +__all__ = [ + "ModuleMapping", + "ModuleName", +] + +_ModuleT = TypeVar("_ModuleT", bound=Module, covariant=True) +_ModuleBaseT = TypeVar("_ModuleBaseT", bound=Module, covariant=True) + +class ModuleName(Generic[_ModuleT]): + """Class for typed Module names. At runtime delegated to str.""" + + def __init__(self, value: str, /) -> None: ... + def __len__(self) -> int: ... + def __hash__(self) -> int: ... + def __eq__(self, other: object) -> bool: ... + def __getitem__(self, index: int) -> str: ... + +class ModuleMapping( + Mapping[ModuleName[_ModuleBaseT] | str, _ModuleBaseT], metaclass=ABCMeta +): + """Custom dict type to provide better value type hints for Module key types.""" + + @overload + def __getitem__(self, key: ModuleName[_ModuleT], /) -> _ModuleT: ... + @overload + def __getitem__(self, key: str, /) -> _ModuleBaseT: ... + @overload + def __getitem__( + self, key: ModuleName[_ModuleT] | str, / + ) -> _ModuleT | _ModuleBaseT: ... + @overload # type: ignore[override] + def get(self, key: ModuleName[_ModuleT], /) -> _ModuleT | None: ... + @overload + def get(self, key: str, /) -> _ModuleBaseT | None: ... + @overload + def get( + self, key: ModuleName[_ModuleT] | str, / + ) -> _ModuleT | _ModuleBaseT | None: ... + +def _test_module_mapping_typing() -> None: + """Test ModuleMapping overloads work as intended. + + This is tested during the mypy run and needs to be in this file. + """ + from typing import Any, NewType, cast + + from typing_extensions import assert_type + + from .iot.iotmodule import IotModule + from .module import Module + from .smart.smartmodule import SmartModule + + NewCommonModule = NewType("NewCommonModule", Module) + NewIotModule = NewType("NewIotModule", IotModule) + NewSmartModule = NewType("NewSmartModule", SmartModule) + NotModule = NewType("NotModule", list) + + NEW_COMMON_MODULE: ModuleName[NewCommonModule] = ModuleName("NewCommonModule") + NEW_IOT_MODULE: ModuleName[NewIotModule] = ModuleName("NewIotModule") + NEW_SMART_MODULE: ModuleName[NewSmartModule] = ModuleName("NewSmartModule") + + # TODO Enable --warn-unused-ignores + NOT_MODULE: ModuleName[NotModule] = ModuleName("NotModule") # type: ignore[type-var] # noqa: F841 + NOT_MODULE_2 = ModuleName[NotModule]("NotModule2") # type: ignore[type-var] # noqa: F841 + + device_modules: ModuleMapping[Module] = cast(ModuleMapping[Module], {}) + assert_type(device_modules[NEW_COMMON_MODULE], NewCommonModule) + assert_type(device_modules[NEW_IOT_MODULE], NewIotModule) + assert_type(device_modules[NEW_SMART_MODULE], NewSmartModule) + assert_type(device_modules["foobar"], Module) + assert_type(device_modules[3], Any) # type: ignore[call-overload] + + assert_type(device_modules.get(NEW_COMMON_MODULE), NewCommonModule | None) + assert_type(device_modules.get(NEW_IOT_MODULE), NewIotModule | None) + assert_type(device_modules.get(NEW_SMART_MODULE), NewSmartModule | None) + assert_type(device_modules.get(NEW_COMMON_MODULE, default=[1, 2]), Any) # type: ignore[call-overload] + + iot_modules: ModuleMapping[IotModule] = cast(ModuleMapping[IotModule], {}) + smart_modules: ModuleMapping[SmartModule] = cast(ModuleMapping[SmartModule], {}) + + assert_type(smart_modules["foobar"], SmartModule) + assert_type(iot_modules["foobar"], IotModule) + + # Test for covariance + device_modules_2: ModuleMapping[Module] = iot_modules # noqa: F841 + device_modules_3: ModuleMapping[Module] = smart_modules # noqa: F841 + NEW_MODULE: ModuleName[Module] = NEW_SMART_MODULE # noqa: F841 + NEW_MODULE_2: ModuleName[Module] = NEW_IOT_MODULE # noqa: F841 diff --git a/kasa/plug.py b/kasa/plug.py deleted file mode 100644 index 00796d1c..00000000 --- a/kasa/plug.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Module for a TAPO Plug.""" - -import logging -from abc import ABC - -from .device import Device - -_LOGGER = logging.getLogger(__name__) - - -class Plug(Device, ABC): - """Base class to represent a Plug.""" diff --git a/kasa/smart/modules/ledmodule.py b/kasa/smart/modules/ledmodule.py index e3113159..587be51c 100644 --- a/kasa/smart/modules/ledmodule.py +++ b/kasa/smart/modules/ledmodule.py @@ -2,37 +2,16 @@ from __future__ import annotations -from typing import TYPE_CHECKING - -from ...feature import Feature +from ...interfaces.led import Led from ..smartmodule import SmartModule -if TYPE_CHECKING: - from ..smartdevice import SmartDevice - -class LedModule(SmartModule): +class LedModule(SmartModule, Led): """Implementation of led controls.""" REQUIRED_COMPONENT = "led" QUERY_GETTER_NAME = "get_led_info" - def __init__(self, device: SmartDevice, module: str): - super().__init__(device, module) - self._add_feature( - Feature( - device=device, - container=self, - id="led", - name="LED", - icon="mdi:led-{state}", - attribute_getter="led", - attribute_setter="set_led", - type=Feature.Type.Switch, - category=Feature.Category.Config, - ) - ) - def query(self) -> dict: """Query to execute during the update cycle.""" return {self.QUERY_GETTER_NAME: {"led_rule": None}} @@ -56,7 +35,7 @@ class LedModule(SmartModule): This should probably be a select with always/never/nightmode. """ rule = "always" if enable else "never" - return await self.call("set_led_info", self.data | {"led_rule": rule}) + return await self.call("set_led_info", dict(self.data, **{"led_rule": rule})) @property def night_mode_settings(self): diff --git a/kasa/smart/modules/lighteffectmodule.py b/kasa/smart/modules/lighteffectmodule.py index bd0eea0a..a06e979a 100644 --- a/kasa/smart/modules/lighteffectmodule.py +++ b/kasa/smart/modules/lighteffectmodule.py @@ -6,14 +6,14 @@ import base64 import copy from typing import TYPE_CHECKING, Any -from ...feature import Feature +from ...interfaces.lighteffect import LightEffect from ..smartmodule import SmartModule if TYPE_CHECKING: from ..smartdevice import SmartDevice -class LightEffectModule(SmartModule): +class LightEffectModule(SmartModule, LightEffect): """Implementation of dynamic light effects.""" REQUIRED_COMPONENT = "light_effect" @@ -22,29 +22,11 @@ class LightEffectModule(SmartModule): "L1": "Party", "L2": "Relax", } - LIGHT_EFFECTS_OFF = "Off" def __init__(self, device: SmartDevice, module: str): super().__init__(device, module) self._scenes_names_to_id: dict[str, str] = {} - def _initialize_features(self): - """Initialize features.""" - device = self._device - self._add_feature( - Feature( - device, - id="light_effect", - name="Light effect", - container=self, - attribute_getter="effect", - attribute_setter="set_effect", - category=Feature.Category.Config, - type=Feature.Type.Choice, - choices_getter="effect_list", - ) - ) - def _initialize_effects(self) -> dict[str, dict[str, Any]]: """Return built-in effects.""" # Copy the effects so scene name updates do not update the underlying dict. @@ -64,7 +46,7 @@ class LightEffectModule(SmartModule): return effects @property - def effect_list(self) -> list[str] | None: + def effect_list(self) -> list[str]: """Return built-in effects list. Example: @@ -90,6 +72,9 @@ class LightEffectModule(SmartModule): async def set_effect( self, effect: str, + *, + brightness: int | None = None, + transition: int | None = None, ) -> None: """Set an effect for the device. @@ -108,6 +93,24 @@ class LightEffectModule(SmartModule): params["id"] = effect_id return await self.call("set_dynamic_light_effect_rule_enable", params) + async def set_custom_effect( + self, + effect_dict: dict, + ) -> None: + """Set a custom effect on the device. + + :param str effect_dict: The custom effect dict to set + """ + raise NotImplementedError( + "Device does not support setting custom effects. " + "Use has_custom_effects to check for support." + ) + + @property + def has_custom_effects(self) -> bool: + """Return True if the device supports setting custom effects.""" + return False + def query(self) -> dict: """Query to execute during the update cycle.""" return {self.QUERY_GETTER_NAME: {"start_index": 0}} diff --git a/kasa/smart/smartdevice.py b/kasa/smart/smartdevice.py index 68b08902..194e7c17 100644 --- a/kasa/smart/smartdevice.py +++ b/kasa/smart/smartdevice.py @@ -5,7 +5,7 @@ from __future__ import annotations import base64 import logging from datetime import datetime, timedelta -from typing import Any, Mapping, Sequence, cast, overload +from typing import TYPE_CHECKING, Any, Mapping, Sequence, cast from ..aestransport import AesTransport from ..bulb import HSV, Bulb, BulbPreset, ColorTempRange @@ -16,7 +16,8 @@ from ..emeterstatus import EmeterStatus from ..exceptions import AuthenticationError, DeviceError, KasaException, SmartErrorCode from ..fan import Fan from ..feature import Feature -from ..module import ModuleT +from ..module import Module +from ..modulemapping import ModuleMapping, ModuleName from ..smartprotocol import SmartProtocol from .modules import ( Brightness, @@ -61,7 +62,7 @@ class SmartDevice(Bulb, Fan, Device): self._components_raw: dict[str, Any] | None = None self._components: dict[str, int] = {} self._state_information: dict[str, Any] = {} - self._modules: dict[str, SmartModule] = {} + self._modules: dict[str | ModuleName[Module], SmartModule] = {} self._exposes_child_modules = False self._parent: SmartDevice | None = None self._children: Mapping[str, SmartDevice] = {} @@ -102,8 +103,20 @@ class SmartDevice(Bulb, Fan, Device): return list(self._children.values()) @property - def modules(self) -> dict[str, SmartModule]: + def modules(self) -> ModuleMapping[SmartModule]: """Return the device modules.""" + if self._exposes_child_modules: + modules = {k: v for k, v in self._modules.items()} + for child in self._children.values(): + for k, v in child._modules.items(): + if k not in modules: + modules[k] = v + if TYPE_CHECKING: + return cast(ModuleMapping[SmartModule], modules) + return modules + + if TYPE_CHECKING: # Needed for python 3.8 + return cast(ModuleMapping[SmartModule], self._modules) return self._modules def _try_get_response(self, responses: dict, request: str, default=None) -> dict: @@ -315,30 +328,6 @@ class SmartDevice(Bulb, Fan, Device): for feat in module._module_features.values(): self._add_feature(feat) - @overload - def get_module(self, module_type: type[ModuleT]) -> ModuleT | None: ... - - @overload - def get_module(self, module_type: str) -> SmartModule | None: ... - - def get_module( - self, module_type: type[ModuleT] | str - ) -> ModuleT | SmartModule | None: - """Return the module from the device modules or None if not present.""" - if isinstance(module_type, str): - module_name = module_type - elif issubclass(module_type, SmartModule): - module_name = module_type.__name__ - else: - return None - if module_name in self.modules: - return self.modules[module_name] - elif self._exposes_child_modules: - for child in self._children.values(): - if module_name in child.modules: - return child.modules[module_name] - return None - @property def is_cloud_connected(self): """Returns if the device is connected to the cloud.""" diff --git a/kasa/tests/fakeprotocol_smart.py b/kasa/tests/fakeprotocol_smart.py index 5ca4a8ae..7c73c71e 100644 --- a/kasa/tests/fakeprotocol_smart.py +++ b/kasa/tests/fakeprotocol_smart.py @@ -189,6 +189,11 @@ class FakeSmartTransport(BaseTransport): if "current_rule_id" in info["get_dynamic_light_effect_rules"]: del info["get_dynamic_light_effect_rules"]["current_rule_id"] + def _set_led_info(self, info, params): + """Set or remove values as per the device behaviour.""" + info["get_led_info"]["led_status"] = params["led_rule"] != "never" + info["get_led_info"]["led_rule"] = params["led_rule"] + def _send_request(self, request_dict: dict): method = request_dict["method"] params = request_dict["params"] @@ -218,7 +223,9 @@ class FakeSmartTransport(BaseTransport): # SMART fixtures started to be generated missing_result := self.FIXTURE_MISSING_MAP.get(method) ) and missing_result[0] in self.components: - result = copy.deepcopy(missing_result[1]) + # Copy to info so it will work with update methods + info[method] = copy.deepcopy(missing_result[1]) + result = copy.deepcopy(info[method]) retval = {"result": result, "error_code": 0} else: # PARAMS error returned for KS240 when get_device_usage called @@ -239,6 +246,9 @@ class FakeSmartTransport(BaseTransport): elif method == "set_dynamic_light_effect_rule_enable": self._set_light_effect(info, params) return {"error_code": 0} + elif method == "set_led_info": + self._set_led_info(info, params) + return {"error_code": 0} elif method[:4] == "set_": target_method = f"get_{method[4:]}" info[target_method].update(params) diff --git a/kasa/tests/smart/features/test_brightness.py b/kasa/tests/smart/features/test_brightness.py index 02a396aa..3c00a4d1 100644 --- a/kasa/tests/smart/features/test_brightness.py +++ b/kasa/tests/smart/features/test_brightness.py @@ -10,7 +10,7 @@ brightness = parametrize("brightness smart", component_filter="brightness") @brightness async def test_brightness_component(dev: SmartDevice): """Test brightness feature.""" - brightness = dev.get_module("Brightness") + brightness = dev.modules.get("Brightness") assert brightness assert isinstance(dev, SmartDevice) assert "brightness" in dev._components diff --git a/kasa/tests/smart/modules/test_contact.py b/kasa/tests/smart/modules/test_contact.py index fc337545..88677c58 100644 --- a/kasa/tests/smart/modules/test_contact.py +++ b/kasa/tests/smart/modules/test_contact.py @@ -1,7 +1,6 @@ import pytest -from kasa import SmartDevice -from kasa.smart.modules import ContactSensor +from kasa import Module, SmartDevice from kasa.tests.device_fixtures import parametrize contact = parametrize( @@ -18,7 +17,7 @@ contact = parametrize( ) async def test_contact_features(dev: SmartDevice, feature, type): """Test that features are registered and work as expected.""" - contact = dev.get_module(ContactSensor) + contact = dev.modules.get(Module.ContactSensor) assert contact is not None prop = getattr(contact, feature) diff --git a/kasa/tests/smart/modules/test_fan.py b/kasa/tests/smart/modules/test_fan.py index 37245951..9597471b 100644 --- a/kasa/tests/smart/modules/test_fan.py +++ b/kasa/tests/smart/modules/test_fan.py @@ -1,8 +1,8 @@ import pytest from pytest_mock import MockerFixture +from kasa import Module from kasa.smart import SmartDevice -from kasa.smart.modules import FanModule from kasa.tests.device_fixtures import parametrize fan = parametrize("has fan", component_filter="fan_control", protocol_filter={"SMART"}) @@ -11,7 +11,7 @@ fan = parametrize("has fan", component_filter="fan_control", protocol_filter={"S @fan async def test_fan_speed(dev: SmartDevice, mocker: MockerFixture): """Test fan speed feature.""" - fan = dev.get_module(FanModule) + fan = dev.modules.get(Module.Fan) assert fan level_feature = fan._module_features["fan_speed_level"] @@ -36,7 +36,7 @@ async def test_fan_speed(dev: SmartDevice, mocker: MockerFixture): @fan async def test_sleep_mode(dev: SmartDevice, mocker: MockerFixture): """Test sleep mode feature.""" - fan = dev.get_module(FanModule) + fan = dev.modules.get(Module.Fan) assert fan sleep_feature = fan._module_features["fan_sleep_mode"] assert isinstance(sleep_feature.value, bool) @@ -55,7 +55,7 @@ async def test_sleep_mode(dev: SmartDevice, mocker: MockerFixture): async def test_fan_interface(dev: SmartDevice, mocker: MockerFixture): """Test fan speed on device interface.""" assert isinstance(dev, SmartDevice) - fan = dev.get_module(FanModule) + fan = dev.modules.get(Module.Fan) assert fan device = fan._device assert device.is_fan diff --git a/kasa/tests/smart/modules/test_firmware.py b/kasa/tests/smart/modules/test_firmware.py index d0df87ca..8f329f70 100644 --- a/kasa/tests/smart/modules/test_firmware.py +++ b/kasa/tests/smart/modules/test_firmware.py @@ -6,8 +6,8 @@ import logging import pytest from pytest_mock import MockerFixture +from kasa import Module from kasa.smart import SmartDevice -from kasa.smart.modules import Firmware from kasa.smart.modules.firmware import DownloadState from kasa.tests.device_fixtures import parametrize @@ -31,7 +31,7 @@ async def test_firmware_features( dev: SmartDevice, feature, prop_name, type, required_version, mocker: MockerFixture ): """Test light effect.""" - fw = dev.get_module(Firmware) + fw = dev.modules.get(Module.Firmware) assert fw if not dev.is_cloud_connected: @@ -51,7 +51,7 @@ async def test_firmware_features( @firmware async def test_update_available_without_cloud(dev: SmartDevice): """Test that update_available returns None when disconnected.""" - fw = dev.get_module(Firmware) + fw = dev.modules.get(Module.Firmware) assert fw if dev.is_cloud_connected: @@ -67,7 +67,7 @@ async def test_firmware_update( """Test updating firmware.""" caplog.set_level(logging.INFO) - fw = dev.get_module(Firmware) + fw = dev.modules.get(Module.Firmware) assert fw upgrade_time = 5 diff --git a/kasa/tests/smart/modules/test_light_effect.py b/kasa/tests/smart/modules/test_light_effect.py index ba1b2293..cc0eee8a 100644 --- a/kasa/tests/smart/modules/test_light_effect.py +++ b/kasa/tests/smart/modules/test_light_effect.py @@ -1,12 +1,11 @@ from __future__ import annotations from itertools import chain -from typing import cast import pytest from pytest_mock import MockerFixture -from kasa import Device, Feature +from kasa import Device, Feature, Module from kasa.smart.modules import LightEffectModule from kasa.tests.device_fixtures import parametrize @@ -18,8 +17,8 @@ light_effect = parametrize( @light_effect async def test_light_effect(dev: Device, mocker: MockerFixture): """Test light effect.""" - light_effect = cast(LightEffectModule, dev.modules.get("LightEffectModule")) - assert light_effect + light_effect = dev.modules.get(Module.LightEffect) + assert isinstance(light_effect, LightEffectModule) feature = light_effect._module_features["light_effect"] assert feature.type == Feature.Type.Choice diff --git a/kasa/tests/test_common_modules.py b/kasa/tests/test_common_modules.py new file mode 100644 index 00000000..8f7def95 --- /dev/null +++ b/kasa/tests/test_common_modules.py @@ -0,0 +1,95 @@ +import pytest +from pytest_mock import MockerFixture + +from kasa import Device, Module +from kasa.tests.device_fixtures import ( + lightstrip, + parametrize, + parametrize_combine, + plug_iot, +) + +led_smart = parametrize( + "has led smart", component_filter="led", protocol_filter={"SMART"} +) +led = parametrize_combine([led_smart, plug_iot]) + +light_effect_smart = parametrize( + "has light effect smart", component_filter="light_effect", protocol_filter={"SMART"} +) +light_effect = parametrize_combine([light_effect_smart, lightstrip]) + + +@led +async def test_led_module(dev: Device, mocker: MockerFixture): + """Test fan speed feature.""" + led_module = dev.modules.get(Module.Led) + assert led_module + feat = led_module._module_features["led"] + + call = mocker.spy(led_module, "call") + await led_module.set_led(True) + assert call.call_count == 1 + await dev.update() + assert led_module.led is True + assert feat.value is True + + await led_module.set_led(False) + assert call.call_count == 2 + await dev.update() + assert led_module.led is False + assert feat.value is False + + await feat.set_value(True) + assert call.call_count == 3 + await dev.update() + assert feat.value is True + assert led_module.led is True + + +@light_effect +async def test_light_effect_module(dev: Device, mocker: MockerFixture): + """Test fan speed feature.""" + light_effect_module = dev.modules[Module.LightEffect] + assert light_effect_module + feat = light_effect_module._module_features["light_effect"] + + call = mocker.spy(light_effect_module, "call") + effect_list = light_effect_module.effect_list + assert "Off" in effect_list + assert effect_list.index("Off") == 0 + assert len(effect_list) > 1 + assert effect_list == feat.choices + + assert light_effect_module.has_custom_effects is not None + + await light_effect_module.set_effect("Off") + assert call.call_count == 1 + await dev.update() + assert light_effect_module.effect == "Off" + assert feat.value == "Off" + + second_effect = effect_list[1] + await light_effect_module.set_effect(second_effect) + assert call.call_count == 2 + await dev.update() + assert light_effect_module.effect == second_effect + assert feat.value == second_effect + + last_effect = effect_list[len(effect_list) - 1] + await light_effect_module.set_effect(last_effect) + assert call.call_count == 3 + await dev.update() + assert light_effect_module.effect == last_effect + assert feat.value == last_effect + + # Test feature set + await feat.set_value(second_effect) + assert call.call_count == 4 + await dev.update() + assert light_effect_module.effect == second_effect + assert feat.value == second_effect + + with pytest.raises(ValueError): + await light_effect_module.set_effect("foobar") + assert call.call_count == 4 diff --git a/kasa/tests/test_iotdevice.py b/kasa/tests/test_iotdevice.py index b4d56291..d5c76192 100644 --- a/kasa/tests/test_iotdevice.py +++ b/kasa/tests/test_iotdevice.py @@ -16,7 +16,7 @@ from voluptuous import ( Schema, ) -from kasa import KasaException +from kasa import KasaException, Module from kasa.iot import IotDevice from .conftest import get_device_for_fixture_protocol, handle_turn_on, turn_on @@ -261,27 +261,26 @@ async def test_modules_not_supported(dev: IotDevice): async def test_get_modules(): - """Test get_modules for child and parent modules.""" + """Test getting modules for child and parent modules.""" dummy_device = await get_device_for_fixture_protocol( "HS100(US)_2.0_1.5.6.json", "IOT" ) from kasa.iot.modules import Cloud - from kasa.smart.modules import CloudModule # Modules on device - module = dummy_device.get_module("Cloud") + module = dummy_device.modules.get("cloud") assert module assert module._device == dummy_device assert isinstance(module, Cloud) - module = dummy_device.get_module(Cloud) + module = dummy_device.modules.get(Module.IotCloud) assert module assert module._device == dummy_device assert isinstance(module, Cloud) # Invalid modules - module = dummy_device.get_module("DummyModule") + module = dummy_device.modules.get("DummyModule") assert module is None - module = dummy_device.get_module(CloudModule) + module = dummy_device.modules.get(Module.Cloud) assert module is None diff --git a/kasa/tests/test_lightstrip.py b/kasa/tests/test_lightstrip.py index fc987d2e..f51f1805 100644 --- a/kasa/tests/test_lightstrip.py +++ b/kasa/tests/test_lightstrip.py @@ -1,7 +1,6 @@ import pytest from kasa import DeviceType -from kasa.exceptions import KasaException from kasa.iot import IotLightStrip from .conftest import lightstrip @@ -23,7 +22,7 @@ async def test_lightstrip_effect(dev: IotLightStrip): @lightstrip async def test_effects_lightstrip_set_effect(dev: IotLightStrip): - with pytest.raises(KasaException): + with pytest.raises(ValueError): await dev.set_effect("Not real") await dev.set_effect("Candy Cane") diff --git a/kasa/tests/test_smartdevice.py b/kasa/tests/test_smartdevice.py index bb2f81bf..a0af2cb1 100644 --- a/kasa/tests/test_smartdevice.py +++ b/kasa/tests/test_smartdevice.py @@ -9,7 +9,7 @@ from unittest.mock import patch import pytest from pytest_mock import MockerFixture -from kasa import KasaException +from kasa import KasaException, Module from kasa.exceptions import SmartErrorCode from kasa.smart import SmartDevice @@ -123,40 +123,39 @@ async def test_update_module_queries(dev: SmartDevice, mocker: MockerFixture): async def test_get_modules(): - """Test get_modules for child and parent modules.""" + """Test getting modules for child and parent modules.""" dummy_device = await get_device_for_fixture_protocol( "KS240(US)_1.0_1.0.5.json", "SMART" ) - from kasa.iot.modules import AmbientLight - from kasa.smart.modules import CloudModule, FanModule + from kasa.smart.modules import CloudModule # Modules on device - module = dummy_device.get_module("CloudModule") + module = dummy_device.modules.get("CloudModule") assert module assert module._device == dummy_device assert isinstance(module, CloudModule) - module = dummy_device.get_module(CloudModule) + module = dummy_device.modules.get(Module.Cloud) assert module assert module._device == dummy_device assert isinstance(module, CloudModule) # Modules on child - module = dummy_device.get_module("FanModule") + module = dummy_device.modules.get("FanModule") assert module assert module._device != dummy_device assert module._device._parent == dummy_device - module = dummy_device.get_module(FanModule) + module = dummy_device.modules.get(Module.Fan) assert module assert module._device != dummy_device assert module._device._parent == dummy_device # Invalid modules - module = dummy_device.get_module("DummyModule") + module = dummy_device.modules.get("DummyModule") assert module is None - module = dummy_device.get_module(AmbientLight) + module = dummy_device.modules.get(Module.IotAmbientLight) assert module is None