mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-01-08 22:07:06 +00:00
Make get_module return typed module (#892)
Passing in a string still works and returns either `IotModule` or `SmartModule` type when called on `IotDevice` or `SmartDevice` respectively. When calling on `Device` will return `Module` type. Passing in a module type is then typed to that module, i.e.: ```py smartdev.get_module(FanModule) # type is FanModule smartdev.get_module("FanModule") # type is SmartModule ``` Only thing this doesn't do is check that you can't pass an `IotModule` to a `SmartDevice.get_module()`. However there is a runtime check which will return null if the passed `ModuleType` is not a subclass of `SmartModule`. Many thanks to @cdce8p for helping with this.
This commit is contained in:
parent
530fb841b0
commit
c5d65b624b
@ -6,7 +6,7 @@ import logging
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Mapping, Sequence
|
from typing import Any, Mapping, Sequence, overload
|
||||||
|
|
||||||
from .credentials import Credentials
|
from .credentials import Credentials
|
||||||
from .device_type import DeviceType
|
from .device_type import DeviceType
|
||||||
@ -15,7 +15,7 @@ from .emeterstatus import EmeterStatus
|
|||||||
from .exceptions import KasaException
|
from .exceptions import KasaException
|
||||||
from .feature import Feature
|
from .feature import Feature
|
||||||
from .iotprotocol import IotProtocol
|
from .iotprotocol import IotProtocol
|
||||||
from .module import Module
|
from .module import Module, ModuleT
|
||||||
from .protocol import BaseProtocol
|
from .protocol import BaseProtocol
|
||||||
from .xortransport import XorTransport
|
from .xortransport import XorTransport
|
||||||
|
|
||||||
@ -116,6 +116,18 @@ class Device(ABC):
|
|||||||
def modules(self) -> Mapping[str, Module]:
|
def modules(self) -> Mapping[str, Module]:
|
||||||
"""Return the device modules."""
|
"""Return the device modules."""
|
||||||
|
|
||||||
|
@overload
|
||||||
|
@abstractmethod
|
||||||
|
def get_module(self, module_type: type[ModuleT]) -> ModuleT | None: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
@abstractmethod
|
||||||
|
def get_module(self, module_type: str) -> Module | None: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_module(self, module_type: type[ModuleT] | str) -> ModuleT | Module | None:
|
||||||
|
"""Return the module from the device modules or None if not present."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def is_on(self) -> bool:
|
def is_on(self) -> bool:
|
||||||
|
@ -19,13 +19,14 @@ import functools
|
|||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Any, Mapping, Sequence, cast
|
from typing import Any, Mapping, Sequence, cast, overload
|
||||||
|
|
||||||
from ..device import Device, WifiNetwork
|
from ..device import Device, WifiNetwork
|
||||||
from ..deviceconfig import DeviceConfig
|
from ..deviceconfig import DeviceConfig
|
||||||
from ..emeterstatus import EmeterStatus
|
from ..emeterstatus import EmeterStatus
|
||||||
from ..exceptions import KasaException
|
from ..exceptions import KasaException
|
||||||
from ..feature import Feature
|
from ..feature import Feature
|
||||||
|
from ..module import ModuleT
|
||||||
from ..protocol import BaseProtocol
|
from ..protocol import BaseProtocol
|
||||||
from .iotmodule import IotModule
|
from .iotmodule import IotModule
|
||||||
from .modules import Emeter, Time
|
from .modules import Emeter, Time
|
||||||
@ -201,6 +202,26 @@ class IotDevice(Device):
|
|||||||
"""Return the device modules."""
|
"""Return the device modules."""
|
||||||
return self._modules
|
return self._modules
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_module(self, module_type: type[ModuleT]) -> ModuleT | None: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_module(self, module_type: str) -> IotModule | None: ...
|
||||||
|
|
||||||
|
def get_module(
|
||||||
|
self, module_type: type[ModuleT] | str
|
||||||
|
) -> ModuleT | IotModule | None:
|
||||||
|
"""Return the module from the device modules or None if not present."""
|
||||||
|
if isinstance(module_type, str):
|
||||||
|
module_name = module_type.lower()
|
||||||
|
elif issubclass(module_type, IotModule):
|
||||||
|
module_name = module_type.__name__.lower()
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
if module_name in self.modules:
|
||||||
|
return self.modules[module_name]
|
||||||
|
return None
|
||||||
|
|
||||||
def add_module(self, name: str, module: IotModule):
|
def add_module(self, name: str, module: IotModule):
|
||||||
"""Register a module."""
|
"""Register a module."""
|
||||||
if name in self.modules:
|
if name in self.modules:
|
||||||
|
@ -4,7 +4,10 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import TYPE_CHECKING
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
TypeVar,
|
||||||
|
)
|
||||||
|
|
||||||
from .exceptions import KasaException
|
from .exceptions import KasaException
|
||||||
from .feature import Feature
|
from .feature import Feature
|
||||||
@ -14,6 +17,8 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
ModuleT = TypeVar("ModuleT", bound="Module")
|
||||||
|
|
||||||
|
|
||||||
class Module(ABC):
|
class Module(ABC):
|
||||||
"""Base class implemention for all modules.
|
"""Base class implemention for all modules.
|
||||||
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
|||||||
import base64
|
import base64
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import TYPE_CHECKING, Any, Mapping, Sequence, cast
|
from typing import Any, Mapping, Sequence, cast, overload
|
||||||
|
|
||||||
from ..aestransport import AesTransport
|
from ..aestransport import AesTransport
|
||||||
from ..bulb import HSV, Bulb, BulbPreset, ColorTempRange
|
from ..bulb import HSV, Bulb, BulbPreset, ColorTempRange
|
||||||
@ -16,6 +16,7 @@ from ..emeterstatus import EmeterStatus
|
|||||||
from ..exceptions import AuthenticationError, DeviceError, KasaException, SmartErrorCode
|
from ..exceptions import AuthenticationError, DeviceError, KasaException, SmartErrorCode
|
||||||
from ..fan import Fan
|
from ..fan import Fan
|
||||||
from ..feature import Feature
|
from ..feature import Feature
|
||||||
|
from ..module import ModuleT
|
||||||
from ..smartprotocol import SmartProtocol
|
from ..smartprotocol import SmartProtocol
|
||||||
from .modules import (
|
from .modules import (
|
||||||
Brightness,
|
Brightness,
|
||||||
@ -28,11 +29,10 @@ from .modules import (
|
|||||||
Firmware,
|
Firmware,
|
||||||
TimeModule,
|
TimeModule,
|
||||||
)
|
)
|
||||||
|
from .smartmodule import SmartModule
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .smartmodule import SmartModule
|
|
||||||
|
|
||||||
# List of modules that wall switches with children, i.e. ks240 report on
|
# List of modules that wall switches with children, i.e. ks240 report on
|
||||||
# the child but only work on the parent. See longer note below in _initialize_modules.
|
# the child but only work on the parent. See longer note below in _initialize_modules.
|
||||||
@ -305,8 +305,22 @@ class SmartDevice(Bulb, Fan, Device):
|
|||||||
for feat in module._module_features.values():
|
for feat in module._module_features.values():
|
||||||
self._add_feature(feat)
|
self._add_feature(feat)
|
||||||
|
|
||||||
def get_module(self, module_name) -> SmartModule | None:
|
@overload
|
||||||
|
def get_module(self, module_type: type[ModuleT]) -> ModuleT | None: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_module(self, module_type: str) -> SmartModule | None: ...
|
||||||
|
|
||||||
|
def get_module(
|
||||||
|
self, module_type: type[ModuleT] | str
|
||||||
|
) -> ModuleT | SmartModule | None:
|
||||||
"""Return the module from the device modules or None if not present."""
|
"""Return the module from the device modules or None if not present."""
|
||||||
|
if isinstance(module_type, str):
|
||||||
|
module_name = module_type
|
||||||
|
elif issubclass(module_type, SmartModule):
|
||||||
|
module_name = module_type.__name__
|
||||||
|
else:
|
||||||
|
return None
|
||||||
if module_name in self.modules:
|
if module_name in self.modules:
|
||||||
return self.modules[module_name]
|
return self.modules[module_name]
|
||||||
elif self._exposes_child_modules:
|
elif self._exposes_child_modules:
|
||||||
|
@ -33,7 +33,7 @@ async def test_brightness_component(dev: SmartDevice):
|
|||||||
|
|
||||||
|
|
||||||
@dimmable
|
@dimmable
|
||||||
async def test_brightness_dimmable(dev: SmartDevice):
|
async def test_brightness_dimmable(dev: IotDevice):
|
||||||
"""Test brightness feature."""
|
"""Test brightness feature."""
|
||||||
assert isinstance(dev, IotDevice)
|
assert isinstance(dev, IotDevice)
|
||||||
assert "brightness" in dev.sys_info or bool(dev.sys_info["is_dimmable"])
|
assert "brightness" in dev.sys_info or bool(dev.sys_info["is_dimmable"])
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
from typing import cast
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pytest_mock import MockerFixture
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
@ -13,7 +11,7 @@ fan = parametrize("has fan", component_filter="fan_control", protocol_filter={"S
|
|||||||
@fan
|
@fan
|
||||||
async def test_fan_speed(dev: SmartDevice, mocker: MockerFixture):
|
async def test_fan_speed(dev: SmartDevice, mocker: MockerFixture):
|
||||||
"""Test fan speed feature."""
|
"""Test fan speed feature."""
|
||||||
fan = cast(FanModule, dev.get_module("FanModule"))
|
fan = dev.get_module(FanModule)
|
||||||
assert fan
|
assert fan
|
||||||
|
|
||||||
level_feature = fan._module_features["fan_speed_level"]
|
level_feature = fan._module_features["fan_speed_level"]
|
||||||
@ -38,7 +36,7 @@ async def test_fan_speed(dev: SmartDevice, mocker: MockerFixture):
|
|||||||
@fan
|
@fan
|
||||||
async def test_sleep_mode(dev: SmartDevice, mocker: MockerFixture):
|
async def test_sleep_mode(dev: SmartDevice, mocker: MockerFixture):
|
||||||
"""Test sleep mode feature."""
|
"""Test sleep mode feature."""
|
||||||
fan = cast(FanModule, dev.get_module("FanModule"))
|
fan = dev.get_module(FanModule)
|
||||||
assert fan
|
assert fan
|
||||||
sleep_feature = fan._module_features["fan_sleep_mode"]
|
sleep_feature = fan._module_features["fan_sleep_mode"]
|
||||||
assert isinstance(sleep_feature.value, bool)
|
assert isinstance(sleep_feature.value, bool)
|
||||||
@ -57,7 +55,8 @@ async def test_sleep_mode(dev: SmartDevice, mocker: MockerFixture):
|
|||||||
async def test_fan_interface(dev: SmartDevice, mocker: MockerFixture):
|
async def test_fan_interface(dev: SmartDevice, mocker: MockerFixture):
|
||||||
"""Test fan speed on device interface."""
|
"""Test fan speed on device interface."""
|
||||||
assert isinstance(dev, SmartDevice)
|
assert isinstance(dev, SmartDevice)
|
||||||
fan = cast(FanModule, dev.get_module("FanModule"))
|
fan = dev.get_module(FanModule)
|
||||||
|
assert fan
|
||||||
device = fan._device
|
device = fan._device
|
||||||
assert device.is_fan
|
assert device.is_fan
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ from voluptuous import (
|
|||||||
from kasa import KasaException
|
from kasa import KasaException
|
||||||
from kasa.iot import IotDevice
|
from kasa.iot import IotDevice
|
||||||
|
|
||||||
from .conftest import handle_turn_on, turn_on
|
from .conftest import get_device_for_fixture_protocol, handle_turn_on, turn_on
|
||||||
from .device_fixtures import device_iot, has_emeter_iot, no_emeter_iot
|
from .device_fixtures import device_iot, has_emeter_iot, no_emeter_iot
|
||||||
from .fakeprotocol_iot import FakeIotProtocol
|
from .fakeprotocol_iot import FakeIotProtocol
|
||||||
|
|
||||||
@ -258,3 +258,30 @@ async def test_modules_not_supported(dev: IotDevice):
|
|||||||
await dev.update()
|
await dev.update()
|
||||||
for module in dev.modules.values():
|
for module in dev.modules.values():
|
||||||
assert module.is_supported is not None
|
assert module.is_supported is not None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_modules():
|
||||||
|
"""Test get_modules for child and parent modules."""
|
||||||
|
dummy_device = await get_device_for_fixture_protocol(
|
||||||
|
"HS100(US)_2.0_1.5.6.json", "IOT"
|
||||||
|
)
|
||||||
|
from kasa.iot.modules import Cloud
|
||||||
|
from kasa.smart.modules import CloudModule
|
||||||
|
|
||||||
|
# Modules on device
|
||||||
|
module = dummy_device.get_module("Cloud")
|
||||||
|
assert module
|
||||||
|
assert module._device == dummy_device
|
||||||
|
assert isinstance(module, Cloud)
|
||||||
|
|
||||||
|
module = dummy_device.get_module(Cloud)
|
||||||
|
assert module
|
||||||
|
assert module._device == dummy_device
|
||||||
|
assert isinstance(module, Cloud)
|
||||||
|
|
||||||
|
# Invalid modules
|
||||||
|
module = dummy_device.get_module("DummyModule")
|
||||||
|
assert module is None
|
||||||
|
|
||||||
|
module = dummy_device.get_module(CloudModule)
|
||||||
|
assert module is None
|
||||||
|
@ -122,23 +122,43 @@ async def test_update_module_queries(dev: SmartDevice, mocker: MockerFixture):
|
|||||||
spies[device].assert_not_called()
|
spies[device].assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
async def test_get_modules(mocker):
|
async def test_get_modules():
|
||||||
"""Test get_modules for child and parent modules."""
|
"""Test get_modules for child and parent modules."""
|
||||||
dummy_device = await get_device_for_fixture_protocol(
|
dummy_device = await get_device_for_fixture_protocol(
|
||||||
"KS240(US)_1.0_1.0.5.json", "SMART"
|
"KS240(US)_1.0_1.0.5.json", "SMART"
|
||||||
)
|
)
|
||||||
|
from kasa.iot.modules import AmbientLight
|
||||||
|
from kasa.smart.modules import CloudModule, FanModule
|
||||||
|
|
||||||
|
# Modules on device
|
||||||
module = dummy_device.get_module("CloudModule")
|
module = dummy_device.get_module("CloudModule")
|
||||||
assert module
|
assert module
|
||||||
assert module._device == dummy_device
|
assert module._device == dummy_device
|
||||||
|
assert isinstance(module, CloudModule)
|
||||||
|
|
||||||
|
module = dummy_device.get_module(CloudModule)
|
||||||
|
assert module
|
||||||
|
assert module._device == dummy_device
|
||||||
|
assert isinstance(module, CloudModule)
|
||||||
|
|
||||||
|
# Modules on child
|
||||||
module = dummy_device.get_module("FanModule")
|
module = dummy_device.get_module("FanModule")
|
||||||
assert module
|
assert module
|
||||||
assert module._device != dummy_device
|
assert module._device != dummy_device
|
||||||
assert module._device._parent == dummy_device
|
assert module._device._parent == dummy_device
|
||||||
|
|
||||||
|
module = dummy_device.get_module(FanModule)
|
||||||
|
assert module
|
||||||
|
assert module._device != dummy_device
|
||||||
|
assert module._device._parent == dummy_device
|
||||||
|
|
||||||
|
# Invalid modules
|
||||||
module = dummy_device.get_module("DummyModule")
|
module = dummy_device.get_module("DummyModule")
|
||||||
assert module is None
|
assert module is None
|
||||||
|
|
||||||
|
module = dummy_device.get_module(AmbientLight)
|
||||||
|
assert module is None
|
||||||
|
|
||||||
|
|
||||||
@bulb_smart
|
@bulb_smart
|
||||||
async def test_smartdevice_brightness(dev: SmartDevice):
|
async def test_smartdevice_brightness(dev: SmartDevice):
|
||||||
|
Loading…
Reference in New Issue
Block a user