diff --git a/kasa/interfaces/energy.py b/kasa/interfaces/energy.py index c57a3ed8..b6cc203f 100644 --- a/kasa/interfaces/energy.py +++ b/kasa/interfaces/energy.py @@ -28,7 +28,7 @@ class Energy(Module, ABC): _supported: ModuleFeature = ModuleFeature(0) - def supports(self, module_feature: ModuleFeature) -> bool: + def supports(self, module_feature: Energy.ModuleFeature) -> bool: """Return True if module supports the feature.""" return module_feature in self._supported diff --git a/kasa/smart/smartmodule.py b/kasa/smart/smartmodule.py index 243852e0..91efa33d 100644 --- a/kasa/smart/smartmodule.py +++ b/kasa/smart/smartmodule.py @@ -3,7 +3,8 @@ from __future__ import annotations import logging -from collections.abc import Awaitable, Callable, Coroutine +from collections.abc import Callable, Coroutine +from functools import wraps from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar from ..exceptions import DeviceError, KasaException, SmartErrorCode @@ -20,15 +21,16 @@ _R = TypeVar("_R") def allow_update_after( - func: Callable[Concatenate[_T, _P], Awaitable[dict]], -) -> Callable[Concatenate[_T, _P], Coroutine[Any, Any, dict]]: + func: Callable[Concatenate[_T, _P], Coroutine[Any, Any, _R]], +) -> Callable[Concatenate[_T, _P], Coroutine[Any, Any, _R]]: """Define a wrapper to set _last_update_time to None. This will ensure that a module is updated in the next update cycle after a value has been changed. """ - async def _async_wrap(self: _T, *args: _P.args, **kwargs: _P.kwargs) -> dict: + @wraps(func) + async def _async_wrap(self: _T, *args: _P.args, **kwargs: _P.kwargs) -> _R: try: return await func(self, *args, **kwargs) finally: @@ -40,6 +42,7 @@ def allow_update_after( def raise_if_update_error(func: Callable[[_T], _R]) -> Callable[[_T], _R]: """Define a wrapper to raise an error if the last module update was an error.""" + @wraps(func) def _wrap(self: _T) -> _R: if err := self._last_update_error: raise err diff --git a/kasa/smartcam/modules/alarm.py b/kasa/smartcam/modules/alarm.py index 18833d82..df1891ec 100644 --- a/kasa/smartcam/modules/alarm.py +++ b/kasa/smartcam/modules/alarm.py @@ -2,8 +2,11 @@ from __future__ import annotations +from typing import Annotated + from ...feature import Feature from ...interfaces import Alarm as AlarmInterface +from ...module import FeatureAttribute from ...smart.smartmodule import allow_update_after from ..smartcammodule import SmartCamModule @@ -105,12 +108,12 @@ class Alarm(SmartCamModule, AlarmInterface): ) @property - def alarm_sound(self) -> str: + def alarm_sound(self) -> Annotated[str, FeatureAttribute()]: """Return current alarm sound.""" return self.data["getSirenConfig"]["siren_type"] @allow_update_after - async def set_alarm_sound(self, sound: str) -> dict: + async def set_alarm_sound(self, sound: str) -> Annotated[dict, FeatureAttribute()]: """Set alarm sound. See *alarm_sounds* for list of available sounds. @@ -124,7 +127,7 @@ class Alarm(SmartCamModule, AlarmInterface): return self.data["getSirenTypeList"]["siren_type_list"] @property - def alarm_volume(self) -> int: + def alarm_volume(self) -> Annotated[int, FeatureAttribute()]: """Return alarm volume. Unlike duration the device expects/returns a string for volume. @@ -132,18 +135,22 @@ class Alarm(SmartCamModule, AlarmInterface): return int(self.data["getSirenConfig"]["volume"]) @allow_update_after - async def set_alarm_volume(self, volume: int) -> dict: + async def set_alarm_volume( + self, volume: int + ) -> Annotated[dict, FeatureAttribute()]: """Set alarm volume.""" config = self._validate_and_get_config(volume=volume) return await self.call("setSirenConfig", {"siren": config}) @property - def alarm_duration(self) -> int: + def alarm_duration(self) -> Annotated[int, FeatureAttribute()]: """Return alarm duration.""" return self.data["getSirenConfig"]["duration"] @allow_update_after - async def set_alarm_duration(self, duration: int) -> dict: + async def set_alarm_duration( + self, duration: int + ) -> Annotated[dict, FeatureAttribute()]: """Set alarm volume.""" config = self._validate_and_get_config(duration=duration) return await self.call("setSirenConfig", {"siren": config}) diff --git a/tests/test_common_modules.py b/tests/test_common_modules.py index 3b1d8988..869ba27d 100644 --- a/tests/test_common_modules.py +++ b/tests/test_common_modules.py @@ -1,10 +1,16 @@ +import importlib +import inspect +import pkgutil +import sys from datetime import datetime from zoneinfo import ZoneInfo import pytest from pytest_mock import MockerFixture +import kasa.interfaces from kasa import Device, LightState, Module, ThermostatState +from kasa.module import _get_feature_attribute from .device_fixtures import ( bulb_iot, @@ -64,6 +70,57 @@ temp_control_smart = parametrize( ) +interfaces = pytest.mark.parametrize("interface", kasa.interfaces.__all__) + + +def _get_subclasses(of_class, package): + """Get all the subclasses of a given class.""" + subclasses = set() + # iter_modules returns ModuleInfo: (module_finder, name, ispkg) + for _, modname, ispkg in pkgutil.iter_modules(package.__path__): + importlib.import_module("." + modname, package=package.__name__) + module = sys.modules[package.__name__ + "." + modname] + for _, obj in inspect.getmembers(module): + if ( + inspect.isclass(obj) + and issubclass(obj, of_class) + and obj is not of_class + ): + subclasses.add(obj) + + if ispkg: + res = _get_subclasses(of_class, module) + subclasses.update(res) + + return subclasses + + +@interfaces +def test_feature_attributes(interface): + """Test that all common derived classes define the FeatureAttributes.""" + klass = getattr(kasa.interfaces, interface) + + package = sys.modules["kasa"] + sub_classes = _get_subclasses(klass, package) + + feat_attributes: set[str] = set() + attribute_names = [ + k + for k, v in vars(klass).items() + if (callable(v) and not inspect.isclass(v)) or isinstance(v, property) + ] + for attr_name in attribute_names: + attribute = getattr(klass, attr_name) + if _get_feature_attribute(attribute): + feat_attributes.add(attr_name) + + for sub_class in sub_classes: + for attr_name in feat_attributes: + attribute = getattr(sub_class, attr_name) + fa = _get_feature_attribute(attribute) + assert fa, f"{attr_name} is not a defined module feature for {sub_class}" + + @led async def test_led_module(dev: Device, mocker: MockerFixture): """Test fan speed feature."""