Support child devices in all applicable cli commands (#1020)

Adds a new decorator that adds child options to a command and gets the
child device if the options are set.

- Single definition of options and error handling
- Adds options automatically to command
- Backwards compatible with `--index` and `--name`
- `--child` allows for id and alias for ease of use
- Omitting a value for `--child` gives an interactive prompt

Implements private `_update` to allow the CLI to patch a child `update`
method to call the parent device `update`.

Example help output:
```
$ kasa brightness --help
Usage: kasa brightness [OPTIONS] [BRIGHTNESS]

  Get or set brightness.

Options:
  --transition INTEGER
  --child, --name TEXT            Child ID or alias for controlling sub-
                                  devices. If no value provided will show an
                                  interactive prompt allowing you to select a
                                  child.
  --child-index, --index INTEGER  Child index controlling sub-devices
  --help                          Show this message and exit.
```

Fixes #769
This commit is contained in:
Steven B 2024-07-02 14:11:19 +01:00 committed by GitHub
parent b8a87f1c57
commit 9cffbe9e48
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 333 additions and 134 deletions

View File

@ -8,11 +8,11 @@ import json
import logging import logging
import re import re
import sys import sys
from contextlib import asynccontextmanager from contextlib import asynccontextmanager, contextmanager
from datetime import datetime from datetime import datetime
from functools import singledispatch, wraps from functools import singledispatch, update_wrapper, wraps
from pprint import pformat as pf from pprint import pformat as pf
from typing import Any, cast from typing import Any, Final, cast
import asyncclick as click import asyncclick as click
from pydantic.v1 import ValidationError from pydantic.v1 import ValidationError
@ -41,6 +41,7 @@ from kasa.iot import (
IotStrip, IotStrip,
IotWallSwitch, IotWallSwitch,
) )
from kasa.iot.iotstrip import IotStripPlug
from kasa.iot.modules import Usage from kasa.iot.modules import Usage
from kasa.smart import SmartDevice from kasa.smart import SmartDevice
@ -77,6 +78,9 @@ def error(msg: str):
sys.exit(1) sys.exit(1)
# Value for optional options if passed without a value
OPTIONAL_VALUE_FLAG: Final = "_FLAG_"
TYPE_TO_CLASS = { TYPE_TO_CLASS = {
"plug": IotPlug, "plug": IotPlug,
"switch": IotWallSwitch, "switch": IotWallSwitch,
@ -169,6 +173,112 @@ def json_formatter_cb(result, **kwargs):
print(json_content) print(json_content)
def pass_dev_or_child(wrapped_function):
"""Pass the device or child to the click command based on the child options."""
child_help = (
"Child ID or alias for controlling sub-devices. "
"If no value provided will show an interactive prompt allowing you to "
"select a child."
)
child_index_help = "Child index controlling sub-devices"
@contextmanager
def patched_device_update(parent: Device, child: Device):
try:
orig_update = child.update
# patch child update method. Can be removed once update can be called
# directly on child devices
child.update = parent.update # type: ignore[method-assign]
yield child
finally:
child.update = orig_update # type: ignore[method-assign]
@click.pass_obj
@click.pass_context
@click.option(
"--child",
"--name",
is_flag=False,
flag_value=OPTIONAL_VALUE_FLAG,
default=None,
required=False,
type=click.STRING,
help=child_help,
)
@click.option(
"--child-index",
"--index",
required=False,
default=None,
type=click.INT,
help=child_index_help,
)
async def wrapper(ctx: click.Context, dev, *args, child, child_index, **kwargs):
if child := await _get_child_device(dev, child, child_index, ctx.info_name):
ctx.obj = ctx.with_resource(patched_device_update(dev, child))
dev = child
return await ctx.invoke(wrapped_function, dev, *args, **kwargs)
# Update wrapper function to look like wrapped function
return update_wrapper(wrapper, wrapped_function)
async def _get_child_device(
device: Device, child_option, child_index_option, info_command
) -> Device | None:
def _list_children():
return "\n".join(
[
f"{idx}: {child.device_id} ({child.alias})"
for idx, child in enumerate(device.children)
]
)
if child_option is None and child_index_option is None:
return None
if info_command in SKIP_UPDATE_COMMANDS:
# The device hasn't had update called (e.g. for cmd_command)
# The way child devices are accessed requires a ChildDevice to
# wrap the communications. Doing this properly would require creating
# a common interfaces for both IOT and SMART child devices.
# As a stop-gap solution, we perform an update instead.
await device.update()
if not device.children:
error(f"Device: {device.host} does not have children")
if child_option is not None and child_index_option is not None:
raise click.BadOptionUsage(
"child", "Use either --child or --child-index, not both."
)
if child_option is not None:
if child_option is OPTIONAL_VALUE_FLAG:
msg = _list_children()
child_index_option = click.prompt(
f"\n{msg}\nEnter the index number of the child device",
type=click.IntRange(0, len(device.children) - 1),
)
elif child := device.get_child_device(child_option):
echo(f"Targeting child device {child.alias}")
return child
else:
error(
"No child device found with device_id or name: "
f"{child_option} children are:\n{_list_children()}"
)
if child_index_option + 1 > len(device.children) or child_index_option < 0:
error(
f"Invalid index {child_index_option}, "
f"device has {len(device.children)} children"
)
child_by_index = device.children[child_index_option]
echo(f"Targeting child device {child_by_index.alias}")
return child_by_index
@click.group( @click.group(
invoke_without_command=True, invoke_without_command=True,
cls=CatchAllExceptions(click.Group), cls=CatchAllExceptions(click.Group),
@ -232,6 +342,7 @@ def json_formatter_cb(result, **kwargs):
help="Output raw device response as JSON.", help="Output raw device response as JSON.",
) )
@click.option( @click.option(
"-e",
"--encrypt-type", "--encrypt-type",
envvar="KASA_ENCRYPT_TYPE", envvar="KASA_ENCRYPT_TYPE",
default=None, default=None,
@ -240,13 +351,14 @@ def json_formatter_cb(result, **kwargs):
@click.option( @click.option(
"--device-family", "--device-family",
envvar="KASA_DEVICE_FAMILY", envvar="KASA_DEVICE_FAMILY",
default=None, default="SMART.TAPOPLUG",
type=click.Choice(DEVICE_FAMILY_TYPES, case_sensitive=False), type=click.Choice(DEVICE_FAMILY_TYPES, case_sensitive=False),
) )
@click.option( @click.option(
"-lv",
"--login-version", "--login-version",
envvar="KASA_LOGIN_VERSION", envvar="KASA_LOGIN_VERSION",
default=None, default=2,
type=int, type=int,
) )
@click.option( @click.option(
@ -379,7 +491,8 @@ async def cli(
device_updated = False device_updated = False
if type is not None: if type is not None:
dev = TYPE_TO_CLASS[type](host) config = DeviceConfig(host=host, port_override=port, timeout=timeout)
dev = TYPE_TO_CLASS[type](host, config=config)
elif device_family and encrypt_type: elif device_family and encrypt_type:
ctype = DeviceConnectionParameters( ctype = DeviceConnectionParameters(
DeviceFamily(device_family), DeviceFamily(device_family),
@ -397,12 +510,6 @@ async def cli(
dev = await Device.connect(config=config) dev = await Device.connect(config=config)
device_updated = True device_updated = True
else: else:
if device_family or encrypt_type:
echo(
"--device-family and --encrypt-type options must both be "
"provided or they are ignored\n"
f"discovering for {discovery_timeout} seconds.."
)
dev = await Discover.discover_single( dev = await Discover.discover_single(
host, host,
port=port, port=port,
@ -587,7 +694,7 @@ async def find_host_from_alias(alias, target="255.255.255.255", timeout=1, attem
@cli.command() @cli.command()
@pass_dev @pass_dev_or_child
async def sysinfo(dev): async def sysinfo(dev):
"""Print out full system information.""" """Print out full system information."""
echo("== System info ==") echo("== System info ==")
@ -624,6 +731,7 @@ def _echo_all_features(features, *, verbose=False, title_prefix=None, indent="")
"""Print out all features by category.""" """Print out all features by category."""
if title_prefix is not None: if title_prefix is not None:
echo(f"[bold]\n{indent}== {title_prefix} ==[/bold]") echo(f"[bold]\n{indent}== {title_prefix} ==[/bold]")
echo()
_echo_features( _echo_features(
features, features,
title="== Primary features ==", title="== Primary features ==",
@ -658,7 +766,7 @@ def _echo_all_features(features, *, verbose=False, title_prefix=None, indent="")
@cli.command() @cli.command()
@pass_dev @pass_dev_or_child
@click.pass_context @click.pass_context
async def state(ctx, dev: Device): async def state(ctx, dev: Device):
"""Print out device state and versions.""" """Print out device state and versions."""
@ -676,11 +784,16 @@ async def state(ctx, dev: Device):
if verbose: if verbose:
echo(f"Location: {dev.location}") echo(f"Location: {dev.location}")
_echo_all_features(dev.features, verbose=verbose)
echo() echo()
_echo_all_features(dev.features, verbose=verbose)
if verbose:
echo("\n[bold]== Modules ==[/bold]")
for module in dev.modules.values():
echo(f"[green]+ {module}[/green]")
if dev.children: if dev.children:
echo("[bold]== Children ==[/bold]") echo("\n[bold]== Children ==[/bold]")
for child in dev.children: for child in dev.children:
_echo_all_features( _echo_all_features(
child.features, child.features,
@ -688,14 +801,13 @@ async def state(ctx, dev: Device):
verbose=verbose, verbose=verbose,
indent="\t", indent="\t",
) )
if verbose:
echo(f"\n\t[bold]== Child {child.alias} Modules ==[/bold]")
for module in child.modules.values():
echo(f"\t[green]+ {module}[/green]")
echo() echo()
if verbose: if verbose:
echo("\n\t[bold]== Modules ==[/bold]")
for module in dev.modules.values():
echo(f"\t[green]+ {module}[/green]")
echo("\n\t[bold]== Protocol information ==[/bold]") echo("\n\t[bold]== Protocol information ==[/bold]")
echo(f"\tCredentials hash: {dev.credentials_hash}") echo(f"\tCredentials hash: {dev.credentials_hash}")
echo() echo()
@ -705,24 +817,19 @@ async def state(ctx, dev: Device):
@cli.command() @cli.command()
@pass_dev
@click.argument("new_alias", required=False, default=None) @click.argument("new_alias", required=False, default=None)
@click.option("--index", type=int) @pass_dev_or_child
async def alias(dev, new_alias, index): async def alias(dev, new_alias):
"""Get or set the device (or plug) alias.""" """Get or set the device (or plug) alias."""
if index is not None:
if not dev.is_strip:
echo("Index can only used for power strips!")
return
dev = dev.get_plug_by_index(index)
if new_alias is not None: if new_alias is not None:
echo(f"Setting alias to {new_alias}") echo(f"Setting alias to {new_alias}")
res = await dev.set_alias(new_alias) res = await dev.set_alias(new_alias)
await dev.update()
echo(f"Alias set to: {dev.alias}")
return res return res
echo(f"Alias: {dev.alias}") echo(f"Alias: {dev.alias}")
if dev.is_strip: if dev.children:
for plug in dev.children: for plug in dev.children:
echo(f" * {plug.alias}") echo(f" * {plug.alias}")
@ -730,36 +837,26 @@ async def alias(dev, new_alias, index):
@cli.command() @cli.command()
@pass_dev
@click.pass_context @click.pass_context
@click.argument("module") @click.argument("module")
@click.argument("command") @click.argument("command")
@click.argument("parameters", default=None, required=False) @click.argument("parameters", default=None, required=False)
async def raw_command(ctx, dev: Device, module, command, parameters): async def raw_command(ctx, module, command, parameters):
"""Run a raw command on the device.""" """Run a raw command on the device."""
logging.warning("Deprecated, use 'kasa command --module %s %s'", module, command) logging.warning("Deprecated, use 'kasa command --module %s %s'", module, command)
return await ctx.forward(cmd_command) return await ctx.forward(cmd_command)
@cli.command(name="command") @cli.command(name="command")
@pass_dev
@click.option("--module", required=False, help="Module for IOT protocol.") @click.option("--module", required=False, help="Module for IOT protocol.")
@click.option("--child", required=False, help="Child ID for controlling sub-devices")
@click.argument("command") @click.argument("command")
@click.argument("parameters", default=None, required=False) @click.argument("parameters", default=None, required=False)
async def cmd_command(dev: Device, module, child, command, parameters): @pass_dev_or_child
async def cmd_command(dev: Device, module, command, parameters):
"""Run a raw command on the device.""" """Run a raw command on the device."""
if parameters is not None: if parameters is not None:
parameters = ast.literal_eval(parameters) parameters = ast.literal_eval(parameters)
if child:
# The way child devices are accessed requires a ChildDevice to
# wrap the communications. Doing this properly would require creating
# a common interfaces for both IOT and SMART child devices.
# As a stop-gap solution, we perform an update instead.
await dev.update()
dev = dev.get_child_device(child)
if isinstance(dev, IotDevice): if isinstance(dev, IotDevice):
res = await dev._query_helper(module, command, parameters) res = await dev._query_helper(module, command, parameters)
elif isinstance(dev, SmartDevice): elif isinstance(dev, SmartDevice):
@ -771,27 +868,30 @@ async def cmd_command(dev: Device, module, child, command, parameters):
@cli.command() @cli.command()
@pass_dev
@click.option("--index", type=int, required=False) @click.option("--index", type=int, required=False)
@click.option("--name", type=str, required=False) @click.option("--name", type=str, required=False)
@click.option("--year", type=click.DateTime(["%Y"]), default=None, required=False) @click.option("--year", type=click.DateTime(["%Y"]), default=None, required=False)
@click.option("--month", type=click.DateTime(["%Y-%m"]), default=None, required=False) @click.option("--month", type=click.DateTime(["%Y-%m"]), default=None, required=False)
@click.option("--erase", is_flag=True) @click.option("--erase", is_flag=True)
async def emeter(dev: Device, index: int, name: str, year, month, erase): @click.pass_context
"""Query emeter for historical consumption. async def emeter(ctx: click.Context, index, name, year, month, erase):
"""Query emeter for historical consumption."""
logging.warning("Deprecated, use 'kasa energy'")
return await ctx.invoke(
energy, child_index=index, child=name, year=year, month=month, erase=erase
)
@cli.command()
@click.option("--year", type=click.DateTime(["%Y"]), default=None, required=False)
@click.option("--month", type=click.DateTime(["%Y-%m"]), default=None, required=False)
@click.option("--erase", is_flag=True)
@pass_dev_or_child
async def energy(dev: Device, year, month, erase):
"""Query energy module for historical consumption.
Daily and monthly data provided in CSV format. Daily and monthly data provided in CSV format.
""" """
if index is not None or name is not None:
if not dev.is_strip:
error("Index and name are only for power strips!")
return
if index is not None:
dev = dev.get_plug_by_index(index)
elif name:
dev = dev.get_plug_by_name(name)
echo("[bold]== Emeter ==[/bold]") echo("[bold]== Emeter ==[/bold]")
if not dev.has_emeter: if not dev.has_emeter:
error("Device has no emeter") error("Device has no emeter")
@ -817,7 +917,7 @@ async def emeter(dev: Device, index: int, name: str, year, month, erase):
usage_data = await dev.get_emeter_daily(year=month.year, month=month.month) usage_data = await dev.get_emeter_daily(year=month.year, month=month.month)
else: else:
# Call with no argument outputs summary data and returns # Call with no argument outputs summary data and returns
if index is not None or name is not None: if isinstance(dev, IotStripPlug):
emeter_status = await dev.get_emeter_realtime() emeter_status = await dev.get_emeter_realtime()
else: else:
emeter_status = dev.emeter_realtime emeter_status = dev.emeter_realtime
@ -840,10 +940,10 @@ async def emeter(dev: Device, index: int, name: str, year, month, erase):
@cli.command() @cli.command()
@pass_dev
@click.option("--year", type=click.DateTime(["%Y"]), default=None, required=False) @click.option("--year", type=click.DateTime(["%Y"]), default=None, required=False)
@click.option("--month", type=click.DateTime(["%Y-%m"]), default=None, required=False) @click.option("--month", type=click.DateTime(["%Y-%m"]), default=None, required=False)
@click.option("--erase", is_flag=True) @click.option("--erase", is_flag=True)
@pass_dev_or_child
async def usage(dev: Device, year, month, erase): async def usage(dev: Device, year, month, erase):
"""Query usage for historical consumption. """Query usage for historical consumption.
@ -881,7 +981,7 @@ async def usage(dev: Device, year, month, erase):
@cli.command() @cli.command()
@click.argument("brightness", type=click.IntRange(0, 100), default=None, required=False) @click.argument("brightness", type=click.IntRange(0, 100), default=None, required=False)
@click.option("--transition", type=int, required=False) @click.option("--transition", type=int, required=False)
@pass_dev @pass_dev_or_child
async def brightness(dev: Device, brightness: int, transition: int): async def brightness(dev: Device, brightness: int, transition: int):
"""Get or set brightness.""" """Get or set brightness."""
if not (light := dev.modules.get(Module.Light)) or not light.is_dimmable: if not (light := dev.modules.get(Module.Light)) or not light.is_dimmable:
@ -901,7 +1001,7 @@ async def brightness(dev: Device, brightness: int, transition: int):
"temperature", type=click.IntRange(2500, 9000), default=None, required=False "temperature", type=click.IntRange(2500, 9000), default=None, required=False
) )
@click.option("--transition", type=int, required=False) @click.option("--transition", type=int, required=False)
@pass_dev @pass_dev_or_child
async def temperature(dev: Device, temperature: int, transition: int): async def temperature(dev: Device, temperature: int, transition: int):
"""Get or set color temperature.""" """Get or set color temperature."""
if not (light := dev.modules.get(Module.Light)) or not light.is_variable_color_temp: if not (light := dev.modules.get(Module.Light)) or not light.is_variable_color_temp:
@ -927,7 +1027,7 @@ async def temperature(dev: Device, temperature: int, transition: int):
@cli.command() @cli.command()
@click.argument("effect", type=click.STRING, default=None, required=False) @click.argument("effect", type=click.STRING, default=None, required=False)
@click.pass_context @click.pass_context
@pass_dev @pass_dev_or_child
async def effect(dev: Device, ctx, effect): async def effect(dev: Device, ctx, effect):
"""Set an effect.""" """Set an effect."""
if not (light_effect := dev.modules.get(Module.LightEffect)): if not (light_effect := dev.modules.get(Module.LightEffect)):
@ -955,7 +1055,7 @@ async def effect(dev: Device, ctx, effect):
@click.argument("v", type=click.IntRange(0, 100), default=None, required=False) @click.argument("v", type=click.IntRange(0, 100), default=None, required=False)
@click.option("--transition", type=int, required=False) @click.option("--transition", type=int, required=False)
@click.pass_context @click.pass_context
@pass_dev @pass_dev_or_child
async def hsv(dev: Device, ctx, h, s, v, transition): async def hsv(dev: Device, ctx, h, s, v, transition):
"""Get or set color in HSV.""" """Get or set color in HSV."""
if not (light := dev.modules.get(Module.Light)) or not light.is_color: if not (light := dev.modules.get(Module.Light)) or not light.is_color:
@ -974,7 +1074,7 @@ async def hsv(dev: Device, ctx, h, s, v, transition):
@cli.command() @cli.command()
@click.argument("state", type=bool, required=False) @click.argument("state", type=bool, required=False)
@pass_dev @pass_dev_or_child
async def led(dev: Device, state): async def led(dev: Device, state):
"""Get or set (Plug's) led state.""" """Get or set (Plug's) led state."""
if not (led := dev.modules.get(Module.Led)): if not (led := dev.modules.get(Module.Led)):
@ -1026,64 +1126,28 @@ async def time_sync(dev: Device):
@cli.command() @cli.command()
@click.option("--index", type=int, required=False)
@click.option("--name", type=str, required=False)
@click.option("--transition", type=int, required=False) @click.option("--transition", type=int, required=False)
@pass_dev @pass_dev_or_child
async def on(dev: Device, index: int, name: str, transition: int): async def on(dev: Device, transition: int):
"""Turn the device on.""" """Turn the device on."""
if index is not None or name is not None:
if not dev.children:
error("Index and name are only for devices with children.")
return
if index is not None:
dev = dev.get_plug_by_index(index)
elif name:
dev = dev.get_plug_by_name(name)
echo(f"Turning on {dev.alias}") echo(f"Turning on {dev.alias}")
return await dev.turn_on(transition=transition) return await dev.turn_on(transition=transition)
@cli.command() @cli.command
@click.option("--index", type=int, required=False)
@click.option("--name", type=str, required=False)
@click.option("--transition", type=int, required=False) @click.option("--transition", type=int, required=False)
@pass_dev @pass_dev_or_child
async def off(dev: Device, index: int, name: str, transition: int): async def off(dev: Device, transition: int):
"""Turn the device off.""" """Turn the device off."""
if index is not None or name is not None:
if not dev.children:
error("Index and name are only for devices with children.")
return
if index is not None:
dev = dev.get_plug_by_index(index)
elif name:
dev = dev.get_plug_by_name(name)
echo(f"Turning off {dev.alias}") echo(f"Turning off {dev.alias}")
return await dev.turn_off(transition=transition) return await dev.turn_off(transition=transition)
@cli.command() @cli.command()
@click.option("--index", type=int, required=False)
@click.option("--name", type=str, required=False)
@click.option("--transition", type=int, required=False) @click.option("--transition", type=int, required=False)
@pass_dev @pass_dev_or_child
async def toggle(dev: Device, index: int, name: str, transition: int): async def toggle(dev: Device, transition: int):
"""Toggle the device on/off.""" """Toggle the device on/off."""
if index is not None or name is not None:
if not dev.children:
error("Index and name are only for devices with children.")
return
if index is not None:
dev = dev.get_plug_by_index(index)
elif name:
dev = dev.get_plug_by_name(name)
if dev.is_on: if dev.is_on:
echo(f"Turning off {dev.alias}") echo(f"Turning off {dev.alias}")
return await dev.turn_off(transition=transition) return await dev.turn_off(transition=transition)
@ -1108,9 +1172,9 @@ async def schedule(dev):
@schedule.command(name="list") @schedule.command(name="list")
@pass_dev @pass_dev_or_child
@click.argument("type", default="schedule") @click.argument("type", default="schedule")
def _schedule_list(dev, type): async def _schedule_list(dev, type):
"""Return the list of schedule actions for the given type.""" """Return the list of schedule actions for the given type."""
sched = dev.modules[type] sched = dev.modules[type]
for rule in sched.rules: for rule in sched.rules:
@ -1122,7 +1186,7 @@ def _schedule_list(dev, type):
@schedule.command(name="delete") @schedule.command(name="delete")
@pass_dev @pass_dev_or_child
@click.option("--id", type=str, required=True) @click.option("--id", type=str, required=True)
async def delete_rule(dev, id): async def delete_rule(dev, id):
"""Delete rule from device.""" """Delete rule from device."""
@ -1136,25 +1200,26 @@ async def delete_rule(dev, id):
@cli.group(invoke_without_command=True) @cli.group(invoke_without_command=True)
@pass_dev_or_child
@click.pass_context @click.pass_context
async def presets(ctx): async def presets(ctx, dev):
"""List and modify bulb setting presets.""" """List and modify bulb setting presets."""
if ctx.invoked_subcommand is None: if ctx.invoked_subcommand is None:
return await ctx.invoke(presets_list) return await ctx.invoke(presets_list)
@presets.command(name="list") @presets.command(name="list")
@pass_dev @pass_dev_or_child
def presets_list(dev: Device): def presets_list(dev: Device):
"""List presets.""" """List presets."""
if not dev.is_bulb or not isinstance(dev, IotBulb): if not (light_preset := dev.modules.get(Module.LightPreset)):
error("Presets only supported on iot bulbs") error("Presets not supported on device")
return return
for preset in dev.presets: for preset in light_preset.preset_states_list:
echo(preset) echo(preset)
return dev.presets return light_preset.preset_states_list
@presets.command(name="modify") @presets.command(name="modify")
@ -1163,7 +1228,7 @@ def presets_list(dev: Device):
@click.option("--hue", type=int) @click.option("--hue", type=int)
@click.option("--saturation", type=int) @click.option("--saturation", type=int)
@click.option("--temperature", type=int) @click.option("--temperature", type=int)
@pass_dev @pass_dev_or_child
async def presets_modify(dev: Device, index, brightness, hue, saturation, temperature): async def presets_modify(dev: Device, index, brightness, hue, saturation, temperature):
"""Modify a preset.""" """Modify a preset."""
for preset in dev.presets: for preset in dev.presets:
@ -1188,7 +1253,7 @@ async def presets_modify(dev: Device, index, brightness, hue, saturation, temper
@cli.command() @cli.command()
@pass_dev @pass_dev_or_child
@click.option("--type", type=click.Choice(["soft", "hard"], case_sensitive=False)) @click.option("--type", type=click.Choice(["soft", "hard"], case_sensitive=False))
@click.option("--last", is_flag=True) @click.option("--last", is_flag=True)
@click.option("--preset", type=int) @click.option("--preset", type=int)
@ -1240,7 +1305,7 @@ async def update_credentials(dev, username, password):
@cli.command() @cli.command()
@pass_dev @pass_dev_or_child
async def shell(dev: Device): async def shell(dev: Device):
"""Open interactive shell.""" """Open interactive shell."""
echo("Opening shell for %s" % dev) echo("Opening shell for %s" % dev)
@ -1263,10 +1328,14 @@ async def shell(dev: Device):
@cli.command(name="feature") @cli.command(name="feature")
@click.argument("name", required=False) @click.argument("name", required=False)
@click.argument("value", required=False) @click.argument("value", required=False)
@click.option("--child", required=False) @pass_dev_or_child
@pass_dev
@click.pass_context @click.pass_context
async def feature(ctx: click.Context, dev: Device, child: str, name: str, value): async def feature(
ctx: click.Context,
dev: Device,
name: str,
value,
):
"""Access and modify features. """Access and modify features.
If no *name* is given, lists available features and their values. If no *name* is given, lists available features and their values.
@ -1275,9 +1344,6 @@ async def feature(ctx: click.Context, dev: Device, child: str, name: str, value)
""" """
verbose = ctx.parent.params.get("verbose", False) if ctx.parent else False verbose = ctx.parent.params.get("verbose", False) if ctx.parent else False
if child is not None:
echo(f"Targeting child device {child}")
dev = dev.get_child_device(child)
if not name: if not name:
_echo_all_features(dev.features, verbose=verbose, indent="") _echo_all_features(dev.features, verbose=verbose, indent="")

View File

@ -338,9 +338,15 @@ class Device(ABC):
"""Returns the child devices.""" """Returns the child devices."""
return list(self._children.values()) return list(self._children.values())
def get_child_device(self, id_: str) -> Device: def get_child_device(self, name_or_id: str) -> Device | None:
"""Return child device by its ID.""" """Return child device by its device_id or alias."""
return self._children[id_] if name_or_id in self._children:
return self._children[name_or_id]
name_lower = name_or_id.lower()
for child in self.children:
if child.alias and child.alias.lower() == name_lower:
return child
return None
@property @property
@abstractmethod @abstractmethod

View File

@ -145,7 +145,7 @@ class IotStrip(IotDevice):
if update_children: if update_children:
for plug in self.children: for plug in self.children:
await plug.update() await plug._update()
if not self.features: if not self.features:
await self._initialize_features() await self._initialize_features()
@ -362,6 +362,14 @@ class IotStripPlug(IotPlug):
Needed for properties that are decorated with `requires_update`. Needed for properties that are decorated with `requires_update`.
""" """
await self._update(update_children)
async def _update(self, update_children: bool = True):
"""Query the device to update the data.
Internal implementation to allow patching of public update in the cli
or test framework.
"""
await self._modular_update({}) await self._modular_update({})
for module in self._modules.values(): for module in self._modules.values():
module._post_update_hook() module._post_update_hook()

View File

@ -40,6 +40,14 @@ class SmartChildDevice(SmartDevice):
The parent updates our internal info so just update modules with The parent updates our internal info so just update modules with
their own queries. their own queries.
""" """
await self._update(update_children)
async def _update(self, update_children: bool = True):
"""Update child module info.
Internal implementation to allow patching of public update in the cli
or test framework.
"""
req: dict[str, Any] = {} req: dict[str, Any] = {}
for module in self.modules.values(): for module in self.modules.values():
if mod_query := module.query(): if mod_query := module.query():

View File

@ -171,7 +171,7 @@ class SmartDevice(Device):
# devices will always update children to prevent errors on module access. # devices will always update children to prevent errors on module access.
if update_children or self.device_type != DeviceType.Hub: if update_children or self.device_type != DeviceType.Hub:
for child in self._children.values(): for child in self._children.values():
await child.update() await child._update()
if child_info := self._try_get_response(resp, "get_child_device_list", {}): if child_info := self._try_get_response(resp, "get_child_device_list", {}):
for info in child_info["child_device_list"]: for info in child_info["child_device_list"]:
self._children[info["device_id"]]._update_internal_state(info) self._children[info["device_id"]]._update_internal_state(info)

View File

@ -5,6 +5,7 @@ import re
import asyncclick as click import asyncclick as click
import pytest import pytest
from asyncclick.testing import CliRunner from asyncclick.testing import CliRunner
from pytest_mock import MockerFixture
from kasa import ( from kasa import (
AuthenticationError, AuthenticationError,
@ -24,6 +25,7 @@ from kasa.cli import (
cmd_command, cmd_command,
effect, effect,
emeter, emeter,
energy,
hsv, hsv,
led, led,
raw_command, raw_command,
@ -62,7 +64,6 @@ def runner():
[ [
pytest.param(None, None, id="No connect params"), pytest.param(None, None, id="No connect params"),
pytest.param("SMART.TAPOPLUG", None, id="Only device_family"), pytest.param("SMART.TAPOPLUG", None, id="Only device_family"),
pytest.param(None, "KLAP", id="Only encrypt_type"),
], ],
) )
async def test_update_called_by_cli(dev, mocker, runner, device_family, encrypt_type): async def test_update_called_by_cli(dev, mocker, runner, device_family, encrypt_type):
@ -171,13 +172,16 @@ async def test_command_with_child(dev, mocker, runner):
class DummyDevice(dev.__class__): class DummyDevice(dev.__class__):
def __init__(self): def __init__(self):
super().__init__("127.0.0.1") super().__init__("127.0.0.1")
# device_type and _info initialised for repr
self._device_type = Device.Type.StripSocket
self._info = {}
async def _query_helper(*_, **__): async def _query_helper(*_, **__):
return {"dummy": "response"} return {"dummy": "response"}
dummy_child = DummyDevice() dummy_child = DummyDevice()
mocker.patch.object(dev, "_children", {"XYZ": dummy_child}) mocker.patch.object(dev, "_children", {"XYZ": [dummy_child]})
mocker.patch.object(dev, "get_child_device", return_value=dummy_child) mocker.patch.object(dev, "get_child_device", return_value=dummy_child)
res = await runner.invoke( res = await runner.invoke(
@ -314,9 +318,9 @@ async def test_emeter(dev: Device, mocker, runner):
if not dev.is_strip: if not dev.is_strip:
res = await runner.invoke(emeter, ["--index", "0"], obj=dev) res = await runner.invoke(emeter, ["--index", "0"], obj=dev)
assert "Index and name are only for power strips!" in res.output assert f"Device: {dev.host} does not have children" in res.output
res = await runner.invoke(emeter, ["--name", "mock"], obj=dev) res = await runner.invoke(emeter, ["--name", "mock"], obj=dev)
assert "Index and name are only for power strips!" in res.output assert f"Device: {dev.host} does not have children" in res.output
if dev.is_strip and len(dev.children) > 0: if dev.is_strip and len(dev.children) > 0:
realtime_emeter = mocker.patch.object(dev.children[0], "get_emeter_realtime") realtime_emeter = mocker.patch.object(dev.children[0], "get_emeter_realtime")
@ -930,3 +934,110 @@ async def test_feature_set_child(mocker, runner):
assert f"Targeting child device {child_id}" assert f"Targeting child device {child_id}"
assert "Changing state from False to True" in res.output assert "Changing state from False to True" in res.output
assert res.exit_code == 0 assert res.exit_code == 0
async def test_cli_child_commands(
dev: Device, runner: CliRunner, mocker: MockerFixture
):
if not dev.children:
res = await runner.invoke(alias, ["--child-index", "0"], obj=dev)
assert f"Device: {dev.host} does not have children" in res.output
assert res.exit_code == 1
res = await runner.invoke(alias, ["--index", "0"], obj=dev)
assert f"Device: {dev.host} does not have children" in res.output
assert res.exit_code == 1
res = await runner.invoke(alias, ["--child", "Plug 2"], obj=dev)
assert f"Device: {dev.host} does not have children" in res.output
assert res.exit_code == 1
res = await runner.invoke(alias, ["--name", "Plug 2"], obj=dev)
assert f"Device: {dev.host} does not have children" in res.output
assert res.exit_code == 1
if dev.children:
child_alias = dev.children[0].alias
assert child_alias
child_device_id = dev.children[0].device_id
child_count = len(dev.children)
child_update_method = dev.children[0].update
# Test child retrieval
res = await runner.invoke(alias, ["--child-index", "0"], obj=dev)
assert f"Targeting child device {child_alias}" in res.output
assert res.exit_code == 0
res = await runner.invoke(alias, ["--index", "0"], obj=dev)
assert f"Targeting child device {child_alias}" in res.output
assert res.exit_code == 0
res = await runner.invoke(alias, ["--child", child_alias], obj=dev)
assert f"Targeting child device {child_alias}" in res.output
assert res.exit_code == 0
res = await runner.invoke(alias, ["--name", child_alias], obj=dev)
assert f"Targeting child device {child_alias}" in res.output
assert res.exit_code == 0
res = await runner.invoke(alias, ["--child", child_device_id], obj=dev)
assert f"Targeting child device {child_alias}" in res.output
assert res.exit_code == 0
res = await runner.invoke(alias, ["--name", child_device_id], obj=dev)
assert f"Targeting child device {child_alias}" in res.output
assert res.exit_code == 0
# Test invalid name and index
res = await runner.invoke(alias, ["--child-index", "-1"], obj=dev)
assert f"Invalid index -1, device has {child_count} children" in res.output
assert res.exit_code == 1
res = await runner.invoke(alias, ["--child-index", str(child_count)], obj=dev)
assert (
f"Invalid index {child_count}, device has {child_count} children"
in res.output
)
assert res.exit_code == 1
res = await runner.invoke(alias, ["--child", "foobar"], obj=dev)
assert "No child device found with device_id or name: foobar" in res.output
assert res.exit_code == 1
# Test using both options:
res = await runner.invoke(
alias, ["--child", child_alias, "--child-index", "0"], obj=dev
)
assert "Use either --child or --child-index, not both." in res.output
assert res.exit_code == 2
# Test child with no parameter interactive prompt
res = await runner.invoke(alias, ["--child"], obj=dev, input="0\n")
assert "Enter the index number of the child device:" in res.output
assert f"Alias: {child_alias}" in res.output
assert res.exit_code == 0
# Test values and updates
res = await runner.invoke(alias, ["foo", "--child", child_device_id], obj=dev)
assert "Alias set to: foo" in res.output
assert res.exit_code == 0
# Test help has command options plus child options
res = await runner.invoke(energy, ["--help"], obj=dev)
assert "--year" in res.output
assert "--child" in res.output
assert "--child-index" in res.output
assert res.exit_code == 0
# Test child update patching calls parent and is undone on exit
parent_update_spy = mocker.spy(dev, "update")
res = await runner.invoke(alias, ["bar", "--child", child_device_id], obj=dev)
assert "Alias set to: bar" in res.output
assert res.exit_code == 0
parent_update_spy.assert_called_once()
assert dev.children[0].update == child_update_method