Add DeviceConfig to allow specifying configuration parameters (#569)

* Add DeviceConfig handling

* Update post review

* Further update post latest review

* Update following latest review

* Update docstrings and docs
This commit is contained in:
sdb9696 2023-12-29 19:17:15 +00:00 committed by GitHub
parent ec3ea39a37
commit f6fd898faf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
33 changed files with 1032 additions and 589 deletions

View File

@ -23,9 +23,12 @@ This will return you a list of device instances based on the discovery replies.
If the device's host is already known, you can use to construct a device instance with If the device's host is already known, you can use to construct a device instance with
:meth:`~kasa.SmartDevice.connect()`. :meth:`~kasa.SmartDevice.connect()`.
When connecting a device with the :meth:`~kasa.SmartDevice.connect()` method, it is recommended to The :meth:`~kasa.SmartDevice.connect()` also enables support for connecting to new
pass the device type as well as this allows the library to use the correct device class for the KASA SMART protocol and TAPO devices directly using the parameter :class:`~kasa.DeviceConfig`.
device without having to query the device. Simply serialize the :attr:`~kasa.SmartDevice.config` property via :meth:`~kasa.DeviceConfig.to_dict()`
and then deserialize it later with :func:`~kasa.DeviceConfig.from_dict()`
and then pass it into :meth:`~kasa.SmartDevice.connect()`.
.. _update_cycle: .. _update_cycle:

View File

@ -0,0 +1,18 @@
DeviceConfig
============
.. contents:: Contents
:local:
.. note::
Feel free to open a pull request to improve the documentation!
API documentation
*****************
.. autoclass:: kasa.DeviceConfig
:members:
:inherited-members:
:undoc-members:

View File

@ -15,3 +15,4 @@
smartdimmer smartdimmer
smartstrip smartstrip
smartlightstrip smartlightstrip
deviceconfig

View File

@ -14,6 +14,12 @@ to be handled by the user of the library.
from importlib.metadata import version from importlib.metadata import version
from kasa.credentials import Credentials from kasa.credentials import Credentials
from kasa.deviceconfig import (
ConnectionType,
DeviceConfig,
DeviceFamilyType,
EncryptType,
)
from kasa.discover import Discover from kasa.discover import Discover
from kasa.emeterstatus import EmeterStatus from kasa.emeterstatus import EmeterStatus
from kasa.exceptions import ( from kasa.exceptions import (
@ -55,4 +61,8 @@ __all__ = [
"AuthenticationException", "AuthenticationException",
"UnsupportedDeviceException", "UnsupportedDeviceException",
"Credentials", "Credentials",
"DeviceConfig",
"ConnectionType",
"EncryptType",
"DeviceFamilyType",
] ]

View File

@ -16,7 +16,7 @@ from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padd
from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from .credentials import Credentials from .deviceconfig import DeviceConfig
from .exceptions import ( from .exceptions import (
SMART_AUTHENTICATION_ERRORS, SMART_AUTHENTICATION_ERRORS,
SMART_RETRYABLE_ERRORS, SMART_RETRYABLE_ERRORS,
@ -47,8 +47,7 @@ class AesTransport(BaseTransport):
protocol, sometimes used by newer firmware versions on kasa devices. protocol, sometimes used by newer firmware versions on kasa devices.
""" """
DEFAULT_PORT = 80 DEFAULT_PORT: int = 80
DEFAULT_TIMEOUT = 5
SESSION_COOKIE_NAME = "TP_SESSIONID" SESSION_COOKIE_NAME = "TP_SESSIONID"
COMMON_HEADERS = { COMMON_HEADERS = {
"Content-Type": "application/json", "Content-Type": "application/json",
@ -58,32 +57,37 @@ class AesTransport(BaseTransport):
def __init__( def __init__(
self, self,
host: str,
*, *,
port: Optional[int] = None, config: DeviceConfig,
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
) -> None: ) -> None:
super().__init__( super().__init__(config=config)
host,
port=port or self.DEFAULT_PORT, self._default_http_client: Optional[httpx.AsyncClient] = None
credentials=credentials,
timeout=timeout,
)
self._handshake_done = False self._handshake_done = False
self._encryption_session: Optional[AesEncyptionSession] = None self._encryption_session: Optional[AesEncyptionSession] = None
self._session_expire_at: Optional[float] = None self._session_expire_at: Optional[float] = None
self._timeout = timeout if timeout else self.DEFAULT_TIMEOUT
self._session_cookie = None self._session_cookie = None
self._http_client: httpx.AsyncClient = httpx.AsyncClient()
self._login_token = None self._login_token = None
_LOGGER.debug("Created AES transport for %s", self._host) _LOGGER.debug("Created AES transport for %s", self._host)
@property
def default_port(self):
"""Default port for the transport."""
return self.DEFAULT_PORT
@property
def _http_client(self) -> httpx.AsyncClient:
if self._config.http_client:
return self._config.http_client
if not self._default_http_client:
self._default_http_client = httpx.AsyncClient()
return self._default_http_client
def hash_credentials(self, login_v2): def hash_credentials(self, login_v2):
"""Hash the credentials.""" """Hash the credentials."""
if login_v2: if login_v2:
@ -102,8 +106,6 @@ class AesTransport(BaseTransport):
async def client_post(self, url, params=None, data=None, json=None, headers=None): async def client_post(self, url, params=None, data=None, json=None, headers=None):
"""Send an http post request to the device.""" """Send an http post request to the device."""
if not self._http_client:
self._http_client = httpx.AsyncClient()
response_data = None response_data = None
cookies = None cookies = None
if self._session_cookie: if self._session_cookie:
@ -268,8 +270,8 @@ class AesTransport(BaseTransport):
async def close(self) -> None: async def close(self) -> None:
"""Close the protocol.""" """Close the protocol."""
client = self._http_client client = self._default_http_client
self._http_client = None self._default_http_client = None
self._handshake_done = False self._handshake_done = False
self._login_token = None self._login_token = None
if client: if client:

View File

@ -12,15 +12,20 @@ import asyncclick as click
from kasa import ( from kasa import (
AuthenticationException, AuthenticationException,
ConnectionType,
Credentials, Credentials,
DeviceType, DeviceConfig,
DeviceFamilyType,
Discover, Discover,
EncryptType,
SmartBulb, SmartBulb,
SmartDevice, SmartDevice,
SmartDimmer,
SmartLightStrip,
SmartPlug,
SmartStrip, SmartStrip,
UnsupportedDeviceException, UnsupportedDeviceException,
) )
from kasa.device_factory import DEVICE_TYPE_TO_CLASS
from kasa.discover import DiscoveryResult from kasa.discover import DiscoveryResult
try: try:
@ -49,10 +54,19 @@ except ImportError:
# --json has set it to _nop_echo # --json has set it to _nop_echo
echo = _do_echo echo = _do_echo
DEVICE_TYPES = [
device_type.value TYPE_TO_CLASS = {
for device_type in DeviceType "plug": SmartPlug,
if device_type in DEVICE_TYPE_TO_CLASS "bulb": SmartBulb,
"dimmer": SmartDimmer,
"strip": SmartStrip,
"lightstrip": SmartLightStrip,
}
ENCRYPT_TYPES = [encrypt_type.value for encrypt_type in EncryptType]
DEVICE_FAMILY_TYPES = [
device_family_type.value for device_family_type in DeviceFamilyType
] ]
click.anyio_backend = "asyncio" click.anyio_backend = "asyncio"
@ -149,7 +163,7 @@ def json_formatter_cb(result, **kwargs):
"--type", "--type",
envvar="KASA_TYPE", envvar="KASA_TYPE",
default=None, default=None,
type=click.Choice(DEVICE_TYPES, case_sensitive=False), type=click.Choice(list(TYPE_TO_CLASS), case_sensitive=False),
) )
@click.option( @click.option(
"--json/--no-json", "--json/--no-json",
@ -158,6 +172,18 @@ def json_formatter_cb(result, **kwargs):
is_flag=True, is_flag=True,
help="Output raw device response as JSON.", help="Output raw device response as JSON.",
) )
@click.option(
"--encrypt-type",
envvar="KASA_ENCRYPT_TYPE",
default=None,
type=click.Choice(ENCRYPT_TYPES, case_sensitive=False),
)
@click.option(
"--device-family",
envvar="KASA_DEVICE_FAMILY",
default=None,
type=click.Choice(DEVICE_FAMILY_TYPES, case_sensitive=False),
)
@click.option( @click.option(
"--timeout", "--timeout",
envvar="KASA_TIMEOUT", envvar="KASA_TIMEOUT",
@ -199,6 +225,8 @@ async def cli(
verbose, verbose,
debug, debug,
type, type,
encrypt_type,
device_family,
json, json,
timeout, timeout,
discovery_timeout, discovery_timeout,
@ -270,12 +298,19 @@ async def cli(
return await ctx.invoke(discover) return await ctx.invoke(discover)
if type is not None: if type is not None:
device_type = DeviceType.from_value(type) dev = TYPE_TO_CLASS[type](host)
dev = await SmartDevice.connect( await dev.update()
host, credentials=credentials, device_type=device_type, timeout=timeout elif device_family or encrypt_type:
ctype = ConnectionType(
DeviceFamilyType(device_family),
EncryptType(encrypt_type),
) )
config = DeviceConfig(
host=host, credentials=credentials, timeout=timeout, connection_type=ctype
)
dev = await SmartDevice.connect(config=config)
else: else:
echo("No --type defined, discovering..") echo("No --type or --device-family and --encrypt-type defined, discovering..")
dev = await Discover.discover_single( dev = await Discover.discover_single(
host, host,
port=port, port=port,
@ -332,8 +367,10 @@ async def discover(ctx):
target = ctx.parent.params["target"] target = ctx.parent.params["target"]
username = ctx.parent.params["username"] username = ctx.parent.params["username"]
password = ctx.parent.params["password"] password = ctx.parent.params["password"]
timeout = ctx.parent.params["discovery_timeout"]
verbose = ctx.parent.params["verbose"] verbose = ctx.parent.params["verbose"]
discovery_timeout = ctx.parent.params["discovery_timeout"]
timeout = ctx.parent.params["timeout"]
port = ctx.parent.params["port"]
credentials = Credentials(username, password) credentials = Credentials(username, password)
@ -354,7 +391,7 @@ async def discover(ctx):
echo(f"\t{unsupported_exception}") echo(f"\t{unsupported_exception}")
echo() echo()
echo(f"Discovering devices on {target} for {timeout} seconds") echo(f"Discovering devices on {target} for {discovery_timeout} seconds")
async def print_discovered(dev: SmartDevice): async def print_discovered(dev: SmartDevice):
async with sem: async with sem:
@ -376,9 +413,11 @@ async def discover(ctx):
await Discover.discover( await Discover.discover(
target=target, target=target,
timeout=timeout, discovery_timeout=discovery_timeout,
on_discovered=print_discovered, on_discovered=print_discovered,
on_unsupported=print_unsupported, on_unsupported=print_unsupported,
port=port,
timeout=timeout,
credentials=credentials, credentials=credentials,
) )

View File

@ -8,5 +8,5 @@ from typing import Optional
class Credentials: class Credentials:
"""Credentials for authentication.""" """Credentials for authentication."""
username: Optional[str] = field(default=None, repr=False) username: Optional[str] = field(default="", repr=False)
password: Optional[str] = field(default=None, repr=False) password: Optional[str] = field(default="", repr=False)

View File

@ -1,18 +1,21 @@
"""Device creation by type.""" """Device creation via DeviceConfig."""
import logging import logging
import time import time
from typing import Any, Dict, Optional, Tuple, Type from typing import Any, Dict, Optional, Tuple, Type
from .aestransport import AesTransport from .aestransport import AesTransport
from .credentials import Credentials from .deviceconfig import DeviceConfig
from .device_type import DeviceType from .exceptions import SmartDeviceException, UnsupportedDeviceException
from .exceptions import UnsupportedDeviceException
from .iotprotocol import IotProtocol from .iotprotocol import IotProtocol
from .klaptransport import KlapTransport, TPlinkKlapTransportV2 from .klaptransport import KlapTransport, KlapTransportV2
from .protocol import BaseTransport, TPLinkProtocol from .protocol import (
BaseTransport,
TPLinkProtocol,
TPLinkSmartHomeProtocol,
_XorTransport,
)
from .smartbulb import SmartBulb from .smartbulb import SmartBulb
from .smartdevice import SmartDevice, SmartDeviceException from .smartdevice import SmartDevice
from .smartdimmer import SmartDimmer from .smartdimmer import SmartDimmer
from .smartlightstrip import SmartLightStrip from .smartlightstrip import SmartLightStrip
from .smartplug import SmartPlug from .smartplug import SmartPlug
@ -20,104 +23,80 @@ from .smartprotocol import SmartProtocol
from .smartstrip import SmartStrip from .smartstrip import SmartStrip
from .tapo import TapoBulb, TapoPlug from .tapo import TapoBulb, TapoPlug
DEVICE_TYPE_TO_CLASS = {
DeviceType.Plug: SmartPlug,
DeviceType.Bulb: SmartBulb,
DeviceType.Strip: SmartStrip,
DeviceType.Dimmer: SmartDimmer,
DeviceType.LightStrip: SmartLightStrip,
DeviceType.TapoPlug: TapoPlug,
DeviceType.TapoBulb: TapoBulb,
}
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
GET_SYSINFO_QUERY = {
"system": {"get_sysinfo": None},
}
async def connect(
host: str, async def connect(*, host: Optional[str] = None, config: DeviceConfig) -> "SmartDevice":
*, """Connect to a single device by the given hostname or device configuration.
port: Optional[int] = None,
timeout=5,
credentials: Optional[Credentials] = None,
device_type: Optional[DeviceType] = None,
protocol_class: Optional[Type[TPLinkProtocol]] = None,
) -> "SmartDevice":
"""Connect to a single device by the given IP address.
This method avoids the UDP based discovery process and This method avoids the UDP based discovery process and
will connect directly to the device to query its type. will connect directly to the device.
It is generally preferred to avoid :func:`discover_single()` and It is generally preferred to avoid :func:`discover_single()` and
use this function instead as it should perform better when use this function instead as it should perform better when
the WiFi network is congested or the device is not responding the WiFi network is congested or the device is not responding
to discovery requests. to discovery requests.
The device type is discovered by querying the device. Do not use this function directly, use SmartDevice.connect()
:param host: Hostname of device to query :param host: Hostname of device to query
:param device_type: Device type to use for the device. :param config: Connection parameters to ensure the correct protocol
If not given, the device type is discovered by querying the device. and connection options are used.
If the device type is already known, it is preferred to pass it
to avoid the extra query to the device to discover its type.
:param protocol_class: Optionally provide the protocol class
to use.
:rtype: SmartDevice :rtype: SmartDevice
:return: Object for querying/controlling found device. :return: Object for querying/controlling found device.
""" """
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG) if host and config or (not host and not config):
raise SmartDeviceException("One of host or config must be provded and not both")
if host:
config = DeviceConfig(host=host)
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
if debug_enabled: if debug_enabled:
start_time = time.perf_counter() start_time = time.perf_counter()
if device_type and (klass := DEVICE_TYPE_TO_CLASS.get(device_type)): def _perf_log(has_params, perf_type):
dev: SmartDevice = klass( nonlocal start_time
host=host, port=port, credentials=credentials, timeout=timeout
)
if protocol_class is not None:
dev.protocol = protocol_class(
host,
transport=AesTransport(
host, port=port, credentials=credentials, timeout=timeout
),
)
await dev.update()
if debug_enabled: if debug_enabled:
end_time = time.perf_counter() end_time = time.perf_counter()
_LOGGER.debug( _LOGGER.debug(
"Device %s with known type (%s) took %.2f seconds to connect", f"Device {config.host} with connection params {has_params} "
host, + f"took {end_time - start_time:.2f} seconds to {perf_type}",
device_type.value,
end_time - start_time,
) )
return dev start_time = time.perf_counter()
unknown_dev = SmartDevice( if (protocol := get_protocol(config=config)) is None:
host=host, port=port, credentials=credentials, timeout=timeout raise UnsupportedDeviceException(
) f"Unsupported device for {config.host}: "
if protocol_class is not None: + f"{config.connection_type.device_family.value}"
# TODO this will be replaced with connection params
unknown_dev.protocol = protocol_class(
host,
transport=AesTransport(
host, port=port, credentials=credentials, timeout=timeout
),
) )
await unknown_dev.update()
device_class = get_device_class_from_sys_info(unknown_dev.internal_state) device_class: Optional[Type[SmartDevice]]
dev = device_class(host=host, port=port, credentials=credentials, timeout=timeout)
# Reuse the connection from the unknown device if isinstance(protocol, TPLinkSmartHomeProtocol):
# so we don't have to reconnect info = await protocol.query(GET_SYSINFO_QUERY)
dev.protocol = unknown_dev.protocol _perf_log(True, "get_sysinfo")
await dev.update() device_class = get_device_class_from_sys_info(info)
if debug_enabled: device = device_class(config.host, protocol=protocol)
end_time = time.perf_counter() device.update_from_discover_info(info)
_LOGGER.debug( await device.update()
"Device %s with unknown type (%s) took %.2f seconds to connect", _perf_log(True, "update")
host, return device
dev.device_type.value, elif device_class := get_device_class_from_family(
end_time - start_time, config.connection_type.device_family.value
):
device = device_class(host=config.host, protocol=protocol)
await device.update()
_perf_log(True, "update")
return device
else:
raise UnsupportedDeviceException(
f"Unsupported device for {config.host}: "
+ f"{config.connection_type.device_family.value}"
) )
return dev
def get_device_class_from_sys_info(info: Dict[str, Any]) -> Type[SmartDevice]: def get_device_class_from_sys_info(info: Dict[str, Any]) -> Type[SmartDevice]:
@ -147,32 +126,38 @@ def get_device_class_from_sys_info(info: Dict[str, Any]) -> Type[SmartDevice]:
raise UnsupportedDeviceException("Unknown device type: %s" % type_) raise UnsupportedDeviceException("Unknown device type: %s" % type_)
def get_device_class_from_type_name(device_type: str) -> Optional[Type[SmartDevice]]: def get_device_class_from_family(device_type: str) -> Optional[Type[SmartDevice]]:
"""Return the device class from the type name.""" """Return the device class from the type name."""
supported_device_types: dict[str, Type[SmartDevice]] = { supported_device_types: dict[str, Type[SmartDevice]] = {
"SMART.TAPOPLUG": TapoPlug, "SMART.TAPOPLUG": TapoPlug,
"SMART.TAPOBULB": TapoBulb, "SMART.TAPOBULB": TapoBulb,
"SMART.KASAPLUG": TapoPlug, "SMART.KASAPLUG": TapoPlug,
"IOT.SMARTPLUGSWITCH": SmartPlug, "IOT.SMARTPLUGSWITCH": SmartPlug,
"IOT.SMARTBULB": SmartBulb,
} }
return supported_device_types.get(device_type) return supported_device_types.get(device_type)
def get_protocol_from_connection_name( def get_protocol(
connection_name: str, host: str, credentials: Optional[Credentials] = None config: DeviceConfig,
) -> Optional[TPLinkProtocol]: ) -> Optional[TPLinkProtocol]:
"""Return the protocol from the connection name.""" """Return the protocol from the connection name."""
protocol_name = config.connection_type.device_family.value.split(".")[0]
protocol_transport_key = (
protocol_name + "." + config.connection_type.encryption_type.value
)
supported_device_protocols: dict[ supported_device_protocols: dict[
str, Tuple[Type[TPLinkProtocol], Type[BaseTransport]] str, Tuple[Type[TPLinkProtocol], Type[BaseTransport]]
] = { ] = {
"IOT.XOR": (TPLinkSmartHomeProtocol, _XorTransport),
"IOT.KLAP": (IotProtocol, KlapTransport), "IOT.KLAP": (IotProtocol, KlapTransport),
"SMART.AES": (SmartProtocol, AesTransport), "SMART.AES": (SmartProtocol, AesTransport),
"SMART.KLAP": (SmartProtocol, TPlinkKlapTransportV2), "SMART.KLAP": (SmartProtocol, KlapTransportV2),
} }
if connection_name not in supported_device_protocols: if protocol_transport_key not in supported_device_protocols:
return None return None
protocol_class, transport_class = supported_device_protocols.get(connection_name) # type: ignore protocol_class, transport_class = supported_device_protocols.get(
transport: BaseTransport = transport_class(host, credentials=credentials) protocol_transport_key
protocol: TPLinkProtocol = protocol_class(host, transport=transport) ) # type: ignore
return protocol return protocol_class(transport=transport_class(config=config))

148
kasa/deviceconfig.py Normal file
View File

@ -0,0 +1,148 @@
"""Module for holding connection parameters."""
import logging
from dataclasses import asdict, dataclass, field, fields, is_dataclass
from enum import Enum
from typing import Dict, Optional
import httpx
from .credentials import Credentials
from .exceptions import SmartDeviceException
_LOGGER = logging.getLogger(__name__)
class EncryptType(Enum):
"""Encrypt type enum."""
Klap = "KLAP"
Aes = "AES"
Xor = "XOR"
class DeviceFamilyType(Enum):
"""Encrypt type enum."""
IotSmartPlugSwitch = "IOT.SMARTPLUGSWITCH"
IotSmartBulb = "IOT.SMARTBULB"
SmartKasaPlug = "SMART.KASAPLUG"
SmartTapoPlug = "SMART.TAPOPLUG"
SmartTapoBulb = "SMART.TAPOBULB"
def _dataclass_from_dict(klass, in_val):
if is_dataclass(klass):
fieldtypes = {f.name: f.type for f in fields(klass)}
val = {}
for dict_key in in_val:
if dict_key in fieldtypes and hasattr(fieldtypes[dict_key], "from_dict"):
val[dict_key] = fieldtypes[dict_key].from_dict(in_val[dict_key])
else:
val[dict_key] = _dataclass_from_dict(
fieldtypes[dict_key], in_val[dict_key]
)
return klass(**val)
else:
return in_val
def _dataclass_to_dict(in_val):
fieldtypes = {f.name: f.type for f in fields(in_val) if f.compare}
out_val = {}
for field_name in fieldtypes:
val = getattr(in_val, field_name)
if val is None:
continue
elif hasattr(val, "to_dict"):
out_val[field_name] = val.to_dict()
elif is_dataclass(fieldtypes[field_name]):
out_val[field_name] = asdict(val)
else:
out_val[field_name] = val
return out_val
@dataclass
class ConnectionType:
"""Class to hold the the parameters determining connection type."""
device_family: DeviceFamilyType
encryption_type: EncryptType
@staticmethod
def from_values(
device_family: str,
encryption_type: str,
) -> "ConnectionType":
"""Return connection parameters from string values."""
try:
return ConnectionType(
DeviceFamilyType(device_family),
EncryptType(encryption_type),
)
except ValueError as ex:
raise SmartDeviceException(
f"Invalid connection parameters for {device_family}.{encryption_type}"
) from ex
@staticmethod
def from_dict(connection_type_dict: Dict[str, str]) -> "ConnectionType":
"""Return connection parameters from dict."""
if (
isinstance(connection_type_dict, dict)
and (device_family := connection_type_dict.get("device_family"))
and (encryption_type := connection_type_dict.get("encryption_type"))
):
return ConnectionType.from_values(device_family, encryption_type)
raise SmartDeviceException(
f"Invalid connection type data for {connection_type_dict}"
)
def to_dict(self) -> Dict[str, str]:
"""Convert connection params to dict."""
result = {
"device_family": self.device_family.value,
"encryption_type": self.encryption_type.value,
}
return result
@dataclass
class DeviceConfig:
"""Class to represent paramaters that determine how to connect to devices."""
DEFAULT_TIMEOUT = 5
host: str
timeout: Optional[int] = DEFAULT_TIMEOUT
port_override: Optional[int] = None
credentials: Credentials = field(
default_factory=lambda: Credentials(username="", password="")
)
connection_type: ConnectionType = field(
default_factory=lambda: ConnectionType(
DeviceFamilyType.IotSmartPlugSwitch, EncryptType.Xor
)
)
uses_http: bool = False
# compare=False will be excluded from the serialization and object comparison.
http_client: Optional[httpx.AsyncClient] = field(default=None, compare=False)
def __post_init__(self):
if self.credentials is None:
self.credentials = Credentials(username="", password="")
if self.connection_type is None:
self.connection_type = ConnectionType(
DeviceFamilyType.IotSmartPlugSwitch, EncryptType.Xor
)
def to_dict(self) -> Dict[str, Dict[str, str]]:
"""Convert connection params to dict."""
return _dataclass_to_dict(self)
@staticmethod
def from_dict(cparam_dict: Dict[str, Dict[str, str]]) -> "DeviceConfig":
"""Return connection parameters from dict."""
return _dataclass_from_dict(DeviceConfig, cparam_dict)

View File

@ -1,5 +1,6 @@
"""Discovery module for TP-Link Smart Home devices.""" """Discovery module for TP-Link Smart Home devices."""
import asyncio import asyncio
import base64
import binascii import binascii
import ipaddress import ipaddress
import logging import logging
@ -11,29 +12,32 @@ from typing import Awaitable, Callable, Dict, Optional, Set, Type, cast
from async_timeout import timeout as asyncio_timeout from async_timeout import timeout as asyncio_timeout
try: try:
from pydantic.v1 import BaseModel, Field from pydantic.v1 import BaseModel, ValidationError # pragma: no cover
except ImportError: except ImportError:
from pydantic import BaseModel, Field from pydantic import BaseModel, ValidationError # pragma: no cover
from kasa.credentials import Credentials from kasa.credentials import Credentials
from kasa.device_factory import (
get_device_class_from_family,
get_device_class_from_sys_info,
get_protocol,
)
from kasa.deviceconfig import ConnectionType, DeviceConfig, EncryptType
from kasa.exceptions import UnsupportedDeviceException from kasa.exceptions import UnsupportedDeviceException
from kasa.json import dumps as json_dumps from kasa.json import dumps as json_dumps
from kasa.json import loads as json_loads from kasa.json import loads as json_loads
from kasa.protocol import TPLinkSmartHomeProtocol from kasa.protocol import TPLinkSmartHomeProtocol
from kasa.smartdevice import SmartDevice, SmartDeviceException from kasa.smartdevice import SmartDevice, SmartDeviceException
from .device_factory import (
get_device_class_from_sys_info,
get_device_class_from_type_name,
get_protocol_from_connection_name,
)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
OnDiscoveredCallable = Callable[[SmartDevice], Awaitable[None]] OnDiscoveredCallable = Callable[[SmartDevice], Awaitable[None]]
DeviceDict = Dict[str, SmartDevice] DeviceDict = Dict[str, SmartDevice]
UNAVAILABLE_ALIAS = "Authentication required"
UNAVAILABLE_NICKNAME = base64.b64encode(UNAVAILABLE_ALIAS.encode()).decode()
class _DiscoverProtocol(asyncio.DatagramProtocol): class _DiscoverProtocol(asyncio.DatagramProtocol):
"""Implementation of the discovery protocol handler. """Implementation of the discovery protocol handler.
@ -62,9 +66,12 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
self.discovery_packets = discovery_packets self.discovery_packets = discovery_packets
self.interface = interface self.interface = interface
self.on_discovered = on_discovered self.on_discovered = on_discovered
self.port = port
self.discovery_port = port or Discover.DISCOVERY_PORT self.discovery_port = port or Discover.DISCOVERY_PORT
self.target = (target, self.discovery_port) self.target = (target, self.discovery_port)
self.target_2 = (target, Discover.DISCOVERY_PORT_2) self.target_2 = (target, Discover.DISCOVERY_PORT_2)
self.discovered_devices = {} self.discovered_devices = {}
self.unsupported_device_exceptions: Dict = {} self.unsupported_device_exceptions: Dict = {}
self.invalid_device_exceptions: Dict = {} self.invalid_device_exceptions: Dict = {}
@ -110,13 +117,18 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
self.seen_hosts.add(ip) self.seen_hosts.add(ip)
device = None device = None
config = DeviceConfig(host=ip, port_override=self.port)
if self.credentials:
config.credentials = self.credentials
if self.timeout:
config.timeout = self.timeout
try: try:
if port == self.discovery_port: if port == self.discovery_port:
device = Discover._get_device_instance_legacy(data, ip, port) device = Discover._get_device_instance_legacy(data, config)
elif port == Discover.DISCOVERY_PORT_2: elif port == Discover.DISCOVERY_PORT_2:
device = Discover._get_device_instance( config.uses_http = True
data, ip, port, self.credentials or Credentials() device = Discover._get_device_instance(data, config)
)
else: else:
return return
except UnsupportedDeviceException as udex: except UnsupportedDeviceException as udex:
@ -200,11 +212,13 @@ class Discover:
*, *,
target="255.255.255.255", target="255.255.255.255",
on_discovered=None, on_discovered=None,
timeout=5, discovery_timeout=5,
discovery_packets=3, discovery_packets=3,
interface=None, interface=None,
on_unsupported=None, on_unsupported=None,
credentials=None, credentials=None,
port=None,
timeout=None,
) -> DeviceDict: ) -> DeviceDict:
"""Discover supported devices. """Discover supported devices.
@ -240,14 +254,15 @@ class Discover:
on_unsupported=on_unsupported, on_unsupported=on_unsupported,
credentials=credentials, credentials=credentials,
timeout=timeout, timeout=timeout,
port=port,
), ),
local_addr=("0.0.0.0", 0), # noqa: S104 local_addr=("0.0.0.0", 0), # noqa: S104
) )
protocol = cast(_DiscoverProtocol, protocol) protocol = cast(_DiscoverProtocol, protocol)
try: try:
_LOGGER.debug("Waiting %s seconds for responses...", timeout) _LOGGER.debug("Waiting %s seconds for responses...", discovery_timeout)
await asyncio.sleep(timeout) await asyncio.sleep(discovery_timeout)
finally: finally:
transport.close() transport.close()
@ -259,10 +274,10 @@ class Discover:
async def discover_single( async def discover_single(
host: str, host: str,
*, *,
discovery_timeout: int = 5,
port: Optional[int] = None, port: Optional[int] = None,
timeout=5, timeout: Optional[int] = None,
credentials: Optional[Credentials] = None, credentials: Optional[Credentials] = None,
update_parent_devices: bool = True,
) -> SmartDevice: ) -> SmartDevice:
"""Discover a single device by the given IP address. """Discover a single device by the given IP address.
@ -275,8 +290,6 @@ class Discover:
:param port: Optionally set a different port for the device :param port: Optionally set a different port for the device
:param timeout: Timeout for discovery :param timeout: Timeout for discovery
:param credentials: Credentials for devices that require authentication :param credentials: Credentials for devices that require authentication
:param update_parent_devices: Automatically call device.update() on
devices that have children
:rtype: SmartDevice :rtype: SmartDevice
:return: Object for querying/controlling found device. :return: Object for querying/controlling found device.
""" """
@ -320,9 +333,11 @@ class Discover:
protocol = cast(_DiscoverProtocol, protocol) protocol = cast(_DiscoverProtocol, protocol)
try: try:
_LOGGER.debug("Waiting a total of %s seconds for responses...", timeout) _LOGGER.debug(
"Waiting a total of %s seconds for responses...", discovery_timeout
)
async with asyncio_timeout(timeout): async with asyncio_timeout(discovery_timeout):
await event.wait() await event.wait()
except asyncio.TimeoutError as ex: except asyncio.TimeoutError as ex:
raise SmartDeviceException( raise SmartDeviceException(
@ -334,9 +349,6 @@ class Discover:
if ip in protocol.discovered_devices: if ip in protocol.discovered_devices:
dev = protocol.discovered_devices[ip] dev = protocol.discovered_devices[ip]
dev.host = host dev.host = host
# Call device update on devices that have children
if update_parent_devices and dev.has_children:
await dev.update()
return dev return dev
elif ip in protocol.unsupported_device_exceptions: elif ip in protocol.unsupported_device_exceptions:
raise protocol.unsupported_device_exceptions[ip] raise protocol.unsupported_device_exceptions[ip]
@ -350,99 +362,121 @@ class Discover:
"""Find SmartDevice subclass for device described by passed data.""" """Find SmartDevice subclass for device described by passed data."""
if "result" in info: if "result" in info:
discovery_result = DiscoveryResult(**info["result"]) discovery_result = DiscoveryResult(**info["result"])
dev_class = get_device_class_from_type_name(discovery_result.device_type) dev_class = get_device_class_from_family(discovery_result.device_type)
if not dev_class: if not dev_class:
raise UnsupportedDeviceException( raise UnsupportedDeviceException(
"Unknown device type: %s" % discovery_result.device_type "Unknown device type: %s" % discovery_result.device_type,
discovery_result=info,
) )
return dev_class return dev_class
else: else:
return get_device_class_from_sys_info(info) return get_device_class_from_sys_info(info)
@staticmethod @staticmethod
def _get_device_instance_legacy(data: bytes, ip: str, port: int) -> SmartDevice: def _get_device_instance_legacy(data: bytes, config: DeviceConfig) -> SmartDevice:
"""Get SmartDevice from legacy 9999 response.""" """Get SmartDevice from legacy 9999 response."""
try: try:
info = json_loads(TPLinkSmartHomeProtocol.decrypt(data)) info = json_loads(TPLinkSmartHomeProtocol.decrypt(data))
except Exception as ex: except Exception as ex:
raise SmartDeviceException( raise SmartDeviceException(
f"Unable to read response from device: {ip}: {ex}" f"Unable to read response from device: {config.host}: {ex}"
) from ex ) from ex
_LOGGER.debug("[DISCOVERY] %s << %s", ip, info) _LOGGER.debug("[DISCOVERY] %s << %s", config.host, info)
device_class = Discover._get_device_class(info) device_class = Discover._get_device_class(info)
device = device_class(ip, port=port) device = device_class(config.host, config=config)
sys_info = info["system"]["get_sysinfo"]
if device_type := sys_info.get("mic_type", sys_info.get("type")):
config.connection_type = ConnectionType.from_values(
device_family=device_type, encryption_type=EncryptType.Xor.value
)
device.protocol = get_protocol(config) # type: ignore[assignment]
device.update_from_discover_info(info) device.update_from_discover_info(info)
return device return device
@staticmethod @staticmethod
def _get_device_instance( def _get_device_instance(
data: bytes, ip: str, port: int, credentials: Credentials data: bytes,
config: DeviceConfig,
) -> SmartDevice: ) -> SmartDevice:
"""Get SmartDevice from the new 20002 response.""" """Get SmartDevice from the new 20002 response."""
try: try:
info = json_loads(data[16:]) info = json_loads(data[16:])
discovery_result = DiscoveryResult(**info["result"])
except Exception as ex: except Exception as ex:
_LOGGER.debug("Got invalid response from device %s: %s", config.host, data)
raise SmartDeviceException(
f"Unable to read response from device: {config.host}: {ex}"
) from ex
try:
discovery_result = DiscoveryResult(**info["result"])
except ValidationError as ex:
_LOGGER.debug(
"Unable to parse discovery from device %s: %s", config.host, info
)
raise UnsupportedDeviceException( raise UnsupportedDeviceException(
f"Unable to read response from device: {ip}: {ex}" f"Unable to parse discovery from device: {config.host}: {ex}"
) from ex ) from ex
type_ = discovery_result.device_type type_ = discovery_result.device_type
encrypt_type_ = (
f"{type_.split('.')[0]}.{discovery_result.mgt_encrypt_schm.encrypt_type}"
)
if (device_class := get_device_class_from_type_name(type_)) is None: try:
config.connection_type = ConnectionType.from_values(
type_, discovery_result.mgt_encrypt_schm.encrypt_type
)
except SmartDeviceException as ex:
raise UnsupportedDeviceException(
f"Unsupported device {config.host} of type {type_} "
+ f"with encrypt_type {discovery_result.mgt_encrypt_schm.encrypt_type}",
discovery_result=discovery_result.get_dict(),
) from ex
if (device_class := get_device_class_from_family(type_)) is None:
_LOGGER.warning("Got unsupported device type: %s", type_) _LOGGER.warning("Got unsupported device type: %s", type_)
raise UnsupportedDeviceException( raise UnsupportedDeviceException(
f"Unsupported device {ip} of type {type_}: {info}", f"Unsupported device {config.host} of type {type_}: {info}",
discovery_result=discovery_result.get_dict(), discovery_result=discovery_result.get_dict(),
) )
if ( if (protocol := get_protocol(config)) is None:
protocol := get_protocol_from_connection_name( _LOGGER.warning(
encrypt_type_, ip, credentials=credentials "Got unsupported connection type: %s", config.connection_type.to_dict()
) )
) is None:
_LOGGER.warning("Got unsupported device type: %s", encrypt_type_)
raise UnsupportedDeviceException( raise UnsupportedDeviceException(
f"Unsupported encryption scheme {ip} of type {encrypt_type_}: {info}", f"Unsupported encryption scheme {config.host} of "
+ f"type {config.connection_type.to_dict()}: {info}",
discovery_result=discovery_result.get_dict(), discovery_result=discovery_result.get_dict(),
) )
_LOGGER.debug("[DISCOVERY] %s << %s", ip, info) _LOGGER.debug("[DISCOVERY] %s << %s", config.host, info)
device = device_class(ip, port=port, credentials=credentials) device = device_class(config.host, protocol=protocol)
device.protocol = protocol
device.update_from_discover_info(discovery_result.get_dict()) di = discovery_result.get_dict()
di["model"] = discovery_result.device_model
di["alias"] = UNAVAILABLE_ALIAS
di["nickname"] = UNAVAILABLE_NICKNAME
device.update_from_discover_info(di)
return device return device
class DiscoveryResult(BaseModel): class DiscoveryResult(BaseModel):
"""Base model for discovery result.""" """Base model for discovery result."""
class Config:
"""Class for configuring model behaviour."""
allow_population_by_field_name = True
class EncryptionScheme(BaseModel): class EncryptionScheme(BaseModel):
"""Base model for encryption scheme of discovery result.""" """Base model for encryption scheme of discovery result."""
is_support_https: Optional[bool] = None is_support_https: bool
encrypt_type: Optional[str] = None encrypt_type: str
http_port: Optional[int] = None http_port: int
lv: Optional[int] = 1 lv: Optional[int] = None
device_type: str = Field(alias="device_type_text") device_type: str
device_model: str = Field(alias="model") device_model: str
ip: str = Field(alias="alias") ip: str
mac: str mac: str
mgt_encrypt_schm: EncryptionScheme mgt_encrypt_schm: EncryptionScheme
device_id: str
device_id: Optional[str] = Field(default=None, alias="device_id_hash")
owner: Optional[str] = Field(default=None, alias="device_owner_hash")
hw_ver: Optional[str] = None hw_ver: Optional[str] = None
owner: Optional[str] = None
is_support_iot_cloud: Optional[bool] = None is_support_iot_cloud: Optional[bool] = None
obd_src: Optional[str] = None obd_src: Optional[str] = None
factory_default: Optional[bool] = None factory_default: Optional[bool] = None
@ -453,5 +487,5 @@ class DiscoveryResult(BaseModel):
containing only the values actually set and with aliases as field names. containing only the values actually set and with aliases as field names.
""" """
return self.dict( return self.dict(
by_alias=True, exclude_unset=True, exclude_none=True, exclude_defaults=True by_alias=False, exclude_unset=True, exclude_none=True, exclude_defaults=True
) )

View File

@ -17,12 +17,11 @@ class IotProtocol(TPLinkProtocol):
def __init__( def __init__(
self, self,
host: str,
*, *,
transport: BaseTransport, transport: BaseTransport,
) -> None: ) -> None:
"""Create a protocol object.""" """Create a protocol object."""
super().__init__(host, transport=transport) super().__init__(transport=transport)
self._query_lock = asyncio.Lock() self._query_lock = asyncio.Lock()
@ -39,25 +38,21 @@ class IotProtocol(TPLinkProtocol):
for retry in range(retry_count + 1): for retry in range(retry_count + 1):
try: try:
return await self._execute_query(request, retry) return await self._execute_query(request, retry)
except httpx.CloseError as sdex: except httpx.ConnectError as sdex:
await self.close()
if retry >= retry_count: if retry >= retry_count:
await self.close()
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry) _LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise SmartDeviceException( raise SmartDeviceException(
f"Unable to connect to the device: {self._host}: {sdex}" f"Unable to connect to the device: {self._host}: {sdex}"
) from sdex ) from sdex
continue continue
except httpx.ConnectError as cex:
await self.close()
raise SmartDeviceException(
f"Unable to connect to the device: {self._host}: {cex}"
) from cex
except TimeoutError as tex: except TimeoutError as tex:
await self.close() await self.close()
raise SmartDeviceException( raise SmartDeviceException(
f"Unable to connect to the device, timed out: {self._host}: {tex}" f"Unable to connect to the device, timed out: {self._host}: {tex}"
) from tex ) from tex
except AuthenticationException as auex: except AuthenticationException as auex:
await self.close()
_LOGGER.debug( _LOGGER.debug(
"Unable to authenticate with %s, not retrying", self._host "Unable to authenticate with %s, not retrying", self._host
) )
@ -70,8 +65,8 @@ class IotProtocol(TPLinkProtocol):
) )
raise ex raise ex
except Exception as ex: except Exception as ex:
await self.close()
if retry >= retry_count: if retry >= retry_count:
await self.close()
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry) _LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise SmartDeviceException( raise SmartDeviceException(
f"Unable to connect to the device: {self._host}: {ex}" f"Unable to connect to the device: {self._host}: {ex}"

View File

@ -54,6 +54,7 @@ from cryptography.hazmat.primitives import hashes, padding
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from .credentials import Credentials from .credentials import Credentials
from .deviceconfig import DeviceConfig
from .exceptions import AuthenticationException, SmartDeviceException from .exceptions import AuthenticationException, SmartDeviceException
from .json import loads as json_loads from .json import loads as json_loads
from .protocol import BaseTransport, md5 from .protocol import BaseTransport, md5
@ -82,27 +83,21 @@ class KlapTransport(BaseTransport):
protocol, used by newer firmware versions. protocol, used by newer firmware versions.
""" """
DEFAULT_PORT = 80 DEFAULT_PORT: int = 80
DISCOVERY_QUERY = {"system": {"get_sysinfo": None}} DISCOVERY_QUERY = {"system": {"get_sysinfo": None}}
KASA_SETUP_EMAIL = "kasa@tp-link.net" KASA_SETUP_EMAIL = "kasa@tp-link.net"
KASA_SETUP_PASSWORD = "kasaSetup" # noqa: S105 KASA_SETUP_PASSWORD = "kasaSetup" # noqa: S105
SESSION_COOKIE_NAME = "TP_SESSIONID" SESSION_COOKIE_NAME = "TP_SESSIONID"
def __init__( def __init__(
self, self,
host: str,
*, *,
port: Optional[int] = None, config: DeviceConfig,
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
) -> None: ) -> None:
super().__init__( super().__init__(config=config)
host,
port=port or self.DEFAULT_PORT,
credentials=credentials,
timeout=timeout,
)
self._default_http_client: Optional[httpx.AsyncClient] = None
self._local_seed: Optional[bytes] = None self._local_seed: Optional[bytes] = None
self._local_auth_hash = self.generate_auth_hash(self._credentials) self._local_auth_hash = self.generate_auth_hash(self._credentials)
self._local_auth_owner = self.generate_owner_hash(self._credentials).hex() self._local_auth_owner = self.generate_owner_hash(self._credentials).hex()
@ -116,14 +111,24 @@ class KlapTransport(BaseTransport):
self._session_expire_at: Optional[float] = None self._session_expire_at: Optional[float] = None
self._session_cookie = None self._session_cookie = None
self._http_client: httpx.AsyncClient = httpx.AsyncClient()
_LOGGER.debug("Created KLAP transport for %s", self._host) _LOGGER.debug("Created KLAP transport for %s", self._host)
@property
def default_port(self):
"""Default port for the transport."""
return self.DEFAULT_PORT
@property
def _http_client(self) -> httpx.AsyncClient:
if self._config.http_client:
return self._config.http_client
if not self._default_http_client:
self._default_http_client = httpx.AsyncClient()
return self._default_http_client
async def client_post(self, url, params=None, data=None): async def client_post(self, url, params=None, data=None):
"""Send an http post request to the device.""" """Send an http post request to the device."""
if not self._http_client:
self._http_client = httpx.AsyncClient()
response_data = None response_data = None
cookies = None cookies = None
if self._session_cookie: if self._session_cookie:
@ -355,8 +360,8 @@ class KlapTransport(BaseTransport):
async def close(self) -> None: async def close(self) -> None:
"""Close the transport.""" """Close the transport."""
client = self._http_client client = self._default_http_client
self._http_client = None self._default_http_client = None
self._handshake_done = False self._handshake_done = False
if client: if client:
await client.aclose() await client.aclose()
@ -390,7 +395,7 @@ class KlapTransport(BaseTransport):
return md5(un.encode()) return md5(un.encode())
class TPlinkKlapTransportV2(KlapTransport): class KlapTransportV2(KlapTransport):
"""Implementation of the KLAP encryption protocol with v2 hanshake hashes.""" """Implementation of the KLAP encryption protocol with v2 hanshake hashes."""
@staticmethod @staticmethod

View File

@ -24,7 +24,7 @@ from typing import Dict, Generator, Optional, Union
from async_timeout import timeout as asyncio_timeout from async_timeout import timeout as asyncio_timeout
from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import hashes
from .credentials import Credentials from .deviceconfig import DeviceConfig
from .exceptions import SmartDeviceException from .exceptions import SmartDeviceException
from .json import dumps as json_dumps from .json import dumps as json_dumps
from .json import loads as json_loads from .json import loads as json_loads
@ -48,17 +48,20 @@ class BaseTransport(ABC):
def __init__( def __init__(
self, self,
host: str,
*, *,
port: Optional[int] = None, config: DeviceConfig,
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
) -> None: ) -> None:
"""Create a protocol object.""" """Create a protocol object."""
self._host = host self._config = config
self._port = port self._host = config.host
self._credentials = credentials or Credentials(username="", password="") self._port = config.port_override or self.default_port
self._timeout = timeout or self.DEFAULT_TIMEOUT self._credentials = config.credentials
self._timeout = config.timeout
@property
@abstractmethod
def default_port(self) -> int:
"""The default port for the transport."""
@abstractmethod @abstractmethod
async def send(self, request: str) -> Dict: async def send(self, request: str) -> Dict:
@ -74,7 +77,6 @@ class TPLinkProtocol(ABC):
def __init__( def __init__(
self, self,
host: str,
*, *,
transport: BaseTransport, transport: BaseTransport,
) -> None: ) -> None:
@ -85,6 +87,11 @@ class TPLinkProtocol(ABC):
def _host(self): def _host(self):
return self._transport._host return self._transport._host
@property
def config(self) -> DeviceConfig:
"""Return the connection parameters the device is using."""
return self._transport._config
@abstractmethod @abstractmethod
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
"""Query the device for the protocol. Abstract method to be overriden.""" """Query the device for the protocol. Abstract method to be overriden."""
@ -103,22 +110,15 @@ class _XorTransport(BaseTransport):
class. class.
""" """
DEFAULT_PORT = 9999 DEFAULT_PORT: int = 9999
def __init__( def __init__(self, *, config: DeviceConfig) -> None:
self, super().__init__(config=config)
host: str,
*, @property
port: Optional[int] = None, def default_port(self):
credentials: Optional[Credentials] = None, """Default port for the transport."""
timeout: Optional[int] = None, return self.DEFAULT_PORT
) -> None:
super().__init__(
host,
port=port or self.DEFAULT_PORT,
credentials=credentials,
timeout=timeout,
)
async def send(self, request: str) -> Dict: async def send(self, request: str) -> Dict:
"""Send a message to the device and return a response.""" """Send a message to the device and return a response."""
@ -133,17 +133,15 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
INITIALIZATION_VECTOR = 171 INITIALIZATION_VECTOR = 171
DEFAULT_PORT = 9999 DEFAULT_PORT = 9999
DEFAULT_TIMEOUT = 5
BLOCK_SIZE = 4 BLOCK_SIZE = 4
def __init__( def __init__(
self, self,
host: str,
*, *,
transport: BaseTransport, transport: BaseTransport,
) -> None: ) -> None:
"""Create a protocol object.""" """Create a protocol object."""
super().__init__(host, transport=transport) super().__init__(transport=transport)
self.reader: Optional[asyncio.StreamReader] = None self.reader: Optional[asyncio.StreamReader] = None
self.writer: Optional[asyncio.StreamWriter] = None self.writer: Optional[asyncio.StreamWriter] = None
@ -167,7 +165,7 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
assert isinstance(request, str) # noqa: S101 assert isinstance(request, str) # noqa: S101
async with self.query_lock: async with self.query_lock:
return await self._query(request, retry_count, self._timeout) return await self._query(request, retry_count, self._timeout) # type: ignore[arg-type]
async def _connect(self, timeout: int) -> None: async def _connect(self, timeout: int) -> None:
"""Try to connect or reconnect to the device.""" """Try to connect or reconnect to the device."""

View File

@ -9,8 +9,9 @@ try:
except ImportError: except ImportError:
from pydantic import BaseModel, Field, root_validator from pydantic import BaseModel, Field, root_validator
from .credentials import Credentials from .deviceconfig import DeviceConfig
from .modules import Antitheft, Cloud, Countdown, Emeter, Schedule, Time, Usage from .modules import Antitheft, Cloud, Countdown, Emeter, Schedule, Time, Usage
from .protocol import TPLinkProtocol
from .smartdevice import DeviceType, SmartDevice, SmartDeviceException, requires_update from .smartdevice import DeviceType, SmartDevice, SmartDeviceException, requires_update
@ -220,11 +221,10 @@ class SmartBulb(SmartDevice):
self, self,
host: str, host: str,
*, *,
port: Optional[int] = None, config: Optional[DeviceConfig] = None,
credentials: Optional[Credentials] = None, protocol: Optional[TPLinkProtocol] = None,
timeout: Optional[int] = None,
) -> None: ) -> None:
super().__init__(host=host, port=port, credentials=credentials, timeout=timeout) super().__init__(host=host, config=config, protocol=protocol)
self._device_type = DeviceType.Bulb self._device_type = DeviceType.Bulb
self.add_module("schedule", Schedule(self, "smartlife.iot.common.schedule")) self.add_module("schedule", Schedule(self, "smartlife.iot.common.schedule"))
self.add_module("usage", Usage(self, "smartlife.iot.common.schedule")) self.add_module("usage", Usage(self, "smartlife.iot.common.schedule"))

View File

@ -21,6 +21,7 @@ from typing import Any, Dict, List, Optional, Set
from .credentials import Credentials from .credentials import Credentials
from .device_type import DeviceType from .device_type import DeviceType
from .deviceconfig import DeviceConfig
from .emeterstatus import EmeterStatus from .emeterstatus import EmeterStatus
from .exceptions import SmartDeviceException from .exceptions import SmartDeviceException
from .modules import Emeter, Module from .modules import Emeter, Module
@ -191,20 +192,18 @@ class SmartDevice:
self, self,
host: str, host: str,
*, *,
port: Optional[int] = None, config: Optional[DeviceConfig] = None,
credentials: Optional[Credentials] = None, protocol: Optional[TPLinkProtocol] = None,
timeout: Optional[int] = None,
) -> None: ) -> None:
"""Create a new SmartDevice instance. """Create a new SmartDevice instance.
:param str host: host name or ip address on which the device listens :param str host: host name or ip address on which the device listens
""" """
self.host = host if config and protocol:
self.port = port protocol._transport._config = config
self.protocol: TPLinkProtocol = TPLinkSmartHomeProtocol( self.protocol: TPLinkProtocol = protocol or TPLinkSmartHomeProtocol(
host, transport=_XorTransport(host, port=port, timeout=timeout) transport=_XorTransport(config=config or DeviceConfig(host=host)),
) )
self.credentials = credentials
_LOGGER.debug("Initializing %s of type %s", self.host, type(self)) _LOGGER.debug("Initializing %s of type %s", self.host, type(self))
self._device_type = DeviceType.Unknown self._device_type = DeviceType.Unknown
# TODO: typing Any is just as using Optional[Dict] would require separate # TODO: typing Any is just as using Optional[Dict] would require separate
@ -219,6 +218,30 @@ class SmartDevice:
self.children: List["SmartDevice"] = [] self.children: List["SmartDevice"] = []
@property
def host(self) -> str:
"""The device host."""
return self.protocol._transport._host
@host.setter
def host(self, value):
"""Set the device host.
Generally used by discovery to set the hostname after ip discovery.
"""
self.protocol._transport._host = value
self.protocol._transport._config.host = value
@property
def port(self) -> int:
"""The device port."""
return self.protocol._transport._port
@property
def credentials(self) -> Optional[Credentials]:
"""The device credentials."""
return self.protocol._transport._credentials
def add_module(self, name: str, module: Module): def add_module(self, name: str, module: Module):
"""Register a module.""" """Register a module."""
if name in self.modules: if name in self.modules:
@ -760,7 +783,7 @@ class SmartDevice:
The returned object contains the raw results from the last update call. The returned object contains the raw results from the last update call.
This should only be used for debugging purposes. This should only be used for debugging purposes.
""" """
return self._last_update return self._last_update or self._discovery_info
def __repr__(self): def __repr__(self):
if self._last_update is None: if self._last_update is None:
@ -771,41 +794,33 @@ class SmartDevice:
f" - dev specific: {self.state_information}>" f" - dev specific: {self.state_information}>"
) )
@property
def config(self) -> DeviceConfig:
"""Return the connection parameters the device is using."""
return self.protocol.config
@staticmethod @staticmethod
async def connect( async def connect(
host: str,
*, *,
port: Optional[int] = None, host: Optional[str] = None,
timeout=5, config: Optional[DeviceConfig] = None,
credentials: Optional[Credentials] = None,
device_type: Optional[DeviceType] = None,
) -> "SmartDevice": ) -> "SmartDevice":
"""Connect to a single device by the given IP address. """Connect to a single device by the given hostname or device configuration.
This method avoids the UDP based discovery process and This method avoids the UDP based discovery process and
will connect directly to the device to query its type. will connect directly to the device.
It is generally preferred to avoid :func:`discover_single()` and It is generally preferred to avoid :func:`discover_single()` and
use this function instead as it should perform better when use this function instead as it should perform better when
the WiFi network is congested or the device is not responding the WiFi network is congested or the device is not responding
to discovery requests. to discovery requests.
The device type is discovered by querying the device.
:param host: Hostname of device to query :param host: Hostname of device to query
:param device_type: Device type to use for the device. :param config: Connection parameters to ensure the correct protocol
If not given, the device type is discovered by querying the device. and connection options are used.
If the device type is already known, it is preferred to pass it
to avoid the extra query to the device to discover its type.
:rtype: SmartDevice :rtype: SmartDevice
:return: Object for querying/controlling found device. :return: Object for querying/controlling found device.
""" """
from .device_factory import connect # pylint: disable=import-outside-toplevel from .device_factory import connect # pylint: disable=import-outside-toplevel
return await connect( return await connect(host=host, config=config) # type: ignore[arg-type]
host=host,
port=port,
timeout=timeout,
credentials=credentials,
device_type=device_type,
)

View File

@ -2,8 +2,9 @@
from enum import Enum from enum import Enum
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from kasa.credentials import Credentials from kasa.deviceconfig import DeviceConfig
from kasa.modules import AmbientLight, Motion from kasa.modules import AmbientLight, Motion
from kasa.protocol import TPLinkProtocol
from kasa.smartdevice import DeviceType, SmartDeviceException, requires_update from kasa.smartdevice import DeviceType, SmartDeviceException, requires_update
from kasa.smartplug import SmartPlug from kasa.smartplug import SmartPlug
@ -68,11 +69,10 @@ class SmartDimmer(SmartPlug):
self, self,
host: str, host: str,
*, *,
port: Optional[int] = None, config: Optional[DeviceConfig] = None,
credentials: Optional[Credentials] = None, protocol: Optional[TPLinkProtocol] = None,
timeout: Optional[int] = None,
) -> None: ) -> None:
super().__init__(host, port=port, credentials=credentials, timeout=timeout) super().__init__(host=host, config=config, protocol=protocol)
self._device_type = DeviceType.Dimmer self._device_type = DeviceType.Dimmer
# TODO: need to be verified if it's okay to call these on HS220 w/o these # TODO: need to be verified if it's okay to call these on HS220 w/o these
# TODO: need to be figured out what's the best approach to detect support # TODO: need to be figured out what's the best approach to detect support

View File

@ -1,8 +1,9 @@
"""Module for light strips (KL430).""" """Module for light strips (KL430)."""
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from .credentials import Credentials from .deviceconfig import DeviceConfig
from .effects import EFFECT_MAPPING_V1, EFFECT_NAMES_V1 from .effects import EFFECT_MAPPING_V1, EFFECT_NAMES_V1
from .protocol import TPLinkProtocol
from .smartbulb import SmartBulb from .smartbulb import SmartBulb
from .smartdevice import DeviceType, SmartDeviceException, requires_update from .smartdevice import DeviceType, SmartDeviceException, requires_update
@ -46,11 +47,10 @@ class SmartLightStrip(SmartBulb):
self, self,
host: str, host: str,
*, *,
port: Optional[int] = None, config: Optional[DeviceConfig] = None,
credentials: Optional[Credentials] = None, protocol: Optional[TPLinkProtocol] = None,
timeout: Optional[int] = None,
) -> None: ) -> None:
super().__init__(host, port=port, credentials=credentials, timeout=timeout) super().__init__(host=host, config=config, protocol=protocol)
self._device_type = DeviceType.LightStrip self._device_type = DeviceType.LightStrip
@property # type: ignore @property # type: ignore

View File

@ -2,8 +2,9 @@
import logging import logging
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from kasa.credentials import Credentials from kasa.deviceconfig import DeviceConfig
from kasa.modules import Antitheft, Cloud, Schedule, Time, Usage from kasa.modules import Antitheft, Cloud, Schedule, Time, Usage
from kasa.protocol import TPLinkProtocol
from kasa.smartdevice import DeviceType, SmartDevice, requires_update from kasa.smartdevice import DeviceType, SmartDevice, requires_update
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -43,11 +44,10 @@ class SmartPlug(SmartDevice):
self, self,
host: str, host: str,
*, *,
port: Optional[int] = None, config: Optional[DeviceConfig] = None,
credentials: Optional[Credentials] = None, protocol: Optional[TPLinkProtocol] = None,
timeout: Optional[int] = None,
) -> None: ) -> None:
super().__init__(host, port=port, credentials=credentials, timeout=timeout) super().__init__(host=host, config=config, protocol=protocol)
self._device_type = DeviceType.Plug self._device_type = DeviceType.Plug
self.add_module("schedule", Schedule(self, "schedule")) self.add_module("schedule", Schedule(self, "schedule"))
self.add_module("usage", Usage(self, "schedule")) self.add_module("usage", Usage(self, "schedule"))

View File

@ -38,12 +38,11 @@ class SmartProtocol(TPLinkProtocol):
def __init__( def __init__(
self, self,
host: str,
*, *,
transport: BaseTransport, transport: BaseTransport,
) -> None: ) -> None:
"""Create a protocol object.""" """Create a protocol object."""
super().__init__(host, transport=transport) super().__init__(transport=transport)
self._terminal_uuid: str = base64.b64encode(md5(uuid.uuid4().bytes)).decode() self._terminal_uuid: str = base64.b64encode(md5(uuid.uuid4().bytes)).decode()
self._request_id_generator = SnowflakeId(1, 1) self._request_id_generator = SnowflakeId(1, 1)
self._query_lock = asyncio.Lock() self._query_lock = asyncio.Lock()
@ -68,19 +67,14 @@ class SmartProtocol(TPLinkProtocol):
for retry in range(retry_count + 1): for retry in range(retry_count + 1):
try: try:
return await self._execute_query(request, retry) return await self._execute_query(request, retry)
except httpx.CloseError as sdex: except httpx.ConnectError as sdex:
await self.close()
if retry >= retry_count: if retry >= retry_count:
await self.close()
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry) _LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise SmartDeviceException( raise SmartDeviceException(
f"Unable to connect to the device: {self._host}: {sdex}" f"Unable to connect to the device: {self._host}: {sdex}"
) from sdex ) from sdex
continue continue
except httpx.ConnectError as cex:
await self.close()
raise SmartDeviceException(
f"Unable to connect to the device: {self._host}: {cex}"
) from cex
except TimeoutError as tex: except TimeoutError as tex:
if retry >= retry_count: if retry >= retry_count:
await self.close() await self.close()

View File

@ -14,8 +14,9 @@ from kasa.smartdevice import (
) )
from kasa.smartplug import SmartPlug from kasa.smartplug import SmartPlug
from .credentials import Credentials from .deviceconfig import DeviceConfig
from .modules import Antitheft, Countdown, Emeter, Schedule, Time, Usage from .modules import Antitheft, Countdown, Emeter, Schedule, Time, Usage
from .protocol import TPLinkProtocol
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -85,11 +86,10 @@ class SmartStrip(SmartDevice):
self, self,
host: str, host: str,
*, *,
port: Optional[int] = None, config: Optional[DeviceConfig] = None,
credentials: Optional[Credentials] = None, protocol: Optional[TPLinkProtocol] = None,
timeout: Optional[int] = None,
) -> None: ) -> None:
super().__init__(host=host, port=port, credentials=credentials, timeout=timeout) super().__init__(host=host, config=config, protocol=protocol)
self.emeter_type = "emeter" self.emeter_type = "emeter"
self._device_type = DeviceType.Strip self._device_type = DeviceType.Strip
self.add_module("antitheft", Antitheft(self, "anti_theft")) self.add_module("antitheft", Antitheft(self, "anti_theft"))

View File

@ -5,8 +5,9 @@ from datetime import datetime, timedelta, timezone
from typing import Any, Dict, Optional, Set, cast from typing import Any, Dict, Optional, Set, cast
from ..aestransport import AesTransport from ..aestransport import AesTransport
from ..credentials import Credentials from ..deviceconfig import DeviceConfig
from ..exceptions import AuthenticationException from ..exceptions import AuthenticationException
from ..protocol import TPLinkProtocol
from ..smartdevice import SmartDevice from ..smartdevice import SmartDevice
from ..smartprotocol import SmartProtocol from ..smartprotocol import SmartProtocol
@ -20,20 +21,16 @@ class TapoDevice(SmartDevice):
self, self,
host: str, host: str,
*, *,
port: Optional[int] = None, config: Optional[DeviceConfig] = None,
credentials: Optional[Credentials] = None, protocol: Optional[TPLinkProtocol] = None,
timeout: Optional[int] = None,
) -> None: ) -> None:
super().__init__(host, port=port, credentials=credentials, timeout=timeout) _protocol = protocol or SmartProtocol(
transport=AesTransport(config=config or DeviceConfig(host=host)),
)
super().__init__(host=host, config=config, protocol=_protocol)
self._components: Optional[Dict[str, Any]] = None self._components: Optional[Dict[str, Any]] = None
self._state_information: Dict[str, Any] = {} self._state_information: Dict[str, Any] = {}
self._discovery_info: Optional[Dict[str, Any]] = None self._discovery_info: Optional[Dict[str, Any]] = None
self.protocol = SmartProtocol(
host,
transport=AesTransport(
host, credentials=credentials, timeout=timeout, port=port
),
)
async def update(self, update_children: bool = True): async def update(self, update_children: bool = True):
"""Update the device.""" """Update the device."""
@ -66,7 +63,7 @@ class TapoDevice(SmartDevice):
@property @property
def sys_info(self) -> Dict[str, Any]: def sys_info(self) -> Dict[str, Any]:
"""Returns the device info.""" """Returns the device info."""
return self._info return self._info # type: ignore
@property @property
def model(self) -> str: def model(self) -> str:
@ -180,3 +177,4 @@ class TapoDevice(SmartDevice):
def update_from_discover_info(self, info): def update_from_discover_info(self, info):
"""Update state from info from the discover call.""" """Update state from info from the discover call."""
self._discovery_info = info self._discovery_info = info
self._info = info

View File

@ -3,9 +3,10 @@ import logging
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Dict, Optional, cast from typing import Any, Dict, Optional, cast
from ..credentials import Credentials from ..deviceconfig import DeviceConfig
from ..emeterstatus import EmeterStatus from ..emeterstatus import EmeterStatus
from ..modules import Emeter from ..modules import Emeter
from ..protocol import TPLinkProtocol
from ..smartdevice import DeviceType, requires_update from ..smartdevice import DeviceType, requires_update
from .tapodevice import TapoDevice from .tapodevice import TapoDevice
@ -19,11 +20,10 @@ class TapoPlug(TapoDevice):
self, self,
host: str, host: str,
*, *,
port: Optional[int] = None, config: Optional[DeviceConfig] = None,
credentials: Optional[Credentials] = None, protocol: Optional[TPLinkProtocol] = None,
timeout: Optional[int] = None,
) -> None: ) -> None:
super().__init__(host, port=port, credentials=credentials, timeout=timeout) super().__init__(host=host, config=config, protocol=protocol)
self._device_type = DeviceType.Plug self._device_type = DeviceType.Plug
self.modules: Dict[str, Any] = {} self.modules: Dict[str, Any] = {}
self.emeter_type = "emeter" self.emeter_type = "emeter"

View File

@ -388,7 +388,6 @@ async def get_device_for_file(file, protocol):
d = device_for_file(model, protocol)(host="127.0.0.123") d = device_for_file(model, protocol)(host="127.0.0.123")
if protocol == "SMART": if protocol == "SMART":
d.protocol = FakeSmartProtocol(sysinfo) d.protocol = FakeSmartProtocol(sysinfo)
d.credentials = Credentials("", "")
else: else:
d.protocol = FakeTransportProtocol(sysinfo) d.protocol = FakeTransportProtocol(sysinfo)
await _update_and_close(d) await _update_and_close(d)
@ -426,28 +425,53 @@ def discovery_mock(all_fixture_data, mocker):
class _DiscoveryMock: class _DiscoveryMock:
ip: str ip: str
default_port: int default_port: int
discovery_port: int
discovery_data: dict discovery_data: dict
query_data: dict query_data: dict
device_type: str
encrypt_type: str
port_override: Optional[int] = None port_override: Optional[int] = None
if "discovery_result" in all_fixture_data: if "discovery_result" in all_fixture_data:
discovery_data = {"result": all_fixture_data["discovery_result"]} discovery_data = {"result": all_fixture_data["discovery_result"]}
device_type = all_fixture_data["discovery_result"]["device_type"]
encrypt_type = all_fixture_data["discovery_result"]["mgt_encrypt_schm"][
"encrypt_type"
]
datagram = ( datagram = (
b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8" b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8"
+ json_dumps(discovery_data).encode() + json_dumps(discovery_data).encode()
) )
dm = _DiscoveryMock("127.0.0.123", 20002, discovery_data, all_fixture_data) dm = _DiscoveryMock(
"127.0.0.123",
80,
20002,
discovery_data,
all_fixture_data,
device_type,
encrypt_type,
)
else: else:
sys_info = all_fixture_data["system"]["get_sysinfo"] sys_info = all_fixture_data["system"]["get_sysinfo"]
discovery_data = {"system": {"get_sysinfo": sys_info}} discovery_data = {"system": {"get_sysinfo": sys_info}}
device_type = sys_info.get("mic_type") or sys_info.get("type")
encrypt_type = "XOR"
datagram = TPLinkSmartHomeProtocol.encrypt(json_dumps(discovery_data))[4:] datagram = TPLinkSmartHomeProtocol.encrypt(json_dumps(discovery_data))[4:]
dm = _DiscoveryMock("127.0.0.123", 9999, discovery_data, all_fixture_data) dm = _DiscoveryMock(
"127.0.0.123",
9999,
9999,
discovery_data,
all_fixture_data,
device_type,
encrypt_type,
)
def mock_discover(self): def mock_discover(self):
port = ( port = (
dm.port_override dm.port_override
if dm.port_override and dm.default_port != 20002 if dm.port_override and dm.discovery_port != 20002
else dm.default_port else dm.discovery_port
) )
self.datagram_received( self.datagram_received(
datagram, datagram,

View File

@ -15,7 +15,9 @@ from voluptuous import (
Schema, Schema,
) )
from ..protocol import BaseTransport, TPLinkSmartHomeProtocol from ..credentials import Credentials
from ..deviceconfig import DeviceConfig
from ..protocol import BaseTransport, TPLinkSmartHomeProtocol, _XorTransport
from ..smartprotocol import SmartProtocol from ..smartprotocol import SmartProtocol
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -290,7 +292,9 @@ TIME_MODULE = {
class FakeSmartProtocol(SmartProtocol): class FakeSmartProtocol(SmartProtocol):
def __init__(self, info): def __init__(self, info):
super().__init__("127.0.0.123", transport=FakeSmartTransport(info)) super().__init__(
transport=FakeSmartTransport(info),
)
async def query(self, request, retry_count: int = 3): async def query(self, request, retry_count: int = 3):
"""Implement query here so can still patch SmartProtocol.query.""" """Implement query here so can still patch SmartProtocol.query."""
@ -301,10 +305,15 @@ class FakeSmartProtocol(SmartProtocol):
class FakeSmartTransport(BaseTransport): class FakeSmartTransport(BaseTransport):
def __init__(self, info): def __init__(self, info):
super().__init__( super().__init__(
"127.0.0.123", config=DeviceConfig("127.0.0.123", credentials=Credentials()),
) )
self.info = info self.info = info
@property
def default_port(self):
"""Default port for the transport."""
return 80
async def send(self, request: str): async def send(self, request: str):
request_dict = json_loads(request) request_dict = json_loads(request)
method = request_dict["method"] method = request_dict["method"]
@ -344,6 +353,11 @@ class FakeSmartTransport(BaseTransport):
class FakeTransportProtocol(TPLinkSmartHomeProtocol): class FakeTransportProtocol(TPLinkSmartHomeProtocol):
def __init__(self, info): def __init__(self, info):
super().__init__(
transport=_XorTransport(
config=DeviceConfig("127.0.0.123"),
)
)
self.discovery_data = info self.discovery_data = info
self.writer = None self.writer = None
self.reader = None self.reader = None

View File

@ -12,6 +12,7 @@ from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padd
from ..aestransport import AesEncyptionSession, AesTransport from ..aestransport import AesEncyptionSession, AesTransport
from ..credentials import Credentials from ..credentials import Credentials
from ..deviceconfig import DeviceConfig
from ..exceptions import ( from ..exceptions import (
SMART_RETRYABLE_ERRORS, SMART_RETRYABLE_ERRORS,
SMART_TIMEOUT_ERRORS, SMART_TIMEOUT_ERRORS,
@ -58,7 +59,9 @@ async def test_handshake(
mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code) mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code)
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post) mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
transport = AesTransport(host=host, credentials=Credentials("foo", "bar")) transport = AesTransport(
config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
)
assert transport._encryption_session is None assert transport._encryption_session is None
assert transport._handshake_done is False assert transport._handshake_done is False
@ -74,7 +77,9 @@ async def test_login(mocker, status_code, error_code, inner_error_code, expectat
mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code) mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code)
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post) mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
transport = AesTransport(host=host, credentials=Credentials("foo", "bar")) transport = AesTransport(
config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
)
transport._handshake_done = True transport._handshake_done = True
transport._session_expire_at = time.time() + 86400 transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session transport._encryption_session = mock_aes_device.encryption_session
@ -91,13 +96,14 @@ async def test_send(mocker, status_code, error_code, inner_error_code, expectati
mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code) mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code)
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post) mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
transport = AesTransport(host=host, credentials=Credentials("foo", "bar")) transport = AesTransport(
config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
)
transport._handshake_done = True transport._handshake_done = True
transport._session_expire_at = time.time() + 86400 transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session transport._encryption_session = mock_aes_device.encryption_session
transport._login_token = mock_aes_device.token transport._login_token = mock_aes_device.token
un, pw = transport.hash_credentials(True)
request = { request = {
"method": "get_device_info", "method": "get_device_info",
"params": None, "params": None,
@ -119,7 +125,8 @@ async def test_passthrough_errors(mocker, error_code):
mock_aes_device = MockAesDevice(host, 200, error_code, 0) mock_aes_device = MockAesDevice(host, 200, error_code, 0)
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post) mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
transport = AesTransport(host=host, credentials=Credentials("foo", "bar")) config = DeviceConfig(host, credentials=Credentials("foo", "bar"))
transport = AesTransport(config=config)
transport._handshake_done = True transport._handshake_done = True
transport._session_expire_at = time.time() + 86400 transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session transport._encryption_session = mock_aes_device.encryption_session

View File

@ -4,10 +4,26 @@ import asyncclick as click
import pytest import pytest
from asyncclick.testing import CliRunner from asyncclick.testing import CliRunner
from kasa import AuthenticationException, SmartDevice, UnsupportedDeviceException from kasa import (
from kasa.cli import alias, brightness, cli, emeter, raw_command, state, sysinfo, toggle AuthenticationException,
from kasa.device_factory import DEVICE_TYPE_TO_CLASS Credentials,
from kasa.discover import Discover SmartDevice,
TPLinkSmartHomeProtocol,
UnsupportedDeviceException,
)
from kasa.cli import (
TYPE_TO_CLASS,
alias,
brightness,
cli,
emeter,
raw_command,
state,
sysinfo,
toggle,
)
from kasa.discover import Discover, DiscoveryResult
from kasa.smartprotocol import SmartProtocol
from .conftest import device_iot, handle_turn_on, new_discovery, turn_on from .conftest import device_iot, handle_turn_on, new_discovery, turn_on
@ -145,9 +161,11 @@ async def test_credentials(discovery_mock, mocker):
) )
mocker.patch("kasa.cli.state", new=_state) mocker.patch("kasa.cli.state", new=_state)
for subclass in DEVICE_TYPE_TO_CLASS.values():
mocker.patch.object(subclass, "update")
mocker.patch("kasa.IotProtocol.query", return_value=discovery_mock.query_data)
mocker.patch("kasa.SmartProtocol.query", return_value=discovery_mock.query_data)
dr = DiscoveryResult(**discovery_mock.discovery_data["result"])
runner = CliRunner() runner = CliRunner()
res = await runner.invoke( res = await runner.invoke(
cli, cli,
@ -158,6 +176,10 @@ async def test_credentials(discovery_mock, mocker):
"foo", "foo",
"--password", "--password",
"bar", "bar",
"--device-family",
dr.device_type,
"--encrypt-type",
dr.mgt_encrypt_schm.encrypt_type,
], ],
) )
assert res.exit_code == 0 assert res.exit_code == 0
@ -166,7 +188,7 @@ async def test_credentials(discovery_mock, mocker):
@device_iot @device_iot
async def test_without_device_type(discovery_data: dict, dev, mocker): async def test_without_device_type(dev, mocker):
"""Test connecting without the device type.""" """Test connecting without the device type."""
runner = CliRunner() runner = CliRunner()
mocker.patch("kasa.discover.Discover.discover_single", return_value=dev) mocker.patch("kasa.discover.Discover.discover_single", return_value=dev)
@ -342,3 +364,27 @@ async def test_host_auth_failed(discovery_mock, mocker):
assert res.exit_code != 0 assert res.exit_code != 0
assert isinstance(res.exception, AuthenticationException) assert isinstance(res.exception, AuthenticationException)
@pytest.mark.parametrize("device_type", list(TYPE_TO_CLASS))
async def test_type_param(device_type, mocker):
"""Test for handling only one of username or password supplied."""
runner = CliRunner()
result_device = FileNotFoundError
pass_dev = click.make_pass_decorator(SmartDevice)
@pass_dev
async def _state(dev: SmartDevice):
nonlocal result_device
result_device = dev
mocker.patch("kasa.cli.state", new=_state)
expected_type = TYPE_TO_CLASS[device_type]
mocker.patch.object(expected_type, "update")
res = await runner.invoke(
cli,
["--type", device_type, "--host", "127.0.0.1"],
)
assert res.exit_code == 0
assert isinstance(result_device, expected_type)

View File

@ -2,6 +2,7 @@
import logging import logging
from typing import Type from typing import Type
import httpx
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
from kasa import ( from kasa import (
@ -15,122 +16,138 @@ from kasa import (
SmartLightStrip, SmartLightStrip,
SmartPlug, SmartPlug,
) )
from kasa.device_factory import ( from kasa.device_factory import connect, get_protocol
DEVICE_TYPE_TO_CLASS, from kasa.deviceconfig import (
connect, ConnectionType,
get_protocol_from_connection_name, DeviceConfig,
DeviceFamilyType,
EncryptType,
) )
from kasa.discover import DiscoveryResult from kasa.discover import DiscoveryResult
from kasa.iotprotocol import IotProtocol
from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol
@pytest.mark.parametrize("custom_port", [123, None]) def _get_connection_type_device_class(the_fixture_data):
async def test_connect(discovery_data: dict, mocker, custom_port): if "discovery_result" in the_fixture_data:
"""Make sure that connect returns an initialized SmartDevice instance.""" discovery_info = {"result": the_fixture_data["discovery_result"]}
host = "127.0.0.1" device_class = Discover._get_device_class(discovery_info)
dr = DiscoveryResult(**discovery_info["result"])
if "result" in discovery_data: connection_type = ConnectionType.from_values(
with pytest.raises(SmartDeviceException): dr.device_type, dr.mgt_encrypt_schm.encrypt_type
dev = await connect(host, port=custom_port) )
else: else:
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data) connection_type = ConnectionType.from_values(
dev = await connect(host, port=custom_port) DeviceFamilyType.IotSmartPlugSwitch.value, EncryptType.Xor.value
assert issubclass(dev.__class__, SmartDevice) )
assert dev.port == custom_port or dev.port == 9999 device_class = Discover._get_device_class(the_fixture_data)
return connection_type, device_class
@pytest.mark.parametrize("custom_port", [123, None]) async def test_connect(
@pytest.mark.parametrize(
("device_type", "klass"),
(
(DeviceType.Plug, SmartPlug),
(DeviceType.Bulb, SmartBulb),
(DeviceType.Dimmer, SmartDimmer),
(DeviceType.LightStrip, SmartLightStrip),
(DeviceType.Unknown, SmartDevice),
),
)
async def test_connect_passed_device_type(
discovery_data: dict,
mocker,
device_type: DeviceType,
klass: Type[SmartDevice],
custom_port,
):
"""Make sure that connect with a passed device type."""
host = "127.0.0.1"
if "result" in discovery_data:
with pytest.raises(SmartDeviceException):
dev = await connect(host, port=custom_port)
else:
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
dev = await connect(host, port=custom_port, device_type=device_type)
assert isinstance(dev, klass)
assert dev.port == custom_port or dev.port == 9999
async def test_connect_query_fails(discovery_data: dict, mocker):
"""Make sure that connect fails when query fails."""
host = "127.0.0.1"
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", side_effect=SmartDeviceException)
with pytest.raises(SmartDeviceException):
await connect(host)
async def test_connect_logs_connect_time(
discovery_data: dict, caplog: pytest.LogCaptureFixture, mocker
):
"""Test that the connect time is logged when debug logging is enabled."""
host = "127.0.0.1"
if "result" in discovery_data:
with pytest.raises(SmartDeviceException):
await connect(host)
else:
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
logging.getLogger("kasa").setLevel(logging.DEBUG)
await connect(host)
assert "seconds to connect" in caplog.text
async def test_connect_pass_protocol(
all_fixture_data: dict, all_fixture_data: dict,
mocker, mocker,
): ):
"""Test that if the protocol is passed in it's gets set correctly.""" """Test that if the protocol is passed in it gets set correctly."""
if "discovery_result" in all_fixture_data:
discovery_info = {"result": all_fixture_data["discovery_result"]}
device_class = Discover._get_device_class(discovery_info)
else:
device_class = Discover._get_device_class(all_fixture_data)
device_type = list(DEVICE_TYPE_TO_CLASS.keys())[
list(DEVICE_TYPE_TO_CLASS.values()).index(device_class)
]
host = "127.0.0.1" host = "127.0.0.1"
if "discovery_result" in all_fixture_data: ctype, device_class = _get_connection_type_device_class(all_fixture_data)
mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data)
mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data)
dr = DiscoveryResult(**discovery_info["result"]) mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data)
connection_name = ( mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data)
dr.device_type.split(".")[0] + "." + dr.mgt_encrypt_schm.encrypt_type mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data)
)
protocol_class = get_protocol_from_connection_name( config = DeviceConfig(
connection_name, host host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
).__class__ )
else: protocol_class = get_protocol(config).__class__
mocker.patch(
"kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data
)
protocol_class = TPLinkSmartHomeProtocol
dev = await connect( dev = await connect(
host, config=config,
device_type=device_type,
protocol_class=protocol_class,
credentials=Credentials("", ""),
) )
assert isinstance(dev, device_class)
assert isinstance(dev.protocol, protocol_class) assert isinstance(dev.protocol, protocol_class)
assert dev.config == config
@pytest.mark.parametrize("custom_port", [123, None])
async def test_connect_custom_port(all_fixture_data: dict, mocker, custom_port):
"""Make sure that connect returns an initialized SmartDevice instance."""
host = "127.0.0.1"
ctype, _ = _get_connection_type_device_class(all_fixture_data)
config = DeviceConfig(host=host, port_override=custom_port, connection_type=ctype)
default_port = 80 if "discovery_result" in all_fixture_data else 9999
ctype, _ = _get_connection_type_device_class(all_fixture_data)
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data)
mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data)
mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data)
dev = await connect(config=config)
assert issubclass(dev.__class__, SmartDevice)
assert dev.port == custom_port or dev.port == default_port
async def test_connect_logs_connect_time(
all_fixture_data: dict, caplog: pytest.LogCaptureFixture, mocker
):
"""Test that the connect time is logged when debug logging is enabled."""
ctype, _ = _get_connection_type_device_class(all_fixture_data)
mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data)
mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data)
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data)
host = "127.0.0.1"
config = DeviceConfig(
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
)
logging.getLogger("kasa").setLevel(logging.DEBUG)
await connect(
config=config,
)
assert "seconds to update" in caplog.text
async def test_connect_query_fails(all_fixture_data: dict, mocker):
"""Make sure that connect fails when query fails."""
host = "127.0.0.1"
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", side_effect=SmartDeviceException)
mocker.patch("kasa.IotProtocol.query", side_effect=SmartDeviceException)
mocker.patch("kasa.SmartProtocol.query", side_effect=SmartDeviceException)
ctype, _ = _get_connection_type_device_class(all_fixture_data)
config = DeviceConfig(
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
)
with pytest.raises(SmartDeviceException):
await connect(config=config)
async def test_connect_http_client(all_fixture_data, mocker):
"""Make sure that discover_single returns an initialized SmartDevice instance."""
host = "127.0.0.1"
ctype, _ = _get_connection_type_device_class(all_fixture_data)
mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data)
mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data)
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data)
http_client = httpx.AsyncClient()
config = DeviceConfig(
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
)
dev = await connect(config=config)
if ctype.encryption_type != EncryptType.Xor:
assert dev.protocol._transport._http_client != http_client
config = DeviceConfig(
host=host,
credentials=Credentials("foor", "bar"),
connection_type=ctype,
http_client=http_client,
)
dev = await connect(config=config)
if ctype.encryption_type != EncryptType.Xor:
assert dev.protocol._transport._http_client == http_client

View File

@ -0,0 +1,21 @@
from json import dumps as json_dumps
from json import loads as json_loads
import httpx
from kasa.credentials import Credentials
from kasa.deviceconfig import (
ConnectionType,
DeviceConfig,
DeviceFamilyType,
EncryptType,
)
def test_serialization():
config = DeviceConfig(host="Foo", http_client=httpx.AsyncClient())
config_dict = config.to_dict()
config_json = json_dumps(config_dict)
config2_dict = json_loads(config_json)
config2 = DeviceConfig.from_dict(config2_dict)
assert config == config2

View File

@ -1,21 +1,29 @@
# type: ignore # type: ignore
import logging
import re import re
import socket import socket
import httpx
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
from kasa import ( from kasa import (
Credentials,
DeviceType, DeviceType,
Discover, Discover,
SmartDevice, SmartDevice,
SmartDeviceException, SmartDeviceException,
SmartStrip,
protocol, protocol,
) )
from kasa.deviceconfig import (
ConnectionType,
DeviceConfig,
DeviceFamilyType,
EncryptType,
)
from kasa.discover import DiscoveryResult, _DiscoverProtocol, json_dumps from kasa.discover import DiscoveryResult, _DiscoverProtocol, json_dumps
from kasa.exceptions import AuthenticationException, UnsupportedDeviceException from kasa.exceptions import AuthenticationException, UnsupportedDeviceException
from .conftest import bulb, bulb_iot, dimmer, lightstrip, plug, strip from .conftest import bulb, bulb_iot, dimmer, lightstrip, new_discovery, plug, strip
UNSUPPORTED = { UNSUPPORTED = {
"result": { "result": {
@ -89,13 +97,26 @@ async def test_discover_single(discovery_mock, custom_port, mocker):
host = "127.0.0.1" host = "127.0.0.1"
discovery_mock.ip = host discovery_mock.ip = host
discovery_mock.port_override = custom_port discovery_mock.port_override = custom_port
update_mock = mocker.patch.object(SmartStrip, "update")
x = await Discover.discover_single(host, port=custom_port) device_class = Discover._get_device_class(discovery_mock.discovery_data)
update_mock = mocker.patch.object(device_class, "update")
x = await Discover.discover_single(
host, port=custom_port, credentials=Credentials()
)
assert issubclass(x.__class__, SmartDevice) assert issubclass(x.__class__, SmartDevice)
assert x._discovery_info is not None assert x._discovery_info is not None
assert x.port == custom_port or x.port == discovery_mock.default_port assert x.port == custom_port or x.port == discovery_mock.default_port
assert (update_mock.call_count > 0) == isinstance(x, SmartStrip) assert update_mock.call_count == 0
ct = ConnectionType.from_values(
discovery_mock.device_type, discovery_mock.encrypt_type
)
uses_http = discovery_mock.default_port == 80
config = DeviceConfig(
host=host, port_override=custom_port, connection_type=ct, uses_http=uses_http
)
assert x.config == config
async def test_discover_single_hostname(discovery_mock, mocker): async def test_discover_single_hostname(discovery_mock, mocker):
@ -104,47 +125,39 @@ async def test_discover_single_hostname(discovery_mock, mocker):
ip = "127.0.0.1" ip = "127.0.0.1"
discovery_mock.ip = ip discovery_mock.ip = ip
update_mock = mocker.patch.object(SmartStrip, "update") device_class = Discover._get_device_class(discovery_mock.discovery_data)
update_mock = mocker.patch.object(device_class, "update")
x = await Discover.discover_single(host) x = await Discover.discover_single(host, credentials=Credentials())
assert issubclass(x.__class__, SmartDevice) assert issubclass(x.__class__, SmartDevice)
assert x._discovery_info is not None assert x._discovery_info is not None
assert x.host == host assert x.host == host
assert (update_mock.call_count > 0) == isinstance(x, SmartStrip) assert update_mock.call_count == 0
mocker.patch("socket.getaddrinfo", side_effect=socket.gaierror()) mocker.patch("socket.getaddrinfo", side_effect=socket.gaierror())
with pytest.raises(SmartDeviceException): with pytest.raises(SmartDeviceException):
x = await Discover.discover_single(host) x = await Discover.discover_single(host, credentials=Credentials())
async def test_discover_single_unsupported(mocker): async def test_discover_single_unsupported(unsupported_device_info, mocker):
"""Make sure that discover_single handles unsupported devices correctly.""" """Make sure that discover_single handles unsupported devices correctly."""
host = "127.0.0.1" host = "127.0.0.1"
def mock_discover(self):
if discovery_data:
data = (
b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8"
+ json_dumps(discovery_data).encode()
)
self.datagram_received(data, (host, 20002))
mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover)
# Test with a valid unsupported response # Test with a valid unsupported response
discovery_data = UNSUPPORTED
with pytest.raises( with pytest.raises(
UnsupportedDeviceException, UnsupportedDeviceException,
match=f"Unsupported device {host} of type SMART.TAPOXMASTREE: {re.escape(str(UNSUPPORTED))}",
): ):
await Discover.discover_single(host) await Discover.discover_single(host)
# Test with no response
discovery_data = None async def test_discover_single_no_response(mocker):
"""Make sure that discover_single handles no response correctly."""
host = "127.0.0.1"
mocker.patch.object(_DiscoverProtocol, "do_discover")
with pytest.raises( with pytest.raises(
SmartDeviceException, match=f"Timed out getting discovery response for {host}" SmartDeviceException, match=f"Timed out getting discovery response for {host}"
): ):
await Discover.discover_single(host, timeout=0.001) await Discover.discover_single(host, discovery_timeout=0)
INVALIDS = [ INVALIDS = [
@ -241,52 +254,82 @@ AUTHENTICATION_DATA_KLAP = {
} }
async def test_discover_single_authentication(mocker): @new_discovery
async def test_discover_single_authentication(discovery_mock, mocker):
"""Make sure that discover_single handles authenticating devices correctly.""" """Make sure that discover_single handles authenticating devices correctly."""
host = "127.0.0.1" host = "127.0.0.1"
discovery_mock.ip = host
def mock_discover(self): device_class = Discover._get_device_class(discovery_mock.discovery_data)
if discovery_data:
data = (
b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8"
+ json_dumps(discovery_data).encode()
)
self.datagram_received(data, (host, 20002))
mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover)
mocker.patch.object( mocker.patch.object(
SmartDevice, device_class,
"update", "update",
side_effect=AuthenticationException("Failed to authenticate"), side_effect=AuthenticationException("Failed to authenticate"),
) )
# Test with a valid unsupported response
discovery_data = AUTHENTICATION_DATA_KLAP
with pytest.raises( with pytest.raises(
AuthenticationException, AuthenticationException,
match="Failed to authenticate", match="Failed to authenticate",
): ):
device = await Discover.discover_single(host) device = await Discover.discover_single(
host, credentials=Credentials("foo", "bar")
)
await device.update() await device.update()
mocker.patch.object(SmartDevice, "update") mocker.patch.object(device_class, "update")
device = await Discover.discover_single(host) device = await Discover.discover_single(host, credentials=Credentials("foo", "bar"))
await device.update() await device.update()
assert device.device_type == DeviceType.Plug assert isinstance(device, device_class)
async def test_device_update_from_new_discovery_info(): @new_discovery
async def test_device_update_from_new_discovery_info(discovery_data):
device = SmartDevice("127.0.0.7") device = SmartDevice("127.0.0.7")
discover_info = DiscoveryResult(**AUTHENTICATION_DATA_KLAP["result"]) discover_info = DiscoveryResult(**discovery_data["result"])
discover_dump = discover_info.get_dict() discover_dump = discover_info.get_dict()
discover_dump["alias"] = "foobar"
discover_dump["model"] = discover_dump["device_model"]
device.update_from_discover_info(discover_dump) device.update_from_discover_info(discover_dump)
assert device.alias == discover_dump["alias"] assert device.alias == "foobar"
assert device.mac == discover_dump["mac"].replace("-", ":") assert device.mac == discover_dump["mac"].replace("-", ":")
assert device.model == discover_dump["model"] assert device.model == discover_dump["device_model"]
with pytest.raises( with pytest.raises(
SmartDeviceException, SmartDeviceException,
match=re.escape("You need to await update() to access the data"), match=re.escape("You need to await update() to access the data"),
): ):
assert device.supported_modules assert device.supported_modules
async def test_discover_single_http_client(discovery_mock, mocker):
"""Make sure that discover_single returns an initialized SmartDevice instance."""
host = "127.0.0.1"
discovery_mock.ip = host
http_client = httpx.AsyncClient()
x: SmartDevice = await Discover.discover_single(host)
assert x.config.uses_http == (discovery_mock.default_port == 80)
if discovery_mock.default_port == 80:
assert x.protocol._transport._http_client != http_client
x.config.http_client = http_client
assert x.protocol._transport._http_client == http_client
async def test_discover_http_client(discovery_mock, mocker):
"""Make sure that discover_single returns an initialized SmartDevice instance."""
host = "127.0.0.1"
discovery_mock.ip = host
http_client = httpx.AsyncClient()
devices = await Discover.discover(discovery_timeout=0)
x: SmartDevice = devices[host]
assert x.config.uses_http == (discovery_mock.default_port == 80)
if discovery_mock.default_port == 80:
assert x.protocol._transport._http_client != http_client
x.config.http_client = http_client
assert x.protocol._transport._http_client == http_client

View File

@ -12,9 +12,15 @@ import pytest
from ..aestransport import AesTransport from ..aestransport import AesTransport
from ..credentials import Credentials from ..credentials import Credentials
from ..deviceconfig import DeviceConfig
from ..exceptions import AuthenticationException, SmartDeviceException from ..exceptions import AuthenticationException, SmartDeviceException
from ..iotprotocol import IotProtocol from ..iotprotocol import IotProtocol
from ..klaptransport import KlapEncryptionSession, KlapTransport, _sha256 from ..klaptransport import (
KlapEncryptionSession,
KlapTransport,
KlapTransportV2,
_sha256,
)
from ..smartprotocol import SmartProtocol from ..smartprotocol import SmartProtocol
DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}} DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}}
@ -31,8 +37,9 @@ class _mock_response:
[ [
(Exception("dummy exception"), True), (Exception("dummy exception"), True),
(SmartDeviceException("dummy exception"), False), (SmartDeviceException("dummy exception"), False),
(httpx.ConnectError("dummy exception"), True),
], ],
ids=("Exception", "SmartDeviceException"), ids=("Exception", "SmartDeviceException", "httpx.ConnectError"),
) )
@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport]) @pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol]) @pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol])
@ -42,8 +49,10 @@ async def test_protocol_retries(
): ):
host = "127.0.0.1" host = "127.0.0.1"
conn = mocker.patch.object(httpx.AsyncClient, "post", side_effect=error) conn = mocker.patch.object(httpx.AsyncClient, "post", side_effect=error)
config = DeviceConfig(host)
with pytest.raises(SmartDeviceException): with pytest.raises(SmartDeviceException):
await protocol_class(host, transport=transport_class(host)).query( await protocol_class(transport=transport_class(config=config)).query(
DUMMY_QUERY, retry_count=retry_count DUMMY_QUERY, retry_count=retry_count
) )
@ -60,10 +69,11 @@ async def test_protocol_no_retry_on_connection_error(
conn = mocker.patch.object( conn = mocker.patch.object(
httpx.AsyncClient, httpx.AsyncClient,
"post", "post",
side_effect=httpx.ConnectError("foo"), side_effect=AuthenticationException("foo"),
) )
config = DeviceConfig(host)
with pytest.raises(SmartDeviceException): with pytest.raises(SmartDeviceException):
await protocol_class(host, transport=transport_class(host)).query( await protocol_class(transport=transport_class(config=config)).query(
DUMMY_QUERY, retry_count=5 DUMMY_QUERY, retry_count=5
) )
@ -81,8 +91,9 @@ async def test_protocol_retry_recoverable_error(
"post", "post",
side_effect=httpx.CloseError("foo"), side_effect=httpx.CloseError("foo"),
) )
config = DeviceConfig(host)
with pytest.raises(SmartDeviceException): with pytest.raises(SmartDeviceException):
await protocol_class(host, transport=transport_class(host)).query( await protocol_class(transport=transport_class(config=config)).query(
DUMMY_QUERY, retry_count=5 DUMMY_QUERY, retry_count=5
) )
@ -115,7 +126,8 @@ async def test_protocol_reconnect(mocker, retry_count, protocol_class, transport
side_effect=_fail_one_less_than_retry_count, side_effect=_fail_one_less_than_retry_count,
) )
response = await protocol_class(host, transport=transport_class(host)).query( config = DeviceConfig(host)
response = await protocol_class(transport=transport_class(config=config)).query(
DUMMY_QUERY, retry_count=retry_count DUMMY_QUERY, retry_count=retry_count
) )
assert "result" in response or "foobar" in response assert "result" in response or "foobar" in response
@ -136,7 +148,9 @@ async def test_protocol_logging(mocker, caplog, log_level):
seed = secrets.token_bytes(16) seed = secrets.token_bytes(16)
auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar")) auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar"))
encryption_session = KlapEncryptionSession(seed, seed, auth_hash) encryption_session = KlapEncryptionSession(seed, seed, auth_hash)
protocol = IotProtocol("127.0.0.1", transport=KlapTransport("127.0.0.1"))
config = DeviceConfig("127.0.0.1")
protocol = IotProtocol(transport=KlapTransport(config=config))
protocol._transport._handshake_done = True protocol._transport._handshake_done = True
protocol._transport._session_expire_at = time.time() + 86400 protocol._transport._session_expire_at = time.time() + 86400
@ -181,7 +195,7 @@ def test_encrypt_unicode():
"device_credentials, expectation", "device_credentials, expectation",
[ [
(Credentials("foo", "bar"), does_not_raise()), (Credentials("foo", "bar"), does_not_raise()),
(Credentials("", ""), does_not_raise()), (Credentials(), does_not_raise()),
( (
Credentials( Credentials(
KlapTransport.KASA_SETUP_EMAIL, KlapTransport.KASA_SETUP_EMAIL,
@ -196,30 +210,37 @@ def test_encrypt_unicode():
], ],
ids=("client", "blank", "kasa_setup", "shouldfail"), ids=("client", "blank", "kasa_setup", "shouldfail"),
) )
async def test_handshake1(mocker, device_credentials, expectation): @pytest.mark.parametrize(
"transport_class, seed_auth_hash_calc",
[
pytest.param(KlapTransport, lambda c, s, a: c + a, id="KLAP"),
pytest.param(KlapTransportV2, lambda c, s, a: c + s + a, id="KLAPV2"),
],
)
async def test_handshake1(
mocker, device_credentials, expectation, transport_class, seed_auth_hash_calc
):
async def _return_handshake1_response(url, params=None, data=None, *_, **__): async def _return_handshake1_response(url, params=None, data=None, *_, **__):
nonlocal client_seed, server_seed, device_auth_hash nonlocal client_seed, server_seed, device_auth_hash
client_seed = data client_seed = data
client_seed_auth_hash = _sha256(data + device_auth_hash) seed_auth_hash = _sha256(
seed_auth_hash_calc(client_seed, server_seed, device_auth_hash)
return _mock_response(200, server_seed + client_seed_auth_hash) )
return _mock_response(200, server_seed + seed_auth_hash)
client_seed = None client_seed = None
server_seed = secrets.token_bytes(16) server_seed = secrets.token_bytes(16)
client_credentials = Credentials("foo", "bar") client_credentials = Credentials("foo", "bar")
device_auth_hash = KlapTransport.generate_auth_hash(device_credentials) device_auth_hash = transport_class.generate_auth_hash(device_credentials)
mocker.patch.object( mocker.patch.object(
httpx.AsyncClient, "post", side_effect=_return_handshake1_response httpx.AsyncClient, "post", side_effect=_return_handshake1_response
) )
protocol = IotProtocol( config = DeviceConfig("127.0.0.1", credentials=client_credentials)
"127.0.0.1", protocol = IotProtocol(transport=transport_class(config=config))
transport=KlapTransport("127.0.0.1", credentials=client_credentials),
)
protocol._transport.http_client = httpx.AsyncClient()
with expectation: with expectation:
( (
local_seed, local_seed,
@ -233,31 +254,51 @@ async def test_handshake1(mocker, device_credentials, expectation):
await protocol.close() await protocol.close()
async def test_handshake(mocker): @pytest.mark.parametrize(
"transport_class, seed_auth_hash_calc1, seed_auth_hash_calc2",
[
pytest.param(
KlapTransport, lambda c, s, a: c + a, lambda c, s, a: s + a, id="KLAP"
),
pytest.param(
KlapTransportV2,
lambda c, s, a: c + s + a,
lambda c, s, a: s + c + a,
id="KLAPV2",
),
],
)
async def test_handshake(
mocker, transport_class, seed_auth_hash_calc1, seed_auth_hash_calc2
):
async def _return_handshake_response(url, params=None, data=None, *_, **__): async def _return_handshake_response(url, params=None, data=None, *_, **__):
nonlocal response_status, client_seed, server_seed, device_auth_hash nonlocal client_seed, server_seed, device_auth_hash
if url == "http://127.0.0.1/app/handshake1": if url == "http://127.0.0.1/app/handshake1":
client_seed = data client_seed = data
client_seed_auth_hash = _sha256(data + device_auth_hash) seed_auth_hash = _sha256(
seed_auth_hash_calc1(client_seed, server_seed, device_auth_hash)
)
return _mock_response(200, server_seed + client_seed_auth_hash) return _mock_response(200, server_seed + seed_auth_hash)
elif url == "http://127.0.0.1/app/handshake2": elif url == "http://127.0.0.1/app/handshake2":
seed_auth_hash = _sha256(
seed_auth_hash_calc2(client_seed, server_seed, device_auth_hash)
)
assert data == seed_auth_hash
return _mock_response(response_status, b"") return _mock_response(response_status, b"")
client_seed = None client_seed = None
server_seed = secrets.token_bytes(16) server_seed = secrets.token_bytes(16)
client_credentials = Credentials("foo", "bar") client_credentials = Credentials("foo", "bar")
device_auth_hash = KlapTransport.generate_auth_hash(client_credentials) device_auth_hash = transport_class.generate_auth_hash(client_credentials)
mocker.patch.object( mocker.patch.object(
httpx.AsyncClient, "post", side_effect=_return_handshake_response httpx.AsyncClient, "post", side_effect=_return_handshake_response
) )
protocol = IotProtocol( config = DeviceConfig("127.0.0.1", credentials=client_credentials)
"127.0.0.1", protocol = IotProtocol(transport=transport_class(config=config))
transport=KlapTransport("127.0.0.1", credentials=client_credentials),
)
protocol._transport.http_client = httpx.AsyncClient() protocol._transport.http_client = httpx.AsyncClient()
response_status = 200 response_status = 200
@ -273,7 +314,7 @@ async def test_handshake(mocker):
async def test_query(mocker): async def test_query(mocker):
async def _return_response(url, params=None, data=None, *_, **__): async def _return_response(url, params=None, data=None, *_, **__):
nonlocal client_seed, server_seed, device_auth_hash, protocol, seq nonlocal client_seed, server_seed, device_auth_hash, seq
if url == "http://127.0.0.1/app/handshake1": if url == "http://127.0.0.1/app/handshake1":
client_seed = data client_seed = data
@ -303,10 +344,8 @@ async def test_query(mocker):
mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response) mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response)
protocol = IotProtocol( config = DeviceConfig("127.0.0.1", credentials=client_credentials)
"127.0.0.1", protocol = IotProtocol(transport=KlapTransport(config=config))
transport=KlapTransport("127.0.0.1", credentials=client_credentials),
)
for _ in range(10): for _ in range(10):
resp = await protocol.query({}) resp = await protocol.query({})
@ -350,10 +389,8 @@ async def test_authentication_failures(mocker, response_status, expectation):
mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response) mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response)
protocol = IotProtocol( config = DeviceConfig("127.0.0.1", credentials=client_credentials)
"127.0.0.1", protocol = IotProtocol(transport=KlapTransport(config=config))
transport=KlapTransport("127.0.0.1", credentials=client_credentials),
)
with expectation: with expectation:
await protocol.query({}) await protocol.query({})

View File

@ -9,6 +9,7 @@ import sys
import pytest import pytest
from ..deviceconfig import DeviceConfig
from ..exceptions import SmartDeviceException from ..exceptions import SmartDeviceException
from ..protocol import ( from ..protocol import (
BaseTransport, BaseTransport,
@ -31,10 +32,11 @@ async def test_protocol_retries(mocker, retry_count):
return reader, writer return reader, writer
conn = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) conn = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
config = DeviceConfig("127.0.0.1")
with pytest.raises(SmartDeviceException): with pytest.raises(SmartDeviceException):
await TPLinkSmartHomeProtocol( await TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)).query(
"127.0.0.1", transport=_XorTransport("127.0.0.1") {}, retry_count=retry_count
).query({}, retry_count=retry_count) )
assert conn.call_count == retry_count + 1 assert conn.call_count == retry_count + 1
@ -44,10 +46,11 @@ async def test_protocol_no_retry_on_unreachable(mocker):
"asyncio.open_connection", "asyncio.open_connection",
side_effect=OSError(errno.EHOSTUNREACH, "No route to host"), side_effect=OSError(errno.EHOSTUNREACH, "No route to host"),
) )
config = DeviceConfig("127.0.0.1")
with pytest.raises(SmartDeviceException): with pytest.raises(SmartDeviceException):
await TPLinkSmartHomeProtocol( await TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)).query(
"127.0.0.1", transport=_XorTransport("127.0.0.1") {}, retry_count=5
).query({}, retry_count=5) )
assert conn.call_count == 1 assert conn.call_count == 1
@ -57,10 +60,11 @@ async def test_protocol_no_retry_connection_refused(mocker):
"asyncio.open_connection", "asyncio.open_connection",
side_effect=ConnectionRefusedError, side_effect=ConnectionRefusedError,
) )
config = DeviceConfig("127.0.0.1")
with pytest.raises(SmartDeviceException): with pytest.raises(SmartDeviceException):
await TPLinkSmartHomeProtocol( await TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)).query(
"127.0.0.1", transport=_XorTransport("127.0.0.1") {}, retry_count=5
).query({}, retry_count=5) )
assert conn.call_count == 1 assert conn.call_count == 1
@ -70,10 +74,11 @@ async def test_protocol_retry_recoverable_error(mocker):
"asyncio.open_connection", "asyncio.open_connection",
side_effect=OSError(errno.ECONNRESET, "Connection reset by peer"), side_effect=OSError(errno.ECONNRESET, "Connection reset by peer"),
) )
config = DeviceConfig("127.0.0.1")
with pytest.raises(SmartDeviceException): with pytest.raises(SmartDeviceException):
await TPLinkSmartHomeProtocol( await TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)).query(
"127.0.0.1", transport=_XorTransport("127.0.0.1") {}, retry_count=5
).query({}, retry_count=5) )
assert conn.call_count == 6 assert conn.call_count == 6
@ -107,9 +112,8 @@ async def test_protocol_reconnect(mocker, retry_count):
mocker.patch.object(reader, "readexactly", _mock_read) mocker.patch.object(reader, "readexactly", _mock_read)
return reader, writer return reader, writer
protocol = TPLinkSmartHomeProtocol( config = DeviceConfig("127.0.0.1")
"127.0.0.1", transport=_XorTransport("127.0.0.1") protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config))
)
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
response = await protocol.query({}, retry_count=retry_count) response = await protocol.query({}, retry_count=retry_count)
assert response == {"great": "success"} assert response == {"great": "success"}
@ -137,9 +141,8 @@ async def test_protocol_logging(mocker, caplog, log_level):
mocker.patch.object(reader, "readexactly", _mock_read) mocker.patch.object(reader, "readexactly", _mock_read)
return reader, writer return reader, writer
protocol = TPLinkSmartHomeProtocol( config = DeviceConfig("127.0.0.1")
"127.0.0.1", transport=_XorTransport("127.0.0.1") protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config))
)
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
response = await protocol.query({}) response = await protocol.query({})
assert response == {"great": "success"} assert response == {"great": "success"}
@ -173,9 +176,8 @@ async def test_protocol_custom_port(mocker, custom_port):
mocker.patch.object(reader, "readexactly", _mock_read) mocker.patch.object(reader, "readexactly", _mock_read)
return reader, writer return reader, writer
protocol = TPLinkSmartHomeProtocol( config = DeviceConfig("127.0.0.1", port_override=custom_port)
"127.0.0.1", transport=_XorTransport("127.0.0.1", port=custom_port) protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config))
)
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
response = await protocol.query({}) response = await protocol.query({})
assert response == {"great": "success"} assert response == {"great": "success"}
@ -271,18 +273,14 @@ def _get_subclasses(of_class):
def test_protocol_init_signature(class_name_obj): def test_protocol_init_signature(class_name_obj):
params = list(inspect.signature(class_name_obj[1].__init__).parameters.values()) params = list(inspect.signature(class_name_obj[1].__init__).parameters.values())
assert len(params) == 3 assert len(params) == 2
assert ( assert (
params[0].name == "self" params[0].name == "self"
and params[0].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD and params[0].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
) )
assert ( assert (
params[1].name == "host" params[1].name == "transport"
and params[1].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD and params[1].kind == inspect.Parameter.KEYWORD_ONLY
)
assert (
params[2].name == "transport"
and params[2].kind == inspect.Parameter.KEYWORD_ONLY
) )
@ -292,20 +290,11 @@ def test_protocol_init_signature(class_name_obj):
def test_transport_init_signature(class_name_obj): def test_transport_init_signature(class_name_obj):
params = list(inspect.signature(class_name_obj[1].__init__).parameters.values()) params = list(inspect.signature(class_name_obj[1].__init__).parameters.values())
assert len(params) == 5 assert len(params) == 2
assert ( assert (
params[0].name == "self" params[0].name == "self"
and params[0].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD and params[0].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
) )
assert ( assert (
params[1].name == "host" params[1].name == "config" and params[1].kind == inspect.Parameter.KEYWORD_ONLY
and params[1].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
)
assert params[2].name == "port" and params[2].kind == inspect.Parameter.KEYWORD_ONLY
assert (
params[3].name == "credentials"
and params[3].kind == inspect.Parameter.KEYWORD_ONLY
)
assert (
params[4].name == "timeout" and params[4].kind == inspect.Parameter.KEYWORD_ONLY
) )

View File

@ -5,8 +5,7 @@ from unittest.mock import Mock, patch
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
import kasa import kasa
from kasa import Credentials, SmartDevice, SmartDeviceException from kasa import Credentials, DeviceConfig, SmartDevice, SmartDeviceException
from kasa.smartdevice import DeviceType
from .conftest import device_iot, handle_turn_on, has_emeter, no_emeter_iot, turn_on from .conftest import device_iot, handle_turn_on, has_emeter, no_emeter_iot, turn_on
from .newfakes import PLUG_SCHEMA, TZ_SCHEMA, FakeTransportProtocol from .newfakes import PLUG_SCHEMA, TZ_SCHEMA, FakeTransportProtocol
@ -215,7 +214,8 @@ def test_device_class_ctors(device_class):
host = "127.0.0.2" host = "127.0.0.2"
port = 1234 port = 1234
credentials = Credentials("foo", "bar") credentials = Credentials("foo", "bar")
dev = device_class(host, port=port, credentials=credentials) config = DeviceConfig(host, port_override=port, credentials=credentials)
dev = device_class(host, config=config)
assert dev.host == host assert dev.host == host
assert dev.port == port assert dev.port == port
assert dev.credentials == credentials assert dev.credentials == credentials
@ -231,29 +231,27 @@ async def test_modules_preserved(dev: SmartDevice):
async def test_create_smart_device_with_timeout(): async def test_create_smart_device_with_timeout():
"""Make sure timeout is passed to the protocol.""" """Make sure timeout is passed to the protocol."""
dev = SmartDevice(host="127.0.0.1", timeout=100) host = "127.0.0.1"
dev = SmartDevice(host, config=DeviceConfig(host, timeout=100))
assert dev.protocol._transport._timeout == 100 assert dev.protocol._transport._timeout == 100
async def test_create_thin_wrapper(): async def test_create_thin_wrapper():
"""Make sure thin wrapper is created with the correct device type.""" """Make sure thin wrapper is created with the correct device type."""
mock = Mock() mock = Mock()
config = DeviceConfig(
host="test_host",
port_override=1234,
timeout=100,
credentials=Credentials("username", "password"),
)
with patch("kasa.device_factory.connect", return_value=mock) as connect: with patch("kasa.device_factory.connect", return_value=mock) as connect:
dev = await SmartDevice.connect( dev = await SmartDevice.connect(config=config)
host="test_host",
port=1234,
timeout=100,
credentials=Credentials("username", "password"),
device_type=DeviceType.Strip,
)
assert dev is mock assert dev is mock
connect.assert_called_once_with( connect.assert_called_once_with(
host="test_host", host=None,
port=1234, config=config,
timeout=100,
credentials=Credentials("username", "password"),
device_type=DeviceType.Strip,
) )

View File

@ -13,6 +13,7 @@ import pytest
from ..aestransport import AesTransport from ..aestransport import AesTransport
from ..credentials import Credentials from ..credentials import Credentials
from ..deviceconfig import DeviceConfig
from ..exceptions import ( from ..exceptions import (
SMART_RETRYABLE_ERRORS, SMART_RETRYABLE_ERRORS,
SMART_TIMEOUT_ERRORS, SMART_TIMEOUT_ERRORS,
@ -37,7 +38,8 @@ async def test_smart_device_errors(mocker, error_code):
send_mock = mocker.patch.object(AesTransport, "send", return_value=mock_response) send_mock = mocker.patch.object(AesTransport, "send", return_value=mock_response)
protocol = SmartProtocol(host, transport=AesTransport(host)) config = DeviceConfig(host, credentials=Credentials("foo", "bar"))
protocol = SmartProtocol(transport=AesTransport(config=config))
with pytest.raises(SmartDeviceException): with pytest.raises(SmartDeviceException):
await protocol.query(DUMMY_QUERY, retry_count=2) await protocol.query(DUMMY_QUERY, retry_count=2)
@ -70,8 +72,8 @@ async def test_smart_device_errors_in_multiple_request(mocker, error_code):
mocker.patch.object(AesTransport, "perform_login") mocker.patch.object(AesTransport, "perform_login")
send_mock = mocker.patch.object(AesTransport, "send", return_value=mock_response) send_mock = mocker.patch.object(AesTransport, "send", return_value=mock_response)
config = DeviceConfig(host, credentials=Credentials("foo", "bar"))
protocol = SmartProtocol(host, transport=AesTransport(host)) protocol = SmartProtocol(transport=AesTransport(config=config))
with pytest.raises(SmartDeviceException): with pytest.raises(SmartDeviceException):
await protocol.query(DUMMY_QUERY, retry_count=2) await protocol.query(DUMMY_QUERY, retry_count=2)
if error_code in chain(SMART_TIMEOUT_ERRORS, SMART_RETRYABLE_ERRORS): if error_code in chain(SMART_TIMEOUT_ERRORS, SMART_RETRYABLE_ERRORS):