mirror of
https://github.com/python-kasa/python-kasa.git
synced 2024-12-23 03:33:35 +00:00
Update try_connect_all to be more efficient and report attempts (#1222)
This commit is contained in:
parent
70c96b5a5d
commit
77b654a9aa
@ -41,7 +41,7 @@ from kasa.iotprotocol import (
|
|||||||
_deprecated_TPLinkSmartHomeProtocol, # noqa: F401
|
_deprecated_TPLinkSmartHomeProtocol, # noqa: F401
|
||||||
)
|
)
|
||||||
from kasa.module import Module
|
from kasa.module import Module
|
||||||
from kasa.protocol import BaseProtocol
|
from kasa.protocol import BaseProtocol, BaseTransport
|
||||||
from kasa.smartprotocol import SmartProtocol
|
from kasa.smartprotocol import SmartProtocol
|
||||||
|
|
||||||
__version__ = version("python-kasa")
|
__version__ = version("python-kasa")
|
||||||
@ -50,6 +50,7 @@ __version__ = version("python-kasa")
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"Discover",
|
"Discover",
|
||||||
"BaseProtocol",
|
"BaseProtocol",
|
||||||
|
"BaseTransport",
|
||||||
"IotProtocol",
|
"IotProtocol",
|
||||||
"SmartProtocol",
|
"SmartProtocol",
|
||||||
"LightState",
|
"LightState",
|
||||||
|
@ -15,7 +15,7 @@ from kasa import (
|
|||||||
Discover,
|
Discover,
|
||||||
UnsupportedDeviceError,
|
UnsupportedDeviceError,
|
||||||
)
|
)
|
||||||
from kasa.discover import DiscoveryResult
|
from kasa.discover import ConnectAttempt, DiscoveryResult
|
||||||
|
|
||||||
from .common import echo, error
|
from .common import echo, error
|
||||||
|
|
||||||
@ -165,8 +165,17 @@ async def config(ctx):
|
|||||||
|
|
||||||
credentials = Credentials(username, password) if username and password else None
|
credentials = Credentials(username, password) if username and password else None
|
||||||
|
|
||||||
|
host_port = host + (f":{port}" if port else "")
|
||||||
|
|
||||||
|
def on_attempt(connect_attempt: ConnectAttempt, success: bool) -> None:
|
||||||
|
prot, tran, dev = connect_attempt
|
||||||
|
key_str = f"{prot.__name__} + {tran.__name__} + {dev.__name__}"
|
||||||
|
result = "succeeded" if success else "failed"
|
||||||
|
msg = f"Attempt to connect to {host_port} with {key_str} {result}"
|
||||||
|
echo(msg)
|
||||||
|
|
||||||
dev = await Discover.try_connect_all(
|
dev = await Discover.try_connect_all(
|
||||||
host, credentials=credentials, timeout=timeout, port=port
|
host, credentials=credentials, timeout=timeout, port=port, on_attempt=on_attempt
|
||||||
)
|
)
|
||||||
if dev:
|
if dev:
|
||||||
cparams = dev.config.connection_type
|
cparams = dev.config.connection_type
|
||||||
|
@ -167,7 +167,7 @@ def get_device_class_from_sys_info(sysinfo: dict[str, Any]) -> type[IotDevice]:
|
|||||||
|
|
||||||
|
|
||||||
def get_device_class_from_family(
|
def get_device_class_from_family(
|
||||||
device_type: str, *, https: bool
|
device_type: str, *, https: bool, require_exact: bool = False
|
||||||
) -> type[Device] | None:
|
) -> type[Device] | None:
|
||||||
"""Return the device class from the type name."""
|
"""Return the device class from the type name."""
|
||||||
supported_device_types: dict[str, type[Device]] = {
|
supported_device_types: dict[str, type[Device]] = {
|
||||||
@ -185,8 +185,10 @@ def get_device_class_from_family(
|
|||||||
}
|
}
|
||||||
lookup_key = f"{device_type}{'.HTTPS' if https else ''}"
|
lookup_key = f"{device_type}{'.HTTPS' if https else ''}"
|
||||||
if (
|
if (
|
||||||
cls := supported_device_types.get(lookup_key)
|
(cls := supported_device_types.get(lookup_key)) is None
|
||||||
) is None and device_type.startswith("SMART."):
|
and device_type.startswith("SMART.")
|
||||||
|
and not require_exact
|
||||||
|
):
|
||||||
_LOGGER.warning("Unknown SMART device with %s, using SmartDevice", device_type)
|
_LOGGER.warning("Unknown SMART device with %s, using SmartDevice", device_type)
|
||||||
cls = SmartDevice
|
cls = SmartDevice
|
||||||
|
|
||||||
|
@ -91,7 +91,7 @@ import socket
|
|||||||
import struct
|
import struct
|
||||||
from collections.abc import Awaitable
|
from collections.abc import Awaitable
|
||||||
from pprint import pformat as pf
|
from pprint import pformat as pf
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast
|
from typing import TYPE_CHECKING, Any, Callable, Dict, NamedTuple, Optional, Type, cast
|
||||||
|
|
||||||
from aiohttp import ClientSession
|
from aiohttp import ClientSession
|
||||||
|
|
||||||
@ -118,6 +118,7 @@ from kasa.exceptions import (
|
|||||||
TimeoutError,
|
TimeoutError,
|
||||||
UnsupportedDeviceError,
|
UnsupportedDeviceError,
|
||||||
)
|
)
|
||||||
|
from kasa.experimental import Experimental
|
||||||
from kasa.iot.iotdevice import IotDevice
|
from kasa.iot.iotdevice import IotDevice
|
||||||
from kasa.iotprotocol import REDACTORS as IOT_REDACTORS
|
from kasa.iotprotocol import REDACTORS as IOT_REDACTORS
|
||||||
from kasa.json import dumps as json_dumps
|
from kasa.json import dumps as json_dumps
|
||||||
@ -127,9 +128,21 @@ from kasa.xortransport import XorEncryption
|
|||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from kasa import BaseProtocol, BaseTransport
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectAttempt(NamedTuple):
|
||||||
|
"""Try to connect attempt."""
|
||||||
|
|
||||||
|
protocol: type
|
||||||
|
transport: type
|
||||||
|
device: type
|
||||||
|
|
||||||
|
|
||||||
OnDiscoveredCallable = Callable[[Device], Awaitable[None]]
|
OnDiscoveredCallable = Callable[[Device], Awaitable[None]]
|
||||||
OnUnsupportedCallable = Callable[[UnsupportedDeviceError], Awaitable[None]]
|
OnUnsupportedCallable = Callable[[UnsupportedDeviceError], Awaitable[None]]
|
||||||
|
OnConnectAttemptCallable = Callable[[ConnectAttempt, bool], None]
|
||||||
DeviceDict = Dict[str, Device]
|
DeviceDict = Dict[str, Device]
|
||||||
|
|
||||||
NEW_DISCOVERY_REDACTORS: dict[str, Callable[[Any], Any] | None] = {
|
NEW_DISCOVERY_REDACTORS: dict[str, Callable[[Any], Any] | None] = {
|
||||||
@ -535,6 +548,7 @@ class Discover:
|
|||||||
timeout: int | None = None,
|
timeout: int | None = None,
|
||||||
credentials: Credentials | None = None,
|
credentials: Credentials | None = None,
|
||||||
http_client: ClientSession | None = None,
|
http_client: ClientSession | None = None,
|
||||||
|
on_attempt: OnConnectAttemptCallable | None = None,
|
||||||
) -> Device | None:
|
) -> Device | None:
|
||||||
"""Try to connect directly to a device with all possible parameters.
|
"""Try to connect directly to a device with all possible parameters.
|
||||||
|
|
||||||
@ -551,13 +565,22 @@ class Discover:
|
|||||||
"""
|
"""
|
||||||
from .device_factory import _connect
|
from .device_factory import _connect
|
||||||
|
|
||||||
candidates = {
|
main_device_families = {
|
||||||
|
Device.Family.SmartTapoPlug,
|
||||||
|
Device.Family.IotSmartPlugSwitch,
|
||||||
|
}
|
||||||
|
if Experimental.enabled():
|
||||||
|
main_device_families.add(Device.Family.SmartIpCamera)
|
||||||
|
candidates: dict[
|
||||||
|
tuple[type[BaseProtocol], type[BaseTransport], type[Device]],
|
||||||
|
tuple[BaseProtocol, DeviceConfig],
|
||||||
|
] = {
|
||||||
(type(protocol), type(protocol._transport), device_class): (
|
(type(protocol), type(protocol._transport), device_class): (
|
||||||
protocol,
|
protocol,
|
||||||
config,
|
config,
|
||||||
)
|
)
|
||||||
for encrypt in Device.EncryptionType
|
for encrypt in Device.EncryptionType
|
||||||
for device_family in Device.Family
|
for device_family in main_device_families
|
||||||
for https in (True, False)
|
for https in (True, False)
|
||||||
if (
|
if (
|
||||||
conn_params := DeviceConnectionParameters(
|
conn_params := DeviceConnectionParameters(
|
||||||
@ -580,19 +603,26 @@ class Discover:
|
|||||||
and (protocol := get_protocol(config))
|
and (protocol := get_protocol(config))
|
||||||
and (
|
and (
|
||||||
device_class := get_device_class_from_family(
|
device_class := get_device_class_from_family(
|
||||||
device_family.value, https=https
|
device_family.value, https=https, require_exact=True
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
for protocol, config in candidates.values():
|
for key, val in candidates.items():
|
||||||
try:
|
try:
|
||||||
dev = await _connect(config, protocol)
|
prot, config = val
|
||||||
|
dev = await _connect(config, prot)
|
||||||
except Exception:
|
except Exception:
|
||||||
_LOGGER.debug("Unable to connect with %s", protocol)
|
_LOGGER.debug("Unable to connect with %s", prot)
|
||||||
|
if on_attempt:
|
||||||
|
ca = tuple.__new__(ConnectAttempt, key)
|
||||||
|
on_attempt(ca, False)
|
||||||
else:
|
else:
|
||||||
|
if on_attempt:
|
||||||
|
ca = tuple.__new__(ConnectAttempt, key)
|
||||||
|
on_attempt(ca, True)
|
||||||
return dev
|
return dev
|
||||||
finally:
|
finally:
|
||||||
await protocol.close()
|
await prot.close()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -1162,7 +1162,7 @@ async def test_cli_child_commands(
|
|||||||
async def test_discover_config(dev: Device, mocker, runner):
|
async def test_discover_config(dev: Device, mocker, runner):
|
||||||
"""Test that device config is returned."""
|
"""Test that device config is returned."""
|
||||||
host = "127.0.0.1"
|
host = "127.0.0.1"
|
||||||
mocker.patch("kasa.discover.Discover.try_connect_all", return_value=dev)
|
mocker.patch("kasa.device_factory._connect", side_effect=[Exception, dev])
|
||||||
|
|
||||||
res = await runner.invoke(
|
res = await runner.invoke(
|
||||||
cli,
|
cli,
|
||||||
@ -1182,6 +1182,14 @@ async def test_discover_config(dev: Device, mocker, runner):
|
|||||||
cparam = dev.config.connection_type
|
cparam = dev.config.connection_type
|
||||||
expected = f"--device-family {cparam.device_family.value} --encrypt-type {cparam.encryption_type.value} {'--https' if cparam.https else '--no-https'}"
|
expected = f"--device-family {cparam.device_family.value} --encrypt-type {cparam.encryption_type.value} {'--https' if cparam.https else '--no-https'}"
|
||||||
assert expected in res.output
|
assert expected in res.output
|
||||||
|
assert re.search(
|
||||||
|
r"Attempt to connect to 127\.0\.0\.1 with \w+ \+ \w+ \+ \w+ failed",
|
||||||
|
res.output.replace("\n", ""),
|
||||||
|
)
|
||||||
|
assert re.search(
|
||||||
|
r"Attempt to connect to 127\.0\.0\.1 with \w+ \+ \w+ \+ \w+ succeeded",
|
||||||
|
res.output.replace("\n", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_discover_config_invalid(mocker, runner):
|
async def test_discover_config_invalid(mocker, runner):
|
||||||
|
Loading…
Reference in New Issue
Block a user