Fix discover cli command with host (#1437)

This commit is contained in:
Steven B. 2025-01-14 14:47:52 +00:00 committed by GitHub
parent 1be87674bf
commit d03f535568
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 103 additions and 24 deletions

View File

@ -10,7 +10,7 @@ from collections.abc import Callable
from contextlib import contextmanager from contextlib import contextmanager
from functools import singledispatch, update_wrapper, wraps from functools import singledispatch, update_wrapper, wraps
from gettext import gettext from gettext import gettext
from typing import TYPE_CHECKING, Any, Final from typing import TYPE_CHECKING, Any, Final, NoReturn
import asyncclick as click import asyncclick as click
@ -57,7 +57,7 @@ def echo(*args, **kwargs) -> None:
_echo(*args, **kwargs) _echo(*args, **kwargs)
def error(msg: str) -> None: def error(msg: str) -> NoReturn:
"""Print an error and exit.""" """Print an error and exit."""
echo(f"[bold red]{msg}[/bold red]") echo(f"[bold red]{msg}[/bold red]")
sys.exit(1) sys.exit(1)
@ -68,6 +68,16 @@ def json_formatter_cb(result: Any, **kwargs) -> None:
if not kwargs.get("json"): if not kwargs.get("json"):
return return
# Calling the discover command directly always returns a DeviceDict so if host
# was specified just format the device json
if (
(host := kwargs.get("host"))
and isinstance(result, dict)
and (dev := result.get(host))
and isinstance(dev, Device)
):
result = dev
@singledispatch @singledispatch
def to_serializable(val): def to_serializable(val):
"""Regular obj-to-string for json serialization. """Regular obj-to-string for json serialization.
@ -85,6 +95,25 @@ def json_formatter_cb(result: Any, **kwargs) -> None:
print(json_content) print(json_content)
async def invoke_subcommand(
command: click.BaseCommand,
ctx: click.Context,
args: list[str] | None = None,
**extra: Any,
) -> Any:
"""Invoke a click subcommand.
Calling ctx.Invoke() treats the command like a simple callback and doesn't
process any result_callbacks so we use this pattern from the click docs
https://click.palletsprojects.com/en/stable/exceptions/#what-if-i-don-t-want-that.
"""
if args is None:
args = []
sub_ctx = await command.make_context(command.name, args, parent=ctx, **extra)
async with sub_ctx:
return await command.invoke(sub_ctx)
def pass_dev_or_child(wrapped_function: Callable) -> Callable: def pass_dev_or_child(wrapped_function: Callable) -> Callable:
"""Pass the device or child to the click command based on the child options.""" """Pass the device or child to the click command based on the child options."""
child_help = ( child_help = (

View File

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from pprint import pformat as pf from pprint import pformat as pf
from typing import TYPE_CHECKING
import asyncclick as click import asyncclick as click
@ -82,6 +83,8 @@ async def state(ctx, dev: Device):
echo() echo()
from .discover import _echo_discovery_info from .discover import _echo_discovery_info
if TYPE_CHECKING:
assert dev._discovery_info
_echo_discovery_info(dev._discovery_info) _echo_discovery_info(dev._discovery_info)
return dev.internal_state return dev.internal_state

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio import asyncio
from pprint import pformat as pf from pprint import pformat as pf
from typing import TYPE_CHECKING, cast
import asyncclick as click import asyncclick as click
@ -17,8 +18,12 @@ from kasa import (
from kasa.discover import ( from kasa.discover import (
NEW_DISCOVERY_REDACTORS, NEW_DISCOVERY_REDACTORS,
ConnectAttempt, ConnectAttempt,
DeviceDict,
DiscoveredRaw, DiscoveredRaw,
DiscoveryResult, DiscoveryResult,
OnDiscoveredCallable,
OnDiscoveredRawCallable,
OnUnsupportedCallable,
) )
from kasa.iot.iotdevice import _extract_sys_info from kasa.iot.iotdevice import _extract_sys_info
from kasa.protocols.iotprotocol import REDACTORS as IOT_REDACTORS from kasa.protocols.iotprotocol import REDACTORS as IOT_REDACTORS
@ -30,15 +35,33 @@ from .common import echo, error
@click.group(invoke_without_command=True) @click.group(invoke_without_command=True)
@click.pass_context @click.pass_context
async def discover(ctx): async def discover(ctx: click.Context):
"""Discover devices in the network.""" """Discover devices in the network."""
if ctx.invoked_subcommand is None: if ctx.invoked_subcommand is None:
return await ctx.invoke(detail) return await ctx.invoke(detail)
@discover.result_callback()
@click.pass_context
async def _close_protocols(ctx: click.Context, discovered: DeviceDict):
"""Close all the device protocols if discover was invoked directly by the user."""
if _discover_is_root_cmd(ctx):
for dev in discovered.values():
await dev.disconnect()
return discovered
def _discover_is_root_cmd(ctx: click.Context) -> bool:
"""Will return true if discover was invoked directly by the user."""
root_ctx = ctx.find_root()
return (
root_ctx.invoked_subcommand is None or root_ctx.invoked_subcommand == "discover"
)
@discover.command() @discover.command()
@click.pass_context @click.pass_context
async def detail(ctx): async def detail(ctx: click.Context) -> DeviceDict:
"""Discover devices in the network using udp broadcasts.""" """Discover devices in the network using udp broadcasts."""
unsupported = [] unsupported = []
auth_failed = [] auth_failed = []
@ -59,10 +82,14 @@ async def detail(ctx):
from .device import state from .device import state
async def print_discovered(dev: Device) -> None: async def print_discovered(dev: Device) -> None:
if TYPE_CHECKING:
assert ctx.parent
async with sem: async with sem:
try: try:
await dev.update() await dev.update()
except AuthenticationError: except AuthenticationError:
if TYPE_CHECKING:
assert dev._discovery_info
auth_failed.append(dev._discovery_info) auth_failed.append(dev._discovery_info)
echo("== Authentication failed for device ==") echo("== Authentication failed for device ==")
_echo_discovery_info(dev._discovery_info) _echo_discovery_info(dev._discovery_info)
@ -73,9 +100,11 @@ async def detail(ctx):
echo() echo()
discovered = await _discover( discovered = await _discover(
ctx, print_discovered=print_discovered, print_unsupported=print_unsupported ctx,
print_discovered=print_discovered if _discover_is_root_cmd(ctx) else None,
print_unsupported=print_unsupported,
) )
if ctx.parent.parent.params["host"]: if ctx.find_root().params["host"]:
return discovered return discovered
echo(f"Found {len(discovered)} devices") echo(f"Found {len(discovered)} devices")
@ -96,7 +125,7 @@ async def detail(ctx):
help="Set flag to redact sensitive data from raw output.", help="Set flag to redact sensitive data from raw output.",
) )
@click.pass_context @click.pass_context
async def raw(ctx, redact: bool): async def raw(ctx: click.Context, redact: bool) -> DeviceDict:
"""Return raw discovery data returned from devices.""" """Return raw discovery data returned from devices."""
def print_raw(discovered: DiscoveredRaw): def print_raw(discovered: DiscoveredRaw):
@ -116,7 +145,7 @@ async def raw(ctx, redact: bool):
@discover.command() @discover.command()
@click.pass_context @click.pass_context
async def list(ctx): async def list(ctx: click.Context) -> DeviceDict:
"""List devices in the network in a table using udp broadcasts.""" """List devices in the network in a table using udp broadcasts."""
sem = asyncio.Semaphore() sem = asyncio.Semaphore()
@ -147,18 +176,24 @@ async def list(ctx):
f"{'HOST':<15} {'MODEL':<9} {'DEVICE FAMILY':<20} {'ENCRYPT':<7} " f"{'HOST':<15} {'MODEL':<9} {'DEVICE FAMILY':<20} {'ENCRYPT':<7} "
f"{'HTTPS':<5} {'LV':<3} {'ALIAS'}" f"{'HTTPS':<5} {'LV':<3} {'ALIAS'}"
) )
return await _discover( discovered = await _discover(
ctx, ctx,
print_discovered=print_discovered, print_discovered=print_discovered,
print_unsupported=print_unsupported, print_unsupported=print_unsupported,
do_echo=False, do_echo=False,
) )
return discovered
async def _discover( async def _discover(
ctx, *, print_discovered=None, print_unsupported=None, print_raw=None, do_echo=True ctx: click.Context,
): *,
params = ctx.parent.parent.params print_discovered: OnDiscoveredCallable | None = None,
print_unsupported: OnUnsupportedCallable | None = None,
print_raw: OnDiscoveredRawCallable | None = None,
do_echo=True,
) -> DeviceDict:
params = ctx.find_root().params
target = params["target"] target = params["target"]
username = params["username"] username = params["username"]
password = params["password"] password = params["password"]
@ -170,8 +205,9 @@ async def _discover(
credentials = Credentials(username, password) if username and password else None credentials = Credentials(username, password) if username and password else None
if host: if host:
host = cast(str, host)
echo(f"Discovering device {host} for {discovery_timeout} seconds") echo(f"Discovering device {host} for {discovery_timeout} seconds")
return await Discover.discover_single( dev = await Discover.discover_single(
host, host,
port=port, port=port,
credentials=credentials, credentials=credentials,
@ -180,6 +216,12 @@ async def _discover(
on_unsupported=print_unsupported, on_unsupported=print_unsupported,
on_discovered_raw=print_raw, on_discovered_raw=print_raw,
) )
if dev:
if print_discovered:
await print_discovered(dev)
return {host: dev}
else:
return {}
if do_echo: if do_echo:
echo(f"Discovering devices on {target} for {discovery_timeout} seconds") echo(f"Discovering devices on {target} for {discovery_timeout} seconds")
discovered_devices = await Discover.discover( discovered_devices = await Discover.discover(
@ -193,21 +235,18 @@ async def _discover(
on_discovered_raw=print_raw, on_discovered_raw=print_raw,
) )
for device in discovered_devices.values():
await device.protocol.close()
return discovered_devices return discovered_devices
@discover.command() @discover.command()
@click.pass_context @click.pass_context
async def config(ctx): async def config(ctx: click.Context) -> DeviceDict:
"""Bypass udp discovery and try to show connection config for a device. """Bypass udp discovery and try to show connection config for a device.
Bypasses udp discovery and shows the parameters required to connect Bypasses udp discovery and shows the parameters required to connect
directly to the device. directly to the device.
""" """
params = ctx.parent.parent.params params = ctx.find_root().params
username = params["username"] username = params["username"]
password = params["password"] password = params["password"]
timeout = params["timeout"] timeout = params["timeout"]
@ -239,6 +278,7 @@ async def config(ctx):
f"--encrypt-type {cparams.encryption_type.value} " f"--encrypt-type {cparams.encryption_type.value} "
f"{'--https' if cparams.https else '--no-https'}" f"{'--https' if cparams.https else '--no-https'}"
) )
return {host: dev}
else: else:
error(f"Unable to connect to {host}") error(f"Unable to connect to {host}")
@ -251,7 +291,7 @@ def _echo_dictionary(discovery_info: dict) -> None:
echo(f"\t{key_name_and_spaces}{value}") echo(f"\t{key_name_and_spaces}{value}")
def _echo_discovery_info(discovery_info) -> None: def _echo_discovery_info(discovery_info: dict) -> None:
# We don't have discovery info when all connection params are passed manually # We don't have discovery info when all connection params are passed manually
if discovery_info is None: if discovery_info is None:
return return

View File

@ -22,6 +22,7 @@ from .common import (
CatchAllExceptions, CatchAllExceptions,
echo, echo,
error, error,
invoke_subcommand,
json_formatter_cb, json_formatter_cb,
pass_dev_or_child, pass_dev_or_child,
) )
@ -295,9 +296,10 @@ async def cli(
echo("No host name given, trying discovery..") echo("No host name given, trying discovery..")
from .discover import discover from .discover import discover
return await ctx.invoke(discover) return await invoke_subcommand(discover, ctx)
device_updated = False device_updated = False
device_discovered = False
if type is not None and type not in {"smart", "camera"}: if type is not None and type not in {"smart", "camera"}:
from kasa.deviceconfig import DeviceConfig from kasa.deviceconfig import DeviceConfig
@ -351,12 +353,14 @@ async def cli(
return return
echo(f"Found hostname by alias: {dev.host}") echo(f"Found hostname by alias: {dev.host}")
device_updated = True device_updated = True
else: else: # host will be set
from .discover import discover from .discover import discover
dev = await ctx.invoke(discover) discovered = await invoke_subcommand(discover, ctx)
if not dev: if not discovered:
error(f"Unable to create device for {host}") error(f"Unable to create device for {host}")
dev = discovered[host]
device_discovered = True
# Skip update on specific commands, or if device factory, # Skip update on specific commands, or if device factory,
# that performs an update was used for the device. # that performs an update was used for the device.
@ -372,11 +376,14 @@ async def cli(
ctx.obj = await ctx.with_async_resource(async_wrapped_device(dev)) ctx.obj = await ctx.with_async_resource(async_wrapped_device(dev))
if ctx.invoked_subcommand is None: # discover command has already invoked state
if ctx.invoked_subcommand is None and not device_discovered:
from .device import state from .device import state
return await ctx.invoke(state) return await ctx.invoke(state)
return dev
@cli.command() @cli.command()
@pass_dev_or_child @pass_dev_or_child