Update try_connect_all to be more efficient and report attempts (#1222)

This commit is contained in:
Steven B. 2024-11-01 18:17:18 +00:00 committed by GitHub
parent 70c96b5a5d
commit 77b654a9aa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 65 additions and 15 deletions

View File

@ -41,7 +41,7 @@ from kasa.iotprotocol import (
_deprecated_TPLinkSmartHomeProtocol, # noqa: F401
)
from kasa.module import Module
from kasa.protocol import BaseProtocol
from kasa.protocol import BaseProtocol, BaseTransport
from kasa.smartprotocol import SmartProtocol
__version__ = version("python-kasa")
@ -50,6 +50,7 @@ __version__ = version("python-kasa")
__all__ = [
"Discover",
"BaseProtocol",
"BaseTransport",
"IotProtocol",
"SmartProtocol",
"LightState",

View File

@ -15,7 +15,7 @@ from kasa import (
Discover,
UnsupportedDeviceError,
)
from kasa.discover import DiscoveryResult
from kasa.discover import ConnectAttempt, DiscoveryResult
from .common import echo, error
@ -165,8 +165,17 @@ async def config(ctx):
credentials = Credentials(username, password) if username and password else None
host_port = host + (f":{port}" if port else "")
def on_attempt(connect_attempt: ConnectAttempt, success: bool) -> None:
prot, tran, dev = connect_attempt
key_str = f"{prot.__name__} + {tran.__name__} + {dev.__name__}"
result = "succeeded" if success else "failed"
msg = f"Attempt to connect to {host_port} with {key_str} {result}"
echo(msg)
dev = await Discover.try_connect_all(
host, credentials=credentials, timeout=timeout, port=port
host, credentials=credentials, timeout=timeout, port=port, on_attempt=on_attempt
)
if dev:
cparams = dev.config.connection_type

View File

@ -167,7 +167,7 @@ def get_device_class_from_sys_info(sysinfo: dict[str, Any]) -> type[IotDevice]:
def get_device_class_from_family(
device_type: str, *, https: bool
device_type: str, *, https: bool, require_exact: bool = False
) -> type[Device] | None:
"""Return the device class from the type name."""
supported_device_types: dict[str, type[Device]] = {
@ -185,8 +185,10 @@ def get_device_class_from_family(
}
lookup_key = f"{device_type}{'.HTTPS' if https else ''}"
if (
cls := supported_device_types.get(lookup_key)
) is None and device_type.startswith("SMART."):
(cls := supported_device_types.get(lookup_key)) is None
and device_type.startswith("SMART.")
and not require_exact
):
_LOGGER.warning("Unknown SMART device with %s, using SmartDevice", device_type)
cls = SmartDevice

View File

@ -91,7 +91,7 @@ import socket
import struct
from collections.abc import Awaitable
from pprint import pformat as pf
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast
from typing import TYPE_CHECKING, Any, Callable, Dict, NamedTuple, Optional, Type, cast
from aiohttp import ClientSession
@ -118,6 +118,7 @@ from kasa.exceptions import (
TimeoutError,
UnsupportedDeviceError,
)
from kasa.experimental import Experimental
from kasa.iot.iotdevice import IotDevice
from kasa.iotprotocol import REDACTORS as IOT_REDACTORS
from kasa.json import dumps as json_dumps
@ -127,9 +128,21 @@ from kasa.xortransport import XorEncryption
_LOGGER = logging.getLogger(__name__)
if TYPE_CHECKING:
from kasa import BaseProtocol, BaseTransport
class ConnectAttempt(NamedTuple):
"""Try to connect attempt."""
protocol: type
transport: type
device: type
OnDiscoveredCallable = Callable[[Device], Awaitable[None]]
OnUnsupportedCallable = Callable[[UnsupportedDeviceError], Awaitable[None]]
OnConnectAttemptCallable = Callable[[ConnectAttempt, bool], None]
DeviceDict = Dict[str, Device]
NEW_DISCOVERY_REDACTORS: dict[str, Callable[[Any], Any] | None] = {
@ -535,6 +548,7 @@ class Discover:
timeout: int | None = None,
credentials: Credentials | None = None,
http_client: ClientSession | None = None,
on_attempt: OnConnectAttemptCallable | None = None,
) -> Device | None:
"""Try to connect directly to a device with all possible parameters.
@ -551,13 +565,22 @@ class Discover:
"""
from .device_factory import _connect
candidates = {
main_device_families = {
Device.Family.SmartTapoPlug,
Device.Family.IotSmartPlugSwitch,
}
if Experimental.enabled():
main_device_families.add(Device.Family.SmartIpCamera)
candidates: dict[
tuple[type[BaseProtocol], type[BaseTransport], type[Device]],
tuple[BaseProtocol, DeviceConfig],
] = {
(type(protocol), type(protocol._transport), device_class): (
protocol,
config,
)
for encrypt in Device.EncryptionType
for device_family in Device.Family
for device_family in main_device_families
for https in (True, False)
if (
conn_params := DeviceConnectionParameters(
@ -580,19 +603,26 @@ class Discover:
and (protocol := get_protocol(config))
and (
device_class := get_device_class_from_family(
device_family.value, https=https
device_family.value, https=https, require_exact=True
)
)
}
for protocol, config in candidates.values():
for key, val in candidates.items():
try:
dev = await _connect(config, protocol)
prot, config = val
dev = await _connect(config, prot)
except Exception:
_LOGGER.debug("Unable to connect with %s", protocol)
_LOGGER.debug("Unable to connect with %s", prot)
if on_attempt:
ca = tuple.__new__(ConnectAttempt, key)
on_attempt(ca, False)
else:
if on_attempt:
ca = tuple.__new__(ConnectAttempt, key)
on_attempt(ca, True)
return dev
finally:
await protocol.close()
await prot.close()
return None
@staticmethod

View File

@ -1162,7 +1162,7 @@ async def test_cli_child_commands(
async def test_discover_config(dev: Device, mocker, runner):
"""Test that device config is returned."""
host = "127.0.0.1"
mocker.patch("kasa.discover.Discover.try_connect_all", return_value=dev)
mocker.patch("kasa.device_factory._connect", side_effect=[Exception, dev])
res = await runner.invoke(
cli,
@ -1182,6 +1182,14 @@ async def test_discover_config(dev: Device, mocker, runner):
cparam = dev.config.connection_type
expected = f"--device-family {cparam.device_family.value} --encrypt-type {cparam.encryption_type.value} {'--https' if cparam.https else '--no-https'}"
assert expected in res.output
assert re.search(
r"Attempt to connect to 127\.0\.0\.1 with \w+ \+ \w+ \+ \w+ failed",
res.output.replace("\n", ""),
)
assert re.search(
r"Attempt to connect to 127\.0\.0\.1 with \w+ \+ \w+ \+ \w+ succeeded",
res.output.replace("\n", ""),
)
async def test_discover_config_invalid(mocker, runner):