Disallow non-targeted device commands (#982)

Prevent the cli from allowing sub commands unless host or alias is specified.
It is unwise to allow commands to be run on an arbitrary set of discovered
devices so this PR shows an error if attempted.
Also consolidates other invalid cli operations to use a single error function
to display the error to the user.
This commit is contained in:
Teemu R 2024-06-17 11:04:46 +02:00 committed by GitHub
parent 867b7b8830
commit 51a972542f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 32 additions and 25 deletions

View File

@ -71,6 +71,12 @@ except ImportError:
echo = _do_echo echo = _do_echo
def error(msg: str):
"""Print an error and exit."""
echo(f"[bold red]{msg}[/bold red]")
sys.exit(1)
TYPE_TO_CLASS = { TYPE_TO_CLASS = {
"plug": IotPlug, "plug": IotPlug,
"switch": IotWallSwitch, "switch": IotWallSwitch,
@ -367,6 +373,9 @@ async def cli(
credentials = None credentials = None
if host is None: if host is None:
if ctx.invoked_subcommand and ctx.invoked_subcommand != "discover":
error("Only discover is available without --host or --alias")
echo("No host name given, trying discovery..") echo("No host name given, trying discovery..")
return await ctx.invoke(discover) return await ctx.invoke(discover)
@ -764,7 +773,7 @@ async def emeter(dev: Device, index: int, name: str, year, month, erase):
""" """
if index is not None or name is not None: if index is not None or name is not None:
if not dev.is_strip: if not dev.is_strip:
echo("Index and name are only for power strips!") error("Index and name are only for power strips!")
return return
if index is not None: if index is not None:
@ -774,11 +783,11 @@ async def emeter(dev: Device, index: int, name: str, year, month, erase):
echo("[bold]== Emeter ==[/bold]") echo("[bold]== Emeter ==[/bold]")
if not dev.has_emeter: if not dev.has_emeter:
echo("Device has no emeter") error("Device has no emeter")
return return
if (year or month or erase) and not isinstance(dev, IotDevice): if (year or month or erase) and not isinstance(dev, IotDevice):
echo("Device has no historical statistics") error("Device has no historical statistics")
return return
else: else:
dev = cast(IotDevice, dev) dev = cast(IotDevice, dev)
@ -865,7 +874,7 @@ async def usage(dev: Device, year, month, erase):
async def brightness(dev: Device, brightness: int, transition: int): async def brightness(dev: Device, brightness: int, transition: int):
"""Get or set brightness.""" """Get or set brightness."""
if not (light := dev.modules.get(Module.Light)) or not light.is_dimmable: if not (light := dev.modules.get(Module.Light)) or not light.is_dimmable:
echo("This device does not support brightness.") error("This device does not support brightness.")
return return
if brightness is None: if brightness is None:
@ -885,7 +894,7 @@ async def brightness(dev: Device, brightness: int, transition: int):
async def temperature(dev: Device, temperature: int, transition: int): async def temperature(dev: Device, temperature: int, transition: int):
"""Get or set color temperature.""" """Get or set color temperature."""
if not (light := dev.modules.get(Module.Light)) or not light.is_variable_color_temp: if not (light := dev.modules.get(Module.Light)) or not light.is_variable_color_temp:
echo("Device does not support color temperature") error("Device does not support color temperature")
return return
if temperature is None: if temperature is None:
@ -911,7 +920,7 @@ async def temperature(dev: Device, temperature: int, transition: int):
async def effect(dev: Device, ctx, effect): async def effect(dev: Device, ctx, effect):
"""Set an effect.""" """Set an effect."""
if not (light_effect := dev.modules.get(Module.LightEffect)): if not (light_effect := dev.modules.get(Module.LightEffect)):
echo("Device does not support effects") error("Device does not support effects")
return return
if effect is None: if effect is None:
echo( echo(
@ -939,7 +948,7 @@ async def effect(dev: Device, ctx, effect):
async def hsv(dev: Device, ctx, h, s, v, transition): async def hsv(dev: Device, ctx, h, s, v, transition):
"""Get or set color in HSV.""" """Get or set color in HSV."""
if not (light := dev.modules.get(Module.Light)) or not light.is_color: if not (light := dev.modules.get(Module.Light)) or not light.is_color:
echo("Device does not support colors") error("Device does not support colors")
return return
if h is None and s is None and v is None: if h is None and s is None and v is None:
@ -958,7 +967,7 @@ async def hsv(dev: Device, ctx, h, s, v, transition):
async def led(dev: Device, state): async def led(dev: Device, state):
"""Get or set (Plug's) led state.""" """Get or set (Plug's) led state."""
if not (led := dev.modules.get(Module.Led)): if not (led := dev.modules.get(Module.Led)):
echo("Device does not support led.") error("Device does not support led.")
return return
if state is not None: if state is not None:
echo(f"Turning led to {state}") echo(f"Turning led to {state}")
@ -1014,7 +1023,7 @@ async def on(dev: Device, index: int, name: str, transition: int):
"""Turn the device on.""" """Turn the device on."""
if index is not None or name is not None: if index is not None or name is not None:
if not dev.children: if not dev.children:
echo("Index and name are only for devices with children.") error("Index and name are only for devices with children.")
return return
if index is not None: if index is not None:
@ -1035,7 +1044,7 @@ async def off(dev: Device, index: int, name: str, transition: int):
"""Turn the device off.""" """Turn the device off."""
if index is not None or name is not None: if index is not None or name is not None:
if not dev.children: if not dev.children:
echo("Index and name are only for devices with children.") error("Index and name are only for devices with children.")
return return
if index is not None: if index is not None:
@ -1056,7 +1065,7 @@ async def toggle(dev: Device, index: int, name: str, transition: int):
"""Toggle the device on/off.""" """Toggle the device on/off."""
if index is not None or name is not None: if index is not None or name is not None:
if not dev.children: if not dev.children:
echo("Index and name are only for devices with children.") error("Index and name are only for devices with children.")
return return
if index is not None: if index is not None:
@ -1096,7 +1105,7 @@ def _schedule_list(dev, type):
for rule in sched.rules: for rule in sched.rules:
print(rule) print(rule)
else: else:
echo(f"No rules of type {type}") error(f"No rules of type {type}")
return sched.rules return sched.rules
@ -1112,7 +1121,7 @@ async def delete_rule(dev, id):
echo(f"Deleting rule id {id}") echo(f"Deleting rule id {id}")
return await schedule.delete_rule(rule_to_delete) return await schedule.delete_rule(rule_to_delete)
else: else:
echo(f"No rule with id {id} was found") error(f"No rule with id {id} was found")
@cli.group(invoke_without_command=True) @cli.group(invoke_without_command=True)
@ -1128,7 +1137,7 @@ async def presets(ctx):
def presets_list(dev: IotBulb): def presets_list(dev: IotBulb):
"""List presets.""" """List presets."""
if not dev.is_bulb or not isinstance(dev, IotBulb): if not dev.is_bulb or not isinstance(dev, IotBulb):
echo("Presets only supported on iot bulbs") error("Presets only supported on iot bulbs")
return return
for preset in dev.presets: for preset in dev.presets:
@ -1150,7 +1159,7 @@ async def presets_modify(dev: IotBulb, index, brightness, hue, saturation, tempe
if preset.index == index: if preset.index == index:
break break
else: else:
echo(f"No preset found for index {index}") error(f"No preset found for index {index}")
return return
if brightness is not None: if brightness is not None:
@ -1175,7 +1184,7 @@ async def presets_modify(dev: IotBulb, index, brightness, hue, saturation, tempe
async def turn_on_behavior(dev: IotBulb, type, last, preset): async def turn_on_behavior(dev: IotBulb, type, last, preset):
"""Modify bulb turn-on behavior.""" """Modify bulb turn-on behavior."""
if not dev.is_bulb or not isinstance(dev, IotBulb): if not dev.is_bulb or not isinstance(dev, IotBulb):
echo("Presets only supported on iot bulbs") error("Presets only supported on iot bulbs")
return return
settings = await dev.get_turn_on_behavior() settings = await dev.get_turn_on_behavior()
echo(f"Current turn on behavior: {settings}") echo(f"Current turn on behavior: {settings}")
@ -1212,9 +1221,7 @@ async def turn_on_behavior(dev: IotBulb, type, last, preset):
async def update_credentials(dev, username, password): async def update_credentials(dev, username, password):
"""Update device credentials for authenticated devices.""" """Update device credentials for authenticated devices."""
if not isinstance(dev, SmartDevice): if not isinstance(dev, SmartDevice):
raise NotImplementedError( error("Credentials can only be updated on authenticated devices.")
"Credentials can only be updated on authenticated devices."
)
click.confirm("Do you really want to replace the existing credentials?", abort=True) click.confirm("Do you really want to replace the existing credentials?", abort=True)
@ -1271,7 +1278,7 @@ async def feature(dev: Device, child: str, name: str, value):
return return
if name not in dev.features: if name not in dev.features:
echo(f"No feature by name '{name}'") error(f"No feature by name '{name}'")
return return
feat = dev.features[name] feat = dev.features[name]

View File

@ -461,12 +461,12 @@ async def test_led(dev: Device, runner: CliRunner):
async def test_json_output(dev: Device, mocker, runner): async def test_json_output(dev: Device, mocker, runner):
"""Test that the json output produces correct output.""" """Test that the json output produces correct output."""
mocker.patch("kasa.Discover.discover", return_value={"127.0.0.1": dev}) mocker.patch("kasa.Discover.discover_single", return_value=dev)
# These will mock the features to avoid accessing non-existing # These will mock the features to avoid accessing non-existing ones
mocker.patch("kasa.device.Device.features", return_value={}) mocker.patch("kasa.device.Device.features", return_value={})
mocker.patch("kasa.iot.iotdevice.IotDevice.features", return_value={}) mocker.patch("kasa.iot.iotdevice.IotDevice.features", return_value={})
res = await runner.invoke(cli, ["--json", "state"], obj=dev) res = await runner.invoke(cli, ["--host", "127.0.0.1", "--json", "state"], obj=dev)
assert res.exit_code == 0 assert res.exit_code == 0
assert json.loads(res.output) == dev.internal_state assert json.loads(res.output) == dev.internal_state
@ -789,7 +789,7 @@ async def test_errors(mocker, runner):
) )
assert res.exit_code == 1 assert res.exit_code == 1
assert ( assert (
"Raised error: Managed to invoke callback without a context object of type 'Device' existing." "Only discover is available without --host or --alias"
in res.output.replace("\n", "") # Remove newlines from rich formatting in res.output.replace("\n", "") # Remove newlines from rich formatting
) )
assert isinstance(res.exception, SystemExit) assert isinstance(res.exception, SystemExit)
@ -860,7 +860,7 @@ async def test_feature_missing(mocker, runner):
) )
assert "No feature by name 'missing'" in res.output assert "No feature by name 'missing'" in res.output
assert "== Features ==" not in res.output assert "== Features ==" not in res.output
assert res.exit_code == 0 assert res.exit_code == 1
async def test_feature_set(mocker, runner): async def test_feature_set(mocker, runner):