diff --git a/kasa/cli/discover.py b/kasa/cli/discover.py index aac2f96d..deb28b4d 100644 --- a/kasa/cli/discover.py +++ b/kasa/cli/discover.py @@ -17,7 +17,7 @@ from kasa import ( ) from kasa.discover import DiscoveryResult -from .common import echo +from .common import echo, error @click.group(invoke_without_command=True) @@ -145,6 +145,41 @@ async def _discover(ctx, print_discovered, print_unsupported, *, do_echo=True): return discovered_devices +@discover.command() +@click.pass_context +async def config(ctx): + """Bypass udp discovery and try to show connection config for a device. + + Bypasses udp discovery and shows the parameters required to connect + directly to the device. + """ + params = ctx.parent.parent.params + username = params["username"] + password = params["password"] + timeout = params["timeout"] + host = params["host"] + port = params["port"] + + if not host: + error("--host option must be supplied to discover config") + + credentials = Credentials(username, password) if username and password else None + + dev = await Discover.try_connect_all( + host, credentials=credentials, timeout=timeout, port=port + ) + if dev: + cparams = dev.config.connection_type + echo("Managed to connect, cli options to connect are:") + echo( + f"--device-family {cparams.device_family.value} " + f"--encrypt-type {cparams.encryption_type.value} " + f"{'--https' if cparams.https else '--no-https'}" + ) + else: + error(f"Unable to connect to {host}") + + def _echo_dictionary(discovery_info: dict): echo("\t[bold]== Discovery information ==[/bold]") for key, value in discovery_info.items(): diff --git a/kasa/cli/main.py b/kasa/cli/main.py index 7ba65155..b721e984 100755 --- a/kasa/cli/main.py +++ b/kasa/cli/main.py @@ -39,6 +39,7 @@ TYPES = [ ] ENCRYPT_TYPES = [encrypt_type.value for encrypt_type in DeviceEncryptionType] +DEFAULT_TARGET = "255.255.255.255" def _legacy_type_to_class(_type): @@ -115,7 +116,7 @@ def _legacy_type_to_class(_type): @click.option( "--target", envvar="KASA_TARGET", - default="255.255.255.255", + default=DEFAULT_TARGET, required=False, show_default=True, help="The broadcast address to be used for discovery.", @@ -256,6 +257,9 @@ async def cli( ctx.obj = object() return + if target != DEFAULT_TARGET and host: + error("--target is not a valid option for single host discovery") + if experimental: from kasa.experimental.enabled import Enabled diff --git a/kasa/discover.py b/kasa/discover.py index 79c16216..e7a3946c 100755 --- a/kasa/discover.py +++ b/kasa/discover.py @@ -526,6 +526,66 @@ class Discover: else: raise TimeoutError(f"Timed out getting discovery response for {host}") + @staticmethod + async def try_connect_all( + host: str, + *, + port: int | None = None, + timeout: int | None = None, + credentials: Credentials | None = None, + ) -> Device | None: + """Try to connect directly to a device with all possible parameters. + + This method can be used when udp is not working due to network issues. + After succesfully connecting use the device config and + :meth:`Device.connect()` for future connections. + + :param host: Hostname of device to query + :param port: Optionally set a different port for legacy devices using port 9999 + :param timeout: Timeout in seconds device for devices queries + :param credentials: Credentials for devices that require authentication. + username and password are ignored if provided. + """ + from .device_factory import _connect + + candidates = { + (type(protocol), type(protocol._transport), device_class): ( + protocol, + config, + ) + for encrypt in Device.EncryptionType + for device_family in Device.Family + for https in (True, False) + if ( + conn_params := DeviceConnectionParameters( + device_family=device_family, + encryption_type=encrypt, + https=https, + ) + ) + and ( + config := DeviceConfig( + host=host, + connection_type=conn_params, + timeout=timeout, + port_override=port, + credentials=credentials, + ) + ) + and (protocol := get_protocol(config)) + and (device_class := get_device_class_from_family(device_family.value)) + } + for protocol, config in candidates.values(): + try: + dev = await _connect(config, protocol) + except Exception: + _LOGGER.debug("Unable to connect with %s", protocol) + else: + return dev + finally: + await protocol.close() + return None + @staticmethod def _get_device_class(info: dict) -> type[Device]: """Find SmartDevice subclass for device described by passed data.""" diff --git a/kasa/tests/test_cli.py b/kasa/tests/test_cli.py index 8d830f08..e1861a29 100644 --- a/kasa/tests/test_cli.py +++ b/kasa/tests/test_cli.py @@ -1158,3 +1158,78 @@ async def test_cli_child_commands( assert res.exit_code == 0 parent_update_spy.assert_called_once() assert dev.children[0].update == child_update_method + + +async def test_discover_config(dev: Device, mocker, runner): + """Test that device config is returned.""" + host = "127.0.0.1" + mocker.patch("kasa.discover.Discover.try_connect_all", return_value=dev) + + res = await runner.invoke( + cli, + [ + "--username", + "foo", + "--password", + "bar", + "--host", + host, + "discover", + "config", + ], + catch_exceptions=False, + ) + assert res.exit_code == 0 + cparam = dev.config.connection_type + expected = f"--device-family {cparam.device_family.value} --encrypt-type {cparam.encryption_type.value} {'--https' if cparam.https else '--no-https'}" + assert expected in res.output + + +async def test_discover_config_invalid(mocker, runner): + """Test the device config command with invalids.""" + host = "127.0.0.1" + mocker.patch("kasa.discover.Discover.try_connect_all", return_value=None) + + res = await runner.invoke( + cli, + [ + "--username", + "foo", + "--password", + "bar", + "--host", + host, + "discover", + "config", + ], + catch_exceptions=False, + ) + assert res.exit_code == 1 + assert f"Unable to connect to {host}" in res.output + + res = await runner.invoke( + cli, + ["--username", "foo", "--password", "bar", "discover", "config"], + catch_exceptions=False, + ) + assert res.exit_code == 1 + assert "--host option must be supplied to discover config" in res.output + + res = await runner.invoke( + cli, + [ + "--username", + "foo", + "--password", + "bar", + "--host", + host, + "--target", + "127.0.0.2", + "discover", + "config", + ], + catch_exceptions=False, + ) + assert res.exit_code == 1 + assert "--target is not a valid option for single host discovery" in res.output diff --git a/kasa/tests/test_discovery.py b/kasa/tests/test_discovery.py index 8163d4c1..d6e0a0db 100644 --- a/kasa/tests/test_discovery.py +++ b/kasa/tests/test_discovery.py @@ -20,9 +20,15 @@ from kasa import ( Device, DeviceType, Discover, + IotProtocol, KasaException, ) from kasa.aestransport import AesEncyptionSession +from kasa.device_factory import ( + get_device_class_from_family, + get_device_class_from_sys_info, + get_protocol, +) from kasa.deviceconfig import ( DeviceConfig, DeviceConnectionParameters, @@ -35,7 +41,7 @@ from kasa.discover import ( ) from kasa.exceptions import AuthenticationError, UnsupportedDeviceError from kasa.iot import IotDevice -from kasa.xortransport import XorEncryption +from kasa.xortransport import XorEncryption, XorTransport from .conftest import ( bulb_iot, @@ -647,3 +653,51 @@ async def test_discovery_decryption(): dr = DiscoveryResult(**info) Discover._decrypt_discovery_data(dr) assert dr.decrypted_data == data_dict + + +async def test_discover_try_connect_all(discovery_mock, mocker): + """Test that device update is called on main.""" + if "result" in discovery_mock.discovery_data: + dev_class = get_device_class_from_family(discovery_mock.device_type) + cparams = DeviceConnectionParameters.from_values( + discovery_mock.device_type, + discovery_mock.encrypt_type, + discovery_mock.login_version, + False, + ) + protocol = get_protocol( + DeviceConfig(discovery_mock.ip, connection_type=cparams) + ) + protocol_class = protocol.__class__ + transport_class = protocol._transport.__class__ + else: + dev_class = get_device_class_from_sys_info(discovery_mock.discovery_data) + protocol_class = IotProtocol + transport_class = XorTransport + + async def _query(self, *args, **kwargs): + if ( + self.__class__ is protocol_class + and self._transport.__class__ is transport_class + ): + return discovery_mock.query_data + raise KasaException() + + async def _update(self, *args, **kwargs): + if ( + self.protocol.__class__ is protocol_class + and self.protocol._transport.__class__ is transport_class + ): + return + raise KasaException() + + mocker.patch("kasa.IotProtocol.query", new=_query) + mocker.patch("kasa.SmartProtocol.query", new=_query) + mocker.patch.object(dev_class, "update", new=_update) + + dev = await Discover.try_connect_all(discovery_mock.ip) + + assert dev + assert isinstance(dev, dev_class) + assert isinstance(dev.protocol, protocol_class) + assert isinstance(dev.protocol._transport, transport_class)