From d03f535568ca22e32b51c1e1f9f703d12489130e Mon Sep 17 00:00:00 2001 From: "Steven B." <51370195+sdb9696@users.noreply.github.com> Date: Tue, 14 Jan 2025 14:47:52 +0000 Subject: [PATCH] Fix discover cli command with host (#1437) --- kasa/cli/common.py | 33 ++++++++++++++++++-- kasa/cli/device.py | 3 ++ kasa/cli/discover.py | 74 ++++++++++++++++++++++++++++++++++---------- kasa/cli/main.py | 17 +++++++--- 4 files changed, 103 insertions(+), 24 deletions(-) diff --git a/kasa/cli/common.py b/kasa/cli/common.py index 5114f7af..d0ef9dc3 100644 --- a/kasa/cli/common.py +++ b/kasa/cli/common.py @@ -10,7 +10,7 @@ from collections.abc import Callable from contextlib import contextmanager from functools import singledispatch, update_wrapper, wraps from gettext import gettext -from typing import TYPE_CHECKING, Any, Final +from typing import TYPE_CHECKING, Any, Final, NoReturn import asyncclick as click @@ -57,7 +57,7 @@ def echo(*args, **kwargs) -> None: _echo(*args, **kwargs) -def error(msg: str) -> None: +def error(msg: str) -> NoReturn: """Print an error and exit.""" echo(f"[bold red]{msg}[/bold red]") sys.exit(1) @@ -68,6 +68,16 @@ def json_formatter_cb(result: Any, **kwargs) -> None: if not kwargs.get("json"): return + # Calling the discover command directly always returns a DeviceDict so if host + # was specified just format the device json + if ( + (host := kwargs.get("host")) + and isinstance(result, dict) + and (dev := result.get(host)) + and isinstance(dev, Device) + ): + result = dev + @singledispatch def to_serializable(val): """Regular obj-to-string for json serialization. @@ -85,6 +95,25 @@ def json_formatter_cb(result: Any, **kwargs) -> None: print(json_content) +async def invoke_subcommand( + command: click.BaseCommand, + ctx: click.Context, + args: list[str] | None = None, + **extra: Any, +) -> Any: + """Invoke a click subcommand. + + Calling ctx.Invoke() treats the command like a simple callback and doesn't + process any result_callbacks so we use this pattern from the click docs + https://click.palletsprojects.com/en/stable/exceptions/#what-if-i-don-t-want-that. + """ + if args is None: + args = [] + sub_ctx = await command.make_context(command.name, args, parent=ctx, **extra) + async with sub_ctx: + return await command.invoke(sub_ctx) + + def pass_dev_or_child(wrapped_function: Callable) -> Callable: """Pass the device or child to the click command based on the child options.""" child_help = ( diff --git a/kasa/cli/device.py b/kasa/cli/device.py index 0ef8a76f..a10f485d 100644 --- a/kasa/cli/device.py +++ b/kasa/cli/device.py @@ -3,6 +3,7 @@ from __future__ import annotations from pprint import pformat as pf +from typing import TYPE_CHECKING import asyncclick as click @@ -82,6 +83,8 @@ async def state(ctx, dev: Device): echo() from .discover import _echo_discovery_info + if TYPE_CHECKING: + assert dev._discovery_info _echo_discovery_info(dev._discovery_info) return dev.internal_state diff --git a/kasa/cli/discover.py b/kasa/cli/discover.py index ff201ce6..07500f3b 100644 --- a/kasa/cli/discover.py +++ b/kasa/cli/discover.py @@ -4,6 +4,7 @@ from __future__ import annotations import asyncio from pprint import pformat as pf +from typing import TYPE_CHECKING, cast import asyncclick as click @@ -17,8 +18,12 @@ from kasa import ( from kasa.discover import ( NEW_DISCOVERY_REDACTORS, ConnectAttempt, + DeviceDict, DiscoveredRaw, DiscoveryResult, + OnDiscoveredCallable, + OnDiscoveredRawCallable, + OnUnsupportedCallable, ) from kasa.iot.iotdevice import _extract_sys_info from kasa.protocols.iotprotocol import REDACTORS as IOT_REDACTORS @@ -30,15 +35,33 @@ from .common import echo, error @click.group(invoke_without_command=True) @click.pass_context -async def discover(ctx): +async def discover(ctx: click.Context): """Discover devices in the network.""" if ctx.invoked_subcommand is None: return await ctx.invoke(detail) +@discover.result_callback() +@click.pass_context +async def _close_protocols(ctx: click.Context, discovered: DeviceDict): + """Close all the device protocols if discover was invoked directly by the user.""" + if _discover_is_root_cmd(ctx): + for dev in discovered.values(): + await dev.disconnect() + return discovered + + +def _discover_is_root_cmd(ctx: click.Context) -> bool: + """Will return true if discover was invoked directly by the user.""" + root_ctx = ctx.find_root() + return ( + root_ctx.invoked_subcommand is None or root_ctx.invoked_subcommand == "discover" + ) + + @discover.command() @click.pass_context -async def detail(ctx): +async def detail(ctx: click.Context) -> DeviceDict: """Discover devices in the network using udp broadcasts.""" unsupported = [] auth_failed = [] @@ -59,10 +82,14 @@ async def detail(ctx): from .device import state async def print_discovered(dev: Device) -> None: + if TYPE_CHECKING: + assert ctx.parent async with sem: try: await dev.update() except AuthenticationError: + if TYPE_CHECKING: + assert dev._discovery_info auth_failed.append(dev._discovery_info) echo("== Authentication failed for device ==") _echo_discovery_info(dev._discovery_info) @@ -73,9 +100,11 @@ async def detail(ctx): echo() discovered = await _discover( - ctx, print_discovered=print_discovered, print_unsupported=print_unsupported + ctx, + print_discovered=print_discovered if _discover_is_root_cmd(ctx) else None, + print_unsupported=print_unsupported, ) - if ctx.parent.parent.params["host"]: + if ctx.find_root().params["host"]: return discovered echo(f"Found {len(discovered)} devices") @@ -96,7 +125,7 @@ async def detail(ctx): help="Set flag to redact sensitive data from raw output.", ) @click.pass_context -async def raw(ctx, redact: bool): +async def raw(ctx: click.Context, redact: bool) -> DeviceDict: """Return raw discovery data returned from devices.""" def print_raw(discovered: DiscoveredRaw): @@ -116,7 +145,7 @@ async def raw(ctx, redact: bool): @discover.command() @click.pass_context -async def list(ctx): +async def list(ctx: click.Context) -> DeviceDict: """List devices in the network in a table using udp broadcasts.""" sem = asyncio.Semaphore() @@ -147,18 +176,24 @@ async def list(ctx): f"{'HOST':<15} {'MODEL':<9} {'DEVICE FAMILY':<20} {'ENCRYPT':<7} " f"{'HTTPS':<5} {'LV':<3} {'ALIAS'}" ) - return await _discover( + discovered = await _discover( ctx, print_discovered=print_discovered, print_unsupported=print_unsupported, do_echo=False, ) + return discovered async def _discover( - ctx, *, print_discovered=None, print_unsupported=None, print_raw=None, do_echo=True -): - params = ctx.parent.parent.params + ctx: click.Context, + *, + print_discovered: OnDiscoveredCallable | None = None, + print_unsupported: OnUnsupportedCallable | None = None, + print_raw: OnDiscoveredRawCallable | None = None, + do_echo=True, +) -> DeviceDict: + params = ctx.find_root().params target = params["target"] username = params["username"] password = params["password"] @@ -170,8 +205,9 @@ async def _discover( credentials = Credentials(username, password) if username and password else None if host: + host = cast(str, host) echo(f"Discovering device {host} for {discovery_timeout} seconds") - return await Discover.discover_single( + dev = await Discover.discover_single( host, port=port, credentials=credentials, @@ -180,6 +216,12 @@ async def _discover( on_unsupported=print_unsupported, on_discovered_raw=print_raw, ) + if dev: + if print_discovered: + await print_discovered(dev) + return {host: dev} + else: + return {} if do_echo: echo(f"Discovering devices on {target} for {discovery_timeout} seconds") discovered_devices = await Discover.discover( @@ -193,21 +235,18 @@ async def _discover( on_discovered_raw=print_raw, ) - for device in discovered_devices.values(): - await device.protocol.close() - return discovered_devices @discover.command() @click.pass_context -async def config(ctx): +async def config(ctx: click.Context) -> DeviceDict: """Bypass udp discovery and try to show connection config for a device. Bypasses udp discovery and shows the parameters required to connect directly to the device. """ - params = ctx.parent.parent.params + params = ctx.find_root().params username = params["username"] password = params["password"] timeout = params["timeout"] @@ -239,6 +278,7 @@ async def config(ctx): f"--encrypt-type {cparams.encryption_type.value} " f"{'--https' if cparams.https else '--no-https'}" ) + return {host: dev} else: error(f"Unable to connect to {host}") @@ -251,7 +291,7 @@ def _echo_dictionary(discovery_info: dict) -> None: echo(f"\t{key_name_and_spaces}{value}") -def _echo_discovery_info(discovery_info) -> None: +def _echo_discovery_info(discovery_info: dict) -> None: # We don't have discovery info when all connection params are passed manually if discovery_info is None: return diff --git a/kasa/cli/main.py b/kasa/cli/main.py index fbcdf391..debde60c 100755 --- a/kasa/cli/main.py +++ b/kasa/cli/main.py @@ -22,6 +22,7 @@ from .common import ( CatchAllExceptions, echo, error, + invoke_subcommand, json_formatter_cb, pass_dev_or_child, ) @@ -295,9 +296,10 @@ async def cli( echo("No host name given, trying discovery..") from .discover import discover - return await ctx.invoke(discover) + return await invoke_subcommand(discover, ctx) device_updated = False + device_discovered = False if type is not None and type not in {"smart", "camera"}: from kasa.deviceconfig import DeviceConfig @@ -351,12 +353,14 @@ async def cli( return echo(f"Found hostname by alias: {dev.host}") device_updated = True - else: + else: # host will be set from .discover import discover - dev = await ctx.invoke(discover) - if not dev: + discovered = await invoke_subcommand(discover, ctx) + if not discovered: error(f"Unable to create device for {host}") + dev = discovered[host] + device_discovered = True # Skip update on specific commands, or if device factory, # that performs an update was used for the device. @@ -372,11 +376,14 @@ async def cli( ctx.obj = await ctx.with_async_resource(async_wrapped_device(dev)) - if ctx.invoked_subcommand is None: + # discover command has already invoked state + if ctx.invoked_subcommand is None and not device_discovered: from .device import state return await ctx.invoke(state) + return dev + @cli.command() @pass_dev_or_child