mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-10-12 10:28:01 +00:00
Update try_connect_all to be more efficient and report attempts (#1222)
This commit is contained in:
@@ -91,7 +91,7 @@ import socket
|
||||
import struct
|
||||
from collections.abc import Awaitable
|
||||
from pprint import pformat as pf
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, NamedTuple, Optional, Type, cast
|
||||
|
||||
from aiohttp import ClientSession
|
||||
|
||||
@@ -118,6 +118,7 @@ from kasa.exceptions import (
|
||||
TimeoutError,
|
||||
UnsupportedDeviceError,
|
||||
)
|
||||
from kasa.experimental import Experimental
|
||||
from kasa.iot.iotdevice import IotDevice
|
||||
from kasa.iotprotocol import REDACTORS as IOT_REDACTORS
|
||||
from kasa.json import dumps as json_dumps
|
||||
@@ -127,9 +128,21 @@ from kasa.xortransport import XorEncryption
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kasa import BaseProtocol, BaseTransport
|
||||
|
||||
|
||||
class ConnectAttempt(NamedTuple):
|
||||
"""Try to connect attempt."""
|
||||
|
||||
protocol: type
|
||||
transport: type
|
||||
device: type
|
||||
|
||||
|
||||
OnDiscoveredCallable = Callable[[Device], Awaitable[None]]
|
||||
OnUnsupportedCallable = Callable[[UnsupportedDeviceError], Awaitable[None]]
|
||||
OnConnectAttemptCallable = Callable[[ConnectAttempt, bool], None]
|
||||
DeviceDict = Dict[str, Device]
|
||||
|
||||
NEW_DISCOVERY_REDACTORS: dict[str, Callable[[Any], Any] | None] = {
|
||||
@@ -535,6 +548,7 @@ class Discover:
|
||||
timeout: int | None = None,
|
||||
credentials: Credentials | None = None,
|
||||
http_client: ClientSession | None = None,
|
||||
on_attempt: OnConnectAttemptCallable | None = None,
|
||||
) -> Device | None:
|
||||
"""Try to connect directly to a device with all possible parameters.
|
||||
|
||||
@@ -551,13 +565,22 @@ class Discover:
|
||||
"""
|
||||
from .device_factory import _connect
|
||||
|
||||
candidates = {
|
||||
main_device_families = {
|
||||
Device.Family.SmartTapoPlug,
|
||||
Device.Family.IotSmartPlugSwitch,
|
||||
}
|
||||
if Experimental.enabled():
|
||||
main_device_families.add(Device.Family.SmartIpCamera)
|
||||
candidates: dict[
|
||||
tuple[type[BaseProtocol], type[BaseTransport], type[Device]],
|
||||
tuple[BaseProtocol, DeviceConfig],
|
||||
] = {
|
||||
(type(protocol), type(protocol._transport), device_class): (
|
||||
protocol,
|
||||
config,
|
||||
)
|
||||
for encrypt in Device.EncryptionType
|
||||
for device_family in Device.Family
|
||||
for device_family in main_device_families
|
||||
for https in (True, False)
|
||||
if (
|
||||
conn_params := DeviceConnectionParameters(
|
||||
@@ -580,19 +603,26 @@ class Discover:
|
||||
and (protocol := get_protocol(config))
|
||||
and (
|
||||
device_class := get_device_class_from_family(
|
||||
device_family.value, https=https
|
||||
device_family.value, https=https, require_exact=True
|
||||
)
|
||||
)
|
||||
}
|
||||
for protocol, config in candidates.values():
|
||||
for key, val in candidates.items():
|
||||
try:
|
||||
dev = await _connect(config, protocol)
|
||||
prot, config = val
|
||||
dev = await _connect(config, prot)
|
||||
except Exception:
|
||||
_LOGGER.debug("Unable to connect with %s", protocol)
|
||||
_LOGGER.debug("Unable to connect with %s", prot)
|
||||
if on_attempt:
|
||||
ca = tuple.__new__(ConnectAttempt, key)
|
||||
on_attempt(ca, False)
|
||||
else:
|
||||
if on_attempt:
|
||||
ca = tuple.__new__(ConnectAttempt, key)
|
||||
on_attempt(ca, True)
|
||||
return dev
|
||||
finally:
|
||||
await protocol.close()
|
||||
await prot.close()
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
|
Reference in New Issue
Block a user