mirror of
https://github.com/python-kasa/python-kasa.git
synced 2024-12-22 19:23:34 +00:00
Allow getting Annotated features from modules (#1018)
Co-authored-by: Teemu R. <tpr@iki.fi>
This commit is contained in:
parent
cae9decb02
commit
37cc4da7b6
@ -42,10 +42,13 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Callable
|
||||||
|
from functools import cache
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Final,
|
Final,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
|
get_type_hints,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .exceptions import KasaException
|
from .exceptions import KasaException
|
||||||
@ -64,6 +67,10 @@ _LOGGER = logging.getLogger(__name__)
|
|||||||
ModuleT = TypeVar("ModuleT", bound="Module")
|
ModuleT = TypeVar("ModuleT", bound="Module")
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureAttribute:
|
||||||
|
"""Class for annotating attributes bound to feature."""
|
||||||
|
|
||||||
|
|
||||||
class Module(ABC):
|
class Module(ABC):
|
||||||
"""Base class implemention for all modules.
|
"""Base class implemention for all modules.
|
||||||
|
|
||||||
@ -140,6 +147,14 @@ class Module(ABC):
|
|||||||
self._module = module
|
self._module = module
|
||||||
self._module_features: dict[str, Feature] = {}
|
self._module_features: dict[str, Feature] = {}
|
||||||
|
|
||||||
|
def has_feature(self, attribute: str | property | Callable) -> bool:
|
||||||
|
"""Return True if the module attribute feature is supported."""
|
||||||
|
return bool(self.get_feature(attribute))
|
||||||
|
|
||||||
|
def get_feature(self, attribute: str | property | Callable) -> Feature | None:
|
||||||
|
"""Get Feature for a module attribute or None if not supported."""
|
||||||
|
return _get_bound_feature(self, attribute)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def query(self) -> dict:
|
def query(self) -> dict:
|
||||||
"""Query to execute during the update cycle.
|
"""Query to execute during the update cycle.
|
||||||
@ -183,3 +198,60 @@ class Module(ABC):
|
|||||||
f"<Module {self.__class__.__name__} ({self._module})"
|
f"<Module {self.__class__.__name__} ({self._module})"
|
||||||
f" for {self._device.host}>"
|
f" for {self._device.host}>"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_bound_feature(attribute: property | Callable) -> bool:
|
||||||
|
"""Check if an attribute is bound to a feature with FeatureAttribute."""
|
||||||
|
if isinstance(attribute, property):
|
||||||
|
hints = get_type_hints(attribute.fget, include_extras=True)
|
||||||
|
else:
|
||||||
|
hints = get_type_hints(attribute, include_extras=True)
|
||||||
|
|
||||||
|
if (return_hints := hints.get("return")) and hasattr(return_hints, "__metadata__"):
|
||||||
|
metadata = hints["return"].__metadata__
|
||||||
|
for meta in metadata:
|
||||||
|
if isinstance(meta, FeatureAttribute):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def _get_bound_feature(
|
||||||
|
module: Module, attribute: str | property | Callable
|
||||||
|
) -> Feature | None:
|
||||||
|
"""Get Feature for a bound property or None if not supported."""
|
||||||
|
if not isinstance(attribute, str):
|
||||||
|
if isinstance(attribute, property):
|
||||||
|
# Properties have __name__ in 3.13 so this could be simplified
|
||||||
|
# when only 3.13 supported
|
||||||
|
attribute_name = attribute.fget.__name__ # type: ignore[union-attr]
|
||||||
|
else:
|
||||||
|
attribute_name = attribute.__name__
|
||||||
|
attribute_callable = attribute
|
||||||
|
else:
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
assert isinstance(attribute, str)
|
||||||
|
attribute_name = attribute
|
||||||
|
attribute_callable = getattr(module.__class__, attribute, None) # type: ignore[assignment]
|
||||||
|
if not attribute_callable:
|
||||||
|
raise KasaException(
|
||||||
|
f"No attribute named {attribute_name} in "
|
||||||
|
f"module {module.__class__.__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not _is_bound_feature(attribute_callable):
|
||||||
|
raise KasaException(
|
||||||
|
f"Attribute {attribute_name} of module {module.__class__.__name__}"
|
||||||
|
" is not bound to a feature"
|
||||||
|
)
|
||||||
|
|
||||||
|
check = {attribute_name, attribute_callable}
|
||||||
|
for feature in module._module_features.values():
|
||||||
|
if (getter := feature.attribute_getter) and getter in check:
|
||||||
|
return feature
|
||||||
|
|
||||||
|
if (setter := feature.attribute_setter) and setter in check:
|
||||||
|
return feature
|
||||||
|
|
||||||
|
return None
|
||||||
|
@ -2,8 +2,11 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
from ...feature import Feature
|
from ...feature import Feature
|
||||||
from ...interfaces.fan import Fan as FanInterface
|
from ...interfaces.fan import Fan as FanInterface
|
||||||
|
from ...module import FeatureAttribute
|
||||||
from ..smartmodule import SmartModule
|
from ..smartmodule import SmartModule
|
||||||
|
|
||||||
|
|
||||||
@ -46,11 +49,13 @@ class Fan(SmartModule, FanInterface):
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def fan_speed_level(self) -> int:
|
def fan_speed_level(self) -> Annotated[int, FeatureAttribute()]:
|
||||||
"""Return fan speed level."""
|
"""Return fan speed level."""
|
||||||
return 0 if self.data["device_on"] is False else self.data["fan_speed_level"]
|
return 0 if self.data["device_on"] is False else self.data["fan_speed_level"]
|
||||||
|
|
||||||
async def set_fan_speed_level(self, level: int) -> dict:
|
async def set_fan_speed_level(
|
||||||
|
self, level: int
|
||||||
|
) -> Annotated[dict, FeatureAttribute()]:
|
||||||
"""Set fan speed level, 0 for off, 1-4 for on."""
|
"""Set fan speed level, 0 for off, 1-4 for on."""
|
||||||
if level < 0 or level > 4:
|
if level < 0 or level > 4:
|
||||||
raise ValueError("Invalid level, should be in range 0-4.")
|
raise ValueError("Invalid level, should be in range 0-4.")
|
||||||
@ -61,11 +66,11 @@ class Fan(SmartModule, FanInterface):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sleep_mode(self) -> bool:
|
def sleep_mode(self) -> Annotated[bool, FeatureAttribute()]:
|
||||||
"""Return sleep mode status."""
|
"""Return sleep mode status."""
|
||||||
return self.data["fan_sleep_mode_on"]
|
return self.data["fan_sleep_mode_on"]
|
||||||
|
|
||||||
async def set_sleep_mode(self, on: bool) -> dict:
|
async def set_sleep_mode(self, on: bool) -> Annotated[dict, FeatureAttribute()]:
|
||||||
"""Set sleep mode."""
|
"""Set sleep mode."""
|
||||||
return await self.call("set_device_info", {"fan_sleep_mode_on": on})
|
return await self.call("set_device_info", {"fan_sleep_mode_on": on})
|
||||||
|
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from pytest_mock import MockerFixture
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
from kasa import Module
|
from kasa import KasaException, Module
|
||||||
from kasa.smart import SmartDevice
|
from kasa.smart import SmartDevice
|
||||||
|
from kasa.smart.modules import Fan
|
||||||
|
|
||||||
from ...device_fixtures import get_parent_and_child_modules, parametrize
|
from ...device_fixtures import get_parent_and_child_modules, parametrize
|
||||||
|
|
||||||
@ -77,8 +78,42 @@ async def test_fan_module(dev: SmartDevice, mocker: MockerFixture):
|
|||||||
await dev.update()
|
await dev.update()
|
||||||
assert not device.is_on
|
assert not device.is_on
|
||||||
|
|
||||||
|
fan_speed_level_feature = fan._module_features["fan_speed_level"]
|
||||||
|
max_level = fan_speed_level_feature.maximum_value
|
||||||
|
min_level = fan_speed_level_feature.minimum_value
|
||||||
with pytest.raises(ValueError, match="Invalid level"):
|
with pytest.raises(ValueError, match="Invalid level"):
|
||||||
await fan.set_fan_speed_level(-1)
|
await fan.set_fan_speed_level(min_level - 1)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Invalid level"):
|
with pytest.raises(ValueError, match="Invalid level"):
|
||||||
await fan.set_fan_speed_level(5)
|
await fan.set_fan_speed_level(max_level - 5)
|
||||||
|
|
||||||
|
|
||||||
|
@fan
|
||||||
|
async def test_fan_features(dev: SmartDevice, mocker: MockerFixture):
|
||||||
|
"""Test fan speed on device interface."""
|
||||||
|
assert isinstance(dev, SmartDevice)
|
||||||
|
fan = next(get_parent_and_child_modules(dev, Module.Fan))
|
||||||
|
assert fan
|
||||||
|
expected_feature = fan._module_features["fan_speed_level"]
|
||||||
|
|
||||||
|
fan_speed_level_feature = fan.get_feature(Fan.set_fan_speed_level)
|
||||||
|
assert expected_feature == fan_speed_level_feature
|
||||||
|
|
||||||
|
fan_speed_level_feature = fan.get_feature(fan.set_fan_speed_level)
|
||||||
|
assert expected_feature == fan_speed_level_feature
|
||||||
|
|
||||||
|
fan_speed_level_feature = fan.get_feature(Fan.fan_speed_level)
|
||||||
|
assert expected_feature == fan_speed_level_feature
|
||||||
|
|
||||||
|
fan_speed_level_feature = fan.get_feature("fan_speed_level")
|
||||||
|
assert expected_feature == fan_speed_level_feature
|
||||||
|
|
||||||
|
assert fan.has_feature(Fan.fan_speed_level)
|
||||||
|
|
||||||
|
msg = "Attribute _check_supported of module Fan is not bound to a feature"
|
||||||
|
with pytest.raises(KasaException, match=msg):
|
||||||
|
assert fan.has_feature(fan._check_supported)
|
||||||
|
|
||||||
|
msg = "No attribute named foobar in module Fan"
|
||||||
|
with pytest.raises(KasaException, match=msg):
|
||||||
|
assert fan.has_feature("foobar")
|
||||||
|
Loading…
Reference in New Issue
Block a user