mirror of
https://github.com/python-kasa/python-kasa.git
synced 2024-12-22 11:13:34 +00:00
Enable ruff check for ANN (#1139)
This commit is contained in:
parent
6b44fe6242
commit
66eb17057e
@ -66,6 +66,6 @@ todo_include_todos = True
|
||||
myst_heading_anchors = 3
|
||||
|
||||
|
||||
def setup(app):
|
||||
def setup(app): # noqa: ANN201,ANN001
|
||||
# add copybutton to hide the >>> prompts, see https://github.com/readthedocs/sphinx_rtd_theme/issues/167
|
||||
app.add_js_file("copybutton.js")
|
||||
|
@ -13,7 +13,7 @@ to be handled by the user of the library.
|
||||
"""
|
||||
|
||||
from importlib.metadata import version
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from warnings import warn
|
||||
|
||||
from kasa.credentials import Credentials
|
||||
@ -101,7 +101,7 @@ deprecated_classes = {
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name):
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name in deprecated_names:
|
||||
warn(f"{name} is deprecated", DeprecationWarning, stacklevel=2)
|
||||
return globals()[f"_deprecated_{name}"]
|
||||
@ -117,7 +117,7 @@ def __getattr__(name):
|
||||
)
|
||||
return new_class
|
||||
if name in deprecated_classes:
|
||||
new_class = deprecated_classes[name]
|
||||
new_class = deprecated_classes[name] # type: ignore[assignment]
|
||||
msg = f"{name} is deprecated, use {new_class.__name__} instead"
|
||||
warn(msg, DeprecationWarning, stacklevel=2)
|
||||
return new_class
|
||||
|
@ -146,7 +146,7 @@ class AesTransport(BaseTransport):
|
||||
pw = base64.b64encode(credentials.password.encode()).decode()
|
||||
return un, pw
|
||||
|
||||
def _handle_response_error_code(self, resp_dict: Any, msg: str) -> None:
|
||||
def _handle_response_error_code(self, resp_dict: dict, msg: str) -> None:
|
||||
error_code_raw = resp_dict.get("error_code")
|
||||
try:
|
||||
error_code = SmartErrorCode.from_int(error_code_raw)
|
||||
@ -191,14 +191,14 @@ class AesTransport(BaseTransport):
|
||||
+ f"status code {status_code} to passthrough"
|
||||
)
|
||||
|
||||
self._handle_response_error_code(
|
||||
resp_dict, "Error sending secure_passthrough message"
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
resp_dict = cast(Dict[str, Any], resp_dict)
|
||||
assert self._encryption_session is not None
|
||||
|
||||
self._handle_response_error_code(
|
||||
resp_dict, "Error sending secure_passthrough message"
|
||||
)
|
||||
|
||||
raw_response: str = resp_dict["result"]["response"]
|
||||
|
||||
try:
|
||||
@ -219,7 +219,7 @@ class AesTransport(BaseTransport):
|
||||
) from ex
|
||||
return ret_val # type: ignore[return-value]
|
||||
|
||||
async def perform_login(self):
|
||||
async def perform_login(self) -> None:
|
||||
"""Login to the device."""
|
||||
try:
|
||||
await self.try_login(self._login_params)
|
||||
@ -324,11 +324,11 @@ class AesTransport(BaseTransport):
|
||||
+ f"status code {status_code} to handshake"
|
||||
)
|
||||
|
||||
self._handle_response_error_code(resp_dict, "Unable to complete handshake")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
resp_dict = cast(Dict[str, Any], resp_dict)
|
||||
|
||||
self._handle_response_error_code(resp_dict, "Unable to complete handshake")
|
||||
|
||||
handshake_key = resp_dict["result"]["key"]
|
||||
|
||||
if (
|
||||
@ -355,7 +355,7 @@ class AesTransport(BaseTransport):
|
||||
|
||||
_LOGGER.debug("Handshake with %s complete", self._host)
|
||||
|
||||
def _handshake_session_expired(self):
|
||||
def _handshake_session_expired(self) -> bool:
|
||||
"""Return true if session has expired."""
|
||||
return (
|
||||
self._session_expire_at is None
|
||||
@ -394,7 +394,9 @@ class AesEncyptionSession:
|
||||
"""Class for an AES encryption session."""
|
||||
|
||||
@staticmethod
|
||||
def create_from_keypair(handshake_key: str, keypair: KeyPair):
|
||||
def create_from_keypair(
|
||||
handshake_key: str, keypair: KeyPair
|
||||
) -> AesEncyptionSession:
|
||||
"""Create the encryption session."""
|
||||
handshake_key_bytes: bytes = base64.b64decode(handshake_key.encode())
|
||||
|
||||
@ -404,11 +406,11 @@ class AesEncyptionSession:
|
||||
|
||||
return AesEncyptionSession(key_and_iv[:16], key_and_iv[16:])
|
||||
|
||||
def __init__(self, key, iv):
|
||||
def __init__(self, key: bytes, iv: bytes) -> None:
|
||||
self.cipher = Cipher(algorithms.AES(key), modes.CBC(iv))
|
||||
self.padding_strategy = padding.PKCS7(algorithms.AES.block_size)
|
||||
|
||||
def encrypt(self, data) -> bytes:
|
||||
def encrypt(self, data: bytes) -> bytes:
|
||||
"""Encrypt the message."""
|
||||
encryptor = self.cipher.encryptor()
|
||||
padder = self.padding_strategy.padder()
|
||||
@ -416,7 +418,7 @@ class AesEncyptionSession:
|
||||
encrypted = encryptor.update(padded_data) + encryptor.finalize()
|
||||
return base64.b64encode(encrypted)
|
||||
|
||||
def decrypt(self, data) -> str:
|
||||
def decrypt(self, data: str | bytes) -> str:
|
||||
"""Decrypt the message."""
|
||||
decryptor = self.cipher.decryptor()
|
||||
unpadder = self.padding_strategy.unpadder()
|
||||
@ -429,14 +431,16 @@ class KeyPair:
|
||||
"""Class for generating key pairs."""
|
||||
|
||||
@staticmethod
|
||||
def create_key_pair(key_size: int = 1024):
|
||||
def create_key_pair(key_size: int = 1024) -> KeyPair:
|
||||
"""Create a key pair."""
|
||||
private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size)
|
||||
public_key = private_key.public_key()
|
||||
return KeyPair(private_key, public_key)
|
||||
|
||||
@staticmethod
|
||||
def create_from_der_keys(private_key_der_b64: str, public_key_der_b64: str):
|
||||
def create_from_der_keys(
|
||||
private_key_der_b64: str, public_key_der_b64: str
|
||||
) -> KeyPair:
|
||||
"""Create a key pair."""
|
||||
key_bytes = base64.b64decode(private_key_der_b64.encode())
|
||||
private_key = cast(
|
||||
@ -449,7 +453,9 @@ class KeyPair:
|
||||
|
||||
return KeyPair(private_key, public_key)
|
||||
|
||||
def __init__(self, private_key: rsa.RSAPrivateKey, public_key: rsa.RSAPublicKey):
|
||||
def __init__(
|
||||
self, private_key: rsa.RSAPrivateKey, public_key: rsa.RSAPublicKey
|
||||
) -> None:
|
||||
self.private_key = private_key
|
||||
self.public_key = public_key
|
||||
self.private_key_der_bytes = self.private_key.private_bytes(
|
||||
|
@ -7,7 +7,7 @@ import re
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from functools import singledispatch, update_wrapper, wraps
|
||||
from typing import Final
|
||||
from typing import TYPE_CHECKING, Any, Callable, Final
|
||||
|
||||
import asyncclick as click
|
||||
|
||||
@ -37,7 +37,7 @@ except ImportError:
|
||||
"""Strip rich formatting from messages."""
|
||||
|
||||
@wraps(echo_func)
|
||||
def wrapper(message=None, *args, **kwargs):
|
||||
def wrapper(message=None, *args, **kwargs) -> None:
|
||||
if message is not None:
|
||||
message = rich_formatting.sub("", message)
|
||||
echo_func(message, *args, **kwargs)
|
||||
@ -47,20 +47,20 @@ except ImportError:
|
||||
_echo = _strip_rich_formatting(click.echo)
|
||||
|
||||
|
||||
def echo(*args, **kwargs):
|
||||
def echo(*args, **kwargs) -> None:
|
||||
"""Print a message."""
|
||||
ctx = click.get_current_context().find_root()
|
||||
if "json" not in ctx.params or ctx.params["json"] is False:
|
||||
_echo(*args, **kwargs)
|
||||
|
||||
|
||||
def error(msg: str):
|
||||
def error(msg: str) -> None:
|
||||
"""Print an error and exit."""
|
||||
echo(f"[bold red]{msg}[/bold red]")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def json_formatter_cb(result, **kwargs):
|
||||
def json_formatter_cb(result: Any, **kwargs) -> None:
|
||||
"""Format and output the result as JSON, if requested."""
|
||||
if not kwargs.get("json"):
|
||||
return
|
||||
@ -82,7 +82,7 @@ def json_formatter_cb(result, **kwargs):
|
||||
print(json_content)
|
||||
|
||||
|
||||
def pass_dev_or_child(wrapped_function):
|
||||
def pass_dev_or_child(wrapped_function: Callable) -> Callable:
|
||||
"""Pass the device or child to the click command based on the child options."""
|
||||
child_help = (
|
||||
"Child ID or alias for controlling sub-devices. "
|
||||
@ -133,7 +133,10 @@ def pass_dev_or_child(wrapped_function):
|
||||
|
||||
|
||||
async def _get_child_device(
|
||||
device: Device, child_option, child_index_option, info_command
|
||||
device: Device,
|
||||
child_option: str | None,
|
||||
child_index_option: int | None,
|
||||
info_command: str | None,
|
||||
) -> Device | None:
|
||||
def _list_children():
|
||||
return "\n".join(
|
||||
@ -178,11 +181,15 @@ async def _get_child_device(
|
||||
f"{child_option} children are:\n{_list_children()}"
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(child_index_option, int)
|
||||
|
||||
if child_index_option + 1 > len(device.children) or child_index_option < 0:
|
||||
error(
|
||||
f"Invalid index {child_index_option}, "
|
||||
f"device has {len(device.children)} children"
|
||||
)
|
||||
|
||||
child_by_index = device.children[child_index_option]
|
||||
echo(f"Targeting child device {child_by_index.alias}")
|
||||
return child_by_index
|
||||
@ -195,7 +202,7 @@ def CatchAllExceptions(cls):
|
||||
https://stackoverflow.com/questions/52213375
|
||||
"""
|
||||
|
||||
def _handle_exception(debug, exc):
|
||||
def _handle_exception(debug, exc) -> None:
|
||||
if isinstance(exc, click.ClickException):
|
||||
raise
|
||||
# Handle exit request from click.
|
||||
|
@ -22,7 +22,7 @@ from .common import (
|
||||
|
||||
@click.group()
|
||||
@pass_dev_or_child
|
||||
def device(dev):
|
||||
def device(dev) -> None:
|
||||
"""Commands to control basic device settings."""
|
||||
|
||||
|
||||
|
@ -36,7 +36,7 @@ async def detail(ctx):
|
||||
auth_failed = []
|
||||
sem = asyncio.Semaphore()
|
||||
|
||||
async def print_unsupported(unsupported_exception: UnsupportedDeviceError):
|
||||
async def print_unsupported(unsupported_exception: UnsupportedDeviceError) -> None:
|
||||
unsupported.append(unsupported_exception)
|
||||
async with sem:
|
||||
if unsupported_exception.discovery_result:
|
||||
@ -50,7 +50,7 @@ async def detail(ctx):
|
||||
|
||||
from .device import state
|
||||
|
||||
async def print_discovered(dev: Device):
|
||||
async def print_discovered(dev: Device) -> None:
|
||||
async with sem:
|
||||
try:
|
||||
await dev.update()
|
||||
@ -189,7 +189,7 @@ async def config(ctx):
|
||||
error(f"Unable to connect to {host}")
|
||||
|
||||
|
||||
def _echo_dictionary(discovery_info: dict):
|
||||
def _echo_dictionary(discovery_info: dict) -> None:
|
||||
echo("\t[bold]== Discovery information ==[/bold]")
|
||||
for key, value in discovery_info.items():
|
||||
key_name = " ".join(x.capitalize() or "_" for x in key.split("_"))
|
||||
@ -197,7 +197,7 @@ def _echo_dictionary(discovery_info: dict):
|
||||
echo(f"\t{key_name_and_spaces}{value}")
|
||||
|
||||
|
||||
def _echo_discovery_info(discovery_info):
|
||||
def _echo_discovery_info(discovery_info) -> None:
|
||||
# We don't have discovery info when all connection params are passed manually
|
||||
if discovery_info is None:
|
||||
return
|
||||
|
@ -24,7 +24,7 @@ def _echo_features(
|
||||
category: Feature.Category | None = None,
|
||||
verbose: bool = False,
|
||||
indent: str = "\t",
|
||||
):
|
||||
) -> None:
|
||||
"""Print out a listing of features and their values."""
|
||||
if category is not None:
|
||||
features = {
|
||||
@ -43,7 +43,9 @@ def _echo_features(
|
||||
echo(f"{indent}{feat.name} ({feat.id}): [red]got exception ({ex})[/red]")
|
||||
|
||||
|
||||
def _echo_all_features(features, *, verbose=False, title_prefix=None, indent=""):
|
||||
def _echo_all_features(
|
||||
features, *, verbose=False, title_prefix=None, indent=""
|
||||
) -> None:
|
||||
"""Print out all features by category."""
|
||||
if title_prefix is not None:
|
||||
echo(f"[bold]\n{indent}== {title_prefix} ==[/bold]")
|
||||
|
@ -3,6 +3,8 @@
|
||||
Taken from the click help files.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
|
||||
import asyncclick as click
|
||||
@ -11,7 +13,7 @@ import asyncclick as click
|
||||
class LazyGroup(click.Group):
|
||||
"""Lazy group class."""
|
||||
|
||||
def __init__(self, *args, lazy_subcommands=None, **kwargs):
|
||||
def __init__(self, *args, lazy_subcommands=None, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
# lazy_subcommands is a map of the form:
|
||||
#
|
||||
@ -31,9 +33,9 @@ class LazyGroup(click.Group):
|
||||
return self._lazy_load(cmd_name)
|
||||
return super().get_command(ctx, cmd_name)
|
||||
|
||||
def format_commands(self, ctx, formatter):
|
||||
def format_commands(self, ctx, formatter) -> None:
|
||||
"""Format the top level help output."""
|
||||
sections = {}
|
||||
sections: dict[str, list] = {}
|
||||
for cmd, parent in self.lazy_subcommands.items():
|
||||
sections.setdefault(parent, [])
|
||||
cmd_obj = self.get_command(ctx, cmd)
|
||||
|
@ -15,7 +15,7 @@ from .common import echo, error, pass_dev_or_child
|
||||
|
||||
@click.group()
|
||||
@pass_dev_or_child
|
||||
def light(dev):
|
||||
def light(dev) -> None:
|
||||
"""Commands to control light settings."""
|
||||
|
||||
|
||||
|
@ -43,7 +43,7 @@ ENCRYPT_TYPES = [encrypt_type.value for encrypt_type in DeviceEncryptionType]
|
||||
DEFAULT_TARGET = "255.255.255.255"
|
||||
|
||||
|
||||
def _legacy_type_to_class(_type):
|
||||
def _legacy_type_to_class(_type: str) -> Any:
|
||||
from kasa.iot import (
|
||||
IotBulb,
|
||||
IotDimmer,
|
||||
@ -396,9 +396,9 @@ async def cli(
|
||||
|
||||
@cli.command()
|
||||
@pass_dev_or_child
|
||||
async def shell(dev: Device):
|
||||
async def shell(dev: Device) -> None:
|
||||
"""Open interactive shell."""
|
||||
echo("Opening shell for %s" % dev)
|
||||
echo(f"Opening shell for {dev}")
|
||||
from ptpython.repl import embed
|
||||
|
||||
logging.getLogger("parso").setLevel(logging.WARNING) # prompt parsing
|
||||
|
@ -14,7 +14,7 @@ from .common import (
|
||||
|
||||
@click.group()
|
||||
@pass_dev
|
||||
async def schedule(dev):
|
||||
async def schedule(dev) -> None:
|
||||
"""Scheduling commands."""
|
||||
|
||||
|
||||
|
@ -23,7 +23,7 @@ from .common import (
|
||||
|
||||
@click.group(invoke_without_command=True)
|
||||
@click.pass_context
|
||||
async def time(ctx: click.Context):
|
||||
async def time(ctx: click.Context) -> None:
|
||||
"""Get and set time."""
|
||||
if ctx.invoked_subcommand is None:
|
||||
await ctx.invoke(time_get)
|
||||
|
@ -78,13 +78,13 @@ async def energy(dev: Device, year, month, erase):
|
||||
else:
|
||||
emeter_status = dev.emeter_realtime
|
||||
|
||||
echo("Current: %s A" % emeter_status["current"])
|
||||
echo("Voltage: %s V" % emeter_status["voltage"])
|
||||
echo("Power: %s W" % emeter_status["power"])
|
||||
echo("Total consumption: %s kWh" % emeter_status["total"])
|
||||
echo("Current: {} A".format(emeter_status["current"]))
|
||||
echo("Voltage: {} V".format(emeter_status["voltage"]))
|
||||
echo("Power: {} W".format(emeter_status["power"]))
|
||||
echo("Total consumption: {} kWh".format(emeter_status["total"]))
|
||||
|
||||
echo("Today: %s kWh" % dev.emeter_today)
|
||||
echo("This month: %s kWh" % dev.emeter_this_month)
|
||||
echo(f"Today: {dev.emeter_today} kWh")
|
||||
echo(f"This month: {dev.emeter_this_month} kWh")
|
||||
|
||||
return emeter_status
|
||||
|
||||
@ -122,8 +122,8 @@ async def usage(dev: Device, year, month, erase):
|
||||
usage_data = await usage.get_daystat(year=month.year, month=month.month)
|
||||
else:
|
||||
# Call with no argument outputs summary data and returns
|
||||
echo("Today: %s minutes" % usage.usage_today)
|
||||
echo("This month: %s minutes" % usage.usage_this_month)
|
||||
echo(f"Today: {usage.usage_today} minutes")
|
||||
echo(f"This month: {usage.usage_this_month} minutes")
|
||||
|
||||
return usage
|
||||
|
||||
|
@ -16,7 +16,7 @@ from .common import (
|
||||
|
||||
@click.group()
|
||||
@pass_dev
|
||||
def wifi(dev):
|
||||
def wifi(dev) -> None:
|
||||
"""Commands to control wifi settings."""
|
||||
|
||||
|
||||
|
@ -234,10 +234,10 @@ class Device(ABC):
|
||||
return await connect(host=host, config=config) # type: ignore[arg-type]
|
||||
|
||||
@abstractmethod
|
||||
async def update(self, update_children: bool = True):
|
||||
async def update(self, update_children: bool = True) -> None:
|
||||
"""Update the device."""
|
||||
|
||||
async def disconnect(self):
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect and close any underlying connection resources."""
|
||||
await self.protocol.close()
|
||||
|
||||
@ -257,15 +257,15 @@ class Device(ABC):
|
||||
return not self.is_on
|
||||
|
||||
@abstractmethod
|
||||
async def turn_on(self, **kwargs) -> dict | None:
|
||||
async def turn_on(self, **kwargs) -> dict:
|
||||
"""Turn on the device."""
|
||||
|
||||
@abstractmethod
|
||||
async def turn_off(self, **kwargs) -> dict | None:
|
||||
async def turn_off(self, **kwargs) -> dict:
|
||||
"""Turn off the device."""
|
||||
|
||||
@abstractmethod
|
||||
async def set_state(self, on: bool):
|
||||
async def set_state(self, on: bool) -> dict:
|
||||
"""Set the device state to *on*.
|
||||
|
||||
This allows turning the device on and off.
|
||||
@ -278,7 +278,7 @@ class Device(ABC):
|
||||
return self.protocol._transport._host
|
||||
|
||||
@host.setter
|
||||
def host(self, value):
|
||||
def host(self, value: str) -> None:
|
||||
"""Set the device host.
|
||||
|
||||
Generally used by discovery to set the hostname after ip discovery.
|
||||
@ -307,7 +307,7 @@ class Device(ABC):
|
||||
return self._device_type
|
||||
|
||||
@abstractmethod
|
||||
def update_from_discover_info(self, info):
|
||||
def update_from_discover_info(self, info: dict) -> None:
|
||||
"""Update state from info from the discover call."""
|
||||
|
||||
@property
|
||||
@ -325,7 +325,7 @@ class Device(ABC):
|
||||
def alias(self) -> str | None:
|
||||
"""Returns the device alias or nickname."""
|
||||
|
||||
async def _raw_query(self, request: str | dict) -> Any:
|
||||
async def _raw_query(self, request: str | dict) -> dict:
|
||||
"""Send a raw query to the device."""
|
||||
return await self.protocol.query(request=request)
|
||||
|
||||
@ -407,7 +407,7 @@ class Device(ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def internal_state(self) -> Any:
|
||||
def internal_state(self) -> dict:
|
||||
"""Return all the internal state data."""
|
||||
|
||||
@property
|
||||
@ -420,10 +420,10 @@ class Device(ABC):
|
||||
"""Return the list of supported features."""
|
||||
return self._features
|
||||
|
||||
def _add_feature(self, feature: Feature):
|
||||
def _add_feature(self, feature: Feature) -> None:
|
||||
"""Add a new feature to the device."""
|
||||
if feature.id in self._features:
|
||||
raise KasaException("Duplicate feature id %s" % feature.id)
|
||||
raise KasaException(f"Duplicate feature id {feature.id}")
|
||||
assert feature.id is not None # TODO: hack for typing # noqa: S101
|
||||
self._features[feature.id] = feature
|
||||
|
||||
@ -446,11 +446,13 @@ class Device(ABC):
|
||||
"""Scan for available wifi networks."""
|
||||
|
||||
@abstractmethod
|
||||
async def wifi_join(self, ssid: str, password: str, keytype: str = "wpa2_psk"):
|
||||
async def wifi_join(
|
||||
self, ssid: str, password: str, keytype: str = "wpa2_psk"
|
||||
) -> dict:
|
||||
"""Join the given wifi network."""
|
||||
|
||||
@abstractmethod
|
||||
async def set_alias(self, alias: str):
|
||||
async def set_alias(self, alias: str) -> dict:
|
||||
"""Set the device name (alias)."""
|
||||
|
||||
@abstractmethod
|
||||
@ -468,7 +470,7 @@ class Device(ABC):
|
||||
Note, this does not downgrade the firmware.
|
||||
"""
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
update_needed = " - update() needed" if not self._last_update else ""
|
||||
return (
|
||||
f"<{self.device_type} at {self.host} -"
|
||||
@ -486,7 +488,9 @@ class Device(ABC):
|
||||
"is_strip_socket": (None, DeviceType.StripSocket),
|
||||
}
|
||||
|
||||
def _get_replacing_attr(self, module_name: ModuleName, *attrs):
|
||||
def _get_replacing_attr(
|
||||
self, module_name: ModuleName | None, *attrs: Any
|
||||
) -> str | None:
|
||||
# If module name is None check self
|
||||
if not module_name:
|
||||
check = self
|
||||
@ -540,7 +544,7 @@ class Device(ABC):
|
||||
"supported_modules": (None, ["modules"]),
|
||||
}
|
||||
|
||||
def __getattr__(self, name):
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
# is_device_type
|
||||
if dep_device_type_attr := self._deprecated_device_type_attributes.get(name):
|
||||
module = dep_device_type_attr[0]
|
||||
|
@ -83,7 +83,7 @@ async def _connect(config: DeviceConfig, protocol: BaseProtocol) -> Device:
|
||||
if debug_enabled:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
def _perf_log(has_params, perf_type):
|
||||
def _perf_log(has_params: bool, perf_type: str) -> None:
|
||||
nonlocal start_time
|
||||
if debug_enabled:
|
||||
end_time = time.perf_counter()
|
||||
@ -150,7 +150,7 @@ def _get_device_type_from_sys_info(info: dict[str, Any]) -> DeviceType:
|
||||
return DeviceType.LightStrip
|
||||
|
||||
return DeviceType.Bulb
|
||||
raise UnsupportedDeviceError("Unknown device type: %s" % type_)
|
||||
raise UnsupportedDeviceError(f"Unknown device type: {type_}")
|
||||
|
||||
|
||||
def get_device_class_from_sys_info(sysinfo: dict[str, Any]) -> type[IotDevice]:
|
||||
|
@ -75,14 +75,14 @@ class DeviceFamily(Enum):
|
||||
SmartIpCamera = "SMART.IPCAMERA"
|
||||
|
||||
|
||||
def _dataclass_from_dict(klass, in_val):
|
||||
def _dataclass_from_dict(klass: Any, in_val: dict) -> Any:
|
||||
if is_dataclass(klass):
|
||||
fieldtypes = {f.name: f.type for f in fields(klass)}
|
||||
val = {}
|
||||
for dict_key in in_val:
|
||||
if dict_key in fieldtypes:
|
||||
if hasattr(fieldtypes[dict_key], "from_dict"):
|
||||
val[dict_key] = fieldtypes[dict_key].from_dict(in_val[dict_key])
|
||||
val[dict_key] = fieldtypes[dict_key].from_dict(in_val[dict_key]) # type: ignore[union-attr]
|
||||
else:
|
||||
val[dict_key] = _dataclass_from_dict(
|
||||
fieldtypes[dict_key], in_val[dict_key]
|
||||
@ -91,12 +91,12 @@ def _dataclass_from_dict(klass, in_val):
|
||||
raise KasaException(
|
||||
f"Cannot create dataclass from dict, unknown key: {dict_key}"
|
||||
)
|
||||
return klass(**val)
|
||||
return klass(**val) # type: ignore[operator]
|
||||
else:
|
||||
return in_val
|
||||
|
||||
|
||||
def _dataclass_to_dict(in_val):
|
||||
def _dataclass_to_dict(in_val: Any) -> dict:
|
||||
fieldtypes = {f.name: f.type for f in fields(in_val) if f.compare}
|
||||
out_val = {}
|
||||
for field_name in fieldtypes:
|
||||
@ -210,7 +210,7 @@ class DeviceConfig:
|
||||
|
||||
aes_keys: Optional[KeyPairDict] = None
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
if self.connection_type is None:
|
||||
self.connection_type = DeviceConnectionParameters(
|
||||
DeviceFamily.IotSmartPlugSwitch, DeviceEncryptionType.Xor
|
||||
|
@ -89,9 +89,19 @@ import logging
|
||||
import secrets
|
||||
import socket
|
||||
import struct
|
||||
from collections.abc import Awaitable
|
||||
from asyncio.transports import DatagramTransport
|
||||
from pprint import pformat as pf
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, NamedTuple, Optional, Type, cast
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Type,
|
||||
cast,
|
||||
)
|
||||
|
||||
from aiohttp import ClientSession
|
||||
|
||||
@ -140,8 +150,8 @@ class ConnectAttempt(NamedTuple):
|
||||
device: type
|
||||
|
||||
|
||||
OnDiscoveredCallable = Callable[[Device], Awaitable[None]]
|
||||
OnUnsupportedCallable = Callable[[UnsupportedDeviceError], Awaitable[None]]
|
||||
OnDiscoveredCallable = Callable[[Device], Coroutine]
|
||||
OnUnsupportedCallable = Callable[[UnsupportedDeviceError], Coroutine]
|
||||
OnConnectAttemptCallable = Callable[[ConnectAttempt, bool], None]
|
||||
DeviceDict = Dict[str, Device]
|
||||
|
||||
@ -156,7 +166,7 @@ class _AesDiscoveryQuery:
|
||||
keypair: KeyPair | None = None
|
||||
|
||||
@classmethod
|
||||
def generate_query(cls):
|
||||
def generate_query(cls) -> bytearray:
|
||||
if not cls.keypair:
|
||||
cls.keypair = KeyPair.create_key_pair(key_size=2048)
|
||||
secret = secrets.token_bytes(4)
|
||||
@ -215,7 +225,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
|
||||
credentials: Credentials | None = None,
|
||||
timeout: int | None = None,
|
||||
) -> None:
|
||||
self.transport = None
|
||||
self.transport: DatagramTransport | None = None
|
||||
self.discovery_packets = discovery_packets
|
||||
self.interface = interface
|
||||
self.on_discovered = on_discovered
|
||||
@ -239,16 +249,19 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
|
||||
self.target_discovered: bool = False
|
||||
self._started_event = asyncio.Event()
|
||||
|
||||
def _run_callback_task(self, coro):
|
||||
task = asyncio.create_task(coro)
|
||||
def _run_callback_task(self, coro: Coroutine) -> None:
|
||||
task: asyncio.Task = asyncio.create_task(coro)
|
||||
self.callback_tasks.append(task)
|
||||
|
||||
async def wait_for_discovery_to_complete(self):
|
||||
async def wait_for_discovery_to_complete(self) -> None:
|
||||
"""Wait for the discovery task to complete."""
|
||||
# Give some time for connection_made event to be received
|
||||
async with asyncio_timeout(self.DISCOVERY_START_TIMEOUT):
|
||||
await self._started_event.wait()
|
||||
try:
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(self.discover_task, asyncio.Task)
|
||||
|
||||
await self.discover_task
|
||||
except asyncio.CancelledError:
|
||||
# if target_discovered then cancel was called internally
|
||||
@ -257,11 +270,11 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
|
||||
# Wait for any pending callbacks to complete
|
||||
await asyncio.gather(*self.callback_tasks)
|
||||
|
||||
def connection_made(self, transport) -> None:
|
||||
def connection_made(self, transport: DatagramTransport) -> None: # type: ignore[override]
|
||||
"""Set socket options for broadcasting."""
|
||||
self.transport = transport
|
||||
self.transport = cast(DatagramTransport, transport)
|
||||
|
||||
sock = transport.get_extra_info("socket")
|
||||
sock = self.transport.get_extra_info("socket")
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
|
||||
try:
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
@ -292,7 +305,11 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
|
||||
self.transport.sendto(aes_discovery_query, self.target_2) # type: ignore
|
||||
await asyncio.sleep(sleep_between_packets)
|
||||
|
||||
def datagram_received(self, data, addr) -> None:
|
||||
def datagram_received(
|
||||
self,
|
||||
data: bytes,
|
||||
addr: tuple[str, int],
|
||||
) -> None:
|
||||
"""Handle discovery responses."""
|
||||
if TYPE_CHECKING:
|
||||
assert _AesDiscoveryQuery.keypair
|
||||
@ -338,18 +355,18 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
|
||||
|
||||
self._handle_discovered_event()
|
||||
|
||||
def _handle_discovered_event(self):
|
||||
def _handle_discovered_event(self) -> None:
|
||||
"""If target is in seen_hosts cancel discover_task."""
|
||||
if self.target in self.seen_hosts:
|
||||
self.target_discovered = True
|
||||
if self.discover_task:
|
||||
self.discover_task.cancel()
|
||||
|
||||
def error_received(self, ex):
|
||||
def error_received(self, ex: Exception) -> None:
|
||||
"""Handle asyncio.Protocol errors."""
|
||||
_LOGGER.error("Got error: %s", ex)
|
||||
|
||||
def connection_lost(self, ex): # pragma: no cover
|
||||
def connection_lost(self, ex: Exception | None) -> None: # pragma: no cover
|
||||
"""Cancel the discover task if running."""
|
||||
if self.discover_task:
|
||||
self.discover_task.cancel()
|
||||
@ -372,17 +389,17 @@ class Discover:
|
||||
@staticmethod
|
||||
async def discover(
|
||||
*,
|
||||
target="255.255.255.255",
|
||||
on_discovered=None,
|
||||
discovery_timeout=5,
|
||||
discovery_packets=3,
|
||||
interface=None,
|
||||
on_unsupported=None,
|
||||
credentials=None,
|
||||
target: str = "255.255.255.255",
|
||||
on_discovered: OnDiscoveredCallable | None = None,
|
||||
discovery_timeout: int = 5,
|
||||
discovery_packets: int = 3,
|
||||
interface: str | None = None,
|
||||
on_unsupported: OnUnsupportedCallable | None = None,
|
||||
credentials: Credentials | None = None,
|
||||
username: str | None = None,
|
||||
password: str | None = None,
|
||||
port=None,
|
||||
timeout=None,
|
||||
port: int | None = None,
|
||||
timeout: int | None = None,
|
||||
) -> DeviceDict:
|
||||
"""Discover supported devices.
|
||||
|
||||
@ -636,7 +653,7 @@ class Discover:
|
||||
)
|
||||
if not dev_class:
|
||||
raise UnsupportedDeviceError(
|
||||
"Unknown device type: %s" % discovery_result.device_type,
|
||||
f"Unknown device type: {discovery_result.device_type}",
|
||||
discovery_result=info,
|
||||
)
|
||||
return dev_class
|
||||
|
@ -49,13 +49,13 @@ class EmeterStatus(dict):
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<EmeterStatus power={self.power} voltage={self.voltage}"
|
||||
f" current={self.current} total={self.total}>"
|
||||
)
|
||||
|
||||
def __getitem__(self, item):
|
||||
def __getitem__(self, item: str) -> float | None:
|
||||
"""Return value in wanted units."""
|
||||
valid_keys = [
|
||||
"voltage_mv",
|
||||
|
@ -15,10 +15,10 @@ class KasaException(Exception):
|
||||
class TimeoutError(KasaException, _asyncioTimeoutError):
|
||||
"""Timeout exception for device errors."""
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return KasaException.__repr__(self)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return KasaException.__str__(self)
|
||||
|
||||
|
||||
@ -42,11 +42,11 @@ class DeviceError(KasaException):
|
||||
self.error_code: SmartErrorCode | None = kwargs.get("error_code", None)
|
||||
super().__init__(*args)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
err_code = self.error_code.__repr__() if self.error_code else ""
|
||||
return f"{self.__class__.__name__}({err_code})"
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
err_code = f" (error_code={self.error_code.name})" if self.error_code else ""
|
||||
return super().__str__() + err_code
|
||||
|
||||
@ -62,7 +62,7 @@ class _RetryableError(DeviceError):
|
||||
class SmartErrorCode(IntEnum):
|
||||
"""Enum for SMART Error Codes."""
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return f"{self.name}({self.value})"
|
||||
|
||||
@staticmethod
|
||||
|
@ -12,12 +12,12 @@ class Experimental:
|
||||
ENV_VAR = "KASA_EXPERIMENTAL"
|
||||
|
||||
@classmethod
|
||||
def set_enabled(cls, enabled):
|
||||
def set_enabled(cls, enabled: bool) -> None:
|
||||
"""Set the enabled value."""
|
||||
cls._enabled = enabled
|
||||
|
||||
@classmethod
|
||||
def enabled(cls):
|
||||
def enabled(cls) -> bool:
|
||||
"""Get the enabled value."""
|
||||
if cls._enabled is not None:
|
||||
return cls._enabled
|
||||
|
@ -50,11 +50,13 @@ class SmartCameraProtocol(SmartProtocol):
|
||||
"""Class for SmartCamera Protocol."""
|
||||
|
||||
async def _handle_response_lists(
|
||||
self, response_result: dict[str, Any], method, retry_count
|
||||
):
|
||||
self, response_result: dict[str, Any], method: str, retry_count: int
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def _handle_response_error_code(self, resp_dict: dict, method, raise_on_error=True):
|
||||
def _handle_response_error_code(
|
||||
self, resp_dict: dict, method: str, raise_on_error: bool = True
|
||||
) -> None:
|
||||
error_code_raw = resp_dict.get("error_code")
|
||||
try:
|
||||
error_code = SmartErrorCode.from_int(error_code_raw)
|
||||
@ -203,7 +205,7 @@ class _ChildCameraProtocolWrapper(SmartProtocol):
|
||||
device responses before returning to the caller.
|
||||
"""
|
||||
|
||||
def __init__(self, device_id: str, base_protocol: SmartProtocol):
|
||||
def __init__(self, device_id: str, base_protocol: SmartProtocol) -> None:
|
||||
self._device_id = device_id
|
||||
self._protocol = base_protocol
|
||||
self._transport = base_protocol._transport
|
||||
|
@ -256,7 +256,9 @@ class SslAesTransport(BaseTransport):
|
||||
return ret_val # type: ignore[return-value]
|
||||
|
||||
@staticmethod
|
||||
def generate_confirm_hash(local_nonce, server_nonce, pwd_hash):
|
||||
def generate_confirm_hash(
|
||||
local_nonce: str, server_nonce: str, pwd_hash: str
|
||||
) -> str:
|
||||
"""Generate an auth hash for the protocol on the supplied credentials."""
|
||||
expected_confirm_bytes = _sha256_hash(
|
||||
local_nonce.encode() + pwd_hash.encode() + server_nonce.encode()
|
||||
@ -264,7 +266,9 @@ class SslAesTransport(BaseTransport):
|
||||
return expected_confirm_bytes + server_nonce + local_nonce
|
||||
|
||||
@staticmethod
|
||||
def generate_digest_password(local_nonce, server_nonce, pwd_hash):
|
||||
def generate_digest_password(
|
||||
local_nonce: str, server_nonce: str, pwd_hash: str
|
||||
) -> str:
|
||||
"""Generate an auth hash for the protocol on the supplied credentials."""
|
||||
digest_password_hash = _sha256_hash(
|
||||
pwd_hash.encode() + local_nonce.encode() + server_nonce.encode()
|
||||
@ -275,7 +279,7 @@ class SslAesTransport(BaseTransport):
|
||||
|
||||
@staticmethod
|
||||
def generate_encryption_token(
|
||||
token_type, local_nonce, server_nonce, pwd_hash
|
||||
token_type: str, local_nonce: str, server_nonce: str, pwd_hash: str
|
||||
) -> bytes:
|
||||
"""Generate encryption token."""
|
||||
hashedKey = _sha256_hash(
|
||||
@ -302,7 +306,9 @@ class SslAesTransport(BaseTransport):
|
||||
local_nonce, server_nonce, pwd_hash = await self.perform_handshake1()
|
||||
await self.perform_handshake2(local_nonce, server_nonce, pwd_hash)
|
||||
|
||||
async def perform_handshake2(self, local_nonce, server_nonce, pwd_hash) -> None:
|
||||
async def perform_handshake2(
|
||||
self, local_nonce: str, server_nonce: str, pwd_hash: str
|
||||
) -> None:
|
||||
"""Perform the handshake."""
|
||||
_LOGGER.debug("Performing handshake2 ...")
|
||||
digest_password = self.generate_digest_password(
|
||||
|
@ -162,7 +162,7 @@ class Feature:
|
||||
#: If set, this property will be used to get *choices*.
|
||||
choices_getter: str | Callable[[], list[str]] | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
"""Handle late-binding of members."""
|
||||
# Populate minimum & maximum values, if range_getter is given
|
||||
self._container = self.container if self.container is not None else self.device
|
||||
@ -188,7 +188,7 @@ class Feature:
|
||||
f"Read-only feat defines attribute_setter: {self.name} ({self.id}):"
|
||||
)
|
||||
|
||||
def _get_property_value(self, getter):
|
||||
def _get_property_value(self, getter: str | Callable | None) -> Any:
|
||||
if getter is None:
|
||||
return None
|
||||
if isinstance(getter, str):
|
||||
@ -227,7 +227,7 @@ class Feature:
|
||||
return 0
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
def value(self) -> int | float | bool | str | Enum | None:
|
||||
"""Return the current value."""
|
||||
if self.type == Feature.Type.Action:
|
||||
return "<Action>"
|
||||
@ -264,7 +264,7 @@ class Feature:
|
||||
|
||||
return await getattr(container, self.attribute_setter)(value)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
try:
|
||||
value = self.value
|
||||
choices = self.choices
|
||||
@ -286,8 +286,8 @@ class Feature:
|
||||
value = " ".join(
|
||||
[f"*{choice}*" if choice == value else choice for choice in choices]
|
||||
)
|
||||
if self.precision_hint is not None and value is not None:
|
||||
value = round(self.value, self.precision_hint)
|
||||
if self.precision_hint is not None and isinstance(value, float):
|
||||
value = round(value, self.precision_hint)
|
||||
|
||||
s = f"{self.name} ({self.id}): {value}"
|
||||
if self.unit is not None:
|
||||
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import ssl
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
|
||||
@ -64,7 +65,7 @@ class HttpClient:
|
||||
json: dict | Any | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
cookies_dict: dict[str, str] | None = None,
|
||||
ssl=False,
|
||||
ssl: ssl.SSLContext | bool = False,
|
||||
) -> tuple[int, dict | bytes | None]:
|
||||
"""Send an http post request to the device.
|
||||
|
||||
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import IntFlag, auto
|
||||
from typing import Any
|
||||
from warnings import warn
|
||||
|
||||
from ..emeterstatus import EmeterStatus
|
||||
@ -31,7 +32,7 @@ class Energy(Module, ABC):
|
||||
"""Return True if module supports the feature."""
|
||||
return module_feature in self._supported
|
||||
|
||||
def _initialize_features(self):
|
||||
def _initialize_features(self) -> None:
|
||||
"""Initialize features."""
|
||||
device = self._device
|
||||
self._add_feature(
|
||||
@ -151,22 +152,26 @@ class Energy(Module, ABC):
|
||||
"""Get the current voltage in V."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_status(self):
|
||||
async def get_status(self) -> EmeterStatus:
|
||||
"""Return real-time statistics."""
|
||||
|
||||
@abstractmethod
|
||||
async def erase_stats(self):
|
||||
async def erase_stats(self) -> dict:
|
||||
"""Erase all stats."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_daily_stats(self, *, year=None, month=None, kwh=True) -> dict:
|
||||
async def get_daily_stats(
|
||||
self, *, year: int | None = None, month: int | None = None, kwh: bool = True
|
||||
) -> dict:
|
||||
"""Return daily stats for the given year & month.
|
||||
|
||||
The return value is a dictionary of {day: energy, ...}.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_monthly_stats(self, *, year=None, kwh=True) -> dict:
|
||||
async def get_monthly_stats(
|
||||
self, *, year: int | None = None, kwh: bool = True
|
||||
) -> dict:
|
||||
"""Return monthly stats for the given year."""
|
||||
|
||||
_deprecated_attributes = {
|
||||
@ -179,7 +184,7 @@ class Energy(Module, ABC):
|
||||
"get_monthstat": "get_monthly_stats",
|
||||
}
|
||||
|
||||
def __getattr__(self, name):
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
if attr := self._deprecated_attributes.get(name):
|
||||
msg = f"{name} is deprecated, use {attr} instead"
|
||||
warn(msg, DeprecationWarning, stacklevel=2)
|
||||
|
@ -16,5 +16,5 @@ class Fan(Module, ABC):
|
||||
"""Return fan speed level."""
|
||||
|
||||
@abstractmethod
|
||||
async def set_fan_speed_level(self, level: int):
|
||||
async def set_fan_speed_level(self, level: int) -> dict:
|
||||
"""Set fan speed level."""
|
||||
|
@ -11,7 +11,7 @@ from ..module import Module
|
||||
class Led(Module, ABC):
|
||||
"""Base interface to represent a LED module."""
|
||||
|
||||
def _initialize_features(self):
|
||||
def _initialize_features(self) -> None:
|
||||
"""Initialize features."""
|
||||
device = self._device
|
||||
self._add_feature(
|
||||
@ -34,5 +34,5 @@ class Led(Module, ABC):
|
||||
"""Return current led status."""
|
||||
|
||||
@abstractmethod
|
||||
async def set_led(self, enable: bool) -> None:
|
||||
async def set_led(self, enable: bool) -> dict:
|
||||
"""Set led."""
|
||||
|
@ -166,7 +166,7 @@ class Light(Module, ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def set_color_temp(
|
||||
self, temp: int, *, brightness=None, transition: int | None = None
|
||||
self, temp: int, *, brightness: int | None = None, transition: int | None = None
|
||||
) -> dict:
|
||||
"""Set the color temperature of the device in kelvin.
|
||||
|
||||
|
@ -53,7 +53,7 @@ class LightEffect(Module, ABC):
|
||||
|
||||
LIGHT_EFFECTS_OFF = "Off"
|
||||
|
||||
def _initialize_features(self):
|
||||
def _initialize_features(self) -> None:
|
||||
"""Initialize features."""
|
||||
device = self._device
|
||||
self._add_feature(
|
||||
@ -96,7 +96,7 @@ class LightEffect(Module, ABC):
|
||||
*,
|
||||
brightness: int | None = None,
|
||||
transition: int | None = None,
|
||||
) -> None:
|
||||
) -> dict:
|
||||
"""Set an effect on the device.
|
||||
|
||||
If brightness or transition is defined,
|
||||
@ -110,10 +110,11 @@ class LightEffect(Module, ABC):
|
||||
:param int transition: The wanted transition time
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def set_custom_effect(
|
||||
self,
|
||||
effect_dict: dict,
|
||||
) -> None:
|
||||
) -> dict:
|
||||
"""Set a custom effect on the device.
|
||||
|
||||
:param str effect_dict: The custom effect dict to set
|
||||
|
@ -83,7 +83,7 @@ class LightPreset(Module):
|
||||
|
||||
PRESET_NOT_SET = "Not set"
|
||||
|
||||
def _initialize_features(self):
|
||||
def _initialize_features(self) -> None:
|
||||
"""Initialize features."""
|
||||
device = self._device
|
||||
self._add_feature(
|
||||
@ -127,7 +127,7 @@ class LightPreset(Module):
|
||||
async def set_preset(
|
||||
self,
|
||||
preset_name: str,
|
||||
) -> None:
|
||||
) -> dict:
|
||||
"""Set a light preset for the device."""
|
||||
|
||||
@abstractmethod
|
||||
@ -135,7 +135,7 @@ class LightPreset(Module):
|
||||
self,
|
||||
preset_name: str,
|
||||
preset_info: LightState,
|
||||
) -> None:
|
||||
) -> dict:
|
||||
"""Update the preset with *preset_name* with the new *preset_info*."""
|
||||
|
||||
@property
|
||||
|
@ -54,7 +54,7 @@ class TurnOnBehavior(BaseModel):
|
||||
mode: BehaviorMode
|
||||
|
||||
@root_validator
|
||||
def _mode_based_on_preset(cls, values):
|
||||
def _mode_based_on_preset(cls, values: dict) -> dict:
|
||||
"""Set the mode based on the preset value."""
|
||||
if values["preset"] is not None:
|
||||
values["mode"] = BehaviorMode.Preset
|
||||
@ -209,7 +209,7 @@ class IotBulb(IotDevice):
|
||||
super().__init__(host=host, config=config, protocol=protocol)
|
||||
self._device_type = DeviceType.Bulb
|
||||
|
||||
async def _initialize_modules(self):
|
||||
async def _initialize_modules(self) -> None:
|
||||
"""Initialize modules not added in init."""
|
||||
await super()._initialize_modules()
|
||||
self.add_module(
|
||||
@ -307,7 +307,7 @@ class IotBulb(IotDevice):
|
||||
await self._query_helper(self.LIGHT_SERVICE, "get_default_behavior")
|
||||
)
|
||||
|
||||
async def set_turn_on_behavior(self, behavior: TurnOnBehaviors):
|
||||
async def set_turn_on_behavior(self, behavior: TurnOnBehaviors) -> dict:
|
||||
"""Set the behavior for turning the bulb on.
|
||||
|
||||
If you do not want to manually construct the behavior object,
|
||||
@ -426,7 +426,7 @@ class IotBulb(IotDevice):
|
||||
|
||||
@requires_update
|
||||
async def _set_color_temp(
|
||||
self, temp: int, *, brightness=None, transition: int | None = None
|
||||
self, temp: int, *, brightness: int | None = None, transition: int | None = None
|
||||
) -> dict:
|
||||
"""Set the color temperature of the device in kelvin.
|
||||
|
||||
@ -450,7 +450,7 @@ class IotBulb(IotDevice):
|
||||
|
||||
return await self._set_light_state(light_state, transition=transition)
|
||||
|
||||
def _raise_for_invalid_brightness(self, value):
|
||||
def _raise_for_invalid_brightness(self, value: int) -> None:
|
||||
if not isinstance(value, int):
|
||||
raise TypeError("Brightness must be an integer")
|
||||
if not (0 <= value <= 100):
|
||||
@ -517,7 +517,7 @@ class IotBulb(IotDevice):
|
||||
"""Return that the bulb has an emeter."""
|
||||
return True
|
||||
|
||||
async def set_alias(self, alias: str) -> None:
|
||||
async def set_alias(self, alias: str) -> dict:
|
||||
"""Set the device name (alias).
|
||||
|
||||
Overridden to use a different module name.
|
||||
|
@ -19,7 +19,7 @@ import inspect
|
||||
import logging
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime, timedelta, tzinfo
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import TYPE_CHECKING, Any, Callable, cast
|
||||
from warnings import warn
|
||||
|
||||
from ..device import Device, WifiNetwork
|
||||
@ -35,12 +35,12 @@ from .modules import Emeter
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def requires_update(f):
|
||||
def requires_update(f: Callable) -> Any:
|
||||
"""Indicate that `update` should be called before accessing this method.""" # noqa: D202
|
||||
if inspect.iscoroutinefunction(f):
|
||||
|
||||
@functools.wraps(f)
|
||||
async def wrapped(*args, **kwargs):
|
||||
async def wrapped(*args: Any, **kwargs: Any) -> Any:
|
||||
self = args[0]
|
||||
if self._last_update is None and f.__name__ not in self._sys_info:
|
||||
raise KasaException("You need to await update() to access the data")
|
||||
@ -49,13 +49,13 @@ def requires_update(f):
|
||||
else:
|
||||
|
||||
@functools.wraps(f)
|
||||
def wrapped(*args, **kwargs):
|
||||
def wrapped(*args: Any, **kwargs: Any) -> Any:
|
||||
self = args[0]
|
||||
if self._last_update is None and f.__name__ not in self._sys_info:
|
||||
raise KasaException("You need to await update() to access the data")
|
||||
return f(*args, **kwargs)
|
||||
|
||||
f.requires_update = True
|
||||
f.requires_update = True # type: ignore[attr-defined]
|
||||
return wrapped
|
||||
|
||||
|
||||
@ -197,7 +197,7 @@ class IotDevice(Device):
|
||||
return cast(ModuleMapping[IotModule], self._supported_modules)
|
||||
return self._supported_modules
|
||||
|
||||
def add_module(self, name: str | ModuleName[Module], module: IotModule):
|
||||
def add_module(self, name: str | ModuleName[Module], module: IotModule) -> None:
|
||||
"""Register a module."""
|
||||
if name in self._modules:
|
||||
_LOGGER.debug("Module %s already registered, ignoring...", name)
|
||||
@ -207,8 +207,12 @@ class IotDevice(Device):
|
||||
self._modules[name] = module
|
||||
|
||||
def _create_request(
|
||||
self, target: str, cmd: str, arg: dict | None = None, child_ids=None
|
||||
):
|
||||
self,
|
||||
target: str,
|
||||
cmd: str,
|
||||
arg: dict | None = None,
|
||||
child_ids: list | None = None,
|
||||
) -> dict:
|
||||
if arg is None:
|
||||
arg = {}
|
||||
request: dict[str, Any] = {target: {cmd: arg}}
|
||||
@ -225,8 +229,12 @@ class IotDevice(Device):
|
||||
raise KasaException("update() required prior accessing emeter")
|
||||
|
||||
async def _query_helper(
|
||||
self, target: str, cmd: str, arg: dict | None = None, child_ids=None
|
||||
) -> Any:
|
||||
self,
|
||||
target: str,
|
||||
cmd: str,
|
||||
arg: dict | None = None,
|
||||
child_ids: list | None = None,
|
||||
) -> dict:
|
||||
"""Query device, return results or raise an exception.
|
||||
|
||||
:param target: Target system {system, time, emeter, ..}
|
||||
@ -276,7 +284,7 @@ class IotDevice(Device):
|
||||
"""Retrieve system information."""
|
||||
return await self._query_helper("system", "get_sysinfo")
|
||||
|
||||
async def update(self, update_children: bool = True):
|
||||
async def update(self, update_children: bool = True) -> None:
|
||||
"""Query the device to update the data.
|
||||
|
||||
Needed for properties that are decorated with `requires_update`.
|
||||
@ -305,7 +313,7 @@ class IotDevice(Device):
|
||||
if not self._features:
|
||||
await self._initialize_features()
|
||||
|
||||
async def _initialize_modules(self):
|
||||
async def _initialize_modules(self) -> None:
|
||||
"""Initialize modules not added in init."""
|
||||
if self.has_emeter:
|
||||
_LOGGER.debug(
|
||||
@ -313,7 +321,7 @@ class IotDevice(Device):
|
||||
)
|
||||
self.add_module(Module.Energy, Emeter(self, self.emeter_type))
|
||||
|
||||
async def _initialize_features(self):
|
||||
async def _initialize_features(self) -> None:
|
||||
"""Initialize common features."""
|
||||
self._add_feature(
|
||||
Feature(
|
||||
@ -364,7 +372,7 @@ class IotDevice(Device):
|
||||
)
|
||||
)
|
||||
|
||||
for module in self._supported_modules.values():
|
||||
for module in self.modules.values():
|
||||
module._initialize_features()
|
||||
for module_feat in module._module_features.values():
|
||||
self._add_feature(module_feat)
|
||||
@ -453,7 +461,7 @@ class IotDevice(Device):
|
||||
sys_info = self._sys_info
|
||||
return sys_info.get("alias") if sys_info else None
|
||||
|
||||
async def set_alias(self, alias: str) -> None:
|
||||
async def set_alias(self, alias: str) -> dict:
|
||||
"""Set the device name (alias)."""
|
||||
return await self._query_helper("system", "set_dev_alias", {"alias": alias})
|
||||
|
||||
@ -550,7 +558,7 @@ class IotDevice(Device):
|
||||
|
||||
return mac
|
||||
|
||||
async def set_mac(self, mac):
|
||||
async def set_mac(self, mac: str) -> dict:
|
||||
"""Set the mac address.
|
||||
|
||||
:param str mac: mac in hexadecimal with colons, e.g. 01:23:45:67:89:ab
|
||||
@ -576,7 +584,7 @@ class IotDevice(Device):
|
||||
"""Turn off the device."""
|
||||
raise NotImplementedError("Device subclass needs to implement this.")
|
||||
|
||||
async def turn_on(self, **kwargs) -> dict | None:
|
||||
async def turn_on(self, **kwargs) -> dict:
|
||||
"""Turn device on."""
|
||||
raise NotImplementedError("Device subclass needs to implement this.")
|
||||
|
||||
@ -586,7 +594,7 @@ class IotDevice(Device):
|
||||
"""Return True if the device is on."""
|
||||
raise NotImplementedError("Device subclass needs to implement this.")
|
||||
|
||||
async def set_state(self, on: bool):
|
||||
async def set_state(self, on: bool) -> dict:
|
||||
"""Set the device state."""
|
||||
if on:
|
||||
return await self.turn_on()
|
||||
@ -627,7 +635,7 @@ class IotDevice(Device):
|
||||
async def wifi_scan(self) -> list[WifiNetwork]: # noqa: D202
|
||||
"""Scan for available wifi networks."""
|
||||
|
||||
async def _scan(target):
|
||||
async def _scan(target: str) -> dict:
|
||||
return await self._query_helper(target, "get_scaninfo", {"refresh": 1})
|
||||
|
||||
try:
|
||||
@ -639,17 +647,17 @@ class IotDevice(Device):
|
||||
info = await _scan("smartlife.iot.common.softaponboarding")
|
||||
|
||||
if "ap_list" not in info:
|
||||
raise KasaException("Invalid response for wifi scan: %s" % info)
|
||||
raise KasaException(f"Invalid response for wifi scan: {info}")
|
||||
|
||||
return [WifiNetwork(**x) for x in info["ap_list"]]
|
||||
|
||||
async def wifi_join(self, ssid: str, password: str, keytype: str = "3"): # noqa: D202
|
||||
async def wifi_join(self, ssid: str, password: str, keytype: str = "3") -> dict: # noqa: D202
|
||||
"""Join the given wifi network.
|
||||
|
||||
If joining the network fails, the device will return to AP mode after a while.
|
||||
"""
|
||||
|
||||
async def _join(target, payload):
|
||||
async def _join(target: str, payload: dict) -> dict:
|
||||
return await self._query_helper(target, "set_stainfo", payload)
|
||||
|
||||
payload = {"ssid": ssid, "password": password, "key_type": int(keytype)}
|
||||
|
@ -80,7 +80,7 @@ class IotDimmer(IotPlug):
|
||||
super().__init__(host=host, config=config, protocol=protocol)
|
||||
self._device_type = DeviceType.Dimmer
|
||||
|
||||
async def _initialize_modules(self):
|
||||
async def _initialize_modules(self) -> None:
|
||||
"""Initialize modules."""
|
||||
await super()._initialize_modules()
|
||||
# TODO: need to be verified if it's okay to call these on HS220 w/o these
|
||||
@ -103,7 +103,9 @@ class IotDimmer(IotPlug):
|
||||
return int(sys_info["brightness"])
|
||||
|
||||
@requires_update
|
||||
async def _set_brightness(self, brightness: int, *, transition: int | None = None):
|
||||
async def _set_brightness(
|
||||
self, brightness: int, *, transition: int | None = None
|
||||
) -> dict:
|
||||
"""Set the new dimmer brightness level in percentage.
|
||||
|
||||
:param int transition: transition duration in milliseconds.
|
||||
@ -134,7 +136,7 @@ class IotDimmer(IotPlug):
|
||||
self.DIMMER_SERVICE, "set_brightness", {"brightness": brightness}
|
||||
)
|
||||
|
||||
async def turn_off(self, *, transition: int | None = None, **kwargs):
|
||||
async def turn_off(self, *, transition: int | None = None, **kwargs) -> dict:
|
||||
"""Turn the bulb off.
|
||||
|
||||
:param int transition: transition duration in milliseconds.
|
||||
@ -145,7 +147,7 @@ class IotDimmer(IotPlug):
|
||||
return await super().turn_off()
|
||||
|
||||
@requires_update
|
||||
async def turn_on(self, *, transition: int | None = None, **kwargs):
|
||||
async def turn_on(self, *, transition: int | None = None, **kwargs) -> dict:
|
||||
"""Turn the bulb on.
|
||||
|
||||
:param int transition: transition duration in milliseconds.
|
||||
@ -157,7 +159,7 @@ class IotDimmer(IotPlug):
|
||||
|
||||
return await super().turn_on()
|
||||
|
||||
async def set_dimmer_transition(self, brightness: int, transition: int):
|
||||
async def set_dimmer_transition(self, brightness: int, transition: int) -> dict:
|
||||
"""Turn the bulb on to brightness percentage over transition milliseconds.
|
||||
|
||||
A brightness value of 0 will turn off the dimmer.
|
||||
@ -176,7 +178,7 @@ class IotDimmer(IotPlug):
|
||||
if not isinstance(transition, int):
|
||||
raise TypeError(f"Transition must be integer, not of {type(transition)}.")
|
||||
if transition <= 0:
|
||||
raise ValueError("Transition value %s is not valid." % transition)
|
||||
raise ValueError(f"Transition value {transition} is not valid.")
|
||||
|
||||
return await self._query_helper(
|
||||
self.DIMMER_SERVICE,
|
||||
@ -185,7 +187,7 @@ class IotDimmer(IotPlug):
|
||||
)
|
||||
|
||||
@requires_update
|
||||
async def get_behaviors(self):
|
||||
async def get_behaviors(self) -> dict:
|
||||
"""Return button behavior settings."""
|
||||
behaviors = await self._query_helper(
|
||||
self.DIMMER_SERVICE, "get_default_behavior", {}
|
||||
@ -195,7 +197,7 @@ class IotDimmer(IotPlug):
|
||||
@requires_update
|
||||
async def set_button_action(
|
||||
self, action_type: ActionType, action: ButtonAction, index: int | None = None
|
||||
):
|
||||
) -> dict:
|
||||
"""Set action to perform on button click/hold.
|
||||
|
||||
:param action_type ActionType: whether to control double click or hold action.
|
||||
@ -209,15 +211,17 @@ class IotDimmer(IotPlug):
|
||||
if index is not None:
|
||||
payload["index"] = index
|
||||
|
||||
await self._query_helper(self.DIMMER_SERVICE, action_type_setter, payload)
|
||||
return await self._query_helper(
|
||||
self.DIMMER_SERVICE, action_type_setter, payload
|
||||
)
|
||||
|
||||
@requires_update
|
||||
async def set_fade_time(self, fade_type: FadeType, time: int):
|
||||
async def set_fade_time(self, fade_type: FadeType, time: int) -> dict:
|
||||
"""Set time for fade in / fade out."""
|
||||
fade_type_setter = f"set_{fade_type}_time"
|
||||
payload = {"fadeTime": time}
|
||||
|
||||
await self._query_helper(self.DIMMER_SERVICE, fade_type_setter, payload)
|
||||
return await self._query_helper(self.DIMMER_SERVICE, fade_type_setter, payload)
|
||||
|
||||
@property # type: ignore
|
||||
@requires_update
|
||||
|
@ -57,7 +57,7 @@ class IotLightStrip(IotBulb):
|
||||
super().__init__(host=host, config=config, protocol=protocol)
|
||||
self._device_type = DeviceType.LightStrip
|
||||
|
||||
async def _initialize_modules(self):
|
||||
async def _initialize_modules(self) -> None:
|
||||
"""Initialize modules not added in init."""
|
||||
await super()._initialize_modules()
|
||||
self.add_module(
|
||||
|
@ -1,6 +1,9 @@
|
||||
"""Base class for IOT module implementations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from ..exceptions import KasaException
|
||||
from ..module import Module
|
||||
@ -24,16 +27,16 @@ merge = _merge_dict
|
||||
class IotModule(Module):
|
||||
"""Base class implemention for all IOT modules."""
|
||||
|
||||
def call(self, method, params=None):
|
||||
async def call(self, method: str, params: dict | None = None) -> dict:
|
||||
"""Call the given method with the given parameters."""
|
||||
return self._device._query_helper(self._module, method, params)
|
||||
return await self._device._query_helper(self._module, method, params)
|
||||
|
||||
def query_for_command(self, query, params=None):
|
||||
def query_for_command(self, query: str, params: dict | None = None) -> dict:
|
||||
"""Create a request object for the given parameters."""
|
||||
return self._device._create_request(self._module, query, params)
|
||||
|
||||
@property
|
||||
def estimated_query_response_size(self):
|
||||
def estimated_query_response_size(self) -> int:
|
||||
"""Estimated maximum size of query response.
|
||||
|
||||
The inheriting modules implement this to estimate how large a query response
|
||||
@ -42,7 +45,7 @@ class IotModule(Module):
|
||||
return 256 # Estimate for modules that don't specify
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
def data(self) -> dict[str, Any]:
|
||||
"""Return the module specific raw data from the last update."""
|
||||
dev = self._device
|
||||
q = self.query()
|
||||
|
@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from ..device_type import DeviceType
|
||||
from ..deviceconfig import DeviceConfig
|
||||
@ -54,7 +55,7 @@ class IotPlug(IotDevice):
|
||||
super().__init__(host=host, config=config, protocol=protocol)
|
||||
self._device_type = DeviceType.Plug
|
||||
|
||||
async def _initialize_modules(self):
|
||||
async def _initialize_modules(self) -> None:
|
||||
"""Initialize modules."""
|
||||
await super()._initialize_modules()
|
||||
self.add_module(Module.IotSchedule, Schedule(self, "schedule"))
|
||||
@ -71,11 +72,11 @@ class IotPlug(IotDevice):
|
||||
sys_info = self.sys_info
|
||||
return bool(sys_info["relay_state"])
|
||||
|
||||
async def turn_on(self, **kwargs):
|
||||
async def turn_on(self, **kwargs: Any) -> dict:
|
||||
"""Turn the switch on."""
|
||||
return await self._query_helper("system", "set_relay_state", {"state": 1})
|
||||
|
||||
async def turn_off(self, **kwargs):
|
||||
async def turn_off(self, **kwargs: Any) -> dict:
|
||||
"""Turn the switch off."""
|
||||
return await self._query_helper("system", "set_relay_state", {"state": 0})
|
||||
|
||||
|
@ -26,7 +26,7 @@ from .modules import Antitheft, Cloud, Countdown, Emeter, Led, Schedule, Time, U
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def merge_sums(dicts):
|
||||
def merge_sums(dicts: list[dict]) -> dict:
|
||||
"""Merge the sum of dicts."""
|
||||
total_dict: defaultdict[int, float] = defaultdict(lambda: 0.0)
|
||||
for sum_dict in dicts:
|
||||
@ -99,7 +99,7 @@ class IotStrip(IotDevice):
|
||||
self.emeter_type = "emeter"
|
||||
self._device_type = DeviceType.Strip
|
||||
|
||||
async def _initialize_modules(self):
|
||||
async def _initialize_modules(self) -> None:
|
||||
"""Initialize modules."""
|
||||
# Strip has different modules to plug so do not call super
|
||||
self.add_module(Module.IotAntitheft, Antitheft(self, "anti_theft"))
|
||||
@ -121,7 +121,7 @@ class IotStrip(IotDevice):
|
||||
"""Return if any of the outlets are on."""
|
||||
return any(plug.is_on for plug in self.children)
|
||||
|
||||
async def update(self, update_children: bool = True):
|
||||
async def update(self, update_children: bool = True) -> None:
|
||||
"""Update some of the attributes.
|
||||
|
||||
Needed for methods that are decorated with `requires_update`.
|
||||
@ -150,20 +150,20 @@ class IotStrip(IotDevice):
|
||||
if not self.features:
|
||||
await self._initialize_features()
|
||||
|
||||
async def _initialize_features(self):
|
||||
async def _initialize_features(self) -> None:
|
||||
"""Initialize common features."""
|
||||
# Do not initialize features until children are created
|
||||
if not self.children:
|
||||
return
|
||||
await super()._initialize_features()
|
||||
|
||||
async def turn_on(self, **kwargs):
|
||||
async def turn_on(self, **kwargs) -> dict:
|
||||
"""Turn the strip on."""
|
||||
await self._query_helper("system", "set_relay_state", {"state": 1})
|
||||
return await self._query_helper("system", "set_relay_state", {"state": 1})
|
||||
|
||||
async def turn_off(self, **kwargs):
|
||||
async def turn_off(self, **kwargs) -> dict:
|
||||
"""Turn the strip off."""
|
||||
await self._query_helper("system", "set_relay_state", {"state": 0})
|
||||
return await self._query_helper("system", "set_relay_state", {"state": 0})
|
||||
|
||||
@property # type: ignore
|
||||
@requires_update
|
||||
@ -188,7 +188,7 @@ class StripEmeter(IotModule, Energy):
|
||||
"""Return True if module supports the feature."""
|
||||
return module_feature in self._supported
|
||||
|
||||
def query(self):
|
||||
def query(self) -> dict:
|
||||
"""Return the base query."""
|
||||
return {}
|
||||
|
||||
@ -246,11 +246,13 @@ class StripEmeter(IotModule, Energy):
|
||||
]
|
||||
)
|
||||
|
||||
async def erase_stats(self):
|
||||
async def erase_stats(self) -> dict:
|
||||
"""Erase energy meter statistics for all plugs."""
|
||||
for plug in self._device.children:
|
||||
await plug.modules[Module.Energy].erase_stats()
|
||||
|
||||
return {}
|
||||
|
||||
@property # type: ignore
|
||||
def consumption_this_month(self) -> float | None:
|
||||
"""Return this month's energy consumption in kWh."""
|
||||
@ -320,7 +322,7 @@ class IotStripPlug(IotPlug):
|
||||
self.protocol = parent.protocol # Must use the same connection as the parent
|
||||
self._on_since: datetime | None = None
|
||||
|
||||
async def _initialize_modules(self):
|
||||
async def _initialize_modules(self) -> None:
|
||||
"""Initialize modules not added in init."""
|
||||
if self.has_emeter:
|
||||
self.add_module(Module.Energy, Emeter(self, self.emeter_type))
|
||||
@ -329,7 +331,7 @@ class IotStripPlug(IotPlug):
|
||||
self.add_module(Module.IotSchedule, Schedule(self, "schedule"))
|
||||
self.add_module(Module.IotCountdown, Countdown(self, "countdown"))
|
||||
|
||||
async def _initialize_features(self):
|
||||
async def _initialize_features(self) -> None:
|
||||
"""Initialize common features."""
|
||||
self._add_feature(
|
||||
Feature(
|
||||
@ -353,19 +355,20 @@ class IotStripPlug(IotPlug):
|
||||
type=Feature.Type.Sensor,
|
||||
)
|
||||
)
|
||||
for module in self._supported_modules.values():
|
||||
|
||||
for module in self.modules.values():
|
||||
module._initialize_features()
|
||||
for module_feat in module._module_features.values():
|
||||
self._add_feature(module_feat)
|
||||
|
||||
async def update(self, update_children: bool = True):
|
||||
async def update(self, update_children: bool = True) -> None:
|
||||
"""Query the device to update the data.
|
||||
|
||||
Needed for properties that are decorated with `requires_update`.
|
||||
"""
|
||||
await self._update(update_children)
|
||||
|
||||
async def _update(self, update_children: bool = True):
|
||||
async def _update(self, update_children: bool = True) -> None:
|
||||
"""Query the device to update the data.
|
||||
|
||||
Internal implementation to allow patching of public update in the cli
|
||||
@ -379,8 +382,12 @@ class IotStripPlug(IotPlug):
|
||||
await self._initialize_features()
|
||||
|
||||
def _create_request(
|
||||
self, target: str, cmd: str, arg: dict | None = None, child_ids=None
|
||||
):
|
||||
self,
|
||||
target: str,
|
||||
cmd: str,
|
||||
arg: dict | None = None,
|
||||
child_ids: list | None = None,
|
||||
) -> dict:
|
||||
request: dict[str, Any] = {
|
||||
"context": {"child_ids": [self.child_id]},
|
||||
target: {cmd: arg},
|
||||
@ -388,8 +395,12 @@ class IotStripPlug(IotPlug):
|
||||
return request
|
||||
|
||||
async def _query_helper(
|
||||
self, target: str, cmd: str, arg: dict | None = None, child_ids=None
|
||||
) -> Any:
|
||||
self,
|
||||
target: str,
|
||||
cmd: str,
|
||||
arg: dict | None = None,
|
||||
child_ids: list | None = None,
|
||||
) -> dict:
|
||||
"""Override query helper to include the child_ids."""
|
||||
return await self._parent._query_helper(
|
||||
target, cmd, arg, child_ids=[self.child_id]
|
||||
|
@ -11,7 +11,7 @@ _LOGGER = logging.getLogger(__name__)
|
||||
class AmbientLight(IotModule):
|
||||
"""Implements ambient light controls for the motion sensor."""
|
||||
|
||||
def _initialize_features(self):
|
||||
def _initialize_features(self) -> None:
|
||||
"""Initialize features after the initial update."""
|
||||
self._add_feature(
|
||||
Feature(
|
||||
@ -40,7 +40,7 @@ class AmbientLight(IotModule):
|
||||
)
|
||||
)
|
||||
|
||||
def query(self):
|
||||
def query(self) -> dict:
|
||||
"""Request configuration."""
|
||||
req = merge(
|
||||
self.query_for_command("get_config"),
|
||||
@ -74,18 +74,18 @@ class AmbientLight(IotModule):
|
||||
"""Return True if the module is enabled."""
|
||||
return int(self.data["get_current_brt"]["value"])
|
||||
|
||||
async def set_enabled(self, state: bool):
|
||||
async def set_enabled(self, state: bool) -> dict:
|
||||
"""Enable/disable LAS."""
|
||||
return await self.call("set_enable", {"enable": int(state)})
|
||||
|
||||
async def current_brightness(self) -> int:
|
||||
async def current_brightness(self) -> dict:
|
||||
"""Return current brightness.
|
||||
|
||||
Return value units.
|
||||
"""
|
||||
return await self.call("get_current_brt")
|
||||
|
||||
async def set_brightness_limit(self, value: int):
|
||||
async def set_brightness_limit(self, value: int) -> dict:
|
||||
"""Set the limit when the motion sensor is inactive.
|
||||
|
||||
See `presets` for preset values. Custom values are also likely allowed.
|
||||
|
@ -24,7 +24,7 @@ class CloudInfo(BaseModel):
|
||||
class Cloud(IotModule):
|
||||
"""Module implementing support for cloud services."""
|
||||
|
||||
def _initialize_features(self):
|
||||
def _initialize_features(self) -> None:
|
||||
"""Initialize features after the initial update."""
|
||||
self._add_feature(
|
||||
Feature(
|
||||
@ -44,7 +44,7 @@ class Cloud(IotModule):
|
||||
"""Return true if device is connected to the cloud."""
|
||||
return self.info.binded
|
||||
|
||||
def query(self):
|
||||
def query(self) -> dict:
|
||||
"""Request cloud connectivity info."""
|
||||
return self.query_for_command("get_info")
|
||||
|
||||
@ -53,20 +53,20 @@ class Cloud(IotModule):
|
||||
"""Return information about the cloud connectivity."""
|
||||
return CloudInfo.parse_obj(self.data["get_info"])
|
||||
|
||||
def get_available_firmwares(self):
|
||||
def get_available_firmwares(self) -> dict:
|
||||
"""Return list of available firmwares."""
|
||||
return self.query_for_command("get_intl_fw_list")
|
||||
|
||||
def set_server(self, url: str):
|
||||
def set_server(self, url: str) -> dict:
|
||||
"""Set the update server URL."""
|
||||
return self.query_for_command("set_server_url", {"server": url})
|
||||
|
||||
def connect(self, username: str, password: str):
|
||||
def connect(self, username: str, password: str) -> dict:
|
||||
"""Login to the cloud using given information."""
|
||||
return self.query_for_command(
|
||||
"bind", {"username": username, "password": password}
|
||||
)
|
||||
|
||||
def disconnect(self):
|
||||
def disconnect(self) -> dict:
|
||||
"""Disconnect from the cloud."""
|
||||
return self.query_for_command("unbind")
|
||||
|
@ -70,7 +70,7 @@ class Emeter(Usage, EnergyInterface):
|
||||
"""Get the current voltage in V."""
|
||||
return self.status.voltage
|
||||
|
||||
async def erase_stats(self):
|
||||
async def erase_stats(self) -> dict:
|
||||
"""Erase all stats.
|
||||
|
||||
Uses different query than usage meter.
|
||||
@ -81,7 +81,9 @@ class Emeter(Usage, EnergyInterface):
|
||||
"""Return real-time statistics."""
|
||||
return EmeterStatus(await self.call("get_realtime"))
|
||||
|
||||
async def get_daily_stats(self, *, year=None, month=None, kwh=True) -> dict:
|
||||
async def get_daily_stats(
|
||||
self, *, year: int | None = None, month: int | None = None, kwh: bool = True
|
||||
) -> dict:
|
||||
"""Return daily stats for the given year & month.
|
||||
|
||||
The return value is a dictionary of {day: energy, ...}.
|
||||
@ -90,7 +92,9 @@ class Emeter(Usage, EnergyInterface):
|
||||
data = self._convert_stat_data(data["day_list"], entry_key="day", kwh=kwh)
|
||||
return data
|
||||
|
||||
async def get_monthly_stats(self, *, year=None, kwh=True) -> dict:
|
||||
async def get_monthly_stats(
|
||||
self, *, year: int | None = None, kwh: bool = True
|
||||
) -> dict:
|
||||
"""Return monthly stats for the given year.
|
||||
|
||||
The return value is a dictionary of {month: energy, ...}.
|
||||
|
@ -14,7 +14,7 @@ class Led(IotModule, LedInterface):
|
||||
return {}
|
||||
|
||||
@property
|
||||
def mode(self):
|
||||
def mode(self) -> str:
|
||||
"""LED mode setting.
|
||||
|
||||
"always", "never"
|
||||
@ -27,7 +27,7 @@ class Led(IotModule, LedInterface):
|
||||
sys_info = self.data
|
||||
return bool(1 - sys_info["led_off"])
|
||||
|
||||
async def set_led(self, state: bool):
|
||||
async def set_led(self, state: bool) -> dict:
|
||||
"""Set the state of the led (night mode)."""
|
||||
return await self.call("set_led_off", {"off": int(not state)})
|
||||
|
||||
|
@ -27,7 +27,7 @@ class Light(IotModule, LightInterface):
|
||||
_device: IotBulb | IotDimmer
|
||||
_light_state: LightState
|
||||
|
||||
def _initialize_features(self):
|
||||
def _initialize_features(self) -> None:
|
||||
"""Initialize features."""
|
||||
super()._initialize_features()
|
||||
device = self._device
|
||||
@ -185,7 +185,7 @@ class Light(IotModule, LightInterface):
|
||||
return bulb._color_temp
|
||||
|
||||
async def set_color_temp(
|
||||
self, temp: int, *, brightness=None, transition: int | None = None
|
||||
self, temp: int, *, brightness: int | None = None, transition: int | None = None
|
||||
) -> dict:
|
||||
"""Set the color temperature of the device in kelvin.
|
||||
|
||||
|
@ -50,7 +50,7 @@ class LightEffect(IotModule, LightEffectInterface):
|
||||
*,
|
||||
brightness: int | None = None,
|
||||
transition: int | None = None,
|
||||
) -> None:
|
||||
) -> dict:
|
||||
"""Set an effect on the device.
|
||||
|
||||
If brightness or transition is defined,
|
||||
@ -73,7 +73,7 @@ class LightEffect(IotModule, LightEffectInterface):
|
||||
effect_dict = EFFECT_MAPPING_V1["Aurora"]
|
||||
effect_dict = {**effect_dict}
|
||||
effect_dict["enable"] = 0
|
||||
await self.set_custom_effect(effect_dict)
|
||||
return await self.set_custom_effect(effect_dict)
|
||||
elif effect not in EFFECT_MAPPING_V1:
|
||||
raise ValueError(f"The effect {effect} is not a built in effect.")
|
||||
else:
|
||||
@ -84,12 +84,12 @@ class LightEffect(IotModule, LightEffectInterface):
|
||||
if transition is not None:
|
||||
effect_dict["transition"] = transition
|
||||
|
||||
await self.set_custom_effect(effect_dict)
|
||||
return await self.set_custom_effect(effect_dict)
|
||||
|
||||
async def set_custom_effect(
|
||||
self,
|
||||
effect_dict: dict,
|
||||
) -> None:
|
||||
) -> dict:
|
||||
"""Set a custom effect on the device.
|
||||
|
||||
:param str effect_dict: The custom effect dict to set
|
||||
@ -104,7 +104,7 @@ class LightEffect(IotModule, LightEffectInterface):
|
||||
"""Return True if the device supports setting custom effects."""
|
||||
return True
|
||||
|
||||
def query(self):
|
||||
def query(self) -> dict:
|
||||
"""Return the base query."""
|
||||
return {}
|
||||
|
||||
|
@ -41,7 +41,7 @@ class LightPreset(IotModule, LightPresetInterface):
|
||||
_presets: dict[str, IotLightPreset]
|
||||
_preset_list: list[str]
|
||||
|
||||
async def _post_update_hook(self):
|
||||
async def _post_update_hook(self) -> None:
|
||||
"""Update the internal presets."""
|
||||
self._presets = {
|
||||
f"Light preset {index+1}": IotLightPreset(**vals)
|
||||
@ -93,7 +93,7 @@ class LightPreset(IotModule, LightPresetInterface):
|
||||
async def set_preset(
|
||||
self,
|
||||
preset_name: str,
|
||||
) -> None:
|
||||
) -> dict:
|
||||
"""Set a light preset for the device."""
|
||||
light = self._device.modules[Module.Light]
|
||||
if preset_name == self.PRESET_NOT_SET:
|
||||
@ -104,7 +104,7 @@ class LightPreset(IotModule, LightPresetInterface):
|
||||
elif (preset := self._presets.get(preset_name)) is None: # type: ignore[assignment]
|
||||
raise ValueError(f"{preset_name} is not a valid preset: {self.preset_list}")
|
||||
|
||||
await light.set_state(preset)
|
||||
return await light.set_state(preset)
|
||||
|
||||
@property
|
||||
def has_save_preset(self) -> bool:
|
||||
@ -115,7 +115,7 @@ class LightPreset(IotModule, LightPresetInterface):
|
||||
self,
|
||||
preset_name: str,
|
||||
preset_state: LightState,
|
||||
) -> None:
|
||||
) -> dict:
|
||||
"""Update the preset with preset_name with the new preset_info."""
|
||||
if len(self._presets) == 0:
|
||||
raise KasaException("Device does not supported saving presets")
|
||||
@ -129,7 +129,7 @@ class LightPreset(IotModule, LightPresetInterface):
|
||||
|
||||
return await self.call("set_preferred_state", state)
|
||||
|
||||
def query(self):
|
||||
def query(self) -> dict:
|
||||
"""Return the base query."""
|
||||
return {}
|
||||
|
||||
@ -142,7 +142,7 @@ class LightPreset(IotModule, LightPresetInterface):
|
||||
if "id" not in vals
|
||||
]
|
||||
|
||||
async def _deprecated_save_preset(self, preset: IotLightPreset):
|
||||
async def _deprecated_save_preset(self, preset: IotLightPreset) -> dict:
|
||||
"""Save a setting preset.
|
||||
|
||||
You can either construct a preset object manually, or pass an existing one
|
||||
|
@ -24,7 +24,7 @@ class Range(Enum):
|
||||
class Motion(IotModule):
|
||||
"""Implements the motion detection (PIR) module."""
|
||||
|
||||
def _initialize_features(self):
|
||||
def _initialize_features(self) -> None:
|
||||
"""Initialize features after the initial update."""
|
||||
# Only add features if the device supports the module
|
||||
if "get_config" not in self.data:
|
||||
@ -48,7 +48,7 @@ class Motion(IotModule):
|
||||
)
|
||||
)
|
||||
|
||||
def query(self):
|
||||
def query(self) -> dict:
|
||||
"""Request PIR configuration."""
|
||||
return self.query_for_command("get_config")
|
||||
|
||||
@ -67,13 +67,13 @@ class Motion(IotModule):
|
||||
"""Return True if module is enabled."""
|
||||
return bool(self.config["enable"])
|
||||
|
||||
async def set_enabled(self, state: bool):
|
||||
async def set_enabled(self, state: bool) -> dict:
|
||||
"""Enable/disable PIR."""
|
||||
return await self.call("set_enable", {"enable": int(state)})
|
||||
|
||||
async def set_range(
|
||||
self, *, range: Range | None = None, custom_range: int | None = None
|
||||
):
|
||||
) -> dict:
|
||||
"""Set the range for the sensor.
|
||||
|
||||
:param range: for using standard ranges
|
||||
@ -93,7 +93,7 @@ class Motion(IotModule):
|
||||
"""Return inactivity timeout in milliseconds."""
|
||||
return self.config["cold_time"]
|
||||
|
||||
async def set_inactivity_timeout(self, timeout: int):
|
||||
async def set_inactivity_timeout(self, timeout: int) -> dict:
|
||||
"""Set inactivity timeout in milliseconds.
|
||||
|
||||
Note, that you need to delete the default "Smart Control" rule in the app
|
||||
|
@ -57,7 +57,7 @@ _LOGGER = logging.getLogger(__name__)
|
||||
class RuleModule(IotModule):
|
||||
"""Base class for rule-based modules, such as countdown and antitheft."""
|
||||
|
||||
def query(self):
|
||||
def query(self) -> dict:
|
||||
"""Prepare the query for rules."""
|
||||
q = self.query_for_command("get_rules")
|
||||
return merge(q, self.query_for_command("get_next_action"))
|
||||
@ -73,14 +73,14 @@ class RuleModule(IotModule):
|
||||
_LOGGER.error("Unable to read rule list: %s (data: %s)", ex, self.data)
|
||||
return []
|
||||
|
||||
async def set_enabled(self, state: bool):
|
||||
async def set_enabled(self, state: bool) -> dict:
|
||||
"""Enable or disable the service."""
|
||||
return await self.call("set_overall_enable", state)
|
||||
return await self.call("set_overall_enable", {"enable": state})
|
||||
|
||||
async def delete_rule(self, rule: Rule):
|
||||
async def delete_rule(self, rule: Rule) -> dict:
|
||||
"""Delete the given rule."""
|
||||
return await self.call("delete_rule", {"id": rule.id})
|
||||
|
||||
async def delete_all_rules(self):
|
||||
async def delete_all_rules(self) -> dict:
|
||||
"""Delete all rules."""
|
||||
return await self.call("delete_all_rules")
|
||||
|
@ -15,14 +15,14 @@ class Time(IotModule, TimeInterface):
|
||||
|
||||
_timezone: tzinfo = timezone.utc
|
||||
|
||||
def query(self):
|
||||
def query(self) -> dict:
|
||||
"""Request time and timezone."""
|
||||
q = self.query_for_command("get_time")
|
||||
|
||||
merge(q, self.query_for_command("get_timezone"))
|
||||
return q
|
||||
|
||||
async def _post_update_hook(self):
|
||||
async def _post_update_hook(self) -> None:
|
||||
"""Perform actions after a device update."""
|
||||
if res := self.data.get("get_timezone"):
|
||||
self._timezone = await get_timezone(res.get("index"))
|
||||
@ -47,7 +47,7 @@ class Time(IotModule, TimeInterface):
|
||||
"""Return current timezone."""
|
||||
return self._timezone
|
||||
|
||||
async def get_time(self):
|
||||
async def get_time(self) -> datetime | None:
|
||||
"""Return current device time."""
|
||||
try:
|
||||
res = await self.call("get_time")
|
||||
@ -88,6 +88,6 @@ class Time(IotModule, TimeInterface):
|
||||
except Exception as ex:
|
||||
raise KasaException(ex) from ex
|
||||
|
||||
async def get_timezone(self):
|
||||
async def get_timezone(self) -> dict:
|
||||
"""Request timezone information from the device."""
|
||||
return await self.call("get_timezone")
|
||||
|
@ -10,7 +10,7 @@ from ..iotmodule import IotModule, merge
|
||||
class Usage(IotModule):
|
||||
"""Baseclass for emeter/usage interfaces."""
|
||||
|
||||
def query(self):
|
||||
def query(self) -> dict:
|
||||
"""Return the base query."""
|
||||
now = datetime.now()
|
||||
year = now.year
|
||||
@ -25,22 +25,22 @@ class Usage(IotModule):
|
||||
return req
|
||||
|
||||
@property
|
||||
def estimated_query_response_size(self):
|
||||
def estimated_query_response_size(self) -> int:
|
||||
"""Estimated maximum query response size."""
|
||||
return 2048
|
||||
|
||||
@property
|
||||
def daily_data(self):
|
||||
def daily_data(self) -> list[dict]:
|
||||
"""Return statistics on daily basis."""
|
||||
return self.data["get_daystat"]["day_list"]
|
||||
|
||||
@property
|
||||
def monthly_data(self):
|
||||
def monthly_data(self) -> list[dict]:
|
||||
"""Return statistics on monthly basis."""
|
||||
return self.data["get_monthstat"]["month_list"]
|
||||
|
||||
@property
|
||||
def usage_today(self):
|
||||
def usage_today(self) -> int | None:
|
||||
"""Return today's usage in minutes."""
|
||||
today = datetime.now().day
|
||||
# Traverse the list in reverse order to find the latest entry.
|
||||
@ -50,7 +50,7 @@ class Usage(IotModule):
|
||||
return None
|
||||
|
||||
@property
|
||||
def usage_this_month(self):
|
||||
def usage_this_month(self) -> int | None:
|
||||
"""Return usage in this month in minutes."""
|
||||
this_month = datetime.now().month
|
||||
# Traverse the list in reverse order to find the latest entry.
|
||||
@ -59,7 +59,9 @@ class Usage(IotModule):
|
||||
return entry["time"]
|
||||
return None
|
||||
|
||||
async def get_raw_daystat(self, *, year=None, month=None) -> dict:
|
||||
async def get_raw_daystat(
|
||||
self, *, year: int | None = None, month: int | None = None
|
||||
) -> dict:
|
||||
"""Return raw daily stats for the given year & month."""
|
||||
if year is None:
|
||||
year = datetime.now().year
|
||||
@ -68,14 +70,16 @@ class Usage(IotModule):
|
||||
|
||||
return await self.call("get_daystat", {"year": year, "month": month})
|
||||
|
||||
async def get_raw_monthstat(self, *, year=None) -> dict:
|
||||
async def get_raw_monthstat(self, *, year: int | None = None) -> dict:
|
||||
"""Return raw monthly stats for the given year."""
|
||||
if year is None:
|
||||
year = datetime.now().year
|
||||
|
||||
return await self.call("get_monthstat", {"year": year})
|
||||
|
||||
async def get_daystat(self, *, year=None, month=None) -> dict:
|
||||
async def get_daystat(
|
||||
self, *, year: int | None = None, month: int | None = None
|
||||
) -> dict:
|
||||
"""Return daily stats for the given year & month.
|
||||
|
||||
The return value is a dictionary of {day: time, ...}.
|
||||
@ -84,7 +88,7 @@ class Usage(IotModule):
|
||||
data = self._convert_stat_data(data["day_list"], entry_key="day")
|
||||
return data
|
||||
|
||||
async def get_monthstat(self, *, year=None) -> dict:
|
||||
async def get_monthstat(self, *, year: int | None = None) -> dict:
|
||||
"""Return monthly stats for the given year.
|
||||
|
||||
The return value is a dictionary of {month: time, ...}.
|
||||
@ -93,11 +97,11 @@ class Usage(IotModule):
|
||||
data = self._convert_stat_data(data["month_list"], entry_key="month")
|
||||
return data
|
||||
|
||||
async def erase_stats(self):
|
||||
async def erase_stats(self) -> dict:
|
||||
"""Erase all stats."""
|
||||
return await self.call("erase_runtime_stat")
|
||||
|
||||
def _convert_stat_data(self, data, entry_key) -> dict:
|
||||
def _convert_stat_data(self, data: list[dict], entry_key: str) -> dict:
|
||||
"""Return usage information keyed with the day/month.
|
||||
|
||||
The incoming data is a list of dictionaries::
|
||||
@ -113,6 +117,6 @@ class Usage(IotModule):
|
||||
if not data:
|
||||
return {}
|
||||
|
||||
data = {entry[entry_key]: entry["time"] for entry in data}
|
||||
res = {entry[entry_key]: entry["time"] for entry in data}
|
||||
|
||||
return data
|
||||
return res
|
||||
|
@ -1,9 +1,13 @@
|
||||
"""JSON abstraction."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable
|
||||
|
||||
try:
|
||||
import orjson
|
||||
|
||||
def dumps(obj, *, default=None):
|
||||
def dumps(obj: Any, *, default: Callable | None = None) -> str:
|
||||
"""Dump JSON."""
|
||||
return orjson.dumps(obj).decode()
|
||||
|
||||
@ -11,7 +15,7 @@ try:
|
||||
except ImportError:
|
||||
import json
|
||||
|
||||
def dumps(obj, *, default=None):
|
||||
def dumps(obj: Any, *, default: Callable | None = None) -> str:
|
||||
"""Dump JSON."""
|
||||
# Separators specified for consistency with orjson
|
||||
return json.dumps(obj, separators=(",", ":"))
|
||||
|
@ -50,7 +50,8 @@ import logging
|
||||
import secrets
|
||||
import struct
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from asyncio import Future
|
||||
from typing import TYPE_CHECKING, Any, Generator, cast
|
||||
|
||||
from cryptography.hazmat.primitives import padding
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
@ -110,10 +111,10 @@ class KlapTransport(BaseTransport):
|
||||
else:
|
||||
self._local_auth_hash = base64.b64decode(self._credentials_hash.encode()) # type: ignore[union-attr]
|
||||
self._default_credentials_auth_hash: dict[str, bytes] = {}
|
||||
self._blank_auth_hash = None
|
||||
self._blank_auth_hash: bytes | None = None
|
||||
self._handshake_lock = asyncio.Lock()
|
||||
self._query_lock = asyncio.Lock()
|
||||
self._handshake_done = False
|
||||
self._handshake_done: bool = False
|
||||
|
||||
self._encryption_session: KlapEncryptionSession | None = None
|
||||
self._session_expire_at: float | None = None
|
||||
@ -125,7 +126,7 @@ class KlapTransport(BaseTransport):
|
||||
self._request_url = self._app_url / "request"
|
||||
|
||||
@property
|
||||
def default_port(self):
|
||||
def default_port(self) -> int:
|
||||
"""Default port for the transport."""
|
||||
return self.DEFAULT_PORT
|
||||
|
||||
@ -242,7 +243,7 @@ class KlapTransport(BaseTransport):
|
||||
raise AuthenticationError(msg)
|
||||
|
||||
async def perform_handshake2(
|
||||
self, local_seed, remote_seed, auth_hash
|
||||
self, local_seed: bytes, remote_seed: bytes, auth_hash: bytes
|
||||
) -> KlapEncryptionSession:
|
||||
"""Perform handshake2."""
|
||||
# Handshake 2 has the following payload:
|
||||
@ -277,7 +278,7 @@ class KlapTransport(BaseTransport):
|
||||
|
||||
return KlapEncryptionSession(local_seed, remote_seed, auth_hash)
|
||||
|
||||
async def perform_handshake(self) -> Any:
|
||||
async def perform_handshake(self) -> None:
|
||||
"""Perform handshake1 and handshake2.
|
||||
|
||||
Sets the encryption_session if successful.
|
||||
@ -309,14 +310,14 @@ class KlapTransport(BaseTransport):
|
||||
|
||||
_LOGGER.debug("Handshake with %s complete", self._host)
|
||||
|
||||
def _handshake_session_expired(self):
|
||||
def _handshake_session_expired(self) -> bool:
|
||||
"""Return true if session has expired."""
|
||||
return (
|
||||
self._session_expire_at is None
|
||||
or self._session_expire_at - time.monotonic() <= 0
|
||||
)
|
||||
|
||||
async def send(self, request: str):
|
||||
async def send(self, request: str) -> Generator[Future, None, dict[str, str]]: # type: ignore[override]
|
||||
"""Send the request."""
|
||||
if not self._handshake_done or self._handshake_session_expired():
|
||||
await self.perform_handshake()
|
||||
@ -355,6 +356,7 @@ class KlapTransport(BaseTransport):
|
||||
|
||||
if TYPE_CHECKING:
|
||||
assert self._encryption_session
|
||||
assert isinstance(response_data, bytes)
|
||||
try:
|
||||
decrypted_response = self._encryption_session.decrypt(response_data)
|
||||
except Exception as ex:
|
||||
@ -378,7 +380,7 @@ class KlapTransport(BaseTransport):
|
||||
self._handshake_done = False
|
||||
|
||||
@staticmethod
|
||||
def generate_auth_hash(creds: Credentials):
|
||||
def generate_auth_hash(creds: Credentials) -> bytes:
|
||||
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
|
||||
un = creds.username
|
||||
pw = creds.password
|
||||
@ -388,19 +390,19 @@ class KlapTransport(BaseTransport):
|
||||
@staticmethod
|
||||
def handshake1_seed_auth_hash(
|
||||
local_seed: bytes, remote_seed: bytes, auth_hash: bytes
|
||||
):
|
||||
) -> bytes:
|
||||
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
|
||||
return _sha256(local_seed + auth_hash)
|
||||
|
||||
@staticmethod
|
||||
def handshake2_seed_auth_hash(
|
||||
local_seed: bytes, remote_seed: bytes, auth_hash: bytes
|
||||
):
|
||||
) -> bytes:
|
||||
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
|
||||
return _sha256(remote_seed + auth_hash)
|
||||
|
||||
@staticmethod
|
||||
def generate_owner_hash(creds: Credentials):
|
||||
def generate_owner_hash(creds: Credentials) -> bytes:
|
||||
"""Return the MD5 hash of the username in this object."""
|
||||
un = creds.username
|
||||
return md5(un.encode())
|
||||
@ -410,7 +412,7 @@ class KlapTransportV2(KlapTransport):
|
||||
"""Implementation of the KLAP encryption protocol with v2 hanshake hashes."""
|
||||
|
||||
@staticmethod
|
||||
def generate_auth_hash(creds: Credentials):
|
||||
def generate_auth_hash(creds: Credentials) -> bytes:
|
||||
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
|
||||
un = creds.username
|
||||
pw = creds.password
|
||||
@ -420,14 +422,14 @@ class KlapTransportV2(KlapTransport):
|
||||
@staticmethod
|
||||
def handshake1_seed_auth_hash(
|
||||
local_seed: bytes, remote_seed: bytes, auth_hash: bytes
|
||||
):
|
||||
) -> bytes:
|
||||
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
|
||||
return _sha256(local_seed + remote_seed + auth_hash)
|
||||
|
||||
@staticmethod
|
||||
def handshake2_seed_auth_hash(
|
||||
local_seed: bytes, remote_seed: bytes, auth_hash: bytes
|
||||
):
|
||||
) -> bytes:
|
||||
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
|
||||
return _sha256(remote_seed + local_seed + auth_hash)
|
||||
|
||||
@ -440,7 +442,7 @@ class KlapEncryptionSession:
|
||||
|
||||
_cipher: Cipher
|
||||
|
||||
def __init__(self, local_seed, remote_seed, user_hash):
|
||||
def __init__(self, local_seed: bytes, remote_seed: bytes, user_hash: bytes) -> None:
|
||||
self.local_seed = local_seed
|
||||
self.remote_seed = remote_seed
|
||||
self.user_hash = user_hash
|
||||
@ -449,11 +451,15 @@ class KlapEncryptionSession:
|
||||
self._aes = algorithms.AES(self._key)
|
||||
self._sig = self._sig_derive(local_seed, remote_seed, user_hash)
|
||||
|
||||
def _key_derive(self, local_seed, remote_seed, user_hash):
|
||||
def _key_derive(
|
||||
self, local_seed: bytes, remote_seed: bytes, user_hash: bytes
|
||||
) -> bytes:
|
||||
payload = b"lsk" + local_seed + remote_seed + user_hash
|
||||
return hashlib.sha256(payload).digest()[:16]
|
||||
|
||||
def _iv_derive(self, local_seed, remote_seed, user_hash):
|
||||
def _iv_derive(
|
||||
self, local_seed: bytes, remote_seed: bytes, user_hash: bytes
|
||||
) -> tuple[bytes, int]:
|
||||
# iv is first 16 bytes of sha256, where the last 4 bytes forms the
|
||||
# sequence number used in requests and is incremented on each request
|
||||
payload = b"iv" + local_seed + remote_seed + user_hash
|
||||
@ -461,17 +467,19 @@ class KlapEncryptionSession:
|
||||
seq = int.from_bytes(fulliv[-4:], "big", signed=True)
|
||||
return (fulliv[:12], seq)
|
||||
|
||||
def _sig_derive(self, local_seed, remote_seed, user_hash):
|
||||
def _sig_derive(
|
||||
self, local_seed: bytes, remote_seed: bytes, user_hash: bytes
|
||||
) -> bytes:
|
||||
# used to create a hash with which to prefix each request
|
||||
payload = b"ldk" + local_seed + remote_seed + user_hash
|
||||
return hashlib.sha256(payload).digest()[:28]
|
||||
|
||||
def _generate_cipher(self):
|
||||
def _generate_cipher(self) -> None:
|
||||
iv_seq = self._iv + PACK_SIGNED_LONG(self._seq)
|
||||
cbc = modes.CBC(iv_seq)
|
||||
self._cipher = Cipher(self._aes, cbc)
|
||||
|
||||
def encrypt(self, msg):
|
||||
def encrypt(self, msg: bytes | str) -> tuple[bytes, int]:
|
||||
"""Encrypt the data and increment the sequence number."""
|
||||
self._seq += 1
|
||||
self._generate_cipher()
|
||||
@ -488,7 +496,7 @@ class KlapEncryptionSession:
|
||||
).digest()
|
||||
return (signature + ciphertext, self._seq)
|
||||
|
||||
def decrypt(self, msg):
|
||||
def decrypt(self, msg: bytes) -> str:
|
||||
"""Decrypt the data."""
|
||||
decryptor = self._cipher.decryptor()
|
||||
dp = decryptor.update(msg[32:]) + decryptor.finalize()
|
||||
|
@ -135,13 +135,13 @@ class Module(ABC):
|
||||
# SMARTCAMERA only modules
|
||||
Camera: Final[ModuleName[experimental.Camera]] = ModuleName("Camera")
|
||||
|
||||
def __init__(self, device: Device, module: str):
|
||||
def __init__(self, device: Device, module: str) -> None:
|
||||
self._device = device
|
||||
self._module = module
|
||||
self._module_features: dict[str, Feature] = {}
|
||||
|
||||
@abstractmethod
|
||||
def query(self):
|
||||
def query(self) -> dict:
|
||||
"""Query to execute during the update cycle.
|
||||
|
||||
The inheriting modules implement this to include their wanted
|
||||
@ -150,10 +150,10 @@ class Module(ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def data(self):
|
||||
def data(self) -> dict:
|
||||
"""Return the module specific raw data from the last update."""
|
||||
|
||||
def _initialize_features(self): # noqa: B027
|
||||
def _initialize_features(self) -> None: # noqa: B027
|
||||
"""Initialize features after the initial update.
|
||||
|
||||
This can be implemented if features depend on module query responses.
|
||||
@ -162,7 +162,7 @@ class Module(ABC):
|
||||
children's modules.
|
||||
"""
|
||||
|
||||
async def _post_update_hook(self): # noqa: B027
|
||||
async def _post_update_hook(self) -> None: # noqa: B027
|
||||
"""Perform actions after a device update.
|
||||
|
||||
This can be implemented if a module needs to perform actions each time
|
||||
@ -171,11 +171,11 @@ class Module(ABC):
|
||||
*_initialize_features* on the first update.
|
||||
"""
|
||||
|
||||
def _add_feature(self, feature: Feature):
|
||||
def _add_feature(self, feature: Feature) -> None:
|
||||
"""Add module feature."""
|
||||
id_ = feature.id
|
||||
if id_ in self._module_features:
|
||||
raise KasaException("Duplicate id detected %s" % id_)
|
||||
raise KasaException(f"Duplicate id detected {id_}")
|
||||
self._module_features[id_] = feature
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
@ -130,7 +130,7 @@ class BaseProtocol(ABC):
|
||||
self._transport = transport
|
||||
|
||||
@property
|
||||
def _host(self):
|
||||
def _host(self) -> str:
|
||||
return self._transport._host
|
||||
|
||||
@property
|
||||
|
@ -15,7 +15,9 @@ class SmartLightEffect(LightEffectInterface, ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def set_brightness(self, brightness: int, *, transition: int | None = None):
|
||||
async def set_brightness(
|
||||
self, brightness: int, *, transition: int | None = None
|
||||
) -> dict:
|
||||
"""Set effect brightness."""
|
||||
|
||||
@property
|
||||
|
@ -20,7 +20,7 @@ class Alarm(SmartModule):
|
||||
"get_support_alarm_type_list": None, # This should be needed only once
|
||||
}
|
||||
|
||||
def _initialize_features(self):
|
||||
def _initialize_features(self) -> None:
|
||||
"""Initialize features.
|
||||
|
||||
This is implemented as some features depend on device responses.
|
||||
@ -100,7 +100,7 @@ class Alarm(SmartModule):
|
||||
"""Return current alarm sound."""
|
||||
return self.data["get_alarm_configure"]["type"]
|
||||
|
||||
async def set_alarm_sound(self, sound: str):
|
||||
async def set_alarm_sound(self, sound: str) -> dict:
|
||||
"""Set alarm sound.
|
||||
|
||||
See *alarm_sounds* for list of available sounds.
|
||||
@ -119,7 +119,7 @@ class Alarm(SmartModule):
|
||||
"""Return alarm volume."""
|
||||
return self.data["get_alarm_configure"]["volume"]
|
||||
|
||||
async def set_alarm_volume(self, volume: Literal["low", "normal", "high"]):
|
||||
async def set_alarm_volume(self, volume: Literal["low", "normal", "high"]) -> dict:
|
||||
"""Set alarm volume."""
|
||||
payload = self.data["get_alarm_configure"].copy()
|
||||
payload["volume"] = volume
|
||||
|
@ -17,7 +17,7 @@ class AutoOff(SmartModule):
|
||||
REQUIRED_COMPONENT = "auto_off"
|
||||
QUERY_GETTER_NAME = "get_auto_off_config"
|
||||
|
||||
def _initialize_features(self):
|
||||
def _initialize_features(self) -> None:
|
||||
"""Initialize features after the initial update."""
|
||||
self._add_feature(
|
||||
Feature(
|
||||
@ -63,7 +63,7 @@ class AutoOff(SmartModule):
|
||||
"""Return True if enabled."""
|
||||
return self.data["enable"]
|
||||
|
||||
async def set_enabled(self, enable: bool):
|
||||
async def set_enabled(self, enable: bool) -> dict:
|
||||
"""Enable/disable auto off."""
|
||||
return await self.call(
|
||||
"set_auto_off_config",
|
||||
@ -75,7 +75,7 @@ class AutoOff(SmartModule):
|
||||
"""Return time until auto off."""
|
||||
return self.data["delay_min"]
|
||||
|
||||
async def set_delay(self, delay: int):
|
||||
async def set_delay(self, delay: int) -> dict:
|
||||
"""Set time until auto off."""
|
||||
return await self.call(
|
||||
"set_auto_off_config", {"delay_min": delay, "enable": self.data["enable"]}
|
||||
@ -96,7 +96,7 @@ class AutoOff(SmartModule):
|
||||
|
||||
return self._device.time + timedelta(seconds=sysinfo["auto_off_remain_time"])
|
||||
|
||||
async def _check_supported(self):
|
||||
async def _check_supported(self) -> bool:
|
||||
"""Additional check to see if the module is supported by the device.
|
||||
|
||||
Parent devices that report components of children such as P300 will not have
|
||||
|
@ -12,7 +12,7 @@ class BatterySensor(SmartModule):
|
||||
REQUIRED_COMPONENT = "battery_detect"
|
||||
QUERY_GETTER_NAME = "get_battery_detect_info"
|
||||
|
||||
def _initialize_features(self):
|
||||
def _initialize_features(self) -> None:
|
||||
"""Initialize features."""
|
||||
self._add_feature(
|
||||
Feature(
|
||||
@ -48,11 +48,11 @@ class BatterySensor(SmartModule):
|
||||
return {}
|
||||
|
||||
@property
|
||||
def battery(self):
|
||||
def battery(self) -> int:
|
||||
"""Return battery level."""
|
||||
return self._device.sys_info["battery_percentage"]
|
||||
|
||||
@property
|
||||
def battery_low(self):
|
||||
def battery_low(self) -> bool:
|
||||
"""Return True if battery is low."""
|
||||
return self._device.sys_info["at_low_battery"]
|
||||
|
@ -14,7 +14,7 @@ class Brightness(SmartModule):
|
||||
|
||||
REQUIRED_COMPONENT = "brightness"
|
||||
|
||||
def _initialize_features(self):
|
||||
def _initialize_features(self) -> None:
|
||||
"""Initialize features."""
|
||||
super()._initialize_features()
|
||||
|
||||
@ -39,7 +39,7 @@ class Brightness(SmartModule):
|
||||
return {}
|
||||
|
||||
@property
|
||||
def brightness(self):
|
||||
def brightness(self) -> int:
|
||||
"""Return current brightness."""
|
||||
# If the device supports effects and one is active, use its brightness
|
||||
if (
|
||||
@ -49,7 +49,9 @@ class Brightness(SmartModule):
|
||||
|
||||
return self.data["brightness"]
|
||||
|
||||
async def set_brightness(self, brightness: int, *, transition: int | None = None):
|
||||
async def set_brightness(
|
||||
self, brightness: int, *, transition: int | None = None
|
||||
) -> dict:
|
||||
"""Set the brightness. A brightness value of 0 will turn off the light.
|
||||
|
||||
Note, transition is not supported and will be ignored.
|
||||
@ -73,6 +75,6 @@ class Brightness(SmartModule):
|
||||
|
||||
return await self.call("set_device_info", {"brightness": brightness})
|
||||
|
||||
async def _check_supported(self):
|
||||
async def _check_supported(self) -> bool:
|
||||
"""Additional check to see if the module is supported by the device."""
|
||||
return "brightness" in self.data
|
||||
|
@ -12,7 +12,7 @@ class ChildProtection(SmartModule):
|
||||
REQUIRED_COMPONENT = "child_protection"
|
||||
QUERY_GETTER_NAME = "get_child_protection"
|
||||
|
||||
def _initialize_features(self):
|
||||
def _initialize_features(self) -> None:
|
||||
"""Initialize features after the initial update."""
|
||||
self._add_feature(
|
||||
Feature(
|
||||
|
@ -13,7 +13,7 @@ class Cloud(SmartModule):
|
||||
REQUIRED_COMPONENT = "cloud_connect"
|
||||
MINIMUM_UPDATE_INTERVAL_SECS = 60
|
||||
|
||||
def _initialize_features(self):
|
||||
def _initialize_features(self) -> None:
|
||||
"""Initialize features after the initial update."""
|
||||
self._add_feature(
|
||||
Feature(
|
||||
@ -29,7 +29,7 @@ class Cloud(SmartModule):
|
||||
)
|
||||
|
||||
@property
|
||||
def is_connected(self):
|
||||
def is_connected(self) -> bool:
|
||||
"""Return True if device is connected to the cloud."""
|
||||
if self._has_data_error():
|
||||
return False
|
||||
|
@ -12,7 +12,7 @@ class Color(SmartModule):
|
||||
|
||||
REQUIRED_COMPONENT = "color"
|
||||
|
||||
def _initialize_features(self):
|
||||
def _initialize_features(self) -> None:
|
||||
"""Initialize features after the initial update."""
|
||||
self._add_feature(
|
||||
Feature(
|
||||
@ -48,7 +48,7 @@ class Color(SmartModule):
|
||||
# due to the cpython implementation.
|
||||
return tuple.__new__(HSV, (h, s, v))
|
||||
|
||||
def _raise_for_invalid_brightness(self, value):
|
||||
def _raise_for_invalid_brightness(self, value: int) -> None:
|
||||
"""Raise error on invalid brightness value."""
|
||||
if not isinstance(value, int):
|
||||
raise TypeError("Brightness must be an integer")
|
||||
|
@ -18,7 +18,7 @@ class ColorTemperature(SmartModule):
|
||||
|
||||
REQUIRED_COMPONENT = "color_temperature"
|
||||
|
||||
def _initialize_features(self):
|
||||
def _initialize_features(self) -> None:
|
||||
"""Initialize features."""
|
||||
self._add_feature(
|
||||
Feature(
|
||||
@ -52,11 +52,11 @@ class ColorTemperature(SmartModule):
|
||||
return ColorTempRange(*ct_range)
|
||||
|
||||
@property
|
||||
def color_temp(self):
|
||||
def color_temp(self) -> int:
|
||||
"""Return current color temperature."""
|
||||
return self.data["color_temp"]
|
||||
|
||||
async def set_color_temp(self, temp: int, *, brightness=None):
|
||||
async def set_color_temp(self, temp: int, *, brightness: int | None = None) -> dict:
|
||||
"""Set the color temperature."""
|
||||
valid_temperature_range = self.valid_temperature_range
|
||||
if temp < valid_temperature_range[0] or temp > valid_temperature_range[1]:
|
||||
|
@ -12,7 +12,7 @@ class ContactSensor(SmartModule):
|
||||
REQUIRED_COMPONENT = None # we depend on availability of key
|
||||
REQUIRED_KEY_ON_PARENT = "open"
|
||||
|
||||
def _initialize_features(self):
|
||||
def _initialize_features(self) -> None:
|
||||
"""Initialize features after the initial update."""
|
||||
self._add_feature(
|
||||
Feature(
|
||||
@ -32,6 +32,6 @@ class ContactSensor(SmartModule):
|
||||
return {}
|
||||
|
||||
@property
|
||||
def is_open(self):
|
||||
def is_open(self) -> bool:
|
||||
"""Return True if the contact sensor is open."""
|
||||
return self._device.sys_info["open"]
|
||||
|
@ -10,7 +10,7 @@ class DeviceModule(SmartModule):
|
||||
|
||||
REQUIRED_COMPONENT = "device"
|
||||
|
||||
async def _post_update_hook(self):
|
||||
async def _post_update_hook(self) -> None:
|
||||
"""Perform actions after a device update.
|
||||
|
||||
Overrides the default behaviour to disable a module if the query returns
|
||||
|
@ -2,6 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import NoReturn
|
||||
|
||||
from ...emeterstatus import EmeterStatus
|
||||
from ...exceptions import KasaException
|
||||
from ...interfaces.energy import Energy as EnergyInterface
|
||||
@ -31,34 +33,34 @@ class Energy(SmartModule, EnergyInterface):
|
||||
# Fallback if get_energy_usage does not provide current_power,
|
||||
# which can happen on some newer devices (e.g. P304M).
|
||||
elif (
|
||||
power := self.data.get("get_current_power").get("current_power")
|
||||
power := self.data.get("get_current_power", {}).get("current_power")
|
||||
) is not None:
|
||||
return power
|
||||
return None
|
||||
|
||||
@property
|
||||
@raise_if_update_error
|
||||
def energy(self):
|
||||
def energy(self) -> dict:
|
||||
"""Return get_energy_usage results."""
|
||||
if en := self.data.get("get_energy_usage"):
|
||||
return en
|
||||
return self.data
|
||||
|
||||
def _get_status_from_energy(self, energy) -> EmeterStatus:
|
||||
def _get_status_from_energy(self, energy: dict) -> EmeterStatus:
|
||||
return EmeterStatus(
|
||||
{
|
||||
"power_mw": energy.get("current_power"),
|
||||
"total": energy.get("today_energy") / 1_000,
|
||||
"power_mw": energy.get("current_power", 0),
|
||||
"total": energy.get("today_energy", 0) / 1_000,
|
||||
}
|
||||
)
|
||||
|
||||
@property
|
||||
@raise_if_update_error
|
||||
def status(self):
|
||||
def status(self) -> EmeterStatus:
|
||||
"""Get the emeter status."""
|
||||
return self._get_status_from_energy(self.energy)
|
||||
|
||||
async def get_status(self):
|
||||
async def get_status(self) -> EmeterStatus:
|
||||
"""Return real-time statistics."""
|
||||
res = await self.call("get_energy_usage")
|
||||
return self._get_status_from_energy(res["get_energy_usage"])
|
||||
@ -67,13 +69,13 @@ class Energy(SmartModule, EnergyInterface):
|
||||
@raise_if_update_error
|
||||
def consumption_this_month(self) -> float | None:
|
||||
"""Get the emeter value for this month in kWh."""
|
||||
return self.energy.get("month_energy") / 1_000
|
||||
return self.energy.get("month_energy", 0) / 1_000
|
||||
|
||||
@property
|
||||
@raise_if_update_error
|
||||
def consumption_today(self) -> float | None:
|
||||
"""Get the emeter value for today in kWh."""
|
||||
return self.energy.get("today_energy") / 1_000
|
||||
return self.energy.get("today_energy", 0) / 1_000
|
||||
|
||||
@property
|
||||
@raise_if_update_error
|
||||
@ -97,22 +99,26 @@ class Energy(SmartModule, EnergyInterface):
|
||||
"""Retrieve current energy readings."""
|
||||
return self.status
|
||||
|
||||
async def erase_stats(self):
|
||||
async def erase_stats(self) -> NoReturn:
|
||||
"""Erase all stats."""
|
||||
raise KasaException("Device does not support periodic statistics")
|
||||
|
||||
async def get_daily_stats(self, *, year=None, month=None, kwh=True) -> dict:
|
||||
async def get_daily_stats(
|
||||
self, *, year: int | None = None, month: int | None = None, kwh: bool = True
|
||||
) -> dict:
|
||||
"""Return daily stats for the given year & month.
|
||||
|
||||
The return value is a dictionary of {day: energy, ...}.
|
||||
"""
|
||||
raise KasaException("Device does not support periodic statistics")
|
||||
|
||||
async def get_monthly_stats(self, *, year=None, kwh=True) -> dict:
|
||||
async def get_monthly_stats(
|
||||
self, *, year: int | None = None, kwh: bool = True
|
||||
) -> dict:
|
||||
"""Return monthly stats for the given year."""
|
||||
raise KasaException("Device does not support periodic statistics")
|
||||
|
||||
async def _check_supported(self):
|
||||
async def _check_supported(self) -> bool:
|
||||
"""Additional check to see if the module is supported by the device."""
|
||||
# Energy module is not supported on P304M parent device
|
||||
return "device_on" in self._device.sys_info
|
||||
|
@ -12,7 +12,7 @@ class Fan(SmartModule, FanInterface):
|
||||
|
||||
REQUIRED_COMPONENT = "fan_control"
|
||||
|
||||
def _initialize_features(self):
|
||||
def _initialize_features(self) -> None:
|
||||
"""Initialize features after the initial update."""
|
||||
self._add_feature(
|
||||
Feature(
|
||||
@ -50,7 +50,7 @@ class Fan(SmartModule, FanInterface):
|
||||
"""Return fan speed level."""
|
||||
return 0 if self.data["device_on"] is False else self.data["fan_speed_level"]
|
||||
|
||||
async def set_fan_speed_level(self, level: int):
|
||||
async def set_fan_speed_level(self, level: int) -> dict:
|
||||
"""Set fan speed level, 0 for off, 1-4 for on."""
|
||||
if level < 0 or level > 4:
|
||||
raise ValueError("Invalid level, should be in range 0-4.")
|
||||
@ -65,10 +65,10 @@ class Fan(SmartModule, FanInterface):
|
||||
"""Return sleep mode status."""
|
||||
return self.data["fan_sleep_mode_on"]
|
||||
|
||||
async def set_sleep_mode(self, on: bool):
|
||||
async def set_sleep_mode(self, on: bool) -> dict:
|
||||
"""Set sleep mode."""
|
||||
return await self.call("set_device_info", {"fan_sleep_mode_on": on})
|
||||
|
||||
async def _check_supported(self):
|
||||
async def _check_supported(self) -> bool:
|
||||
"""Is the module available on this device."""
|
||||
return "fan_speed_level" in self.data
|
||||
|
@ -49,14 +49,14 @@ class UpdateInfo(BaseModel):
|
||||
needs_upgrade: bool = Field(alias="need_to_upgrade")
|
||||
|
||||
@validator("release_date", pre=True)
|
||||
def _release_date_optional(cls, v):
|
||||
def _release_date_optional(cls, v: str) -> str | None:
|
||||
if not v:
|
||||
return None
|
||||
|
||||
return v
|
||||
|
||||
@property
|
||||
def update_available(self):
|
||||
def update_available(self) -> bool:
|
||||
"""Return True if update available."""
|
||||
if self.status != 0:
|
||||
return True
|
||||
@ -69,11 +69,11 @@ class Firmware(SmartModule):
|
||||
REQUIRED_COMPONENT = "firmware"
|
||||
MINIMUM_UPDATE_INTERVAL_SECS = 60 * 60 * 24
|
||||
|
||||
def __init__(self, device: SmartDevice, module: str):
|
||||
def __init__(self, device: SmartDevice, module: str) -> None:
|
||||
super().__init__(device, module)
|
||||
self._firmware_update_info: UpdateInfo | None = None
|
||||
|
||||
def _initialize_features(self):
|
||||
def _initialize_features(self) -> None:
|
||||
"""Initialize features."""
|
||||
device = self._device
|
||||
if self.supported_version > 1:
|
||||
@ -183,7 +183,7 @@ class Firmware(SmartModule):
|
||||
@allow_update_after
|
||||
async def update(
|
||||
self, progress_cb: Callable[[DownloadState], Coroutine] | None = None
|
||||
):
|
||||
) -> dict:
|
||||
"""Update the device firmware."""
|
||||
if not self._firmware_update_info:
|
||||
raise KasaException(
|
||||
@ -236,13 +236,15 @@ class Firmware(SmartModule):
|
||||
else:
|
||||
_LOGGER.warning("Unhandled state code: %s", state)
|
||||
|
||||
return state.dict()
|
||||
|
||||
@property
|
||||
def auto_update_enabled(self) -> bool:
|
||||
"""Return True if autoupdate is enabled."""
|
||||
return "enable" in self.data and self.data["enable"]
|
||||
|
||||
@allow_update_after
|
||||
async def set_auto_update_enabled(self, enabled: bool):
|
||||
async def set_auto_update_enabled(self, enabled: bool) -> dict:
|
||||
"""Change autoupdate setting."""
|
||||
data = {**self.data, "enable": enabled}
|
||||
await self.call("set_auto_update_info", data)
|
||||
return await self.call("set_auto_update_info", data)
|
||||
|
@ -23,7 +23,7 @@ class FrostProtection(SmartModule):
|
||||
"""Return True if frost protection is on."""
|
||||
return self._device.sys_info["frost_protection_on"]
|
||||
|
||||
async def set_enabled(self, enable: bool):
|
||||
async def set_enabled(self, enable: bool) -> dict:
|
||||
"""Enable/disable frost protection."""
|
||||
return await self.call(
|
||||
"set_device_info",
|
||||
|
@ -12,7 +12,7 @@ class HumiditySensor(SmartModule):
|
||||
REQUIRED_COMPONENT = "humidity"
|
||||
QUERY_GETTER_NAME = "get_comfort_humidity_config"
|
||||
|
||||
def _initialize_features(self):
|
||||
def _initialize_features(self) -> None:
|
||||
"""Initialize features after the initial update."""
|
||||
self._add_feature(
|
||||
Feature(
|
||||
@ -45,7 +45,7 @@ class HumiditySensor(SmartModule):
|
||||
return {}
|
||||
|
||||
@property
|
||||
def humidity(self):
|
||||
def humidity(self) -> int:
|
||||
"""Return current humidity in percentage."""
|
||||
return self._device.sys_info["current_humidity"]
|
||||
|
||||
|
@ -19,7 +19,7 @@ class Led(SmartModule, LedInterface):
|
||||
return {self.QUERY_GETTER_NAME: None}
|
||||
|
||||
@property
|
||||
def mode(self):
|
||||
def mode(self) -> str:
|
||||
"""LED mode setting.
|
||||
|
||||
"always", "never", "night_mode"
|
||||
@ -27,12 +27,12 @@ class Led(SmartModule, LedInterface):
|
||||
return self.data["led_rule"]
|
||||
|
||||
@property
|
||||
def led(self):
|
||||
def led(self) -> bool:
|
||||
"""Return current led status."""
|
||||
return self.data["led_rule"] != "never"
|
||||
|
||||
@allow_update_after
|
||||
async def set_led(self, enable: bool):
|
||||
async def set_led(self, enable: bool) -> dict:
|
||||
"""Set led.
|
||||
|
||||
This should probably be a select with always/never/nightmode.
|
||||
@ -41,7 +41,7 @@ class Led(SmartModule, LedInterface):
|
||||
return await self.call("set_led_info", dict(self.data, **{"led_rule": rule}))
|
||||
|
||||
@property
|
||||
def night_mode_settings(self):
|
||||
def night_mode_settings(self) -> dict:
|
||||
"""Night mode settings."""
|
||||
return {
|
||||
"start": self.data["start_time"],
|
||||
|
@ -96,7 +96,7 @@ class Light(SmartModule, LightInterface):
|
||||
return await self._device.modules[Module.Color].set_hsv(hue, saturation, value)
|
||||
|
||||
async def set_color_temp(
|
||||
self, temp: int, *, brightness=None, transition: int | None = None
|
||||
self, temp: int, *, brightness: int | None = None, transition: int | None = None
|
||||
) -> dict:
|
||||
"""Set the color temperature of the device in kelvin.
|
||||
|
||||
|
@ -81,7 +81,7 @@ class LightEffect(SmartModule, SmartLightEffect):
|
||||
*,
|
||||
brightness: int | None = None,
|
||||
transition: int | None = None,
|
||||
) -> None:
|
||||
) -> dict:
|
||||
"""Set an effect for the device.
|
||||
|
||||
Calling this will modify the brightness of the effect on the device.
|
||||
@ -107,7 +107,7 @@ class LightEffect(SmartModule, SmartLightEffect):
|
||||
)
|
||||
await self.set_brightness(brightness, effect_id=effect_id)
|
||||
|
||||
await self.call("set_dynamic_light_effect_rule_enable", params)
|
||||
return await self.call("set_dynamic_light_effect_rule_enable", params)
|
||||
|
||||
@property
|
||||
def is_active(self) -> bool:
|
||||
@ -139,11 +139,11 @@ class LightEffect(SmartModule, SmartLightEffect):
|
||||
*,
|
||||
transition: int | None = None,
|
||||
effect_id: str | None = None,
|
||||
):
|
||||
) -> dict:
|
||||
"""Set effect brightness."""
|
||||
new_effect = self._get_effect_data(effect_id=effect_id).copy()
|
||||
|
||||
def _replace_brightness(data, new_brightness):
|
||||
def _replace_brightness(data: list[int], new_brightness: int) -> list[int]:
|
||||
"""Replace brightness.
|
||||
|
||||
The first element is the brightness, the rest are unknown.
|
||||
@ -163,7 +163,7 @@ class LightEffect(SmartModule, SmartLightEffect):
|
||||
async def set_custom_effect(
|
||||
self,
|
||||
effect_dict: dict,
|
||||
) -> None:
|
||||
) -> dict:
|
||||
"""Set a custom effect on the device.
|
||||
|
||||
:param str effect_dict: The custom effect dict to set
|
||||
|
@ -29,12 +29,12 @@ class LightPreset(SmartModule, LightPresetInterface):
|
||||
_presets: dict[str, LightState]
|
||||
_preset_list: list[str]
|
||||
|
||||
def __init__(self, device: SmartDevice, module: str):
|
||||
def __init__(self, device: SmartDevice, module: str) -> None:
|
||||
super().__init__(device, module)
|
||||
self._state_in_sysinfo = self.SYS_INFO_STATE_KEY in device.sys_info
|
||||
self._brightness_only: bool = False
|
||||
|
||||
async def _post_update_hook(self):
|
||||
async def _post_update_hook(self) -> None:
|
||||
"""Update the internal presets."""
|
||||
index = 0
|
||||
self._presets = {}
|
||||
@ -113,7 +113,7 @@ class LightPreset(SmartModule, LightPresetInterface):
|
||||
async def set_preset(
|
||||
self,
|
||||
preset_name: str,
|
||||
) -> None:
|
||||
) -> dict:
|
||||
"""Set a light preset for the device."""
|
||||
light = self._device.modules[SmartModule.Light]
|
||||
if preset_name == self.PRESET_NOT_SET:
|
||||
@ -123,14 +123,14 @@ class LightPreset(SmartModule, LightPresetInterface):
|
||||
preset = LightState(brightness=100)
|
||||
elif (preset := self._presets.get(preset_name)) is None: # type: ignore[assignment]
|
||||
raise ValueError(f"{preset_name} is not a valid preset: {self.preset_list}")
|
||||
await self._device.modules[SmartModule.Light].set_state(preset)
|
||||
return await self._device.modules[SmartModule.Light].set_state(preset)
|
||||
|
||||
@allow_update_after
|
||||
async def save_preset(
|
||||
self,
|
||||
preset_name: str,
|
||||
preset_state: LightState,
|
||||
) -> None:
|
||||
) -> dict:
|
||||
"""Update the preset with preset_name with the new preset_info."""
|
||||
if preset_name not in self._presets:
|
||||
raise ValueError(f"{preset_name} is not a valid preset: {self.preset_list}")
|
||||
@ -138,11 +138,13 @@ class LightPreset(SmartModule, LightPresetInterface):
|
||||
if self._brightness_only:
|
||||
bright_list = [state.brightness for state in self._presets.values()]
|
||||
bright_list[index] = preset_state.brightness
|
||||
await self.call("set_preset_rules", {"brightness": bright_list})
|
||||
return await self.call("set_preset_rules", {"brightness": bright_list})
|
||||
else:
|
||||
state_params = asdict(preset_state)
|
||||
new_info = {k: v for k, v in state_params.items() if v is not None}
|
||||
await self.call("edit_preset_rules", {"index": index, "state": new_info})
|
||||
return await self.call(
|
||||
"edit_preset_rules", {"index": index, "state": new_info}
|
||||
)
|
||||
|
||||
@property
|
||||
def has_save_preset(self) -> bool:
|
||||
@ -158,7 +160,7 @@ class LightPreset(SmartModule, LightPresetInterface):
|
||||
|
||||
return {self.QUERY_GETTER_NAME: {"start_index": 0}}
|
||||
|
||||
async def _check_supported(self):
|
||||
async def _check_supported(self) -> bool:
|
||||
"""Additional check to see if the module is supported by the device.
|
||||
|
||||
Parent devices that report components of children such as ks240 will not have
|
||||
|
@ -16,7 +16,7 @@ class LightStripEffect(SmartModule, SmartLightEffect):
|
||||
|
||||
REQUIRED_COMPONENT = "light_strip_lighting_effect"
|
||||
|
||||
def __init__(self, device: SmartDevice, module: str):
|
||||
def __init__(self, device: SmartDevice, module: str) -> None:
|
||||
super().__init__(device, module)
|
||||
effect_list = [self.LIGHT_EFFECTS_OFF]
|
||||
effect_list.extend(EFFECT_NAMES)
|
||||
@ -66,7 +66,9 @@ class LightStripEffect(SmartModule, SmartLightEffect):
|
||||
eff = self.data["lighting_effect"]
|
||||
return eff["brightness"]
|
||||
|
||||
async def set_brightness(self, brightness: int, *, transition: int | None = None):
|
||||
async def set_brightness(
|
||||
self, brightness: int, *, transition: int | None = None
|
||||
) -> dict:
|
||||
"""Set effect brightness."""
|
||||
if brightness <= 0:
|
||||
return await self.set_effect(self.LIGHT_EFFECTS_OFF)
|
||||
@ -91,7 +93,7 @@ class LightStripEffect(SmartModule, SmartLightEffect):
|
||||
*,
|
||||
brightness: int | None = None,
|
||||
transition: int | None = None,
|
||||
) -> None:
|
||||
) -> dict:
|
||||
"""Set an effect on the device.
|
||||
|
||||
If brightness or transition is defined,
|
||||
@ -115,8 +117,7 @@ class LightStripEffect(SmartModule, SmartLightEffect):
|
||||
effect_dict = self._effect_mapping["Aurora"]
|
||||
effect_dict = {**effect_dict}
|
||||
effect_dict["enable"] = 0
|
||||
await self.set_custom_effect(effect_dict)
|
||||
return
|
||||
return await self.set_custom_effect(effect_dict)
|
||||
|
||||
if effect not in self._effect_mapping:
|
||||
raise ValueError(f"The effect {effect} is not a built in effect.")
|
||||
@ -134,13 +135,13 @@ class LightStripEffect(SmartModule, SmartLightEffect):
|
||||
if transition is not None:
|
||||
effect_dict["transition"] = transition
|
||||
|
||||
await self.set_custom_effect(effect_dict)
|
||||
return await self.set_custom_effect(effect_dict)
|
||||
|
||||
@allow_update_after
|
||||
async def set_custom_effect(
|
||||
self,
|
||||
effect_dict: dict,
|
||||
) -> None:
|
||||
) -> dict:
|
||||
"""Set a custom effect on the device.
|
||||
|
||||
:param str effect_dict: The custom effect dict to set
|
||||
@ -155,7 +156,7 @@ class LightStripEffect(SmartModule, SmartLightEffect):
|
||||
"""Return True if the device supports setting custom effects."""
|
||||
return True
|
||||
|
||||
def query(self):
|
||||
def query(self) -> dict:
|
||||
"""Return the base query."""
|
||||
return {}
|
||||
|
||||
|
@ -39,14 +39,14 @@ class LightTransition(SmartModule):
|
||||
_off_state: _State
|
||||
_enabled: bool
|
||||
|
||||
def __init__(self, device: SmartDevice, module: str):
|
||||
def __init__(self, device: SmartDevice, module: str) -> None:
|
||||
super().__init__(device, module)
|
||||
self._state_in_sysinfo = all(
|
||||
key in device.sys_info for key in self.SYS_INFO_STATE_KEYS
|
||||
)
|
||||
self._supports_on_and_off: bool = self.supported_version > 1
|
||||
|
||||
def _initialize_features(self):
|
||||
def _initialize_features(self) -> None:
|
||||
"""Initialize features."""
|
||||
icon = "mdi:transition"
|
||||
if not self._supports_on_and_off:
|
||||
@ -138,7 +138,7 @@ class LightTransition(SmartModule):
|
||||
}
|
||||
|
||||
@allow_update_after
|
||||
async def set_enabled(self, enable: bool):
|
||||
async def set_enabled(self, enable: bool) -> dict:
|
||||
"""Enable gradual on/off."""
|
||||
if not self._supports_on_and_off:
|
||||
return await self.call("set_on_off_gradually_info", {"enable": enable})
|
||||
@ -171,7 +171,7 @@ class LightTransition(SmartModule):
|
||||
return self._on_state["max_duration"]
|
||||
|
||||
@allow_update_after
|
||||
async def set_turn_on_transition(self, seconds: int):
|
||||
async def set_turn_on_transition(self, seconds: int) -> dict:
|
||||
"""Set turn on transition in seconds.
|
||||
|
||||
Setting to 0 turns the feature off.
|
||||
@ -207,7 +207,7 @@ class LightTransition(SmartModule):
|
||||
return self._off_state["max_duration"]
|
||||
|
||||
@allow_update_after
|
||||
async def set_turn_off_transition(self, seconds: int):
|
||||
async def set_turn_off_transition(self, seconds: int) -> dict:
|
||||
"""Set turn on transition in seconds.
|
||||
|
||||
Setting to 0 turns the feature off.
|
||||
@ -236,7 +236,7 @@ class LightTransition(SmartModule):
|
||||
else:
|
||||
return {self.QUERY_GETTER_NAME: None}
|
||||
|
||||
async def _check_supported(self):
|
||||
async def _check_supported(self) -> bool:
|
||||
"""Additional check to see if the module is supported by the device."""
|
||||
# For devices that report child components on the parent that are not
|
||||
# actually supported by the parent.
|
||||
|
@ -11,7 +11,7 @@ class MotionSensor(SmartModule):
|
||||
|
||||
REQUIRED_COMPONENT = "sensitivity"
|
||||
|
||||
def _initialize_features(self):
|
||||
def _initialize_features(self) -> None:
|
||||
"""Initialize features."""
|
||||
self._add_feature(
|
||||
Feature(
|
||||
@ -31,6 +31,6 @@ class MotionSensor(SmartModule):
|
||||
return {}
|
||||
|
||||
@property
|
||||
def motion_detected(self):
|
||||
def motion_detected(self) -> bool:
|
||||
"""Return True if the motion has been detected."""
|
||||
return self._device.sys_info["detected"]
|
||||
|
@ -12,7 +12,7 @@ class ReportMode(SmartModule):
|
||||
REQUIRED_COMPONENT = "report_mode"
|
||||
QUERY_GETTER_NAME = "get_report_mode"
|
||||
|
||||
def _initialize_features(self):
|
||||
def _initialize_features(self) -> None:
|
||||
"""Initialize features after the initial update."""
|
||||
self._add_feature(
|
||||
Feature(
|
||||
@ -32,6 +32,6 @@ class ReportMode(SmartModule):
|
||||
return {}
|
||||
|
||||
@property
|
||||
def report_interval(self):
|
||||
def report_interval(self) -> int:
|
||||
"""Reporting interval of a sensor device."""
|
||||
return self._device.sys_info["report_interval"]
|
||||
|
@ -26,7 +26,7 @@ class TemperatureControl(SmartModule):
|
||||
|
||||
REQUIRED_COMPONENT = "temp_control"
|
||||
|
||||
def _initialize_features(self):
|
||||
def _initialize_features(self) -> None:
|
||||
"""Initialize features after the initial update."""
|
||||
self._add_feature(
|
||||
Feature(
|
||||
@ -92,7 +92,7 @@ class TemperatureControl(SmartModule):
|
||||
"""Return thermostat state."""
|
||||
return self._device.sys_info["frost_protection_on"] is False
|
||||
|
||||
async def set_state(self, enabled: bool):
|
||||
async def set_state(self, enabled: bool) -> dict:
|
||||
"""Set thermostat state."""
|
||||
return await self.call("set_device_info", {"frost_protection_on": not enabled})
|
||||
|
||||
@ -147,7 +147,7 @@ class TemperatureControl(SmartModule):
|
||||
"""Return thermostat states."""
|
||||
return set(self._device.sys_info["trv_states"])
|
||||
|
||||
async def set_target_temperature(self, target: float):
|
||||
async def set_target_temperature(self, target: float) -> dict:
|
||||
"""Set target temperature."""
|
||||
if (
|
||||
target < self.minimum_target_temperature
|
||||
@ -170,7 +170,7 @@ class TemperatureControl(SmartModule):
|
||||
"""Return temperature offset."""
|
||||
return self._device.sys_info["temp_offset"]
|
||||
|
||||
async def set_temperature_offset(self, offset: int):
|
||||
async def set_temperature_offset(self, offset: int) -> dict:
|
||||
"""Set temperature offset."""
|
||||
if offset < -10 or offset > 10:
|
||||
raise ValueError("Temperature offset must be [-10, 10]")
|
||||
|
@ -14,7 +14,7 @@ class TemperatureSensor(SmartModule):
|
||||
REQUIRED_COMPONENT = "temperature"
|
||||
QUERY_GETTER_NAME = "get_comfort_temp_config"
|
||||
|
||||
def _initialize_features(self):
|
||||
def _initialize_features(self) -> None:
|
||||
"""Initialize features after the initial update."""
|
||||
self._add_feature(
|
||||
Feature(
|
||||
@ -60,7 +60,7 @@ class TemperatureSensor(SmartModule):
|
||||
return {}
|
||||
|
||||
@property
|
||||
def temperature(self):
|
||||
def temperature(self) -> float:
|
||||
"""Return current humidity in percentage."""
|
||||
return self._device.sys_info["current_temp"]
|
||||
|
||||
@ -74,6 +74,8 @@ class TemperatureSensor(SmartModule):
|
||||
"""Return current temperature unit."""
|
||||
return self._device.sys_info["temp_unit"]
|
||||
|
||||
async def set_temperature_unit(self, unit: Literal["celsius", "fahrenheit"]):
|
||||
async def set_temperature_unit(
|
||||
self, unit: Literal["celsius", "fahrenheit"]
|
||||
) -> dict:
|
||||
"""Set the device temperature unit."""
|
||||
return await self.call("set_temperature_unit", {"temp_unit": unit})
|
||||
|
@ -21,7 +21,7 @@ class Time(SmartModule, TimeInterface):
|
||||
|
||||
_timezone: tzinfo = timezone.utc
|
||||
|
||||
def _initialize_features(self):
|
||||
def _initialize_features(self) -> None:
|
||||
"""Initialize features after the initial update."""
|
||||
self._add_feature(
|
||||
Feature(
|
||||
@ -35,7 +35,7 @@ class Time(SmartModule, TimeInterface):
|
||||
)
|
||||
)
|
||||
|
||||
async def _post_update_hook(self):
|
||||
async def _post_update_hook(self) -> None:
|
||||
"""Perform actions after a device update."""
|
||||
td = timedelta(minutes=cast(float, self.data.get("time_diff")))
|
||||
if region := self.data.get("region"):
|
||||
@ -84,7 +84,7 @@ class Time(SmartModule, TimeInterface):
|
||||
params["region"] = region
|
||||
return await self.call("set_device_time", params)
|
||||
|
||||
async def _check_supported(self):
|
||||
async def _check_supported(self) -> bool:
|
||||
"""Additional check to see if the module is supported by the device.
|
||||
|
||||
Hub attached sensors report the time module but do return device time.
|
||||
|
@ -22,7 +22,7 @@ class WaterleakSensor(SmartModule):
|
||||
|
||||
REQUIRED_COMPONENT = "sensor_alarm"
|
||||
|
||||
def _initialize_features(self):
|
||||
def _initialize_features(self) -> None:
|
||||
"""Initialize features after the initial update."""
|
||||
self._add_feature(
|
||||
Feature(
|
||||
|
@ -49,7 +49,7 @@ class SmartChildDevice(SmartDevice):
|
||||
self._update_internal_state(info)
|
||||
self._components = component_info
|
||||
|
||||
async def update(self, update_children: bool = True):
|
||||
async def update(self, update_children: bool = True) -> None:
|
||||
"""Update child module info.
|
||||
|
||||
The parent updates our internal info so just update modules with
|
||||
@ -57,7 +57,7 @@ class SmartChildDevice(SmartDevice):
|
||||
"""
|
||||
await self._update(update_children)
|
||||
|
||||
async def _update(self, update_children: bool = True):
|
||||
async def _update(self, update_children: bool = True) -> None:
|
||||
"""Update child module info.
|
||||
|
||||
Internal implementation to allow patching of public update in the cli
|
||||
@ -118,5 +118,5 @@ class SmartChildDevice(SmartDevice):
|
||||
dev_type = DeviceType.Unknown
|
||||
return dev_type
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.device_type} {self.alias} ({self.model}) of {self._parent}>"
|
||||
|
@ -69,7 +69,7 @@ class SmartDevice(Device):
|
||||
self._on_since: datetime | None = None
|
||||
self._info: dict[str, Any] = {}
|
||||
|
||||
async def _initialize_children(self):
|
||||
async def _initialize_children(self) -> None:
|
||||
"""Initialize children for power strips."""
|
||||
child_info_query = {
|
||||
"get_child_device_component_list": None,
|
||||
@ -108,7 +108,9 @@ class SmartDevice(Device):
|
||||
"""Return the device modules."""
|
||||
return cast(ModuleMapping[SmartModule], self._modules)
|
||||
|
||||
def _try_get_response(self, responses: dict, request: str, default=None) -> dict:
|
||||
def _try_get_response(
|
||||
self, responses: dict, request: str, default: Any | None = None
|
||||
) -> dict:
|
||||
response = responses.get(request)
|
||||
if isinstance(response, SmartErrorCode):
|
||||
_LOGGER.debug(
|
||||
@ -126,7 +128,7 @@ class SmartDevice(Device):
|
||||
f"{request} not found in {responses} for device {self.host}"
|
||||
)
|
||||
|
||||
async def _negotiate(self):
|
||||
async def _negotiate(self) -> None:
|
||||
"""Perform initialization.
|
||||
|
||||
We fetch the device info and the available components as early as possible.
|
||||
@ -146,7 +148,8 @@ class SmartDevice(Device):
|
||||
self._info = self._try_get_response(resp, "get_device_info")
|
||||
|
||||
# Create our internal presentation of available components
|
||||
self._components_raw = resp["component_nego"]
|
||||
self._components_raw = cast(dict, resp["component_nego"])
|
||||
|
||||
self._components = {
|
||||
comp["id"]: int(comp["ver_code"])
|
||||
for comp in self._components_raw["component_list"]
|
||||
@ -167,7 +170,7 @@ class SmartDevice(Device):
|
||||
"""Update the internal device info."""
|
||||
self._info = self._try_get_response(info_resp, "get_device_info")
|
||||
|
||||
async def update(self, update_children: bool = False):
|
||||
async def update(self, update_children: bool = False) -> None:
|
||||
"""Update the device."""
|
||||
if self.credentials is None and self.credentials_hash is None:
|
||||
raise AuthenticationError("Tapo plug requires authentication.")
|
||||
@ -206,7 +209,7 @@ class SmartDevice(Device):
|
||||
|
||||
async def _handle_module_post_update(
|
||||
self, module: SmartModule, update_time: float, had_query: bool
|
||||
):
|
||||
) -> None:
|
||||
if module.disabled:
|
||||
return # pragma: no cover
|
||||
if had_query:
|
||||
@ -312,7 +315,7 @@ class SmartDevice(Device):
|
||||
responses[meth] = SmartErrorCode.INTERNAL_QUERY_ERROR
|
||||
return responses
|
||||
|
||||
async def _initialize_modules(self):
|
||||
async def _initialize_modules(self) -> None:
|
||||
"""Initialize modules based on component negotiation response."""
|
||||
from .smartmodule import SmartModule
|
||||
|
||||
@ -324,7 +327,7 @@ class SmartDevice(Device):
|
||||
# It also ensures that devices like power strips do not add modules such as
|
||||
# firmware to the child devices.
|
||||
skip_parent_only_modules = False
|
||||
child_modules_to_skip = {}
|
||||
child_modules_to_skip: dict = {} # TODO: this is never non-empty
|
||||
if self._parent and self._parent.device_type != DeviceType.Hub:
|
||||
skip_parent_only_modules = True
|
||||
|
||||
@ -333,17 +336,18 @@ class SmartDevice(Device):
|
||||
skip_parent_only_modules and mod in NON_HUB_PARENT_ONLY_MODULES
|
||||
) or mod.__name__ in child_modules_to_skip:
|
||||
continue
|
||||
if (
|
||||
mod.REQUIRED_COMPONENT in self._components
|
||||
or self.sys_info.get(mod.REQUIRED_KEY_ON_PARENT) is not None
|
||||
required_component = cast(str, mod.REQUIRED_COMPONENT)
|
||||
if required_component in self._components or (
|
||||
mod.REQUIRED_KEY_ON_PARENT
|
||||
and self.sys_info.get(mod.REQUIRED_KEY_ON_PARENT) is not None
|
||||
):
|
||||
_LOGGER.debug(
|
||||
"Device %s, found required %s, adding %s to modules.",
|
||||
self.host,
|
||||
mod.REQUIRED_COMPONENT,
|
||||
required_component,
|
||||
mod.__name__,
|
||||
)
|
||||
module = mod(self, mod.REQUIRED_COMPONENT)
|
||||
module = mod(self, required_component)
|
||||
if await module._check_supported():
|
||||
self._modules[module.name] = module
|
||||
|
||||
@ -354,7 +358,7 @@ class SmartDevice(Device):
|
||||
):
|
||||
self._modules[Light.__name__] = Light(self, "light")
|
||||
|
||||
async def _initialize_features(self):
|
||||
async def _initialize_features(self) -> None:
|
||||
"""Initialize device features."""
|
||||
self._add_feature(
|
||||
Feature(
|
||||
@ -575,11 +579,11 @@ class SmartDevice(Device):
|
||||
return str(self._info.get("device_id"))
|
||||
|
||||
@property
|
||||
def internal_state(self) -> Any:
|
||||
def internal_state(self) -> dict:
|
||||
"""Return all the internal state data."""
|
||||
return self._last_update
|
||||
|
||||
def _update_internal_state(self, info: dict) -> None:
|
||||
def _update_internal_state(self, info: dict[str, Any]) -> None:
|
||||
"""Update the internal info state.
|
||||
|
||||
This is used by the parent to push updates to its children.
|
||||
@ -587,8 +591,8 @@ class SmartDevice(Device):
|
||||
self._info = info
|
||||
|
||||
async def _query_helper(
|
||||
self, method: str, params: dict | None = None, child_ids=None
|
||||
) -> Any:
|
||||
self, method: str, params: dict | None = None, child_ids: None = None
|
||||
) -> dict:
|
||||
res = await self.protocol.query({method: params})
|
||||
|
||||
return res
|
||||
@ -610,22 +614,25 @@ class SmartDevice(Device):
|
||||
"""Return true if the device is on."""
|
||||
return bool(self._info.get("device_on"))
|
||||
|
||||
async def set_state(self, on: bool): # TODO: better name wanted.
|
||||
async def set_state(self, on: bool) -> dict:
|
||||
"""Set the device state.
|
||||
|
||||
See :meth:`is_on`.
|
||||
"""
|
||||
return await self.protocol.query({"set_device_info": {"device_on": on}})
|
||||
|
||||
async def turn_on(self, **kwargs):
|
||||
async def turn_on(self, **kwargs: Any) -> dict:
|
||||
"""Turn on the device."""
|
||||
await self.set_state(True)
|
||||
return await self.set_state(True)
|
||||
|
||||
async def turn_off(self, **kwargs):
|
||||
async def turn_off(self, **kwargs: Any) -> dict:
|
||||
"""Turn off the device."""
|
||||
await self.set_state(False)
|
||||
return await self.set_state(False)
|
||||
|
||||
def update_from_discover_info(self, info):
|
||||
def update_from_discover_info(
|
||||
self,
|
||||
info: dict,
|
||||
) -> None:
|
||||
"""Update state from info from the discover call."""
|
||||
self._discovery_info = info
|
||||
self._info = info
|
||||
@ -633,7 +640,7 @@ class SmartDevice(Device):
|
||||
async def wifi_scan(self) -> list[WifiNetwork]:
|
||||
"""Scan for available wifi networks."""
|
||||
|
||||
def _net_for_scan_info(res):
|
||||
def _net_for_scan_info(res: dict) -> WifiNetwork:
|
||||
return WifiNetwork(
|
||||
ssid=base64.b64decode(res["ssid"]).decode(),
|
||||
cipher_type=res["cipher_type"],
|
||||
@ -651,7 +658,9 @@ class SmartDevice(Device):
|
||||
]
|
||||
return networks
|
||||
|
||||
async def wifi_join(self, ssid: str, password: str, keytype: str = "wpa2_psk"):
|
||||
async def wifi_join(
|
||||
self, ssid: str, password: str, keytype: str = "wpa2_psk"
|
||||
) -> dict:
|
||||
"""Join the given wifi network.
|
||||
|
||||
This method returns nothing as the device tries to activate the new
|
||||
@ -688,9 +697,12 @@ class SmartDevice(Device):
|
||||
except DeviceError:
|
||||
raise # Re-raise on device-reported errors
|
||||
except KasaException:
|
||||
_LOGGER.debug("Received an expected for wifi join, but this is expected")
|
||||
_LOGGER.debug(
|
||||
"Received a kasa exception for wifi join, but this is expected"
|
||||
)
|
||||
return {}
|
||||
|
||||
async def update_credentials(self, username: str, password: str):
|
||||
async def update_credentials(self, username: str, password: str) -> dict:
|
||||
"""Update device credentials.
|
||||
|
||||
This will replace the existing authentication credentials on the device.
|
||||
@ -705,7 +717,7 @@ class SmartDevice(Device):
|
||||
}
|
||||
return await self.protocol.query({"set_qs_info": payload})
|
||||
|
||||
async def set_alias(self, alias: str):
|
||||
async def set_alias(self, alias: str) -> dict:
|
||||
"""Set the device name (alias)."""
|
||||
return await self.protocol.query(
|
||||
{"set_device_info": {"nickname": base64.b64encode(alias.encode()).decode()}}
|
||||
|
@ -22,17 +22,17 @@ _R = TypeVar("_R")
|
||||
|
||||
|
||||
def allow_update_after(
|
||||
func: Callable[Concatenate[_T, _P], Awaitable[None]],
|
||||
) -> Callable[Concatenate[_T, _P], Coroutine[Any, Any, None]]:
|
||||
func: Callable[Concatenate[_T, _P], Awaitable[dict]],
|
||||
) -> Callable[Concatenate[_T, _P], Coroutine[Any, Any, dict]]:
|
||||
"""Define a wrapper to set _last_update_time to None.
|
||||
|
||||
This will ensure that a module is updated in the next update cycle after
|
||||
a value has been changed.
|
||||
"""
|
||||
|
||||
async def _async_wrap(self: _T, *args: _P.args, **kwargs: _P.kwargs) -> None:
|
||||
async def _async_wrap(self: _T, *args: _P.args, **kwargs: _P.kwargs) -> dict:
|
||||
try:
|
||||
await func(self, *args, **kwargs)
|
||||
return await func(self, *args, **kwargs)
|
||||
finally:
|
||||
self._last_update_time = None
|
||||
|
||||
@ -68,21 +68,21 @@ class SmartModule(Module):
|
||||
|
||||
DISABLE_AFTER_ERROR_COUNT = 10
|
||||
|
||||
def __init__(self, device: SmartDevice, module: str):
|
||||
def __init__(self, device: SmartDevice, module: str) -> None:
|
||||
self._device: SmartDevice
|
||||
super().__init__(device, module)
|
||||
self._last_update_time: float | None = None
|
||||
self._last_update_error: KasaException | None = None
|
||||
self._error_count = 0
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
def __init_subclass__(cls, **kwargs) -> None:
|
||||
# We only want to register submodules in a modules package so that
|
||||
# other classes can inherit from smartmodule and not be registered
|
||||
if cls.__module__.split(".")[-2] == "modules":
|
||||
_LOGGER.debug("Registering %s", cls)
|
||||
cls.REGISTERED_MODULES[cls._module_name()] = cls
|
||||
|
||||
def _set_error(self, err: Exception | None):
|
||||
def _set_error(self, err: Exception | None) -> None:
|
||||
if err is None:
|
||||
self._error_count = 0
|
||||
self._last_update_error = None
|
||||
@ -119,7 +119,7 @@ class SmartModule(Module):
|
||||
return self._error_count >= self.DISABLE_AFTER_ERROR_COUNT
|
||||
|
||||
@classmethod
|
||||
def _module_name(cls):
|
||||
def _module_name(cls) -> str:
|
||||
return getattr(cls, "NAME", cls.__name__)
|
||||
|
||||
@property
|
||||
@ -127,7 +127,7 @@ class SmartModule(Module):
|
||||
"""Name of the module."""
|
||||
return self._module_name()
|
||||
|
||||
async def _post_update_hook(self): # noqa: B027
|
||||
async def _post_update_hook(self) -> None: # noqa: B027
|
||||
"""Perform actions after a device update.
|
||||
|
||||
Any modules overriding this should ensure that self.data is
|
||||
@ -142,7 +142,7 @@ class SmartModule(Module):
|
||||
"""
|
||||
return {self.QUERY_GETTER_NAME: None}
|
||||
|
||||
async def call(self, method, params=None):
|
||||
async def call(self, method: str, params: dict | None = None) -> dict:
|
||||
"""Call a method.
|
||||
|
||||
Just a helper method.
|
||||
@ -150,7 +150,7 @@ class SmartModule(Module):
|
||||
return await self._device._query_helper(method, params)
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
def data(self) -> dict[str, Any]:
|
||||
"""Return response data for the module.
|
||||
|
||||
If the module performs only a single query, the resulting response is unwrapped.
|
||||
|
@ -72,7 +72,7 @@ class SmartProtocol(BaseProtocol):
|
||||
)
|
||||
self._redact_data = True
|
||||
|
||||
def get_smart_request(self, method, params=None) -> str:
|
||||
def get_smart_request(self, method: str, params: dict | None = None) -> str:
|
||||
"""Get a request message as a string."""
|
||||
request = {
|
||||
"method": method,
|
||||
@ -289,8 +289,8 @@ class SmartProtocol(BaseProtocol):
|
||||
return {smart_method: result}
|
||||
|
||||
async def _handle_response_lists(
|
||||
self, response_result: dict[str, Any], method, retry_count
|
||||
):
|
||||
self, response_result: dict[str, Any], method: str, retry_count: int
|
||||
) -> None:
|
||||
if (
|
||||
response_result is None
|
||||
or isinstance(response_result, SmartErrorCode)
|
||||
@ -325,7 +325,9 @@ class SmartProtocol(BaseProtocol):
|
||||
break
|
||||
response_result[response_list_name].extend(next_batch[response_list_name])
|
||||
|
||||
def _handle_response_error_code(self, resp_dict: dict, method, raise_on_error=True):
|
||||
def _handle_response_error_code(
|
||||
self, resp_dict: dict, method: str, raise_on_error: bool = True
|
||||
) -> None:
|
||||
error_code_raw = resp_dict.get("error_code")
|
||||
try:
|
||||
error_code = SmartErrorCode.from_int(error_code_raw)
|
||||
@ -369,12 +371,12 @@ class _ChildProtocolWrapper(SmartProtocol):
|
||||
device responses before returning to the caller.
|
||||
"""
|
||||
|
||||
def __init__(self, device_id: str, base_protocol: SmartProtocol):
|
||||
def __init__(self, device_id: str, base_protocol: SmartProtocol) -> None:
|
||||
self._device_id = device_id
|
||||
self._protocol = base_protocol
|
||||
self._transport = base_protocol._transport
|
||||
|
||||
def _get_method_and_params_for_request(self, request):
|
||||
def _get_method_and_params_for_request(self, request: dict[str, Any] | str) -> Any:
|
||||
"""Return payload for wrapping.
|
||||
|
||||
TODO: this does not support batches and requires refactoring in the future.
|
||||
|
@ -310,9 +310,7 @@ class FakeSmartTransport(BaseTransport):
|
||||
}
|
||||
return retval
|
||||
|
||||
raise NotImplementedError(
|
||||
"Method %s not implemented for children" % child_method
|
||||
)
|
||||
raise NotImplementedError(f"Method {child_method} not implemented for children")
|
||||
|
||||
def _get_on_off_gradually_info(self, info, params):
|
||||
if self.components["on_off_gradually"] == 1:
|
||||
|
@ -41,7 +41,7 @@ async def test_firmware_features(
|
||||
|
||||
await fw.check_latest_firmware()
|
||||
if fw.supported_version < required_version:
|
||||
pytest.skip("Feature %s requires newer version" % feature)
|
||||
pytest.skip(f"Feature {feature} requires newer version")
|
||||
|
||||
prop = getattr(fw, prop_name)
|
||||
assert isinstance(prop, type)
|
||||
|
@ -48,7 +48,7 @@ class XorTransport(BaseTransport):
|
||||
self.loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
@property
|
||||
def default_port(self):
|
||||
def default_port(self) -> int:
|
||||
"""Default port for the transport."""
|
||||
return self.DEFAULT_PORT
|
||||
|
||||
|
@ -139,10 +139,15 @@ select = [
|
||||
"PT", # flake8-pytest-style
|
||||
"LOG", # flake8-logging
|
||||
"G", # flake8-logging-format
|
||||
"ANN", # annotations
|
||||
]
|
||||
ignore = [
|
||||
"D105", # Missing docstring in magic method
|
||||
"D107", # Missing docstring in `__init__`
|
||||
"ANN101", # Missing type annotation for `self`
|
||||
"ANN102", # Missing type annotation for `cls` in classmethod
|
||||
"ANN003", # Missing type annotation for `**kwargs`
|
||||
"ANN401", # allow any
|
||||
]
|
||||
|
||||
[tool.ruff.lint.pydocstyle]
|
||||
@ -157,11 +162,21 @@ convention = "pep257"
|
||||
"D104",
|
||||
"S101", # allow asserts
|
||||
"E501", # ignore line-too-longs
|
||||
"ANN", # skip for now
|
||||
]
|
||||
"docs/source/conf.py" = [
|
||||
"D100",
|
||||
"D103",
|
||||
]
|
||||
# Temporary ANN disable
|
||||
"kasa/cli/*.py" = [
|
||||
"ANN",
|
||||
]
|
||||
# Temporary ANN disable
|
||||
"devtools/*.py" = [
|
||||
"ANN",
|
||||
]
|
||||
|
||||
|
||||
[tool.mypy]
|
||||
warn_unused_configs = true # warns if overrides sections unused/mis-spelled
|
||||
|
Loading…
Reference in New Issue
Block a user