From 37cc4da7b6b536df3e67bda99d9ccd2532a58472 Mon Sep 17 00:00:00 2001 From: "Steven B." <51370195+sdb9696@users.noreply.github.com> Date: Fri, 22 Nov 2024 07:52:23 +0000 Subject: [PATCH] Allow getting Annotated features from modules (#1018) Co-authored-by: Teemu R. --- kasa/module.py | 72 +++++++++++++++++++++++++++++++++ kasa/smart/modules/fan.py | 13 ++++-- tests/smart/modules/test_fan.py | 41 +++++++++++++++++-- 3 files changed, 119 insertions(+), 7 deletions(-) diff --git a/kasa/module.py b/kasa/module.py index f3d0dade..ccd22d4e 100644 --- a/kasa/module.py +++ b/kasa/module.py @@ -42,10 +42,13 @@ from __future__ import annotations import logging from abc import ABC, abstractmethod +from collections.abc import Callable +from functools import cache from typing import ( TYPE_CHECKING, Final, TypeVar, + get_type_hints, ) from .exceptions import KasaException @@ -64,6 +67,10 @@ _LOGGER = logging.getLogger(__name__) ModuleT = TypeVar("ModuleT", bound="Module") +class FeatureAttribute: + """Class for annotating attributes bound to feature.""" + + class Module(ABC): """Base class implemention for all modules. @@ -140,6 +147,14 @@ class Module(ABC): self._module = module self._module_features: dict[str, Feature] = {} + def has_feature(self, attribute: str | property | Callable) -> bool: + """Return True if the module attribute feature is supported.""" + return bool(self.get_feature(attribute)) + + def get_feature(self, attribute: str | property | Callable) -> Feature | None: + """Get Feature for a module attribute or None if not supported.""" + return _get_bound_feature(self, attribute) + @abstractmethod def query(self) -> dict: """Query to execute during the update cycle. @@ -183,3 +198,60 @@ class Module(ABC): f"" ) + + +def _is_bound_feature(attribute: property | Callable) -> bool: + """Check if an attribute is bound to a feature with FeatureAttribute.""" + if isinstance(attribute, property): + hints = get_type_hints(attribute.fget, include_extras=True) + else: + hints = get_type_hints(attribute, include_extras=True) + + if (return_hints := hints.get("return")) and hasattr(return_hints, "__metadata__"): + metadata = hints["return"].__metadata__ + for meta in metadata: + if isinstance(meta, FeatureAttribute): + return True + + return False + + +@cache +def _get_bound_feature( + module: Module, attribute: str | property | Callable +) -> Feature | None: + """Get Feature for a bound property or None if not supported.""" + if not isinstance(attribute, str): + if isinstance(attribute, property): + # Properties have __name__ in 3.13 so this could be simplified + # when only 3.13 supported + attribute_name = attribute.fget.__name__ # type: ignore[union-attr] + else: + attribute_name = attribute.__name__ + attribute_callable = attribute + else: + if TYPE_CHECKING: + assert isinstance(attribute, str) + attribute_name = attribute + attribute_callable = getattr(module.__class__, attribute, None) # type: ignore[assignment] + if not attribute_callable: + raise KasaException( + f"No attribute named {attribute_name} in " + f"module {module.__class__.__name__}" + ) + + if not _is_bound_feature(attribute_callable): + raise KasaException( + f"Attribute {attribute_name} of module {module.__class__.__name__}" + " is not bound to a feature" + ) + + check = {attribute_name, attribute_callable} + for feature in module._module_features.values(): + if (getter := feature.attribute_getter) and getter in check: + return feature + + if (setter := feature.attribute_setter) and setter in check: + return feature + + return None diff --git a/kasa/smart/modules/fan.py b/kasa/smart/modules/fan.py index 36b3aadf..6443cbac 100644 --- a/kasa/smart/modules/fan.py +++ b/kasa/smart/modules/fan.py @@ -2,8 +2,11 @@ from __future__ import annotations +from typing import Annotated + from ...feature import Feature from ...interfaces.fan import Fan as FanInterface +from ...module import FeatureAttribute from ..smartmodule import SmartModule @@ -46,11 +49,13 @@ class Fan(SmartModule, FanInterface): return {} @property - def fan_speed_level(self) -> int: + def fan_speed_level(self) -> Annotated[int, FeatureAttribute()]: """Return fan speed level.""" return 0 if self.data["device_on"] is False else self.data["fan_speed_level"] - async def set_fan_speed_level(self, level: int) -> dict: + async def set_fan_speed_level( + self, level: int + ) -> Annotated[dict, FeatureAttribute()]: """Set fan speed level, 0 for off, 1-4 for on.""" if level < 0 or level > 4: raise ValueError("Invalid level, should be in range 0-4.") @@ -61,11 +66,11 @@ class Fan(SmartModule, FanInterface): ) @property - def sleep_mode(self) -> bool: + def sleep_mode(self) -> Annotated[bool, FeatureAttribute()]: """Return sleep mode status.""" return self.data["fan_sleep_mode_on"] - async def set_sleep_mode(self, on: bool) -> dict: + async def set_sleep_mode(self, on: bool) -> Annotated[dict, FeatureAttribute()]: """Set sleep mode.""" return await self.call("set_device_info", {"fan_sleep_mode_on": on}) diff --git a/tests/smart/modules/test_fan.py b/tests/smart/modules/test_fan.py index a032794c..9a6878e5 100644 --- a/tests/smart/modules/test_fan.py +++ b/tests/smart/modules/test_fan.py @@ -1,8 +1,9 @@ import pytest from pytest_mock import MockerFixture -from kasa import Module +from kasa import KasaException, Module from kasa.smart import SmartDevice +from kasa.smart.modules import Fan from ...device_fixtures import get_parent_and_child_modules, parametrize @@ -77,8 +78,42 @@ async def test_fan_module(dev: SmartDevice, mocker: MockerFixture): await dev.update() assert not device.is_on + fan_speed_level_feature = fan._module_features["fan_speed_level"] + max_level = fan_speed_level_feature.maximum_value + min_level = fan_speed_level_feature.minimum_value with pytest.raises(ValueError, match="Invalid level"): - await fan.set_fan_speed_level(-1) + await fan.set_fan_speed_level(min_level - 1) with pytest.raises(ValueError, match="Invalid level"): - await fan.set_fan_speed_level(5) + await fan.set_fan_speed_level(max_level - 5) + + +@fan +async def test_fan_features(dev: SmartDevice, mocker: MockerFixture): + """Test fan speed on device interface.""" + assert isinstance(dev, SmartDevice) + fan = next(get_parent_and_child_modules(dev, Module.Fan)) + assert fan + expected_feature = fan._module_features["fan_speed_level"] + + fan_speed_level_feature = fan.get_feature(Fan.set_fan_speed_level) + assert expected_feature == fan_speed_level_feature + + fan_speed_level_feature = fan.get_feature(fan.set_fan_speed_level) + assert expected_feature == fan_speed_level_feature + + fan_speed_level_feature = fan.get_feature(Fan.fan_speed_level) + assert expected_feature == fan_speed_level_feature + + fan_speed_level_feature = fan.get_feature("fan_speed_level") + assert expected_feature == fan_speed_level_feature + + assert fan.has_feature(Fan.fan_speed_level) + + msg = "Attribute _check_supported of module Fan is not bound to a feature" + with pytest.raises(KasaException, match=msg): + assert fan.has_feature(fan._check_supported) + + msg = "No attribute named foobar in module Fan" + with pytest.raises(KasaException, match=msg): + assert fan.has_feature("foobar")