diff --git a/kasa/__init__.py b/kasa/__init__.py index 11000419..a74cb4c4 100755 --- a/kasa/__init__.py +++ b/kasa/__init__.py @@ -41,7 +41,7 @@ from kasa.iotprotocol import ( _deprecated_TPLinkSmartHomeProtocol, # noqa: F401 ) from kasa.module import Module -from kasa.protocol import BaseProtocol +from kasa.protocol import BaseProtocol, BaseTransport from kasa.smartprotocol import SmartProtocol __version__ = version("python-kasa") @@ -50,6 +50,7 @@ __version__ = version("python-kasa") __all__ = [ "Discover", "BaseProtocol", + "BaseTransport", "IotProtocol", "SmartProtocol", "LightState", diff --git a/kasa/cli/discover.py b/kasa/cli/discover.py index 7989dbb1..6a55cb43 100644 --- a/kasa/cli/discover.py +++ b/kasa/cli/discover.py @@ -15,7 +15,7 @@ from kasa import ( Discover, UnsupportedDeviceError, ) -from kasa.discover import DiscoveryResult +from kasa.discover import ConnectAttempt, DiscoveryResult from .common import echo, error @@ -165,8 +165,17 @@ async def config(ctx): credentials = Credentials(username, password) if username and password else None + host_port = host + (f":{port}" if port else "") + + def on_attempt(connect_attempt: ConnectAttempt, success: bool) -> None: + prot, tran, dev = connect_attempt + key_str = f"{prot.__name__} + {tran.__name__} + {dev.__name__}" + result = "succeeded" if success else "failed" + msg = f"Attempt to connect to {host_port} with {key_str} {result}" + echo(msg) + dev = await Discover.try_connect_all( - host, credentials=credentials, timeout=timeout, port=port + host, credentials=credentials, timeout=timeout, port=port, on_attempt=on_attempt ) if dev: cparams = dev.config.connection_type diff --git a/kasa/device_factory.py b/kasa/device_factory.py index d7b77843..7f2150d7 100755 --- a/kasa/device_factory.py +++ b/kasa/device_factory.py @@ -167,7 +167,7 @@ def get_device_class_from_sys_info(sysinfo: dict[str, Any]) -> type[IotDevice]: def get_device_class_from_family( - device_type: str, *, https: bool + device_type: str, *, https: bool, require_exact: bool = False ) -> type[Device] | None: """Return the device class from the type name.""" supported_device_types: dict[str, type[Device]] = { @@ -185,8 +185,10 @@ def get_device_class_from_family( } lookup_key = f"{device_type}{'.HTTPS' if https else ''}" if ( - cls := supported_device_types.get(lookup_key) - ) is None and device_type.startswith("SMART."): + (cls := supported_device_types.get(lookup_key)) is None + and device_type.startswith("SMART.") + and not require_exact + ): _LOGGER.warning("Unknown SMART device with %s, using SmartDevice", device_type) cls = SmartDevice diff --git a/kasa/discover.py b/kasa/discover.py index 3b8f7c44..a774ebde 100755 --- a/kasa/discover.py +++ b/kasa/discover.py @@ -91,7 +91,7 @@ import socket import struct from collections.abc import Awaitable from pprint import pformat as pf -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast +from typing import TYPE_CHECKING, Any, Callable, Dict, NamedTuple, Optional, Type, cast from aiohttp import ClientSession @@ -118,6 +118,7 @@ from kasa.exceptions import ( TimeoutError, UnsupportedDeviceError, ) +from kasa.experimental import Experimental from kasa.iot.iotdevice import IotDevice from kasa.iotprotocol import REDACTORS as IOT_REDACTORS from kasa.json import dumps as json_dumps @@ -127,9 +128,21 @@ from kasa.xortransport import XorEncryption _LOGGER = logging.getLogger(__name__) +if TYPE_CHECKING: + from kasa import BaseProtocol, BaseTransport + + +class ConnectAttempt(NamedTuple): + """Try to connect attempt.""" + + protocol: type + transport: type + device: type + OnDiscoveredCallable = Callable[[Device], Awaitable[None]] OnUnsupportedCallable = Callable[[UnsupportedDeviceError], Awaitable[None]] +OnConnectAttemptCallable = Callable[[ConnectAttempt, bool], None] DeviceDict = Dict[str, Device] NEW_DISCOVERY_REDACTORS: dict[str, Callable[[Any], Any] | None] = { @@ -535,6 +548,7 @@ class Discover: timeout: int | None = None, credentials: Credentials | None = None, http_client: ClientSession | None = None, + on_attempt: OnConnectAttemptCallable | None = None, ) -> Device | None: """Try to connect directly to a device with all possible parameters. @@ -551,13 +565,22 @@ class Discover: """ from .device_factory import _connect - candidates = { + main_device_families = { + Device.Family.SmartTapoPlug, + Device.Family.IotSmartPlugSwitch, + } + if Experimental.enabled(): + main_device_families.add(Device.Family.SmartIpCamera) + candidates: dict[ + tuple[type[BaseProtocol], type[BaseTransport], type[Device]], + tuple[BaseProtocol, DeviceConfig], + ] = { (type(protocol), type(protocol._transport), device_class): ( protocol, config, ) for encrypt in Device.EncryptionType - for device_family in Device.Family + for device_family in main_device_families for https in (True, False) if ( conn_params := DeviceConnectionParameters( @@ -580,19 +603,26 @@ class Discover: and (protocol := get_protocol(config)) and ( device_class := get_device_class_from_family( - device_family.value, https=https + device_family.value, https=https, require_exact=True ) ) } - for protocol, config in candidates.values(): + for key, val in candidates.items(): try: - dev = await _connect(config, protocol) + prot, config = val + dev = await _connect(config, prot) except Exception: - _LOGGER.debug("Unable to connect with %s", protocol) + _LOGGER.debug("Unable to connect with %s", prot) + if on_attempt: + ca = tuple.__new__(ConnectAttempt, key) + on_attempt(ca, False) else: + if on_attempt: + ca = tuple.__new__(ConnectAttempt, key) + on_attempt(ca, True) return dev finally: - await protocol.close() + await prot.close() return None @staticmethod diff --git a/kasa/tests/test_cli.py b/kasa/tests/test_cli.py index 80b5daaf..7a0b0dde 100644 --- a/kasa/tests/test_cli.py +++ b/kasa/tests/test_cli.py @@ -1162,7 +1162,7 @@ async def test_cli_child_commands( async def test_discover_config(dev: Device, mocker, runner): """Test that device config is returned.""" host = "127.0.0.1" - mocker.patch("kasa.discover.Discover.try_connect_all", return_value=dev) + mocker.patch("kasa.device_factory._connect", side_effect=[Exception, dev]) res = await runner.invoke( cli, @@ -1182,6 +1182,14 @@ async def test_discover_config(dev: Device, mocker, runner): cparam = dev.config.connection_type expected = f"--device-family {cparam.device_family.value} --encrypt-type {cparam.encryption_type.value} {'--https' if cparam.https else '--no-https'}" assert expected in res.output + assert re.search( + r"Attempt to connect to 127\.0\.0\.1 with \w+ \+ \w+ \+ \w+ failed", + res.output.replace("\n", ""), + ) + assert re.search( + r"Attempt to connect to 127\.0\.0\.1 with \w+ \+ \w+ \+ \w+ succeeded", + res.output.replace("\n", ""), + ) async def test_discover_config_invalid(mocker, runner):