mirror of
				https://github.com/python-kasa/python-kasa.git
				synced 2025-10-31 12:41:54 +00:00 
			
		
		
		
	Update try_connect_all to be more efficient and report attempts (#1222)
This commit is contained in:
		| @@ -41,7 +41,7 @@ from kasa.iotprotocol import ( | ||||
|     _deprecated_TPLinkSmartHomeProtocol,  # noqa: F401 | ||||
| ) | ||||
| from kasa.module import Module | ||||
| from kasa.protocol import BaseProtocol | ||||
| from kasa.protocol import BaseProtocol, BaseTransport | ||||
| from kasa.smartprotocol import SmartProtocol | ||||
|  | ||||
| __version__ = version("python-kasa") | ||||
| @@ -50,6 +50,7 @@ __version__ = version("python-kasa") | ||||
| __all__ = [ | ||||
|     "Discover", | ||||
|     "BaseProtocol", | ||||
|     "BaseTransport", | ||||
|     "IotProtocol", | ||||
|     "SmartProtocol", | ||||
|     "LightState", | ||||
|   | ||||
| @@ -15,7 +15,7 @@ from kasa import ( | ||||
|     Discover, | ||||
|     UnsupportedDeviceError, | ||||
| ) | ||||
| from kasa.discover import DiscoveryResult | ||||
| from kasa.discover import ConnectAttempt, DiscoveryResult | ||||
|  | ||||
| from .common import echo, error | ||||
|  | ||||
| @@ -165,8 +165,17 @@ async def config(ctx): | ||||
|  | ||||
|     credentials = Credentials(username, password) if username and password else None | ||||
|  | ||||
|     host_port = host + (f":{port}" if port else "") | ||||
|  | ||||
|     def on_attempt(connect_attempt: ConnectAttempt, success: bool) -> None: | ||||
|         prot, tran, dev = connect_attempt | ||||
|         key_str = f"{prot.__name__} + {tran.__name__} + {dev.__name__}" | ||||
|         result = "succeeded" if success else "failed" | ||||
|         msg = f"Attempt to connect to {host_port} with {key_str} {result}" | ||||
|         echo(msg) | ||||
|  | ||||
|     dev = await Discover.try_connect_all( | ||||
|         host, credentials=credentials, timeout=timeout, port=port | ||||
|         host, credentials=credentials, timeout=timeout, port=port, on_attempt=on_attempt | ||||
|     ) | ||||
|     if dev: | ||||
|         cparams = dev.config.connection_type | ||||
|   | ||||
| @@ -167,7 +167,7 @@ def get_device_class_from_sys_info(sysinfo: dict[str, Any]) -> type[IotDevice]: | ||||
|  | ||||
|  | ||||
| def get_device_class_from_family( | ||||
|     device_type: str, *, https: bool | ||||
|     device_type: str, *, https: bool, require_exact: bool = False | ||||
| ) -> type[Device] | None: | ||||
|     """Return the device class from the type name.""" | ||||
|     supported_device_types: dict[str, type[Device]] = { | ||||
| @@ -185,8 +185,10 @@ def get_device_class_from_family( | ||||
|     } | ||||
|     lookup_key = f"{device_type}{'.HTTPS' if https else ''}" | ||||
|     if ( | ||||
|         cls := supported_device_types.get(lookup_key) | ||||
|     ) is None and device_type.startswith("SMART."): | ||||
|         (cls := supported_device_types.get(lookup_key)) is None | ||||
|         and device_type.startswith("SMART.") | ||||
|         and not require_exact | ||||
|     ): | ||||
|         _LOGGER.warning("Unknown SMART device with %s, using SmartDevice", device_type) | ||||
|         cls = SmartDevice | ||||
|  | ||||
|   | ||||
| @@ -91,7 +91,7 @@ import socket | ||||
| import struct | ||||
| from collections.abc import Awaitable | ||||
| from pprint import pformat as pf | ||||
| from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast | ||||
| from typing import TYPE_CHECKING, Any, Callable, Dict, NamedTuple, Optional, Type, cast | ||||
|  | ||||
| from aiohttp import ClientSession | ||||
|  | ||||
| @@ -118,6 +118,7 @@ from kasa.exceptions import ( | ||||
|     TimeoutError, | ||||
|     UnsupportedDeviceError, | ||||
| ) | ||||
| from kasa.experimental import Experimental | ||||
| from kasa.iot.iotdevice import IotDevice | ||||
| from kasa.iotprotocol import REDACTORS as IOT_REDACTORS | ||||
| from kasa.json import dumps as json_dumps | ||||
| @@ -127,9 +128,21 @@ from kasa.xortransport import XorEncryption | ||||
|  | ||||
| _LOGGER = logging.getLogger(__name__) | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from kasa import BaseProtocol, BaseTransport | ||||
|  | ||||
|  | ||||
| class ConnectAttempt(NamedTuple): | ||||
|     """Try to connect attempt.""" | ||||
|  | ||||
|     protocol: type | ||||
|     transport: type | ||||
|     device: type | ||||
|  | ||||
|  | ||||
| OnDiscoveredCallable = Callable[[Device], Awaitable[None]] | ||||
| OnUnsupportedCallable = Callable[[UnsupportedDeviceError], Awaitable[None]] | ||||
| OnConnectAttemptCallable = Callable[[ConnectAttempt, bool], None] | ||||
| DeviceDict = Dict[str, Device] | ||||
|  | ||||
| NEW_DISCOVERY_REDACTORS: dict[str, Callable[[Any], Any] | None] = { | ||||
| @@ -535,6 +548,7 @@ class Discover: | ||||
|         timeout: int | None = None, | ||||
|         credentials: Credentials | None = None, | ||||
|         http_client: ClientSession | None = None, | ||||
|         on_attempt: OnConnectAttemptCallable | None = None, | ||||
|     ) -> Device | None: | ||||
|         """Try to connect directly to a device with all possible parameters. | ||||
|  | ||||
| @@ -551,13 +565,22 @@ class Discover: | ||||
|         """ | ||||
|         from .device_factory import _connect | ||||
|  | ||||
|         candidates = { | ||||
|         main_device_families = { | ||||
|             Device.Family.SmartTapoPlug, | ||||
|             Device.Family.IotSmartPlugSwitch, | ||||
|         } | ||||
|         if Experimental.enabled(): | ||||
|             main_device_families.add(Device.Family.SmartIpCamera) | ||||
|         candidates: dict[ | ||||
|             tuple[type[BaseProtocol], type[BaseTransport], type[Device]], | ||||
|             tuple[BaseProtocol, DeviceConfig], | ||||
|         ] = { | ||||
|             (type(protocol), type(protocol._transport), device_class): ( | ||||
|                 protocol, | ||||
|                 config, | ||||
|             ) | ||||
|             for encrypt in Device.EncryptionType | ||||
|             for device_family in Device.Family | ||||
|             for device_family in main_device_families | ||||
|             for https in (True, False) | ||||
|             if ( | ||||
|                 conn_params := DeviceConnectionParameters( | ||||
| @@ -580,19 +603,26 @@ class Discover: | ||||
|             and (protocol := get_protocol(config)) | ||||
|             and ( | ||||
|                 device_class := get_device_class_from_family( | ||||
|                     device_family.value, https=https | ||||
|                     device_family.value, https=https, require_exact=True | ||||
|                 ) | ||||
|             ) | ||||
|         } | ||||
|         for protocol, config in candidates.values(): | ||||
|         for key, val in candidates.items(): | ||||
|             try: | ||||
|                 dev = await _connect(config, protocol) | ||||
|                 prot, config = val | ||||
|                 dev = await _connect(config, prot) | ||||
|             except Exception: | ||||
|                 _LOGGER.debug("Unable to connect with %s", protocol) | ||||
|                 _LOGGER.debug("Unable to connect with %s", prot) | ||||
|                 if on_attempt: | ||||
|                     ca = tuple.__new__(ConnectAttempt, key) | ||||
|                     on_attempt(ca, False) | ||||
|             else: | ||||
|                 if on_attempt: | ||||
|                     ca = tuple.__new__(ConnectAttempt, key) | ||||
|                     on_attempt(ca, True) | ||||
|                 return dev | ||||
|             finally: | ||||
|                 await protocol.close() | ||||
|                 await prot.close() | ||||
|         return None | ||||
|  | ||||
|     @staticmethod | ||||
|   | ||||
| @@ -1162,7 +1162,7 @@ async def test_cli_child_commands( | ||||
| async def test_discover_config(dev: Device, mocker, runner): | ||||
|     """Test that device config is returned.""" | ||||
|     host = "127.0.0.1" | ||||
|     mocker.patch("kasa.discover.Discover.try_connect_all", return_value=dev) | ||||
|     mocker.patch("kasa.device_factory._connect", side_effect=[Exception, dev]) | ||||
|  | ||||
|     res = await runner.invoke( | ||||
|         cli, | ||||
| @@ -1182,6 +1182,14 @@ async def test_discover_config(dev: Device, mocker, runner): | ||||
|     cparam = dev.config.connection_type | ||||
|     expected = f"--device-family {cparam.device_family.value} --encrypt-type {cparam.encryption_type.value} {'--https' if cparam.https else '--no-https'}" | ||||
|     assert expected in res.output | ||||
|     assert re.search( | ||||
|         r"Attempt to connect to 127\.0\.0\.1 with \w+ \+ \w+ \+ \w+ failed", | ||||
|         res.output.replace("\n", ""), | ||||
|     ) | ||||
|     assert re.search( | ||||
|         r"Attempt to connect to 127\.0\.0\.1 with \w+ \+ \w+ \+ \w+ succeeded", | ||||
|         res.output.replace("\n", ""), | ||||
|     ) | ||||
|  | ||||
|  | ||||
| async def test_discover_config_invalid(mocker, runner): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Steven B.
					Steven B.