diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c274bb97..d322dd99 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -23,7 +23,7 @@ repos: additional_dependencies: [types-click] exclude: | (?x)^( - kasa/modulemapping\.py| + kasa/typedmapping\.py| )$ diff --git a/kasa/cli.py b/kasa/cli.py index 235387bc..090c00b1 100755 --- a/kasa/cli.py +++ b/kasa/cli.py @@ -11,7 +11,7 @@ import sys from contextlib import asynccontextmanager from functools import singledispatch, wraps from pprint import pformat as pf -from typing import Any, cast +from typing import TYPE_CHECKING, Any, cast import asyncclick as click from pydantic.v1 import ValidationError @@ -43,6 +43,9 @@ from kasa.iot import ( from kasa.iot.modules import Usage from kasa.smart import SmartDevice +if TYPE_CHECKING: + from kasa.typedmapping import FeatureId, FeatureMapping + try: from rich import print as _do_echo except ImportError: @@ -582,7 +585,7 @@ async def sysinfo(dev): def _echo_features( - features: dict[str, Feature], + features: FeatureMapping | dict[FeatureId | str, Feature], title: str, category: Feature.Category | None = None, verbose: bool = False, diff --git a/kasa/device.py b/kasa/device.py index 7156a219..89c7746b 100644 --- a/kasa/device.py +++ b/kasa/device.py @@ -6,7 +6,7 @@ import logging from abc import ABC, abstractmethod from dataclasses import dataclass from datetime import datetime -from typing import TYPE_CHECKING, Any, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Mapping, Sequence, cast from warnings import warn from .credentials import Credentials @@ -18,10 +18,11 @@ from .feature import Feature from .iotprotocol import IotProtocol from .module import Module from .protocol import BaseProtocol +from .typedmapping import FeatureMapping from .xortransport import XorTransport if TYPE_CHECKING: - from .modulemapping import ModuleMapping, ModuleName + from .typedmapping import ModuleMapping, ModuleName @dataclass @@ -271,9 +272,9 @@ class Device(ABC): return {feat.name: feat.value for feat in self._features.values()} @property - def features(self) -> dict[str, Feature]: + def features(self) -> FeatureMapping: """Return the list of supported features.""" - return self._features + return cast(FeatureMapping, self._features) def _add_feature(self, feature: Feature): """Add a new feature to the device.""" diff --git a/kasa/feature.py b/kasa/feature.py index 1f7d3f3d..58885481 100644 --- a/kasa/feature.py +++ b/kasa/feature.py @@ -4,20 +4,56 @@ from __future__ import annotations import logging from dataclasses import dataclass +from datetime import datetime from enum import Enum, auto -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Final, Generic, TypeVar, cast + +from kasa.typedmapping import FeatureId if TYPE_CHECKING: from .device import Device + from .interfaces.light import HSV _LOGGER = logging.getLogger(__name__) +_T = TypeVar("_T") + @dataclass -class Feature: +class Feature(Generic[_T]): """Feature defines a generic interface for device features.""" + class Id: + """Class containing typed common feature ids.""" + + LED: Final[FeatureId[bool]] = FeatureId("led") + LIGHT_EFFECT: Final[FeatureId[str]] = FeatureId("light_effect") + LIGHT_PRESET: Final[FeatureId[str]] = FeatureId("light_preset") + RSSI: Final[FeatureId[int]] = FeatureId("rssi") + ON_SINCE: Final[FeatureId[datetime]] = FeatureId("on_since") + AMBIENT_LIGHT: Final[FeatureId[int]] = FeatureId("ambient_light") + + CLOUD_CONNECTION: Final[FeatureId[bool]] = FeatureId("cloud_connection") + CURRENT_CONSUMPTION: Final[FeatureId[float]] = FeatureId("current_consumption") + EMETER_TODAY: Final[FeatureId[float]] = FeatureId("emeter_today") + CONSUMPTION_THIS_MONTH: Final[FeatureId[float]] = FeatureId( + "consumption_this_month" + ) + EMETER_TOTAL: Final[FeatureId[float]] = FeatureId("emeter_total") + VOLTAGE: Final[FeatureId[float]] = FeatureId("voltage") + CURRENT: Final[FeatureId[float]] = FeatureId("current") + + BRIGHTNESS: Final[FeatureId[int]] = FeatureId("brightness") + COLOUR_TEMPERATURE: Final[FeatureId[int]] = FeatureId("color_temp") + HSV: Final[FeatureId[HSV]] = FeatureId("hsv") + + DEVICE_ID: Final[FeatureId[str]] = FeatureId("device_id") + STATE: Final[FeatureId[bool]] = FeatureId("state") + SIGNAL_LEVEL: Final[FeatureId[int]] = FeatureId("signal_level") + SSID: Final[FeatureId[str]] = FeatureId("ssid") + OVERHEATED: Final[FeatureId[bool]] = FeatureId("overheated") + class Type(Enum): """Type to help decide how to present the feature.""" @@ -96,7 +132,7 @@ class Feature: # Choice-specific attributes #: List of choices as enum - choices: list[str] | None = None + choices: list[_T] | None = None #: Attribute name of the choices getter property. #: If set, this property will be used to set *choices*. choices_getter: str | None = None @@ -131,30 +167,32 @@ class Feature: ) @property - def value(self): + def value(self) -> _T: """Return the current value.""" if self.type == Feature.Type.Action: - return "" + return cast(_T, "") if self.attribute_getter is None: raise ValueError("Not an action and no attribute_getter set") container = self.container if self.container is not None else self.device - if isinstance(self.attribute_getter, Callable): + if callable(self.attribute_getter): return self.attribute_getter(container) return getattr(container, self.attribute_getter) - async def set_value(self, value): + async def set_value(self, value: _T) -> Any: """Set the value.""" if self.attribute_setter is None: raise ValueError("Tried to set read-only feature.") if self.type == Feature.Type.Number: # noqa: SIM102 + if not isinstance(value, (int, float)): + raise ValueError("value must be a number") if value < self.minimum_value or value > self.maximum_value: raise ValueError( f"Value {value} out of range " f"[{self.minimum_value}, {self.maximum_value}]" ) elif self.type == Feature.Type.Choice: # noqa: SIM102 - if value not in self.choices: + if not self.choices or value not in self.choices: raise ValueError( f"Unexpected value for {self.name}: {value}" f" - allowed: {self.choices}" diff --git a/kasa/iot/iotdevice.py b/kasa/iot/iotdevice.py index 25e3b44d..3ef48d99 100755 --- a/kasa/iot/iotdevice.py +++ b/kasa/iot/iotdevice.py @@ -27,8 +27,8 @@ from ..emeterstatus import EmeterStatus from ..exceptions import KasaException from ..feature import Feature from ..module import Module -from ..modulemapping import ModuleMapping, ModuleName from ..protocol import BaseProtocol +from ..typedmapping import ModuleMapping, ModuleName from .iotmodule import IotModule from .modules import Emeter diff --git a/kasa/module.py b/kasa/module.py index a2a9c931..e0fbef22 100644 --- a/kasa/module.py +++ b/kasa/module.py @@ -12,7 +12,7 @@ from typing import ( from .exceptions import KasaException from .feature import Feature -from .modulemapping import ModuleName +from .typedmapping import ModuleName if TYPE_CHECKING: from . import interfaces diff --git a/kasa/smart/smartdevice.py b/kasa/smart/smartdevice.py index 3250c98e..1d34511c 100644 --- a/kasa/smart/smartdevice.py +++ b/kasa/smart/smartdevice.py @@ -15,8 +15,8 @@ from ..emeterstatus import EmeterStatus from ..exceptions import AuthenticationError, DeviceError, KasaException, SmartErrorCode from ..feature import Feature from ..module import Module -from ..modulemapping import ModuleMapping, ModuleName from ..smartprotocol import SmartProtocol +from ..typedmapping import ModuleMapping, ModuleName from .modules import ( Cloud, DeviceModule, diff --git a/kasa/tests/smart/features/test_brightness.py b/kasa/tests/smart/features/test_brightness.py index e3c3c530..1adcf6aa 100644 --- a/kasa/tests/smart/features/test_brightness.py +++ b/kasa/tests/smart/features/test_brightness.py @@ -1,5 +1,9 @@ -import pytest +from typing import TYPE_CHECKING +import pytest +from typing_extensions import assert_type + +from kasa import Feature from kasa.iot import IotDevice from kasa.smart import SmartDevice from kasa.tests.conftest import dimmable_iot, parametrize @@ -16,7 +20,9 @@ async def test_brightness_component(dev: SmartDevice): assert "brightness" in dev._components # Test getting the value - feature = dev.features["brightness"] + feature = dev.features[Feature.Id.BRIGHTNESS] + if TYPE_CHECKING: + assert_type(feature.value, int) assert isinstance(feature.value, int) assert feature.value > 1 and feature.value <= 100 diff --git a/kasa/tests/test_feature.py b/kasa/tests/test_feature.py index 0fb7156d..4bc2a7d8 100644 --- a/kasa/tests/test_feature.py +++ b/kasa/tests/test_feature.py @@ -18,7 +18,7 @@ class DummyDevice: def dummy_feature() -> Feature: # create_autospec for device slows tests way too much, so we use a dummy here - feat = Feature( + feat: Feature = Feature( device=DummyDevice(), # type: ignore[arg-type] id="dummy_feature", name="dummy_feature", diff --git a/kasa/modulemapping.py b/kasa/typedmapping.py similarity index 52% rename from kasa/modulemapping.py rename to kasa/typedmapping.py index 06ba8619..2e311b69 100644 --- a/kasa/modulemapping.py +++ b/kasa/typedmapping.py @@ -1,6 +1,6 @@ -"""Module for Implementation for ModuleMapping and ModuleName types. +"""Module for Implementation for typed mappings. -Custom dict for getting typed modules from the module dict. +Custom mappings for getting typed modules and features from mapping collections. """ from __future__ import annotations @@ -12,6 +12,8 @@ if TYPE_CHECKING: _ModuleT = TypeVar("_ModuleT", bound="Module") +_FeatureT = TypeVar("_FeatureT") + class ModuleName(str, Generic[_ModuleT]): """Generic Module name type. @@ -22,4 +24,14 @@ class ModuleName(str, Generic[_ModuleT]): __slots__ = () +class FeatureId(str, Generic[_FeatureT]): + """Generic feature id type. + + At runtime this is a generic subclass of str. + """ + + __slots__ = () + + ModuleMapping = dict +FeatureMapping = dict diff --git a/kasa/modulemapping.pyi b/kasa/typedmapping.pyi similarity index 61% rename from kasa/modulemapping.pyi rename to kasa/typedmapping.pyi index 8d110d39..29b9757b 100644 --- a/kasa/modulemapping.pyi +++ b/kasa/typedmapping.pyi @@ -2,8 +2,9 @@ from abc import ABCMeta from collections.abc import Mapping -from typing import Generic, TypeVar, overload +from typing import Any, Generic, TypeVar, overload +from .feature import Feature from .module import Module __all__ = [ @@ -14,6 +15,9 @@ __all__ = [ _ModuleT = TypeVar("_ModuleT", bound=Module, covariant=True) _ModuleBaseT = TypeVar("_ModuleBaseT", bound=Module, covariant=True) +_FeatureT = TypeVar("_FeatureT") +_T = TypeVar("_T") + class ModuleName(Generic[_ModuleT]): """Class for typed Module names. At runtime delegated to str.""" @@ -23,6 +27,15 @@ class ModuleName(Generic[_ModuleT]): def __eq__(self, other: object) -> bool: ... def __getitem__(self, index: int) -> str: ... +class FeatureId(Generic[_FeatureT]): + """Class for typed Module names. At runtime delegated to str.""" + + def __init__(self, value: str, /) -> None: ... + def __len__(self) -> int: ... + def __hash__(self) -> int: ... + def __eq__(self, other: object) -> bool: ... + def __getitem__(self, index: int) -> str: ... + class ModuleMapping( Mapping[ModuleName[_ModuleBaseT] | str, _ModuleBaseT], metaclass=ABCMeta ): @@ -45,6 +58,26 @@ class ModuleMapping( self, key: ModuleName[_ModuleT] | str, / ) -> _ModuleT | _ModuleBaseT | None: ... +class FeatureMapping(Mapping[FeatureId[Any] | str, Any], metaclass=ABCMeta): + """Custom dict type to provide better value type hints for Module key types.""" + + @overload + def __getitem__(self, key: FeatureId[_FeatureT], /) -> Feature[_FeatureT]: ... + @overload + def __getitem__(self, key: str, /) -> Feature[Any]: ... + @overload + def __getitem__( + self, key: FeatureId[_FeatureT] | str, / + ) -> Feature[_FeatureT] | Feature[Any]: ... + @overload # type: ignore[override] + def get(self, key: FeatureId[_FeatureT], /) -> Feature[_FeatureT] | None: ... + @overload + def get(self, key: str, /) -> Feature[Any] | None: ... + @overload + def get( + self, key: FeatureId[_FeatureT] | str, / + ) -> Feature[_FeatureT] | Feature[Any] | None: ... + def _test_module_mapping_typing() -> None: """Test ModuleMapping overloads work as intended. @@ -94,3 +127,35 @@ def _test_module_mapping_typing() -> None: device_modules_3: ModuleMapping[Module] = smart_modules # noqa: F841 NEW_MODULE: ModuleName[Module] = NEW_SMART_MODULE # noqa: F841 NEW_MODULE_2: ModuleName[Module] = NEW_IOT_MODULE # noqa: F841 + +def _test_feature_mapping_typing() -> None: + """Test ModuleMapping overloads work as intended. + + This is tested during the mypy run and needs to be in this file. + """ + from typing import Any, cast + + from typing_extensions import assert_type + + from .feature import Feature + + featstr: Feature[str] + featint: Feature[int] + assert_type(featstr.value, str) + assert_type(featint.value, int) + + INT_FEATURE_ID: FeatureId[int] = FeatureId("intfeature") + STR_FEATURE_ID: FeatureId[str] = FeatureId("strfeature") + + features: FeatureMapping = cast(FeatureMapping, {}) + assert_type(features[INT_FEATURE_ID], Feature[int]) + assert_type(features[STR_FEATURE_ID], Feature[str]) + assert_type(features["foobar"], Feature) + + assert_type(features[INT_FEATURE_ID].value, int) + assert_type(features[STR_FEATURE_ID].value, str) + assert_type(features["foobar"].value, Any) + + assert_type(features.get(INT_FEATURE_ID), Feature[int] | None) + assert_type(features.get(STR_FEATURE_ID), Feature[str] | None) + assert_type(features.get("foobar"), Feature | None)