diff --git a/kasa/smartdevice.py b/kasa/smartdevice.py index 1efc0773..11c7d1c9 100755 --- a/kasa/smartdevice.py +++ b/kasa/smartdevice.py @@ -11,13 +11,14 @@ Stroetmann which is licensed under the Apache License, Version 2.0. You may obtain a copy of the license at http://www.apache.org/licenses/LICENSE-2.0 """ +import collections.abc import functools import inspect import logging from dataclasses import dataclass from datetime import datetime, timedelta from enum import Enum, auto -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Set from .emeterstatus import EmeterStatus from .exceptions import SmartDeviceException @@ -51,6 +52,16 @@ class WifiNetwork: rssi: Optional[int] = None +def merge(d, u): + """Update dict recursively.""" + for k, v in u.items(): + if isinstance(v, collections.abc.Mapping): + d[k] = merge(d.get(k, {}), v) + else: + d[k] = v + return d + + def requires_update(f): """Indicate that `update` should be called before accessing this method.""" # noqa: D202 if inspect.iscoroutinefunction(f): @@ -204,6 +215,11 @@ class SmartDevice: return request + def _verify_emeter(self) -> None: + """Raise an exception if there is no emeter.""" + if not self.has_emeter: + raise SmartDeviceException("Device has no emeter") + async def _query_helper( self, target: str, cmd: str, arg: Optional[Dict] = None, child_ids=None ) -> Any: @@ -240,13 +256,17 @@ class SmartDevice: return result + @property # type: ignore + @requires_update + def features(self) -> Set[str]: + """Return a set of features that the device supports.""" + return set(self.sys_info["feature"].split(":")) + @property # type: ignore @requires_update def has_emeter(self) -> bool: """Return True if device has an energy meter.""" - sys_info = self.sys_info - features = sys_info["feature"].split(":") - return "ENE" in features + return "ENE" in self.features async def get_sys_info(self) -> Dict[str, Any]: """Retrieve system information.""" @@ -374,10 +394,8 @@ class SmartDevice: @requires_update def rssi(self) -> Optional[int]: """Return WiFi signal strenth (rssi).""" - sys_info = self.sys_info - if "rssi" in sys_info: - return int(sys_info["rssi"]) - return None + rssi = self.sys_info.get("rssi") + return None if rssi is None else int(rssi) @property # type: ignore @requires_update @@ -410,16 +428,12 @@ class SmartDevice: @requires_update def emeter_realtime(self) -> EmeterStatus: """Return current energy readings.""" - if not self.has_emeter: - raise SmartDeviceException("Device has no emeter") - + self._verify_emeter() return EmeterStatus(self._last_update[self.emeter_type]["get_realtime"]) async def get_emeter_realtime(self) -> EmeterStatus: """Retrieve current energy readings.""" - if not self.has_emeter: - raise SmartDeviceException("Device has no emeter") - + self._verify_emeter() return EmeterStatus(await self._query_helper(self.emeter_type, "get_realtime")) def _create_emeter_request(self, year: int = None, month: int = None): @@ -429,23 +443,12 @@ class SmartDevice: if month is None: month = datetime.now().month - import collections.abc - - def update(d, u): - """Update dict recursively.""" - for k, v in u.items(): - if isinstance(v, collections.abc.Mapping): - d[k] = update(d.get(k, {}), v) - else: - d[k] = v - return d - req: Dict[str, Any] = {} - update(req, self._create_request(self.emeter_type, "get_realtime")) - update( + merge(req, self._create_request(self.emeter_type, "get_realtime")) + merge( req, self._create_request(self.emeter_type, "get_monthstat", {"year": year}) ) - update( + merge( req, self._create_request( self.emeter_type, "get_daystat", {"month": month, "year": year} @@ -458,9 +461,7 @@ class SmartDevice: @requires_update def emeter_today(self) -> Optional[float]: """Return today's energy consumption in kWh.""" - if not self.has_emeter: - raise SmartDeviceException("Device has no emeter") - + self._verify_emeter() raw_data = self._last_update[self.emeter_type]["get_daystat"]["day_list"] data = self._emeter_convert_emeter_data(raw_data) today = datetime.now().day @@ -474,9 +475,7 @@ class SmartDevice: @requires_update def emeter_this_month(self) -> Optional[float]: """Return this month's energy consumption in kWh.""" - if not self.has_emeter: - raise SmartDeviceException("Device has no emeter") - + self._verify_emeter() raw_data = self._last_update[self.emeter_type]["get_monthstat"]["month_list"] data = self._emeter_convert_emeter_data(raw_data) current_month = datetime.now().month @@ -516,9 +515,7 @@ class SmartDevice: :param kwh: return usage in kWh (default: True) :return: mapping of day of month to value """ - if not self.has_emeter: - raise SmartDeviceException("Device has no emeter") - + self._verify_emeter() if year is None: year = datetime.now().year if month is None: @@ -538,9 +535,7 @@ class SmartDevice: :param kwh: return usage in kWh (default: True) :return: dict: mapping of month to value """ - if not self.has_emeter: - raise SmartDeviceException("Device has no emeter") - + self._verify_emeter() if year is None: year = datetime.now().year @@ -553,17 +548,13 @@ class SmartDevice: @requires_update async def erase_emeter_stats(self) -> Dict: """Erase energy meter statistics.""" - if not self.has_emeter: - raise SmartDeviceException("Device has no emeter") - + self._verify_emeter() return await self._query_helper(self.emeter_type, "erase_emeter_stat", None) @requires_update async def current_consumption(self) -> float: """Get the current power consumption in Watt.""" - if not self.has_emeter: - raise SmartDeviceException("Device has no emeter") - + self._verify_emeter() response = EmeterStatus(await self.get_emeter_realtime()) return float(response["power"]) diff --git a/kasa/smartstrip.py b/kasa/smartstrip.py index a5351c5b..c1235920 100755 --- a/kasa/smartstrip.py +++ b/kasa/smartstrip.py @@ -6,6 +6,7 @@ from typing import Any, DefaultDict, Dict, Optional from kasa.smartdevice import ( DeviceType, + EmeterStatus, SmartDevice, SmartDeviceException, requires_update, @@ -15,6 +16,15 @@ from kasa.smartplug import SmartPlug _LOGGER = logging.getLogger(__name__) +def merge_sums(dicts): + """Merge the sum of dicts.""" + total_dict: DefaultDict[int, float] = defaultdict(lambda: 0.0) + for sum_dict in dicts: + for day, value in sum_dict.items(): + total_dict[day] += value + return total_dict + + class SmartStrip(SmartDevice): """Representation of a TP-Link Smart Power Strip. @@ -75,11 +85,7 @@ class SmartStrip(SmartDevice): @requires_update def is_on(self) -> bool: """Return if any of the outlets are on.""" - for plug in self.children: - is_on = plug.is_on - if is_on: - return True - return False + return any(plug.is_on for plug in self.children) async def update(self): """Update some of the attributes. @@ -97,6 +103,10 @@ class SmartStrip(SmartDevice): SmartStripPlug(self.host, parent=self, child_id=child["id"]) ) + if self.has_emeter: + for plug in self.children: + await plug.update() + async def turn_on(self, **kwargs): """Turn the strip on.""" await self._query_helper("system", "set_relay_state", {"state": 1}) @@ -140,16 +150,16 @@ class SmartStrip(SmartDevice): async def current_consumption(self) -> float: """Get the current power consumption in watts.""" - consumption = sum(await plug.current_consumption() for plug in self.children) + return sum([await plug.current_consumption() for plug in self.children]) - return consumption - - async def set_alias(self, alias: str) -> None: - """Set the alias for the strip. - - :param alias: new alias - """ - return await super().set_alias(alias) + @requires_update + async def get_emeter_realtime(self) -> EmeterStatus: + """Retrieve current energy readings.""" + emeter_rt = await self._async_get_emeter_sum("get_emeter_realtime", {}) + # Voltage is averaged since each read will result + # in a slightly different voltage since they are not atomic + emeter_rt["voltage_mv"] = int(emeter_rt["voltage_mv"] / len(self.children)) + return EmeterStatus(emeter_rt) @requires_update async def get_emeter_daily( @@ -163,14 +173,9 @@ class SmartStrip(SmartDevice): :param kwh: return usage in kWh (default: True) :return: mapping of day of month to value """ - emeter_daily: DefaultDict[int, float] = defaultdict(lambda: 0.0) - for plug in self.children: - plug_emeter_daily = await plug.get_emeter_daily( - year=year, month=month, kwh=kwh - ) - for day, value in plug_emeter_daily.items(): - emeter_daily[day] += value - return emeter_daily + return await self._async_get_emeter_sum( + "get_emeter_daily", {"year": year, "month": month, "kwh": kwh} + ) @requires_update async def get_emeter_monthly(self, year: int = None, kwh: bool = True) -> Dict: @@ -179,13 +184,16 @@ class SmartStrip(SmartDevice): :param year: year for which to retrieve statistics (default: this year) :param kwh: return usage in kWh (default: True) """ - emeter_monthly: DefaultDict[int, float] = defaultdict(lambda: 0.0) - for plug in self.children: - plug_emeter_monthly = await plug.get_emeter_monthly(year=year, kwh=kwh) - for month, value in plug_emeter_monthly: - emeter_monthly[month] += value + return await self._async_get_emeter_sum( + "get_emeter_monthly", {"year": year, "kwh": kwh} + ) - return emeter_monthly + async def _async_get_emeter_sum(self, func: str, kwargs: Dict[str, Any]) -> Dict: + """Retreive emeter stats for a time period from children.""" + self._verify_emeter() + return merge_sums( + [await getattr(plug, func)(**kwargs) for plug in self.children] + ) @requires_update async def erase_emeter_stats(self): @@ -193,6 +201,28 @@ class SmartStrip(SmartDevice): for plug in self.children: await plug.erase_emeter_stats() + @property # type: ignore + @requires_update + def emeter_this_month(self) -> Optional[float]: + """Return this month's energy consumption in kWh.""" + return sum([plug.emeter_this_month for plug in self.children]) + + @property # type: ignore + @requires_update + def emeter_today(self) -> Optional[float]: + """Return this month's energy consumption in kWh.""" + return sum([plug.emeter_today for plug in self.children]) + + @property # type: ignore + @requires_update + def emeter_realtime(self) -> EmeterStatus: + """Return current energy readings.""" + emeter = merge_sums([plug.emeter_realtime for plug in self.children]) + # Voltage is averaged since each read will result + # in a slightly different voltage since they are not atomic + emeter["voltage_mv"] = int(emeter["voltage_mv"] / len(self.children)) + return EmeterStatus(emeter) + class SmartStripPlug(SmartPlug): """Representation of a single socket in a power strip. @@ -214,12 +244,22 @@ class SmartStripPlug(SmartPlug): self._device_type = DeviceType.StripSocket async def update(self): - """Override the update to no-op and inform the user.""" - _LOGGER.warning( - "You called update() on a child device, which has no effect." - "Call update() on the parent device instead." + """Query the device to update the data. + + Needed for properties that are decorated with `requires_update`. + """ + self._last_update = await self.parent.protocol.query( + self.host, self._create_emeter_request() ) - return + + def _create_request( + self, target: str, cmd: str, arg: Optional[Dict] = None, child_ids=None + ): + request: Dict[str, Any] = { + "context": {"child_ids": [self.child_id]}, + target: {cmd: arg}, + } + return request async def _query_helper( self, target: str, cmd: str, arg: Optional[Dict] = None, child_ids=None @@ -245,12 +285,6 @@ class SmartStripPlug(SmartPlug): """ return False - @property # type: ignore - @requires_update - def has_emeter(self) -> bool: - """Children have no emeter to my knowledge.""" - return False - @property # type: ignore @requires_update def device_id(self) -> str: diff --git a/kasa/tests/test_emeter.py b/kasa/tests/test_emeter.py index 7f0f95ac..b3d567dd 100644 --- a/kasa/tests/test_emeter.py +++ b/kasa/tests/test_emeter.py @@ -22,9 +22,6 @@ async def test_no_emeter(dev): @has_emeter async def test_get_emeter_realtime(dev): - if dev.is_strip: - pytest.skip("Disabled for strips temporarily") - assert dev.has_emeter current_emeter = await dev.get_emeter_realtime() @@ -34,9 +31,6 @@ async def test_get_emeter_realtime(dev): @has_emeter @pytest.mark.requires_dummy async def test_get_emeter_daily(dev): - if dev.is_strip: - pytest.skip("Disabled for strips temporarily") - assert dev.has_emeter assert await dev.get_emeter_daily(year=1900, month=1) == {} @@ -57,9 +51,6 @@ async def test_get_emeter_daily(dev): @has_emeter @pytest.mark.requires_dummy async def test_get_emeter_monthly(dev): - if dev.is_strip: - pytest.skip("Disabled for strips temporarily") - assert dev.has_emeter assert await dev.get_emeter_monthly(year=1900) == {} @@ -79,9 +70,6 @@ async def test_get_emeter_monthly(dev): @has_emeter async def test_emeter_status(dev): - if dev.is_strip: - pytest.skip("Disabled for strips temporarily") - assert dev.has_emeter d = await dev.get_emeter_realtime() @@ -108,9 +96,6 @@ async def test_erase_emeter_stats(dev): @has_emeter async def test_current_consumption(dev): - if dev.is_strip: - pytest.skip("Disabled for strips temporarily") - if dev.has_emeter: x = await dev.current_consumption() assert isinstance(x, float) diff --git a/kasa/tests/test_smartdevice.py b/kasa/tests/test_smartdevice.py index 002adb90..380cdd1f 100644 --- a/kasa/tests/test_smartdevice.py +++ b/kasa/tests/test_smartdevice.py @@ -4,6 +4,7 @@ from unittest.mock import patch import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 from kasa import SmartDeviceException +from kasa.smartstrip import SmartStripPlug from .conftest import handle_turn_on, has_emeter, no_emeter, pytestmark, turn_on from .newfakes import PLUG_SCHEMA, TZ_SCHEMA, FakeTransportProtocol @@ -26,7 +27,7 @@ async def test_initial_update_emeter(dev, mocker): dev._last_update = None spy = mocker.spy(dev.protocol, "query") await dev.update() - assert spy.call_count == 2 + assert spy.call_count == 2 + len(dev.children) @no_emeter