mirror of
https://github.com/python-kasa/python-kasa.git
synced 2024-12-22 11:13:34 +00:00
Add DeviceConfig to allow specifying configuration parameters (#569)
* Add DeviceConfig handling * Update post review * Further update post latest review * Update following latest review * Update docstrings and docs
This commit is contained in:
parent
ec3ea39a37
commit
f6fd898faf
@ -23,9 +23,12 @@ This will return you a list of device instances based on the discovery replies.
|
||||
If the device's host is already known, you can use to construct a device instance with
|
||||
:meth:`~kasa.SmartDevice.connect()`.
|
||||
|
||||
When connecting a device with the :meth:`~kasa.SmartDevice.connect()` method, it is recommended to
|
||||
pass the device type as well as this allows the library to use the correct device class for the
|
||||
device without having to query the device.
|
||||
The :meth:`~kasa.SmartDevice.connect()` also enables support for connecting to new
|
||||
KASA SMART protocol and TAPO devices directly using the parameter :class:`~kasa.DeviceConfig`.
|
||||
Simply serialize the :attr:`~kasa.SmartDevice.config` property via :meth:`~kasa.DeviceConfig.to_dict()`
|
||||
and then deserialize it later with :func:`~kasa.DeviceConfig.from_dict()`
|
||||
and then pass it into :meth:`~kasa.SmartDevice.connect()`.
|
||||
|
||||
|
||||
.. _update_cycle:
|
||||
|
||||
|
18
docs/source/deviceconfig.rst
Normal file
18
docs/source/deviceconfig.rst
Normal file
@ -0,0 +1,18 @@
|
||||
DeviceConfig
|
||||
============
|
||||
|
||||
.. contents:: Contents
|
||||
:local:
|
||||
|
||||
.. note::
|
||||
|
||||
Feel free to open a pull request to improve the documentation!
|
||||
|
||||
|
||||
API documentation
|
||||
*****************
|
||||
|
||||
.. autoclass:: kasa.DeviceConfig
|
||||
:members:
|
||||
:inherited-members:
|
||||
:undoc-members:
|
@ -15,3 +15,4 @@
|
||||
smartdimmer
|
||||
smartstrip
|
||||
smartlightstrip
|
||||
deviceconfig
|
||||
|
@ -14,6 +14,12 @@ to be handled by the user of the library.
|
||||
from importlib.metadata import version
|
||||
|
||||
from kasa.credentials import Credentials
|
||||
from kasa.deviceconfig import (
|
||||
ConnectionType,
|
||||
DeviceConfig,
|
||||
DeviceFamilyType,
|
||||
EncryptType,
|
||||
)
|
||||
from kasa.discover import Discover
|
||||
from kasa.emeterstatus import EmeterStatus
|
||||
from kasa.exceptions import (
|
||||
@ -55,4 +61,8 @@ __all__ = [
|
||||
"AuthenticationException",
|
||||
"UnsupportedDeviceException",
|
||||
"Credentials",
|
||||
"DeviceConfig",
|
||||
"ConnectionType",
|
||||
"EncryptType",
|
||||
"DeviceFamilyType",
|
||||
]
|
||||
|
@ -16,7 +16,7 @@ from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padd
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
|
||||
from .credentials import Credentials
|
||||
from .deviceconfig import DeviceConfig
|
||||
from .exceptions import (
|
||||
SMART_AUTHENTICATION_ERRORS,
|
||||
SMART_RETRYABLE_ERRORS,
|
||||
@ -47,8 +47,7 @@ class AesTransport(BaseTransport):
|
||||
protocol, sometimes used by newer firmware versions on kasa devices.
|
||||
"""
|
||||
|
||||
DEFAULT_PORT = 80
|
||||
DEFAULT_TIMEOUT = 5
|
||||
DEFAULT_PORT: int = 80
|
||||
SESSION_COOKIE_NAME = "TP_SESSIONID"
|
||||
COMMON_HEADERS = {
|
||||
"Content-Type": "application/json",
|
||||
@ -58,32 +57,37 @@ class AesTransport(BaseTransport):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
*,
|
||||
port: Optional[int] = None,
|
||||
credentials: Optional[Credentials] = None,
|
||||
timeout: Optional[int] = None,
|
||||
config: DeviceConfig,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
host,
|
||||
port=port or self.DEFAULT_PORT,
|
||||
credentials=credentials,
|
||||
timeout=timeout,
|
||||
)
|
||||
super().__init__(config=config)
|
||||
|
||||
self._default_http_client: Optional[httpx.AsyncClient] = None
|
||||
|
||||
self._handshake_done = False
|
||||
|
||||
self._encryption_session: Optional[AesEncyptionSession] = None
|
||||
self._session_expire_at: Optional[float] = None
|
||||
|
||||
self._timeout = timeout if timeout else self.DEFAULT_TIMEOUT
|
||||
self._session_cookie = None
|
||||
|
||||
self._http_client: httpx.AsyncClient = httpx.AsyncClient()
|
||||
self._login_token = None
|
||||
|
||||
_LOGGER.debug("Created AES transport for %s", self._host)
|
||||
|
||||
@property
|
||||
def default_port(self):
|
||||
"""Default port for the transport."""
|
||||
return self.DEFAULT_PORT
|
||||
|
||||
@property
|
||||
def _http_client(self) -> httpx.AsyncClient:
|
||||
if self._config.http_client:
|
||||
return self._config.http_client
|
||||
if not self._default_http_client:
|
||||
self._default_http_client = httpx.AsyncClient()
|
||||
return self._default_http_client
|
||||
|
||||
def hash_credentials(self, login_v2):
|
||||
"""Hash the credentials."""
|
||||
if login_v2:
|
||||
@ -102,8 +106,6 @@ class AesTransport(BaseTransport):
|
||||
|
||||
async def client_post(self, url, params=None, data=None, json=None, headers=None):
|
||||
"""Send an http post request to the device."""
|
||||
if not self._http_client:
|
||||
self._http_client = httpx.AsyncClient()
|
||||
response_data = None
|
||||
cookies = None
|
||||
if self._session_cookie:
|
||||
@ -268,8 +270,8 @@ class AesTransport(BaseTransport):
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the protocol."""
|
||||
client = self._http_client
|
||||
self._http_client = None
|
||||
client = self._default_http_client
|
||||
self._default_http_client = None
|
||||
self._handshake_done = False
|
||||
self._login_token = None
|
||||
if client:
|
||||
|
67
kasa/cli.py
67
kasa/cli.py
@ -12,15 +12,20 @@ import asyncclick as click
|
||||
|
||||
from kasa import (
|
||||
AuthenticationException,
|
||||
ConnectionType,
|
||||
Credentials,
|
||||
DeviceType,
|
||||
DeviceConfig,
|
||||
DeviceFamilyType,
|
||||
Discover,
|
||||
EncryptType,
|
||||
SmartBulb,
|
||||
SmartDevice,
|
||||
SmartDimmer,
|
||||
SmartLightStrip,
|
||||
SmartPlug,
|
||||
SmartStrip,
|
||||
UnsupportedDeviceException,
|
||||
)
|
||||
from kasa.device_factory import DEVICE_TYPE_TO_CLASS
|
||||
from kasa.discover import DiscoveryResult
|
||||
|
||||
try:
|
||||
@ -49,10 +54,19 @@ except ImportError:
|
||||
# --json has set it to _nop_echo
|
||||
echo = _do_echo
|
||||
|
||||
DEVICE_TYPES = [
|
||||
device_type.value
|
||||
for device_type in DeviceType
|
||||
if device_type in DEVICE_TYPE_TO_CLASS
|
||||
|
||||
TYPE_TO_CLASS = {
|
||||
"plug": SmartPlug,
|
||||
"bulb": SmartBulb,
|
||||
"dimmer": SmartDimmer,
|
||||
"strip": SmartStrip,
|
||||
"lightstrip": SmartLightStrip,
|
||||
}
|
||||
|
||||
ENCRYPT_TYPES = [encrypt_type.value for encrypt_type in EncryptType]
|
||||
|
||||
DEVICE_FAMILY_TYPES = [
|
||||
device_family_type.value for device_family_type in DeviceFamilyType
|
||||
]
|
||||
|
||||
click.anyio_backend = "asyncio"
|
||||
@ -149,7 +163,7 @@ def json_formatter_cb(result, **kwargs):
|
||||
"--type",
|
||||
envvar="KASA_TYPE",
|
||||
default=None,
|
||||
type=click.Choice(DEVICE_TYPES, case_sensitive=False),
|
||||
type=click.Choice(list(TYPE_TO_CLASS), case_sensitive=False),
|
||||
)
|
||||
@click.option(
|
||||
"--json/--no-json",
|
||||
@ -158,6 +172,18 @@ def json_formatter_cb(result, **kwargs):
|
||||
is_flag=True,
|
||||
help="Output raw device response as JSON.",
|
||||
)
|
||||
@click.option(
|
||||
"--encrypt-type",
|
||||
envvar="KASA_ENCRYPT_TYPE",
|
||||
default=None,
|
||||
type=click.Choice(ENCRYPT_TYPES, case_sensitive=False),
|
||||
)
|
||||
@click.option(
|
||||
"--device-family",
|
||||
envvar="KASA_DEVICE_FAMILY",
|
||||
default=None,
|
||||
type=click.Choice(DEVICE_FAMILY_TYPES, case_sensitive=False),
|
||||
)
|
||||
@click.option(
|
||||
"--timeout",
|
||||
envvar="KASA_TIMEOUT",
|
||||
@ -199,6 +225,8 @@ async def cli(
|
||||
verbose,
|
||||
debug,
|
||||
type,
|
||||
encrypt_type,
|
||||
device_family,
|
||||
json,
|
||||
timeout,
|
||||
discovery_timeout,
|
||||
@ -270,12 +298,19 @@ async def cli(
|
||||
return await ctx.invoke(discover)
|
||||
|
||||
if type is not None:
|
||||
device_type = DeviceType.from_value(type)
|
||||
dev = await SmartDevice.connect(
|
||||
host, credentials=credentials, device_type=device_type, timeout=timeout
|
||||
dev = TYPE_TO_CLASS[type](host)
|
||||
await dev.update()
|
||||
elif device_family or encrypt_type:
|
||||
ctype = ConnectionType(
|
||||
DeviceFamilyType(device_family),
|
||||
EncryptType(encrypt_type),
|
||||
)
|
||||
config = DeviceConfig(
|
||||
host=host, credentials=credentials, timeout=timeout, connection_type=ctype
|
||||
)
|
||||
dev = await SmartDevice.connect(config=config)
|
||||
else:
|
||||
echo("No --type defined, discovering..")
|
||||
echo("No --type or --device-family and --encrypt-type defined, discovering..")
|
||||
dev = await Discover.discover_single(
|
||||
host,
|
||||
port=port,
|
||||
@ -332,8 +367,10 @@ async def discover(ctx):
|
||||
target = ctx.parent.params["target"]
|
||||
username = ctx.parent.params["username"]
|
||||
password = ctx.parent.params["password"]
|
||||
timeout = ctx.parent.params["discovery_timeout"]
|
||||
verbose = ctx.parent.params["verbose"]
|
||||
discovery_timeout = ctx.parent.params["discovery_timeout"]
|
||||
timeout = ctx.parent.params["timeout"]
|
||||
port = ctx.parent.params["port"]
|
||||
|
||||
credentials = Credentials(username, password)
|
||||
|
||||
@ -354,7 +391,7 @@ async def discover(ctx):
|
||||
echo(f"\t{unsupported_exception}")
|
||||
echo()
|
||||
|
||||
echo(f"Discovering devices on {target} for {timeout} seconds")
|
||||
echo(f"Discovering devices on {target} for {discovery_timeout} seconds")
|
||||
|
||||
async def print_discovered(dev: SmartDevice):
|
||||
async with sem:
|
||||
@ -376,9 +413,11 @@ async def discover(ctx):
|
||||
|
||||
await Discover.discover(
|
||||
target=target,
|
||||
timeout=timeout,
|
||||
discovery_timeout=discovery_timeout,
|
||||
on_discovered=print_discovered,
|
||||
on_unsupported=print_unsupported,
|
||||
port=port,
|
||||
timeout=timeout,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
|
@ -8,5 +8,5 @@ from typing import Optional
|
||||
class Credentials:
|
||||
"""Credentials for authentication."""
|
||||
|
||||
username: Optional[str] = field(default=None, repr=False)
|
||||
password: Optional[str] = field(default=None, repr=False)
|
||||
username: Optional[str] = field(default="", repr=False)
|
||||
password: Optional[str] = field(default="", repr=False)
|
||||
|
@ -1,18 +1,21 @@
|
||||
"""Device creation by type."""
|
||||
|
||||
"""Device creation via DeviceConfig."""
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, Optional, Tuple, Type
|
||||
|
||||
from .aestransport import AesTransport
|
||||
from .credentials import Credentials
|
||||
from .device_type import DeviceType
|
||||
from .exceptions import UnsupportedDeviceException
|
||||
from .deviceconfig import DeviceConfig
|
||||
from .exceptions import SmartDeviceException, UnsupportedDeviceException
|
||||
from .iotprotocol import IotProtocol
|
||||
from .klaptransport import KlapTransport, TPlinkKlapTransportV2
|
||||
from .protocol import BaseTransport, TPLinkProtocol
|
||||
from .klaptransport import KlapTransport, KlapTransportV2
|
||||
from .protocol import (
|
||||
BaseTransport,
|
||||
TPLinkProtocol,
|
||||
TPLinkSmartHomeProtocol,
|
||||
_XorTransport,
|
||||
)
|
||||
from .smartbulb import SmartBulb
|
||||
from .smartdevice import SmartDevice, SmartDeviceException
|
||||
from .smartdevice import SmartDevice
|
||||
from .smartdimmer import SmartDimmer
|
||||
from .smartlightstrip import SmartLightStrip
|
||||
from .smartplug import SmartPlug
|
||||
@ -20,104 +23,80 @@ from .smartprotocol import SmartProtocol
|
||||
from .smartstrip import SmartStrip
|
||||
from .tapo import TapoBulb, TapoPlug
|
||||
|
||||
DEVICE_TYPE_TO_CLASS = {
|
||||
DeviceType.Plug: SmartPlug,
|
||||
DeviceType.Bulb: SmartBulb,
|
||||
DeviceType.Strip: SmartStrip,
|
||||
DeviceType.Dimmer: SmartDimmer,
|
||||
DeviceType.LightStrip: SmartLightStrip,
|
||||
DeviceType.TapoPlug: TapoPlug,
|
||||
DeviceType.TapoBulb: TapoBulb,
|
||||
}
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
GET_SYSINFO_QUERY = {
|
||||
"system": {"get_sysinfo": None},
|
||||
}
|
||||
|
||||
async def connect(
|
||||
host: str,
|
||||
*,
|
||||
port: Optional[int] = None,
|
||||
timeout=5,
|
||||
credentials: Optional[Credentials] = None,
|
||||
device_type: Optional[DeviceType] = None,
|
||||
protocol_class: Optional[Type[TPLinkProtocol]] = None,
|
||||
) -> "SmartDevice":
|
||||
"""Connect to a single device by the given IP address.
|
||||
|
||||
async def connect(*, host: Optional[str] = None, config: DeviceConfig) -> "SmartDevice":
|
||||
"""Connect to a single device by the given hostname or device configuration.
|
||||
|
||||
This method avoids the UDP based discovery process and
|
||||
will connect directly to the device to query its type.
|
||||
will connect directly to the device.
|
||||
|
||||
It is generally preferred to avoid :func:`discover_single()` and
|
||||
use this function instead as it should perform better when
|
||||
the WiFi network is congested or the device is not responding
|
||||
to discovery requests.
|
||||
|
||||
The device type is discovered by querying the device.
|
||||
Do not use this function directly, use SmartDevice.connect()
|
||||
|
||||
:param host: Hostname of device to query
|
||||
:param device_type: Device type to use for the device.
|
||||
If not given, the device type is discovered by querying the device.
|
||||
If the device type is already known, it is preferred to pass it
|
||||
to avoid the extra query to the device to discover its type.
|
||||
:param protocol_class: Optionally provide the protocol class
|
||||
to use.
|
||||
:param config: Connection parameters to ensure the correct protocol
|
||||
and connection options are used.
|
||||
:rtype: SmartDevice
|
||||
:return: Object for querying/controlling found device.
|
||||
"""
|
||||
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
|
||||
if host and config or (not host and not config):
|
||||
raise SmartDeviceException("One of host or config must be provded and not both")
|
||||
if host:
|
||||
config = DeviceConfig(host=host)
|
||||
|
||||
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
|
||||
if debug_enabled:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if device_type and (klass := DEVICE_TYPE_TO_CLASS.get(device_type)):
|
||||
dev: SmartDevice = klass(
|
||||
host=host, port=port, credentials=credentials, timeout=timeout
|
||||
)
|
||||
if protocol_class is not None:
|
||||
dev.protocol = protocol_class(
|
||||
host,
|
||||
transport=AesTransport(
|
||||
host, port=port, credentials=credentials, timeout=timeout
|
||||
),
|
||||
)
|
||||
await dev.update()
|
||||
def _perf_log(has_params, perf_type):
|
||||
nonlocal start_time
|
||||
if debug_enabled:
|
||||
end_time = time.perf_counter()
|
||||
_LOGGER.debug(
|
||||
"Device %s with known type (%s) took %.2f seconds to connect",
|
||||
host,
|
||||
device_type.value,
|
||||
end_time - start_time,
|
||||
f"Device {config.host} with connection params {has_params} "
|
||||
+ f"took {end_time - start_time:.2f} seconds to {perf_type}",
|
||||
)
|
||||
return dev
|
||||
start_time = time.perf_counter()
|
||||
|
||||
unknown_dev = SmartDevice(
|
||||
host=host, port=port, credentials=credentials, timeout=timeout
|
||||
)
|
||||
if protocol_class is not None:
|
||||
# TODO this will be replaced with connection params
|
||||
unknown_dev.protocol = protocol_class(
|
||||
host,
|
||||
transport=AesTransport(
|
||||
host, port=port, credentials=credentials, timeout=timeout
|
||||
),
|
||||
if (protocol := get_protocol(config=config)) is None:
|
||||
raise UnsupportedDeviceException(
|
||||
f"Unsupported device for {config.host}: "
|
||||
+ f"{config.connection_type.device_family.value}"
|
||||
)
|
||||
await unknown_dev.update()
|
||||
device_class = get_device_class_from_sys_info(unknown_dev.internal_state)
|
||||
dev = device_class(host=host, port=port, credentials=credentials, timeout=timeout)
|
||||
# Reuse the connection from the unknown device
|
||||
# so we don't have to reconnect
|
||||
dev.protocol = unknown_dev.protocol
|
||||
await dev.update()
|
||||
if debug_enabled:
|
||||
end_time = time.perf_counter()
|
||||
_LOGGER.debug(
|
||||
"Device %s with unknown type (%s) took %.2f seconds to connect",
|
||||
host,
|
||||
dev.device_type.value,
|
||||
end_time - start_time,
|
||||
|
||||
device_class: Optional[Type[SmartDevice]]
|
||||
|
||||
if isinstance(protocol, TPLinkSmartHomeProtocol):
|
||||
info = await protocol.query(GET_SYSINFO_QUERY)
|
||||
_perf_log(True, "get_sysinfo")
|
||||
device_class = get_device_class_from_sys_info(info)
|
||||
device = device_class(config.host, protocol=protocol)
|
||||
device.update_from_discover_info(info)
|
||||
await device.update()
|
||||
_perf_log(True, "update")
|
||||
return device
|
||||
elif device_class := get_device_class_from_family(
|
||||
config.connection_type.device_family.value
|
||||
):
|
||||
device = device_class(host=config.host, protocol=protocol)
|
||||
await device.update()
|
||||
_perf_log(True, "update")
|
||||
return device
|
||||
else:
|
||||
raise UnsupportedDeviceException(
|
||||
f"Unsupported device for {config.host}: "
|
||||
+ f"{config.connection_type.device_family.value}"
|
||||
)
|
||||
return dev
|
||||
|
||||
|
||||
def get_device_class_from_sys_info(info: Dict[str, Any]) -> Type[SmartDevice]:
|
||||
@ -147,32 +126,38 @@ def get_device_class_from_sys_info(info: Dict[str, Any]) -> Type[SmartDevice]:
|
||||
raise UnsupportedDeviceException("Unknown device type: %s" % type_)
|
||||
|
||||
|
||||
def get_device_class_from_type_name(device_type: str) -> Optional[Type[SmartDevice]]:
|
||||
def get_device_class_from_family(device_type: str) -> Optional[Type[SmartDevice]]:
|
||||
"""Return the device class from the type name."""
|
||||
supported_device_types: dict[str, Type[SmartDevice]] = {
|
||||
"SMART.TAPOPLUG": TapoPlug,
|
||||
"SMART.TAPOBULB": TapoBulb,
|
||||
"SMART.KASAPLUG": TapoPlug,
|
||||
"IOT.SMARTPLUGSWITCH": SmartPlug,
|
||||
"IOT.SMARTBULB": SmartBulb,
|
||||
}
|
||||
return supported_device_types.get(device_type)
|
||||
|
||||
|
||||
def get_protocol_from_connection_name(
|
||||
connection_name: str, host: str, credentials: Optional[Credentials] = None
|
||||
def get_protocol(
|
||||
config: DeviceConfig,
|
||||
) -> Optional[TPLinkProtocol]:
|
||||
"""Return the protocol from the connection name."""
|
||||
protocol_name = config.connection_type.device_family.value.split(".")[0]
|
||||
protocol_transport_key = (
|
||||
protocol_name + "." + config.connection_type.encryption_type.value
|
||||
)
|
||||
supported_device_protocols: dict[
|
||||
str, Tuple[Type[TPLinkProtocol], Type[BaseTransport]]
|
||||
] = {
|
||||
"IOT.XOR": (TPLinkSmartHomeProtocol, _XorTransport),
|
||||
"IOT.KLAP": (IotProtocol, KlapTransport),
|
||||
"SMART.AES": (SmartProtocol, AesTransport),
|
||||
"SMART.KLAP": (SmartProtocol, TPlinkKlapTransportV2),
|
||||
"SMART.KLAP": (SmartProtocol, KlapTransportV2),
|
||||
}
|
||||
if connection_name not in supported_device_protocols:
|
||||
if protocol_transport_key not in supported_device_protocols:
|
||||
return None
|
||||
|
||||
protocol_class, transport_class = supported_device_protocols.get(connection_name) # type: ignore
|
||||
transport: BaseTransport = transport_class(host, credentials=credentials)
|
||||
protocol: TPLinkProtocol = protocol_class(host, transport=transport)
|
||||
return protocol
|
||||
protocol_class, transport_class = supported_device_protocols.get(
|
||||
protocol_transport_key
|
||||
) # type: ignore
|
||||
return protocol_class(transport=transport_class(config=config))
|
||||
|
148
kasa/deviceconfig.py
Normal file
148
kasa/deviceconfig.py
Normal file
@ -0,0 +1,148 @@
|
||||
"""Module for holding connection parameters."""
|
||||
import logging
|
||||
from dataclasses import asdict, dataclass, field, fields, is_dataclass
|
||||
from enum import Enum
|
||||
from typing import Dict, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from .credentials import Credentials
|
||||
from .exceptions import SmartDeviceException
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EncryptType(Enum):
|
||||
"""Encrypt type enum."""
|
||||
|
||||
Klap = "KLAP"
|
||||
Aes = "AES"
|
||||
Xor = "XOR"
|
||||
|
||||
|
||||
class DeviceFamilyType(Enum):
|
||||
"""Encrypt type enum."""
|
||||
|
||||
IotSmartPlugSwitch = "IOT.SMARTPLUGSWITCH"
|
||||
IotSmartBulb = "IOT.SMARTBULB"
|
||||
SmartKasaPlug = "SMART.KASAPLUG"
|
||||
SmartTapoPlug = "SMART.TAPOPLUG"
|
||||
SmartTapoBulb = "SMART.TAPOBULB"
|
||||
|
||||
|
||||
def _dataclass_from_dict(klass, in_val):
|
||||
if is_dataclass(klass):
|
||||
fieldtypes = {f.name: f.type for f in fields(klass)}
|
||||
val = {}
|
||||
for dict_key in in_val:
|
||||
if dict_key in fieldtypes and hasattr(fieldtypes[dict_key], "from_dict"):
|
||||
val[dict_key] = fieldtypes[dict_key].from_dict(in_val[dict_key])
|
||||
else:
|
||||
val[dict_key] = _dataclass_from_dict(
|
||||
fieldtypes[dict_key], in_val[dict_key]
|
||||
)
|
||||
return klass(**val)
|
||||
else:
|
||||
return in_val
|
||||
|
||||
|
||||
def _dataclass_to_dict(in_val):
|
||||
fieldtypes = {f.name: f.type for f in fields(in_val) if f.compare}
|
||||
out_val = {}
|
||||
for field_name in fieldtypes:
|
||||
val = getattr(in_val, field_name)
|
||||
if val is None:
|
||||
continue
|
||||
elif hasattr(val, "to_dict"):
|
||||
out_val[field_name] = val.to_dict()
|
||||
elif is_dataclass(fieldtypes[field_name]):
|
||||
out_val[field_name] = asdict(val)
|
||||
else:
|
||||
out_val[field_name] = val
|
||||
return out_val
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionType:
|
||||
"""Class to hold the the parameters determining connection type."""
|
||||
|
||||
device_family: DeviceFamilyType
|
||||
encryption_type: EncryptType
|
||||
|
||||
@staticmethod
|
||||
def from_values(
|
||||
device_family: str,
|
||||
encryption_type: str,
|
||||
) -> "ConnectionType":
|
||||
"""Return connection parameters from string values."""
|
||||
try:
|
||||
return ConnectionType(
|
||||
DeviceFamilyType(device_family),
|
||||
EncryptType(encryption_type),
|
||||
)
|
||||
except ValueError as ex:
|
||||
raise SmartDeviceException(
|
||||
f"Invalid connection parameters for {device_family}.{encryption_type}"
|
||||
) from ex
|
||||
|
||||
@staticmethod
|
||||
def from_dict(connection_type_dict: Dict[str, str]) -> "ConnectionType":
|
||||
"""Return connection parameters from dict."""
|
||||
if (
|
||||
isinstance(connection_type_dict, dict)
|
||||
and (device_family := connection_type_dict.get("device_family"))
|
||||
and (encryption_type := connection_type_dict.get("encryption_type"))
|
||||
):
|
||||
return ConnectionType.from_values(device_family, encryption_type)
|
||||
|
||||
raise SmartDeviceException(
|
||||
f"Invalid connection type data for {connection_type_dict}"
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, str]:
|
||||
"""Convert connection params to dict."""
|
||||
result = {
|
||||
"device_family": self.device_family.value,
|
||||
"encryption_type": self.encryption_type.value,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeviceConfig:
|
||||
"""Class to represent paramaters that determine how to connect to devices."""
|
||||
|
||||
DEFAULT_TIMEOUT = 5
|
||||
|
||||
host: str
|
||||
timeout: Optional[int] = DEFAULT_TIMEOUT
|
||||
port_override: Optional[int] = None
|
||||
credentials: Credentials = field(
|
||||
default_factory=lambda: Credentials(username="", password="")
|
||||
)
|
||||
connection_type: ConnectionType = field(
|
||||
default_factory=lambda: ConnectionType(
|
||||
DeviceFamilyType.IotSmartPlugSwitch, EncryptType.Xor
|
||||
)
|
||||
)
|
||||
|
||||
uses_http: bool = False
|
||||
# compare=False will be excluded from the serialization and object comparison.
|
||||
http_client: Optional[httpx.AsyncClient] = field(default=None, compare=False)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.credentials is None:
|
||||
self.credentials = Credentials(username="", password="")
|
||||
if self.connection_type is None:
|
||||
self.connection_type = ConnectionType(
|
||||
DeviceFamilyType.IotSmartPlugSwitch, EncryptType.Xor
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Dict[str, str]]:
|
||||
"""Convert connection params to dict."""
|
||||
return _dataclass_to_dict(self)
|
||||
|
||||
@staticmethod
|
||||
def from_dict(cparam_dict: Dict[str, Dict[str, str]]) -> "DeviceConfig":
|
||||
"""Return connection parameters from dict."""
|
||||
return _dataclass_from_dict(DeviceConfig, cparam_dict)
|
160
kasa/discover.py
160
kasa/discover.py
@ -1,5 +1,6 @@
|
||||
"""Discovery module for TP-Link Smart Home devices."""
|
||||
import asyncio
|
||||
import base64
|
||||
import binascii
|
||||
import ipaddress
|
||||
import logging
|
||||
@ -11,29 +12,32 @@ from typing import Awaitable, Callable, Dict, Optional, Set, Type, cast
|
||||
from async_timeout import timeout as asyncio_timeout
|
||||
|
||||
try:
|
||||
from pydantic.v1 import BaseModel, Field
|
||||
from pydantic.v1 import BaseModel, ValidationError # pragma: no cover
|
||||
except ImportError:
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ValidationError # pragma: no cover
|
||||
|
||||
from kasa.credentials import Credentials
|
||||
from kasa.device_factory import (
|
||||
get_device_class_from_family,
|
||||
get_device_class_from_sys_info,
|
||||
get_protocol,
|
||||
)
|
||||
from kasa.deviceconfig import ConnectionType, DeviceConfig, EncryptType
|
||||
from kasa.exceptions import UnsupportedDeviceException
|
||||
from kasa.json import dumps as json_dumps
|
||||
from kasa.json import loads as json_loads
|
||||
from kasa.protocol import TPLinkSmartHomeProtocol
|
||||
from kasa.smartdevice import SmartDevice, SmartDeviceException
|
||||
|
||||
from .device_factory import (
|
||||
get_device_class_from_sys_info,
|
||||
get_device_class_from_type_name,
|
||||
get_protocol_from_connection_name,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
OnDiscoveredCallable = Callable[[SmartDevice], Awaitable[None]]
|
||||
DeviceDict = Dict[str, SmartDevice]
|
||||
|
||||
UNAVAILABLE_ALIAS = "Authentication required"
|
||||
UNAVAILABLE_NICKNAME = base64.b64encode(UNAVAILABLE_ALIAS.encode()).decode()
|
||||
|
||||
|
||||
class _DiscoverProtocol(asyncio.DatagramProtocol):
|
||||
"""Implementation of the discovery protocol handler.
|
||||
@ -62,9 +66,12 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
|
||||
self.discovery_packets = discovery_packets
|
||||
self.interface = interface
|
||||
self.on_discovered = on_discovered
|
||||
|
||||
self.port = port
|
||||
self.discovery_port = port or Discover.DISCOVERY_PORT
|
||||
self.target = (target, self.discovery_port)
|
||||
self.target_2 = (target, Discover.DISCOVERY_PORT_2)
|
||||
|
||||
self.discovered_devices = {}
|
||||
self.unsupported_device_exceptions: Dict = {}
|
||||
self.invalid_device_exceptions: Dict = {}
|
||||
@ -110,13 +117,18 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
|
||||
self.seen_hosts.add(ip)
|
||||
|
||||
device = None
|
||||
|
||||
config = DeviceConfig(host=ip, port_override=self.port)
|
||||
if self.credentials:
|
||||
config.credentials = self.credentials
|
||||
if self.timeout:
|
||||
config.timeout = self.timeout
|
||||
try:
|
||||
if port == self.discovery_port:
|
||||
device = Discover._get_device_instance_legacy(data, ip, port)
|
||||
device = Discover._get_device_instance_legacy(data, config)
|
||||
elif port == Discover.DISCOVERY_PORT_2:
|
||||
device = Discover._get_device_instance(
|
||||
data, ip, port, self.credentials or Credentials()
|
||||
)
|
||||
config.uses_http = True
|
||||
device = Discover._get_device_instance(data, config)
|
||||
else:
|
||||
return
|
||||
except UnsupportedDeviceException as udex:
|
||||
@ -200,11 +212,13 @@ class Discover:
|
||||
*,
|
||||
target="255.255.255.255",
|
||||
on_discovered=None,
|
||||
timeout=5,
|
||||
discovery_timeout=5,
|
||||
discovery_packets=3,
|
||||
interface=None,
|
||||
on_unsupported=None,
|
||||
credentials=None,
|
||||
port=None,
|
||||
timeout=None,
|
||||
) -> DeviceDict:
|
||||
"""Discover supported devices.
|
||||
|
||||
@ -240,14 +254,15 @@ class Discover:
|
||||
on_unsupported=on_unsupported,
|
||||
credentials=credentials,
|
||||
timeout=timeout,
|
||||
port=port,
|
||||
),
|
||||
local_addr=("0.0.0.0", 0), # noqa: S104
|
||||
)
|
||||
protocol = cast(_DiscoverProtocol, protocol)
|
||||
|
||||
try:
|
||||
_LOGGER.debug("Waiting %s seconds for responses...", timeout)
|
||||
await asyncio.sleep(timeout)
|
||||
_LOGGER.debug("Waiting %s seconds for responses...", discovery_timeout)
|
||||
await asyncio.sleep(discovery_timeout)
|
||||
finally:
|
||||
transport.close()
|
||||
|
||||
@ -259,10 +274,10 @@ class Discover:
|
||||
async def discover_single(
|
||||
host: str,
|
||||
*,
|
||||
discovery_timeout: int = 5,
|
||||
port: Optional[int] = None,
|
||||
timeout=5,
|
||||
timeout: Optional[int] = None,
|
||||
credentials: Optional[Credentials] = None,
|
||||
update_parent_devices: bool = True,
|
||||
) -> SmartDevice:
|
||||
"""Discover a single device by the given IP address.
|
||||
|
||||
@ -275,8 +290,6 @@ class Discover:
|
||||
:param port: Optionally set a different port for the device
|
||||
:param timeout: Timeout for discovery
|
||||
:param credentials: Credentials for devices that require authentication
|
||||
:param update_parent_devices: Automatically call device.update() on
|
||||
devices that have children
|
||||
:rtype: SmartDevice
|
||||
:return: Object for querying/controlling found device.
|
||||
"""
|
||||
@ -320,9 +333,11 @@ class Discover:
|
||||
protocol = cast(_DiscoverProtocol, protocol)
|
||||
|
||||
try:
|
||||
_LOGGER.debug("Waiting a total of %s seconds for responses...", timeout)
|
||||
_LOGGER.debug(
|
||||
"Waiting a total of %s seconds for responses...", discovery_timeout
|
||||
)
|
||||
|
||||
async with asyncio_timeout(timeout):
|
||||
async with asyncio_timeout(discovery_timeout):
|
||||
await event.wait()
|
||||
except asyncio.TimeoutError as ex:
|
||||
raise SmartDeviceException(
|
||||
@ -334,9 +349,6 @@ class Discover:
|
||||
if ip in protocol.discovered_devices:
|
||||
dev = protocol.discovered_devices[ip]
|
||||
dev.host = host
|
||||
# Call device update on devices that have children
|
||||
if update_parent_devices and dev.has_children:
|
||||
await dev.update()
|
||||
return dev
|
||||
elif ip in protocol.unsupported_device_exceptions:
|
||||
raise protocol.unsupported_device_exceptions[ip]
|
||||
@ -350,99 +362,121 @@ class Discover:
|
||||
"""Find SmartDevice subclass for device described by passed data."""
|
||||
if "result" in info:
|
||||
discovery_result = DiscoveryResult(**info["result"])
|
||||
dev_class = get_device_class_from_type_name(discovery_result.device_type)
|
||||
dev_class = get_device_class_from_family(discovery_result.device_type)
|
||||
if not dev_class:
|
||||
raise UnsupportedDeviceException(
|
||||
"Unknown device type: %s" % discovery_result.device_type
|
||||
"Unknown device type: %s" % discovery_result.device_type,
|
||||
discovery_result=info,
|
||||
)
|
||||
return dev_class
|
||||
else:
|
||||
return get_device_class_from_sys_info(info)
|
||||
|
||||
@staticmethod
|
||||
def _get_device_instance_legacy(data: bytes, ip: str, port: int) -> SmartDevice:
|
||||
def _get_device_instance_legacy(data: bytes, config: DeviceConfig) -> SmartDevice:
|
||||
"""Get SmartDevice from legacy 9999 response."""
|
||||
try:
|
||||
info = json_loads(TPLinkSmartHomeProtocol.decrypt(data))
|
||||
except Exception as ex:
|
||||
raise SmartDeviceException(
|
||||
f"Unable to read response from device: {ip}: {ex}"
|
||||
f"Unable to read response from device: {config.host}: {ex}"
|
||||
) from ex
|
||||
|
||||
_LOGGER.debug("[DISCOVERY] %s << %s", ip, info)
|
||||
_LOGGER.debug("[DISCOVERY] %s << %s", config.host, info)
|
||||
|
||||
device_class = Discover._get_device_class(info)
|
||||
device = device_class(ip, port=port)
|
||||
device = device_class(config.host, config=config)
|
||||
sys_info = info["system"]["get_sysinfo"]
|
||||
if device_type := sys_info.get("mic_type", sys_info.get("type")):
|
||||
config.connection_type = ConnectionType.from_values(
|
||||
device_family=device_type, encryption_type=EncryptType.Xor.value
|
||||
)
|
||||
device.protocol = get_protocol(config) # type: ignore[assignment]
|
||||
device.update_from_discover_info(info)
|
||||
return device
|
||||
|
||||
@staticmethod
|
||||
def _get_device_instance(
|
||||
data: bytes, ip: str, port: int, credentials: Credentials
|
||||
data: bytes,
|
||||
config: DeviceConfig,
|
||||
) -> SmartDevice:
|
||||
"""Get SmartDevice from the new 20002 response."""
|
||||
try:
|
||||
info = json_loads(data[16:])
|
||||
discovery_result = DiscoveryResult(**info["result"])
|
||||
except Exception as ex:
|
||||
_LOGGER.debug("Got invalid response from device %s: %s", config.host, data)
|
||||
raise SmartDeviceException(
|
||||
f"Unable to read response from device: {config.host}: {ex}"
|
||||
) from ex
|
||||
try:
|
||||
discovery_result = DiscoveryResult(**info["result"])
|
||||
except ValidationError as ex:
|
||||
_LOGGER.debug(
|
||||
"Unable to parse discovery from device %s: %s", config.host, info
|
||||
)
|
||||
raise UnsupportedDeviceException(
|
||||
f"Unable to read response from device: {ip}: {ex}"
|
||||
f"Unable to parse discovery from device: {config.host}: {ex}"
|
||||
) from ex
|
||||
|
||||
type_ = discovery_result.device_type
|
||||
encrypt_type_ = (
|
||||
f"{type_.split('.')[0]}.{discovery_result.mgt_encrypt_schm.encrypt_type}"
|
||||
)
|
||||
|
||||
if (device_class := get_device_class_from_type_name(type_)) is None:
|
||||
try:
|
||||
config.connection_type = ConnectionType.from_values(
|
||||
type_, discovery_result.mgt_encrypt_schm.encrypt_type
|
||||
)
|
||||
except SmartDeviceException as ex:
|
||||
raise UnsupportedDeviceException(
|
||||
f"Unsupported device {config.host} of type {type_} "
|
||||
+ f"with encrypt_type {discovery_result.mgt_encrypt_schm.encrypt_type}",
|
||||
discovery_result=discovery_result.get_dict(),
|
||||
) from ex
|
||||
if (device_class := get_device_class_from_family(type_)) is None:
|
||||
_LOGGER.warning("Got unsupported device type: %s", type_)
|
||||
raise UnsupportedDeviceException(
|
||||
f"Unsupported device {ip} of type {type_}: {info}",
|
||||
f"Unsupported device {config.host} of type {type_}: {info}",
|
||||
discovery_result=discovery_result.get_dict(),
|
||||
)
|
||||
if (
|
||||
protocol := get_protocol_from_connection_name(
|
||||
encrypt_type_, ip, credentials=credentials
|
||||
if (protocol := get_protocol(config)) is None:
|
||||
_LOGGER.warning(
|
||||
"Got unsupported connection type: %s", config.connection_type.to_dict()
|
||||
)
|
||||
) is None:
|
||||
_LOGGER.warning("Got unsupported device type: %s", encrypt_type_)
|
||||
raise UnsupportedDeviceException(
|
||||
f"Unsupported encryption scheme {ip} of type {encrypt_type_}: {info}",
|
||||
f"Unsupported encryption scheme {config.host} of "
|
||||
+ f"type {config.connection_type.to_dict()}: {info}",
|
||||
discovery_result=discovery_result.get_dict(),
|
||||
)
|
||||
|
||||
_LOGGER.debug("[DISCOVERY] %s << %s", ip, info)
|
||||
device = device_class(ip, port=port, credentials=credentials)
|
||||
device.protocol = protocol
|
||||
device.update_from_discover_info(discovery_result.get_dict())
|
||||
_LOGGER.debug("[DISCOVERY] %s << %s", config.host, info)
|
||||
device = device_class(config.host, protocol=protocol)
|
||||
|
||||
di = discovery_result.get_dict()
|
||||
di["model"] = discovery_result.device_model
|
||||
di["alias"] = UNAVAILABLE_ALIAS
|
||||
di["nickname"] = UNAVAILABLE_NICKNAME
|
||||
device.update_from_discover_info(di)
|
||||
return device
|
||||
|
||||
|
||||
class DiscoveryResult(BaseModel):
|
||||
"""Base model for discovery result."""
|
||||
|
||||
class Config:
|
||||
"""Class for configuring model behaviour."""
|
||||
|
||||
allow_population_by_field_name = True
|
||||
|
||||
class EncryptionScheme(BaseModel):
|
||||
"""Base model for encryption scheme of discovery result."""
|
||||
|
||||
is_support_https: Optional[bool] = None
|
||||
encrypt_type: Optional[str] = None
|
||||
http_port: Optional[int] = None
|
||||
lv: Optional[int] = 1
|
||||
is_support_https: bool
|
||||
encrypt_type: str
|
||||
http_port: int
|
||||
lv: Optional[int] = None
|
||||
|
||||
device_type: str = Field(alias="device_type_text")
|
||||
device_model: str = Field(alias="model")
|
||||
ip: str = Field(alias="alias")
|
||||
device_type: str
|
||||
device_model: str
|
||||
ip: str
|
||||
mac: str
|
||||
mgt_encrypt_schm: EncryptionScheme
|
||||
device_id: str
|
||||
|
||||
device_id: Optional[str] = Field(default=None, alias="device_id_hash")
|
||||
owner: Optional[str] = Field(default=None, alias="device_owner_hash")
|
||||
hw_ver: Optional[str] = None
|
||||
owner: Optional[str] = None
|
||||
is_support_iot_cloud: Optional[bool] = None
|
||||
obd_src: Optional[str] = None
|
||||
factory_default: Optional[bool] = None
|
||||
@ -453,5 +487,5 @@ class DiscoveryResult(BaseModel):
|
||||
containing only the values actually set and with aliases as field names.
|
||||
"""
|
||||
return self.dict(
|
||||
by_alias=True, exclude_unset=True, exclude_none=True, exclude_defaults=True
|
||||
by_alias=False, exclude_unset=True, exclude_none=True, exclude_defaults=True
|
||||
)
|
||||
|
@ -17,12 +17,11 @@ class IotProtocol(TPLinkProtocol):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
*,
|
||||
transport: BaseTransport,
|
||||
) -> None:
|
||||
"""Create a protocol object."""
|
||||
super().__init__(host, transport=transport)
|
||||
super().__init__(transport=transport)
|
||||
|
||||
self._query_lock = asyncio.Lock()
|
||||
|
||||
@ -39,25 +38,21 @@ class IotProtocol(TPLinkProtocol):
|
||||
for retry in range(retry_count + 1):
|
||||
try:
|
||||
return await self._execute_query(request, retry)
|
||||
except httpx.CloseError as sdex:
|
||||
await self.close()
|
||||
except httpx.ConnectError as sdex:
|
||||
if retry >= retry_count:
|
||||
await self.close()
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device: {self._host}: {sdex}"
|
||||
) from sdex
|
||||
continue
|
||||
except httpx.ConnectError as cex:
|
||||
await self.close()
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device: {self._host}: {cex}"
|
||||
) from cex
|
||||
except TimeoutError as tex:
|
||||
await self.close()
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device, timed out: {self._host}: {tex}"
|
||||
) from tex
|
||||
except AuthenticationException as auex:
|
||||
await self.close()
|
||||
_LOGGER.debug(
|
||||
"Unable to authenticate with %s, not retrying", self._host
|
||||
)
|
||||
@ -70,8 +65,8 @@ class IotProtocol(TPLinkProtocol):
|
||||
)
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
await self.close()
|
||||
if retry >= retry_count:
|
||||
await self.close()
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device: {self._host}: {ex}"
|
||||
|
@ -54,6 +54,7 @@ from cryptography.hazmat.primitives import hashes, padding
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
|
||||
from .credentials import Credentials
|
||||
from .deviceconfig import DeviceConfig
|
||||
from .exceptions import AuthenticationException, SmartDeviceException
|
||||
from .json import loads as json_loads
|
||||
from .protocol import BaseTransport, md5
|
||||
@ -82,27 +83,21 @@ class KlapTransport(BaseTransport):
|
||||
protocol, used by newer firmware versions.
|
||||
"""
|
||||
|
||||
DEFAULT_PORT = 80
|
||||
DEFAULT_PORT: int = 80
|
||||
DISCOVERY_QUERY = {"system": {"get_sysinfo": None}}
|
||||
|
||||
KASA_SETUP_EMAIL = "kasa@tp-link.net"
|
||||
KASA_SETUP_PASSWORD = "kasaSetup" # noqa: S105
|
||||
SESSION_COOKIE_NAME = "TP_SESSIONID"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
*,
|
||||
port: Optional[int] = None,
|
||||
credentials: Optional[Credentials] = None,
|
||||
timeout: Optional[int] = None,
|
||||
config: DeviceConfig,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
host,
|
||||
port=port or self.DEFAULT_PORT,
|
||||
credentials=credentials,
|
||||
timeout=timeout,
|
||||
)
|
||||
super().__init__(config=config)
|
||||
|
||||
self._default_http_client: Optional[httpx.AsyncClient] = None
|
||||
self._local_seed: Optional[bytes] = None
|
||||
self._local_auth_hash = self.generate_auth_hash(self._credentials)
|
||||
self._local_auth_owner = self.generate_owner_hash(self._credentials).hex()
|
||||
@ -116,14 +111,24 @@ class KlapTransport(BaseTransport):
|
||||
self._session_expire_at: Optional[float] = None
|
||||
|
||||
self._session_cookie = None
|
||||
self._http_client: httpx.AsyncClient = httpx.AsyncClient()
|
||||
|
||||
_LOGGER.debug("Created KLAP transport for %s", self._host)
|
||||
|
||||
@property
|
||||
def default_port(self):
|
||||
"""Default port for the transport."""
|
||||
return self.DEFAULT_PORT
|
||||
|
||||
@property
|
||||
def _http_client(self) -> httpx.AsyncClient:
|
||||
if self._config.http_client:
|
||||
return self._config.http_client
|
||||
if not self._default_http_client:
|
||||
self._default_http_client = httpx.AsyncClient()
|
||||
return self._default_http_client
|
||||
|
||||
async def client_post(self, url, params=None, data=None):
|
||||
"""Send an http post request to the device."""
|
||||
if not self._http_client:
|
||||
self._http_client = httpx.AsyncClient()
|
||||
response_data = None
|
||||
cookies = None
|
||||
if self._session_cookie:
|
||||
@ -355,8 +360,8 @@ class KlapTransport(BaseTransport):
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the transport."""
|
||||
client = self._http_client
|
||||
self._http_client = None
|
||||
client = self._default_http_client
|
||||
self._default_http_client = None
|
||||
self._handshake_done = False
|
||||
if client:
|
||||
await client.aclose()
|
||||
@ -390,7 +395,7 @@ class KlapTransport(BaseTransport):
|
||||
return md5(un.encode())
|
||||
|
||||
|
||||
class TPlinkKlapTransportV2(KlapTransport):
|
||||
class KlapTransportV2(KlapTransport):
|
||||
"""Implementation of the KLAP encryption protocol with v2 hanshake hashes."""
|
||||
|
||||
@staticmethod
|
||||
|
@ -24,7 +24,7 @@ from typing import Dict, Generator, Optional, Union
|
||||
from async_timeout import timeout as asyncio_timeout
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
|
||||
from .credentials import Credentials
|
||||
from .deviceconfig import DeviceConfig
|
||||
from .exceptions import SmartDeviceException
|
||||
from .json import dumps as json_dumps
|
||||
from .json import loads as json_loads
|
||||
@ -48,17 +48,20 @@ class BaseTransport(ABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
*,
|
||||
port: Optional[int] = None,
|
||||
credentials: Optional[Credentials] = None,
|
||||
timeout: Optional[int] = None,
|
||||
config: DeviceConfig,
|
||||
) -> None:
|
||||
"""Create a protocol object."""
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._credentials = credentials or Credentials(username="", password="")
|
||||
self._timeout = timeout or self.DEFAULT_TIMEOUT
|
||||
self._config = config
|
||||
self._host = config.host
|
||||
self._port = config.port_override or self.default_port
|
||||
self._credentials = config.credentials
|
||||
self._timeout = config.timeout
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def default_port(self) -> int:
|
||||
"""The default port for the transport."""
|
||||
|
||||
@abstractmethod
|
||||
async def send(self, request: str) -> Dict:
|
||||
@ -74,7 +77,6 @@ class TPLinkProtocol(ABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
*,
|
||||
transport: BaseTransport,
|
||||
) -> None:
|
||||
@ -85,6 +87,11 @@ class TPLinkProtocol(ABC):
|
||||
def _host(self):
|
||||
return self._transport._host
|
||||
|
||||
@property
|
||||
def config(self) -> DeviceConfig:
|
||||
"""Return the connection parameters the device is using."""
|
||||
return self._transport._config
|
||||
|
||||
@abstractmethod
|
||||
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
||||
"""Query the device for the protocol. Abstract method to be overriden."""
|
||||
@ -103,22 +110,15 @@ class _XorTransport(BaseTransport):
|
||||
class.
|
||||
"""
|
||||
|
||||
DEFAULT_PORT = 9999
|
||||
DEFAULT_PORT: int = 9999
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
*,
|
||||
port: Optional[int] = None,
|
||||
credentials: Optional[Credentials] = None,
|
||||
timeout: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
host,
|
||||
port=port or self.DEFAULT_PORT,
|
||||
credentials=credentials,
|
||||
timeout=timeout,
|
||||
)
|
||||
def __init__(self, *, config: DeviceConfig) -> None:
|
||||
super().__init__(config=config)
|
||||
|
||||
@property
|
||||
def default_port(self):
|
||||
"""Default port for the transport."""
|
||||
return self.DEFAULT_PORT
|
||||
|
||||
async def send(self, request: str) -> Dict:
|
||||
"""Send a message to the device and return a response."""
|
||||
@ -133,17 +133,15 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
|
||||
|
||||
INITIALIZATION_VECTOR = 171
|
||||
DEFAULT_PORT = 9999
|
||||
DEFAULT_TIMEOUT = 5
|
||||
BLOCK_SIZE = 4
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
*,
|
||||
transport: BaseTransport,
|
||||
) -> None:
|
||||
"""Create a protocol object."""
|
||||
super().__init__(host, transport=transport)
|
||||
super().__init__(transport=transport)
|
||||
|
||||
self.reader: Optional[asyncio.StreamReader] = None
|
||||
self.writer: Optional[asyncio.StreamWriter] = None
|
||||
@ -167,7 +165,7 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
|
||||
assert isinstance(request, str) # noqa: S101
|
||||
|
||||
async with self.query_lock:
|
||||
return await self._query(request, retry_count, self._timeout)
|
||||
return await self._query(request, retry_count, self._timeout) # type: ignore[arg-type]
|
||||
|
||||
async def _connect(self, timeout: int) -> None:
|
||||
"""Try to connect or reconnect to the device."""
|
||||
|
@ -9,8 +9,9 @@ try:
|
||||
except ImportError:
|
||||
from pydantic import BaseModel, Field, root_validator
|
||||
|
||||
from .credentials import Credentials
|
||||
from .deviceconfig import DeviceConfig
|
||||
from .modules import Antitheft, Cloud, Countdown, Emeter, Schedule, Time, Usage
|
||||
from .protocol import TPLinkProtocol
|
||||
from .smartdevice import DeviceType, SmartDevice, SmartDeviceException, requires_update
|
||||
|
||||
|
||||
@ -220,11 +221,10 @@ class SmartBulb(SmartDevice):
|
||||
self,
|
||||
host: str,
|
||||
*,
|
||||
port: Optional[int] = None,
|
||||
credentials: Optional[Credentials] = None,
|
||||
timeout: Optional[int] = None,
|
||||
config: Optional[DeviceConfig] = None,
|
||||
protocol: Optional[TPLinkProtocol] = None,
|
||||
) -> None:
|
||||
super().__init__(host=host, port=port, credentials=credentials, timeout=timeout)
|
||||
super().__init__(host=host, config=config, protocol=protocol)
|
||||
self._device_type = DeviceType.Bulb
|
||||
self.add_module("schedule", Schedule(self, "smartlife.iot.common.schedule"))
|
||||
self.add_module("usage", Usage(self, "smartlife.iot.common.schedule"))
|
||||
|
@ -21,6 +21,7 @@ from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from .credentials import Credentials
|
||||
from .device_type import DeviceType
|
||||
from .deviceconfig import DeviceConfig
|
||||
from .emeterstatus import EmeterStatus
|
||||
from .exceptions import SmartDeviceException
|
||||
from .modules import Emeter, Module
|
||||
@ -191,20 +192,18 @@ class SmartDevice:
|
||||
self,
|
||||
host: str,
|
||||
*,
|
||||
port: Optional[int] = None,
|
||||
credentials: Optional[Credentials] = None,
|
||||
timeout: Optional[int] = None,
|
||||
config: Optional[DeviceConfig] = None,
|
||||
protocol: Optional[TPLinkProtocol] = None,
|
||||
) -> None:
|
||||
"""Create a new SmartDevice instance.
|
||||
|
||||
:param str host: host name or ip address on which the device listens
|
||||
"""
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.protocol: TPLinkProtocol = TPLinkSmartHomeProtocol(
|
||||
host, transport=_XorTransport(host, port=port, timeout=timeout)
|
||||
if config and protocol:
|
||||
protocol._transport._config = config
|
||||
self.protocol: TPLinkProtocol = protocol or TPLinkSmartHomeProtocol(
|
||||
transport=_XorTransport(config=config or DeviceConfig(host=host)),
|
||||
)
|
||||
self.credentials = credentials
|
||||
_LOGGER.debug("Initializing %s of type %s", self.host, type(self))
|
||||
self._device_type = DeviceType.Unknown
|
||||
# TODO: typing Any is just as using Optional[Dict] would require separate
|
||||
@ -219,6 +218,30 @@ class SmartDevice:
|
||||
|
||||
self.children: List["SmartDevice"] = []
|
||||
|
||||
@property
|
||||
def host(self) -> str:
|
||||
"""The device host."""
|
||||
return self.protocol._transport._host
|
||||
|
||||
@host.setter
|
||||
def host(self, value):
|
||||
"""Set the device host.
|
||||
|
||||
Generally used by discovery to set the hostname after ip discovery.
|
||||
"""
|
||||
self.protocol._transport._host = value
|
||||
self.protocol._transport._config.host = value
|
||||
|
||||
@property
|
||||
def port(self) -> int:
|
||||
"""The device port."""
|
||||
return self.protocol._transport._port
|
||||
|
||||
@property
|
||||
def credentials(self) -> Optional[Credentials]:
|
||||
"""The device credentials."""
|
||||
return self.protocol._transport._credentials
|
||||
|
||||
def add_module(self, name: str, module: Module):
|
||||
"""Register a module."""
|
||||
if name in self.modules:
|
||||
@ -760,7 +783,7 @@ class SmartDevice:
|
||||
The returned object contains the raw results from the last update call.
|
||||
This should only be used for debugging purposes.
|
||||
"""
|
||||
return self._last_update
|
||||
return self._last_update or self._discovery_info
|
||||
|
||||
def __repr__(self):
|
||||
if self._last_update is None:
|
||||
@ -771,41 +794,33 @@ class SmartDevice:
|
||||
f" - dev specific: {self.state_information}>"
|
||||
)
|
||||
|
||||
@property
|
||||
def config(self) -> DeviceConfig:
|
||||
"""Return the connection parameters the device is using."""
|
||||
return self.protocol.config
|
||||
|
||||
@staticmethod
|
||||
async def connect(
|
||||
host: str,
|
||||
*,
|
||||
port: Optional[int] = None,
|
||||
timeout=5,
|
||||
credentials: Optional[Credentials] = None,
|
||||
device_type: Optional[DeviceType] = None,
|
||||
host: Optional[str] = None,
|
||||
config: Optional[DeviceConfig] = None,
|
||||
) -> "SmartDevice":
|
||||
"""Connect to a single device by the given IP address.
|
||||
"""Connect to a single device by the given hostname or device configuration.
|
||||
|
||||
This method avoids the UDP based discovery process and
|
||||
will connect directly to the device to query its type.
|
||||
will connect directly to the device.
|
||||
|
||||
It is generally preferred to avoid :func:`discover_single()` and
|
||||
use this function instead as it should perform better when
|
||||
the WiFi network is congested or the device is not responding
|
||||
to discovery requests.
|
||||
|
||||
The device type is discovered by querying the device.
|
||||
|
||||
:param host: Hostname of device to query
|
||||
:param device_type: Device type to use for the device.
|
||||
If not given, the device type is discovered by querying the device.
|
||||
If the device type is already known, it is preferred to pass it
|
||||
to avoid the extra query to the device to discover its type.
|
||||
:param config: Connection parameters to ensure the correct protocol
|
||||
and connection options are used.
|
||||
:rtype: SmartDevice
|
||||
:return: Object for querying/controlling found device.
|
||||
"""
|
||||
from .device_factory import connect # pylint: disable=import-outside-toplevel
|
||||
|
||||
return await connect(
|
||||
host=host,
|
||||
port=port,
|
||||
timeout=timeout,
|
||||
credentials=credentials,
|
||||
device_type=device_type,
|
||||
)
|
||||
return await connect(host=host, config=config) # type: ignore[arg-type]
|
||||
|
@ -2,8 +2,9 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from kasa.credentials import Credentials
|
||||
from kasa.deviceconfig import DeviceConfig
|
||||
from kasa.modules import AmbientLight, Motion
|
||||
from kasa.protocol import TPLinkProtocol
|
||||
from kasa.smartdevice import DeviceType, SmartDeviceException, requires_update
|
||||
from kasa.smartplug import SmartPlug
|
||||
|
||||
@ -68,11 +69,10 @@ class SmartDimmer(SmartPlug):
|
||||
self,
|
||||
host: str,
|
||||
*,
|
||||
port: Optional[int] = None,
|
||||
credentials: Optional[Credentials] = None,
|
||||
timeout: Optional[int] = None,
|
||||
config: Optional[DeviceConfig] = None,
|
||||
protocol: Optional[TPLinkProtocol] = None,
|
||||
) -> None:
|
||||
super().__init__(host, port=port, credentials=credentials, timeout=timeout)
|
||||
super().__init__(host=host, config=config, protocol=protocol)
|
||||
self._device_type = DeviceType.Dimmer
|
||||
# TODO: need to be verified if it's okay to call these on HS220 w/o these
|
||||
# TODO: need to be figured out what's the best approach to detect support
|
||||
|
@ -1,8 +1,9 @@
|
||||
"""Module for light strips (KL430)."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .credentials import Credentials
|
||||
from .deviceconfig import DeviceConfig
|
||||
from .effects import EFFECT_MAPPING_V1, EFFECT_NAMES_V1
|
||||
from .protocol import TPLinkProtocol
|
||||
from .smartbulb import SmartBulb
|
||||
from .smartdevice import DeviceType, SmartDeviceException, requires_update
|
||||
|
||||
@ -46,11 +47,10 @@ class SmartLightStrip(SmartBulb):
|
||||
self,
|
||||
host: str,
|
||||
*,
|
||||
port: Optional[int] = None,
|
||||
credentials: Optional[Credentials] = None,
|
||||
timeout: Optional[int] = None,
|
||||
config: Optional[DeviceConfig] = None,
|
||||
protocol: Optional[TPLinkProtocol] = None,
|
||||
) -> None:
|
||||
super().__init__(host, port=port, credentials=credentials, timeout=timeout)
|
||||
super().__init__(host=host, config=config, protocol=protocol)
|
||||
self._device_type = DeviceType.LightStrip
|
||||
|
||||
@property # type: ignore
|
||||
|
@ -2,8 +2,9 @@
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from kasa.credentials import Credentials
|
||||
from kasa.deviceconfig import DeviceConfig
|
||||
from kasa.modules import Antitheft, Cloud, Schedule, Time, Usage
|
||||
from kasa.protocol import TPLinkProtocol
|
||||
from kasa.smartdevice import DeviceType, SmartDevice, requires_update
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@ -43,11 +44,10 @@ class SmartPlug(SmartDevice):
|
||||
self,
|
||||
host: str,
|
||||
*,
|
||||
port: Optional[int] = None,
|
||||
credentials: Optional[Credentials] = None,
|
||||
timeout: Optional[int] = None,
|
||||
config: Optional[DeviceConfig] = None,
|
||||
protocol: Optional[TPLinkProtocol] = None,
|
||||
) -> None:
|
||||
super().__init__(host, port=port, credentials=credentials, timeout=timeout)
|
||||
super().__init__(host=host, config=config, protocol=protocol)
|
||||
self._device_type = DeviceType.Plug
|
||||
self.add_module("schedule", Schedule(self, "schedule"))
|
||||
self.add_module("usage", Usage(self, "schedule"))
|
||||
|
@ -38,12 +38,11 @@ class SmartProtocol(TPLinkProtocol):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
*,
|
||||
transport: BaseTransport,
|
||||
) -> None:
|
||||
"""Create a protocol object."""
|
||||
super().__init__(host, transport=transport)
|
||||
super().__init__(transport=transport)
|
||||
self._terminal_uuid: str = base64.b64encode(md5(uuid.uuid4().bytes)).decode()
|
||||
self._request_id_generator = SnowflakeId(1, 1)
|
||||
self._query_lock = asyncio.Lock()
|
||||
@ -68,19 +67,14 @@ class SmartProtocol(TPLinkProtocol):
|
||||
for retry in range(retry_count + 1):
|
||||
try:
|
||||
return await self._execute_query(request, retry)
|
||||
except httpx.CloseError as sdex:
|
||||
await self.close()
|
||||
except httpx.ConnectError as sdex:
|
||||
if retry >= retry_count:
|
||||
await self.close()
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device: {self._host}: {sdex}"
|
||||
) from sdex
|
||||
continue
|
||||
except httpx.ConnectError as cex:
|
||||
await self.close()
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device: {self._host}: {cex}"
|
||||
) from cex
|
||||
except TimeoutError as tex:
|
||||
if retry >= retry_count:
|
||||
await self.close()
|
||||
|
@ -14,8 +14,9 @@ from kasa.smartdevice import (
|
||||
)
|
||||
from kasa.smartplug import SmartPlug
|
||||
|
||||
from .credentials import Credentials
|
||||
from .deviceconfig import DeviceConfig
|
||||
from .modules import Antitheft, Countdown, Emeter, Schedule, Time, Usage
|
||||
from .protocol import TPLinkProtocol
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@ -85,11 +86,10 @@ class SmartStrip(SmartDevice):
|
||||
self,
|
||||
host: str,
|
||||
*,
|
||||
port: Optional[int] = None,
|
||||
credentials: Optional[Credentials] = None,
|
||||
timeout: Optional[int] = None,
|
||||
config: Optional[DeviceConfig] = None,
|
||||
protocol: Optional[TPLinkProtocol] = None,
|
||||
) -> None:
|
||||
super().__init__(host=host, port=port, credentials=credentials, timeout=timeout)
|
||||
super().__init__(host=host, config=config, protocol=protocol)
|
||||
self.emeter_type = "emeter"
|
||||
self._device_type = DeviceType.Strip
|
||||
self.add_module("antitheft", Antitheft(self, "anti_theft"))
|
||||
|
@ -5,8 +5,9 @@ from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, Optional, Set, cast
|
||||
|
||||
from ..aestransport import AesTransport
|
||||
from ..credentials import Credentials
|
||||
from ..deviceconfig import DeviceConfig
|
||||
from ..exceptions import AuthenticationException
|
||||
from ..protocol import TPLinkProtocol
|
||||
from ..smartdevice import SmartDevice
|
||||
from ..smartprotocol import SmartProtocol
|
||||
|
||||
@ -20,20 +21,16 @@ class TapoDevice(SmartDevice):
|
||||
self,
|
||||
host: str,
|
||||
*,
|
||||
port: Optional[int] = None,
|
||||
credentials: Optional[Credentials] = None,
|
||||
timeout: Optional[int] = None,
|
||||
config: Optional[DeviceConfig] = None,
|
||||
protocol: Optional[TPLinkProtocol] = None,
|
||||
) -> None:
|
||||
super().__init__(host, port=port, credentials=credentials, timeout=timeout)
|
||||
_protocol = protocol or SmartProtocol(
|
||||
transport=AesTransport(config=config or DeviceConfig(host=host)),
|
||||
)
|
||||
super().__init__(host=host, config=config, protocol=_protocol)
|
||||
self._components: Optional[Dict[str, Any]] = None
|
||||
self._state_information: Dict[str, Any] = {}
|
||||
self._discovery_info: Optional[Dict[str, Any]] = None
|
||||
self.protocol = SmartProtocol(
|
||||
host,
|
||||
transport=AesTransport(
|
||||
host, credentials=credentials, timeout=timeout, port=port
|
||||
),
|
||||
)
|
||||
|
||||
async def update(self, update_children: bool = True):
|
||||
"""Update the device."""
|
||||
@ -66,7 +63,7 @@ class TapoDevice(SmartDevice):
|
||||
@property
|
||||
def sys_info(self) -> Dict[str, Any]:
|
||||
"""Returns the device info."""
|
||||
return self._info
|
||||
return self._info # type: ignore
|
||||
|
||||
@property
|
||||
def model(self) -> str:
|
||||
@ -180,3 +177,4 @@ class TapoDevice(SmartDevice):
|
||||
def update_from_discover_info(self, info):
|
||||
"""Update state from info from the discover call."""
|
||||
self._discovery_info = info
|
||||
self._info = info
|
||||
|
@ -3,9 +3,10 @@ import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, Optional, cast
|
||||
|
||||
from ..credentials import Credentials
|
||||
from ..deviceconfig import DeviceConfig
|
||||
from ..emeterstatus import EmeterStatus
|
||||
from ..modules import Emeter
|
||||
from ..protocol import TPLinkProtocol
|
||||
from ..smartdevice import DeviceType, requires_update
|
||||
from .tapodevice import TapoDevice
|
||||
|
||||
@ -19,11 +20,10 @@ class TapoPlug(TapoDevice):
|
||||
self,
|
||||
host: str,
|
||||
*,
|
||||
port: Optional[int] = None,
|
||||
credentials: Optional[Credentials] = None,
|
||||
timeout: Optional[int] = None,
|
||||
config: Optional[DeviceConfig] = None,
|
||||
protocol: Optional[TPLinkProtocol] = None,
|
||||
) -> None:
|
||||
super().__init__(host, port=port, credentials=credentials, timeout=timeout)
|
||||
super().__init__(host=host, config=config, protocol=protocol)
|
||||
self._device_type = DeviceType.Plug
|
||||
self.modules: Dict[str, Any] = {}
|
||||
self.emeter_type = "emeter"
|
||||
|
@ -388,7 +388,6 @@ async def get_device_for_file(file, protocol):
|
||||
d = device_for_file(model, protocol)(host="127.0.0.123")
|
||||
if protocol == "SMART":
|
||||
d.protocol = FakeSmartProtocol(sysinfo)
|
||||
d.credentials = Credentials("", "")
|
||||
else:
|
||||
d.protocol = FakeTransportProtocol(sysinfo)
|
||||
await _update_and_close(d)
|
||||
@ -426,28 +425,53 @@ def discovery_mock(all_fixture_data, mocker):
|
||||
class _DiscoveryMock:
|
||||
ip: str
|
||||
default_port: int
|
||||
discovery_port: int
|
||||
discovery_data: dict
|
||||
query_data: dict
|
||||
device_type: str
|
||||
encrypt_type: str
|
||||
port_override: Optional[int] = None
|
||||
|
||||
if "discovery_result" in all_fixture_data:
|
||||
discovery_data = {"result": all_fixture_data["discovery_result"]}
|
||||
device_type = all_fixture_data["discovery_result"]["device_type"]
|
||||
encrypt_type = all_fixture_data["discovery_result"]["mgt_encrypt_schm"][
|
||||
"encrypt_type"
|
||||
]
|
||||
datagram = (
|
||||
b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8"
|
||||
+ json_dumps(discovery_data).encode()
|
||||
)
|
||||
dm = _DiscoveryMock("127.0.0.123", 20002, discovery_data, all_fixture_data)
|
||||
dm = _DiscoveryMock(
|
||||
"127.0.0.123",
|
||||
80,
|
||||
20002,
|
||||
discovery_data,
|
||||
all_fixture_data,
|
||||
device_type,
|
||||
encrypt_type,
|
||||
)
|
||||
else:
|
||||
sys_info = all_fixture_data["system"]["get_sysinfo"]
|
||||
discovery_data = {"system": {"get_sysinfo": sys_info}}
|
||||
device_type = sys_info.get("mic_type") or sys_info.get("type")
|
||||
encrypt_type = "XOR"
|
||||
datagram = TPLinkSmartHomeProtocol.encrypt(json_dumps(discovery_data))[4:]
|
||||
dm = _DiscoveryMock("127.0.0.123", 9999, discovery_data, all_fixture_data)
|
||||
dm = _DiscoveryMock(
|
||||
"127.0.0.123",
|
||||
9999,
|
||||
9999,
|
||||
discovery_data,
|
||||
all_fixture_data,
|
||||
device_type,
|
||||
encrypt_type,
|
||||
)
|
||||
|
||||
def mock_discover(self):
|
||||
port = (
|
||||
dm.port_override
|
||||
if dm.port_override and dm.default_port != 20002
|
||||
else dm.default_port
|
||||
if dm.port_override and dm.discovery_port != 20002
|
||||
else dm.discovery_port
|
||||
)
|
||||
self.datagram_received(
|
||||
datagram,
|
||||
|
@ -15,7 +15,9 @@ from voluptuous import (
|
||||
Schema,
|
||||
)
|
||||
|
||||
from ..protocol import BaseTransport, TPLinkSmartHomeProtocol
|
||||
from ..credentials import Credentials
|
||||
from ..deviceconfig import DeviceConfig
|
||||
from ..protocol import BaseTransport, TPLinkSmartHomeProtocol, _XorTransport
|
||||
from ..smartprotocol import SmartProtocol
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@ -290,7 +292,9 @@ TIME_MODULE = {
|
||||
|
||||
class FakeSmartProtocol(SmartProtocol):
|
||||
def __init__(self, info):
|
||||
super().__init__("127.0.0.123", transport=FakeSmartTransport(info))
|
||||
super().__init__(
|
||||
transport=FakeSmartTransport(info),
|
||||
)
|
||||
|
||||
async def query(self, request, retry_count: int = 3):
|
||||
"""Implement query here so can still patch SmartProtocol.query."""
|
||||
@ -301,10 +305,15 @@ class FakeSmartProtocol(SmartProtocol):
|
||||
class FakeSmartTransport(BaseTransport):
|
||||
def __init__(self, info):
|
||||
super().__init__(
|
||||
"127.0.0.123",
|
||||
config=DeviceConfig("127.0.0.123", credentials=Credentials()),
|
||||
)
|
||||
self.info = info
|
||||
|
||||
@property
|
||||
def default_port(self):
|
||||
"""Default port for the transport."""
|
||||
return 80
|
||||
|
||||
async def send(self, request: str):
|
||||
request_dict = json_loads(request)
|
||||
method = request_dict["method"]
|
||||
@ -344,6 +353,11 @@ class FakeSmartTransport(BaseTransport):
|
||||
|
||||
class FakeTransportProtocol(TPLinkSmartHomeProtocol):
|
||||
def __init__(self, info):
|
||||
super().__init__(
|
||||
transport=_XorTransport(
|
||||
config=DeviceConfig("127.0.0.123"),
|
||||
)
|
||||
)
|
||||
self.discovery_data = info
|
||||
self.writer = None
|
||||
self.reader = None
|
||||
|
@ -12,6 +12,7 @@ from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padd
|
||||
|
||||
from ..aestransport import AesEncyptionSession, AesTransport
|
||||
from ..credentials import Credentials
|
||||
from ..deviceconfig import DeviceConfig
|
||||
from ..exceptions import (
|
||||
SMART_RETRYABLE_ERRORS,
|
||||
SMART_TIMEOUT_ERRORS,
|
||||
@ -58,7 +59,9 @@ async def test_handshake(
|
||||
mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code)
|
||||
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
|
||||
|
||||
transport = AesTransport(host=host, credentials=Credentials("foo", "bar"))
|
||||
transport = AesTransport(
|
||||
config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
|
||||
)
|
||||
|
||||
assert transport._encryption_session is None
|
||||
assert transport._handshake_done is False
|
||||
@ -74,7 +77,9 @@ async def test_login(mocker, status_code, error_code, inner_error_code, expectat
|
||||
mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code)
|
||||
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
|
||||
|
||||
transport = AesTransport(host=host, credentials=Credentials("foo", "bar"))
|
||||
transport = AesTransport(
|
||||
config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
|
||||
)
|
||||
transport._handshake_done = True
|
||||
transport._session_expire_at = time.time() + 86400
|
||||
transport._encryption_session = mock_aes_device.encryption_session
|
||||
@ -91,13 +96,14 @@ async def test_send(mocker, status_code, error_code, inner_error_code, expectati
|
||||
mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code)
|
||||
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
|
||||
|
||||
transport = AesTransport(host=host, credentials=Credentials("foo", "bar"))
|
||||
transport = AesTransport(
|
||||
config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
|
||||
)
|
||||
transport._handshake_done = True
|
||||
transport._session_expire_at = time.time() + 86400
|
||||
transport._encryption_session = mock_aes_device.encryption_session
|
||||
transport._login_token = mock_aes_device.token
|
||||
|
||||
un, pw = transport.hash_credentials(True)
|
||||
request = {
|
||||
"method": "get_device_info",
|
||||
"params": None,
|
||||
@ -119,7 +125,8 @@ async def test_passthrough_errors(mocker, error_code):
|
||||
mock_aes_device = MockAesDevice(host, 200, error_code, 0)
|
||||
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
|
||||
|
||||
transport = AesTransport(host=host, credentials=Credentials("foo", "bar"))
|
||||
config = DeviceConfig(host, credentials=Credentials("foo", "bar"))
|
||||
transport = AesTransport(config=config)
|
||||
transport._handshake_done = True
|
||||
transport._session_expire_at = time.time() + 86400
|
||||
transport._encryption_session = mock_aes_device.encryption_session
|
||||
|
@ -4,10 +4,26 @@ import asyncclick as click
|
||||
import pytest
|
||||
from asyncclick.testing import CliRunner
|
||||
|
||||
from kasa import AuthenticationException, SmartDevice, UnsupportedDeviceException
|
||||
from kasa.cli import alias, brightness, cli, emeter, raw_command, state, sysinfo, toggle
|
||||
from kasa.device_factory import DEVICE_TYPE_TO_CLASS
|
||||
from kasa.discover import Discover
|
||||
from kasa import (
|
||||
AuthenticationException,
|
||||
Credentials,
|
||||
SmartDevice,
|
||||
TPLinkSmartHomeProtocol,
|
||||
UnsupportedDeviceException,
|
||||
)
|
||||
from kasa.cli import (
|
||||
TYPE_TO_CLASS,
|
||||
alias,
|
||||
brightness,
|
||||
cli,
|
||||
emeter,
|
||||
raw_command,
|
||||
state,
|
||||
sysinfo,
|
||||
toggle,
|
||||
)
|
||||
from kasa.discover import Discover, DiscoveryResult
|
||||
from kasa.smartprotocol import SmartProtocol
|
||||
|
||||
from .conftest import device_iot, handle_turn_on, new_discovery, turn_on
|
||||
|
||||
@ -145,9 +161,11 @@ async def test_credentials(discovery_mock, mocker):
|
||||
)
|
||||
|
||||
mocker.patch("kasa.cli.state", new=_state)
|
||||
for subclass in DEVICE_TYPE_TO_CLASS.values():
|
||||
mocker.patch.object(subclass, "update")
|
||||
|
||||
mocker.patch("kasa.IotProtocol.query", return_value=discovery_mock.query_data)
|
||||
mocker.patch("kasa.SmartProtocol.query", return_value=discovery_mock.query_data)
|
||||
|
||||
dr = DiscoveryResult(**discovery_mock.discovery_data["result"])
|
||||
runner = CliRunner()
|
||||
res = await runner.invoke(
|
||||
cli,
|
||||
@ -158,6 +176,10 @@ async def test_credentials(discovery_mock, mocker):
|
||||
"foo",
|
||||
"--password",
|
||||
"bar",
|
||||
"--device-family",
|
||||
dr.device_type,
|
||||
"--encrypt-type",
|
||||
dr.mgt_encrypt_schm.encrypt_type,
|
||||
],
|
||||
)
|
||||
assert res.exit_code == 0
|
||||
@ -166,7 +188,7 @@ async def test_credentials(discovery_mock, mocker):
|
||||
|
||||
|
||||
@device_iot
|
||||
async def test_without_device_type(discovery_data: dict, dev, mocker):
|
||||
async def test_without_device_type(dev, mocker):
|
||||
"""Test connecting without the device type."""
|
||||
runner = CliRunner()
|
||||
mocker.patch("kasa.discover.Discover.discover_single", return_value=dev)
|
||||
@ -342,3 +364,27 @@ async def test_host_auth_failed(discovery_mock, mocker):
|
||||
|
||||
assert res.exit_code != 0
|
||||
assert isinstance(res.exception, AuthenticationException)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device_type", list(TYPE_TO_CLASS))
|
||||
async def test_type_param(device_type, mocker):
|
||||
"""Test for handling only one of username or password supplied."""
|
||||
runner = CliRunner()
|
||||
|
||||
result_device = FileNotFoundError
|
||||
pass_dev = click.make_pass_decorator(SmartDevice)
|
||||
|
||||
@pass_dev
|
||||
async def _state(dev: SmartDevice):
|
||||
nonlocal result_device
|
||||
result_device = dev
|
||||
|
||||
mocker.patch("kasa.cli.state", new=_state)
|
||||
expected_type = TYPE_TO_CLASS[device_type]
|
||||
mocker.patch.object(expected_type, "update")
|
||||
res = await runner.invoke(
|
||||
cli,
|
||||
["--type", device_type, "--host", "127.0.0.1"],
|
||||
)
|
||||
assert res.exit_code == 0
|
||||
assert isinstance(result_device, expected_type)
|
||||
|
@ -2,6 +2,7 @@
|
||||
import logging
|
||||
from typing import Type
|
||||
|
||||
import httpx
|
||||
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
|
||||
|
||||
from kasa import (
|
||||
@ -15,122 +16,138 @@ from kasa import (
|
||||
SmartLightStrip,
|
||||
SmartPlug,
|
||||
)
|
||||
from kasa.device_factory import (
|
||||
DEVICE_TYPE_TO_CLASS,
|
||||
connect,
|
||||
get_protocol_from_connection_name,
|
||||
from kasa.device_factory import connect, get_protocol
|
||||
from kasa.deviceconfig import (
|
||||
ConnectionType,
|
||||
DeviceConfig,
|
||||
DeviceFamilyType,
|
||||
EncryptType,
|
||||
)
|
||||
from kasa.discover import DiscoveryResult
|
||||
from kasa.iotprotocol import IotProtocol
|
||||
from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol
|
||||
|
||||
|
||||
@pytest.mark.parametrize("custom_port", [123, None])
|
||||
async def test_connect(discovery_data: dict, mocker, custom_port):
|
||||
"""Make sure that connect returns an initialized SmartDevice instance."""
|
||||
host = "127.0.0.1"
|
||||
def _get_connection_type_device_class(the_fixture_data):
|
||||
if "discovery_result" in the_fixture_data:
|
||||
discovery_info = {"result": the_fixture_data["discovery_result"]}
|
||||
device_class = Discover._get_device_class(discovery_info)
|
||||
dr = DiscoveryResult(**discovery_info["result"])
|
||||
|
||||
if "result" in discovery_data:
|
||||
with pytest.raises(SmartDeviceException):
|
||||
dev = await connect(host, port=custom_port)
|
||||
connection_type = ConnectionType.from_values(
|
||||
dr.device_type, dr.mgt_encrypt_schm.encrypt_type
|
||||
)
|
||||
else:
|
||||
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
|
||||
dev = await connect(host, port=custom_port)
|
||||
assert issubclass(dev.__class__, SmartDevice)
|
||||
assert dev.port == custom_port or dev.port == 9999
|
||||
connection_type = ConnectionType.from_values(
|
||||
DeviceFamilyType.IotSmartPlugSwitch.value, EncryptType.Xor.value
|
||||
)
|
||||
device_class = Discover._get_device_class(the_fixture_data)
|
||||
|
||||
return connection_type, device_class
|
||||
|
||||
|
||||
@pytest.mark.parametrize("custom_port", [123, None])
|
||||
@pytest.mark.parametrize(
|
||||
("device_type", "klass"),
|
||||
(
|
||||
(DeviceType.Plug, SmartPlug),
|
||||
(DeviceType.Bulb, SmartBulb),
|
||||
(DeviceType.Dimmer, SmartDimmer),
|
||||
(DeviceType.LightStrip, SmartLightStrip),
|
||||
(DeviceType.Unknown, SmartDevice),
|
||||
),
|
||||
)
|
||||
async def test_connect_passed_device_type(
|
||||
discovery_data: dict,
|
||||
mocker,
|
||||
device_type: DeviceType,
|
||||
klass: Type[SmartDevice],
|
||||
custom_port,
|
||||
):
|
||||
"""Make sure that connect with a passed device type."""
|
||||
host = "127.0.0.1"
|
||||
|
||||
if "result" in discovery_data:
|
||||
with pytest.raises(SmartDeviceException):
|
||||
dev = await connect(host, port=custom_port)
|
||||
else:
|
||||
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
|
||||
dev = await connect(host, port=custom_port, device_type=device_type)
|
||||
assert isinstance(dev, klass)
|
||||
assert dev.port == custom_port or dev.port == 9999
|
||||
|
||||
|
||||
async def test_connect_query_fails(discovery_data: dict, mocker):
|
||||
"""Make sure that connect fails when query fails."""
|
||||
host = "127.0.0.1"
|
||||
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", side_effect=SmartDeviceException)
|
||||
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await connect(host)
|
||||
|
||||
|
||||
async def test_connect_logs_connect_time(
|
||||
discovery_data: dict, caplog: pytest.LogCaptureFixture, mocker
|
||||
):
|
||||
"""Test that the connect time is logged when debug logging is enabled."""
|
||||
host = "127.0.0.1"
|
||||
if "result" in discovery_data:
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await connect(host)
|
||||
else:
|
||||
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
|
||||
logging.getLogger("kasa").setLevel(logging.DEBUG)
|
||||
await connect(host)
|
||||
assert "seconds to connect" in caplog.text
|
||||
|
||||
|
||||
async def test_connect_pass_protocol(
|
||||
async def test_connect(
|
||||
all_fixture_data: dict,
|
||||
mocker,
|
||||
):
|
||||
"""Test that if the protocol is passed in it's gets set correctly."""
|
||||
if "discovery_result" in all_fixture_data:
|
||||
discovery_info = {"result": all_fixture_data["discovery_result"]}
|
||||
device_class = Discover._get_device_class(discovery_info)
|
||||
else:
|
||||
device_class = Discover._get_device_class(all_fixture_data)
|
||||
|
||||
device_type = list(DEVICE_TYPE_TO_CLASS.keys())[
|
||||
list(DEVICE_TYPE_TO_CLASS.values()).index(device_class)
|
||||
]
|
||||
"""Test that if the protocol is passed in it gets set correctly."""
|
||||
host = "127.0.0.1"
|
||||
if "discovery_result" in all_fixture_data:
|
||||
mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data)
|
||||
mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data)
|
||||
ctype, device_class = _get_connection_type_device_class(all_fixture_data)
|
||||
|
||||
dr = DiscoveryResult(**discovery_info["result"])
|
||||
connection_name = (
|
||||
dr.device_type.split(".")[0] + "." + dr.mgt_encrypt_schm.encrypt_type
|
||||
)
|
||||
protocol_class = get_protocol_from_connection_name(
|
||||
connection_name, host
|
||||
).__class__
|
||||
else:
|
||||
mocker.patch(
|
||||
"kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data
|
||||
)
|
||||
protocol_class = TPLinkSmartHomeProtocol
|
||||
mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data)
|
||||
mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data)
|
||||
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data)
|
||||
|
||||
config = DeviceConfig(
|
||||
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
|
||||
)
|
||||
protocol_class = get_protocol(config).__class__
|
||||
|
||||
dev = await connect(
|
||||
host,
|
||||
device_type=device_type,
|
||||
protocol_class=protocol_class,
|
||||
credentials=Credentials("", ""),
|
||||
config=config,
|
||||
)
|
||||
assert isinstance(dev, device_class)
|
||||
assert isinstance(dev.protocol, protocol_class)
|
||||
|
||||
assert dev.config == config
|
||||
|
||||
|
||||
@pytest.mark.parametrize("custom_port", [123, None])
|
||||
async def test_connect_custom_port(all_fixture_data: dict, mocker, custom_port):
|
||||
"""Make sure that connect returns an initialized SmartDevice instance."""
|
||||
host = "127.0.0.1"
|
||||
|
||||
ctype, _ = _get_connection_type_device_class(all_fixture_data)
|
||||
config = DeviceConfig(host=host, port_override=custom_port, connection_type=ctype)
|
||||
default_port = 80 if "discovery_result" in all_fixture_data else 9999
|
||||
|
||||
ctype, _ = _get_connection_type_device_class(all_fixture_data)
|
||||
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data)
|
||||
mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data)
|
||||
mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data)
|
||||
dev = await connect(config=config)
|
||||
assert issubclass(dev.__class__, SmartDevice)
|
||||
assert dev.port == custom_port or dev.port == default_port
|
||||
|
||||
|
||||
async def test_connect_logs_connect_time(
|
||||
all_fixture_data: dict, caplog: pytest.LogCaptureFixture, mocker
|
||||
):
|
||||
"""Test that the connect time is logged when debug logging is enabled."""
|
||||
ctype, _ = _get_connection_type_device_class(all_fixture_data)
|
||||
mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data)
|
||||
mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data)
|
||||
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data)
|
||||
|
||||
host = "127.0.0.1"
|
||||
config = DeviceConfig(
|
||||
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
|
||||
)
|
||||
logging.getLogger("kasa").setLevel(logging.DEBUG)
|
||||
await connect(
|
||||
config=config,
|
||||
)
|
||||
assert "seconds to update" in caplog.text
|
||||
|
||||
|
||||
async def test_connect_query_fails(all_fixture_data: dict, mocker):
|
||||
"""Make sure that connect fails when query fails."""
|
||||
host = "127.0.0.1"
|
||||
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", side_effect=SmartDeviceException)
|
||||
mocker.patch("kasa.IotProtocol.query", side_effect=SmartDeviceException)
|
||||
mocker.patch("kasa.SmartProtocol.query", side_effect=SmartDeviceException)
|
||||
|
||||
ctype, _ = _get_connection_type_device_class(all_fixture_data)
|
||||
config = DeviceConfig(
|
||||
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
|
||||
)
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await connect(config=config)
|
||||
|
||||
|
||||
async def test_connect_http_client(all_fixture_data, mocker):
|
||||
"""Make sure that discover_single returns an initialized SmartDevice instance."""
|
||||
host = "127.0.0.1"
|
||||
|
||||
ctype, _ = _get_connection_type_device_class(all_fixture_data)
|
||||
|
||||
mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data)
|
||||
mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data)
|
||||
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data)
|
||||
|
||||
http_client = httpx.AsyncClient()
|
||||
|
||||
config = DeviceConfig(
|
||||
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
|
||||
)
|
||||
dev = await connect(config=config)
|
||||
if ctype.encryption_type != EncryptType.Xor:
|
||||
assert dev.protocol._transport._http_client != http_client
|
||||
|
||||
config = DeviceConfig(
|
||||
host=host,
|
||||
credentials=Credentials("foor", "bar"),
|
||||
connection_type=ctype,
|
||||
http_client=http_client,
|
||||
)
|
||||
dev = await connect(config=config)
|
||||
if ctype.encryption_type != EncryptType.Xor:
|
||||
assert dev.protocol._transport._http_client == http_client
|
||||
|
21
kasa/tests/test_deviceconfig.py
Normal file
21
kasa/tests/test_deviceconfig.py
Normal file
@ -0,0 +1,21 @@
|
||||
from json import dumps as json_dumps
|
||||
from json import loads as json_loads
|
||||
|
||||
import httpx
|
||||
|
||||
from kasa.credentials import Credentials
|
||||
from kasa.deviceconfig import (
|
||||
ConnectionType,
|
||||
DeviceConfig,
|
||||
DeviceFamilyType,
|
||||
EncryptType,
|
||||
)
|
||||
|
||||
|
||||
def test_serialization():
|
||||
config = DeviceConfig(host="Foo", http_client=httpx.AsyncClient())
|
||||
config_dict = config.to_dict()
|
||||
config_json = json_dumps(config_dict)
|
||||
config2_dict = json_loads(config_json)
|
||||
config2 = DeviceConfig.from_dict(config2_dict)
|
||||
assert config == config2
|
@ -1,21 +1,29 @@
|
||||
# type: ignore
|
||||
import logging
|
||||
import re
|
||||
import socket
|
||||
|
||||
import httpx
|
||||
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
|
||||
|
||||
from kasa import (
|
||||
Credentials,
|
||||
DeviceType,
|
||||
Discover,
|
||||
SmartDevice,
|
||||
SmartDeviceException,
|
||||
SmartStrip,
|
||||
protocol,
|
||||
)
|
||||
from kasa.deviceconfig import (
|
||||
ConnectionType,
|
||||
DeviceConfig,
|
||||
DeviceFamilyType,
|
||||
EncryptType,
|
||||
)
|
||||
from kasa.discover import DiscoveryResult, _DiscoverProtocol, json_dumps
|
||||
from kasa.exceptions import AuthenticationException, UnsupportedDeviceException
|
||||
|
||||
from .conftest import bulb, bulb_iot, dimmer, lightstrip, plug, strip
|
||||
from .conftest import bulb, bulb_iot, dimmer, lightstrip, new_discovery, plug, strip
|
||||
|
||||
UNSUPPORTED = {
|
||||
"result": {
|
||||
@ -89,13 +97,26 @@ async def test_discover_single(discovery_mock, custom_port, mocker):
|
||||
host = "127.0.0.1"
|
||||
discovery_mock.ip = host
|
||||
discovery_mock.port_override = custom_port
|
||||
update_mock = mocker.patch.object(SmartStrip, "update")
|
||||
|
||||
x = await Discover.discover_single(host, port=custom_port)
|
||||
device_class = Discover._get_device_class(discovery_mock.discovery_data)
|
||||
update_mock = mocker.patch.object(device_class, "update")
|
||||
|
||||
x = await Discover.discover_single(
|
||||
host, port=custom_port, credentials=Credentials()
|
||||
)
|
||||
assert issubclass(x.__class__, SmartDevice)
|
||||
assert x._discovery_info is not None
|
||||
assert x.port == custom_port or x.port == discovery_mock.default_port
|
||||
assert (update_mock.call_count > 0) == isinstance(x, SmartStrip)
|
||||
assert update_mock.call_count == 0
|
||||
|
||||
ct = ConnectionType.from_values(
|
||||
discovery_mock.device_type, discovery_mock.encrypt_type
|
||||
)
|
||||
uses_http = discovery_mock.default_port == 80
|
||||
config = DeviceConfig(
|
||||
host=host, port_override=custom_port, connection_type=ct, uses_http=uses_http
|
||||
)
|
||||
assert x.config == config
|
||||
|
||||
|
||||
async def test_discover_single_hostname(discovery_mock, mocker):
|
||||
@ -104,47 +125,39 @@ async def test_discover_single_hostname(discovery_mock, mocker):
|
||||
ip = "127.0.0.1"
|
||||
|
||||
discovery_mock.ip = ip
|
||||
update_mock = mocker.patch.object(SmartStrip, "update")
|
||||
device_class = Discover._get_device_class(discovery_mock.discovery_data)
|
||||
update_mock = mocker.patch.object(device_class, "update")
|
||||
|
||||
x = await Discover.discover_single(host)
|
||||
x = await Discover.discover_single(host, credentials=Credentials())
|
||||
assert issubclass(x.__class__, SmartDevice)
|
||||
assert x._discovery_info is not None
|
||||
assert x.host == host
|
||||
assert (update_mock.call_count > 0) == isinstance(x, SmartStrip)
|
||||
assert update_mock.call_count == 0
|
||||
|
||||
mocker.patch("socket.getaddrinfo", side_effect=socket.gaierror())
|
||||
with pytest.raises(SmartDeviceException):
|
||||
x = await Discover.discover_single(host)
|
||||
x = await Discover.discover_single(host, credentials=Credentials())
|
||||
|
||||
|
||||
async def test_discover_single_unsupported(mocker):
|
||||
async def test_discover_single_unsupported(unsupported_device_info, mocker):
|
||||
"""Make sure that discover_single handles unsupported devices correctly."""
|
||||
host = "127.0.0.1"
|
||||
|
||||
def mock_discover(self):
|
||||
if discovery_data:
|
||||
data = (
|
||||
b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8"
|
||||
+ json_dumps(discovery_data).encode()
|
||||
)
|
||||
self.datagram_received(data, (host, 20002))
|
||||
|
||||
mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover)
|
||||
|
||||
# Test with a valid unsupported response
|
||||
discovery_data = UNSUPPORTED
|
||||
with pytest.raises(
|
||||
UnsupportedDeviceException,
|
||||
match=f"Unsupported device {host} of type SMART.TAPOXMASTREE: {re.escape(str(UNSUPPORTED))}",
|
||||
):
|
||||
await Discover.discover_single(host)
|
||||
|
||||
# Test with no response
|
||||
discovery_data = None
|
||||
|
||||
async def test_discover_single_no_response(mocker):
|
||||
"""Make sure that discover_single handles no response correctly."""
|
||||
host = "127.0.0.1"
|
||||
mocker.patch.object(_DiscoverProtocol, "do_discover")
|
||||
with pytest.raises(
|
||||
SmartDeviceException, match=f"Timed out getting discovery response for {host}"
|
||||
):
|
||||
await Discover.discover_single(host, timeout=0.001)
|
||||
await Discover.discover_single(host, discovery_timeout=0)
|
||||
|
||||
|
||||
INVALIDS = [
|
||||
@ -241,52 +254,82 @@ AUTHENTICATION_DATA_KLAP = {
|
||||
}
|
||||
|
||||
|
||||
async def test_discover_single_authentication(mocker):
|
||||
@new_discovery
|
||||
async def test_discover_single_authentication(discovery_mock, mocker):
|
||||
"""Make sure that discover_single handles authenticating devices correctly."""
|
||||
host = "127.0.0.1"
|
||||
|
||||
def mock_discover(self):
|
||||
if discovery_data:
|
||||
data = (
|
||||
b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8"
|
||||
+ json_dumps(discovery_data).encode()
|
||||
)
|
||||
self.datagram_received(data, (host, 20002))
|
||||
|
||||
mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover)
|
||||
discovery_mock.ip = host
|
||||
device_class = Discover._get_device_class(discovery_mock.discovery_data)
|
||||
mocker.patch.object(
|
||||
SmartDevice,
|
||||
device_class,
|
||||
"update",
|
||||
side_effect=AuthenticationException("Failed to authenticate"),
|
||||
)
|
||||
|
||||
# Test with a valid unsupported response
|
||||
discovery_data = AUTHENTICATION_DATA_KLAP
|
||||
with pytest.raises(
|
||||
AuthenticationException,
|
||||
match="Failed to authenticate",
|
||||
):
|
||||
device = await Discover.discover_single(host)
|
||||
device = await Discover.discover_single(
|
||||
host, credentials=Credentials("foo", "bar")
|
||||
)
|
||||
await device.update()
|
||||
|
||||
mocker.patch.object(SmartDevice, "update")
|
||||
device = await Discover.discover_single(host)
|
||||
mocker.patch.object(device_class, "update")
|
||||
device = await Discover.discover_single(host, credentials=Credentials("foo", "bar"))
|
||||
await device.update()
|
||||
assert device.device_type == DeviceType.Plug
|
||||
assert isinstance(device, device_class)
|
||||
|
||||
|
||||
async def test_device_update_from_new_discovery_info():
|
||||
@new_discovery
|
||||
async def test_device_update_from_new_discovery_info(discovery_data):
|
||||
device = SmartDevice("127.0.0.7")
|
||||
discover_info = DiscoveryResult(**AUTHENTICATION_DATA_KLAP["result"])
|
||||
discover_info = DiscoveryResult(**discovery_data["result"])
|
||||
discover_dump = discover_info.get_dict()
|
||||
discover_dump["alias"] = "foobar"
|
||||
discover_dump["model"] = discover_dump["device_model"]
|
||||
device.update_from_discover_info(discover_dump)
|
||||
|
||||
assert device.alias == discover_dump["alias"]
|
||||
assert device.alias == "foobar"
|
||||
assert device.mac == discover_dump["mac"].replace("-", ":")
|
||||
assert device.model == discover_dump["model"]
|
||||
assert device.model == discover_dump["device_model"]
|
||||
|
||||
with pytest.raises(
|
||||
SmartDeviceException,
|
||||
match=re.escape("You need to await update() to access the data"),
|
||||
):
|
||||
assert device.supported_modules
|
||||
|
||||
|
||||
async def test_discover_single_http_client(discovery_mock, mocker):
|
||||
"""Make sure that discover_single returns an initialized SmartDevice instance."""
|
||||
host = "127.0.0.1"
|
||||
discovery_mock.ip = host
|
||||
|
||||
http_client = httpx.AsyncClient()
|
||||
|
||||
x: SmartDevice = await Discover.discover_single(host)
|
||||
|
||||
assert x.config.uses_http == (discovery_mock.default_port == 80)
|
||||
|
||||
if discovery_mock.default_port == 80:
|
||||
assert x.protocol._transport._http_client != http_client
|
||||
x.config.http_client = http_client
|
||||
assert x.protocol._transport._http_client == http_client
|
||||
|
||||
|
||||
async def test_discover_http_client(discovery_mock, mocker):
|
||||
"""Make sure that discover_single returns an initialized SmartDevice instance."""
|
||||
host = "127.0.0.1"
|
||||
discovery_mock.ip = host
|
||||
|
||||
http_client = httpx.AsyncClient()
|
||||
|
||||
devices = await Discover.discover(discovery_timeout=0)
|
||||
x: SmartDevice = devices[host]
|
||||
assert x.config.uses_http == (discovery_mock.default_port == 80)
|
||||
|
||||
if discovery_mock.default_port == 80:
|
||||
assert x.protocol._transport._http_client != http_client
|
||||
x.config.http_client = http_client
|
||||
assert x.protocol._transport._http_client == http_client
|
||||
|
@ -12,9 +12,15 @@ import pytest
|
||||
|
||||
from ..aestransport import AesTransport
|
||||
from ..credentials import Credentials
|
||||
from ..deviceconfig import DeviceConfig
|
||||
from ..exceptions import AuthenticationException, SmartDeviceException
|
||||
from ..iotprotocol import IotProtocol
|
||||
from ..klaptransport import KlapEncryptionSession, KlapTransport, _sha256
|
||||
from ..klaptransport import (
|
||||
KlapEncryptionSession,
|
||||
KlapTransport,
|
||||
KlapTransportV2,
|
||||
_sha256,
|
||||
)
|
||||
from ..smartprotocol import SmartProtocol
|
||||
|
||||
DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}}
|
||||
@ -31,8 +37,9 @@ class _mock_response:
|
||||
[
|
||||
(Exception("dummy exception"), True),
|
||||
(SmartDeviceException("dummy exception"), False),
|
||||
(httpx.ConnectError("dummy exception"), True),
|
||||
],
|
||||
ids=("Exception", "SmartDeviceException"),
|
||||
ids=("Exception", "SmartDeviceException", "httpx.ConnectError"),
|
||||
)
|
||||
@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
|
||||
@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol])
|
||||
@ -42,8 +49,10 @@ async def test_protocol_retries(
|
||||
):
|
||||
host = "127.0.0.1"
|
||||
conn = mocker.patch.object(httpx.AsyncClient, "post", side_effect=error)
|
||||
|
||||
config = DeviceConfig(host)
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await protocol_class(host, transport=transport_class(host)).query(
|
||||
await protocol_class(transport=transport_class(config=config)).query(
|
||||
DUMMY_QUERY, retry_count=retry_count
|
||||
)
|
||||
|
||||
@ -60,10 +69,11 @@ async def test_protocol_no_retry_on_connection_error(
|
||||
conn = mocker.patch.object(
|
||||
httpx.AsyncClient,
|
||||
"post",
|
||||
side_effect=httpx.ConnectError("foo"),
|
||||
side_effect=AuthenticationException("foo"),
|
||||
)
|
||||
config = DeviceConfig(host)
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await protocol_class(host, transport=transport_class(host)).query(
|
||||
await protocol_class(transport=transport_class(config=config)).query(
|
||||
DUMMY_QUERY, retry_count=5
|
||||
)
|
||||
|
||||
@ -81,8 +91,9 @@ async def test_protocol_retry_recoverable_error(
|
||||
"post",
|
||||
side_effect=httpx.CloseError("foo"),
|
||||
)
|
||||
config = DeviceConfig(host)
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await protocol_class(host, transport=transport_class(host)).query(
|
||||
await protocol_class(transport=transport_class(config=config)).query(
|
||||
DUMMY_QUERY, retry_count=5
|
||||
)
|
||||
|
||||
@ -115,7 +126,8 @@ async def test_protocol_reconnect(mocker, retry_count, protocol_class, transport
|
||||
side_effect=_fail_one_less_than_retry_count,
|
||||
)
|
||||
|
||||
response = await protocol_class(host, transport=transport_class(host)).query(
|
||||
config = DeviceConfig(host)
|
||||
response = await protocol_class(transport=transport_class(config=config)).query(
|
||||
DUMMY_QUERY, retry_count=retry_count
|
||||
)
|
||||
assert "result" in response or "foobar" in response
|
||||
@ -136,7 +148,9 @@ async def test_protocol_logging(mocker, caplog, log_level):
|
||||
seed = secrets.token_bytes(16)
|
||||
auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar"))
|
||||
encryption_session = KlapEncryptionSession(seed, seed, auth_hash)
|
||||
protocol = IotProtocol("127.0.0.1", transport=KlapTransport("127.0.0.1"))
|
||||
|
||||
config = DeviceConfig("127.0.0.1")
|
||||
protocol = IotProtocol(transport=KlapTransport(config=config))
|
||||
|
||||
protocol._transport._handshake_done = True
|
||||
protocol._transport._session_expire_at = time.time() + 86400
|
||||
@ -181,7 +195,7 @@ def test_encrypt_unicode():
|
||||
"device_credentials, expectation",
|
||||
[
|
||||
(Credentials("foo", "bar"), does_not_raise()),
|
||||
(Credentials("", ""), does_not_raise()),
|
||||
(Credentials(), does_not_raise()),
|
||||
(
|
||||
Credentials(
|
||||
KlapTransport.KASA_SETUP_EMAIL,
|
||||
@ -196,30 +210,37 @@ def test_encrypt_unicode():
|
||||
],
|
||||
ids=("client", "blank", "kasa_setup", "shouldfail"),
|
||||
)
|
||||
async def test_handshake1(mocker, device_credentials, expectation):
|
||||
@pytest.mark.parametrize(
|
||||
"transport_class, seed_auth_hash_calc",
|
||||
[
|
||||
pytest.param(KlapTransport, lambda c, s, a: c + a, id="KLAP"),
|
||||
pytest.param(KlapTransportV2, lambda c, s, a: c + s + a, id="KLAPV2"),
|
||||
],
|
||||
)
|
||||
async def test_handshake1(
|
||||
mocker, device_credentials, expectation, transport_class, seed_auth_hash_calc
|
||||
):
|
||||
async def _return_handshake1_response(url, params=None, data=None, *_, **__):
|
||||
nonlocal client_seed, server_seed, device_auth_hash
|
||||
|
||||
client_seed = data
|
||||
client_seed_auth_hash = _sha256(data + device_auth_hash)
|
||||
|
||||
return _mock_response(200, server_seed + client_seed_auth_hash)
|
||||
seed_auth_hash = _sha256(
|
||||
seed_auth_hash_calc(client_seed, server_seed, device_auth_hash)
|
||||
)
|
||||
return _mock_response(200, server_seed + seed_auth_hash)
|
||||
|
||||
client_seed = None
|
||||
server_seed = secrets.token_bytes(16)
|
||||
client_credentials = Credentials("foo", "bar")
|
||||
device_auth_hash = KlapTransport.generate_auth_hash(device_credentials)
|
||||
device_auth_hash = transport_class.generate_auth_hash(device_credentials)
|
||||
|
||||
mocker.patch.object(
|
||||
httpx.AsyncClient, "post", side_effect=_return_handshake1_response
|
||||
)
|
||||
|
||||
protocol = IotProtocol(
|
||||
"127.0.0.1",
|
||||
transport=KlapTransport("127.0.0.1", credentials=client_credentials),
|
||||
)
|
||||
config = DeviceConfig("127.0.0.1", credentials=client_credentials)
|
||||
protocol = IotProtocol(transport=transport_class(config=config))
|
||||
|
||||
protocol._transport.http_client = httpx.AsyncClient()
|
||||
with expectation:
|
||||
(
|
||||
local_seed,
|
||||
@ -233,31 +254,51 @@ async def test_handshake1(mocker, device_credentials, expectation):
|
||||
await protocol.close()
|
||||
|
||||
|
||||
async def test_handshake(mocker):
|
||||
@pytest.mark.parametrize(
|
||||
"transport_class, seed_auth_hash_calc1, seed_auth_hash_calc2",
|
||||
[
|
||||
pytest.param(
|
||||
KlapTransport, lambda c, s, a: c + a, lambda c, s, a: s + a, id="KLAP"
|
||||
),
|
||||
pytest.param(
|
||||
KlapTransportV2,
|
||||
lambda c, s, a: c + s + a,
|
||||
lambda c, s, a: s + c + a,
|
||||
id="KLAPV2",
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_handshake(
|
||||
mocker, transport_class, seed_auth_hash_calc1, seed_auth_hash_calc2
|
||||
):
|
||||
async def _return_handshake_response(url, params=None, data=None, *_, **__):
|
||||
nonlocal response_status, client_seed, server_seed, device_auth_hash
|
||||
nonlocal client_seed, server_seed, device_auth_hash
|
||||
|
||||
if url == "http://127.0.0.1/app/handshake1":
|
||||
client_seed = data
|
||||
client_seed_auth_hash = _sha256(data + device_auth_hash)
|
||||
seed_auth_hash = _sha256(
|
||||
seed_auth_hash_calc1(client_seed, server_seed, device_auth_hash)
|
||||
)
|
||||
|
||||
return _mock_response(200, server_seed + client_seed_auth_hash)
|
||||
return _mock_response(200, server_seed + seed_auth_hash)
|
||||
elif url == "http://127.0.0.1/app/handshake2":
|
||||
seed_auth_hash = _sha256(
|
||||
seed_auth_hash_calc2(client_seed, server_seed, device_auth_hash)
|
||||
)
|
||||
assert data == seed_auth_hash
|
||||
return _mock_response(response_status, b"")
|
||||
|
||||
client_seed = None
|
||||
server_seed = secrets.token_bytes(16)
|
||||
client_credentials = Credentials("foo", "bar")
|
||||
device_auth_hash = KlapTransport.generate_auth_hash(client_credentials)
|
||||
device_auth_hash = transport_class.generate_auth_hash(client_credentials)
|
||||
|
||||
mocker.patch.object(
|
||||
httpx.AsyncClient, "post", side_effect=_return_handshake_response
|
||||
)
|
||||
|
||||
protocol = IotProtocol(
|
||||
"127.0.0.1",
|
||||
transport=KlapTransport("127.0.0.1", credentials=client_credentials),
|
||||
)
|
||||
config = DeviceConfig("127.0.0.1", credentials=client_credentials)
|
||||
protocol = IotProtocol(transport=transport_class(config=config))
|
||||
protocol._transport.http_client = httpx.AsyncClient()
|
||||
|
||||
response_status = 200
|
||||
@ -273,7 +314,7 @@ async def test_handshake(mocker):
|
||||
|
||||
async def test_query(mocker):
|
||||
async def _return_response(url, params=None, data=None, *_, **__):
|
||||
nonlocal client_seed, server_seed, device_auth_hash, protocol, seq
|
||||
nonlocal client_seed, server_seed, device_auth_hash, seq
|
||||
|
||||
if url == "http://127.0.0.1/app/handshake1":
|
||||
client_seed = data
|
||||
@ -303,10 +344,8 @@ async def test_query(mocker):
|
||||
|
||||
mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response)
|
||||
|
||||
protocol = IotProtocol(
|
||||
"127.0.0.1",
|
||||
transport=KlapTransport("127.0.0.1", credentials=client_credentials),
|
||||
)
|
||||
config = DeviceConfig("127.0.0.1", credentials=client_credentials)
|
||||
protocol = IotProtocol(transport=KlapTransport(config=config))
|
||||
|
||||
for _ in range(10):
|
||||
resp = await protocol.query({})
|
||||
@ -350,10 +389,8 @@ async def test_authentication_failures(mocker, response_status, expectation):
|
||||
|
||||
mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response)
|
||||
|
||||
protocol = IotProtocol(
|
||||
"127.0.0.1",
|
||||
transport=KlapTransport("127.0.0.1", credentials=client_credentials),
|
||||
)
|
||||
config = DeviceConfig("127.0.0.1", credentials=client_credentials)
|
||||
protocol = IotProtocol(transport=KlapTransport(config=config))
|
||||
|
||||
with expectation:
|
||||
await protocol.query({})
|
||||
|
@ -9,6 +9,7 @@ import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from ..deviceconfig import DeviceConfig
|
||||
from ..exceptions import SmartDeviceException
|
||||
from ..protocol import (
|
||||
BaseTransport,
|
||||
@ -31,10 +32,11 @@ async def test_protocol_retries(mocker, retry_count):
|
||||
return reader, writer
|
||||
|
||||
conn = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
|
||||
config = DeviceConfig("127.0.0.1")
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await TPLinkSmartHomeProtocol(
|
||||
"127.0.0.1", transport=_XorTransport("127.0.0.1")
|
||||
).query({}, retry_count=retry_count)
|
||||
await TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)).query(
|
||||
{}, retry_count=retry_count
|
||||
)
|
||||
|
||||
assert conn.call_count == retry_count + 1
|
||||
|
||||
@ -44,10 +46,11 @@ async def test_protocol_no_retry_on_unreachable(mocker):
|
||||
"asyncio.open_connection",
|
||||
side_effect=OSError(errno.EHOSTUNREACH, "No route to host"),
|
||||
)
|
||||
config = DeviceConfig("127.0.0.1")
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await TPLinkSmartHomeProtocol(
|
||||
"127.0.0.1", transport=_XorTransport("127.0.0.1")
|
||||
).query({}, retry_count=5)
|
||||
await TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)).query(
|
||||
{}, retry_count=5
|
||||
)
|
||||
|
||||
assert conn.call_count == 1
|
||||
|
||||
@ -57,10 +60,11 @@ async def test_protocol_no_retry_connection_refused(mocker):
|
||||
"asyncio.open_connection",
|
||||
side_effect=ConnectionRefusedError,
|
||||
)
|
||||
config = DeviceConfig("127.0.0.1")
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await TPLinkSmartHomeProtocol(
|
||||
"127.0.0.1", transport=_XorTransport("127.0.0.1")
|
||||
).query({}, retry_count=5)
|
||||
await TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)).query(
|
||||
{}, retry_count=5
|
||||
)
|
||||
|
||||
assert conn.call_count == 1
|
||||
|
||||
@ -70,10 +74,11 @@ async def test_protocol_retry_recoverable_error(mocker):
|
||||
"asyncio.open_connection",
|
||||
side_effect=OSError(errno.ECONNRESET, "Connection reset by peer"),
|
||||
)
|
||||
config = DeviceConfig("127.0.0.1")
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await TPLinkSmartHomeProtocol(
|
||||
"127.0.0.1", transport=_XorTransport("127.0.0.1")
|
||||
).query({}, retry_count=5)
|
||||
await TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)).query(
|
||||
{}, retry_count=5
|
||||
)
|
||||
|
||||
assert conn.call_count == 6
|
||||
|
||||
@ -107,9 +112,8 @@ async def test_protocol_reconnect(mocker, retry_count):
|
||||
mocker.patch.object(reader, "readexactly", _mock_read)
|
||||
return reader, writer
|
||||
|
||||
protocol = TPLinkSmartHomeProtocol(
|
||||
"127.0.0.1", transport=_XorTransport("127.0.0.1")
|
||||
)
|
||||
config = DeviceConfig("127.0.0.1")
|
||||
protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config))
|
||||
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
|
||||
response = await protocol.query({}, retry_count=retry_count)
|
||||
assert response == {"great": "success"}
|
||||
@ -137,9 +141,8 @@ async def test_protocol_logging(mocker, caplog, log_level):
|
||||
mocker.patch.object(reader, "readexactly", _mock_read)
|
||||
return reader, writer
|
||||
|
||||
protocol = TPLinkSmartHomeProtocol(
|
||||
"127.0.0.1", transport=_XorTransport("127.0.0.1")
|
||||
)
|
||||
config = DeviceConfig("127.0.0.1")
|
||||
protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config))
|
||||
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
|
||||
response = await protocol.query({})
|
||||
assert response == {"great": "success"}
|
||||
@ -173,9 +176,8 @@ async def test_protocol_custom_port(mocker, custom_port):
|
||||
mocker.patch.object(reader, "readexactly", _mock_read)
|
||||
return reader, writer
|
||||
|
||||
protocol = TPLinkSmartHomeProtocol(
|
||||
"127.0.0.1", transport=_XorTransport("127.0.0.1", port=custom_port)
|
||||
)
|
||||
config = DeviceConfig("127.0.0.1", port_override=custom_port)
|
||||
protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config))
|
||||
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
|
||||
response = await protocol.query({})
|
||||
assert response == {"great": "success"}
|
||||
@ -271,18 +273,14 @@ def _get_subclasses(of_class):
|
||||
def test_protocol_init_signature(class_name_obj):
|
||||
params = list(inspect.signature(class_name_obj[1].__init__).parameters.values())
|
||||
|
||||
assert len(params) == 3
|
||||
assert len(params) == 2
|
||||
assert (
|
||||
params[0].name == "self"
|
||||
and params[0].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||||
)
|
||||
assert (
|
||||
params[1].name == "host"
|
||||
and params[1].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||||
)
|
||||
assert (
|
||||
params[2].name == "transport"
|
||||
and params[2].kind == inspect.Parameter.KEYWORD_ONLY
|
||||
params[1].name == "transport"
|
||||
and params[1].kind == inspect.Parameter.KEYWORD_ONLY
|
||||
)
|
||||
|
||||
|
||||
@ -292,20 +290,11 @@ def test_protocol_init_signature(class_name_obj):
|
||||
def test_transport_init_signature(class_name_obj):
|
||||
params = list(inspect.signature(class_name_obj[1].__init__).parameters.values())
|
||||
|
||||
assert len(params) == 5
|
||||
assert len(params) == 2
|
||||
assert (
|
||||
params[0].name == "self"
|
||||
and params[0].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||||
)
|
||||
assert (
|
||||
params[1].name == "host"
|
||||
and params[1].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||||
)
|
||||
assert params[2].name == "port" and params[2].kind == inspect.Parameter.KEYWORD_ONLY
|
||||
assert (
|
||||
params[3].name == "credentials"
|
||||
and params[3].kind == inspect.Parameter.KEYWORD_ONLY
|
||||
)
|
||||
assert (
|
||||
params[4].name == "timeout" and params[4].kind == inspect.Parameter.KEYWORD_ONLY
|
||||
params[1].name == "config" and params[1].kind == inspect.Parameter.KEYWORD_ONLY
|
||||
)
|
||||
|
@ -5,8 +5,7 @@ from unittest.mock import Mock, patch
|
||||
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
|
||||
|
||||
import kasa
|
||||
from kasa import Credentials, SmartDevice, SmartDeviceException
|
||||
from kasa.smartdevice import DeviceType
|
||||
from kasa import Credentials, DeviceConfig, SmartDevice, SmartDeviceException
|
||||
|
||||
from .conftest import device_iot, handle_turn_on, has_emeter, no_emeter_iot, turn_on
|
||||
from .newfakes import PLUG_SCHEMA, TZ_SCHEMA, FakeTransportProtocol
|
||||
@ -215,7 +214,8 @@ def test_device_class_ctors(device_class):
|
||||
host = "127.0.0.2"
|
||||
port = 1234
|
||||
credentials = Credentials("foo", "bar")
|
||||
dev = device_class(host, port=port, credentials=credentials)
|
||||
config = DeviceConfig(host, port_override=port, credentials=credentials)
|
||||
dev = device_class(host, config=config)
|
||||
assert dev.host == host
|
||||
assert dev.port == port
|
||||
assert dev.credentials == credentials
|
||||
@ -231,29 +231,27 @@ async def test_modules_preserved(dev: SmartDevice):
|
||||
|
||||
async def test_create_smart_device_with_timeout():
|
||||
"""Make sure timeout is passed to the protocol."""
|
||||
dev = SmartDevice(host="127.0.0.1", timeout=100)
|
||||
host = "127.0.0.1"
|
||||
dev = SmartDevice(host, config=DeviceConfig(host, timeout=100))
|
||||
assert dev.protocol._transport._timeout == 100
|
||||
|
||||
|
||||
async def test_create_thin_wrapper():
|
||||
"""Make sure thin wrapper is created with the correct device type."""
|
||||
mock = Mock()
|
||||
config = DeviceConfig(
|
||||
host="test_host",
|
||||
port_override=1234,
|
||||
timeout=100,
|
||||
credentials=Credentials("username", "password"),
|
||||
)
|
||||
with patch("kasa.device_factory.connect", return_value=mock) as connect:
|
||||
dev = await SmartDevice.connect(
|
||||
host="test_host",
|
||||
port=1234,
|
||||
timeout=100,
|
||||
credentials=Credentials("username", "password"),
|
||||
device_type=DeviceType.Strip,
|
||||
)
|
||||
dev = await SmartDevice.connect(config=config)
|
||||
assert dev is mock
|
||||
|
||||
connect.assert_called_once_with(
|
||||
host="test_host",
|
||||
port=1234,
|
||||
timeout=100,
|
||||
credentials=Credentials("username", "password"),
|
||||
device_type=DeviceType.Strip,
|
||||
host=None,
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
|
@ -13,6 +13,7 @@ import pytest
|
||||
|
||||
from ..aestransport import AesTransport
|
||||
from ..credentials import Credentials
|
||||
from ..deviceconfig import DeviceConfig
|
||||
from ..exceptions import (
|
||||
SMART_RETRYABLE_ERRORS,
|
||||
SMART_TIMEOUT_ERRORS,
|
||||
@ -37,7 +38,8 @@ async def test_smart_device_errors(mocker, error_code):
|
||||
|
||||
send_mock = mocker.patch.object(AesTransport, "send", return_value=mock_response)
|
||||
|
||||
protocol = SmartProtocol(host, transport=AesTransport(host))
|
||||
config = DeviceConfig(host, credentials=Credentials("foo", "bar"))
|
||||
protocol = SmartProtocol(transport=AesTransport(config=config))
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await protocol.query(DUMMY_QUERY, retry_count=2)
|
||||
|
||||
@ -70,8 +72,8 @@ async def test_smart_device_errors_in_multiple_request(mocker, error_code):
|
||||
mocker.patch.object(AesTransport, "perform_login")
|
||||
|
||||
send_mock = mocker.patch.object(AesTransport, "send", return_value=mock_response)
|
||||
|
||||
protocol = SmartProtocol(host, transport=AesTransport(host))
|
||||
config = DeviceConfig(host, credentials=Credentials("foo", "bar"))
|
||||
protocol = SmartProtocol(transport=AesTransport(config=config))
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await protocol.query(DUMMY_QUERY, retry_count=2)
|
||||
if error_code in chain(SMART_TIMEOUT_ERRORS, SMART_RETRYABLE_ERRORS):
|
||||
|
Loading…
Reference in New Issue
Block a user