mirror of
				https://github.com/python-kasa/python-kasa.git
				synced 2025-11-04 06:32:07 +00:00 
			
		
		
		
	Make get_module return typed module (#892)
Passing in a string still works and returns either `IotModule` or
`SmartModule` type when called on `IotDevice` or `SmartDevice`
respectively. When calling on `Device` will return `Module` type.
Passing in a module type is then typed to that module, i.e.:
```py
smartdev.get_module(FanModule)  # type is FanModule
smartdev.get_module("FanModule")  # type is SmartModule
```
Only thing this doesn't do is check that you can't pass an `IotModule`
to a `SmartDevice.get_module()`. However there is a runtime check which
will return null if the passed `ModuleType` is not a subclass of
`SmartModule`.
Many thanks to @cdce8p for helping with this.
			
			
This commit is contained in:
		@@ -6,7 +6,7 @@ import logging
 | 
			
		||||
from abc import ABC, abstractmethod
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from datetime import datetime
 | 
			
		||||
from typing import Any, Mapping, Sequence
 | 
			
		||||
from typing import Any, Mapping, Sequence, overload
 | 
			
		||||
 | 
			
		||||
from .credentials import Credentials
 | 
			
		||||
from .device_type import DeviceType
 | 
			
		||||
@@ -15,7 +15,7 @@ from .emeterstatus import EmeterStatus
 | 
			
		||||
from .exceptions import KasaException
 | 
			
		||||
from .feature import Feature
 | 
			
		||||
from .iotprotocol import IotProtocol
 | 
			
		||||
from .module import Module
 | 
			
		||||
from .module import Module, ModuleT
 | 
			
		||||
from .protocol import BaseProtocol
 | 
			
		||||
from .xortransport import XorTransport
 | 
			
		||||
 | 
			
		||||
@@ -116,6 +116,18 @@ class Device(ABC):
 | 
			
		||||
    def modules(self) -> Mapping[str, Module]:
 | 
			
		||||
        """Return the device modules."""
 | 
			
		||||
 | 
			
		||||
    @overload
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def get_module(self, module_type: type[ModuleT]) -> ModuleT | None: ...
 | 
			
		||||
 | 
			
		||||
    @overload
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def get_module(self, module_type: str) -> Module | None: ...
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def get_module(self, module_type: type[ModuleT] | str) -> ModuleT | Module | None:
 | 
			
		||||
        """Return the module from the device modules or None if not present."""
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def is_on(self) -> bool:
 | 
			
		||||
 
 | 
			
		||||
@@ -19,13 +19,14 @@ import functools
 | 
			
		||||
import inspect
 | 
			
		||||
import logging
 | 
			
		||||
from datetime import datetime, timedelta
 | 
			
		||||
from typing import Any, Mapping, Sequence, cast
 | 
			
		||||
from typing import Any, Mapping, Sequence, cast, overload
 | 
			
		||||
 | 
			
		||||
from ..device import Device, WifiNetwork
 | 
			
		||||
from ..deviceconfig import DeviceConfig
 | 
			
		||||
from ..emeterstatus import EmeterStatus
 | 
			
		||||
from ..exceptions import KasaException
 | 
			
		||||
from ..feature import Feature
 | 
			
		||||
from ..module import ModuleT
 | 
			
		||||
from ..protocol import BaseProtocol
 | 
			
		||||
from .iotmodule import IotModule
 | 
			
		||||
from .modules import Emeter, Time
 | 
			
		||||
@@ -201,6 +202,26 @@ class IotDevice(Device):
 | 
			
		||||
        """Return the device modules."""
 | 
			
		||||
        return self._modules
 | 
			
		||||
 | 
			
		||||
    @overload
 | 
			
		||||
    def get_module(self, module_type: type[ModuleT]) -> ModuleT | None: ...
 | 
			
		||||
 | 
			
		||||
    @overload
 | 
			
		||||
    def get_module(self, module_type: str) -> IotModule | None: ...
 | 
			
		||||
 | 
			
		||||
    def get_module(
 | 
			
		||||
        self, module_type: type[ModuleT] | str
 | 
			
		||||
    ) -> ModuleT | IotModule | None:
 | 
			
		||||
        """Return the module from the device modules or None if not present."""
 | 
			
		||||
        if isinstance(module_type, str):
 | 
			
		||||
            module_name = module_type.lower()
 | 
			
		||||
        elif issubclass(module_type, IotModule):
 | 
			
		||||
            module_name = module_type.__name__.lower()
 | 
			
		||||
        else:
 | 
			
		||||
            return None
 | 
			
		||||
        if module_name in self.modules:
 | 
			
		||||
            return self.modules[module_name]
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    def add_module(self, name: str, module: IotModule):
 | 
			
		||||
        """Register a module."""
 | 
			
		||||
        if name in self.modules:
 | 
			
		||||
 
 | 
			
		||||
@@ -4,7 +4,10 @@ from __future__ import annotations
 | 
			
		||||
 | 
			
		||||
import logging
 | 
			
		||||
from abc import ABC, abstractmethod
 | 
			
		||||
from typing import TYPE_CHECKING
 | 
			
		||||
from typing import (
 | 
			
		||||
    TYPE_CHECKING,
 | 
			
		||||
    TypeVar,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from .exceptions import KasaException
 | 
			
		||||
from .feature import Feature
 | 
			
		||||
@@ -14,6 +17,8 @@ if TYPE_CHECKING:
 | 
			
		||||
 | 
			
		||||
_LOGGER = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
ModuleT = TypeVar("ModuleT", bound="Module")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Module(ABC):
 | 
			
		||||
    """Base class implemention for all modules.
 | 
			
		||||
 
 | 
			
		||||
@@ -5,7 +5,7 @@ from __future__ import annotations
 | 
			
		||||
import base64
 | 
			
		||||
import logging
 | 
			
		||||
from datetime import datetime, timedelta
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Mapping, Sequence, cast
 | 
			
		||||
from typing import Any, Mapping, Sequence, cast, overload
 | 
			
		||||
 | 
			
		||||
from ..aestransport import AesTransport
 | 
			
		||||
from ..bulb import HSV, Bulb, BulbPreset, ColorTempRange
 | 
			
		||||
@@ -16,6 +16,7 @@ from ..emeterstatus import EmeterStatus
 | 
			
		||||
from ..exceptions import AuthenticationError, DeviceError, KasaException, SmartErrorCode
 | 
			
		||||
from ..fan import Fan
 | 
			
		||||
from ..feature import Feature
 | 
			
		||||
from ..module import ModuleT
 | 
			
		||||
from ..smartprotocol import SmartProtocol
 | 
			
		||||
from .modules import (
 | 
			
		||||
    Brightness,
 | 
			
		||||
@@ -28,11 +29,10 @@ from .modules import (
 | 
			
		||||
    Firmware,
 | 
			
		||||
    TimeModule,
 | 
			
		||||
)
 | 
			
		||||
from .smartmodule import SmartModule
 | 
			
		||||
 | 
			
		||||
_LOGGER = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from .smartmodule import SmartModule
 | 
			
		||||
 | 
			
		||||
# List of modules that wall switches with children, i.e. ks240 report on
 | 
			
		||||
# the child but only work on the parent.  See longer note below in _initialize_modules.
 | 
			
		||||
@@ -305,8 +305,22 @@ class SmartDevice(Bulb, Fan, Device):
 | 
			
		||||
            for feat in module._module_features.values():
 | 
			
		||||
                self._add_feature(feat)
 | 
			
		||||
 | 
			
		||||
    def get_module(self, module_name) -> SmartModule | None:
 | 
			
		||||
    @overload
 | 
			
		||||
    def get_module(self, module_type: type[ModuleT]) -> ModuleT | None: ...
 | 
			
		||||
 | 
			
		||||
    @overload
 | 
			
		||||
    def get_module(self, module_type: str) -> SmartModule | None: ...
 | 
			
		||||
 | 
			
		||||
    def get_module(
 | 
			
		||||
        self, module_type: type[ModuleT] | str
 | 
			
		||||
    ) -> ModuleT | SmartModule | None:
 | 
			
		||||
        """Return the module from the device modules or None if not present."""
 | 
			
		||||
        if isinstance(module_type, str):
 | 
			
		||||
            module_name = module_type
 | 
			
		||||
        elif issubclass(module_type, SmartModule):
 | 
			
		||||
            module_name = module_type.__name__
 | 
			
		||||
        else:
 | 
			
		||||
            return None
 | 
			
		||||
        if module_name in self.modules:
 | 
			
		||||
            return self.modules[module_name]
 | 
			
		||||
        elif self._exposes_child_modules:
 | 
			
		||||
 
 | 
			
		||||
@@ -33,7 +33,7 @@ async def test_brightness_component(dev: SmartDevice):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dimmable
 | 
			
		||||
async def test_brightness_dimmable(dev: SmartDevice):
 | 
			
		||||
async def test_brightness_dimmable(dev: IotDevice):
 | 
			
		||||
    """Test brightness feature."""
 | 
			
		||||
    assert isinstance(dev, IotDevice)
 | 
			
		||||
    assert "brightness" in dev.sys_info or bool(dev.sys_info["is_dimmable"])
 | 
			
		||||
 
 | 
			
		||||
@@ -1,5 +1,3 @@
 | 
			
		||||
from typing import cast
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
from pytest_mock import MockerFixture
 | 
			
		||||
 | 
			
		||||
@@ -13,7 +11,7 @@ fan = parametrize("has fan", component_filter="fan_control", protocol_filter={"S
 | 
			
		||||
@fan
 | 
			
		||||
async def test_fan_speed(dev: SmartDevice, mocker: MockerFixture):
 | 
			
		||||
    """Test fan speed feature."""
 | 
			
		||||
    fan = cast(FanModule, dev.get_module("FanModule"))
 | 
			
		||||
    fan = dev.get_module(FanModule)
 | 
			
		||||
    assert fan
 | 
			
		||||
 | 
			
		||||
    level_feature = fan._module_features["fan_speed_level"]
 | 
			
		||||
@@ -38,7 +36,7 @@ async def test_fan_speed(dev: SmartDevice, mocker: MockerFixture):
 | 
			
		||||
@fan
 | 
			
		||||
async def test_sleep_mode(dev: SmartDevice, mocker: MockerFixture):
 | 
			
		||||
    """Test sleep mode feature."""
 | 
			
		||||
    fan = cast(FanModule, dev.get_module("FanModule"))
 | 
			
		||||
    fan = dev.get_module(FanModule)
 | 
			
		||||
    assert fan
 | 
			
		||||
    sleep_feature = fan._module_features["fan_sleep_mode"]
 | 
			
		||||
    assert isinstance(sleep_feature.value, bool)
 | 
			
		||||
@@ -57,7 +55,8 @@ async def test_sleep_mode(dev: SmartDevice, mocker: MockerFixture):
 | 
			
		||||
async def test_fan_interface(dev: SmartDevice, mocker: MockerFixture):
 | 
			
		||||
    """Test fan speed on device interface."""
 | 
			
		||||
    assert isinstance(dev, SmartDevice)
 | 
			
		||||
    fan = cast(FanModule, dev.get_module("FanModule"))
 | 
			
		||||
    fan = dev.get_module(FanModule)
 | 
			
		||||
    assert fan
 | 
			
		||||
    device = fan._device
 | 
			
		||||
    assert device.is_fan
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -19,7 +19,7 @@ from voluptuous import (
 | 
			
		||||
from kasa import KasaException
 | 
			
		||||
from kasa.iot import IotDevice
 | 
			
		||||
 | 
			
		||||
from .conftest import handle_turn_on, turn_on
 | 
			
		||||
from .conftest import get_device_for_fixture_protocol, handle_turn_on, turn_on
 | 
			
		||||
from .device_fixtures import device_iot, has_emeter_iot, no_emeter_iot
 | 
			
		||||
from .fakeprotocol_iot import FakeIotProtocol
 | 
			
		||||
 | 
			
		||||
@@ -258,3 +258,30 @@ async def test_modules_not_supported(dev: IotDevice):
 | 
			
		||||
    await dev.update()
 | 
			
		||||
    for module in dev.modules.values():
 | 
			
		||||
        assert module.is_supported is not None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def test_get_modules():
 | 
			
		||||
    """Test get_modules for child and parent modules."""
 | 
			
		||||
    dummy_device = await get_device_for_fixture_protocol(
 | 
			
		||||
        "HS100(US)_2.0_1.5.6.json", "IOT"
 | 
			
		||||
    )
 | 
			
		||||
    from kasa.iot.modules import Cloud
 | 
			
		||||
    from kasa.smart.modules import CloudModule
 | 
			
		||||
 | 
			
		||||
    # Modules on device
 | 
			
		||||
    module = dummy_device.get_module("Cloud")
 | 
			
		||||
    assert module
 | 
			
		||||
    assert module._device == dummy_device
 | 
			
		||||
    assert isinstance(module, Cloud)
 | 
			
		||||
 | 
			
		||||
    module = dummy_device.get_module(Cloud)
 | 
			
		||||
    assert module
 | 
			
		||||
    assert module._device == dummy_device
 | 
			
		||||
    assert isinstance(module, Cloud)
 | 
			
		||||
 | 
			
		||||
    # Invalid modules
 | 
			
		||||
    module = dummy_device.get_module("DummyModule")
 | 
			
		||||
    assert module is None
 | 
			
		||||
 | 
			
		||||
    module = dummy_device.get_module(CloudModule)
 | 
			
		||||
    assert module is None
 | 
			
		||||
 
 | 
			
		||||
@@ -122,23 +122,43 @@ async def test_update_module_queries(dev: SmartDevice, mocker: MockerFixture):
 | 
			
		||||
            spies[device].assert_not_called()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def test_get_modules(mocker):
 | 
			
		||||
async def test_get_modules():
 | 
			
		||||
    """Test get_modules for child and parent modules."""
 | 
			
		||||
    dummy_device = await get_device_for_fixture_protocol(
 | 
			
		||||
        "KS240(US)_1.0_1.0.5.json", "SMART"
 | 
			
		||||
    )
 | 
			
		||||
    from kasa.iot.modules import AmbientLight
 | 
			
		||||
    from kasa.smart.modules import CloudModule, FanModule
 | 
			
		||||
 | 
			
		||||
    # Modules on device
 | 
			
		||||
    module = dummy_device.get_module("CloudModule")
 | 
			
		||||
    assert module
 | 
			
		||||
    assert module._device == dummy_device
 | 
			
		||||
    assert isinstance(module, CloudModule)
 | 
			
		||||
 | 
			
		||||
    module = dummy_device.get_module(CloudModule)
 | 
			
		||||
    assert module
 | 
			
		||||
    assert module._device == dummy_device
 | 
			
		||||
    assert isinstance(module, CloudModule)
 | 
			
		||||
 | 
			
		||||
    # Modules on child
 | 
			
		||||
    module = dummy_device.get_module("FanModule")
 | 
			
		||||
    assert module
 | 
			
		||||
    assert module._device != dummy_device
 | 
			
		||||
    assert module._device._parent == dummy_device
 | 
			
		||||
 | 
			
		||||
    module = dummy_device.get_module(FanModule)
 | 
			
		||||
    assert module
 | 
			
		||||
    assert module._device != dummy_device
 | 
			
		||||
    assert module._device._parent == dummy_device
 | 
			
		||||
 | 
			
		||||
    # Invalid modules
 | 
			
		||||
    module = dummy_device.get_module("DummyModule")
 | 
			
		||||
    assert module is None
 | 
			
		||||
 | 
			
		||||
    module = dummy_device.get_module(AmbientLight)
 | 
			
		||||
    assert module is None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@bulb_smart
 | 
			
		||||
async def test_smartdevice_brightness(dev: SmartDevice):
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user