From 44c561b04d77dd590f235118929f889da4f3b80e Mon Sep 17 00:00:00 2001
From: "Steven B." <51370195+sdb9696@users.noreply.github.com>
Date: Wed, 29 Jan 2025 19:32:01 +0000
Subject: [PATCH] Add FeatureAttributes to smartcam Alarm (#1489)

Co-authored-by: Teemu R. <tpr@iki.fi>
---
 kasa/interfaces/energy.py      |  2 +-
 kasa/smart/smartmodule.py      | 11 ++++---
 kasa/smartcam/modules/alarm.py | 19 ++++++++----
 tests/test_common_modules.py   | 57 ++++++++++++++++++++++++++++++++++
 4 files changed, 78 insertions(+), 11 deletions(-)

diff --git a/kasa/interfaces/energy.py b/kasa/interfaces/energy.py
index c57a3ed8..b6cc203f 100644
--- a/kasa/interfaces/energy.py
+++ b/kasa/interfaces/energy.py
@@ -28,7 +28,7 @@ class Energy(Module, ABC):
 
     _supported: ModuleFeature = ModuleFeature(0)
 
-    def supports(self, module_feature: ModuleFeature) -> bool:
+    def supports(self, module_feature: Energy.ModuleFeature) -> bool:
         """Return True if module supports the feature."""
         return module_feature in self._supported
 
diff --git a/kasa/smart/smartmodule.py b/kasa/smart/smartmodule.py
index 243852e0..91efa33d 100644
--- a/kasa/smart/smartmodule.py
+++ b/kasa/smart/smartmodule.py
@@ -3,7 +3,8 @@
 from __future__ import annotations
 
 import logging
-from collections.abc import Awaitable, Callable, Coroutine
+from collections.abc import Callable, Coroutine
+from functools import wraps
 from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar
 
 from ..exceptions import DeviceError, KasaException, SmartErrorCode
@@ -20,15 +21,16 @@ _R = TypeVar("_R")
 
 
 def allow_update_after(
-    func: Callable[Concatenate[_T, _P], Awaitable[dict]],
-) -> Callable[Concatenate[_T, _P], Coroutine[Any, Any, dict]]:
+    func: Callable[Concatenate[_T, _P], Coroutine[Any, Any, _R]],
+) -> Callable[Concatenate[_T, _P], Coroutine[Any, Any, _R]]:
     """Define a wrapper to set _last_update_time to None.
 
     This will ensure that a module is updated in the next update cycle after
     a value has been changed.
     """
 
-    async def _async_wrap(self: _T, *args: _P.args, **kwargs: _P.kwargs) -> dict:
+    @wraps(func)
+    async def _async_wrap(self: _T, *args: _P.args, **kwargs: _P.kwargs) -> _R:
         try:
             return await func(self, *args, **kwargs)
         finally:
@@ -40,6 +42,7 @@ def allow_update_after(
 def raise_if_update_error(func: Callable[[_T], _R]) -> Callable[[_T], _R]:
     """Define a wrapper to raise an error if the last module update was an error."""
 
+    @wraps(func)
     def _wrap(self: _T) -> _R:
         if err := self._last_update_error:
             raise err
diff --git a/kasa/smartcam/modules/alarm.py b/kasa/smartcam/modules/alarm.py
index 18833d82..df1891ec 100644
--- a/kasa/smartcam/modules/alarm.py
+++ b/kasa/smartcam/modules/alarm.py
@@ -2,8 +2,11 @@
 
 from __future__ import annotations
 
+from typing import Annotated
+
 from ...feature import Feature
 from ...interfaces import Alarm as AlarmInterface
+from ...module import FeatureAttribute
 from ...smart.smartmodule import allow_update_after
 from ..smartcammodule import SmartCamModule
 
@@ -105,12 +108,12 @@ class Alarm(SmartCamModule, AlarmInterface):
         )
 
     @property
-    def alarm_sound(self) -> str:
+    def alarm_sound(self) -> Annotated[str, FeatureAttribute()]:
         """Return current alarm sound."""
         return self.data["getSirenConfig"]["siren_type"]
 
     @allow_update_after
-    async def set_alarm_sound(self, sound: str) -> dict:
+    async def set_alarm_sound(self, sound: str) -> Annotated[dict, FeatureAttribute()]:
         """Set alarm sound.
 
         See *alarm_sounds* for list of available sounds.
@@ -124,7 +127,7 @@ class Alarm(SmartCamModule, AlarmInterface):
         return self.data["getSirenTypeList"]["siren_type_list"]
 
     @property
-    def alarm_volume(self) -> int:
+    def alarm_volume(self) -> Annotated[int, FeatureAttribute()]:
         """Return alarm volume.
 
         Unlike duration the device expects/returns a string for volume.
@@ -132,18 +135,22 @@ class Alarm(SmartCamModule, AlarmInterface):
         return int(self.data["getSirenConfig"]["volume"])
 
     @allow_update_after
-    async def set_alarm_volume(self, volume: int) -> dict:
+    async def set_alarm_volume(
+        self, volume: int
+    ) -> Annotated[dict, FeatureAttribute()]:
         """Set alarm volume."""
         config = self._validate_and_get_config(volume=volume)
         return await self.call("setSirenConfig", {"siren": config})
 
     @property
-    def alarm_duration(self) -> int:
+    def alarm_duration(self) -> Annotated[int, FeatureAttribute()]:
         """Return alarm duration."""
         return self.data["getSirenConfig"]["duration"]
 
     @allow_update_after
-    async def set_alarm_duration(self, duration: int) -> dict:
+    async def set_alarm_duration(
+        self, duration: int
+    ) -> Annotated[dict, FeatureAttribute()]:
         """Set alarm volume."""
         config = self._validate_and_get_config(duration=duration)
         return await self.call("setSirenConfig", {"siren": config})
diff --git a/tests/test_common_modules.py b/tests/test_common_modules.py
index 3b1d8988..869ba27d 100644
--- a/tests/test_common_modules.py
+++ b/tests/test_common_modules.py
@@ -1,10 +1,16 @@
+import importlib
+import inspect
+import pkgutil
+import sys
 from datetime import datetime
 from zoneinfo import ZoneInfo
 
 import pytest
 from pytest_mock import MockerFixture
 
+import kasa.interfaces
 from kasa import Device, LightState, Module, ThermostatState
+from kasa.module import _get_feature_attribute
 
 from .device_fixtures import (
     bulb_iot,
@@ -64,6 +70,57 @@ temp_control_smart = parametrize(
 )
 
 
+interfaces = pytest.mark.parametrize("interface", kasa.interfaces.__all__)
+
+
+def _get_subclasses(of_class, package):
+    """Get all the subclasses of a given class."""
+    subclasses = set()
+    # iter_modules returns ModuleInfo: (module_finder, name, ispkg)
+    for _, modname, ispkg in pkgutil.iter_modules(package.__path__):
+        importlib.import_module("." + modname, package=package.__name__)
+        module = sys.modules[package.__name__ + "." + modname]
+        for _, obj in inspect.getmembers(module):
+            if (
+                inspect.isclass(obj)
+                and issubclass(obj, of_class)
+                and obj is not of_class
+            ):
+                subclasses.add(obj)
+
+        if ispkg:
+            res = _get_subclasses(of_class, module)
+            subclasses.update(res)
+
+    return subclasses
+
+
+@interfaces
+def test_feature_attributes(interface):
+    """Test that all common derived classes define the FeatureAttributes."""
+    klass = getattr(kasa.interfaces, interface)
+
+    package = sys.modules["kasa"]
+    sub_classes = _get_subclasses(klass, package)
+
+    feat_attributes: set[str] = set()
+    attribute_names = [
+        k
+        for k, v in vars(klass).items()
+        if (callable(v) and not inspect.isclass(v)) or isinstance(v, property)
+    ]
+    for attr_name in attribute_names:
+        attribute = getattr(klass, attr_name)
+        if _get_feature_attribute(attribute):
+            feat_attributes.add(attr_name)
+
+    for sub_class in sub_classes:
+        for attr_name in feat_attributes:
+            attribute = getattr(sub_class, attr_name)
+            fa = _get_feature_attribute(attribute)
+            assert fa, f"{attr_name} is not a defined module feature for {sub_class}"
+
+
 @led
 async def test_led_module(dev: Device, mocker: MockerFixture):
     """Test fan speed feature."""