From 852116795c2aa84da2fb2b3a139f08f72502a332 Mon Sep 17 00:00:00 2001 From: "Steven B." <51370195+sdb9696@users.noreply.github.com> Date: Tue, 22 Oct 2024 12:15:08 +0100 Subject: [PATCH] Add discovery list command to cli (#1183) Report discovered devices in a concise table format. --- kasa/cli/discover.py | 85 +++++++++++++++++++++++++++++++----------- kasa/tests/test_cli.py | 49 ++++++++++++++++++++++++ 2 files changed, 113 insertions(+), 21 deletions(-) diff --git a/kasa/cli/discover.py b/kasa/cli/discover.py index 78f426f5..aac2f96d 100644 --- a/kasa/cli/discover.py +++ b/kasa/cli/discover.py @@ -20,24 +20,21 @@ from kasa.discover import DiscoveryResult from .common import echo -@click.command() +@click.group(invoke_without_command=True) @click.pass_context async def discover(ctx): """Discover devices in the network.""" - target = ctx.parent.params["target"] - username = ctx.parent.params["username"] - password = ctx.parent.params["password"] - discovery_timeout = ctx.parent.params["discovery_timeout"] - timeout = ctx.parent.params["timeout"] - host = ctx.parent.params["host"] - port = ctx.parent.params["port"] + if ctx.invoked_subcommand is None: + return await ctx.invoke(detail) - credentials = Credentials(username, password) if username and password else None - sem = asyncio.Semaphore() - discovered = dict() +@discover.command() +@click.pass_context +async def detail(ctx): + """Discover devices in the network using udp broadcasts.""" unsupported = [] auth_failed = [] + sem = asyncio.Semaphore() async def print_unsupported(unsupported_exception: UnsupportedDeviceError): unsupported.append(unsupported_exception) @@ -65,9 +62,61 @@ async def discover(ctx): else: ctx.parent.obj = dev await ctx.parent.invoke(state) - discovered[dev.host] = dev.internal_state echo() + discovered = await _discover(ctx, print_discovered, print_unsupported) + if ctx.parent.parent.params["host"]: + return discovered + + echo(f"Found {len(discovered)} devices") + if unsupported: + echo(f"Found {len(unsupported)} unsupported devices") + if auth_failed: + echo(f"Found {len(auth_failed)} devices that failed to authenticate") + + return discovered + + +@discover.command() +@click.pass_context +async def list(ctx): + """List devices in the network in a table using udp broadcasts.""" + sem = asyncio.Semaphore() + + async def print_discovered(dev: Device): + cparams = dev.config.connection_type + infostr = ( + f"{dev.host:<15} {cparams.device_family.value:<20} " + f"{cparams.encryption_type.value:<7}" + ) + async with sem: + try: + await dev.update() + except AuthenticationError: + echo(f"{infostr} - Authentication failed") + else: + echo(f"{infostr} {dev.alias}") + + async def print_unsupported(unsupported_exception: UnsupportedDeviceError): + if res := unsupported_exception.discovery_result: + echo(f"{res.get('ip'):<15} UNSUPPORTED DEVICE") + + echo(f"{'HOST':<15} {'DEVICE FAMILY':<20} {'ENCRYPT':<7} {'ALIAS'}") + return await _discover(ctx, print_discovered, print_unsupported, do_echo=False) + + +async def _discover(ctx, print_discovered, print_unsupported, *, do_echo=True): + params = ctx.parent.parent.params + target = params["target"] + username = params["username"] + password = params["password"] + discovery_timeout = params["discovery_timeout"] + timeout = params["timeout"] + host = params["host"] + port = params["port"] + + credentials = Credentials(username, password) if username and password else None + if host: echo(f"Discovering device {host} for {discovery_timeout} seconds") return await Discover.discover_single( @@ -78,8 +127,8 @@ async def discover(ctx): discovery_timeout=discovery_timeout, on_unsupported=print_unsupported, ) - - echo(f"Discovering devices on {target} for {discovery_timeout} seconds") + if do_echo: + echo(f"Discovering devices on {target} for {discovery_timeout} seconds") discovered_devices = await Discover.discover( target=target, discovery_timeout=discovery_timeout, @@ -93,13 +142,7 @@ async def discover(ctx): for device in discovered_devices.values(): await device.protocol.close() - echo(f"Found {len(discovered)} devices") - if unsupported: - echo(f"Found {len(unsupported)} unsupported devices") - if auth_failed: - echo(f"Found {len(auth_failed)} devices that failed to authenticate") - - return discovered + return discovered_devices def _echo_dictionary(discovery_info: dict): diff --git a/kasa/tests/test_cli.py b/kasa/tests/test_cli.py index f22286e5..8d830f08 100644 --- a/kasa/tests/test_cli.py +++ b/kasa/tests/test_cli.py @@ -104,6 +104,55 @@ async def test_update_called_by_cli(dev, mocker, runner, device_family, encrypt_ update.assert_called() +async def test_list_devices(discovery_mock, runner): + """Test that device update is called on main.""" + res = await runner.invoke( + cli, + ["--username", "foo", "--password", "bar", "discover", "list"], + catch_exceptions=False, + ) + assert res.exit_code == 0 + header = f"{'HOST':<15} {'DEVICE FAMILY':<20} {'ENCRYPT':<7} {'ALIAS'}" + row = f"{discovery_mock.ip:<15} {discovery_mock.device_type:<20} {discovery_mock.encrypt_type:<7}" + assert header in res.output + assert row in res.output + + +@new_discovery +async def test_list_auth_failed(discovery_mock, mocker, runner): + """Test that device update is called on main.""" + device_class = Discover._get_device_class(discovery_mock.discovery_data) + mocker.patch.object( + device_class, + "update", + side_effect=AuthenticationError("Failed to authenticate"), + ) + res = await runner.invoke( + cli, + ["--username", "foo", "--password", "bar", "discover", "list"], + catch_exceptions=False, + ) + assert res.exit_code == 0 + header = f"{'HOST':<15} {'DEVICE FAMILY':<20} {'ENCRYPT':<7} {'ALIAS'}" + row = f"{discovery_mock.ip:<15} {discovery_mock.device_type:<20} {discovery_mock.encrypt_type:<7} - Authentication failed" + assert header in res.output + assert row in res.output + + +async def test_list_unsupported(unsupported_device_info, runner): + """Test that device update is called on main.""" + res = await runner.invoke( + cli, + ["--username", "foo", "--password", "bar", "discover", "list"], + catch_exceptions=False, + ) + assert res.exit_code == 0 + header = f"{'HOST':<15} {'DEVICE FAMILY':<20} {'ENCRYPT':<7} {'ALIAS'}" + row = f"{'127.0.0.1':<15} UNSUPPORTED DEVICE" + assert header in res.output + assert row in res.output + + async def test_sysinfo(dev: Device, runner): res = await runner.invoke(sysinfo, obj=dev) assert "System info" in res.output