Add generic typing support to features

This commit is contained in:
sdb9696 2024-05-21 16:03:21 +01:00
parent 5e619af29f
commit a3a5c5df55
11 changed files with 149 additions and 24 deletions

View File

@ -23,7 +23,7 @@ repos:
additional_dependencies: [types-click] additional_dependencies: [types-click]
exclude: | exclude: |
(?x)^( (?x)^(
kasa/modulemapping\.py| kasa/typedmapping\.py|
)$ )$

View File

@ -11,7 +11,7 @@ import sys
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from functools import singledispatch, wraps from functools import singledispatch, wraps
from pprint import pformat as pf from pprint import pformat as pf
from typing import Any, cast from typing import TYPE_CHECKING, Any, cast
import asyncclick as click import asyncclick as click
from pydantic.v1 import ValidationError from pydantic.v1 import ValidationError
@ -43,6 +43,9 @@ from kasa.iot import (
from kasa.iot.modules import Usage from kasa.iot.modules import Usage
from kasa.smart import SmartDevice from kasa.smart import SmartDevice
if TYPE_CHECKING:
from kasa.typedmapping import FeatureId, FeatureMapping
try: try:
from rich import print as _do_echo from rich import print as _do_echo
except ImportError: except ImportError:
@ -582,7 +585,7 @@ async def sysinfo(dev):
def _echo_features( def _echo_features(
features: dict[str, Feature], features: FeatureMapping | dict[FeatureId | str, Feature],
title: str, title: str,
category: Feature.Category | None = None, category: Feature.Category | None = None,
verbose: bool = False, verbose: bool = False,

View File

@ -6,7 +6,7 @@ import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING, Any, Mapping, Sequence from typing import TYPE_CHECKING, Any, Mapping, Sequence, cast
from warnings import warn from warnings import warn
from .credentials import Credentials from .credentials import Credentials
@ -18,10 +18,11 @@ from .feature import Feature
from .iotprotocol import IotProtocol from .iotprotocol import IotProtocol
from .module import Module from .module import Module
from .protocol import BaseProtocol from .protocol import BaseProtocol
from .typedmapping import FeatureMapping
from .xortransport import XorTransport from .xortransport import XorTransport
if TYPE_CHECKING: if TYPE_CHECKING:
from .modulemapping import ModuleMapping, ModuleName from .typedmapping import ModuleMapping, ModuleName
@dataclass @dataclass
@ -271,9 +272,9 @@ class Device(ABC):
return {feat.name: feat.value for feat in self._features.values()} return {feat.name: feat.value for feat in self._features.values()}
@property @property
def features(self) -> dict[str, Feature]: def features(self) -> FeatureMapping:
"""Return the list of supported features.""" """Return the list of supported features."""
return self._features return cast(FeatureMapping, self._features)
def _add_feature(self, feature: Feature): def _add_feature(self, feature: Feature):
"""Add a new feature to the device.""" """Add a new feature to the device."""

View File

@ -4,20 +4,56 @@ from __future__ import annotations
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime
from enum import Enum, auto from enum import Enum, auto
from typing import TYPE_CHECKING, Any, Callable from typing import TYPE_CHECKING, Any, Callable, Final, Generic, TypeVar, cast
from kasa.typedmapping import FeatureId
if TYPE_CHECKING: if TYPE_CHECKING:
from .device import Device from .device import Device
from .interfaces.light import HSV
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
_T = TypeVar("_T")
@dataclass @dataclass
class Feature: class Feature(Generic[_T]):
"""Feature defines a generic interface for device features.""" """Feature defines a generic interface for device features."""
class Id:
"""Class containing typed common feature ids."""
LED: Final[FeatureId[bool]] = FeatureId("led")
LIGHT_EFFECT: Final[FeatureId[str]] = FeatureId("light_effect")
LIGHT_PRESET: Final[FeatureId[str]] = FeatureId("light_preset")
RSSI: Final[FeatureId[int]] = FeatureId("rssi")
ON_SINCE: Final[FeatureId[datetime]] = FeatureId("on_since")
AMBIENT_LIGHT: Final[FeatureId[int]] = FeatureId("ambient_light")
CLOUD_CONNECTION: Final[FeatureId[bool]] = FeatureId("cloud_connection")
CURRENT_CONSUMPTION: Final[FeatureId[float]] = FeatureId("current_consumption")
EMETER_TODAY: Final[FeatureId[float]] = FeatureId("emeter_today")
CONSUMPTION_THIS_MONTH: Final[FeatureId[float]] = FeatureId(
"consumption_this_month"
)
EMETER_TOTAL: Final[FeatureId[float]] = FeatureId("emeter_total")
VOLTAGE: Final[FeatureId[float]] = FeatureId("voltage")
CURRENT: Final[FeatureId[float]] = FeatureId("current")
BRIGHTNESS: Final[FeatureId[int]] = FeatureId("brightness")
COLOUR_TEMPERATURE: Final[FeatureId[int]] = FeatureId("color_temp")
HSV: Final[FeatureId[HSV]] = FeatureId("hsv")
DEVICE_ID: Final[FeatureId[str]] = FeatureId("device_id")
STATE: Final[FeatureId[bool]] = FeatureId("state")
SIGNAL_LEVEL: Final[FeatureId[int]] = FeatureId("signal_level")
SSID: Final[FeatureId[str]] = FeatureId("ssid")
OVERHEATED: Final[FeatureId[bool]] = FeatureId("overheated")
class Type(Enum): class Type(Enum):
"""Type to help decide how to present the feature.""" """Type to help decide how to present the feature."""
@ -96,7 +132,7 @@ class Feature:
# Choice-specific attributes # Choice-specific attributes
#: List of choices as enum #: List of choices as enum
choices: list[str] | None = None choices: list[_T] | None = None
#: Attribute name of the choices getter property. #: Attribute name of the choices getter property.
#: If set, this property will be used to set *choices*. #: If set, this property will be used to set *choices*.
choices_getter: str | None = None choices_getter: str | None = None
@ -131,30 +167,32 @@ class Feature:
) )
@property @property
def value(self): def value(self) -> _T:
"""Return the current value.""" """Return the current value."""
if self.type == Feature.Type.Action: if self.type == Feature.Type.Action:
return "<Action>" return cast(_T, "<Action>")
if self.attribute_getter is None: if self.attribute_getter is None:
raise ValueError("Not an action and no attribute_getter set") raise ValueError("Not an action and no attribute_getter set")
container = self.container if self.container is not None else self.device container = self.container if self.container is not None else self.device
if isinstance(self.attribute_getter, Callable): if callable(self.attribute_getter):
return self.attribute_getter(container) return self.attribute_getter(container)
return getattr(container, self.attribute_getter) return getattr(container, self.attribute_getter)
async def set_value(self, value): async def set_value(self, value: _T) -> Any:
"""Set the value.""" """Set the value."""
if self.attribute_setter is None: if self.attribute_setter is None:
raise ValueError("Tried to set read-only feature.") raise ValueError("Tried to set read-only feature.")
if self.type == Feature.Type.Number: # noqa: SIM102 if self.type == Feature.Type.Number: # noqa: SIM102
if not isinstance(value, (int, float)):
raise ValueError("value must be a number")
if value < self.minimum_value or value > self.maximum_value: if value < self.minimum_value or value > self.maximum_value:
raise ValueError( raise ValueError(
f"Value {value} out of range " f"Value {value} out of range "
f"[{self.minimum_value}, {self.maximum_value}]" f"[{self.minimum_value}, {self.maximum_value}]"
) )
elif self.type == Feature.Type.Choice: # noqa: SIM102 elif self.type == Feature.Type.Choice: # noqa: SIM102
if value not in self.choices: if not self.choices or value not in self.choices:
raise ValueError( raise ValueError(
f"Unexpected value for {self.name}: {value}" f"Unexpected value for {self.name}: {value}"
f" - allowed: {self.choices}" f" - allowed: {self.choices}"

View File

@ -27,8 +27,8 @@ from ..emeterstatus import EmeterStatus
from ..exceptions import KasaException from ..exceptions import KasaException
from ..feature import Feature from ..feature import Feature
from ..module import Module from ..module import Module
from ..modulemapping import ModuleMapping, ModuleName
from ..protocol import BaseProtocol from ..protocol import BaseProtocol
from ..typedmapping import ModuleMapping, ModuleName
from .iotmodule import IotModule from .iotmodule import IotModule
from .modules import Emeter from .modules import Emeter

View File

@ -12,7 +12,7 @@ from typing import (
from .exceptions import KasaException from .exceptions import KasaException
from .feature import Feature from .feature import Feature
from .modulemapping import ModuleName from .typedmapping import ModuleName
if TYPE_CHECKING: if TYPE_CHECKING:
from . import interfaces from . import interfaces

View File

@ -15,8 +15,8 @@ from ..emeterstatus import EmeterStatus
from ..exceptions import AuthenticationError, DeviceError, KasaException, SmartErrorCode from ..exceptions import AuthenticationError, DeviceError, KasaException, SmartErrorCode
from ..feature import Feature from ..feature import Feature
from ..module import Module from ..module import Module
from ..modulemapping import ModuleMapping, ModuleName
from ..smartprotocol import SmartProtocol from ..smartprotocol import SmartProtocol
from ..typedmapping import ModuleMapping, ModuleName
from .modules import ( from .modules import (
Cloud, Cloud,
DeviceModule, DeviceModule,

View File

@ -1,5 +1,9 @@
import pytest from typing import TYPE_CHECKING
import pytest
from typing_extensions import assert_type
from kasa import Feature
from kasa.iot import IotDevice from kasa.iot import IotDevice
from kasa.smart import SmartDevice from kasa.smart import SmartDevice
from kasa.tests.conftest import dimmable_iot, parametrize from kasa.tests.conftest import dimmable_iot, parametrize
@ -16,7 +20,9 @@ async def test_brightness_component(dev: SmartDevice):
assert "brightness" in dev._components assert "brightness" in dev._components
# Test getting the value # Test getting the value
feature = dev.features["brightness"] feature = dev.features[Feature.Id.BRIGHTNESS]
if TYPE_CHECKING:
assert_type(feature.value, int)
assert isinstance(feature.value, int) assert isinstance(feature.value, int)
assert feature.value > 1 and feature.value <= 100 assert feature.value > 1 and feature.value <= 100

View File

@ -18,7 +18,7 @@ class DummyDevice:
def dummy_feature() -> Feature: def dummy_feature() -> Feature:
# create_autospec for device slows tests way too much, so we use a dummy here # create_autospec for device slows tests way too much, so we use a dummy here
feat = Feature( feat: Feature = Feature(
device=DummyDevice(), # type: ignore[arg-type] device=DummyDevice(), # type: ignore[arg-type]
id="dummy_feature", id="dummy_feature",
name="dummy_feature", name="dummy_feature",

View File

@ -1,6 +1,6 @@
"""Module for Implementation for ModuleMapping and ModuleName types. """Module for Implementation for typed mappings.
Custom dict for getting typed modules from the module dict. Custom mappings for getting typed modules and features from mapping collections.
""" """
from __future__ import annotations from __future__ import annotations
@ -12,6 +12,8 @@ if TYPE_CHECKING:
_ModuleT = TypeVar("_ModuleT", bound="Module") _ModuleT = TypeVar("_ModuleT", bound="Module")
_FeatureT = TypeVar("_FeatureT")
class ModuleName(str, Generic[_ModuleT]): class ModuleName(str, Generic[_ModuleT]):
"""Generic Module name type. """Generic Module name type.
@ -22,4 +24,14 @@ class ModuleName(str, Generic[_ModuleT]):
__slots__ = () __slots__ = ()
class FeatureId(str, Generic[_FeatureT]):
"""Generic feature id type.
At runtime this is a generic subclass of str.
"""
__slots__ = ()
ModuleMapping = dict ModuleMapping = dict
FeatureMapping = dict

View File

@ -2,8 +2,9 @@
from abc import ABCMeta from abc import ABCMeta
from collections.abc import Mapping from collections.abc import Mapping
from typing import Generic, TypeVar, overload from typing import Any, Generic, TypeVar, overload
from .feature import Feature
from .module import Module from .module import Module
__all__ = [ __all__ = [
@ -14,6 +15,9 @@ __all__ = [
_ModuleT = TypeVar("_ModuleT", bound=Module, covariant=True) _ModuleT = TypeVar("_ModuleT", bound=Module, covariant=True)
_ModuleBaseT = TypeVar("_ModuleBaseT", bound=Module, covariant=True) _ModuleBaseT = TypeVar("_ModuleBaseT", bound=Module, covariant=True)
_FeatureT = TypeVar("_FeatureT")
_T = TypeVar("_T")
class ModuleName(Generic[_ModuleT]): class ModuleName(Generic[_ModuleT]):
"""Class for typed Module names. At runtime delegated to str.""" """Class for typed Module names. At runtime delegated to str."""
@ -23,6 +27,15 @@ class ModuleName(Generic[_ModuleT]):
def __eq__(self, other: object) -> bool: ... def __eq__(self, other: object) -> bool: ...
def __getitem__(self, index: int) -> str: ... def __getitem__(self, index: int) -> str: ...
class FeatureId(Generic[_FeatureT]):
"""Class for typed Module names. At runtime delegated to str."""
def __init__(self, value: str, /) -> None: ...
def __len__(self) -> int: ...
def __hash__(self) -> int: ...
def __eq__(self, other: object) -> bool: ...
def __getitem__(self, index: int) -> str: ...
class ModuleMapping( class ModuleMapping(
Mapping[ModuleName[_ModuleBaseT] | str, _ModuleBaseT], metaclass=ABCMeta Mapping[ModuleName[_ModuleBaseT] | str, _ModuleBaseT], metaclass=ABCMeta
): ):
@ -45,6 +58,26 @@ class ModuleMapping(
self, key: ModuleName[_ModuleT] | str, / self, key: ModuleName[_ModuleT] | str, /
) -> _ModuleT | _ModuleBaseT | None: ... ) -> _ModuleT | _ModuleBaseT | None: ...
class FeatureMapping(Mapping[FeatureId[Any] | str, Any], metaclass=ABCMeta):
"""Custom dict type to provide better value type hints for Module key types."""
@overload
def __getitem__(self, key: FeatureId[_FeatureT], /) -> Feature[_FeatureT]: ...
@overload
def __getitem__(self, key: str, /) -> Feature[Any]: ...
@overload
def __getitem__(
self, key: FeatureId[_FeatureT] | str, /
) -> Feature[_FeatureT] | Feature[Any]: ...
@overload # type: ignore[override]
def get(self, key: FeatureId[_FeatureT], /) -> Feature[_FeatureT] | None: ...
@overload
def get(self, key: str, /) -> Feature[Any] | None: ...
@overload
def get(
self, key: FeatureId[_FeatureT] | str, /
) -> Feature[_FeatureT] | Feature[Any] | None: ...
def _test_module_mapping_typing() -> None: def _test_module_mapping_typing() -> None:
"""Test ModuleMapping overloads work as intended. """Test ModuleMapping overloads work as intended.
@ -94,3 +127,35 @@ def _test_module_mapping_typing() -> None:
device_modules_3: ModuleMapping[Module] = smart_modules # noqa: F841 device_modules_3: ModuleMapping[Module] = smart_modules # noqa: F841
NEW_MODULE: ModuleName[Module] = NEW_SMART_MODULE # noqa: F841 NEW_MODULE: ModuleName[Module] = NEW_SMART_MODULE # noqa: F841
NEW_MODULE_2: ModuleName[Module] = NEW_IOT_MODULE # noqa: F841 NEW_MODULE_2: ModuleName[Module] = NEW_IOT_MODULE # noqa: F841
def _test_feature_mapping_typing() -> None:
"""Test ModuleMapping overloads work as intended.
This is tested during the mypy run and needs to be in this file.
"""
from typing import Any, cast
from typing_extensions import assert_type
from .feature import Feature
featstr: Feature[str]
featint: Feature[int]
assert_type(featstr.value, str)
assert_type(featint.value, int)
INT_FEATURE_ID: FeatureId[int] = FeatureId("intfeature")
STR_FEATURE_ID: FeatureId[str] = FeatureId("strfeature")
features: FeatureMapping = cast(FeatureMapping, {})
assert_type(features[INT_FEATURE_ID], Feature[int])
assert_type(features[STR_FEATURE_ID], Feature[str])
assert_type(features["foobar"], Feature)
assert_type(features[INT_FEATURE_ID].value, int)
assert_type(features[STR_FEATURE_ID].value, str)
assert_type(features["foobar"].value, Any)
assert_type(features.get(INT_FEATURE_ID), Feature[int] | None)
assert_type(features.get(STR_FEATURE_ID), Feature[str] | None)
assert_type(features.get("foobar"), Feature | None)