diff --git a/docs/source/design.rst b/docs/source/design.rst index 5679943d..6538c8b8 100644 --- a/docs/source/design.rst +++ b/docs/source/design.rst @@ -23,9 +23,12 @@ This will return you a list of device instances based on the discovery replies. If the device's host is already known, you can use to construct a device instance with :meth:`~kasa.SmartDevice.connect()`. -When connecting a device with the :meth:`~kasa.SmartDevice.connect()` method, it is recommended to -pass the device type as well as this allows the library to use the correct device class for the -device without having to query the device. +The :meth:`~kasa.SmartDevice.connect()` also enables support for connecting to new +KASA SMART protocol and TAPO devices directly using the parameter :class:`~kasa.DeviceConfig`. +Simply serialize the :attr:`~kasa.SmartDevice.config` property via :meth:`~kasa.DeviceConfig.to_dict()` +and then deserialize it later with :func:`~kasa.DeviceConfig.from_dict()` +and then pass it into :meth:`~kasa.SmartDevice.connect()`. + .. _update_cycle: diff --git a/docs/source/deviceconfig.rst b/docs/source/deviceconfig.rst new file mode 100644 index 00000000..25bf077b --- /dev/null +++ b/docs/source/deviceconfig.rst @@ -0,0 +1,18 @@ +DeviceConfig +============ + +.. contents:: Contents + :local: + +.. note:: + + Feel free to open a pull request to improve the documentation! + + +API documentation +***************** + +.. autoclass:: kasa.DeviceConfig + :members: + :inherited-members: + :undoc-members: diff --git a/docs/source/index.rst b/docs/source/index.rst index 346c53d0..16e7cbd0 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -15,3 +15,4 @@ smartdimmer smartstrip smartlightstrip + deviceconfig diff --git a/kasa/__init__.py b/kasa/__init__.py index 7de394c1..f5b795bd 100755 --- a/kasa/__init__.py +++ b/kasa/__init__.py @@ -14,6 +14,12 @@ to be handled by the user of the library. from importlib.metadata import version from kasa.credentials import Credentials +from kasa.deviceconfig import ( + ConnectionType, + DeviceConfig, + DeviceFamilyType, + EncryptType, +) from kasa.discover import Discover from kasa.emeterstatus import EmeterStatus from kasa.exceptions import ( @@ -55,4 +61,8 @@ __all__ = [ "AuthenticationException", "UnsupportedDeviceException", "Credentials", + "DeviceConfig", + "ConnectionType", + "EncryptType", + "DeviceFamilyType", ] diff --git a/kasa/aestransport.py b/kasa/aestransport.py index e7dd5356..b6fa3472 100644 --- a/kasa/aestransport.py +++ b/kasa/aestransport.py @@ -16,7 +16,7 @@ from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padd from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes -from .credentials import Credentials +from .deviceconfig import DeviceConfig from .exceptions import ( SMART_AUTHENTICATION_ERRORS, SMART_RETRYABLE_ERRORS, @@ -47,8 +47,7 @@ class AesTransport(BaseTransport): protocol, sometimes used by newer firmware versions on kasa devices. """ - DEFAULT_PORT = 80 - DEFAULT_TIMEOUT = 5 + DEFAULT_PORT: int = 80 SESSION_COOKIE_NAME = "TP_SESSIONID" COMMON_HEADERS = { "Content-Type": "application/json", @@ -58,32 +57,37 @@ class AesTransport(BaseTransport): def __init__( self, - host: str, *, - port: Optional[int] = None, - credentials: Optional[Credentials] = None, - timeout: Optional[int] = None, + config: DeviceConfig, ) -> None: - super().__init__( - host, - port=port or self.DEFAULT_PORT, - credentials=credentials, - timeout=timeout, - ) + super().__init__(config=config) + + self._default_http_client: Optional[httpx.AsyncClient] = None self._handshake_done = False self._encryption_session: Optional[AesEncyptionSession] = None self._session_expire_at: Optional[float] = None - self._timeout = timeout if timeout else self.DEFAULT_TIMEOUT self._session_cookie = None - self._http_client: httpx.AsyncClient = httpx.AsyncClient() self._login_token = None _LOGGER.debug("Created AES transport for %s", self._host) + @property + def default_port(self): + """Default port for the transport.""" + return self.DEFAULT_PORT + + @property + def _http_client(self) -> httpx.AsyncClient: + if self._config.http_client: + return self._config.http_client + if not self._default_http_client: + self._default_http_client = httpx.AsyncClient() + return self._default_http_client + def hash_credentials(self, login_v2): """Hash the credentials.""" if login_v2: @@ -102,8 +106,6 @@ class AesTransport(BaseTransport): async def client_post(self, url, params=None, data=None, json=None, headers=None): """Send an http post request to the device.""" - if not self._http_client: - self._http_client = httpx.AsyncClient() response_data = None cookies = None if self._session_cookie: @@ -268,8 +270,8 @@ class AesTransport(BaseTransport): async def close(self) -> None: """Close the protocol.""" - client = self._http_client - self._http_client = None + client = self._default_http_client + self._default_http_client = None self._handshake_done = False self._login_token = None if client: diff --git a/kasa/cli.py b/kasa/cli.py index 3478c35a..13458b0e 100755 --- a/kasa/cli.py +++ b/kasa/cli.py @@ -12,15 +12,20 @@ import asyncclick as click from kasa import ( AuthenticationException, + ConnectionType, Credentials, - DeviceType, + DeviceConfig, + DeviceFamilyType, Discover, + EncryptType, SmartBulb, SmartDevice, + SmartDimmer, + SmartLightStrip, + SmartPlug, SmartStrip, UnsupportedDeviceException, ) -from kasa.device_factory import DEVICE_TYPE_TO_CLASS from kasa.discover import DiscoveryResult try: @@ -49,10 +54,19 @@ except ImportError: # --json has set it to _nop_echo echo = _do_echo -DEVICE_TYPES = [ - device_type.value - for device_type in DeviceType - if device_type in DEVICE_TYPE_TO_CLASS + +TYPE_TO_CLASS = { + "plug": SmartPlug, + "bulb": SmartBulb, + "dimmer": SmartDimmer, + "strip": SmartStrip, + "lightstrip": SmartLightStrip, +} + +ENCRYPT_TYPES = [encrypt_type.value for encrypt_type in EncryptType] + +DEVICE_FAMILY_TYPES = [ + device_family_type.value for device_family_type in DeviceFamilyType ] click.anyio_backend = "asyncio" @@ -149,7 +163,7 @@ def json_formatter_cb(result, **kwargs): "--type", envvar="KASA_TYPE", default=None, - type=click.Choice(DEVICE_TYPES, case_sensitive=False), + type=click.Choice(list(TYPE_TO_CLASS), case_sensitive=False), ) @click.option( "--json/--no-json", @@ -158,6 +172,18 @@ def json_formatter_cb(result, **kwargs): is_flag=True, help="Output raw device response as JSON.", ) +@click.option( + "--encrypt-type", + envvar="KASA_ENCRYPT_TYPE", + default=None, + type=click.Choice(ENCRYPT_TYPES, case_sensitive=False), +) +@click.option( + "--device-family", + envvar="KASA_DEVICE_FAMILY", + default=None, + type=click.Choice(DEVICE_FAMILY_TYPES, case_sensitive=False), +) @click.option( "--timeout", envvar="KASA_TIMEOUT", @@ -199,6 +225,8 @@ async def cli( verbose, debug, type, + encrypt_type, + device_family, json, timeout, discovery_timeout, @@ -270,12 +298,19 @@ async def cli( return await ctx.invoke(discover) if type is not None: - device_type = DeviceType.from_value(type) - dev = await SmartDevice.connect( - host, credentials=credentials, device_type=device_type, timeout=timeout + dev = TYPE_TO_CLASS[type](host) + await dev.update() + elif device_family or encrypt_type: + ctype = ConnectionType( + DeviceFamilyType(device_family), + EncryptType(encrypt_type), ) + config = DeviceConfig( + host=host, credentials=credentials, timeout=timeout, connection_type=ctype + ) + dev = await SmartDevice.connect(config=config) else: - echo("No --type defined, discovering..") + echo("No --type or --device-family and --encrypt-type defined, discovering..") dev = await Discover.discover_single( host, port=port, @@ -332,8 +367,10 @@ async def discover(ctx): target = ctx.parent.params["target"] username = ctx.parent.params["username"] password = ctx.parent.params["password"] - timeout = ctx.parent.params["discovery_timeout"] verbose = ctx.parent.params["verbose"] + discovery_timeout = ctx.parent.params["discovery_timeout"] + timeout = ctx.parent.params["timeout"] + port = ctx.parent.params["port"] credentials = Credentials(username, password) @@ -354,7 +391,7 @@ async def discover(ctx): echo(f"\t{unsupported_exception}") echo() - echo(f"Discovering devices on {target} for {timeout} seconds") + echo(f"Discovering devices on {target} for {discovery_timeout} seconds") async def print_discovered(dev: SmartDevice): async with sem: @@ -376,9 +413,11 @@ async def discover(ctx): await Discover.discover( target=target, - timeout=timeout, + discovery_timeout=discovery_timeout, on_discovered=print_discovered, on_unsupported=print_unsupported, + port=port, + timeout=timeout, credentials=credentials, ) diff --git a/kasa/credentials.py b/kasa/credentials.py index a56f5710..4ae4df35 100644 --- a/kasa/credentials.py +++ b/kasa/credentials.py @@ -8,5 +8,5 @@ from typing import Optional class Credentials: """Credentials for authentication.""" - username: Optional[str] = field(default=None, repr=False) - password: Optional[str] = field(default=None, repr=False) + username: Optional[str] = field(default="", repr=False) + password: Optional[str] = field(default="", repr=False) diff --git a/kasa/device_factory.py b/kasa/device_factory.py index d8a07bee..505b6487 100755 --- a/kasa/device_factory.py +++ b/kasa/device_factory.py @@ -1,18 +1,21 @@ -"""Device creation by type.""" - +"""Device creation via DeviceConfig.""" import logging import time from typing import Any, Dict, Optional, Tuple, Type from .aestransport import AesTransport -from .credentials import Credentials -from .device_type import DeviceType -from .exceptions import UnsupportedDeviceException +from .deviceconfig import DeviceConfig +from .exceptions import SmartDeviceException, UnsupportedDeviceException from .iotprotocol import IotProtocol -from .klaptransport import KlapTransport, TPlinkKlapTransportV2 -from .protocol import BaseTransport, TPLinkProtocol +from .klaptransport import KlapTransport, KlapTransportV2 +from .protocol import ( + BaseTransport, + TPLinkProtocol, + TPLinkSmartHomeProtocol, + _XorTransport, +) from .smartbulb import SmartBulb -from .smartdevice import SmartDevice, SmartDeviceException +from .smartdevice import SmartDevice from .smartdimmer import SmartDimmer from .smartlightstrip import SmartLightStrip from .smartplug import SmartPlug @@ -20,104 +23,80 @@ from .smartprotocol import SmartProtocol from .smartstrip import SmartStrip from .tapo import TapoBulb, TapoPlug -DEVICE_TYPE_TO_CLASS = { - DeviceType.Plug: SmartPlug, - DeviceType.Bulb: SmartBulb, - DeviceType.Strip: SmartStrip, - DeviceType.Dimmer: SmartDimmer, - DeviceType.LightStrip: SmartLightStrip, - DeviceType.TapoPlug: TapoPlug, - DeviceType.TapoBulb: TapoBulb, -} - _LOGGER = logging.getLogger(__name__) +GET_SYSINFO_QUERY = { + "system": {"get_sysinfo": None}, +} -async def connect( - host: str, - *, - port: Optional[int] = None, - timeout=5, - credentials: Optional[Credentials] = None, - device_type: Optional[DeviceType] = None, - protocol_class: Optional[Type[TPLinkProtocol]] = None, -) -> "SmartDevice": - """Connect to a single device by the given IP address. + +async def connect(*, host: Optional[str] = None, config: DeviceConfig) -> "SmartDevice": + """Connect to a single device by the given hostname or device configuration. This method avoids the UDP based discovery process and - will connect directly to the device to query its type. + will connect directly to the device. It is generally preferred to avoid :func:`discover_single()` and use this function instead as it should perform better when the WiFi network is congested or the device is not responding to discovery requests. - The device type is discovered by querying the device. + Do not use this function directly, use SmartDevice.connect() :param host: Hostname of device to query - :param device_type: Device type to use for the device. - If not given, the device type is discovered by querying the device. - If the device type is already known, it is preferred to pass it - to avoid the extra query to the device to discover its type. - :param protocol_class: Optionally provide the protocol class - to use. + :param config: Connection parameters to ensure the correct protocol + and connection options are used. :rtype: SmartDevice :return: Object for querying/controlling found device. """ - debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG) + if host and config or (not host and not config): + raise SmartDeviceException("One of host or config must be provded and not both") + if host: + config = DeviceConfig(host=host) + debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG) if debug_enabled: start_time = time.perf_counter() - if device_type and (klass := DEVICE_TYPE_TO_CLASS.get(device_type)): - dev: SmartDevice = klass( - host=host, port=port, credentials=credentials, timeout=timeout - ) - if protocol_class is not None: - dev.protocol = protocol_class( - host, - transport=AesTransport( - host, port=port, credentials=credentials, timeout=timeout - ), - ) - await dev.update() + def _perf_log(has_params, perf_type): + nonlocal start_time if debug_enabled: end_time = time.perf_counter() _LOGGER.debug( - "Device %s with known type (%s) took %.2f seconds to connect", - host, - device_type.value, - end_time - start_time, + f"Device {config.host} with connection params {has_params} " + + f"took {end_time - start_time:.2f} seconds to {perf_type}", ) - return dev + start_time = time.perf_counter() - unknown_dev = SmartDevice( - host=host, port=port, credentials=credentials, timeout=timeout - ) - if protocol_class is not None: - # TODO this will be replaced with connection params - unknown_dev.protocol = protocol_class( - host, - transport=AesTransport( - host, port=port, credentials=credentials, timeout=timeout - ), + if (protocol := get_protocol(config=config)) is None: + raise UnsupportedDeviceException( + f"Unsupported device for {config.host}: " + + f"{config.connection_type.device_family.value}" ) - await unknown_dev.update() - device_class = get_device_class_from_sys_info(unknown_dev.internal_state) - dev = device_class(host=host, port=port, credentials=credentials, timeout=timeout) - # Reuse the connection from the unknown device - # so we don't have to reconnect - dev.protocol = unknown_dev.protocol - await dev.update() - if debug_enabled: - end_time = time.perf_counter() - _LOGGER.debug( - "Device %s with unknown type (%s) took %.2f seconds to connect", - host, - dev.device_type.value, - end_time - start_time, + + device_class: Optional[Type[SmartDevice]] + + if isinstance(protocol, TPLinkSmartHomeProtocol): + info = await protocol.query(GET_SYSINFO_QUERY) + _perf_log(True, "get_sysinfo") + device_class = get_device_class_from_sys_info(info) + device = device_class(config.host, protocol=protocol) + device.update_from_discover_info(info) + await device.update() + _perf_log(True, "update") + return device + elif device_class := get_device_class_from_family( + config.connection_type.device_family.value + ): + device = device_class(host=config.host, protocol=protocol) + await device.update() + _perf_log(True, "update") + return device + else: + raise UnsupportedDeviceException( + f"Unsupported device for {config.host}: " + + f"{config.connection_type.device_family.value}" ) - return dev def get_device_class_from_sys_info(info: Dict[str, Any]) -> Type[SmartDevice]: @@ -147,32 +126,38 @@ def get_device_class_from_sys_info(info: Dict[str, Any]) -> Type[SmartDevice]: raise UnsupportedDeviceException("Unknown device type: %s" % type_) -def get_device_class_from_type_name(device_type: str) -> Optional[Type[SmartDevice]]: +def get_device_class_from_family(device_type: str) -> Optional[Type[SmartDevice]]: """Return the device class from the type name.""" supported_device_types: dict[str, Type[SmartDevice]] = { "SMART.TAPOPLUG": TapoPlug, "SMART.TAPOBULB": TapoBulb, "SMART.KASAPLUG": TapoPlug, "IOT.SMARTPLUGSWITCH": SmartPlug, + "IOT.SMARTBULB": SmartBulb, } return supported_device_types.get(device_type) -def get_protocol_from_connection_name( - connection_name: str, host: str, credentials: Optional[Credentials] = None +def get_protocol( + config: DeviceConfig, ) -> Optional[TPLinkProtocol]: """Return the protocol from the connection name.""" + protocol_name = config.connection_type.device_family.value.split(".")[0] + protocol_transport_key = ( + protocol_name + "." + config.connection_type.encryption_type.value + ) supported_device_protocols: dict[ str, Tuple[Type[TPLinkProtocol], Type[BaseTransport]] ] = { + "IOT.XOR": (TPLinkSmartHomeProtocol, _XorTransport), "IOT.KLAP": (IotProtocol, KlapTransport), "SMART.AES": (SmartProtocol, AesTransport), - "SMART.KLAP": (SmartProtocol, TPlinkKlapTransportV2), + "SMART.KLAP": (SmartProtocol, KlapTransportV2), } - if connection_name not in supported_device_protocols: + if protocol_transport_key not in supported_device_protocols: return None - protocol_class, transport_class = supported_device_protocols.get(connection_name) # type: ignore - transport: BaseTransport = transport_class(host, credentials=credentials) - protocol: TPLinkProtocol = protocol_class(host, transport=transport) - return protocol + protocol_class, transport_class = supported_device_protocols.get( + protocol_transport_key + ) # type: ignore + return protocol_class(transport=transport_class(config=config)) diff --git a/kasa/deviceconfig.py b/kasa/deviceconfig.py new file mode 100644 index 00000000..7a774b2e --- /dev/null +++ b/kasa/deviceconfig.py @@ -0,0 +1,148 @@ +"""Module for holding connection parameters.""" +import logging +from dataclasses import asdict, dataclass, field, fields, is_dataclass +from enum import Enum +from typing import Dict, Optional + +import httpx + +from .credentials import Credentials +from .exceptions import SmartDeviceException + +_LOGGER = logging.getLogger(__name__) + + +class EncryptType(Enum): + """Encrypt type enum.""" + + Klap = "KLAP" + Aes = "AES" + Xor = "XOR" + + +class DeviceFamilyType(Enum): + """Encrypt type enum.""" + + IotSmartPlugSwitch = "IOT.SMARTPLUGSWITCH" + IotSmartBulb = "IOT.SMARTBULB" + SmartKasaPlug = "SMART.KASAPLUG" + SmartTapoPlug = "SMART.TAPOPLUG" + SmartTapoBulb = "SMART.TAPOBULB" + + +def _dataclass_from_dict(klass, in_val): + if is_dataclass(klass): + fieldtypes = {f.name: f.type for f in fields(klass)} + val = {} + for dict_key in in_val: + if dict_key in fieldtypes and hasattr(fieldtypes[dict_key], "from_dict"): + val[dict_key] = fieldtypes[dict_key].from_dict(in_val[dict_key]) + else: + val[dict_key] = _dataclass_from_dict( + fieldtypes[dict_key], in_val[dict_key] + ) + return klass(**val) + else: + return in_val + + +def _dataclass_to_dict(in_val): + fieldtypes = {f.name: f.type for f in fields(in_val) if f.compare} + out_val = {} + for field_name in fieldtypes: + val = getattr(in_val, field_name) + if val is None: + continue + elif hasattr(val, "to_dict"): + out_val[field_name] = val.to_dict() + elif is_dataclass(fieldtypes[field_name]): + out_val[field_name] = asdict(val) + else: + out_val[field_name] = val + return out_val + + +@dataclass +class ConnectionType: + """Class to hold the the parameters determining connection type.""" + + device_family: DeviceFamilyType + encryption_type: EncryptType + + @staticmethod + def from_values( + device_family: str, + encryption_type: str, + ) -> "ConnectionType": + """Return connection parameters from string values.""" + try: + return ConnectionType( + DeviceFamilyType(device_family), + EncryptType(encryption_type), + ) + except ValueError as ex: + raise SmartDeviceException( + f"Invalid connection parameters for {device_family}.{encryption_type}" + ) from ex + + @staticmethod + def from_dict(connection_type_dict: Dict[str, str]) -> "ConnectionType": + """Return connection parameters from dict.""" + if ( + isinstance(connection_type_dict, dict) + and (device_family := connection_type_dict.get("device_family")) + and (encryption_type := connection_type_dict.get("encryption_type")) + ): + return ConnectionType.from_values(device_family, encryption_type) + + raise SmartDeviceException( + f"Invalid connection type data for {connection_type_dict}" + ) + + def to_dict(self) -> Dict[str, str]: + """Convert connection params to dict.""" + result = { + "device_family": self.device_family.value, + "encryption_type": self.encryption_type.value, + } + return result + + +@dataclass +class DeviceConfig: + """Class to represent paramaters that determine how to connect to devices.""" + + DEFAULT_TIMEOUT = 5 + + host: str + timeout: Optional[int] = DEFAULT_TIMEOUT + port_override: Optional[int] = None + credentials: Credentials = field( + default_factory=lambda: Credentials(username="", password="") + ) + connection_type: ConnectionType = field( + default_factory=lambda: ConnectionType( + DeviceFamilyType.IotSmartPlugSwitch, EncryptType.Xor + ) + ) + + uses_http: bool = False + # compare=False will be excluded from the serialization and object comparison. + http_client: Optional[httpx.AsyncClient] = field(default=None, compare=False) + + def __post_init__(self): + if self.credentials is None: + self.credentials = Credentials(username="", password="") + if self.connection_type is None: + self.connection_type = ConnectionType( + DeviceFamilyType.IotSmartPlugSwitch, EncryptType.Xor + ) + + def to_dict(self) -> Dict[str, Dict[str, str]]: + """Convert connection params to dict.""" + return _dataclass_to_dict(self) + + @staticmethod + def from_dict(cparam_dict: Dict[str, Dict[str, str]]) -> "DeviceConfig": + """Return connection parameters from dict.""" + return _dataclass_from_dict(DeviceConfig, cparam_dict) diff --git a/kasa/discover.py b/kasa/discover.py index 4ec3775e..e39122f3 100755 --- a/kasa/discover.py +++ b/kasa/discover.py @@ -1,5 +1,6 @@ """Discovery module for TP-Link Smart Home devices.""" import asyncio +import base64 import binascii import ipaddress import logging @@ -11,29 +12,32 @@ from typing import Awaitable, Callable, Dict, Optional, Set, Type, cast from async_timeout import timeout as asyncio_timeout try: - from pydantic.v1 import BaseModel, Field + from pydantic.v1 import BaseModel, ValidationError # pragma: no cover except ImportError: - from pydantic import BaseModel, Field + from pydantic import BaseModel, ValidationError # pragma: no cover from kasa.credentials import Credentials +from kasa.device_factory import ( + get_device_class_from_family, + get_device_class_from_sys_info, + get_protocol, +) +from kasa.deviceconfig import ConnectionType, DeviceConfig, EncryptType from kasa.exceptions import UnsupportedDeviceException from kasa.json import dumps as json_dumps from kasa.json import loads as json_loads from kasa.protocol import TPLinkSmartHomeProtocol from kasa.smartdevice import SmartDevice, SmartDeviceException -from .device_factory import ( - get_device_class_from_sys_info, - get_device_class_from_type_name, - get_protocol_from_connection_name, -) - _LOGGER = logging.getLogger(__name__) OnDiscoveredCallable = Callable[[SmartDevice], Awaitable[None]] DeviceDict = Dict[str, SmartDevice] +UNAVAILABLE_ALIAS = "Authentication required" +UNAVAILABLE_NICKNAME = base64.b64encode(UNAVAILABLE_ALIAS.encode()).decode() + class _DiscoverProtocol(asyncio.DatagramProtocol): """Implementation of the discovery protocol handler. @@ -62,9 +66,12 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): self.discovery_packets = discovery_packets self.interface = interface self.on_discovered = on_discovered + + self.port = port self.discovery_port = port or Discover.DISCOVERY_PORT self.target = (target, self.discovery_port) self.target_2 = (target, Discover.DISCOVERY_PORT_2) + self.discovered_devices = {} self.unsupported_device_exceptions: Dict = {} self.invalid_device_exceptions: Dict = {} @@ -110,13 +117,18 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): self.seen_hosts.add(ip) device = None + + config = DeviceConfig(host=ip, port_override=self.port) + if self.credentials: + config.credentials = self.credentials + if self.timeout: + config.timeout = self.timeout try: if port == self.discovery_port: - device = Discover._get_device_instance_legacy(data, ip, port) + device = Discover._get_device_instance_legacy(data, config) elif port == Discover.DISCOVERY_PORT_2: - device = Discover._get_device_instance( - data, ip, port, self.credentials or Credentials() - ) + config.uses_http = True + device = Discover._get_device_instance(data, config) else: return except UnsupportedDeviceException as udex: @@ -200,11 +212,13 @@ class Discover: *, target="255.255.255.255", on_discovered=None, - timeout=5, + discovery_timeout=5, discovery_packets=3, interface=None, on_unsupported=None, credentials=None, + port=None, + timeout=None, ) -> DeviceDict: """Discover supported devices. @@ -240,14 +254,15 @@ class Discover: on_unsupported=on_unsupported, credentials=credentials, timeout=timeout, + port=port, ), local_addr=("0.0.0.0", 0), # noqa: S104 ) protocol = cast(_DiscoverProtocol, protocol) try: - _LOGGER.debug("Waiting %s seconds for responses...", timeout) - await asyncio.sleep(timeout) + _LOGGER.debug("Waiting %s seconds for responses...", discovery_timeout) + await asyncio.sleep(discovery_timeout) finally: transport.close() @@ -259,10 +274,10 @@ class Discover: async def discover_single( host: str, *, + discovery_timeout: int = 5, port: Optional[int] = None, - timeout=5, + timeout: Optional[int] = None, credentials: Optional[Credentials] = None, - update_parent_devices: bool = True, ) -> SmartDevice: """Discover a single device by the given IP address. @@ -275,8 +290,6 @@ class Discover: :param port: Optionally set a different port for the device :param timeout: Timeout for discovery :param credentials: Credentials for devices that require authentication - :param update_parent_devices: Automatically call device.update() on - devices that have children :rtype: SmartDevice :return: Object for querying/controlling found device. """ @@ -320,9 +333,11 @@ class Discover: protocol = cast(_DiscoverProtocol, protocol) try: - _LOGGER.debug("Waiting a total of %s seconds for responses...", timeout) + _LOGGER.debug( + "Waiting a total of %s seconds for responses...", discovery_timeout + ) - async with asyncio_timeout(timeout): + async with asyncio_timeout(discovery_timeout): await event.wait() except asyncio.TimeoutError as ex: raise SmartDeviceException( @@ -334,9 +349,6 @@ class Discover: if ip in protocol.discovered_devices: dev = protocol.discovered_devices[ip] dev.host = host - # Call device update on devices that have children - if update_parent_devices and dev.has_children: - await dev.update() return dev elif ip in protocol.unsupported_device_exceptions: raise protocol.unsupported_device_exceptions[ip] @@ -350,99 +362,121 @@ class Discover: """Find SmartDevice subclass for device described by passed data.""" if "result" in info: discovery_result = DiscoveryResult(**info["result"]) - dev_class = get_device_class_from_type_name(discovery_result.device_type) + dev_class = get_device_class_from_family(discovery_result.device_type) if not dev_class: raise UnsupportedDeviceException( - "Unknown device type: %s" % discovery_result.device_type + "Unknown device type: %s" % discovery_result.device_type, + discovery_result=info, ) return dev_class else: return get_device_class_from_sys_info(info) @staticmethod - def _get_device_instance_legacy(data: bytes, ip: str, port: int) -> SmartDevice: + def _get_device_instance_legacy(data: bytes, config: DeviceConfig) -> SmartDevice: """Get SmartDevice from legacy 9999 response.""" try: info = json_loads(TPLinkSmartHomeProtocol.decrypt(data)) except Exception as ex: raise SmartDeviceException( - f"Unable to read response from device: {ip}: {ex}" + f"Unable to read response from device: {config.host}: {ex}" ) from ex - _LOGGER.debug("[DISCOVERY] %s << %s", ip, info) + _LOGGER.debug("[DISCOVERY] %s << %s", config.host, info) device_class = Discover._get_device_class(info) - device = device_class(ip, port=port) + device = device_class(config.host, config=config) + sys_info = info["system"]["get_sysinfo"] + if device_type := sys_info.get("mic_type", sys_info.get("type")): + config.connection_type = ConnectionType.from_values( + device_family=device_type, encryption_type=EncryptType.Xor.value + ) + device.protocol = get_protocol(config) # type: ignore[assignment] device.update_from_discover_info(info) return device @staticmethod def _get_device_instance( - data: bytes, ip: str, port: int, credentials: Credentials + data: bytes, + config: DeviceConfig, ) -> SmartDevice: """Get SmartDevice from the new 20002 response.""" try: info = json_loads(data[16:]) - discovery_result = DiscoveryResult(**info["result"]) except Exception as ex: + _LOGGER.debug("Got invalid response from device %s: %s", config.host, data) + raise SmartDeviceException( + f"Unable to read response from device: {config.host}: {ex}" + ) from ex + try: + discovery_result = DiscoveryResult(**info["result"]) + except ValidationError as ex: + _LOGGER.debug( + "Unable to parse discovery from device %s: %s", config.host, info + ) raise UnsupportedDeviceException( - f"Unable to read response from device: {ip}: {ex}" + f"Unable to parse discovery from device: {config.host}: {ex}" ) from ex type_ = discovery_result.device_type - encrypt_type_ = ( - f"{type_.split('.')[0]}.{discovery_result.mgt_encrypt_schm.encrypt_type}" - ) - if (device_class := get_device_class_from_type_name(type_)) is None: + try: + config.connection_type = ConnectionType.from_values( + type_, discovery_result.mgt_encrypt_schm.encrypt_type + ) + except SmartDeviceException as ex: + raise UnsupportedDeviceException( + f"Unsupported device {config.host} of type {type_} " + + f"with encrypt_type {discovery_result.mgt_encrypt_schm.encrypt_type}", + discovery_result=discovery_result.get_dict(), + ) from ex + if (device_class := get_device_class_from_family(type_)) is None: _LOGGER.warning("Got unsupported device type: %s", type_) raise UnsupportedDeviceException( - f"Unsupported device {ip} of type {type_}: {info}", + f"Unsupported device {config.host} of type {type_}: {info}", discovery_result=discovery_result.get_dict(), ) - if ( - protocol := get_protocol_from_connection_name( - encrypt_type_, ip, credentials=credentials + if (protocol := get_protocol(config)) is None: + _LOGGER.warning( + "Got unsupported connection type: %s", config.connection_type.to_dict() ) - ) is None: - _LOGGER.warning("Got unsupported device type: %s", encrypt_type_) raise UnsupportedDeviceException( - f"Unsupported encryption scheme {ip} of type {encrypt_type_}: {info}", + f"Unsupported encryption scheme {config.host} of " + + f"type {config.connection_type.to_dict()}: {info}", discovery_result=discovery_result.get_dict(), ) - _LOGGER.debug("[DISCOVERY] %s << %s", ip, info) - device = device_class(ip, port=port, credentials=credentials) - device.protocol = protocol - device.update_from_discover_info(discovery_result.get_dict()) + _LOGGER.debug("[DISCOVERY] %s << %s", config.host, info) + device = device_class(config.host, protocol=protocol) + + di = discovery_result.get_dict() + di["model"] = discovery_result.device_model + di["alias"] = UNAVAILABLE_ALIAS + di["nickname"] = UNAVAILABLE_NICKNAME + device.update_from_discover_info(di) return device class DiscoveryResult(BaseModel): """Base model for discovery result.""" - class Config: - """Class for configuring model behaviour.""" - - allow_population_by_field_name = True - class EncryptionScheme(BaseModel): """Base model for encryption scheme of discovery result.""" - is_support_https: Optional[bool] = None - encrypt_type: Optional[str] = None - http_port: Optional[int] = None - lv: Optional[int] = 1 + is_support_https: bool + encrypt_type: str + http_port: int + lv: Optional[int] = None - device_type: str = Field(alias="device_type_text") - device_model: str = Field(alias="model") - ip: str = Field(alias="alias") + device_type: str + device_model: str + ip: str mac: str mgt_encrypt_schm: EncryptionScheme + device_id: str - device_id: Optional[str] = Field(default=None, alias="device_id_hash") - owner: Optional[str] = Field(default=None, alias="device_owner_hash") hw_ver: Optional[str] = None + owner: Optional[str] = None is_support_iot_cloud: Optional[bool] = None obd_src: Optional[str] = None factory_default: Optional[bool] = None @@ -453,5 +487,5 @@ class DiscoveryResult(BaseModel): containing only the values actually set and with aliases as field names. """ return self.dict( - by_alias=True, exclude_unset=True, exclude_none=True, exclude_defaults=True + by_alias=False, exclude_unset=True, exclude_none=True, exclude_defaults=True ) diff --git a/kasa/iotprotocol.py b/kasa/iotprotocol.py index fbb37b15..470f4055 100755 --- a/kasa/iotprotocol.py +++ b/kasa/iotprotocol.py @@ -17,12 +17,11 @@ class IotProtocol(TPLinkProtocol): def __init__( self, - host: str, *, transport: BaseTransport, ) -> None: """Create a protocol object.""" - super().__init__(host, transport=transport) + super().__init__(transport=transport) self._query_lock = asyncio.Lock() @@ -39,25 +38,21 @@ class IotProtocol(TPLinkProtocol): for retry in range(retry_count + 1): try: return await self._execute_query(request, retry) - except httpx.CloseError as sdex: - await self.close() + except httpx.ConnectError as sdex: if retry >= retry_count: + await self.close() _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise SmartDeviceException( f"Unable to connect to the device: {self._host}: {sdex}" ) from sdex continue - except httpx.ConnectError as cex: - await self.close() - raise SmartDeviceException( - f"Unable to connect to the device: {self._host}: {cex}" - ) from cex except TimeoutError as tex: await self.close() raise SmartDeviceException( f"Unable to connect to the device, timed out: {self._host}: {tex}" ) from tex except AuthenticationException as auex: + await self.close() _LOGGER.debug( "Unable to authenticate with %s, not retrying", self._host ) @@ -70,8 +65,8 @@ class IotProtocol(TPLinkProtocol): ) raise ex except Exception as ex: - await self.close() if retry >= retry_count: + await self.close() _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise SmartDeviceException( f"Unable to connect to the device: {self._host}: {ex}" diff --git a/kasa/klaptransport.py b/kasa/klaptransport.py index e7bb8ae6..0e7ef565 100644 --- a/kasa/klaptransport.py +++ b/kasa/klaptransport.py @@ -54,6 +54,7 @@ from cryptography.hazmat.primitives import hashes, padding from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from .credentials import Credentials +from .deviceconfig import DeviceConfig from .exceptions import AuthenticationException, SmartDeviceException from .json import loads as json_loads from .protocol import BaseTransport, md5 @@ -82,27 +83,21 @@ class KlapTransport(BaseTransport): protocol, used by newer firmware versions. """ - DEFAULT_PORT = 80 + DEFAULT_PORT: int = 80 DISCOVERY_QUERY = {"system": {"get_sysinfo": None}} + KASA_SETUP_EMAIL = "kasa@tp-link.net" KASA_SETUP_PASSWORD = "kasaSetup" # noqa: S105 SESSION_COOKIE_NAME = "TP_SESSIONID" def __init__( self, - host: str, *, - port: Optional[int] = None, - credentials: Optional[Credentials] = None, - timeout: Optional[int] = None, + config: DeviceConfig, ) -> None: - super().__init__( - host, - port=port or self.DEFAULT_PORT, - credentials=credentials, - timeout=timeout, - ) + super().__init__(config=config) + self._default_http_client: Optional[httpx.AsyncClient] = None self._local_seed: Optional[bytes] = None self._local_auth_hash = self.generate_auth_hash(self._credentials) self._local_auth_owner = self.generate_owner_hash(self._credentials).hex() @@ -116,14 +111,24 @@ class KlapTransport(BaseTransport): self._session_expire_at: Optional[float] = None self._session_cookie = None - self._http_client: httpx.AsyncClient = httpx.AsyncClient() _LOGGER.debug("Created KLAP transport for %s", self._host) + @property + def default_port(self): + """Default port for the transport.""" + return self.DEFAULT_PORT + + @property + def _http_client(self) -> httpx.AsyncClient: + if self._config.http_client: + return self._config.http_client + if not self._default_http_client: + self._default_http_client = httpx.AsyncClient() + return self._default_http_client + async def client_post(self, url, params=None, data=None): """Send an http post request to the device.""" - if not self._http_client: - self._http_client = httpx.AsyncClient() response_data = None cookies = None if self._session_cookie: @@ -355,8 +360,8 @@ class KlapTransport(BaseTransport): async def close(self) -> None: """Close the transport.""" - client = self._http_client - self._http_client = None + client = self._default_http_client + self._default_http_client = None self._handshake_done = False if client: await client.aclose() @@ -390,7 +395,7 @@ class KlapTransport(BaseTransport): return md5(un.encode()) -class TPlinkKlapTransportV2(KlapTransport): +class KlapTransportV2(KlapTransport): """Implementation of the KLAP encryption protocol with v2 hanshake hashes.""" @staticmethod diff --git a/kasa/protocol.py b/kasa/protocol.py index f73260bf..c998807c 100755 --- a/kasa/protocol.py +++ b/kasa/protocol.py @@ -24,7 +24,7 @@ from typing import Dict, Generator, Optional, Union from async_timeout import timeout as asyncio_timeout from cryptography.hazmat.primitives import hashes -from .credentials import Credentials +from .deviceconfig import DeviceConfig from .exceptions import SmartDeviceException from .json import dumps as json_dumps from .json import loads as json_loads @@ -48,17 +48,20 @@ class BaseTransport(ABC): def __init__( self, - host: str, *, - port: Optional[int] = None, - credentials: Optional[Credentials] = None, - timeout: Optional[int] = None, + config: DeviceConfig, ) -> None: """Create a protocol object.""" - self._host = host - self._port = port - self._credentials = credentials or Credentials(username="", password="") - self._timeout = timeout or self.DEFAULT_TIMEOUT + self._config = config + self._host = config.host + self._port = config.port_override or self.default_port + self._credentials = config.credentials + self._timeout = config.timeout + + @property + @abstractmethod + def default_port(self) -> int: + """The default port for the transport.""" @abstractmethod async def send(self, request: str) -> Dict: @@ -74,7 +77,6 @@ class TPLinkProtocol(ABC): def __init__( self, - host: str, *, transport: BaseTransport, ) -> None: @@ -85,6 +87,11 @@ class TPLinkProtocol(ABC): def _host(self): return self._transport._host + @property + def config(self) -> DeviceConfig: + """Return the connection parameters the device is using.""" + return self._transport._config + @abstractmethod async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: """Query the device for the protocol. Abstract method to be overriden.""" @@ -103,22 +110,15 @@ class _XorTransport(BaseTransport): class. """ - DEFAULT_PORT = 9999 + DEFAULT_PORT: int = 9999 - def __init__( - self, - host: str, - *, - port: Optional[int] = None, - credentials: Optional[Credentials] = None, - timeout: Optional[int] = None, - ) -> None: - super().__init__( - host, - port=port or self.DEFAULT_PORT, - credentials=credentials, - timeout=timeout, - ) + def __init__(self, *, config: DeviceConfig) -> None: + super().__init__(config=config) + + @property + def default_port(self): + """Default port for the transport.""" + return self.DEFAULT_PORT async def send(self, request: str) -> Dict: """Send a message to the device and return a response.""" @@ -133,17 +133,15 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol): INITIALIZATION_VECTOR = 171 DEFAULT_PORT = 9999 - DEFAULT_TIMEOUT = 5 BLOCK_SIZE = 4 def __init__( self, - host: str, *, transport: BaseTransport, ) -> None: """Create a protocol object.""" - super().__init__(host, transport=transport) + super().__init__(transport=transport) self.reader: Optional[asyncio.StreamReader] = None self.writer: Optional[asyncio.StreamWriter] = None @@ -167,7 +165,7 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol): assert isinstance(request, str) # noqa: S101 async with self.query_lock: - return await self._query(request, retry_count, self._timeout) + return await self._query(request, retry_count, self._timeout) # type: ignore[arg-type] async def _connect(self, timeout: int) -> None: """Try to connect or reconnect to the device.""" diff --git a/kasa/smartbulb.py b/kasa/smartbulb.py index 6dd4513c..8897cece 100644 --- a/kasa/smartbulb.py +++ b/kasa/smartbulb.py @@ -9,8 +9,9 @@ try: except ImportError: from pydantic import BaseModel, Field, root_validator -from .credentials import Credentials +from .deviceconfig import DeviceConfig from .modules import Antitheft, Cloud, Countdown, Emeter, Schedule, Time, Usage +from .protocol import TPLinkProtocol from .smartdevice import DeviceType, SmartDevice, SmartDeviceException, requires_update @@ -220,11 +221,10 @@ class SmartBulb(SmartDevice): self, host: str, *, - port: Optional[int] = None, - credentials: Optional[Credentials] = None, - timeout: Optional[int] = None, + config: Optional[DeviceConfig] = None, + protocol: Optional[TPLinkProtocol] = None, ) -> None: - super().__init__(host=host, port=port, credentials=credentials, timeout=timeout) + super().__init__(host=host, config=config, protocol=protocol) self._device_type = DeviceType.Bulb self.add_module("schedule", Schedule(self, "smartlife.iot.common.schedule")) self.add_module("usage", Usage(self, "smartlife.iot.common.schedule")) diff --git a/kasa/smartdevice.py b/kasa/smartdevice.py index 5ad94a9f..97b46ddc 100755 --- a/kasa/smartdevice.py +++ b/kasa/smartdevice.py @@ -21,6 +21,7 @@ from typing import Any, Dict, List, Optional, Set from .credentials import Credentials from .device_type import DeviceType +from .deviceconfig import DeviceConfig from .emeterstatus import EmeterStatus from .exceptions import SmartDeviceException from .modules import Emeter, Module @@ -191,20 +192,18 @@ class SmartDevice: self, host: str, *, - port: Optional[int] = None, - credentials: Optional[Credentials] = None, - timeout: Optional[int] = None, + config: Optional[DeviceConfig] = None, + protocol: Optional[TPLinkProtocol] = None, ) -> None: """Create a new SmartDevice instance. :param str host: host name or ip address on which the device listens """ - self.host = host - self.port = port - self.protocol: TPLinkProtocol = TPLinkSmartHomeProtocol( - host, transport=_XorTransport(host, port=port, timeout=timeout) + if config and protocol: + protocol._transport._config = config + self.protocol: TPLinkProtocol = protocol or TPLinkSmartHomeProtocol( + transport=_XorTransport(config=config or DeviceConfig(host=host)), ) - self.credentials = credentials _LOGGER.debug("Initializing %s of type %s", self.host, type(self)) self._device_type = DeviceType.Unknown # TODO: typing Any is just as using Optional[Dict] would require separate @@ -219,6 +218,30 @@ class SmartDevice: self.children: List["SmartDevice"] = [] + @property + def host(self) -> str: + """The device host.""" + return self.protocol._transport._host + + @host.setter + def host(self, value): + """Set the device host. + + Generally used by discovery to set the hostname after ip discovery. + """ + self.protocol._transport._host = value + self.protocol._transport._config.host = value + + @property + def port(self) -> int: + """The device port.""" + return self.protocol._transport._port + + @property + def credentials(self) -> Optional[Credentials]: + """The device credentials.""" + return self.protocol._transport._credentials + def add_module(self, name: str, module: Module): """Register a module.""" if name in self.modules: @@ -760,7 +783,7 @@ class SmartDevice: The returned object contains the raw results from the last update call. This should only be used for debugging purposes. """ - return self._last_update + return self._last_update or self._discovery_info def __repr__(self): if self._last_update is None: @@ -771,41 +794,33 @@ class SmartDevice: f" - dev specific: {self.state_information}>" ) + @property + def config(self) -> DeviceConfig: + """Return the connection parameters the device is using.""" + return self.protocol.config + @staticmethod async def connect( - host: str, *, - port: Optional[int] = None, - timeout=5, - credentials: Optional[Credentials] = None, - device_type: Optional[DeviceType] = None, + host: Optional[str] = None, + config: Optional[DeviceConfig] = None, ) -> "SmartDevice": - """Connect to a single device by the given IP address. + """Connect to a single device by the given hostname or device configuration. This method avoids the UDP based discovery process and - will connect directly to the device to query its type. + will connect directly to the device. It is generally preferred to avoid :func:`discover_single()` and use this function instead as it should perform better when the WiFi network is congested or the device is not responding to discovery requests. - The device type is discovered by querying the device. - :param host: Hostname of device to query - :param device_type: Device type to use for the device. - If not given, the device type is discovered by querying the device. - If the device type is already known, it is preferred to pass it - to avoid the extra query to the device to discover its type. + :param config: Connection parameters to ensure the correct protocol + and connection options are used. :rtype: SmartDevice :return: Object for querying/controlling found device. """ from .device_factory import connect # pylint: disable=import-outside-toplevel - return await connect( - host=host, - port=port, - timeout=timeout, - credentials=credentials, - device_type=device_type, - ) + return await connect(host=host, config=config) # type: ignore[arg-type] diff --git a/kasa/smartdimmer.py b/kasa/smartdimmer.py index 7980319c..ca0960f1 100644 --- a/kasa/smartdimmer.py +++ b/kasa/smartdimmer.py @@ -2,8 +2,9 @@ from enum import Enum from typing import Any, Dict, Optional -from kasa.credentials import Credentials +from kasa.deviceconfig import DeviceConfig from kasa.modules import AmbientLight, Motion +from kasa.protocol import TPLinkProtocol from kasa.smartdevice import DeviceType, SmartDeviceException, requires_update from kasa.smartplug import SmartPlug @@ -68,11 +69,10 @@ class SmartDimmer(SmartPlug): self, host: str, *, - port: Optional[int] = None, - credentials: Optional[Credentials] = None, - timeout: Optional[int] = None, + config: Optional[DeviceConfig] = None, + protocol: Optional[TPLinkProtocol] = None, ) -> None: - super().__init__(host, port=port, credentials=credentials, timeout=timeout) + super().__init__(host=host, config=config, protocol=protocol) self._device_type = DeviceType.Dimmer # TODO: need to be verified if it's okay to call these on HS220 w/o these # TODO: need to be figured out what's the best approach to detect support diff --git a/kasa/smartlightstrip.py b/kasa/smartlightstrip.py index 2990e1fa..27ebf838 100644 --- a/kasa/smartlightstrip.py +++ b/kasa/smartlightstrip.py @@ -1,8 +1,9 @@ """Module for light strips (KL430).""" from typing import Any, Dict, List, Optional -from .credentials import Credentials +from .deviceconfig import DeviceConfig from .effects import EFFECT_MAPPING_V1, EFFECT_NAMES_V1 +from .protocol import TPLinkProtocol from .smartbulb import SmartBulb from .smartdevice import DeviceType, SmartDeviceException, requires_update @@ -46,11 +47,10 @@ class SmartLightStrip(SmartBulb): self, host: str, *, - port: Optional[int] = None, - credentials: Optional[Credentials] = None, - timeout: Optional[int] = None, + config: Optional[DeviceConfig] = None, + protocol: Optional[TPLinkProtocol] = None, ) -> None: - super().__init__(host, port=port, credentials=credentials, timeout=timeout) + super().__init__(host=host, config=config, protocol=protocol) self._device_type = DeviceType.LightStrip @property # type: ignore diff --git a/kasa/smartplug.py b/kasa/smartplug.py index 4ba230b4..d9ac0c86 100644 --- a/kasa/smartplug.py +++ b/kasa/smartplug.py @@ -2,8 +2,9 @@ import logging from typing import Any, Dict, Optional -from kasa.credentials import Credentials +from kasa.deviceconfig import DeviceConfig from kasa.modules import Antitheft, Cloud, Schedule, Time, Usage +from kasa.protocol import TPLinkProtocol from kasa.smartdevice import DeviceType, SmartDevice, requires_update _LOGGER = logging.getLogger(__name__) @@ -43,11 +44,10 @@ class SmartPlug(SmartDevice): self, host: str, *, - port: Optional[int] = None, - credentials: Optional[Credentials] = None, - timeout: Optional[int] = None, + config: Optional[DeviceConfig] = None, + protocol: Optional[TPLinkProtocol] = None, ) -> None: - super().__init__(host, port=port, credentials=credentials, timeout=timeout) + super().__init__(host=host, config=config, protocol=protocol) self._device_type = DeviceType.Plug self.add_module("schedule", Schedule(self, "schedule")) self.add_module("usage", Usage(self, "schedule")) diff --git a/kasa/smartprotocol.py b/kasa/smartprotocol.py index 443d1def..97573d93 100644 --- a/kasa/smartprotocol.py +++ b/kasa/smartprotocol.py @@ -38,12 +38,11 @@ class SmartProtocol(TPLinkProtocol): def __init__( self, - host: str, *, transport: BaseTransport, ) -> None: """Create a protocol object.""" - super().__init__(host, transport=transport) + super().__init__(transport=transport) self._terminal_uuid: str = base64.b64encode(md5(uuid.uuid4().bytes)).decode() self._request_id_generator = SnowflakeId(1, 1) self._query_lock = asyncio.Lock() @@ -68,19 +67,14 @@ class SmartProtocol(TPLinkProtocol): for retry in range(retry_count + 1): try: return await self._execute_query(request, retry) - except httpx.CloseError as sdex: - await self.close() + except httpx.ConnectError as sdex: if retry >= retry_count: + await self.close() _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise SmartDeviceException( f"Unable to connect to the device: {self._host}: {sdex}" ) from sdex continue - except httpx.ConnectError as cex: - await self.close() - raise SmartDeviceException( - f"Unable to connect to the device: {self._host}: {cex}" - ) from cex except TimeoutError as tex: if retry >= retry_count: await self.close() diff --git a/kasa/smartstrip.py b/kasa/smartstrip.py index 80aa27d1..79393132 100755 --- a/kasa/smartstrip.py +++ b/kasa/smartstrip.py @@ -14,8 +14,9 @@ from kasa.smartdevice import ( ) from kasa.smartplug import SmartPlug -from .credentials import Credentials +from .deviceconfig import DeviceConfig from .modules import Antitheft, Countdown, Emeter, Schedule, Time, Usage +from .protocol import TPLinkProtocol _LOGGER = logging.getLogger(__name__) @@ -85,11 +86,10 @@ class SmartStrip(SmartDevice): self, host: str, *, - port: Optional[int] = None, - credentials: Optional[Credentials] = None, - timeout: Optional[int] = None, + config: Optional[DeviceConfig] = None, + protocol: Optional[TPLinkProtocol] = None, ) -> None: - super().__init__(host=host, port=port, credentials=credentials, timeout=timeout) + super().__init__(host=host, config=config, protocol=protocol) self.emeter_type = "emeter" self._device_type = DeviceType.Strip self.add_module("antitheft", Antitheft(self, "anti_theft")) diff --git a/kasa/tapo/tapodevice.py b/kasa/tapo/tapodevice.py index 97405b3f..717de7ef 100644 --- a/kasa/tapo/tapodevice.py +++ b/kasa/tapo/tapodevice.py @@ -5,8 +5,9 @@ from datetime import datetime, timedelta, timezone from typing import Any, Dict, Optional, Set, cast from ..aestransport import AesTransport -from ..credentials import Credentials +from ..deviceconfig import DeviceConfig from ..exceptions import AuthenticationException +from ..protocol import TPLinkProtocol from ..smartdevice import SmartDevice from ..smartprotocol import SmartProtocol @@ -20,20 +21,16 @@ class TapoDevice(SmartDevice): self, host: str, *, - port: Optional[int] = None, - credentials: Optional[Credentials] = None, - timeout: Optional[int] = None, + config: Optional[DeviceConfig] = None, + protocol: Optional[TPLinkProtocol] = None, ) -> None: - super().__init__(host, port=port, credentials=credentials, timeout=timeout) + _protocol = protocol or SmartProtocol( + transport=AesTransport(config=config or DeviceConfig(host=host)), + ) + super().__init__(host=host, config=config, protocol=_protocol) self._components: Optional[Dict[str, Any]] = None self._state_information: Dict[str, Any] = {} self._discovery_info: Optional[Dict[str, Any]] = None - self.protocol = SmartProtocol( - host, - transport=AesTransport( - host, credentials=credentials, timeout=timeout, port=port - ), - ) async def update(self, update_children: bool = True): """Update the device.""" @@ -66,7 +63,7 @@ class TapoDevice(SmartDevice): @property def sys_info(self) -> Dict[str, Any]: """Returns the device info.""" - return self._info + return self._info # type: ignore @property def model(self) -> str: @@ -180,3 +177,4 @@ class TapoDevice(SmartDevice): def update_from_discover_info(self, info): """Update state from info from the discover call.""" self._discovery_info = info + self._info = info diff --git a/kasa/tapo/tapoplug.py b/kasa/tapo/tapoplug.py index 9d868253..67aed565 100644 --- a/kasa/tapo/tapoplug.py +++ b/kasa/tapo/tapoplug.py @@ -3,9 +3,10 @@ import logging from datetime import datetime, timedelta from typing import Any, Dict, Optional, cast -from ..credentials import Credentials +from ..deviceconfig import DeviceConfig from ..emeterstatus import EmeterStatus from ..modules import Emeter +from ..protocol import TPLinkProtocol from ..smartdevice import DeviceType, requires_update from .tapodevice import TapoDevice @@ -19,11 +20,10 @@ class TapoPlug(TapoDevice): self, host: str, *, - port: Optional[int] = None, - credentials: Optional[Credentials] = None, - timeout: Optional[int] = None, + config: Optional[DeviceConfig] = None, + protocol: Optional[TPLinkProtocol] = None, ) -> None: - super().__init__(host, port=port, credentials=credentials, timeout=timeout) + super().__init__(host=host, config=config, protocol=protocol) self._device_type = DeviceType.Plug self.modules: Dict[str, Any] = {} self.emeter_type = "emeter" diff --git a/kasa/tests/conftest.py b/kasa/tests/conftest.py index 43bba825..11efe693 100644 --- a/kasa/tests/conftest.py +++ b/kasa/tests/conftest.py @@ -388,7 +388,6 @@ async def get_device_for_file(file, protocol): d = device_for_file(model, protocol)(host="127.0.0.123") if protocol == "SMART": d.protocol = FakeSmartProtocol(sysinfo) - d.credentials = Credentials("", "") else: d.protocol = FakeTransportProtocol(sysinfo) await _update_and_close(d) @@ -426,28 +425,53 @@ def discovery_mock(all_fixture_data, mocker): class _DiscoveryMock: ip: str default_port: int + discovery_port: int discovery_data: dict query_data: dict + device_type: str + encrypt_type: str port_override: Optional[int] = None if "discovery_result" in all_fixture_data: discovery_data = {"result": all_fixture_data["discovery_result"]} + device_type = all_fixture_data["discovery_result"]["device_type"] + encrypt_type = all_fixture_data["discovery_result"]["mgt_encrypt_schm"][ + "encrypt_type" + ] datagram = ( b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8" + json_dumps(discovery_data).encode() ) - dm = _DiscoveryMock("127.0.0.123", 20002, discovery_data, all_fixture_data) + dm = _DiscoveryMock( + "127.0.0.123", + 80, + 20002, + discovery_data, + all_fixture_data, + device_type, + encrypt_type, + ) else: sys_info = all_fixture_data["system"]["get_sysinfo"] discovery_data = {"system": {"get_sysinfo": sys_info}} + device_type = sys_info.get("mic_type") or sys_info.get("type") + encrypt_type = "XOR" datagram = TPLinkSmartHomeProtocol.encrypt(json_dumps(discovery_data))[4:] - dm = _DiscoveryMock("127.0.0.123", 9999, discovery_data, all_fixture_data) + dm = _DiscoveryMock( + "127.0.0.123", + 9999, + 9999, + discovery_data, + all_fixture_data, + device_type, + encrypt_type, + ) def mock_discover(self): port = ( dm.port_override - if dm.port_override and dm.default_port != 20002 - else dm.default_port + if dm.port_override and dm.discovery_port != 20002 + else dm.discovery_port ) self.datagram_received( datagram, diff --git a/kasa/tests/newfakes.py b/kasa/tests/newfakes.py index cd7ad4fd..13d11d3d 100644 --- a/kasa/tests/newfakes.py +++ b/kasa/tests/newfakes.py @@ -15,7 +15,9 @@ from voluptuous import ( Schema, ) -from ..protocol import BaseTransport, TPLinkSmartHomeProtocol +from ..credentials import Credentials +from ..deviceconfig import DeviceConfig +from ..protocol import BaseTransport, TPLinkSmartHomeProtocol, _XorTransport from ..smartprotocol import SmartProtocol _LOGGER = logging.getLogger(__name__) @@ -290,7 +292,9 @@ TIME_MODULE = { class FakeSmartProtocol(SmartProtocol): def __init__(self, info): - super().__init__("127.0.0.123", transport=FakeSmartTransport(info)) + super().__init__( + transport=FakeSmartTransport(info), + ) async def query(self, request, retry_count: int = 3): """Implement query here so can still patch SmartProtocol.query.""" @@ -301,10 +305,15 @@ class FakeSmartProtocol(SmartProtocol): class FakeSmartTransport(BaseTransport): def __init__(self, info): super().__init__( - "127.0.0.123", + config=DeviceConfig("127.0.0.123", credentials=Credentials()), ) self.info = info + @property + def default_port(self): + """Default port for the transport.""" + return 80 + async def send(self, request: str): request_dict = json_loads(request) method = request_dict["method"] @@ -344,6 +353,11 @@ class FakeSmartTransport(BaseTransport): class FakeTransportProtocol(TPLinkSmartHomeProtocol): def __init__(self, info): + super().__init__( + transport=_XorTransport( + config=DeviceConfig("127.0.0.123"), + ) + ) self.discovery_data = info self.writer = None self.reader = None diff --git a/kasa/tests/test_aestransport.py b/kasa/tests/test_aestransport.py index 198e8f39..faf47a75 100644 --- a/kasa/tests/test_aestransport.py +++ b/kasa/tests/test_aestransport.py @@ -12,6 +12,7 @@ from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padd from ..aestransport import AesEncyptionSession, AesTransport from ..credentials import Credentials +from ..deviceconfig import DeviceConfig from ..exceptions import ( SMART_RETRYABLE_ERRORS, SMART_TIMEOUT_ERRORS, @@ -58,7 +59,9 @@ async def test_handshake( mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code) mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post) - transport = AesTransport(host=host, credentials=Credentials("foo", "bar")) + transport = AesTransport( + config=DeviceConfig(host, credentials=Credentials("foo", "bar")) + ) assert transport._encryption_session is None assert transport._handshake_done is False @@ -74,7 +77,9 @@ async def test_login(mocker, status_code, error_code, inner_error_code, expectat mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code) mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post) - transport = AesTransport(host=host, credentials=Credentials("foo", "bar")) + transport = AesTransport( + config=DeviceConfig(host, credentials=Credentials("foo", "bar")) + ) transport._handshake_done = True transport._session_expire_at = time.time() + 86400 transport._encryption_session = mock_aes_device.encryption_session @@ -91,13 +96,14 @@ async def test_send(mocker, status_code, error_code, inner_error_code, expectati mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code) mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post) - transport = AesTransport(host=host, credentials=Credentials("foo", "bar")) + transport = AesTransport( + config=DeviceConfig(host, credentials=Credentials("foo", "bar")) + ) transport._handshake_done = True transport._session_expire_at = time.time() + 86400 transport._encryption_session = mock_aes_device.encryption_session transport._login_token = mock_aes_device.token - un, pw = transport.hash_credentials(True) request = { "method": "get_device_info", "params": None, @@ -119,7 +125,8 @@ async def test_passthrough_errors(mocker, error_code): mock_aes_device = MockAesDevice(host, 200, error_code, 0) mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post) - transport = AesTransport(host=host, credentials=Credentials("foo", "bar")) + config = DeviceConfig(host, credentials=Credentials("foo", "bar")) + transport = AesTransport(config=config) transport._handshake_done = True transport._session_expire_at = time.time() + 86400 transport._encryption_session = mock_aes_device.encryption_session diff --git a/kasa/tests/test_cli.py b/kasa/tests/test_cli.py index c46015ea..1983b6cc 100644 --- a/kasa/tests/test_cli.py +++ b/kasa/tests/test_cli.py @@ -4,10 +4,26 @@ import asyncclick as click import pytest from asyncclick.testing import CliRunner -from kasa import AuthenticationException, SmartDevice, UnsupportedDeviceException -from kasa.cli import alias, brightness, cli, emeter, raw_command, state, sysinfo, toggle -from kasa.device_factory import DEVICE_TYPE_TO_CLASS -from kasa.discover import Discover +from kasa import ( + AuthenticationException, + Credentials, + SmartDevice, + TPLinkSmartHomeProtocol, + UnsupportedDeviceException, +) +from kasa.cli import ( + TYPE_TO_CLASS, + alias, + brightness, + cli, + emeter, + raw_command, + state, + sysinfo, + toggle, +) +from kasa.discover import Discover, DiscoveryResult +from kasa.smartprotocol import SmartProtocol from .conftest import device_iot, handle_turn_on, new_discovery, turn_on @@ -145,9 +161,11 @@ async def test_credentials(discovery_mock, mocker): ) mocker.patch("kasa.cli.state", new=_state) - for subclass in DEVICE_TYPE_TO_CLASS.values(): - mocker.patch.object(subclass, "update") + mocker.patch("kasa.IotProtocol.query", return_value=discovery_mock.query_data) + mocker.patch("kasa.SmartProtocol.query", return_value=discovery_mock.query_data) + + dr = DiscoveryResult(**discovery_mock.discovery_data["result"]) runner = CliRunner() res = await runner.invoke( cli, @@ -158,6 +176,10 @@ async def test_credentials(discovery_mock, mocker): "foo", "--password", "bar", + "--device-family", + dr.device_type, + "--encrypt-type", + dr.mgt_encrypt_schm.encrypt_type, ], ) assert res.exit_code == 0 @@ -166,7 +188,7 @@ async def test_credentials(discovery_mock, mocker): @device_iot -async def test_without_device_type(discovery_data: dict, dev, mocker): +async def test_without_device_type(dev, mocker): """Test connecting without the device type.""" runner = CliRunner() mocker.patch("kasa.discover.Discover.discover_single", return_value=dev) @@ -342,3 +364,27 @@ async def test_host_auth_failed(discovery_mock, mocker): assert res.exit_code != 0 assert isinstance(res.exception, AuthenticationException) + + +@pytest.mark.parametrize("device_type", list(TYPE_TO_CLASS)) +async def test_type_param(device_type, mocker): + """Test for handling only one of username or password supplied.""" + runner = CliRunner() + + result_device = FileNotFoundError + pass_dev = click.make_pass_decorator(SmartDevice) + + @pass_dev + async def _state(dev: SmartDevice): + nonlocal result_device + result_device = dev + + mocker.patch("kasa.cli.state", new=_state) + expected_type = TYPE_TO_CLASS[device_type] + mocker.patch.object(expected_type, "update") + res = await runner.invoke( + cli, + ["--type", device_type, "--host", "127.0.0.1"], + ) + assert res.exit_code == 0 + assert isinstance(result_device, expected_type) diff --git a/kasa/tests/test_device_factory.py b/kasa/tests/test_device_factory.py index eb12b3b0..666bd9e9 100644 --- a/kasa/tests/test_device_factory.py +++ b/kasa/tests/test_device_factory.py @@ -2,6 +2,7 @@ import logging from typing import Type +import httpx import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 from kasa import ( @@ -15,122 +16,138 @@ from kasa import ( SmartLightStrip, SmartPlug, ) -from kasa.device_factory import ( - DEVICE_TYPE_TO_CLASS, - connect, - get_protocol_from_connection_name, +from kasa.device_factory import connect, get_protocol +from kasa.deviceconfig import ( + ConnectionType, + DeviceConfig, + DeviceFamilyType, + EncryptType, ) from kasa.discover import DiscoveryResult -from kasa.iotprotocol import IotProtocol -from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol -@pytest.mark.parametrize("custom_port", [123, None]) -async def test_connect(discovery_data: dict, mocker, custom_port): - """Make sure that connect returns an initialized SmartDevice instance.""" - host = "127.0.0.1" +def _get_connection_type_device_class(the_fixture_data): + if "discovery_result" in the_fixture_data: + discovery_info = {"result": the_fixture_data["discovery_result"]} + device_class = Discover._get_device_class(discovery_info) + dr = DiscoveryResult(**discovery_info["result"]) - if "result" in discovery_data: - with pytest.raises(SmartDeviceException): - dev = await connect(host, port=custom_port) + connection_type = ConnectionType.from_values( + dr.device_type, dr.mgt_encrypt_schm.encrypt_type + ) else: - mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data) - dev = await connect(host, port=custom_port) - assert issubclass(dev.__class__, SmartDevice) - assert dev.port == custom_port or dev.port == 9999 + connection_type = ConnectionType.from_values( + DeviceFamilyType.IotSmartPlugSwitch.value, EncryptType.Xor.value + ) + device_class = Discover._get_device_class(the_fixture_data) + + return connection_type, device_class -@pytest.mark.parametrize("custom_port", [123, None]) -@pytest.mark.parametrize( - ("device_type", "klass"), - ( - (DeviceType.Plug, SmartPlug), - (DeviceType.Bulb, SmartBulb), - (DeviceType.Dimmer, SmartDimmer), - (DeviceType.LightStrip, SmartLightStrip), - (DeviceType.Unknown, SmartDevice), - ), -) -async def test_connect_passed_device_type( - discovery_data: dict, - mocker, - device_type: DeviceType, - klass: Type[SmartDevice], - custom_port, -): - """Make sure that connect with a passed device type.""" - host = "127.0.0.1" - - if "result" in discovery_data: - with pytest.raises(SmartDeviceException): - dev = await connect(host, port=custom_port) - else: - mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data) - dev = await connect(host, port=custom_port, device_type=device_type) - assert isinstance(dev, klass) - assert dev.port == custom_port or dev.port == 9999 - - -async def test_connect_query_fails(discovery_data: dict, mocker): - """Make sure that connect fails when query fails.""" - host = "127.0.0.1" - mocker.patch("kasa.TPLinkSmartHomeProtocol.query", side_effect=SmartDeviceException) - - with pytest.raises(SmartDeviceException): - await connect(host) - - -async def test_connect_logs_connect_time( - discovery_data: dict, caplog: pytest.LogCaptureFixture, mocker -): - """Test that the connect time is logged when debug logging is enabled.""" - host = "127.0.0.1" - if "result" in discovery_data: - with pytest.raises(SmartDeviceException): - await connect(host) - else: - mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data) - logging.getLogger("kasa").setLevel(logging.DEBUG) - await connect(host) - assert "seconds to connect" in caplog.text - - -async def test_connect_pass_protocol( +async def test_connect( all_fixture_data: dict, mocker, ): - """Test that if the protocol is passed in it's gets set correctly.""" - if "discovery_result" in all_fixture_data: - discovery_info = {"result": all_fixture_data["discovery_result"]} - device_class = Discover._get_device_class(discovery_info) - else: - device_class = Discover._get_device_class(all_fixture_data) - - device_type = list(DEVICE_TYPE_TO_CLASS.keys())[ - list(DEVICE_TYPE_TO_CLASS.values()).index(device_class) - ] + """Test that if the protocol is passed in it gets set correctly.""" host = "127.0.0.1" - if "discovery_result" in all_fixture_data: - mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data) - mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data) + ctype, device_class = _get_connection_type_device_class(all_fixture_data) - dr = DiscoveryResult(**discovery_info["result"]) - connection_name = ( - dr.device_type.split(".")[0] + "." + dr.mgt_encrypt_schm.encrypt_type - ) - protocol_class = get_protocol_from_connection_name( - connection_name, host - ).__class__ - else: - mocker.patch( - "kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data - ) - protocol_class = TPLinkSmartHomeProtocol + mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data) + mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data) + mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data) + + config = DeviceConfig( + host=host, credentials=Credentials("foor", "bar"), connection_type=ctype + ) + protocol_class = get_protocol(config).__class__ dev = await connect( - host, - device_type=device_type, - protocol_class=protocol_class, - credentials=Credentials("", ""), + config=config, ) + assert isinstance(dev, device_class) assert isinstance(dev.protocol, protocol_class) + + assert dev.config == config + + +@pytest.mark.parametrize("custom_port", [123, None]) +async def test_connect_custom_port(all_fixture_data: dict, mocker, custom_port): + """Make sure that connect returns an initialized SmartDevice instance.""" + host = "127.0.0.1" + + ctype, _ = _get_connection_type_device_class(all_fixture_data) + config = DeviceConfig(host=host, port_override=custom_port, connection_type=ctype) + default_port = 80 if "discovery_result" in all_fixture_data else 9999 + + ctype, _ = _get_connection_type_device_class(all_fixture_data) + mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data) + mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data) + mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data) + dev = await connect(config=config) + assert issubclass(dev.__class__, SmartDevice) + assert dev.port == custom_port or dev.port == default_port + + +async def test_connect_logs_connect_time( + all_fixture_data: dict, caplog: pytest.LogCaptureFixture, mocker +): + """Test that the connect time is logged when debug logging is enabled.""" + ctype, _ = _get_connection_type_device_class(all_fixture_data) + mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data) + mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data) + mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data) + + host = "127.0.0.1" + config = DeviceConfig( + host=host, credentials=Credentials("foor", "bar"), connection_type=ctype + ) + logging.getLogger("kasa").setLevel(logging.DEBUG) + await connect( + config=config, + ) + assert "seconds to update" in caplog.text + + +async def test_connect_query_fails(all_fixture_data: dict, mocker): + """Make sure that connect fails when query fails.""" + host = "127.0.0.1" + mocker.patch("kasa.TPLinkSmartHomeProtocol.query", side_effect=SmartDeviceException) + mocker.patch("kasa.IotProtocol.query", side_effect=SmartDeviceException) + mocker.patch("kasa.SmartProtocol.query", side_effect=SmartDeviceException) + + ctype, _ = _get_connection_type_device_class(all_fixture_data) + config = DeviceConfig( + host=host, credentials=Credentials("foor", "bar"), connection_type=ctype + ) + with pytest.raises(SmartDeviceException): + await connect(config=config) + + +async def test_connect_http_client(all_fixture_data, mocker): + """Make sure that discover_single returns an initialized SmartDevice instance.""" + host = "127.0.0.1" + + ctype, _ = _get_connection_type_device_class(all_fixture_data) + + mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data) + mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data) + mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data) + + http_client = httpx.AsyncClient() + + config = DeviceConfig( + host=host, credentials=Credentials("foor", "bar"), connection_type=ctype + ) + dev = await connect(config=config) + if ctype.encryption_type != EncryptType.Xor: + assert dev.protocol._transport._http_client != http_client + + config = DeviceConfig( + host=host, + credentials=Credentials("foor", "bar"), + connection_type=ctype, + http_client=http_client, + ) + dev = await connect(config=config) + if ctype.encryption_type != EncryptType.Xor: + assert dev.protocol._transport._http_client == http_client diff --git a/kasa/tests/test_deviceconfig.py b/kasa/tests/test_deviceconfig.py new file mode 100644 index 00000000..7970449d --- /dev/null +++ b/kasa/tests/test_deviceconfig.py @@ -0,0 +1,21 @@ +from json import dumps as json_dumps +from json import loads as json_loads + +import httpx + +from kasa.credentials import Credentials +from kasa.deviceconfig import ( + ConnectionType, + DeviceConfig, + DeviceFamilyType, + EncryptType, +) + + +def test_serialization(): + config = DeviceConfig(host="Foo", http_client=httpx.AsyncClient()) + config_dict = config.to_dict() + config_json = json_dumps(config_dict) + config2_dict = json_loads(config_json) + config2 = DeviceConfig.from_dict(config2_dict) + assert config == config2 diff --git a/kasa/tests/test_discovery.py b/kasa/tests/test_discovery.py index 18798ab9..396ef2f2 100644 --- a/kasa/tests/test_discovery.py +++ b/kasa/tests/test_discovery.py @@ -1,21 +1,29 @@ # type: ignore +import logging import re import socket +import httpx import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 from kasa import ( + Credentials, DeviceType, Discover, SmartDevice, SmartDeviceException, - SmartStrip, protocol, ) +from kasa.deviceconfig import ( + ConnectionType, + DeviceConfig, + DeviceFamilyType, + EncryptType, +) from kasa.discover import DiscoveryResult, _DiscoverProtocol, json_dumps from kasa.exceptions import AuthenticationException, UnsupportedDeviceException -from .conftest import bulb, bulb_iot, dimmer, lightstrip, plug, strip +from .conftest import bulb, bulb_iot, dimmer, lightstrip, new_discovery, plug, strip UNSUPPORTED = { "result": { @@ -89,13 +97,26 @@ async def test_discover_single(discovery_mock, custom_port, mocker): host = "127.0.0.1" discovery_mock.ip = host discovery_mock.port_override = custom_port - update_mock = mocker.patch.object(SmartStrip, "update") - x = await Discover.discover_single(host, port=custom_port) + device_class = Discover._get_device_class(discovery_mock.discovery_data) + update_mock = mocker.patch.object(device_class, "update") + + x = await Discover.discover_single( + host, port=custom_port, credentials=Credentials() + ) assert issubclass(x.__class__, SmartDevice) assert x._discovery_info is not None assert x.port == custom_port or x.port == discovery_mock.default_port - assert (update_mock.call_count > 0) == isinstance(x, SmartStrip) + assert update_mock.call_count == 0 + + ct = ConnectionType.from_values( + discovery_mock.device_type, discovery_mock.encrypt_type + ) + uses_http = discovery_mock.default_port == 80 + config = DeviceConfig( + host=host, port_override=custom_port, connection_type=ct, uses_http=uses_http + ) + assert x.config == config async def test_discover_single_hostname(discovery_mock, mocker): @@ -104,47 +125,39 @@ async def test_discover_single_hostname(discovery_mock, mocker): ip = "127.0.0.1" discovery_mock.ip = ip - update_mock = mocker.patch.object(SmartStrip, "update") + device_class = Discover._get_device_class(discovery_mock.discovery_data) + update_mock = mocker.patch.object(device_class, "update") - x = await Discover.discover_single(host) + x = await Discover.discover_single(host, credentials=Credentials()) assert issubclass(x.__class__, SmartDevice) assert x._discovery_info is not None assert x.host == host - assert (update_mock.call_count > 0) == isinstance(x, SmartStrip) + assert update_mock.call_count == 0 mocker.patch("socket.getaddrinfo", side_effect=socket.gaierror()) with pytest.raises(SmartDeviceException): - x = await Discover.discover_single(host) + x = await Discover.discover_single(host, credentials=Credentials()) -async def test_discover_single_unsupported(mocker): +async def test_discover_single_unsupported(unsupported_device_info, mocker): """Make sure that discover_single handles unsupported devices correctly.""" host = "127.0.0.1" - def mock_discover(self): - if discovery_data: - data = ( - b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8" - + json_dumps(discovery_data).encode() - ) - self.datagram_received(data, (host, 20002)) - - mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover) - # Test with a valid unsupported response - discovery_data = UNSUPPORTED with pytest.raises( UnsupportedDeviceException, - match=f"Unsupported device {host} of type SMART.TAPOXMASTREE: {re.escape(str(UNSUPPORTED))}", ): await Discover.discover_single(host) - # Test with no response - discovery_data = None + +async def test_discover_single_no_response(mocker): + """Make sure that discover_single handles no response correctly.""" + host = "127.0.0.1" + mocker.patch.object(_DiscoverProtocol, "do_discover") with pytest.raises( SmartDeviceException, match=f"Timed out getting discovery response for {host}" ): - await Discover.discover_single(host, timeout=0.001) + await Discover.discover_single(host, discovery_timeout=0) INVALIDS = [ @@ -241,52 +254,82 @@ AUTHENTICATION_DATA_KLAP = { } -async def test_discover_single_authentication(mocker): +@new_discovery +async def test_discover_single_authentication(discovery_mock, mocker): """Make sure that discover_single handles authenticating devices correctly.""" host = "127.0.0.1" - - def mock_discover(self): - if discovery_data: - data = ( - b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8" - + json_dumps(discovery_data).encode() - ) - self.datagram_received(data, (host, 20002)) - - mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover) + discovery_mock.ip = host + device_class = Discover._get_device_class(discovery_mock.discovery_data) mocker.patch.object( - SmartDevice, + device_class, "update", side_effect=AuthenticationException("Failed to authenticate"), ) - # Test with a valid unsupported response - discovery_data = AUTHENTICATION_DATA_KLAP with pytest.raises( AuthenticationException, match="Failed to authenticate", ): - device = await Discover.discover_single(host) + device = await Discover.discover_single( + host, credentials=Credentials("foo", "bar") + ) await device.update() - mocker.patch.object(SmartDevice, "update") - device = await Discover.discover_single(host) + mocker.patch.object(device_class, "update") + device = await Discover.discover_single(host, credentials=Credentials("foo", "bar")) await device.update() - assert device.device_type == DeviceType.Plug + assert isinstance(device, device_class) -async def test_device_update_from_new_discovery_info(): +@new_discovery +async def test_device_update_from_new_discovery_info(discovery_data): device = SmartDevice("127.0.0.7") - discover_info = DiscoveryResult(**AUTHENTICATION_DATA_KLAP["result"]) + discover_info = DiscoveryResult(**discovery_data["result"]) discover_dump = discover_info.get_dict() + discover_dump["alias"] = "foobar" + discover_dump["model"] = discover_dump["device_model"] device.update_from_discover_info(discover_dump) - assert device.alias == discover_dump["alias"] + assert device.alias == "foobar" assert device.mac == discover_dump["mac"].replace("-", ":") - assert device.model == discover_dump["model"] + assert device.model == discover_dump["device_model"] with pytest.raises( SmartDeviceException, match=re.escape("You need to await update() to access the data"), ): assert device.supported_modules + + +async def test_discover_single_http_client(discovery_mock, mocker): + """Make sure that discover_single returns an initialized SmartDevice instance.""" + host = "127.0.0.1" + discovery_mock.ip = host + + http_client = httpx.AsyncClient() + + x: SmartDevice = await Discover.discover_single(host) + + assert x.config.uses_http == (discovery_mock.default_port == 80) + + if discovery_mock.default_port == 80: + assert x.protocol._transport._http_client != http_client + x.config.http_client = http_client + assert x.protocol._transport._http_client == http_client + + +async def test_discover_http_client(discovery_mock, mocker): + """Make sure that discover_single returns an initialized SmartDevice instance.""" + host = "127.0.0.1" + discovery_mock.ip = host + + http_client = httpx.AsyncClient() + + devices = await Discover.discover(discovery_timeout=0) + x: SmartDevice = devices[host] + assert x.config.uses_http == (discovery_mock.default_port == 80) + + if discovery_mock.default_port == 80: + assert x.protocol._transport._http_client != http_client + x.config.http_client = http_client + assert x.protocol._transport._http_client == http_client diff --git a/kasa/tests/test_klapprotocol.py b/kasa/tests/test_klapprotocol.py index 1ed57ef2..5108fef0 100644 --- a/kasa/tests/test_klapprotocol.py +++ b/kasa/tests/test_klapprotocol.py @@ -12,9 +12,15 @@ import pytest from ..aestransport import AesTransport from ..credentials import Credentials +from ..deviceconfig import DeviceConfig from ..exceptions import AuthenticationException, SmartDeviceException from ..iotprotocol import IotProtocol -from ..klaptransport import KlapEncryptionSession, KlapTransport, _sha256 +from ..klaptransport import ( + KlapEncryptionSession, + KlapTransport, + KlapTransportV2, + _sha256, +) from ..smartprotocol import SmartProtocol DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}} @@ -31,8 +37,9 @@ class _mock_response: [ (Exception("dummy exception"), True), (SmartDeviceException("dummy exception"), False), + (httpx.ConnectError("dummy exception"), True), ], - ids=("Exception", "SmartDeviceException"), + ids=("Exception", "SmartDeviceException", "httpx.ConnectError"), ) @pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport]) @pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol]) @@ -42,8 +49,10 @@ async def test_protocol_retries( ): host = "127.0.0.1" conn = mocker.patch.object(httpx.AsyncClient, "post", side_effect=error) + + config = DeviceConfig(host) with pytest.raises(SmartDeviceException): - await protocol_class(host, transport=transport_class(host)).query( + await protocol_class(transport=transport_class(config=config)).query( DUMMY_QUERY, retry_count=retry_count ) @@ -60,10 +69,11 @@ async def test_protocol_no_retry_on_connection_error( conn = mocker.patch.object( httpx.AsyncClient, "post", - side_effect=httpx.ConnectError("foo"), + side_effect=AuthenticationException("foo"), ) + config = DeviceConfig(host) with pytest.raises(SmartDeviceException): - await protocol_class(host, transport=transport_class(host)).query( + await protocol_class(transport=transport_class(config=config)).query( DUMMY_QUERY, retry_count=5 ) @@ -81,8 +91,9 @@ async def test_protocol_retry_recoverable_error( "post", side_effect=httpx.CloseError("foo"), ) + config = DeviceConfig(host) with pytest.raises(SmartDeviceException): - await protocol_class(host, transport=transport_class(host)).query( + await protocol_class(transport=transport_class(config=config)).query( DUMMY_QUERY, retry_count=5 ) @@ -115,7 +126,8 @@ async def test_protocol_reconnect(mocker, retry_count, protocol_class, transport side_effect=_fail_one_less_than_retry_count, ) - response = await protocol_class(host, transport=transport_class(host)).query( + config = DeviceConfig(host) + response = await protocol_class(transport=transport_class(config=config)).query( DUMMY_QUERY, retry_count=retry_count ) assert "result" in response or "foobar" in response @@ -136,7 +148,9 @@ async def test_protocol_logging(mocker, caplog, log_level): seed = secrets.token_bytes(16) auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar")) encryption_session = KlapEncryptionSession(seed, seed, auth_hash) - protocol = IotProtocol("127.0.0.1", transport=KlapTransport("127.0.0.1")) + + config = DeviceConfig("127.0.0.1") + protocol = IotProtocol(transport=KlapTransport(config=config)) protocol._transport._handshake_done = True protocol._transport._session_expire_at = time.time() + 86400 @@ -181,7 +195,7 @@ def test_encrypt_unicode(): "device_credentials, expectation", [ (Credentials("foo", "bar"), does_not_raise()), - (Credentials("", ""), does_not_raise()), + (Credentials(), does_not_raise()), ( Credentials( KlapTransport.KASA_SETUP_EMAIL, @@ -196,30 +210,37 @@ def test_encrypt_unicode(): ], ids=("client", "blank", "kasa_setup", "shouldfail"), ) -async def test_handshake1(mocker, device_credentials, expectation): +@pytest.mark.parametrize( + "transport_class, seed_auth_hash_calc", + [ + pytest.param(KlapTransport, lambda c, s, a: c + a, id="KLAP"), + pytest.param(KlapTransportV2, lambda c, s, a: c + s + a, id="KLAPV2"), + ], +) +async def test_handshake1( + mocker, device_credentials, expectation, transport_class, seed_auth_hash_calc +): async def _return_handshake1_response(url, params=None, data=None, *_, **__): nonlocal client_seed, server_seed, device_auth_hash client_seed = data - client_seed_auth_hash = _sha256(data + device_auth_hash) - - return _mock_response(200, server_seed + client_seed_auth_hash) + seed_auth_hash = _sha256( + seed_auth_hash_calc(client_seed, server_seed, device_auth_hash) + ) + return _mock_response(200, server_seed + seed_auth_hash) client_seed = None server_seed = secrets.token_bytes(16) client_credentials = Credentials("foo", "bar") - device_auth_hash = KlapTransport.generate_auth_hash(device_credentials) + device_auth_hash = transport_class.generate_auth_hash(device_credentials) mocker.patch.object( httpx.AsyncClient, "post", side_effect=_return_handshake1_response ) - protocol = IotProtocol( - "127.0.0.1", - transport=KlapTransport("127.0.0.1", credentials=client_credentials), - ) + config = DeviceConfig("127.0.0.1", credentials=client_credentials) + protocol = IotProtocol(transport=transport_class(config=config)) - protocol._transport.http_client = httpx.AsyncClient() with expectation: ( local_seed, @@ -233,31 +254,51 @@ async def test_handshake1(mocker, device_credentials, expectation): await protocol.close() -async def test_handshake(mocker): +@pytest.mark.parametrize( + "transport_class, seed_auth_hash_calc1, seed_auth_hash_calc2", + [ + pytest.param( + KlapTransport, lambda c, s, a: c + a, lambda c, s, a: s + a, id="KLAP" + ), + pytest.param( + KlapTransportV2, + lambda c, s, a: c + s + a, + lambda c, s, a: s + c + a, + id="KLAPV2", + ), + ], +) +async def test_handshake( + mocker, transport_class, seed_auth_hash_calc1, seed_auth_hash_calc2 +): async def _return_handshake_response(url, params=None, data=None, *_, **__): - nonlocal response_status, client_seed, server_seed, device_auth_hash + nonlocal client_seed, server_seed, device_auth_hash if url == "http://127.0.0.1/app/handshake1": client_seed = data - client_seed_auth_hash = _sha256(data + device_auth_hash) + seed_auth_hash = _sha256( + seed_auth_hash_calc1(client_seed, server_seed, device_auth_hash) + ) - return _mock_response(200, server_seed + client_seed_auth_hash) + return _mock_response(200, server_seed + seed_auth_hash) elif url == "http://127.0.0.1/app/handshake2": + seed_auth_hash = _sha256( + seed_auth_hash_calc2(client_seed, server_seed, device_auth_hash) + ) + assert data == seed_auth_hash return _mock_response(response_status, b"") client_seed = None server_seed = secrets.token_bytes(16) client_credentials = Credentials("foo", "bar") - device_auth_hash = KlapTransport.generate_auth_hash(client_credentials) + device_auth_hash = transport_class.generate_auth_hash(client_credentials) mocker.patch.object( httpx.AsyncClient, "post", side_effect=_return_handshake_response ) - protocol = IotProtocol( - "127.0.0.1", - transport=KlapTransport("127.0.0.1", credentials=client_credentials), - ) + config = DeviceConfig("127.0.0.1", credentials=client_credentials) + protocol = IotProtocol(transport=transport_class(config=config)) protocol._transport.http_client = httpx.AsyncClient() response_status = 200 @@ -273,7 +314,7 @@ async def test_handshake(mocker): async def test_query(mocker): async def _return_response(url, params=None, data=None, *_, **__): - nonlocal client_seed, server_seed, device_auth_hash, protocol, seq + nonlocal client_seed, server_seed, device_auth_hash, seq if url == "http://127.0.0.1/app/handshake1": client_seed = data @@ -303,10 +344,8 @@ async def test_query(mocker): mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response) - protocol = IotProtocol( - "127.0.0.1", - transport=KlapTransport("127.0.0.1", credentials=client_credentials), - ) + config = DeviceConfig("127.0.0.1", credentials=client_credentials) + protocol = IotProtocol(transport=KlapTransport(config=config)) for _ in range(10): resp = await protocol.query({}) @@ -350,10 +389,8 @@ async def test_authentication_failures(mocker, response_status, expectation): mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response) - protocol = IotProtocol( - "127.0.0.1", - transport=KlapTransport("127.0.0.1", credentials=client_credentials), - ) + config = DeviceConfig("127.0.0.1", credentials=client_credentials) + protocol = IotProtocol(transport=KlapTransport(config=config)) with expectation: await protocol.query({}) diff --git a/kasa/tests/test_protocol.py b/kasa/tests/test_protocol.py index 7bd6342b..0e74da3b 100644 --- a/kasa/tests/test_protocol.py +++ b/kasa/tests/test_protocol.py @@ -9,6 +9,7 @@ import sys import pytest +from ..deviceconfig import DeviceConfig from ..exceptions import SmartDeviceException from ..protocol import ( BaseTransport, @@ -31,10 +32,11 @@ async def test_protocol_retries(mocker, retry_count): return reader, writer conn = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) + config = DeviceConfig("127.0.0.1") with pytest.raises(SmartDeviceException): - await TPLinkSmartHomeProtocol( - "127.0.0.1", transport=_XorTransport("127.0.0.1") - ).query({}, retry_count=retry_count) + await TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)).query( + {}, retry_count=retry_count + ) assert conn.call_count == retry_count + 1 @@ -44,10 +46,11 @@ async def test_protocol_no_retry_on_unreachable(mocker): "asyncio.open_connection", side_effect=OSError(errno.EHOSTUNREACH, "No route to host"), ) + config = DeviceConfig("127.0.0.1") with pytest.raises(SmartDeviceException): - await TPLinkSmartHomeProtocol( - "127.0.0.1", transport=_XorTransport("127.0.0.1") - ).query({}, retry_count=5) + await TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)).query( + {}, retry_count=5 + ) assert conn.call_count == 1 @@ -57,10 +60,11 @@ async def test_protocol_no_retry_connection_refused(mocker): "asyncio.open_connection", side_effect=ConnectionRefusedError, ) + config = DeviceConfig("127.0.0.1") with pytest.raises(SmartDeviceException): - await TPLinkSmartHomeProtocol( - "127.0.0.1", transport=_XorTransport("127.0.0.1") - ).query({}, retry_count=5) + await TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)).query( + {}, retry_count=5 + ) assert conn.call_count == 1 @@ -70,10 +74,11 @@ async def test_protocol_retry_recoverable_error(mocker): "asyncio.open_connection", side_effect=OSError(errno.ECONNRESET, "Connection reset by peer"), ) + config = DeviceConfig("127.0.0.1") with pytest.raises(SmartDeviceException): - await TPLinkSmartHomeProtocol( - "127.0.0.1", transport=_XorTransport("127.0.0.1") - ).query({}, retry_count=5) + await TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)).query( + {}, retry_count=5 + ) assert conn.call_count == 6 @@ -107,9 +112,8 @@ async def test_protocol_reconnect(mocker, retry_count): mocker.patch.object(reader, "readexactly", _mock_read) return reader, writer - protocol = TPLinkSmartHomeProtocol( - "127.0.0.1", transport=_XorTransport("127.0.0.1") - ) + config = DeviceConfig("127.0.0.1") + protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)) mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) response = await protocol.query({}, retry_count=retry_count) assert response == {"great": "success"} @@ -137,9 +141,8 @@ async def test_protocol_logging(mocker, caplog, log_level): mocker.patch.object(reader, "readexactly", _mock_read) return reader, writer - protocol = TPLinkSmartHomeProtocol( - "127.0.0.1", transport=_XorTransport("127.0.0.1") - ) + config = DeviceConfig("127.0.0.1") + protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)) mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) response = await protocol.query({}) assert response == {"great": "success"} @@ -173,9 +176,8 @@ async def test_protocol_custom_port(mocker, custom_port): mocker.patch.object(reader, "readexactly", _mock_read) return reader, writer - protocol = TPLinkSmartHomeProtocol( - "127.0.0.1", transport=_XorTransport("127.0.0.1", port=custom_port) - ) + config = DeviceConfig("127.0.0.1", port_override=custom_port) + protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)) mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) response = await protocol.query({}) assert response == {"great": "success"} @@ -271,18 +273,14 @@ def _get_subclasses(of_class): def test_protocol_init_signature(class_name_obj): params = list(inspect.signature(class_name_obj[1].__init__).parameters.values()) - assert len(params) == 3 + assert len(params) == 2 assert ( params[0].name == "self" and params[0].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD ) assert ( - params[1].name == "host" - and params[1].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD - ) - assert ( - params[2].name == "transport" - and params[2].kind == inspect.Parameter.KEYWORD_ONLY + params[1].name == "transport" + and params[1].kind == inspect.Parameter.KEYWORD_ONLY ) @@ -292,20 +290,11 @@ def test_protocol_init_signature(class_name_obj): def test_transport_init_signature(class_name_obj): params = list(inspect.signature(class_name_obj[1].__init__).parameters.values()) - assert len(params) == 5 + assert len(params) == 2 assert ( params[0].name == "self" and params[0].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD ) assert ( - params[1].name == "host" - and params[1].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD - ) - assert params[2].name == "port" and params[2].kind == inspect.Parameter.KEYWORD_ONLY - assert ( - params[3].name == "credentials" - and params[3].kind == inspect.Parameter.KEYWORD_ONLY - ) - assert ( - params[4].name == "timeout" and params[4].kind == inspect.Parameter.KEYWORD_ONLY + params[1].name == "config" and params[1].kind == inspect.Parameter.KEYWORD_ONLY ) diff --git a/kasa/tests/test_smartdevice.py b/kasa/tests/test_smartdevice.py index 47f523d0..a3019bff 100644 --- a/kasa/tests/test_smartdevice.py +++ b/kasa/tests/test_smartdevice.py @@ -5,8 +5,7 @@ from unittest.mock import Mock, patch import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 import kasa -from kasa import Credentials, SmartDevice, SmartDeviceException -from kasa.smartdevice import DeviceType +from kasa import Credentials, DeviceConfig, SmartDevice, SmartDeviceException from .conftest import device_iot, handle_turn_on, has_emeter, no_emeter_iot, turn_on from .newfakes import PLUG_SCHEMA, TZ_SCHEMA, FakeTransportProtocol @@ -215,7 +214,8 @@ def test_device_class_ctors(device_class): host = "127.0.0.2" port = 1234 credentials = Credentials("foo", "bar") - dev = device_class(host, port=port, credentials=credentials) + config = DeviceConfig(host, port_override=port, credentials=credentials) + dev = device_class(host, config=config) assert dev.host == host assert dev.port == port assert dev.credentials == credentials @@ -231,29 +231,27 @@ async def test_modules_preserved(dev: SmartDevice): async def test_create_smart_device_with_timeout(): """Make sure timeout is passed to the protocol.""" - dev = SmartDevice(host="127.0.0.1", timeout=100) + host = "127.0.0.1" + dev = SmartDevice(host, config=DeviceConfig(host, timeout=100)) assert dev.protocol._transport._timeout == 100 async def test_create_thin_wrapper(): """Make sure thin wrapper is created with the correct device type.""" mock = Mock() + config = DeviceConfig( + host="test_host", + port_override=1234, + timeout=100, + credentials=Credentials("username", "password"), + ) with patch("kasa.device_factory.connect", return_value=mock) as connect: - dev = await SmartDevice.connect( - host="test_host", - port=1234, - timeout=100, - credentials=Credentials("username", "password"), - device_type=DeviceType.Strip, - ) + dev = await SmartDevice.connect(config=config) assert dev is mock connect.assert_called_once_with( - host="test_host", - port=1234, - timeout=100, - credentials=Credentials("username", "password"), - device_type=DeviceType.Strip, + host=None, + config=config, ) diff --git a/kasa/tests/test_smartprotocol.py b/kasa/tests/test_smartprotocol.py index 5dbbed27..301e367f 100644 --- a/kasa/tests/test_smartprotocol.py +++ b/kasa/tests/test_smartprotocol.py @@ -13,6 +13,7 @@ import pytest from ..aestransport import AesTransport from ..credentials import Credentials +from ..deviceconfig import DeviceConfig from ..exceptions import ( SMART_RETRYABLE_ERRORS, SMART_TIMEOUT_ERRORS, @@ -37,7 +38,8 @@ async def test_smart_device_errors(mocker, error_code): send_mock = mocker.patch.object(AesTransport, "send", return_value=mock_response) - protocol = SmartProtocol(host, transport=AesTransport(host)) + config = DeviceConfig(host, credentials=Credentials("foo", "bar")) + protocol = SmartProtocol(transport=AesTransport(config=config)) with pytest.raises(SmartDeviceException): await protocol.query(DUMMY_QUERY, retry_count=2) @@ -70,8 +72,8 @@ async def test_smart_device_errors_in_multiple_request(mocker, error_code): mocker.patch.object(AesTransport, "perform_login") send_mock = mocker.patch.object(AesTransport, "send", return_value=mock_response) - - protocol = SmartProtocol(host, transport=AesTransport(host)) + config = DeviceConfig(host, credentials=Credentials("foo", "bar")) + protocol = SmartProtocol(transport=AesTransport(config=config)) with pytest.raises(SmartDeviceException): await protocol.query(DUMMY_QUERY, retry_count=2) if error_code in chain(SMART_TIMEOUT_ERRORS, SMART_RETRYABLE_ERRORS):