Enable ruff check for ANN (#1139)

This commit is contained in:
Teemu R. 2024-11-10 19:55:13 +01:00 committed by GitHub
parent 6b44fe6242
commit 66eb17057e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
89 changed files with 596 additions and 452 deletions

View File

@ -66,6 +66,6 @@ todo_include_todos = True
myst_heading_anchors = 3
def setup(app):
def setup(app): # noqa: ANN201,ANN001
# add copybutton to hide the >>> prompts, see https://github.com/readthedocs/sphinx_rtd_theme/issues/167
app.add_js_file("copybutton.js")

View File

@ -13,7 +13,7 @@ to be handled by the user of the library.
"""
from importlib.metadata import version
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
from warnings import warn
from kasa.credentials import Credentials
@ -101,7 +101,7 @@ deprecated_classes = {
}
def __getattr__(name):
def __getattr__(name: str) -> Any:
if name in deprecated_names:
warn(f"{name} is deprecated", DeprecationWarning, stacklevel=2)
return globals()[f"_deprecated_{name}"]
@ -117,7 +117,7 @@ def __getattr__(name):
)
return new_class
if name in deprecated_classes:
new_class = deprecated_classes[name]
new_class = deprecated_classes[name] # type: ignore[assignment]
msg = f"{name} is deprecated, use {new_class.__name__} instead"
warn(msg, DeprecationWarning, stacklevel=2)
return new_class

View File

@ -146,7 +146,7 @@ class AesTransport(BaseTransport):
pw = base64.b64encode(credentials.password.encode()).decode()
return un, pw
def _handle_response_error_code(self, resp_dict: Any, msg: str) -> None:
def _handle_response_error_code(self, resp_dict: dict, msg: str) -> None:
error_code_raw = resp_dict.get("error_code")
try:
error_code = SmartErrorCode.from_int(error_code_raw)
@ -191,14 +191,14 @@ class AesTransport(BaseTransport):
+ f"status code {status_code} to passthrough"
)
self._handle_response_error_code(
resp_dict, "Error sending secure_passthrough message"
)
if TYPE_CHECKING:
resp_dict = cast(Dict[str, Any], resp_dict)
assert self._encryption_session is not None
self._handle_response_error_code(
resp_dict, "Error sending secure_passthrough message"
)
raw_response: str = resp_dict["result"]["response"]
try:
@ -219,7 +219,7 @@ class AesTransport(BaseTransport):
) from ex
return ret_val # type: ignore[return-value]
async def perform_login(self):
async def perform_login(self) -> None:
"""Login to the device."""
try:
await self.try_login(self._login_params)
@ -324,11 +324,11 @@ class AesTransport(BaseTransport):
+ f"status code {status_code} to handshake"
)
self._handle_response_error_code(resp_dict, "Unable to complete handshake")
if TYPE_CHECKING:
resp_dict = cast(Dict[str, Any], resp_dict)
self._handle_response_error_code(resp_dict, "Unable to complete handshake")
handshake_key = resp_dict["result"]["key"]
if (
@ -355,7 +355,7 @@ class AesTransport(BaseTransport):
_LOGGER.debug("Handshake with %s complete", self._host)
def _handshake_session_expired(self):
def _handshake_session_expired(self) -> bool:
"""Return true if session has expired."""
return (
self._session_expire_at is None
@ -394,7 +394,9 @@ class AesEncyptionSession:
"""Class for an AES encryption session."""
@staticmethod
def create_from_keypair(handshake_key: str, keypair: KeyPair):
def create_from_keypair(
handshake_key: str, keypair: KeyPair
) -> AesEncyptionSession:
"""Create the encryption session."""
handshake_key_bytes: bytes = base64.b64decode(handshake_key.encode())
@ -404,11 +406,11 @@ class AesEncyptionSession:
return AesEncyptionSession(key_and_iv[:16], key_and_iv[16:])
def __init__(self, key, iv):
def __init__(self, key: bytes, iv: bytes) -> None:
self.cipher = Cipher(algorithms.AES(key), modes.CBC(iv))
self.padding_strategy = padding.PKCS7(algorithms.AES.block_size)
def encrypt(self, data) -> bytes:
def encrypt(self, data: bytes) -> bytes:
"""Encrypt the message."""
encryptor = self.cipher.encryptor()
padder = self.padding_strategy.padder()
@ -416,7 +418,7 @@ class AesEncyptionSession:
encrypted = encryptor.update(padded_data) + encryptor.finalize()
return base64.b64encode(encrypted)
def decrypt(self, data) -> str:
def decrypt(self, data: str | bytes) -> str:
"""Decrypt the message."""
decryptor = self.cipher.decryptor()
unpadder = self.padding_strategy.unpadder()
@ -429,14 +431,16 @@ class KeyPair:
"""Class for generating key pairs."""
@staticmethod
def create_key_pair(key_size: int = 1024):
def create_key_pair(key_size: int = 1024) -> KeyPair:
"""Create a key pair."""
private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size)
public_key = private_key.public_key()
return KeyPair(private_key, public_key)
@staticmethod
def create_from_der_keys(private_key_der_b64: str, public_key_der_b64: str):
def create_from_der_keys(
private_key_der_b64: str, public_key_der_b64: str
) -> KeyPair:
"""Create a key pair."""
key_bytes = base64.b64decode(private_key_der_b64.encode())
private_key = cast(
@ -449,7 +453,9 @@ class KeyPair:
return KeyPair(private_key, public_key)
def __init__(self, private_key: rsa.RSAPrivateKey, public_key: rsa.RSAPublicKey):
def __init__(
self, private_key: rsa.RSAPrivateKey, public_key: rsa.RSAPublicKey
) -> None:
self.private_key = private_key
self.public_key = public_key
self.private_key_der_bytes = self.private_key.private_bytes(

View File

@ -7,7 +7,7 @@ import re
import sys
from contextlib import contextmanager
from functools import singledispatch, update_wrapper, wraps
from typing import Final
from typing import TYPE_CHECKING, Any, Callable, Final
import asyncclick as click
@ -37,7 +37,7 @@ except ImportError:
"""Strip rich formatting from messages."""
@wraps(echo_func)
def wrapper(message=None, *args, **kwargs):
def wrapper(message=None, *args, **kwargs) -> None:
if message is not None:
message = rich_formatting.sub("", message)
echo_func(message, *args, **kwargs)
@ -47,20 +47,20 @@ except ImportError:
_echo = _strip_rich_formatting(click.echo)
def echo(*args, **kwargs):
def echo(*args, **kwargs) -> None:
"""Print a message."""
ctx = click.get_current_context().find_root()
if "json" not in ctx.params or ctx.params["json"] is False:
_echo(*args, **kwargs)
def error(msg: str):
def error(msg: str) -> None:
"""Print an error and exit."""
echo(f"[bold red]{msg}[/bold red]")
sys.exit(1)
def json_formatter_cb(result, **kwargs):
def json_formatter_cb(result: Any, **kwargs) -> None:
"""Format and output the result as JSON, if requested."""
if not kwargs.get("json"):
return
@ -82,7 +82,7 @@ def json_formatter_cb(result, **kwargs):
print(json_content)
def pass_dev_or_child(wrapped_function):
def pass_dev_or_child(wrapped_function: Callable) -> Callable:
"""Pass the device or child to the click command based on the child options."""
child_help = (
"Child ID or alias for controlling sub-devices. "
@ -133,7 +133,10 @@ def pass_dev_or_child(wrapped_function):
async def _get_child_device(
device: Device, child_option, child_index_option, info_command
device: Device,
child_option: str | None,
child_index_option: int | None,
info_command: str | None,
) -> Device | None:
def _list_children():
return "\n".join(
@ -178,11 +181,15 @@ async def _get_child_device(
f"{child_option} children are:\n{_list_children()}"
)
if TYPE_CHECKING:
assert isinstance(child_index_option, int)
if child_index_option + 1 > len(device.children) or child_index_option < 0:
error(
f"Invalid index {child_index_option}, "
f"device has {len(device.children)} children"
)
child_by_index = device.children[child_index_option]
echo(f"Targeting child device {child_by_index.alias}")
return child_by_index
@ -195,7 +202,7 @@ def CatchAllExceptions(cls):
https://stackoverflow.com/questions/52213375
"""
def _handle_exception(debug, exc):
def _handle_exception(debug, exc) -> None:
if isinstance(exc, click.ClickException):
raise
# Handle exit request from click.

View File

@ -22,7 +22,7 @@ from .common import (
@click.group()
@pass_dev_or_child
def device(dev):
def device(dev) -> None:
"""Commands to control basic device settings."""

View File

@ -36,7 +36,7 @@ async def detail(ctx):
auth_failed = []
sem = asyncio.Semaphore()
async def print_unsupported(unsupported_exception: UnsupportedDeviceError):
async def print_unsupported(unsupported_exception: UnsupportedDeviceError) -> None:
unsupported.append(unsupported_exception)
async with sem:
if unsupported_exception.discovery_result:
@ -50,7 +50,7 @@ async def detail(ctx):
from .device import state
async def print_discovered(dev: Device):
async def print_discovered(dev: Device) -> None:
async with sem:
try:
await dev.update()
@ -189,7 +189,7 @@ async def config(ctx):
error(f"Unable to connect to {host}")
def _echo_dictionary(discovery_info: dict):
def _echo_dictionary(discovery_info: dict) -> None:
echo("\t[bold]== Discovery information ==[/bold]")
for key, value in discovery_info.items():
key_name = " ".join(x.capitalize() or "_" for x in key.split("_"))
@ -197,7 +197,7 @@ def _echo_dictionary(discovery_info: dict):
echo(f"\t{key_name_and_spaces}{value}")
def _echo_discovery_info(discovery_info):
def _echo_discovery_info(discovery_info) -> None:
# We don't have discovery info when all connection params are passed manually
if discovery_info is None:
return

View File

@ -24,7 +24,7 @@ def _echo_features(
category: Feature.Category | None = None,
verbose: bool = False,
indent: str = "\t",
):
) -> None:
"""Print out a listing of features and their values."""
if category is not None:
features = {
@ -43,7 +43,9 @@ def _echo_features(
echo(f"{indent}{feat.name} ({feat.id}): [red]got exception ({ex})[/red]")
def _echo_all_features(features, *, verbose=False, title_prefix=None, indent=""):
def _echo_all_features(
features, *, verbose=False, title_prefix=None, indent=""
) -> None:
"""Print out all features by category."""
if title_prefix is not None:
echo(f"[bold]\n{indent}== {title_prefix} ==[/bold]")

View File

@ -3,6 +3,8 @@
Taken from the click help files.
"""
from __future__ import annotations
import importlib
import asyncclick as click
@ -11,7 +13,7 @@ import asyncclick as click
class LazyGroup(click.Group):
"""Lazy group class."""
def __init__(self, *args, lazy_subcommands=None, **kwargs):
def __init__(self, *args, lazy_subcommands=None, **kwargs) -> None:
super().__init__(*args, **kwargs)
# lazy_subcommands is a map of the form:
#
@ -31,9 +33,9 @@ class LazyGroup(click.Group):
return self._lazy_load(cmd_name)
return super().get_command(ctx, cmd_name)
def format_commands(self, ctx, formatter):
def format_commands(self, ctx, formatter) -> None:
"""Format the top level help output."""
sections = {}
sections: dict[str, list] = {}
for cmd, parent in self.lazy_subcommands.items():
sections.setdefault(parent, [])
cmd_obj = self.get_command(ctx, cmd)

View File

@ -15,7 +15,7 @@ from .common import echo, error, pass_dev_or_child
@click.group()
@pass_dev_or_child
def light(dev):
def light(dev) -> None:
"""Commands to control light settings."""

View File

@ -43,7 +43,7 @@ ENCRYPT_TYPES = [encrypt_type.value for encrypt_type in DeviceEncryptionType]
DEFAULT_TARGET = "255.255.255.255"
def _legacy_type_to_class(_type):
def _legacy_type_to_class(_type: str) -> Any:
from kasa.iot import (
IotBulb,
IotDimmer,
@ -396,9 +396,9 @@ async def cli(
@cli.command()
@pass_dev_or_child
async def shell(dev: Device):
async def shell(dev: Device) -> None:
"""Open interactive shell."""
echo("Opening shell for %s" % dev)
echo(f"Opening shell for {dev}")
from ptpython.repl import embed
logging.getLogger("parso").setLevel(logging.WARNING) # prompt parsing

View File

@ -14,7 +14,7 @@ from .common import (
@click.group()
@pass_dev
async def schedule(dev):
async def schedule(dev) -> None:
"""Scheduling commands."""

View File

@ -23,7 +23,7 @@ from .common import (
@click.group(invoke_without_command=True)
@click.pass_context
async def time(ctx: click.Context):
async def time(ctx: click.Context) -> None:
"""Get and set time."""
if ctx.invoked_subcommand is None:
await ctx.invoke(time_get)

View File

@ -78,13 +78,13 @@ async def energy(dev: Device, year, month, erase):
else:
emeter_status = dev.emeter_realtime
echo("Current: %s A" % emeter_status["current"])
echo("Voltage: %s V" % emeter_status["voltage"])
echo("Power: %s W" % emeter_status["power"])
echo("Total consumption: %s kWh" % emeter_status["total"])
echo("Current: {} A".format(emeter_status["current"]))
echo("Voltage: {} V".format(emeter_status["voltage"]))
echo("Power: {} W".format(emeter_status["power"]))
echo("Total consumption: {} kWh".format(emeter_status["total"]))
echo("Today: %s kWh" % dev.emeter_today)
echo("This month: %s kWh" % dev.emeter_this_month)
echo(f"Today: {dev.emeter_today} kWh")
echo(f"This month: {dev.emeter_this_month} kWh")
return emeter_status
@ -122,8 +122,8 @@ async def usage(dev: Device, year, month, erase):
usage_data = await usage.get_daystat(year=month.year, month=month.month)
else:
# Call with no argument outputs summary data and returns
echo("Today: %s minutes" % usage.usage_today)
echo("This month: %s minutes" % usage.usage_this_month)
echo(f"Today: {usage.usage_today} minutes")
echo(f"This month: {usage.usage_this_month} minutes")
return usage

View File

@ -16,7 +16,7 @@ from .common import (
@click.group()
@pass_dev
def wifi(dev):
def wifi(dev) -> None:
"""Commands to control wifi settings."""

View File

@ -234,10 +234,10 @@ class Device(ABC):
return await connect(host=host, config=config) # type: ignore[arg-type]
@abstractmethod
async def update(self, update_children: bool = True):
async def update(self, update_children: bool = True) -> None:
"""Update the device."""
async def disconnect(self):
async def disconnect(self) -> None:
"""Disconnect and close any underlying connection resources."""
await self.protocol.close()
@ -257,15 +257,15 @@ class Device(ABC):
return not self.is_on
@abstractmethod
async def turn_on(self, **kwargs) -> dict | None:
async def turn_on(self, **kwargs) -> dict:
"""Turn on the device."""
@abstractmethod
async def turn_off(self, **kwargs) -> dict | None:
async def turn_off(self, **kwargs) -> dict:
"""Turn off the device."""
@abstractmethod
async def set_state(self, on: bool):
async def set_state(self, on: bool) -> dict:
"""Set the device state to *on*.
This allows turning the device on and off.
@ -278,7 +278,7 @@ class Device(ABC):
return self.protocol._transport._host
@host.setter
def host(self, value):
def host(self, value: str) -> None:
"""Set the device host.
Generally used by discovery to set the hostname after ip discovery.
@ -307,7 +307,7 @@ class Device(ABC):
return self._device_type
@abstractmethod
def update_from_discover_info(self, info):
def update_from_discover_info(self, info: dict) -> None:
"""Update state from info from the discover call."""
@property
@ -325,7 +325,7 @@ class Device(ABC):
def alias(self) -> str | None:
"""Returns the device alias or nickname."""
async def _raw_query(self, request: str | dict) -> Any:
async def _raw_query(self, request: str | dict) -> dict:
"""Send a raw query to the device."""
return await self.protocol.query(request=request)
@ -407,7 +407,7 @@ class Device(ABC):
@property
@abstractmethod
def internal_state(self) -> Any:
def internal_state(self) -> dict:
"""Return all the internal state data."""
@property
@ -420,10 +420,10 @@ class Device(ABC):
"""Return the list of supported features."""
return self._features
def _add_feature(self, feature: Feature):
def _add_feature(self, feature: Feature) -> None:
"""Add a new feature to the device."""
if feature.id in self._features:
raise KasaException("Duplicate feature id %s" % feature.id)
raise KasaException(f"Duplicate feature id {feature.id}")
assert feature.id is not None # TODO: hack for typing # noqa: S101
self._features[feature.id] = feature
@ -446,11 +446,13 @@ class Device(ABC):
"""Scan for available wifi networks."""
@abstractmethod
async def wifi_join(self, ssid: str, password: str, keytype: str = "wpa2_psk"):
async def wifi_join(
self, ssid: str, password: str, keytype: str = "wpa2_psk"
) -> dict:
"""Join the given wifi network."""
@abstractmethod
async def set_alias(self, alias: str):
async def set_alias(self, alias: str) -> dict:
"""Set the device name (alias)."""
@abstractmethod
@ -468,7 +470,7 @@ class Device(ABC):
Note, this does not downgrade the firmware.
"""
def __repr__(self):
def __repr__(self) -> str:
update_needed = " - update() needed" if not self._last_update else ""
return (
f"<{self.device_type} at {self.host} -"
@ -486,7 +488,9 @@ class Device(ABC):
"is_strip_socket": (None, DeviceType.StripSocket),
}
def _get_replacing_attr(self, module_name: ModuleName, *attrs):
def _get_replacing_attr(
self, module_name: ModuleName | None, *attrs: Any
) -> str | None:
# If module name is None check self
if not module_name:
check = self
@ -540,7 +544,7 @@ class Device(ABC):
"supported_modules": (None, ["modules"]),
}
def __getattr__(self, name):
def __getattr__(self, name: str) -> Any:
# is_device_type
if dep_device_type_attr := self._deprecated_device_type_attributes.get(name):
module = dep_device_type_attr[0]

View File

@ -83,7 +83,7 @@ async def _connect(config: DeviceConfig, protocol: BaseProtocol) -> Device:
if debug_enabled:
start_time = time.perf_counter()
def _perf_log(has_params, perf_type):
def _perf_log(has_params: bool, perf_type: str) -> None:
nonlocal start_time
if debug_enabled:
end_time = time.perf_counter()
@ -150,7 +150,7 @@ def _get_device_type_from_sys_info(info: dict[str, Any]) -> DeviceType:
return DeviceType.LightStrip
return DeviceType.Bulb
raise UnsupportedDeviceError("Unknown device type: %s" % type_)
raise UnsupportedDeviceError(f"Unknown device type: {type_}")
def get_device_class_from_sys_info(sysinfo: dict[str, Any]) -> type[IotDevice]:

View File

@ -75,14 +75,14 @@ class DeviceFamily(Enum):
SmartIpCamera = "SMART.IPCAMERA"
def _dataclass_from_dict(klass, in_val):
def _dataclass_from_dict(klass: Any, in_val: dict) -> Any:
if is_dataclass(klass):
fieldtypes = {f.name: f.type for f in fields(klass)}
val = {}
for dict_key in in_val:
if dict_key in fieldtypes:
if hasattr(fieldtypes[dict_key], "from_dict"):
val[dict_key] = fieldtypes[dict_key].from_dict(in_val[dict_key])
val[dict_key] = fieldtypes[dict_key].from_dict(in_val[dict_key]) # type: ignore[union-attr]
else:
val[dict_key] = _dataclass_from_dict(
fieldtypes[dict_key], in_val[dict_key]
@ -91,12 +91,12 @@ def _dataclass_from_dict(klass, in_val):
raise KasaException(
f"Cannot create dataclass from dict, unknown key: {dict_key}"
)
return klass(**val)
return klass(**val) # type: ignore[operator]
else:
return in_val
def _dataclass_to_dict(in_val):
def _dataclass_to_dict(in_val: Any) -> dict:
fieldtypes = {f.name: f.type for f in fields(in_val) if f.compare}
out_val = {}
for field_name in fieldtypes:
@ -210,7 +210,7 @@ class DeviceConfig:
aes_keys: Optional[KeyPairDict] = None
def __post_init__(self):
def __post_init__(self) -> None:
if self.connection_type is None:
self.connection_type = DeviceConnectionParameters(
DeviceFamily.IotSmartPlugSwitch, DeviceEncryptionType.Xor

View File

@ -89,9 +89,19 @@ import logging
import secrets
import socket
import struct
from collections.abc import Awaitable
from asyncio.transports import DatagramTransport
from pprint import pformat as pf
from typing import TYPE_CHECKING, Any, Callable, Dict, NamedTuple, Optional, Type, cast
from typing import (
TYPE_CHECKING,
Any,
Callable,
Coroutine,
Dict,
NamedTuple,
Optional,
Type,
cast,
)
from aiohttp import ClientSession
@ -140,8 +150,8 @@ class ConnectAttempt(NamedTuple):
device: type
OnDiscoveredCallable = Callable[[Device], Awaitable[None]]
OnUnsupportedCallable = Callable[[UnsupportedDeviceError], Awaitable[None]]
OnDiscoveredCallable = Callable[[Device], Coroutine]
OnUnsupportedCallable = Callable[[UnsupportedDeviceError], Coroutine]
OnConnectAttemptCallable = Callable[[ConnectAttempt, bool], None]
DeviceDict = Dict[str, Device]
@ -156,7 +166,7 @@ class _AesDiscoveryQuery:
keypair: KeyPair | None = None
@classmethod
def generate_query(cls):
def generate_query(cls) -> bytearray:
if not cls.keypair:
cls.keypair = KeyPair.create_key_pair(key_size=2048)
secret = secrets.token_bytes(4)
@ -215,7 +225,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
credentials: Credentials | None = None,
timeout: int | None = None,
) -> None:
self.transport = None
self.transport: DatagramTransport | None = None
self.discovery_packets = discovery_packets
self.interface = interface
self.on_discovered = on_discovered
@ -239,16 +249,19 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
self.target_discovered: bool = False
self._started_event = asyncio.Event()
def _run_callback_task(self, coro):
task = asyncio.create_task(coro)
def _run_callback_task(self, coro: Coroutine) -> None:
task: asyncio.Task = asyncio.create_task(coro)
self.callback_tasks.append(task)
async def wait_for_discovery_to_complete(self):
async def wait_for_discovery_to_complete(self) -> None:
"""Wait for the discovery task to complete."""
# Give some time for connection_made event to be received
async with asyncio_timeout(self.DISCOVERY_START_TIMEOUT):
await self._started_event.wait()
try:
if TYPE_CHECKING:
assert isinstance(self.discover_task, asyncio.Task)
await self.discover_task
except asyncio.CancelledError:
# if target_discovered then cancel was called internally
@ -257,11 +270,11 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
# Wait for any pending callbacks to complete
await asyncio.gather(*self.callback_tasks)
def connection_made(self, transport) -> None:
def connection_made(self, transport: DatagramTransport) -> None: # type: ignore[override]
"""Set socket options for broadcasting."""
self.transport = transport
self.transport = cast(DatagramTransport, transport)
sock = transport.get_extra_info("socket")
sock = self.transport.get_extra_info("socket")
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
try:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
@ -292,7 +305,11 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
self.transport.sendto(aes_discovery_query, self.target_2) # type: ignore
await asyncio.sleep(sleep_between_packets)
def datagram_received(self, data, addr) -> None:
def datagram_received(
self,
data: bytes,
addr: tuple[str, int],
) -> None:
"""Handle discovery responses."""
if TYPE_CHECKING:
assert _AesDiscoveryQuery.keypair
@ -338,18 +355,18 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
self._handle_discovered_event()
def _handle_discovered_event(self):
def _handle_discovered_event(self) -> None:
"""If target is in seen_hosts cancel discover_task."""
if self.target in self.seen_hosts:
self.target_discovered = True
if self.discover_task:
self.discover_task.cancel()
def error_received(self, ex):
def error_received(self, ex: Exception) -> None:
"""Handle asyncio.Protocol errors."""
_LOGGER.error("Got error: %s", ex)
def connection_lost(self, ex): # pragma: no cover
def connection_lost(self, ex: Exception | None) -> None: # pragma: no cover
"""Cancel the discover task if running."""
if self.discover_task:
self.discover_task.cancel()
@ -372,17 +389,17 @@ class Discover:
@staticmethod
async def discover(
*,
target="255.255.255.255",
on_discovered=None,
discovery_timeout=5,
discovery_packets=3,
interface=None,
on_unsupported=None,
credentials=None,
target: str = "255.255.255.255",
on_discovered: OnDiscoveredCallable | None = None,
discovery_timeout: int = 5,
discovery_packets: int = 3,
interface: str | None = None,
on_unsupported: OnUnsupportedCallable | None = None,
credentials: Credentials | None = None,
username: str | None = None,
password: str | None = None,
port=None,
timeout=None,
port: int | None = None,
timeout: int | None = None,
) -> DeviceDict:
"""Discover supported devices.
@ -636,7 +653,7 @@ class Discover:
)
if not dev_class:
raise UnsupportedDeviceError(
"Unknown device type: %s" % discovery_result.device_type,
f"Unknown device type: {discovery_result.device_type}",
discovery_result=info,
)
return dev_class

View File

@ -49,13 +49,13 @@ class EmeterStatus(dict):
except ValueError:
return None
def __repr__(self):
def __repr__(self) -> str:
return (
f"<EmeterStatus power={self.power} voltage={self.voltage}"
f" current={self.current} total={self.total}>"
)
def __getitem__(self, item):
def __getitem__(self, item: str) -> float | None:
"""Return value in wanted units."""
valid_keys = [
"voltage_mv",

View File

@ -15,10 +15,10 @@ class KasaException(Exception):
class TimeoutError(KasaException, _asyncioTimeoutError):
"""Timeout exception for device errors."""
def __repr__(self):
def __repr__(self) -> str:
return KasaException.__repr__(self)
def __str__(self):
def __str__(self) -> str:
return KasaException.__str__(self)
@ -42,11 +42,11 @@ class DeviceError(KasaException):
self.error_code: SmartErrorCode | None = kwargs.get("error_code", None)
super().__init__(*args)
def __repr__(self):
def __repr__(self) -> str:
err_code = self.error_code.__repr__() if self.error_code else ""
return f"{self.__class__.__name__}({err_code})"
def __str__(self):
def __str__(self) -> str:
err_code = f" (error_code={self.error_code.name})" if self.error_code else ""
return super().__str__() + err_code
@ -62,7 +62,7 @@ class _RetryableError(DeviceError):
class SmartErrorCode(IntEnum):
"""Enum for SMART Error Codes."""
def __str__(self):
def __str__(self) -> str:
return f"{self.name}({self.value})"
@staticmethod

View File

@ -12,12 +12,12 @@ class Experimental:
ENV_VAR = "KASA_EXPERIMENTAL"
@classmethod
def set_enabled(cls, enabled):
def set_enabled(cls, enabled: bool) -> None:
"""Set the enabled value."""
cls._enabled = enabled
@classmethod
def enabled(cls):
def enabled(cls) -> bool:
"""Get the enabled value."""
if cls._enabled is not None:
return cls._enabled

View File

@ -50,11 +50,13 @@ class SmartCameraProtocol(SmartProtocol):
"""Class for SmartCamera Protocol."""
async def _handle_response_lists(
self, response_result: dict[str, Any], method, retry_count
):
self, response_result: dict[str, Any], method: str, retry_count: int
) -> None:
pass
def _handle_response_error_code(self, resp_dict: dict, method, raise_on_error=True):
def _handle_response_error_code(
self, resp_dict: dict, method: str, raise_on_error: bool = True
) -> None:
error_code_raw = resp_dict.get("error_code")
try:
error_code = SmartErrorCode.from_int(error_code_raw)
@ -203,7 +205,7 @@ class _ChildCameraProtocolWrapper(SmartProtocol):
device responses before returning to the caller.
"""
def __init__(self, device_id: str, base_protocol: SmartProtocol):
def __init__(self, device_id: str, base_protocol: SmartProtocol) -> None:
self._device_id = device_id
self._protocol = base_protocol
self._transport = base_protocol._transport

View File

@ -256,7 +256,9 @@ class SslAesTransport(BaseTransport):
return ret_val # type: ignore[return-value]
@staticmethod
def generate_confirm_hash(local_nonce, server_nonce, pwd_hash):
def generate_confirm_hash(
local_nonce: str, server_nonce: str, pwd_hash: str
) -> str:
"""Generate an auth hash for the protocol on the supplied credentials."""
expected_confirm_bytes = _sha256_hash(
local_nonce.encode() + pwd_hash.encode() + server_nonce.encode()
@ -264,7 +266,9 @@ class SslAesTransport(BaseTransport):
return expected_confirm_bytes + server_nonce + local_nonce
@staticmethod
def generate_digest_password(local_nonce, server_nonce, pwd_hash):
def generate_digest_password(
local_nonce: str, server_nonce: str, pwd_hash: str
) -> str:
"""Generate an auth hash for the protocol on the supplied credentials."""
digest_password_hash = _sha256_hash(
pwd_hash.encode() + local_nonce.encode() + server_nonce.encode()
@ -275,7 +279,7 @@ class SslAesTransport(BaseTransport):
@staticmethod
def generate_encryption_token(
token_type, local_nonce, server_nonce, pwd_hash
token_type: str, local_nonce: str, server_nonce: str, pwd_hash: str
) -> bytes:
"""Generate encryption token."""
hashedKey = _sha256_hash(
@ -302,7 +306,9 @@ class SslAesTransport(BaseTransport):
local_nonce, server_nonce, pwd_hash = await self.perform_handshake1()
await self.perform_handshake2(local_nonce, server_nonce, pwd_hash)
async def perform_handshake2(self, local_nonce, server_nonce, pwd_hash) -> None:
async def perform_handshake2(
self, local_nonce: str, server_nonce: str, pwd_hash: str
) -> None:
"""Perform the handshake."""
_LOGGER.debug("Performing handshake2 ...")
digest_password = self.generate_digest_password(

View File

@ -162,7 +162,7 @@ class Feature:
#: If set, this property will be used to get *choices*.
choices_getter: str | Callable[[], list[str]] | None = None
def __post_init__(self):
def __post_init__(self) -> None:
"""Handle late-binding of members."""
# Populate minimum & maximum values, if range_getter is given
self._container = self.container if self.container is not None else self.device
@ -188,7 +188,7 @@ class Feature:
f"Read-only feat defines attribute_setter: {self.name} ({self.id}):"
)
def _get_property_value(self, getter):
def _get_property_value(self, getter: str | Callable | None) -> Any:
if getter is None:
return None
if isinstance(getter, str):
@ -227,7 +227,7 @@ class Feature:
return 0
@property
def value(self):
def value(self) -> int | float | bool | str | Enum | None:
"""Return the current value."""
if self.type == Feature.Type.Action:
return "<Action>"
@ -264,7 +264,7 @@ class Feature:
return await getattr(container, self.attribute_setter)(value)
def __repr__(self):
def __repr__(self) -> str:
try:
value = self.value
choices = self.choices
@ -286,8 +286,8 @@ class Feature:
value = " ".join(
[f"*{choice}*" if choice == value else choice for choice in choices]
)
if self.precision_hint is not None and value is not None:
value = round(self.value, self.precision_hint)
if self.precision_hint is not None and isinstance(value, float):
value = round(value, self.precision_hint)
s = f"{self.name} ({self.id}): {value}"
if self.unit is not None:

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio
import logging
import ssl
import time
from typing import Any, Dict
@ -64,7 +65,7 @@ class HttpClient:
json: dict | Any | None = None,
headers: dict[str, str] | None = None,
cookies_dict: dict[str, str] | None = None,
ssl=False,
ssl: ssl.SSLContext | bool = False,
) -> tuple[int, dict | bytes | None]:
"""Send an http post request to the device.

View File

@ -4,6 +4,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from enum import IntFlag, auto
from typing import Any
from warnings import warn
from ..emeterstatus import EmeterStatus
@ -31,7 +32,7 @@ class Energy(Module, ABC):
"""Return True if module supports the feature."""
return module_feature in self._supported
def _initialize_features(self):
def _initialize_features(self) -> None:
"""Initialize features."""
device = self._device
self._add_feature(
@ -151,22 +152,26 @@ class Energy(Module, ABC):
"""Get the current voltage in V."""
@abstractmethod
async def get_status(self):
async def get_status(self) -> EmeterStatus:
"""Return real-time statistics."""
@abstractmethod
async def erase_stats(self):
async def erase_stats(self) -> dict:
"""Erase all stats."""
@abstractmethod
async def get_daily_stats(self, *, year=None, month=None, kwh=True) -> dict:
async def get_daily_stats(
self, *, year: int | None = None, month: int | None = None, kwh: bool = True
) -> dict:
"""Return daily stats for the given year & month.
The return value is a dictionary of {day: energy, ...}.
"""
@abstractmethod
async def get_monthly_stats(self, *, year=None, kwh=True) -> dict:
async def get_monthly_stats(
self, *, year: int | None = None, kwh: bool = True
) -> dict:
"""Return monthly stats for the given year."""
_deprecated_attributes = {
@ -179,7 +184,7 @@ class Energy(Module, ABC):
"get_monthstat": "get_monthly_stats",
}
def __getattr__(self, name):
def __getattr__(self, name: str) -> Any:
if attr := self._deprecated_attributes.get(name):
msg = f"{name} is deprecated, use {attr} instead"
warn(msg, DeprecationWarning, stacklevel=2)

View File

@ -16,5 +16,5 @@ class Fan(Module, ABC):
"""Return fan speed level."""
@abstractmethod
async def set_fan_speed_level(self, level: int):
async def set_fan_speed_level(self, level: int) -> dict:
"""Set fan speed level."""

View File

@ -11,7 +11,7 @@ from ..module import Module
class Led(Module, ABC):
"""Base interface to represent a LED module."""
def _initialize_features(self):
def _initialize_features(self) -> None:
"""Initialize features."""
device = self._device
self._add_feature(
@ -34,5 +34,5 @@ class Led(Module, ABC):
"""Return current led status."""
@abstractmethod
async def set_led(self, enable: bool) -> None:
async def set_led(self, enable: bool) -> dict:
"""Set led."""

View File

@ -166,7 +166,7 @@ class Light(Module, ABC):
@abstractmethod
async def set_color_temp(
self, temp: int, *, brightness=None, transition: int | None = None
self, temp: int, *, brightness: int | None = None, transition: int | None = None
) -> dict:
"""Set the color temperature of the device in kelvin.

View File

@ -53,7 +53,7 @@ class LightEffect(Module, ABC):
LIGHT_EFFECTS_OFF = "Off"
def _initialize_features(self):
def _initialize_features(self) -> None:
"""Initialize features."""
device = self._device
self._add_feature(
@ -96,7 +96,7 @@ class LightEffect(Module, ABC):
*,
brightness: int | None = None,
transition: int | None = None,
) -> None:
) -> dict:
"""Set an effect on the device.
If brightness or transition is defined,
@ -110,10 +110,11 @@ class LightEffect(Module, ABC):
:param int transition: The wanted transition time
"""
@abstractmethod
async def set_custom_effect(
self,
effect_dict: dict,
) -> None:
) -> dict:
"""Set a custom effect on the device.
:param str effect_dict: The custom effect dict to set

View File

@ -83,7 +83,7 @@ class LightPreset(Module):
PRESET_NOT_SET = "Not set"
def _initialize_features(self):
def _initialize_features(self) -> None:
"""Initialize features."""
device = self._device
self._add_feature(
@ -127,7 +127,7 @@ class LightPreset(Module):
async def set_preset(
self,
preset_name: str,
) -> None:
) -> dict:
"""Set a light preset for the device."""
@abstractmethod
@ -135,7 +135,7 @@ class LightPreset(Module):
self,
preset_name: str,
preset_info: LightState,
) -> None:
) -> dict:
"""Update the preset with *preset_name* with the new *preset_info*."""
@property

View File

@ -54,7 +54,7 @@ class TurnOnBehavior(BaseModel):
mode: BehaviorMode
@root_validator
def _mode_based_on_preset(cls, values):
def _mode_based_on_preset(cls, values: dict) -> dict:
"""Set the mode based on the preset value."""
if values["preset"] is not None:
values["mode"] = BehaviorMode.Preset
@ -209,7 +209,7 @@ class IotBulb(IotDevice):
super().__init__(host=host, config=config, protocol=protocol)
self._device_type = DeviceType.Bulb
async def _initialize_modules(self):
async def _initialize_modules(self) -> None:
"""Initialize modules not added in init."""
await super()._initialize_modules()
self.add_module(
@ -307,7 +307,7 @@ class IotBulb(IotDevice):
await self._query_helper(self.LIGHT_SERVICE, "get_default_behavior")
)
async def set_turn_on_behavior(self, behavior: TurnOnBehaviors):
async def set_turn_on_behavior(self, behavior: TurnOnBehaviors) -> dict:
"""Set the behavior for turning the bulb on.
If you do not want to manually construct the behavior object,
@ -426,7 +426,7 @@ class IotBulb(IotDevice):
@requires_update
async def _set_color_temp(
self, temp: int, *, brightness=None, transition: int | None = None
self, temp: int, *, brightness: int | None = None, transition: int | None = None
) -> dict:
"""Set the color temperature of the device in kelvin.
@ -450,7 +450,7 @@ class IotBulb(IotDevice):
return await self._set_light_state(light_state, transition=transition)
def _raise_for_invalid_brightness(self, value):
def _raise_for_invalid_brightness(self, value: int) -> None:
if not isinstance(value, int):
raise TypeError("Brightness must be an integer")
if not (0 <= value <= 100):
@ -517,7 +517,7 @@ class IotBulb(IotDevice):
"""Return that the bulb has an emeter."""
return True
async def set_alias(self, alias: str) -> None:
async def set_alias(self, alias: str) -> dict:
"""Set the device name (alias).
Overridden to use a different module name.

View File

@ -19,7 +19,7 @@ import inspect
import logging
from collections.abc import Mapping, Sequence
from datetime import datetime, timedelta, tzinfo
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, Callable, cast
from warnings import warn
from ..device import Device, WifiNetwork
@ -35,12 +35,12 @@ from .modules import Emeter
_LOGGER = logging.getLogger(__name__)
def requires_update(f):
def requires_update(f: Callable) -> Any:
"""Indicate that `update` should be called before accessing this method.""" # noqa: D202
if inspect.iscoroutinefunction(f):
@functools.wraps(f)
async def wrapped(*args, **kwargs):
async def wrapped(*args: Any, **kwargs: Any) -> Any:
self = args[0]
if self._last_update is None and f.__name__ not in self._sys_info:
raise KasaException("You need to await update() to access the data")
@ -49,13 +49,13 @@ def requires_update(f):
else:
@functools.wraps(f)
def wrapped(*args, **kwargs):
def wrapped(*args: Any, **kwargs: Any) -> Any:
self = args[0]
if self._last_update is None and f.__name__ not in self._sys_info:
raise KasaException("You need to await update() to access the data")
return f(*args, **kwargs)
f.requires_update = True
f.requires_update = True # type: ignore[attr-defined]
return wrapped
@ -197,7 +197,7 @@ class IotDevice(Device):
return cast(ModuleMapping[IotModule], self._supported_modules)
return self._supported_modules
def add_module(self, name: str | ModuleName[Module], module: IotModule):
def add_module(self, name: str | ModuleName[Module], module: IotModule) -> None:
"""Register a module."""
if name in self._modules:
_LOGGER.debug("Module %s already registered, ignoring...", name)
@ -207,8 +207,12 @@ class IotDevice(Device):
self._modules[name] = module
def _create_request(
self, target: str, cmd: str, arg: dict | None = None, child_ids=None
):
self,
target: str,
cmd: str,
arg: dict | None = None,
child_ids: list | None = None,
) -> dict:
if arg is None:
arg = {}
request: dict[str, Any] = {target: {cmd: arg}}
@ -225,8 +229,12 @@ class IotDevice(Device):
raise KasaException("update() required prior accessing emeter")
async def _query_helper(
self, target: str, cmd: str, arg: dict | None = None, child_ids=None
) -> Any:
self,
target: str,
cmd: str,
arg: dict | None = None,
child_ids: list | None = None,
) -> dict:
"""Query device, return results or raise an exception.
:param target: Target system {system, time, emeter, ..}
@ -276,7 +284,7 @@ class IotDevice(Device):
"""Retrieve system information."""
return await self._query_helper("system", "get_sysinfo")
async def update(self, update_children: bool = True):
async def update(self, update_children: bool = True) -> None:
"""Query the device to update the data.
Needed for properties that are decorated with `requires_update`.
@ -305,7 +313,7 @@ class IotDevice(Device):
if not self._features:
await self._initialize_features()
async def _initialize_modules(self):
async def _initialize_modules(self) -> None:
"""Initialize modules not added in init."""
if self.has_emeter:
_LOGGER.debug(
@ -313,7 +321,7 @@ class IotDevice(Device):
)
self.add_module(Module.Energy, Emeter(self, self.emeter_type))
async def _initialize_features(self):
async def _initialize_features(self) -> None:
"""Initialize common features."""
self._add_feature(
Feature(
@ -364,7 +372,7 @@ class IotDevice(Device):
)
)
for module in self._supported_modules.values():
for module in self.modules.values():
module._initialize_features()
for module_feat in module._module_features.values():
self._add_feature(module_feat)
@ -453,7 +461,7 @@ class IotDevice(Device):
sys_info = self._sys_info
return sys_info.get("alias") if sys_info else None
async def set_alias(self, alias: str) -> None:
async def set_alias(self, alias: str) -> dict:
"""Set the device name (alias)."""
return await self._query_helper("system", "set_dev_alias", {"alias": alias})
@ -550,7 +558,7 @@ class IotDevice(Device):
return mac
async def set_mac(self, mac):
async def set_mac(self, mac: str) -> dict:
"""Set the mac address.
:param str mac: mac in hexadecimal with colons, e.g. 01:23:45:67:89:ab
@ -576,7 +584,7 @@ class IotDevice(Device):
"""Turn off the device."""
raise NotImplementedError("Device subclass needs to implement this.")
async def turn_on(self, **kwargs) -> dict | None:
async def turn_on(self, **kwargs) -> dict:
"""Turn device on."""
raise NotImplementedError("Device subclass needs to implement this.")
@ -586,7 +594,7 @@ class IotDevice(Device):
"""Return True if the device is on."""
raise NotImplementedError("Device subclass needs to implement this.")
async def set_state(self, on: bool):
async def set_state(self, on: bool) -> dict:
"""Set the device state."""
if on:
return await self.turn_on()
@ -627,7 +635,7 @@ class IotDevice(Device):
async def wifi_scan(self) -> list[WifiNetwork]: # noqa: D202
"""Scan for available wifi networks."""
async def _scan(target):
async def _scan(target: str) -> dict:
return await self._query_helper(target, "get_scaninfo", {"refresh": 1})
try:
@ -639,17 +647,17 @@ class IotDevice(Device):
info = await _scan("smartlife.iot.common.softaponboarding")
if "ap_list" not in info:
raise KasaException("Invalid response for wifi scan: %s" % info)
raise KasaException(f"Invalid response for wifi scan: {info}")
return [WifiNetwork(**x) for x in info["ap_list"]]
async def wifi_join(self, ssid: str, password: str, keytype: str = "3"): # noqa: D202
async def wifi_join(self, ssid: str, password: str, keytype: str = "3") -> dict: # noqa: D202
"""Join the given wifi network.
If joining the network fails, the device will return to AP mode after a while.
"""
async def _join(target, payload):
async def _join(target: str, payload: dict) -> dict:
return await self._query_helper(target, "set_stainfo", payload)
payload = {"ssid": ssid, "password": password, "key_type": int(keytype)}

View File

@ -80,7 +80,7 @@ class IotDimmer(IotPlug):
super().__init__(host=host, config=config, protocol=protocol)
self._device_type = DeviceType.Dimmer
async def _initialize_modules(self):
async def _initialize_modules(self) -> None:
"""Initialize modules."""
await super()._initialize_modules()
# TODO: need to be verified if it's okay to call these on HS220 w/o these
@ -103,7 +103,9 @@ class IotDimmer(IotPlug):
return int(sys_info["brightness"])
@requires_update
async def _set_brightness(self, brightness: int, *, transition: int | None = None):
async def _set_brightness(
self, brightness: int, *, transition: int | None = None
) -> dict:
"""Set the new dimmer brightness level in percentage.
:param int transition: transition duration in milliseconds.
@ -134,7 +136,7 @@ class IotDimmer(IotPlug):
self.DIMMER_SERVICE, "set_brightness", {"brightness": brightness}
)
async def turn_off(self, *, transition: int | None = None, **kwargs):
async def turn_off(self, *, transition: int | None = None, **kwargs) -> dict:
"""Turn the bulb off.
:param int transition: transition duration in milliseconds.
@ -145,7 +147,7 @@ class IotDimmer(IotPlug):
return await super().turn_off()
@requires_update
async def turn_on(self, *, transition: int | None = None, **kwargs):
async def turn_on(self, *, transition: int | None = None, **kwargs) -> dict:
"""Turn the bulb on.
:param int transition: transition duration in milliseconds.
@ -157,7 +159,7 @@ class IotDimmer(IotPlug):
return await super().turn_on()
async def set_dimmer_transition(self, brightness: int, transition: int):
async def set_dimmer_transition(self, brightness: int, transition: int) -> dict:
"""Turn the bulb on to brightness percentage over transition milliseconds.
A brightness value of 0 will turn off the dimmer.
@ -176,7 +178,7 @@ class IotDimmer(IotPlug):
if not isinstance(transition, int):
raise TypeError(f"Transition must be integer, not of {type(transition)}.")
if transition <= 0:
raise ValueError("Transition value %s is not valid." % transition)
raise ValueError(f"Transition value {transition} is not valid.")
return await self._query_helper(
self.DIMMER_SERVICE,
@ -185,7 +187,7 @@ class IotDimmer(IotPlug):
)
@requires_update
async def get_behaviors(self):
async def get_behaviors(self) -> dict:
"""Return button behavior settings."""
behaviors = await self._query_helper(
self.DIMMER_SERVICE, "get_default_behavior", {}
@ -195,7 +197,7 @@ class IotDimmer(IotPlug):
@requires_update
async def set_button_action(
self, action_type: ActionType, action: ButtonAction, index: int | None = None
):
) -> dict:
"""Set action to perform on button click/hold.
:param action_type ActionType: whether to control double click or hold action.
@ -209,15 +211,17 @@ class IotDimmer(IotPlug):
if index is not None:
payload["index"] = index
await self._query_helper(self.DIMMER_SERVICE, action_type_setter, payload)
return await self._query_helper(
self.DIMMER_SERVICE, action_type_setter, payload
)
@requires_update
async def set_fade_time(self, fade_type: FadeType, time: int):
async def set_fade_time(self, fade_type: FadeType, time: int) -> dict:
"""Set time for fade in / fade out."""
fade_type_setter = f"set_{fade_type}_time"
payload = {"fadeTime": time}
await self._query_helper(self.DIMMER_SERVICE, fade_type_setter, payload)
return await self._query_helper(self.DIMMER_SERVICE, fade_type_setter, payload)
@property # type: ignore
@requires_update

View File

@ -57,7 +57,7 @@ class IotLightStrip(IotBulb):
super().__init__(host=host, config=config, protocol=protocol)
self._device_type = DeviceType.LightStrip
async def _initialize_modules(self):
async def _initialize_modules(self) -> None:
"""Initialize modules not added in init."""
await super()._initialize_modules()
self.add_module(

View File

@ -1,6 +1,9 @@
"""Base class for IOT module implementations."""
from __future__ import annotations
import logging
from typing import Any
from ..exceptions import KasaException
from ..module import Module
@ -24,16 +27,16 @@ merge = _merge_dict
class IotModule(Module):
"""Base class implemention for all IOT modules."""
def call(self, method, params=None):
async def call(self, method: str, params: dict | None = None) -> dict:
"""Call the given method with the given parameters."""
return self._device._query_helper(self._module, method, params)
return await self._device._query_helper(self._module, method, params)
def query_for_command(self, query, params=None):
def query_for_command(self, query: str, params: dict | None = None) -> dict:
"""Create a request object for the given parameters."""
return self._device._create_request(self._module, query, params)
@property
def estimated_query_response_size(self):
def estimated_query_response_size(self) -> int:
"""Estimated maximum size of query response.
The inheriting modules implement this to estimate how large a query response
@ -42,7 +45,7 @@ class IotModule(Module):
return 256 # Estimate for modules that don't specify
@property
def data(self):
def data(self) -> dict[str, Any]:
"""Return the module specific raw data from the last update."""
dev = self._device
q = self.query()

View File

@ -3,6 +3,7 @@
from __future__ import annotations
import logging
from typing import Any
from ..device_type import DeviceType
from ..deviceconfig import DeviceConfig
@ -54,7 +55,7 @@ class IotPlug(IotDevice):
super().__init__(host=host, config=config, protocol=protocol)
self._device_type = DeviceType.Plug
async def _initialize_modules(self):
async def _initialize_modules(self) -> None:
"""Initialize modules."""
await super()._initialize_modules()
self.add_module(Module.IotSchedule, Schedule(self, "schedule"))
@ -71,11 +72,11 @@ class IotPlug(IotDevice):
sys_info = self.sys_info
return bool(sys_info["relay_state"])
async def turn_on(self, **kwargs):
async def turn_on(self, **kwargs: Any) -> dict:
"""Turn the switch on."""
return await self._query_helper("system", "set_relay_state", {"state": 1})
async def turn_off(self, **kwargs):
async def turn_off(self, **kwargs: Any) -> dict:
"""Turn the switch off."""
return await self._query_helper("system", "set_relay_state", {"state": 0})

View File

@ -26,7 +26,7 @@ from .modules import Antitheft, Cloud, Countdown, Emeter, Led, Schedule, Time, U
_LOGGER = logging.getLogger(__name__)
def merge_sums(dicts):
def merge_sums(dicts: list[dict]) -> dict:
"""Merge the sum of dicts."""
total_dict: defaultdict[int, float] = defaultdict(lambda: 0.0)
for sum_dict in dicts:
@ -99,7 +99,7 @@ class IotStrip(IotDevice):
self.emeter_type = "emeter"
self._device_type = DeviceType.Strip
async def _initialize_modules(self):
async def _initialize_modules(self) -> None:
"""Initialize modules."""
# Strip has different modules to plug so do not call super
self.add_module(Module.IotAntitheft, Antitheft(self, "anti_theft"))
@ -121,7 +121,7 @@ class IotStrip(IotDevice):
"""Return if any of the outlets are on."""
return any(plug.is_on for plug in self.children)
async def update(self, update_children: bool = True):
async def update(self, update_children: bool = True) -> None:
"""Update some of the attributes.
Needed for methods that are decorated with `requires_update`.
@ -150,20 +150,20 @@ class IotStrip(IotDevice):
if not self.features:
await self._initialize_features()
async def _initialize_features(self):
async def _initialize_features(self) -> None:
"""Initialize common features."""
# Do not initialize features until children are created
if not self.children:
return
await super()._initialize_features()
async def turn_on(self, **kwargs):
async def turn_on(self, **kwargs) -> dict:
"""Turn the strip on."""
await self._query_helper("system", "set_relay_state", {"state": 1})
return await self._query_helper("system", "set_relay_state", {"state": 1})
async def turn_off(self, **kwargs):
async def turn_off(self, **kwargs) -> dict:
"""Turn the strip off."""
await self._query_helper("system", "set_relay_state", {"state": 0})
return await self._query_helper("system", "set_relay_state", {"state": 0})
@property # type: ignore
@requires_update
@ -188,7 +188,7 @@ class StripEmeter(IotModule, Energy):
"""Return True if module supports the feature."""
return module_feature in self._supported
def query(self):
def query(self) -> dict:
"""Return the base query."""
return {}
@ -246,11 +246,13 @@ class StripEmeter(IotModule, Energy):
]
)
async def erase_stats(self):
async def erase_stats(self) -> dict:
"""Erase energy meter statistics for all plugs."""
for plug in self._device.children:
await plug.modules[Module.Energy].erase_stats()
return {}
@property # type: ignore
def consumption_this_month(self) -> float | None:
"""Return this month's energy consumption in kWh."""
@ -320,7 +322,7 @@ class IotStripPlug(IotPlug):
self.protocol = parent.protocol # Must use the same connection as the parent
self._on_since: datetime | None = None
async def _initialize_modules(self):
async def _initialize_modules(self) -> None:
"""Initialize modules not added in init."""
if self.has_emeter:
self.add_module(Module.Energy, Emeter(self, self.emeter_type))
@ -329,7 +331,7 @@ class IotStripPlug(IotPlug):
self.add_module(Module.IotSchedule, Schedule(self, "schedule"))
self.add_module(Module.IotCountdown, Countdown(self, "countdown"))
async def _initialize_features(self):
async def _initialize_features(self) -> None:
"""Initialize common features."""
self._add_feature(
Feature(
@ -353,19 +355,20 @@ class IotStripPlug(IotPlug):
type=Feature.Type.Sensor,
)
)
for module in self._supported_modules.values():
for module in self.modules.values():
module._initialize_features()
for module_feat in module._module_features.values():
self._add_feature(module_feat)
async def update(self, update_children: bool = True):
async def update(self, update_children: bool = True) -> None:
"""Query the device to update the data.
Needed for properties that are decorated with `requires_update`.
"""
await self._update(update_children)
async def _update(self, update_children: bool = True):
async def _update(self, update_children: bool = True) -> None:
"""Query the device to update the data.
Internal implementation to allow patching of public update in the cli
@ -379,8 +382,12 @@ class IotStripPlug(IotPlug):
await self._initialize_features()
def _create_request(
self, target: str, cmd: str, arg: dict | None = None, child_ids=None
):
self,
target: str,
cmd: str,
arg: dict | None = None,
child_ids: list | None = None,
) -> dict:
request: dict[str, Any] = {
"context": {"child_ids": [self.child_id]},
target: {cmd: arg},
@ -388,8 +395,12 @@ class IotStripPlug(IotPlug):
return request
async def _query_helper(
self, target: str, cmd: str, arg: dict | None = None, child_ids=None
) -> Any:
self,
target: str,
cmd: str,
arg: dict | None = None,
child_ids: list | None = None,
) -> dict:
"""Override query helper to include the child_ids."""
return await self._parent._query_helper(
target, cmd, arg, child_ids=[self.child_id]

View File

@ -11,7 +11,7 @@ _LOGGER = logging.getLogger(__name__)
class AmbientLight(IotModule):
"""Implements ambient light controls for the motion sensor."""
def _initialize_features(self):
def _initialize_features(self) -> None:
"""Initialize features after the initial update."""
self._add_feature(
Feature(
@ -40,7 +40,7 @@ class AmbientLight(IotModule):
)
)
def query(self):
def query(self) -> dict:
"""Request configuration."""
req = merge(
self.query_for_command("get_config"),
@ -74,18 +74,18 @@ class AmbientLight(IotModule):
"""Return True if the module is enabled."""
return int(self.data["get_current_brt"]["value"])
async def set_enabled(self, state: bool):
async def set_enabled(self, state: bool) -> dict:
"""Enable/disable LAS."""
return await self.call("set_enable", {"enable": int(state)})
async def current_brightness(self) -> int:
async def current_brightness(self) -> dict:
"""Return current brightness.
Return value units.
"""
return await self.call("get_current_brt")
async def set_brightness_limit(self, value: int):
async def set_brightness_limit(self, value: int) -> dict:
"""Set the limit when the motion sensor is inactive.
See `presets` for preset values. Custom values are also likely allowed.

View File

@ -24,7 +24,7 @@ class CloudInfo(BaseModel):
class Cloud(IotModule):
"""Module implementing support for cloud services."""
def _initialize_features(self):
def _initialize_features(self) -> None:
"""Initialize features after the initial update."""
self._add_feature(
Feature(
@ -44,7 +44,7 @@ class Cloud(IotModule):
"""Return true if device is connected to the cloud."""
return self.info.binded
def query(self):
def query(self) -> dict:
"""Request cloud connectivity info."""
return self.query_for_command("get_info")
@ -53,20 +53,20 @@ class Cloud(IotModule):
"""Return information about the cloud connectivity."""
return CloudInfo.parse_obj(self.data["get_info"])
def get_available_firmwares(self):
def get_available_firmwares(self) -> dict:
"""Return list of available firmwares."""
return self.query_for_command("get_intl_fw_list")
def set_server(self, url: str):
def set_server(self, url: str) -> dict:
"""Set the update server URL."""
return self.query_for_command("set_server_url", {"server": url})
def connect(self, username: str, password: str):
def connect(self, username: str, password: str) -> dict:
"""Login to the cloud using given information."""
return self.query_for_command(
"bind", {"username": username, "password": password}
)
def disconnect(self):
def disconnect(self) -> dict:
"""Disconnect from the cloud."""
return self.query_for_command("unbind")

View File

@ -70,7 +70,7 @@ class Emeter(Usage, EnergyInterface):
"""Get the current voltage in V."""
return self.status.voltage
async def erase_stats(self):
async def erase_stats(self) -> dict:
"""Erase all stats.
Uses different query than usage meter.
@ -81,7 +81,9 @@ class Emeter(Usage, EnergyInterface):
"""Return real-time statistics."""
return EmeterStatus(await self.call("get_realtime"))
async def get_daily_stats(self, *, year=None, month=None, kwh=True) -> dict:
async def get_daily_stats(
self, *, year: int | None = None, month: int | None = None, kwh: bool = True
) -> dict:
"""Return daily stats for the given year & month.
The return value is a dictionary of {day: energy, ...}.
@ -90,7 +92,9 @@ class Emeter(Usage, EnergyInterface):
data = self._convert_stat_data(data["day_list"], entry_key="day", kwh=kwh)
return data
async def get_monthly_stats(self, *, year=None, kwh=True) -> dict:
async def get_monthly_stats(
self, *, year: int | None = None, kwh: bool = True
) -> dict:
"""Return monthly stats for the given year.
The return value is a dictionary of {month: energy, ...}.

View File

@ -14,7 +14,7 @@ class Led(IotModule, LedInterface):
return {}
@property
def mode(self):
def mode(self) -> str:
"""LED mode setting.
"always", "never"
@ -27,7 +27,7 @@ class Led(IotModule, LedInterface):
sys_info = self.data
return bool(1 - sys_info["led_off"])
async def set_led(self, state: bool):
async def set_led(self, state: bool) -> dict:
"""Set the state of the led (night mode)."""
return await self.call("set_led_off", {"off": int(not state)})

View File

@ -27,7 +27,7 @@ class Light(IotModule, LightInterface):
_device: IotBulb | IotDimmer
_light_state: LightState
def _initialize_features(self):
def _initialize_features(self) -> None:
"""Initialize features."""
super()._initialize_features()
device = self._device
@ -185,7 +185,7 @@ class Light(IotModule, LightInterface):
return bulb._color_temp
async def set_color_temp(
self, temp: int, *, brightness=None, transition: int | None = None
self, temp: int, *, brightness: int | None = None, transition: int | None = None
) -> dict:
"""Set the color temperature of the device in kelvin.

View File

@ -50,7 +50,7 @@ class LightEffect(IotModule, LightEffectInterface):
*,
brightness: int | None = None,
transition: int | None = None,
) -> None:
) -> dict:
"""Set an effect on the device.
If brightness or transition is defined,
@ -73,7 +73,7 @@ class LightEffect(IotModule, LightEffectInterface):
effect_dict = EFFECT_MAPPING_V1["Aurora"]
effect_dict = {**effect_dict}
effect_dict["enable"] = 0
await self.set_custom_effect(effect_dict)
return await self.set_custom_effect(effect_dict)
elif effect not in EFFECT_MAPPING_V1:
raise ValueError(f"The effect {effect} is not a built in effect.")
else:
@ -84,12 +84,12 @@ class LightEffect(IotModule, LightEffectInterface):
if transition is not None:
effect_dict["transition"] = transition
await self.set_custom_effect(effect_dict)
return await self.set_custom_effect(effect_dict)
async def set_custom_effect(
self,
effect_dict: dict,
) -> None:
) -> dict:
"""Set a custom effect on the device.
:param str effect_dict: The custom effect dict to set
@ -104,7 +104,7 @@ class LightEffect(IotModule, LightEffectInterface):
"""Return True if the device supports setting custom effects."""
return True
def query(self):
def query(self) -> dict:
"""Return the base query."""
return {}

View File

@ -41,7 +41,7 @@ class LightPreset(IotModule, LightPresetInterface):
_presets: dict[str, IotLightPreset]
_preset_list: list[str]
async def _post_update_hook(self):
async def _post_update_hook(self) -> None:
"""Update the internal presets."""
self._presets = {
f"Light preset {index+1}": IotLightPreset(**vals)
@ -93,7 +93,7 @@ class LightPreset(IotModule, LightPresetInterface):
async def set_preset(
self,
preset_name: str,
) -> None:
) -> dict:
"""Set a light preset for the device."""
light = self._device.modules[Module.Light]
if preset_name == self.PRESET_NOT_SET:
@ -104,7 +104,7 @@ class LightPreset(IotModule, LightPresetInterface):
elif (preset := self._presets.get(preset_name)) is None: # type: ignore[assignment]
raise ValueError(f"{preset_name} is not a valid preset: {self.preset_list}")
await light.set_state(preset)
return await light.set_state(preset)
@property
def has_save_preset(self) -> bool:
@ -115,7 +115,7 @@ class LightPreset(IotModule, LightPresetInterface):
self,
preset_name: str,
preset_state: LightState,
) -> None:
) -> dict:
"""Update the preset with preset_name with the new preset_info."""
if len(self._presets) == 0:
raise KasaException("Device does not supported saving presets")
@ -129,7 +129,7 @@ class LightPreset(IotModule, LightPresetInterface):
return await self.call("set_preferred_state", state)
def query(self):
def query(self) -> dict:
"""Return the base query."""
return {}
@ -142,7 +142,7 @@ class LightPreset(IotModule, LightPresetInterface):
if "id" not in vals
]
async def _deprecated_save_preset(self, preset: IotLightPreset):
async def _deprecated_save_preset(self, preset: IotLightPreset) -> dict:
"""Save a setting preset.
You can either construct a preset object manually, or pass an existing one

View File

@ -24,7 +24,7 @@ class Range(Enum):
class Motion(IotModule):
"""Implements the motion detection (PIR) module."""
def _initialize_features(self):
def _initialize_features(self) -> None:
"""Initialize features after the initial update."""
# Only add features if the device supports the module
if "get_config" not in self.data:
@ -48,7 +48,7 @@ class Motion(IotModule):
)
)
def query(self):
def query(self) -> dict:
"""Request PIR configuration."""
return self.query_for_command("get_config")
@ -67,13 +67,13 @@ class Motion(IotModule):
"""Return True if module is enabled."""
return bool(self.config["enable"])
async def set_enabled(self, state: bool):
async def set_enabled(self, state: bool) -> dict:
"""Enable/disable PIR."""
return await self.call("set_enable", {"enable": int(state)})
async def set_range(
self, *, range: Range | None = None, custom_range: int | None = None
):
) -> dict:
"""Set the range for the sensor.
:param range: for using standard ranges
@ -93,7 +93,7 @@ class Motion(IotModule):
"""Return inactivity timeout in milliseconds."""
return self.config["cold_time"]
async def set_inactivity_timeout(self, timeout: int):
async def set_inactivity_timeout(self, timeout: int) -> dict:
"""Set inactivity timeout in milliseconds.
Note, that you need to delete the default "Smart Control" rule in the app

View File

@ -57,7 +57,7 @@ _LOGGER = logging.getLogger(__name__)
class RuleModule(IotModule):
"""Base class for rule-based modules, such as countdown and antitheft."""
def query(self):
def query(self) -> dict:
"""Prepare the query for rules."""
q = self.query_for_command("get_rules")
return merge(q, self.query_for_command("get_next_action"))
@ -73,14 +73,14 @@ class RuleModule(IotModule):
_LOGGER.error("Unable to read rule list: %s (data: %s)", ex, self.data)
return []
async def set_enabled(self, state: bool):
async def set_enabled(self, state: bool) -> dict:
"""Enable or disable the service."""
return await self.call("set_overall_enable", state)
return await self.call("set_overall_enable", {"enable": state})
async def delete_rule(self, rule: Rule):
async def delete_rule(self, rule: Rule) -> dict:
"""Delete the given rule."""
return await self.call("delete_rule", {"id": rule.id})
async def delete_all_rules(self):
async def delete_all_rules(self) -> dict:
"""Delete all rules."""
return await self.call("delete_all_rules")

View File

@ -15,14 +15,14 @@ class Time(IotModule, TimeInterface):
_timezone: tzinfo = timezone.utc
def query(self):
def query(self) -> dict:
"""Request time and timezone."""
q = self.query_for_command("get_time")
merge(q, self.query_for_command("get_timezone"))
return q
async def _post_update_hook(self):
async def _post_update_hook(self) -> None:
"""Perform actions after a device update."""
if res := self.data.get("get_timezone"):
self._timezone = await get_timezone(res.get("index"))
@ -47,7 +47,7 @@ class Time(IotModule, TimeInterface):
"""Return current timezone."""
return self._timezone
async def get_time(self):
async def get_time(self) -> datetime | None:
"""Return current device time."""
try:
res = await self.call("get_time")
@ -88,6 +88,6 @@ class Time(IotModule, TimeInterface):
except Exception as ex:
raise KasaException(ex) from ex
async def get_timezone(self):
async def get_timezone(self) -> dict:
"""Request timezone information from the device."""
return await self.call("get_timezone")

View File

@ -10,7 +10,7 @@ from ..iotmodule import IotModule, merge
class Usage(IotModule):
"""Baseclass for emeter/usage interfaces."""
def query(self):
def query(self) -> dict:
"""Return the base query."""
now = datetime.now()
year = now.year
@ -25,22 +25,22 @@ class Usage(IotModule):
return req
@property
def estimated_query_response_size(self):
def estimated_query_response_size(self) -> int:
"""Estimated maximum query response size."""
return 2048
@property
def daily_data(self):
def daily_data(self) -> list[dict]:
"""Return statistics on daily basis."""
return self.data["get_daystat"]["day_list"]
@property
def monthly_data(self):
def monthly_data(self) -> list[dict]:
"""Return statistics on monthly basis."""
return self.data["get_monthstat"]["month_list"]
@property
def usage_today(self):
def usage_today(self) -> int | None:
"""Return today's usage in minutes."""
today = datetime.now().day
# Traverse the list in reverse order to find the latest entry.
@ -50,7 +50,7 @@ class Usage(IotModule):
return None
@property
def usage_this_month(self):
def usage_this_month(self) -> int | None:
"""Return usage in this month in minutes."""
this_month = datetime.now().month
# Traverse the list in reverse order to find the latest entry.
@ -59,7 +59,9 @@ class Usage(IotModule):
return entry["time"]
return None
async def get_raw_daystat(self, *, year=None, month=None) -> dict:
async def get_raw_daystat(
self, *, year: int | None = None, month: int | None = None
) -> dict:
"""Return raw daily stats for the given year & month."""
if year is None:
year = datetime.now().year
@ -68,14 +70,16 @@ class Usage(IotModule):
return await self.call("get_daystat", {"year": year, "month": month})
async def get_raw_monthstat(self, *, year=None) -> dict:
async def get_raw_monthstat(self, *, year: int | None = None) -> dict:
"""Return raw monthly stats for the given year."""
if year is None:
year = datetime.now().year
return await self.call("get_monthstat", {"year": year})
async def get_daystat(self, *, year=None, month=None) -> dict:
async def get_daystat(
self, *, year: int | None = None, month: int | None = None
) -> dict:
"""Return daily stats for the given year & month.
The return value is a dictionary of {day: time, ...}.
@ -84,7 +88,7 @@ class Usage(IotModule):
data = self._convert_stat_data(data["day_list"], entry_key="day")
return data
async def get_monthstat(self, *, year=None) -> dict:
async def get_monthstat(self, *, year: int | None = None) -> dict:
"""Return monthly stats for the given year.
The return value is a dictionary of {month: time, ...}.
@ -93,11 +97,11 @@ class Usage(IotModule):
data = self._convert_stat_data(data["month_list"], entry_key="month")
return data
async def erase_stats(self):
async def erase_stats(self) -> dict:
"""Erase all stats."""
return await self.call("erase_runtime_stat")
def _convert_stat_data(self, data, entry_key) -> dict:
def _convert_stat_data(self, data: list[dict], entry_key: str) -> dict:
"""Return usage information keyed with the day/month.
The incoming data is a list of dictionaries::
@ -113,6 +117,6 @@ class Usage(IotModule):
if not data:
return {}
data = {entry[entry_key]: entry["time"] for entry in data}
res = {entry[entry_key]: entry["time"] for entry in data}
return data
return res

View File

@ -1,9 +1,13 @@
"""JSON abstraction."""
from __future__ import annotations
from typing import Any, Callable
try:
import orjson
def dumps(obj, *, default=None):
def dumps(obj: Any, *, default: Callable | None = None) -> str:
"""Dump JSON."""
return orjson.dumps(obj).decode()
@ -11,7 +15,7 @@ try:
except ImportError:
import json
def dumps(obj, *, default=None):
def dumps(obj: Any, *, default: Callable | None = None) -> str:
"""Dump JSON."""
# Separators specified for consistency with orjson
return json.dumps(obj, separators=(",", ":"))

View File

@ -50,7 +50,8 @@ import logging
import secrets
import struct
import time
from typing import TYPE_CHECKING, Any, cast
from asyncio import Future
from typing import TYPE_CHECKING, Any, Generator, cast
from cryptography.hazmat.primitives import padding
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
@ -110,10 +111,10 @@ class KlapTransport(BaseTransport):
else:
self._local_auth_hash = base64.b64decode(self._credentials_hash.encode()) # type: ignore[union-attr]
self._default_credentials_auth_hash: dict[str, bytes] = {}
self._blank_auth_hash = None
self._blank_auth_hash: bytes | None = None
self._handshake_lock = asyncio.Lock()
self._query_lock = asyncio.Lock()
self._handshake_done = False
self._handshake_done: bool = False
self._encryption_session: KlapEncryptionSession | None = None
self._session_expire_at: float | None = None
@ -125,7 +126,7 @@ class KlapTransport(BaseTransport):
self._request_url = self._app_url / "request"
@property
def default_port(self):
def default_port(self) -> int:
"""Default port for the transport."""
return self.DEFAULT_PORT
@ -242,7 +243,7 @@ class KlapTransport(BaseTransport):
raise AuthenticationError(msg)
async def perform_handshake2(
self, local_seed, remote_seed, auth_hash
self, local_seed: bytes, remote_seed: bytes, auth_hash: bytes
) -> KlapEncryptionSession:
"""Perform handshake2."""
# Handshake 2 has the following payload:
@ -277,7 +278,7 @@ class KlapTransport(BaseTransport):
return KlapEncryptionSession(local_seed, remote_seed, auth_hash)
async def perform_handshake(self) -> Any:
async def perform_handshake(self) -> None:
"""Perform handshake1 and handshake2.
Sets the encryption_session if successful.
@ -309,14 +310,14 @@ class KlapTransport(BaseTransport):
_LOGGER.debug("Handshake with %s complete", self._host)
def _handshake_session_expired(self):
def _handshake_session_expired(self) -> bool:
"""Return true if session has expired."""
return (
self._session_expire_at is None
or self._session_expire_at - time.monotonic() <= 0
)
async def send(self, request: str):
async def send(self, request: str) -> Generator[Future, None, dict[str, str]]: # type: ignore[override]
"""Send the request."""
if not self._handshake_done or self._handshake_session_expired():
await self.perform_handshake()
@ -355,6 +356,7 @@ class KlapTransport(BaseTransport):
if TYPE_CHECKING:
assert self._encryption_session
assert isinstance(response_data, bytes)
try:
decrypted_response = self._encryption_session.decrypt(response_data)
except Exception as ex:
@ -378,7 +380,7 @@ class KlapTransport(BaseTransport):
self._handshake_done = False
@staticmethod
def generate_auth_hash(creds: Credentials):
def generate_auth_hash(creds: Credentials) -> bytes:
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
un = creds.username
pw = creds.password
@ -388,19 +390,19 @@ class KlapTransport(BaseTransport):
@staticmethod
def handshake1_seed_auth_hash(
local_seed: bytes, remote_seed: bytes, auth_hash: bytes
):
) -> bytes:
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
return _sha256(local_seed + auth_hash)
@staticmethod
def handshake2_seed_auth_hash(
local_seed: bytes, remote_seed: bytes, auth_hash: bytes
):
) -> bytes:
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
return _sha256(remote_seed + auth_hash)
@staticmethod
def generate_owner_hash(creds: Credentials):
def generate_owner_hash(creds: Credentials) -> bytes:
"""Return the MD5 hash of the username in this object."""
un = creds.username
return md5(un.encode())
@ -410,7 +412,7 @@ class KlapTransportV2(KlapTransport):
"""Implementation of the KLAP encryption protocol with v2 hanshake hashes."""
@staticmethod
def generate_auth_hash(creds: Credentials):
def generate_auth_hash(creds: Credentials) -> bytes:
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
un = creds.username
pw = creds.password
@ -420,14 +422,14 @@ class KlapTransportV2(KlapTransport):
@staticmethod
def handshake1_seed_auth_hash(
local_seed: bytes, remote_seed: bytes, auth_hash: bytes
):
) -> bytes:
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
return _sha256(local_seed + remote_seed + auth_hash)
@staticmethod
def handshake2_seed_auth_hash(
local_seed: bytes, remote_seed: bytes, auth_hash: bytes
):
) -> bytes:
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
return _sha256(remote_seed + local_seed + auth_hash)
@ -440,7 +442,7 @@ class KlapEncryptionSession:
_cipher: Cipher
def __init__(self, local_seed, remote_seed, user_hash):
def __init__(self, local_seed: bytes, remote_seed: bytes, user_hash: bytes) -> None:
self.local_seed = local_seed
self.remote_seed = remote_seed
self.user_hash = user_hash
@ -449,11 +451,15 @@ class KlapEncryptionSession:
self._aes = algorithms.AES(self._key)
self._sig = self._sig_derive(local_seed, remote_seed, user_hash)
def _key_derive(self, local_seed, remote_seed, user_hash):
def _key_derive(
self, local_seed: bytes, remote_seed: bytes, user_hash: bytes
) -> bytes:
payload = b"lsk" + local_seed + remote_seed + user_hash
return hashlib.sha256(payload).digest()[:16]
def _iv_derive(self, local_seed, remote_seed, user_hash):
def _iv_derive(
self, local_seed: bytes, remote_seed: bytes, user_hash: bytes
) -> tuple[bytes, int]:
# iv is first 16 bytes of sha256, where the last 4 bytes forms the
# sequence number used in requests and is incremented on each request
payload = b"iv" + local_seed + remote_seed + user_hash
@ -461,17 +467,19 @@ class KlapEncryptionSession:
seq = int.from_bytes(fulliv[-4:], "big", signed=True)
return (fulliv[:12], seq)
def _sig_derive(self, local_seed, remote_seed, user_hash):
def _sig_derive(
self, local_seed: bytes, remote_seed: bytes, user_hash: bytes
) -> bytes:
# used to create a hash with which to prefix each request
payload = b"ldk" + local_seed + remote_seed + user_hash
return hashlib.sha256(payload).digest()[:28]
def _generate_cipher(self):
def _generate_cipher(self) -> None:
iv_seq = self._iv + PACK_SIGNED_LONG(self._seq)
cbc = modes.CBC(iv_seq)
self._cipher = Cipher(self._aes, cbc)
def encrypt(self, msg):
def encrypt(self, msg: bytes | str) -> tuple[bytes, int]:
"""Encrypt the data and increment the sequence number."""
self._seq += 1
self._generate_cipher()
@ -488,7 +496,7 @@ class KlapEncryptionSession:
).digest()
return (signature + ciphertext, self._seq)
def decrypt(self, msg):
def decrypt(self, msg: bytes) -> str:
"""Decrypt the data."""
decryptor = self._cipher.decryptor()
dp = decryptor.update(msg[32:]) + decryptor.finalize()

View File

@ -135,13 +135,13 @@ class Module(ABC):
# SMARTCAMERA only modules
Camera: Final[ModuleName[experimental.Camera]] = ModuleName("Camera")
def __init__(self, device: Device, module: str):
def __init__(self, device: Device, module: str) -> None:
self._device = device
self._module = module
self._module_features: dict[str, Feature] = {}
@abstractmethod
def query(self):
def query(self) -> dict:
"""Query to execute during the update cycle.
The inheriting modules implement this to include their wanted
@ -150,10 +150,10 @@ class Module(ABC):
@property
@abstractmethod
def data(self):
def data(self) -> dict:
"""Return the module specific raw data from the last update."""
def _initialize_features(self): # noqa: B027
def _initialize_features(self) -> None: # noqa: B027
"""Initialize features after the initial update.
This can be implemented if features depend on module query responses.
@ -162,7 +162,7 @@ class Module(ABC):
children's modules.
"""
async def _post_update_hook(self): # noqa: B027
async def _post_update_hook(self) -> None: # noqa: B027
"""Perform actions after a device update.
This can be implemented if a module needs to perform actions each time
@ -171,11 +171,11 @@ class Module(ABC):
*_initialize_features* on the first update.
"""
def _add_feature(self, feature: Feature):
def _add_feature(self, feature: Feature) -> None:
"""Add module feature."""
id_ = feature.id
if id_ in self._module_features:
raise KasaException("Duplicate id detected %s" % id_)
raise KasaException(f"Duplicate id detected {id_}")
self._module_features[id_] = feature
def __repr__(self) -> str:

View File

@ -130,7 +130,7 @@ class BaseProtocol(ABC):
self._transport = transport
@property
def _host(self):
def _host(self) -> str:
return self._transport._host
@property

View File

@ -15,7 +15,9 @@ class SmartLightEffect(LightEffectInterface, ABC):
"""
@abstractmethod
async def set_brightness(self, brightness: int, *, transition: int | None = None):
async def set_brightness(
self, brightness: int, *, transition: int | None = None
) -> dict:
"""Set effect brightness."""
@property

View File

@ -20,7 +20,7 @@ class Alarm(SmartModule):
"get_support_alarm_type_list": None, # This should be needed only once
}
def _initialize_features(self):
def _initialize_features(self) -> None:
"""Initialize features.
This is implemented as some features depend on device responses.
@ -100,7 +100,7 @@ class Alarm(SmartModule):
"""Return current alarm sound."""
return self.data["get_alarm_configure"]["type"]
async def set_alarm_sound(self, sound: str):
async def set_alarm_sound(self, sound: str) -> dict:
"""Set alarm sound.
See *alarm_sounds* for list of available sounds.
@ -119,7 +119,7 @@ class Alarm(SmartModule):
"""Return alarm volume."""
return self.data["get_alarm_configure"]["volume"]
async def set_alarm_volume(self, volume: Literal["low", "normal", "high"]):
async def set_alarm_volume(self, volume: Literal["low", "normal", "high"]) -> dict:
"""Set alarm volume."""
payload = self.data["get_alarm_configure"].copy()
payload["volume"] = volume

View File

@ -17,7 +17,7 @@ class AutoOff(SmartModule):
REQUIRED_COMPONENT = "auto_off"
QUERY_GETTER_NAME = "get_auto_off_config"
def _initialize_features(self):
def _initialize_features(self) -> None:
"""Initialize features after the initial update."""
self._add_feature(
Feature(
@ -63,7 +63,7 @@ class AutoOff(SmartModule):
"""Return True if enabled."""
return self.data["enable"]
async def set_enabled(self, enable: bool):
async def set_enabled(self, enable: bool) -> dict:
"""Enable/disable auto off."""
return await self.call(
"set_auto_off_config",
@ -75,7 +75,7 @@ class AutoOff(SmartModule):
"""Return time until auto off."""
return self.data["delay_min"]
async def set_delay(self, delay: int):
async def set_delay(self, delay: int) -> dict:
"""Set time until auto off."""
return await self.call(
"set_auto_off_config", {"delay_min": delay, "enable": self.data["enable"]}
@ -96,7 +96,7 @@ class AutoOff(SmartModule):
return self._device.time + timedelta(seconds=sysinfo["auto_off_remain_time"])
async def _check_supported(self):
async def _check_supported(self) -> bool:
"""Additional check to see if the module is supported by the device.
Parent devices that report components of children such as P300 will not have

View File

@ -12,7 +12,7 @@ class BatterySensor(SmartModule):
REQUIRED_COMPONENT = "battery_detect"
QUERY_GETTER_NAME = "get_battery_detect_info"
def _initialize_features(self):
def _initialize_features(self) -> None:
"""Initialize features."""
self._add_feature(
Feature(
@ -48,11 +48,11 @@ class BatterySensor(SmartModule):
return {}
@property
def battery(self):
def battery(self) -> int:
"""Return battery level."""
return self._device.sys_info["battery_percentage"]
@property
def battery_low(self):
def battery_low(self) -> bool:
"""Return True if battery is low."""
return self._device.sys_info["at_low_battery"]

View File

@ -14,7 +14,7 @@ class Brightness(SmartModule):
REQUIRED_COMPONENT = "brightness"
def _initialize_features(self):
def _initialize_features(self) -> None:
"""Initialize features."""
super()._initialize_features()
@ -39,7 +39,7 @@ class Brightness(SmartModule):
return {}
@property
def brightness(self):
def brightness(self) -> int:
"""Return current brightness."""
# If the device supports effects and one is active, use its brightness
if (
@ -49,7 +49,9 @@ class Brightness(SmartModule):
return self.data["brightness"]
async def set_brightness(self, brightness: int, *, transition: int | None = None):
async def set_brightness(
self, brightness: int, *, transition: int | None = None
) -> dict:
"""Set the brightness. A brightness value of 0 will turn off the light.
Note, transition is not supported and will be ignored.
@ -73,6 +75,6 @@ class Brightness(SmartModule):
return await self.call("set_device_info", {"brightness": brightness})
async def _check_supported(self):
async def _check_supported(self) -> bool:
"""Additional check to see if the module is supported by the device."""
return "brightness" in self.data

View File

@ -12,7 +12,7 @@ class ChildProtection(SmartModule):
REQUIRED_COMPONENT = "child_protection"
QUERY_GETTER_NAME = "get_child_protection"
def _initialize_features(self):
def _initialize_features(self) -> None:
"""Initialize features after the initial update."""
self._add_feature(
Feature(

View File

@ -13,7 +13,7 @@ class Cloud(SmartModule):
REQUIRED_COMPONENT = "cloud_connect"
MINIMUM_UPDATE_INTERVAL_SECS = 60
def _initialize_features(self):
def _initialize_features(self) -> None:
"""Initialize features after the initial update."""
self._add_feature(
Feature(
@ -29,7 +29,7 @@ class Cloud(SmartModule):
)
@property
def is_connected(self):
def is_connected(self) -> bool:
"""Return True if device is connected to the cloud."""
if self._has_data_error():
return False

View File

@ -12,7 +12,7 @@ class Color(SmartModule):
REQUIRED_COMPONENT = "color"
def _initialize_features(self):
def _initialize_features(self) -> None:
"""Initialize features after the initial update."""
self._add_feature(
Feature(
@ -48,7 +48,7 @@ class Color(SmartModule):
# due to the cpython implementation.
return tuple.__new__(HSV, (h, s, v))
def _raise_for_invalid_brightness(self, value):
def _raise_for_invalid_brightness(self, value: int) -> None:
"""Raise error on invalid brightness value."""
if not isinstance(value, int):
raise TypeError("Brightness must be an integer")

View File

@ -18,7 +18,7 @@ class ColorTemperature(SmartModule):
REQUIRED_COMPONENT = "color_temperature"
def _initialize_features(self):
def _initialize_features(self) -> None:
"""Initialize features."""
self._add_feature(
Feature(
@ -52,11 +52,11 @@ class ColorTemperature(SmartModule):
return ColorTempRange(*ct_range)
@property
def color_temp(self):
def color_temp(self) -> int:
"""Return current color temperature."""
return self.data["color_temp"]
async def set_color_temp(self, temp: int, *, brightness=None):
async def set_color_temp(self, temp: int, *, brightness: int | None = None) -> dict:
"""Set the color temperature."""
valid_temperature_range = self.valid_temperature_range
if temp < valid_temperature_range[0] or temp > valid_temperature_range[1]:

View File

@ -12,7 +12,7 @@ class ContactSensor(SmartModule):
REQUIRED_COMPONENT = None # we depend on availability of key
REQUIRED_KEY_ON_PARENT = "open"
def _initialize_features(self):
def _initialize_features(self) -> None:
"""Initialize features after the initial update."""
self._add_feature(
Feature(
@ -32,6 +32,6 @@ class ContactSensor(SmartModule):
return {}
@property
def is_open(self):
def is_open(self) -> bool:
"""Return True if the contact sensor is open."""
return self._device.sys_info["open"]

View File

@ -10,7 +10,7 @@ class DeviceModule(SmartModule):
REQUIRED_COMPONENT = "device"
async def _post_update_hook(self):
async def _post_update_hook(self) -> None:
"""Perform actions after a device update.
Overrides the default behaviour to disable a module if the query returns

View File

@ -2,6 +2,8 @@
from __future__ import annotations
from typing import NoReturn
from ...emeterstatus import EmeterStatus
from ...exceptions import KasaException
from ...interfaces.energy import Energy as EnergyInterface
@ -31,34 +33,34 @@ class Energy(SmartModule, EnergyInterface):
# Fallback if get_energy_usage does not provide current_power,
# which can happen on some newer devices (e.g. P304M).
elif (
power := self.data.get("get_current_power").get("current_power")
power := self.data.get("get_current_power", {}).get("current_power")
) is not None:
return power
return None
@property
@raise_if_update_error
def energy(self):
def energy(self) -> dict:
"""Return get_energy_usage results."""
if en := self.data.get("get_energy_usage"):
return en
return self.data
def _get_status_from_energy(self, energy) -> EmeterStatus:
def _get_status_from_energy(self, energy: dict) -> EmeterStatus:
return EmeterStatus(
{
"power_mw": energy.get("current_power"),
"total": energy.get("today_energy") / 1_000,
"power_mw": energy.get("current_power", 0),
"total": energy.get("today_energy", 0) / 1_000,
}
)
@property
@raise_if_update_error
def status(self):
def status(self) -> EmeterStatus:
"""Get the emeter status."""
return self._get_status_from_energy(self.energy)
async def get_status(self):
async def get_status(self) -> EmeterStatus:
"""Return real-time statistics."""
res = await self.call("get_energy_usage")
return self._get_status_from_energy(res["get_energy_usage"])
@ -67,13 +69,13 @@ class Energy(SmartModule, EnergyInterface):
@raise_if_update_error
def consumption_this_month(self) -> float | None:
"""Get the emeter value for this month in kWh."""
return self.energy.get("month_energy") / 1_000
return self.energy.get("month_energy", 0) / 1_000
@property
@raise_if_update_error
def consumption_today(self) -> float | None:
"""Get the emeter value for today in kWh."""
return self.energy.get("today_energy") / 1_000
return self.energy.get("today_energy", 0) / 1_000
@property
@raise_if_update_error
@ -97,22 +99,26 @@ class Energy(SmartModule, EnergyInterface):
"""Retrieve current energy readings."""
return self.status
async def erase_stats(self):
async def erase_stats(self) -> NoReturn:
"""Erase all stats."""
raise KasaException("Device does not support periodic statistics")
async def get_daily_stats(self, *, year=None, month=None, kwh=True) -> dict:
async def get_daily_stats(
self, *, year: int | None = None, month: int | None = None, kwh: bool = True
) -> dict:
"""Return daily stats for the given year & month.
The return value is a dictionary of {day: energy, ...}.
"""
raise KasaException("Device does not support periodic statistics")
async def get_monthly_stats(self, *, year=None, kwh=True) -> dict:
async def get_monthly_stats(
self, *, year: int | None = None, kwh: bool = True
) -> dict:
"""Return monthly stats for the given year."""
raise KasaException("Device does not support periodic statistics")
async def _check_supported(self):
async def _check_supported(self) -> bool:
"""Additional check to see if the module is supported by the device."""
# Energy module is not supported on P304M parent device
return "device_on" in self._device.sys_info

View File

@ -12,7 +12,7 @@ class Fan(SmartModule, FanInterface):
REQUIRED_COMPONENT = "fan_control"
def _initialize_features(self):
def _initialize_features(self) -> None:
"""Initialize features after the initial update."""
self._add_feature(
Feature(
@ -50,7 +50,7 @@ class Fan(SmartModule, FanInterface):
"""Return fan speed level."""
return 0 if self.data["device_on"] is False else self.data["fan_speed_level"]
async def set_fan_speed_level(self, level: int):
async def set_fan_speed_level(self, level: int) -> dict:
"""Set fan speed level, 0 for off, 1-4 for on."""
if level < 0 or level > 4:
raise ValueError("Invalid level, should be in range 0-4.")
@ -65,10 +65,10 @@ class Fan(SmartModule, FanInterface):
"""Return sleep mode status."""
return self.data["fan_sleep_mode_on"]
async def set_sleep_mode(self, on: bool):
async def set_sleep_mode(self, on: bool) -> dict:
"""Set sleep mode."""
return await self.call("set_device_info", {"fan_sleep_mode_on": on})
async def _check_supported(self):
async def _check_supported(self) -> bool:
"""Is the module available on this device."""
return "fan_speed_level" in self.data

View File

@ -49,14 +49,14 @@ class UpdateInfo(BaseModel):
needs_upgrade: bool = Field(alias="need_to_upgrade")
@validator("release_date", pre=True)
def _release_date_optional(cls, v):
def _release_date_optional(cls, v: str) -> str | None:
if not v:
return None
return v
@property
def update_available(self):
def update_available(self) -> bool:
"""Return True if update available."""
if self.status != 0:
return True
@ -69,11 +69,11 @@ class Firmware(SmartModule):
REQUIRED_COMPONENT = "firmware"
MINIMUM_UPDATE_INTERVAL_SECS = 60 * 60 * 24
def __init__(self, device: SmartDevice, module: str):
def __init__(self, device: SmartDevice, module: str) -> None:
super().__init__(device, module)
self._firmware_update_info: UpdateInfo | None = None
def _initialize_features(self):
def _initialize_features(self) -> None:
"""Initialize features."""
device = self._device
if self.supported_version > 1:
@ -183,7 +183,7 @@ class Firmware(SmartModule):
@allow_update_after
async def update(
self, progress_cb: Callable[[DownloadState], Coroutine] | None = None
):
) -> dict:
"""Update the device firmware."""
if not self._firmware_update_info:
raise KasaException(
@ -236,13 +236,15 @@ class Firmware(SmartModule):
else:
_LOGGER.warning("Unhandled state code: %s", state)
return state.dict()
@property
def auto_update_enabled(self) -> bool:
"""Return True if autoupdate is enabled."""
return "enable" in self.data and self.data["enable"]
@allow_update_after
async def set_auto_update_enabled(self, enabled: bool):
async def set_auto_update_enabled(self, enabled: bool) -> dict:
"""Change autoupdate setting."""
data = {**self.data, "enable": enabled}
await self.call("set_auto_update_info", data)
return await self.call("set_auto_update_info", data)

View File

@ -23,7 +23,7 @@ class FrostProtection(SmartModule):
"""Return True if frost protection is on."""
return self._device.sys_info["frost_protection_on"]
async def set_enabled(self, enable: bool):
async def set_enabled(self, enable: bool) -> dict:
"""Enable/disable frost protection."""
return await self.call(
"set_device_info",

View File

@ -12,7 +12,7 @@ class HumiditySensor(SmartModule):
REQUIRED_COMPONENT = "humidity"
QUERY_GETTER_NAME = "get_comfort_humidity_config"
def _initialize_features(self):
def _initialize_features(self) -> None:
"""Initialize features after the initial update."""
self._add_feature(
Feature(
@ -45,7 +45,7 @@ class HumiditySensor(SmartModule):
return {}
@property
def humidity(self):
def humidity(self) -> int:
"""Return current humidity in percentage."""
return self._device.sys_info["current_humidity"]

View File

@ -19,7 +19,7 @@ class Led(SmartModule, LedInterface):
return {self.QUERY_GETTER_NAME: None}
@property
def mode(self):
def mode(self) -> str:
"""LED mode setting.
"always", "never", "night_mode"
@ -27,12 +27,12 @@ class Led(SmartModule, LedInterface):
return self.data["led_rule"]
@property
def led(self):
def led(self) -> bool:
"""Return current led status."""
return self.data["led_rule"] != "never"
@allow_update_after
async def set_led(self, enable: bool):
async def set_led(self, enable: bool) -> dict:
"""Set led.
This should probably be a select with always/never/nightmode.
@ -41,7 +41,7 @@ class Led(SmartModule, LedInterface):
return await self.call("set_led_info", dict(self.data, **{"led_rule": rule}))
@property
def night_mode_settings(self):
def night_mode_settings(self) -> dict:
"""Night mode settings."""
return {
"start": self.data["start_time"],

View File

@ -96,7 +96,7 @@ class Light(SmartModule, LightInterface):
return await self._device.modules[Module.Color].set_hsv(hue, saturation, value)
async def set_color_temp(
self, temp: int, *, brightness=None, transition: int | None = None
self, temp: int, *, brightness: int | None = None, transition: int | None = None
) -> dict:
"""Set the color temperature of the device in kelvin.

View File

@ -81,7 +81,7 @@ class LightEffect(SmartModule, SmartLightEffect):
*,
brightness: int | None = None,
transition: int | None = None,
) -> None:
) -> dict:
"""Set an effect for the device.
Calling this will modify the brightness of the effect on the device.
@ -107,7 +107,7 @@ class LightEffect(SmartModule, SmartLightEffect):
)
await self.set_brightness(brightness, effect_id=effect_id)
await self.call("set_dynamic_light_effect_rule_enable", params)
return await self.call("set_dynamic_light_effect_rule_enable", params)
@property
def is_active(self) -> bool:
@ -139,11 +139,11 @@ class LightEffect(SmartModule, SmartLightEffect):
*,
transition: int | None = None,
effect_id: str | None = None,
):
) -> dict:
"""Set effect brightness."""
new_effect = self._get_effect_data(effect_id=effect_id).copy()
def _replace_brightness(data, new_brightness):
def _replace_brightness(data: list[int], new_brightness: int) -> list[int]:
"""Replace brightness.
The first element is the brightness, the rest are unknown.
@ -163,7 +163,7 @@ class LightEffect(SmartModule, SmartLightEffect):
async def set_custom_effect(
self,
effect_dict: dict,
) -> None:
) -> dict:
"""Set a custom effect on the device.
:param str effect_dict: The custom effect dict to set

View File

@ -29,12 +29,12 @@ class LightPreset(SmartModule, LightPresetInterface):
_presets: dict[str, LightState]
_preset_list: list[str]
def __init__(self, device: SmartDevice, module: str):
def __init__(self, device: SmartDevice, module: str) -> None:
super().__init__(device, module)
self._state_in_sysinfo = self.SYS_INFO_STATE_KEY in device.sys_info
self._brightness_only: bool = False
async def _post_update_hook(self):
async def _post_update_hook(self) -> None:
"""Update the internal presets."""
index = 0
self._presets = {}
@ -113,7 +113,7 @@ class LightPreset(SmartModule, LightPresetInterface):
async def set_preset(
self,
preset_name: str,
) -> None:
) -> dict:
"""Set a light preset for the device."""
light = self._device.modules[SmartModule.Light]
if preset_name == self.PRESET_NOT_SET:
@ -123,14 +123,14 @@ class LightPreset(SmartModule, LightPresetInterface):
preset = LightState(brightness=100)
elif (preset := self._presets.get(preset_name)) is None: # type: ignore[assignment]
raise ValueError(f"{preset_name} is not a valid preset: {self.preset_list}")
await self._device.modules[SmartModule.Light].set_state(preset)
return await self._device.modules[SmartModule.Light].set_state(preset)
@allow_update_after
async def save_preset(
self,
preset_name: str,
preset_state: LightState,
) -> None:
) -> dict:
"""Update the preset with preset_name with the new preset_info."""
if preset_name not in self._presets:
raise ValueError(f"{preset_name} is not a valid preset: {self.preset_list}")
@ -138,11 +138,13 @@ class LightPreset(SmartModule, LightPresetInterface):
if self._brightness_only:
bright_list = [state.brightness for state in self._presets.values()]
bright_list[index] = preset_state.brightness
await self.call("set_preset_rules", {"brightness": bright_list})
return await self.call("set_preset_rules", {"brightness": bright_list})
else:
state_params = asdict(preset_state)
new_info = {k: v for k, v in state_params.items() if v is not None}
await self.call("edit_preset_rules", {"index": index, "state": new_info})
return await self.call(
"edit_preset_rules", {"index": index, "state": new_info}
)
@property
def has_save_preset(self) -> bool:
@ -158,7 +160,7 @@ class LightPreset(SmartModule, LightPresetInterface):
return {self.QUERY_GETTER_NAME: {"start_index": 0}}
async def _check_supported(self):
async def _check_supported(self) -> bool:
"""Additional check to see if the module is supported by the device.
Parent devices that report components of children such as ks240 will not have

View File

@ -16,7 +16,7 @@ class LightStripEffect(SmartModule, SmartLightEffect):
REQUIRED_COMPONENT = "light_strip_lighting_effect"
def __init__(self, device: SmartDevice, module: str):
def __init__(self, device: SmartDevice, module: str) -> None:
super().__init__(device, module)
effect_list = [self.LIGHT_EFFECTS_OFF]
effect_list.extend(EFFECT_NAMES)
@ -66,7 +66,9 @@ class LightStripEffect(SmartModule, SmartLightEffect):
eff = self.data["lighting_effect"]
return eff["brightness"]
async def set_brightness(self, brightness: int, *, transition: int | None = None):
async def set_brightness(
self, brightness: int, *, transition: int | None = None
) -> dict:
"""Set effect brightness."""
if brightness <= 0:
return await self.set_effect(self.LIGHT_EFFECTS_OFF)
@ -91,7 +93,7 @@ class LightStripEffect(SmartModule, SmartLightEffect):
*,
brightness: int | None = None,
transition: int | None = None,
) -> None:
) -> dict:
"""Set an effect on the device.
If brightness or transition is defined,
@ -115,8 +117,7 @@ class LightStripEffect(SmartModule, SmartLightEffect):
effect_dict = self._effect_mapping["Aurora"]
effect_dict = {**effect_dict}
effect_dict["enable"] = 0
await self.set_custom_effect(effect_dict)
return
return await self.set_custom_effect(effect_dict)
if effect not in self._effect_mapping:
raise ValueError(f"The effect {effect} is not a built in effect.")
@ -134,13 +135,13 @@ class LightStripEffect(SmartModule, SmartLightEffect):
if transition is not None:
effect_dict["transition"] = transition
await self.set_custom_effect(effect_dict)
return await self.set_custom_effect(effect_dict)
@allow_update_after
async def set_custom_effect(
self,
effect_dict: dict,
) -> None:
) -> dict:
"""Set a custom effect on the device.
:param str effect_dict: The custom effect dict to set
@ -155,7 +156,7 @@ class LightStripEffect(SmartModule, SmartLightEffect):
"""Return True if the device supports setting custom effects."""
return True
def query(self):
def query(self) -> dict:
"""Return the base query."""
return {}

View File

@ -39,14 +39,14 @@ class LightTransition(SmartModule):
_off_state: _State
_enabled: bool
def __init__(self, device: SmartDevice, module: str):
def __init__(self, device: SmartDevice, module: str) -> None:
super().__init__(device, module)
self._state_in_sysinfo = all(
key in device.sys_info for key in self.SYS_INFO_STATE_KEYS
)
self._supports_on_and_off: bool = self.supported_version > 1
def _initialize_features(self):
def _initialize_features(self) -> None:
"""Initialize features."""
icon = "mdi:transition"
if not self._supports_on_and_off:
@ -138,7 +138,7 @@ class LightTransition(SmartModule):
}
@allow_update_after
async def set_enabled(self, enable: bool):
async def set_enabled(self, enable: bool) -> dict:
"""Enable gradual on/off."""
if not self._supports_on_and_off:
return await self.call("set_on_off_gradually_info", {"enable": enable})
@ -171,7 +171,7 @@ class LightTransition(SmartModule):
return self._on_state["max_duration"]
@allow_update_after
async def set_turn_on_transition(self, seconds: int):
async def set_turn_on_transition(self, seconds: int) -> dict:
"""Set turn on transition in seconds.
Setting to 0 turns the feature off.
@ -207,7 +207,7 @@ class LightTransition(SmartModule):
return self._off_state["max_duration"]
@allow_update_after
async def set_turn_off_transition(self, seconds: int):
async def set_turn_off_transition(self, seconds: int) -> dict:
"""Set turn on transition in seconds.
Setting to 0 turns the feature off.
@ -236,7 +236,7 @@ class LightTransition(SmartModule):
else:
return {self.QUERY_GETTER_NAME: None}
async def _check_supported(self):
async def _check_supported(self) -> bool:
"""Additional check to see if the module is supported by the device."""
# For devices that report child components on the parent that are not
# actually supported by the parent.

View File

@ -11,7 +11,7 @@ class MotionSensor(SmartModule):
REQUIRED_COMPONENT = "sensitivity"
def _initialize_features(self):
def _initialize_features(self) -> None:
"""Initialize features."""
self._add_feature(
Feature(
@ -31,6 +31,6 @@ class MotionSensor(SmartModule):
return {}
@property
def motion_detected(self):
def motion_detected(self) -> bool:
"""Return True if the motion has been detected."""
return self._device.sys_info["detected"]

View File

@ -12,7 +12,7 @@ class ReportMode(SmartModule):
REQUIRED_COMPONENT = "report_mode"
QUERY_GETTER_NAME = "get_report_mode"
def _initialize_features(self):
def _initialize_features(self) -> None:
"""Initialize features after the initial update."""
self._add_feature(
Feature(
@ -32,6 +32,6 @@ class ReportMode(SmartModule):
return {}
@property
def report_interval(self):
def report_interval(self) -> int:
"""Reporting interval of a sensor device."""
return self._device.sys_info["report_interval"]

View File

@ -26,7 +26,7 @@ class TemperatureControl(SmartModule):
REQUIRED_COMPONENT = "temp_control"
def _initialize_features(self):
def _initialize_features(self) -> None:
"""Initialize features after the initial update."""
self._add_feature(
Feature(
@ -92,7 +92,7 @@ class TemperatureControl(SmartModule):
"""Return thermostat state."""
return self._device.sys_info["frost_protection_on"] is False
async def set_state(self, enabled: bool):
async def set_state(self, enabled: bool) -> dict:
"""Set thermostat state."""
return await self.call("set_device_info", {"frost_protection_on": not enabled})
@ -147,7 +147,7 @@ class TemperatureControl(SmartModule):
"""Return thermostat states."""
return set(self._device.sys_info["trv_states"])
async def set_target_temperature(self, target: float):
async def set_target_temperature(self, target: float) -> dict:
"""Set target temperature."""
if (
target < self.minimum_target_temperature
@ -170,7 +170,7 @@ class TemperatureControl(SmartModule):
"""Return temperature offset."""
return self._device.sys_info["temp_offset"]
async def set_temperature_offset(self, offset: int):
async def set_temperature_offset(self, offset: int) -> dict:
"""Set temperature offset."""
if offset < -10 or offset > 10:
raise ValueError("Temperature offset must be [-10, 10]")

View File

@ -14,7 +14,7 @@ class TemperatureSensor(SmartModule):
REQUIRED_COMPONENT = "temperature"
QUERY_GETTER_NAME = "get_comfort_temp_config"
def _initialize_features(self):
def _initialize_features(self) -> None:
"""Initialize features after the initial update."""
self._add_feature(
Feature(
@ -60,7 +60,7 @@ class TemperatureSensor(SmartModule):
return {}
@property
def temperature(self):
def temperature(self) -> float:
"""Return current humidity in percentage."""
return self._device.sys_info["current_temp"]
@ -74,6 +74,8 @@ class TemperatureSensor(SmartModule):
"""Return current temperature unit."""
return self._device.sys_info["temp_unit"]
async def set_temperature_unit(self, unit: Literal["celsius", "fahrenheit"]):
async def set_temperature_unit(
self, unit: Literal["celsius", "fahrenheit"]
) -> dict:
"""Set the device temperature unit."""
return await self.call("set_temperature_unit", {"temp_unit": unit})

View File

@ -21,7 +21,7 @@ class Time(SmartModule, TimeInterface):
_timezone: tzinfo = timezone.utc
def _initialize_features(self):
def _initialize_features(self) -> None:
"""Initialize features after the initial update."""
self._add_feature(
Feature(
@ -35,7 +35,7 @@ class Time(SmartModule, TimeInterface):
)
)
async def _post_update_hook(self):
async def _post_update_hook(self) -> None:
"""Perform actions after a device update."""
td = timedelta(minutes=cast(float, self.data.get("time_diff")))
if region := self.data.get("region"):
@ -84,7 +84,7 @@ class Time(SmartModule, TimeInterface):
params["region"] = region
return await self.call("set_device_time", params)
async def _check_supported(self):
async def _check_supported(self) -> bool:
"""Additional check to see if the module is supported by the device.
Hub attached sensors report the time module but do return device time.

View File

@ -22,7 +22,7 @@ class WaterleakSensor(SmartModule):
REQUIRED_COMPONENT = "sensor_alarm"
def _initialize_features(self):
def _initialize_features(self) -> None:
"""Initialize features after the initial update."""
self._add_feature(
Feature(

View File

@ -49,7 +49,7 @@ class SmartChildDevice(SmartDevice):
self._update_internal_state(info)
self._components = component_info
async def update(self, update_children: bool = True):
async def update(self, update_children: bool = True) -> None:
"""Update child module info.
The parent updates our internal info so just update modules with
@ -57,7 +57,7 @@ class SmartChildDevice(SmartDevice):
"""
await self._update(update_children)
async def _update(self, update_children: bool = True):
async def _update(self, update_children: bool = True) -> None:
"""Update child module info.
Internal implementation to allow patching of public update in the cli
@ -118,5 +118,5 @@ class SmartChildDevice(SmartDevice):
dev_type = DeviceType.Unknown
return dev_type
def __repr__(self):
def __repr__(self) -> str:
return f"<{self.device_type} {self.alias} ({self.model}) of {self._parent}>"

View File

@ -69,7 +69,7 @@ class SmartDevice(Device):
self._on_since: datetime | None = None
self._info: dict[str, Any] = {}
async def _initialize_children(self):
async def _initialize_children(self) -> None:
"""Initialize children for power strips."""
child_info_query = {
"get_child_device_component_list": None,
@ -108,7 +108,9 @@ class SmartDevice(Device):
"""Return the device modules."""
return cast(ModuleMapping[SmartModule], self._modules)
def _try_get_response(self, responses: dict, request: str, default=None) -> dict:
def _try_get_response(
self, responses: dict, request: str, default: Any | None = None
) -> dict:
response = responses.get(request)
if isinstance(response, SmartErrorCode):
_LOGGER.debug(
@ -126,7 +128,7 @@ class SmartDevice(Device):
f"{request} not found in {responses} for device {self.host}"
)
async def _negotiate(self):
async def _negotiate(self) -> None:
"""Perform initialization.
We fetch the device info and the available components as early as possible.
@ -146,7 +148,8 @@ class SmartDevice(Device):
self._info = self._try_get_response(resp, "get_device_info")
# Create our internal presentation of available components
self._components_raw = resp["component_nego"]
self._components_raw = cast(dict, resp["component_nego"])
self._components = {
comp["id"]: int(comp["ver_code"])
for comp in self._components_raw["component_list"]
@ -167,7 +170,7 @@ class SmartDevice(Device):
"""Update the internal device info."""
self._info = self._try_get_response(info_resp, "get_device_info")
async def update(self, update_children: bool = False):
async def update(self, update_children: bool = False) -> None:
"""Update the device."""
if self.credentials is None and self.credentials_hash is None:
raise AuthenticationError("Tapo plug requires authentication.")
@ -206,7 +209,7 @@ class SmartDevice(Device):
async def _handle_module_post_update(
self, module: SmartModule, update_time: float, had_query: bool
):
) -> None:
if module.disabled:
return # pragma: no cover
if had_query:
@ -312,7 +315,7 @@ class SmartDevice(Device):
responses[meth] = SmartErrorCode.INTERNAL_QUERY_ERROR
return responses
async def _initialize_modules(self):
async def _initialize_modules(self) -> None:
"""Initialize modules based on component negotiation response."""
from .smartmodule import SmartModule
@ -324,7 +327,7 @@ class SmartDevice(Device):
# It also ensures that devices like power strips do not add modules such as
# firmware to the child devices.
skip_parent_only_modules = False
child_modules_to_skip = {}
child_modules_to_skip: dict = {} # TODO: this is never non-empty
if self._parent and self._parent.device_type != DeviceType.Hub:
skip_parent_only_modules = True
@ -333,17 +336,18 @@ class SmartDevice(Device):
skip_parent_only_modules and mod in NON_HUB_PARENT_ONLY_MODULES
) or mod.__name__ in child_modules_to_skip:
continue
if (
mod.REQUIRED_COMPONENT in self._components
or self.sys_info.get(mod.REQUIRED_KEY_ON_PARENT) is not None
required_component = cast(str, mod.REQUIRED_COMPONENT)
if required_component in self._components or (
mod.REQUIRED_KEY_ON_PARENT
and self.sys_info.get(mod.REQUIRED_KEY_ON_PARENT) is not None
):
_LOGGER.debug(
"Device %s, found required %s, adding %s to modules.",
self.host,
mod.REQUIRED_COMPONENT,
required_component,
mod.__name__,
)
module = mod(self, mod.REQUIRED_COMPONENT)
module = mod(self, required_component)
if await module._check_supported():
self._modules[module.name] = module
@ -354,7 +358,7 @@ class SmartDevice(Device):
):
self._modules[Light.__name__] = Light(self, "light")
async def _initialize_features(self):
async def _initialize_features(self) -> None:
"""Initialize device features."""
self._add_feature(
Feature(
@ -575,11 +579,11 @@ class SmartDevice(Device):
return str(self._info.get("device_id"))
@property
def internal_state(self) -> Any:
def internal_state(self) -> dict:
"""Return all the internal state data."""
return self._last_update
def _update_internal_state(self, info: dict) -> None:
def _update_internal_state(self, info: dict[str, Any]) -> None:
"""Update the internal info state.
This is used by the parent to push updates to its children.
@ -587,8 +591,8 @@ class SmartDevice(Device):
self._info = info
async def _query_helper(
self, method: str, params: dict | None = None, child_ids=None
) -> Any:
self, method: str, params: dict | None = None, child_ids: None = None
) -> dict:
res = await self.protocol.query({method: params})
return res
@ -610,22 +614,25 @@ class SmartDevice(Device):
"""Return true if the device is on."""
return bool(self._info.get("device_on"))
async def set_state(self, on: bool): # TODO: better name wanted.
async def set_state(self, on: bool) -> dict:
"""Set the device state.
See :meth:`is_on`.
"""
return await self.protocol.query({"set_device_info": {"device_on": on}})
async def turn_on(self, **kwargs):
async def turn_on(self, **kwargs: Any) -> dict:
"""Turn on the device."""
await self.set_state(True)
return await self.set_state(True)
async def turn_off(self, **kwargs):
async def turn_off(self, **kwargs: Any) -> dict:
"""Turn off the device."""
await self.set_state(False)
return await self.set_state(False)
def update_from_discover_info(self, info):
def update_from_discover_info(
self,
info: dict,
) -> None:
"""Update state from info from the discover call."""
self._discovery_info = info
self._info = info
@ -633,7 +640,7 @@ class SmartDevice(Device):
async def wifi_scan(self) -> list[WifiNetwork]:
"""Scan for available wifi networks."""
def _net_for_scan_info(res):
def _net_for_scan_info(res: dict) -> WifiNetwork:
return WifiNetwork(
ssid=base64.b64decode(res["ssid"]).decode(),
cipher_type=res["cipher_type"],
@ -651,7 +658,9 @@ class SmartDevice(Device):
]
return networks
async def wifi_join(self, ssid: str, password: str, keytype: str = "wpa2_psk"):
async def wifi_join(
self, ssid: str, password: str, keytype: str = "wpa2_psk"
) -> dict:
"""Join the given wifi network.
This method returns nothing as the device tries to activate the new
@ -688,9 +697,12 @@ class SmartDevice(Device):
except DeviceError:
raise # Re-raise on device-reported errors
except KasaException:
_LOGGER.debug("Received an expected for wifi join, but this is expected")
_LOGGER.debug(
"Received a kasa exception for wifi join, but this is expected"
)
return {}
async def update_credentials(self, username: str, password: str):
async def update_credentials(self, username: str, password: str) -> dict:
"""Update device credentials.
This will replace the existing authentication credentials on the device.
@ -705,7 +717,7 @@ class SmartDevice(Device):
}
return await self.protocol.query({"set_qs_info": payload})
async def set_alias(self, alias: str):
async def set_alias(self, alias: str) -> dict:
"""Set the device name (alias)."""
return await self.protocol.query(
{"set_device_info": {"nickname": base64.b64encode(alias.encode()).decode()}}

View File

@ -22,17 +22,17 @@ _R = TypeVar("_R")
def allow_update_after(
func: Callable[Concatenate[_T, _P], Awaitable[None]],
) -> Callable[Concatenate[_T, _P], Coroutine[Any, Any, None]]:
func: Callable[Concatenate[_T, _P], Awaitable[dict]],
) -> Callable[Concatenate[_T, _P], Coroutine[Any, Any, dict]]:
"""Define a wrapper to set _last_update_time to None.
This will ensure that a module is updated in the next update cycle after
a value has been changed.
"""
async def _async_wrap(self: _T, *args: _P.args, **kwargs: _P.kwargs) -> None:
async def _async_wrap(self: _T, *args: _P.args, **kwargs: _P.kwargs) -> dict:
try:
await func(self, *args, **kwargs)
return await func(self, *args, **kwargs)
finally:
self._last_update_time = None
@ -68,21 +68,21 @@ class SmartModule(Module):
DISABLE_AFTER_ERROR_COUNT = 10
def __init__(self, device: SmartDevice, module: str):
def __init__(self, device: SmartDevice, module: str) -> None:
self._device: SmartDevice
super().__init__(device, module)
self._last_update_time: float | None = None
self._last_update_error: KasaException | None = None
self._error_count = 0
def __init_subclass__(cls, **kwargs):
def __init_subclass__(cls, **kwargs) -> None:
# We only want to register submodules in a modules package so that
# other classes can inherit from smartmodule and not be registered
if cls.__module__.split(".")[-2] == "modules":
_LOGGER.debug("Registering %s", cls)
cls.REGISTERED_MODULES[cls._module_name()] = cls
def _set_error(self, err: Exception | None):
def _set_error(self, err: Exception | None) -> None:
if err is None:
self._error_count = 0
self._last_update_error = None
@ -119,7 +119,7 @@ class SmartModule(Module):
return self._error_count >= self.DISABLE_AFTER_ERROR_COUNT
@classmethod
def _module_name(cls):
def _module_name(cls) -> str:
return getattr(cls, "NAME", cls.__name__)
@property
@ -127,7 +127,7 @@ class SmartModule(Module):
"""Name of the module."""
return self._module_name()
async def _post_update_hook(self): # noqa: B027
async def _post_update_hook(self) -> None: # noqa: B027
"""Perform actions after a device update.
Any modules overriding this should ensure that self.data is
@ -142,7 +142,7 @@ class SmartModule(Module):
"""
return {self.QUERY_GETTER_NAME: None}
async def call(self, method, params=None):
async def call(self, method: str, params: dict | None = None) -> dict:
"""Call a method.
Just a helper method.
@ -150,7 +150,7 @@ class SmartModule(Module):
return await self._device._query_helper(method, params)
@property
def data(self):
def data(self) -> dict[str, Any]:
"""Return response data for the module.
If the module performs only a single query, the resulting response is unwrapped.

View File

@ -72,7 +72,7 @@ class SmartProtocol(BaseProtocol):
)
self._redact_data = True
def get_smart_request(self, method, params=None) -> str:
def get_smart_request(self, method: str, params: dict | None = None) -> str:
"""Get a request message as a string."""
request = {
"method": method,
@ -289,8 +289,8 @@ class SmartProtocol(BaseProtocol):
return {smart_method: result}
async def _handle_response_lists(
self, response_result: dict[str, Any], method, retry_count
):
self, response_result: dict[str, Any], method: str, retry_count: int
) -> None:
if (
response_result is None
or isinstance(response_result, SmartErrorCode)
@ -325,7 +325,9 @@ class SmartProtocol(BaseProtocol):
break
response_result[response_list_name].extend(next_batch[response_list_name])
def _handle_response_error_code(self, resp_dict: dict, method, raise_on_error=True):
def _handle_response_error_code(
self, resp_dict: dict, method: str, raise_on_error: bool = True
) -> None:
error_code_raw = resp_dict.get("error_code")
try:
error_code = SmartErrorCode.from_int(error_code_raw)
@ -369,12 +371,12 @@ class _ChildProtocolWrapper(SmartProtocol):
device responses before returning to the caller.
"""
def __init__(self, device_id: str, base_protocol: SmartProtocol):
def __init__(self, device_id: str, base_protocol: SmartProtocol) -> None:
self._device_id = device_id
self._protocol = base_protocol
self._transport = base_protocol._transport
def _get_method_and_params_for_request(self, request):
def _get_method_and_params_for_request(self, request: dict[str, Any] | str) -> Any:
"""Return payload for wrapping.
TODO: this does not support batches and requires refactoring in the future.

View File

@ -310,9 +310,7 @@ class FakeSmartTransport(BaseTransport):
}
return retval
raise NotImplementedError(
"Method %s not implemented for children" % child_method
)
raise NotImplementedError(f"Method {child_method} not implemented for children")
def _get_on_off_gradually_info(self, info, params):
if self.components["on_off_gradually"] == 1:

View File

@ -41,7 +41,7 @@ async def test_firmware_features(
await fw.check_latest_firmware()
if fw.supported_version < required_version:
pytest.skip("Feature %s requires newer version" % feature)
pytest.skip(f"Feature {feature} requires newer version")
prop = getattr(fw, prop_name)
assert isinstance(prop, type)

View File

@ -48,7 +48,7 @@ class XorTransport(BaseTransport):
self.loop: asyncio.AbstractEventLoop | None = None
@property
def default_port(self):
def default_port(self) -> int:
"""Default port for the transport."""
return self.DEFAULT_PORT

View File

@ -139,10 +139,15 @@ select = [
"PT", # flake8-pytest-style
"LOG", # flake8-logging
"G", # flake8-logging-format
"ANN", # annotations
]
ignore = [
"D105", # Missing docstring in magic method
"D107", # Missing docstring in `__init__`
"ANN101", # Missing type annotation for `self`
"ANN102", # Missing type annotation for `cls` in classmethod
"ANN003", # Missing type annotation for `**kwargs`
"ANN401", # allow any
]
[tool.ruff.lint.pydocstyle]
@ -157,11 +162,21 @@ convention = "pep257"
"D104",
"S101", # allow asserts
"E501", # ignore line-too-longs
"ANN", # skip for now
]
"docs/source/conf.py" = [
"D100",
"D103",
]
# Temporary ANN disable
"kasa/cli/*.py" = [
"ANN",
]
# Temporary ANN disable
"devtools/*.py" = [
"ANN",
]
[tool.mypy]
warn_unused_configs = true # warns if overrides sections unused/mis-spelled