mirror of
https://github.com/python-kasa/python-kasa.git
synced 2024-12-22 11:13:34 +00:00
Update try_connect_all to be more efficient and report attempts (#1222)
This commit is contained in:
parent
70c96b5a5d
commit
77b654a9aa
@ -41,7 +41,7 @@ from kasa.iotprotocol import (
|
||||
_deprecated_TPLinkSmartHomeProtocol, # noqa: F401
|
||||
)
|
||||
from kasa.module import Module
|
||||
from kasa.protocol import BaseProtocol
|
||||
from kasa.protocol import BaseProtocol, BaseTransport
|
||||
from kasa.smartprotocol import SmartProtocol
|
||||
|
||||
__version__ = version("python-kasa")
|
||||
@ -50,6 +50,7 @@ __version__ = version("python-kasa")
|
||||
__all__ = [
|
||||
"Discover",
|
||||
"BaseProtocol",
|
||||
"BaseTransport",
|
||||
"IotProtocol",
|
||||
"SmartProtocol",
|
||||
"LightState",
|
||||
|
@ -15,7 +15,7 @@ from kasa import (
|
||||
Discover,
|
||||
UnsupportedDeviceError,
|
||||
)
|
||||
from kasa.discover import DiscoveryResult
|
||||
from kasa.discover import ConnectAttempt, DiscoveryResult
|
||||
|
||||
from .common import echo, error
|
||||
|
||||
@ -165,8 +165,17 @@ async def config(ctx):
|
||||
|
||||
credentials = Credentials(username, password) if username and password else None
|
||||
|
||||
host_port = host + (f":{port}" if port else "")
|
||||
|
||||
def on_attempt(connect_attempt: ConnectAttempt, success: bool) -> None:
|
||||
prot, tran, dev = connect_attempt
|
||||
key_str = f"{prot.__name__} + {tran.__name__} + {dev.__name__}"
|
||||
result = "succeeded" if success else "failed"
|
||||
msg = f"Attempt to connect to {host_port} with {key_str} {result}"
|
||||
echo(msg)
|
||||
|
||||
dev = await Discover.try_connect_all(
|
||||
host, credentials=credentials, timeout=timeout, port=port
|
||||
host, credentials=credentials, timeout=timeout, port=port, on_attempt=on_attempt
|
||||
)
|
||||
if dev:
|
||||
cparams = dev.config.connection_type
|
||||
|
@ -167,7 +167,7 @@ def get_device_class_from_sys_info(sysinfo: dict[str, Any]) -> type[IotDevice]:
|
||||
|
||||
|
||||
def get_device_class_from_family(
|
||||
device_type: str, *, https: bool
|
||||
device_type: str, *, https: bool, require_exact: bool = False
|
||||
) -> type[Device] | None:
|
||||
"""Return the device class from the type name."""
|
||||
supported_device_types: dict[str, type[Device]] = {
|
||||
@ -185,8 +185,10 @@ def get_device_class_from_family(
|
||||
}
|
||||
lookup_key = f"{device_type}{'.HTTPS' if https else ''}"
|
||||
if (
|
||||
cls := supported_device_types.get(lookup_key)
|
||||
) is None and device_type.startswith("SMART."):
|
||||
(cls := supported_device_types.get(lookup_key)) is None
|
||||
and device_type.startswith("SMART.")
|
||||
and not require_exact
|
||||
):
|
||||
_LOGGER.warning("Unknown SMART device with %s, using SmartDevice", device_type)
|
||||
cls = SmartDevice
|
||||
|
||||
|
@ -91,7 +91,7 @@ import socket
|
||||
import struct
|
||||
from collections.abc import Awaitable
|
||||
from pprint import pformat as pf
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, NamedTuple, Optional, Type, cast
|
||||
|
||||
from aiohttp import ClientSession
|
||||
|
||||
@ -118,6 +118,7 @@ from kasa.exceptions import (
|
||||
TimeoutError,
|
||||
UnsupportedDeviceError,
|
||||
)
|
||||
from kasa.experimental import Experimental
|
||||
from kasa.iot.iotdevice import IotDevice
|
||||
from kasa.iotprotocol import REDACTORS as IOT_REDACTORS
|
||||
from kasa.json import dumps as json_dumps
|
||||
@ -127,9 +128,21 @@ from kasa.xortransport import XorEncryption
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kasa import BaseProtocol, BaseTransport
|
||||
|
||||
|
||||
class ConnectAttempt(NamedTuple):
|
||||
"""Try to connect attempt."""
|
||||
|
||||
protocol: type
|
||||
transport: type
|
||||
device: type
|
||||
|
||||
|
||||
OnDiscoveredCallable = Callable[[Device], Awaitable[None]]
|
||||
OnUnsupportedCallable = Callable[[UnsupportedDeviceError], Awaitable[None]]
|
||||
OnConnectAttemptCallable = Callable[[ConnectAttempt, bool], None]
|
||||
DeviceDict = Dict[str, Device]
|
||||
|
||||
NEW_DISCOVERY_REDACTORS: dict[str, Callable[[Any], Any] | None] = {
|
||||
@ -535,6 +548,7 @@ class Discover:
|
||||
timeout: int | None = None,
|
||||
credentials: Credentials | None = None,
|
||||
http_client: ClientSession | None = None,
|
||||
on_attempt: OnConnectAttemptCallable | None = None,
|
||||
) -> Device | None:
|
||||
"""Try to connect directly to a device with all possible parameters.
|
||||
|
||||
@ -551,13 +565,22 @@ class Discover:
|
||||
"""
|
||||
from .device_factory import _connect
|
||||
|
||||
candidates = {
|
||||
main_device_families = {
|
||||
Device.Family.SmartTapoPlug,
|
||||
Device.Family.IotSmartPlugSwitch,
|
||||
}
|
||||
if Experimental.enabled():
|
||||
main_device_families.add(Device.Family.SmartIpCamera)
|
||||
candidates: dict[
|
||||
tuple[type[BaseProtocol], type[BaseTransport], type[Device]],
|
||||
tuple[BaseProtocol, DeviceConfig],
|
||||
] = {
|
||||
(type(protocol), type(protocol._transport), device_class): (
|
||||
protocol,
|
||||
config,
|
||||
)
|
||||
for encrypt in Device.EncryptionType
|
||||
for device_family in Device.Family
|
||||
for device_family in main_device_families
|
||||
for https in (True, False)
|
||||
if (
|
||||
conn_params := DeviceConnectionParameters(
|
||||
@ -580,19 +603,26 @@ class Discover:
|
||||
and (protocol := get_protocol(config))
|
||||
and (
|
||||
device_class := get_device_class_from_family(
|
||||
device_family.value, https=https
|
||||
device_family.value, https=https, require_exact=True
|
||||
)
|
||||
)
|
||||
}
|
||||
for protocol, config in candidates.values():
|
||||
for key, val in candidates.items():
|
||||
try:
|
||||
dev = await _connect(config, protocol)
|
||||
prot, config = val
|
||||
dev = await _connect(config, prot)
|
||||
except Exception:
|
||||
_LOGGER.debug("Unable to connect with %s", protocol)
|
||||
_LOGGER.debug("Unable to connect with %s", prot)
|
||||
if on_attempt:
|
||||
ca = tuple.__new__(ConnectAttempt, key)
|
||||
on_attempt(ca, False)
|
||||
else:
|
||||
if on_attempt:
|
||||
ca = tuple.__new__(ConnectAttempt, key)
|
||||
on_attempt(ca, True)
|
||||
return dev
|
||||
finally:
|
||||
await protocol.close()
|
||||
await prot.close()
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
|
@ -1162,7 +1162,7 @@ async def test_cli_child_commands(
|
||||
async def test_discover_config(dev: Device, mocker, runner):
|
||||
"""Test that device config is returned."""
|
||||
host = "127.0.0.1"
|
||||
mocker.patch("kasa.discover.Discover.try_connect_all", return_value=dev)
|
||||
mocker.patch("kasa.device_factory._connect", side_effect=[Exception, dev])
|
||||
|
||||
res = await runner.invoke(
|
||||
cli,
|
||||
@ -1182,6 +1182,14 @@ async def test_discover_config(dev: Device, mocker, runner):
|
||||
cparam = dev.config.connection_type
|
||||
expected = f"--device-family {cparam.device_family.value} --encrypt-type {cparam.encryption_type.value} {'--https' if cparam.https else '--no-https'}"
|
||||
assert expected in res.output
|
||||
assert re.search(
|
||||
r"Attempt to connect to 127\.0\.0\.1 with \w+ \+ \w+ \+ \w+ failed",
|
||||
res.output.replace("\n", ""),
|
||||
)
|
||||
assert re.search(
|
||||
r"Attempt to connect to 127\.0\.0\.1 with \w+ \+ \w+ \+ \w+ succeeded",
|
||||
res.output.replace("\n", ""),
|
||||
)
|
||||
|
||||
|
||||
async def test_discover_config_invalid(mocker, runner):
|
||||
|
Loading…
Reference in New Issue
Block a user