From ed033679e5f7e570129c1c6437562c670fa1bc26 Mon Sep 17 00:00:00 2001 From: "Steven B." <51370195+sdb9696@users.noreply.github.com> Date: Tue, 23 Jul 2024 19:13:52 +0100 Subject: [PATCH] Split out main cli module into lazily loaded submodules (#1039) --- kasa/cli/__main__.py | 3 +- kasa/cli/common.py | 231 ++++++++ kasa/cli/device.py | 184 ++++++ kasa/cli/discover.py | 142 +++++ kasa/cli/feature.py | 134 +++++ kasa/cli/lazygroup.py | 70 +++ kasa/cli/light.py | 200 +++++++ kasa/cli/main.py | 1228 +++++----------------------------------- kasa/cli/schedule.py | 46 ++ kasa/cli/time.py | 55 ++ kasa/cli/usage.py | 134 +++++ kasa/cli/wifi.py | 50 ++ kasa/tests/test_cli.py | 42 +- pyproject.toml | 2 +- 14 files changed, 1403 insertions(+), 1118 deletions(-) create mode 100644 kasa/cli/common.py create mode 100644 kasa/cli/device.py create mode 100644 kasa/cli/discover.py create mode 100644 kasa/cli/feature.py create mode 100644 kasa/cli/lazygroup.py create mode 100644 kasa/cli/light.py create mode 100644 kasa/cli/schedule.py create mode 100644 kasa/cli/time.py create mode 100644 kasa/cli/usage.py create mode 100644 kasa/cli/wifi.py diff --git a/kasa/cli/__main__.py b/kasa/cli/__main__.py index 5d4ca6a0..1cf92da1 100644 --- a/kasa/cli/__main__.py +++ b/kasa/cli/__main__.py @@ -2,4 +2,5 @@ from kasa.cli.main import cli -cli() +if __name__ == "__main__": + cli() diff --git a/kasa/cli/common.py b/kasa/cli/common.py new file mode 100644 index 00000000..1977d0c8 --- /dev/null +++ b/kasa/cli/common.py @@ -0,0 +1,231 @@ +"""Common cli module.""" + +from __future__ import annotations + +import json +import re +import sys +from contextlib import contextmanager +from functools import singledispatch, update_wrapper, wraps +from typing import Final + +import asyncclick as click + +from kasa import ( + Device, +) + +# Value for optional options if passed without a value +OPTIONAL_VALUE_FLAG: Final = "_FLAG_" + +# Block list of commands which require no update +SKIP_UPDATE_COMMANDS = ["raw-command", "command"] + +pass_dev = click.make_pass_decorator(Device) # type: ignore[type-abstract] + + +try: + from rich import print as _echo +except ImportError: + # Strip out rich formatting if rich is not installed + # but only lower case tags to avoid stripping out + # raw data from the device that is printed from + # the device state. + rich_formatting = re.compile(r"\[/?[a-z]+]") + + def _strip_rich_formatting(echo_func): + """Strip rich formatting from messages.""" + + @wraps(echo_func) + def wrapper(message=None, *args, **kwargs): + if message is not None: + message = rich_formatting.sub("", message) + echo_func(message, *args, **kwargs) + + return wrapper + + _echo = _strip_rich_formatting(click.echo) + + +def echo(*args, **kwargs): + """Print a message.""" + ctx = click.get_current_context().find_root() + if "json" not in ctx.params or ctx.params["json"] is False: + _echo(*args, **kwargs) + + +def error(msg: str): + """Print an error and exit.""" + echo(f"[bold red]{msg}[/bold red]") + sys.exit(1) + + +def json_formatter_cb(result, **kwargs): + """Format and output the result as JSON, if requested.""" + if not kwargs.get("json"): + return + + @singledispatch + def to_serializable(val): + """Regular obj-to-string for json serialization. + + The singledispatch trick is from hynek: https://hynek.me/articles/serialization/ + """ + return str(val) + + @to_serializable.register(Device) + def _device_to_serializable(val: Device): + """Serialize smart device data, just using the last update raw payload.""" + return val.internal_state + + json_content = json.dumps(result, indent=4, default=to_serializable) + print(json_content) + + +def pass_dev_or_child(wrapped_function): + """Pass the device or child to the click command based on the child options.""" + child_help = ( + "Child ID or alias for controlling sub-devices. " + "If no value provided will show an interactive prompt allowing you to " + "select a child." + ) + child_index_help = "Child index controlling sub-devices" + + @contextmanager + def patched_device_update(parent: Device, child: Device): + try: + orig_update = child.update + # patch child update method. Can be removed once update can be called + # directly on child devices + child.update = parent.update # type: ignore[method-assign] + yield child + finally: + child.update = orig_update # type: ignore[method-assign] + + @click.pass_obj + @click.pass_context + @click.option( + "--child", + "--name", + is_flag=False, + flag_value=OPTIONAL_VALUE_FLAG, + default=None, + required=False, + type=click.STRING, + help=child_help, + ) + @click.option( + "--child-index", + "--index", + required=False, + default=None, + type=click.INT, + help=child_index_help, + ) + async def wrapper(ctx: click.Context, dev, *args, child, child_index, **kwargs): + if child := await _get_child_device(dev, child, child_index, ctx.info_name): + ctx.obj = ctx.with_resource(patched_device_update(dev, child)) + dev = child + return await ctx.invoke(wrapped_function, dev, *args, **kwargs) + + # Update wrapper function to look like wrapped function + return update_wrapper(wrapper, wrapped_function) + + +async def _get_child_device( + device: Device, child_option, child_index_option, info_command +) -> Device | None: + def _list_children(): + return "\n".join( + [ + f"{idx}: {child.device_id} ({child.alias})" + for idx, child in enumerate(device.children) + ] + ) + + if child_option is None and child_index_option is None: + return None + + if info_command in SKIP_UPDATE_COMMANDS: + # The device hasn't had update called (e.g. for cmd_command) + # The way child devices are accessed requires a ChildDevice to + # wrap the communications. Doing this properly would require creating + # a common interfaces for both IOT and SMART child devices. + # As a stop-gap solution, we perform an update instead. + await device.update() + + if not device.children: + error(f"Device: {device.host} does not have children") + + if child_option is not None and child_index_option is not None: + raise click.BadOptionUsage( + "child", "Use either --child or --child-index, not both." + ) + + if child_option is not None: + if child_option is OPTIONAL_VALUE_FLAG: + msg = _list_children() + child_index_option = click.prompt( + f"\n{msg}\nEnter the index number of the child device", + type=click.IntRange(0, len(device.children) - 1), + ) + elif child := device.get_child_device(child_option): + echo(f"Targeting child device {child.alias}") + return child + else: + error( + "No child device found with device_id or name: " + f"{child_option} children are:\n{_list_children()}" + ) + + if child_index_option + 1 > len(device.children) or child_index_option < 0: + error( + f"Invalid index {child_index_option}, " + f"device has {len(device.children)} children" + ) + child_by_index = device.children[child_index_option] + echo(f"Targeting child device {child_by_index.alias}") + return child_by_index + + +def CatchAllExceptions(cls): + """Capture all exceptions and prints them nicely. + + Idea from https://stackoverflow.com/a/44347763 and + https://stackoverflow.com/questions/52213375 + """ + + def _handle_exception(debug, exc): + if isinstance(exc, click.ClickException): + raise + # Handle exit request from click. + if isinstance(exc, click.exceptions.Exit): + sys.exit(exc.exit_code) + + echo(f"Raised error: {exc}") + if debug: + raise + echo("Run with --debug enabled to see stacktrace") + sys.exit(1) + + class _CommandCls(cls): + _debug = False + + async def make_context(self, info_name, args, parent=None, **extra): + self._debug = any( + [arg for arg in args if arg in ["--debug", "-d", "--verbose", "-v"]] + ) + try: + return await super().make_context( + info_name, args, parent=parent, **extra + ) + except Exception as exc: + _handle_exception(self._debug, exc) + + async def invoke(self, ctx): + try: + return await super().invoke(ctx) + except Exception as exc: + _handle_exception(self._debug, exc) + + return _CommandCls diff --git a/kasa/cli/device.py b/kasa/cli/device.py new file mode 100644 index 00000000..60438035 --- /dev/null +++ b/kasa/cli/device.py @@ -0,0 +1,184 @@ +"""Module for cli device commands.""" + +from __future__ import annotations + +from pprint import pformat as pf + +import asyncclick as click + +from kasa import ( + Device, + Module, +) +from kasa.smart import SmartDevice + +from .common import ( + echo, + error, + pass_dev, + pass_dev_or_child, +) + + +@click.group() +@pass_dev_or_child +def device(dev): + """Commands to control basic device settings.""" + + +@device.command() +@pass_dev_or_child +@click.pass_context +async def state(ctx, dev: Device): + """Print out device state and versions.""" + from .feature import _echo_all_features + + verbose = ctx.parent.params.get("verbose", False) if ctx.parent else False + + echo(f"[bold]== {dev.alias} - {dev.model} ==[/bold]") + echo(f"Host: {dev.host}") + echo(f"Port: {dev.port}") + echo(f"Device state: {dev.is_on}") + + echo(f"Time: {dev.time} (tz: {dev.timezone}") + echo(f"Hardware: {dev.hw_info['hw_ver']}") + echo(f"Software: {dev.hw_info['sw_ver']}") + echo(f"MAC (rssi): {dev.mac} ({dev.rssi})") + if verbose: + echo(f"Location: {dev.location}") + + echo() + _echo_all_features(dev.features, verbose=verbose) + + if verbose: + echo("\n[bold]== Modules ==[/bold]") + for module in dev.modules.values(): + echo(f"[green]+ {module}[/green]") + + if dev.children: + echo("\n[bold]== Children ==[/bold]") + for child in dev.children: + _echo_all_features( + child.features, + title_prefix=f"{child.alias} ({child.model})", + verbose=verbose, + indent="\t", + ) + if verbose: + echo(f"\n\t[bold]== Child {child.alias} Modules ==[/bold]") + for module in child.modules.values(): + echo(f"\t[green]+ {module}[/green]") + echo() + + if verbose: + echo("\n\t[bold]== Protocol information ==[/bold]") + echo(f"\tCredentials hash: {dev.credentials_hash}") + echo() + from .discover import _echo_discovery_info + + _echo_discovery_info(dev._discovery_info) + + return dev.internal_state + + +@device.command() +@pass_dev_or_child +async def sysinfo(dev): + """Print out full system information.""" + echo("== System info ==") + echo(pf(dev.sys_info)) + return dev.sys_info + + +@device.command() +@click.option("--transition", type=int, required=False) +@pass_dev_or_child +async def on(dev: Device, transition: int): + """Turn the device on.""" + echo(f"Turning on {dev.alias}") + return await dev.turn_on(transition=transition) + + +@click.command +@click.option("--transition", type=int, required=False) +@pass_dev_or_child +async def off(dev: Device, transition: int): + """Turn the device off.""" + echo(f"Turning off {dev.alias}") + return await dev.turn_off(transition=transition) + + +@device.command() +@click.option("--transition", type=int, required=False) +@pass_dev_or_child +async def toggle(dev: Device, transition: int): + """Toggle the device on/off.""" + if dev.is_on: + echo(f"Turning off {dev.alias}") + return await dev.turn_off(transition=transition) + + echo(f"Turning on {dev.alias}") + return await dev.turn_on(transition=transition) + + +@device.command() +@click.argument("state", type=bool, required=False) +@pass_dev_or_child +async def led(dev: Device, state): + """Get or set (Plug's) led state.""" + if not (led := dev.modules.get(Module.Led)): + error("Device does not support led.") + return + if state is not None: + echo(f"Turning led to {state}") + return await led.set_led(state) + else: + echo(f"LED state: {led.led}") + return led.led + + +@device.command() +@click.argument("new_alias", required=False, default=None) +@pass_dev_or_child +async def alias(dev, new_alias): + """Get or set the device (or plug) alias.""" + if new_alias is not None: + echo(f"Setting alias to {new_alias}") + res = await dev.set_alias(new_alias) + await dev.update() + echo(f"Alias set to: {dev.alias}") + return res + + echo(f"Alias: {dev.alias}") + if dev.children: + for plug in dev.children: + echo(f" * {plug.alias}") + + return dev.alias + + +@device.command() +@click.option("--delay", default=1) +@pass_dev +async def reboot(plug, delay): + """Reboot the device.""" + echo("Rebooting the device..") + return await plug.reboot(delay) + + +@device.command() +@pass_dev +@click.option( + "--username", required=True, prompt=True, help="New username to set on the device" +) +@click.option( + "--password", required=True, prompt=True, help="New password to set on the device" +) +async def update_credentials(dev, username, password): + """Update device credentials for authenticated devices.""" + if not isinstance(dev, SmartDevice): + error("Credentials can only be updated on authenticated devices.") + + click.confirm("Do you really want to replace the existing credentials?", abort=True) + + return await dev.update_credentials(username, password) diff --git a/kasa/cli/discover.py b/kasa/cli/discover.py new file mode 100644 index 00000000..6bf58e72 --- /dev/null +++ b/kasa/cli/discover.py @@ -0,0 +1,142 @@ +"""Module for cli discovery commands.""" + +from __future__ import annotations + +import asyncio + +import asyncclick as click +from pydantic.v1 import ValidationError + +from kasa import ( + AuthenticationError, + Credentials, + Device, + Discover, + UnsupportedDeviceError, +) +from kasa.discover import DiscoveryResult + +from .common import echo + + +@click.command() +@click.pass_context +async def discover(ctx): + """Discover devices in the network.""" + target = ctx.parent.params["target"] + username = ctx.parent.params["username"] + password = ctx.parent.params["password"] + discovery_timeout = ctx.parent.params["discovery_timeout"] + timeout = ctx.parent.params["timeout"] + port = ctx.parent.params["port"] + + credentials = Credentials(username, password) if username and password else None + + sem = asyncio.Semaphore() + discovered = dict() + unsupported = [] + auth_failed = [] + + async def print_unsupported(unsupported_exception: UnsupportedDeviceError): + unsupported.append(unsupported_exception) + async with sem: + if unsupported_exception.discovery_result: + echo("== Unsupported device ==") + _echo_discovery_info(unsupported_exception.discovery_result) + echo() + else: + echo("== Unsupported device ==") + echo(f"\t{unsupported_exception}") + echo() + + echo(f"Discovering devices on {target} for {discovery_timeout} seconds") + + from .device import state + + async def print_discovered(dev: Device): + async with sem: + try: + await dev.update() + except AuthenticationError: + auth_failed.append(dev._discovery_info) + echo("== Authentication failed for device ==") + _echo_discovery_info(dev._discovery_info) + echo() + else: + ctx.parent.obj = dev + await ctx.parent.invoke(state) + discovered[dev.host] = dev.internal_state + echo() + + discovered_devices = await Discover.discover( + target=target, + discovery_timeout=discovery_timeout, + on_discovered=print_discovered, + on_unsupported=print_unsupported, + port=port, + timeout=timeout, + credentials=credentials, + ) + + for device in discovered_devices.values(): + await device.protocol.close() + + echo(f"Found {len(discovered)} devices") + if unsupported: + echo(f"Found {len(unsupported)} unsupported devices") + if auth_failed: + echo(f"Found {len(auth_failed)} devices that failed to authenticate") + + return discovered + + +def _echo_dictionary(discovery_info: dict): + echo("\t[bold]== Discovery information ==[/bold]") + for key, value in discovery_info.items(): + key_name = " ".join(x.capitalize() or "_" for x in key.split("_")) + key_name_and_spaces = "{:<15}".format(key_name + ":") + echo(f"\t{key_name_and_spaces}{value}") + + +def _echo_discovery_info(discovery_info): + # We don't have discovery info when all connection params are passed manually + if discovery_info is None: + return + + if "system" in discovery_info and "get_sysinfo" in discovery_info["system"]: + _echo_dictionary(discovery_info["system"]["get_sysinfo"]) + return + + try: + dr = DiscoveryResult(**discovery_info) + except ValidationError: + _echo_dictionary(discovery_info) + return + + echo("\t[bold]== Discovery Result ==[/bold]") + echo(f"\tDevice Type: {dr.device_type}") + echo(f"\tDevice Model: {dr.device_model}") + echo(f"\tIP: {dr.ip}") + echo(f"\tMAC: {dr.mac}") + echo(f"\tDevice Id (hash): {dr.device_id}") + echo(f"\tOwner (hash): {dr.owner}") + echo(f"\tHW Ver: {dr.hw_ver}") + echo(f"\tSupports IOT Cloud: {dr.is_support_iot_cloud}") + echo(f"\tOBD Src: {dr.obd_src}") + echo(f"\tFactory Default: {dr.factory_default}") + echo(f"\tEncrypt Type: {dr.mgt_encrypt_schm.encrypt_type}") + echo(f"\tSupports HTTPS: {dr.mgt_encrypt_schm.is_support_https}") + echo(f"\tHTTP Port: {dr.mgt_encrypt_schm.http_port}") + echo(f"\tLV (Login Level): {dr.mgt_encrypt_schm.lv}") + + +async def find_host_from_alias(alias, target="255.255.255.255", timeout=1, attempts=3): + """Discover a device identified by its alias.""" + for _attempt in range(1, attempts): + found_devs = await Discover.discover(target=target, timeout=timeout) + for _ip, dev in found_devs.items(): + if dev.alias.lower() == alias.lower(): + host = dev.host + return host + + return None diff --git a/kasa/cli/feature.py b/kasa/cli/feature.py new file mode 100644 index 00000000..f8cba4e3 --- /dev/null +++ b/kasa/cli/feature.py @@ -0,0 +1,134 @@ +"""Module for cli feature commands.""" + +from __future__ import annotations + +import ast + +import asyncclick as click + +from kasa import ( + Device, + Feature, +) + +from .common import ( + echo, + error, + pass_dev_or_child, +) + + +def _echo_features( + features: dict[str, Feature], + title: str, + category: Feature.Category | None = None, + verbose: bool = False, + indent: str = "\t", +): + """Print out a listing of features and their values.""" + if category is not None: + features = { + id_: feat for id_, feat in features.items() if feat.category == category + } + + echo(f"{indent}[bold]{title}[/bold]") + for _, feat in features.items(): + try: + echo(f"{indent}{feat}") + if verbose: + echo(f"{indent}\tType: {feat.type}") + echo(f"{indent}\tCategory: {feat.category}") + echo(f"{indent}\tIcon: {feat.icon}") + except Exception as ex: + echo(f"{indent}{feat.name} ({feat.id}): [red]got exception ({ex})[/red]") + + +def _echo_all_features(features, *, verbose=False, title_prefix=None, indent=""): + """Print out all features by category.""" + if title_prefix is not None: + echo(f"[bold]\n{indent}== {title_prefix} ==[/bold]") + echo() + _echo_features( + features, + title="== Primary features ==", + category=Feature.Category.Primary, + verbose=verbose, + indent=indent, + ) + echo() + _echo_features( + features, + title="== Information ==", + category=Feature.Category.Info, + verbose=verbose, + indent=indent, + ) + echo() + _echo_features( + features, + title="== Configuration ==", + category=Feature.Category.Config, + verbose=verbose, + indent=indent, + ) + echo() + _echo_features( + features, + title="== Debug ==", + category=Feature.Category.Debug, + verbose=verbose, + indent=indent, + ) + + +@click.command(name="feature") +@click.argument("name", required=False) +@click.argument("value", required=False) +@pass_dev_or_child +@click.pass_context +async def feature( + ctx: click.Context, + dev: Device, + name: str, + value, +): + """Access and modify features. + + If no *name* is given, lists available features and their values. + If only *name* is given, the value of named feature is returned. + If both *name* and *value* are set, the described setting is changed. + """ + verbose = ctx.parent.params.get("verbose", False) if ctx.parent else False + + if not name: + _echo_all_features(dev.features, verbose=verbose, indent="") + + if dev.children: + for child_dev in dev.children: + _echo_all_features( + child_dev.features, + verbose=verbose, + title_prefix=f"Child {child_dev.alias}", + indent="\t", + ) + + return + + if name not in dev.features: + error(f"No feature by name '{name}'") + return + + feat = dev.features[name] + + if value is None: + unit = f" {feat.unit}" if feat.unit else "" + echo(f"{feat.name} ({name}): {feat.value}{unit}") + return feat.value + + value = ast.literal_eval(value) + echo(f"Changing {name} from {feat.value} to {value}") + response = await dev.features[name].set_value(value) + await dev.update() + echo(f"New state: {feat.value}") + + return response diff --git a/kasa/cli/lazygroup.py b/kasa/cli/lazygroup.py new file mode 100644 index 00000000..9e9724aa --- /dev/null +++ b/kasa/cli/lazygroup.py @@ -0,0 +1,70 @@ +"""Module for lazily instantiating sub modules. + +Taken from the click help files. +""" + +import importlib + +import asyncclick as click + + +class LazyGroup(click.Group): + """Lazy group class.""" + + def __init__(self, *args, lazy_subcommands=None, **kwargs): + super().__init__(*args, **kwargs) + # lazy_subcommands is a map of the form: + # + # {command-name} -> {module-name}.{command-object-name} + # + self.lazy_subcommands = lazy_subcommands or {} + + def list_commands(self, ctx): + """List click commands.""" + base = super().list_commands(ctx) + lazy = list(self.lazy_subcommands.keys()) + return lazy + base + + def get_command(self, ctx, cmd_name): + """Get click command.""" + if cmd_name in self.lazy_subcommands: + return self._lazy_load(cmd_name) + return super().get_command(ctx, cmd_name) + + def format_commands(self, ctx, formatter): + """Format the top level help output.""" + sections = {} + for cmd, parent in self.lazy_subcommands.items(): + sections.setdefault(parent, []) + cmd_obj = self.get_command(ctx, cmd) + help = cmd_obj.get_short_help_str() + sections[parent].append((cmd, help)) + for section in sections: + if section: + header = ( + f"Common {section} commands (also available " + f"under the `{section}` subcommand)" + ) + else: + header = "Subcommands" + with formatter.section(header): + formatter.write_dl(sections[section]) + + def _lazy_load(self, cmd_name): + # lazily loading a command, first get the module name and attribute name + if not (import_path := self.lazy_subcommands[cmd_name]): + import_path = f".{cmd_name}.{cmd_name}" + else: + import_path = f".{import_path}.{cmd_name}" + modname, cmd_object_name = import_path.rsplit(".", 1) + # do the import + mod = importlib.import_module(modname, package=__package__) + # get the Command object from that module + cmd_object = getattr(mod, cmd_object_name) + # check the result to make debugging easier + if not isinstance(cmd_object, click.BaseCommand): + raise ValueError( + f"Lazy loading of {cmd_name} failed by returning " + "a non-command object" + ) + return cmd_object diff --git a/kasa/cli/light.py b/kasa/cli/light.py new file mode 100644 index 00000000..06c46907 --- /dev/null +++ b/kasa/cli/light.py @@ -0,0 +1,200 @@ +"""Module for cli light control commands.""" + +import asyncclick as click + +from kasa import ( + Device, + Module, +) +from kasa.iot import ( + IotBulb, +) + +from .common import echo, error, pass_dev_or_child + + +@click.group() +@pass_dev_or_child +def light(dev): + """Commands to control light settings.""" + + +@light.command() +@click.argument("brightness", type=click.IntRange(0, 100), default=None, required=False) +@click.option("--transition", type=int, required=False) +@pass_dev_or_child +async def brightness(dev: Device, brightness: int, transition: int): + """Get or set brightness.""" + if not (light := dev.modules.get(Module.Light)) or not light.is_dimmable: + error("This device does not support brightness.") + return + + if brightness is None: + echo(f"Brightness: {light.brightness}") + return light.brightness + else: + echo(f"Setting brightness to {brightness}") + return await light.set_brightness(brightness, transition=transition) + + +@light.command() +@click.argument( + "temperature", type=click.IntRange(2500, 9000), default=None, required=False +) +@click.option("--transition", type=int, required=False) +@pass_dev_or_child +async def temperature(dev: Device, temperature: int, transition: int): + """Get or set color temperature.""" + if not (light := dev.modules.get(Module.Light)) or not light.is_variable_color_temp: + error("Device does not support color temperature") + return + + if temperature is None: + echo(f"Color temperature: {light.color_temp}") + valid_temperature_range = light.valid_temperature_range + if valid_temperature_range != (0, 0): + echo("(min: {}, max: {})".format(*valid_temperature_range)) + else: + echo( + "Temperature range unknown, please open a github issue" + f" or a pull request for model '{dev.model}'" + ) + return light.valid_temperature_range + else: + echo(f"Setting color temperature to {temperature}") + return await light.set_color_temp(temperature, transition=transition) + + +@light.command() +@click.argument("effect", type=click.STRING, default=None, required=False) +@click.pass_context +@pass_dev_or_child +async def effect(dev: Device, ctx, effect): + """Set an effect.""" + if not (light_effect := dev.modules.get(Module.LightEffect)): + error("Device does not support effects") + return + if effect is None: + echo( + f"Light effect: {light_effect.effect}\n" + + f"Available Effects: {light_effect.effect_list}" + ) + return light_effect.effect + + if effect not in light_effect.effect_list: + raise click.BadArgumentUsage( + f"Effect must be one of: {light_effect.effect_list}", ctx + ) + + echo(f"Setting Effect: {effect}") + return await light_effect.set_effect(effect) + + +@light.command() +@click.argument("h", type=click.IntRange(0, 360), default=None, required=False) +@click.argument("s", type=click.IntRange(0, 100), default=None, required=False) +@click.argument("v", type=click.IntRange(0, 100), default=None, required=False) +@click.option("--transition", type=int, required=False) +@click.pass_context +@pass_dev_or_child +async def hsv(dev: Device, ctx, h, s, v, transition): + """Get or set color in HSV.""" + if not (light := dev.modules.get(Module.Light)) or not light.is_color: + error("Device does not support colors") + return + + if h is None and s is None and v is None: + echo(f"Current HSV: {light.hsv}") + return light.hsv + elif s is None or v is None: + raise click.BadArgumentUsage("Setting a color requires 3 values.", ctx) + else: + echo(f"Setting HSV: {h} {s} {v}") + return await light.set_hsv(h, s, v, transition=transition) + + +@light.group(invoke_without_command=True) +@pass_dev_or_child +@click.pass_context +async def presets(ctx, dev): + """List and modify bulb setting presets.""" + if ctx.invoked_subcommand is None: + return await ctx.invoke(presets_list) + + +@presets.command(name="list") +@pass_dev_or_child +def presets_list(dev: Device): + """List presets.""" + if not (light_preset := dev.modules.get(Module.LightPreset)): + error("Presets not supported on device") + return + + for preset in light_preset.preset_states_list: + echo(preset) + + return light_preset.preset_states_list + + +@presets.command(name="modify") +@click.argument("index", type=int) +@click.option("--brightness", type=int) +@click.option("--hue", type=int) +@click.option("--saturation", type=int) +@click.option("--temperature", type=int) +@pass_dev_or_child +async def presets_modify(dev: Device, index, brightness, hue, saturation, temperature): + """Modify a preset.""" + for preset in dev.presets: + if preset.index == index: + break + else: + error(f"No preset found for index {index}") + return + + if brightness is not None: + preset.brightness = brightness + if hue is not None: + preset.hue = hue + if saturation is not None: + preset.saturation = saturation + if temperature is not None: + preset.color_temp = temperature + + echo(f"Going to save preset: {preset}") + + return await dev.save_preset(preset) + + +@light.command() +@pass_dev_or_child +@click.option("--type", type=click.Choice(["soft", "hard"], case_sensitive=False)) +@click.option("--last", is_flag=True) +@click.option("--preset", type=int) +async def turn_on_behavior(dev: Device, type, last, preset): + """Modify bulb turn-on behavior.""" + if not dev.is_bulb or not isinstance(dev, IotBulb): + error("Presets only supported on iot bulbs") + return + settings = await dev.get_turn_on_behavior() + echo(f"Current turn on behavior: {settings}") + + # Return if we are not setting the value + if not type and not last and not preset: + return settings + + # If we are setting the value, the type has to be specified + if (last or preset) and type is None: + echo("To set the behavior, you need to define --type") + return + + behavior = getattr(settings, type) + + if last: + echo(f"Going to set {type} to last") + behavior.preset = None + elif preset is not None: + echo(f"Going to set {type} to preset {preset}") + behavior.preset = preset + + return await dev.set_turn_on_behavior(settings) diff --git a/kasa/cli/main.py b/kasa/cli/main.py index 10c42297..88b768c4 100755 --- a/kasa/cli/main.py +++ b/kasa/cli/main.py @@ -1,4 +1,4 @@ -"""python-kasa cli tool.""" +"""Main module for cli tool.""" from __future__ import annotations @@ -6,282 +6,90 @@ import ast import asyncio import json import logging -import re import sys -from contextlib import asynccontextmanager, contextmanager -from datetime import datetime -from functools import singledispatch, update_wrapper, wraps -from pprint import pformat as pf -from typing import Any, Final, cast +from contextlib import asynccontextmanager +from typing import TYPE_CHECKING, Any import asyncclick as click -from pydantic.v1 import ValidationError -from kasa import ( - AuthenticationError, - Credentials, - Device, - DeviceConfig, - DeviceConnectionParameters, - DeviceEncryptionType, - DeviceFamily, - Discover, - Feature, - KasaException, - Module, - UnsupportedDeviceError, +if TYPE_CHECKING: + from kasa import Device + +from kasa.deviceconfig import DeviceEncryptionType + +from .common import ( + SKIP_UPDATE_COMMANDS, + CatchAllExceptions, + echo, + error, + json_formatter_cb, + pass_dev_or_child, ) -from kasa.discover import DiscoveryResult -from kasa.iot import ( - IotBulb, - IotDevice, - IotDimmer, - IotLightStrip, - IotPlug, - IotStrip, - IotWallSwitch, -) -from kasa.iot.iotstrip import IotStripPlug -from kasa.iot.modules import Usage -from kasa.smart import SmartDevice +from .lazygroup import LazyGroup -try: - from rich import print as _do_echo -except ImportError: - # Strip out rich formatting if rich is not installed - # but only lower case tags to avoid stripping out - # raw data from the device that is printed from - # the device state. - rich_formatting = re.compile(r"\[/?[a-z]+]") - - def _strip_rich_formatting(echo_func): - """Strip rich formatting from messages.""" - - @wraps(echo_func) - def wrapper(message=None, *args, **kwargs): - if message is not None: - message = rich_formatting.sub("", message) - echo_func(message, *args, **kwargs) - - return wrapper - - _do_echo = _strip_rich_formatting(click.echo) - -# echo is set to _do_echo so that it can be reset to _do_echo later after -# --json has set it to _nop_echo -echo = _do_echo - - -def error(msg: str): - """Print an error and exit.""" - echo(f"[bold red]{msg}[/bold red]") - sys.exit(1) - - -# Value for optional options if passed without a value -OPTIONAL_VALUE_FLAG: Final = "_FLAG_" - -TYPE_TO_CLASS = { - "plug": IotPlug, - "switch": IotWallSwitch, - "bulb": IotBulb, - "dimmer": IotDimmer, - "strip": IotStrip, - "lightstrip": IotLightStrip, - "iot.plug": IotPlug, - "iot.switch": IotWallSwitch, - "iot.bulb": IotBulb, - "iot.dimmer": IotDimmer, - "iot.strip": IotStrip, - "iot.lightstrip": IotLightStrip, - "smart.plug": SmartDevice, - "smart.bulb": SmartDevice, -} +TYPES = [ + "plug", + "switch", + "bulb", + "dimmer", + "strip", + "lightstrip", + "smart", +] ENCRYPT_TYPES = [encrypt_type.value for encrypt_type in DeviceEncryptionType] -DEVICE_FAMILY_TYPES = [device_family_type.value for device_family_type in DeviceFamily] -# Block list of commands which require no update -SKIP_UPDATE_COMMANDS = ["raw-command", "command"] - -pass_dev = click.make_pass_decorator(Device) # type: ignore[type-abstract] - - -def CatchAllExceptions(cls): - """Capture all exceptions and prints them nicely. - - Idea from https://stackoverflow.com/a/44347763 and - https://stackoverflow.com/questions/52213375 - """ - - def _handle_exception(debug, exc): - if isinstance(exc, click.ClickException): - raise - # Handle exit request from click. - if isinstance(exc, click.exceptions.Exit): - sys.exit(exc.exit_code) - - echo(f"Raised error: {exc}") - if debug: - raise - echo("Run with --debug enabled to see stacktrace") - sys.exit(1) - - class _CommandCls(cls): - _debug = False - - async def make_context(self, info_name, args, parent=None, **extra): - self._debug = any( - [arg for arg in args if arg in ["--debug", "-d", "--verbose", "-v"]] - ) - try: - return await super().make_context( - info_name, args, parent=parent, **extra - ) - except Exception as exc: - _handle_exception(self._debug, exc) - - async def invoke(self, ctx): - try: - return await super().invoke(ctx) - except Exception as exc: - _handle_exception(self._debug, exc) - - return _CommandCls - - -def json_formatter_cb(result, **kwargs): - """Format and output the result as JSON, if requested.""" - if not kwargs.get("json"): - return - - @singledispatch - def to_serializable(val): - """Regular obj-to-string for json serialization. - - The singledispatch trick is from hynek: https://hynek.me/articles/serialization/ - """ - return str(val) - - @to_serializable.register(Device) - def _device_to_serializable(val: Device): - """Serialize smart device data, just using the last update raw payload.""" - return val.internal_state - - json_content = json.dumps(result, indent=4, default=to_serializable) - print(json_content) - - -def pass_dev_or_child(wrapped_function): - """Pass the device or child to the click command based on the child options.""" - child_help = ( - "Child ID or alias for controlling sub-devices. " - "If no value provided will show an interactive prompt allowing you to " - "select a child." +def _legacy_type_to_class(_type): + from kasa.iot import ( + IotBulb, + IotDimmer, + IotLightStrip, + IotPlug, + IotStrip, + IotWallSwitch, ) - child_index_help = "Child index controlling sub-devices" - @contextmanager - def patched_device_update(parent: Device, child: Device): - try: - orig_update = child.update - # patch child update method. Can be removed once update can be called - # directly on child devices - child.update = parent.update # type: ignore[method-assign] - yield child - finally: - child.update = orig_update # type: ignore[method-assign] - - @click.pass_obj - @click.pass_context - @click.option( - "--child", - "--name", - is_flag=False, - flag_value=OPTIONAL_VALUE_FLAG, - default=None, - required=False, - type=click.STRING, - help=child_help, - ) - @click.option( - "--child-index", - "--index", - required=False, - default=None, - type=click.INT, - help=child_index_help, - ) - async def wrapper(ctx: click.Context, dev, *args, child, child_index, **kwargs): - if child := await _get_child_device(dev, child, child_index, ctx.info_name): - ctx.obj = ctx.with_resource(patched_device_update(dev, child)) - dev = child - return await ctx.invoke(wrapped_function, dev, *args, **kwargs) - - # Update wrapper function to look like wrapped function - return update_wrapper(wrapper, wrapped_function) - - -async def _get_child_device( - device: Device, child_option, child_index_option, info_command -) -> Device | None: - def _list_children(): - return "\n".join( - [ - f"{idx}: {child.device_id} ({child.alias})" - for idx, child in enumerate(device.children) - ] - ) - - if child_option is None and child_index_option is None: - return None - - if info_command in SKIP_UPDATE_COMMANDS: - # The device hasn't had update called (e.g. for cmd_command) - # The way child devices are accessed requires a ChildDevice to - # wrap the communications. Doing this properly would require creating - # a common interfaces for both IOT and SMART child devices. - # As a stop-gap solution, we perform an update instead. - await device.update() - - if not device.children: - error(f"Device: {device.host} does not have children") - - if child_option is not None and child_index_option is not None: - raise click.BadOptionUsage( - "child", "Use either --child or --child-index, not both." - ) - - if child_option is not None: - if child_option is OPTIONAL_VALUE_FLAG: - msg = _list_children() - child_index_option = click.prompt( - f"\n{msg}\nEnter the index number of the child device", - type=click.IntRange(0, len(device.children) - 1), - ) - elif child := device.get_child_device(child_option): - echo(f"Targeting child device {child.alias}") - return child - else: - error( - "No child device found with device_id or name: " - f"{child_option} children are:\n{_list_children()}" - ) - - if child_index_option + 1 > len(device.children) or child_index_option < 0: - error( - f"Invalid index {child_index_option}, " - f"device has {len(device.children)} children" - ) - child_by_index = device.children[child_index_option] - echo(f"Targeting child device {child_by_index.alias}") - return child_by_index + TYPE_TO_CLASS = { + "plug": IotPlug, + "switch": IotWallSwitch, + "bulb": IotBulb, + "dimmer": IotDimmer, + "strip": IotStrip, + "lightstrip": IotLightStrip, + } + return TYPE_TO_CLASS[_type] @click.group( invoke_without_command=True, - cls=CatchAllExceptions(click.Group), + cls=CatchAllExceptions(LazyGroup), + lazy_subcommands={ + "discover": None, + "device": None, + "feature": None, + "light": None, + "wifi": None, + "time": None, + "schedule": None, + "usage": None, + # device commands runnnable at top level + "state": "device", + "on": "device", + "off": "device", + "toggle": "device", + "led": "device", + "alias": "device", + "reboot": "device", + "update_credentials": "device", + "sysinfo": "device", + # light commands runnnable at top level + "presets": "light", + "brightness": "light", + "hsv": "light", + "temperature": "light", + "effect": "light", + }, result_callback=json_formatter_cb, ) @click.option( @@ -332,7 +140,8 @@ async def _get_child_device( "--type", envvar="KASA_TYPE", default=None, - type=click.Choice(list(TYPE_TO_CLASS), case_sensitive=False), + type=click.Choice(TYPES, case_sensitive=False), + help="The device type in order to bypass discovery. Use `smart` for newer devices", ) @click.option( "--json/--no-json", @@ -352,7 +161,7 @@ async def _get_child_device( "--device-family", envvar="KASA_DEVICE_FAMILY", default="SMART.TAPOPLUG", - type=click.Choice(DEVICE_FAMILY_TYPES, case_sensitive=False), + help="Device family type, e.g. `SMART.KASASWITCH`. Deprecated use `--type smart`", ) @click.option( "-lv", @@ -360,6 +169,7 @@ async def _get_child_device( envvar="KASA_LOGIN_VERSION", default=2, type=int, + help="The login version for device authentication. Defaults to 2", ) @click.option( "--timeout", @@ -426,19 +236,6 @@ async def cli( ctx.obj = object() return - # If JSON output is requested, disable echo - global echo - if json: - - def _nop_echo(*args, **kwargs): - pass - - echo = _nop_echo - else: - # Set back to default is required if running tests with CliRunner - global _do_echo - echo = _do_echo - logging_config: dict[str, Any] = { "level": logging.DEBUG if debug > 0 else logging.INFO } @@ -465,6 +262,9 @@ async def cli( if alias is not None and host is None: echo(f"Alias is given, using discovery to find host {alias}") + + from .discover import find_host_from_alias + host = await find_host_from_alias(alias=alias, target=target) if host: echo(f"Found hostname is {host}") @@ -478,6 +278,8 @@ async def cli( ) if username: + from kasa.credentials import Credentials + credentials = Credentials(username=username, password=password) else: credentials = None @@ -487,13 +289,27 @@ async def cli( error("Only discover is available without --host or --alias") echo("No host name given, trying discovery..") + from .discover import discover + return await ctx.invoke(discover) device_updated = False - if type is not None: + if type is not None and type != "smart": + from kasa.deviceconfig import DeviceConfig + config = DeviceConfig(host=host, port_override=port, timeout=timeout) - dev = TYPE_TO_CLASS[type](host, config=config) - elif device_family and encrypt_type: + dev = _legacy_type_to_class(type)(host, config=config) + elif type == "smart" or (device_family and encrypt_type): + from kasa.device import Device + from kasa.deviceconfig import ( + DeviceConfig, + DeviceConnectionParameters, + DeviceEncryptionType, + DeviceFamily, + ) + + if not encrypt_type: + encrypt_type = "KLAP" ctype = DeviceConnectionParameters( DeviceFamily(device_family), DeviceEncryptionType(encrypt_type), @@ -510,6 +326,8 @@ async def cli( dev = await Device.connect(config=config) device_updated = True else: + from kasa.discover import Discover + dev = await Discover.discover_single( host, port=port, @@ -533,777 +351,11 @@ async def cli( ctx.obj = await ctx.with_async_resource(async_wrapped_device(dev)) if ctx.invoked_subcommand is None: + from .device import state + return await ctx.invoke(state) -@cli.group() -@pass_dev -def wifi(dev): - """Commands to control wifi settings.""" - - -@wifi.command() -@pass_dev -async def scan(dev): - """Scan for available wifi networks.""" - echo("Scanning for wifi networks, wait a second..") - devs = await dev.wifi_scan() - echo(f"Found {len(devs)} wifi networks!") - for dev in devs: - echo(f"\t {dev}") - - return devs - - -@wifi.command() -@click.argument("ssid") -@click.option("--keytype", prompt=True) -@click.option("--password", prompt=True, hide_input=True) -@pass_dev -async def join(dev: Device, ssid: str, password: str, keytype: str): - """Join the given wifi network.""" - echo(f"Asking the device to connect to {ssid}..") - res = await dev.wifi_join(ssid, password, keytype=keytype) - echo( - f"Response: {res} - if the device is not able to join the network, " - f"it will revert back to its previous state." - ) - - return res - - -@cli.command() -@click.pass_context -async def discover(ctx): - """Discover devices in the network.""" - target = ctx.parent.params["target"] - username = ctx.parent.params["username"] - password = ctx.parent.params["password"] - discovery_timeout = ctx.parent.params["discovery_timeout"] - timeout = ctx.parent.params["timeout"] - port = ctx.parent.params["port"] - - credentials = Credentials(username, password) if username and password else None - - sem = asyncio.Semaphore() - discovered = dict() - unsupported = [] - auth_failed = [] - - async def print_unsupported(unsupported_exception: UnsupportedDeviceError): - unsupported.append(unsupported_exception) - async with sem: - if unsupported_exception.discovery_result: - echo("== Unsupported device ==") - _echo_discovery_info(unsupported_exception.discovery_result) - echo() - else: - echo("== Unsupported device ==") - echo(f"\t{unsupported_exception}") - echo() - - echo(f"Discovering devices on {target} for {discovery_timeout} seconds") - - async def print_discovered(dev: Device): - async with sem: - try: - await dev.update() - except AuthenticationError: - auth_failed.append(dev._discovery_info) - echo("== Authentication failed for device ==") - _echo_discovery_info(dev._discovery_info) - echo() - else: - ctx.parent.obj = dev - await ctx.parent.invoke(state) - discovered[dev.host] = dev.internal_state - echo() - - discovered_devices = await Discover.discover( - target=target, - discovery_timeout=discovery_timeout, - on_discovered=print_discovered, - on_unsupported=print_unsupported, - port=port, - timeout=timeout, - credentials=credentials, - ) - - for device in discovered_devices.values(): - await device.protocol.close() - - echo(f"Found {len(discovered)} devices") - if unsupported: - echo(f"Found {len(unsupported)} unsupported devices") - if auth_failed: - echo(f"Found {len(auth_failed)} devices that failed to authenticate") - - return discovered - - -def _echo_dictionary(discovery_info: dict): - echo("\t[bold]== Discovery information ==[/bold]") - for key, value in discovery_info.items(): - key_name = " ".join(x.capitalize() or "_" for x in key.split("_")) - key_name_and_spaces = "{:<15}".format(key_name + ":") - echo(f"\t{key_name_and_spaces}{value}") - - -def _echo_discovery_info(discovery_info): - # We don't have discovery info when all connection params are passed manually - if discovery_info is None: - return - - if "system" in discovery_info and "get_sysinfo" in discovery_info["system"]: - _echo_dictionary(discovery_info["system"]["get_sysinfo"]) - return - - try: - dr = DiscoveryResult(**discovery_info) - except ValidationError: - _echo_dictionary(discovery_info) - return - - echo("\t[bold]== Discovery Result ==[/bold]") - echo(f"\tDevice Type: {dr.device_type}") - echo(f"\tDevice Model: {dr.device_model}") - echo(f"\tIP: {dr.ip}") - echo(f"\tMAC: {dr.mac}") - echo(f"\tDevice Id (hash): {dr.device_id}") - echo(f"\tOwner (hash): {dr.owner}") - echo(f"\tHW Ver: {dr.hw_ver}") - echo(f"\tSupports IOT Cloud: {dr.is_support_iot_cloud}") - echo(f"\tOBD Src: {dr.obd_src}") - echo(f"\tFactory Default: {dr.factory_default}") - echo(f"\tEncrypt Type: {dr.mgt_encrypt_schm.encrypt_type}") - echo(f"\tSupports HTTPS: {dr.mgt_encrypt_schm.is_support_https}") - echo(f"\tHTTP Port: {dr.mgt_encrypt_schm.http_port}") - echo(f"\tLV (Login Level): {dr.mgt_encrypt_schm.lv}") - - -async def find_host_from_alias(alias, target="255.255.255.255", timeout=1, attempts=3): - """Discover a device identified by its alias.""" - for _attempt in range(1, attempts): - found_devs = await Discover.discover(target=target, timeout=timeout) - for _ip, dev in found_devs.items(): - if dev.alias.lower() == alias.lower(): - host = dev.host - return host - - return None - - -@cli.command() -@pass_dev_or_child -async def sysinfo(dev): - """Print out full system information.""" - echo("== System info ==") - echo(pf(dev.sys_info)) - return dev.sys_info - - -def _echo_features( - features: dict[str, Feature], - title: str, - category: Feature.Category | None = None, - verbose: bool = False, - indent: str = "\t", -): - """Print out a listing of features and their values.""" - if category is not None: - features = { - id_: feat for id_, feat in features.items() if feat.category == category - } - - echo(f"{indent}[bold]{title}[/bold]") - for _, feat in features.items(): - try: - echo(f"{indent}{feat}") - if verbose: - echo(f"{indent}\tType: {feat.type}") - echo(f"{indent}\tCategory: {feat.category}") - echo(f"{indent}\tIcon: {feat.icon}") - except Exception as ex: - echo(f"{indent}{feat.name} ({feat.id}): [red]got exception ({ex})[/red]") - - -def _echo_all_features(features, *, verbose=False, title_prefix=None, indent=""): - """Print out all features by category.""" - if title_prefix is not None: - echo(f"[bold]\n{indent}== {title_prefix} ==[/bold]") - echo() - _echo_features( - features, - title="== Primary features ==", - category=Feature.Category.Primary, - verbose=verbose, - indent=indent, - ) - echo() - _echo_features( - features, - title="== Information ==", - category=Feature.Category.Info, - verbose=verbose, - indent=indent, - ) - echo() - _echo_features( - features, - title="== Configuration ==", - category=Feature.Category.Config, - verbose=verbose, - indent=indent, - ) - echo() - _echo_features( - features, - title="== Debug ==", - category=Feature.Category.Debug, - verbose=verbose, - indent=indent, - ) - - -@cli.command() -@pass_dev_or_child -@click.pass_context -async def state(ctx, dev: Device): - """Print out device state and versions.""" - verbose = ctx.parent.params.get("verbose", False) if ctx.parent else False - - echo(f"[bold]== {dev.alias} - {dev.model} ==[/bold]") - echo(f"Host: {dev.host}") - echo(f"Port: {dev.port}") - echo(f"Device state: {dev.is_on}") - - echo(f"Time: {dev.time} (tz: {dev.timezone}") - echo(f"Hardware: {dev.hw_info['hw_ver']}") - echo(f"Software: {dev.hw_info['sw_ver']}") - echo(f"MAC (rssi): {dev.mac} ({dev.rssi})") - if verbose: - echo(f"Location: {dev.location}") - - echo() - _echo_all_features(dev.features, verbose=verbose) - - if verbose: - echo("\n[bold]== Modules ==[/bold]") - for module in dev.modules.values(): - echo(f"[green]+ {module}[/green]") - - if dev.children: - echo("\n[bold]== Children ==[/bold]") - for child in dev.children: - _echo_all_features( - child.features, - title_prefix=f"{child.alias} ({child.model})", - verbose=verbose, - indent="\t", - ) - if verbose: - echo(f"\n\t[bold]== Child {child.alias} Modules ==[/bold]") - for module in child.modules.values(): - echo(f"\t[green]+ {module}[/green]") - echo() - - if verbose: - echo("\n\t[bold]== Protocol information ==[/bold]") - echo(f"\tCredentials hash: {dev.credentials_hash}") - echo() - _echo_discovery_info(dev._discovery_info) - - return dev.internal_state - - -@cli.command() -@click.argument("new_alias", required=False, default=None) -@pass_dev_or_child -async def alias(dev, new_alias): - """Get or set the device (or plug) alias.""" - if new_alias is not None: - echo(f"Setting alias to {new_alias}") - res = await dev.set_alias(new_alias) - await dev.update() - echo(f"Alias set to: {dev.alias}") - return res - - echo(f"Alias: {dev.alias}") - if dev.children: - for plug in dev.children: - echo(f" * {plug.alias}") - - return dev.alias - - -@cli.command() -@click.pass_context -@click.argument("module") -@click.argument("command") -@click.argument("parameters", default=None, required=False) -async def raw_command(ctx, module, command, parameters): - """Run a raw command on the device.""" - logging.warning("Deprecated, use 'kasa command --module %s %s'", module, command) - return await ctx.forward(cmd_command) - - -@cli.command(name="command") -@click.option("--module", required=False, help="Module for IOT protocol.") -@click.argument("command") -@click.argument("parameters", default=None, required=False) -@pass_dev_or_child -async def cmd_command(dev: Device, module, command, parameters): - """Run a raw command on the device.""" - if parameters is not None: - parameters = ast.literal_eval(parameters) - - if isinstance(dev, IotDevice): - res = await dev._query_helper(module, command, parameters) - elif isinstance(dev, SmartDevice): - res = await dev._query_helper(command, parameters) - else: - raise KasaException("Unexpected device type %s.", dev) - echo(json.dumps(res)) - return res - - -@cli.command() -@click.option("--index", type=int, required=False) -@click.option("--name", type=str, required=False) -@click.option("--year", type=click.DateTime(["%Y"]), default=None, required=False) -@click.option("--month", type=click.DateTime(["%Y-%m"]), default=None, required=False) -@click.option("--erase", is_flag=True) -@click.pass_context -async def emeter(ctx: click.Context, index, name, year, month, erase): - """Query emeter for historical consumption.""" - logging.warning("Deprecated, use 'kasa energy'") - return await ctx.invoke( - energy, child_index=index, child=name, year=year, month=month, erase=erase - ) - - -@cli.command() -@click.option("--year", type=click.DateTime(["%Y"]), default=None, required=False) -@click.option("--month", type=click.DateTime(["%Y-%m"]), default=None, required=False) -@click.option("--erase", is_flag=True) -@pass_dev_or_child -async def energy(dev: Device, year, month, erase): - """Query energy module for historical consumption. - - Daily and monthly data provided in CSV format. - """ - echo("[bold]== Emeter ==[/bold]") - if not dev.has_emeter: - error("Device has no emeter") - return - - if (year or month or erase) and not isinstance(dev, IotDevice): - error("Device has no historical statistics") - return - else: - dev = cast(IotDevice, dev) - - if erase: - echo("Erasing emeter statistics..") - return await dev.erase_emeter_stats() - - if year: - echo(f"== For year {year.year} ==") - echo("Month, usage (kWh)") - usage_data = await dev.get_emeter_monthly(year=year.year) - elif month: - echo(f"== For month {month.month} of {month.year} ==") - echo("Day, usage (kWh)") - usage_data = await dev.get_emeter_daily(year=month.year, month=month.month) - else: - # Call with no argument outputs summary data and returns - if isinstance(dev, IotStripPlug): - emeter_status = await dev.get_emeter_realtime() - else: - emeter_status = dev.emeter_realtime - - echo("Current: %s A" % emeter_status["current"]) - echo("Voltage: %s V" % emeter_status["voltage"]) - echo("Power: %s W" % emeter_status["power"]) - echo("Total consumption: %s kWh" % emeter_status["total"]) - - echo("Today: %s kWh" % dev.emeter_today) - echo("This month: %s kWh" % dev.emeter_this_month) - - return emeter_status - - # output any detailed usage data - for index, usage in usage_data.items(): - echo(f"{index}, {usage}") - - return usage_data - - -@cli.command() -@click.option("--year", type=click.DateTime(["%Y"]), default=None, required=False) -@click.option("--month", type=click.DateTime(["%Y-%m"]), default=None, required=False) -@click.option("--erase", is_flag=True) -@pass_dev_or_child -async def usage(dev: Device, year, month, erase): - """Query usage for historical consumption. - - Daily and monthly data provided in CSV format. - """ - echo("[bold]== Usage ==[/bold]") - usage = cast(Usage, dev.modules["usage"]) - - if erase: - echo("Erasing usage statistics..") - return await usage.erase_stats() - - if year: - echo(f"== For year {year.year} ==") - echo("Month, usage (minutes)") - usage_data = await usage.get_monthstat(year=year.year) - elif month: - echo(f"== For month {month.month} of {month.year} ==") - echo("Day, usage (minutes)") - usage_data = await usage.get_daystat(year=month.year, month=month.month) - else: - # Call with no argument outputs summary data and returns - echo("Today: %s minutes" % usage.usage_today) - echo("This month: %s minutes" % usage.usage_this_month) - - return usage - - # output any detailed usage data - for index, usage in usage_data.items(): - echo(f"{index}, {usage}") - - return usage_data - - -@cli.command() -@click.argument("brightness", type=click.IntRange(0, 100), default=None, required=False) -@click.option("--transition", type=int, required=False) -@pass_dev_or_child -async def brightness(dev: Device, brightness: int, transition: int): - """Get or set brightness.""" - if not (light := dev.modules.get(Module.Light)) or not light.is_dimmable: - error("This device does not support brightness.") - return - - if brightness is None: - echo(f"Brightness: {light.brightness}") - return light.brightness - else: - echo(f"Setting brightness to {brightness}") - return await light.set_brightness(brightness, transition=transition) - - -@cli.command() -@click.argument( - "temperature", type=click.IntRange(2500, 9000), default=None, required=False -) -@click.option("--transition", type=int, required=False) -@pass_dev_or_child -async def temperature(dev: Device, temperature: int, transition: int): - """Get or set color temperature.""" - if not (light := dev.modules.get(Module.Light)) or not light.is_variable_color_temp: - error("Device does not support color temperature") - return - - if temperature is None: - echo(f"Color temperature: {light.color_temp}") - valid_temperature_range = light.valid_temperature_range - if valid_temperature_range != (0, 0): - echo("(min: {}, max: {})".format(*valid_temperature_range)) - else: - echo( - "Temperature range unknown, please open a github issue" - f" or a pull request for model '{dev.model}'" - ) - return light.valid_temperature_range - else: - echo(f"Setting color temperature to {temperature}") - return await light.set_color_temp(temperature, transition=transition) - - -@cli.command() -@click.argument("effect", type=click.STRING, default=None, required=False) -@click.pass_context -@pass_dev_or_child -async def effect(dev: Device, ctx, effect): - """Set an effect.""" - if not (light_effect := dev.modules.get(Module.LightEffect)): - error("Device does not support effects") - return - if effect is None: - echo( - f"Light effect: {light_effect.effect}\n" - + f"Available Effects: {light_effect.effect_list}" - ) - return light_effect.effect - - if effect not in light_effect.effect_list: - raise click.BadArgumentUsage( - f"Effect must be one of: {light_effect.effect_list}", ctx - ) - - echo(f"Setting Effect: {effect}") - return await light_effect.set_effect(effect) - - -@cli.command() -@click.argument("h", type=click.IntRange(0, 360), default=None, required=False) -@click.argument("s", type=click.IntRange(0, 100), default=None, required=False) -@click.argument("v", type=click.IntRange(0, 100), default=None, required=False) -@click.option("--transition", type=int, required=False) -@click.pass_context -@pass_dev_or_child -async def hsv(dev: Device, ctx, h, s, v, transition): - """Get or set color in HSV.""" - if not (light := dev.modules.get(Module.Light)) or not light.is_color: - error("Device does not support colors") - return - - if h is None and s is None and v is None: - echo(f"Current HSV: {light.hsv}") - return light.hsv - elif s is None or v is None: - raise click.BadArgumentUsage("Setting a color requires 3 values.", ctx) - else: - echo(f"Setting HSV: {h} {s} {v}") - return await light.set_hsv(h, s, v, transition=transition) - - -@cli.command() -@click.argument("state", type=bool, required=False) -@pass_dev_or_child -async def led(dev: Device, state): - """Get or set (Plug's) led state.""" - if not (led := dev.modules.get(Module.Led)): - error("Device does not support led.") - return - if state is not None: - echo(f"Turning led to {state}") - return await led.set_led(state) - else: - echo(f"LED state: {led.led}") - return led.led - - -@cli.group(invoke_without_command=True) -@click.pass_context -async def time(ctx: click.Context): - """Get and set time.""" - if ctx.invoked_subcommand is None: - await ctx.invoke(time_get) - - -@time.command(name="get") -@pass_dev -async def time_get(dev: Device): - """Get the device time.""" - res = dev.time - echo(f"Current time: {res}") - return res - - -@time.command(name="sync") -@pass_dev -async def time_sync(dev: Device): - """Set the device time to current time.""" - if not isinstance(dev, SmartDevice): - raise NotImplementedError("setting time currently only implemented on smart") - - if (time := dev.modules.get(Module.Time)) is None: - echo("Device does not have time module") - return - - echo("Old time: %s" % time.time) - - local_tz = datetime.now().astimezone().tzinfo - await time.set_time(datetime.now(tz=local_tz)) - - await dev.update() - echo("New time: %s" % time.time) - - -@cli.command() -@click.option("--transition", type=int, required=False) -@pass_dev_or_child -async def on(dev: Device, transition: int): - """Turn the device on.""" - echo(f"Turning on {dev.alias}") - return await dev.turn_on(transition=transition) - - -@cli.command -@click.option("--transition", type=int, required=False) -@pass_dev_or_child -async def off(dev: Device, transition: int): - """Turn the device off.""" - echo(f"Turning off {dev.alias}") - return await dev.turn_off(transition=transition) - - -@cli.command() -@click.option("--transition", type=int, required=False) -@pass_dev_or_child -async def toggle(dev: Device, transition: int): - """Toggle the device on/off.""" - if dev.is_on: - echo(f"Turning off {dev.alias}") - return await dev.turn_off(transition=transition) - - echo(f"Turning on {dev.alias}") - return await dev.turn_on(transition=transition) - - -@cli.command() -@click.option("--delay", default=1) -@pass_dev -async def reboot(plug, delay): - """Reboot the device.""" - echo("Rebooting the device..") - return await plug.reboot(delay) - - -@cli.group() -@pass_dev -async def schedule(dev): - """Scheduling commands.""" - - -@schedule.command(name="list") -@pass_dev_or_child -@click.argument("type", default="schedule") -async def _schedule_list(dev, type): - """Return the list of schedule actions for the given type.""" - sched = dev.modules[type] - for rule in sched.rules: - print(rule) - else: - error(f"No rules of type {type}") - - return sched.rules - - -@schedule.command(name="delete") -@pass_dev_or_child -@click.option("--id", type=str, required=True) -async def delete_rule(dev, id): - """Delete rule from device.""" - schedule = dev.modules["schedule"] - rule_to_delete = next(filter(lambda rule: (rule.id == id), schedule.rules), None) - if rule_to_delete: - echo(f"Deleting rule id {id}") - return await schedule.delete_rule(rule_to_delete) - else: - error(f"No rule with id {id} was found") - - -@cli.group(invoke_without_command=True) -@pass_dev_or_child -@click.pass_context -async def presets(ctx, dev): - """List and modify bulb setting presets.""" - if ctx.invoked_subcommand is None: - return await ctx.invoke(presets_list) - - -@presets.command(name="list") -@pass_dev_or_child -def presets_list(dev: Device): - """List presets.""" - if not (light_preset := dev.modules.get(Module.LightPreset)): - error("Presets not supported on device") - return - - for preset in light_preset.preset_states_list: - echo(preset) - - return light_preset.preset_states_list - - -@presets.command(name="modify") -@click.argument("index", type=int) -@click.option("--brightness", type=int) -@click.option("--hue", type=int) -@click.option("--saturation", type=int) -@click.option("--temperature", type=int) -@pass_dev_or_child -async def presets_modify(dev: Device, index, brightness, hue, saturation, temperature): - """Modify a preset.""" - for preset in dev.presets: - if preset.index == index: - break - else: - error(f"No preset found for index {index}") - return - - if brightness is not None: - preset.brightness = brightness - if hue is not None: - preset.hue = hue - if saturation is not None: - preset.saturation = saturation - if temperature is not None: - preset.color_temp = temperature - - echo(f"Going to save preset: {preset}") - - return await dev.save_preset(preset) - - -@cli.command() -@pass_dev_or_child -@click.option("--type", type=click.Choice(["soft", "hard"], case_sensitive=False)) -@click.option("--last", is_flag=True) -@click.option("--preset", type=int) -async def turn_on_behavior(dev: Device, type, last, preset): - """Modify bulb turn-on behavior.""" - if not dev.is_bulb or not isinstance(dev, IotBulb): - error("Presets only supported on iot bulbs") - return - settings = await dev.get_turn_on_behavior() - echo(f"Current turn on behavior: {settings}") - - # Return if we are not setting the value - if not type and not last and not preset: - return settings - - # If we are setting the value, the type has to be specified - if (last or preset) and type is None: - echo("To set the behavior, you need to define --type") - return - - behavior = getattr(settings, type) - - if last: - echo(f"Going to set {type} to last") - behavior.preset = None - elif preset is not None: - echo(f"Going to set {type} to preset {preset}") - behavior.preset = preset - - return await dev.set_turn_on_behavior(settings) - - -@cli.command() -@pass_dev -@click.option( - "--username", required=True, prompt=True, help="New username to set on the device" -) -@click.option( - "--password", required=True, prompt=True, help="New password to set on the device" -) -async def update_credentials(dev, username, password): - """Update device credentials for authenticated devices.""" - if not isinstance(dev, SmartDevice): - error("Credentials can only be updated on authenticated devices.") - - click.confirm("Do you really want to replace the existing credentials?", abort=True) - - return await dev.update_credentials(username, password) - - @cli.command() @pass_dev_or_child async def shell(dev: Device): @@ -1325,58 +377,36 @@ async def shell(dev: Device): loop.stop() -@cli.command(name="feature") -@click.argument("name", required=False) -@click.argument("value", required=False) -@pass_dev_or_child +@cli.command() @click.pass_context -async def feature( - ctx: click.Context, - dev: Device, - name: str, - value, -): - """Access and modify features. - - If no *name* is given, lists available features and their values. - If only *name* is given, the value of named feature is returned. - If both *name* and *value* are set, the described setting is changed. - """ - verbose = ctx.parent.params.get("verbose", False) if ctx.parent else False - - if not name: - _echo_all_features(dev.features, verbose=verbose, indent="") - - if dev.children: - for child_dev in dev.children: - _echo_all_features( - child_dev.features, - verbose=verbose, - title_prefix=f"Child {child_dev.alias}", - indent="\t", - ) - - return - - if name not in dev.features: - error(f"No feature by name '{name}'") - return - - feat = dev.features[name] - - if value is None: - unit = f" {feat.unit}" if feat.unit else "" - echo(f"{feat.name} ({name}): {feat.value}{unit}") - return feat.value - - value = ast.literal_eval(value) - echo(f"Changing {name} from {feat.value} to {value}") - response = await dev.features[name].set_value(value) - await dev.update() - echo(f"New state: {feat.value}") - - return response +@click.argument("module") +@click.argument("command") +@click.argument("parameters", default=None, required=False) +async def raw_command(ctx, module, command, parameters): + """Run a raw command on the device.""" + logging.warning("Deprecated, use 'kasa command --module %s %s'", module, command) + return await ctx.forward(cmd_command) -if __name__ == "__main__": - cli() +@cli.command(name="command") +@click.option("--module", required=False, help="Module for IOT protocol.") +@click.argument("command") +@click.argument("parameters", default=None, required=False) +@pass_dev_or_child +async def cmd_command(dev: Device, module, command, parameters): + """Run a raw command on the device.""" + if parameters is not None: + parameters = ast.literal_eval(parameters) + + from kasa import KasaException + from kasa.iot import IotDevice + from kasa.smart import SmartDevice + + if isinstance(dev, IotDevice): + res = await dev._query_helper(module, command, parameters) + elif isinstance(dev, SmartDevice): + res = await dev._query_helper(command, parameters) + else: + raise KasaException("Unexpected device type %s.", dev) + echo(json.dumps(res)) + return res diff --git a/kasa/cli/schedule.py b/kasa/cli/schedule.py new file mode 100644 index 00000000..8deda315 --- /dev/null +++ b/kasa/cli/schedule.py @@ -0,0 +1,46 @@ +"""Module for cli schedule commands..""" + +from __future__ import annotations + +import asyncclick as click + +from .common import ( + echo, + error, + pass_dev, + pass_dev_or_child, +) + + +@click.group() +@pass_dev +async def schedule(dev): + """Scheduling commands.""" + + +@schedule.command(name="list") +@pass_dev_or_child +@click.argument("type", default="schedule") +async def _schedule_list(dev, type): + """Return the list of schedule actions for the given type.""" + sched = dev.modules[type] + for rule in sched.rules: + print(rule) + else: + error(f"No rules of type {type}") + + return sched.rules + + +@schedule.command(name="delete") +@pass_dev_or_child +@click.option("--id", type=str, required=True) +async def delete_rule(dev, id): + """Delete rule from device.""" + schedule = dev.modules["schedule"] + rule_to_delete = next(filter(lambda rule: (rule.id == id), schedule.rules), None) + if rule_to_delete: + echo(f"Deleting rule id {id}") + return await schedule.delete_rule(rule_to_delete) + else: + error(f"No rule with id {id} was found") diff --git a/kasa/cli/time.py b/kasa/cli/time.py new file mode 100644 index 00000000..c6681222 --- /dev/null +++ b/kasa/cli/time.py @@ -0,0 +1,55 @@ +"""Module for cli time commands..""" + +from __future__ import annotations + +from datetime import datetime + +import asyncclick as click + +from kasa import ( + Device, + Module, +) +from kasa.smart import SmartDevice + +from .common import ( + echo, + pass_dev, +) + + +@click.group(invoke_without_command=True) +@click.pass_context +async def time(ctx: click.Context): + """Get and set time.""" + if ctx.invoked_subcommand is None: + await ctx.invoke(time_get) + + +@time.command(name="get") +@pass_dev +async def time_get(dev: Device): + """Get the device time.""" + res = dev.time + echo(f"Current time: {res}") + return res + + +@time.command(name="sync") +@pass_dev +async def time_sync(dev: Device): + """Set the device time to current time.""" + if not isinstance(dev, SmartDevice): + raise NotImplementedError("setting time currently only implemented on smart") + + if (time := dev.modules.get(Module.Time)) is None: + echo("Device does not have time module") + return + + echo("Old time: %s" % time.time) + + local_tz = datetime.now().astimezone().tzinfo + await time.set_time(datetime.now(tz=local_tz)) + + await dev.update() + echo("New time: %s" % time.time) diff --git a/kasa/cli/usage.py b/kasa/cli/usage.py new file mode 100644 index 00000000..1a336c74 --- /dev/null +++ b/kasa/cli/usage.py @@ -0,0 +1,134 @@ +"""Module for cli usage commands..""" + +from __future__ import annotations + +import logging +from typing import cast + +import asyncclick as click + +from kasa import ( + Device, +) +from kasa.iot import ( + IotDevice, +) +from kasa.iot.iotstrip import IotStripPlug +from kasa.iot.modules import Usage + +from .common import ( + echo, + error, + pass_dev_or_child, +) + + +@click.command() +@click.option("--index", type=int, required=False) +@click.option("--name", type=str, required=False) +@click.option("--year", type=click.DateTime(["%Y"]), default=None, required=False) +@click.option("--month", type=click.DateTime(["%Y-%m"]), default=None, required=False) +@click.option("--erase", is_flag=True) +@click.pass_context +async def emeter(ctx: click.Context, index, name, year, month, erase): + """Query emeter for historical consumption.""" + logging.warning("Deprecated, use 'kasa energy'") + return await ctx.invoke( + energy, child_index=index, child=name, year=year, month=month, erase=erase + ) + + +@click.command() +@click.option("--year", type=click.DateTime(["%Y"]), default=None, required=False) +@click.option("--month", type=click.DateTime(["%Y-%m"]), default=None, required=False) +@click.option("--erase", is_flag=True) +@pass_dev_or_child +async def energy(dev: Device, year, month, erase): + """Query energy module for historical consumption. + + Daily and monthly data provided in CSV format. + """ + echo("[bold]== Emeter ==[/bold]") + if not dev.has_emeter: + error("Device has no emeter") + return + + if (year or month or erase) and not isinstance(dev, IotDevice): + error("Device has no historical statistics") + return + else: + dev = cast(IotDevice, dev) + + if erase: + echo("Erasing emeter statistics..") + return await dev.erase_emeter_stats() + + if year: + echo(f"== For year {year.year} ==") + echo("Month, usage (kWh)") + usage_data = await dev.get_emeter_monthly(year=year.year) + elif month: + echo(f"== For month {month.month} of {month.year} ==") + echo("Day, usage (kWh)") + usage_data = await dev.get_emeter_daily(year=month.year, month=month.month) + else: + # Call with no argument outputs summary data and returns + if isinstance(dev, IotStripPlug): + emeter_status = await dev.get_emeter_realtime() + else: + emeter_status = dev.emeter_realtime + + echo("Current: %s A" % emeter_status["current"]) + echo("Voltage: %s V" % emeter_status["voltage"]) + echo("Power: %s W" % emeter_status["power"]) + echo("Total consumption: %s kWh" % emeter_status["total"]) + + echo("Today: %s kWh" % dev.emeter_today) + echo("This month: %s kWh" % dev.emeter_this_month) + + return emeter_status + + # output any detailed usage data + for index, usage in usage_data.items(): + echo(f"{index}, {usage}") + + return usage_data + + +@click.command() +@click.option("--year", type=click.DateTime(["%Y"]), default=None, required=False) +@click.option("--month", type=click.DateTime(["%Y-%m"]), default=None, required=False) +@click.option("--erase", is_flag=True) +@pass_dev_or_child +async def usage(dev: Device, year, month, erase): + """Query usage for historical consumption. + + Daily and monthly data provided in CSV format. + """ + echo("[bold]== Usage ==[/bold]") + usage = cast(Usage, dev.modules["usage"]) + + if erase: + echo("Erasing usage statistics..") + return await usage.erase_stats() + + if year: + echo(f"== For year {year.year} ==") + echo("Month, usage (minutes)") + usage_data = await usage.get_monthstat(year=year.year) + elif month: + echo(f"== For month {month.month} of {month.year} ==") + echo("Day, usage (minutes)") + usage_data = await usage.get_daystat(year=month.year, month=month.month) + else: + # Call with no argument outputs summary data and returns + echo("Today: %s minutes" % usage.usage_today) + echo("This month: %s minutes" % usage.usage_this_month) + + return usage + + # output any detailed usage data + for index, usage in usage_data.items(): + echo(f"{index}, {usage}") + + return usage_data diff --git a/kasa/cli/wifi.py b/kasa/cli/wifi.py new file mode 100644 index 00000000..07fb5f20 --- /dev/null +++ b/kasa/cli/wifi.py @@ -0,0 +1,50 @@ +"""Module for cli wifi commands.""" + +from __future__ import annotations + +import asyncclick as click + +from kasa import ( + Device, +) + +from .common import ( + echo, + pass_dev, +) + + +@click.group() +@pass_dev +def wifi(dev): + """Commands to control wifi settings.""" + + +@wifi.command() +@pass_dev +async def scan(dev): + """Scan for available wifi networks.""" + echo("Scanning for wifi networks, wait a second..") + devs = await dev.wifi_scan() + echo(f"Found {len(devs)} wifi networks!") + for dev in devs: + echo(f"\t {dev}") + + return devs + + +@wifi.command() +@click.argument("ssid") +@click.option("--keytype", prompt=True) +@click.option("--password", prompt=True, hide_input=True) +@pass_dev +async def join(dev: Device, ssid: str, password: str, keytype: str): + """Join the given wifi network.""" + echo(f"Asking the device to connect to {ssid}..") + res = await dev.wifi_join(ssid, password, keytype=keytype) + echo( + f"Response: {res} - if the device is not able to join the network, " + f"it will revert back to its previous state." + ) + + return res diff --git a/kasa/tests/test_cli.py b/kasa/tests/test_cli.py index e6b96cd7..e55f4d01 100644 --- a/kasa/tests/test_cli.py +++ b/kasa/tests/test_cli.py @@ -17,29 +17,28 @@ from kasa import ( Module, UnsupportedDeviceError, ) -from kasa.cli.main import ( - TYPE_TO_CLASS, +from kasa.cli.device import ( alias, - brightness, - cli, - cmd_command, - effect, - emeter, - energy, - hsv, led, - raw_command, reboot, state, sysinfo, - temperature, - time, toggle, update_credentials, - wifi, ) +from kasa.cli.light import ( + brightness, + effect, + hsv, + temperature, +) +from kasa.cli.main import TYPES, _legacy_type_to_class, cli, cmd_command, raw_command +from kasa.cli.time import time +from kasa.cli.usage import emeter, energy +from kasa.cli.wifi import wifi from kasa.discover import Discover, DiscoveryResult from kasa.iot import IotDevice +from kasa.smart import SmartDevice from .conftest import ( device_smart, @@ -59,6 +58,12 @@ def runner(): return runner +async def test_help(runner): + """Test that all the lazy modules are correctly names.""" + res = await runner.invoke(cli, ["--help"]) + assert res.exit_code == 0, "--help failed, check lazy module names" + + @pytest.mark.parametrize( ("device_family", "encrypt_type"), [ @@ -500,7 +505,7 @@ async def test_credentials(discovery_mock, mocker, runner): f"Username:{dev.credentials.username} Password:{dev.credentials.password}" ) - mocker.patch("kasa.cli.main.state", new=_state) + mocker.patch("kasa.cli.device.state", new=_state) dr = DiscoveryResult(**discovery_mock.discovery_data["result"]) res = await runner.invoke( @@ -735,7 +740,7 @@ async def test_host_auth_failed(discovery_mock, mocker, runner): assert isinstance(res.exception, AuthenticationError) -@pytest.mark.parametrize("device_type", list(TYPE_TO_CLASS)) +@pytest.mark.parametrize("device_type", TYPES) async def test_type_param(device_type, mocker, runner): """Test for handling only one of username or password supplied.""" result_device = FileNotFoundError @@ -746,8 +751,11 @@ async def test_type_param(device_type, mocker, runner): nonlocal result_device result_device = dev - mocker.patch("kasa.cli.main.state", new=_state) - expected_type = TYPE_TO_CLASS[device_type] + mocker.patch("kasa.cli.device.state", new=_state) + if device_type == "smart": + expected_type = SmartDevice + else: + expected_type = _legacy_type_to_class(device_type) mocker.patch.object(expected_type, "update") res = await runner.invoke( cli, diff --git a/pyproject.toml b/pyproject.toml index 91317f48..c5c87072 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ include = [ "Documentation" = "https://python-kasa.readthedocs.io" [tool.poetry.scripts] -kasa = "kasa.cli:__main__" +kasa = "kasa.cli.__main__:cli" [tool.poetry.dependencies] python = "^3.9"