From 86e1f5ae383fe562fe657a604d6ad19fea770922 Mon Sep 17 00:00:00 2001 From: Steven B <51370195+sdb9696@users.noreply.github.com> Date: Thu, 23 Jan 2025 17:34:00 +0000 Subject: [PATCH] Apply suggestions from code review --- kasa/module.py | 16 +++-- kasa/smart/modules/powerprotection.py | 57 ++++++++++------- tests/smart/modules/test_powerprotection.py | 70 ++++++++++++++------- 3 files changed, 92 insertions(+), 51 deletions(-) diff --git a/kasa/module.py b/kasa/module.py index 0e098a1c..a21e69e4 100644 --- a/kasa/module.py +++ b/kasa/module.py @@ -81,6 +81,9 @@ ModuleT = TypeVar("ModuleT", bound="Module") class FeatureAttribute: """Class for annotating attributes bound to feature.""" + def __init__(self, feature_name: str | None = None) -> None: + self.feature_name = feature_name + def __repr__(self) -> str: return "FeatureAttribute" @@ -237,7 +240,7 @@ class Module(ABC): ) -def _is_bound_feature(attribute: property | Callable) -> bool: +def _get_feature_attribute(attribute: property | Callable) -> FeatureAttribute | None: """Check if an attribute is bound to a feature with FeatureAttribute.""" if isinstance(attribute, property): hints = get_type_hints(attribute.fget, include_extras=True) @@ -248,9 +251,9 @@ def _is_bound_feature(attribute: property | Callable) -> bool: metadata = hints["return"].__metadata__ for meta in metadata: if isinstance(meta, FeatureAttribute): - return True + return meta - return False + return None @cache @@ -277,12 +280,17 @@ def _get_bound_feature( f"module {module.__class__.__name__}" ) - if not _is_bound_feature(attribute_callable): + if not (fa := _get_feature_attribute(attribute_callable)): raise KasaException( f"Attribute {attribute_name} of module {module.__class__.__name__}" " is not bound to a feature" ) + # If a feature_name was passed to the FeatureAttribute use that to check + # for the feature. Otherwise check the getters and setters in the features + if fa.feature_name: + return module._all_features.get(fa.feature_name) + check = {attribute_name, attribute_callable} for feature in module._all_features.values(): if (getter := feature.attribute_getter) and getter in check: diff --git a/kasa/smart/modules/powerprotection.py b/kasa/smart/modules/powerprotection.py index 095bfbe1..b4d9c988 100644 --- a/kasa/smart/modules/powerprotection.py +++ b/kasa/smart/modules/powerprotection.py @@ -2,7 +2,10 @@ from __future__ import annotations +from typing import Annotated + from ...feature import Feature +from ...module import FeatureAttribute from ..smartmodule import SmartModule @@ -24,36 +27,28 @@ class PowerProtection(SmartModule): category=Feature.Category.Info, ) ) - self._add_feature( - Feature( - device=self._device, - id="power_protection_enabled", - name="Power protection enabled", - container=self, - attribute_getter="enabled", - attribute_setter="set_enabled", - type=Feature.Type.Switch, - category=Feature.Category.Config, - ) - ) self._add_feature( Feature( device=self._device, id="power_protection_threshold", name="Power protection threshold", container=self, - attribute_getter="protection_threshold", - attribute_setter="set_protection_threshold", + attribute_getter=lambda x: self.protection_threshold + if self.enabled + else 0, + attribute_setter=lambda x: self.set_enabled(False) + if x == 0 + else self.set_enabled(True, threshold=x), unit_getter=lambda: "W", type=Feature.Type.Number, - range_getter="protection_threshold_range", + range_getter=lambda: (0, self._max_power), category=Feature.Category.Config, ) ) def query(self) -> dict: """Query to execute during the update cycle.""" - return {"get_protection_power": None, "get_max_power": None} + return {"get_protection_power": {}, "get_max_power": {}} @property def overloaded(self) -> bool: @@ -68,25 +63,41 @@ class PowerProtection(SmartModule): """Return True if child protection is enabled.""" return self.data["get_protection_power"]["enabled"] - async def set_enabled(self, enabled: bool) -> dict: - """Set child protection.""" + async def set_enabled(self, enabled: bool, *, threshold: int | None = None) -> dict: + """Set child protection. + + If power protection has never been enabled before the threshold will + be 0 so if threshold is not provided it will be set to half the max. + """ + if threshold is None and enabled and self.protection_threshold == 0: + threshold = int(self._max_power / 2) + + if threshold and (threshold < 0 or threshold > self._max_power): + raise ValueError( + "Threshold out of range: %s (%s)", threshold, self.protection_threshold + ) + params = {**self.data["get_protection_power"], "enabled": enabled} + if threshold is not None: + params["protection_power"] = threshold return await self.call("set_protection_power", params) @property - def protection_threshold_range(self) -> tuple[int, int]: - """Return threshold range.""" - return 0, self.data["get_max_power"]["max_power"] + def _max_power(self) -> int: + """Return max power.""" + return self.data["get_max_power"]["max_power"] @property - def protection_threshold(self) -> int: + def protection_threshold( + self, + ) -> Annotated[int, FeatureAttribute("power_protection_threshold")]: """Return protection threshold in watts.""" # If never configured, there is no value set. return self.data["get_protection_power"].get("protection_power", 0) async def set_protection_threshold(self, threshold: int) -> dict: """Set protection threshold.""" - if threshold < 0 or threshold > self.protection_threshold_range[1]: + if threshold < 0 or threshold > self._max_power: raise ValueError( "Threshold out of range: %s (%s)", threshold, self.protection_threshold ) diff --git a/tests/smart/modules/test_powerprotection.py b/tests/smart/modules/test_powerprotection.py index 10e94ede..7f03c0e9 100644 --- a/tests/smart/modules/test_powerprotection.py +++ b/tests/smart/modules/test_powerprotection.py @@ -2,9 +2,8 @@ import pytest from pytest_mock import MockerFixture from kasa import Module, SmartDevice -from kasa.smart.modules import PowerProtection -from ...device_fixtures import parametrize +from ...device_fixtures import get_parent_and_child_modules, parametrize powerprotection = parametrize( "has powerprotection", @@ -13,30 +12,24 @@ powerprotection = parametrize( ) -def _skip_on_unavailable(dev: SmartDevice): - if Module.PowerProtection not in dev.modules: - pytest.skip(f"No powerprotection module on {dev}, maybe a strip parent?") - - @powerprotection @pytest.mark.parametrize( ("feature", "prop_name", "type"), [ ("overloaded", "overloaded", bool), - ("power_protection_enabled", "enabled", bool), ("power_protection_threshold", "protection_threshold", int), ], ) async def test_features(dev, feature, prop_name, type): """Test that features are registered and work as expected.""" - _skip_on_unavailable(dev) - - powerprot: PowerProtection = dev.modules[Module.PowerProtection] + powerprot = next(get_parent_and_child_modules(dev, Module.PowerProtection)) + assert powerprot + device = powerprot._device prop = getattr(powerprot, prop_name) assert isinstance(prop, type) - feat = dev.features[feature] + feat = device.features[feature] assert feat.value == prop assert isinstance(feat.value, type) @@ -44,25 +37,54 @@ async def test_features(dev, feature, prop_name, type): @powerprotection async def test_set_enable(dev: SmartDevice, mocker: MockerFixture): """Test enable.""" - _skip_on_unavailable(dev) + powerprot = next(get_parent_and_child_modules(dev, Module.PowerProtection)) + assert powerprot + device = powerprot._device - powerprot: PowerProtection = dev.modules[Module.PowerProtection] + original_enabled = powerprot.enabled + original_threshold = powerprot.protection_threshold - call_spy = mocker.spy(powerprot, "call") - await powerprot.set_enabled(True) - params = { - "enabled": mocker.ANY, - "protection_power": mocker.ANY, - } - call_spy.assert_called_with("set_protection_power", params) + try: + # Simple enable with an existing threshold + call_spy = mocker.spy(powerprot, "call") + await powerprot.set_enabled(True) + params = { + "enabled": True, + "protection_power": mocker.ANY, + } + call_spy.assert_called_with("set_protection_power", params) + + # Enable with no threshold param when 0 + call_spy.reset_mock() + await powerprot.set_protection_threshold(0) + await device.update() + await powerprot.set_enabled(True) + params = { + "enabled": True, + "protection_power": int(powerprot._max_power / 2), + } + call_spy.assert_called_with("set_protection_power", params) + + # Enable false should not update the threshold + call_spy.reset_mock() + await powerprot.set_protection_threshold(0) + await device.update() + await powerprot.set_enabled(False) + params = { + "enabled": False, + "protection_power": 0, + } + call_spy.assert_called_with("set_protection_power", params) + + finally: + await powerprot.set_enabled(original_enabled, threshold=original_threshold) @powerprotection async def test_set_threshold(dev: SmartDevice, mocker: MockerFixture): """Test enable.""" - _skip_on_unavailable(dev) - - powerprot: PowerProtection = dev.modules[Module.PowerProtection] + powerprot = next(get_parent_and_child_modules(dev, Module.PowerProtection)) + assert powerprot call_spy = mocker.spy(powerprot, "call") await powerprot.set_protection_threshold(123)