From c5d65b624b52ffc1d0d1ae1317eee4dd7d50c802 Mon Sep 17 00:00:00 2001 From: Steven B <51370195+sdb9696@users.noreply.github.com> Date: Fri, 3 May 2024 16:01:21 +0100 Subject: [PATCH] Make get_module return typed module (#892) Passing in a string still works and returns either `IotModule` or `SmartModule` type when called on `IotDevice` or `SmartDevice` respectively. When calling on `Device` will return `Module` type. Passing in a module type is then typed to that module, i.e.: ```py smartdev.get_module(FanModule) # type is FanModule smartdev.get_module("FanModule") # type is SmartModule ``` Only thing this doesn't do is check that you can't pass an `IotModule` to a `SmartDevice.get_module()`. However there is a runtime check which will return null if the passed `ModuleType` is not a subclass of `SmartModule`. Many thanks to @cdce8p for helping with this. --- kasa/device.py | 16 +++++++++-- kasa/iot/iotdevice.py | 23 +++++++++++++++- kasa/module.py | 7 ++++- kasa/smart/smartdevice.py | 22 ++++++++++++--- kasa/tests/smart/features/test_brightness.py | 2 +- kasa/tests/smart/modules/test_fan.py | 9 +++--- kasa/tests/test_iotdevice.py | 29 +++++++++++++++++++- kasa/tests/test_smartdevice.py | 22 ++++++++++++++- 8 files changed, 114 insertions(+), 16 deletions(-) diff --git a/kasa/device.py b/kasa/device.py index 4cb6bd98..ea358a8d 100644 --- a/kasa/device.py +++ b/kasa/device.py @@ -6,7 +6,7 @@ import logging from abc import ABC, abstractmethod from dataclasses import dataclass from datetime import datetime -from typing import Any, Mapping, Sequence +from typing import Any, Mapping, Sequence, overload from .credentials import Credentials from .device_type import DeviceType @@ -15,7 +15,7 @@ from .emeterstatus import EmeterStatus from .exceptions import KasaException from .feature import Feature from .iotprotocol import IotProtocol -from .module import Module +from .module import Module, ModuleT from .protocol import BaseProtocol from .xortransport import XorTransport @@ -116,6 +116,18 @@ class Device(ABC): def modules(self) -> Mapping[str, Module]: """Return the device modules.""" + @overload + @abstractmethod + def get_module(self, module_type: type[ModuleT]) -> ModuleT | None: ... + + @overload + @abstractmethod + def get_module(self, module_type: str) -> Module | None: ... + + @abstractmethod + def get_module(self, module_type: type[ModuleT] | str) -> ModuleT | Module | None: + """Return the module from the device modules or None if not present.""" + @property @abstractmethod def is_on(self) -> bool: diff --git a/kasa/iot/iotdevice.py b/kasa/iot/iotdevice.py index 81b5edda..e69de80c 100755 --- a/kasa/iot/iotdevice.py +++ b/kasa/iot/iotdevice.py @@ -19,13 +19,14 @@ import functools import inspect import logging from datetime import datetime, timedelta -from typing import Any, Mapping, Sequence, cast +from typing import Any, Mapping, Sequence, cast, overload from ..device import Device, WifiNetwork from ..deviceconfig import DeviceConfig from ..emeterstatus import EmeterStatus from ..exceptions import KasaException from ..feature import Feature +from ..module import ModuleT from ..protocol import BaseProtocol from .iotmodule import IotModule from .modules import Emeter, Time @@ -201,6 +202,26 @@ class IotDevice(Device): """Return the device modules.""" return self._modules + @overload + def get_module(self, module_type: type[ModuleT]) -> ModuleT | None: ... + + @overload + def get_module(self, module_type: str) -> IotModule | None: ... + + def get_module( + self, module_type: type[ModuleT] | str + ) -> ModuleT | IotModule | None: + """Return the module from the device modules or None if not present.""" + if isinstance(module_type, str): + module_name = module_type.lower() + elif issubclass(module_type, IotModule): + module_name = module_type.__name__.lower() + else: + return None + if module_name in self.modules: + return self.modules[module_name] + return None + def add_module(self, name: str, module: IotModule): """Register a module.""" if name in self.modules: diff --git a/kasa/module.py b/kasa/module.py index 8422eaf9..5b6354a9 100644 --- a/kasa/module.py +++ b/kasa/module.py @@ -4,7 +4,10 @@ from __future__ import annotations import logging from abc import ABC, abstractmethod -from typing import TYPE_CHECKING +from typing import ( + TYPE_CHECKING, + TypeVar, +) from .exceptions import KasaException from .feature import Feature @@ -14,6 +17,8 @@ if TYPE_CHECKING: _LOGGER = logging.getLogger(__name__) +ModuleT = TypeVar("ModuleT", bound="Module") + class Module(ABC): """Base class implemention for all modules. diff --git a/kasa/smart/smartdevice.py b/kasa/smart/smartdevice.py index e5df10be..98c5f7ef 100644 --- a/kasa/smart/smartdevice.py +++ b/kasa/smart/smartdevice.py @@ -5,7 +5,7 @@ from __future__ import annotations import base64 import logging from datetime import datetime, timedelta -from typing import TYPE_CHECKING, Any, Mapping, Sequence, cast +from typing import Any, Mapping, Sequence, cast, overload from ..aestransport import AesTransport from ..bulb import HSV, Bulb, BulbPreset, ColorTempRange @@ -16,6 +16,7 @@ from ..emeterstatus import EmeterStatus from ..exceptions import AuthenticationError, DeviceError, KasaException, SmartErrorCode from ..fan import Fan from ..feature import Feature +from ..module import ModuleT from ..smartprotocol import SmartProtocol from .modules import ( Brightness, @@ -28,11 +29,10 @@ from .modules import ( Firmware, TimeModule, ) +from .smartmodule import SmartModule _LOGGER = logging.getLogger(__name__) -if TYPE_CHECKING: - from .smartmodule import SmartModule # List of modules that wall switches with children, i.e. ks240 report on # the child but only work on the parent. See longer note below in _initialize_modules. @@ -305,8 +305,22 @@ class SmartDevice(Bulb, Fan, Device): for feat in module._module_features.values(): self._add_feature(feat) - def get_module(self, module_name) -> SmartModule | None: + @overload + def get_module(self, module_type: type[ModuleT]) -> ModuleT | None: ... + + @overload + def get_module(self, module_type: str) -> SmartModule | None: ... + + def get_module( + self, module_type: type[ModuleT] | str + ) -> ModuleT | SmartModule | None: """Return the module from the device modules or None if not present.""" + if isinstance(module_type, str): + module_name = module_type + elif issubclass(module_type, SmartModule): + module_name = module_type.__name__ + else: + return None if module_name in self.modules: return self.modules[module_name] elif self._exposes_child_modules: diff --git a/kasa/tests/smart/features/test_brightness.py b/kasa/tests/smart/features/test_brightness.py index 79df0abf..02a396aa 100644 --- a/kasa/tests/smart/features/test_brightness.py +++ b/kasa/tests/smart/features/test_brightness.py @@ -33,7 +33,7 @@ async def test_brightness_component(dev: SmartDevice): @dimmable -async def test_brightness_dimmable(dev: SmartDevice): +async def test_brightness_dimmable(dev: IotDevice): """Test brightness feature.""" assert isinstance(dev, IotDevice) assert "brightness" in dev.sys_info or bool(dev.sys_info["is_dimmable"]) diff --git a/kasa/tests/smart/modules/test_fan.py b/kasa/tests/smart/modules/test_fan.py index 429a5d18..37245951 100644 --- a/kasa/tests/smart/modules/test_fan.py +++ b/kasa/tests/smart/modules/test_fan.py @@ -1,5 +1,3 @@ -from typing import cast - import pytest from pytest_mock import MockerFixture @@ -13,7 +11,7 @@ fan = parametrize("has fan", component_filter="fan_control", protocol_filter={"S @fan async def test_fan_speed(dev: SmartDevice, mocker: MockerFixture): """Test fan speed feature.""" - fan = cast(FanModule, dev.get_module("FanModule")) + fan = dev.get_module(FanModule) assert fan level_feature = fan._module_features["fan_speed_level"] @@ -38,7 +36,7 @@ async def test_fan_speed(dev: SmartDevice, mocker: MockerFixture): @fan async def test_sleep_mode(dev: SmartDevice, mocker: MockerFixture): """Test sleep mode feature.""" - fan = cast(FanModule, dev.get_module("FanModule")) + fan = dev.get_module(FanModule) assert fan sleep_feature = fan._module_features["fan_sleep_mode"] assert isinstance(sleep_feature.value, bool) @@ -57,7 +55,8 @@ async def test_sleep_mode(dev: SmartDevice, mocker: MockerFixture): async def test_fan_interface(dev: SmartDevice, mocker: MockerFixture): """Test fan speed on device interface.""" assert isinstance(dev, SmartDevice) - fan = cast(FanModule, dev.get_module("FanModule")) + fan = dev.get_module(FanModule) + assert fan device = fan._device assert device.is_fan diff --git a/kasa/tests/test_iotdevice.py b/kasa/tests/test_iotdevice.py index 4c5d5126..b4d56291 100644 --- a/kasa/tests/test_iotdevice.py +++ b/kasa/tests/test_iotdevice.py @@ -19,7 +19,7 @@ from voluptuous import ( from kasa import KasaException from kasa.iot import IotDevice -from .conftest import handle_turn_on, turn_on +from .conftest import get_device_for_fixture_protocol, handle_turn_on, turn_on from .device_fixtures import device_iot, has_emeter_iot, no_emeter_iot from .fakeprotocol_iot import FakeIotProtocol @@ -258,3 +258,30 @@ async def test_modules_not_supported(dev: IotDevice): await dev.update() for module in dev.modules.values(): assert module.is_supported is not None + + +async def test_get_modules(): + """Test get_modules for child and parent modules.""" + dummy_device = await get_device_for_fixture_protocol( + "HS100(US)_2.0_1.5.6.json", "IOT" + ) + from kasa.iot.modules import Cloud + from kasa.smart.modules import CloudModule + + # Modules on device + module = dummy_device.get_module("Cloud") + assert module + assert module._device == dummy_device + assert isinstance(module, Cloud) + + module = dummy_device.get_module(Cloud) + assert module + assert module._device == dummy_device + assert isinstance(module, Cloud) + + # Invalid modules + module = dummy_device.get_module("DummyModule") + assert module is None + + module = dummy_device.get_module(CloudModule) + assert module is None diff --git a/kasa/tests/test_smartdevice.py b/kasa/tests/test_smartdevice.py index 476a37ae..bb2f81bf 100644 --- a/kasa/tests/test_smartdevice.py +++ b/kasa/tests/test_smartdevice.py @@ -122,23 +122,43 @@ async def test_update_module_queries(dev: SmartDevice, mocker: MockerFixture): spies[device].assert_not_called() -async def test_get_modules(mocker): +async def test_get_modules(): """Test get_modules for child and parent modules.""" dummy_device = await get_device_for_fixture_protocol( "KS240(US)_1.0_1.0.5.json", "SMART" ) + from kasa.iot.modules import AmbientLight + from kasa.smart.modules import CloudModule, FanModule + + # Modules on device module = dummy_device.get_module("CloudModule") assert module assert module._device == dummy_device + assert isinstance(module, CloudModule) + module = dummy_device.get_module(CloudModule) + assert module + assert module._device == dummy_device + assert isinstance(module, CloudModule) + + # Modules on child module = dummy_device.get_module("FanModule") assert module assert module._device != dummy_device assert module._device._parent == dummy_device + module = dummy_device.get_module(FanModule) + assert module + assert module._device != dummy_device + assert module._device._parent == dummy_device + + # Invalid modules module = dummy_device.get_module("DummyModule") assert module is None + module = dummy_device.get_module(AmbientLight) + assert module is None + @bulb_smart async def test_smartdevice_brightness(dev: SmartDevice):