Add emeter support for strip sockets (#203)

* Add support for plugs with emeters.

* Tweaks for emeter

* black

* tweaks

* tweaks

* more tweaks

* dry

* flake8

* flake8

* legacy typing

* Update kasa/smartstrip.py

Co-authored-by: Teemu R. <tpr@iki.fi>

* reduce

* remove useless delegation

* tweaks

* tweaks

* dry

* tweak

* tweak

* tweak

* tweak

* update tests

* wrap

* preen

* prune

* prune

* prune

* guard

* adjust

* robust

* prune

* prune

* reduce dict lookups by 1

* Update kasa/smartstrip.py

Co-authored-by: Teemu R. <tpr@iki.fi>

* delete utils

* isort

Co-authored-by: Brendan Burns <brendan.d.burns@gmail.com>
Co-authored-by: Teemu R. <tpr@iki.fi>
This commit is contained in:
J. Nick Koston 2021-09-23 17:24:44 -05:00 committed by GitHub
parent d7202883e9
commit 94e5a90ac4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 112 additions and 101 deletions

View File

@ -11,13 +11,14 @@ Stroetmann which is licensed under the Apache License, Version 2.0.
You may obtain a copy of the license at
http://www.apache.org/licenses/LICENSE-2.0
"""
import collections.abc
import functools
import inspect
import logging
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import Enum, auto
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Set
from .emeterstatus import EmeterStatus
from .exceptions import SmartDeviceException
@ -51,6 +52,16 @@ class WifiNetwork:
rssi: Optional[int] = None
def merge(d, u):
"""Update dict recursively."""
for k, v in u.items():
if isinstance(v, collections.abc.Mapping):
d[k] = merge(d.get(k, {}), v)
else:
d[k] = v
return d
def requires_update(f):
"""Indicate that `update` should be called before accessing this method.""" # noqa: D202
if inspect.iscoroutinefunction(f):
@ -204,6 +215,11 @@ class SmartDevice:
return request
def _verify_emeter(self) -> None:
"""Raise an exception if there is no emeter."""
if not self.has_emeter:
raise SmartDeviceException("Device has no emeter")
async def _query_helper(
self, target: str, cmd: str, arg: Optional[Dict] = None, child_ids=None
) -> Any:
@ -240,13 +256,17 @@ class SmartDevice:
return result
@property # type: ignore
@requires_update
def features(self) -> Set[str]:
"""Return a set of features that the device supports."""
return set(self.sys_info["feature"].split(":"))
@property # type: ignore
@requires_update
def has_emeter(self) -> bool:
"""Return True if device has an energy meter."""
sys_info = self.sys_info
features = sys_info["feature"].split(":")
return "ENE" in features
return "ENE" in self.features
async def get_sys_info(self) -> Dict[str, Any]:
"""Retrieve system information."""
@ -374,10 +394,8 @@ class SmartDevice:
@requires_update
def rssi(self) -> Optional[int]:
"""Return WiFi signal strenth (rssi)."""
sys_info = self.sys_info
if "rssi" in sys_info:
return int(sys_info["rssi"])
return None
rssi = self.sys_info.get("rssi")
return None if rssi is None else int(rssi)
@property # type: ignore
@requires_update
@ -410,16 +428,12 @@ class SmartDevice:
@requires_update
def emeter_realtime(self) -> EmeterStatus:
"""Return current energy readings."""
if not self.has_emeter:
raise SmartDeviceException("Device has no emeter")
self._verify_emeter()
return EmeterStatus(self._last_update[self.emeter_type]["get_realtime"])
async def get_emeter_realtime(self) -> EmeterStatus:
"""Retrieve current energy readings."""
if not self.has_emeter:
raise SmartDeviceException("Device has no emeter")
self._verify_emeter()
return EmeterStatus(await self._query_helper(self.emeter_type, "get_realtime"))
def _create_emeter_request(self, year: int = None, month: int = None):
@ -429,23 +443,12 @@ class SmartDevice:
if month is None:
month = datetime.now().month
import collections.abc
def update(d, u):
"""Update dict recursively."""
for k, v in u.items():
if isinstance(v, collections.abc.Mapping):
d[k] = update(d.get(k, {}), v)
else:
d[k] = v
return d
req: Dict[str, Any] = {}
update(req, self._create_request(self.emeter_type, "get_realtime"))
update(
merge(req, self._create_request(self.emeter_type, "get_realtime"))
merge(
req, self._create_request(self.emeter_type, "get_monthstat", {"year": year})
)
update(
merge(
req,
self._create_request(
self.emeter_type, "get_daystat", {"month": month, "year": year}
@ -458,9 +461,7 @@ class SmartDevice:
@requires_update
def emeter_today(self) -> Optional[float]:
"""Return today's energy consumption in kWh."""
if not self.has_emeter:
raise SmartDeviceException("Device has no emeter")
self._verify_emeter()
raw_data = self._last_update[self.emeter_type]["get_daystat"]["day_list"]
data = self._emeter_convert_emeter_data(raw_data)
today = datetime.now().day
@ -474,9 +475,7 @@ class SmartDevice:
@requires_update
def emeter_this_month(self) -> Optional[float]:
"""Return this month's energy consumption in kWh."""
if not self.has_emeter:
raise SmartDeviceException("Device has no emeter")
self._verify_emeter()
raw_data = self._last_update[self.emeter_type]["get_monthstat"]["month_list"]
data = self._emeter_convert_emeter_data(raw_data)
current_month = datetime.now().month
@ -516,9 +515,7 @@ class SmartDevice:
:param kwh: return usage in kWh (default: True)
:return: mapping of day of month to value
"""
if not self.has_emeter:
raise SmartDeviceException("Device has no emeter")
self._verify_emeter()
if year is None:
year = datetime.now().year
if month is None:
@ -538,9 +535,7 @@ class SmartDevice:
:param kwh: return usage in kWh (default: True)
:return: dict: mapping of month to value
"""
if not self.has_emeter:
raise SmartDeviceException("Device has no emeter")
self._verify_emeter()
if year is None:
year = datetime.now().year
@ -553,17 +548,13 @@ class SmartDevice:
@requires_update
async def erase_emeter_stats(self) -> Dict:
"""Erase energy meter statistics."""
if not self.has_emeter:
raise SmartDeviceException("Device has no emeter")
self._verify_emeter()
return await self._query_helper(self.emeter_type, "erase_emeter_stat", None)
@requires_update
async def current_consumption(self) -> float:
"""Get the current power consumption in Watt."""
if not self.has_emeter:
raise SmartDeviceException("Device has no emeter")
self._verify_emeter()
response = EmeterStatus(await self.get_emeter_realtime())
return float(response["power"])

View File

@ -6,6 +6,7 @@ from typing import Any, DefaultDict, Dict, Optional
from kasa.smartdevice import (
DeviceType,
EmeterStatus,
SmartDevice,
SmartDeviceException,
requires_update,
@ -15,6 +16,15 @@ from kasa.smartplug import SmartPlug
_LOGGER = logging.getLogger(__name__)
def merge_sums(dicts):
"""Merge the sum of dicts."""
total_dict: DefaultDict[int, float] = defaultdict(lambda: 0.0)
for sum_dict in dicts:
for day, value in sum_dict.items():
total_dict[day] += value
return total_dict
class SmartStrip(SmartDevice):
"""Representation of a TP-Link Smart Power Strip.
@ -75,11 +85,7 @@ class SmartStrip(SmartDevice):
@requires_update
def is_on(self) -> bool:
"""Return if any of the outlets are on."""
for plug in self.children:
is_on = plug.is_on
if is_on:
return True
return False
return any(plug.is_on for plug in self.children)
async def update(self):
"""Update some of the attributes.
@ -97,6 +103,10 @@ class SmartStrip(SmartDevice):
SmartStripPlug(self.host, parent=self, child_id=child["id"])
)
if self.has_emeter:
for plug in self.children:
await plug.update()
async def turn_on(self, **kwargs):
"""Turn the strip on."""
await self._query_helper("system", "set_relay_state", {"state": 1})
@ -140,16 +150,16 @@ class SmartStrip(SmartDevice):
async def current_consumption(self) -> float:
"""Get the current power consumption in watts."""
consumption = sum(await plug.current_consumption() for plug in self.children)
return sum([await plug.current_consumption() for plug in self.children])
return consumption
async def set_alias(self, alias: str) -> None:
"""Set the alias for the strip.
:param alias: new alias
"""
return await super().set_alias(alias)
@requires_update
async def get_emeter_realtime(self) -> EmeterStatus:
"""Retrieve current energy readings."""
emeter_rt = await self._async_get_emeter_sum("get_emeter_realtime", {})
# Voltage is averaged since each read will result
# in a slightly different voltage since they are not atomic
emeter_rt["voltage_mv"] = int(emeter_rt["voltage_mv"] / len(self.children))
return EmeterStatus(emeter_rt)
@requires_update
async def get_emeter_daily(
@ -163,14 +173,9 @@ class SmartStrip(SmartDevice):
:param kwh: return usage in kWh (default: True)
:return: mapping of day of month to value
"""
emeter_daily: DefaultDict[int, float] = defaultdict(lambda: 0.0)
for plug in self.children:
plug_emeter_daily = await plug.get_emeter_daily(
year=year, month=month, kwh=kwh
return await self._async_get_emeter_sum(
"get_emeter_daily", {"year": year, "month": month, "kwh": kwh}
)
for day, value in plug_emeter_daily.items():
emeter_daily[day] += value
return emeter_daily
@requires_update
async def get_emeter_monthly(self, year: int = None, kwh: bool = True) -> Dict:
@ -179,13 +184,16 @@ class SmartStrip(SmartDevice):
:param year: year for which to retrieve statistics (default: this year)
:param kwh: return usage in kWh (default: True)
"""
emeter_monthly: DefaultDict[int, float] = defaultdict(lambda: 0.0)
for plug in self.children:
plug_emeter_monthly = await plug.get_emeter_monthly(year=year, kwh=kwh)
for month, value in plug_emeter_monthly:
emeter_monthly[month] += value
return await self._async_get_emeter_sum(
"get_emeter_monthly", {"year": year, "kwh": kwh}
)
return emeter_monthly
async def _async_get_emeter_sum(self, func: str, kwargs: Dict[str, Any]) -> Dict:
"""Retreive emeter stats for a time period from children."""
self._verify_emeter()
return merge_sums(
[await getattr(plug, func)(**kwargs) for plug in self.children]
)
@requires_update
async def erase_emeter_stats(self):
@ -193,6 +201,28 @@ class SmartStrip(SmartDevice):
for plug in self.children:
await plug.erase_emeter_stats()
@property # type: ignore
@requires_update
def emeter_this_month(self) -> Optional[float]:
"""Return this month's energy consumption in kWh."""
return sum([plug.emeter_this_month for plug in self.children])
@property # type: ignore
@requires_update
def emeter_today(self) -> Optional[float]:
"""Return this month's energy consumption in kWh."""
return sum([plug.emeter_today for plug in self.children])
@property # type: ignore
@requires_update
def emeter_realtime(self) -> EmeterStatus:
"""Return current energy readings."""
emeter = merge_sums([plug.emeter_realtime for plug in self.children])
# Voltage is averaged since each read will result
# in a slightly different voltage since they are not atomic
emeter["voltage_mv"] = int(emeter["voltage_mv"] / len(self.children))
return EmeterStatus(emeter)
class SmartStripPlug(SmartPlug):
"""Representation of a single socket in a power strip.
@ -214,12 +244,22 @@ class SmartStripPlug(SmartPlug):
self._device_type = DeviceType.StripSocket
async def update(self):
"""Override the update to no-op and inform the user."""
_LOGGER.warning(
"You called update() on a child device, which has no effect."
"Call update() on the parent device instead."
"""Query the device to update the data.
Needed for properties that are decorated with `requires_update`.
"""
self._last_update = await self.parent.protocol.query(
self.host, self._create_emeter_request()
)
return
def _create_request(
self, target: str, cmd: str, arg: Optional[Dict] = None, child_ids=None
):
request: Dict[str, Any] = {
"context": {"child_ids": [self.child_id]},
target: {cmd: arg},
}
return request
async def _query_helper(
self, target: str, cmd: str, arg: Optional[Dict] = None, child_ids=None
@ -245,12 +285,6 @@ class SmartStripPlug(SmartPlug):
"""
return False
@property # type: ignore
@requires_update
def has_emeter(self) -> bool:
"""Children have no emeter to my knowledge."""
return False
@property # type: ignore
@requires_update
def device_id(self) -> str:

View File

@ -22,9 +22,6 @@ async def test_no_emeter(dev):
@has_emeter
async def test_get_emeter_realtime(dev):
if dev.is_strip:
pytest.skip("Disabled for strips temporarily")
assert dev.has_emeter
current_emeter = await dev.get_emeter_realtime()
@ -34,9 +31,6 @@ async def test_get_emeter_realtime(dev):
@has_emeter
@pytest.mark.requires_dummy
async def test_get_emeter_daily(dev):
if dev.is_strip:
pytest.skip("Disabled for strips temporarily")
assert dev.has_emeter
assert await dev.get_emeter_daily(year=1900, month=1) == {}
@ -57,9 +51,6 @@ async def test_get_emeter_daily(dev):
@has_emeter
@pytest.mark.requires_dummy
async def test_get_emeter_monthly(dev):
if dev.is_strip:
pytest.skip("Disabled for strips temporarily")
assert dev.has_emeter
assert await dev.get_emeter_monthly(year=1900) == {}
@ -79,9 +70,6 @@ async def test_get_emeter_monthly(dev):
@has_emeter
async def test_emeter_status(dev):
if dev.is_strip:
pytest.skip("Disabled for strips temporarily")
assert dev.has_emeter
d = await dev.get_emeter_realtime()
@ -108,9 +96,6 @@ async def test_erase_emeter_stats(dev):
@has_emeter
async def test_current_consumption(dev):
if dev.is_strip:
pytest.skip("Disabled for strips temporarily")
if dev.has_emeter:
x = await dev.current_consumption()
assert isinstance(x, float)

View File

@ -4,6 +4,7 @@ from unittest.mock import patch
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
from kasa import SmartDeviceException
from kasa.smartstrip import SmartStripPlug
from .conftest import handle_turn_on, has_emeter, no_emeter, pytestmark, turn_on
from .newfakes import PLUG_SCHEMA, TZ_SCHEMA, FakeTransportProtocol
@ -26,7 +27,7 @@ async def test_initial_update_emeter(dev, mocker):
dev._last_update = None
spy = mocker.spy(dev.protocol, "query")
await dev.update()
assert spy.call_count == 2
assert spy.call_count == 2 + len(dev.children)
@no_emeter