Get child emeters with CLI (#623)

* Get child emeters with CLI

* Avoid extra IO when not que querying the child emeter
This commit is contained in:
Nathan Wreggit 2024-01-04 17:25:24 -08:00 committed by GitHub
parent 2d8a8d9511
commit cfe694e5de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 40 additions and 4 deletions

View File

@ -602,14 +602,27 @@ async def raw_command(dev: SmartDevice, module, command, parameters):
@cli.command() @cli.command()
@pass_dev @pass_dev
@click.option("--index", type=int, required=False)
@click.option("--name", type=str, required=False)
@click.option("--year", type=click.DateTime(["%Y"]), default=None, required=False) @click.option("--year", type=click.DateTime(["%Y"]), default=None, required=False)
@click.option("--month", type=click.DateTime(["%Y-%m"]), default=None, required=False) @click.option("--month", type=click.DateTime(["%Y-%m"]), default=None, required=False)
@click.option("--erase", is_flag=True) @click.option("--erase", is_flag=True)
async def emeter(dev: SmartDevice, year, month, erase): async def emeter(dev: SmartDevice, index: int, name: str, year, month, erase):
"""Query emeter for historical consumption. """Query emeter for historical consumption.
Daily and monthly data provided in CSV format. Daily and monthly data provided in CSV format.
""" """
if index is not None or name is not None:
if not dev.is_strip:
echo("Index and name are only for power strips!")
return
dev = cast(SmartStrip, dev)
if index is not None:
dev = dev.get_plug_by_index(index)
elif name:
dev = dev.get_plug_by_name(name)
echo("[bold]== Emeter ==[/bold]") echo("[bold]== Emeter ==[/bold]")
if not dev.has_emeter: if not dev.has_emeter:
echo("Device has no emeter") echo("Device has no emeter")
@ -629,7 +642,10 @@ async def emeter(dev: SmartDevice, year, month, erase):
usage_data = await dev.get_emeter_daily(year=month.year, month=month.month) usage_data = await dev.get_emeter_daily(year=month.year, month=month.month)
else: else:
# Call with no argument outputs summary data and returns # Call with no argument outputs summary data and returns
emeter_status = dev.emeter_realtime if index is not None or name is not None:
emeter_status = await dev.get_emeter_realtime()
else:
emeter_status = dev.emeter_realtime
echo("Current: %s A" % emeter_status["current"]) echo("Current: %s A" % emeter_status["current"])
echo("Voltage: %s V" % emeter_status["voltage"]) echo("Voltage: %s V" % emeter_status["voltage"])

View File

@ -128,7 +128,7 @@ def get_device_class_from_sys_info(info: Dict[str, Any]) -> Type[SmartDevice]:
def get_device_class_from_family(device_type: str) -> Optional[Type[SmartDevice]]: def get_device_class_from_family(device_type: str) -> Optional[Type[SmartDevice]]:
"""Return the device class from the type name.""" """Return the device class from the type name."""
supported_device_types: dict[str, Type[SmartDevice]] = { supported_device_types: Dict[str, Type[SmartDevice]] = {
"SMART.TAPOPLUG": TapoPlug, "SMART.TAPOPLUG": TapoPlug,
"SMART.TAPOBULB": TapoBulb, "SMART.TAPOBULB": TapoBulb,
"SMART.KASAPLUG": TapoPlug, "SMART.KASAPLUG": TapoPlug,
@ -147,7 +147,7 @@ def get_protocol(
protocol_transport_key = ( protocol_transport_key = (
protocol_name + "." + config.connection_type.encryption_type.value protocol_name + "." + config.connection_type.encryption_type.value
) )
supported_device_protocols: dict[ supported_device_protocols: Dict[
str, Tuple[Type[TPLinkProtocol], Type[BaseTransport]] str, Tuple[Type[TPLinkProtocol], Type[BaseTransport]]
] = { ] = {
"IOT.XOR": (TPLinkSmartHomeProtocol, _XorTransport), "IOT.XOR": (TPLinkSmartHomeProtocol, _XorTransport),

View File

@ -7,6 +7,7 @@ from asyncclick.testing import CliRunner
from kasa import ( from kasa import (
AuthenticationException, AuthenticationException,
Credentials, Credentials,
EmeterStatus,
SmartDevice, SmartDevice,
TPLinkSmartHomeProtocol, TPLinkSmartHomeProtocol,
UnsupportedDeviceException, UnsupportedDeviceException,
@ -104,6 +105,25 @@ async def test_emeter(dev: SmartDevice, mocker):
assert "== Emeter ==" in res.output assert "== Emeter ==" in res.output
if not dev.is_strip:
res = await runner.invoke(emeter, ["--index", "0"], obj=dev)
assert "Index and name are only for power strips!" in res.output
res = await runner.invoke(emeter, ["--name", "mock"], obj=dev)
assert "Index and name are only for power strips!" in res.output
if dev.is_strip and len(dev.children) > 0:
realtime_emeter = mocker.patch.object(dev.children[0], "get_emeter_realtime")
realtime_emeter.return_value = EmeterStatus({"voltage_mv": 122066})
res = await runner.invoke(emeter, ["--index", "0"], obj=dev)
assert "Voltage: 122.066 V" in res.output
realtime_emeter.assert_called()
assert realtime_emeter.call_count == 1
res = await runner.invoke(emeter, ["--name", dev.children[0].alias], obj=dev)
assert "Voltage: 122.066 V" in res.output
assert realtime_emeter.call_count == 2
monthly = mocker.patch.object(dev, "get_emeter_monthly") monthly = mocker.patch.object(dev, "get_emeter_monthly")
monthly.return_value = {1: 1234} monthly.return_value = {1: 1234}
res = await runner.invoke(emeter, ["--year", "1900"], obj=dev) res = await runner.invoke(emeter, ["--year", "1900"], obj=dev)