Allow getting Annotated features from modules (#1018)

Co-authored-by: Teemu R. <tpr@iki.fi>
This commit is contained in:
Steven B. 2024-11-22 07:52:23 +00:00 committed by GitHub
parent cae9decb02
commit 37cc4da7b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 119 additions and 7 deletions

View File

@ -42,10 +42,13 @@ from __future__ import annotations
import logging
from abc import ABC, abstractmethod
from collections.abc import Callable
from functools import cache
from typing import (
TYPE_CHECKING,
Final,
TypeVar,
get_type_hints,
)
from .exceptions import KasaException
@ -64,6 +67,10 @@ _LOGGER = logging.getLogger(__name__)
ModuleT = TypeVar("ModuleT", bound="Module")
class FeatureAttribute:
"""Class for annotating attributes bound to feature."""
class Module(ABC):
"""Base class implemention for all modules.
@ -140,6 +147,14 @@ class Module(ABC):
self._module = module
self._module_features: dict[str, Feature] = {}
def has_feature(self, attribute: str | property | Callable) -> bool:
"""Return True if the module attribute feature is supported."""
return bool(self.get_feature(attribute))
def get_feature(self, attribute: str | property | Callable) -> Feature | None:
"""Get Feature for a module attribute or None if not supported."""
return _get_bound_feature(self, attribute)
@abstractmethod
def query(self) -> dict:
"""Query to execute during the update cycle.
@ -183,3 +198,60 @@ class Module(ABC):
f"<Module {self.__class__.__name__} ({self._module})"
f" for {self._device.host}>"
)
def _is_bound_feature(attribute: property | Callable) -> bool:
"""Check if an attribute is bound to a feature with FeatureAttribute."""
if isinstance(attribute, property):
hints = get_type_hints(attribute.fget, include_extras=True)
else:
hints = get_type_hints(attribute, include_extras=True)
if (return_hints := hints.get("return")) and hasattr(return_hints, "__metadata__"):
metadata = hints["return"].__metadata__
for meta in metadata:
if isinstance(meta, FeatureAttribute):
return True
return False
@cache
def _get_bound_feature(
module: Module, attribute: str | property | Callable
) -> Feature | None:
"""Get Feature for a bound property or None if not supported."""
if not isinstance(attribute, str):
if isinstance(attribute, property):
# Properties have __name__ in 3.13 so this could be simplified
# when only 3.13 supported
attribute_name = attribute.fget.__name__ # type: ignore[union-attr]
else:
attribute_name = attribute.__name__
attribute_callable = attribute
else:
if TYPE_CHECKING:
assert isinstance(attribute, str)
attribute_name = attribute
attribute_callable = getattr(module.__class__, attribute, None) # type: ignore[assignment]
if not attribute_callable:
raise KasaException(
f"No attribute named {attribute_name} in "
f"module {module.__class__.__name__}"
)
if not _is_bound_feature(attribute_callable):
raise KasaException(
f"Attribute {attribute_name} of module {module.__class__.__name__}"
" is not bound to a feature"
)
check = {attribute_name, attribute_callable}
for feature in module._module_features.values():
if (getter := feature.attribute_getter) and getter in check:
return feature
if (setter := feature.attribute_setter) and setter in check:
return feature
return None

View File

@ -2,8 +2,11 @@
from __future__ import annotations
from typing import Annotated
from ...feature import Feature
from ...interfaces.fan import Fan as FanInterface
from ...module import FeatureAttribute
from ..smartmodule import SmartModule
@ -46,11 +49,13 @@ class Fan(SmartModule, FanInterface):
return {}
@property
def fan_speed_level(self) -> int:
def fan_speed_level(self) -> Annotated[int, FeatureAttribute()]:
"""Return fan speed level."""
return 0 if self.data["device_on"] is False else self.data["fan_speed_level"]
async def set_fan_speed_level(self, level: int) -> dict:
async def set_fan_speed_level(
self, level: int
) -> Annotated[dict, FeatureAttribute()]:
"""Set fan speed level, 0 for off, 1-4 for on."""
if level < 0 or level > 4:
raise ValueError("Invalid level, should be in range 0-4.")
@ -61,11 +66,11 @@ class Fan(SmartModule, FanInterface):
)
@property
def sleep_mode(self) -> bool:
def sleep_mode(self) -> Annotated[bool, FeatureAttribute()]:
"""Return sleep mode status."""
return self.data["fan_sleep_mode_on"]
async def set_sleep_mode(self, on: bool) -> dict:
async def set_sleep_mode(self, on: bool) -> Annotated[dict, FeatureAttribute()]:
"""Set sleep mode."""
return await self.call("set_device_info", {"fan_sleep_mode_on": on})

View File

@ -1,8 +1,9 @@
import pytest
from pytest_mock import MockerFixture
from kasa import Module
from kasa import KasaException, Module
from kasa.smart import SmartDevice
from kasa.smart.modules import Fan
from ...device_fixtures import get_parent_and_child_modules, parametrize
@ -77,8 +78,42 @@ async def test_fan_module(dev: SmartDevice, mocker: MockerFixture):
await dev.update()
assert not device.is_on
fan_speed_level_feature = fan._module_features["fan_speed_level"]
max_level = fan_speed_level_feature.maximum_value
min_level = fan_speed_level_feature.minimum_value
with pytest.raises(ValueError, match="Invalid level"):
await fan.set_fan_speed_level(-1)
await fan.set_fan_speed_level(min_level - 1)
with pytest.raises(ValueError, match="Invalid level"):
await fan.set_fan_speed_level(5)
await fan.set_fan_speed_level(max_level - 5)
@fan
async def test_fan_features(dev: SmartDevice, mocker: MockerFixture):
"""Test fan speed on device interface."""
assert isinstance(dev, SmartDevice)
fan = next(get_parent_and_child_modules(dev, Module.Fan))
assert fan
expected_feature = fan._module_features["fan_speed_level"]
fan_speed_level_feature = fan.get_feature(Fan.set_fan_speed_level)
assert expected_feature == fan_speed_level_feature
fan_speed_level_feature = fan.get_feature(fan.set_fan_speed_level)
assert expected_feature == fan_speed_level_feature
fan_speed_level_feature = fan.get_feature(Fan.fan_speed_level)
assert expected_feature == fan_speed_level_feature
fan_speed_level_feature = fan.get_feature("fan_speed_level")
assert expected_feature == fan_speed_level_feature
assert fan.has_feature(Fan.fan_speed_level)
msg = "Attribute _check_supported of module Fan is not bound to a feature"
with pytest.raises(KasaException, match=msg):
assert fan.has_feature(fan._check_supported)
msg = "No attribute named foobar in module Fan"
with pytest.raises(KasaException, match=msg):
assert fan.has_feature("foobar")