python-kasa/kasa/cli/common.py
Steven B. 296af3192e
Handle KeyboardInterrupts in the cli better (#1391)
Addresses an issue with how `asyncclick` deals with `KeyboardInterrupt`
errors. Instead of the `click.main` receiving `KeyboardInterrupt` it
receives `CancelledError` because it's a task running inside the loop.

Also ensures that discovery catches the `CancelledError` and closes the
http clients.
2024-12-20 14:21:38 +01:00

259 lines
8.2 KiB
Python

"""Common cli module."""
from __future__ import annotations
import asyncio
import json
import re
import sys
from collections.abc import Callable
from contextlib import contextmanager
from functools import singledispatch, update_wrapper, wraps
from gettext import gettext
from typing import TYPE_CHECKING, Any, Final
import asyncclick as click
from kasa import (
Device,
)
# Value for optional options if passed without a value
OPTIONAL_VALUE_FLAG: Final = "_FLAG_"
# Block list of commands which require no update
SKIP_UPDATE_COMMANDS = ["raw-command", "command"]
pass_dev = click.make_pass_decorator(Device) # type: ignore[type-abstract]
try:
from rich import print as _echo
except ImportError:
# Strip out rich formatting if rich is not installed
# but only lower case tags to avoid stripping out
# raw data from the device that is printed from
# the device state.
rich_formatting = re.compile(r"\[/?[a-z]+]")
def _strip_rich_formatting(echo_func):
"""Strip rich formatting from messages."""
@wraps(echo_func)
def wrapper(message=None, *args, **kwargs) -> None:
if message is not None:
message = rich_formatting.sub("", message)
echo_func(message, *args, **kwargs)
return wrapper
_echo = _strip_rich_formatting(click.echo)
def echo(*args, **kwargs) -> None:
"""Print a message."""
ctx = click.get_current_context().find_root()
if "json" not in ctx.params or ctx.params["json"] is False:
_echo(*args, **kwargs)
def error(msg: str) -> None:
"""Print an error and exit."""
echo(f"[bold red]{msg}[/bold red]")
sys.exit(1)
def json_formatter_cb(result: Any, **kwargs) -> None:
"""Format and output the result as JSON, if requested."""
if not kwargs.get("json"):
return
@singledispatch
def to_serializable(val):
"""Regular obj-to-string for json serialization.
The singledispatch trick is from hynek: https://hynek.me/articles/serialization/
"""
return str(val)
@to_serializable.register(Device)
def _device_to_serializable(val: Device):
"""Serialize smart device data, just using the last update raw payload."""
return val.internal_state
json_content = json.dumps(result, indent=4, default=to_serializable)
print(json_content)
def pass_dev_or_child(wrapped_function: Callable) -> Callable:
"""Pass the device or child to the click command based on the child options."""
child_help = (
"Child ID or alias for controlling sub-devices. "
"If no value provided will show an interactive prompt allowing you to "
"select a child."
)
child_index_help = "Child index controlling sub-devices"
@contextmanager
def patched_device_update(parent: Device, child: Device):
try:
orig_update = child.update
# patch child update method. Can be removed once update can be called
# directly on child devices
child.update = parent.update # type: ignore[method-assign]
yield child
finally:
child.update = orig_update # type: ignore[method-assign]
@click.pass_obj
@click.pass_context
@click.option(
"--child",
"--name",
is_flag=False,
flag_value=OPTIONAL_VALUE_FLAG,
default=None,
required=False,
type=click.STRING,
help=child_help,
)
@click.option(
"--child-index",
"--index",
required=False,
default=None,
type=click.INT,
help=child_index_help,
)
async def wrapper(ctx: click.Context, dev, *args, child, child_index, **kwargs):
if child := await _get_child_device(dev, child, child_index, ctx.info_name):
ctx.obj = ctx.with_resource(patched_device_update(dev, child))
dev = child
return await ctx.invoke(wrapped_function, dev, *args, **kwargs)
# Update wrapper function to look like wrapped function
return update_wrapper(wrapper, wrapped_function)
async def _get_child_device(
device: Device,
child_option: str | None,
child_index_option: int | None,
info_command: str | None,
) -> Device | None:
def _list_children():
return "\n".join(
[
f"{idx}: {child.device_id} ({child.alias})"
for idx, child in enumerate(device.children)
]
)
if child_option is None and child_index_option is None:
return None
if info_command in SKIP_UPDATE_COMMANDS:
# The device hasn't had update called (e.g. for cmd_command)
# The way child devices are accessed requires a ChildDevice to
# wrap the communications. Doing this properly would require creating
# a common interfaces for both IOT and SMART child devices.
# As a stop-gap solution, we perform an update instead.
await device.update()
if not device.children:
error(f"Device: {device.host} does not have children")
if child_option is not None and child_index_option is not None:
raise click.BadOptionUsage(
"child", "Use either --child or --child-index, not both."
)
if child_option is not None:
if child_option is OPTIONAL_VALUE_FLAG:
msg = _list_children()
child_index_option = click.prompt(
f"\n{msg}\nEnter the index number of the child device",
type=click.IntRange(0, len(device.children) - 1),
)
elif child := device.get_child_device(child_option):
echo(f"Targeting child device {child.alias}")
return child
else:
error(
"No child device found with device_id or name: "
f"{child_option} children are:\n{_list_children()}"
)
if TYPE_CHECKING:
assert isinstance(child_index_option, int)
if child_index_option + 1 > len(device.children) or child_index_option < 0:
error(
f"Invalid index {child_index_option}, "
f"device has {len(device.children)} children"
)
child_by_index = device.children[child_index_option]
echo(f"Targeting child device {child_by_index.alias}")
return child_by_index
def CatchAllExceptions(cls):
"""Capture all exceptions and prints them nicely.
Idea from https://stackoverflow.com/a/44347763 and
https://stackoverflow.com/questions/52213375
"""
def _handle_exception(debug, exc) -> None:
if isinstance(exc, click.ClickException):
raise
# Handle exit request from click.
if isinstance(exc, click.exceptions.Exit):
sys.exit(exc.exit_code)
if isinstance(exc, click.exceptions.Abort):
sys.exit(0)
echo(f"Raised error: {exc}")
if debug:
raise
echo("Run with --debug enabled to see stacktrace")
sys.exit(1)
class _CommandCls(cls):
_debug = False
async def make_context(self, info_name, args, parent=None, **extra):
self._debug = any(
[arg for arg in args if arg in ["--debug", "-d", "--verbose", "-v"]]
)
try:
return await super().make_context(
info_name, args, parent=parent, **extra
)
except Exception as exc:
_handle_exception(self._debug, exc)
async def invoke(self, ctx):
try:
return await super().invoke(ctx)
except Exception as exc:
_handle_exception(self._debug, exc)
def __call__(self, *args, **kwargs):
"""Run the coroutine in the event loop and print any exceptions.
python click catches KeyboardInterrupt in main, raises Abort()
and does sys.exit. asyncclick doesn't properly handle a coroutine
receiving CancelledError on a KeyboardInterrupt, so we catch the
KeyboardInterrupt here once asyncio.run has re-raised it. This
avoids large stacktraces when a user presses Ctrl-C.
"""
try:
asyncio.run(self.main(*args, **kwargs))
except KeyboardInterrupt:
click.echo(gettext("\nAborted!"), file=sys.stderr)
sys.exit(1)
return _CommandCls