"""Common cli module.""" from __future__ import annotations import asyncio import json import re import sys from collections.abc import Callable from contextlib import contextmanager from functools import singledispatch, update_wrapper, wraps from gettext import gettext from typing import TYPE_CHECKING, Any, Final, NoReturn import asyncclick as click from kasa import ( Device, ) # Value for optional options if passed without a value OPTIONAL_VALUE_FLAG: Final = "_FLAG_" # Block list of commands which require no update SKIP_UPDATE_COMMANDS = ["raw-command", "command"] pass_dev = click.make_pass_decorator(Device) # type: ignore[type-abstract] try: from rich import print as _echo except ImportError: # Strip out rich formatting if rich is not installed # but only lower case tags to avoid stripping out # raw data from the device that is printed from # the device state. rich_formatting = re.compile(r"\[/?[a-z]+]") def _strip_rich_formatting(echo_func): """Strip rich formatting from messages.""" @wraps(echo_func) def wrapper(message=None, *args, **kwargs) -> None: if message is not None: message = rich_formatting.sub("", message) echo_func(message, *args, **kwargs) return wrapper _echo = _strip_rich_formatting(click.echo) def echo(*args, **kwargs) -> None: """Print a message.""" ctx = click.get_current_context().find_root() if "json" not in ctx.params or ctx.params["json"] is False: _echo(*args, **kwargs) def error(msg: str) -> NoReturn: """Print an error and exit.""" echo(f"[bold red]{msg}[/bold red]") sys.exit(1) def json_formatter_cb(result: Any, **kwargs) -> None: """Format and output the result as JSON, if requested.""" if not kwargs.get("json"): return # Calling the discover command directly always returns a DeviceDict so if host # was specified just format the device json if ( (host := kwargs.get("host")) and isinstance(result, dict) and (dev := result.get(host)) and isinstance(dev, Device) ): result = dev @singledispatch def to_serializable(val): """Regular obj-to-string for json serialization. The singledispatch trick is from hynek: https://hynek.me/articles/serialization/ """ return str(val) @to_serializable.register(Device) def _device_to_serializable(val: Device): """Serialize smart device data, just using the last update raw payload.""" return val.internal_state json_content = json.dumps(result, indent=4, default=to_serializable) print(json_content) async def invoke_subcommand( command: click.BaseCommand, ctx: click.Context, args: list[str] | None = None, **extra: Any, ) -> Any: """Invoke a click subcommand. Calling ctx.Invoke() treats the command like a simple callback and doesn't process any result_callbacks so we use this pattern from the click docs https://click.palletsprojects.com/en/stable/exceptions/#what-if-i-don-t-want-that. """ if args is None: args = [] sub_ctx = await command.make_context(command.name, args, parent=ctx, **extra) async with sub_ctx: return await command.invoke(sub_ctx) def pass_dev_or_child(wrapped_function: Callable) -> Callable: """Pass the device or child to the click command based on the child options.""" child_help = ( "Child ID or alias for controlling sub-devices. " "If no value provided will show an interactive prompt allowing you to " "select a child." ) child_index_help = "Child index controlling sub-devices" @contextmanager def patched_device_update(parent: Device, child: Device): try: orig_update = child.update # patch child update method. Can be removed once update can be called # directly on child devices child.update = parent.update # type: ignore[method-assign] yield child finally: child.update = orig_update # type: ignore[method-assign] @click.pass_obj @click.pass_context @click.option( "--child", "--name", is_flag=False, flag_value=OPTIONAL_VALUE_FLAG, default=None, required=False, type=click.STRING, help=child_help, ) @click.option( "--child-index", "--index", required=False, default=None, type=click.INT, help=child_index_help, ) async def wrapper(ctx: click.Context, dev, *args, child, child_index, **kwargs): if child := await _get_child_device(dev, child, child_index, ctx.info_name): ctx.obj = ctx.with_resource(patched_device_update(dev, child)) dev = child return await ctx.invoke(wrapped_function, dev, *args, **kwargs) # Update wrapper function to look like wrapped function return update_wrapper(wrapper, wrapped_function) async def _get_child_device( device: Device, child_option: str | None, child_index_option: int | None, info_command: str | None, ) -> Device | None: def _list_children(): return "\n".join( [ f"{idx}: {child.device_id} ({child.alias})" for idx, child in enumerate(device.children) ] ) if child_option is None and child_index_option is None: return None if info_command in SKIP_UPDATE_COMMANDS: # The device hasn't had update called (e.g. for cmd_command) # The way child devices are accessed requires a ChildDevice to # wrap the communications. Doing this properly would require creating # a common interfaces for both IOT and SMART child devices. # As a stop-gap solution, we perform an update instead. await device.update() if not device.children: error(f"Device: {device.host} does not have children") if child_option is not None and child_index_option is not None: raise click.BadOptionUsage( "child", "Use either --child or --child-index, not both." ) if child_option is not None: if child_option is OPTIONAL_VALUE_FLAG: msg = _list_children() child_index_option = click.prompt( f"\n{msg}\nEnter the index number of the child device", type=click.IntRange(0, len(device.children) - 1), ) elif child := device.get_child_device(child_option): echo(f"Targeting child device {child.alias}") return child else: error( "No child device found with device_id or name: " f"{child_option} children are:\n{_list_children()}" ) if TYPE_CHECKING: assert isinstance(child_index_option, int) if child_index_option + 1 > len(device.children) or child_index_option < 0: error( f"Invalid index {child_index_option}, " f"device has {len(device.children)} children" ) child_by_index = device.children[child_index_option] echo(f"Targeting child device {child_by_index.alias}") return child_by_index def CatchAllExceptions(cls): """Capture all exceptions and prints them nicely. Idea from https://stackoverflow.com/a/44347763 and https://stackoverflow.com/questions/52213375 """ def _handle_exception(debug, exc) -> None: if isinstance(exc, click.ClickException): raise # Handle exit request from click. if isinstance(exc, click.exceptions.Exit): sys.exit(exc.exit_code) if isinstance(exc, click.exceptions.Abort): sys.exit(0) echo(f"Raised error: {exc}") if debug: raise echo("Run with --debug enabled to see stacktrace") sys.exit(1) class _CommandCls(cls): _debug = False async def make_context(self, info_name, args, parent=None, **extra): self._debug = any( [arg for arg in args if arg in ["--debug", "-d", "--verbose", "-v"]] ) try: return await super().make_context( info_name, args, parent=parent, **extra ) except Exception as exc: _handle_exception(self._debug, exc) async def invoke(self, ctx): try: return await super().invoke(ctx) except Exception as exc: _handle_exception(self._debug, exc) def __call__(self, *args, **kwargs): """Run the coroutine in the event loop and print any exceptions. python click catches KeyboardInterrupt in main, raises Abort() and does sys.exit. asyncclick doesn't properly handle a coroutine receiving CancelledError on a KeyboardInterrupt, so we catch the KeyboardInterrupt here once asyncio.run has re-raised it. This avoids large stacktraces when a user presses Ctrl-C. """ try: asyncio.run(self.main(*args, **kwargs)) except KeyboardInterrupt: click.echo(gettext("\nAborted!"), file=sys.stderr) sys.exit(1) return _CommandCls