Add DeviceConfig to allow specifying configuration parameters (#569)

* Add DeviceConfig handling

* Update post review

* Further update post latest review

* Update following latest review

* Update docstrings and docs
This commit is contained in:
sdb9696
2023-12-29 19:17:15 +00:00
committed by GitHub
parent ec3ea39a37
commit f6fd898faf
33 changed files with 1032 additions and 589 deletions

View File

@@ -1,5 +1,6 @@
"""Discovery module for TP-Link Smart Home devices."""
import asyncio
import base64
import binascii
import ipaddress
import logging
@@ -11,29 +12,32 @@ from typing import Awaitable, Callable, Dict, Optional, Set, Type, cast
from async_timeout import timeout as asyncio_timeout
try:
from pydantic.v1 import BaseModel, Field
from pydantic.v1 import BaseModel, ValidationError # pragma: no cover
except ImportError:
from pydantic import BaseModel, Field
from pydantic import BaseModel, ValidationError # pragma: no cover
from kasa.credentials import Credentials
from kasa.device_factory import (
get_device_class_from_family,
get_device_class_from_sys_info,
get_protocol,
)
from kasa.deviceconfig import ConnectionType, DeviceConfig, EncryptType
from kasa.exceptions import UnsupportedDeviceException
from kasa.json import dumps as json_dumps
from kasa.json import loads as json_loads
from kasa.protocol import TPLinkSmartHomeProtocol
from kasa.smartdevice import SmartDevice, SmartDeviceException
from .device_factory import (
get_device_class_from_sys_info,
get_device_class_from_type_name,
get_protocol_from_connection_name,
)
_LOGGER = logging.getLogger(__name__)
OnDiscoveredCallable = Callable[[SmartDevice], Awaitable[None]]
DeviceDict = Dict[str, SmartDevice]
UNAVAILABLE_ALIAS = "Authentication required"
UNAVAILABLE_NICKNAME = base64.b64encode(UNAVAILABLE_ALIAS.encode()).decode()
class _DiscoverProtocol(asyncio.DatagramProtocol):
"""Implementation of the discovery protocol handler.
@@ -62,9 +66,12 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
self.discovery_packets = discovery_packets
self.interface = interface
self.on_discovered = on_discovered
self.port = port
self.discovery_port = port or Discover.DISCOVERY_PORT
self.target = (target, self.discovery_port)
self.target_2 = (target, Discover.DISCOVERY_PORT_2)
self.discovered_devices = {}
self.unsupported_device_exceptions: Dict = {}
self.invalid_device_exceptions: Dict = {}
@@ -110,13 +117,18 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
self.seen_hosts.add(ip)
device = None
config = DeviceConfig(host=ip, port_override=self.port)
if self.credentials:
config.credentials = self.credentials
if self.timeout:
config.timeout = self.timeout
try:
if port == self.discovery_port:
device = Discover._get_device_instance_legacy(data, ip, port)
device = Discover._get_device_instance_legacy(data, config)
elif port == Discover.DISCOVERY_PORT_2:
device = Discover._get_device_instance(
data, ip, port, self.credentials or Credentials()
)
config.uses_http = True
device = Discover._get_device_instance(data, config)
else:
return
except UnsupportedDeviceException as udex:
@@ -200,11 +212,13 @@ class Discover:
*,
target="255.255.255.255",
on_discovered=None,
timeout=5,
discovery_timeout=5,
discovery_packets=3,
interface=None,
on_unsupported=None,
credentials=None,
port=None,
timeout=None,
) -> DeviceDict:
"""Discover supported devices.
@@ -240,14 +254,15 @@ class Discover:
on_unsupported=on_unsupported,
credentials=credentials,
timeout=timeout,
port=port,
),
local_addr=("0.0.0.0", 0), # noqa: S104
)
protocol = cast(_DiscoverProtocol, protocol)
try:
_LOGGER.debug("Waiting %s seconds for responses...", timeout)
await asyncio.sleep(timeout)
_LOGGER.debug("Waiting %s seconds for responses...", discovery_timeout)
await asyncio.sleep(discovery_timeout)
finally:
transport.close()
@@ -259,10 +274,10 @@ class Discover:
async def discover_single(
host: str,
*,
discovery_timeout: int = 5,
port: Optional[int] = None,
timeout=5,
timeout: Optional[int] = None,
credentials: Optional[Credentials] = None,
update_parent_devices: bool = True,
) -> SmartDevice:
"""Discover a single device by the given IP address.
@@ -275,8 +290,6 @@ class Discover:
:param port: Optionally set a different port for the device
:param timeout: Timeout for discovery
:param credentials: Credentials for devices that require authentication
:param update_parent_devices: Automatically call device.update() on
devices that have children
:rtype: SmartDevice
:return: Object for querying/controlling found device.
"""
@@ -320,9 +333,11 @@ class Discover:
protocol = cast(_DiscoverProtocol, protocol)
try:
_LOGGER.debug("Waiting a total of %s seconds for responses...", timeout)
_LOGGER.debug(
"Waiting a total of %s seconds for responses...", discovery_timeout
)
async with asyncio_timeout(timeout):
async with asyncio_timeout(discovery_timeout):
await event.wait()
except asyncio.TimeoutError as ex:
raise SmartDeviceException(
@@ -334,9 +349,6 @@ class Discover:
if ip in protocol.discovered_devices:
dev = protocol.discovered_devices[ip]
dev.host = host
# Call device update on devices that have children
if update_parent_devices and dev.has_children:
await dev.update()
return dev
elif ip in protocol.unsupported_device_exceptions:
raise protocol.unsupported_device_exceptions[ip]
@@ -350,99 +362,121 @@ class Discover:
"""Find SmartDevice subclass for device described by passed data."""
if "result" in info:
discovery_result = DiscoveryResult(**info["result"])
dev_class = get_device_class_from_type_name(discovery_result.device_type)
dev_class = get_device_class_from_family(discovery_result.device_type)
if not dev_class:
raise UnsupportedDeviceException(
"Unknown device type: %s" % discovery_result.device_type
"Unknown device type: %s" % discovery_result.device_type,
discovery_result=info,
)
return dev_class
else:
return get_device_class_from_sys_info(info)
@staticmethod
def _get_device_instance_legacy(data: bytes, ip: str, port: int) -> SmartDevice:
def _get_device_instance_legacy(data: bytes, config: DeviceConfig) -> SmartDevice:
"""Get SmartDevice from legacy 9999 response."""
try:
info = json_loads(TPLinkSmartHomeProtocol.decrypt(data))
except Exception as ex:
raise SmartDeviceException(
f"Unable to read response from device: {ip}: {ex}"
f"Unable to read response from device: {config.host}: {ex}"
) from ex
_LOGGER.debug("[DISCOVERY] %s << %s", ip, info)
_LOGGER.debug("[DISCOVERY] %s << %s", config.host, info)
device_class = Discover._get_device_class(info)
device = device_class(ip, port=port)
device = device_class(config.host, config=config)
sys_info = info["system"]["get_sysinfo"]
if device_type := sys_info.get("mic_type", sys_info.get("type")):
config.connection_type = ConnectionType.from_values(
device_family=device_type, encryption_type=EncryptType.Xor.value
)
device.protocol = get_protocol(config) # type: ignore[assignment]
device.update_from_discover_info(info)
return device
@staticmethod
def _get_device_instance(
data: bytes, ip: str, port: int, credentials: Credentials
data: bytes,
config: DeviceConfig,
) -> SmartDevice:
"""Get SmartDevice from the new 20002 response."""
try:
info = json_loads(data[16:])
discovery_result = DiscoveryResult(**info["result"])
except Exception as ex:
_LOGGER.debug("Got invalid response from device %s: %s", config.host, data)
raise SmartDeviceException(
f"Unable to read response from device: {config.host}: {ex}"
) from ex
try:
discovery_result = DiscoveryResult(**info["result"])
except ValidationError as ex:
_LOGGER.debug(
"Unable to parse discovery from device %s: %s", config.host, info
)
raise UnsupportedDeviceException(
f"Unable to read response from device: {ip}: {ex}"
f"Unable to parse discovery from device: {config.host}: {ex}"
) from ex
type_ = discovery_result.device_type
encrypt_type_ = (
f"{type_.split('.')[0]}.{discovery_result.mgt_encrypt_schm.encrypt_type}"
)
if (device_class := get_device_class_from_type_name(type_)) is None:
try:
config.connection_type = ConnectionType.from_values(
type_, discovery_result.mgt_encrypt_schm.encrypt_type
)
except SmartDeviceException as ex:
raise UnsupportedDeviceException(
f"Unsupported device {config.host} of type {type_} "
+ f"with encrypt_type {discovery_result.mgt_encrypt_schm.encrypt_type}",
discovery_result=discovery_result.get_dict(),
) from ex
if (device_class := get_device_class_from_family(type_)) is None:
_LOGGER.warning("Got unsupported device type: %s", type_)
raise UnsupportedDeviceException(
f"Unsupported device {ip} of type {type_}: {info}",
f"Unsupported device {config.host} of type {type_}: {info}",
discovery_result=discovery_result.get_dict(),
)
if (
protocol := get_protocol_from_connection_name(
encrypt_type_, ip, credentials=credentials
if (protocol := get_protocol(config)) is None:
_LOGGER.warning(
"Got unsupported connection type: %s", config.connection_type.to_dict()
)
) is None:
_LOGGER.warning("Got unsupported device type: %s", encrypt_type_)
raise UnsupportedDeviceException(
f"Unsupported encryption scheme {ip} of type {encrypt_type_}: {info}",
f"Unsupported encryption scheme {config.host} of "
+ f"type {config.connection_type.to_dict()}: {info}",
discovery_result=discovery_result.get_dict(),
)
_LOGGER.debug("[DISCOVERY] %s << %s", ip, info)
device = device_class(ip, port=port, credentials=credentials)
device.protocol = protocol
device.update_from_discover_info(discovery_result.get_dict())
_LOGGER.debug("[DISCOVERY] %s << %s", config.host, info)
device = device_class(config.host, protocol=protocol)
di = discovery_result.get_dict()
di["model"] = discovery_result.device_model
di["alias"] = UNAVAILABLE_ALIAS
di["nickname"] = UNAVAILABLE_NICKNAME
device.update_from_discover_info(di)
return device
class DiscoveryResult(BaseModel):
"""Base model for discovery result."""
class Config:
"""Class for configuring model behaviour."""
allow_population_by_field_name = True
class EncryptionScheme(BaseModel):
"""Base model for encryption scheme of discovery result."""
is_support_https: Optional[bool] = None
encrypt_type: Optional[str] = None
http_port: Optional[int] = None
lv: Optional[int] = 1
is_support_https: bool
encrypt_type: str
http_port: int
lv: Optional[int] = None
device_type: str = Field(alias="device_type_text")
device_model: str = Field(alias="model")
ip: str = Field(alias="alias")
device_type: str
device_model: str
ip: str
mac: str
mgt_encrypt_schm: EncryptionScheme
device_id: str
device_id: Optional[str] = Field(default=None, alias="device_id_hash")
owner: Optional[str] = Field(default=None, alias="device_owner_hash")
hw_ver: Optional[str] = None
owner: Optional[str] = None
is_support_iot_cloud: Optional[bool] = None
obd_src: Optional[str] = None
factory_default: Optional[bool] = None
@@ -453,5 +487,5 @@ class DiscoveryResult(BaseModel):
containing only the values actually set and with aliases as field names.
"""
return self.dict(
by_alias=True, exclude_unset=True, exclude_none=True, exclude_defaults=True
by_alias=False, exclude_unset=True, exclude_none=True, exclude_defaults=True
)