Apply suggestions from code review

This commit is contained in:
Steven B 2025-01-23 17:34:00 +00:00
parent 1c1c0dcb8c
commit 86e1f5ae38
No known key found for this signature in database
GPG Key ID: 6D5B46B3679F2A43
3 changed files with 92 additions and 51 deletions

View File

@ -81,6 +81,9 @@ ModuleT = TypeVar("ModuleT", bound="Module")
class FeatureAttribute: class FeatureAttribute:
"""Class for annotating attributes bound to feature.""" """Class for annotating attributes bound to feature."""
def __init__(self, feature_name: str | None = None) -> None:
self.feature_name = feature_name
def __repr__(self) -> str: def __repr__(self) -> str:
return "FeatureAttribute" return "FeatureAttribute"
@ -237,7 +240,7 @@ class Module(ABC):
) )
def _is_bound_feature(attribute: property | Callable) -> bool: def _get_feature_attribute(attribute: property | Callable) -> FeatureAttribute | None:
"""Check if an attribute is bound to a feature with FeatureAttribute.""" """Check if an attribute is bound to a feature with FeatureAttribute."""
if isinstance(attribute, property): if isinstance(attribute, property):
hints = get_type_hints(attribute.fget, include_extras=True) hints = get_type_hints(attribute.fget, include_extras=True)
@ -248,9 +251,9 @@ def _is_bound_feature(attribute: property | Callable) -> bool:
metadata = hints["return"].__metadata__ metadata = hints["return"].__metadata__
for meta in metadata: for meta in metadata:
if isinstance(meta, FeatureAttribute): if isinstance(meta, FeatureAttribute):
return True return meta
return False return None
@cache @cache
@ -277,12 +280,17 @@ def _get_bound_feature(
f"module {module.__class__.__name__}" f"module {module.__class__.__name__}"
) )
if not _is_bound_feature(attribute_callable): if not (fa := _get_feature_attribute(attribute_callable)):
raise KasaException( raise KasaException(
f"Attribute {attribute_name} of module {module.__class__.__name__}" f"Attribute {attribute_name} of module {module.__class__.__name__}"
" is not bound to a feature" " is not bound to a feature"
) )
# If a feature_name was passed to the FeatureAttribute use that to check
# for the feature. Otherwise check the getters and setters in the features
if fa.feature_name:
return module._all_features.get(fa.feature_name)
check = {attribute_name, attribute_callable} check = {attribute_name, attribute_callable}
for feature in module._all_features.values(): for feature in module._all_features.values():
if (getter := feature.attribute_getter) and getter in check: if (getter := feature.attribute_getter) and getter in check:

View File

@ -2,7 +2,10 @@
from __future__ import annotations from __future__ import annotations
from typing import Annotated
from ...feature import Feature from ...feature import Feature
from ...module import FeatureAttribute
from ..smartmodule import SmartModule from ..smartmodule import SmartModule
@ -24,36 +27,28 @@ class PowerProtection(SmartModule):
category=Feature.Category.Info, category=Feature.Category.Info,
) )
) )
self._add_feature(
Feature(
device=self._device,
id="power_protection_enabled",
name="Power protection enabled",
container=self,
attribute_getter="enabled",
attribute_setter="set_enabled",
type=Feature.Type.Switch,
category=Feature.Category.Config,
)
)
self._add_feature( self._add_feature(
Feature( Feature(
device=self._device, device=self._device,
id="power_protection_threshold", id="power_protection_threshold",
name="Power protection threshold", name="Power protection threshold",
container=self, container=self,
attribute_getter="protection_threshold", attribute_getter=lambda x: self.protection_threshold
attribute_setter="set_protection_threshold", if self.enabled
else 0,
attribute_setter=lambda x: self.set_enabled(False)
if x == 0
else self.set_enabled(True, threshold=x),
unit_getter=lambda: "W", unit_getter=lambda: "W",
type=Feature.Type.Number, type=Feature.Type.Number,
range_getter="protection_threshold_range", range_getter=lambda: (0, self._max_power),
category=Feature.Category.Config, category=Feature.Category.Config,
) )
) )
def query(self) -> dict: def query(self) -> dict:
"""Query to execute during the update cycle.""" """Query to execute during the update cycle."""
return {"get_protection_power": None, "get_max_power": None} return {"get_protection_power": {}, "get_max_power": {}}
@property @property
def overloaded(self) -> bool: def overloaded(self) -> bool:
@ -68,25 +63,41 @@ class PowerProtection(SmartModule):
"""Return True if child protection is enabled.""" """Return True if child protection is enabled."""
return self.data["get_protection_power"]["enabled"] return self.data["get_protection_power"]["enabled"]
async def set_enabled(self, enabled: bool) -> dict: async def set_enabled(self, enabled: bool, *, threshold: int | None = None) -> dict:
"""Set child protection.""" """Set child protection.
If power protection has never been enabled before the threshold will
be 0 so if threshold is not provided it will be set to half the max.
"""
if threshold is None and enabled and self.protection_threshold == 0:
threshold = int(self._max_power / 2)
if threshold and (threshold < 0 or threshold > self._max_power):
raise ValueError(
"Threshold out of range: %s (%s)", threshold, self.protection_threshold
)
params = {**self.data["get_protection_power"], "enabled": enabled} params = {**self.data["get_protection_power"], "enabled": enabled}
if threshold is not None:
params["protection_power"] = threshold
return await self.call("set_protection_power", params) return await self.call("set_protection_power", params)
@property @property
def protection_threshold_range(self) -> tuple[int, int]: def _max_power(self) -> int:
"""Return threshold range.""" """Return max power."""
return 0, self.data["get_max_power"]["max_power"] return self.data["get_max_power"]["max_power"]
@property @property
def protection_threshold(self) -> int: def protection_threshold(
self,
) -> Annotated[int, FeatureAttribute("power_protection_threshold")]:
"""Return protection threshold in watts.""" """Return protection threshold in watts."""
# If never configured, there is no value set. # If never configured, there is no value set.
return self.data["get_protection_power"].get("protection_power", 0) return self.data["get_protection_power"].get("protection_power", 0)
async def set_protection_threshold(self, threshold: int) -> dict: async def set_protection_threshold(self, threshold: int) -> dict:
"""Set protection threshold.""" """Set protection threshold."""
if threshold < 0 or threshold > self.protection_threshold_range[1]: if threshold < 0 or threshold > self._max_power:
raise ValueError( raise ValueError(
"Threshold out of range: %s (%s)", threshold, self.protection_threshold "Threshold out of range: %s (%s)", threshold, self.protection_threshold
) )

View File

@ -2,9 +2,8 @@ import pytest
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
from kasa import Module, SmartDevice from kasa import Module, SmartDevice
from kasa.smart.modules import PowerProtection
from ...device_fixtures import parametrize from ...device_fixtures import get_parent_and_child_modules, parametrize
powerprotection = parametrize( powerprotection = parametrize(
"has powerprotection", "has powerprotection",
@ -13,30 +12,24 @@ powerprotection = parametrize(
) )
def _skip_on_unavailable(dev: SmartDevice):
if Module.PowerProtection not in dev.modules:
pytest.skip(f"No powerprotection module on {dev}, maybe a strip parent?")
@powerprotection @powerprotection
@pytest.mark.parametrize( @pytest.mark.parametrize(
("feature", "prop_name", "type"), ("feature", "prop_name", "type"),
[ [
("overloaded", "overloaded", bool), ("overloaded", "overloaded", bool),
("power_protection_enabled", "enabled", bool),
("power_protection_threshold", "protection_threshold", int), ("power_protection_threshold", "protection_threshold", int),
], ],
) )
async def test_features(dev, feature, prop_name, type): async def test_features(dev, feature, prop_name, type):
"""Test that features are registered and work as expected.""" """Test that features are registered and work as expected."""
_skip_on_unavailable(dev) powerprot = next(get_parent_and_child_modules(dev, Module.PowerProtection))
assert powerprot
powerprot: PowerProtection = dev.modules[Module.PowerProtection] device = powerprot._device
prop = getattr(powerprot, prop_name) prop = getattr(powerprot, prop_name)
assert isinstance(prop, type) assert isinstance(prop, type)
feat = dev.features[feature] feat = device.features[feature]
assert feat.value == prop assert feat.value == prop
assert isinstance(feat.value, type) assert isinstance(feat.value, type)
@ -44,25 +37,54 @@ async def test_features(dev, feature, prop_name, type):
@powerprotection @powerprotection
async def test_set_enable(dev: SmartDevice, mocker: MockerFixture): async def test_set_enable(dev: SmartDevice, mocker: MockerFixture):
"""Test enable.""" """Test enable."""
_skip_on_unavailable(dev) powerprot = next(get_parent_and_child_modules(dev, Module.PowerProtection))
assert powerprot
device = powerprot._device
powerprot: PowerProtection = dev.modules[Module.PowerProtection] original_enabled = powerprot.enabled
original_threshold = powerprot.protection_threshold
call_spy = mocker.spy(powerprot, "call") try:
await powerprot.set_enabled(True) # Simple enable with an existing threshold
params = { call_spy = mocker.spy(powerprot, "call")
"enabled": mocker.ANY, await powerprot.set_enabled(True)
"protection_power": mocker.ANY, params = {
} "enabled": True,
call_spy.assert_called_with("set_protection_power", params) "protection_power": mocker.ANY,
}
call_spy.assert_called_with("set_protection_power", params)
# Enable with no threshold param when 0
call_spy.reset_mock()
await powerprot.set_protection_threshold(0)
await device.update()
await powerprot.set_enabled(True)
params = {
"enabled": True,
"protection_power": int(powerprot._max_power / 2),
}
call_spy.assert_called_with("set_protection_power", params)
# Enable false should not update the threshold
call_spy.reset_mock()
await powerprot.set_protection_threshold(0)
await device.update()
await powerprot.set_enabled(False)
params = {
"enabled": False,
"protection_power": 0,
}
call_spy.assert_called_with("set_protection_power", params)
finally:
await powerprot.set_enabled(original_enabled, threshold=original_threshold)
@powerprotection @powerprotection
async def test_set_threshold(dev: SmartDevice, mocker: MockerFixture): async def test_set_threshold(dev: SmartDevice, mocker: MockerFixture):
"""Test enable.""" """Test enable."""
_skip_on_unavailable(dev) powerprot = next(get_parent_and_child_modules(dev, Module.PowerProtection))
assert powerprot
powerprot: PowerProtection = dev.modules[Module.PowerProtection]
call_spy = mocker.spy(powerprot, "call") call_spy = mocker.spy(powerprot, "call")
await powerprot.set_protection_threshold(123) await powerprot.set_protection_threshold(123)