Update try_connect_all to be more efficient and report attempts (#1222)

This commit is contained in:
Steven B. 2024-11-01 18:17:18 +00:00 committed by GitHub
parent 70c96b5a5d
commit 77b654a9aa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 65 additions and 15 deletions

View File

@ -41,7 +41,7 @@ from kasa.iotprotocol import (
_deprecated_TPLinkSmartHomeProtocol, # noqa: F401 _deprecated_TPLinkSmartHomeProtocol, # noqa: F401
) )
from kasa.module import Module from kasa.module import Module
from kasa.protocol import BaseProtocol from kasa.protocol import BaseProtocol, BaseTransport
from kasa.smartprotocol import SmartProtocol from kasa.smartprotocol import SmartProtocol
__version__ = version("python-kasa") __version__ = version("python-kasa")
@ -50,6 +50,7 @@ __version__ = version("python-kasa")
__all__ = [ __all__ = [
"Discover", "Discover",
"BaseProtocol", "BaseProtocol",
"BaseTransport",
"IotProtocol", "IotProtocol",
"SmartProtocol", "SmartProtocol",
"LightState", "LightState",

View File

@ -15,7 +15,7 @@ from kasa import (
Discover, Discover,
UnsupportedDeviceError, UnsupportedDeviceError,
) )
from kasa.discover import DiscoveryResult from kasa.discover import ConnectAttempt, DiscoveryResult
from .common import echo, error from .common import echo, error
@ -165,8 +165,17 @@ async def config(ctx):
credentials = Credentials(username, password) if username and password else None credentials = Credentials(username, password) if username and password else None
host_port = host + (f":{port}" if port else "")
def on_attempt(connect_attempt: ConnectAttempt, success: bool) -> None:
prot, tran, dev = connect_attempt
key_str = f"{prot.__name__} + {tran.__name__} + {dev.__name__}"
result = "succeeded" if success else "failed"
msg = f"Attempt to connect to {host_port} with {key_str} {result}"
echo(msg)
dev = await Discover.try_connect_all( dev = await Discover.try_connect_all(
host, credentials=credentials, timeout=timeout, port=port host, credentials=credentials, timeout=timeout, port=port, on_attempt=on_attempt
) )
if dev: if dev:
cparams = dev.config.connection_type cparams = dev.config.connection_type

View File

@ -167,7 +167,7 @@ def get_device_class_from_sys_info(sysinfo: dict[str, Any]) -> type[IotDevice]:
def get_device_class_from_family( def get_device_class_from_family(
device_type: str, *, https: bool device_type: str, *, https: bool, require_exact: bool = False
) -> type[Device] | None: ) -> type[Device] | None:
"""Return the device class from the type name.""" """Return the device class from the type name."""
supported_device_types: dict[str, type[Device]] = { supported_device_types: dict[str, type[Device]] = {
@ -185,8 +185,10 @@ def get_device_class_from_family(
} }
lookup_key = f"{device_type}{'.HTTPS' if https else ''}" lookup_key = f"{device_type}{'.HTTPS' if https else ''}"
if ( if (
cls := supported_device_types.get(lookup_key) (cls := supported_device_types.get(lookup_key)) is None
) is None and device_type.startswith("SMART."): and device_type.startswith("SMART.")
and not require_exact
):
_LOGGER.warning("Unknown SMART device with %s, using SmartDevice", device_type) _LOGGER.warning("Unknown SMART device with %s, using SmartDevice", device_type)
cls = SmartDevice cls = SmartDevice

View File

@ -91,7 +91,7 @@ import socket
import struct import struct
from collections.abc import Awaitable from collections.abc import Awaitable
from pprint import pformat as pf from pprint import pformat as pf
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast from typing import TYPE_CHECKING, Any, Callable, Dict, NamedTuple, Optional, Type, cast
from aiohttp import ClientSession from aiohttp import ClientSession
@ -118,6 +118,7 @@ from kasa.exceptions import (
TimeoutError, TimeoutError,
UnsupportedDeviceError, UnsupportedDeviceError,
) )
from kasa.experimental import Experimental
from kasa.iot.iotdevice import IotDevice from kasa.iot.iotdevice import IotDevice
from kasa.iotprotocol import REDACTORS as IOT_REDACTORS from kasa.iotprotocol import REDACTORS as IOT_REDACTORS
from kasa.json import dumps as json_dumps from kasa.json import dumps as json_dumps
@ -127,9 +128,21 @@ from kasa.xortransport import XorEncryption
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
if TYPE_CHECKING:
from kasa import BaseProtocol, BaseTransport
class ConnectAttempt(NamedTuple):
"""Try to connect attempt."""
protocol: type
transport: type
device: type
OnDiscoveredCallable = Callable[[Device], Awaitable[None]] OnDiscoveredCallable = Callable[[Device], Awaitable[None]]
OnUnsupportedCallable = Callable[[UnsupportedDeviceError], Awaitable[None]] OnUnsupportedCallable = Callable[[UnsupportedDeviceError], Awaitable[None]]
OnConnectAttemptCallable = Callable[[ConnectAttempt, bool], None]
DeviceDict = Dict[str, Device] DeviceDict = Dict[str, Device]
NEW_DISCOVERY_REDACTORS: dict[str, Callable[[Any], Any] | None] = { NEW_DISCOVERY_REDACTORS: dict[str, Callable[[Any], Any] | None] = {
@ -535,6 +548,7 @@ class Discover:
timeout: int | None = None, timeout: int | None = None,
credentials: Credentials | None = None, credentials: Credentials | None = None,
http_client: ClientSession | None = None, http_client: ClientSession | None = None,
on_attempt: OnConnectAttemptCallable | None = None,
) -> Device | None: ) -> Device | None:
"""Try to connect directly to a device with all possible parameters. """Try to connect directly to a device with all possible parameters.
@ -551,13 +565,22 @@ class Discover:
""" """
from .device_factory import _connect from .device_factory import _connect
candidates = { main_device_families = {
Device.Family.SmartTapoPlug,
Device.Family.IotSmartPlugSwitch,
}
if Experimental.enabled():
main_device_families.add(Device.Family.SmartIpCamera)
candidates: dict[
tuple[type[BaseProtocol], type[BaseTransport], type[Device]],
tuple[BaseProtocol, DeviceConfig],
] = {
(type(protocol), type(protocol._transport), device_class): ( (type(protocol), type(protocol._transport), device_class): (
protocol, protocol,
config, config,
) )
for encrypt in Device.EncryptionType for encrypt in Device.EncryptionType
for device_family in Device.Family for device_family in main_device_families
for https in (True, False) for https in (True, False)
if ( if (
conn_params := DeviceConnectionParameters( conn_params := DeviceConnectionParameters(
@ -580,19 +603,26 @@ class Discover:
and (protocol := get_protocol(config)) and (protocol := get_protocol(config))
and ( and (
device_class := get_device_class_from_family( device_class := get_device_class_from_family(
device_family.value, https=https device_family.value, https=https, require_exact=True
) )
) )
} }
for protocol, config in candidates.values(): for key, val in candidates.items():
try: try:
dev = await _connect(config, protocol) prot, config = val
dev = await _connect(config, prot)
except Exception: except Exception:
_LOGGER.debug("Unable to connect with %s", protocol) _LOGGER.debug("Unable to connect with %s", prot)
if on_attempt:
ca = tuple.__new__(ConnectAttempt, key)
on_attempt(ca, False)
else: else:
if on_attempt:
ca = tuple.__new__(ConnectAttempt, key)
on_attempt(ca, True)
return dev return dev
finally: finally:
await protocol.close() await prot.close()
return None return None
@staticmethod @staticmethod

View File

@ -1162,7 +1162,7 @@ async def test_cli_child_commands(
async def test_discover_config(dev: Device, mocker, runner): async def test_discover_config(dev: Device, mocker, runner):
"""Test that device config is returned.""" """Test that device config is returned."""
host = "127.0.0.1" host = "127.0.0.1"
mocker.patch("kasa.discover.Discover.try_connect_all", return_value=dev) mocker.patch("kasa.device_factory._connect", side_effect=[Exception, dev])
res = await runner.invoke( res = await runner.invoke(
cli, cli,
@ -1182,6 +1182,14 @@ async def test_discover_config(dev: Device, mocker, runner):
cparam = dev.config.connection_type cparam = dev.config.connection_type
expected = f"--device-family {cparam.device_family.value} --encrypt-type {cparam.encryption_type.value} {'--https' if cparam.https else '--no-https'}" expected = f"--device-family {cparam.device_family.value} --encrypt-type {cparam.encryption_type.value} {'--https' if cparam.https else '--no-https'}"
assert expected in res.output assert expected in res.output
assert re.search(
r"Attempt to connect to 127\.0\.0\.1 with \w+ \+ \w+ \+ \w+ failed",
res.output.replace("\n", ""),
)
assert re.search(
r"Attempt to connect to 127\.0\.0\.1 with \w+ \+ \w+ \+ \w+ succeeded",
res.output.replace("\n", ""),
)
async def test_discover_config_invalid(mocker, runner): async def test_discover_config_invalid(mocker, runner):