Add DeviceConfig to allow specifying configuration parameters (#569)

* Add DeviceConfig handling

* Update post review

* Further update post latest review

* Update following latest review

* Update docstrings and docs
This commit is contained in:
sdb9696 2023-12-29 19:17:15 +00:00 committed by GitHub
parent ec3ea39a37
commit f6fd898faf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
33 changed files with 1032 additions and 589 deletions

View File

@ -23,9 +23,12 @@ This will return you a list of device instances based on the discovery replies.
If the device's host is already known, you can use to construct a device instance with
:meth:`~kasa.SmartDevice.connect()`.
When connecting a device with the :meth:`~kasa.SmartDevice.connect()` method, it is recommended to
pass the device type as well as this allows the library to use the correct device class for the
device without having to query the device.
The :meth:`~kasa.SmartDevice.connect()` also enables support for connecting to new
KASA SMART protocol and TAPO devices directly using the parameter :class:`~kasa.DeviceConfig`.
Simply serialize the :attr:`~kasa.SmartDevice.config` property via :meth:`~kasa.DeviceConfig.to_dict()`
and then deserialize it later with :func:`~kasa.DeviceConfig.from_dict()`
and then pass it into :meth:`~kasa.SmartDevice.connect()`.
.. _update_cycle:

View File

@ -0,0 +1,18 @@
DeviceConfig
============
.. contents:: Contents
:local:
.. note::
Feel free to open a pull request to improve the documentation!
API documentation
*****************
.. autoclass:: kasa.DeviceConfig
:members:
:inherited-members:
:undoc-members:

View File

@ -15,3 +15,4 @@
smartdimmer
smartstrip
smartlightstrip
deviceconfig

View File

@ -14,6 +14,12 @@ to be handled by the user of the library.
from importlib.metadata import version
from kasa.credentials import Credentials
from kasa.deviceconfig import (
ConnectionType,
DeviceConfig,
DeviceFamilyType,
EncryptType,
)
from kasa.discover import Discover
from kasa.emeterstatus import EmeterStatus
from kasa.exceptions import (
@ -55,4 +61,8 @@ __all__ = [
"AuthenticationException",
"UnsupportedDeviceException",
"Credentials",
"DeviceConfig",
"ConnectionType",
"EncryptType",
"DeviceFamilyType",
]

View File

@ -16,7 +16,7 @@ from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padd
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from .credentials import Credentials
from .deviceconfig import DeviceConfig
from .exceptions import (
SMART_AUTHENTICATION_ERRORS,
SMART_RETRYABLE_ERRORS,
@ -47,8 +47,7 @@ class AesTransport(BaseTransport):
protocol, sometimes used by newer firmware versions on kasa devices.
"""
DEFAULT_PORT = 80
DEFAULT_TIMEOUT = 5
DEFAULT_PORT: int = 80
SESSION_COOKIE_NAME = "TP_SESSIONID"
COMMON_HEADERS = {
"Content-Type": "application/json",
@ -58,32 +57,37 @@ class AesTransport(BaseTransport):
def __init__(
self,
host: str,
*,
port: Optional[int] = None,
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
config: DeviceConfig,
) -> None:
super().__init__(
host,
port=port or self.DEFAULT_PORT,
credentials=credentials,
timeout=timeout,
)
super().__init__(config=config)
self._default_http_client: Optional[httpx.AsyncClient] = None
self._handshake_done = False
self._encryption_session: Optional[AesEncyptionSession] = None
self._session_expire_at: Optional[float] = None
self._timeout = timeout if timeout else self.DEFAULT_TIMEOUT
self._session_cookie = None
self._http_client: httpx.AsyncClient = httpx.AsyncClient()
self._login_token = None
_LOGGER.debug("Created AES transport for %s", self._host)
@property
def default_port(self):
"""Default port for the transport."""
return self.DEFAULT_PORT
@property
def _http_client(self) -> httpx.AsyncClient:
if self._config.http_client:
return self._config.http_client
if not self._default_http_client:
self._default_http_client = httpx.AsyncClient()
return self._default_http_client
def hash_credentials(self, login_v2):
"""Hash the credentials."""
if login_v2:
@ -102,8 +106,6 @@ class AesTransport(BaseTransport):
async def client_post(self, url, params=None, data=None, json=None, headers=None):
"""Send an http post request to the device."""
if not self._http_client:
self._http_client = httpx.AsyncClient()
response_data = None
cookies = None
if self._session_cookie:
@ -268,8 +270,8 @@ class AesTransport(BaseTransport):
async def close(self) -> None:
"""Close the protocol."""
client = self._http_client
self._http_client = None
client = self._default_http_client
self._default_http_client = None
self._handshake_done = False
self._login_token = None
if client:

View File

@ -12,15 +12,20 @@ import asyncclick as click
from kasa import (
AuthenticationException,
ConnectionType,
Credentials,
DeviceType,
DeviceConfig,
DeviceFamilyType,
Discover,
EncryptType,
SmartBulb,
SmartDevice,
SmartDimmer,
SmartLightStrip,
SmartPlug,
SmartStrip,
UnsupportedDeviceException,
)
from kasa.device_factory import DEVICE_TYPE_TO_CLASS
from kasa.discover import DiscoveryResult
try:
@ -49,10 +54,19 @@ except ImportError:
# --json has set it to _nop_echo
echo = _do_echo
DEVICE_TYPES = [
device_type.value
for device_type in DeviceType
if device_type in DEVICE_TYPE_TO_CLASS
TYPE_TO_CLASS = {
"plug": SmartPlug,
"bulb": SmartBulb,
"dimmer": SmartDimmer,
"strip": SmartStrip,
"lightstrip": SmartLightStrip,
}
ENCRYPT_TYPES = [encrypt_type.value for encrypt_type in EncryptType]
DEVICE_FAMILY_TYPES = [
device_family_type.value for device_family_type in DeviceFamilyType
]
click.anyio_backend = "asyncio"
@ -149,7 +163,7 @@ def json_formatter_cb(result, **kwargs):
"--type",
envvar="KASA_TYPE",
default=None,
type=click.Choice(DEVICE_TYPES, case_sensitive=False),
type=click.Choice(list(TYPE_TO_CLASS), case_sensitive=False),
)
@click.option(
"--json/--no-json",
@ -158,6 +172,18 @@ def json_formatter_cb(result, **kwargs):
is_flag=True,
help="Output raw device response as JSON.",
)
@click.option(
"--encrypt-type",
envvar="KASA_ENCRYPT_TYPE",
default=None,
type=click.Choice(ENCRYPT_TYPES, case_sensitive=False),
)
@click.option(
"--device-family",
envvar="KASA_DEVICE_FAMILY",
default=None,
type=click.Choice(DEVICE_FAMILY_TYPES, case_sensitive=False),
)
@click.option(
"--timeout",
envvar="KASA_TIMEOUT",
@ -199,6 +225,8 @@ async def cli(
verbose,
debug,
type,
encrypt_type,
device_family,
json,
timeout,
discovery_timeout,
@ -270,12 +298,19 @@ async def cli(
return await ctx.invoke(discover)
if type is not None:
device_type = DeviceType.from_value(type)
dev = await SmartDevice.connect(
host, credentials=credentials, device_type=device_type, timeout=timeout
dev = TYPE_TO_CLASS[type](host)
await dev.update()
elif device_family or encrypt_type:
ctype = ConnectionType(
DeviceFamilyType(device_family),
EncryptType(encrypt_type),
)
config = DeviceConfig(
host=host, credentials=credentials, timeout=timeout, connection_type=ctype
)
dev = await SmartDevice.connect(config=config)
else:
echo("No --type defined, discovering..")
echo("No --type or --device-family and --encrypt-type defined, discovering..")
dev = await Discover.discover_single(
host,
port=port,
@ -332,8 +367,10 @@ async def discover(ctx):
target = ctx.parent.params["target"]
username = ctx.parent.params["username"]
password = ctx.parent.params["password"]
timeout = ctx.parent.params["discovery_timeout"]
verbose = ctx.parent.params["verbose"]
discovery_timeout = ctx.parent.params["discovery_timeout"]
timeout = ctx.parent.params["timeout"]
port = ctx.parent.params["port"]
credentials = Credentials(username, password)
@ -354,7 +391,7 @@ async def discover(ctx):
echo(f"\t{unsupported_exception}")
echo()
echo(f"Discovering devices on {target} for {timeout} seconds")
echo(f"Discovering devices on {target} for {discovery_timeout} seconds")
async def print_discovered(dev: SmartDevice):
async with sem:
@ -376,9 +413,11 @@ async def discover(ctx):
await Discover.discover(
target=target,
timeout=timeout,
discovery_timeout=discovery_timeout,
on_discovered=print_discovered,
on_unsupported=print_unsupported,
port=port,
timeout=timeout,
credentials=credentials,
)

View File

@ -8,5 +8,5 @@ from typing import Optional
class Credentials:
"""Credentials for authentication."""
username: Optional[str] = field(default=None, repr=False)
password: Optional[str] = field(default=None, repr=False)
username: Optional[str] = field(default="", repr=False)
password: Optional[str] = field(default="", repr=False)

View File

@ -1,18 +1,21 @@
"""Device creation by type."""
"""Device creation via DeviceConfig."""
import logging
import time
from typing import Any, Dict, Optional, Tuple, Type
from .aestransport import AesTransport
from .credentials import Credentials
from .device_type import DeviceType
from .exceptions import UnsupportedDeviceException
from .deviceconfig import DeviceConfig
from .exceptions import SmartDeviceException, UnsupportedDeviceException
from .iotprotocol import IotProtocol
from .klaptransport import KlapTransport, TPlinkKlapTransportV2
from .protocol import BaseTransport, TPLinkProtocol
from .klaptransport import KlapTransport, KlapTransportV2
from .protocol import (
BaseTransport,
TPLinkProtocol,
TPLinkSmartHomeProtocol,
_XorTransport,
)
from .smartbulb import SmartBulb
from .smartdevice import SmartDevice, SmartDeviceException
from .smartdevice import SmartDevice
from .smartdimmer import SmartDimmer
from .smartlightstrip import SmartLightStrip
from .smartplug import SmartPlug
@ -20,104 +23,80 @@ from .smartprotocol import SmartProtocol
from .smartstrip import SmartStrip
from .tapo import TapoBulb, TapoPlug
DEVICE_TYPE_TO_CLASS = {
DeviceType.Plug: SmartPlug,
DeviceType.Bulb: SmartBulb,
DeviceType.Strip: SmartStrip,
DeviceType.Dimmer: SmartDimmer,
DeviceType.LightStrip: SmartLightStrip,
DeviceType.TapoPlug: TapoPlug,
DeviceType.TapoBulb: TapoBulb,
}
_LOGGER = logging.getLogger(__name__)
GET_SYSINFO_QUERY = {
"system": {"get_sysinfo": None},
}
async def connect(
host: str,
*,
port: Optional[int] = None,
timeout=5,
credentials: Optional[Credentials] = None,
device_type: Optional[DeviceType] = None,
protocol_class: Optional[Type[TPLinkProtocol]] = None,
) -> "SmartDevice":
"""Connect to a single device by the given IP address.
async def connect(*, host: Optional[str] = None, config: DeviceConfig) -> "SmartDevice":
"""Connect to a single device by the given hostname or device configuration.
This method avoids the UDP based discovery process and
will connect directly to the device to query its type.
will connect directly to the device.
It is generally preferred to avoid :func:`discover_single()` and
use this function instead as it should perform better when
the WiFi network is congested or the device is not responding
to discovery requests.
The device type is discovered by querying the device.
Do not use this function directly, use SmartDevice.connect()
:param host: Hostname of device to query
:param device_type: Device type to use for the device.
If not given, the device type is discovered by querying the device.
If the device type is already known, it is preferred to pass it
to avoid the extra query to the device to discover its type.
:param protocol_class: Optionally provide the protocol class
to use.
:param config: Connection parameters to ensure the correct protocol
and connection options are used.
:rtype: SmartDevice
:return: Object for querying/controlling found device.
"""
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
if host and config or (not host and not config):
raise SmartDeviceException("One of host or config must be provded and not both")
if host:
config = DeviceConfig(host=host)
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
if debug_enabled:
start_time = time.perf_counter()
if device_type and (klass := DEVICE_TYPE_TO_CLASS.get(device_type)):
dev: SmartDevice = klass(
host=host, port=port, credentials=credentials, timeout=timeout
)
if protocol_class is not None:
dev.protocol = protocol_class(
host,
transport=AesTransport(
host, port=port, credentials=credentials, timeout=timeout
),
)
await dev.update()
def _perf_log(has_params, perf_type):
nonlocal start_time
if debug_enabled:
end_time = time.perf_counter()
_LOGGER.debug(
"Device %s with known type (%s) took %.2f seconds to connect",
host,
device_type.value,
end_time - start_time,
f"Device {config.host} with connection params {has_params} "
+ f"took {end_time - start_time:.2f} seconds to {perf_type}",
)
return dev
start_time = time.perf_counter()
unknown_dev = SmartDevice(
host=host, port=port, credentials=credentials, timeout=timeout
)
if protocol_class is not None:
# TODO this will be replaced with connection params
unknown_dev.protocol = protocol_class(
host,
transport=AesTransport(
host, port=port, credentials=credentials, timeout=timeout
),
if (protocol := get_protocol(config=config)) is None:
raise UnsupportedDeviceException(
f"Unsupported device for {config.host}: "
+ f"{config.connection_type.device_family.value}"
)
await unknown_dev.update()
device_class = get_device_class_from_sys_info(unknown_dev.internal_state)
dev = device_class(host=host, port=port, credentials=credentials, timeout=timeout)
# Reuse the connection from the unknown device
# so we don't have to reconnect
dev.protocol = unknown_dev.protocol
await dev.update()
if debug_enabled:
end_time = time.perf_counter()
_LOGGER.debug(
"Device %s with unknown type (%s) took %.2f seconds to connect",
host,
dev.device_type.value,
end_time - start_time,
device_class: Optional[Type[SmartDevice]]
if isinstance(protocol, TPLinkSmartHomeProtocol):
info = await protocol.query(GET_SYSINFO_QUERY)
_perf_log(True, "get_sysinfo")
device_class = get_device_class_from_sys_info(info)
device = device_class(config.host, protocol=protocol)
device.update_from_discover_info(info)
await device.update()
_perf_log(True, "update")
return device
elif device_class := get_device_class_from_family(
config.connection_type.device_family.value
):
device = device_class(host=config.host, protocol=protocol)
await device.update()
_perf_log(True, "update")
return device
else:
raise UnsupportedDeviceException(
f"Unsupported device for {config.host}: "
+ f"{config.connection_type.device_family.value}"
)
return dev
def get_device_class_from_sys_info(info: Dict[str, Any]) -> Type[SmartDevice]:
@ -147,32 +126,38 @@ def get_device_class_from_sys_info(info: Dict[str, Any]) -> Type[SmartDevice]:
raise UnsupportedDeviceException("Unknown device type: %s" % type_)
def get_device_class_from_type_name(device_type: str) -> Optional[Type[SmartDevice]]:
def get_device_class_from_family(device_type: str) -> Optional[Type[SmartDevice]]:
"""Return the device class from the type name."""
supported_device_types: dict[str, Type[SmartDevice]] = {
"SMART.TAPOPLUG": TapoPlug,
"SMART.TAPOBULB": TapoBulb,
"SMART.KASAPLUG": TapoPlug,
"IOT.SMARTPLUGSWITCH": SmartPlug,
"IOT.SMARTBULB": SmartBulb,
}
return supported_device_types.get(device_type)
def get_protocol_from_connection_name(
connection_name: str, host: str, credentials: Optional[Credentials] = None
def get_protocol(
config: DeviceConfig,
) -> Optional[TPLinkProtocol]:
"""Return the protocol from the connection name."""
protocol_name = config.connection_type.device_family.value.split(".")[0]
protocol_transport_key = (
protocol_name + "." + config.connection_type.encryption_type.value
)
supported_device_protocols: dict[
str, Tuple[Type[TPLinkProtocol], Type[BaseTransport]]
] = {
"IOT.XOR": (TPLinkSmartHomeProtocol, _XorTransport),
"IOT.KLAP": (IotProtocol, KlapTransport),
"SMART.AES": (SmartProtocol, AesTransport),
"SMART.KLAP": (SmartProtocol, TPlinkKlapTransportV2),
"SMART.KLAP": (SmartProtocol, KlapTransportV2),
}
if connection_name not in supported_device_protocols:
if protocol_transport_key not in supported_device_protocols:
return None
protocol_class, transport_class = supported_device_protocols.get(connection_name) # type: ignore
transport: BaseTransport = transport_class(host, credentials=credentials)
protocol: TPLinkProtocol = protocol_class(host, transport=transport)
return protocol
protocol_class, transport_class = supported_device_protocols.get(
protocol_transport_key
) # type: ignore
return protocol_class(transport=transport_class(config=config))

148
kasa/deviceconfig.py Normal file
View File

@ -0,0 +1,148 @@
"""Module for holding connection parameters."""
import logging
from dataclasses import asdict, dataclass, field, fields, is_dataclass
from enum import Enum
from typing import Dict, Optional
import httpx
from .credentials import Credentials
from .exceptions import SmartDeviceException
_LOGGER = logging.getLogger(__name__)
class EncryptType(Enum):
"""Encrypt type enum."""
Klap = "KLAP"
Aes = "AES"
Xor = "XOR"
class DeviceFamilyType(Enum):
"""Encrypt type enum."""
IotSmartPlugSwitch = "IOT.SMARTPLUGSWITCH"
IotSmartBulb = "IOT.SMARTBULB"
SmartKasaPlug = "SMART.KASAPLUG"
SmartTapoPlug = "SMART.TAPOPLUG"
SmartTapoBulb = "SMART.TAPOBULB"
def _dataclass_from_dict(klass, in_val):
if is_dataclass(klass):
fieldtypes = {f.name: f.type for f in fields(klass)}
val = {}
for dict_key in in_val:
if dict_key in fieldtypes and hasattr(fieldtypes[dict_key], "from_dict"):
val[dict_key] = fieldtypes[dict_key].from_dict(in_val[dict_key])
else:
val[dict_key] = _dataclass_from_dict(
fieldtypes[dict_key], in_val[dict_key]
)
return klass(**val)
else:
return in_val
def _dataclass_to_dict(in_val):
fieldtypes = {f.name: f.type for f in fields(in_val) if f.compare}
out_val = {}
for field_name in fieldtypes:
val = getattr(in_val, field_name)
if val is None:
continue
elif hasattr(val, "to_dict"):
out_val[field_name] = val.to_dict()
elif is_dataclass(fieldtypes[field_name]):
out_val[field_name] = asdict(val)
else:
out_val[field_name] = val
return out_val
@dataclass
class ConnectionType:
"""Class to hold the the parameters determining connection type."""
device_family: DeviceFamilyType
encryption_type: EncryptType
@staticmethod
def from_values(
device_family: str,
encryption_type: str,
) -> "ConnectionType":
"""Return connection parameters from string values."""
try:
return ConnectionType(
DeviceFamilyType(device_family),
EncryptType(encryption_type),
)
except ValueError as ex:
raise SmartDeviceException(
f"Invalid connection parameters for {device_family}.{encryption_type}"
) from ex
@staticmethod
def from_dict(connection_type_dict: Dict[str, str]) -> "ConnectionType":
"""Return connection parameters from dict."""
if (
isinstance(connection_type_dict, dict)
and (device_family := connection_type_dict.get("device_family"))
and (encryption_type := connection_type_dict.get("encryption_type"))
):
return ConnectionType.from_values(device_family, encryption_type)
raise SmartDeviceException(
f"Invalid connection type data for {connection_type_dict}"
)
def to_dict(self) -> Dict[str, str]:
"""Convert connection params to dict."""
result = {
"device_family": self.device_family.value,
"encryption_type": self.encryption_type.value,
}
return result
@dataclass
class DeviceConfig:
"""Class to represent paramaters that determine how to connect to devices."""
DEFAULT_TIMEOUT = 5
host: str
timeout: Optional[int] = DEFAULT_TIMEOUT
port_override: Optional[int] = None
credentials: Credentials = field(
default_factory=lambda: Credentials(username="", password="")
)
connection_type: ConnectionType = field(
default_factory=lambda: ConnectionType(
DeviceFamilyType.IotSmartPlugSwitch, EncryptType.Xor
)
)
uses_http: bool = False
# compare=False will be excluded from the serialization and object comparison.
http_client: Optional[httpx.AsyncClient] = field(default=None, compare=False)
def __post_init__(self):
if self.credentials is None:
self.credentials = Credentials(username="", password="")
if self.connection_type is None:
self.connection_type = ConnectionType(
DeviceFamilyType.IotSmartPlugSwitch, EncryptType.Xor
)
def to_dict(self) -> Dict[str, Dict[str, str]]:
"""Convert connection params to dict."""
return _dataclass_to_dict(self)
@staticmethod
def from_dict(cparam_dict: Dict[str, Dict[str, str]]) -> "DeviceConfig":
"""Return connection parameters from dict."""
return _dataclass_from_dict(DeviceConfig, cparam_dict)

View File

@ -1,5 +1,6 @@
"""Discovery module for TP-Link Smart Home devices."""
import asyncio
import base64
import binascii
import ipaddress
import logging
@ -11,29 +12,32 @@ from typing import Awaitable, Callable, Dict, Optional, Set, Type, cast
from async_timeout import timeout as asyncio_timeout
try:
from pydantic.v1 import BaseModel, Field
from pydantic.v1 import BaseModel, ValidationError # pragma: no cover
except ImportError:
from pydantic import BaseModel, Field
from pydantic import BaseModel, ValidationError # pragma: no cover
from kasa.credentials import Credentials
from kasa.device_factory import (
get_device_class_from_family,
get_device_class_from_sys_info,
get_protocol,
)
from kasa.deviceconfig import ConnectionType, DeviceConfig, EncryptType
from kasa.exceptions import UnsupportedDeviceException
from kasa.json import dumps as json_dumps
from kasa.json import loads as json_loads
from kasa.protocol import TPLinkSmartHomeProtocol
from kasa.smartdevice import SmartDevice, SmartDeviceException
from .device_factory import (
get_device_class_from_sys_info,
get_device_class_from_type_name,
get_protocol_from_connection_name,
)
_LOGGER = logging.getLogger(__name__)
OnDiscoveredCallable = Callable[[SmartDevice], Awaitable[None]]
DeviceDict = Dict[str, SmartDevice]
UNAVAILABLE_ALIAS = "Authentication required"
UNAVAILABLE_NICKNAME = base64.b64encode(UNAVAILABLE_ALIAS.encode()).decode()
class _DiscoverProtocol(asyncio.DatagramProtocol):
"""Implementation of the discovery protocol handler.
@ -62,9 +66,12 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
self.discovery_packets = discovery_packets
self.interface = interface
self.on_discovered = on_discovered
self.port = port
self.discovery_port = port or Discover.DISCOVERY_PORT
self.target = (target, self.discovery_port)
self.target_2 = (target, Discover.DISCOVERY_PORT_2)
self.discovered_devices = {}
self.unsupported_device_exceptions: Dict = {}
self.invalid_device_exceptions: Dict = {}
@ -110,13 +117,18 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
self.seen_hosts.add(ip)
device = None
config = DeviceConfig(host=ip, port_override=self.port)
if self.credentials:
config.credentials = self.credentials
if self.timeout:
config.timeout = self.timeout
try:
if port == self.discovery_port:
device = Discover._get_device_instance_legacy(data, ip, port)
device = Discover._get_device_instance_legacy(data, config)
elif port == Discover.DISCOVERY_PORT_2:
device = Discover._get_device_instance(
data, ip, port, self.credentials or Credentials()
)
config.uses_http = True
device = Discover._get_device_instance(data, config)
else:
return
except UnsupportedDeviceException as udex:
@ -200,11 +212,13 @@ class Discover:
*,
target="255.255.255.255",
on_discovered=None,
timeout=5,
discovery_timeout=5,
discovery_packets=3,
interface=None,
on_unsupported=None,
credentials=None,
port=None,
timeout=None,
) -> DeviceDict:
"""Discover supported devices.
@ -240,14 +254,15 @@ class Discover:
on_unsupported=on_unsupported,
credentials=credentials,
timeout=timeout,
port=port,
),
local_addr=("0.0.0.0", 0), # noqa: S104
)
protocol = cast(_DiscoverProtocol, protocol)
try:
_LOGGER.debug("Waiting %s seconds for responses...", timeout)
await asyncio.sleep(timeout)
_LOGGER.debug("Waiting %s seconds for responses...", discovery_timeout)
await asyncio.sleep(discovery_timeout)
finally:
transport.close()
@ -259,10 +274,10 @@ class Discover:
async def discover_single(
host: str,
*,
discovery_timeout: int = 5,
port: Optional[int] = None,
timeout=5,
timeout: Optional[int] = None,
credentials: Optional[Credentials] = None,
update_parent_devices: bool = True,
) -> SmartDevice:
"""Discover a single device by the given IP address.
@ -275,8 +290,6 @@ class Discover:
:param port: Optionally set a different port for the device
:param timeout: Timeout for discovery
:param credentials: Credentials for devices that require authentication
:param update_parent_devices: Automatically call device.update() on
devices that have children
:rtype: SmartDevice
:return: Object for querying/controlling found device.
"""
@ -320,9 +333,11 @@ class Discover:
protocol = cast(_DiscoverProtocol, protocol)
try:
_LOGGER.debug("Waiting a total of %s seconds for responses...", timeout)
_LOGGER.debug(
"Waiting a total of %s seconds for responses...", discovery_timeout
)
async with asyncio_timeout(timeout):
async with asyncio_timeout(discovery_timeout):
await event.wait()
except asyncio.TimeoutError as ex:
raise SmartDeviceException(
@ -334,9 +349,6 @@ class Discover:
if ip in protocol.discovered_devices:
dev = protocol.discovered_devices[ip]
dev.host = host
# Call device update on devices that have children
if update_parent_devices and dev.has_children:
await dev.update()
return dev
elif ip in protocol.unsupported_device_exceptions:
raise protocol.unsupported_device_exceptions[ip]
@ -350,99 +362,121 @@ class Discover:
"""Find SmartDevice subclass for device described by passed data."""
if "result" in info:
discovery_result = DiscoveryResult(**info["result"])
dev_class = get_device_class_from_type_name(discovery_result.device_type)
dev_class = get_device_class_from_family(discovery_result.device_type)
if not dev_class:
raise UnsupportedDeviceException(
"Unknown device type: %s" % discovery_result.device_type
"Unknown device type: %s" % discovery_result.device_type,
discovery_result=info,
)
return dev_class
else:
return get_device_class_from_sys_info(info)
@staticmethod
def _get_device_instance_legacy(data: bytes, ip: str, port: int) -> SmartDevice:
def _get_device_instance_legacy(data: bytes, config: DeviceConfig) -> SmartDevice:
"""Get SmartDevice from legacy 9999 response."""
try:
info = json_loads(TPLinkSmartHomeProtocol.decrypt(data))
except Exception as ex:
raise SmartDeviceException(
f"Unable to read response from device: {ip}: {ex}"
f"Unable to read response from device: {config.host}: {ex}"
) from ex
_LOGGER.debug("[DISCOVERY] %s << %s", ip, info)
_LOGGER.debug("[DISCOVERY] %s << %s", config.host, info)
device_class = Discover._get_device_class(info)
device = device_class(ip, port=port)
device = device_class(config.host, config=config)
sys_info = info["system"]["get_sysinfo"]
if device_type := sys_info.get("mic_type", sys_info.get("type")):
config.connection_type = ConnectionType.from_values(
device_family=device_type, encryption_type=EncryptType.Xor.value
)
device.protocol = get_protocol(config) # type: ignore[assignment]
device.update_from_discover_info(info)
return device
@staticmethod
def _get_device_instance(
data: bytes, ip: str, port: int, credentials: Credentials
data: bytes,
config: DeviceConfig,
) -> SmartDevice:
"""Get SmartDevice from the new 20002 response."""
try:
info = json_loads(data[16:])
discovery_result = DiscoveryResult(**info["result"])
except Exception as ex:
_LOGGER.debug("Got invalid response from device %s: %s", config.host, data)
raise SmartDeviceException(
f"Unable to read response from device: {config.host}: {ex}"
) from ex
try:
discovery_result = DiscoveryResult(**info["result"])
except ValidationError as ex:
_LOGGER.debug(
"Unable to parse discovery from device %s: %s", config.host, info
)
raise UnsupportedDeviceException(
f"Unable to read response from device: {ip}: {ex}"
f"Unable to parse discovery from device: {config.host}: {ex}"
) from ex
type_ = discovery_result.device_type
encrypt_type_ = (
f"{type_.split('.')[0]}.{discovery_result.mgt_encrypt_schm.encrypt_type}"
)
if (device_class := get_device_class_from_type_name(type_)) is None:
try:
config.connection_type = ConnectionType.from_values(
type_, discovery_result.mgt_encrypt_schm.encrypt_type
)
except SmartDeviceException as ex:
raise UnsupportedDeviceException(
f"Unsupported device {config.host} of type {type_} "
+ f"with encrypt_type {discovery_result.mgt_encrypt_schm.encrypt_type}",
discovery_result=discovery_result.get_dict(),
) from ex
if (device_class := get_device_class_from_family(type_)) is None:
_LOGGER.warning("Got unsupported device type: %s", type_)
raise UnsupportedDeviceException(
f"Unsupported device {ip} of type {type_}: {info}",
f"Unsupported device {config.host} of type {type_}: {info}",
discovery_result=discovery_result.get_dict(),
)
if (
protocol := get_protocol_from_connection_name(
encrypt_type_, ip, credentials=credentials
if (protocol := get_protocol(config)) is None:
_LOGGER.warning(
"Got unsupported connection type: %s", config.connection_type.to_dict()
)
) is None:
_LOGGER.warning("Got unsupported device type: %s", encrypt_type_)
raise UnsupportedDeviceException(
f"Unsupported encryption scheme {ip} of type {encrypt_type_}: {info}",
f"Unsupported encryption scheme {config.host} of "
+ f"type {config.connection_type.to_dict()}: {info}",
discovery_result=discovery_result.get_dict(),
)
_LOGGER.debug("[DISCOVERY] %s << %s", ip, info)
device = device_class(ip, port=port, credentials=credentials)
device.protocol = protocol
device.update_from_discover_info(discovery_result.get_dict())
_LOGGER.debug("[DISCOVERY] %s << %s", config.host, info)
device = device_class(config.host, protocol=protocol)
di = discovery_result.get_dict()
di["model"] = discovery_result.device_model
di["alias"] = UNAVAILABLE_ALIAS
di["nickname"] = UNAVAILABLE_NICKNAME
device.update_from_discover_info(di)
return device
class DiscoveryResult(BaseModel):
"""Base model for discovery result."""
class Config:
"""Class for configuring model behaviour."""
allow_population_by_field_name = True
class EncryptionScheme(BaseModel):
"""Base model for encryption scheme of discovery result."""
is_support_https: Optional[bool] = None
encrypt_type: Optional[str] = None
http_port: Optional[int] = None
lv: Optional[int] = 1
is_support_https: bool
encrypt_type: str
http_port: int
lv: Optional[int] = None
device_type: str = Field(alias="device_type_text")
device_model: str = Field(alias="model")
ip: str = Field(alias="alias")
device_type: str
device_model: str
ip: str
mac: str
mgt_encrypt_schm: EncryptionScheme
device_id: str
device_id: Optional[str] = Field(default=None, alias="device_id_hash")
owner: Optional[str] = Field(default=None, alias="device_owner_hash")
hw_ver: Optional[str] = None
owner: Optional[str] = None
is_support_iot_cloud: Optional[bool] = None
obd_src: Optional[str] = None
factory_default: Optional[bool] = None
@ -453,5 +487,5 @@ class DiscoveryResult(BaseModel):
containing only the values actually set and with aliases as field names.
"""
return self.dict(
by_alias=True, exclude_unset=True, exclude_none=True, exclude_defaults=True
by_alias=False, exclude_unset=True, exclude_none=True, exclude_defaults=True
)

View File

@ -17,12 +17,11 @@ class IotProtocol(TPLinkProtocol):
def __init__(
self,
host: str,
*,
transport: BaseTransport,
) -> None:
"""Create a protocol object."""
super().__init__(host, transport=transport)
super().__init__(transport=transport)
self._query_lock = asyncio.Lock()
@ -39,25 +38,21 @@ class IotProtocol(TPLinkProtocol):
for retry in range(retry_count + 1):
try:
return await self._execute_query(request, retry)
except httpx.CloseError as sdex:
await self.close()
except httpx.ConnectError as sdex:
if retry >= retry_count:
await self.close()
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise SmartDeviceException(
f"Unable to connect to the device: {self._host}: {sdex}"
) from sdex
continue
except httpx.ConnectError as cex:
await self.close()
raise SmartDeviceException(
f"Unable to connect to the device: {self._host}: {cex}"
) from cex
except TimeoutError as tex:
await self.close()
raise SmartDeviceException(
f"Unable to connect to the device, timed out: {self._host}: {tex}"
) from tex
except AuthenticationException as auex:
await self.close()
_LOGGER.debug(
"Unable to authenticate with %s, not retrying", self._host
)
@ -70,8 +65,8 @@ class IotProtocol(TPLinkProtocol):
)
raise ex
except Exception as ex:
await self.close()
if retry >= retry_count:
await self.close()
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise SmartDeviceException(
f"Unable to connect to the device: {self._host}: {ex}"

View File

@ -54,6 +54,7 @@ from cryptography.hazmat.primitives import hashes, padding
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from .credentials import Credentials
from .deviceconfig import DeviceConfig
from .exceptions import AuthenticationException, SmartDeviceException
from .json import loads as json_loads
from .protocol import BaseTransport, md5
@ -82,27 +83,21 @@ class KlapTransport(BaseTransport):
protocol, used by newer firmware versions.
"""
DEFAULT_PORT = 80
DEFAULT_PORT: int = 80
DISCOVERY_QUERY = {"system": {"get_sysinfo": None}}
KASA_SETUP_EMAIL = "kasa@tp-link.net"
KASA_SETUP_PASSWORD = "kasaSetup" # noqa: S105
SESSION_COOKIE_NAME = "TP_SESSIONID"
def __init__(
self,
host: str,
*,
port: Optional[int] = None,
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
config: DeviceConfig,
) -> None:
super().__init__(
host,
port=port or self.DEFAULT_PORT,
credentials=credentials,
timeout=timeout,
)
super().__init__(config=config)
self._default_http_client: Optional[httpx.AsyncClient] = None
self._local_seed: Optional[bytes] = None
self._local_auth_hash = self.generate_auth_hash(self._credentials)
self._local_auth_owner = self.generate_owner_hash(self._credentials).hex()
@ -116,14 +111,24 @@ class KlapTransport(BaseTransport):
self._session_expire_at: Optional[float] = None
self._session_cookie = None
self._http_client: httpx.AsyncClient = httpx.AsyncClient()
_LOGGER.debug("Created KLAP transport for %s", self._host)
@property
def default_port(self):
"""Default port for the transport."""
return self.DEFAULT_PORT
@property
def _http_client(self) -> httpx.AsyncClient:
if self._config.http_client:
return self._config.http_client
if not self._default_http_client:
self._default_http_client = httpx.AsyncClient()
return self._default_http_client
async def client_post(self, url, params=None, data=None):
"""Send an http post request to the device."""
if not self._http_client:
self._http_client = httpx.AsyncClient()
response_data = None
cookies = None
if self._session_cookie:
@ -355,8 +360,8 @@ class KlapTransport(BaseTransport):
async def close(self) -> None:
"""Close the transport."""
client = self._http_client
self._http_client = None
client = self._default_http_client
self._default_http_client = None
self._handshake_done = False
if client:
await client.aclose()
@ -390,7 +395,7 @@ class KlapTransport(BaseTransport):
return md5(un.encode())
class TPlinkKlapTransportV2(KlapTransport):
class KlapTransportV2(KlapTransport):
"""Implementation of the KLAP encryption protocol with v2 hanshake hashes."""
@staticmethod

View File

@ -24,7 +24,7 @@ from typing import Dict, Generator, Optional, Union
from async_timeout import timeout as asyncio_timeout
from cryptography.hazmat.primitives import hashes
from .credentials import Credentials
from .deviceconfig import DeviceConfig
from .exceptions import SmartDeviceException
from .json import dumps as json_dumps
from .json import loads as json_loads
@ -48,17 +48,20 @@ class BaseTransport(ABC):
def __init__(
self,
host: str,
*,
port: Optional[int] = None,
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
config: DeviceConfig,
) -> None:
"""Create a protocol object."""
self._host = host
self._port = port
self._credentials = credentials or Credentials(username="", password="")
self._timeout = timeout or self.DEFAULT_TIMEOUT
self._config = config
self._host = config.host
self._port = config.port_override or self.default_port
self._credentials = config.credentials
self._timeout = config.timeout
@property
@abstractmethod
def default_port(self) -> int:
"""The default port for the transport."""
@abstractmethod
async def send(self, request: str) -> Dict:
@ -74,7 +77,6 @@ class TPLinkProtocol(ABC):
def __init__(
self,
host: str,
*,
transport: BaseTransport,
) -> None:
@ -85,6 +87,11 @@ class TPLinkProtocol(ABC):
def _host(self):
return self._transport._host
@property
def config(self) -> DeviceConfig:
"""Return the connection parameters the device is using."""
return self._transport._config
@abstractmethod
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
"""Query the device for the protocol. Abstract method to be overriden."""
@ -103,22 +110,15 @@ class _XorTransport(BaseTransport):
class.
"""
DEFAULT_PORT = 9999
DEFAULT_PORT: int = 9999
def __init__(
self,
host: str,
*,
port: Optional[int] = None,
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
) -> None:
super().__init__(
host,
port=port or self.DEFAULT_PORT,
credentials=credentials,
timeout=timeout,
)
def __init__(self, *, config: DeviceConfig) -> None:
super().__init__(config=config)
@property
def default_port(self):
"""Default port for the transport."""
return self.DEFAULT_PORT
async def send(self, request: str) -> Dict:
"""Send a message to the device and return a response."""
@ -133,17 +133,15 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
INITIALIZATION_VECTOR = 171
DEFAULT_PORT = 9999
DEFAULT_TIMEOUT = 5
BLOCK_SIZE = 4
def __init__(
self,
host: str,
*,
transport: BaseTransport,
) -> None:
"""Create a protocol object."""
super().__init__(host, transport=transport)
super().__init__(transport=transport)
self.reader: Optional[asyncio.StreamReader] = None
self.writer: Optional[asyncio.StreamWriter] = None
@ -167,7 +165,7 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
assert isinstance(request, str) # noqa: S101
async with self.query_lock:
return await self._query(request, retry_count, self._timeout)
return await self._query(request, retry_count, self._timeout) # type: ignore[arg-type]
async def _connect(self, timeout: int) -> None:
"""Try to connect or reconnect to the device."""

View File

@ -9,8 +9,9 @@ try:
except ImportError:
from pydantic import BaseModel, Field, root_validator
from .credentials import Credentials
from .deviceconfig import DeviceConfig
from .modules import Antitheft, Cloud, Countdown, Emeter, Schedule, Time, Usage
from .protocol import TPLinkProtocol
from .smartdevice import DeviceType, SmartDevice, SmartDeviceException, requires_update
@ -220,11 +221,10 @@ class SmartBulb(SmartDevice):
self,
host: str,
*,
port: Optional[int] = None,
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
config: Optional[DeviceConfig] = None,
protocol: Optional[TPLinkProtocol] = None,
) -> None:
super().__init__(host=host, port=port, credentials=credentials, timeout=timeout)
super().__init__(host=host, config=config, protocol=protocol)
self._device_type = DeviceType.Bulb
self.add_module("schedule", Schedule(self, "smartlife.iot.common.schedule"))
self.add_module("usage", Usage(self, "smartlife.iot.common.schedule"))

View File

@ -21,6 +21,7 @@ from typing import Any, Dict, List, Optional, Set
from .credentials import Credentials
from .device_type import DeviceType
from .deviceconfig import DeviceConfig
from .emeterstatus import EmeterStatus
from .exceptions import SmartDeviceException
from .modules import Emeter, Module
@ -191,20 +192,18 @@ class SmartDevice:
self,
host: str,
*,
port: Optional[int] = None,
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
config: Optional[DeviceConfig] = None,
protocol: Optional[TPLinkProtocol] = None,
) -> None:
"""Create a new SmartDevice instance.
:param str host: host name or ip address on which the device listens
"""
self.host = host
self.port = port
self.protocol: TPLinkProtocol = TPLinkSmartHomeProtocol(
host, transport=_XorTransport(host, port=port, timeout=timeout)
if config and protocol:
protocol._transport._config = config
self.protocol: TPLinkProtocol = protocol or TPLinkSmartHomeProtocol(
transport=_XorTransport(config=config or DeviceConfig(host=host)),
)
self.credentials = credentials
_LOGGER.debug("Initializing %s of type %s", self.host, type(self))
self._device_type = DeviceType.Unknown
# TODO: typing Any is just as using Optional[Dict] would require separate
@ -219,6 +218,30 @@ class SmartDevice:
self.children: List["SmartDevice"] = []
@property
def host(self) -> str:
"""The device host."""
return self.protocol._transport._host
@host.setter
def host(self, value):
"""Set the device host.
Generally used by discovery to set the hostname after ip discovery.
"""
self.protocol._transport._host = value
self.protocol._transport._config.host = value
@property
def port(self) -> int:
"""The device port."""
return self.protocol._transport._port
@property
def credentials(self) -> Optional[Credentials]:
"""The device credentials."""
return self.protocol._transport._credentials
def add_module(self, name: str, module: Module):
"""Register a module."""
if name in self.modules:
@ -760,7 +783,7 @@ class SmartDevice:
The returned object contains the raw results from the last update call.
This should only be used for debugging purposes.
"""
return self._last_update
return self._last_update or self._discovery_info
def __repr__(self):
if self._last_update is None:
@ -771,41 +794,33 @@ class SmartDevice:
f" - dev specific: {self.state_information}>"
)
@property
def config(self) -> DeviceConfig:
"""Return the connection parameters the device is using."""
return self.protocol.config
@staticmethod
async def connect(
host: str,
*,
port: Optional[int] = None,
timeout=5,
credentials: Optional[Credentials] = None,
device_type: Optional[DeviceType] = None,
host: Optional[str] = None,
config: Optional[DeviceConfig] = None,
) -> "SmartDevice":
"""Connect to a single device by the given IP address.
"""Connect to a single device by the given hostname or device configuration.
This method avoids the UDP based discovery process and
will connect directly to the device to query its type.
will connect directly to the device.
It is generally preferred to avoid :func:`discover_single()` and
use this function instead as it should perform better when
the WiFi network is congested or the device is not responding
to discovery requests.
The device type is discovered by querying the device.
:param host: Hostname of device to query
:param device_type: Device type to use for the device.
If not given, the device type is discovered by querying the device.
If the device type is already known, it is preferred to pass it
to avoid the extra query to the device to discover its type.
:param config: Connection parameters to ensure the correct protocol
and connection options are used.
:rtype: SmartDevice
:return: Object for querying/controlling found device.
"""
from .device_factory import connect # pylint: disable=import-outside-toplevel
return await connect(
host=host,
port=port,
timeout=timeout,
credentials=credentials,
device_type=device_type,
)
return await connect(host=host, config=config) # type: ignore[arg-type]

View File

@ -2,8 +2,9 @@
from enum import Enum
from typing import Any, Dict, Optional
from kasa.credentials import Credentials
from kasa.deviceconfig import DeviceConfig
from kasa.modules import AmbientLight, Motion
from kasa.protocol import TPLinkProtocol
from kasa.smartdevice import DeviceType, SmartDeviceException, requires_update
from kasa.smartplug import SmartPlug
@ -68,11 +69,10 @@ class SmartDimmer(SmartPlug):
self,
host: str,
*,
port: Optional[int] = None,
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
config: Optional[DeviceConfig] = None,
protocol: Optional[TPLinkProtocol] = None,
) -> None:
super().__init__(host, port=port, credentials=credentials, timeout=timeout)
super().__init__(host=host, config=config, protocol=protocol)
self._device_type = DeviceType.Dimmer
# TODO: need to be verified if it's okay to call these on HS220 w/o these
# TODO: need to be figured out what's the best approach to detect support

View File

@ -1,8 +1,9 @@
"""Module for light strips (KL430)."""
from typing import Any, Dict, List, Optional
from .credentials import Credentials
from .deviceconfig import DeviceConfig
from .effects import EFFECT_MAPPING_V1, EFFECT_NAMES_V1
from .protocol import TPLinkProtocol
from .smartbulb import SmartBulb
from .smartdevice import DeviceType, SmartDeviceException, requires_update
@ -46,11 +47,10 @@ class SmartLightStrip(SmartBulb):
self,
host: str,
*,
port: Optional[int] = None,
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
config: Optional[DeviceConfig] = None,
protocol: Optional[TPLinkProtocol] = None,
) -> None:
super().__init__(host, port=port, credentials=credentials, timeout=timeout)
super().__init__(host=host, config=config, protocol=protocol)
self._device_type = DeviceType.LightStrip
@property # type: ignore

View File

@ -2,8 +2,9 @@
import logging
from typing import Any, Dict, Optional
from kasa.credentials import Credentials
from kasa.deviceconfig import DeviceConfig
from kasa.modules import Antitheft, Cloud, Schedule, Time, Usage
from kasa.protocol import TPLinkProtocol
from kasa.smartdevice import DeviceType, SmartDevice, requires_update
_LOGGER = logging.getLogger(__name__)
@ -43,11 +44,10 @@ class SmartPlug(SmartDevice):
self,
host: str,
*,
port: Optional[int] = None,
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
config: Optional[DeviceConfig] = None,
protocol: Optional[TPLinkProtocol] = None,
) -> None:
super().__init__(host, port=port, credentials=credentials, timeout=timeout)
super().__init__(host=host, config=config, protocol=protocol)
self._device_type = DeviceType.Plug
self.add_module("schedule", Schedule(self, "schedule"))
self.add_module("usage", Usage(self, "schedule"))

View File

@ -38,12 +38,11 @@ class SmartProtocol(TPLinkProtocol):
def __init__(
self,
host: str,
*,
transport: BaseTransport,
) -> None:
"""Create a protocol object."""
super().__init__(host, transport=transport)
super().__init__(transport=transport)
self._terminal_uuid: str = base64.b64encode(md5(uuid.uuid4().bytes)).decode()
self._request_id_generator = SnowflakeId(1, 1)
self._query_lock = asyncio.Lock()
@ -68,19 +67,14 @@ class SmartProtocol(TPLinkProtocol):
for retry in range(retry_count + 1):
try:
return await self._execute_query(request, retry)
except httpx.CloseError as sdex:
await self.close()
except httpx.ConnectError as sdex:
if retry >= retry_count:
await self.close()
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise SmartDeviceException(
f"Unable to connect to the device: {self._host}: {sdex}"
) from sdex
continue
except httpx.ConnectError as cex:
await self.close()
raise SmartDeviceException(
f"Unable to connect to the device: {self._host}: {cex}"
) from cex
except TimeoutError as tex:
if retry >= retry_count:
await self.close()

View File

@ -14,8 +14,9 @@ from kasa.smartdevice import (
)
from kasa.smartplug import SmartPlug
from .credentials import Credentials
from .deviceconfig import DeviceConfig
from .modules import Antitheft, Countdown, Emeter, Schedule, Time, Usage
from .protocol import TPLinkProtocol
_LOGGER = logging.getLogger(__name__)
@ -85,11 +86,10 @@ class SmartStrip(SmartDevice):
self,
host: str,
*,
port: Optional[int] = None,
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
config: Optional[DeviceConfig] = None,
protocol: Optional[TPLinkProtocol] = None,
) -> None:
super().__init__(host=host, port=port, credentials=credentials, timeout=timeout)
super().__init__(host=host, config=config, protocol=protocol)
self.emeter_type = "emeter"
self._device_type = DeviceType.Strip
self.add_module("antitheft", Antitheft(self, "anti_theft"))

View File

@ -5,8 +5,9 @@ from datetime import datetime, timedelta, timezone
from typing import Any, Dict, Optional, Set, cast
from ..aestransport import AesTransport
from ..credentials import Credentials
from ..deviceconfig import DeviceConfig
from ..exceptions import AuthenticationException
from ..protocol import TPLinkProtocol
from ..smartdevice import SmartDevice
from ..smartprotocol import SmartProtocol
@ -20,20 +21,16 @@ class TapoDevice(SmartDevice):
self,
host: str,
*,
port: Optional[int] = None,
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
config: Optional[DeviceConfig] = None,
protocol: Optional[TPLinkProtocol] = None,
) -> None:
super().__init__(host, port=port, credentials=credentials, timeout=timeout)
_protocol = protocol or SmartProtocol(
transport=AesTransport(config=config or DeviceConfig(host=host)),
)
super().__init__(host=host, config=config, protocol=_protocol)
self._components: Optional[Dict[str, Any]] = None
self._state_information: Dict[str, Any] = {}
self._discovery_info: Optional[Dict[str, Any]] = None
self.protocol = SmartProtocol(
host,
transport=AesTransport(
host, credentials=credentials, timeout=timeout, port=port
),
)
async def update(self, update_children: bool = True):
"""Update the device."""
@ -66,7 +63,7 @@ class TapoDevice(SmartDevice):
@property
def sys_info(self) -> Dict[str, Any]:
"""Returns the device info."""
return self._info
return self._info # type: ignore
@property
def model(self) -> str:
@ -180,3 +177,4 @@ class TapoDevice(SmartDevice):
def update_from_discover_info(self, info):
"""Update state from info from the discover call."""
self._discovery_info = info
self._info = info

View File

@ -3,9 +3,10 @@ import logging
from datetime import datetime, timedelta
from typing import Any, Dict, Optional, cast
from ..credentials import Credentials
from ..deviceconfig import DeviceConfig
from ..emeterstatus import EmeterStatus
from ..modules import Emeter
from ..protocol import TPLinkProtocol
from ..smartdevice import DeviceType, requires_update
from .tapodevice import TapoDevice
@ -19,11 +20,10 @@ class TapoPlug(TapoDevice):
self,
host: str,
*,
port: Optional[int] = None,
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
config: Optional[DeviceConfig] = None,
protocol: Optional[TPLinkProtocol] = None,
) -> None:
super().__init__(host, port=port, credentials=credentials, timeout=timeout)
super().__init__(host=host, config=config, protocol=protocol)
self._device_type = DeviceType.Plug
self.modules: Dict[str, Any] = {}
self.emeter_type = "emeter"

View File

@ -388,7 +388,6 @@ async def get_device_for_file(file, protocol):
d = device_for_file(model, protocol)(host="127.0.0.123")
if protocol == "SMART":
d.protocol = FakeSmartProtocol(sysinfo)
d.credentials = Credentials("", "")
else:
d.protocol = FakeTransportProtocol(sysinfo)
await _update_and_close(d)
@ -426,28 +425,53 @@ def discovery_mock(all_fixture_data, mocker):
class _DiscoveryMock:
ip: str
default_port: int
discovery_port: int
discovery_data: dict
query_data: dict
device_type: str
encrypt_type: str
port_override: Optional[int] = None
if "discovery_result" in all_fixture_data:
discovery_data = {"result": all_fixture_data["discovery_result"]}
device_type = all_fixture_data["discovery_result"]["device_type"]
encrypt_type = all_fixture_data["discovery_result"]["mgt_encrypt_schm"][
"encrypt_type"
]
datagram = (
b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8"
+ json_dumps(discovery_data).encode()
)
dm = _DiscoveryMock("127.0.0.123", 20002, discovery_data, all_fixture_data)
dm = _DiscoveryMock(
"127.0.0.123",
80,
20002,
discovery_data,
all_fixture_data,
device_type,
encrypt_type,
)
else:
sys_info = all_fixture_data["system"]["get_sysinfo"]
discovery_data = {"system": {"get_sysinfo": sys_info}}
device_type = sys_info.get("mic_type") or sys_info.get("type")
encrypt_type = "XOR"
datagram = TPLinkSmartHomeProtocol.encrypt(json_dumps(discovery_data))[4:]
dm = _DiscoveryMock("127.0.0.123", 9999, discovery_data, all_fixture_data)
dm = _DiscoveryMock(
"127.0.0.123",
9999,
9999,
discovery_data,
all_fixture_data,
device_type,
encrypt_type,
)
def mock_discover(self):
port = (
dm.port_override
if dm.port_override and dm.default_port != 20002
else dm.default_port
if dm.port_override and dm.discovery_port != 20002
else dm.discovery_port
)
self.datagram_received(
datagram,

View File

@ -15,7 +15,9 @@ from voluptuous import (
Schema,
)
from ..protocol import BaseTransport, TPLinkSmartHomeProtocol
from ..credentials import Credentials
from ..deviceconfig import DeviceConfig
from ..protocol import BaseTransport, TPLinkSmartHomeProtocol, _XorTransport
from ..smartprotocol import SmartProtocol
_LOGGER = logging.getLogger(__name__)
@ -290,7 +292,9 @@ TIME_MODULE = {
class FakeSmartProtocol(SmartProtocol):
def __init__(self, info):
super().__init__("127.0.0.123", transport=FakeSmartTransport(info))
super().__init__(
transport=FakeSmartTransport(info),
)
async def query(self, request, retry_count: int = 3):
"""Implement query here so can still patch SmartProtocol.query."""
@ -301,10 +305,15 @@ class FakeSmartProtocol(SmartProtocol):
class FakeSmartTransport(BaseTransport):
def __init__(self, info):
super().__init__(
"127.0.0.123",
config=DeviceConfig("127.0.0.123", credentials=Credentials()),
)
self.info = info
@property
def default_port(self):
"""Default port for the transport."""
return 80
async def send(self, request: str):
request_dict = json_loads(request)
method = request_dict["method"]
@ -344,6 +353,11 @@ class FakeSmartTransport(BaseTransport):
class FakeTransportProtocol(TPLinkSmartHomeProtocol):
def __init__(self, info):
super().__init__(
transport=_XorTransport(
config=DeviceConfig("127.0.0.123"),
)
)
self.discovery_data = info
self.writer = None
self.reader = None

View File

@ -12,6 +12,7 @@ from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padd
from ..aestransport import AesEncyptionSession, AesTransport
from ..credentials import Credentials
from ..deviceconfig import DeviceConfig
from ..exceptions import (
SMART_RETRYABLE_ERRORS,
SMART_TIMEOUT_ERRORS,
@ -58,7 +59,9 @@ async def test_handshake(
mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code)
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
transport = AesTransport(host=host, credentials=Credentials("foo", "bar"))
transport = AesTransport(
config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
)
assert transport._encryption_session is None
assert transport._handshake_done is False
@ -74,7 +77,9 @@ async def test_login(mocker, status_code, error_code, inner_error_code, expectat
mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code)
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
transport = AesTransport(host=host, credentials=Credentials("foo", "bar"))
transport = AesTransport(
config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
)
transport._handshake_done = True
transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session
@ -91,13 +96,14 @@ async def test_send(mocker, status_code, error_code, inner_error_code, expectati
mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code)
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
transport = AesTransport(host=host, credentials=Credentials("foo", "bar"))
transport = AesTransport(
config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
)
transport._handshake_done = True
transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session
transport._login_token = mock_aes_device.token
un, pw = transport.hash_credentials(True)
request = {
"method": "get_device_info",
"params": None,
@ -119,7 +125,8 @@ async def test_passthrough_errors(mocker, error_code):
mock_aes_device = MockAesDevice(host, 200, error_code, 0)
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
transport = AesTransport(host=host, credentials=Credentials("foo", "bar"))
config = DeviceConfig(host, credentials=Credentials("foo", "bar"))
transport = AesTransport(config=config)
transport._handshake_done = True
transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session

View File

@ -4,10 +4,26 @@ import asyncclick as click
import pytest
from asyncclick.testing import CliRunner
from kasa import AuthenticationException, SmartDevice, UnsupportedDeviceException
from kasa.cli import alias, brightness, cli, emeter, raw_command, state, sysinfo, toggle
from kasa.device_factory import DEVICE_TYPE_TO_CLASS
from kasa.discover import Discover
from kasa import (
AuthenticationException,
Credentials,
SmartDevice,
TPLinkSmartHomeProtocol,
UnsupportedDeviceException,
)
from kasa.cli import (
TYPE_TO_CLASS,
alias,
brightness,
cli,
emeter,
raw_command,
state,
sysinfo,
toggle,
)
from kasa.discover import Discover, DiscoveryResult
from kasa.smartprotocol import SmartProtocol
from .conftest import device_iot, handle_turn_on, new_discovery, turn_on
@ -145,9 +161,11 @@ async def test_credentials(discovery_mock, mocker):
)
mocker.patch("kasa.cli.state", new=_state)
for subclass in DEVICE_TYPE_TO_CLASS.values():
mocker.patch.object(subclass, "update")
mocker.patch("kasa.IotProtocol.query", return_value=discovery_mock.query_data)
mocker.patch("kasa.SmartProtocol.query", return_value=discovery_mock.query_data)
dr = DiscoveryResult(**discovery_mock.discovery_data["result"])
runner = CliRunner()
res = await runner.invoke(
cli,
@ -158,6 +176,10 @@ async def test_credentials(discovery_mock, mocker):
"foo",
"--password",
"bar",
"--device-family",
dr.device_type,
"--encrypt-type",
dr.mgt_encrypt_schm.encrypt_type,
],
)
assert res.exit_code == 0
@ -166,7 +188,7 @@ async def test_credentials(discovery_mock, mocker):
@device_iot
async def test_without_device_type(discovery_data: dict, dev, mocker):
async def test_without_device_type(dev, mocker):
"""Test connecting without the device type."""
runner = CliRunner()
mocker.patch("kasa.discover.Discover.discover_single", return_value=dev)
@ -342,3 +364,27 @@ async def test_host_auth_failed(discovery_mock, mocker):
assert res.exit_code != 0
assert isinstance(res.exception, AuthenticationException)
@pytest.mark.parametrize("device_type", list(TYPE_TO_CLASS))
async def test_type_param(device_type, mocker):
"""Test for handling only one of username or password supplied."""
runner = CliRunner()
result_device = FileNotFoundError
pass_dev = click.make_pass_decorator(SmartDevice)
@pass_dev
async def _state(dev: SmartDevice):
nonlocal result_device
result_device = dev
mocker.patch("kasa.cli.state", new=_state)
expected_type = TYPE_TO_CLASS[device_type]
mocker.patch.object(expected_type, "update")
res = await runner.invoke(
cli,
["--type", device_type, "--host", "127.0.0.1"],
)
assert res.exit_code == 0
assert isinstance(result_device, expected_type)

View File

@ -2,6 +2,7 @@
import logging
from typing import Type
import httpx
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
from kasa import (
@ -15,122 +16,138 @@ from kasa import (
SmartLightStrip,
SmartPlug,
)
from kasa.device_factory import (
DEVICE_TYPE_TO_CLASS,
connect,
get_protocol_from_connection_name,
from kasa.device_factory import connect, get_protocol
from kasa.deviceconfig import (
ConnectionType,
DeviceConfig,
DeviceFamilyType,
EncryptType,
)
from kasa.discover import DiscoveryResult
from kasa.iotprotocol import IotProtocol
from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol
@pytest.mark.parametrize("custom_port", [123, None])
async def test_connect(discovery_data: dict, mocker, custom_port):
"""Make sure that connect returns an initialized SmartDevice instance."""
host = "127.0.0.1"
def _get_connection_type_device_class(the_fixture_data):
if "discovery_result" in the_fixture_data:
discovery_info = {"result": the_fixture_data["discovery_result"]}
device_class = Discover._get_device_class(discovery_info)
dr = DiscoveryResult(**discovery_info["result"])
if "result" in discovery_data:
with pytest.raises(SmartDeviceException):
dev = await connect(host, port=custom_port)
connection_type = ConnectionType.from_values(
dr.device_type, dr.mgt_encrypt_schm.encrypt_type
)
else:
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
dev = await connect(host, port=custom_port)
assert issubclass(dev.__class__, SmartDevice)
assert dev.port == custom_port or dev.port == 9999
connection_type = ConnectionType.from_values(
DeviceFamilyType.IotSmartPlugSwitch.value, EncryptType.Xor.value
)
device_class = Discover._get_device_class(the_fixture_data)
return connection_type, device_class
@pytest.mark.parametrize("custom_port", [123, None])
@pytest.mark.parametrize(
("device_type", "klass"),
(
(DeviceType.Plug, SmartPlug),
(DeviceType.Bulb, SmartBulb),
(DeviceType.Dimmer, SmartDimmer),
(DeviceType.LightStrip, SmartLightStrip),
(DeviceType.Unknown, SmartDevice),
),
)
async def test_connect_passed_device_type(
discovery_data: dict,
mocker,
device_type: DeviceType,
klass: Type[SmartDevice],
custom_port,
):
"""Make sure that connect with a passed device type."""
host = "127.0.0.1"
if "result" in discovery_data:
with pytest.raises(SmartDeviceException):
dev = await connect(host, port=custom_port)
else:
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
dev = await connect(host, port=custom_port, device_type=device_type)
assert isinstance(dev, klass)
assert dev.port == custom_port or dev.port == 9999
async def test_connect_query_fails(discovery_data: dict, mocker):
"""Make sure that connect fails when query fails."""
host = "127.0.0.1"
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", side_effect=SmartDeviceException)
with pytest.raises(SmartDeviceException):
await connect(host)
async def test_connect_logs_connect_time(
discovery_data: dict, caplog: pytest.LogCaptureFixture, mocker
):
"""Test that the connect time is logged when debug logging is enabled."""
host = "127.0.0.1"
if "result" in discovery_data:
with pytest.raises(SmartDeviceException):
await connect(host)
else:
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
logging.getLogger("kasa").setLevel(logging.DEBUG)
await connect(host)
assert "seconds to connect" in caplog.text
async def test_connect_pass_protocol(
async def test_connect(
all_fixture_data: dict,
mocker,
):
"""Test that if the protocol is passed in it's gets set correctly."""
if "discovery_result" in all_fixture_data:
discovery_info = {"result": all_fixture_data["discovery_result"]}
device_class = Discover._get_device_class(discovery_info)
else:
device_class = Discover._get_device_class(all_fixture_data)
device_type = list(DEVICE_TYPE_TO_CLASS.keys())[
list(DEVICE_TYPE_TO_CLASS.values()).index(device_class)
]
"""Test that if the protocol is passed in it gets set correctly."""
host = "127.0.0.1"
if "discovery_result" in all_fixture_data:
mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data)
mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data)
ctype, device_class = _get_connection_type_device_class(all_fixture_data)
dr = DiscoveryResult(**discovery_info["result"])
connection_name = (
dr.device_type.split(".")[0] + "." + dr.mgt_encrypt_schm.encrypt_type
)
protocol_class = get_protocol_from_connection_name(
connection_name, host
).__class__
else:
mocker.patch(
"kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data
)
protocol_class = TPLinkSmartHomeProtocol
mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data)
mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data)
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data)
config = DeviceConfig(
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
)
protocol_class = get_protocol(config).__class__
dev = await connect(
host,
device_type=device_type,
protocol_class=protocol_class,
credentials=Credentials("", ""),
config=config,
)
assert isinstance(dev, device_class)
assert isinstance(dev.protocol, protocol_class)
assert dev.config == config
@pytest.mark.parametrize("custom_port", [123, None])
async def test_connect_custom_port(all_fixture_data: dict, mocker, custom_port):
"""Make sure that connect returns an initialized SmartDevice instance."""
host = "127.0.0.1"
ctype, _ = _get_connection_type_device_class(all_fixture_data)
config = DeviceConfig(host=host, port_override=custom_port, connection_type=ctype)
default_port = 80 if "discovery_result" in all_fixture_data else 9999
ctype, _ = _get_connection_type_device_class(all_fixture_data)
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data)
mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data)
mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data)
dev = await connect(config=config)
assert issubclass(dev.__class__, SmartDevice)
assert dev.port == custom_port or dev.port == default_port
async def test_connect_logs_connect_time(
all_fixture_data: dict, caplog: pytest.LogCaptureFixture, mocker
):
"""Test that the connect time is logged when debug logging is enabled."""
ctype, _ = _get_connection_type_device_class(all_fixture_data)
mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data)
mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data)
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data)
host = "127.0.0.1"
config = DeviceConfig(
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
)
logging.getLogger("kasa").setLevel(logging.DEBUG)
await connect(
config=config,
)
assert "seconds to update" in caplog.text
async def test_connect_query_fails(all_fixture_data: dict, mocker):
"""Make sure that connect fails when query fails."""
host = "127.0.0.1"
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", side_effect=SmartDeviceException)
mocker.patch("kasa.IotProtocol.query", side_effect=SmartDeviceException)
mocker.patch("kasa.SmartProtocol.query", side_effect=SmartDeviceException)
ctype, _ = _get_connection_type_device_class(all_fixture_data)
config = DeviceConfig(
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
)
with pytest.raises(SmartDeviceException):
await connect(config=config)
async def test_connect_http_client(all_fixture_data, mocker):
"""Make sure that discover_single returns an initialized SmartDevice instance."""
host = "127.0.0.1"
ctype, _ = _get_connection_type_device_class(all_fixture_data)
mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data)
mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data)
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data)
http_client = httpx.AsyncClient()
config = DeviceConfig(
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
)
dev = await connect(config=config)
if ctype.encryption_type != EncryptType.Xor:
assert dev.protocol._transport._http_client != http_client
config = DeviceConfig(
host=host,
credentials=Credentials("foor", "bar"),
connection_type=ctype,
http_client=http_client,
)
dev = await connect(config=config)
if ctype.encryption_type != EncryptType.Xor:
assert dev.protocol._transport._http_client == http_client

View File

@ -0,0 +1,21 @@
from json import dumps as json_dumps
from json import loads as json_loads
import httpx
from kasa.credentials import Credentials
from kasa.deviceconfig import (
ConnectionType,
DeviceConfig,
DeviceFamilyType,
EncryptType,
)
def test_serialization():
config = DeviceConfig(host="Foo", http_client=httpx.AsyncClient())
config_dict = config.to_dict()
config_json = json_dumps(config_dict)
config2_dict = json_loads(config_json)
config2 = DeviceConfig.from_dict(config2_dict)
assert config == config2

View File

@ -1,21 +1,29 @@
# type: ignore
import logging
import re
import socket
import httpx
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
from kasa import (
Credentials,
DeviceType,
Discover,
SmartDevice,
SmartDeviceException,
SmartStrip,
protocol,
)
from kasa.deviceconfig import (
ConnectionType,
DeviceConfig,
DeviceFamilyType,
EncryptType,
)
from kasa.discover import DiscoveryResult, _DiscoverProtocol, json_dumps
from kasa.exceptions import AuthenticationException, UnsupportedDeviceException
from .conftest import bulb, bulb_iot, dimmer, lightstrip, plug, strip
from .conftest import bulb, bulb_iot, dimmer, lightstrip, new_discovery, plug, strip
UNSUPPORTED = {
"result": {
@ -89,13 +97,26 @@ async def test_discover_single(discovery_mock, custom_port, mocker):
host = "127.0.0.1"
discovery_mock.ip = host
discovery_mock.port_override = custom_port
update_mock = mocker.patch.object(SmartStrip, "update")
x = await Discover.discover_single(host, port=custom_port)
device_class = Discover._get_device_class(discovery_mock.discovery_data)
update_mock = mocker.patch.object(device_class, "update")
x = await Discover.discover_single(
host, port=custom_port, credentials=Credentials()
)
assert issubclass(x.__class__, SmartDevice)
assert x._discovery_info is not None
assert x.port == custom_port or x.port == discovery_mock.default_port
assert (update_mock.call_count > 0) == isinstance(x, SmartStrip)
assert update_mock.call_count == 0
ct = ConnectionType.from_values(
discovery_mock.device_type, discovery_mock.encrypt_type
)
uses_http = discovery_mock.default_port == 80
config = DeviceConfig(
host=host, port_override=custom_port, connection_type=ct, uses_http=uses_http
)
assert x.config == config
async def test_discover_single_hostname(discovery_mock, mocker):
@ -104,47 +125,39 @@ async def test_discover_single_hostname(discovery_mock, mocker):
ip = "127.0.0.1"
discovery_mock.ip = ip
update_mock = mocker.patch.object(SmartStrip, "update")
device_class = Discover._get_device_class(discovery_mock.discovery_data)
update_mock = mocker.patch.object(device_class, "update")
x = await Discover.discover_single(host)
x = await Discover.discover_single(host, credentials=Credentials())
assert issubclass(x.__class__, SmartDevice)
assert x._discovery_info is not None
assert x.host == host
assert (update_mock.call_count > 0) == isinstance(x, SmartStrip)
assert update_mock.call_count == 0
mocker.patch("socket.getaddrinfo", side_effect=socket.gaierror())
with pytest.raises(SmartDeviceException):
x = await Discover.discover_single(host)
x = await Discover.discover_single(host, credentials=Credentials())
async def test_discover_single_unsupported(mocker):
async def test_discover_single_unsupported(unsupported_device_info, mocker):
"""Make sure that discover_single handles unsupported devices correctly."""
host = "127.0.0.1"
def mock_discover(self):
if discovery_data:
data = (
b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8"
+ json_dumps(discovery_data).encode()
)
self.datagram_received(data, (host, 20002))
mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover)
# Test with a valid unsupported response
discovery_data = UNSUPPORTED
with pytest.raises(
UnsupportedDeviceException,
match=f"Unsupported device {host} of type SMART.TAPOXMASTREE: {re.escape(str(UNSUPPORTED))}",
):
await Discover.discover_single(host)
# Test with no response
discovery_data = None
async def test_discover_single_no_response(mocker):
"""Make sure that discover_single handles no response correctly."""
host = "127.0.0.1"
mocker.patch.object(_DiscoverProtocol, "do_discover")
with pytest.raises(
SmartDeviceException, match=f"Timed out getting discovery response for {host}"
):
await Discover.discover_single(host, timeout=0.001)
await Discover.discover_single(host, discovery_timeout=0)
INVALIDS = [
@ -241,52 +254,82 @@ AUTHENTICATION_DATA_KLAP = {
}
async def test_discover_single_authentication(mocker):
@new_discovery
async def test_discover_single_authentication(discovery_mock, mocker):
"""Make sure that discover_single handles authenticating devices correctly."""
host = "127.0.0.1"
def mock_discover(self):
if discovery_data:
data = (
b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8"
+ json_dumps(discovery_data).encode()
)
self.datagram_received(data, (host, 20002))
mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover)
discovery_mock.ip = host
device_class = Discover._get_device_class(discovery_mock.discovery_data)
mocker.patch.object(
SmartDevice,
device_class,
"update",
side_effect=AuthenticationException("Failed to authenticate"),
)
# Test with a valid unsupported response
discovery_data = AUTHENTICATION_DATA_KLAP
with pytest.raises(
AuthenticationException,
match="Failed to authenticate",
):
device = await Discover.discover_single(host)
device = await Discover.discover_single(
host, credentials=Credentials("foo", "bar")
)
await device.update()
mocker.patch.object(SmartDevice, "update")
device = await Discover.discover_single(host)
mocker.patch.object(device_class, "update")
device = await Discover.discover_single(host, credentials=Credentials("foo", "bar"))
await device.update()
assert device.device_type == DeviceType.Plug
assert isinstance(device, device_class)
async def test_device_update_from_new_discovery_info():
@new_discovery
async def test_device_update_from_new_discovery_info(discovery_data):
device = SmartDevice("127.0.0.7")
discover_info = DiscoveryResult(**AUTHENTICATION_DATA_KLAP["result"])
discover_info = DiscoveryResult(**discovery_data["result"])
discover_dump = discover_info.get_dict()
discover_dump["alias"] = "foobar"
discover_dump["model"] = discover_dump["device_model"]
device.update_from_discover_info(discover_dump)
assert device.alias == discover_dump["alias"]
assert device.alias == "foobar"
assert device.mac == discover_dump["mac"].replace("-", ":")
assert device.model == discover_dump["model"]
assert device.model == discover_dump["device_model"]
with pytest.raises(
SmartDeviceException,
match=re.escape("You need to await update() to access the data"),
):
assert device.supported_modules
async def test_discover_single_http_client(discovery_mock, mocker):
"""Make sure that discover_single returns an initialized SmartDevice instance."""
host = "127.0.0.1"
discovery_mock.ip = host
http_client = httpx.AsyncClient()
x: SmartDevice = await Discover.discover_single(host)
assert x.config.uses_http == (discovery_mock.default_port == 80)
if discovery_mock.default_port == 80:
assert x.protocol._transport._http_client != http_client
x.config.http_client = http_client
assert x.protocol._transport._http_client == http_client
async def test_discover_http_client(discovery_mock, mocker):
"""Make sure that discover_single returns an initialized SmartDevice instance."""
host = "127.0.0.1"
discovery_mock.ip = host
http_client = httpx.AsyncClient()
devices = await Discover.discover(discovery_timeout=0)
x: SmartDevice = devices[host]
assert x.config.uses_http == (discovery_mock.default_port == 80)
if discovery_mock.default_port == 80:
assert x.protocol._transport._http_client != http_client
x.config.http_client = http_client
assert x.protocol._transport._http_client == http_client

View File

@ -12,9 +12,15 @@ import pytest
from ..aestransport import AesTransport
from ..credentials import Credentials
from ..deviceconfig import DeviceConfig
from ..exceptions import AuthenticationException, SmartDeviceException
from ..iotprotocol import IotProtocol
from ..klaptransport import KlapEncryptionSession, KlapTransport, _sha256
from ..klaptransport import (
KlapEncryptionSession,
KlapTransport,
KlapTransportV2,
_sha256,
)
from ..smartprotocol import SmartProtocol
DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}}
@ -31,8 +37,9 @@ class _mock_response:
[
(Exception("dummy exception"), True),
(SmartDeviceException("dummy exception"), False),
(httpx.ConnectError("dummy exception"), True),
],
ids=("Exception", "SmartDeviceException"),
ids=("Exception", "SmartDeviceException", "httpx.ConnectError"),
)
@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol])
@ -42,8 +49,10 @@ async def test_protocol_retries(
):
host = "127.0.0.1"
conn = mocker.patch.object(httpx.AsyncClient, "post", side_effect=error)
config = DeviceConfig(host)
with pytest.raises(SmartDeviceException):
await protocol_class(host, transport=transport_class(host)).query(
await protocol_class(transport=transport_class(config=config)).query(
DUMMY_QUERY, retry_count=retry_count
)
@ -60,10 +69,11 @@ async def test_protocol_no_retry_on_connection_error(
conn = mocker.patch.object(
httpx.AsyncClient,
"post",
side_effect=httpx.ConnectError("foo"),
side_effect=AuthenticationException("foo"),
)
config = DeviceConfig(host)
with pytest.raises(SmartDeviceException):
await protocol_class(host, transport=transport_class(host)).query(
await protocol_class(transport=transport_class(config=config)).query(
DUMMY_QUERY, retry_count=5
)
@ -81,8 +91,9 @@ async def test_protocol_retry_recoverable_error(
"post",
side_effect=httpx.CloseError("foo"),
)
config = DeviceConfig(host)
with pytest.raises(SmartDeviceException):
await protocol_class(host, transport=transport_class(host)).query(
await protocol_class(transport=transport_class(config=config)).query(
DUMMY_QUERY, retry_count=5
)
@ -115,7 +126,8 @@ async def test_protocol_reconnect(mocker, retry_count, protocol_class, transport
side_effect=_fail_one_less_than_retry_count,
)
response = await protocol_class(host, transport=transport_class(host)).query(
config = DeviceConfig(host)
response = await protocol_class(transport=transport_class(config=config)).query(
DUMMY_QUERY, retry_count=retry_count
)
assert "result" in response or "foobar" in response
@ -136,7 +148,9 @@ async def test_protocol_logging(mocker, caplog, log_level):
seed = secrets.token_bytes(16)
auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar"))
encryption_session = KlapEncryptionSession(seed, seed, auth_hash)
protocol = IotProtocol("127.0.0.1", transport=KlapTransport("127.0.0.1"))
config = DeviceConfig("127.0.0.1")
protocol = IotProtocol(transport=KlapTransport(config=config))
protocol._transport._handshake_done = True
protocol._transport._session_expire_at = time.time() + 86400
@ -181,7 +195,7 @@ def test_encrypt_unicode():
"device_credentials, expectation",
[
(Credentials("foo", "bar"), does_not_raise()),
(Credentials("", ""), does_not_raise()),
(Credentials(), does_not_raise()),
(
Credentials(
KlapTransport.KASA_SETUP_EMAIL,
@ -196,30 +210,37 @@ def test_encrypt_unicode():
],
ids=("client", "blank", "kasa_setup", "shouldfail"),
)
async def test_handshake1(mocker, device_credentials, expectation):
@pytest.mark.parametrize(
"transport_class, seed_auth_hash_calc",
[
pytest.param(KlapTransport, lambda c, s, a: c + a, id="KLAP"),
pytest.param(KlapTransportV2, lambda c, s, a: c + s + a, id="KLAPV2"),
],
)
async def test_handshake1(
mocker, device_credentials, expectation, transport_class, seed_auth_hash_calc
):
async def _return_handshake1_response(url, params=None, data=None, *_, **__):
nonlocal client_seed, server_seed, device_auth_hash
client_seed = data
client_seed_auth_hash = _sha256(data + device_auth_hash)
return _mock_response(200, server_seed + client_seed_auth_hash)
seed_auth_hash = _sha256(
seed_auth_hash_calc(client_seed, server_seed, device_auth_hash)
)
return _mock_response(200, server_seed + seed_auth_hash)
client_seed = None
server_seed = secrets.token_bytes(16)
client_credentials = Credentials("foo", "bar")
device_auth_hash = KlapTransport.generate_auth_hash(device_credentials)
device_auth_hash = transport_class.generate_auth_hash(device_credentials)
mocker.patch.object(
httpx.AsyncClient, "post", side_effect=_return_handshake1_response
)
protocol = IotProtocol(
"127.0.0.1",
transport=KlapTransport("127.0.0.1", credentials=client_credentials),
)
config = DeviceConfig("127.0.0.1", credentials=client_credentials)
protocol = IotProtocol(transport=transport_class(config=config))
protocol._transport.http_client = httpx.AsyncClient()
with expectation:
(
local_seed,
@ -233,31 +254,51 @@ async def test_handshake1(mocker, device_credentials, expectation):
await protocol.close()
async def test_handshake(mocker):
@pytest.mark.parametrize(
"transport_class, seed_auth_hash_calc1, seed_auth_hash_calc2",
[
pytest.param(
KlapTransport, lambda c, s, a: c + a, lambda c, s, a: s + a, id="KLAP"
),
pytest.param(
KlapTransportV2,
lambda c, s, a: c + s + a,
lambda c, s, a: s + c + a,
id="KLAPV2",
),
],
)
async def test_handshake(
mocker, transport_class, seed_auth_hash_calc1, seed_auth_hash_calc2
):
async def _return_handshake_response(url, params=None, data=None, *_, **__):
nonlocal response_status, client_seed, server_seed, device_auth_hash
nonlocal client_seed, server_seed, device_auth_hash
if url == "http://127.0.0.1/app/handshake1":
client_seed = data
client_seed_auth_hash = _sha256(data + device_auth_hash)
seed_auth_hash = _sha256(
seed_auth_hash_calc1(client_seed, server_seed, device_auth_hash)
)
return _mock_response(200, server_seed + client_seed_auth_hash)
return _mock_response(200, server_seed + seed_auth_hash)
elif url == "http://127.0.0.1/app/handshake2":
seed_auth_hash = _sha256(
seed_auth_hash_calc2(client_seed, server_seed, device_auth_hash)
)
assert data == seed_auth_hash
return _mock_response(response_status, b"")
client_seed = None
server_seed = secrets.token_bytes(16)
client_credentials = Credentials("foo", "bar")
device_auth_hash = KlapTransport.generate_auth_hash(client_credentials)
device_auth_hash = transport_class.generate_auth_hash(client_credentials)
mocker.patch.object(
httpx.AsyncClient, "post", side_effect=_return_handshake_response
)
protocol = IotProtocol(
"127.0.0.1",
transport=KlapTransport("127.0.0.1", credentials=client_credentials),
)
config = DeviceConfig("127.0.0.1", credentials=client_credentials)
protocol = IotProtocol(transport=transport_class(config=config))
protocol._transport.http_client = httpx.AsyncClient()
response_status = 200
@ -273,7 +314,7 @@ async def test_handshake(mocker):
async def test_query(mocker):
async def _return_response(url, params=None, data=None, *_, **__):
nonlocal client_seed, server_seed, device_auth_hash, protocol, seq
nonlocal client_seed, server_seed, device_auth_hash, seq
if url == "http://127.0.0.1/app/handshake1":
client_seed = data
@ -303,10 +344,8 @@ async def test_query(mocker):
mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response)
protocol = IotProtocol(
"127.0.0.1",
transport=KlapTransport("127.0.0.1", credentials=client_credentials),
)
config = DeviceConfig("127.0.0.1", credentials=client_credentials)
protocol = IotProtocol(transport=KlapTransport(config=config))
for _ in range(10):
resp = await protocol.query({})
@ -350,10 +389,8 @@ async def test_authentication_failures(mocker, response_status, expectation):
mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response)
protocol = IotProtocol(
"127.0.0.1",
transport=KlapTransport("127.0.0.1", credentials=client_credentials),
)
config = DeviceConfig("127.0.0.1", credentials=client_credentials)
protocol = IotProtocol(transport=KlapTransport(config=config))
with expectation:
await protocol.query({})

View File

@ -9,6 +9,7 @@ import sys
import pytest
from ..deviceconfig import DeviceConfig
from ..exceptions import SmartDeviceException
from ..protocol import (
BaseTransport,
@ -31,10 +32,11 @@ async def test_protocol_retries(mocker, retry_count):
return reader, writer
conn = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
config = DeviceConfig("127.0.0.1")
with pytest.raises(SmartDeviceException):
await TPLinkSmartHomeProtocol(
"127.0.0.1", transport=_XorTransport("127.0.0.1")
).query({}, retry_count=retry_count)
await TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)).query(
{}, retry_count=retry_count
)
assert conn.call_count == retry_count + 1
@ -44,10 +46,11 @@ async def test_protocol_no_retry_on_unreachable(mocker):
"asyncio.open_connection",
side_effect=OSError(errno.EHOSTUNREACH, "No route to host"),
)
config = DeviceConfig("127.0.0.1")
with pytest.raises(SmartDeviceException):
await TPLinkSmartHomeProtocol(
"127.0.0.1", transport=_XorTransport("127.0.0.1")
).query({}, retry_count=5)
await TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)).query(
{}, retry_count=5
)
assert conn.call_count == 1
@ -57,10 +60,11 @@ async def test_protocol_no_retry_connection_refused(mocker):
"asyncio.open_connection",
side_effect=ConnectionRefusedError,
)
config = DeviceConfig("127.0.0.1")
with pytest.raises(SmartDeviceException):
await TPLinkSmartHomeProtocol(
"127.0.0.1", transport=_XorTransport("127.0.0.1")
).query({}, retry_count=5)
await TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)).query(
{}, retry_count=5
)
assert conn.call_count == 1
@ -70,10 +74,11 @@ async def test_protocol_retry_recoverable_error(mocker):
"asyncio.open_connection",
side_effect=OSError(errno.ECONNRESET, "Connection reset by peer"),
)
config = DeviceConfig("127.0.0.1")
with pytest.raises(SmartDeviceException):
await TPLinkSmartHomeProtocol(
"127.0.0.1", transport=_XorTransport("127.0.0.1")
).query({}, retry_count=5)
await TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)).query(
{}, retry_count=5
)
assert conn.call_count == 6
@ -107,9 +112,8 @@ async def test_protocol_reconnect(mocker, retry_count):
mocker.patch.object(reader, "readexactly", _mock_read)
return reader, writer
protocol = TPLinkSmartHomeProtocol(
"127.0.0.1", transport=_XorTransport("127.0.0.1")
)
config = DeviceConfig("127.0.0.1")
protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config))
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
response = await protocol.query({}, retry_count=retry_count)
assert response == {"great": "success"}
@ -137,9 +141,8 @@ async def test_protocol_logging(mocker, caplog, log_level):
mocker.patch.object(reader, "readexactly", _mock_read)
return reader, writer
protocol = TPLinkSmartHomeProtocol(
"127.0.0.1", transport=_XorTransport("127.0.0.1")
)
config = DeviceConfig("127.0.0.1")
protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config))
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
response = await protocol.query({})
assert response == {"great": "success"}
@ -173,9 +176,8 @@ async def test_protocol_custom_port(mocker, custom_port):
mocker.patch.object(reader, "readexactly", _mock_read)
return reader, writer
protocol = TPLinkSmartHomeProtocol(
"127.0.0.1", transport=_XorTransport("127.0.0.1", port=custom_port)
)
config = DeviceConfig("127.0.0.1", port_override=custom_port)
protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config))
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
response = await protocol.query({})
assert response == {"great": "success"}
@ -271,18 +273,14 @@ def _get_subclasses(of_class):
def test_protocol_init_signature(class_name_obj):
params = list(inspect.signature(class_name_obj[1].__init__).parameters.values())
assert len(params) == 3
assert len(params) == 2
assert (
params[0].name == "self"
and params[0].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
)
assert (
params[1].name == "host"
and params[1].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
)
assert (
params[2].name == "transport"
and params[2].kind == inspect.Parameter.KEYWORD_ONLY
params[1].name == "transport"
and params[1].kind == inspect.Parameter.KEYWORD_ONLY
)
@ -292,20 +290,11 @@ def test_protocol_init_signature(class_name_obj):
def test_transport_init_signature(class_name_obj):
params = list(inspect.signature(class_name_obj[1].__init__).parameters.values())
assert len(params) == 5
assert len(params) == 2
assert (
params[0].name == "self"
and params[0].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
)
assert (
params[1].name == "host"
and params[1].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
)
assert params[2].name == "port" and params[2].kind == inspect.Parameter.KEYWORD_ONLY
assert (
params[3].name == "credentials"
and params[3].kind == inspect.Parameter.KEYWORD_ONLY
)
assert (
params[4].name == "timeout" and params[4].kind == inspect.Parameter.KEYWORD_ONLY
params[1].name == "config" and params[1].kind == inspect.Parameter.KEYWORD_ONLY
)

View File

@ -5,8 +5,7 @@ from unittest.mock import Mock, patch
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
import kasa
from kasa import Credentials, SmartDevice, SmartDeviceException
from kasa.smartdevice import DeviceType
from kasa import Credentials, DeviceConfig, SmartDevice, SmartDeviceException
from .conftest import device_iot, handle_turn_on, has_emeter, no_emeter_iot, turn_on
from .newfakes import PLUG_SCHEMA, TZ_SCHEMA, FakeTransportProtocol
@ -215,7 +214,8 @@ def test_device_class_ctors(device_class):
host = "127.0.0.2"
port = 1234
credentials = Credentials("foo", "bar")
dev = device_class(host, port=port, credentials=credentials)
config = DeviceConfig(host, port_override=port, credentials=credentials)
dev = device_class(host, config=config)
assert dev.host == host
assert dev.port == port
assert dev.credentials == credentials
@ -231,29 +231,27 @@ async def test_modules_preserved(dev: SmartDevice):
async def test_create_smart_device_with_timeout():
"""Make sure timeout is passed to the protocol."""
dev = SmartDevice(host="127.0.0.1", timeout=100)
host = "127.0.0.1"
dev = SmartDevice(host, config=DeviceConfig(host, timeout=100))
assert dev.protocol._transport._timeout == 100
async def test_create_thin_wrapper():
"""Make sure thin wrapper is created with the correct device type."""
mock = Mock()
config = DeviceConfig(
host="test_host",
port_override=1234,
timeout=100,
credentials=Credentials("username", "password"),
)
with patch("kasa.device_factory.connect", return_value=mock) as connect:
dev = await SmartDevice.connect(
host="test_host",
port=1234,
timeout=100,
credentials=Credentials("username", "password"),
device_type=DeviceType.Strip,
)
dev = await SmartDevice.connect(config=config)
assert dev is mock
connect.assert_called_once_with(
host="test_host",
port=1234,
timeout=100,
credentials=Credentials("username", "password"),
device_type=DeviceType.Strip,
host=None,
config=config,
)

View File

@ -13,6 +13,7 @@ import pytest
from ..aestransport import AesTransport
from ..credentials import Credentials
from ..deviceconfig import DeviceConfig
from ..exceptions import (
SMART_RETRYABLE_ERRORS,
SMART_TIMEOUT_ERRORS,
@ -37,7 +38,8 @@ async def test_smart_device_errors(mocker, error_code):
send_mock = mocker.patch.object(AesTransport, "send", return_value=mock_response)
protocol = SmartProtocol(host, transport=AesTransport(host))
config = DeviceConfig(host, credentials=Credentials("foo", "bar"))
protocol = SmartProtocol(transport=AesTransport(config=config))
with pytest.raises(SmartDeviceException):
await protocol.query(DUMMY_QUERY, retry_count=2)
@ -70,8 +72,8 @@ async def test_smart_device_errors_in_multiple_request(mocker, error_code):
mocker.patch.object(AesTransport, "perform_login")
send_mock = mocker.patch.object(AesTransport, "send", return_value=mock_response)
protocol = SmartProtocol(host, transport=AesTransport(host))
config = DeviceConfig(host, credentials=Credentials("foo", "bar"))
protocol = SmartProtocol(transport=AesTransport(config=config))
with pytest.raises(SmartDeviceException):
await protocol.query(DUMMY_QUERY, retry_count=2)
if error_code in chain(SMART_TIMEOUT_ERRORS, SMART_RETRYABLE_ERRORS):