Enable ruff check for ANN (#1139)

This commit is contained in:
Teemu R. 2024-11-10 19:55:13 +01:00 committed by GitHub
parent 6b44fe6242
commit 66eb17057e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
89 changed files with 596 additions and 452 deletions

View File

@ -66,6 +66,6 @@ todo_include_todos = True
myst_heading_anchors = 3 myst_heading_anchors = 3
def setup(app): def setup(app): # noqa: ANN201,ANN001
# add copybutton to hide the >>> prompts, see https://github.com/readthedocs/sphinx_rtd_theme/issues/167 # add copybutton to hide the >>> prompts, see https://github.com/readthedocs/sphinx_rtd_theme/issues/167
app.add_js_file("copybutton.js") app.add_js_file("copybutton.js")

View File

@ -13,7 +13,7 @@ to be handled by the user of the library.
""" """
from importlib.metadata import version from importlib.metadata import version
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Any
from warnings import warn from warnings import warn
from kasa.credentials import Credentials from kasa.credentials import Credentials
@ -101,7 +101,7 @@ deprecated_classes = {
} }
def __getattr__(name): def __getattr__(name: str) -> Any:
if name in deprecated_names: if name in deprecated_names:
warn(f"{name} is deprecated", DeprecationWarning, stacklevel=2) warn(f"{name} is deprecated", DeprecationWarning, stacklevel=2)
return globals()[f"_deprecated_{name}"] return globals()[f"_deprecated_{name}"]
@ -117,7 +117,7 @@ def __getattr__(name):
) )
return new_class return new_class
if name in deprecated_classes: if name in deprecated_classes:
new_class = deprecated_classes[name] new_class = deprecated_classes[name] # type: ignore[assignment]
msg = f"{name} is deprecated, use {new_class.__name__} instead" msg = f"{name} is deprecated, use {new_class.__name__} instead"
warn(msg, DeprecationWarning, stacklevel=2) warn(msg, DeprecationWarning, stacklevel=2)
return new_class return new_class

View File

@ -146,7 +146,7 @@ class AesTransport(BaseTransport):
pw = base64.b64encode(credentials.password.encode()).decode() pw = base64.b64encode(credentials.password.encode()).decode()
return un, pw return un, pw
def _handle_response_error_code(self, resp_dict: Any, msg: str) -> None: def _handle_response_error_code(self, resp_dict: dict, msg: str) -> None:
error_code_raw = resp_dict.get("error_code") error_code_raw = resp_dict.get("error_code")
try: try:
error_code = SmartErrorCode.from_int(error_code_raw) error_code = SmartErrorCode.from_int(error_code_raw)
@ -191,14 +191,14 @@ class AesTransport(BaseTransport):
+ f"status code {status_code} to passthrough" + f"status code {status_code} to passthrough"
) )
self._handle_response_error_code(
resp_dict, "Error sending secure_passthrough message"
)
if TYPE_CHECKING: if TYPE_CHECKING:
resp_dict = cast(Dict[str, Any], resp_dict) resp_dict = cast(Dict[str, Any], resp_dict)
assert self._encryption_session is not None assert self._encryption_session is not None
self._handle_response_error_code(
resp_dict, "Error sending secure_passthrough message"
)
raw_response: str = resp_dict["result"]["response"] raw_response: str = resp_dict["result"]["response"]
try: try:
@ -219,7 +219,7 @@ class AesTransport(BaseTransport):
) from ex ) from ex
return ret_val # type: ignore[return-value] return ret_val # type: ignore[return-value]
async def perform_login(self): async def perform_login(self) -> None:
"""Login to the device.""" """Login to the device."""
try: try:
await self.try_login(self._login_params) await self.try_login(self._login_params)
@ -324,11 +324,11 @@ class AesTransport(BaseTransport):
+ f"status code {status_code} to handshake" + f"status code {status_code} to handshake"
) )
self._handle_response_error_code(resp_dict, "Unable to complete handshake")
if TYPE_CHECKING: if TYPE_CHECKING:
resp_dict = cast(Dict[str, Any], resp_dict) resp_dict = cast(Dict[str, Any], resp_dict)
self._handle_response_error_code(resp_dict, "Unable to complete handshake")
handshake_key = resp_dict["result"]["key"] handshake_key = resp_dict["result"]["key"]
if ( if (
@ -355,7 +355,7 @@ class AesTransport(BaseTransport):
_LOGGER.debug("Handshake with %s complete", self._host) _LOGGER.debug("Handshake with %s complete", self._host)
def _handshake_session_expired(self): def _handshake_session_expired(self) -> bool:
"""Return true if session has expired.""" """Return true if session has expired."""
return ( return (
self._session_expire_at is None self._session_expire_at is None
@ -394,7 +394,9 @@ class AesEncyptionSession:
"""Class for an AES encryption session.""" """Class for an AES encryption session."""
@staticmethod @staticmethod
def create_from_keypair(handshake_key: str, keypair: KeyPair): def create_from_keypair(
handshake_key: str, keypair: KeyPair
) -> AesEncyptionSession:
"""Create the encryption session.""" """Create the encryption session."""
handshake_key_bytes: bytes = base64.b64decode(handshake_key.encode()) handshake_key_bytes: bytes = base64.b64decode(handshake_key.encode())
@ -404,11 +406,11 @@ class AesEncyptionSession:
return AesEncyptionSession(key_and_iv[:16], key_and_iv[16:]) return AesEncyptionSession(key_and_iv[:16], key_and_iv[16:])
def __init__(self, key, iv): def __init__(self, key: bytes, iv: bytes) -> None:
self.cipher = Cipher(algorithms.AES(key), modes.CBC(iv)) self.cipher = Cipher(algorithms.AES(key), modes.CBC(iv))
self.padding_strategy = padding.PKCS7(algorithms.AES.block_size) self.padding_strategy = padding.PKCS7(algorithms.AES.block_size)
def encrypt(self, data) -> bytes: def encrypt(self, data: bytes) -> bytes:
"""Encrypt the message.""" """Encrypt the message."""
encryptor = self.cipher.encryptor() encryptor = self.cipher.encryptor()
padder = self.padding_strategy.padder() padder = self.padding_strategy.padder()
@ -416,7 +418,7 @@ class AesEncyptionSession:
encrypted = encryptor.update(padded_data) + encryptor.finalize() encrypted = encryptor.update(padded_data) + encryptor.finalize()
return base64.b64encode(encrypted) return base64.b64encode(encrypted)
def decrypt(self, data) -> str: def decrypt(self, data: str | bytes) -> str:
"""Decrypt the message.""" """Decrypt the message."""
decryptor = self.cipher.decryptor() decryptor = self.cipher.decryptor()
unpadder = self.padding_strategy.unpadder() unpadder = self.padding_strategy.unpadder()
@ -429,14 +431,16 @@ class KeyPair:
"""Class for generating key pairs.""" """Class for generating key pairs."""
@staticmethod @staticmethod
def create_key_pair(key_size: int = 1024): def create_key_pair(key_size: int = 1024) -> KeyPair:
"""Create a key pair.""" """Create a key pair."""
private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size) private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size)
public_key = private_key.public_key() public_key = private_key.public_key()
return KeyPair(private_key, public_key) return KeyPair(private_key, public_key)
@staticmethod @staticmethod
def create_from_der_keys(private_key_der_b64: str, public_key_der_b64: str): def create_from_der_keys(
private_key_der_b64: str, public_key_der_b64: str
) -> KeyPair:
"""Create a key pair.""" """Create a key pair."""
key_bytes = base64.b64decode(private_key_der_b64.encode()) key_bytes = base64.b64decode(private_key_der_b64.encode())
private_key = cast( private_key = cast(
@ -449,7 +453,9 @@ class KeyPair:
return KeyPair(private_key, public_key) return KeyPair(private_key, public_key)
def __init__(self, private_key: rsa.RSAPrivateKey, public_key: rsa.RSAPublicKey): def __init__(
self, private_key: rsa.RSAPrivateKey, public_key: rsa.RSAPublicKey
) -> None:
self.private_key = private_key self.private_key = private_key
self.public_key = public_key self.public_key = public_key
self.private_key_der_bytes = self.private_key.private_bytes( self.private_key_der_bytes = self.private_key.private_bytes(

View File

@ -7,7 +7,7 @@ import re
import sys import sys
from contextlib import contextmanager from contextlib import contextmanager
from functools import singledispatch, update_wrapper, wraps from functools import singledispatch, update_wrapper, wraps
from typing import Final from typing import TYPE_CHECKING, Any, Callable, Final
import asyncclick as click import asyncclick as click
@ -37,7 +37,7 @@ except ImportError:
"""Strip rich formatting from messages.""" """Strip rich formatting from messages."""
@wraps(echo_func) @wraps(echo_func)
def wrapper(message=None, *args, **kwargs): def wrapper(message=None, *args, **kwargs) -> None:
if message is not None: if message is not None:
message = rich_formatting.sub("", message) message = rich_formatting.sub("", message)
echo_func(message, *args, **kwargs) echo_func(message, *args, **kwargs)
@ -47,20 +47,20 @@ except ImportError:
_echo = _strip_rich_formatting(click.echo) _echo = _strip_rich_formatting(click.echo)
def echo(*args, **kwargs): def echo(*args, **kwargs) -> None:
"""Print a message.""" """Print a message."""
ctx = click.get_current_context().find_root() ctx = click.get_current_context().find_root()
if "json" not in ctx.params or ctx.params["json"] is False: if "json" not in ctx.params or ctx.params["json"] is False:
_echo(*args, **kwargs) _echo(*args, **kwargs)
def error(msg: str): def error(msg: str) -> None:
"""Print an error and exit.""" """Print an error and exit."""
echo(f"[bold red]{msg}[/bold red]") echo(f"[bold red]{msg}[/bold red]")
sys.exit(1) sys.exit(1)
def json_formatter_cb(result, **kwargs): def json_formatter_cb(result: Any, **kwargs) -> None:
"""Format and output the result as JSON, if requested.""" """Format and output the result as JSON, if requested."""
if not kwargs.get("json"): if not kwargs.get("json"):
return return
@ -82,7 +82,7 @@ def json_formatter_cb(result, **kwargs):
print(json_content) print(json_content)
def pass_dev_or_child(wrapped_function): def pass_dev_or_child(wrapped_function: Callable) -> Callable:
"""Pass the device or child to the click command based on the child options.""" """Pass the device or child to the click command based on the child options."""
child_help = ( child_help = (
"Child ID or alias for controlling sub-devices. " "Child ID or alias for controlling sub-devices. "
@ -133,7 +133,10 @@ def pass_dev_or_child(wrapped_function):
async def _get_child_device( async def _get_child_device(
device: Device, child_option, child_index_option, info_command device: Device,
child_option: str | None,
child_index_option: int | None,
info_command: str | None,
) -> Device | None: ) -> Device | None:
def _list_children(): def _list_children():
return "\n".join( return "\n".join(
@ -178,11 +181,15 @@ async def _get_child_device(
f"{child_option} children are:\n{_list_children()}" f"{child_option} children are:\n{_list_children()}"
) )
if TYPE_CHECKING:
assert isinstance(child_index_option, int)
if child_index_option + 1 > len(device.children) or child_index_option < 0: if child_index_option + 1 > len(device.children) or child_index_option < 0:
error( error(
f"Invalid index {child_index_option}, " f"Invalid index {child_index_option}, "
f"device has {len(device.children)} children" f"device has {len(device.children)} children"
) )
child_by_index = device.children[child_index_option] child_by_index = device.children[child_index_option]
echo(f"Targeting child device {child_by_index.alias}") echo(f"Targeting child device {child_by_index.alias}")
return child_by_index return child_by_index
@ -195,7 +202,7 @@ def CatchAllExceptions(cls):
https://stackoverflow.com/questions/52213375 https://stackoverflow.com/questions/52213375
""" """
def _handle_exception(debug, exc): def _handle_exception(debug, exc) -> None:
if isinstance(exc, click.ClickException): if isinstance(exc, click.ClickException):
raise raise
# Handle exit request from click. # Handle exit request from click.

View File

@ -22,7 +22,7 @@ from .common import (
@click.group() @click.group()
@pass_dev_or_child @pass_dev_or_child
def device(dev): def device(dev) -> None:
"""Commands to control basic device settings.""" """Commands to control basic device settings."""

View File

@ -36,7 +36,7 @@ async def detail(ctx):
auth_failed = [] auth_failed = []
sem = asyncio.Semaphore() sem = asyncio.Semaphore()
async def print_unsupported(unsupported_exception: UnsupportedDeviceError): async def print_unsupported(unsupported_exception: UnsupportedDeviceError) -> None:
unsupported.append(unsupported_exception) unsupported.append(unsupported_exception)
async with sem: async with sem:
if unsupported_exception.discovery_result: if unsupported_exception.discovery_result:
@ -50,7 +50,7 @@ async def detail(ctx):
from .device import state from .device import state
async def print_discovered(dev: Device): async def print_discovered(dev: Device) -> None:
async with sem: async with sem:
try: try:
await dev.update() await dev.update()
@ -189,7 +189,7 @@ async def config(ctx):
error(f"Unable to connect to {host}") error(f"Unable to connect to {host}")
def _echo_dictionary(discovery_info: dict): def _echo_dictionary(discovery_info: dict) -> None:
echo("\t[bold]== Discovery information ==[/bold]") echo("\t[bold]== Discovery information ==[/bold]")
for key, value in discovery_info.items(): for key, value in discovery_info.items():
key_name = " ".join(x.capitalize() or "_" for x in key.split("_")) key_name = " ".join(x.capitalize() or "_" for x in key.split("_"))
@ -197,7 +197,7 @@ def _echo_dictionary(discovery_info: dict):
echo(f"\t{key_name_and_spaces}{value}") echo(f"\t{key_name_and_spaces}{value}")
def _echo_discovery_info(discovery_info): def _echo_discovery_info(discovery_info) -> None:
# We don't have discovery info when all connection params are passed manually # We don't have discovery info when all connection params are passed manually
if discovery_info is None: if discovery_info is None:
return return

View File

@ -24,7 +24,7 @@ def _echo_features(
category: Feature.Category | None = None, category: Feature.Category | None = None,
verbose: bool = False, verbose: bool = False,
indent: str = "\t", indent: str = "\t",
): ) -> None:
"""Print out a listing of features and their values.""" """Print out a listing of features and their values."""
if category is not None: if category is not None:
features = { features = {
@ -43,7 +43,9 @@ def _echo_features(
echo(f"{indent}{feat.name} ({feat.id}): [red]got exception ({ex})[/red]") echo(f"{indent}{feat.name} ({feat.id}): [red]got exception ({ex})[/red]")
def _echo_all_features(features, *, verbose=False, title_prefix=None, indent=""): def _echo_all_features(
features, *, verbose=False, title_prefix=None, indent=""
) -> None:
"""Print out all features by category.""" """Print out all features by category."""
if title_prefix is not None: if title_prefix is not None:
echo(f"[bold]\n{indent}== {title_prefix} ==[/bold]") echo(f"[bold]\n{indent}== {title_prefix} ==[/bold]")

View File

@ -3,6 +3,8 @@
Taken from the click help files. Taken from the click help files.
""" """
from __future__ import annotations
import importlib import importlib
import asyncclick as click import asyncclick as click
@ -11,7 +13,7 @@ import asyncclick as click
class LazyGroup(click.Group): class LazyGroup(click.Group):
"""Lazy group class.""" """Lazy group class."""
def __init__(self, *args, lazy_subcommands=None, **kwargs): def __init__(self, *args, lazy_subcommands=None, **kwargs) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
# lazy_subcommands is a map of the form: # lazy_subcommands is a map of the form:
# #
@ -31,9 +33,9 @@ class LazyGroup(click.Group):
return self._lazy_load(cmd_name) return self._lazy_load(cmd_name)
return super().get_command(ctx, cmd_name) return super().get_command(ctx, cmd_name)
def format_commands(self, ctx, formatter): def format_commands(self, ctx, formatter) -> None:
"""Format the top level help output.""" """Format the top level help output."""
sections = {} sections: dict[str, list] = {}
for cmd, parent in self.lazy_subcommands.items(): for cmd, parent in self.lazy_subcommands.items():
sections.setdefault(parent, []) sections.setdefault(parent, [])
cmd_obj = self.get_command(ctx, cmd) cmd_obj = self.get_command(ctx, cmd)

View File

@ -15,7 +15,7 @@ from .common import echo, error, pass_dev_or_child
@click.group() @click.group()
@pass_dev_or_child @pass_dev_or_child
def light(dev): def light(dev) -> None:
"""Commands to control light settings.""" """Commands to control light settings."""

View File

@ -43,7 +43,7 @@ ENCRYPT_TYPES = [encrypt_type.value for encrypt_type in DeviceEncryptionType]
DEFAULT_TARGET = "255.255.255.255" DEFAULT_TARGET = "255.255.255.255"
def _legacy_type_to_class(_type): def _legacy_type_to_class(_type: str) -> Any:
from kasa.iot import ( from kasa.iot import (
IotBulb, IotBulb,
IotDimmer, IotDimmer,
@ -396,9 +396,9 @@ async def cli(
@cli.command() @cli.command()
@pass_dev_or_child @pass_dev_or_child
async def shell(dev: Device): async def shell(dev: Device) -> None:
"""Open interactive shell.""" """Open interactive shell."""
echo("Opening shell for %s" % dev) echo(f"Opening shell for {dev}")
from ptpython.repl import embed from ptpython.repl import embed
logging.getLogger("parso").setLevel(logging.WARNING) # prompt parsing logging.getLogger("parso").setLevel(logging.WARNING) # prompt parsing

View File

@ -14,7 +14,7 @@ from .common import (
@click.group() @click.group()
@pass_dev @pass_dev
async def schedule(dev): async def schedule(dev) -> None:
"""Scheduling commands.""" """Scheduling commands."""

View File

@ -23,7 +23,7 @@ from .common import (
@click.group(invoke_without_command=True) @click.group(invoke_without_command=True)
@click.pass_context @click.pass_context
async def time(ctx: click.Context): async def time(ctx: click.Context) -> None:
"""Get and set time.""" """Get and set time."""
if ctx.invoked_subcommand is None: if ctx.invoked_subcommand is None:
await ctx.invoke(time_get) await ctx.invoke(time_get)

View File

@ -78,13 +78,13 @@ async def energy(dev: Device, year, month, erase):
else: else:
emeter_status = dev.emeter_realtime emeter_status = dev.emeter_realtime
echo("Current: %s A" % emeter_status["current"]) echo("Current: {} A".format(emeter_status["current"]))
echo("Voltage: %s V" % emeter_status["voltage"]) echo("Voltage: {} V".format(emeter_status["voltage"]))
echo("Power: %s W" % emeter_status["power"]) echo("Power: {} W".format(emeter_status["power"]))
echo("Total consumption: %s kWh" % emeter_status["total"]) echo("Total consumption: {} kWh".format(emeter_status["total"]))
echo("Today: %s kWh" % dev.emeter_today) echo(f"Today: {dev.emeter_today} kWh")
echo("This month: %s kWh" % dev.emeter_this_month) echo(f"This month: {dev.emeter_this_month} kWh")
return emeter_status return emeter_status
@ -122,8 +122,8 @@ async def usage(dev: Device, year, month, erase):
usage_data = await usage.get_daystat(year=month.year, month=month.month) usage_data = await usage.get_daystat(year=month.year, month=month.month)
else: else:
# Call with no argument outputs summary data and returns # Call with no argument outputs summary data and returns
echo("Today: %s minutes" % usage.usage_today) echo(f"Today: {usage.usage_today} minutes")
echo("This month: %s minutes" % usage.usage_this_month) echo(f"This month: {usage.usage_this_month} minutes")
return usage return usage

View File

@ -16,7 +16,7 @@ from .common import (
@click.group() @click.group()
@pass_dev @pass_dev
def wifi(dev): def wifi(dev) -> None:
"""Commands to control wifi settings.""" """Commands to control wifi settings."""

View File

@ -234,10 +234,10 @@ class Device(ABC):
return await connect(host=host, config=config) # type: ignore[arg-type] return await connect(host=host, config=config) # type: ignore[arg-type]
@abstractmethod @abstractmethod
async def update(self, update_children: bool = True): async def update(self, update_children: bool = True) -> None:
"""Update the device.""" """Update the device."""
async def disconnect(self): async def disconnect(self) -> None:
"""Disconnect and close any underlying connection resources.""" """Disconnect and close any underlying connection resources."""
await self.protocol.close() await self.protocol.close()
@ -257,15 +257,15 @@ class Device(ABC):
return not self.is_on return not self.is_on
@abstractmethod @abstractmethod
async def turn_on(self, **kwargs) -> dict | None: async def turn_on(self, **kwargs) -> dict:
"""Turn on the device.""" """Turn on the device."""
@abstractmethod @abstractmethod
async def turn_off(self, **kwargs) -> dict | None: async def turn_off(self, **kwargs) -> dict:
"""Turn off the device.""" """Turn off the device."""
@abstractmethod @abstractmethod
async def set_state(self, on: bool): async def set_state(self, on: bool) -> dict:
"""Set the device state to *on*. """Set the device state to *on*.
This allows turning the device on and off. This allows turning the device on and off.
@ -278,7 +278,7 @@ class Device(ABC):
return self.protocol._transport._host return self.protocol._transport._host
@host.setter @host.setter
def host(self, value): def host(self, value: str) -> None:
"""Set the device host. """Set the device host.
Generally used by discovery to set the hostname after ip discovery. Generally used by discovery to set the hostname after ip discovery.
@ -307,7 +307,7 @@ class Device(ABC):
return self._device_type return self._device_type
@abstractmethod @abstractmethod
def update_from_discover_info(self, info): def update_from_discover_info(self, info: dict) -> None:
"""Update state from info from the discover call.""" """Update state from info from the discover call."""
@property @property
@ -325,7 +325,7 @@ class Device(ABC):
def alias(self) -> str | None: def alias(self) -> str | None:
"""Returns the device alias or nickname.""" """Returns the device alias or nickname."""
async def _raw_query(self, request: str | dict) -> Any: async def _raw_query(self, request: str | dict) -> dict:
"""Send a raw query to the device.""" """Send a raw query to the device."""
return await self.protocol.query(request=request) return await self.protocol.query(request=request)
@ -407,7 +407,7 @@ class Device(ABC):
@property @property
@abstractmethod @abstractmethod
def internal_state(self) -> Any: def internal_state(self) -> dict:
"""Return all the internal state data.""" """Return all the internal state data."""
@property @property
@ -420,10 +420,10 @@ class Device(ABC):
"""Return the list of supported features.""" """Return the list of supported features."""
return self._features return self._features
def _add_feature(self, feature: Feature): def _add_feature(self, feature: Feature) -> None:
"""Add a new feature to the device.""" """Add a new feature to the device."""
if feature.id in self._features: if feature.id in self._features:
raise KasaException("Duplicate feature id %s" % feature.id) raise KasaException(f"Duplicate feature id {feature.id}")
assert feature.id is not None # TODO: hack for typing # noqa: S101 assert feature.id is not None # TODO: hack for typing # noqa: S101
self._features[feature.id] = feature self._features[feature.id] = feature
@ -446,11 +446,13 @@ class Device(ABC):
"""Scan for available wifi networks.""" """Scan for available wifi networks."""
@abstractmethod @abstractmethod
async def wifi_join(self, ssid: str, password: str, keytype: str = "wpa2_psk"): async def wifi_join(
self, ssid: str, password: str, keytype: str = "wpa2_psk"
) -> dict:
"""Join the given wifi network.""" """Join the given wifi network."""
@abstractmethod @abstractmethod
async def set_alias(self, alias: str): async def set_alias(self, alias: str) -> dict:
"""Set the device name (alias).""" """Set the device name (alias)."""
@abstractmethod @abstractmethod
@ -468,7 +470,7 @@ class Device(ABC):
Note, this does not downgrade the firmware. Note, this does not downgrade the firmware.
""" """
def __repr__(self): def __repr__(self) -> str:
update_needed = " - update() needed" if not self._last_update else "" update_needed = " - update() needed" if not self._last_update else ""
return ( return (
f"<{self.device_type} at {self.host} -" f"<{self.device_type} at {self.host} -"
@ -486,7 +488,9 @@ class Device(ABC):
"is_strip_socket": (None, DeviceType.StripSocket), "is_strip_socket": (None, DeviceType.StripSocket),
} }
def _get_replacing_attr(self, module_name: ModuleName, *attrs): def _get_replacing_attr(
self, module_name: ModuleName | None, *attrs: Any
) -> str | None:
# If module name is None check self # If module name is None check self
if not module_name: if not module_name:
check = self check = self
@ -540,7 +544,7 @@ class Device(ABC):
"supported_modules": (None, ["modules"]), "supported_modules": (None, ["modules"]),
} }
def __getattr__(self, name): def __getattr__(self, name: str) -> Any:
# is_device_type # is_device_type
if dep_device_type_attr := self._deprecated_device_type_attributes.get(name): if dep_device_type_attr := self._deprecated_device_type_attributes.get(name):
module = dep_device_type_attr[0] module = dep_device_type_attr[0]

View File

@ -83,7 +83,7 @@ async def _connect(config: DeviceConfig, protocol: BaseProtocol) -> Device:
if debug_enabled: if debug_enabled:
start_time = time.perf_counter() start_time = time.perf_counter()
def _perf_log(has_params, perf_type): def _perf_log(has_params: bool, perf_type: str) -> None:
nonlocal start_time nonlocal start_time
if debug_enabled: if debug_enabled:
end_time = time.perf_counter() end_time = time.perf_counter()
@ -150,7 +150,7 @@ def _get_device_type_from_sys_info(info: dict[str, Any]) -> DeviceType:
return DeviceType.LightStrip return DeviceType.LightStrip
return DeviceType.Bulb return DeviceType.Bulb
raise UnsupportedDeviceError("Unknown device type: %s" % type_) raise UnsupportedDeviceError(f"Unknown device type: {type_}")
def get_device_class_from_sys_info(sysinfo: dict[str, Any]) -> type[IotDevice]: def get_device_class_from_sys_info(sysinfo: dict[str, Any]) -> type[IotDevice]:

View File

@ -75,14 +75,14 @@ class DeviceFamily(Enum):
SmartIpCamera = "SMART.IPCAMERA" SmartIpCamera = "SMART.IPCAMERA"
def _dataclass_from_dict(klass, in_val): def _dataclass_from_dict(klass: Any, in_val: dict) -> Any:
if is_dataclass(klass): if is_dataclass(klass):
fieldtypes = {f.name: f.type for f in fields(klass)} fieldtypes = {f.name: f.type for f in fields(klass)}
val = {} val = {}
for dict_key in in_val: for dict_key in in_val:
if dict_key in fieldtypes: if dict_key in fieldtypes:
if hasattr(fieldtypes[dict_key], "from_dict"): if hasattr(fieldtypes[dict_key], "from_dict"):
val[dict_key] = fieldtypes[dict_key].from_dict(in_val[dict_key]) val[dict_key] = fieldtypes[dict_key].from_dict(in_val[dict_key]) # type: ignore[union-attr]
else: else:
val[dict_key] = _dataclass_from_dict( val[dict_key] = _dataclass_from_dict(
fieldtypes[dict_key], in_val[dict_key] fieldtypes[dict_key], in_val[dict_key]
@ -91,12 +91,12 @@ def _dataclass_from_dict(klass, in_val):
raise KasaException( raise KasaException(
f"Cannot create dataclass from dict, unknown key: {dict_key}" f"Cannot create dataclass from dict, unknown key: {dict_key}"
) )
return klass(**val) return klass(**val) # type: ignore[operator]
else: else:
return in_val return in_val
def _dataclass_to_dict(in_val): def _dataclass_to_dict(in_val: Any) -> dict:
fieldtypes = {f.name: f.type for f in fields(in_val) if f.compare} fieldtypes = {f.name: f.type for f in fields(in_val) if f.compare}
out_val = {} out_val = {}
for field_name in fieldtypes: for field_name in fieldtypes:
@ -210,7 +210,7 @@ class DeviceConfig:
aes_keys: Optional[KeyPairDict] = None aes_keys: Optional[KeyPairDict] = None
def __post_init__(self): def __post_init__(self) -> None:
if self.connection_type is None: if self.connection_type is None:
self.connection_type = DeviceConnectionParameters( self.connection_type = DeviceConnectionParameters(
DeviceFamily.IotSmartPlugSwitch, DeviceEncryptionType.Xor DeviceFamily.IotSmartPlugSwitch, DeviceEncryptionType.Xor

View File

@ -89,9 +89,19 @@ import logging
import secrets import secrets
import socket import socket
import struct import struct
from collections.abc import Awaitable from asyncio.transports import DatagramTransport
from pprint import pformat as pf from pprint import pformat as pf
from typing import TYPE_CHECKING, Any, Callable, Dict, NamedTuple, Optional, Type, cast from typing import (
TYPE_CHECKING,
Any,
Callable,
Coroutine,
Dict,
NamedTuple,
Optional,
Type,
cast,
)
from aiohttp import ClientSession from aiohttp import ClientSession
@ -140,8 +150,8 @@ class ConnectAttempt(NamedTuple):
device: type device: type
OnDiscoveredCallable = Callable[[Device], Awaitable[None]] OnDiscoveredCallable = Callable[[Device], Coroutine]
OnUnsupportedCallable = Callable[[UnsupportedDeviceError], Awaitable[None]] OnUnsupportedCallable = Callable[[UnsupportedDeviceError], Coroutine]
OnConnectAttemptCallable = Callable[[ConnectAttempt, bool], None] OnConnectAttemptCallable = Callable[[ConnectAttempt, bool], None]
DeviceDict = Dict[str, Device] DeviceDict = Dict[str, Device]
@ -156,7 +166,7 @@ class _AesDiscoveryQuery:
keypair: KeyPair | None = None keypair: KeyPair | None = None
@classmethod @classmethod
def generate_query(cls): def generate_query(cls) -> bytearray:
if not cls.keypair: if not cls.keypair:
cls.keypair = KeyPair.create_key_pair(key_size=2048) cls.keypair = KeyPair.create_key_pair(key_size=2048)
secret = secrets.token_bytes(4) secret = secrets.token_bytes(4)
@ -215,7 +225,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
credentials: Credentials | None = None, credentials: Credentials | None = None,
timeout: int | None = None, timeout: int | None = None,
) -> None: ) -> None:
self.transport = None self.transport: DatagramTransport | None = None
self.discovery_packets = discovery_packets self.discovery_packets = discovery_packets
self.interface = interface self.interface = interface
self.on_discovered = on_discovered self.on_discovered = on_discovered
@ -239,16 +249,19 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
self.target_discovered: bool = False self.target_discovered: bool = False
self._started_event = asyncio.Event() self._started_event = asyncio.Event()
def _run_callback_task(self, coro): def _run_callback_task(self, coro: Coroutine) -> None:
task = asyncio.create_task(coro) task: asyncio.Task = asyncio.create_task(coro)
self.callback_tasks.append(task) self.callback_tasks.append(task)
async def wait_for_discovery_to_complete(self): async def wait_for_discovery_to_complete(self) -> None:
"""Wait for the discovery task to complete.""" """Wait for the discovery task to complete."""
# Give some time for connection_made event to be received # Give some time for connection_made event to be received
async with asyncio_timeout(self.DISCOVERY_START_TIMEOUT): async with asyncio_timeout(self.DISCOVERY_START_TIMEOUT):
await self._started_event.wait() await self._started_event.wait()
try: try:
if TYPE_CHECKING:
assert isinstance(self.discover_task, asyncio.Task)
await self.discover_task await self.discover_task
except asyncio.CancelledError: except asyncio.CancelledError:
# if target_discovered then cancel was called internally # if target_discovered then cancel was called internally
@ -257,11 +270,11 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
# Wait for any pending callbacks to complete # Wait for any pending callbacks to complete
await asyncio.gather(*self.callback_tasks) await asyncio.gather(*self.callback_tasks)
def connection_made(self, transport) -> None: def connection_made(self, transport: DatagramTransport) -> None: # type: ignore[override]
"""Set socket options for broadcasting.""" """Set socket options for broadcasting."""
self.transport = transport self.transport = cast(DatagramTransport, transport)
sock = transport.get_extra_info("socket") sock = self.transport.get_extra_info("socket")
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
try: try:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
@ -292,7 +305,11 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
self.transport.sendto(aes_discovery_query, self.target_2) # type: ignore self.transport.sendto(aes_discovery_query, self.target_2) # type: ignore
await asyncio.sleep(sleep_between_packets) await asyncio.sleep(sleep_between_packets)
def datagram_received(self, data, addr) -> None: def datagram_received(
self,
data: bytes,
addr: tuple[str, int],
) -> None:
"""Handle discovery responses.""" """Handle discovery responses."""
if TYPE_CHECKING: if TYPE_CHECKING:
assert _AesDiscoveryQuery.keypair assert _AesDiscoveryQuery.keypair
@ -338,18 +355,18 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
self._handle_discovered_event() self._handle_discovered_event()
def _handle_discovered_event(self): def _handle_discovered_event(self) -> None:
"""If target is in seen_hosts cancel discover_task.""" """If target is in seen_hosts cancel discover_task."""
if self.target in self.seen_hosts: if self.target in self.seen_hosts:
self.target_discovered = True self.target_discovered = True
if self.discover_task: if self.discover_task:
self.discover_task.cancel() self.discover_task.cancel()
def error_received(self, ex): def error_received(self, ex: Exception) -> None:
"""Handle asyncio.Protocol errors.""" """Handle asyncio.Protocol errors."""
_LOGGER.error("Got error: %s", ex) _LOGGER.error("Got error: %s", ex)
def connection_lost(self, ex): # pragma: no cover def connection_lost(self, ex: Exception | None) -> None: # pragma: no cover
"""Cancel the discover task if running.""" """Cancel the discover task if running."""
if self.discover_task: if self.discover_task:
self.discover_task.cancel() self.discover_task.cancel()
@ -372,17 +389,17 @@ class Discover:
@staticmethod @staticmethod
async def discover( async def discover(
*, *,
target="255.255.255.255", target: str = "255.255.255.255",
on_discovered=None, on_discovered: OnDiscoveredCallable | None = None,
discovery_timeout=5, discovery_timeout: int = 5,
discovery_packets=3, discovery_packets: int = 3,
interface=None, interface: str | None = None,
on_unsupported=None, on_unsupported: OnUnsupportedCallable | None = None,
credentials=None, credentials: Credentials | None = None,
username: str | None = None, username: str | None = None,
password: str | None = None, password: str | None = None,
port=None, port: int | None = None,
timeout=None, timeout: int | None = None,
) -> DeviceDict: ) -> DeviceDict:
"""Discover supported devices. """Discover supported devices.
@ -636,7 +653,7 @@ class Discover:
) )
if not dev_class: if not dev_class:
raise UnsupportedDeviceError( raise UnsupportedDeviceError(
"Unknown device type: %s" % discovery_result.device_type, f"Unknown device type: {discovery_result.device_type}",
discovery_result=info, discovery_result=info,
) )
return dev_class return dev_class

View File

@ -49,13 +49,13 @@ class EmeterStatus(dict):
except ValueError: except ValueError:
return None return None
def __repr__(self): def __repr__(self) -> str:
return ( return (
f"<EmeterStatus power={self.power} voltage={self.voltage}" f"<EmeterStatus power={self.power} voltage={self.voltage}"
f" current={self.current} total={self.total}>" f" current={self.current} total={self.total}>"
) )
def __getitem__(self, item): def __getitem__(self, item: str) -> float | None:
"""Return value in wanted units.""" """Return value in wanted units."""
valid_keys = [ valid_keys = [
"voltage_mv", "voltage_mv",

View File

@ -15,10 +15,10 @@ class KasaException(Exception):
class TimeoutError(KasaException, _asyncioTimeoutError): class TimeoutError(KasaException, _asyncioTimeoutError):
"""Timeout exception for device errors.""" """Timeout exception for device errors."""
def __repr__(self): def __repr__(self) -> str:
return KasaException.__repr__(self) return KasaException.__repr__(self)
def __str__(self): def __str__(self) -> str:
return KasaException.__str__(self) return KasaException.__str__(self)
@ -42,11 +42,11 @@ class DeviceError(KasaException):
self.error_code: SmartErrorCode | None = kwargs.get("error_code", None) self.error_code: SmartErrorCode | None = kwargs.get("error_code", None)
super().__init__(*args) super().__init__(*args)
def __repr__(self): def __repr__(self) -> str:
err_code = self.error_code.__repr__() if self.error_code else "" err_code = self.error_code.__repr__() if self.error_code else ""
return f"{self.__class__.__name__}({err_code})" return f"{self.__class__.__name__}({err_code})"
def __str__(self): def __str__(self) -> str:
err_code = f" (error_code={self.error_code.name})" if self.error_code else "" err_code = f" (error_code={self.error_code.name})" if self.error_code else ""
return super().__str__() + err_code return super().__str__() + err_code
@ -62,7 +62,7 @@ class _RetryableError(DeviceError):
class SmartErrorCode(IntEnum): class SmartErrorCode(IntEnum):
"""Enum for SMART Error Codes.""" """Enum for SMART Error Codes."""
def __str__(self): def __str__(self) -> str:
return f"{self.name}({self.value})" return f"{self.name}({self.value})"
@staticmethod @staticmethod

View File

@ -12,12 +12,12 @@ class Experimental:
ENV_VAR = "KASA_EXPERIMENTAL" ENV_VAR = "KASA_EXPERIMENTAL"
@classmethod @classmethod
def set_enabled(cls, enabled): def set_enabled(cls, enabled: bool) -> None:
"""Set the enabled value.""" """Set the enabled value."""
cls._enabled = enabled cls._enabled = enabled
@classmethod @classmethod
def enabled(cls): def enabled(cls) -> bool:
"""Get the enabled value.""" """Get the enabled value."""
if cls._enabled is not None: if cls._enabled is not None:
return cls._enabled return cls._enabled

View File

@ -50,11 +50,13 @@ class SmartCameraProtocol(SmartProtocol):
"""Class for SmartCamera Protocol.""" """Class for SmartCamera Protocol."""
async def _handle_response_lists( async def _handle_response_lists(
self, response_result: dict[str, Any], method, retry_count self, response_result: dict[str, Any], method: str, retry_count: int
): ) -> None:
pass pass
def _handle_response_error_code(self, resp_dict: dict, method, raise_on_error=True): def _handle_response_error_code(
self, resp_dict: dict, method: str, raise_on_error: bool = True
) -> None:
error_code_raw = resp_dict.get("error_code") error_code_raw = resp_dict.get("error_code")
try: try:
error_code = SmartErrorCode.from_int(error_code_raw) error_code = SmartErrorCode.from_int(error_code_raw)
@ -203,7 +205,7 @@ class _ChildCameraProtocolWrapper(SmartProtocol):
device responses before returning to the caller. device responses before returning to the caller.
""" """
def __init__(self, device_id: str, base_protocol: SmartProtocol): def __init__(self, device_id: str, base_protocol: SmartProtocol) -> None:
self._device_id = device_id self._device_id = device_id
self._protocol = base_protocol self._protocol = base_protocol
self._transport = base_protocol._transport self._transport = base_protocol._transport

View File

@ -256,7 +256,9 @@ class SslAesTransport(BaseTransport):
return ret_val # type: ignore[return-value] return ret_val # type: ignore[return-value]
@staticmethod @staticmethod
def generate_confirm_hash(local_nonce, server_nonce, pwd_hash): def generate_confirm_hash(
local_nonce: str, server_nonce: str, pwd_hash: str
) -> str:
"""Generate an auth hash for the protocol on the supplied credentials.""" """Generate an auth hash for the protocol on the supplied credentials."""
expected_confirm_bytes = _sha256_hash( expected_confirm_bytes = _sha256_hash(
local_nonce.encode() + pwd_hash.encode() + server_nonce.encode() local_nonce.encode() + pwd_hash.encode() + server_nonce.encode()
@ -264,7 +266,9 @@ class SslAesTransport(BaseTransport):
return expected_confirm_bytes + server_nonce + local_nonce return expected_confirm_bytes + server_nonce + local_nonce
@staticmethod @staticmethod
def generate_digest_password(local_nonce, server_nonce, pwd_hash): def generate_digest_password(
local_nonce: str, server_nonce: str, pwd_hash: str
) -> str:
"""Generate an auth hash for the protocol on the supplied credentials.""" """Generate an auth hash for the protocol on the supplied credentials."""
digest_password_hash = _sha256_hash( digest_password_hash = _sha256_hash(
pwd_hash.encode() + local_nonce.encode() + server_nonce.encode() pwd_hash.encode() + local_nonce.encode() + server_nonce.encode()
@ -275,7 +279,7 @@ class SslAesTransport(BaseTransport):
@staticmethod @staticmethod
def generate_encryption_token( def generate_encryption_token(
token_type, local_nonce, server_nonce, pwd_hash token_type: str, local_nonce: str, server_nonce: str, pwd_hash: str
) -> bytes: ) -> bytes:
"""Generate encryption token.""" """Generate encryption token."""
hashedKey = _sha256_hash( hashedKey = _sha256_hash(
@ -302,7 +306,9 @@ class SslAesTransport(BaseTransport):
local_nonce, server_nonce, pwd_hash = await self.perform_handshake1() local_nonce, server_nonce, pwd_hash = await self.perform_handshake1()
await self.perform_handshake2(local_nonce, server_nonce, pwd_hash) await self.perform_handshake2(local_nonce, server_nonce, pwd_hash)
async def perform_handshake2(self, local_nonce, server_nonce, pwd_hash) -> None: async def perform_handshake2(
self, local_nonce: str, server_nonce: str, pwd_hash: str
) -> None:
"""Perform the handshake.""" """Perform the handshake."""
_LOGGER.debug("Performing handshake2 ...") _LOGGER.debug("Performing handshake2 ...")
digest_password = self.generate_digest_password( digest_password = self.generate_digest_password(

View File

@ -162,7 +162,7 @@ class Feature:
#: If set, this property will be used to get *choices*. #: If set, this property will be used to get *choices*.
choices_getter: str | Callable[[], list[str]] | None = None choices_getter: str | Callable[[], list[str]] | None = None
def __post_init__(self): def __post_init__(self) -> None:
"""Handle late-binding of members.""" """Handle late-binding of members."""
# Populate minimum & maximum values, if range_getter is given # Populate minimum & maximum values, if range_getter is given
self._container = self.container if self.container is not None else self.device self._container = self.container if self.container is not None else self.device
@ -188,7 +188,7 @@ class Feature:
f"Read-only feat defines attribute_setter: {self.name} ({self.id}):" f"Read-only feat defines attribute_setter: {self.name} ({self.id}):"
) )
def _get_property_value(self, getter): def _get_property_value(self, getter: str | Callable | None) -> Any:
if getter is None: if getter is None:
return None return None
if isinstance(getter, str): if isinstance(getter, str):
@ -227,7 +227,7 @@ class Feature:
return 0 return 0
@property @property
def value(self): def value(self) -> int | float | bool | str | Enum | None:
"""Return the current value.""" """Return the current value."""
if self.type == Feature.Type.Action: if self.type == Feature.Type.Action:
return "<Action>" return "<Action>"
@ -264,7 +264,7 @@ class Feature:
return await getattr(container, self.attribute_setter)(value) return await getattr(container, self.attribute_setter)(value)
def __repr__(self): def __repr__(self) -> str:
try: try:
value = self.value value = self.value
choices = self.choices choices = self.choices
@ -286,8 +286,8 @@ class Feature:
value = " ".join( value = " ".join(
[f"*{choice}*" if choice == value else choice for choice in choices] [f"*{choice}*" if choice == value else choice for choice in choices]
) )
if self.precision_hint is not None and value is not None: if self.precision_hint is not None and isinstance(value, float):
value = round(self.value, self.precision_hint) value = round(value, self.precision_hint)
s = f"{self.name} ({self.id}): {value}" s = f"{self.name} ({self.id}): {value}"
if self.unit is not None: if self.unit is not None:

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import ssl
import time import time
from typing import Any, Dict from typing import Any, Dict
@ -64,7 +65,7 @@ class HttpClient:
json: dict | Any | None = None, json: dict | Any | None = None,
headers: dict[str, str] | None = None, headers: dict[str, str] | None = None,
cookies_dict: dict[str, str] | None = None, cookies_dict: dict[str, str] | None = None,
ssl=False, ssl: ssl.SSLContext | bool = False,
) -> tuple[int, dict | bytes | None]: ) -> tuple[int, dict | bytes | None]:
"""Send an http post request to the device. """Send an http post request to the device.

View File

@ -4,6 +4,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import IntFlag, auto from enum import IntFlag, auto
from typing import Any
from warnings import warn from warnings import warn
from ..emeterstatus import EmeterStatus from ..emeterstatus import EmeterStatus
@ -31,7 +32,7 @@ class Energy(Module, ABC):
"""Return True if module supports the feature.""" """Return True if module supports the feature."""
return module_feature in self._supported return module_feature in self._supported
def _initialize_features(self): def _initialize_features(self) -> None:
"""Initialize features.""" """Initialize features."""
device = self._device device = self._device
self._add_feature( self._add_feature(
@ -151,22 +152,26 @@ class Energy(Module, ABC):
"""Get the current voltage in V.""" """Get the current voltage in V."""
@abstractmethod @abstractmethod
async def get_status(self): async def get_status(self) -> EmeterStatus:
"""Return real-time statistics.""" """Return real-time statistics."""
@abstractmethod @abstractmethod
async def erase_stats(self): async def erase_stats(self) -> dict:
"""Erase all stats.""" """Erase all stats."""
@abstractmethod @abstractmethod
async def get_daily_stats(self, *, year=None, month=None, kwh=True) -> dict: async def get_daily_stats(
self, *, year: int | None = None, month: int | None = None, kwh: bool = True
) -> dict:
"""Return daily stats for the given year & month. """Return daily stats for the given year & month.
The return value is a dictionary of {day: energy, ...}. The return value is a dictionary of {day: energy, ...}.
""" """
@abstractmethod @abstractmethod
async def get_monthly_stats(self, *, year=None, kwh=True) -> dict: async def get_monthly_stats(
self, *, year: int | None = None, kwh: bool = True
) -> dict:
"""Return monthly stats for the given year.""" """Return monthly stats for the given year."""
_deprecated_attributes = { _deprecated_attributes = {
@ -179,7 +184,7 @@ class Energy(Module, ABC):
"get_monthstat": "get_monthly_stats", "get_monthstat": "get_monthly_stats",
} }
def __getattr__(self, name): def __getattr__(self, name: str) -> Any:
if attr := self._deprecated_attributes.get(name): if attr := self._deprecated_attributes.get(name):
msg = f"{name} is deprecated, use {attr} instead" msg = f"{name} is deprecated, use {attr} instead"
warn(msg, DeprecationWarning, stacklevel=2) warn(msg, DeprecationWarning, stacklevel=2)

View File

@ -16,5 +16,5 @@ class Fan(Module, ABC):
"""Return fan speed level.""" """Return fan speed level."""
@abstractmethod @abstractmethod
async def set_fan_speed_level(self, level: int): async def set_fan_speed_level(self, level: int) -> dict:
"""Set fan speed level.""" """Set fan speed level."""

View File

@ -11,7 +11,7 @@ from ..module import Module
class Led(Module, ABC): class Led(Module, ABC):
"""Base interface to represent a LED module.""" """Base interface to represent a LED module."""
def _initialize_features(self): def _initialize_features(self) -> None:
"""Initialize features.""" """Initialize features."""
device = self._device device = self._device
self._add_feature( self._add_feature(
@ -34,5 +34,5 @@ class Led(Module, ABC):
"""Return current led status.""" """Return current led status."""
@abstractmethod @abstractmethod
async def set_led(self, enable: bool) -> None: async def set_led(self, enable: bool) -> dict:
"""Set led.""" """Set led."""

View File

@ -166,7 +166,7 @@ class Light(Module, ABC):
@abstractmethod @abstractmethod
async def set_color_temp( async def set_color_temp(
self, temp: int, *, brightness=None, transition: int | None = None self, temp: int, *, brightness: int | None = None, transition: int | None = None
) -> dict: ) -> dict:
"""Set the color temperature of the device in kelvin. """Set the color temperature of the device in kelvin.

View File

@ -53,7 +53,7 @@ class LightEffect(Module, ABC):
LIGHT_EFFECTS_OFF = "Off" LIGHT_EFFECTS_OFF = "Off"
def _initialize_features(self): def _initialize_features(self) -> None:
"""Initialize features.""" """Initialize features."""
device = self._device device = self._device
self._add_feature( self._add_feature(
@ -96,7 +96,7 @@ class LightEffect(Module, ABC):
*, *,
brightness: int | None = None, brightness: int | None = None,
transition: int | None = None, transition: int | None = None,
) -> None: ) -> dict:
"""Set an effect on the device. """Set an effect on the device.
If brightness or transition is defined, If brightness or transition is defined,
@ -110,10 +110,11 @@ class LightEffect(Module, ABC):
:param int transition: The wanted transition time :param int transition: The wanted transition time
""" """
@abstractmethod
async def set_custom_effect( async def set_custom_effect(
self, self,
effect_dict: dict, effect_dict: dict,
) -> None: ) -> dict:
"""Set a custom effect on the device. """Set a custom effect on the device.
:param str effect_dict: The custom effect dict to set :param str effect_dict: The custom effect dict to set

View File

@ -83,7 +83,7 @@ class LightPreset(Module):
PRESET_NOT_SET = "Not set" PRESET_NOT_SET = "Not set"
def _initialize_features(self): def _initialize_features(self) -> None:
"""Initialize features.""" """Initialize features."""
device = self._device device = self._device
self._add_feature( self._add_feature(
@ -127,7 +127,7 @@ class LightPreset(Module):
async def set_preset( async def set_preset(
self, self,
preset_name: str, preset_name: str,
) -> None: ) -> dict:
"""Set a light preset for the device.""" """Set a light preset for the device."""
@abstractmethod @abstractmethod
@ -135,7 +135,7 @@ class LightPreset(Module):
self, self,
preset_name: str, preset_name: str,
preset_info: LightState, preset_info: LightState,
) -> None: ) -> dict:
"""Update the preset with *preset_name* with the new *preset_info*.""" """Update the preset with *preset_name* with the new *preset_info*."""
@property @property

View File

@ -54,7 +54,7 @@ class TurnOnBehavior(BaseModel):
mode: BehaviorMode mode: BehaviorMode
@root_validator @root_validator
def _mode_based_on_preset(cls, values): def _mode_based_on_preset(cls, values: dict) -> dict:
"""Set the mode based on the preset value.""" """Set the mode based on the preset value."""
if values["preset"] is not None: if values["preset"] is not None:
values["mode"] = BehaviorMode.Preset values["mode"] = BehaviorMode.Preset
@ -209,7 +209,7 @@ class IotBulb(IotDevice):
super().__init__(host=host, config=config, protocol=protocol) super().__init__(host=host, config=config, protocol=protocol)
self._device_type = DeviceType.Bulb self._device_type = DeviceType.Bulb
async def _initialize_modules(self): async def _initialize_modules(self) -> None:
"""Initialize modules not added in init.""" """Initialize modules not added in init."""
await super()._initialize_modules() await super()._initialize_modules()
self.add_module( self.add_module(
@ -307,7 +307,7 @@ class IotBulb(IotDevice):
await self._query_helper(self.LIGHT_SERVICE, "get_default_behavior") await self._query_helper(self.LIGHT_SERVICE, "get_default_behavior")
) )
async def set_turn_on_behavior(self, behavior: TurnOnBehaviors): async def set_turn_on_behavior(self, behavior: TurnOnBehaviors) -> dict:
"""Set the behavior for turning the bulb on. """Set the behavior for turning the bulb on.
If you do not want to manually construct the behavior object, If you do not want to manually construct the behavior object,
@ -426,7 +426,7 @@ class IotBulb(IotDevice):
@requires_update @requires_update
async def _set_color_temp( async def _set_color_temp(
self, temp: int, *, brightness=None, transition: int | None = None self, temp: int, *, brightness: int | None = None, transition: int | None = None
) -> dict: ) -> dict:
"""Set the color temperature of the device in kelvin. """Set the color temperature of the device in kelvin.
@ -450,7 +450,7 @@ class IotBulb(IotDevice):
return await self._set_light_state(light_state, transition=transition) return await self._set_light_state(light_state, transition=transition)
def _raise_for_invalid_brightness(self, value): def _raise_for_invalid_brightness(self, value: int) -> None:
if not isinstance(value, int): if not isinstance(value, int):
raise TypeError("Brightness must be an integer") raise TypeError("Brightness must be an integer")
if not (0 <= value <= 100): if not (0 <= value <= 100):
@ -517,7 +517,7 @@ class IotBulb(IotDevice):
"""Return that the bulb has an emeter.""" """Return that the bulb has an emeter."""
return True return True
async def set_alias(self, alias: str) -> None: async def set_alias(self, alias: str) -> dict:
"""Set the device name (alias). """Set the device name (alias).
Overridden to use a different module name. Overridden to use a different module name.

View File

@ -19,7 +19,7 @@ import inspect
import logging import logging
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from datetime import datetime, timedelta, tzinfo from datetime import datetime, timedelta, tzinfo
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any, Callable, cast
from warnings import warn from warnings import warn
from ..device import Device, WifiNetwork from ..device import Device, WifiNetwork
@ -35,12 +35,12 @@ from .modules import Emeter
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
def requires_update(f): def requires_update(f: Callable) -> Any:
"""Indicate that `update` should be called before accessing this method.""" # noqa: D202 """Indicate that `update` should be called before accessing this method.""" # noqa: D202
if inspect.iscoroutinefunction(f): if inspect.iscoroutinefunction(f):
@functools.wraps(f) @functools.wraps(f)
async def wrapped(*args, **kwargs): async def wrapped(*args: Any, **kwargs: Any) -> Any:
self = args[0] self = args[0]
if self._last_update is None and f.__name__ not in self._sys_info: if self._last_update is None and f.__name__ not in self._sys_info:
raise KasaException("You need to await update() to access the data") raise KasaException("You need to await update() to access the data")
@ -49,13 +49,13 @@ def requires_update(f):
else: else:
@functools.wraps(f) @functools.wraps(f)
def wrapped(*args, **kwargs): def wrapped(*args: Any, **kwargs: Any) -> Any:
self = args[0] self = args[0]
if self._last_update is None and f.__name__ not in self._sys_info: if self._last_update is None and f.__name__ not in self._sys_info:
raise KasaException("You need to await update() to access the data") raise KasaException("You need to await update() to access the data")
return f(*args, **kwargs) return f(*args, **kwargs)
f.requires_update = True f.requires_update = True # type: ignore[attr-defined]
return wrapped return wrapped
@ -197,7 +197,7 @@ class IotDevice(Device):
return cast(ModuleMapping[IotModule], self._supported_modules) return cast(ModuleMapping[IotModule], self._supported_modules)
return self._supported_modules return self._supported_modules
def add_module(self, name: str | ModuleName[Module], module: IotModule): def add_module(self, name: str | ModuleName[Module], module: IotModule) -> None:
"""Register a module.""" """Register a module."""
if name in self._modules: if name in self._modules:
_LOGGER.debug("Module %s already registered, ignoring...", name) _LOGGER.debug("Module %s already registered, ignoring...", name)
@ -207,8 +207,12 @@ class IotDevice(Device):
self._modules[name] = module self._modules[name] = module
def _create_request( def _create_request(
self, target: str, cmd: str, arg: dict | None = None, child_ids=None self,
): target: str,
cmd: str,
arg: dict | None = None,
child_ids: list | None = None,
) -> dict:
if arg is None: if arg is None:
arg = {} arg = {}
request: dict[str, Any] = {target: {cmd: arg}} request: dict[str, Any] = {target: {cmd: arg}}
@ -225,8 +229,12 @@ class IotDevice(Device):
raise KasaException("update() required prior accessing emeter") raise KasaException("update() required prior accessing emeter")
async def _query_helper( async def _query_helper(
self, target: str, cmd: str, arg: dict | None = None, child_ids=None self,
) -> Any: target: str,
cmd: str,
arg: dict | None = None,
child_ids: list | None = None,
) -> dict:
"""Query device, return results or raise an exception. """Query device, return results or raise an exception.
:param target: Target system {system, time, emeter, ..} :param target: Target system {system, time, emeter, ..}
@ -276,7 +284,7 @@ class IotDevice(Device):
"""Retrieve system information.""" """Retrieve system information."""
return await self._query_helper("system", "get_sysinfo") return await self._query_helper("system", "get_sysinfo")
async def update(self, update_children: bool = True): async def update(self, update_children: bool = True) -> None:
"""Query the device to update the data. """Query the device to update the data.
Needed for properties that are decorated with `requires_update`. Needed for properties that are decorated with `requires_update`.
@ -305,7 +313,7 @@ class IotDevice(Device):
if not self._features: if not self._features:
await self._initialize_features() await self._initialize_features()
async def _initialize_modules(self): async def _initialize_modules(self) -> None:
"""Initialize modules not added in init.""" """Initialize modules not added in init."""
if self.has_emeter: if self.has_emeter:
_LOGGER.debug( _LOGGER.debug(
@ -313,7 +321,7 @@ class IotDevice(Device):
) )
self.add_module(Module.Energy, Emeter(self, self.emeter_type)) self.add_module(Module.Energy, Emeter(self, self.emeter_type))
async def _initialize_features(self): async def _initialize_features(self) -> None:
"""Initialize common features.""" """Initialize common features."""
self._add_feature( self._add_feature(
Feature( Feature(
@ -364,7 +372,7 @@ class IotDevice(Device):
) )
) )
for module in self._supported_modules.values(): for module in self.modules.values():
module._initialize_features() module._initialize_features()
for module_feat in module._module_features.values(): for module_feat in module._module_features.values():
self._add_feature(module_feat) self._add_feature(module_feat)
@ -453,7 +461,7 @@ class IotDevice(Device):
sys_info = self._sys_info sys_info = self._sys_info
return sys_info.get("alias") if sys_info else None return sys_info.get("alias") if sys_info else None
async def set_alias(self, alias: str) -> None: async def set_alias(self, alias: str) -> dict:
"""Set the device name (alias).""" """Set the device name (alias)."""
return await self._query_helper("system", "set_dev_alias", {"alias": alias}) return await self._query_helper("system", "set_dev_alias", {"alias": alias})
@ -550,7 +558,7 @@ class IotDevice(Device):
return mac return mac
async def set_mac(self, mac): async def set_mac(self, mac: str) -> dict:
"""Set the mac address. """Set the mac address.
:param str mac: mac in hexadecimal with colons, e.g. 01:23:45:67:89:ab :param str mac: mac in hexadecimal with colons, e.g. 01:23:45:67:89:ab
@ -576,7 +584,7 @@ class IotDevice(Device):
"""Turn off the device.""" """Turn off the device."""
raise NotImplementedError("Device subclass needs to implement this.") raise NotImplementedError("Device subclass needs to implement this.")
async def turn_on(self, **kwargs) -> dict | None: async def turn_on(self, **kwargs) -> dict:
"""Turn device on.""" """Turn device on."""
raise NotImplementedError("Device subclass needs to implement this.") raise NotImplementedError("Device subclass needs to implement this.")
@ -586,7 +594,7 @@ class IotDevice(Device):
"""Return True if the device is on.""" """Return True if the device is on."""
raise NotImplementedError("Device subclass needs to implement this.") raise NotImplementedError("Device subclass needs to implement this.")
async def set_state(self, on: bool): async def set_state(self, on: bool) -> dict:
"""Set the device state.""" """Set the device state."""
if on: if on:
return await self.turn_on() return await self.turn_on()
@ -627,7 +635,7 @@ class IotDevice(Device):
async def wifi_scan(self) -> list[WifiNetwork]: # noqa: D202 async def wifi_scan(self) -> list[WifiNetwork]: # noqa: D202
"""Scan for available wifi networks.""" """Scan for available wifi networks."""
async def _scan(target): async def _scan(target: str) -> dict:
return await self._query_helper(target, "get_scaninfo", {"refresh": 1}) return await self._query_helper(target, "get_scaninfo", {"refresh": 1})
try: try:
@ -639,17 +647,17 @@ class IotDevice(Device):
info = await _scan("smartlife.iot.common.softaponboarding") info = await _scan("smartlife.iot.common.softaponboarding")
if "ap_list" not in info: if "ap_list" not in info:
raise KasaException("Invalid response for wifi scan: %s" % info) raise KasaException(f"Invalid response for wifi scan: {info}")
return [WifiNetwork(**x) for x in info["ap_list"]] return [WifiNetwork(**x) for x in info["ap_list"]]
async def wifi_join(self, ssid: str, password: str, keytype: str = "3"): # noqa: D202 async def wifi_join(self, ssid: str, password: str, keytype: str = "3") -> dict: # noqa: D202
"""Join the given wifi network. """Join the given wifi network.
If joining the network fails, the device will return to AP mode after a while. If joining the network fails, the device will return to AP mode after a while.
""" """
async def _join(target, payload): async def _join(target: str, payload: dict) -> dict:
return await self._query_helper(target, "set_stainfo", payload) return await self._query_helper(target, "set_stainfo", payload)
payload = {"ssid": ssid, "password": password, "key_type": int(keytype)} payload = {"ssid": ssid, "password": password, "key_type": int(keytype)}

View File

@ -80,7 +80,7 @@ class IotDimmer(IotPlug):
super().__init__(host=host, config=config, protocol=protocol) super().__init__(host=host, config=config, protocol=protocol)
self._device_type = DeviceType.Dimmer self._device_type = DeviceType.Dimmer
async def _initialize_modules(self): async def _initialize_modules(self) -> None:
"""Initialize modules.""" """Initialize modules."""
await super()._initialize_modules() await super()._initialize_modules()
# TODO: need to be verified if it's okay to call these on HS220 w/o these # TODO: need to be verified if it's okay to call these on HS220 w/o these
@ -103,7 +103,9 @@ class IotDimmer(IotPlug):
return int(sys_info["brightness"]) return int(sys_info["brightness"])
@requires_update @requires_update
async def _set_brightness(self, brightness: int, *, transition: int | None = None): async def _set_brightness(
self, brightness: int, *, transition: int | None = None
) -> dict:
"""Set the new dimmer brightness level in percentage. """Set the new dimmer brightness level in percentage.
:param int transition: transition duration in milliseconds. :param int transition: transition duration in milliseconds.
@ -134,7 +136,7 @@ class IotDimmer(IotPlug):
self.DIMMER_SERVICE, "set_brightness", {"brightness": brightness} self.DIMMER_SERVICE, "set_brightness", {"brightness": brightness}
) )
async def turn_off(self, *, transition: int | None = None, **kwargs): async def turn_off(self, *, transition: int | None = None, **kwargs) -> dict:
"""Turn the bulb off. """Turn the bulb off.
:param int transition: transition duration in milliseconds. :param int transition: transition duration in milliseconds.
@ -145,7 +147,7 @@ class IotDimmer(IotPlug):
return await super().turn_off() return await super().turn_off()
@requires_update @requires_update
async def turn_on(self, *, transition: int | None = None, **kwargs): async def turn_on(self, *, transition: int | None = None, **kwargs) -> dict:
"""Turn the bulb on. """Turn the bulb on.
:param int transition: transition duration in milliseconds. :param int transition: transition duration in milliseconds.
@ -157,7 +159,7 @@ class IotDimmer(IotPlug):
return await super().turn_on() return await super().turn_on()
async def set_dimmer_transition(self, brightness: int, transition: int): async def set_dimmer_transition(self, brightness: int, transition: int) -> dict:
"""Turn the bulb on to brightness percentage over transition milliseconds. """Turn the bulb on to brightness percentage over transition milliseconds.
A brightness value of 0 will turn off the dimmer. A brightness value of 0 will turn off the dimmer.
@ -176,7 +178,7 @@ class IotDimmer(IotPlug):
if not isinstance(transition, int): if not isinstance(transition, int):
raise TypeError(f"Transition must be integer, not of {type(transition)}.") raise TypeError(f"Transition must be integer, not of {type(transition)}.")
if transition <= 0: if transition <= 0:
raise ValueError("Transition value %s is not valid." % transition) raise ValueError(f"Transition value {transition} is not valid.")
return await self._query_helper( return await self._query_helper(
self.DIMMER_SERVICE, self.DIMMER_SERVICE,
@ -185,7 +187,7 @@ class IotDimmer(IotPlug):
) )
@requires_update @requires_update
async def get_behaviors(self): async def get_behaviors(self) -> dict:
"""Return button behavior settings.""" """Return button behavior settings."""
behaviors = await self._query_helper( behaviors = await self._query_helper(
self.DIMMER_SERVICE, "get_default_behavior", {} self.DIMMER_SERVICE, "get_default_behavior", {}
@ -195,7 +197,7 @@ class IotDimmer(IotPlug):
@requires_update @requires_update
async def set_button_action( async def set_button_action(
self, action_type: ActionType, action: ButtonAction, index: int | None = None self, action_type: ActionType, action: ButtonAction, index: int | None = None
): ) -> dict:
"""Set action to perform on button click/hold. """Set action to perform on button click/hold.
:param action_type ActionType: whether to control double click or hold action. :param action_type ActionType: whether to control double click or hold action.
@ -209,15 +211,17 @@ class IotDimmer(IotPlug):
if index is not None: if index is not None:
payload["index"] = index payload["index"] = index
await self._query_helper(self.DIMMER_SERVICE, action_type_setter, payload) return await self._query_helper(
self.DIMMER_SERVICE, action_type_setter, payload
)
@requires_update @requires_update
async def set_fade_time(self, fade_type: FadeType, time: int): async def set_fade_time(self, fade_type: FadeType, time: int) -> dict:
"""Set time for fade in / fade out.""" """Set time for fade in / fade out."""
fade_type_setter = f"set_{fade_type}_time" fade_type_setter = f"set_{fade_type}_time"
payload = {"fadeTime": time} payload = {"fadeTime": time}
await self._query_helper(self.DIMMER_SERVICE, fade_type_setter, payload) return await self._query_helper(self.DIMMER_SERVICE, fade_type_setter, payload)
@property # type: ignore @property # type: ignore
@requires_update @requires_update

View File

@ -57,7 +57,7 @@ class IotLightStrip(IotBulb):
super().__init__(host=host, config=config, protocol=protocol) super().__init__(host=host, config=config, protocol=protocol)
self._device_type = DeviceType.LightStrip self._device_type = DeviceType.LightStrip
async def _initialize_modules(self): async def _initialize_modules(self) -> None:
"""Initialize modules not added in init.""" """Initialize modules not added in init."""
await super()._initialize_modules() await super()._initialize_modules()
self.add_module( self.add_module(

View File

@ -1,6 +1,9 @@
"""Base class for IOT module implementations.""" """Base class for IOT module implementations."""
from __future__ import annotations
import logging import logging
from typing import Any
from ..exceptions import KasaException from ..exceptions import KasaException
from ..module import Module from ..module import Module
@ -24,16 +27,16 @@ merge = _merge_dict
class IotModule(Module): class IotModule(Module):
"""Base class implemention for all IOT modules.""" """Base class implemention for all IOT modules."""
def call(self, method, params=None): async def call(self, method: str, params: dict | None = None) -> dict:
"""Call the given method with the given parameters.""" """Call the given method with the given parameters."""
return self._device._query_helper(self._module, method, params) return await self._device._query_helper(self._module, method, params)
def query_for_command(self, query, params=None): def query_for_command(self, query: str, params: dict | None = None) -> dict:
"""Create a request object for the given parameters.""" """Create a request object for the given parameters."""
return self._device._create_request(self._module, query, params) return self._device._create_request(self._module, query, params)
@property @property
def estimated_query_response_size(self): def estimated_query_response_size(self) -> int:
"""Estimated maximum size of query response. """Estimated maximum size of query response.
The inheriting modules implement this to estimate how large a query response The inheriting modules implement this to estimate how large a query response
@ -42,7 +45,7 @@ class IotModule(Module):
return 256 # Estimate for modules that don't specify return 256 # Estimate for modules that don't specify
@property @property
def data(self): def data(self) -> dict[str, Any]:
"""Return the module specific raw data from the last update.""" """Return the module specific raw data from the last update."""
dev = self._device dev = self._device
q = self.query() q = self.query()

View File

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Any
from ..device_type import DeviceType from ..device_type import DeviceType
from ..deviceconfig import DeviceConfig from ..deviceconfig import DeviceConfig
@ -54,7 +55,7 @@ class IotPlug(IotDevice):
super().__init__(host=host, config=config, protocol=protocol) super().__init__(host=host, config=config, protocol=protocol)
self._device_type = DeviceType.Plug self._device_type = DeviceType.Plug
async def _initialize_modules(self): async def _initialize_modules(self) -> None:
"""Initialize modules.""" """Initialize modules."""
await super()._initialize_modules() await super()._initialize_modules()
self.add_module(Module.IotSchedule, Schedule(self, "schedule")) self.add_module(Module.IotSchedule, Schedule(self, "schedule"))
@ -71,11 +72,11 @@ class IotPlug(IotDevice):
sys_info = self.sys_info sys_info = self.sys_info
return bool(sys_info["relay_state"]) return bool(sys_info["relay_state"])
async def turn_on(self, **kwargs): async def turn_on(self, **kwargs: Any) -> dict:
"""Turn the switch on.""" """Turn the switch on."""
return await self._query_helper("system", "set_relay_state", {"state": 1}) return await self._query_helper("system", "set_relay_state", {"state": 1})
async def turn_off(self, **kwargs): async def turn_off(self, **kwargs: Any) -> dict:
"""Turn the switch off.""" """Turn the switch off."""
return await self._query_helper("system", "set_relay_state", {"state": 0}) return await self._query_helper("system", "set_relay_state", {"state": 0})

View File

@ -26,7 +26,7 @@ from .modules import Antitheft, Cloud, Countdown, Emeter, Led, Schedule, Time, U
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
def merge_sums(dicts): def merge_sums(dicts: list[dict]) -> dict:
"""Merge the sum of dicts.""" """Merge the sum of dicts."""
total_dict: defaultdict[int, float] = defaultdict(lambda: 0.0) total_dict: defaultdict[int, float] = defaultdict(lambda: 0.0)
for sum_dict in dicts: for sum_dict in dicts:
@ -99,7 +99,7 @@ class IotStrip(IotDevice):
self.emeter_type = "emeter" self.emeter_type = "emeter"
self._device_type = DeviceType.Strip self._device_type = DeviceType.Strip
async def _initialize_modules(self): async def _initialize_modules(self) -> None:
"""Initialize modules.""" """Initialize modules."""
# Strip has different modules to plug so do not call super # Strip has different modules to plug so do not call super
self.add_module(Module.IotAntitheft, Antitheft(self, "anti_theft")) self.add_module(Module.IotAntitheft, Antitheft(self, "anti_theft"))
@ -121,7 +121,7 @@ class IotStrip(IotDevice):
"""Return if any of the outlets are on.""" """Return if any of the outlets are on."""
return any(plug.is_on for plug in self.children) return any(plug.is_on for plug in self.children)
async def update(self, update_children: bool = True): async def update(self, update_children: bool = True) -> None:
"""Update some of the attributes. """Update some of the attributes.
Needed for methods that are decorated with `requires_update`. Needed for methods that are decorated with `requires_update`.
@ -150,20 +150,20 @@ class IotStrip(IotDevice):
if not self.features: if not self.features:
await self._initialize_features() await self._initialize_features()
async def _initialize_features(self): async def _initialize_features(self) -> None:
"""Initialize common features.""" """Initialize common features."""
# Do not initialize features until children are created # Do not initialize features until children are created
if not self.children: if not self.children:
return return
await super()._initialize_features() await super()._initialize_features()
async def turn_on(self, **kwargs): async def turn_on(self, **kwargs) -> dict:
"""Turn the strip on.""" """Turn the strip on."""
await self._query_helper("system", "set_relay_state", {"state": 1}) return await self._query_helper("system", "set_relay_state", {"state": 1})
async def turn_off(self, **kwargs): async def turn_off(self, **kwargs) -> dict:
"""Turn the strip off.""" """Turn the strip off."""
await self._query_helper("system", "set_relay_state", {"state": 0}) return await self._query_helper("system", "set_relay_state", {"state": 0})
@property # type: ignore @property # type: ignore
@requires_update @requires_update
@ -188,7 +188,7 @@ class StripEmeter(IotModule, Energy):
"""Return True if module supports the feature.""" """Return True if module supports the feature."""
return module_feature in self._supported return module_feature in self._supported
def query(self): def query(self) -> dict:
"""Return the base query.""" """Return the base query."""
return {} return {}
@ -246,11 +246,13 @@ class StripEmeter(IotModule, Energy):
] ]
) )
async def erase_stats(self): async def erase_stats(self) -> dict:
"""Erase energy meter statistics for all plugs.""" """Erase energy meter statistics for all plugs."""
for plug in self._device.children: for plug in self._device.children:
await plug.modules[Module.Energy].erase_stats() await plug.modules[Module.Energy].erase_stats()
return {}
@property # type: ignore @property # type: ignore
def consumption_this_month(self) -> float | None: def consumption_this_month(self) -> float | None:
"""Return this month's energy consumption in kWh.""" """Return this month's energy consumption in kWh."""
@ -320,7 +322,7 @@ class IotStripPlug(IotPlug):
self.protocol = parent.protocol # Must use the same connection as the parent self.protocol = parent.protocol # Must use the same connection as the parent
self._on_since: datetime | None = None self._on_since: datetime | None = None
async def _initialize_modules(self): async def _initialize_modules(self) -> None:
"""Initialize modules not added in init.""" """Initialize modules not added in init."""
if self.has_emeter: if self.has_emeter:
self.add_module(Module.Energy, Emeter(self, self.emeter_type)) self.add_module(Module.Energy, Emeter(self, self.emeter_type))
@ -329,7 +331,7 @@ class IotStripPlug(IotPlug):
self.add_module(Module.IotSchedule, Schedule(self, "schedule")) self.add_module(Module.IotSchedule, Schedule(self, "schedule"))
self.add_module(Module.IotCountdown, Countdown(self, "countdown")) self.add_module(Module.IotCountdown, Countdown(self, "countdown"))
async def _initialize_features(self): async def _initialize_features(self) -> None:
"""Initialize common features.""" """Initialize common features."""
self._add_feature( self._add_feature(
Feature( Feature(
@ -353,19 +355,20 @@ class IotStripPlug(IotPlug):
type=Feature.Type.Sensor, type=Feature.Type.Sensor,
) )
) )
for module in self._supported_modules.values():
for module in self.modules.values():
module._initialize_features() module._initialize_features()
for module_feat in module._module_features.values(): for module_feat in module._module_features.values():
self._add_feature(module_feat) self._add_feature(module_feat)
async def update(self, update_children: bool = True): async def update(self, update_children: bool = True) -> None:
"""Query the device to update the data. """Query the device to update the data.
Needed for properties that are decorated with `requires_update`. Needed for properties that are decorated with `requires_update`.
""" """
await self._update(update_children) await self._update(update_children)
async def _update(self, update_children: bool = True): async def _update(self, update_children: bool = True) -> None:
"""Query the device to update the data. """Query the device to update the data.
Internal implementation to allow patching of public update in the cli Internal implementation to allow patching of public update in the cli
@ -379,8 +382,12 @@ class IotStripPlug(IotPlug):
await self._initialize_features() await self._initialize_features()
def _create_request( def _create_request(
self, target: str, cmd: str, arg: dict | None = None, child_ids=None self,
): target: str,
cmd: str,
arg: dict | None = None,
child_ids: list | None = None,
) -> dict:
request: dict[str, Any] = { request: dict[str, Any] = {
"context": {"child_ids": [self.child_id]}, "context": {"child_ids": [self.child_id]},
target: {cmd: arg}, target: {cmd: arg},
@ -388,8 +395,12 @@ class IotStripPlug(IotPlug):
return request return request
async def _query_helper( async def _query_helper(
self, target: str, cmd: str, arg: dict | None = None, child_ids=None self,
) -> Any: target: str,
cmd: str,
arg: dict | None = None,
child_ids: list | None = None,
) -> dict:
"""Override query helper to include the child_ids.""" """Override query helper to include the child_ids."""
return await self._parent._query_helper( return await self._parent._query_helper(
target, cmd, arg, child_ids=[self.child_id] target, cmd, arg, child_ids=[self.child_id]

View File

@ -11,7 +11,7 @@ _LOGGER = logging.getLogger(__name__)
class AmbientLight(IotModule): class AmbientLight(IotModule):
"""Implements ambient light controls for the motion sensor.""" """Implements ambient light controls for the motion sensor."""
def _initialize_features(self): def _initialize_features(self) -> None:
"""Initialize features after the initial update.""" """Initialize features after the initial update."""
self._add_feature( self._add_feature(
Feature( Feature(
@ -40,7 +40,7 @@ class AmbientLight(IotModule):
) )
) )
def query(self): def query(self) -> dict:
"""Request configuration.""" """Request configuration."""
req = merge( req = merge(
self.query_for_command("get_config"), self.query_for_command("get_config"),
@ -74,18 +74,18 @@ class AmbientLight(IotModule):
"""Return True if the module is enabled.""" """Return True if the module is enabled."""
return int(self.data["get_current_brt"]["value"]) return int(self.data["get_current_brt"]["value"])
async def set_enabled(self, state: bool): async def set_enabled(self, state: bool) -> dict:
"""Enable/disable LAS.""" """Enable/disable LAS."""
return await self.call("set_enable", {"enable": int(state)}) return await self.call("set_enable", {"enable": int(state)})
async def current_brightness(self) -> int: async def current_brightness(self) -> dict:
"""Return current brightness. """Return current brightness.
Return value units. Return value units.
""" """
return await self.call("get_current_brt") return await self.call("get_current_brt")
async def set_brightness_limit(self, value: int): async def set_brightness_limit(self, value: int) -> dict:
"""Set the limit when the motion sensor is inactive. """Set the limit when the motion sensor is inactive.
See `presets` for preset values. Custom values are also likely allowed. See `presets` for preset values. Custom values are also likely allowed.

View File

@ -24,7 +24,7 @@ class CloudInfo(BaseModel):
class Cloud(IotModule): class Cloud(IotModule):
"""Module implementing support for cloud services.""" """Module implementing support for cloud services."""
def _initialize_features(self): def _initialize_features(self) -> None:
"""Initialize features after the initial update.""" """Initialize features after the initial update."""
self._add_feature( self._add_feature(
Feature( Feature(
@ -44,7 +44,7 @@ class Cloud(IotModule):
"""Return true if device is connected to the cloud.""" """Return true if device is connected to the cloud."""
return self.info.binded return self.info.binded
def query(self): def query(self) -> dict:
"""Request cloud connectivity info.""" """Request cloud connectivity info."""
return self.query_for_command("get_info") return self.query_for_command("get_info")
@ -53,20 +53,20 @@ class Cloud(IotModule):
"""Return information about the cloud connectivity.""" """Return information about the cloud connectivity."""
return CloudInfo.parse_obj(self.data["get_info"]) return CloudInfo.parse_obj(self.data["get_info"])
def get_available_firmwares(self): def get_available_firmwares(self) -> dict:
"""Return list of available firmwares.""" """Return list of available firmwares."""
return self.query_for_command("get_intl_fw_list") return self.query_for_command("get_intl_fw_list")
def set_server(self, url: str): def set_server(self, url: str) -> dict:
"""Set the update server URL.""" """Set the update server URL."""
return self.query_for_command("set_server_url", {"server": url}) return self.query_for_command("set_server_url", {"server": url})
def connect(self, username: str, password: str): def connect(self, username: str, password: str) -> dict:
"""Login to the cloud using given information.""" """Login to the cloud using given information."""
return self.query_for_command( return self.query_for_command(
"bind", {"username": username, "password": password} "bind", {"username": username, "password": password}
) )
def disconnect(self): def disconnect(self) -> dict:
"""Disconnect from the cloud.""" """Disconnect from the cloud."""
return self.query_for_command("unbind") return self.query_for_command("unbind")

View File

@ -70,7 +70,7 @@ class Emeter(Usage, EnergyInterface):
"""Get the current voltage in V.""" """Get the current voltage in V."""
return self.status.voltage return self.status.voltage
async def erase_stats(self): async def erase_stats(self) -> dict:
"""Erase all stats. """Erase all stats.
Uses different query than usage meter. Uses different query than usage meter.
@ -81,7 +81,9 @@ class Emeter(Usage, EnergyInterface):
"""Return real-time statistics.""" """Return real-time statistics."""
return EmeterStatus(await self.call("get_realtime")) return EmeterStatus(await self.call("get_realtime"))
async def get_daily_stats(self, *, year=None, month=None, kwh=True) -> dict: async def get_daily_stats(
self, *, year: int | None = None, month: int | None = None, kwh: bool = True
) -> dict:
"""Return daily stats for the given year & month. """Return daily stats for the given year & month.
The return value is a dictionary of {day: energy, ...}. The return value is a dictionary of {day: energy, ...}.
@ -90,7 +92,9 @@ class Emeter(Usage, EnergyInterface):
data = self._convert_stat_data(data["day_list"], entry_key="day", kwh=kwh) data = self._convert_stat_data(data["day_list"], entry_key="day", kwh=kwh)
return data return data
async def get_monthly_stats(self, *, year=None, kwh=True) -> dict: async def get_monthly_stats(
self, *, year: int | None = None, kwh: bool = True
) -> dict:
"""Return monthly stats for the given year. """Return monthly stats for the given year.
The return value is a dictionary of {month: energy, ...}. The return value is a dictionary of {month: energy, ...}.

View File

@ -14,7 +14,7 @@ class Led(IotModule, LedInterface):
return {} return {}
@property @property
def mode(self): def mode(self) -> str:
"""LED mode setting. """LED mode setting.
"always", "never" "always", "never"
@ -27,7 +27,7 @@ class Led(IotModule, LedInterface):
sys_info = self.data sys_info = self.data
return bool(1 - sys_info["led_off"]) return bool(1 - sys_info["led_off"])
async def set_led(self, state: bool): async def set_led(self, state: bool) -> dict:
"""Set the state of the led (night mode).""" """Set the state of the led (night mode)."""
return await self.call("set_led_off", {"off": int(not state)}) return await self.call("set_led_off", {"off": int(not state)})

View File

@ -27,7 +27,7 @@ class Light(IotModule, LightInterface):
_device: IotBulb | IotDimmer _device: IotBulb | IotDimmer
_light_state: LightState _light_state: LightState
def _initialize_features(self): def _initialize_features(self) -> None:
"""Initialize features.""" """Initialize features."""
super()._initialize_features() super()._initialize_features()
device = self._device device = self._device
@ -185,7 +185,7 @@ class Light(IotModule, LightInterface):
return bulb._color_temp return bulb._color_temp
async def set_color_temp( async def set_color_temp(
self, temp: int, *, brightness=None, transition: int | None = None self, temp: int, *, brightness: int | None = None, transition: int | None = None
) -> dict: ) -> dict:
"""Set the color temperature of the device in kelvin. """Set the color temperature of the device in kelvin.

View File

@ -50,7 +50,7 @@ class LightEffect(IotModule, LightEffectInterface):
*, *,
brightness: int | None = None, brightness: int | None = None,
transition: int | None = None, transition: int | None = None,
) -> None: ) -> dict:
"""Set an effect on the device. """Set an effect on the device.
If brightness or transition is defined, If brightness or transition is defined,
@ -73,7 +73,7 @@ class LightEffect(IotModule, LightEffectInterface):
effect_dict = EFFECT_MAPPING_V1["Aurora"] effect_dict = EFFECT_MAPPING_V1["Aurora"]
effect_dict = {**effect_dict} effect_dict = {**effect_dict}
effect_dict["enable"] = 0 effect_dict["enable"] = 0
await self.set_custom_effect(effect_dict) return await self.set_custom_effect(effect_dict)
elif effect not in EFFECT_MAPPING_V1: elif effect not in EFFECT_MAPPING_V1:
raise ValueError(f"The effect {effect} is not a built in effect.") raise ValueError(f"The effect {effect} is not a built in effect.")
else: else:
@ -84,12 +84,12 @@ class LightEffect(IotModule, LightEffectInterface):
if transition is not None: if transition is not None:
effect_dict["transition"] = transition effect_dict["transition"] = transition
await self.set_custom_effect(effect_dict) return await self.set_custom_effect(effect_dict)
async def set_custom_effect( async def set_custom_effect(
self, self,
effect_dict: dict, effect_dict: dict,
) -> None: ) -> dict:
"""Set a custom effect on the device. """Set a custom effect on the device.
:param str effect_dict: The custom effect dict to set :param str effect_dict: The custom effect dict to set
@ -104,7 +104,7 @@ class LightEffect(IotModule, LightEffectInterface):
"""Return True if the device supports setting custom effects.""" """Return True if the device supports setting custom effects."""
return True return True
def query(self): def query(self) -> dict:
"""Return the base query.""" """Return the base query."""
return {} return {}

View File

@ -41,7 +41,7 @@ class LightPreset(IotModule, LightPresetInterface):
_presets: dict[str, IotLightPreset] _presets: dict[str, IotLightPreset]
_preset_list: list[str] _preset_list: list[str]
async def _post_update_hook(self): async def _post_update_hook(self) -> None:
"""Update the internal presets.""" """Update the internal presets."""
self._presets = { self._presets = {
f"Light preset {index+1}": IotLightPreset(**vals) f"Light preset {index+1}": IotLightPreset(**vals)
@ -93,7 +93,7 @@ class LightPreset(IotModule, LightPresetInterface):
async def set_preset( async def set_preset(
self, self,
preset_name: str, preset_name: str,
) -> None: ) -> dict:
"""Set a light preset for the device.""" """Set a light preset for the device."""
light = self._device.modules[Module.Light] light = self._device.modules[Module.Light]
if preset_name == self.PRESET_NOT_SET: if preset_name == self.PRESET_NOT_SET:
@ -104,7 +104,7 @@ class LightPreset(IotModule, LightPresetInterface):
elif (preset := self._presets.get(preset_name)) is None: # type: ignore[assignment] elif (preset := self._presets.get(preset_name)) is None: # type: ignore[assignment]
raise ValueError(f"{preset_name} is not a valid preset: {self.preset_list}") raise ValueError(f"{preset_name} is not a valid preset: {self.preset_list}")
await light.set_state(preset) return await light.set_state(preset)
@property @property
def has_save_preset(self) -> bool: def has_save_preset(self) -> bool:
@ -115,7 +115,7 @@ class LightPreset(IotModule, LightPresetInterface):
self, self,
preset_name: str, preset_name: str,
preset_state: LightState, preset_state: LightState,
) -> None: ) -> dict:
"""Update the preset with preset_name with the new preset_info.""" """Update the preset with preset_name with the new preset_info."""
if len(self._presets) == 0: if len(self._presets) == 0:
raise KasaException("Device does not supported saving presets") raise KasaException("Device does not supported saving presets")
@ -129,7 +129,7 @@ class LightPreset(IotModule, LightPresetInterface):
return await self.call("set_preferred_state", state) return await self.call("set_preferred_state", state)
def query(self): def query(self) -> dict:
"""Return the base query.""" """Return the base query."""
return {} return {}
@ -142,7 +142,7 @@ class LightPreset(IotModule, LightPresetInterface):
if "id" not in vals if "id" not in vals
] ]
async def _deprecated_save_preset(self, preset: IotLightPreset): async def _deprecated_save_preset(self, preset: IotLightPreset) -> dict:
"""Save a setting preset. """Save a setting preset.
You can either construct a preset object manually, or pass an existing one You can either construct a preset object manually, or pass an existing one

View File

@ -24,7 +24,7 @@ class Range(Enum):
class Motion(IotModule): class Motion(IotModule):
"""Implements the motion detection (PIR) module.""" """Implements the motion detection (PIR) module."""
def _initialize_features(self): def _initialize_features(self) -> None:
"""Initialize features after the initial update.""" """Initialize features after the initial update."""
# Only add features if the device supports the module # Only add features if the device supports the module
if "get_config" not in self.data: if "get_config" not in self.data:
@ -48,7 +48,7 @@ class Motion(IotModule):
) )
) )
def query(self): def query(self) -> dict:
"""Request PIR configuration.""" """Request PIR configuration."""
return self.query_for_command("get_config") return self.query_for_command("get_config")
@ -67,13 +67,13 @@ class Motion(IotModule):
"""Return True if module is enabled.""" """Return True if module is enabled."""
return bool(self.config["enable"]) return bool(self.config["enable"])
async def set_enabled(self, state: bool): async def set_enabled(self, state: bool) -> dict:
"""Enable/disable PIR.""" """Enable/disable PIR."""
return await self.call("set_enable", {"enable": int(state)}) return await self.call("set_enable", {"enable": int(state)})
async def set_range( async def set_range(
self, *, range: Range | None = None, custom_range: int | None = None self, *, range: Range | None = None, custom_range: int | None = None
): ) -> dict:
"""Set the range for the sensor. """Set the range for the sensor.
:param range: for using standard ranges :param range: for using standard ranges
@ -93,7 +93,7 @@ class Motion(IotModule):
"""Return inactivity timeout in milliseconds.""" """Return inactivity timeout in milliseconds."""
return self.config["cold_time"] return self.config["cold_time"]
async def set_inactivity_timeout(self, timeout: int): async def set_inactivity_timeout(self, timeout: int) -> dict:
"""Set inactivity timeout in milliseconds. """Set inactivity timeout in milliseconds.
Note, that you need to delete the default "Smart Control" rule in the app Note, that you need to delete the default "Smart Control" rule in the app

View File

@ -57,7 +57,7 @@ _LOGGER = logging.getLogger(__name__)
class RuleModule(IotModule): class RuleModule(IotModule):
"""Base class for rule-based modules, such as countdown and antitheft.""" """Base class for rule-based modules, such as countdown and antitheft."""
def query(self): def query(self) -> dict:
"""Prepare the query for rules.""" """Prepare the query for rules."""
q = self.query_for_command("get_rules") q = self.query_for_command("get_rules")
return merge(q, self.query_for_command("get_next_action")) return merge(q, self.query_for_command("get_next_action"))
@ -73,14 +73,14 @@ class RuleModule(IotModule):
_LOGGER.error("Unable to read rule list: %s (data: %s)", ex, self.data) _LOGGER.error("Unable to read rule list: %s (data: %s)", ex, self.data)
return [] return []
async def set_enabled(self, state: bool): async def set_enabled(self, state: bool) -> dict:
"""Enable or disable the service.""" """Enable or disable the service."""
return await self.call("set_overall_enable", state) return await self.call("set_overall_enable", {"enable": state})
async def delete_rule(self, rule: Rule): async def delete_rule(self, rule: Rule) -> dict:
"""Delete the given rule.""" """Delete the given rule."""
return await self.call("delete_rule", {"id": rule.id}) return await self.call("delete_rule", {"id": rule.id})
async def delete_all_rules(self): async def delete_all_rules(self) -> dict:
"""Delete all rules.""" """Delete all rules."""
return await self.call("delete_all_rules") return await self.call("delete_all_rules")

View File

@ -15,14 +15,14 @@ class Time(IotModule, TimeInterface):
_timezone: tzinfo = timezone.utc _timezone: tzinfo = timezone.utc
def query(self): def query(self) -> dict:
"""Request time and timezone.""" """Request time and timezone."""
q = self.query_for_command("get_time") q = self.query_for_command("get_time")
merge(q, self.query_for_command("get_timezone")) merge(q, self.query_for_command("get_timezone"))
return q return q
async def _post_update_hook(self): async def _post_update_hook(self) -> None:
"""Perform actions after a device update.""" """Perform actions after a device update."""
if res := self.data.get("get_timezone"): if res := self.data.get("get_timezone"):
self._timezone = await get_timezone(res.get("index")) self._timezone = await get_timezone(res.get("index"))
@ -47,7 +47,7 @@ class Time(IotModule, TimeInterface):
"""Return current timezone.""" """Return current timezone."""
return self._timezone return self._timezone
async def get_time(self): async def get_time(self) -> datetime | None:
"""Return current device time.""" """Return current device time."""
try: try:
res = await self.call("get_time") res = await self.call("get_time")
@ -88,6 +88,6 @@ class Time(IotModule, TimeInterface):
except Exception as ex: except Exception as ex:
raise KasaException(ex) from ex raise KasaException(ex) from ex
async def get_timezone(self): async def get_timezone(self) -> dict:
"""Request timezone information from the device.""" """Request timezone information from the device."""
return await self.call("get_timezone") return await self.call("get_timezone")

View File

@ -10,7 +10,7 @@ from ..iotmodule import IotModule, merge
class Usage(IotModule): class Usage(IotModule):
"""Baseclass for emeter/usage interfaces.""" """Baseclass for emeter/usage interfaces."""
def query(self): def query(self) -> dict:
"""Return the base query.""" """Return the base query."""
now = datetime.now() now = datetime.now()
year = now.year year = now.year
@ -25,22 +25,22 @@ class Usage(IotModule):
return req return req
@property @property
def estimated_query_response_size(self): def estimated_query_response_size(self) -> int:
"""Estimated maximum query response size.""" """Estimated maximum query response size."""
return 2048 return 2048
@property @property
def daily_data(self): def daily_data(self) -> list[dict]:
"""Return statistics on daily basis.""" """Return statistics on daily basis."""
return self.data["get_daystat"]["day_list"] return self.data["get_daystat"]["day_list"]
@property @property
def monthly_data(self): def monthly_data(self) -> list[dict]:
"""Return statistics on monthly basis.""" """Return statistics on monthly basis."""
return self.data["get_monthstat"]["month_list"] return self.data["get_monthstat"]["month_list"]
@property @property
def usage_today(self): def usage_today(self) -> int | None:
"""Return today's usage in minutes.""" """Return today's usage in minutes."""
today = datetime.now().day today = datetime.now().day
# Traverse the list in reverse order to find the latest entry. # Traverse the list in reverse order to find the latest entry.
@ -50,7 +50,7 @@ class Usage(IotModule):
return None return None
@property @property
def usage_this_month(self): def usage_this_month(self) -> int | None:
"""Return usage in this month in minutes.""" """Return usage in this month in minutes."""
this_month = datetime.now().month this_month = datetime.now().month
# Traverse the list in reverse order to find the latest entry. # Traverse the list in reverse order to find the latest entry.
@ -59,7 +59,9 @@ class Usage(IotModule):
return entry["time"] return entry["time"]
return None return None
async def get_raw_daystat(self, *, year=None, month=None) -> dict: async def get_raw_daystat(
self, *, year: int | None = None, month: int | None = None
) -> dict:
"""Return raw daily stats for the given year & month.""" """Return raw daily stats for the given year & month."""
if year is None: if year is None:
year = datetime.now().year year = datetime.now().year
@ -68,14 +70,16 @@ class Usage(IotModule):
return await self.call("get_daystat", {"year": year, "month": month}) return await self.call("get_daystat", {"year": year, "month": month})
async def get_raw_monthstat(self, *, year=None) -> dict: async def get_raw_monthstat(self, *, year: int | None = None) -> dict:
"""Return raw monthly stats for the given year.""" """Return raw monthly stats for the given year."""
if year is None: if year is None:
year = datetime.now().year year = datetime.now().year
return await self.call("get_monthstat", {"year": year}) return await self.call("get_monthstat", {"year": year})
async def get_daystat(self, *, year=None, month=None) -> dict: async def get_daystat(
self, *, year: int | None = None, month: int | None = None
) -> dict:
"""Return daily stats for the given year & month. """Return daily stats for the given year & month.
The return value is a dictionary of {day: time, ...}. The return value is a dictionary of {day: time, ...}.
@ -84,7 +88,7 @@ class Usage(IotModule):
data = self._convert_stat_data(data["day_list"], entry_key="day") data = self._convert_stat_data(data["day_list"], entry_key="day")
return data return data
async def get_monthstat(self, *, year=None) -> dict: async def get_monthstat(self, *, year: int | None = None) -> dict:
"""Return monthly stats for the given year. """Return monthly stats for the given year.
The return value is a dictionary of {month: time, ...}. The return value is a dictionary of {month: time, ...}.
@ -93,11 +97,11 @@ class Usage(IotModule):
data = self._convert_stat_data(data["month_list"], entry_key="month") data = self._convert_stat_data(data["month_list"], entry_key="month")
return data return data
async def erase_stats(self): async def erase_stats(self) -> dict:
"""Erase all stats.""" """Erase all stats."""
return await self.call("erase_runtime_stat") return await self.call("erase_runtime_stat")
def _convert_stat_data(self, data, entry_key) -> dict: def _convert_stat_data(self, data: list[dict], entry_key: str) -> dict:
"""Return usage information keyed with the day/month. """Return usage information keyed with the day/month.
The incoming data is a list of dictionaries:: The incoming data is a list of dictionaries::
@ -113,6 +117,6 @@ class Usage(IotModule):
if not data: if not data:
return {} return {}
data = {entry[entry_key]: entry["time"] for entry in data} res = {entry[entry_key]: entry["time"] for entry in data}
return data return res

View File

@ -1,9 +1,13 @@
"""JSON abstraction.""" """JSON abstraction."""
from __future__ import annotations
from typing import Any, Callable
try: try:
import orjson import orjson
def dumps(obj, *, default=None): def dumps(obj: Any, *, default: Callable | None = None) -> str:
"""Dump JSON.""" """Dump JSON."""
return orjson.dumps(obj).decode() return orjson.dumps(obj).decode()
@ -11,7 +15,7 @@ try:
except ImportError: except ImportError:
import json import json
def dumps(obj, *, default=None): def dumps(obj: Any, *, default: Callable | None = None) -> str:
"""Dump JSON.""" """Dump JSON."""
# Separators specified for consistency with orjson # Separators specified for consistency with orjson
return json.dumps(obj, separators=(",", ":")) return json.dumps(obj, separators=(",", ":"))

View File

@ -50,7 +50,8 @@ import logging
import secrets import secrets
import struct import struct
import time import time
from typing import TYPE_CHECKING, Any, cast from asyncio import Future
from typing import TYPE_CHECKING, Any, Generator, cast
from cryptography.hazmat.primitives import padding from cryptography.hazmat.primitives import padding
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
@ -110,10 +111,10 @@ class KlapTransport(BaseTransport):
else: else:
self._local_auth_hash = base64.b64decode(self._credentials_hash.encode()) # type: ignore[union-attr] self._local_auth_hash = base64.b64decode(self._credentials_hash.encode()) # type: ignore[union-attr]
self._default_credentials_auth_hash: dict[str, bytes] = {} self._default_credentials_auth_hash: dict[str, bytes] = {}
self._blank_auth_hash = None self._blank_auth_hash: bytes | None = None
self._handshake_lock = asyncio.Lock() self._handshake_lock = asyncio.Lock()
self._query_lock = asyncio.Lock() self._query_lock = asyncio.Lock()
self._handshake_done = False self._handshake_done: bool = False
self._encryption_session: KlapEncryptionSession | None = None self._encryption_session: KlapEncryptionSession | None = None
self._session_expire_at: float | None = None self._session_expire_at: float | None = None
@ -125,7 +126,7 @@ class KlapTransport(BaseTransport):
self._request_url = self._app_url / "request" self._request_url = self._app_url / "request"
@property @property
def default_port(self): def default_port(self) -> int:
"""Default port for the transport.""" """Default port for the transport."""
return self.DEFAULT_PORT return self.DEFAULT_PORT
@ -242,7 +243,7 @@ class KlapTransport(BaseTransport):
raise AuthenticationError(msg) raise AuthenticationError(msg)
async def perform_handshake2( async def perform_handshake2(
self, local_seed, remote_seed, auth_hash self, local_seed: bytes, remote_seed: bytes, auth_hash: bytes
) -> KlapEncryptionSession: ) -> KlapEncryptionSession:
"""Perform handshake2.""" """Perform handshake2."""
# Handshake 2 has the following payload: # Handshake 2 has the following payload:
@ -277,7 +278,7 @@ class KlapTransport(BaseTransport):
return KlapEncryptionSession(local_seed, remote_seed, auth_hash) return KlapEncryptionSession(local_seed, remote_seed, auth_hash)
async def perform_handshake(self) -> Any: async def perform_handshake(self) -> None:
"""Perform handshake1 and handshake2. """Perform handshake1 and handshake2.
Sets the encryption_session if successful. Sets the encryption_session if successful.
@ -309,14 +310,14 @@ class KlapTransport(BaseTransport):
_LOGGER.debug("Handshake with %s complete", self._host) _LOGGER.debug("Handshake with %s complete", self._host)
def _handshake_session_expired(self): def _handshake_session_expired(self) -> bool:
"""Return true if session has expired.""" """Return true if session has expired."""
return ( return (
self._session_expire_at is None self._session_expire_at is None
or self._session_expire_at - time.monotonic() <= 0 or self._session_expire_at - time.monotonic() <= 0
) )
async def send(self, request: str): async def send(self, request: str) -> Generator[Future, None, dict[str, str]]: # type: ignore[override]
"""Send the request.""" """Send the request."""
if not self._handshake_done or self._handshake_session_expired(): if not self._handshake_done or self._handshake_session_expired():
await self.perform_handshake() await self.perform_handshake()
@ -355,6 +356,7 @@ class KlapTransport(BaseTransport):
if TYPE_CHECKING: if TYPE_CHECKING:
assert self._encryption_session assert self._encryption_session
assert isinstance(response_data, bytes)
try: try:
decrypted_response = self._encryption_session.decrypt(response_data) decrypted_response = self._encryption_session.decrypt(response_data)
except Exception as ex: except Exception as ex:
@ -378,7 +380,7 @@ class KlapTransport(BaseTransport):
self._handshake_done = False self._handshake_done = False
@staticmethod @staticmethod
def generate_auth_hash(creds: Credentials): def generate_auth_hash(creds: Credentials) -> bytes:
"""Generate an md5 auth hash for the protocol on the supplied credentials.""" """Generate an md5 auth hash for the protocol on the supplied credentials."""
un = creds.username un = creds.username
pw = creds.password pw = creds.password
@ -388,19 +390,19 @@ class KlapTransport(BaseTransport):
@staticmethod @staticmethod
def handshake1_seed_auth_hash( def handshake1_seed_auth_hash(
local_seed: bytes, remote_seed: bytes, auth_hash: bytes local_seed: bytes, remote_seed: bytes, auth_hash: bytes
): ) -> bytes:
"""Generate an md5 auth hash for the protocol on the supplied credentials.""" """Generate an md5 auth hash for the protocol on the supplied credentials."""
return _sha256(local_seed + auth_hash) return _sha256(local_seed + auth_hash)
@staticmethod @staticmethod
def handshake2_seed_auth_hash( def handshake2_seed_auth_hash(
local_seed: bytes, remote_seed: bytes, auth_hash: bytes local_seed: bytes, remote_seed: bytes, auth_hash: bytes
): ) -> bytes:
"""Generate an md5 auth hash for the protocol on the supplied credentials.""" """Generate an md5 auth hash for the protocol on the supplied credentials."""
return _sha256(remote_seed + auth_hash) return _sha256(remote_seed + auth_hash)
@staticmethod @staticmethod
def generate_owner_hash(creds: Credentials): def generate_owner_hash(creds: Credentials) -> bytes:
"""Return the MD5 hash of the username in this object.""" """Return the MD5 hash of the username in this object."""
un = creds.username un = creds.username
return md5(un.encode()) return md5(un.encode())
@ -410,7 +412,7 @@ class KlapTransportV2(KlapTransport):
"""Implementation of the KLAP encryption protocol with v2 hanshake hashes.""" """Implementation of the KLAP encryption protocol with v2 hanshake hashes."""
@staticmethod @staticmethod
def generate_auth_hash(creds: Credentials): def generate_auth_hash(creds: Credentials) -> bytes:
"""Generate an md5 auth hash for the protocol on the supplied credentials.""" """Generate an md5 auth hash for the protocol on the supplied credentials."""
un = creds.username un = creds.username
pw = creds.password pw = creds.password
@ -420,14 +422,14 @@ class KlapTransportV2(KlapTransport):
@staticmethod @staticmethod
def handshake1_seed_auth_hash( def handshake1_seed_auth_hash(
local_seed: bytes, remote_seed: bytes, auth_hash: bytes local_seed: bytes, remote_seed: bytes, auth_hash: bytes
): ) -> bytes:
"""Generate an md5 auth hash for the protocol on the supplied credentials.""" """Generate an md5 auth hash for the protocol on the supplied credentials."""
return _sha256(local_seed + remote_seed + auth_hash) return _sha256(local_seed + remote_seed + auth_hash)
@staticmethod @staticmethod
def handshake2_seed_auth_hash( def handshake2_seed_auth_hash(
local_seed: bytes, remote_seed: bytes, auth_hash: bytes local_seed: bytes, remote_seed: bytes, auth_hash: bytes
): ) -> bytes:
"""Generate an md5 auth hash for the protocol on the supplied credentials.""" """Generate an md5 auth hash for the protocol on the supplied credentials."""
return _sha256(remote_seed + local_seed + auth_hash) return _sha256(remote_seed + local_seed + auth_hash)
@ -440,7 +442,7 @@ class KlapEncryptionSession:
_cipher: Cipher _cipher: Cipher
def __init__(self, local_seed, remote_seed, user_hash): def __init__(self, local_seed: bytes, remote_seed: bytes, user_hash: bytes) -> None:
self.local_seed = local_seed self.local_seed = local_seed
self.remote_seed = remote_seed self.remote_seed = remote_seed
self.user_hash = user_hash self.user_hash = user_hash
@ -449,11 +451,15 @@ class KlapEncryptionSession:
self._aes = algorithms.AES(self._key) self._aes = algorithms.AES(self._key)
self._sig = self._sig_derive(local_seed, remote_seed, user_hash) self._sig = self._sig_derive(local_seed, remote_seed, user_hash)
def _key_derive(self, local_seed, remote_seed, user_hash): def _key_derive(
self, local_seed: bytes, remote_seed: bytes, user_hash: bytes
) -> bytes:
payload = b"lsk" + local_seed + remote_seed + user_hash payload = b"lsk" + local_seed + remote_seed + user_hash
return hashlib.sha256(payload).digest()[:16] return hashlib.sha256(payload).digest()[:16]
def _iv_derive(self, local_seed, remote_seed, user_hash): def _iv_derive(
self, local_seed: bytes, remote_seed: bytes, user_hash: bytes
) -> tuple[bytes, int]:
# iv is first 16 bytes of sha256, where the last 4 bytes forms the # iv is first 16 bytes of sha256, where the last 4 bytes forms the
# sequence number used in requests and is incremented on each request # sequence number used in requests and is incremented on each request
payload = b"iv" + local_seed + remote_seed + user_hash payload = b"iv" + local_seed + remote_seed + user_hash
@ -461,17 +467,19 @@ class KlapEncryptionSession:
seq = int.from_bytes(fulliv[-4:], "big", signed=True) seq = int.from_bytes(fulliv[-4:], "big", signed=True)
return (fulliv[:12], seq) return (fulliv[:12], seq)
def _sig_derive(self, local_seed, remote_seed, user_hash): def _sig_derive(
self, local_seed: bytes, remote_seed: bytes, user_hash: bytes
) -> bytes:
# used to create a hash with which to prefix each request # used to create a hash with which to prefix each request
payload = b"ldk" + local_seed + remote_seed + user_hash payload = b"ldk" + local_seed + remote_seed + user_hash
return hashlib.sha256(payload).digest()[:28] return hashlib.sha256(payload).digest()[:28]
def _generate_cipher(self): def _generate_cipher(self) -> None:
iv_seq = self._iv + PACK_SIGNED_LONG(self._seq) iv_seq = self._iv + PACK_SIGNED_LONG(self._seq)
cbc = modes.CBC(iv_seq) cbc = modes.CBC(iv_seq)
self._cipher = Cipher(self._aes, cbc) self._cipher = Cipher(self._aes, cbc)
def encrypt(self, msg): def encrypt(self, msg: bytes | str) -> tuple[bytes, int]:
"""Encrypt the data and increment the sequence number.""" """Encrypt the data and increment the sequence number."""
self._seq += 1 self._seq += 1
self._generate_cipher() self._generate_cipher()
@ -488,7 +496,7 @@ class KlapEncryptionSession:
).digest() ).digest()
return (signature + ciphertext, self._seq) return (signature + ciphertext, self._seq)
def decrypt(self, msg): def decrypt(self, msg: bytes) -> str:
"""Decrypt the data.""" """Decrypt the data."""
decryptor = self._cipher.decryptor() decryptor = self._cipher.decryptor()
dp = decryptor.update(msg[32:]) + decryptor.finalize() dp = decryptor.update(msg[32:]) + decryptor.finalize()

View File

@ -135,13 +135,13 @@ class Module(ABC):
# SMARTCAMERA only modules # SMARTCAMERA only modules
Camera: Final[ModuleName[experimental.Camera]] = ModuleName("Camera") Camera: Final[ModuleName[experimental.Camera]] = ModuleName("Camera")
def __init__(self, device: Device, module: str): def __init__(self, device: Device, module: str) -> None:
self._device = device self._device = device
self._module = module self._module = module
self._module_features: dict[str, Feature] = {} self._module_features: dict[str, Feature] = {}
@abstractmethod @abstractmethod
def query(self): def query(self) -> dict:
"""Query to execute during the update cycle. """Query to execute during the update cycle.
The inheriting modules implement this to include their wanted The inheriting modules implement this to include their wanted
@ -150,10 +150,10 @@ class Module(ABC):
@property @property
@abstractmethod @abstractmethod
def data(self): def data(self) -> dict:
"""Return the module specific raw data from the last update.""" """Return the module specific raw data from the last update."""
def _initialize_features(self): # noqa: B027 def _initialize_features(self) -> None: # noqa: B027
"""Initialize features after the initial update. """Initialize features after the initial update.
This can be implemented if features depend on module query responses. This can be implemented if features depend on module query responses.
@ -162,7 +162,7 @@ class Module(ABC):
children's modules. children's modules.
""" """
async def _post_update_hook(self): # noqa: B027 async def _post_update_hook(self) -> None: # noqa: B027
"""Perform actions after a device update. """Perform actions after a device update.
This can be implemented if a module needs to perform actions each time This can be implemented if a module needs to perform actions each time
@ -171,11 +171,11 @@ class Module(ABC):
*_initialize_features* on the first update. *_initialize_features* on the first update.
""" """
def _add_feature(self, feature: Feature): def _add_feature(self, feature: Feature) -> None:
"""Add module feature.""" """Add module feature."""
id_ = feature.id id_ = feature.id
if id_ in self._module_features: if id_ in self._module_features:
raise KasaException("Duplicate id detected %s" % id_) raise KasaException(f"Duplicate id detected {id_}")
self._module_features[id_] = feature self._module_features[id_] = feature
def __repr__(self) -> str: def __repr__(self) -> str:

View File

@ -130,7 +130,7 @@ class BaseProtocol(ABC):
self._transport = transport self._transport = transport
@property @property
def _host(self): def _host(self) -> str:
return self._transport._host return self._transport._host
@property @property

View File

@ -15,7 +15,9 @@ class SmartLightEffect(LightEffectInterface, ABC):
""" """
@abstractmethod @abstractmethod
async def set_brightness(self, brightness: int, *, transition: int | None = None): async def set_brightness(
self, brightness: int, *, transition: int | None = None
) -> dict:
"""Set effect brightness.""" """Set effect brightness."""
@property @property

View File

@ -20,7 +20,7 @@ class Alarm(SmartModule):
"get_support_alarm_type_list": None, # This should be needed only once "get_support_alarm_type_list": None, # This should be needed only once
} }
def _initialize_features(self): def _initialize_features(self) -> None:
"""Initialize features. """Initialize features.
This is implemented as some features depend on device responses. This is implemented as some features depend on device responses.
@ -100,7 +100,7 @@ class Alarm(SmartModule):
"""Return current alarm sound.""" """Return current alarm sound."""
return self.data["get_alarm_configure"]["type"] return self.data["get_alarm_configure"]["type"]
async def set_alarm_sound(self, sound: str): async def set_alarm_sound(self, sound: str) -> dict:
"""Set alarm sound. """Set alarm sound.
See *alarm_sounds* for list of available sounds. See *alarm_sounds* for list of available sounds.
@ -119,7 +119,7 @@ class Alarm(SmartModule):
"""Return alarm volume.""" """Return alarm volume."""
return self.data["get_alarm_configure"]["volume"] return self.data["get_alarm_configure"]["volume"]
async def set_alarm_volume(self, volume: Literal["low", "normal", "high"]): async def set_alarm_volume(self, volume: Literal["low", "normal", "high"]) -> dict:
"""Set alarm volume.""" """Set alarm volume."""
payload = self.data["get_alarm_configure"].copy() payload = self.data["get_alarm_configure"].copy()
payload["volume"] = volume payload["volume"] = volume

View File

@ -17,7 +17,7 @@ class AutoOff(SmartModule):
REQUIRED_COMPONENT = "auto_off" REQUIRED_COMPONENT = "auto_off"
QUERY_GETTER_NAME = "get_auto_off_config" QUERY_GETTER_NAME = "get_auto_off_config"
def _initialize_features(self): def _initialize_features(self) -> None:
"""Initialize features after the initial update.""" """Initialize features after the initial update."""
self._add_feature( self._add_feature(
Feature( Feature(
@ -63,7 +63,7 @@ class AutoOff(SmartModule):
"""Return True if enabled.""" """Return True if enabled."""
return self.data["enable"] return self.data["enable"]
async def set_enabled(self, enable: bool): async def set_enabled(self, enable: bool) -> dict:
"""Enable/disable auto off.""" """Enable/disable auto off."""
return await self.call( return await self.call(
"set_auto_off_config", "set_auto_off_config",
@ -75,7 +75,7 @@ class AutoOff(SmartModule):
"""Return time until auto off.""" """Return time until auto off."""
return self.data["delay_min"] return self.data["delay_min"]
async def set_delay(self, delay: int): async def set_delay(self, delay: int) -> dict:
"""Set time until auto off.""" """Set time until auto off."""
return await self.call( return await self.call(
"set_auto_off_config", {"delay_min": delay, "enable": self.data["enable"]} "set_auto_off_config", {"delay_min": delay, "enable": self.data["enable"]}
@ -96,7 +96,7 @@ class AutoOff(SmartModule):
return self._device.time + timedelta(seconds=sysinfo["auto_off_remain_time"]) return self._device.time + timedelta(seconds=sysinfo["auto_off_remain_time"])
async def _check_supported(self): async def _check_supported(self) -> bool:
"""Additional check to see if the module is supported by the device. """Additional check to see if the module is supported by the device.
Parent devices that report components of children such as P300 will not have Parent devices that report components of children such as P300 will not have

View File

@ -12,7 +12,7 @@ class BatterySensor(SmartModule):
REQUIRED_COMPONENT = "battery_detect" REQUIRED_COMPONENT = "battery_detect"
QUERY_GETTER_NAME = "get_battery_detect_info" QUERY_GETTER_NAME = "get_battery_detect_info"
def _initialize_features(self): def _initialize_features(self) -> None:
"""Initialize features.""" """Initialize features."""
self._add_feature( self._add_feature(
Feature( Feature(
@ -48,11 +48,11 @@ class BatterySensor(SmartModule):
return {} return {}
@property @property
def battery(self): def battery(self) -> int:
"""Return battery level.""" """Return battery level."""
return self._device.sys_info["battery_percentage"] return self._device.sys_info["battery_percentage"]
@property @property
def battery_low(self): def battery_low(self) -> bool:
"""Return True if battery is low.""" """Return True if battery is low."""
return self._device.sys_info["at_low_battery"] return self._device.sys_info["at_low_battery"]

View File

@ -14,7 +14,7 @@ class Brightness(SmartModule):
REQUIRED_COMPONENT = "brightness" REQUIRED_COMPONENT = "brightness"
def _initialize_features(self): def _initialize_features(self) -> None:
"""Initialize features.""" """Initialize features."""
super()._initialize_features() super()._initialize_features()
@ -39,7 +39,7 @@ class Brightness(SmartModule):
return {} return {}
@property @property
def brightness(self): def brightness(self) -> int:
"""Return current brightness.""" """Return current brightness."""
# If the device supports effects and one is active, use its brightness # If the device supports effects and one is active, use its brightness
if ( if (
@ -49,7 +49,9 @@ class Brightness(SmartModule):
return self.data["brightness"] return self.data["brightness"]
async def set_brightness(self, brightness: int, *, transition: int | None = None): async def set_brightness(
self, brightness: int, *, transition: int | None = None
) -> dict:
"""Set the brightness. A brightness value of 0 will turn off the light. """Set the brightness. A brightness value of 0 will turn off the light.
Note, transition is not supported and will be ignored. Note, transition is not supported and will be ignored.
@ -73,6 +75,6 @@ class Brightness(SmartModule):
return await self.call("set_device_info", {"brightness": brightness}) return await self.call("set_device_info", {"brightness": brightness})
async def _check_supported(self): async def _check_supported(self) -> bool:
"""Additional check to see if the module is supported by the device.""" """Additional check to see if the module is supported by the device."""
return "brightness" in self.data return "brightness" in self.data

View File

@ -12,7 +12,7 @@ class ChildProtection(SmartModule):
REQUIRED_COMPONENT = "child_protection" REQUIRED_COMPONENT = "child_protection"
QUERY_GETTER_NAME = "get_child_protection" QUERY_GETTER_NAME = "get_child_protection"
def _initialize_features(self): def _initialize_features(self) -> None:
"""Initialize features after the initial update.""" """Initialize features after the initial update."""
self._add_feature( self._add_feature(
Feature( Feature(

View File

@ -13,7 +13,7 @@ class Cloud(SmartModule):
REQUIRED_COMPONENT = "cloud_connect" REQUIRED_COMPONENT = "cloud_connect"
MINIMUM_UPDATE_INTERVAL_SECS = 60 MINIMUM_UPDATE_INTERVAL_SECS = 60
def _initialize_features(self): def _initialize_features(self) -> None:
"""Initialize features after the initial update.""" """Initialize features after the initial update."""
self._add_feature( self._add_feature(
Feature( Feature(
@ -29,7 +29,7 @@ class Cloud(SmartModule):
) )
@property @property
def is_connected(self): def is_connected(self) -> bool:
"""Return True if device is connected to the cloud.""" """Return True if device is connected to the cloud."""
if self._has_data_error(): if self._has_data_error():
return False return False

View File

@ -12,7 +12,7 @@ class Color(SmartModule):
REQUIRED_COMPONENT = "color" REQUIRED_COMPONENT = "color"
def _initialize_features(self): def _initialize_features(self) -> None:
"""Initialize features after the initial update.""" """Initialize features after the initial update."""
self._add_feature( self._add_feature(
Feature( Feature(
@ -48,7 +48,7 @@ class Color(SmartModule):
# due to the cpython implementation. # due to the cpython implementation.
return tuple.__new__(HSV, (h, s, v)) return tuple.__new__(HSV, (h, s, v))
def _raise_for_invalid_brightness(self, value): def _raise_for_invalid_brightness(self, value: int) -> None:
"""Raise error on invalid brightness value.""" """Raise error on invalid brightness value."""
if not isinstance(value, int): if not isinstance(value, int):
raise TypeError("Brightness must be an integer") raise TypeError("Brightness must be an integer")

View File

@ -18,7 +18,7 @@ class ColorTemperature(SmartModule):
REQUIRED_COMPONENT = "color_temperature" REQUIRED_COMPONENT = "color_temperature"
def _initialize_features(self): def _initialize_features(self) -> None:
"""Initialize features.""" """Initialize features."""
self._add_feature( self._add_feature(
Feature( Feature(
@ -52,11 +52,11 @@ class ColorTemperature(SmartModule):
return ColorTempRange(*ct_range) return ColorTempRange(*ct_range)
@property @property
def color_temp(self): def color_temp(self) -> int:
"""Return current color temperature.""" """Return current color temperature."""
return self.data["color_temp"] return self.data["color_temp"]
async def set_color_temp(self, temp: int, *, brightness=None): async def set_color_temp(self, temp: int, *, brightness: int | None = None) -> dict:
"""Set the color temperature.""" """Set the color temperature."""
valid_temperature_range = self.valid_temperature_range valid_temperature_range = self.valid_temperature_range
if temp < valid_temperature_range[0] or temp > valid_temperature_range[1]: if temp < valid_temperature_range[0] or temp > valid_temperature_range[1]:

View File

@ -12,7 +12,7 @@ class ContactSensor(SmartModule):
REQUIRED_COMPONENT = None # we depend on availability of key REQUIRED_COMPONENT = None # we depend on availability of key
REQUIRED_KEY_ON_PARENT = "open" REQUIRED_KEY_ON_PARENT = "open"
def _initialize_features(self): def _initialize_features(self) -> None:
"""Initialize features after the initial update.""" """Initialize features after the initial update."""
self._add_feature( self._add_feature(
Feature( Feature(
@ -32,6 +32,6 @@ class ContactSensor(SmartModule):
return {} return {}
@property @property
def is_open(self): def is_open(self) -> bool:
"""Return True if the contact sensor is open.""" """Return True if the contact sensor is open."""
return self._device.sys_info["open"] return self._device.sys_info["open"]

View File

@ -10,7 +10,7 @@ class DeviceModule(SmartModule):
REQUIRED_COMPONENT = "device" REQUIRED_COMPONENT = "device"
async def _post_update_hook(self): async def _post_update_hook(self) -> None:
"""Perform actions after a device update. """Perform actions after a device update.
Overrides the default behaviour to disable a module if the query returns Overrides the default behaviour to disable a module if the query returns

View File

@ -2,6 +2,8 @@
from __future__ import annotations from __future__ import annotations
from typing import NoReturn
from ...emeterstatus import EmeterStatus from ...emeterstatus import EmeterStatus
from ...exceptions import KasaException from ...exceptions import KasaException
from ...interfaces.energy import Energy as EnergyInterface from ...interfaces.energy import Energy as EnergyInterface
@ -31,34 +33,34 @@ class Energy(SmartModule, EnergyInterface):
# Fallback if get_energy_usage does not provide current_power, # Fallback if get_energy_usage does not provide current_power,
# which can happen on some newer devices (e.g. P304M). # which can happen on some newer devices (e.g. P304M).
elif ( elif (
power := self.data.get("get_current_power").get("current_power") power := self.data.get("get_current_power", {}).get("current_power")
) is not None: ) is not None:
return power return power
return None return None
@property @property
@raise_if_update_error @raise_if_update_error
def energy(self): def energy(self) -> dict:
"""Return get_energy_usage results.""" """Return get_energy_usage results."""
if en := self.data.get("get_energy_usage"): if en := self.data.get("get_energy_usage"):
return en return en
return self.data return self.data
def _get_status_from_energy(self, energy) -> EmeterStatus: def _get_status_from_energy(self, energy: dict) -> EmeterStatus:
return EmeterStatus( return EmeterStatus(
{ {
"power_mw": energy.get("current_power"), "power_mw": energy.get("current_power", 0),
"total": energy.get("today_energy") / 1_000, "total": energy.get("today_energy", 0) / 1_000,
} }
) )
@property @property
@raise_if_update_error @raise_if_update_error
def status(self): def status(self) -> EmeterStatus:
"""Get the emeter status.""" """Get the emeter status."""
return self._get_status_from_energy(self.energy) return self._get_status_from_energy(self.energy)
async def get_status(self): async def get_status(self) -> EmeterStatus:
"""Return real-time statistics.""" """Return real-time statistics."""
res = await self.call("get_energy_usage") res = await self.call("get_energy_usage")
return self._get_status_from_energy(res["get_energy_usage"]) return self._get_status_from_energy(res["get_energy_usage"])
@ -67,13 +69,13 @@ class Energy(SmartModule, EnergyInterface):
@raise_if_update_error @raise_if_update_error
def consumption_this_month(self) -> float | None: def consumption_this_month(self) -> float | None:
"""Get the emeter value for this month in kWh.""" """Get the emeter value for this month in kWh."""
return self.energy.get("month_energy") / 1_000 return self.energy.get("month_energy", 0) / 1_000
@property @property
@raise_if_update_error @raise_if_update_error
def consumption_today(self) -> float | None: def consumption_today(self) -> float | None:
"""Get the emeter value for today in kWh.""" """Get the emeter value for today in kWh."""
return self.energy.get("today_energy") / 1_000 return self.energy.get("today_energy", 0) / 1_000
@property @property
@raise_if_update_error @raise_if_update_error
@ -97,22 +99,26 @@ class Energy(SmartModule, EnergyInterface):
"""Retrieve current energy readings.""" """Retrieve current energy readings."""
return self.status return self.status
async def erase_stats(self): async def erase_stats(self) -> NoReturn:
"""Erase all stats.""" """Erase all stats."""
raise KasaException("Device does not support periodic statistics") raise KasaException("Device does not support periodic statistics")
async def get_daily_stats(self, *, year=None, month=None, kwh=True) -> dict: async def get_daily_stats(
self, *, year: int | None = None, month: int | None = None, kwh: bool = True
) -> dict:
"""Return daily stats for the given year & month. """Return daily stats for the given year & month.
The return value is a dictionary of {day: energy, ...}. The return value is a dictionary of {day: energy, ...}.
""" """
raise KasaException("Device does not support periodic statistics") raise KasaException("Device does not support periodic statistics")
async def get_monthly_stats(self, *, year=None, kwh=True) -> dict: async def get_monthly_stats(
self, *, year: int | None = None, kwh: bool = True
) -> dict:
"""Return monthly stats for the given year.""" """Return monthly stats for the given year."""
raise KasaException("Device does not support periodic statistics") raise KasaException("Device does not support periodic statistics")
async def _check_supported(self): async def _check_supported(self) -> bool:
"""Additional check to see if the module is supported by the device.""" """Additional check to see if the module is supported by the device."""
# Energy module is not supported on P304M parent device # Energy module is not supported on P304M parent device
return "device_on" in self._device.sys_info return "device_on" in self._device.sys_info

View File

@ -12,7 +12,7 @@ class Fan(SmartModule, FanInterface):
REQUIRED_COMPONENT = "fan_control" REQUIRED_COMPONENT = "fan_control"
def _initialize_features(self): def _initialize_features(self) -> None:
"""Initialize features after the initial update.""" """Initialize features after the initial update."""
self._add_feature( self._add_feature(
Feature( Feature(
@ -50,7 +50,7 @@ class Fan(SmartModule, FanInterface):
"""Return fan speed level.""" """Return fan speed level."""
return 0 if self.data["device_on"] is False else self.data["fan_speed_level"] return 0 if self.data["device_on"] is False else self.data["fan_speed_level"]
async def set_fan_speed_level(self, level: int): async def set_fan_speed_level(self, level: int) -> dict:
"""Set fan speed level, 0 for off, 1-4 for on.""" """Set fan speed level, 0 for off, 1-4 for on."""
if level < 0 or level > 4: if level < 0 or level > 4:
raise ValueError("Invalid level, should be in range 0-4.") raise ValueError("Invalid level, should be in range 0-4.")
@ -65,10 +65,10 @@ class Fan(SmartModule, FanInterface):
"""Return sleep mode status.""" """Return sleep mode status."""
return self.data["fan_sleep_mode_on"] return self.data["fan_sleep_mode_on"]
async def set_sleep_mode(self, on: bool): async def set_sleep_mode(self, on: bool) -> dict:
"""Set sleep mode.""" """Set sleep mode."""
return await self.call("set_device_info", {"fan_sleep_mode_on": on}) return await self.call("set_device_info", {"fan_sleep_mode_on": on})
async def _check_supported(self): async def _check_supported(self) -> bool:
"""Is the module available on this device.""" """Is the module available on this device."""
return "fan_speed_level" in self.data return "fan_speed_level" in self.data

View File

@ -49,14 +49,14 @@ class UpdateInfo(BaseModel):
needs_upgrade: bool = Field(alias="need_to_upgrade") needs_upgrade: bool = Field(alias="need_to_upgrade")
@validator("release_date", pre=True) @validator("release_date", pre=True)
def _release_date_optional(cls, v): def _release_date_optional(cls, v: str) -> str | None:
if not v: if not v:
return None return None
return v return v
@property @property
def update_available(self): def update_available(self) -> bool:
"""Return True if update available.""" """Return True if update available."""
if self.status != 0: if self.status != 0:
return True return True
@ -69,11 +69,11 @@ class Firmware(SmartModule):
REQUIRED_COMPONENT = "firmware" REQUIRED_COMPONENT = "firmware"
MINIMUM_UPDATE_INTERVAL_SECS = 60 * 60 * 24 MINIMUM_UPDATE_INTERVAL_SECS = 60 * 60 * 24
def __init__(self, device: SmartDevice, module: str): def __init__(self, device: SmartDevice, module: str) -> None:
super().__init__(device, module) super().__init__(device, module)
self._firmware_update_info: UpdateInfo | None = None self._firmware_update_info: UpdateInfo | None = None
def _initialize_features(self): def _initialize_features(self) -> None:
"""Initialize features.""" """Initialize features."""
device = self._device device = self._device
if self.supported_version > 1: if self.supported_version > 1:
@ -183,7 +183,7 @@ class Firmware(SmartModule):
@allow_update_after @allow_update_after
async def update( async def update(
self, progress_cb: Callable[[DownloadState], Coroutine] | None = None self, progress_cb: Callable[[DownloadState], Coroutine] | None = None
): ) -> dict:
"""Update the device firmware.""" """Update the device firmware."""
if not self._firmware_update_info: if not self._firmware_update_info:
raise KasaException( raise KasaException(
@ -236,13 +236,15 @@ class Firmware(SmartModule):
else: else:
_LOGGER.warning("Unhandled state code: %s", state) _LOGGER.warning("Unhandled state code: %s", state)
return state.dict()
@property @property
def auto_update_enabled(self) -> bool: def auto_update_enabled(self) -> bool:
"""Return True if autoupdate is enabled.""" """Return True if autoupdate is enabled."""
return "enable" in self.data and self.data["enable"] return "enable" in self.data and self.data["enable"]
@allow_update_after @allow_update_after
async def set_auto_update_enabled(self, enabled: bool): async def set_auto_update_enabled(self, enabled: bool) -> dict:
"""Change autoupdate setting.""" """Change autoupdate setting."""
data = {**self.data, "enable": enabled} data = {**self.data, "enable": enabled}
await self.call("set_auto_update_info", data) return await self.call("set_auto_update_info", data)

View File

@ -23,7 +23,7 @@ class FrostProtection(SmartModule):
"""Return True if frost protection is on.""" """Return True if frost protection is on."""
return self._device.sys_info["frost_protection_on"] return self._device.sys_info["frost_protection_on"]
async def set_enabled(self, enable: bool): async def set_enabled(self, enable: bool) -> dict:
"""Enable/disable frost protection.""" """Enable/disable frost protection."""
return await self.call( return await self.call(
"set_device_info", "set_device_info",

View File

@ -12,7 +12,7 @@ class HumiditySensor(SmartModule):
REQUIRED_COMPONENT = "humidity" REQUIRED_COMPONENT = "humidity"
QUERY_GETTER_NAME = "get_comfort_humidity_config" QUERY_GETTER_NAME = "get_comfort_humidity_config"
def _initialize_features(self): def _initialize_features(self) -> None:
"""Initialize features after the initial update.""" """Initialize features after the initial update."""
self._add_feature( self._add_feature(
Feature( Feature(
@ -45,7 +45,7 @@ class HumiditySensor(SmartModule):
return {} return {}
@property @property
def humidity(self): def humidity(self) -> int:
"""Return current humidity in percentage.""" """Return current humidity in percentage."""
return self._device.sys_info["current_humidity"] return self._device.sys_info["current_humidity"]

View File

@ -19,7 +19,7 @@ class Led(SmartModule, LedInterface):
return {self.QUERY_GETTER_NAME: None} return {self.QUERY_GETTER_NAME: None}
@property @property
def mode(self): def mode(self) -> str:
"""LED mode setting. """LED mode setting.
"always", "never", "night_mode" "always", "never", "night_mode"
@ -27,12 +27,12 @@ class Led(SmartModule, LedInterface):
return self.data["led_rule"] return self.data["led_rule"]
@property @property
def led(self): def led(self) -> bool:
"""Return current led status.""" """Return current led status."""
return self.data["led_rule"] != "never" return self.data["led_rule"] != "never"
@allow_update_after @allow_update_after
async def set_led(self, enable: bool): async def set_led(self, enable: bool) -> dict:
"""Set led. """Set led.
This should probably be a select with always/never/nightmode. This should probably be a select with always/never/nightmode.
@ -41,7 +41,7 @@ class Led(SmartModule, LedInterface):
return await self.call("set_led_info", dict(self.data, **{"led_rule": rule})) return await self.call("set_led_info", dict(self.data, **{"led_rule": rule}))
@property @property
def night_mode_settings(self): def night_mode_settings(self) -> dict:
"""Night mode settings.""" """Night mode settings."""
return { return {
"start": self.data["start_time"], "start": self.data["start_time"],

View File

@ -96,7 +96,7 @@ class Light(SmartModule, LightInterface):
return await self._device.modules[Module.Color].set_hsv(hue, saturation, value) return await self._device.modules[Module.Color].set_hsv(hue, saturation, value)
async def set_color_temp( async def set_color_temp(
self, temp: int, *, brightness=None, transition: int | None = None self, temp: int, *, brightness: int | None = None, transition: int | None = None
) -> dict: ) -> dict:
"""Set the color temperature of the device in kelvin. """Set the color temperature of the device in kelvin.

View File

@ -81,7 +81,7 @@ class LightEffect(SmartModule, SmartLightEffect):
*, *,
brightness: int | None = None, brightness: int | None = None,
transition: int | None = None, transition: int | None = None,
) -> None: ) -> dict:
"""Set an effect for the device. """Set an effect for the device.
Calling this will modify the brightness of the effect on the device. Calling this will modify the brightness of the effect on the device.
@ -107,7 +107,7 @@ class LightEffect(SmartModule, SmartLightEffect):
) )
await self.set_brightness(brightness, effect_id=effect_id) await self.set_brightness(brightness, effect_id=effect_id)
await self.call("set_dynamic_light_effect_rule_enable", params) return await self.call("set_dynamic_light_effect_rule_enable", params)
@property @property
def is_active(self) -> bool: def is_active(self) -> bool:
@ -139,11 +139,11 @@ class LightEffect(SmartModule, SmartLightEffect):
*, *,
transition: int | None = None, transition: int | None = None,
effect_id: str | None = None, effect_id: str | None = None,
): ) -> dict:
"""Set effect brightness.""" """Set effect brightness."""
new_effect = self._get_effect_data(effect_id=effect_id).copy() new_effect = self._get_effect_data(effect_id=effect_id).copy()
def _replace_brightness(data, new_brightness): def _replace_brightness(data: list[int], new_brightness: int) -> list[int]:
"""Replace brightness. """Replace brightness.
The first element is the brightness, the rest are unknown. The first element is the brightness, the rest are unknown.
@ -163,7 +163,7 @@ class LightEffect(SmartModule, SmartLightEffect):
async def set_custom_effect( async def set_custom_effect(
self, self,
effect_dict: dict, effect_dict: dict,
) -> None: ) -> dict:
"""Set a custom effect on the device. """Set a custom effect on the device.
:param str effect_dict: The custom effect dict to set :param str effect_dict: The custom effect dict to set

View File

@ -29,12 +29,12 @@ class LightPreset(SmartModule, LightPresetInterface):
_presets: dict[str, LightState] _presets: dict[str, LightState]
_preset_list: list[str] _preset_list: list[str]
def __init__(self, device: SmartDevice, module: str): def __init__(self, device: SmartDevice, module: str) -> None:
super().__init__(device, module) super().__init__(device, module)
self._state_in_sysinfo = self.SYS_INFO_STATE_KEY in device.sys_info self._state_in_sysinfo = self.SYS_INFO_STATE_KEY in device.sys_info
self._brightness_only: bool = False self._brightness_only: bool = False
async def _post_update_hook(self): async def _post_update_hook(self) -> None:
"""Update the internal presets.""" """Update the internal presets."""
index = 0 index = 0
self._presets = {} self._presets = {}
@ -113,7 +113,7 @@ class LightPreset(SmartModule, LightPresetInterface):
async def set_preset( async def set_preset(
self, self,
preset_name: str, preset_name: str,
) -> None: ) -> dict:
"""Set a light preset for the device.""" """Set a light preset for the device."""
light = self._device.modules[SmartModule.Light] light = self._device.modules[SmartModule.Light]
if preset_name == self.PRESET_NOT_SET: if preset_name == self.PRESET_NOT_SET:
@ -123,14 +123,14 @@ class LightPreset(SmartModule, LightPresetInterface):
preset = LightState(brightness=100) preset = LightState(brightness=100)
elif (preset := self._presets.get(preset_name)) is None: # type: ignore[assignment] elif (preset := self._presets.get(preset_name)) is None: # type: ignore[assignment]
raise ValueError(f"{preset_name} is not a valid preset: {self.preset_list}") raise ValueError(f"{preset_name} is not a valid preset: {self.preset_list}")
await self._device.modules[SmartModule.Light].set_state(preset) return await self._device.modules[SmartModule.Light].set_state(preset)
@allow_update_after @allow_update_after
async def save_preset( async def save_preset(
self, self,
preset_name: str, preset_name: str,
preset_state: LightState, preset_state: LightState,
) -> None: ) -> dict:
"""Update the preset with preset_name with the new preset_info.""" """Update the preset with preset_name with the new preset_info."""
if preset_name not in self._presets: if preset_name not in self._presets:
raise ValueError(f"{preset_name} is not a valid preset: {self.preset_list}") raise ValueError(f"{preset_name} is not a valid preset: {self.preset_list}")
@ -138,11 +138,13 @@ class LightPreset(SmartModule, LightPresetInterface):
if self._brightness_only: if self._brightness_only:
bright_list = [state.brightness for state in self._presets.values()] bright_list = [state.brightness for state in self._presets.values()]
bright_list[index] = preset_state.brightness bright_list[index] = preset_state.brightness
await self.call("set_preset_rules", {"brightness": bright_list}) return await self.call("set_preset_rules", {"brightness": bright_list})
else: else:
state_params = asdict(preset_state) state_params = asdict(preset_state)
new_info = {k: v for k, v in state_params.items() if v is not None} new_info = {k: v for k, v in state_params.items() if v is not None}
await self.call("edit_preset_rules", {"index": index, "state": new_info}) return await self.call(
"edit_preset_rules", {"index": index, "state": new_info}
)
@property @property
def has_save_preset(self) -> bool: def has_save_preset(self) -> bool:
@ -158,7 +160,7 @@ class LightPreset(SmartModule, LightPresetInterface):
return {self.QUERY_GETTER_NAME: {"start_index": 0}} return {self.QUERY_GETTER_NAME: {"start_index": 0}}
async def _check_supported(self): async def _check_supported(self) -> bool:
"""Additional check to see if the module is supported by the device. """Additional check to see if the module is supported by the device.
Parent devices that report components of children such as ks240 will not have Parent devices that report components of children such as ks240 will not have

View File

@ -16,7 +16,7 @@ class LightStripEffect(SmartModule, SmartLightEffect):
REQUIRED_COMPONENT = "light_strip_lighting_effect" REQUIRED_COMPONENT = "light_strip_lighting_effect"
def __init__(self, device: SmartDevice, module: str): def __init__(self, device: SmartDevice, module: str) -> None:
super().__init__(device, module) super().__init__(device, module)
effect_list = [self.LIGHT_EFFECTS_OFF] effect_list = [self.LIGHT_EFFECTS_OFF]
effect_list.extend(EFFECT_NAMES) effect_list.extend(EFFECT_NAMES)
@ -66,7 +66,9 @@ class LightStripEffect(SmartModule, SmartLightEffect):
eff = self.data["lighting_effect"] eff = self.data["lighting_effect"]
return eff["brightness"] return eff["brightness"]
async def set_brightness(self, brightness: int, *, transition: int | None = None): async def set_brightness(
self, brightness: int, *, transition: int | None = None
) -> dict:
"""Set effect brightness.""" """Set effect brightness."""
if brightness <= 0: if brightness <= 0:
return await self.set_effect(self.LIGHT_EFFECTS_OFF) return await self.set_effect(self.LIGHT_EFFECTS_OFF)
@ -91,7 +93,7 @@ class LightStripEffect(SmartModule, SmartLightEffect):
*, *,
brightness: int | None = None, brightness: int | None = None,
transition: int | None = None, transition: int | None = None,
) -> None: ) -> dict:
"""Set an effect on the device. """Set an effect on the device.
If brightness or transition is defined, If brightness or transition is defined,
@ -115,8 +117,7 @@ class LightStripEffect(SmartModule, SmartLightEffect):
effect_dict = self._effect_mapping["Aurora"] effect_dict = self._effect_mapping["Aurora"]
effect_dict = {**effect_dict} effect_dict = {**effect_dict}
effect_dict["enable"] = 0 effect_dict["enable"] = 0
await self.set_custom_effect(effect_dict) return await self.set_custom_effect(effect_dict)
return
if effect not in self._effect_mapping: if effect not in self._effect_mapping:
raise ValueError(f"The effect {effect} is not a built in effect.") raise ValueError(f"The effect {effect} is not a built in effect.")
@ -134,13 +135,13 @@ class LightStripEffect(SmartModule, SmartLightEffect):
if transition is not None: if transition is not None:
effect_dict["transition"] = transition effect_dict["transition"] = transition
await self.set_custom_effect(effect_dict) return await self.set_custom_effect(effect_dict)
@allow_update_after @allow_update_after
async def set_custom_effect( async def set_custom_effect(
self, self,
effect_dict: dict, effect_dict: dict,
) -> None: ) -> dict:
"""Set a custom effect on the device. """Set a custom effect on the device.
:param str effect_dict: The custom effect dict to set :param str effect_dict: The custom effect dict to set
@ -155,7 +156,7 @@ class LightStripEffect(SmartModule, SmartLightEffect):
"""Return True if the device supports setting custom effects.""" """Return True if the device supports setting custom effects."""
return True return True
def query(self): def query(self) -> dict:
"""Return the base query.""" """Return the base query."""
return {} return {}

View File

@ -39,14 +39,14 @@ class LightTransition(SmartModule):
_off_state: _State _off_state: _State
_enabled: bool _enabled: bool
def __init__(self, device: SmartDevice, module: str): def __init__(self, device: SmartDevice, module: str) -> None:
super().__init__(device, module) super().__init__(device, module)
self._state_in_sysinfo = all( self._state_in_sysinfo = all(
key in device.sys_info for key in self.SYS_INFO_STATE_KEYS key in device.sys_info for key in self.SYS_INFO_STATE_KEYS
) )
self._supports_on_and_off: bool = self.supported_version > 1 self._supports_on_and_off: bool = self.supported_version > 1
def _initialize_features(self): def _initialize_features(self) -> None:
"""Initialize features.""" """Initialize features."""
icon = "mdi:transition" icon = "mdi:transition"
if not self._supports_on_and_off: if not self._supports_on_and_off:
@ -138,7 +138,7 @@ class LightTransition(SmartModule):
} }
@allow_update_after @allow_update_after
async def set_enabled(self, enable: bool): async def set_enabled(self, enable: bool) -> dict:
"""Enable gradual on/off.""" """Enable gradual on/off."""
if not self._supports_on_and_off: if not self._supports_on_and_off:
return await self.call("set_on_off_gradually_info", {"enable": enable}) return await self.call("set_on_off_gradually_info", {"enable": enable})
@ -171,7 +171,7 @@ class LightTransition(SmartModule):
return self._on_state["max_duration"] return self._on_state["max_duration"]
@allow_update_after @allow_update_after
async def set_turn_on_transition(self, seconds: int): async def set_turn_on_transition(self, seconds: int) -> dict:
"""Set turn on transition in seconds. """Set turn on transition in seconds.
Setting to 0 turns the feature off. Setting to 0 turns the feature off.
@ -207,7 +207,7 @@ class LightTransition(SmartModule):
return self._off_state["max_duration"] return self._off_state["max_duration"]
@allow_update_after @allow_update_after
async def set_turn_off_transition(self, seconds: int): async def set_turn_off_transition(self, seconds: int) -> dict:
"""Set turn on transition in seconds. """Set turn on transition in seconds.
Setting to 0 turns the feature off. Setting to 0 turns the feature off.
@ -236,7 +236,7 @@ class LightTransition(SmartModule):
else: else:
return {self.QUERY_GETTER_NAME: None} return {self.QUERY_GETTER_NAME: None}
async def _check_supported(self): async def _check_supported(self) -> bool:
"""Additional check to see if the module is supported by the device.""" """Additional check to see if the module is supported by the device."""
# For devices that report child components on the parent that are not # For devices that report child components on the parent that are not
# actually supported by the parent. # actually supported by the parent.

View File

@ -11,7 +11,7 @@ class MotionSensor(SmartModule):
REQUIRED_COMPONENT = "sensitivity" REQUIRED_COMPONENT = "sensitivity"
def _initialize_features(self): def _initialize_features(self) -> None:
"""Initialize features.""" """Initialize features."""
self._add_feature( self._add_feature(
Feature( Feature(
@ -31,6 +31,6 @@ class MotionSensor(SmartModule):
return {} return {}
@property @property
def motion_detected(self): def motion_detected(self) -> bool:
"""Return True if the motion has been detected.""" """Return True if the motion has been detected."""
return self._device.sys_info["detected"] return self._device.sys_info["detected"]

View File

@ -12,7 +12,7 @@ class ReportMode(SmartModule):
REQUIRED_COMPONENT = "report_mode" REQUIRED_COMPONENT = "report_mode"
QUERY_GETTER_NAME = "get_report_mode" QUERY_GETTER_NAME = "get_report_mode"
def _initialize_features(self): def _initialize_features(self) -> None:
"""Initialize features after the initial update.""" """Initialize features after the initial update."""
self._add_feature( self._add_feature(
Feature( Feature(
@ -32,6 +32,6 @@ class ReportMode(SmartModule):
return {} return {}
@property @property
def report_interval(self): def report_interval(self) -> int:
"""Reporting interval of a sensor device.""" """Reporting interval of a sensor device."""
return self._device.sys_info["report_interval"] return self._device.sys_info["report_interval"]

View File

@ -26,7 +26,7 @@ class TemperatureControl(SmartModule):
REQUIRED_COMPONENT = "temp_control" REQUIRED_COMPONENT = "temp_control"
def _initialize_features(self): def _initialize_features(self) -> None:
"""Initialize features after the initial update.""" """Initialize features after the initial update."""
self._add_feature( self._add_feature(
Feature( Feature(
@ -92,7 +92,7 @@ class TemperatureControl(SmartModule):
"""Return thermostat state.""" """Return thermostat state."""
return self._device.sys_info["frost_protection_on"] is False return self._device.sys_info["frost_protection_on"] is False
async def set_state(self, enabled: bool): async def set_state(self, enabled: bool) -> dict:
"""Set thermostat state.""" """Set thermostat state."""
return await self.call("set_device_info", {"frost_protection_on": not enabled}) return await self.call("set_device_info", {"frost_protection_on": not enabled})
@ -147,7 +147,7 @@ class TemperatureControl(SmartModule):
"""Return thermostat states.""" """Return thermostat states."""
return set(self._device.sys_info["trv_states"]) return set(self._device.sys_info["trv_states"])
async def set_target_temperature(self, target: float): async def set_target_temperature(self, target: float) -> dict:
"""Set target temperature.""" """Set target temperature."""
if ( if (
target < self.minimum_target_temperature target < self.minimum_target_temperature
@ -170,7 +170,7 @@ class TemperatureControl(SmartModule):
"""Return temperature offset.""" """Return temperature offset."""
return self._device.sys_info["temp_offset"] return self._device.sys_info["temp_offset"]
async def set_temperature_offset(self, offset: int): async def set_temperature_offset(self, offset: int) -> dict:
"""Set temperature offset.""" """Set temperature offset."""
if offset < -10 or offset > 10: if offset < -10 or offset > 10:
raise ValueError("Temperature offset must be [-10, 10]") raise ValueError("Temperature offset must be [-10, 10]")

View File

@ -14,7 +14,7 @@ class TemperatureSensor(SmartModule):
REQUIRED_COMPONENT = "temperature" REQUIRED_COMPONENT = "temperature"
QUERY_GETTER_NAME = "get_comfort_temp_config" QUERY_GETTER_NAME = "get_comfort_temp_config"
def _initialize_features(self): def _initialize_features(self) -> None:
"""Initialize features after the initial update.""" """Initialize features after the initial update."""
self._add_feature( self._add_feature(
Feature( Feature(
@ -60,7 +60,7 @@ class TemperatureSensor(SmartModule):
return {} return {}
@property @property
def temperature(self): def temperature(self) -> float:
"""Return current humidity in percentage.""" """Return current humidity in percentage."""
return self._device.sys_info["current_temp"] return self._device.sys_info["current_temp"]
@ -74,6 +74,8 @@ class TemperatureSensor(SmartModule):
"""Return current temperature unit.""" """Return current temperature unit."""
return self._device.sys_info["temp_unit"] return self._device.sys_info["temp_unit"]
async def set_temperature_unit(self, unit: Literal["celsius", "fahrenheit"]): async def set_temperature_unit(
self, unit: Literal["celsius", "fahrenheit"]
) -> dict:
"""Set the device temperature unit.""" """Set the device temperature unit."""
return await self.call("set_temperature_unit", {"temp_unit": unit}) return await self.call("set_temperature_unit", {"temp_unit": unit})

View File

@ -21,7 +21,7 @@ class Time(SmartModule, TimeInterface):
_timezone: tzinfo = timezone.utc _timezone: tzinfo = timezone.utc
def _initialize_features(self): def _initialize_features(self) -> None:
"""Initialize features after the initial update.""" """Initialize features after the initial update."""
self._add_feature( self._add_feature(
Feature( Feature(
@ -35,7 +35,7 @@ class Time(SmartModule, TimeInterface):
) )
) )
async def _post_update_hook(self): async def _post_update_hook(self) -> None:
"""Perform actions after a device update.""" """Perform actions after a device update."""
td = timedelta(minutes=cast(float, self.data.get("time_diff"))) td = timedelta(minutes=cast(float, self.data.get("time_diff")))
if region := self.data.get("region"): if region := self.data.get("region"):
@ -84,7 +84,7 @@ class Time(SmartModule, TimeInterface):
params["region"] = region params["region"] = region
return await self.call("set_device_time", params) return await self.call("set_device_time", params)
async def _check_supported(self): async def _check_supported(self) -> bool:
"""Additional check to see if the module is supported by the device. """Additional check to see if the module is supported by the device.
Hub attached sensors report the time module but do return device time. Hub attached sensors report the time module but do return device time.

View File

@ -22,7 +22,7 @@ class WaterleakSensor(SmartModule):
REQUIRED_COMPONENT = "sensor_alarm" REQUIRED_COMPONENT = "sensor_alarm"
def _initialize_features(self): def _initialize_features(self) -> None:
"""Initialize features after the initial update.""" """Initialize features after the initial update."""
self._add_feature( self._add_feature(
Feature( Feature(

View File

@ -49,7 +49,7 @@ class SmartChildDevice(SmartDevice):
self._update_internal_state(info) self._update_internal_state(info)
self._components = component_info self._components = component_info
async def update(self, update_children: bool = True): async def update(self, update_children: bool = True) -> None:
"""Update child module info. """Update child module info.
The parent updates our internal info so just update modules with The parent updates our internal info so just update modules with
@ -57,7 +57,7 @@ class SmartChildDevice(SmartDevice):
""" """
await self._update(update_children) await self._update(update_children)
async def _update(self, update_children: bool = True): async def _update(self, update_children: bool = True) -> None:
"""Update child module info. """Update child module info.
Internal implementation to allow patching of public update in the cli Internal implementation to allow patching of public update in the cli
@ -118,5 +118,5 @@ class SmartChildDevice(SmartDevice):
dev_type = DeviceType.Unknown dev_type = DeviceType.Unknown
return dev_type return dev_type
def __repr__(self): def __repr__(self) -> str:
return f"<{self.device_type} {self.alias} ({self.model}) of {self._parent}>" return f"<{self.device_type} {self.alias} ({self.model}) of {self._parent}>"

View File

@ -69,7 +69,7 @@ class SmartDevice(Device):
self._on_since: datetime | None = None self._on_since: datetime | None = None
self._info: dict[str, Any] = {} self._info: dict[str, Any] = {}
async def _initialize_children(self): async def _initialize_children(self) -> None:
"""Initialize children for power strips.""" """Initialize children for power strips."""
child_info_query = { child_info_query = {
"get_child_device_component_list": None, "get_child_device_component_list": None,
@ -108,7 +108,9 @@ class SmartDevice(Device):
"""Return the device modules.""" """Return the device modules."""
return cast(ModuleMapping[SmartModule], self._modules) return cast(ModuleMapping[SmartModule], self._modules)
def _try_get_response(self, responses: dict, request: str, default=None) -> dict: def _try_get_response(
self, responses: dict, request: str, default: Any | None = None
) -> dict:
response = responses.get(request) response = responses.get(request)
if isinstance(response, SmartErrorCode): if isinstance(response, SmartErrorCode):
_LOGGER.debug( _LOGGER.debug(
@ -126,7 +128,7 @@ class SmartDevice(Device):
f"{request} not found in {responses} for device {self.host}" f"{request} not found in {responses} for device {self.host}"
) )
async def _negotiate(self): async def _negotiate(self) -> None:
"""Perform initialization. """Perform initialization.
We fetch the device info and the available components as early as possible. We fetch the device info and the available components as early as possible.
@ -146,7 +148,8 @@ class SmartDevice(Device):
self._info = self._try_get_response(resp, "get_device_info") self._info = self._try_get_response(resp, "get_device_info")
# Create our internal presentation of available components # Create our internal presentation of available components
self._components_raw = resp["component_nego"] self._components_raw = cast(dict, resp["component_nego"])
self._components = { self._components = {
comp["id"]: int(comp["ver_code"]) comp["id"]: int(comp["ver_code"])
for comp in self._components_raw["component_list"] for comp in self._components_raw["component_list"]
@ -167,7 +170,7 @@ class SmartDevice(Device):
"""Update the internal device info.""" """Update the internal device info."""
self._info = self._try_get_response(info_resp, "get_device_info") self._info = self._try_get_response(info_resp, "get_device_info")
async def update(self, update_children: bool = False): async def update(self, update_children: bool = False) -> None:
"""Update the device.""" """Update the device."""
if self.credentials is None and self.credentials_hash is None: if self.credentials is None and self.credentials_hash is None:
raise AuthenticationError("Tapo plug requires authentication.") raise AuthenticationError("Tapo plug requires authentication.")
@ -206,7 +209,7 @@ class SmartDevice(Device):
async def _handle_module_post_update( async def _handle_module_post_update(
self, module: SmartModule, update_time: float, had_query: bool self, module: SmartModule, update_time: float, had_query: bool
): ) -> None:
if module.disabled: if module.disabled:
return # pragma: no cover return # pragma: no cover
if had_query: if had_query:
@ -312,7 +315,7 @@ class SmartDevice(Device):
responses[meth] = SmartErrorCode.INTERNAL_QUERY_ERROR responses[meth] = SmartErrorCode.INTERNAL_QUERY_ERROR
return responses return responses
async def _initialize_modules(self): async def _initialize_modules(self) -> None:
"""Initialize modules based on component negotiation response.""" """Initialize modules based on component negotiation response."""
from .smartmodule import SmartModule from .smartmodule import SmartModule
@ -324,7 +327,7 @@ class SmartDevice(Device):
# It also ensures that devices like power strips do not add modules such as # It also ensures that devices like power strips do not add modules such as
# firmware to the child devices. # firmware to the child devices.
skip_parent_only_modules = False skip_parent_only_modules = False
child_modules_to_skip = {} child_modules_to_skip: dict = {} # TODO: this is never non-empty
if self._parent and self._parent.device_type != DeviceType.Hub: if self._parent and self._parent.device_type != DeviceType.Hub:
skip_parent_only_modules = True skip_parent_only_modules = True
@ -333,17 +336,18 @@ class SmartDevice(Device):
skip_parent_only_modules and mod in NON_HUB_PARENT_ONLY_MODULES skip_parent_only_modules and mod in NON_HUB_PARENT_ONLY_MODULES
) or mod.__name__ in child_modules_to_skip: ) or mod.__name__ in child_modules_to_skip:
continue continue
if ( required_component = cast(str, mod.REQUIRED_COMPONENT)
mod.REQUIRED_COMPONENT in self._components if required_component in self._components or (
or self.sys_info.get(mod.REQUIRED_KEY_ON_PARENT) is not None mod.REQUIRED_KEY_ON_PARENT
and self.sys_info.get(mod.REQUIRED_KEY_ON_PARENT) is not None
): ):
_LOGGER.debug( _LOGGER.debug(
"Device %s, found required %s, adding %s to modules.", "Device %s, found required %s, adding %s to modules.",
self.host, self.host,
mod.REQUIRED_COMPONENT, required_component,
mod.__name__, mod.__name__,
) )
module = mod(self, mod.REQUIRED_COMPONENT) module = mod(self, required_component)
if await module._check_supported(): if await module._check_supported():
self._modules[module.name] = module self._modules[module.name] = module
@ -354,7 +358,7 @@ class SmartDevice(Device):
): ):
self._modules[Light.__name__] = Light(self, "light") self._modules[Light.__name__] = Light(self, "light")
async def _initialize_features(self): async def _initialize_features(self) -> None:
"""Initialize device features.""" """Initialize device features."""
self._add_feature( self._add_feature(
Feature( Feature(
@ -575,11 +579,11 @@ class SmartDevice(Device):
return str(self._info.get("device_id")) return str(self._info.get("device_id"))
@property @property
def internal_state(self) -> Any: def internal_state(self) -> dict:
"""Return all the internal state data.""" """Return all the internal state data."""
return self._last_update return self._last_update
def _update_internal_state(self, info: dict) -> None: def _update_internal_state(self, info: dict[str, Any]) -> None:
"""Update the internal info state. """Update the internal info state.
This is used by the parent to push updates to its children. This is used by the parent to push updates to its children.
@ -587,8 +591,8 @@ class SmartDevice(Device):
self._info = info self._info = info
async def _query_helper( async def _query_helper(
self, method: str, params: dict | None = None, child_ids=None self, method: str, params: dict | None = None, child_ids: None = None
) -> Any: ) -> dict:
res = await self.protocol.query({method: params}) res = await self.protocol.query({method: params})
return res return res
@ -610,22 +614,25 @@ class SmartDevice(Device):
"""Return true if the device is on.""" """Return true if the device is on."""
return bool(self._info.get("device_on")) return bool(self._info.get("device_on"))
async def set_state(self, on: bool): # TODO: better name wanted. async def set_state(self, on: bool) -> dict:
"""Set the device state. """Set the device state.
See :meth:`is_on`. See :meth:`is_on`.
""" """
return await self.protocol.query({"set_device_info": {"device_on": on}}) return await self.protocol.query({"set_device_info": {"device_on": on}})
async def turn_on(self, **kwargs): async def turn_on(self, **kwargs: Any) -> dict:
"""Turn on the device.""" """Turn on the device."""
await self.set_state(True) return await self.set_state(True)
async def turn_off(self, **kwargs): async def turn_off(self, **kwargs: Any) -> dict:
"""Turn off the device.""" """Turn off the device."""
await self.set_state(False) return await self.set_state(False)
def update_from_discover_info(self, info): def update_from_discover_info(
self,
info: dict,
) -> None:
"""Update state from info from the discover call.""" """Update state from info from the discover call."""
self._discovery_info = info self._discovery_info = info
self._info = info self._info = info
@ -633,7 +640,7 @@ class SmartDevice(Device):
async def wifi_scan(self) -> list[WifiNetwork]: async def wifi_scan(self) -> list[WifiNetwork]:
"""Scan for available wifi networks.""" """Scan for available wifi networks."""
def _net_for_scan_info(res): def _net_for_scan_info(res: dict) -> WifiNetwork:
return WifiNetwork( return WifiNetwork(
ssid=base64.b64decode(res["ssid"]).decode(), ssid=base64.b64decode(res["ssid"]).decode(),
cipher_type=res["cipher_type"], cipher_type=res["cipher_type"],
@ -651,7 +658,9 @@ class SmartDevice(Device):
] ]
return networks return networks
async def wifi_join(self, ssid: str, password: str, keytype: str = "wpa2_psk"): async def wifi_join(
self, ssid: str, password: str, keytype: str = "wpa2_psk"
) -> dict:
"""Join the given wifi network. """Join the given wifi network.
This method returns nothing as the device tries to activate the new This method returns nothing as the device tries to activate the new
@ -688,9 +697,12 @@ class SmartDevice(Device):
except DeviceError: except DeviceError:
raise # Re-raise on device-reported errors raise # Re-raise on device-reported errors
except KasaException: except KasaException:
_LOGGER.debug("Received an expected for wifi join, but this is expected") _LOGGER.debug(
"Received a kasa exception for wifi join, but this is expected"
)
return {}
async def update_credentials(self, username: str, password: str): async def update_credentials(self, username: str, password: str) -> dict:
"""Update device credentials. """Update device credentials.
This will replace the existing authentication credentials on the device. This will replace the existing authentication credentials on the device.
@ -705,7 +717,7 @@ class SmartDevice(Device):
} }
return await self.protocol.query({"set_qs_info": payload}) return await self.protocol.query({"set_qs_info": payload})
async def set_alias(self, alias: str): async def set_alias(self, alias: str) -> dict:
"""Set the device name (alias).""" """Set the device name (alias)."""
return await self.protocol.query( return await self.protocol.query(
{"set_device_info": {"nickname": base64.b64encode(alias.encode()).decode()}} {"set_device_info": {"nickname": base64.b64encode(alias.encode()).decode()}}

View File

@ -22,17 +22,17 @@ _R = TypeVar("_R")
def allow_update_after( def allow_update_after(
func: Callable[Concatenate[_T, _P], Awaitable[None]], func: Callable[Concatenate[_T, _P], Awaitable[dict]],
) -> Callable[Concatenate[_T, _P], Coroutine[Any, Any, None]]: ) -> Callable[Concatenate[_T, _P], Coroutine[Any, Any, dict]]:
"""Define a wrapper to set _last_update_time to None. """Define a wrapper to set _last_update_time to None.
This will ensure that a module is updated in the next update cycle after This will ensure that a module is updated in the next update cycle after
a value has been changed. a value has been changed.
""" """
async def _async_wrap(self: _T, *args: _P.args, **kwargs: _P.kwargs) -> None: async def _async_wrap(self: _T, *args: _P.args, **kwargs: _P.kwargs) -> dict:
try: try:
await func(self, *args, **kwargs) return await func(self, *args, **kwargs)
finally: finally:
self._last_update_time = None self._last_update_time = None
@ -68,21 +68,21 @@ class SmartModule(Module):
DISABLE_AFTER_ERROR_COUNT = 10 DISABLE_AFTER_ERROR_COUNT = 10
def __init__(self, device: SmartDevice, module: str): def __init__(self, device: SmartDevice, module: str) -> None:
self._device: SmartDevice self._device: SmartDevice
super().__init__(device, module) super().__init__(device, module)
self._last_update_time: float | None = None self._last_update_time: float | None = None
self._last_update_error: KasaException | None = None self._last_update_error: KasaException | None = None
self._error_count = 0 self._error_count = 0
def __init_subclass__(cls, **kwargs): def __init_subclass__(cls, **kwargs) -> None:
# We only want to register submodules in a modules package so that # We only want to register submodules in a modules package so that
# other classes can inherit from smartmodule and not be registered # other classes can inherit from smartmodule and not be registered
if cls.__module__.split(".")[-2] == "modules": if cls.__module__.split(".")[-2] == "modules":
_LOGGER.debug("Registering %s", cls) _LOGGER.debug("Registering %s", cls)
cls.REGISTERED_MODULES[cls._module_name()] = cls cls.REGISTERED_MODULES[cls._module_name()] = cls
def _set_error(self, err: Exception | None): def _set_error(self, err: Exception | None) -> None:
if err is None: if err is None:
self._error_count = 0 self._error_count = 0
self._last_update_error = None self._last_update_error = None
@ -119,7 +119,7 @@ class SmartModule(Module):
return self._error_count >= self.DISABLE_AFTER_ERROR_COUNT return self._error_count >= self.DISABLE_AFTER_ERROR_COUNT
@classmethod @classmethod
def _module_name(cls): def _module_name(cls) -> str:
return getattr(cls, "NAME", cls.__name__) return getattr(cls, "NAME", cls.__name__)
@property @property
@ -127,7 +127,7 @@ class SmartModule(Module):
"""Name of the module.""" """Name of the module."""
return self._module_name() return self._module_name()
async def _post_update_hook(self): # noqa: B027 async def _post_update_hook(self) -> None: # noqa: B027
"""Perform actions after a device update. """Perform actions after a device update.
Any modules overriding this should ensure that self.data is Any modules overriding this should ensure that self.data is
@ -142,7 +142,7 @@ class SmartModule(Module):
""" """
return {self.QUERY_GETTER_NAME: None} return {self.QUERY_GETTER_NAME: None}
async def call(self, method, params=None): async def call(self, method: str, params: dict | None = None) -> dict:
"""Call a method. """Call a method.
Just a helper method. Just a helper method.
@ -150,7 +150,7 @@ class SmartModule(Module):
return await self._device._query_helper(method, params) return await self._device._query_helper(method, params)
@property @property
def data(self): def data(self) -> dict[str, Any]:
"""Return response data for the module. """Return response data for the module.
If the module performs only a single query, the resulting response is unwrapped. If the module performs only a single query, the resulting response is unwrapped.

View File

@ -72,7 +72,7 @@ class SmartProtocol(BaseProtocol):
) )
self._redact_data = True self._redact_data = True
def get_smart_request(self, method, params=None) -> str: def get_smart_request(self, method: str, params: dict | None = None) -> str:
"""Get a request message as a string.""" """Get a request message as a string."""
request = { request = {
"method": method, "method": method,
@ -289,8 +289,8 @@ class SmartProtocol(BaseProtocol):
return {smart_method: result} return {smart_method: result}
async def _handle_response_lists( async def _handle_response_lists(
self, response_result: dict[str, Any], method, retry_count self, response_result: dict[str, Any], method: str, retry_count: int
): ) -> None:
if ( if (
response_result is None response_result is None
or isinstance(response_result, SmartErrorCode) or isinstance(response_result, SmartErrorCode)
@ -325,7 +325,9 @@ class SmartProtocol(BaseProtocol):
break break
response_result[response_list_name].extend(next_batch[response_list_name]) response_result[response_list_name].extend(next_batch[response_list_name])
def _handle_response_error_code(self, resp_dict: dict, method, raise_on_error=True): def _handle_response_error_code(
self, resp_dict: dict, method: str, raise_on_error: bool = True
) -> None:
error_code_raw = resp_dict.get("error_code") error_code_raw = resp_dict.get("error_code")
try: try:
error_code = SmartErrorCode.from_int(error_code_raw) error_code = SmartErrorCode.from_int(error_code_raw)
@ -369,12 +371,12 @@ class _ChildProtocolWrapper(SmartProtocol):
device responses before returning to the caller. device responses before returning to the caller.
""" """
def __init__(self, device_id: str, base_protocol: SmartProtocol): def __init__(self, device_id: str, base_protocol: SmartProtocol) -> None:
self._device_id = device_id self._device_id = device_id
self._protocol = base_protocol self._protocol = base_protocol
self._transport = base_protocol._transport self._transport = base_protocol._transport
def _get_method_and_params_for_request(self, request): def _get_method_and_params_for_request(self, request: dict[str, Any] | str) -> Any:
"""Return payload for wrapping. """Return payload for wrapping.
TODO: this does not support batches and requires refactoring in the future. TODO: this does not support batches and requires refactoring in the future.

View File

@ -310,9 +310,7 @@ class FakeSmartTransport(BaseTransport):
} }
return retval return retval
raise NotImplementedError( raise NotImplementedError(f"Method {child_method} not implemented for children")
"Method %s not implemented for children" % child_method
)
def _get_on_off_gradually_info(self, info, params): def _get_on_off_gradually_info(self, info, params):
if self.components["on_off_gradually"] == 1: if self.components["on_off_gradually"] == 1:

View File

@ -41,7 +41,7 @@ async def test_firmware_features(
await fw.check_latest_firmware() await fw.check_latest_firmware()
if fw.supported_version < required_version: if fw.supported_version < required_version:
pytest.skip("Feature %s requires newer version" % feature) pytest.skip(f"Feature {feature} requires newer version")
prop = getattr(fw, prop_name) prop = getattr(fw, prop_name)
assert isinstance(prop, type) assert isinstance(prop, type)

View File

@ -48,7 +48,7 @@ class XorTransport(BaseTransport):
self.loop: asyncio.AbstractEventLoop | None = None self.loop: asyncio.AbstractEventLoop | None = None
@property @property
def default_port(self): def default_port(self) -> int:
"""Default port for the transport.""" """Default port for the transport."""
return self.DEFAULT_PORT return self.DEFAULT_PORT

View File

@ -139,10 +139,15 @@ select = [
"PT", # flake8-pytest-style "PT", # flake8-pytest-style
"LOG", # flake8-logging "LOG", # flake8-logging
"G", # flake8-logging-format "G", # flake8-logging-format
"ANN", # annotations
] ]
ignore = [ ignore = [
"D105", # Missing docstring in magic method "D105", # Missing docstring in magic method
"D107", # Missing docstring in `__init__` "D107", # Missing docstring in `__init__`
"ANN101", # Missing type annotation for `self`
"ANN102", # Missing type annotation for `cls` in classmethod
"ANN003", # Missing type annotation for `**kwargs`
"ANN401", # allow any
] ]
[tool.ruff.lint.pydocstyle] [tool.ruff.lint.pydocstyle]
@ -157,11 +162,21 @@ convention = "pep257"
"D104", "D104",
"S101", # allow asserts "S101", # allow asserts
"E501", # ignore line-too-longs "E501", # ignore line-too-longs
"ANN", # skip for now
] ]
"docs/source/conf.py" = [ "docs/source/conf.py" = [
"D100", "D100",
"D103", "D103",
] ]
# Temporary ANN disable
"kasa/cli/*.py" = [
"ANN",
]
# Temporary ANN disable
"devtools/*.py" = [
"ANN",
]
[tool.mypy] [tool.mypy]
warn_unused_configs = true # warns if overrides sections unused/mis-spelled warn_unused_configs = true # warns if overrides sections unused/mis-spelled