Make get_module return typed module (#892)

Passing in a string still works and returns either `IotModule` or
`SmartModule` type when called on `IotDevice` or `SmartDevice`
respectively. When calling on `Device` will return `Module` type.

Passing in a module type is then typed to that module, i.e.:
```py
smartdev.get_module(FanModule)  # type is FanModule
smartdev.get_module("FanModule")  # type is SmartModule
```
Only thing this doesn't do is check that you can't pass an `IotModule`
to a `SmartDevice.get_module()`. However there is a runtime check which
will return null if the passed `ModuleType` is not a subclass of
`SmartModule`.

Many thanks to @cdce8p for helping with this.
This commit is contained in:
Steven B 2024-05-03 16:01:21 +01:00 committed by GitHub
parent 530fb841b0
commit c5d65b624b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 114 additions and 16 deletions

View File

@ -6,7 +6,7 @@ import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from typing import Any, Mapping, Sequence from typing import Any, Mapping, Sequence, overload
from .credentials import Credentials from .credentials import Credentials
from .device_type import DeviceType from .device_type import DeviceType
@ -15,7 +15,7 @@ from .emeterstatus import EmeterStatus
from .exceptions import KasaException from .exceptions import KasaException
from .feature import Feature from .feature import Feature
from .iotprotocol import IotProtocol from .iotprotocol import IotProtocol
from .module import Module from .module import Module, ModuleT
from .protocol import BaseProtocol from .protocol import BaseProtocol
from .xortransport import XorTransport from .xortransport import XorTransport
@ -116,6 +116,18 @@ class Device(ABC):
def modules(self) -> Mapping[str, Module]: def modules(self) -> Mapping[str, Module]:
"""Return the device modules.""" """Return the device modules."""
@overload
@abstractmethod
def get_module(self, module_type: type[ModuleT]) -> ModuleT | None: ...
@overload
@abstractmethod
def get_module(self, module_type: str) -> Module | None: ...
@abstractmethod
def get_module(self, module_type: type[ModuleT] | str) -> ModuleT | Module | None:
"""Return the module from the device modules or None if not present."""
@property @property
@abstractmethod @abstractmethod
def is_on(self) -> bool: def is_on(self) -> bool:

View File

@ -19,13 +19,14 @@ import functools
import inspect import inspect
import logging import logging
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Mapping, Sequence, cast from typing import Any, Mapping, Sequence, cast, overload
from ..device import Device, WifiNetwork from ..device import Device, WifiNetwork
from ..deviceconfig import DeviceConfig from ..deviceconfig import DeviceConfig
from ..emeterstatus import EmeterStatus from ..emeterstatus import EmeterStatus
from ..exceptions import KasaException from ..exceptions import KasaException
from ..feature import Feature from ..feature import Feature
from ..module import ModuleT
from ..protocol import BaseProtocol from ..protocol import BaseProtocol
from .iotmodule import IotModule from .iotmodule import IotModule
from .modules import Emeter, Time from .modules import Emeter, Time
@ -201,6 +202,26 @@ class IotDevice(Device):
"""Return the device modules.""" """Return the device modules."""
return self._modules return self._modules
@overload
def get_module(self, module_type: type[ModuleT]) -> ModuleT | None: ...
@overload
def get_module(self, module_type: str) -> IotModule | None: ...
def get_module(
self, module_type: type[ModuleT] | str
) -> ModuleT | IotModule | None:
"""Return the module from the device modules or None if not present."""
if isinstance(module_type, str):
module_name = module_type.lower()
elif issubclass(module_type, IotModule):
module_name = module_type.__name__.lower()
else:
return None
if module_name in self.modules:
return self.modules[module_name]
return None
def add_module(self, name: str, module: IotModule): def add_module(self, name: str, module: IotModule):
"""Register a module.""" """Register a module."""
if name in self.modules: if name in self.modules:

View File

@ -4,7 +4,10 @@ from __future__ import annotations
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING from typing import (
TYPE_CHECKING,
TypeVar,
)
from .exceptions import KasaException from .exceptions import KasaException
from .feature import Feature from .feature import Feature
@ -14,6 +17,8 @@ if TYPE_CHECKING:
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
ModuleT = TypeVar("ModuleT", bound="Module")
class Module(ABC): class Module(ABC):
"""Base class implemention for all modules. """Base class implemention for all modules.

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import base64 import base64
import logging import logging
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, Mapping, Sequence, cast from typing import Any, Mapping, Sequence, cast, overload
from ..aestransport import AesTransport from ..aestransport import AesTransport
from ..bulb import HSV, Bulb, BulbPreset, ColorTempRange from ..bulb import HSV, Bulb, BulbPreset, ColorTempRange
@ -16,6 +16,7 @@ from ..emeterstatus import EmeterStatus
from ..exceptions import AuthenticationError, DeviceError, KasaException, SmartErrorCode from ..exceptions import AuthenticationError, DeviceError, KasaException, SmartErrorCode
from ..fan import Fan from ..fan import Fan
from ..feature import Feature from ..feature import Feature
from ..module import ModuleT
from ..smartprotocol import SmartProtocol from ..smartprotocol import SmartProtocol
from .modules import ( from .modules import (
Brightness, Brightness,
@ -28,11 +29,10 @@ from .modules import (
Firmware, Firmware,
TimeModule, TimeModule,
) )
from .smartmodule import SmartModule
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
if TYPE_CHECKING:
from .smartmodule import SmartModule
# List of modules that wall switches with children, i.e. ks240 report on # List of modules that wall switches with children, i.e. ks240 report on
# the child but only work on the parent. See longer note below in _initialize_modules. # the child but only work on the parent. See longer note below in _initialize_modules.
@ -305,8 +305,22 @@ class SmartDevice(Bulb, Fan, Device):
for feat in module._module_features.values(): for feat in module._module_features.values():
self._add_feature(feat) self._add_feature(feat)
def get_module(self, module_name) -> SmartModule | None: @overload
def get_module(self, module_type: type[ModuleT]) -> ModuleT | None: ...
@overload
def get_module(self, module_type: str) -> SmartModule | None: ...
def get_module(
self, module_type: type[ModuleT] | str
) -> ModuleT | SmartModule | None:
"""Return the module from the device modules or None if not present.""" """Return the module from the device modules or None if not present."""
if isinstance(module_type, str):
module_name = module_type
elif issubclass(module_type, SmartModule):
module_name = module_type.__name__
else:
return None
if module_name in self.modules: if module_name in self.modules:
return self.modules[module_name] return self.modules[module_name]
elif self._exposes_child_modules: elif self._exposes_child_modules:

View File

@ -33,7 +33,7 @@ async def test_brightness_component(dev: SmartDevice):
@dimmable @dimmable
async def test_brightness_dimmable(dev: SmartDevice): async def test_brightness_dimmable(dev: IotDevice):
"""Test brightness feature.""" """Test brightness feature."""
assert isinstance(dev, IotDevice) assert isinstance(dev, IotDevice)
assert "brightness" in dev.sys_info or bool(dev.sys_info["is_dimmable"]) assert "brightness" in dev.sys_info or bool(dev.sys_info["is_dimmable"])

View File

@ -1,5 +1,3 @@
from typing import cast
import pytest import pytest
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
@ -13,7 +11,7 @@ fan = parametrize("has fan", component_filter="fan_control", protocol_filter={"S
@fan @fan
async def test_fan_speed(dev: SmartDevice, mocker: MockerFixture): async def test_fan_speed(dev: SmartDevice, mocker: MockerFixture):
"""Test fan speed feature.""" """Test fan speed feature."""
fan = cast(FanModule, dev.get_module("FanModule")) fan = dev.get_module(FanModule)
assert fan assert fan
level_feature = fan._module_features["fan_speed_level"] level_feature = fan._module_features["fan_speed_level"]
@ -38,7 +36,7 @@ async def test_fan_speed(dev: SmartDevice, mocker: MockerFixture):
@fan @fan
async def test_sleep_mode(dev: SmartDevice, mocker: MockerFixture): async def test_sleep_mode(dev: SmartDevice, mocker: MockerFixture):
"""Test sleep mode feature.""" """Test sleep mode feature."""
fan = cast(FanModule, dev.get_module("FanModule")) fan = dev.get_module(FanModule)
assert fan assert fan
sleep_feature = fan._module_features["fan_sleep_mode"] sleep_feature = fan._module_features["fan_sleep_mode"]
assert isinstance(sleep_feature.value, bool) assert isinstance(sleep_feature.value, bool)
@ -57,7 +55,8 @@ async def test_sleep_mode(dev: SmartDevice, mocker: MockerFixture):
async def test_fan_interface(dev: SmartDevice, mocker: MockerFixture): async def test_fan_interface(dev: SmartDevice, mocker: MockerFixture):
"""Test fan speed on device interface.""" """Test fan speed on device interface."""
assert isinstance(dev, SmartDevice) assert isinstance(dev, SmartDevice)
fan = cast(FanModule, dev.get_module("FanModule")) fan = dev.get_module(FanModule)
assert fan
device = fan._device device = fan._device
assert device.is_fan assert device.is_fan

View File

@ -19,7 +19,7 @@ from voluptuous import (
from kasa import KasaException from kasa import KasaException
from kasa.iot import IotDevice from kasa.iot import IotDevice
from .conftest import handle_turn_on, turn_on from .conftest import get_device_for_fixture_protocol, handle_turn_on, turn_on
from .device_fixtures import device_iot, has_emeter_iot, no_emeter_iot from .device_fixtures import device_iot, has_emeter_iot, no_emeter_iot
from .fakeprotocol_iot import FakeIotProtocol from .fakeprotocol_iot import FakeIotProtocol
@ -258,3 +258,30 @@ async def test_modules_not_supported(dev: IotDevice):
await dev.update() await dev.update()
for module in dev.modules.values(): for module in dev.modules.values():
assert module.is_supported is not None assert module.is_supported is not None
async def test_get_modules():
"""Test get_modules for child and parent modules."""
dummy_device = await get_device_for_fixture_protocol(
"HS100(US)_2.0_1.5.6.json", "IOT"
)
from kasa.iot.modules import Cloud
from kasa.smart.modules import CloudModule
# Modules on device
module = dummy_device.get_module("Cloud")
assert module
assert module._device == dummy_device
assert isinstance(module, Cloud)
module = dummy_device.get_module(Cloud)
assert module
assert module._device == dummy_device
assert isinstance(module, Cloud)
# Invalid modules
module = dummy_device.get_module("DummyModule")
assert module is None
module = dummy_device.get_module(CloudModule)
assert module is None

View File

@ -122,23 +122,43 @@ async def test_update_module_queries(dev: SmartDevice, mocker: MockerFixture):
spies[device].assert_not_called() spies[device].assert_not_called()
async def test_get_modules(mocker): async def test_get_modules():
"""Test get_modules for child and parent modules.""" """Test get_modules for child and parent modules."""
dummy_device = await get_device_for_fixture_protocol( dummy_device = await get_device_for_fixture_protocol(
"KS240(US)_1.0_1.0.5.json", "SMART" "KS240(US)_1.0_1.0.5.json", "SMART"
) )
from kasa.iot.modules import AmbientLight
from kasa.smart.modules import CloudModule, FanModule
# Modules on device
module = dummy_device.get_module("CloudModule") module = dummy_device.get_module("CloudModule")
assert module assert module
assert module._device == dummy_device assert module._device == dummy_device
assert isinstance(module, CloudModule)
module = dummy_device.get_module(CloudModule)
assert module
assert module._device == dummy_device
assert isinstance(module, CloudModule)
# Modules on child
module = dummy_device.get_module("FanModule") module = dummy_device.get_module("FanModule")
assert module assert module
assert module._device != dummy_device assert module._device != dummy_device
assert module._device._parent == dummy_device assert module._device._parent == dummy_device
module = dummy_device.get_module(FanModule)
assert module
assert module._device != dummy_device
assert module._device._parent == dummy_device
# Invalid modules
module = dummy_device.get_module("DummyModule") module = dummy_device.get_module("DummyModule")
assert module is None assert module is None
module = dummy_device.get_module(AmbientLight)
assert module is None
@bulb_smart @bulb_smart
async def test_smartdevice_brightness(dev: SmartDevice): async def test_smartdevice_brightness(dev: SmartDevice):