Allow getting Annotated features from modules (#1018)

Co-authored-by: Teemu R. <tpr@iki.fi>
This commit is contained in:
Steven B. 2024-11-22 07:52:23 +00:00 committed by GitHub
parent cae9decb02
commit 37cc4da7b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 119 additions and 7 deletions

View File

@ -42,10 +42,13 @@ from __future__ import annotations
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable
from functools import cache
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Final, Final,
TypeVar, TypeVar,
get_type_hints,
) )
from .exceptions import KasaException from .exceptions import KasaException
@ -64,6 +67,10 @@ _LOGGER = logging.getLogger(__name__)
ModuleT = TypeVar("ModuleT", bound="Module") ModuleT = TypeVar("ModuleT", bound="Module")
class FeatureAttribute:
"""Class for annotating attributes bound to feature."""
class Module(ABC): class Module(ABC):
"""Base class implemention for all modules. """Base class implemention for all modules.
@ -140,6 +147,14 @@ class Module(ABC):
self._module = module self._module = module
self._module_features: dict[str, Feature] = {} self._module_features: dict[str, Feature] = {}
def has_feature(self, attribute: str | property | Callable) -> bool:
"""Return True if the module attribute feature is supported."""
return bool(self.get_feature(attribute))
def get_feature(self, attribute: str | property | Callable) -> Feature | None:
"""Get Feature for a module attribute or None if not supported."""
return _get_bound_feature(self, attribute)
@abstractmethod @abstractmethod
def query(self) -> dict: def query(self) -> dict:
"""Query to execute during the update cycle. """Query to execute during the update cycle.
@ -183,3 +198,60 @@ class Module(ABC):
f"<Module {self.__class__.__name__} ({self._module})" f"<Module {self.__class__.__name__} ({self._module})"
f" for {self._device.host}>" f" for {self._device.host}>"
) )
def _is_bound_feature(attribute: property | Callable) -> bool:
"""Check if an attribute is bound to a feature with FeatureAttribute."""
if isinstance(attribute, property):
hints = get_type_hints(attribute.fget, include_extras=True)
else:
hints = get_type_hints(attribute, include_extras=True)
if (return_hints := hints.get("return")) and hasattr(return_hints, "__metadata__"):
metadata = hints["return"].__metadata__
for meta in metadata:
if isinstance(meta, FeatureAttribute):
return True
return False
@cache
def _get_bound_feature(
module: Module, attribute: str | property | Callable
) -> Feature | None:
"""Get Feature for a bound property or None if not supported."""
if not isinstance(attribute, str):
if isinstance(attribute, property):
# Properties have __name__ in 3.13 so this could be simplified
# when only 3.13 supported
attribute_name = attribute.fget.__name__ # type: ignore[union-attr]
else:
attribute_name = attribute.__name__
attribute_callable = attribute
else:
if TYPE_CHECKING:
assert isinstance(attribute, str)
attribute_name = attribute
attribute_callable = getattr(module.__class__, attribute, None) # type: ignore[assignment]
if not attribute_callable:
raise KasaException(
f"No attribute named {attribute_name} in "
f"module {module.__class__.__name__}"
)
if not _is_bound_feature(attribute_callable):
raise KasaException(
f"Attribute {attribute_name} of module {module.__class__.__name__}"
" is not bound to a feature"
)
check = {attribute_name, attribute_callable}
for feature in module._module_features.values():
if (getter := feature.attribute_getter) and getter in check:
return feature
if (setter := feature.attribute_setter) and setter in check:
return feature
return None

View File

@ -2,8 +2,11 @@
from __future__ import annotations from __future__ import annotations
from typing import Annotated
from ...feature import Feature from ...feature import Feature
from ...interfaces.fan import Fan as FanInterface from ...interfaces.fan import Fan as FanInterface
from ...module import FeatureAttribute
from ..smartmodule import SmartModule from ..smartmodule import SmartModule
@ -46,11 +49,13 @@ class Fan(SmartModule, FanInterface):
return {} return {}
@property @property
def fan_speed_level(self) -> int: def fan_speed_level(self) -> Annotated[int, FeatureAttribute()]:
"""Return fan speed level.""" """Return fan speed level."""
return 0 if self.data["device_on"] is False else self.data["fan_speed_level"] return 0 if self.data["device_on"] is False else self.data["fan_speed_level"]
async def set_fan_speed_level(self, level: int) -> dict: async def set_fan_speed_level(
self, level: int
) -> Annotated[dict, FeatureAttribute()]:
"""Set fan speed level, 0 for off, 1-4 for on.""" """Set fan speed level, 0 for off, 1-4 for on."""
if level < 0 or level > 4: if level < 0 or level > 4:
raise ValueError("Invalid level, should be in range 0-4.") raise ValueError("Invalid level, should be in range 0-4.")
@ -61,11 +66,11 @@ class Fan(SmartModule, FanInterface):
) )
@property @property
def sleep_mode(self) -> bool: def sleep_mode(self) -> Annotated[bool, FeatureAttribute()]:
"""Return sleep mode status.""" """Return sleep mode status."""
return self.data["fan_sleep_mode_on"] return self.data["fan_sleep_mode_on"]
async def set_sleep_mode(self, on: bool) -> dict: async def set_sleep_mode(self, on: bool) -> Annotated[dict, FeatureAttribute()]:
"""Set sleep mode.""" """Set sleep mode."""
return await self.call("set_device_info", {"fan_sleep_mode_on": on}) return await self.call("set_device_info", {"fan_sleep_mode_on": on})

View File

@ -1,8 +1,9 @@
import pytest import pytest
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
from kasa import Module from kasa import KasaException, Module
from kasa.smart import SmartDevice from kasa.smart import SmartDevice
from kasa.smart.modules import Fan
from ...device_fixtures import get_parent_and_child_modules, parametrize from ...device_fixtures import get_parent_and_child_modules, parametrize
@ -77,8 +78,42 @@ async def test_fan_module(dev: SmartDevice, mocker: MockerFixture):
await dev.update() await dev.update()
assert not device.is_on assert not device.is_on
fan_speed_level_feature = fan._module_features["fan_speed_level"]
max_level = fan_speed_level_feature.maximum_value
min_level = fan_speed_level_feature.minimum_value
with pytest.raises(ValueError, match="Invalid level"): with pytest.raises(ValueError, match="Invalid level"):
await fan.set_fan_speed_level(-1) await fan.set_fan_speed_level(min_level - 1)
with pytest.raises(ValueError, match="Invalid level"): with pytest.raises(ValueError, match="Invalid level"):
await fan.set_fan_speed_level(5) await fan.set_fan_speed_level(max_level - 5)
@fan
async def test_fan_features(dev: SmartDevice, mocker: MockerFixture):
"""Test fan speed on device interface."""
assert isinstance(dev, SmartDevice)
fan = next(get_parent_and_child_modules(dev, Module.Fan))
assert fan
expected_feature = fan._module_features["fan_speed_level"]
fan_speed_level_feature = fan.get_feature(Fan.set_fan_speed_level)
assert expected_feature == fan_speed_level_feature
fan_speed_level_feature = fan.get_feature(fan.set_fan_speed_level)
assert expected_feature == fan_speed_level_feature
fan_speed_level_feature = fan.get_feature(Fan.fan_speed_level)
assert expected_feature == fan_speed_level_feature
fan_speed_level_feature = fan.get_feature("fan_speed_level")
assert expected_feature == fan_speed_level_feature
assert fan.has_feature(Fan.fan_speed_level)
msg = "Attribute _check_supported of module Fan is not bound to a feature"
with pytest.raises(KasaException, match=msg):
assert fan.has_feature(fan._check_supported)
msg = "No attribute named foobar in module Fan"
with pytest.raises(KasaException, match=msg):
assert fan.has_feature("foobar")