Initial implementation for modularized smartdevice (#757)

The initial steps to modularize the smartdevice. Modules are initialized based on the component negotiation, and each module can indicate which features it supports and which queries should be run during the update cycle.
This commit is contained in:
Teemu R 2024-02-19 18:01:31 +01:00 committed by GitHub
parent e86dcb6bf5
commit 11719991c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 408 additions and 156 deletions

View File

@ -590,10 +590,7 @@ async def state(ctx, dev: Device):
echo("\n\t[bold]== Modules ==[/bold]")
for module in dev.modules.values():
if module.is_supported:
echo(f"\t[green]+ {module}[/green]")
else:
echo(f"\t[red]- {module}[/red]")
echo(f"\t[green]+ {module}[/green]")
if verbose:
echo("\n\t[bold]== Verbose information ==[/bold]")

View File

@ -24,7 +24,8 @@ from ..emeterstatus import EmeterStatus
from ..exceptions import SmartDeviceException
from ..feature import Feature
from ..protocol import BaseProtocol
from .modules import Emeter, IotModule
from .iotmodule import IotModule
from .modules import Emeter
_LOGGER = logging.getLogger(__name__)

View File

@ -1,20 +1,14 @@
"""Base class for all module implementations."""
"""Base class for IOT module implementations."""
import collections
import logging
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Dict
from ...exceptions import SmartDeviceException
from ...feature import Feature
if TYPE_CHECKING:
from kasa.iot import IotDevice
from ..exceptions import SmartDeviceException
from ..module import Module
_LOGGER = logging.getLogger(__name__)
# TODO: This is used for query construcing
# TODO: This is used for query constructing, check for a better place
def merge(d, u):
"""Update dict recursively."""
for k, v in u.items():
@ -25,32 +19,16 @@ def merge(d, u):
return d
class IotModule(ABC):
"""Base class implemention for all modules.
class IotModule(Module):
"""Base class implemention for all IOT modules."""
The base classes should implement `query` to return the query they want to be
executed during the regular update cycle.
"""
def call(self, method, params=None):
"""Call the given method with the given parameters."""
return self._device._query_helper(self._module, method, params)
def __init__(self, device: "IotDevice", module: str):
self._device = device
self._module = module
self._module_features: Dict[str, Feature] = {}
def _add_feature(self, feature: Feature):
"""Add module feature."""
feature_name = f"{self._module}_{feature.name}"
if feature_name in self._module_features:
raise SmartDeviceException("Duplicate name detected %s" % feature_name)
self._module_features[feature_name] = feature
@abstractmethod
def query(self):
"""Query to execute during the update cycle.
The inheriting modules implement this to include their wanted
queries to the query that gets executed when Device.update() gets called.
"""
def query_for_command(self, query, params=None):
"""Create a request object for the given parameters."""
return self._device._create_request(self._module, query, params)
@property
def estimated_query_response_size(self):
@ -80,17 +58,3 @@ class IotModule(ABC):
return True
return "err_code" not in self.data
def call(self, method, params=None):
"""Call the given method with the given parameters."""
return self._device._query_helper(self._module, method, params)
def query_for_command(self, query, params=None):
"""Create a request object for the given parameters."""
return self._device._create_request(self._module, query, params)
def __repr__(self) -> str:
return (
f"<Module {self.__class__.__name__} ({self._module})"
f" for {self._device.host}>"
)

View File

@ -4,7 +4,6 @@ from .antitheft import Antitheft
from .cloud import Cloud
from .countdown import Countdown
from .emeter import Emeter
from .module import IotModule
from .motion import Motion
from .rulemodule import Rule, RuleModule
from .schedule import Schedule
@ -17,7 +16,6 @@ __all__ = [
"Cloud",
"Countdown",
"Emeter",
"IotModule",
"Motion",
"Rule",
"RuleModule",

View File

@ -1,5 +1,5 @@
"""Implementation of the ambient light (LAS) module found in some dimmers."""
from .module import IotModule
from ..iotmodule import IotModule
# TODO create tests and use the config reply there
# [{"hw_id":0,"enable":0,"dark_index":1,"min_adc":0,"max_adc":2450,

View File

@ -5,7 +5,7 @@ except ImportError:
from pydantic import BaseModel
from ...feature import Feature, FeatureType
from .module import IotModule
from ..iotmodule import IotModule
class CloudInfo(BaseModel):

View File

@ -3,7 +3,7 @@ from enum import Enum
from typing import Optional
from ...exceptions import SmartDeviceException
from .module import IotModule
from ..iotmodule import IotModule
class Range(Enum):

View File

@ -9,7 +9,7 @@ except ImportError:
from pydantic import BaseModel
from .module import IotModule, merge
from ..iotmodule import IotModule, merge
class Action(Enum):

View File

@ -2,7 +2,7 @@
from datetime import datetime
from ...exceptions import SmartDeviceException
from .module import IotModule, merge
from ..iotmodule import IotModule, merge
class Time(IotModule):

View File

@ -2,7 +2,7 @@
from datetime import datetime
from typing import Dict
from .module import IotModule, merge
from ..iotmodule import IotModule, merge
class Usage(IotModule):

49
kasa/module.py Normal file
View File

@ -0,0 +1,49 @@
"""Base class for all module implementations."""
import logging
from abc import ABC, abstractmethod
from typing import Dict
from .device import Device
from .exceptions import SmartDeviceException
from .feature import Feature
_LOGGER = logging.getLogger(__name__)
class Module(ABC):
"""Base class implemention for all modules.
The base classes should implement `query` to return the query they want to be
executed during the regular update cycle.
"""
def __init__(self, device: "Device", module: str):
self._device = device
self._module = module
self._module_features: Dict[str, Feature] = {}
@abstractmethod
def query(self):
"""Query to execute during the update cycle.
The inheriting modules implement this to include their wanted
queries to the query that gets executed when Device.update() gets called.
"""
@property
@abstractmethod
def data(self):
"""Return the module specific raw data from the last update."""
def _add_feature(self, feature: Feature):
"""Add module feature."""
feat_name = f"{self._module}_{feature.name}"
if feat_name in self._module_features:
raise SmartDeviceException("Duplicate name detected %s" % feat_name)
self._module_features[feat_name] = feature
def __repr__(self) -> str:
return (
f"<Module {self.__class__.__name__} ({self._module})"
f" for {self._device.host}>"
)

View File

@ -0,0 +1,7 @@
"""Modules for SMART devices."""
from .childdevicemodule import ChildDeviceModule
from .devicemodule import DeviceModule
from .energymodule import EnergyModule
from .timemodule import TimeModule
__all__ = ["TimeModule", "EnergyModule", "DeviceModule", "ChildDeviceModule"]

View File

@ -0,0 +1,9 @@
"""Implementation for child devices."""
from ..smartmodule import SmartModule
class ChildDeviceModule(SmartModule):
"""Implementation for child devices."""
REQUIRED_COMPONENT = "child_device"
QUERY_GETTER_NAME = "get_child_device_list"

View File

@ -0,0 +1,21 @@
"""Implementation of device module."""
from typing import Dict
from ..smartmodule import SmartModule
class DeviceModule(SmartModule):
"""Implementation of device module."""
REQUIRED_COMPONENT = "device"
def query(self) -> Dict:
"""Query to execute during the update cycle."""
query = {
"get_device_info": None,
}
# Device usage is not available on older firmware versions
if self._device._components[self.REQUIRED_COMPONENT] >= 2:
query["get_device_usage"] = None
return query

View File

@ -0,0 +1,88 @@
"""Implementation of energy monitoring module."""
from typing import TYPE_CHECKING, Dict, Optional
from ...emeterstatus import EmeterStatus
from ...feature import Feature
from ..smartmodule import SmartModule
if TYPE_CHECKING:
from ..smartdevice import SmartDevice
class EnergyModule(SmartModule):
"""Implementation of energy monitoring module."""
REQUIRED_COMPONENT = "energy_monitoring"
def __init__(self, device: "SmartDevice", module: str):
super().__init__(device, module)
self._add_feature(
Feature(
device,
name="Current consumption",
attribute_getter="current_power",
container=self,
)
) # W or mW?
self._add_feature(
Feature(
device,
name="Today's consumption",
attribute_getter="emeter_today",
container=self,
)
) # Wh or kWh?
self._add_feature(
Feature(
device,
name="This month's consumption",
attribute_getter="emeter_this_month",
container=self,
)
) # Wh or kWH?
def query(self) -> Dict:
"""Query to execute during the update cycle."""
return {
"get_energy_usage": None,
# The current_power in get_energy_usage is more precise (mw vs. w),
# making this rather useless, but maybe there are version differences?
"get_current_power": None,
}
@property
def current_power(self):
"""Current power."""
return self.emeter_realtime.power
@property
def energy(self):
"""Return get_energy_usage results."""
return self.data["get_energy_usage"]
@property
def emeter_realtime(self):
"""Get the emeter status."""
# TODO: Perhaps we should get rid of emeterstatus altogether for smartdevices
return EmeterStatus(
{
"power_mw": self.energy.get("current_power"),
"total": self._convert_energy_data(
self.energy.get("today_energy"), 1 / 1000
),
}
)
@property
def emeter_this_month(self) -> Optional[float]:
"""Get the emeter value for this month."""
return self._convert_energy_data(self.energy.get("month_energy"), 1 / 1000)
@property
def emeter_today(self) -> Optional[float]:
"""Get the emeter value for today."""
return self._convert_energy_data(self.energy.get("today_energy"), 1 / 1000)
def _convert_energy_data(self, data, scale) -> Optional[float]:
"""Return adjusted emeter information."""
return data if not data else data * scale

View File

@ -0,0 +1,52 @@
"""Implementation of time module."""
from datetime import datetime, timedelta, timezone
from time import mktime
from typing import TYPE_CHECKING, cast
from ...feature import Feature
from ..smartmodule import SmartModule
if TYPE_CHECKING:
from ..smartdevice import SmartDevice
class TimeModule(SmartModule):
"""Implementation of device_local_time."""
REQUIRED_COMPONENT = "time"
QUERY_GETTER_NAME = "get_device_time"
def __init__(self, device: "SmartDevice", module: str):
super().__init__(device, module)
self._add_feature(
Feature(
device=device,
name="Time",
attribute_getter="time",
container=self,
)
)
@property
def time(self) -> datetime:
"""Return device's current datetime."""
td = timedelta(minutes=cast(float, self.data.get("time_diff")))
if self.data.get("region"):
tz = timezone(td, str(self.data.get("region")))
else:
# in case the device returns a blank region this will result in the
# tzname being a UTC offset
tz = timezone(td)
return datetime.fromtimestamp(
cast(float, self.data.get("timestamp")),
tz=tz,
)
async def set_time(self, dt: datetime):
"""Set device time."""
unixtime = mktime(dt.timetuple())
return await self.call(
"set_device_time",
{"timestamp": unixtime, "time_diff": dt.utcoffset(), "region": dt.tzname()},
)

View File

@ -24,9 +24,6 @@ class SmartChildDevice(SmartDevice):
self._parent = parent
self._id = child_id
self.protocol = _ChildProtocolWrapper(child_id, parent.protocol)
# TODO: remove the assignment after modularization is done,
# currently required to allow accessing time-related properties
self._time = parent._time
self._device_type = DeviceType.StripSocket
async def update(self, update_children: bool = True):

View File

@ -1,7 +1,7 @@
"""Module for a SMART device."""
import base64
import logging
from datetime import datetime, timedelta, timezone
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, cast
from ..aestransport import AesTransport
@ -12,6 +12,13 @@ from ..emeterstatus import EmeterStatus
from ..exceptions import AuthenticationException, SmartDeviceException, SmartErrorCode
from ..feature import Feature, FeatureType
from ..smartprotocol import SmartProtocol
from .modules import ( # noqa: F401
ChildDeviceModule,
DeviceModule,
EnergyModule,
TimeModule,
)
from .smartmodule import SmartModule
_LOGGER = logging.getLogger(__name__)
@ -37,9 +44,8 @@ class SmartDevice(Device):
self._components_raw: Optional[Dict[str, Any]] = None
self._components: Dict[str, int] = {}
self._children: Dict[str, "SmartChildDevice"] = {}
self._energy: Dict[str, Any] = {}
self._state_information: Dict[str, Any] = {}
self._time: Dict[str, Any] = {}
self.modules: Dict[str, SmartModule] = {}
async def _initialize_children(self):
"""Initialize children for power strips."""
@ -79,67 +85,43 @@ class SmartDevice(Device):
f"{request} not found in {responses} for device {self.host}"
)
async def _negotiate(self):
resp = await self.protocol.query("component_nego")
self._components_raw = resp["component_nego"]
self._components = {
comp["id"]: int(comp["ver_code"])
for comp in self._components_raw["component_list"]
}
async def update(self, update_children: bool = True):
"""Update the device."""
if self.credentials is None and self.credentials_hash is None:
raise AuthenticationException("Tapo plug requires authentication.")
if self._components_raw is None:
resp = await self.protocol.query("component_nego")
self._components_raw = resp["component_nego"]
self._components = {
comp["id"]: int(comp["ver_code"])
for comp in self._components_raw["component_list"]
}
await self._negotiate()
await self._initialize_modules()
extra_reqs: Dict[str, Any] = {}
req: Dict[str, Any] = {}
if "child_device" in self._components:
extra_reqs = {**extra_reqs, "get_child_device_list": None}
if "energy_monitoring" in self._components:
extra_reqs = {
**extra_reqs,
"get_energy_usage": None,
"get_current_power": None,
}
if self._components.get("device", 0) >= 2:
extra_reqs = {
**extra_reqs,
"get_device_usage": None,
}
req = {
"get_device_info": None,
"get_device_time": None,
**extra_reqs,
}
# TODO: this could be optimized by constructing the query only once
for module in self.modules.values():
req.update(module.query())
resp = await self.protocol.query(req)
self._info = self._try_get_response(resp, "get_device_info")
self._time = self._try_get_response(resp, "get_device_time", {})
# Device usage is not available on older firmware versions
self._usage = self._try_get_response(resp, "get_device_usage", {})
# Emeter is not always available, but we set them still for now.
self._energy = self._try_get_response(resp, "get_energy_usage", {})
self._emeter = self._try_get_response(resp, "get_current_power", {})
self._last_update = {
"components": self._components_raw,
"info": self._info,
"usage": self._usage,
"time": self._time,
"energy": self._energy,
"emeter": self._emeter,
**resp,
"child_info": self._try_get_response(resp, "get_child_device_list", {}),
}
if child_info := self._last_update.get("child_info"):
if not self.children:
await self._initialize_children()
for info in child_info["child_device_list"]:
self._children[info["device_id"]].update_internal_state(info)
@ -152,11 +134,32 @@ class SmartDevice(Device):
async def _initialize_modules(self):
"""Initialize modules based on component negotiation response."""
if "energy_monitoring" in self._components:
self.emeter_type = "emeter"
from .smartmodule import SmartModule
for mod in SmartModule.REGISTERED_MODULES.values():
_LOGGER.debug("%s requires %s", mod, mod.REQUIRED_COMPONENT)
if mod.REQUIRED_COMPONENT in self._components:
_LOGGER.debug(
"Found required %s, adding %s to modules.",
mod.REQUIRED_COMPONENT,
mod.__name__,
)
module = mod(self, mod.REQUIRED_COMPONENT)
self.modules[module.name] = module
async def _initialize_features(self):
"""Initialize device features."""
if "device_on" in self._info:
self._add_feature(
Feature(
self,
"State",
attribute_getter="is_on",
attribute_setter="set_state",
type=FeatureType.Switch,
)
)
self._add_feature(
Feature(
self,
@ -200,6 +203,10 @@ class SmartDevice(Device):
)
)
for module in self.modules.values():
for feat in module._module_features.values():
self._add_feature(feat)
@property
def sys_info(self) -> Dict[str, Any]:
"""Returns the device info."""
@ -221,17 +228,8 @@ class SmartDevice(Device):
@property
def time(self) -> datetime:
"""Return the time."""
td = timedelta(minutes=cast(float, self._time.get("time_diff")))
if self._time.get("region"):
tz = timezone(td, str(self._time.get("region")))
else:
# in case the device returns a blank region this will result in the
# tzname being a UTC offset
tz = timezone(td)
return datetime.fromtimestamp(
cast(float, self._time.get("timestamp")),
tz=tz,
)
_timemod = cast(TimeModule, self.modules["TimeModule"])
return _timemod.time
@property
def timezone(self) -> Dict:
@ -308,20 +306,27 @@ class SmartDevice(Device):
@property
def has_emeter(self) -> bool:
"""Return if the device has emeter."""
return "energy_monitoring" in self._components
return "EnergyModule" in self.modules
@property
def is_on(self) -> bool:
"""Return true if the device is on."""
return bool(self._info.get("device_on"))
async def set_state(self, on: bool): # TODO: better name wanted.
"""Set the device state.
See :meth:`is_on`.
"""
return await self.protocol.query({"set_device_info": {"device_on": on}})
async def turn_on(self, **kwargs):
"""Turn on the device."""
await self.protocol.query({"set_device_info": {"device_on": True}})
await self.set_state(True)
async def turn_off(self, **kwargs):
"""Turn off the device."""
await self.protocol.query({"set_device_info": {"device_on": False}})
await self.set_state(False)
def update_from_discover_info(self, info):
"""Update state from info from the discover call."""
@ -330,43 +335,28 @@ class SmartDevice(Device):
async def get_emeter_realtime(self) -> EmeterStatus:
"""Retrieve current energy readings."""
self._verify_emeter()
resp = await self.protocol.query("get_energy_usage")
self._energy = resp["get_energy_usage"]
return self.emeter_realtime
def _convert_energy_data(self, data, scale) -> Optional[float]:
"""Return adjusted emeter information."""
return data if not data else data * scale
def _verify_emeter(self) -> None:
"""Raise an exception if there is no emeter."""
_LOGGER.warning("Deprecated, use `emeter_realtime`.")
if not self.has_emeter:
raise SmartDeviceException("Device has no emeter")
if self.emeter_type not in self._last_update:
raise SmartDeviceException("update() required prior accessing emeter")
return self.emeter_realtime
@property
def emeter_realtime(self) -> EmeterStatus:
"""Get the emeter status."""
return EmeterStatus(
{
"power_mw": self._energy.get("current_power"),
"total": self._convert_energy_data(
self._energy.get("today_energy"), 1 / 1000
),
}
)
energy = cast(EnergyModule, self.modules["EnergyModule"])
return energy.emeter_realtime
@property
def emeter_this_month(self) -> Optional[float]:
"""Get the emeter value for this month."""
return self._convert_energy_data(self._energy.get("month_energy"), 1 / 1000)
energy = cast(EnergyModule, self.modules["EnergyModule"])
return energy.emeter_this_month
@property
def emeter_today(self) -> Optional[float]:
"""Get the emeter value for today."""
return self._convert_energy_data(self._energy.get("today_energy"), 1 / 1000)
energy = cast(EnergyModule, self.modules["EnergyModule"])
return energy.emeter_today
@property
def on_since(self) -> Optional[datetime]:
@ -377,7 +367,11 @@ class SmartDevice(Device):
):
return None
on_time = cast(float, on_time)
return datetime.now().replace(microsecond=0) - timedelta(seconds=on_time)
if (timemod := self.modules.get("TimeModule")) is not None:
timemod = cast(TimeModule, timemod)
return timemod.time - timedelta(seconds=on_time)
else: # We have no device time, use current local time.
return datetime.now().replace(microsecond=0) - timedelta(seconds=on_time)
async def wifi_scan(self) -> List[WifiNetwork]:
"""Scan for available wifi networks."""
@ -439,7 +433,7 @@ class SmartDevice(Device):
"password": base64.b64encode(password.encode()).decode(),
"ssid": base64.b64encode(ssid.encode()).decode(),
},
"time": self.internal_state["time"],
"time": self.internal_state["get_device_time"],
}
# The device does not respond to the request but changes the settings
@ -458,13 +452,13 @@ class SmartDevice(Device):
This will replace the existing authentication credentials on the device.
"""
t = self.internal_state["time"]
time_data = self.internal_state["get_device_time"]
payload = {
"account": {
"username": base64.b64encode(username.encode()).decode(),
"password": base64.b64encode(password.encode()).decode(),
},
"time": t,
"time": time_data,
}
return await self.protocol.query({"set_qs_info": payload})

73
kasa/smart/smartmodule.py Normal file
View File

@ -0,0 +1,73 @@
"""Base implementation for SMART modules."""
import logging
from typing import TYPE_CHECKING, Dict, Type
from ..exceptions import SmartDeviceException
from ..module import Module
if TYPE_CHECKING:
from .smartdevice import SmartDevice
_LOGGER = logging.getLogger(__name__)
class SmartModule(Module):
"""Base class for SMART modules."""
NAME: str
REQUIRED_COMPONENT: str
QUERY_GETTER_NAME: str
REGISTERED_MODULES: Dict[str, Type["SmartModule"]] = {}
def __init__(self, device: "SmartDevice", module: str):
self._device: SmartDevice
super().__init__(device, module)
def __init_subclass__(cls, **kwargs):
assert cls.REQUIRED_COMPONENT is not None # noqa: S101
name = getattr(cls, "NAME", cls.__name__)
_LOGGER.debug("Registering %s" % cls)
cls.REGISTERED_MODULES[name] = cls
@property
def name(self) -> str:
"""Name of the module."""
return getattr(self, "NAME", self.__class__.__name__)
def query(self) -> Dict:
"""Query to execute during the update cycle.
Default implementation uses the raw query getter w/o parameters.
"""
return {self.QUERY_GETTER_NAME: None}
def call(self, method, params=None):
"""Call a method.
Just a helper method.
"""
return self._device._query_helper(method, params)
@property
def data(self):
"""Return response data for the module.
If module performs only a single query, the resulting response is unwrapped.
"""
q = self.query()
q_keys = list(q.keys())
# TODO: hacky way to check if update has been called.
if q_keys[0] not in self._device._last_update:
raise SmartDeviceException(
f"You need to call update() prior accessing module data"
f" for '{self._module}'"
)
filtered_data = {
k: v for k, v in self._device._last_update.items() if k in q_keys
}
if len(filtered_data) == 1:
return next(iter(filtered_data.values()))
return filtered_data

View File

@ -60,6 +60,11 @@ async def test_childdevice_properties(dev: SmartChildDevice):
)
for prop in properties:
name, _ = prop
# Skip emeter and time properties
# TODO: needs API cleanup, emeter* should probably be removed in favor
# of access through features/modules, handling of time* needs decision.
if name.startswith("emeter_") or name.startswith("time"):
continue
try:
_ = getattr(first, name)
except Exception as ex:

View File

@ -310,16 +310,13 @@ async def test_modules_not_supported(dev: IotDevice):
@device_smart
async def test_update_sub_errors(dev: SmartDevice, caplog):
async def test_try_get_response(dev: SmartDevice, caplog):
mock_response: dict = {
"get_device_info": {},
"get_device_usage": SmartErrorCode.PARAMS_ERROR,
"get_device_time": {},
"get_device_info": SmartErrorCode.PARAMS_ERROR,
}
caplog.set_level(logging.DEBUG)
with patch.object(dev.protocol, "query", return_value=mock_response):
await dev.update()
msg = "Error PARAMS_ERROR(-1008) getting request get_device_usage for device 127.0.0.123"
dev._try_get_response(mock_response, "get_device_info", {})
msg = "Error PARAMS_ERROR(-1008) getting request get_device_info for device 127.0.0.123"
assert msg in caplog.text