diff --git a/docs/source/conf.py b/docs/source/conf.py index 5554abf1..03e44d95 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -66,6 +66,6 @@ todo_include_todos = True myst_heading_anchors = 3 -def setup(app): +def setup(app): # noqa: ANN201,ANN001 # add copybutton to hide the >>> prompts, see https://github.com/readthedocs/sphinx_rtd_theme/issues/167 app.add_js_file("copybutton.js") diff --git a/kasa/__init__.py b/kasa/__init__.py index a74cb4c4..ffeaa503 100755 --- a/kasa/__init__.py +++ b/kasa/__init__.py @@ -13,7 +13,7 @@ to be handled by the user of the library. """ from importlib.metadata import version -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from warnings import warn from kasa.credentials import Credentials @@ -101,7 +101,7 @@ deprecated_classes = { } -def __getattr__(name): +def __getattr__(name: str) -> Any: if name in deprecated_names: warn(f"{name} is deprecated", DeprecationWarning, stacklevel=2) return globals()[f"_deprecated_{name}"] @@ -117,7 +117,7 @@ def __getattr__(name): ) return new_class if name in deprecated_classes: - new_class = deprecated_classes[name] + new_class = deprecated_classes[name] # type: ignore[assignment] msg = f"{name} is deprecated, use {new_class.__name__} instead" warn(msg, DeprecationWarning, stacklevel=2) return new_class diff --git a/kasa/aestransport.py b/kasa/aestransport.py index ae75117c..fc807fb3 100644 --- a/kasa/aestransport.py +++ b/kasa/aestransport.py @@ -146,7 +146,7 @@ class AesTransport(BaseTransport): pw = base64.b64encode(credentials.password.encode()).decode() return un, pw - def _handle_response_error_code(self, resp_dict: Any, msg: str) -> None: + def _handle_response_error_code(self, resp_dict: dict, msg: str) -> None: error_code_raw = resp_dict.get("error_code") try: error_code = SmartErrorCode.from_int(error_code_raw) @@ -191,14 +191,14 @@ class AesTransport(BaseTransport): + f"status code {status_code} to passthrough" ) - self._handle_response_error_code( - resp_dict, "Error sending secure_passthrough message" - ) - if TYPE_CHECKING: resp_dict = cast(Dict[str, Any], resp_dict) assert self._encryption_session is not None + self._handle_response_error_code( + resp_dict, "Error sending secure_passthrough message" + ) + raw_response: str = resp_dict["result"]["response"] try: @@ -219,7 +219,7 @@ class AesTransport(BaseTransport): ) from ex return ret_val # type: ignore[return-value] - async def perform_login(self): + async def perform_login(self) -> None: """Login to the device.""" try: await self.try_login(self._login_params) @@ -324,11 +324,11 @@ class AesTransport(BaseTransport): + f"status code {status_code} to handshake" ) - self._handle_response_error_code(resp_dict, "Unable to complete handshake") - if TYPE_CHECKING: resp_dict = cast(Dict[str, Any], resp_dict) + self._handle_response_error_code(resp_dict, "Unable to complete handshake") + handshake_key = resp_dict["result"]["key"] if ( @@ -355,7 +355,7 @@ class AesTransport(BaseTransport): _LOGGER.debug("Handshake with %s complete", self._host) - def _handshake_session_expired(self): + def _handshake_session_expired(self) -> bool: """Return true if session has expired.""" return ( self._session_expire_at is None @@ -394,7 +394,9 @@ class AesEncyptionSession: """Class for an AES encryption session.""" @staticmethod - def create_from_keypair(handshake_key: str, keypair: KeyPair): + def create_from_keypair( + handshake_key: str, keypair: KeyPair + ) -> AesEncyptionSession: """Create the encryption session.""" handshake_key_bytes: bytes = base64.b64decode(handshake_key.encode()) @@ -404,11 +406,11 @@ class AesEncyptionSession: return AesEncyptionSession(key_and_iv[:16], key_and_iv[16:]) - def __init__(self, key, iv): + def __init__(self, key: bytes, iv: bytes) -> None: self.cipher = Cipher(algorithms.AES(key), modes.CBC(iv)) self.padding_strategy = padding.PKCS7(algorithms.AES.block_size) - def encrypt(self, data) -> bytes: + def encrypt(self, data: bytes) -> bytes: """Encrypt the message.""" encryptor = self.cipher.encryptor() padder = self.padding_strategy.padder() @@ -416,7 +418,7 @@ class AesEncyptionSession: encrypted = encryptor.update(padded_data) + encryptor.finalize() return base64.b64encode(encrypted) - def decrypt(self, data) -> str: + def decrypt(self, data: str | bytes) -> str: """Decrypt the message.""" decryptor = self.cipher.decryptor() unpadder = self.padding_strategy.unpadder() @@ -429,14 +431,16 @@ class KeyPair: """Class for generating key pairs.""" @staticmethod - def create_key_pair(key_size: int = 1024): + def create_key_pair(key_size: int = 1024) -> KeyPair: """Create a key pair.""" private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size) public_key = private_key.public_key() return KeyPair(private_key, public_key) @staticmethod - def create_from_der_keys(private_key_der_b64: str, public_key_der_b64: str): + def create_from_der_keys( + private_key_der_b64: str, public_key_der_b64: str + ) -> KeyPair: """Create a key pair.""" key_bytes = base64.b64decode(private_key_der_b64.encode()) private_key = cast( @@ -449,7 +453,9 @@ class KeyPair: return KeyPair(private_key, public_key) - def __init__(self, private_key: rsa.RSAPrivateKey, public_key: rsa.RSAPublicKey): + def __init__( + self, private_key: rsa.RSAPrivateKey, public_key: rsa.RSAPublicKey + ) -> None: self.private_key = private_key self.public_key = public_key self.private_key_der_bytes = self.private_key.private_bytes( diff --git a/kasa/cli/common.py b/kasa/cli/common.py index fbd6291b..fe7be761 100644 --- a/kasa/cli/common.py +++ b/kasa/cli/common.py @@ -7,7 +7,7 @@ import re import sys from contextlib import contextmanager from functools import singledispatch, update_wrapper, wraps -from typing import Final +from typing import TYPE_CHECKING, Any, Callable, Final import asyncclick as click @@ -37,7 +37,7 @@ except ImportError: """Strip rich formatting from messages.""" @wraps(echo_func) - def wrapper(message=None, *args, **kwargs): + def wrapper(message=None, *args, **kwargs) -> None: if message is not None: message = rich_formatting.sub("", message) echo_func(message, *args, **kwargs) @@ -47,20 +47,20 @@ except ImportError: _echo = _strip_rich_formatting(click.echo) -def echo(*args, **kwargs): +def echo(*args, **kwargs) -> None: """Print a message.""" ctx = click.get_current_context().find_root() if "json" not in ctx.params or ctx.params["json"] is False: _echo(*args, **kwargs) -def error(msg: str): +def error(msg: str) -> None: """Print an error and exit.""" echo(f"[bold red]{msg}[/bold red]") sys.exit(1) -def json_formatter_cb(result, **kwargs): +def json_formatter_cb(result: Any, **kwargs) -> None: """Format and output the result as JSON, if requested.""" if not kwargs.get("json"): return @@ -82,7 +82,7 @@ def json_formatter_cb(result, **kwargs): print(json_content) -def pass_dev_or_child(wrapped_function): +def pass_dev_or_child(wrapped_function: Callable) -> Callable: """Pass the device or child to the click command based on the child options.""" child_help = ( "Child ID or alias for controlling sub-devices. " @@ -133,7 +133,10 @@ def pass_dev_or_child(wrapped_function): async def _get_child_device( - device: Device, child_option, child_index_option, info_command + device: Device, + child_option: str | None, + child_index_option: int | None, + info_command: str | None, ) -> Device | None: def _list_children(): return "\n".join( @@ -178,11 +181,15 @@ async def _get_child_device( f"{child_option} children are:\n{_list_children()}" ) + if TYPE_CHECKING: + assert isinstance(child_index_option, int) + if child_index_option + 1 > len(device.children) or child_index_option < 0: error( f"Invalid index {child_index_option}, " f"device has {len(device.children)} children" ) + child_by_index = device.children[child_index_option] echo(f"Targeting child device {child_by_index.alias}") return child_by_index @@ -195,7 +202,7 @@ def CatchAllExceptions(cls): https://stackoverflow.com/questions/52213375 """ - def _handle_exception(debug, exc): + def _handle_exception(debug, exc) -> None: if isinstance(exc, click.ClickException): raise # Handle exit request from click. diff --git a/kasa/cli/device.py b/kasa/cli/device.py index 9814108c..2e621368 100644 --- a/kasa/cli/device.py +++ b/kasa/cli/device.py @@ -22,7 +22,7 @@ from .common import ( @click.group() @pass_dev_or_child -def device(dev): +def device(dev) -> None: """Commands to control basic device settings.""" diff --git a/kasa/cli/discover.py b/kasa/cli/discover.py index 6a55cb43..8df59de8 100644 --- a/kasa/cli/discover.py +++ b/kasa/cli/discover.py @@ -36,7 +36,7 @@ async def detail(ctx): auth_failed = [] sem = asyncio.Semaphore() - async def print_unsupported(unsupported_exception: UnsupportedDeviceError): + async def print_unsupported(unsupported_exception: UnsupportedDeviceError) -> None: unsupported.append(unsupported_exception) async with sem: if unsupported_exception.discovery_result: @@ -50,7 +50,7 @@ async def detail(ctx): from .device import state - async def print_discovered(dev: Device): + async def print_discovered(dev: Device) -> None: async with sem: try: await dev.update() @@ -189,7 +189,7 @@ async def config(ctx): error(f"Unable to connect to {host}") -def _echo_dictionary(discovery_info: dict): +def _echo_dictionary(discovery_info: dict) -> None: echo("\t[bold]== Discovery information ==[/bold]") for key, value in discovery_info.items(): key_name = " ".join(x.capitalize() or "_" for x in key.split("_")) @@ -197,7 +197,7 @@ def _echo_dictionary(discovery_info: dict): echo(f"\t{key_name_and_spaces}{value}") -def _echo_discovery_info(discovery_info): +def _echo_discovery_info(discovery_info) -> None: # We don't have discovery info when all connection params are passed manually if discovery_info is None: return diff --git a/kasa/cli/feature.py b/kasa/cli/feature.py index f8cba4e3..2c5fa045 100644 --- a/kasa/cli/feature.py +++ b/kasa/cli/feature.py @@ -24,7 +24,7 @@ def _echo_features( category: Feature.Category | None = None, verbose: bool = False, indent: str = "\t", -): +) -> None: """Print out a listing of features and their values.""" if category is not None: features = { @@ -43,7 +43,9 @@ def _echo_features( echo(f"{indent}{feat.name} ({feat.id}): [red]got exception ({ex})[/red]") -def _echo_all_features(features, *, verbose=False, title_prefix=None, indent=""): +def _echo_all_features( + features, *, verbose=False, title_prefix=None, indent="" +) -> None: """Print out all features by category.""" if title_prefix is not None: echo(f"[bold]\n{indent}== {title_prefix} ==[/bold]") diff --git a/kasa/cli/lazygroup.py b/kasa/cli/lazygroup.py index 9e9724aa..a2858634 100644 --- a/kasa/cli/lazygroup.py +++ b/kasa/cli/lazygroup.py @@ -3,6 +3,8 @@ Taken from the click help files. """ +from __future__ import annotations + import importlib import asyncclick as click @@ -11,7 +13,7 @@ import asyncclick as click class LazyGroup(click.Group): """Lazy group class.""" - def __init__(self, *args, lazy_subcommands=None, **kwargs): + def __init__(self, *args, lazy_subcommands=None, **kwargs) -> None: super().__init__(*args, **kwargs) # lazy_subcommands is a map of the form: # @@ -31,9 +33,9 @@ class LazyGroup(click.Group): return self._lazy_load(cmd_name) return super().get_command(ctx, cmd_name) - def format_commands(self, ctx, formatter): + def format_commands(self, ctx, formatter) -> None: """Format the top level help output.""" - sections = {} + sections: dict[str, list] = {} for cmd, parent in self.lazy_subcommands.items(): sections.setdefault(parent, []) cmd_obj = self.get_command(ctx, cmd) diff --git a/kasa/cli/light.py b/kasa/cli/light.py index d9feee78..6b342c3d 100644 --- a/kasa/cli/light.py +++ b/kasa/cli/light.py @@ -15,7 +15,7 @@ from .common import echo, error, pass_dev_or_child @click.group() @pass_dev_or_child -def light(dev): +def light(dev) -> None: """Commands to control light settings.""" diff --git a/kasa/cli/main.py b/kasa/cli/main.py index a386fe4b..d6b9fa9d 100755 --- a/kasa/cli/main.py +++ b/kasa/cli/main.py @@ -43,7 +43,7 @@ ENCRYPT_TYPES = [encrypt_type.value for encrypt_type in DeviceEncryptionType] DEFAULT_TARGET = "255.255.255.255" -def _legacy_type_to_class(_type): +def _legacy_type_to_class(_type: str) -> Any: from kasa.iot import ( IotBulb, IotDimmer, @@ -396,9 +396,9 @@ async def cli( @cli.command() @pass_dev_or_child -async def shell(dev: Device): +async def shell(dev: Device) -> None: """Open interactive shell.""" - echo("Opening shell for %s" % dev) + echo(f"Opening shell for {dev}") from ptpython.repl import embed logging.getLogger("parso").setLevel(logging.WARNING) # prompt parsing diff --git a/kasa/cli/schedule.py b/kasa/cli/schedule.py index 8deda315..7c9c7381 100644 --- a/kasa/cli/schedule.py +++ b/kasa/cli/schedule.py @@ -14,7 +14,7 @@ from .common import ( @click.group() @pass_dev -async def schedule(dev): +async def schedule(dev) -> None: """Scheduling commands.""" diff --git a/kasa/cli/time.py b/kasa/cli/time.py index 904da2ca..9e930108 100644 --- a/kasa/cli/time.py +++ b/kasa/cli/time.py @@ -23,7 +23,7 @@ from .common import ( @click.group(invoke_without_command=True) @click.pass_context -async def time(ctx: click.Context): +async def time(ctx: click.Context) -> None: """Get and set time.""" if ctx.invoked_subcommand is None: await ctx.invoke(time_get) diff --git a/kasa/cli/usage.py b/kasa/cli/usage.py index 1a336c74..314182fd 100644 --- a/kasa/cli/usage.py +++ b/kasa/cli/usage.py @@ -78,13 +78,13 @@ async def energy(dev: Device, year, month, erase): else: emeter_status = dev.emeter_realtime - echo("Current: %s A" % emeter_status["current"]) - echo("Voltage: %s V" % emeter_status["voltage"]) - echo("Power: %s W" % emeter_status["power"]) - echo("Total consumption: %s kWh" % emeter_status["total"]) + echo("Current: {} A".format(emeter_status["current"])) + echo("Voltage: {} V".format(emeter_status["voltage"])) + echo("Power: {} W".format(emeter_status["power"])) + echo("Total consumption: {} kWh".format(emeter_status["total"])) - echo("Today: %s kWh" % dev.emeter_today) - echo("This month: %s kWh" % dev.emeter_this_month) + echo(f"Today: {dev.emeter_today} kWh") + echo(f"This month: {dev.emeter_this_month} kWh") return emeter_status @@ -122,8 +122,8 @@ async def usage(dev: Device, year, month, erase): usage_data = await usage.get_daystat(year=month.year, month=month.month) else: # Call with no argument outputs summary data and returns - echo("Today: %s minutes" % usage.usage_today) - echo("This month: %s minutes" % usage.usage_this_month) + echo(f"Today: {usage.usage_today} minutes") + echo(f"This month: {usage.usage_this_month} minutes") return usage diff --git a/kasa/cli/wifi.py b/kasa/cli/wifi.py index 07fb5f20..924e83f1 100644 --- a/kasa/cli/wifi.py +++ b/kasa/cli/wifi.py @@ -16,7 +16,7 @@ from .common import ( @click.group() @pass_dev -def wifi(dev): +def wifi(dev) -> None: """Commands to control wifi settings.""" diff --git a/kasa/device.py b/kasa/device.py index fb9b9f0c..72c56717 100644 --- a/kasa/device.py +++ b/kasa/device.py @@ -234,10 +234,10 @@ class Device(ABC): return await connect(host=host, config=config) # type: ignore[arg-type] @abstractmethod - async def update(self, update_children: bool = True): + async def update(self, update_children: bool = True) -> None: """Update the device.""" - async def disconnect(self): + async def disconnect(self) -> None: """Disconnect and close any underlying connection resources.""" await self.protocol.close() @@ -257,15 +257,15 @@ class Device(ABC): return not self.is_on @abstractmethod - async def turn_on(self, **kwargs) -> dict | None: + async def turn_on(self, **kwargs) -> dict: """Turn on the device.""" @abstractmethod - async def turn_off(self, **kwargs) -> dict | None: + async def turn_off(self, **kwargs) -> dict: """Turn off the device.""" @abstractmethod - async def set_state(self, on: bool): + async def set_state(self, on: bool) -> dict: """Set the device state to *on*. This allows turning the device on and off. @@ -278,7 +278,7 @@ class Device(ABC): return self.protocol._transport._host @host.setter - def host(self, value): + def host(self, value: str) -> None: """Set the device host. Generally used by discovery to set the hostname after ip discovery. @@ -307,7 +307,7 @@ class Device(ABC): return self._device_type @abstractmethod - def update_from_discover_info(self, info): + def update_from_discover_info(self, info: dict) -> None: """Update state from info from the discover call.""" @property @@ -325,7 +325,7 @@ class Device(ABC): def alias(self) -> str | None: """Returns the device alias or nickname.""" - async def _raw_query(self, request: str | dict) -> Any: + async def _raw_query(self, request: str | dict) -> dict: """Send a raw query to the device.""" return await self.protocol.query(request=request) @@ -407,7 +407,7 @@ class Device(ABC): @property @abstractmethod - def internal_state(self) -> Any: + def internal_state(self) -> dict: """Return all the internal state data.""" @property @@ -420,10 +420,10 @@ class Device(ABC): """Return the list of supported features.""" return self._features - def _add_feature(self, feature: Feature): + def _add_feature(self, feature: Feature) -> None: """Add a new feature to the device.""" if feature.id in self._features: - raise KasaException("Duplicate feature id %s" % feature.id) + raise KasaException(f"Duplicate feature id {feature.id}") assert feature.id is not None # TODO: hack for typing # noqa: S101 self._features[feature.id] = feature @@ -446,11 +446,13 @@ class Device(ABC): """Scan for available wifi networks.""" @abstractmethod - async def wifi_join(self, ssid: str, password: str, keytype: str = "wpa2_psk"): + async def wifi_join( + self, ssid: str, password: str, keytype: str = "wpa2_psk" + ) -> dict: """Join the given wifi network.""" @abstractmethod - async def set_alias(self, alias: str): + async def set_alias(self, alias: str) -> dict: """Set the device name (alias).""" @abstractmethod @@ -468,7 +470,7 @@ class Device(ABC): Note, this does not downgrade the firmware. """ - def __repr__(self): + def __repr__(self) -> str: update_needed = " - update() needed" if not self._last_update else "" return ( f"<{self.device_type} at {self.host} -" @@ -486,7 +488,9 @@ class Device(ABC): "is_strip_socket": (None, DeviceType.StripSocket), } - def _get_replacing_attr(self, module_name: ModuleName, *attrs): + def _get_replacing_attr( + self, module_name: ModuleName | None, *attrs: Any + ) -> str | None: # If module name is None check self if not module_name: check = self @@ -540,7 +544,7 @@ class Device(ABC): "supported_modules": (None, ["modules"]), } - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: # is_device_type if dep_device_type_attr := self._deprecated_device_type_attributes.get(name): module = dep_device_type_attr[0] diff --git a/kasa/device_factory.py b/kasa/device_factory.py index 7f2150d7..0c1ed427 100755 --- a/kasa/device_factory.py +++ b/kasa/device_factory.py @@ -83,7 +83,7 @@ async def _connect(config: DeviceConfig, protocol: BaseProtocol) -> Device: if debug_enabled: start_time = time.perf_counter() - def _perf_log(has_params, perf_type): + def _perf_log(has_params: bool, perf_type: str) -> None: nonlocal start_time if debug_enabled: end_time = time.perf_counter() @@ -150,7 +150,7 @@ def _get_device_type_from_sys_info(info: dict[str, Any]) -> DeviceType: return DeviceType.LightStrip return DeviceType.Bulb - raise UnsupportedDeviceError("Unknown device type: %s" % type_) + raise UnsupportedDeviceError(f"Unknown device type: {type_}") def get_device_class_from_sys_info(sysinfo: dict[str, Any]) -> type[IotDevice]: diff --git a/kasa/deviceconfig.py b/kasa/deviceconfig.py index e0fd1725..f4a5f2a3 100644 --- a/kasa/deviceconfig.py +++ b/kasa/deviceconfig.py @@ -75,14 +75,14 @@ class DeviceFamily(Enum): SmartIpCamera = "SMART.IPCAMERA" -def _dataclass_from_dict(klass, in_val): +def _dataclass_from_dict(klass: Any, in_val: dict) -> Any: if is_dataclass(klass): fieldtypes = {f.name: f.type for f in fields(klass)} val = {} for dict_key in in_val: if dict_key in fieldtypes: if hasattr(fieldtypes[dict_key], "from_dict"): - val[dict_key] = fieldtypes[dict_key].from_dict(in_val[dict_key]) + val[dict_key] = fieldtypes[dict_key].from_dict(in_val[dict_key]) # type: ignore[union-attr] else: val[dict_key] = _dataclass_from_dict( fieldtypes[dict_key], in_val[dict_key] @@ -91,12 +91,12 @@ def _dataclass_from_dict(klass, in_val): raise KasaException( f"Cannot create dataclass from dict, unknown key: {dict_key}" ) - return klass(**val) + return klass(**val) # type: ignore[operator] else: return in_val -def _dataclass_to_dict(in_val): +def _dataclass_to_dict(in_val: Any) -> dict: fieldtypes = {f.name: f.type for f in fields(in_val) if f.compare} out_val = {} for field_name in fieldtypes: @@ -210,7 +210,7 @@ class DeviceConfig: aes_keys: Optional[KeyPairDict] = None - def __post_init__(self): + def __post_init__(self) -> None: if self.connection_type is None: self.connection_type = DeviceConnectionParameters( DeviceFamily.IotSmartPlugSwitch, DeviceEncryptionType.Xor diff --git a/kasa/discover.py b/kasa/discover.py index a774ebde..efb1e5e4 100755 --- a/kasa/discover.py +++ b/kasa/discover.py @@ -89,9 +89,19 @@ import logging import secrets import socket import struct -from collections.abc import Awaitable +from asyncio.transports import DatagramTransport from pprint import pformat as pf -from typing import TYPE_CHECKING, Any, Callable, Dict, NamedTuple, Optional, Type, cast +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Coroutine, + Dict, + NamedTuple, + Optional, + Type, + cast, +) from aiohttp import ClientSession @@ -140,8 +150,8 @@ class ConnectAttempt(NamedTuple): device: type -OnDiscoveredCallable = Callable[[Device], Awaitable[None]] -OnUnsupportedCallable = Callable[[UnsupportedDeviceError], Awaitable[None]] +OnDiscoveredCallable = Callable[[Device], Coroutine] +OnUnsupportedCallable = Callable[[UnsupportedDeviceError], Coroutine] OnConnectAttemptCallable = Callable[[ConnectAttempt, bool], None] DeviceDict = Dict[str, Device] @@ -156,7 +166,7 @@ class _AesDiscoveryQuery: keypair: KeyPair | None = None @classmethod - def generate_query(cls): + def generate_query(cls) -> bytearray: if not cls.keypair: cls.keypair = KeyPair.create_key_pair(key_size=2048) secret = secrets.token_bytes(4) @@ -215,7 +225,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): credentials: Credentials | None = None, timeout: int | None = None, ) -> None: - self.transport = None + self.transport: DatagramTransport | None = None self.discovery_packets = discovery_packets self.interface = interface self.on_discovered = on_discovered @@ -239,16 +249,19 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): self.target_discovered: bool = False self._started_event = asyncio.Event() - def _run_callback_task(self, coro): - task = asyncio.create_task(coro) + def _run_callback_task(self, coro: Coroutine) -> None: + task: asyncio.Task = asyncio.create_task(coro) self.callback_tasks.append(task) - async def wait_for_discovery_to_complete(self): + async def wait_for_discovery_to_complete(self) -> None: """Wait for the discovery task to complete.""" # Give some time for connection_made event to be received async with asyncio_timeout(self.DISCOVERY_START_TIMEOUT): await self._started_event.wait() try: + if TYPE_CHECKING: + assert isinstance(self.discover_task, asyncio.Task) + await self.discover_task except asyncio.CancelledError: # if target_discovered then cancel was called internally @@ -257,11 +270,11 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): # Wait for any pending callbacks to complete await asyncio.gather(*self.callback_tasks) - def connection_made(self, transport) -> None: + def connection_made(self, transport: DatagramTransport) -> None: # type: ignore[override] """Set socket options for broadcasting.""" - self.transport = transport + self.transport = cast(DatagramTransport, transport) - sock = transport.get_extra_info("socket") + sock = self.transport.get_extra_info("socket") sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) try: sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) @@ -292,7 +305,11 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): self.transport.sendto(aes_discovery_query, self.target_2) # type: ignore await asyncio.sleep(sleep_between_packets) - def datagram_received(self, data, addr) -> None: + def datagram_received( + self, + data: bytes, + addr: tuple[str, int], + ) -> None: """Handle discovery responses.""" if TYPE_CHECKING: assert _AesDiscoveryQuery.keypair @@ -338,18 +355,18 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): self._handle_discovered_event() - def _handle_discovered_event(self): + def _handle_discovered_event(self) -> None: """If target is in seen_hosts cancel discover_task.""" if self.target in self.seen_hosts: self.target_discovered = True if self.discover_task: self.discover_task.cancel() - def error_received(self, ex): + def error_received(self, ex: Exception) -> None: """Handle asyncio.Protocol errors.""" _LOGGER.error("Got error: %s", ex) - def connection_lost(self, ex): # pragma: no cover + def connection_lost(self, ex: Exception | None) -> None: # pragma: no cover """Cancel the discover task if running.""" if self.discover_task: self.discover_task.cancel() @@ -372,17 +389,17 @@ class Discover: @staticmethod async def discover( *, - target="255.255.255.255", - on_discovered=None, - discovery_timeout=5, - discovery_packets=3, - interface=None, - on_unsupported=None, - credentials=None, + target: str = "255.255.255.255", + on_discovered: OnDiscoveredCallable | None = None, + discovery_timeout: int = 5, + discovery_packets: int = 3, + interface: str | None = None, + on_unsupported: OnUnsupportedCallable | None = None, + credentials: Credentials | None = None, username: str | None = None, password: str | None = None, - port=None, - timeout=None, + port: int | None = None, + timeout: int | None = None, ) -> DeviceDict: """Discover supported devices. @@ -636,7 +653,7 @@ class Discover: ) if not dev_class: raise UnsupportedDeviceError( - "Unknown device type: %s" % discovery_result.device_type, + f"Unknown device type: {discovery_result.device_type}", discovery_result=info, ) return dev_class diff --git a/kasa/emeterstatus.py b/kasa/emeterstatus.py index 0112b33a..acb87789 100644 --- a/kasa/emeterstatus.py +++ b/kasa/emeterstatus.py @@ -49,13 +49,13 @@ class EmeterStatus(dict): except ValueError: return None - def __repr__(self): + def __repr__(self) -> str: return ( f"" ) - def __getitem__(self, item): + def __getitem__(self, item: str) -> float | None: """Return value in wanted units.""" valid_keys = [ "voltage_mv", diff --git a/kasa/exceptions.py b/kasa/exceptions.py index b646e514..7bc79653 100644 --- a/kasa/exceptions.py +++ b/kasa/exceptions.py @@ -15,10 +15,10 @@ class KasaException(Exception): class TimeoutError(KasaException, _asyncioTimeoutError): """Timeout exception for device errors.""" - def __repr__(self): + def __repr__(self) -> str: return KasaException.__repr__(self) - def __str__(self): + def __str__(self) -> str: return KasaException.__str__(self) @@ -42,11 +42,11 @@ class DeviceError(KasaException): self.error_code: SmartErrorCode | None = kwargs.get("error_code", None) super().__init__(*args) - def __repr__(self): + def __repr__(self) -> str: err_code = self.error_code.__repr__() if self.error_code else "" return f"{self.__class__.__name__}({err_code})" - def __str__(self): + def __str__(self) -> str: err_code = f" (error_code={self.error_code.name})" if self.error_code else "" return super().__str__() + err_code @@ -62,7 +62,7 @@ class _RetryableError(DeviceError): class SmartErrorCode(IntEnum): """Enum for SMART Error Codes.""" - def __str__(self): + def __str__(self) -> str: return f"{self.name}({self.value})" @staticmethod diff --git a/kasa/experimental/__init__.py b/kasa/experimental/__init__.py index 388c5736..a866787e 100644 --- a/kasa/experimental/__init__.py +++ b/kasa/experimental/__init__.py @@ -12,12 +12,12 @@ class Experimental: ENV_VAR = "KASA_EXPERIMENTAL" @classmethod - def set_enabled(cls, enabled): + def set_enabled(cls, enabled: bool) -> None: """Set the enabled value.""" cls._enabled = enabled @classmethod - def enabled(cls): + def enabled(cls) -> bool: """Get the enabled value.""" if cls._enabled is not None: return cls._enabled diff --git a/kasa/experimental/smartcameraprotocol.py b/kasa/experimental/smartcameraprotocol.py index b298fbd2..38530b16 100644 --- a/kasa/experimental/smartcameraprotocol.py +++ b/kasa/experimental/smartcameraprotocol.py @@ -50,11 +50,13 @@ class SmartCameraProtocol(SmartProtocol): """Class for SmartCamera Protocol.""" async def _handle_response_lists( - self, response_result: dict[str, Any], method, retry_count - ): + self, response_result: dict[str, Any], method: str, retry_count: int + ) -> None: pass - def _handle_response_error_code(self, resp_dict: dict, method, raise_on_error=True): + def _handle_response_error_code( + self, resp_dict: dict, method: str, raise_on_error: bool = True + ) -> None: error_code_raw = resp_dict.get("error_code") try: error_code = SmartErrorCode.from_int(error_code_raw) @@ -203,7 +205,7 @@ class _ChildCameraProtocolWrapper(SmartProtocol): device responses before returning to the caller. """ - def __init__(self, device_id: str, base_protocol: SmartProtocol): + def __init__(self, device_id: str, base_protocol: SmartProtocol) -> None: self._device_id = device_id self._protocol = base_protocol self._transport = base_protocol._transport diff --git a/kasa/experimental/sslaestransport.py b/kasa/experimental/sslaestransport.py index 68420f89..f188f144 100644 --- a/kasa/experimental/sslaestransport.py +++ b/kasa/experimental/sslaestransport.py @@ -256,7 +256,9 @@ class SslAesTransport(BaseTransport): return ret_val # type: ignore[return-value] @staticmethod - def generate_confirm_hash(local_nonce, server_nonce, pwd_hash): + def generate_confirm_hash( + local_nonce: str, server_nonce: str, pwd_hash: str + ) -> str: """Generate an auth hash for the protocol on the supplied credentials.""" expected_confirm_bytes = _sha256_hash( local_nonce.encode() + pwd_hash.encode() + server_nonce.encode() @@ -264,7 +266,9 @@ class SslAesTransport(BaseTransport): return expected_confirm_bytes + server_nonce + local_nonce @staticmethod - def generate_digest_password(local_nonce, server_nonce, pwd_hash): + def generate_digest_password( + local_nonce: str, server_nonce: str, pwd_hash: str + ) -> str: """Generate an auth hash for the protocol on the supplied credentials.""" digest_password_hash = _sha256_hash( pwd_hash.encode() + local_nonce.encode() + server_nonce.encode() @@ -275,7 +279,7 @@ class SslAesTransport(BaseTransport): @staticmethod def generate_encryption_token( - token_type, local_nonce, server_nonce, pwd_hash + token_type: str, local_nonce: str, server_nonce: str, pwd_hash: str ) -> bytes: """Generate encryption token.""" hashedKey = _sha256_hash( @@ -302,7 +306,9 @@ class SslAesTransport(BaseTransport): local_nonce, server_nonce, pwd_hash = await self.perform_handshake1() await self.perform_handshake2(local_nonce, server_nonce, pwd_hash) - async def perform_handshake2(self, local_nonce, server_nonce, pwd_hash) -> None: + async def perform_handshake2( + self, local_nonce: str, server_nonce: str, pwd_hash: str + ) -> None: """Perform the handshake.""" _LOGGER.debug("Performing handshake2 ...") digest_password = self.generate_digest_password( diff --git a/kasa/feature.py b/kasa/feature.py index e20a926d..e61cba07 100644 --- a/kasa/feature.py +++ b/kasa/feature.py @@ -162,7 +162,7 @@ class Feature: #: If set, this property will be used to get *choices*. choices_getter: str | Callable[[], list[str]] | None = None - def __post_init__(self): + def __post_init__(self) -> None: """Handle late-binding of members.""" # Populate minimum & maximum values, if range_getter is given self._container = self.container if self.container is not None else self.device @@ -188,7 +188,7 @@ class Feature: f"Read-only feat defines attribute_setter: {self.name} ({self.id}):" ) - def _get_property_value(self, getter): + def _get_property_value(self, getter: str | Callable | None) -> Any: if getter is None: return None if isinstance(getter, str): @@ -227,7 +227,7 @@ class Feature: return 0 @property - def value(self): + def value(self) -> int | float | bool | str | Enum | None: """Return the current value.""" if self.type == Feature.Type.Action: return "" @@ -264,7 +264,7 @@ class Feature: return await getattr(container, self.attribute_setter)(value) - def __repr__(self): + def __repr__(self) -> str: try: value = self.value choices = self.choices @@ -286,8 +286,8 @@ class Feature: value = " ".join( [f"*{choice}*" if choice == value else choice for choice in choices] ) - if self.precision_hint is not None and value is not None: - value = round(self.value, self.precision_hint) + if self.precision_hint is not None and isinstance(value, float): + value = round(value, self.precision_hint) s = f"{self.name} ({self.id}): {value}" if self.unit is not None: diff --git a/kasa/httpclient.py b/kasa/httpclient.py index 6b8e234c..8b69df52 100644 --- a/kasa/httpclient.py +++ b/kasa/httpclient.py @@ -4,6 +4,7 @@ from __future__ import annotations import asyncio import logging +import ssl import time from typing import Any, Dict @@ -64,7 +65,7 @@ class HttpClient: json: dict | Any | None = None, headers: dict[str, str] | None = None, cookies_dict: dict[str, str] | None = None, - ssl=False, + ssl: ssl.SSLContext | bool = False, ) -> tuple[int, dict | bytes | None]: """Send an http post request to the device. diff --git a/kasa/interfaces/energy.py b/kasa/interfaces/energy.py index 4e040e6f..7092788e 100644 --- a/kasa/interfaces/energy.py +++ b/kasa/interfaces/energy.py @@ -4,6 +4,7 @@ from __future__ import annotations from abc import ABC, abstractmethod from enum import IntFlag, auto +from typing import Any from warnings import warn from ..emeterstatus import EmeterStatus @@ -31,7 +32,7 @@ class Energy(Module, ABC): """Return True if module supports the feature.""" return module_feature in self._supported - def _initialize_features(self): + def _initialize_features(self) -> None: """Initialize features.""" device = self._device self._add_feature( @@ -151,22 +152,26 @@ class Energy(Module, ABC): """Get the current voltage in V.""" @abstractmethod - async def get_status(self): + async def get_status(self) -> EmeterStatus: """Return real-time statistics.""" @abstractmethod - async def erase_stats(self): + async def erase_stats(self) -> dict: """Erase all stats.""" @abstractmethod - async def get_daily_stats(self, *, year=None, month=None, kwh=True) -> dict: + async def get_daily_stats( + self, *, year: int | None = None, month: int | None = None, kwh: bool = True + ) -> dict: """Return daily stats for the given year & month. The return value is a dictionary of {day: energy, ...}. """ @abstractmethod - async def get_monthly_stats(self, *, year=None, kwh=True) -> dict: + async def get_monthly_stats( + self, *, year: int | None = None, kwh: bool = True + ) -> dict: """Return monthly stats for the given year.""" _deprecated_attributes = { @@ -179,7 +184,7 @@ class Energy(Module, ABC): "get_monthstat": "get_monthly_stats", } - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: if attr := self._deprecated_attributes.get(name): msg = f"{name} is deprecated, use {attr} instead" warn(msg, DeprecationWarning, stacklevel=2) diff --git a/kasa/interfaces/fan.py b/kasa/interfaces/fan.py index 89d8d82b..ade00928 100644 --- a/kasa/interfaces/fan.py +++ b/kasa/interfaces/fan.py @@ -16,5 +16,5 @@ class Fan(Module, ABC): """Return fan speed level.""" @abstractmethod - async def set_fan_speed_level(self, level: int): + async def set_fan_speed_level(self, level: int) -> dict: """Set fan speed level.""" diff --git a/kasa/interfaces/led.py b/kasa/interfaces/led.py index 2ddba00c..2d34597b 100644 --- a/kasa/interfaces/led.py +++ b/kasa/interfaces/led.py @@ -11,7 +11,7 @@ from ..module import Module class Led(Module, ABC): """Base interface to represent a LED module.""" - def _initialize_features(self): + def _initialize_features(self) -> None: """Initialize features.""" device = self._device self._add_feature( @@ -34,5 +34,5 @@ class Led(Module, ABC): """Return current led status.""" @abstractmethod - async def set_led(self, enable: bool) -> None: + async def set_led(self, enable: bool) -> dict: """Set led.""" diff --git a/kasa/interfaces/light.py b/kasa/interfaces/light.py index 5d206d1a..298ad1f8 100644 --- a/kasa/interfaces/light.py +++ b/kasa/interfaces/light.py @@ -166,7 +166,7 @@ class Light(Module, ABC): @abstractmethod async def set_color_temp( - self, temp: int, *, brightness=None, transition: int | None = None + self, temp: int, *, brightness: int | None = None, transition: int | None = None ) -> dict: """Set the color temperature of the device in kelvin. diff --git a/kasa/interfaces/lighteffect.py b/kasa/interfaces/lighteffect.py index e4efa2c2..9a69f2d0 100644 --- a/kasa/interfaces/lighteffect.py +++ b/kasa/interfaces/lighteffect.py @@ -53,7 +53,7 @@ class LightEffect(Module, ABC): LIGHT_EFFECTS_OFF = "Off" - def _initialize_features(self): + def _initialize_features(self) -> None: """Initialize features.""" device = self._device self._add_feature( @@ -96,7 +96,7 @@ class LightEffect(Module, ABC): *, brightness: int | None = None, transition: int | None = None, - ) -> None: + ) -> dict: """Set an effect on the device. If brightness or transition is defined, @@ -110,10 +110,11 @@ class LightEffect(Module, ABC): :param int transition: The wanted transition time """ + @abstractmethod async def set_custom_effect( self, effect_dict: dict, - ) -> None: + ) -> dict: """Set a custom effect on the device. :param str effect_dict: The custom effect dict to set diff --git a/kasa/interfaces/lightpreset.py b/kasa/interfaces/lightpreset.py index fc292419..586671e7 100644 --- a/kasa/interfaces/lightpreset.py +++ b/kasa/interfaces/lightpreset.py @@ -83,7 +83,7 @@ class LightPreset(Module): PRESET_NOT_SET = "Not set" - def _initialize_features(self): + def _initialize_features(self) -> None: """Initialize features.""" device = self._device self._add_feature( @@ -127,7 +127,7 @@ class LightPreset(Module): async def set_preset( self, preset_name: str, - ) -> None: + ) -> dict: """Set a light preset for the device.""" @abstractmethod @@ -135,7 +135,7 @@ class LightPreset(Module): self, preset_name: str, preset_info: LightState, - ) -> None: + ) -> dict: """Update the preset with *preset_name* with the new *preset_info*.""" @property diff --git a/kasa/iot/iotbulb.py b/kasa/iot/iotbulb.py index 3302e80d..481a9da8 100644 --- a/kasa/iot/iotbulb.py +++ b/kasa/iot/iotbulb.py @@ -54,7 +54,7 @@ class TurnOnBehavior(BaseModel): mode: BehaviorMode @root_validator - def _mode_based_on_preset(cls, values): + def _mode_based_on_preset(cls, values: dict) -> dict: """Set the mode based on the preset value.""" if values["preset"] is not None: values["mode"] = BehaviorMode.Preset @@ -209,7 +209,7 @@ class IotBulb(IotDevice): super().__init__(host=host, config=config, protocol=protocol) self._device_type = DeviceType.Bulb - async def _initialize_modules(self): + async def _initialize_modules(self) -> None: """Initialize modules not added in init.""" await super()._initialize_modules() self.add_module( @@ -307,7 +307,7 @@ class IotBulb(IotDevice): await self._query_helper(self.LIGHT_SERVICE, "get_default_behavior") ) - async def set_turn_on_behavior(self, behavior: TurnOnBehaviors): + async def set_turn_on_behavior(self, behavior: TurnOnBehaviors) -> dict: """Set the behavior for turning the bulb on. If you do not want to manually construct the behavior object, @@ -426,7 +426,7 @@ class IotBulb(IotDevice): @requires_update async def _set_color_temp( - self, temp: int, *, brightness=None, transition: int | None = None + self, temp: int, *, brightness: int | None = None, transition: int | None = None ) -> dict: """Set the color temperature of the device in kelvin. @@ -450,7 +450,7 @@ class IotBulb(IotDevice): return await self._set_light_state(light_state, transition=transition) - def _raise_for_invalid_brightness(self, value): + def _raise_for_invalid_brightness(self, value: int) -> None: if not isinstance(value, int): raise TypeError("Brightness must be an integer") if not (0 <= value <= 100): @@ -517,7 +517,7 @@ class IotBulb(IotDevice): """Return that the bulb has an emeter.""" return True - async def set_alias(self, alias: str) -> None: + async def set_alias(self, alias: str) -> dict: """Set the device name (alias). Overridden to use a different module name. diff --git a/kasa/iot/iotdevice.py b/kasa/iot/iotdevice.py index 69296823..4ee403db 100755 --- a/kasa/iot/iotdevice.py +++ b/kasa/iot/iotdevice.py @@ -19,7 +19,7 @@ import inspect import logging from collections.abc import Mapping, Sequence from datetime import datetime, timedelta, tzinfo -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, Callable, cast from warnings import warn from ..device import Device, WifiNetwork @@ -35,12 +35,12 @@ from .modules import Emeter _LOGGER = logging.getLogger(__name__) -def requires_update(f): +def requires_update(f: Callable) -> Any: """Indicate that `update` should be called before accessing this method.""" # noqa: D202 if inspect.iscoroutinefunction(f): @functools.wraps(f) - async def wrapped(*args, **kwargs): + async def wrapped(*args: Any, **kwargs: Any) -> Any: self = args[0] if self._last_update is None and f.__name__ not in self._sys_info: raise KasaException("You need to await update() to access the data") @@ -49,13 +49,13 @@ def requires_update(f): else: @functools.wraps(f) - def wrapped(*args, **kwargs): + def wrapped(*args: Any, **kwargs: Any) -> Any: self = args[0] if self._last_update is None and f.__name__ not in self._sys_info: raise KasaException("You need to await update() to access the data") return f(*args, **kwargs) - f.requires_update = True + f.requires_update = True # type: ignore[attr-defined] return wrapped @@ -197,7 +197,7 @@ class IotDevice(Device): return cast(ModuleMapping[IotModule], self._supported_modules) return self._supported_modules - def add_module(self, name: str | ModuleName[Module], module: IotModule): + def add_module(self, name: str | ModuleName[Module], module: IotModule) -> None: """Register a module.""" if name in self._modules: _LOGGER.debug("Module %s already registered, ignoring...", name) @@ -207,8 +207,12 @@ class IotDevice(Device): self._modules[name] = module def _create_request( - self, target: str, cmd: str, arg: dict | None = None, child_ids=None - ): + self, + target: str, + cmd: str, + arg: dict | None = None, + child_ids: list | None = None, + ) -> dict: if arg is None: arg = {} request: dict[str, Any] = {target: {cmd: arg}} @@ -225,8 +229,12 @@ class IotDevice(Device): raise KasaException("update() required prior accessing emeter") async def _query_helper( - self, target: str, cmd: str, arg: dict | None = None, child_ids=None - ) -> Any: + self, + target: str, + cmd: str, + arg: dict | None = None, + child_ids: list | None = None, + ) -> dict: """Query device, return results or raise an exception. :param target: Target system {system, time, emeter, ..} @@ -276,7 +284,7 @@ class IotDevice(Device): """Retrieve system information.""" return await self._query_helper("system", "get_sysinfo") - async def update(self, update_children: bool = True): + async def update(self, update_children: bool = True) -> None: """Query the device to update the data. Needed for properties that are decorated with `requires_update`. @@ -305,7 +313,7 @@ class IotDevice(Device): if not self._features: await self._initialize_features() - async def _initialize_modules(self): + async def _initialize_modules(self) -> None: """Initialize modules not added in init.""" if self.has_emeter: _LOGGER.debug( @@ -313,7 +321,7 @@ class IotDevice(Device): ) self.add_module(Module.Energy, Emeter(self, self.emeter_type)) - async def _initialize_features(self): + async def _initialize_features(self) -> None: """Initialize common features.""" self._add_feature( Feature( @@ -364,7 +372,7 @@ class IotDevice(Device): ) ) - for module in self._supported_modules.values(): + for module in self.modules.values(): module._initialize_features() for module_feat in module._module_features.values(): self._add_feature(module_feat) @@ -453,7 +461,7 @@ class IotDevice(Device): sys_info = self._sys_info return sys_info.get("alias") if sys_info else None - async def set_alias(self, alias: str) -> None: + async def set_alias(self, alias: str) -> dict: """Set the device name (alias).""" return await self._query_helper("system", "set_dev_alias", {"alias": alias}) @@ -550,7 +558,7 @@ class IotDevice(Device): return mac - async def set_mac(self, mac): + async def set_mac(self, mac: str) -> dict: """Set the mac address. :param str mac: mac in hexadecimal with colons, e.g. 01:23:45:67:89:ab @@ -576,7 +584,7 @@ class IotDevice(Device): """Turn off the device.""" raise NotImplementedError("Device subclass needs to implement this.") - async def turn_on(self, **kwargs) -> dict | None: + async def turn_on(self, **kwargs) -> dict: """Turn device on.""" raise NotImplementedError("Device subclass needs to implement this.") @@ -586,7 +594,7 @@ class IotDevice(Device): """Return True if the device is on.""" raise NotImplementedError("Device subclass needs to implement this.") - async def set_state(self, on: bool): + async def set_state(self, on: bool) -> dict: """Set the device state.""" if on: return await self.turn_on() @@ -627,7 +635,7 @@ class IotDevice(Device): async def wifi_scan(self) -> list[WifiNetwork]: # noqa: D202 """Scan for available wifi networks.""" - async def _scan(target): + async def _scan(target: str) -> dict: return await self._query_helper(target, "get_scaninfo", {"refresh": 1}) try: @@ -639,17 +647,17 @@ class IotDevice(Device): info = await _scan("smartlife.iot.common.softaponboarding") if "ap_list" not in info: - raise KasaException("Invalid response for wifi scan: %s" % info) + raise KasaException(f"Invalid response for wifi scan: {info}") return [WifiNetwork(**x) for x in info["ap_list"]] - async def wifi_join(self, ssid: str, password: str, keytype: str = "3"): # noqa: D202 + async def wifi_join(self, ssid: str, password: str, keytype: str = "3") -> dict: # noqa: D202 """Join the given wifi network. If joining the network fails, the device will return to AP mode after a while. """ - async def _join(target, payload): + async def _join(target: str, payload: dict) -> dict: return await self._query_helper(target, "set_stainfo", payload) payload = {"ssid": ssid, "password": password, "key_type": int(keytype)} diff --git a/kasa/iot/iotdimmer.py b/kasa/iot/iotdimmer.py index 04510fe2..2cd8de44 100644 --- a/kasa/iot/iotdimmer.py +++ b/kasa/iot/iotdimmer.py @@ -80,7 +80,7 @@ class IotDimmer(IotPlug): super().__init__(host=host, config=config, protocol=protocol) self._device_type = DeviceType.Dimmer - async def _initialize_modules(self): + async def _initialize_modules(self) -> None: """Initialize modules.""" await super()._initialize_modules() # TODO: need to be verified if it's okay to call these on HS220 w/o these @@ -103,7 +103,9 @@ class IotDimmer(IotPlug): return int(sys_info["brightness"]) @requires_update - async def _set_brightness(self, brightness: int, *, transition: int | None = None): + async def _set_brightness( + self, brightness: int, *, transition: int | None = None + ) -> dict: """Set the new dimmer brightness level in percentage. :param int transition: transition duration in milliseconds. @@ -134,7 +136,7 @@ class IotDimmer(IotPlug): self.DIMMER_SERVICE, "set_brightness", {"brightness": brightness} ) - async def turn_off(self, *, transition: int | None = None, **kwargs): + async def turn_off(self, *, transition: int | None = None, **kwargs) -> dict: """Turn the bulb off. :param int transition: transition duration in milliseconds. @@ -145,7 +147,7 @@ class IotDimmer(IotPlug): return await super().turn_off() @requires_update - async def turn_on(self, *, transition: int | None = None, **kwargs): + async def turn_on(self, *, transition: int | None = None, **kwargs) -> dict: """Turn the bulb on. :param int transition: transition duration in milliseconds. @@ -157,7 +159,7 @@ class IotDimmer(IotPlug): return await super().turn_on() - async def set_dimmer_transition(self, brightness: int, transition: int): + async def set_dimmer_transition(self, brightness: int, transition: int) -> dict: """Turn the bulb on to brightness percentage over transition milliseconds. A brightness value of 0 will turn off the dimmer. @@ -176,7 +178,7 @@ class IotDimmer(IotPlug): if not isinstance(transition, int): raise TypeError(f"Transition must be integer, not of {type(transition)}.") if transition <= 0: - raise ValueError("Transition value %s is not valid." % transition) + raise ValueError(f"Transition value {transition} is not valid.") return await self._query_helper( self.DIMMER_SERVICE, @@ -185,7 +187,7 @@ class IotDimmer(IotPlug): ) @requires_update - async def get_behaviors(self): + async def get_behaviors(self) -> dict: """Return button behavior settings.""" behaviors = await self._query_helper( self.DIMMER_SERVICE, "get_default_behavior", {} @@ -195,7 +197,7 @@ class IotDimmer(IotPlug): @requires_update async def set_button_action( self, action_type: ActionType, action: ButtonAction, index: int | None = None - ): + ) -> dict: """Set action to perform on button click/hold. :param action_type ActionType: whether to control double click or hold action. @@ -209,15 +211,17 @@ class IotDimmer(IotPlug): if index is not None: payload["index"] = index - await self._query_helper(self.DIMMER_SERVICE, action_type_setter, payload) + return await self._query_helper( + self.DIMMER_SERVICE, action_type_setter, payload + ) @requires_update - async def set_fade_time(self, fade_type: FadeType, time: int): + async def set_fade_time(self, fade_type: FadeType, time: int) -> dict: """Set time for fade in / fade out.""" fade_type_setter = f"set_{fade_type}_time" payload = {"fadeTime": time} - await self._query_helper(self.DIMMER_SERVICE, fade_type_setter, payload) + return await self._query_helper(self.DIMMER_SERVICE, fade_type_setter, payload) @property # type: ignore @requires_update diff --git a/kasa/iot/iotlightstrip.py b/kasa/iot/iotlightstrip.py index abe532f7..14e98684 100644 --- a/kasa/iot/iotlightstrip.py +++ b/kasa/iot/iotlightstrip.py @@ -57,7 +57,7 @@ class IotLightStrip(IotBulb): super().__init__(host=host, config=config, protocol=protocol) self._device_type = DeviceType.LightStrip - async def _initialize_modules(self): + async def _initialize_modules(self) -> None: """Initialize modules not added in init.""" await super()._initialize_modules() self.add_module( diff --git a/kasa/iot/iotmodule.py b/kasa/iot/iotmodule.py index 7829c856..ddb0da2c 100644 --- a/kasa/iot/iotmodule.py +++ b/kasa/iot/iotmodule.py @@ -1,6 +1,9 @@ """Base class for IOT module implementations.""" +from __future__ import annotations + import logging +from typing import Any from ..exceptions import KasaException from ..module import Module @@ -24,16 +27,16 @@ merge = _merge_dict class IotModule(Module): """Base class implemention for all IOT modules.""" - def call(self, method, params=None): + async def call(self, method: str, params: dict | None = None) -> dict: """Call the given method with the given parameters.""" - return self._device._query_helper(self._module, method, params) + return await self._device._query_helper(self._module, method, params) - def query_for_command(self, query, params=None): + def query_for_command(self, query: str, params: dict | None = None) -> dict: """Create a request object for the given parameters.""" return self._device._create_request(self._module, query, params) @property - def estimated_query_response_size(self): + def estimated_query_response_size(self) -> int: """Estimated maximum size of query response. The inheriting modules implement this to estimate how large a query response @@ -42,7 +45,7 @@ class IotModule(Module): return 256 # Estimate for modules that don't specify @property - def data(self): + def data(self) -> dict[str, Any]: """Return the module specific raw data from the last update.""" dev = self._device q = self.query() diff --git a/kasa/iot/iotplug.py b/kasa/iot/iotplug.py index 3a119318..ab10e932 100644 --- a/kasa/iot/iotplug.py +++ b/kasa/iot/iotplug.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging +from typing import Any from ..device_type import DeviceType from ..deviceconfig import DeviceConfig @@ -54,7 +55,7 @@ class IotPlug(IotDevice): super().__init__(host=host, config=config, protocol=protocol) self._device_type = DeviceType.Plug - async def _initialize_modules(self): + async def _initialize_modules(self) -> None: """Initialize modules.""" await super()._initialize_modules() self.add_module(Module.IotSchedule, Schedule(self, "schedule")) @@ -71,11 +72,11 @@ class IotPlug(IotDevice): sys_info = self.sys_info return bool(sys_info["relay_state"]) - async def turn_on(self, **kwargs): + async def turn_on(self, **kwargs: Any) -> dict: """Turn the switch on.""" return await self._query_helper("system", "set_relay_state", {"state": 1}) - async def turn_off(self, **kwargs): + async def turn_off(self, **kwargs: Any) -> dict: """Turn the switch off.""" return await self._query_helper("system", "set_relay_state", {"state": 0}) diff --git a/kasa/iot/iotstrip.py b/kasa/iot/iotstrip.py index a18f2756..a212dd61 100755 --- a/kasa/iot/iotstrip.py +++ b/kasa/iot/iotstrip.py @@ -26,7 +26,7 @@ from .modules import Antitheft, Cloud, Countdown, Emeter, Led, Schedule, Time, U _LOGGER = logging.getLogger(__name__) -def merge_sums(dicts): +def merge_sums(dicts: list[dict]) -> dict: """Merge the sum of dicts.""" total_dict: defaultdict[int, float] = defaultdict(lambda: 0.0) for sum_dict in dicts: @@ -99,7 +99,7 @@ class IotStrip(IotDevice): self.emeter_type = "emeter" self._device_type = DeviceType.Strip - async def _initialize_modules(self): + async def _initialize_modules(self) -> None: """Initialize modules.""" # Strip has different modules to plug so do not call super self.add_module(Module.IotAntitheft, Antitheft(self, "anti_theft")) @@ -121,7 +121,7 @@ class IotStrip(IotDevice): """Return if any of the outlets are on.""" return any(plug.is_on for plug in self.children) - async def update(self, update_children: bool = True): + async def update(self, update_children: bool = True) -> None: """Update some of the attributes. Needed for methods that are decorated with `requires_update`. @@ -150,20 +150,20 @@ class IotStrip(IotDevice): if not self.features: await self._initialize_features() - async def _initialize_features(self): + async def _initialize_features(self) -> None: """Initialize common features.""" # Do not initialize features until children are created if not self.children: return await super()._initialize_features() - async def turn_on(self, **kwargs): + async def turn_on(self, **kwargs) -> dict: """Turn the strip on.""" - await self._query_helper("system", "set_relay_state", {"state": 1}) + return await self._query_helper("system", "set_relay_state", {"state": 1}) - async def turn_off(self, **kwargs): + async def turn_off(self, **kwargs) -> dict: """Turn the strip off.""" - await self._query_helper("system", "set_relay_state", {"state": 0}) + return await self._query_helper("system", "set_relay_state", {"state": 0}) @property # type: ignore @requires_update @@ -188,7 +188,7 @@ class StripEmeter(IotModule, Energy): """Return True if module supports the feature.""" return module_feature in self._supported - def query(self): + def query(self) -> dict: """Return the base query.""" return {} @@ -246,11 +246,13 @@ class StripEmeter(IotModule, Energy): ] ) - async def erase_stats(self): + async def erase_stats(self) -> dict: """Erase energy meter statistics for all plugs.""" for plug in self._device.children: await plug.modules[Module.Energy].erase_stats() + return {} + @property # type: ignore def consumption_this_month(self) -> float | None: """Return this month's energy consumption in kWh.""" @@ -320,7 +322,7 @@ class IotStripPlug(IotPlug): self.protocol = parent.protocol # Must use the same connection as the parent self._on_since: datetime | None = None - async def _initialize_modules(self): + async def _initialize_modules(self) -> None: """Initialize modules not added in init.""" if self.has_emeter: self.add_module(Module.Energy, Emeter(self, self.emeter_type)) @@ -329,7 +331,7 @@ class IotStripPlug(IotPlug): self.add_module(Module.IotSchedule, Schedule(self, "schedule")) self.add_module(Module.IotCountdown, Countdown(self, "countdown")) - async def _initialize_features(self): + async def _initialize_features(self) -> None: """Initialize common features.""" self._add_feature( Feature( @@ -353,19 +355,20 @@ class IotStripPlug(IotPlug): type=Feature.Type.Sensor, ) ) - for module in self._supported_modules.values(): + + for module in self.modules.values(): module._initialize_features() for module_feat in module._module_features.values(): self._add_feature(module_feat) - async def update(self, update_children: bool = True): + async def update(self, update_children: bool = True) -> None: """Query the device to update the data. Needed for properties that are decorated with `requires_update`. """ await self._update(update_children) - async def _update(self, update_children: bool = True): + async def _update(self, update_children: bool = True) -> None: """Query the device to update the data. Internal implementation to allow patching of public update in the cli @@ -379,8 +382,12 @@ class IotStripPlug(IotPlug): await self._initialize_features() def _create_request( - self, target: str, cmd: str, arg: dict | None = None, child_ids=None - ): + self, + target: str, + cmd: str, + arg: dict | None = None, + child_ids: list | None = None, + ) -> dict: request: dict[str, Any] = { "context": {"child_ids": [self.child_id]}, target: {cmd: arg}, @@ -388,8 +395,12 @@ class IotStripPlug(IotPlug): return request async def _query_helper( - self, target: str, cmd: str, arg: dict | None = None, child_ids=None - ) -> Any: + self, + target: str, + cmd: str, + arg: dict | None = None, + child_ids: list | None = None, + ) -> dict: """Override query helper to include the child_ids.""" return await self._parent._query_helper( target, cmd, arg, child_ids=[self.child_id] diff --git a/kasa/iot/modules/ambientlight.py b/kasa/iot/modules/ambientlight.py index 691f88f1..ac5c3488 100644 --- a/kasa/iot/modules/ambientlight.py +++ b/kasa/iot/modules/ambientlight.py @@ -11,7 +11,7 @@ _LOGGER = logging.getLogger(__name__) class AmbientLight(IotModule): """Implements ambient light controls for the motion sensor.""" - def _initialize_features(self): + def _initialize_features(self) -> None: """Initialize features after the initial update.""" self._add_feature( Feature( @@ -40,7 +40,7 @@ class AmbientLight(IotModule): ) ) - def query(self): + def query(self) -> dict: """Request configuration.""" req = merge( self.query_for_command("get_config"), @@ -74,18 +74,18 @@ class AmbientLight(IotModule): """Return True if the module is enabled.""" return int(self.data["get_current_brt"]["value"]) - async def set_enabled(self, state: bool): + async def set_enabled(self, state: bool) -> dict: """Enable/disable LAS.""" return await self.call("set_enable", {"enable": int(state)}) - async def current_brightness(self) -> int: + async def current_brightness(self) -> dict: """Return current brightness. Return value units. """ return await self.call("get_current_brt") - async def set_brightness_limit(self, value: int): + async def set_brightness_limit(self, value: int) -> dict: """Set the limit when the motion sensor is inactive. See `presets` for preset values. Custom values are also likely allowed. diff --git a/kasa/iot/modules/cloud.py b/kasa/iot/modules/cloud.py index 8be393d9..10097e64 100644 --- a/kasa/iot/modules/cloud.py +++ b/kasa/iot/modules/cloud.py @@ -24,7 +24,7 @@ class CloudInfo(BaseModel): class Cloud(IotModule): """Module implementing support for cloud services.""" - def _initialize_features(self): + def _initialize_features(self) -> None: """Initialize features after the initial update.""" self._add_feature( Feature( @@ -44,7 +44,7 @@ class Cloud(IotModule): """Return true if device is connected to the cloud.""" return self.info.binded - def query(self): + def query(self) -> dict: """Request cloud connectivity info.""" return self.query_for_command("get_info") @@ -53,20 +53,20 @@ class Cloud(IotModule): """Return information about the cloud connectivity.""" return CloudInfo.parse_obj(self.data["get_info"]) - def get_available_firmwares(self): + def get_available_firmwares(self) -> dict: """Return list of available firmwares.""" return self.query_for_command("get_intl_fw_list") - def set_server(self, url: str): + def set_server(self, url: str) -> dict: """Set the update server URL.""" return self.query_for_command("set_server_url", {"server": url}) - def connect(self, username: str, password: str): + def connect(self, username: str, password: str) -> dict: """Login to the cloud using given information.""" return self.query_for_command( "bind", {"username": username, "password": password} ) - def disconnect(self): + def disconnect(self) -> dict: """Disconnect from the cloud.""" return self.query_for_command("unbind") diff --git a/kasa/iot/modules/emeter.py b/kasa/iot/modules/emeter.py index 1764af90..012bda04 100644 --- a/kasa/iot/modules/emeter.py +++ b/kasa/iot/modules/emeter.py @@ -70,7 +70,7 @@ class Emeter(Usage, EnergyInterface): """Get the current voltage in V.""" return self.status.voltage - async def erase_stats(self): + async def erase_stats(self) -> dict: """Erase all stats. Uses different query than usage meter. @@ -81,7 +81,9 @@ class Emeter(Usage, EnergyInterface): """Return real-time statistics.""" return EmeterStatus(await self.call("get_realtime")) - async def get_daily_stats(self, *, year=None, month=None, kwh=True) -> dict: + async def get_daily_stats( + self, *, year: int | None = None, month: int | None = None, kwh: bool = True + ) -> dict: """Return daily stats for the given year & month. The return value is a dictionary of {day: energy, ...}. @@ -90,7 +92,9 @@ class Emeter(Usage, EnergyInterface): data = self._convert_stat_data(data["day_list"], entry_key="day", kwh=kwh) return data - async def get_monthly_stats(self, *, year=None, kwh=True) -> dict: + async def get_monthly_stats( + self, *, year: int | None = None, kwh: bool = True + ) -> dict: """Return monthly stats for the given year. The return value is a dictionary of {month: energy, ...}. diff --git a/kasa/iot/modules/led.py b/kasa/iot/modules/led.py index 48301f23..8a5727b0 100644 --- a/kasa/iot/modules/led.py +++ b/kasa/iot/modules/led.py @@ -14,7 +14,7 @@ class Led(IotModule, LedInterface): return {} @property - def mode(self): + def mode(self) -> str: """LED mode setting. "always", "never" @@ -27,7 +27,7 @@ class Led(IotModule, LedInterface): sys_info = self.data return bool(1 - sys_info["led_off"]) - async def set_led(self, state: bool): + async def set_led(self, state: bool) -> dict: """Set the state of the led (night mode).""" return await self.call("set_led_off", {"off": int(not state)}) diff --git a/kasa/iot/modules/light.py b/kasa/iot/modules/light.py index d83031c8..7c9342c9 100644 --- a/kasa/iot/modules/light.py +++ b/kasa/iot/modules/light.py @@ -27,7 +27,7 @@ class Light(IotModule, LightInterface): _device: IotBulb | IotDimmer _light_state: LightState - def _initialize_features(self): + def _initialize_features(self) -> None: """Initialize features.""" super()._initialize_features() device = self._device @@ -185,7 +185,7 @@ class Light(IotModule, LightInterface): return bulb._color_temp async def set_color_temp( - self, temp: int, *, brightness=None, transition: int | None = None + self, temp: int, *, brightness: int | None = None, transition: int | None = None ) -> dict: """Set the color temperature of the device in kelvin. diff --git a/kasa/iot/modules/lighteffect.py b/kasa/iot/modules/lighteffect.py index 3a13f680..cdfaaae1 100644 --- a/kasa/iot/modules/lighteffect.py +++ b/kasa/iot/modules/lighteffect.py @@ -50,7 +50,7 @@ class LightEffect(IotModule, LightEffectInterface): *, brightness: int | None = None, transition: int | None = None, - ) -> None: + ) -> dict: """Set an effect on the device. If brightness or transition is defined, @@ -73,7 +73,7 @@ class LightEffect(IotModule, LightEffectInterface): effect_dict = EFFECT_MAPPING_V1["Aurora"] effect_dict = {**effect_dict} effect_dict["enable"] = 0 - await self.set_custom_effect(effect_dict) + return await self.set_custom_effect(effect_dict) elif effect not in EFFECT_MAPPING_V1: raise ValueError(f"The effect {effect} is not a built in effect.") else: @@ -84,12 +84,12 @@ class LightEffect(IotModule, LightEffectInterface): if transition is not None: effect_dict["transition"] = transition - await self.set_custom_effect(effect_dict) + return await self.set_custom_effect(effect_dict) async def set_custom_effect( self, effect_dict: dict, - ) -> None: + ) -> dict: """Set a custom effect on the device. :param str effect_dict: The custom effect dict to set @@ -104,7 +104,7 @@ class LightEffect(IotModule, LightEffectInterface): """Return True if the device supports setting custom effects.""" return True - def query(self): + def query(self) -> dict: """Return the base query.""" return {} diff --git a/kasa/iot/modules/lightpreset.py b/kasa/iot/modules/lightpreset.py index bae401ef..13fee33e 100644 --- a/kasa/iot/modules/lightpreset.py +++ b/kasa/iot/modules/lightpreset.py @@ -41,7 +41,7 @@ class LightPreset(IotModule, LightPresetInterface): _presets: dict[str, IotLightPreset] _preset_list: list[str] - async def _post_update_hook(self): + async def _post_update_hook(self) -> None: """Update the internal presets.""" self._presets = { f"Light preset {index+1}": IotLightPreset(**vals) @@ -93,7 +93,7 @@ class LightPreset(IotModule, LightPresetInterface): async def set_preset( self, preset_name: str, - ) -> None: + ) -> dict: """Set a light preset for the device.""" light = self._device.modules[Module.Light] if preset_name == self.PRESET_NOT_SET: @@ -104,7 +104,7 @@ class LightPreset(IotModule, LightPresetInterface): elif (preset := self._presets.get(preset_name)) is None: # type: ignore[assignment] raise ValueError(f"{preset_name} is not a valid preset: {self.preset_list}") - await light.set_state(preset) + return await light.set_state(preset) @property def has_save_preset(self) -> bool: @@ -115,7 +115,7 @@ class LightPreset(IotModule, LightPresetInterface): self, preset_name: str, preset_state: LightState, - ) -> None: + ) -> dict: """Update the preset with preset_name with the new preset_info.""" if len(self._presets) == 0: raise KasaException("Device does not supported saving presets") @@ -129,7 +129,7 @@ class LightPreset(IotModule, LightPresetInterface): return await self.call("set_preferred_state", state) - def query(self): + def query(self) -> dict: """Return the base query.""" return {} @@ -142,7 +142,7 @@ class LightPreset(IotModule, LightPresetInterface): if "id" not in vals ] - async def _deprecated_save_preset(self, preset: IotLightPreset): + async def _deprecated_save_preset(self, preset: IotLightPreset) -> dict: """Save a setting preset. You can either construct a preset object manually, or pass an existing one diff --git a/kasa/iot/modules/motion.py b/kasa/iot/modules/motion.py index db272e2f..e65cbd93 100644 --- a/kasa/iot/modules/motion.py +++ b/kasa/iot/modules/motion.py @@ -24,7 +24,7 @@ class Range(Enum): class Motion(IotModule): """Implements the motion detection (PIR) module.""" - def _initialize_features(self): + def _initialize_features(self) -> None: """Initialize features after the initial update.""" # Only add features if the device supports the module if "get_config" not in self.data: @@ -48,7 +48,7 @@ class Motion(IotModule): ) ) - def query(self): + def query(self) -> dict: """Request PIR configuration.""" return self.query_for_command("get_config") @@ -67,13 +67,13 @@ class Motion(IotModule): """Return True if module is enabled.""" return bool(self.config["enable"]) - async def set_enabled(self, state: bool): + async def set_enabled(self, state: bool) -> dict: """Enable/disable PIR.""" return await self.call("set_enable", {"enable": int(state)}) async def set_range( self, *, range: Range | None = None, custom_range: int | None = None - ): + ) -> dict: """Set the range for the sensor. :param range: for using standard ranges @@ -93,7 +93,7 @@ class Motion(IotModule): """Return inactivity timeout in milliseconds.""" return self.config["cold_time"] - async def set_inactivity_timeout(self, timeout: int): + async def set_inactivity_timeout(self, timeout: int) -> dict: """Set inactivity timeout in milliseconds. Note, that you need to delete the default "Smart Control" rule in the app diff --git a/kasa/iot/modules/rulemodule.py b/kasa/iot/modules/rulemodule.py index 6e3a2b22..2515b71b 100644 --- a/kasa/iot/modules/rulemodule.py +++ b/kasa/iot/modules/rulemodule.py @@ -57,7 +57,7 @@ _LOGGER = logging.getLogger(__name__) class RuleModule(IotModule): """Base class for rule-based modules, such as countdown and antitheft.""" - def query(self): + def query(self) -> dict: """Prepare the query for rules.""" q = self.query_for_command("get_rules") return merge(q, self.query_for_command("get_next_action")) @@ -73,14 +73,14 @@ class RuleModule(IotModule): _LOGGER.error("Unable to read rule list: %s (data: %s)", ex, self.data) return [] - async def set_enabled(self, state: bool): + async def set_enabled(self, state: bool) -> dict: """Enable or disable the service.""" - return await self.call("set_overall_enable", state) + return await self.call("set_overall_enable", {"enable": state}) - async def delete_rule(self, rule: Rule): + async def delete_rule(self, rule: Rule) -> dict: """Delete the given rule.""" return await self.call("delete_rule", {"id": rule.id}) - async def delete_all_rules(self): + async def delete_all_rules(self) -> dict: """Delete all rules.""" return await self.call("delete_all_rules") diff --git a/kasa/iot/modules/time.py b/kasa/iot/modules/time.py index 8c672d21..f65dd910 100644 --- a/kasa/iot/modules/time.py +++ b/kasa/iot/modules/time.py @@ -15,14 +15,14 @@ class Time(IotModule, TimeInterface): _timezone: tzinfo = timezone.utc - def query(self): + def query(self) -> dict: """Request time and timezone.""" q = self.query_for_command("get_time") merge(q, self.query_for_command("get_timezone")) return q - async def _post_update_hook(self): + async def _post_update_hook(self) -> None: """Perform actions after a device update.""" if res := self.data.get("get_timezone"): self._timezone = await get_timezone(res.get("index")) @@ -47,7 +47,7 @@ class Time(IotModule, TimeInterface): """Return current timezone.""" return self._timezone - async def get_time(self): + async def get_time(self) -> datetime | None: """Return current device time.""" try: res = await self.call("get_time") @@ -88,6 +88,6 @@ class Time(IotModule, TimeInterface): except Exception as ex: raise KasaException(ex) from ex - async def get_timezone(self): + async def get_timezone(self) -> dict: """Request timezone information from the device.""" return await self.call("get_timezone") diff --git a/kasa/iot/modules/usage.py b/kasa/iot/modules/usage.py index 5acf1dbe..89d8cca2 100644 --- a/kasa/iot/modules/usage.py +++ b/kasa/iot/modules/usage.py @@ -10,7 +10,7 @@ from ..iotmodule import IotModule, merge class Usage(IotModule): """Baseclass for emeter/usage interfaces.""" - def query(self): + def query(self) -> dict: """Return the base query.""" now = datetime.now() year = now.year @@ -25,22 +25,22 @@ class Usage(IotModule): return req @property - def estimated_query_response_size(self): + def estimated_query_response_size(self) -> int: """Estimated maximum query response size.""" return 2048 @property - def daily_data(self): + def daily_data(self) -> list[dict]: """Return statistics on daily basis.""" return self.data["get_daystat"]["day_list"] @property - def monthly_data(self): + def monthly_data(self) -> list[dict]: """Return statistics on monthly basis.""" return self.data["get_monthstat"]["month_list"] @property - def usage_today(self): + def usage_today(self) -> int | None: """Return today's usage in minutes.""" today = datetime.now().day # Traverse the list in reverse order to find the latest entry. @@ -50,7 +50,7 @@ class Usage(IotModule): return None @property - def usage_this_month(self): + def usage_this_month(self) -> int | None: """Return usage in this month in minutes.""" this_month = datetime.now().month # Traverse the list in reverse order to find the latest entry. @@ -59,7 +59,9 @@ class Usage(IotModule): return entry["time"] return None - async def get_raw_daystat(self, *, year=None, month=None) -> dict: + async def get_raw_daystat( + self, *, year: int | None = None, month: int | None = None + ) -> dict: """Return raw daily stats for the given year & month.""" if year is None: year = datetime.now().year @@ -68,14 +70,16 @@ class Usage(IotModule): return await self.call("get_daystat", {"year": year, "month": month}) - async def get_raw_monthstat(self, *, year=None) -> dict: + async def get_raw_monthstat(self, *, year: int | None = None) -> dict: """Return raw monthly stats for the given year.""" if year is None: year = datetime.now().year return await self.call("get_monthstat", {"year": year}) - async def get_daystat(self, *, year=None, month=None) -> dict: + async def get_daystat( + self, *, year: int | None = None, month: int | None = None + ) -> dict: """Return daily stats for the given year & month. The return value is a dictionary of {day: time, ...}. @@ -84,7 +88,7 @@ class Usage(IotModule): data = self._convert_stat_data(data["day_list"], entry_key="day") return data - async def get_monthstat(self, *, year=None) -> dict: + async def get_monthstat(self, *, year: int | None = None) -> dict: """Return monthly stats for the given year. The return value is a dictionary of {month: time, ...}. @@ -93,11 +97,11 @@ class Usage(IotModule): data = self._convert_stat_data(data["month_list"], entry_key="month") return data - async def erase_stats(self): + async def erase_stats(self) -> dict: """Erase all stats.""" return await self.call("erase_runtime_stat") - def _convert_stat_data(self, data, entry_key) -> dict: + def _convert_stat_data(self, data: list[dict], entry_key: str) -> dict: """Return usage information keyed with the day/month. The incoming data is a list of dictionaries:: @@ -113,6 +117,6 @@ class Usage(IotModule): if not data: return {} - data = {entry[entry_key]: entry["time"] for entry in data} + res = {entry[entry_key]: entry["time"] for entry in data} - return data + return res diff --git a/kasa/json.py b/kasa/json.py index aed8cd56..10edc690 100755 --- a/kasa/json.py +++ b/kasa/json.py @@ -1,9 +1,13 @@ """JSON abstraction.""" +from __future__ import annotations + +from typing import Any, Callable + try: import orjson - def dumps(obj, *, default=None): + def dumps(obj: Any, *, default: Callable | None = None) -> str: """Dump JSON.""" return orjson.dumps(obj).decode() @@ -11,7 +15,7 @@ try: except ImportError: import json - def dumps(obj, *, default=None): + def dumps(obj: Any, *, default: Callable | None = None) -> str: """Dump JSON.""" # Separators specified for consistency with orjson return json.dumps(obj, separators=(",", ":")) diff --git a/kasa/klaptransport.py b/kasa/klaptransport.py index 02e0b2b7..870304d1 100644 --- a/kasa/klaptransport.py +++ b/kasa/klaptransport.py @@ -50,7 +50,8 @@ import logging import secrets import struct import time -from typing import TYPE_CHECKING, Any, cast +from asyncio import Future +from typing import TYPE_CHECKING, Any, Generator, cast from cryptography.hazmat.primitives import padding from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes @@ -110,10 +111,10 @@ class KlapTransport(BaseTransport): else: self._local_auth_hash = base64.b64decode(self._credentials_hash.encode()) # type: ignore[union-attr] self._default_credentials_auth_hash: dict[str, bytes] = {} - self._blank_auth_hash = None + self._blank_auth_hash: bytes | None = None self._handshake_lock = asyncio.Lock() self._query_lock = asyncio.Lock() - self._handshake_done = False + self._handshake_done: bool = False self._encryption_session: KlapEncryptionSession | None = None self._session_expire_at: float | None = None @@ -125,7 +126,7 @@ class KlapTransport(BaseTransport): self._request_url = self._app_url / "request" @property - def default_port(self): + def default_port(self) -> int: """Default port for the transport.""" return self.DEFAULT_PORT @@ -242,7 +243,7 @@ class KlapTransport(BaseTransport): raise AuthenticationError(msg) async def perform_handshake2( - self, local_seed, remote_seed, auth_hash + self, local_seed: bytes, remote_seed: bytes, auth_hash: bytes ) -> KlapEncryptionSession: """Perform handshake2.""" # Handshake 2 has the following payload: @@ -277,7 +278,7 @@ class KlapTransport(BaseTransport): return KlapEncryptionSession(local_seed, remote_seed, auth_hash) - async def perform_handshake(self) -> Any: + async def perform_handshake(self) -> None: """Perform handshake1 and handshake2. Sets the encryption_session if successful. @@ -309,14 +310,14 @@ class KlapTransport(BaseTransport): _LOGGER.debug("Handshake with %s complete", self._host) - def _handshake_session_expired(self): + def _handshake_session_expired(self) -> bool: """Return true if session has expired.""" return ( self._session_expire_at is None or self._session_expire_at - time.monotonic() <= 0 ) - async def send(self, request: str): + async def send(self, request: str) -> Generator[Future, None, dict[str, str]]: # type: ignore[override] """Send the request.""" if not self._handshake_done or self._handshake_session_expired(): await self.perform_handshake() @@ -355,6 +356,7 @@ class KlapTransport(BaseTransport): if TYPE_CHECKING: assert self._encryption_session + assert isinstance(response_data, bytes) try: decrypted_response = self._encryption_session.decrypt(response_data) except Exception as ex: @@ -378,7 +380,7 @@ class KlapTransport(BaseTransport): self._handshake_done = False @staticmethod - def generate_auth_hash(creds: Credentials): + def generate_auth_hash(creds: Credentials) -> bytes: """Generate an md5 auth hash for the protocol on the supplied credentials.""" un = creds.username pw = creds.password @@ -388,19 +390,19 @@ class KlapTransport(BaseTransport): @staticmethod def handshake1_seed_auth_hash( local_seed: bytes, remote_seed: bytes, auth_hash: bytes - ): + ) -> bytes: """Generate an md5 auth hash for the protocol on the supplied credentials.""" return _sha256(local_seed + auth_hash) @staticmethod def handshake2_seed_auth_hash( local_seed: bytes, remote_seed: bytes, auth_hash: bytes - ): + ) -> bytes: """Generate an md5 auth hash for the protocol on the supplied credentials.""" return _sha256(remote_seed + auth_hash) @staticmethod - def generate_owner_hash(creds: Credentials): + def generate_owner_hash(creds: Credentials) -> bytes: """Return the MD5 hash of the username in this object.""" un = creds.username return md5(un.encode()) @@ -410,7 +412,7 @@ class KlapTransportV2(KlapTransport): """Implementation of the KLAP encryption protocol with v2 hanshake hashes.""" @staticmethod - def generate_auth_hash(creds: Credentials): + def generate_auth_hash(creds: Credentials) -> bytes: """Generate an md5 auth hash for the protocol on the supplied credentials.""" un = creds.username pw = creds.password @@ -420,14 +422,14 @@ class KlapTransportV2(KlapTransport): @staticmethod def handshake1_seed_auth_hash( local_seed: bytes, remote_seed: bytes, auth_hash: bytes - ): + ) -> bytes: """Generate an md5 auth hash for the protocol on the supplied credentials.""" return _sha256(local_seed + remote_seed + auth_hash) @staticmethod def handshake2_seed_auth_hash( local_seed: bytes, remote_seed: bytes, auth_hash: bytes - ): + ) -> bytes: """Generate an md5 auth hash for the protocol on the supplied credentials.""" return _sha256(remote_seed + local_seed + auth_hash) @@ -440,7 +442,7 @@ class KlapEncryptionSession: _cipher: Cipher - def __init__(self, local_seed, remote_seed, user_hash): + def __init__(self, local_seed: bytes, remote_seed: bytes, user_hash: bytes) -> None: self.local_seed = local_seed self.remote_seed = remote_seed self.user_hash = user_hash @@ -449,11 +451,15 @@ class KlapEncryptionSession: self._aes = algorithms.AES(self._key) self._sig = self._sig_derive(local_seed, remote_seed, user_hash) - def _key_derive(self, local_seed, remote_seed, user_hash): + def _key_derive( + self, local_seed: bytes, remote_seed: bytes, user_hash: bytes + ) -> bytes: payload = b"lsk" + local_seed + remote_seed + user_hash return hashlib.sha256(payload).digest()[:16] - def _iv_derive(self, local_seed, remote_seed, user_hash): + def _iv_derive( + self, local_seed: bytes, remote_seed: bytes, user_hash: bytes + ) -> tuple[bytes, int]: # iv is first 16 bytes of sha256, where the last 4 bytes forms the # sequence number used in requests and is incremented on each request payload = b"iv" + local_seed + remote_seed + user_hash @@ -461,17 +467,19 @@ class KlapEncryptionSession: seq = int.from_bytes(fulliv[-4:], "big", signed=True) return (fulliv[:12], seq) - def _sig_derive(self, local_seed, remote_seed, user_hash): + def _sig_derive( + self, local_seed: bytes, remote_seed: bytes, user_hash: bytes + ) -> bytes: # used to create a hash with which to prefix each request payload = b"ldk" + local_seed + remote_seed + user_hash return hashlib.sha256(payload).digest()[:28] - def _generate_cipher(self): + def _generate_cipher(self) -> None: iv_seq = self._iv + PACK_SIGNED_LONG(self._seq) cbc = modes.CBC(iv_seq) self._cipher = Cipher(self._aes, cbc) - def encrypt(self, msg): + def encrypt(self, msg: bytes | str) -> tuple[bytes, int]: """Encrypt the data and increment the sequence number.""" self._seq += 1 self._generate_cipher() @@ -488,7 +496,7 @@ class KlapEncryptionSession: ).digest() return (signature + ciphertext, self._seq) - def decrypt(self, msg): + def decrypt(self, msg: bytes) -> str: """Decrypt the data.""" decryptor = self._cipher.decryptor() dp = decryptor.update(msg[32:]) + decryptor.finalize() diff --git a/kasa/module.py b/kasa/module.py index 8b68881e..c4e9f9a1 100644 --- a/kasa/module.py +++ b/kasa/module.py @@ -135,13 +135,13 @@ class Module(ABC): # SMARTCAMERA only modules Camera: Final[ModuleName[experimental.Camera]] = ModuleName("Camera") - def __init__(self, device: Device, module: str): + def __init__(self, device: Device, module: str) -> None: self._device = device self._module = module self._module_features: dict[str, Feature] = {} @abstractmethod - def query(self): + def query(self) -> dict: """Query to execute during the update cycle. The inheriting modules implement this to include their wanted @@ -150,10 +150,10 @@ class Module(ABC): @property @abstractmethod - def data(self): + def data(self) -> dict: """Return the module specific raw data from the last update.""" - def _initialize_features(self): # noqa: B027 + def _initialize_features(self) -> None: # noqa: B027 """Initialize features after the initial update. This can be implemented if features depend on module query responses. @@ -162,7 +162,7 @@ class Module(ABC): children's modules. """ - async def _post_update_hook(self): # noqa: B027 + async def _post_update_hook(self) -> None: # noqa: B027 """Perform actions after a device update. This can be implemented if a module needs to perform actions each time @@ -171,11 +171,11 @@ class Module(ABC): *_initialize_features* on the first update. """ - def _add_feature(self, feature: Feature): + def _add_feature(self, feature: Feature) -> None: """Add module feature.""" id_ = feature.id if id_ in self._module_features: - raise KasaException("Duplicate id detected %s" % id_) + raise KasaException(f"Duplicate id detected {id_}") self._module_features[id_] = feature def __repr__(self) -> str: diff --git a/kasa/protocol.py b/kasa/protocol.py index 140e9c41..f2560987 100755 --- a/kasa/protocol.py +++ b/kasa/protocol.py @@ -130,7 +130,7 @@ class BaseProtocol(ABC): self._transport = transport @property - def _host(self): + def _host(self) -> str: return self._transport._host @property diff --git a/kasa/smart/effects.py b/kasa/smart/effects.py index e0ed615c..815f777b 100644 --- a/kasa/smart/effects.py +++ b/kasa/smart/effects.py @@ -15,7 +15,9 @@ class SmartLightEffect(LightEffectInterface, ABC): """ @abstractmethod - async def set_brightness(self, brightness: int, *, transition: int | None = None): + async def set_brightness( + self, brightness: int, *, transition: int | None = None + ) -> dict: """Set effect brightness.""" @property diff --git a/kasa/smart/modules/alarm.py b/kasa/smart/modules/alarm.py index 1dacf181..f1bf7236 100644 --- a/kasa/smart/modules/alarm.py +++ b/kasa/smart/modules/alarm.py @@ -20,7 +20,7 @@ class Alarm(SmartModule): "get_support_alarm_type_list": None, # This should be needed only once } - def _initialize_features(self): + def _initialize_features(self) -> None: """Initialize features. This is implemented as some features depend on device responses. @@ -100,7 +100,7 @@ class Alarm(SmartModule): """Return current alarm sound.""" return self.data["get_alarm_configure"]["type"] - async def set_alarm_sound(self, sound: str): + async def set_alarm_sound(self, sound: str) -> dict: """Set alarm sound. See *alarm_sounds* for list of available sounds. @@ -119,7 +119,7 @@ class Alarm(SmartModule): """Return alarm volume.""" return self.data["get_alarm_configure"]["volume"] - async def set_alarm_volume(self, volume: Literal["low", "normal", "high"]): + async def set_alarm_volume(self, volume: Literal["low", "normal", "high"]) -> dict: """Set alarm volume.""" payload = self.data["get_alarm_configure"].copy() payload["volume"] = volume diff --git a/kasa/smart/modules/autooff.py b/kasa/smart/modules/autooff.py index ae1bb082..4fefb000 100644 --- a/kasa/smart/modules/autooff.py +++ b/kasa/smart/modules/autooff.py @@ -17,7 +17,7 @@ class AutoOff(SmartModule): REQUIRED_COMPONENT = "auto_off" QUERY_GETTER_NAME = "get_auto_off_config" - def _initialize_features(self): + def _initialize_features(self) -> None: """Initialize features after the initial update.""" self._add_feature( Feature( @@ -63,7 +63,7 @@ class AutoOff(SmartModule): """Return True if enabled.""" return self.data["enable"] - async def set_enabled(self, enable: bool): + async def set_enabled(self, enable: bool) -> dict: """Enable/disable auto off.""" return await self.call( "set_auto_off_config", @@ -75,7 +75,7 @@ class AutoOff(SmartModule): """Return time until auto off.""" return self.data["delay_min"] - async def set_delay(self, delay: int): + async def set_delay(self, delay: int) -> dict: """Set time until auto off.""" return await self.call( "set_auto_off_config", {"delay_min": delay, "enable": self.data["enable"]} @@ -96,7 +96,7 @@ class AutoOff(SmartModule): return self._device.time + timedelta(seconds=sysinfo["auto_off_remain_time"]) - async def _check_supported(self): + async def _check_supported(self) -> bool: """Additional check to see if the module is supported by the device. Parent devices that report components of children such as P300 will not have diff --git a/kasa/smart/modules/batterysensor.py b/kasa/smart/modules/batterysensor.py index 7ecfad20..87072b10 100644 --- a/kasa/smart/modules/batterysensor.py +++ b/kasa/smart/modules/batterysensor.py @@ -12,7 +12,7 @@ class BatterySensor(SmartModule): REQUIRED_COMPONENT = "battery_detect" QUERY_GETTER_NAME = "get_battery_detect_info" - def _initialize_features(self): + def _initialize_features(self) -> None: """Initialize features.""" self._add_feature( Feature( @@ -48,11 +48,11 @@ class BatterySensor(SmartModule): return {} @property - def battery(self): + def battery(self) -> int: """Return battery level.""" return self._device.sys_info["battery_percentage"] @property - def battery_low(self): + def battery_low(self) -> bool: """Return True if battery is low.""" return self._device.sys_info["at_low_battery"] diff --git a/kasa/smart/modules/brightness.py b/kasa/smart/modules/brightness.py index f6e5c322..b5b8d354 100644 --- a/kasa/smart/modules/brightness.py +++ b/kasa/smart/modules/brightness.py @@ -14,7 +14,7 @@ class Brightness(SmartModule): REQUIRED_COMPONENT = "brightness" - def _initialize_features(self): + def _initialize_features(self) -> None: """Initialize features.""" super()._initialize_features() @@ -39,7 +39,7 @@ class Brightness(SmartModule): return {} @property - def brightness(self): + def brightness(self) -> int: """Return current brightness.""" # If the device supports effects and one is active, use its brightness if ( @@ -49,7 +49,9 @@ class Brightness(SmartModule): return self.data["brightness"] - async def set_brightness(self, brightness: int, *, transition: int | None = None): + async def set_brightness( + self, brightness: int, *, transition: int | None = None + ) -> dict: """Set the brightness. A brightness value of 0 will turn off the light. Note, transition is not supported and will be ignored. @@ -73,6 +75,6 @@ class Brightness(SmartModule): return await self.call("set_device_info", {"brightness": brightness}) - async def _check_supported(self): + async def _check_supported(self) -> bool: """Additional check to see if the module is supported by the device.""" return "brightness" in self.data diff --git a/kasa/smart/modules/childprotection.py b/kasa/smart/modules/childprotection.py index d9670a23..fba89cc0 100644 --- a/kasa/smart/modules/childprotection.py +++ b/kasa/smart/modules/childprotection.py @@ -12,7 +12,7 @@ class ChildProtection(SmartModule): REQUIRED_COMPONENT = "child_protection" QUERY_GETTER_NAME = "get_child_protection" - def _initialize_features(self): + def _initialize_features(self) -> None: """Initialize features after the initial update.""" self._add_feature( Feature( diff --git a/kasa/smart/modules/cloud.py b/kasa/smart/modules/cloud.py index 347b9ec8..fd6d0a0f 100644 --- a/kasa/smart/modules/cloud.py +++ b/kasa/smart/modules/cloud.py @@ -13,7 +13,7 @@ class Cloud(SmartModule): REQUIRED_COMPONENT = "cloud_connect" MINIMUM_UPDATE_INTERVAL_SECS = 60 - def _initialize_features(self): + def _initialize_features(self) -> None: """Initialize features after the initial update.""" self._add_feature( Feature( @@ -29,7 +29,7 @@ class Cloud(SmartModule): ) @property - def is_connected(self): + def is_connected(self) -> bool: """Return True if device is connected to the cloud.""" if self._has_data_error(): return False diff --git a/kasa/smart/modules/color.py b/kasa/smart/modules/color.py index 3faa1a82..de0c3f74 100644 --- a/kasa/smart/modules/color.py +++ b/kasa/smart/modules/color.py @@ -12,7 +12,7 @@ class Color(SmartModule): REQUIRED_COMPONENT = "color" - def _initialize_features(self): + def _initialize_features(self) -> None: """Initialize features after the initial update.""" self._add_feature( Feature( @@ -48,7 +48,7 @@ class Color(SmartModule): # due to the cpython implementation. return tuple.__new__(HSV, (h, s, v)) - def _raise_for_invalid_brightness(self, value): + def _raise_for_invalid_brightness(self, value: int) -> None: """Raise error on invalid brightness value.""" if not isinstance(value, int): raise TypeError("Brightness must be an integer") diff --git a/kasa/smart/modules/colortemperature.py b/kasa/smart/modules/colortemperature.py index 920fa6d2..32d6e67d 100644 --- a/kasa/smart/modules/colortemperature.py +++ b/kasa/smart/modules/colortemperature.py @@ -18,7 +18,7 @@ class ColorTemperature(SmartModule): REQUIRED_COMPONENT = "color_temperature" - def _initialize_features(self): + def _initialize_features(self) -> None: """Initialize features.""" self._add_feature( Feature( @@ -52,11 +52,11 @@ class ColorTemperature(SmartModule): return ColorTempRange(*ct_range) @property - def color_temp(self): + def color_temp(self) -> int: """Return current color temperature.""" return self.data["color_temp"] - async def set_color_temp(self, temp: int, *, brightness=None): + async def set_color_temp(self, temp: int, *, brightness: int | None = None) -> dict: """Set the color temperature.""" valid_temperature_range = self.valid_temperature_range if temp < valid_temperature_range[0] or temp > valid_temperature_range[1]: diff --git a/kasa/smart/modules/contactsensor.py b/kasa/smart/modules/contactsensor.py index 0bfa1bde..f388b781 100644 --- a/kasa/smart/modules/contactsensor.py +++ b/kasa/smart/modules/contactsensor.py @@ -12,7 +12,7 @@ class ContactSensor(SmartModule): REQUIRED_COMPONENT = None # we depend on availability of key REQUIRED_KEY_ON_PARENT = "open" - def _initialize_features(self): + def _initialize_features(self) -> None: """Initialize features after the initial update.""" self._add_feature( Feature( @@ -32,6 +32,6 @@ class ContactSensor(SmartModule): return {} @property - def is_open(self): + def is_open(self) -> bool: """Return True if the contact sensor is open.""" return self._device.sys_info["open"] diff --git a/kasa/smart/modules/devicemodule.py b/kasa/smart/modules/devicemodule.py index 89c87c20..bf112e2d 100644 --- a/kasa/smart/modules/devicemodule.py +++ b/kasa/smart/modules/devicemodule.py @@ -10,7 +10,7 @@ class DeviceModule(SmartModule): REQUIRED_COMPONENT = "device" - async def _post_update_hook(self): + async def _post_update_hook(self) -> None: """Perform actions after a device update. Overrides the default behaviour to disable a module if the query returns diff --git a/kasa/smart/modules/energy.py b/kasa/smart/modules/energy.py index ab89c319..16a4890e 100644 --- a/kasa/smart/modules/energy.py +++ b/kasa/smart/modules/energy.py @@ -2,6 +2,8 @@ from __future__ import annotations +from typing import NoReturn + from ...emeterstatus import EmeterStatus from ...exceptions import KasaException from ...interfaces.energy import Energy as EnergyInterface @@ -31,34 +33,34 @@ class Energy(SmartModule, EnergyInterface): # Fallback if get_energy_usage does not provide current_power, # which can happen on some newer devices (e.g. P304M). elif ( - power := self.data.get("get_current_power").get("current_power") + power := self.data.get("get_current_power", {}).get("current_power") ) is not None: return power return None @property @raise_if_update_error - def energy(self): + def energy(self) -> dict: """Return get_energy_usage results.""" if en := self.data.get("get_energy_usage"): return en return self.data - def _get_status_from_energy(self, energy) -> EmeterStatus: + def _get_status_from_energy(self, energy: dict) -> EmeterStatus: return EmeterStatus( { - "power_mw": energy.get("current_power"), - "total": energy.get("today_energy") / 1_000, + "power_mw": energy.get("current_power", 0), + "total": energy.get("today_energy", 0) / 1_000, } ) @property @raise_if_update_error - def status(self): + def status(self) -> EmeterStatus: """Get the emeter status.""" return self._get_status_from_energy(self.energy) - async def get_status(self): + async def get_status(self) -> EmeterStatus: """Return real-time statistics.""" res = await self.call("get_energy_usage") return self._get_status_from_energy(res["get_energy_usage"]) @@ -67,13 +69,13 @@ class Energy(SmartModule, EnergyInterface): @raise_if_update_error def consumption_this_month(self) -> float | None: """Get the emeter value for this month in kWh.""" - return self.energy.get("month_energy") / 1_000 + return self.energy.get("month_energy", 0) / 1_000 @property @raise_if_update_error def consumption_today(self) -> float | None: """Get the emeter value for today in kWh.""" - return self.energy.get("today_energy") / 1_000 + return self.energy.get("today_energy", 0) / 1_000 @property @raise_if_update_error @@ -97,22 +99,26 @@ class Energy(SmartModule, EnergyInterface): """Retrieve current energy readings.""" return self.status - async def erase_stats(self): + async def erase_stats(self) -> NoReturn: """Erase all stats.""" raise KasaException("Device does not support periodic statistics") - async def get_daily_stats(self, *, year=None, month=None, kwh=True) -> dict: + async def get_daily_stats( + self, *, year: int | None = None, month: int | None = None, kwh: bool = True + ) -> dict: """Return daily stats for the given year & month. The return value is a dictionary of {day: energy, ...}. """ raise KasaException("Device does not support periodic statistics") - async def get_monthly_stats(self, *, year=None, kwh=True) -> dict: + async def get_monthly_stats( + self, *, year: int | None = None, kwh: bool = True + ) -> dict: """Return monthly stats for the given year.""" raise KasaException("Device does not support periodic statistics") - async def _check_supported(self): + async def _check_supported(self) -> bool: """Additional check to see if the module is supported by the device.""" # Energy module is not supported on P304M parent device return "device_on" in self._device.sys_info diff --git a/kasa/smart/modules/fan.py b/kasa/smart/modules/fan.py index 9cb1a8df..36b3aadf 100644 --- a/kasa/smart/modules/fan.py +++ b/kasa/smart/modules/fan.py @@ -12,7 +12,7 @@ class Fan(SmartModule, FanInterface): REQUIRED_COMPONENT = "fan_control" - def _initialize_features(self): + def _initialize_features(self) -> None: """Initialize features after the initial update.""" self._add_feature( Feature( @@ -50,7 +50,7 @@ class Fan(SmartModule, FanInterface): """Return fan speed level.""" return 0 if self.data["device_on"] is False else self.data["fan_speed_level"] - async def set_fan_speed_level(self, level: int): + async def set_fan_speed_level(self, level: int) -> dict: """Set fan speed level, 0 for off, 1-4 for on.""" if level < 0 or level > 4: raise ValueError("Invalid level, should be in range 0-4.") @@ -65,10 +65,10 @@ class Fan(SmartModule, FanInterface): """Return sleep mode status.""" return self.data["fan_sleep_mode_on"] - async def set_sleep_mode(self, on: bool): + async def set_sleep_mode(self, on: bool) -> dict: """Set sleep mode.""" return await self.call("set_device_info", {"fan_sleep_mode_on": on}) - async def _check_supported(self): + async def _check_supported(self) -> bool: """Is the module available on this device.""" return "fan_speed_level" in self.data diff --git a/kasa/smart/modules/firmware.py b/kasa/smart/modules/firmware.py index 036c0b6c..f9e6b034 100644 --- a/kasa/smart/modules/firmware.py +++ b/kasa/smart/modules/firmware.py @@ -49,14 +49,14 @@ class UpdateInfo(BaseModel): needs_upgrade: bool = Field(alias="need_to_upgrade") @validator("release_date", pre=True) - def _release_date_optional(cls, v): + def _release_date_optional(cls, v: str) -> str | None: if not v: return None return v @property - def update_available(self): + def update_available(self) -> bool: """Return True if update available.""" if self.status != 0: return True @@ -69,11 +69,11 @@ class Firmware(SmartModule): REQUIRED_COMPONENT = "firmware" MINIMUM_UPDATE_INTERVAL_SECS = 60 * 60 * 24 - def __init__(self, device: SmartDevice, module: str): + def __init__(self, device: SmartDevice, module: str) -> None: super().__init__(device, module) self._firmware_update_info: UpdateInfo | None = None - def _initialize_features(self): + def _initialize_features(self) -> None: """Initialize features.""" device = self._device if self.supported_version > 1: @@ -183,7 +183,7 @@ class Firmware(SmartModule): @allow_update_after async def update( self, progress_cb: Callable[[DownloadState], Coroutine] | None = None - ): + ) -> dict: """Update the device firmware.""" if not self._firmware_update_info: raise KasaException( @@ -236,13 +236,15 @@ class Firmware(SmartModule): else: _LOGGER.warning("Unhandled state code: %s", state) + return state.dict() + @property def auto_update_enabled(self) -> bool: """Return True if autoupdate is enabled.""" return "enable" in self.data and self.data["enable"] @allow_update_after - async def set_auto_update_enabled(self, enabled: bool): + async def set_auto_update_enabled(self, enabled: bool) -> dict: """Change autoupdate setting.""" data = {**self.data, "enable": enabled} - await self.call("set_auto_update_info", data) + return await self.call("set_auto_update_info", data) diff --git a/kasa/smart/modules/frostprotection.py b/kasa/smart/modules/frostprotection.py index 440e1ed1..dd3671a0 100644 --- a/kasa/smart/modules/frostprotection.py +++ b/kasa/smart/modules/frostprotection.py @@ -23,7 +23,7 @@ class FrostProtection(SmartModule): """Return True if frost protection is on.""" return self._device.sys_info["frost_protection_on"] - async def set_enabled(self, enable: bool): + async def set_enabled(self, enable: bool) -> dict: """Enable/disable frost protection.""" return await self.call( "set_device_info", diff --git a/kasa/smart/modules/humiditysensor.py b/kasa/smart/modules/humiditysensor.py index fab30f05..8ce9e576 100644 --- a/kasa/smart/modules/humiditysensor.py +++ b/kasa/smart/modules/humiditysensor.py @@ -12,7 +12,7 @@ class HumiditySensor(SmartModule): REQUIRED_COMPONENT = "humidity" QUERY_GETTER_NAME = "get_comfort_humidity_config" - def _initialize_features(self): + def _initialize_features(self) -> None: """Initialize features after the initial update.""" self._add_feature( Feature( @@ -45,7 +45,7 @@ class HumiditySensor(SmartModule): return {} @property - def humidity(self): + def humidity(self) -> int: """Return current humidity in percentage.""" return self._device.sys_info["current_humidity"] diff --git a/kasa/smart/modules/led.py b/kasa/smart/modules/led.py index 9c02be85..1733c3ce 100644 --- a/kasa/smart/modules/led.py +++ b/kasa/smart/modules/led.py @@ -19,7 +19,7 @@ class Led(SmartModule, LedInterface): return {self.QUERY_GETTER_NAME: None} @property - def mode(self): + def mode(self) -> str: """LED mode setting. "always", "never", "night_mode" @@ -27,12 +27,12 @@ class Led(SmartModule, LedInterface): return self.data["led_rule"] @property - def led(self): + def led(self) -> bool: """Return current led status.""" return self.data["led_rule"] != "never" @allow_update_after - async def set_led(self, enable: bool): + async def set_led(self, enable: bool) -> dict: """Set led. This should probably be a select with always/never/nightmode. @@ -41,7 +41,7 @@ class Led(SmartModule, LedInterface): return await self.call("set_led_info", dict(self.data, **{"led_rule": rule})) @property - def night_mode_settings(self): + def night_mode_settings(self) -> dict: """Night mode settings.""" return { "start": self.data["start_time"], diff --git a/kasa/smart/modules/light.py b/kasa/smart/modules/light.py index 487c25f3..e637b607 100644 --- a/kasa/smart/modules/light.py +++ b/kasa/smart/modules/light.py @@ -96,7 +96,7 @@ class Light(SmartModule, LightInterface): return await self._device.modules[Module.Color].set_hsv(hue, saturation, value) async def set_color_temp( - self, temp: int, *, brightness=None, transition: int | None = None + self, temp: int, *, brightness: int | None = None, transition: int | None = None ) -> dict: """Set the color temperature of the device in kelvin. diff --git a/kasa/smart/modules/lighteffect.py b/kasa/smart/modules/lighteffect.py index 55dd3d49..96135de4 100644 --- a/kasa/smart/modules/lighteffect.py +++ b/kasa/smart/modules/lighteffect.py @@ -81,7 +81,7 @@ class LightEffect(SmartModule, SmartLightEffect): *, brightness: int | None = None, transition: int | None = None, - ) -> None: + ) -> dict: """Set an effect for the device. Calling this will modify the brightness of the effect on the device. @@ -107,7 +107,7 @@ class LightEffect(SmartModule, SmartLightEffect): ) await self.set_brightness(brightness, effect_id=effect_id) - await self.call("set_dynamic_light_effect_rule_enable", params) + return await self.call("set_dynamic_light_effect_rule_enable", params) @property def is_active(self) -> bool: @@ -139,11 +139,11 @@ class LightEffect(SmartModule, SmartLightEffect): *, transition: int | None = None, effect_id: str | None = None, - ): + ) -> dict: """Set effect brightness.""" new_effect = self._get_effect_data(effect_id=effect_id).copy() - def _replace_brightness(data, new_brightness): + def _replace_brightness(data: list[int], new_brightness: int) -> list[int]: """Replace brightness. The first element is the brightness, the rest are unknown. @@ -163,7 +163,7 @@ class LightEffect(SmartModule, SmartLightEffect): async def set_custom_effect( self, effect_dict: dict, - ) -> None: + ) -> dict: """Set a custom effect on the device. :param str effect_dict: The custom effect dict to set diff --git a/kasa/smart/modules/lightpreset.py b/kasa/smart/modules/lightpreset.py index 56ca42c2..2eba7572 100644 --- a/kasa/smart/modules/lightpreset.py +++ b/kasa/smart/modules/lightpreset.py @@ -29,12 +29,12 @@ class LightPreset(SmartModule, LightPresetInterface): _presets: dict[str, LightState] _preset_list: list[str] - def __init__(self, device: SmartDevice, module: str): + def __init__(self, device: SmartDevice, module: str) -> None: super().__init__(device, module) self._state_in_sysinfo = self.SYS_INFO_STATE_KEY in device.sys_info self._brightness_only: bool = False - async def _post_update_hook(self): + async def _post_update_hook(self) -> None: """Update the internal presets.""" index = 0 self._presets = {} @@ -113,7 +113,7 @@ class LightPreset(SmartModule, LightPresetInterface): async def set_preset( self, preset_name: str, - ) -> None: + ) -> dict: """Set a light preset for the device.""" light = self._device.modules[SmartModule.Light] if preset_name == self.PRESET_NOT_SET: @@ -123,14 +123,14 @@ class LightPreset(SmartModule, LightPresetInterface): preset = LightState(brightness=100) elif (preset := self._presets.get(preset_name)) is None: # type: ignore[assignment] raise ValueError(f"{preset_name} is not a valid preset: {self.preset_list}") - await self._device.modules[SmartModule.Light].set_state(preset) + return await self._device.modules[SmartModule.Light].set_state(preset) @allow_update_after async def save_preset( self, preset_name: str, preset_state: LightState, - ) -> None: + ) -> dict: """Update the preset with preset_name with the new preset_info.""" if preset_name not in self._presets: raise ValueError(f"{preset_name} is not a valid preset: {self.preset_list}") @@ -138,11 +138,13 @@ class LightPreset(SmartModule, LightPresetInterface): if self._brightness_only: bright_list = [state.brightness for state in self._presets.values()] bright_list[index] = preset_state.brightness - await self.call("set_preset_rules", {"brightness": bright_list}) + return await self.call("set_preset_rules", {"brightness": bright_list}) else: state_params = asdict(preset_state) new_info = {k: v for k, v in state_params.items() if v is not None} - await self.call("edit_preset_rules", {"index": index, "state": new_info}) + return await self.call( + "edit_preset_rules", {"index": index, "state": new_info} + ) @property def has_save_preset(self) -> bool: @@ -158,7 +160,7 @@ class LightPreset(SmartModule, LightPresetInterface): return {self.QUERY_GETTER_NAME: {"start_index": 0}} - async def _check_supported(self): + async def _check_supported(self) -> bool: """Additional check to see if the module is supported by the device. Parent devices that report components of children such as ks240 will not have diff --git a/kasa/smart/modules/lightstripeffect.py b/kasa/smart/modules/lightstripeffect.py index 3b0ff7da..91d89188 100644 --- a/kasa/smart/modules/lightstripeffect.py +++ b/kasa/smart/modules/lightstripeffect.py @@ -16,7 +16,7 @@ class LightStripEffect(SmartModule, SmartLightEffect): REQUIRED_COMPONENT = "light_strip_lighting_effect" - def __init__(self, device: SmartDevice, module: str): + def __init__(self, device: SmartDevice, module: str) -> None: super().__init__(device, module) effect_list = [self.LIGHT_EFFECTS_OFF] effect_list.extend(EFFECT_NAMES) @@ -66,7 +66,9 @@ class LightStripEffect(SmartModule, SmartLightEffect): eff = self.data["lighting_effect"] return eff["brightness"] - async def set_brightness(self, brightness: int, *, transition: int | None = None): + async def set_brightness( + self, brightness: int, *, transition: int | None = None + ) -> dict: """Set effect brightness.""" if brightness <= 0: return await self.set_effect(self.LIGHT_EFFECTS_OFF) @@ -91,7 +93,7 @@ class LightStripEffect(SmartModule, SmartLightEffect): *, brightness: int | None = None, transition: int | None = None, - ) -> None: + ) -> dict: """Set an effect on the device. If brightness or transition is defined, @@ -115,8 +117,7 @@ class LightStripEffect(SmartModule, SmartLightEffect): effect_dict = self._effect_mapping["Aurora"] effect_dict = {**effect_dict} effect_dict["enable"] = 0 - await self.set_custom_effect(effect_dict) - return + return await self.set_custom_effect(effect_dict) if effect not in self._effect_mapping: raise ValueError(f"The effect {effect} is not a built in effect.") @@ -134,13 +135,13 @@ class LightStripEffect(SmartModule, SmartLightEffect): if transition is not None: effect_dict["transition"] = transition - await self.set_custom_effect(effect_dict) + return await self.set_custom_effect(effect_dict) @allow_update_after async def set_custom_effect( self, effect_dict: dict, - ) -> None: + ) -> dict: """Set a custom effect on the device. :param str effect_dict: The custom effect dict to set @@ -155,7 +156,7 @@ class LightStripEffect(SmartModule, SmartLightEffect): """Return True if the device supports setting custom effects.""" return True - def query(self): + def query(self) -> dict: """Return the base query.""" return {} diff --git a/kasa/smart/modules/lighttransition.py b/kasa/smart/modules/lighttransition.py index 947f8b0e..68c4af23 100644 --- a/kasa/smart/modules/lighttransition.py +++ b/kasa/smart/modules/lighttransition.py @@ -39,14 +39,14 @@ class LightTransition(SmartModule): _off_state: _State _enabled: bool - def __init__(self, device: SmartDevice, module: str): + def __init__(self, device: SmartDevice, module: str) -> None: super().__init__(device, module) self._state_in_sysinfo = all( key in device.sys_info for key in self.SYS_INFO_STATE_KEYS ) self._supports_on_and_off: bool = self.supported_version > 1 - def _initialize_features(self): + def _initialize_features(self) -> None: """Initialize features.""" icon = "mdi:transition" if not self._supports_on_and_off: @@ -138,7 +138,7 @@ class LightTransition(SmartModule): } @allow_update_after - async def set_enabled(self, enable: bool): + async def set_enabled(self, enable: bool) -> dict: """Enable gradual on/off.""" if not self._supports_on_and_off: return await self.call("set_on_off_gradually_info", {"enable": enable}) @@ -171,7 +171,7 @@ class LightTransition(SmartModule): return self._on_state["max_duration"] @allow_update_after - async def set_turn_on_transition(self, seconds: int): + async def set_turn_on_transition(self, seconds: int) -> dict: """Set turn on transition in seconds. Setting to 0 turns the feature off. @@ -207,7 +207,7 @@ class LightTransition(SmartModule): return self._off_state["max_duration"] @allow_update_after - async def set_turn_off_transition(self, seconds: int): + async def set_turn_off_transition(self, seconds: int) -> dict: """Set turn on transition in seconds. Setting to 0 turns the feature off. @@ -236,7 +236,7 @@ class LightTransition(SmartModule): else: return {self.QUERY_GETTER_NAME: None} - async def _check_supported(self): + async def _check_supported(self) -> bool: """Additional check to see if the module is supported by the device.""" # For devices that report child components on the parent that are not # actually supported by the parent. diff --git a/kasa/smart/modules/motionsensor.py b/kasa/smart/modules/motionsensor.py index 169b25b6..fe9ac5c0 100644 --- a/kasa/smart/modules/motionsensor.py +++ b/kasa/smart/modules/motionsensor.py @@ -11,7 +11,7 @@ class MotionSensor(SmartModule): REQUIRED_COMPONENT = "sensitivity" - def _initialize_features(self): + def _initialize_features(self) -> None: """Initialize features.""" self._add_feature( Feature( @@ -31,6 +31,6 @@ class MotionSensor(SmartModule): return {} @property - def motion_detected(self): + def motion_detected(self) -> bool: """Return True if the motion has been detected.""" return self._device.sys_info["detected"] diff --git a/kasa/smart/modules/reportmode.py b/kasa/smart/modules/reportmode.py index 34559cab..4765b4e1 100644 --- a/kasa/smart/modules/reportmode.py +++ b/kasa/smart/modules/reportmode.py @@ -12,7 +12,7 @@ class ReportMode(SmartModule): REQUIRED_COMPONENT = "report_mode" QUERY_GETTER_NAME = "get_report_mode" - def _initialize_features(self): + def _initialize_features(self) -> None: """Initialize features after the initial update.""" self._add_feature( Feature( @@ -32,6 +32,6 @@ class ReportMode(SmartModule): return {} @property - def report_interval(self): + def report_interval(self) -> int: """Reporting interval of a sensor device.""" return self._device.sys_info["report_interval"] diff --git a/kasa/smart/modules/temperaturecontrol.py b/kasa/smart/modules/temperaturecontrol.py index 96630ce5..138c3d2e 100644 --- a/kasa/smart/modules/temperaturecontrol.py +++ b/kasa/smart/modules/temperaturecontrol.py @@ -26,7 +26,7 @@ class TemperatureControl(SmartModule): REQUIRED_COMPONENT = "temp_control" - def _initialize_features(self): + def _initialize_features(self) -> None: """Initialize features after the initial update.""" self._add_feature( Feature( @@ -92,7 +92,7 @@ class TemperatureControl(SmartModule): """Return thermostat state.""" return self._device.sys_info["frost_protection_on"] is False - async def set_state(self, enabled: bool): + async def set_state(self, enabled: bool) -> dict: """Set thermostat state.""" return await self.call("set_device_info", {"frost_protection_on": not enabled}) @@ -147,7 +147,7 @@ class TemperatureControl(SmartModule): """Return thermostat states.""" return set(self._device.sys_info["trv_states"]) - async def set_target_temperature(self, target: float): + async def set_target_temperature(self, target: float) -> dict: """Set target temperature.""" if ( target < self.minimum_target_temperature @@ -170,7 +170,7 @@ class TemperatureControl(SmartModule): """Return temperature offset.""" return self._device.sys_info["temp_offset"] - async def set_temperature_offset(self, offset: int): + async def set_temperature_offset(self, offset: int) -> dict: """Set temperature offset.""" if offset < -10 or offset > 10: raise ValueError("Temperature offset must be [-10, 10]") diff --git a/kasa/smart/modules/temperaturesensor.py b/kasa/smart/modules/temperaturesensor.py index 8162ce60..0a591a3d 100644 --- a/kasa/smart/modules/temperaturesensor.py +++ b/kasa/smart/modules/temperaturesensor.py @@ -14,7 +14,7 @@ class TemperatureSensor(SmartModule): REQUIRED_COMPONENT = "temperature" QUERY_GETTER_NAME = "get_comfort_temp_config" - def _initialize_features(self): + def _initialize_features(self) -> None: """Initialize features after the initial update.""" self._add_feature( Feature( @@ -60,7 +60,7 @@ class TemperatureSensor(SmartModule): return {} @property - def temperature(self): + def temperature(self) -> float: """Return current humidity in percentage.""" return self._device.sys_info["current_temp"] @@ -74,6 +74,8 @@ class TemperatureSensor(SmartModule): """Return current temperature unit.""" return self._device.sys_info["temp_unit"] - async def set_temperature_unit(self, unit: Literal["celsius", "fahrenheit"]): + async def set_temperature_unit( + self, unit: Literal["celsius", "fahrenheit"] + ) -> dict: """Set the device temperature unit.""" return await self.call("set_temperature_unit", {"temp_unit": unit}) diff --git a/kasa/smart/modules/time.py b/kasa/smart/modules/time.py index cac01d73..d82991c1 100644 --- a/kasa/smart/modules/time.py +++ b/kasa/smart/modules/time.py @@ -21,7 +21,7 @@ class Time(SmartModule, TimeInterface): _timezone: tzinfo = timezone.utc - def _initialize_features(self): + def _initialize_features(self) -> None: """Initialize features after the initial update.""" self._add_feature( Feature( @@ -35,7 +35,7 @@ class Time(SmartModule, TimeInterface): ) ) - async def _post_update_hook(self): + async def _post_update_hook(self) -> None: """Perform actions after a device update.""" td = timedelta(minutes=cast(float, self.data.get("time_diff"))) if region := self.data.get("region"): @@ -84,7 +84,7 @@ class Time(SmartModule, TimeInterface): params["region"] = region return await self.call("set_device_time", params) - async def _check_supported(self): + async def _check_supported(self) -> bool: """Additional check to see if the module is supported by the device. Hub attached sensors report the time module but do return device time. diff --git a/kasa/smart/modules/waterleaksensor.py b/kasa/smart/modules/waterleaksensor.py index 6b8a7ae7..b6f01017 100644 --- a/kasa/smart/modules/waterleaksensor.py +++ b/kasa/smart/modules/waterleaksensor.py @@ -22,7 +22,7 @@ class WaterleakSensor(SmartModule): REQUIRED_COMPONENT = "sensor_alarm" - def _initialize_features(self): + def _initialize_features(self) -> None: """Initialize features after the initial update.""" self._add_feature( Feature( diff --git a/kasa/smart/smartchilddevice.py b/kasa/smart/smartchilddevice.py index a5b24fd5..49c92229 100644 --- a/kasa/smart/smartchilddevice.py +++ b/kasa/smart/smartchilddevice.py @@ -49,7 +49,7 @@ class SmartChildDevice(SmartDevice): self._update_internal_state(info) self._components = component_info - async def update(self, update_children: bool = True): + async def update(self, update_children: bool = True) -> None: """Update child module info. The parent updates our internal info so just update modules with @@ -57,7 +57,7 @@ class SmartChildDevice(SmartDevice): """ await self._update(update_children) - async def _update(self, update_children: bool = True): + async def _update(self, update_children: bool = True) -> None: """Update child module info. Internal implementation to allow patching of public update in the cli @@ -118,5 +118,5 @@ class SmartChildDevice(SmartDevice): dev_type = DeviceType.Unknown return dev_type - def __repr__(self): + def __repr__(self) -> str: return f"<{self.device_type} {self.alias} ({self.model}) of {self._parent}>" diff --git a/kasa/smart/smartdevice.py b/kasa/smart/smartdevice.py index 17386e07..35524ee8 100644 --- a/kasa/smart/smartdevice.py +++ b/kasa/smart/smartdevice.py @@ -69,7 +69,7 @@ class SmartDevice(Device): self._on_since: datetime | None = None self._info: dict[str, Any] = {} - async def _initialize_children(self): + async def _initialize_children(self) -> None: """Initialize children for power strips.""" child_info_query = { "get_child_device_component_list": None, @@ -108,7 +108,9 @@ class SmartDevice(Device): """Return the device modules.""" return cast(ModuleMapping[SmartModule], self._modules) - def _try_get_response(self, responses: dict, request: str, default=None) -> dict: + def _try_get_response( + self, responses: dict, request: str, default: Any | None = None + ) -> dict: response = responses.get(request) if isinstance(response, SmartErrorCode): _LOGGER.debug( @@ -126,7 +128,7 @@ class SmartDevice(Device): f"{request} not found in {responses} for device {self.host}" ) - async def _negotiate(self): + async def _negotiate(self) -> None: """Perform initialization. We fetch the device info and the available components as early as possible. @@ -146,7 +148,8 @@ class SmartDevice(Device): self._info = self._try_get_response(resp, "get_device_info") # Create our internal presentation of available components - self._components_raw = resp["component_nego"] + self._components_raw = cast(dict, resp["component_nego"]) + self._components = { comp["id"]: int(comp["ver_code"]) for comp in self._components_raw["component_list"] @@ -167,7 +170,7 @@ class SmartDevice(Device): """Update the internal device info.""" self._info = self._try_get_response(info_resp, "get_device_info") - async def update(self, update_children: bool = False): + async def update(self, update_children: bool = False) -> None: """Update the device.""" if self.credentials is None and self.credentials_hash is None: raise AuthenticationError("Tapo plug requires authentication.") @@ -206,7 +209,7 @@ class SmartDevice(Device): async def _handle_module_post_update( self, module: SmartModule, update_time: float, had_query: bool - ): + ) -> None: if module.disabled: return # pragma: no cover if had_query: @@ -312,7 +315,7 @@ class SmartDevice(Device): responses[meth] = SmartErrorCode.INTERNAL_QUERY_ERROR return responses - async def _initialize_modules(self): + async def _initialize_modules(self) -> None: """Initialize modules based on component negotiation response.""" from .smartmodule import SmartModule @@ -324,7 +327,7 @@ class SmartDevice(Device): # It also ensures that devices like power strips do not add modules such as # firmware to the child devices. skip_parent_only_modules = False - child_modules_to_skip = {} + child_modules_to_skip: dict = {} # TODO: this is never non-empty if self._parent and self._parent.device_type != DeviceType.Hub: skip_parent_only_modules = True @@ -333,17 +336,18 @@ class SmartDevice(Device): skip_parent_only_modules and mod in NON_HUB_PARENT_ONLY_MODULES ) or mod.__name__ in child_modules_to_skip: continue - if ( - mod.REQUIRED_COMPONENT in self._components - or self.sys_info.get(mod.REQUIRED_KEY_ON_PARENT) is not None + required_component = cast(str, mod.REQUIRED_COMPONENT) + if required_component in self._components or ( + mod.REQUIRED_KEY_ON_PARENT + and self.sys_info.get(mod.REQUIRED_KEY_ON_PARENT) is not None ): _LOGGER.debug( "Device %s, found required %s, adding %s to modules.", self.host, - mod.REQUIRED_COMPONENT, + required_component, mod.__name__, ) - module = mod(self, mod.REQUIRED_COMPONENT) + module = mod(self, required_component) if await module._check_supported(): self._modules[module.name] = module @@ -354,7 +358,7 @@ class SmartDevice(Device): ): self._modules[Light.__name__] = Light(self, "light") - async def _initialize_features(self): + async def _initialize_features(self) -> None: """Initialize device features.""" self._add_feature( Feature( @@ -575,11 +579,11 @@ class SmartDevice(Device): return str(self._info.get("device_id")) @property - def internal_state(self) -> Any: + def internal_state(self) -> dict: """Return all the internal state data.""" return self._last_update - def _update_internal_state(self, info: dict) -> None: + def _update_internal_state(self, info: dict[str, Any]) -> None: """Update the internal info state. This is used by the parent to push updates to its children. @@ -587,8 +591,8 @@ class SmartDevice(Device): self._info = info async def _query_helper( - self, method: str, params: dict | None = None, child_ids=None - ) -> Any: + self, method: str, params: dict | None = None, child_ids: None = None + ) -> dict: res = await self.protocol.query({method: params}) return res @@ -610,22 +614,25 @@ class SmartDevice(Device): """Return true if the device is on.""" return bool(self._info.get("device_on")) - async def set_state(self, on: bool): # TODO: better name wanted. + async def set_state(self, on: bool) -> dict: """Set the device state. See :meth:`is_on`. """ return await self.protocol.query({"set_device_info": {"device_on": on}}) - async def turn_on(self, **kwargs): + async def turn_on(self, **kwargs: Any) -> dict: """Turn on the device.""" - await self.set_state(True) + return await self.set_state(True) - async def turn_off(self, **kwargs): + async def turn_off(self, **kwargs: Any) -> dict: """Turn off the device.""" - await self.set_state(False) + return await self.set_state(False) - def update_from_discover_info(self, info): + def update_from_discover_info( + self, + info: dict, + ) -> None: """Update state from info from the discover call.""" self._discovery_info = info self._info = info @@ -633,7 +640,7 @@ class SmartDevice(Device): async def wifi_scan(self) -> list[WifiNetwork]: """Scan for available wifi networks.""" - def _net_for_scan_info(res): + def _net_for_scan_info(res: dict) -> WifiNetwork: return WifiNetwork( ssid=base64.b64decode(res["ssid"]).decode(), cipher_type=res["cipher_type"], @@ -651,7 +658,9 @@ class SmartDevice(Device): ] return networks - async def wifi_join(self, ssid: str, password: str, keytype: str = "wpa2_psk"): + async def wifi_join( + self, ssid: str, password: str, keytype: str = "wpa2_psk" + ) -> dict: """Join the given wifi network. This method returns nothing as the device tries to activate the new @@ -688,9 +697,12 @@ class SmartDevice(Device): except DeviceError: raise # Re-raise on device-reported errors except KasaException: - _LOGGER.debug("Received an expected for wifi join, but this is expected") + _LOGGER.debug( + "Received a kasa exception for wifi join, but this is expected" + ) + return {} - async def update_credentials(self, username: str, password: str): + async def update_credentials(self, username: str, password: str) -> dict: """Update device credentials. This will replace the existing authentication credentials on the device. @@ -705,7 +717,7 @@ class SmartDevice(Device): } return await self.protocol.query({"set_qs_info": payload}) - async def set_alias(self, alias: str): + async def set_alias(self, alias: str) -> dict: """Set the device name (alias).""" return await self.protocol.query( {"set_device_info": {"nickname": base64.b64encode(alias.encode()).decode()}} diff --git a/kasa/smart/smartmodule.py b/kasa/smart/smartmodule.py index f20186ec..f0b95ecb 100644 --- a/kasa/smart/smartmodule.py +++ b/kasa/smart/smartmodule.py @@ -22,17 +22,17 @@ _R = TypeVar("_R") def allow_update_after( - func: Callable[Concatenate[_T, _P], Awaitable[None]], -) -> Callable[Concatenate[_T, _P], Coroutine[Any, Any, None]]: + func: Callable[Concatenate[_T, _P], Awaitable[dict]], +) -> Callable[Concatenate[_T, _P], Coroutine[Any, Any, dict]]: """Define a wrapper to set _last_update_time to None. This will ensure that a module is updated in the next update cycle after a value has been changed. """ - async def _async_wrap(self: _T, *args: _P.args, **kwargs: _P.kwargs) -> None: + async def _async_wrap(self: _T, *args: _P.args, **kwargs: _P.kwargs) -> dict: try: - await func(self, *args, **kwargs) + return await func(self, *args, **kwargs) finally: self._last_update_time = None @@ -68,21 +68,21 @@ class SmartModule(Module): DISABLE_AFTER_ERROR_COUNT = 10 - def __init__(self, device: SmartDevice, module: str): + def __init__(self, device: SmartDevice, module: str) -> None: self._device: SmartDevice super().__init__(device, module) self._last_update_time: float | None = None self._last_update_error: KasaException | None = None self._error_count = 0 - def __init_subclass__(cls, **kwargs): + def __init_subclass__(cls, **kwargs) -> None: # We only want to register submodules in a modules package so that # other classes can inherit from smartmodule and not be registered if cls.__module__.split(".")[-2] == "modules": _LOGGER.debug("Registering %s", cls) cls.REGISTERED_MODULES[cls._module_name()] = cls - def _set_error(self, err: Exception | None): + def _set_error(self, err: Exception | None) -> None: if err is None: self._error_count = 0 self._last_update_error = None @@ -119,7 +119,7 @@ class SmartModule(Module): return self._error_count >= self.DISABLE_AFTER_ERROR_COUNT @classmethod - def _module_name(cls): + def _module_name(cls) -> str: return getattr(cls, "NAME", cls.__name__) @property @@ -127,7 +127,7 @@ class SmartModule(Module): """Name of the module.""" return self._module_name() - async def _post_update_hook(self): # noqa: B027 + async def _post_update_hook(self) -> None: # noqa: B027 """Perform actions after a device update. Any modules overriding this should ensure that self.data is @@ -142,7 +142,7 @@ class SmartModule(Module): """ return {self.QUERY_GETTER_NAME: None} - async def call(self, method, params=None): + async def call(self, method: str, params: dict | None = None) -> dict: """Call a method. Just a helper method. @@ -150,7 +150,7 @@ class SmartModule(Module): return await self._device._query_helper(method, params) @property - def data(self): + def data(self) -> dict[str, Any]: """Return response data for the module. If the module performs only a single query, the resulting response is unwrapped. diff --git a/kasa/smartprotocol.py b/kasa/smartprotocol.py index 71be7dee..e2ff6af7 100644 --- a/kasa/smartprotocol.py +++ b/kasa/smartprotocol.py @@ -72,7 +72,7 @@ class SmartProtocol(BaseProtocol): ) self._redact_data = True - def get_smart_request(self, method, params=None) -> str: + def get_smart_request(self, method: str, params: dict | None = None) -> str: """Get a request message as a string.""" request = { "method": method, @@ -289,8 +289,8 @@ class SmartProtocol(BaseProtocol): return {smart_method: result} async def _handle_response_lists( - self, response_result: dict[str, Any], method, retry_count - ): + self, response_result: dict[str, Any], method: str, retry_count: int + ) -> None: if ( response_result is None or isinstance(response_result, SmartErrorCode) @@ -325,7 +325,9 @@ class SmartProtocol(BaseProtocol): break response_result[response_list_name].extend(next_batch[response_list_name]) - def _handle_response_error_code(self, resp_dict: dict, method, raise_on_error=True): + def _handle_response_error_code( + self, resp_dict: dict, method: str, raise_on_error: bool = True + ) -> None: error_code_raw = resp_dict.get("error_code") try: error_code = SmartErrorCode.from_int(error_code_raw) @@ -369,12 +371,12 @@ class _ChildProtocolWrapper(SmartProtocol): device responses before returning to the caller. """ - def __init__(self, device_id: str, base_protocol: SmartProtocol): + def __init__(self, device_id: str, base_protocol: SmartProtocol) -> None: self._device_id = device_id self._protocol = base_protocol self._transport = base_protocol._transport - def _get_method_and_params_for_request(self, request): + def _get_method_and_params_for_request(self, request: dict[str, Any] | str) -> Any: """Return payload for wrapping. TODO: this does not support batches and requires refactoring in the future. diff --git a/kasa/tests/fakeprotocol_smart.py b/kasa/tests/fakeprotocol_smart.py index 2deebf90..842147f3 100644 --- a/kasa/tests/fakeprotocol_smart.py +++ b/kasa/tests/fakeprotocol_smart.py @@ -310,9 +310,7 @@ class FakeSmartTransport(BaseTransport): } return retval - raise NotImplementedError( - "Method %s not implemented for children" % child_method - ) + raise NotImplementedError(f"Method {child_method} not implemented for children") def _get_on_off_gradually_info(self, info, params): if self.components["on_off_gradually"] == 1: diff --git a/kasa/tests/smart/modules/test_firmware.py b/kasa/tests/smart/modules/test_firmware.py index c10d9086..013533d0 100644 --- a/kasa/tests/smart/modules/test_firmware.py +++ b/kasa/tests/smart/modules/test_firmware.py @@ -41,7 +41,7 @@ async def test_firmware_features( await fw.check_latest_firmware() if fw.supported_version < required_version: - pytest.skip("Feature %s requires newer version" % feature) + pytest.skip(f"Feature {feature} requires newer version") prop = getattr(fw, prop_name) assert isinstance(prop, type) diff --git a/kasa/xortransport.py b/kasa/xortransport.py index e8d0303b..7abc2a3b 100644 --- a/kasa/xortransport.py +++ b/kasa/xortransport.py @@ -48,7 +48,7 @@ class XorTransport(BaseTransport): self.loop: asyncio.AbstractEventLoop | None = None @property - def default_port(self): + def default_port(self) -> int: """Default port for the transport.""" return self.DEFAULT_PORT diff --git a/pyproject.toml b/pyproject.toml index c2ad3a36..8374a711 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -139,10 +139,15 @@ select = [ "PT", # flake8-pytest-style "LOG", # flake8-logging "G", # flake8-logging-format + "ANN", # annotations ] ignore = [ "D105", # Missing docstring in magic method "D107", # Missing docstring in `__init__` + "ANN101", # Missing type annotation for `self` + "ANN102", # Missing type annotation for `cls` in classmethod + "ANN003", # Missing type annotation for `**kwargs` + "ANN401", # allow any ] [tool.ruff.lint.pydocstyle] @@ -157,11 +162,21 @@ convention = "pep257" "D104", "S101", # allow asserts "E501", # ignore line-too-longs + "ANN", # skip for now ] "docs/source/conf.py" = [ "D100", "D103", ] +# Temporary ANN disable +"kasa/cli/*.py" = [ + "ANN", +] +# Temporary ANN disable +"devtools/*.py" = [ + "ANN", +] + [tool.mypy] warn_unused_configs = true # warns if overrides sections unused/mis-spelled