Add emeter support for strip sockets (#203)

* Add support for plugs with emeters.

* Tweaks for emeter

* black

* tweaks

* tweaks

* more tweaks

* dry

* flake8

* flake8

* legacy typing

* Update kasa/smartstrip.py

Co-authored-by: Teemu R. <tpr@iki.fi>

* reduce

* remove useless delegation

* tweaks

* tweaks

* dry

* tweak

* tweak

* tweak

* tweak

* update tests

* wrap

* preen

* prune

* prune

* prune

* guard

* adjust

* robust

* prune

* prune

* reduce dict lookups by 1

* Update kasa/smartstrip.py

Co-authored-by: Teemu R. <tpr@iki.fi>

* delete utils

* isort

Co-authored-by: Brendan Burns <brendan.d.burns@gmail.com>
Co-authored-by: Teemu R. <tpr@iki.fi>
This commit is contained in:
J. Nick Koston 2021-09-23 17:24:44 -05:00 committed by GitHub
parent d7202883e9
commit 94e5a90ac4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 112 additions and 101 deletions

View File

@ -11,13 +11,14 @@ Stroetmann which is licensed under the Apache License, Version 2.0.
You may obtain a copy of the license at You may obtain a copy of the license at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
""" """
import collections.abc
import functools import functools
import inspect import inspect
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timedelta from datetime import datetime, timedelta
from enum import Enum, auto from enum import Enum, auto
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, Set
from .emeterstatus import EmeterStatus from .emeterstatus import EmeterStatus
from .exceptions import SmartDeviceException from .exceptions import SmartDeviceException
@ -51,6 +52,16 @@ class WifiNetwork:
rssi: Optional[int] = None rssi: Optional[int] = None
def merge(d, u):
"""Update dict recursively."""
for k, v in u.items():
if isinstance(v, collections.abc.Mapping):
d[k] = merge(d.get(k, {}), v)
else:
d[k] = v
return d
def requires_update(f): def requires_update(f):
"""Indicate that `update` should be called before accessing this method.""" # noqa: D202 """Indicate that `update` should be called before accessing this method.""" # noqa: D202
if inspect.iscoroutinefunction(f): if inspect.iscoroutinefunction(f):
@ -204,6 +215,11 @@ class SmartDevice:
return request return request
def _verify_emeter(self) -> None:
"""Raise an exception if there is no emeter."""
if not self.has_emeter:
raise SmartDeviceException("Device has no emeter")
async def _query_helper( async def _query_helper(
self, target: str, cmd: str, arg: Optional[Dict] = None, child_ids=None self, target: str, cmd: str, arg: Optional[Dict] = None, child_ids=None
) -> Any: ) -> Any:
@ -240,13 +256,17 @@ class SmartDevice:
return result return result
@property # type: ignore
@requires_update
def features(self) -> Set[str]:
"""Return a set of features that the device supports."""
return set(self.sys_info["feature"].split(":"))
@property # type: ignore @property # type: ignore
@requires_update @requires_update
def has_emeter(self) -> bool: def has_emeter(self) -> bool:
"""Return True if device has an energy meter.""" """Return True if device has an energy meter."""
sys_info = self.sys_info return "ENE" in self.features
features = sys_info["feature"].split(":")
return "ENE" in features
async def get_sys_info(self) -> Dict[str, Any]: async def get_sys_info(self) -> Dict[str, Any]:
"""Retrieve system information.""" """Retrieve system information."""
@ -374,10 +394,8 @@ class SmartDevice:
@requires_update @requires_update
def rssi(self) -> Optional[int]: def rssi(self) -> Optional[int]:
"""Return WiFi signal strenth (rssi).""" """Return WiFi signal strenth (rssi)."""
sys_info = self.sys_info rssi = self.sys_info.get("rssi")
if "rssi" in sys_info: return None if rssi is None else int(rssi)
return int(sys_info["rssi"])
return None
@property # type: ignore @property # type: ignore
@requires_update @requires_update
@ -410,16 +428,12 @@ class SmartDevice:
@requires_update @requires_update
def emeter_realtime(self) -> EmeterStatus: def emeter_realtime(self) -> EmeterStatus:
"""Return current energy readings.""" """Return current energy readings."""
if not self.has_emeter: self._verify_emeter()
raise SmartDeviceException("Device has no emeter")
return EmeterStatus(self._last_update[self.emeter_type]["get_realtime"]) return EmeterStatus(self._last_update[self.emeter_type]["get_realtime"])
async def get_emeter_realtime(self) -> EmeterStatus: async def get_emeter_realtime(self) -> EmeterStatus:
"""Retrieve current energy readings.""" """Retrieve current energy readings."""
if not self.has_emeter: self._verify_emeter()
raise SmartDeviceException("Device has no emeter")
return EmeterStatus(await self._query_helper(self.emeter_type, "get_realtime")) return EmeterStatus(await self._query_helper(self.emeter_type, "get_realtime"))
def _create_emeter_request(self, year: int = None, month: int = None): def _create_emeter_request(self, year: int = None, month: int = None):
@ -429,23 +443,12 @@ class SmartDevice:
if month is None: if month is None:
month = datetime.now().month month = datetime.now().month
import collections.abc
def update(d, u):
"""Update dict recursively."""
for k, v in u.items():
if isinstance(v, collections.abc.Mapping):
d[k] = update(d.get(k, {}), v)
else:
d[k] = v
return d
req: Dict[str, Any] = {} req: Dict[str, Any] = {}
update(req, self._create_request(self.emeter_type, "get_realtime")) merge(req, self._create_request(self.emeter_type, "get_realtime"))
update( merge(
req, self._create_request(self.emeter_type, "get_monthstat", {"year": year}) req, self._create_request(self.emeter_type, "get_monthstat", {"year": year})
) )
update( merge(
req, req,
self._create_request( self._create_request(
self.emeter_type, "get_daystat", {"month": month, "year": year} self.emeter_type, "get_daystat", {"month": month, "year": year}
@ -458,9 +461,7 @@ class SmartDevice:
@requires_update @requires_update
def emeter_today(self) -> Optional[float]: def emeter_today(self) -> Optional[float]:
"""Return today's energy consumption in kWh.""" """Return today's energy consumption in kWh."""
if not self.has_emeter: self._verify_emeter()
raise SmartDeviceException("Device has no emeter")
raw_data = self._last_update[self.emeter_type]["get_daystat"]["day_list"] raw_data = self._last_update[self.emeter_type]["get_daystat"]["day_list"]
data = self._emeter_convert_emeter_data(raw_data) data = self._emeter_convert_emeter_data(raw_data)
today = datetime.now().day today = datetime.now().day
@ -474,9 +475,7 @@ class SmartDevice:
@requires_update @requires_update
def emeter_this_month(self) -> Optional[float]: def emeter_this_month(self) -> Optional[float]:
"""Return this month's energy consumption in kWh.""" """Return this month's energy consumption in kWh."""
if not self.has_emeter: self._verify_emeter()
raise SmartDeviceException("Device has no emeter")
raw_data = self._last_update[self.emeter_type]["get_monthstat"]["month_list"] raw_data = self._last_update[self.emeter_type]["get_monthstat"]["month_list"]
data = self._emeter_convert_emeter_data(raw_data) data = self._emeter_convert_emeter_data(raw_data)
current_month = datetime.now().month current_month = datetime.now().month
@ -516,9 +515,7 @@ class SmartDevice:
:param kwh: return usage in kWh (default: True) :param kwh: return usage in kWh (default: True)
:return: mapping of day of month to value :return: mapping of day of month to value
""" """
if not self.has_emeter: self._verify_emeter()
raise SmartDeviceException("Device has no emeter")
if year is None: if year is None:
year = datetime.now().year year = datetime.now().year
if month is None: if month is None:
@ -538,9 +535,7 @@ class SmartDevice:
:param kwh: return usage in kWh (default: True) :param kwh: return usage in kWh (default: True)
:return: dict: mapping of month to value :return: dict: mapping of month to value
""" """
if not self.has_emeter: self._verify_emeter()
raise SmartDeviceException("Device has no emeter")
if year is None: if year is None:
year = datetime.now().year year = datetime.now().year
@ -553,17 +548,13 @@ class SmartDevice:
@requires_update @requires_update
async def erase_emeter_stats(self) -> Dict: async def erase_emeter_stats(self) -> Dict:
"""Erase energy meter statistics.""" """Erase energy meter statistics."""
if not self.has_emeter: self._verify_emeter()
raise SmartDeviceException("Device has no emeter")
return await self._query_helper(self.emeter_type, "erase_emeter_stat", None) return await self._query_helper(self.emeter_type, "erase_emeter_stat", None)
@requires_update @requires_update
async def current_consumption(self) -> float: async def current_consumption(self) -> float:
"""Get the current power consumption in Watt.""" """Get the current power consumption in Watt."""
if not self.has_emeter: self._verify_emeter()
raise SmartDeviceException("Device has no emeter")
response = EmeterStatus(await self.get_emeter_realtime()) response = EmeterStatus(await self.get_emeter_realtime())
return float(response["power"]) return float(response["power"])

View File

@ -6,6 +6,7 @@ from typing import Any, DefaultDict, Dict, Optional
from kasa.smartdevice import ( from kasa.smartdevice import (
DeviceType, DeviceType,
EmeterStatus,
SmartDevice, SmartDevice,
SmartDeviceException, SmartDeviceException,
requires_update, requires_update,
@ -15,6 +16,15 @@ from kasa.smartplug import SmartPlug
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
def merge_sums(dicts):
"""Merge the sum of dicts."""
total_dict: DefaultDict[int, float] = defaultdict(lambda: 0.0)
for sum_dict in dicts:
for day, value in sum_dict.items():
total_dict[day] += value
return total_dict
class SmartStrip(SmartDevice): class SmartStrip(SmartDevice):
"""Representation of a TP-Link Smart Power Strip. """Representation of a TP-Link Smart Power Strip.
@ -75,11 +85,7 @@ class SmartStrip(SmartDevice):
@requires_update @requires_update
def is_on(self) -> bool: def is_on(self) -> bool:
"""Return if any of the outlets are on.""" """Return if any of the outlets are on."""
for plug in self.children: return any(plug.is_on for plug in self.children)
is_on = plug.is_on
if is_on:
return True
return False
async def update(self): async def update(self):
"""Update some of the attributes. """Update some of the attributes.
@ -97,6 +103,10 @@ class SmartStrip(SmartDevice):
SmartStripPlug(self.host, parent=self, child_id=child["id"]) SmartStripPlug(self.host, parent=self, child_id=child["id"])
) )
if self.has_emeter:
for plug in self.children:
await plug.update()
async def turn_on(self, **kwargs): async def turn_on(self, **kwargs):
"""Turn the strip on.""" """Turn the strip on."""
await self._query_helper("system", "set_relay_state", {"state": 1}) await self._query_helper("system", "set_relay_state", {"state": 1})
@ -140,16 +150,16 @@ class SmartStrip(SmartDevice):
async def current_consumption(self) -> float: async def current_consumption(self) -> float:
"""Get the current power consumption in watts.""" """Get the current power consumption in watts."""
consumption = sum(await plug.current_consumption() for plug in self.children) return sum([await plug.current_consumption() for plug in self.children])
return consumption @requires_update
async def get_emeter_realtime(self) -> EmeterStatus:
async def set_alias(self, alias: str) -> None: """Retrieve current energy readings."""
"""Set the alias for the strip. emeter_rt = await self._async_get_emeter_sum("get_emeter_realtime", {})
# Voltage is averaged since each read will result
:param alias: new alias # in a slightly different voltage since they are not atomic
""" emeter_rt["voltage_mv"] = int(emeter_rt["voltage_mv"] / len(self.children))
return await super().set_alias(alias) return EmeterStatus(emeter_rt)
@requires_update @requires_update
async def get_emeter_daily( async def get_emeter_daily(
@ -163,14 +173,9 @@ class SmartStrip(SmartDevice):
:param kwh: return usage in kWh (default: True) :param kwh: return usage in kWh (default: True)
:return: mapping of day of month to value :return: mapping of day of month to value
""" """
emeter_daily: DefaultDict[int, float] = defaultdict(lambda: 0.0) return await self._async_get_emeter_sum(
for plug in self.children: "get_emeter_daily", {"year": year, "month": month, "kwh": kwh}
plug_emeter_daily = await plug.get_emeter_daily( )
year=year, month=month, kwh=kwh
)
for day, value in plug_emeter_daily.items():
emeter_daily[day] += value
return emeter_daily
@requires_update @requires_update
async def get_emeter_monthly(self, year: int = None, kwh: bool = True) -> Dict: async def get_emeter_monthly(self, year: int = None, kwh: bool = True) -> Dict:
@ -179,13 +184,16 @@ class SmartStrip(SmartDevice):
:param year: year for which to retrieve statistics (default: this year) :param year: year for which to retrieve statistics (default: this year)
:param kwh: return usage in kWh (default: True) :param kwh: return usage in kWh (default: True)
""" """
emeter_monthly: DefaultDict[int, float] = defaultdict(lambda: 0.0) return await self._async_get_emeter_sum(
for plug in self.children: "get_emeter_monthly", {"year": year, "kwh": kwh}
plug_emeter_monthly = await plug.get_emeter_monthly(year=year, kwh=kwh) )
for month, value in plug_emeter_monthly:
emeter_monthly[month] += value
return emeter_monthly async def _async_get_emeter_sum(self, func: str, kwargs: Dict[str, Any]) -> Dict:
"""Retreive emeter stats for a time period from children."""
self._verify_emeter()
return merge_sums(
[await getattr(plug, func)(**kwargs) for plug in self.children]
)
@requires_update @requires_update
async def erase_emeter_stats(self): async def erase_emeter_stats(self):
@ -193,6 +201,28 @@ class SmartStrip(SmartDevice):
for plug in self.children: for plug in self.children:
await plug.erase_emeter_stats() await plug.erase_emeter_stats()
@property # type: ignore
@requires_update
def emeter_this_month(self) -> Optional[float]:
"""Return this month's energy consumption in kWh."""
return sum([plug.emeter_this_month for plug in self.children])
@property # type: ignore
@requires_update
def emeter_today(self) -> Optional[float]:
"""Return this month's energy consumption in kWh."""
return sum([plug.emeter_today for plug in self.children])
@property # type: ignore
@requires_update
def emeter_realtime(self) -> EmeterStatus:
"""Return current energy readings."""
emeter = merge_sums([plug.emeter_realtime for plug in self.children])
# Voltage is averaged since each read will result
# in a slightly different voltage since they are not atomic
emeter["voltage_mv"] = int(emeter["voltage_mv"] / len(self.children))
return EmeterStatus(emeter)
class SmartStripPlug(SmartPlug): class SmartStripPlug(SmartPlug):
"""Representation of a single socket in a power strip. """Representation of a single socket in a power strip.
@ -214,12 +244,22 @@ class SmartStripPlug(SmartPlug):
self._device_type = DeviceType.StripSocket self._device_type = DeviceType.StripSocket
async def update(self): async def update(self):
"""Override the update to no-op and inform the user.""" """Query the device to update the data.
_LOGGER.warning(
"You called update() on a child device, which has no effect." Needed for properties that are decorated with `requires_update`.
"Call update() on the parent device instead." """
self._last_update = await self.parent.protocol.query(
self.host, self._create_emeter_request()
) )
return
def _create_request(
self, target: str, cmd: str, arg: Optional[Dict] = None, child_ids=None
):
request: Dict[str, Any] = {
"context": {"child_ids": [self.child_id]},
target: {cmd: arg},
}
return request
async def _query_helper( async def _query_helper(
self, target: str, cmd: str, arg: Optional[Dict] = None, child_ids=None self, target: str, cmd: str, arg: Optional[Dict] = None, child_ids=None
@ -245,12 +285,6 @@ class SmartStripPlug(SmartPlug):
""" """
return False return False
@property # type: ignore
@requires_update
def has_emeter(self) -> bool:
"""Children have no emeter to my knowledge."""
return False
@property # type: ignore @property # type: ignore
@requires_update @requires_update
def device_id(self) -> str: def device_id(self) -> str:

View File

@ -22,9 +22,6 @@ async def test_no_emeter(dev):
@has_emeter @has_emeter
async def test_get_emeter_realtime(dev): async def test_get_emeter_realtime(dev):
if dev.is_strip:
pytest.skip("Disabled for strips temporarily")
assert dev.has_emeter assert dev.has_emeter
current_emeter = await dev.get_emeter_realtime() current_emeter = await dev.get_emeter_realtime()
@ -34,9 +31,6 @@ async def test_get_emeter_realtime(dev):
@has_emeter @has_emeter
@pytest.mark.requires_dummy @pytest.mark.requires_dummy
async def test_get_emeter_daily(dev): async def test_get_emeter_daily(dev):
if dev.is_strip:
pytest.skip("Disabled for strips temporarily")
assert dev.has_emeter assert dev.has_emeter
assert await dev.get_emeter_daily(year=1900, month=1) == {} assert await dev.get_emeter_daily(year=1900, month=1) == {}
@ -57,9 +51,6 @@ async def test_get_emeter_daily(dev):
@has_emeter @has_emeter
@pytest.mark.requires_dummy @pytest.mark.requires_dummy
async def test_get_emeter_monthly(dev): async def test_get_emeter_monthly(dev):
if dev.is_strip:
pytest.skip("Disabled for strips temporarily")
assert dev.has_emeter assert dev.has_emeter
assert await dev.get_emeter_monthly(year=1900) == {} assert await dev.get_emeter_monthly(year=1900) == {}
@ -79,9 +70,6 @@ async def test_get_emeter_monthly(dev):
@has_emeter @has_emeter
async def test_emeter_status(dev): async def test_emeter_status(dev):
if dev.is_strip:
pytest.skip("Disabled for strips temporarily")
assert dev.has_emeter assert dev.has_emeter
d = await dev.get_emeter_realtime() d = await dev.get_emeter_realtime()
@ -108,9 +96,6 @@ async def test_erase_emeter_stats(dev):
@has_emeter @has_emeter
async def test_current_consumption(dev): async def test_current_consumption(dev):
if dev.is_strip:
pytest.skip("Disabled for strips temporarily")
if dev.has_emeter: if dev.has_emeter:
x = await dev.current_consumption() x = await dev.current_consumption()
assert isinstance(x, float) assert isinstance(x, float)

View File

@ -4,6 +4,7 @@ from unittest.mock import patch
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
from kasa import SmartDeviceException from kasa import SmartDeviceException
from kasa.smartstrip import SmartStripPlug
from .conftest import handle_turn_on, has_emeter, no_emeter, pytestmark, turn_on from .conftest import handle_turn_on, has_emeter, no_emeter, pytestmark, turn_on
from .newfakes import PLUG_SCHEMA, TZ_SCHEMA, FakeTransportProtocol from .newfakes import PLUG_SCHEMA, TZ_SCHEMA, FakeTransportProtocol
@ -26,7 +27,7 @@ async def test_initial_update_emeter(dev, mocker):
dev._last_update = None dev._last_update = None
spy = mocker.spy(dev.protocol, "query") spy = mocker.spy(dev.protocol, "query")
await dev.update() await dev.update()
assert spy.call_count == 2 assert spy.call_count == 2 + len(dev.children)
@no_emeter @no_emeter