Avoid crashing on childdevice property accesses (#732)

* Avoid crashing on childdevice property accesses

* Push updates from parent to child
This commit is contained in:
Teemu R 2024-02-02 17:29:14 +01:00 committed by GitHub
parent 1f62aee7b6
commit 1f15bcda7c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 100 additions and 32 deletions

View File

@ -79,8 +79,11 @@ class EmeterStatus(dict):
return super().__getitem__(item[: item.find("_")]) * 1000 return super().__getitem__(item[: item.find("_")]) * 1000
else: # downscale else: # downscale
for i in super().keys(): # noqa: SIM118 for i in super().keys(): # noqa: SIM118
if i.startswith(item): if (
return self.__getitem__(i) / 1000 i.startswith(item)
and (value := self.__getitem__(i)) is not None
):
return value / 1000
_LOGGER.debug(f"Unable to find value for '{item}'") _LOGGER.debug(f"Unable to find value for '{item}'")
return None return None

View File

@ -1,8 +1,8 @@
"""Child device implementation.""" """Child device implementation."""
from typing import Dict, Optional from typing import Optional
from ..device_type import DeviceType
from ..deviceconfig import DeviceConfig from ..deviceconfig import DeviceConfig
from ..exceptions import SmartDeviceException
from ..smartprotocol import SmartProtocol, _ChildProtocolWrapper from ..smartprotocol import SmartProtocol, _ChildProtocolWrapper
from .tapodevice import TapoDevice from .tapodevice import TapoDevice
@ -24,21 +24,18 @@ class ChildDevice(TapoDevice):
self._parent = parent self._parent = parent
self._id = child_id self._id = child_id
self.protocol = _ChildProtocolWrapper(child_id, parent.protocol) self.protocol = _ChildProtocolWrapper(child_id, parent.protocol)
# TODO: remove the assignment after modularization is done,
# currently required to allow accessing time-related properties
self._time = parent._time
self._device_type = DeviceType.StripSocket
async def update(self, update_children: bool = True): async def update(self, update_children: bool = True):
"""We just set the info here accordingly.""" """Noop update. The parent updates our internals."""
def _get_child_info() -> Dict: def update_internal_state(self, info):
"""Return the subdevice information for this device.""" """Set internal state for the child."""
for child in self._parent._last_update["child_info"]["child_device_list"]: # TODO: cleanup the _last_update, _sys_info, _info, _data mess.
if child["device_id"] == self._id: self._last_update = self._sys_info = self._info = info
return child
raise SmartDeviceException(
f"Unable to find child device with id {self._id}"
)
self._last_update = self._sys_info = self._info = _get_child_info()
def __repr__(self): def __repr__(self):
return f"<ChildDevice {self.alias} of {self._parent}>" return f"<ChildDevice {self.alias} of {self._parent}>"

View File

@ -2,7 +2,7 @@
import base64 import base64
import logging import logging
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional, Set, cast from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, cast
from ..aestransport import AesTransport from ..aestransport import AesTransport
from ..device_type import DeviceType from ..device_type import DeviceType
@ -15,6 +15,9 @@ from ..smartprotocol import SmartProtocol
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
if TYPE_CHECKING:
from .childdevice import ChildDevice
class TapoDevice(SmartDevice): class TapoDevice(SmartDevice):
"""Base class to represent a TAPO device.""" """Base class to represent a TAPO device."""
@ -32,20 +35,40 @@ class TapoDevice(SmartDevice):
super().__init__(host=host, config=config, protocol=_protocol) super().__init__(host=host, config=config, protocol=_protocol)
self.protocol: SmartProtocol self.protocol: SmartProtocol
self._components_raw: Optional[Dict[str, Any]] = None self._components_raw: Optional[Dict[str, Any]] = None
self._components: Dict[str, int] self._components: Dict[str, int] = {}
self._children: Dict[str, "ChildDevice"] = {}
self._energy: Dict[str, Any] = {}
self._state_information: Dict[str, Any] = {} self._state_information: Dict[str, Any] = {}
async def _initialize_children(self): async def _initialize_children(self):
"""Initialize children for power strips."""
children = self._last_update["child_info"]["child_device_list"] children = self._last_update["child_info"]["child_device_list"]
# TODO: Use the type information to construct children, # TODO: Use the type information to construct children,
# as hubs can also have them. # as hubs can also have them.
from .childdevice import ChildDevice from .childdevice import ChildDevice
self.children = [ self._children = {
ChildDevice(parent=self, child_id=child["device_id"]) for child in children child["device_id"]: ChildDevice(parent=self, child_id=child["device_id"])
] for child in children
}
self._device_type = DeviceType.Strip self._device_type = DeviceType.Strip
@property
def children(self):
"""Return list of children.
This is just to keep the existing SmartDevice API intact.
"""
return list(self._children.values())
@children.setter
def children(self, children):
"""Initialize from a list of children.
This is just to keep the existing SmartDevice API intact.
"""
self._children = {child["device_id"]: child for child in children}
async def update(self, update_children: bool = True): async def update(self, update_children: bool = True):
"""Update the device.""" """Update the device."""
if self.credentials is None and self.credentials_hash is None: if self.credentials is None and self.credentials_hash is None:
@ -88,7 +111,7 @@ class TapoDevice(SmartDevice):
self._energy = resp.get("get_energy_usage", {}) self._energy = resp.get("get_energy_usage", {})
self._emeter = resp.get("get_current_power", {}) self._emeter = resp.get("get_current_power", {})
self._last_update = self._data = { self._last_update = {
"components": self._components_raw, "components": self._components_raw,
"info": self._info, "info": self._info,
"usage": self._usage, "usage": self._usage,
@ -98,13 +121,13 @@ class TapoDevice(SmartDevice):
"child_info": resp.get("get_child_device_list", {}), "child_info": resp.get("get_child_device_list", {}),
} }
if self._last_update["child_info"]: if child_info := self._last_update.get("child_info"):
if not self.children: if not self.children:
await self._initialize_children() await self._initialize_children()
for child in self.children: for info in child_info["child_device_list"]:
await child.update() self._children[info["device_id"]].update_internal_state(info)
_LOGGER.debug("Got an update: %s", self._data) _LOGGER.debug("Got an update: %s", self._last_update)
async def _initialize_modules(self): async def _initialize_modules(self):
"""Initialize modules based on component negotiation response.""" """Initialize modules based on component negotiation response."""
@ -192,7 +215,7 @@ class TapoDevice(SmartDevice):
@property @property
def internal_state(self) -> Any: def internal_state(self) -> Any:
"""Return all the internal state data.""" """Return all the internal state data."""
return self._data return self._last_update
async def _query_helper( async def _query_helper(
self, target: str, cmd: str, arg: Optional[Dict] = None, child_ids=None self, target: str, cmd: str, arg: Optional[Dict] = None, child_ids=None
@ -204,10 +227,13 @@ class TapoDevice(SmartDevice):
@property @property
def state_information(self) -> Dict[str, Any]: def state_information(self) -> Dict[str, Any]:
"""Return the key state information.""" """Return the key state information."""
ssid = self._info.get("ssid")
ssid = base64.b64decode(ssid).decode() if ssid else "No SSID"
return { return {
"overheated": self._info.get("overheated"), "overheated": self._info.get("overheated"),
"signal_level": self._info.get("signal_level"), "signal_level": self._info.get("signal_level"),
"SSID": base64.b64decode(str(self._info.get("ssid"))).decode(), "SSID": ssid,
} }
@property @property

View File

@ -1,4 +1,10 @@
import inspect
import sys
import pytest
from kasa.smartprotocol import _ChildProtocolWrapper from kasa.smartprotocol import _ChildProtocolWrapper
from kasa.tapo.childdevice import ChildDevice
from .conftest import strip_smart from .conftest import strip_smart
@ -19,12 +25,48 @@ def test_childdevice_init(dev, dummy_protocol, mocker):
@strip_smart @strip_smart
async def test_childdevice_update(dev, dummy_protocol, mocker): async def test_childdevice_update(dev, dummy_protocol, mocker):
"""Test that parent update updates children.""" """Test that parent update updates children."""
assert len(dev.children) > 0 child_info = dev._last_update["child_info"]
child_list = child_info["child_device_list"]
assert len(dev.children) == child_info["sum"]
first = dev.children[0] first = dev.children[0]
child_update = mocker.patch.object(first, "update")
await dev.update() await dev.update()
child_update.assert_called()
assert dev._last_update != first._last_update assert dev._last_update != first._last_update
assert dev._last_update["child_info"]["child_device_list"][0] == first._last_update assert child_list[0] == first._last_update
@strip_smart
@pytest.mark.skipif(
sys.version_info < (3, 11),
reason="exceptiongroup requires python3.11+",
)
async def test_childdevice_properties(dev: ChildDevice):
"""Check that accessing childdevice properties do not raise exceptions."""
assert len(dev.children) > 0
first = dev.children[0]
assert first.is_strip_socket
# children do not have children
assert not first.children
def _test_property_getters():
"""Try accessing all properties and return a list of encountered exceptions."""
exceptions = []
properties = inspect.getmembers(
first.__class__, lambda o: isinstance(o, property)
)
for prop in properties:
name, _ = prop
try:
_ = getattr(first, name)
except Exception as ex:
exceptions.append(ex)
return exceptions
exceptions = list(_test_property_getters())
if exceptions:
raise ExceptionGroup("Accessing child properties caused exceptions", exceptions)