Update try_connect_all to be more efficient and report attempts (#1222)

This commit is contained in:
Steven B.
2024-11-01 18:17:18 +00:00
committed by GitHub
parent 70c96b5a5d
commit 77b654a9aa
5 changed files with 65 additions and 15 deletions

View File

@@ -91,7 +91,7 @@ import socket
import struct
from collections.abc import Awaitable
from pprint import pformat as pf
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast
from typing import TYPE_CHECKING, Any, Callable, Dict, NamedTuple, Optional, Type, cast
from aiohttp import ClientSession
@@ -118,6 +118,7 @@ from kasa.exceptions import (
TimeoutError,
UnsupportedDeviceError,
)
from kasa.experimental import Experimental
from kasa.iot.iotdevice import IotDevice
from kasa.iotprotocol import REDACTORS as IOT_REDACTORS
from kasa.json import dumps as json_dumps
@@ -127,9 +128,21 @@ from kasa.xortransport import XorEncryption
_LOGGER = logging.getLogger(__name__)
if TYPE_CHECKING:
from kasa import BaseProtocol, BaseTransport
class ConnectAttempt(NamedTuple):
"""Try to connect attempt."""
protocol: type
transport: type
device: type
OnDiscoveredCallable = Callable[[Device], Awaitable[None]]
OnUnsupportedCallable = Callable[[UnsupportedDeviceError], Awaitable[None]]
OnConnectAttemptCallable = Callable[[ConnectAttempt, bool], None]
DeviceDict = Dict[str, Device]
NEW_DISCOVERY_REDACTORS: dict[str, Callable[[Any], Any] | None] = {
@@ -535,6 +548,7 @@ class Discover:
timeout: int | None = None,
credentials: Credentials | None = None,
http_client: ClientSession | None = None,
on_attempt: OnConnectAttemptCallable | None = None,
) -> Device | None:
"""Try to connect directly to a device with all possible parameters.
@@ -551,13 +565,22 @@ class Discover:
"""
from .device_factory import _connect
candidates = {
main_device_families = {
Device.Family.SmartTapoPlug,
Device.Family.IotSmartPlugSwitch,
}
if Experimental.enabled():
main_device_families.add(Device.Family.SmartIpCamera)
candidates: dict[
tuple[type[BaseProtocol], type[BaseTransport], type[Device]],
tuple[BaseProtocol, DeviceConfig],
] = {
(type(protocol), type(protocol._transport), device_class): (
protocol,
config,
)
for encrypt in Device.EncryptionType
for device_family in Device.Family
for device_family in main_device_families
for https in (True, False)
if (
conn_params := DeviceConnectionParameters(
@@ -580,19 +603,26 @@ class Discover:
and (protocol := get_protocol(config))
and (
device_class := get_device_class_from_family(
device_family.value, https=https
device_family.value, https=https, require_exact=True
)
)
}
for protocol, config in candidates.values():
for key, val in candidates.items():
try:
dev = await _connect(config, protocol)
prot, config = val
dev = await _connect(config, prot)
except Exception:
_LOGGER.debug("Unable to connect with %s", protocol)
_LOGGER.debug("Unable to connect with %s", prot)
if on_attempt:
ca = tuple.__new__(ConnectAttempt, key)
on_attempt(ca, False)
else:
if on_attempt:
ca = tuple.__new__(ConnectAttempt, key)
on_attempt(ca, True)
return dev
finally:
await protocol.close()
await prot.close()
return None
@staticmethod