Make get_module return typed module (#892)

Passing in a string still works and returns either `IotModule` or
`SmartModule` type when called on `IotDevice` or `SmartDevice`
respectively. When calling on `Device` will return `Module` type.

Passing in a module type is then typed to that module, i.e.:
```py
smartdev.get_module(FanModule)  # type is FanModule
smartdev.get_module("FanModule")  # type is SmartModule
```
Only thing this doesn't do is check that you can't pass an `IotModule`
to a `SmartDevice.get_module()`. However there is a runtime check which
will return null if the passed `ModuleType` is not a subclass of
`SmartModule`.

Many thanks to @cdce8p for helping with this.
This commit is contained in:
Steven B 2024-05-03 16:01:21 +01:00 committed by GitHub
parent 530fb841b0
commit c5d65b624b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 114 additions and 16 deletions

View File

@ -6,7 +6,7 @@ import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Mapping, Sequence
from typing import Any, Mapping, Sequence, overload
from .credentials import Credentials
from .device_type import DeviceType
@ -15,7 +15,7 @@ from .emeterstatus import EmeterStatus
from .exceptions import KasaException
from .feature import Feature
from .iotprotocol import IotProtocol
from .module import Module
from .module import Module, ModuleT
from .protocol import BaseProtocol
from .xortransport import XorTransport
@ -116,6 +116,18 @@ class Device(ABC):
def modules(self) -> Mapping[str, Module]:
"""Return the device modules."""
@overload
@abstractmethod
def get_module(self, module_type: type[ModuleT]) -> ModuleT | None: ...
@overload
@abstractmethod
def get_module(self, module_type: str) -> Module | None: ...
@abstractmethod
def get_module(self, module_type: type[ModuleT] | str) -> ModuleT | Module | None:
"""Return the module from the device modules or None if not present."""
@property
@abstractmethod
def is_on(self) -> bool:

View File

@ -19,13 +19,14 @@ import functools
import inspect
import logging
from datetime import datetime, timedelta
from typing import Any, Mapping, Sequence, cast
from typing import Any, Mapping, Sequence, cast, overload
from ..device import Device, WifiNetwork
from ..deviceconfig import DeviceConfig
from ..emeterstatus import EmeterStatus
from ..exceptions import KasaException
from ..feature import Feature
from ..module import ModuleT
from ..protocol import BaseProtocol
from .iotmodule import IotModule
from .modules import Emeter, Time
@ -201,6 +202,26 @@ class IotDevice(Device):
"""Return the device modules."""
return self._modules
@overload
def get_module(self, module_type: type[ModuleT]) -> ModuleT | None: ...
@overload
def get_module(self, module_type: str) -> IotModule | None: ...
def get_module(
self, module_type: type[ModuleT] | str
) -> ModuleT | IotModule | None:
"""Return the module from the device modules or None if not present."""
if isinstance(module_type, str):
module_name = module_type.lower()
elif issubclass(module_type, IotModule):
module_name = module_type.__name__.lower()
else:
return None
if module_name in self.modules:
return self.modules[module_name]
return None
def add_module(self, name: str, module: IotModule):
"""Register a module."""
if name in self.modules:

View File

@ -4,7 +4,10 @@ from __future__ import annotations
import logging
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from typing import (
TYPE_CHECKING,
TypeVar,
)
from .exceptions import KasaException
from .feature import Feature
@ -14,6 +17,8 @@ if TYPE_CHECKING:
_LOGGER = logging.getLogger(__name__)
ModuleT = TypeVar("ModuleT", bound="Module")
class Module(ABC):
"""Base class implemention for all modules.

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import base64
import logging
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, Mapping, Sequence, cast
from typing import Any, Mapping, Sequence, cast, overload
from ..aestransport import AesTransport
from ..bulb import HSV, Bulb, BulbPreset, ColorTempRange
@ -16,6 +16,7 @@ from ..emeterstatus import EmeterStatus
from ..exceptions import AuthenticationError, DeviceError, KasaException, SmartErrorCode
from ..fan import Fan
from ..feature import Feature
from ..module import ModuleT
from ..smartprotocol import SmartProtocol
from .modules import (
Brightness,
@ -28,11 +29,10 @@ from .modules import (
Firmware,
TimeModule,
)
from .smartmodule import SmartModule
_LOGGER = logging.getLogger(__name__)
if TYPE_CHECKING:
from .smartmodule import SmartModule
# List of modules that wall switches with children, i.e. ks240 report on
# the child but only work on the parent. See longer note below in _initialize_modules.
@ -305,8 +305,22 @@ class SmartDevice(Bulb, Fan, Device):
for feat in module._module_features.values():
self._add_feature(feat)
def get_module(self, module_name) -> SmartModule | None:
@overload
def get_module(self, module_type: type[ModuleT]) -> ModuleT | None: ...
@overload
def get_module(self, module_type: str) -> SmartModule | None: ...
def get_module(
self, module_type: type[ModuleT] | str
) -> ModuleT | SmartModule | None:
"""Return the module from the device modules or None if not present."""
if isinstance(module_type, str):
module_name = module_type
elif issubclass(module_type, SmartModule):
module_name = module_type.__name__
else:
return None
if module_name in self.modules:
return self.modules[module_name]
elif self._exposes_child_modules:

View File

@ -33,7 +33,7 @@ async def test_brightness_component(dev: SmartDevice):
@dimmable
async def test_brightness_dimmable(dev: SmartDevice):
async def test_brightness_dimmable(dev: IotDevice):
"""Test brightness feature."""
assert isinstance(dev, IotDevice)
assert "brightness" in dev.sys_info or bool(dev.sys_info["is_dimmable"])

View File

@ -1,5 +1,3 @@
from typing import cast
import pytest
from pytest_mock import MockerFixture
@ -13,7 +11,7 @@ fan = parametrize("has fan", component_filter="fan_control", protocol_filter={"S
@fan
async def test_fan_speed(dev: SmartDevice, mocker: MockerFixture):
"""Test fan speed feature."""
fan = cast(FanModule, dev.get_module("FanModule"))
fan = dev.get_module(FanModule)
assert fan
level_feature = fan._module_features["fan_speed_level"]
@ -38,7 +36,7 @@ async def test_fan_speed(dev: SmartDevice, mocker: MockerFixture):
@fan
async def test_sleep_mode(dev: SmartDevice, mocker: MockerFixture):
"""Test sleep mode feature."""
fan = cast(FanModule, dev.get_module("FanModule"))
fan = dev.get_module(FanModule)
assert fan
sleep_feature = fan._module_features["fan_sleep_mode"]
assert isinstance(sleep_feature.value, bool)
@ -57,7 +55,8 @@ async def test_sleep_mode(dev: SmartDevice, mocker: MockerFixture):
async def test_fan_interface(dev: SmartDevice, mocker: MockerFixture):
"""Test fan speed on device interface."""
assert isinstance(dev, SmartDevice)
fan = cast(FanModule, dev.get_module("FanModule"))
fan = dev.get_module(FanModule)
assert fan
device = fan._device
assert device.is_fan

View File

@ -19,7 +19,7 @@ from voluptuous import (
from kasa import KasaException
from kasa.iot import IotDevice
from .conftest import handle_turn_on, turn_on
from .conftest import get_device_for_fixture_protocol, handle_turn_on, turn_on
from .device_fixtures import device_iot, has_emeter_iot, no_emeter_iot
from .fakeprotocol_iot import FakeIotProtocol
@ -258,3 +258,30 @@ async def test_modules_not_supported(dev: IotDevice):
await dev.update()
for module in dev.modules.values():
assert module.is_supported is not None
async def test_get_modules():
"""Test get_modules for child and parent modules."""
dummy_device = await get_device_for_fixture_protocol(
"HS100(US)_2.0_1.5.6.json", "IOT"
)
from kasa.iot.modules import Cloud
from kasa.smart.modules import CloudModule
# Modules on device
module = dummy_device.get_module("Cloud")
assert module
assert module._device == dummy_device
assert isinstance(module, Cloud)
module = dummy_device.get_module(Cloud)
assert module
assert module._device == dummy_device
assert isinstance(module, Cloud)
# Invalid modules
module = dummy_device.get_module("DummyModule")
assert module is None
module = dummy_device.get_module(CloudModule)
assert module is None

View File

@ -122,23 +122,43 @@ async def test_update_module_queries(dev: SmartDevice, mocker: MockerFixture):
spies[device].assert_not_called()
async def test_get_modules(mocker):
async def test_get_modules():
"""Test get_modules for child and parent modules."""
dummy_device = await get_device_for_fixture_protocol(
"KS240(US)_1.0_1.0.5.json", "SMART"
)
from kasa.iot.modules import AmbientLight
from kasa.smart.modules import CloudModule, FanModule
# Modules on device
module = dummy_device.get_module("CloudModule")
assert module
assert module._device == dummy_device
assert isinstance(module, CloudModule)
module = dummy_device.get_module(CloudModule)
assert module
assert module._device == dummy_device
assert isinstance(module, CloudModule)
# Modules on child
module = dummy_device.get_module("FanModule")
assert module
assert module._device != dummy_device
assert module._device._parent == dummy_device
module = dummy_device.get_module(FanModule)
assert module
assert module._device != dummy_device
assert module._device._parent == dummy_device
# Invalid modules
module = dummy_device.get_module("DummyModule")
assert module is None
module = dummy_device.get_module(AmbientLight)
assert module is None
@bulb_smart
async def test_smartdevice_brightness(dev: SmartDevice):