Allow calling update directly on child devices and skipping updates on the parent

This commit is contained in:
sdb9696 2024-06-28 19:25:39 +01:00
parent 2a62849987
commit ec1082a228
5 changed files with 124 additions and 13 deletions

View File

@ -233,7 +233,7 @@ class Device(ABC):
return await connect(host=host, config=config) # type: ignore[arg-type] return await connect(host=host, config=config) # type: ignore[arg-type]
@abstractmethod @abstractmethod
async def update(self, update_children: bool = True): async def update(self, update_children: bool = True, update_parent: bool = True):
"""Update the device.""" """Update the device."""
async def disconnect(self): async def disconnect(self):

View File

@ -121,7 +121,18 @@ class IotStrip(IotDevice):
"""Return if any of the outlets are on.""" """Return if any of the outlets are on."""
return any(plug.is_on for plug in self.children) return any(plug.is_on for plug in self.children)
async def update(self, update_children: bool = True): async def update(self, update_children: bool = True, update_parent: bool = True):
"""Update some of the attributes.
Needed for methods that are decorated with `requires_update`.
"""
await self._update(update_children)
async def _update(
self,
update_children: bool = True,
called_from_child: IotStripPlug | None = None,
):
"""Update some of the attributes. """Update some of the attributes.
Needed for methods that are decorated with `requires_update`. Needed for methods that are decorated with `requires_update`.
@ -143,9 +154,11 @@ class IotStrip(IotDevice):
for child in self._children.values(): for child in self._children.values():
await child._initialize_modules() await child._initialize_modules()
if update_children: if called_from_child:
for plug in self.children: await called_from_child._update()
await plug.update() elif update_children:
for child in self._children.values():
await child._update()
if not self.features: if not self.features:
await self._initialize_features() await self._initialize_features()
@ -355,7 +368,17 @@ class IotStripPlug(IotPlug):
for module_feat in module._module_features.values(): for module_feat in module._module_features.values():
self._add_feature(module_feat) self._add_feature(module_feat)
async def update(self, update_children: bool = True): async def update(self, update_children: bool = True, update_parent: bool = True):
"""Query the device to update the data.
Needed for properties that are decorated with `requires_update`.
"""
if update_parent:
await self.parent._update(update_children=False, called_from_child=self)
else:
await self._update()
async def _update(self):
"""Query the device to update the data. """Query the device to update the data.
Needed for properties that are decorated with `requires_update`. Needed for properties that are decorated with `requires_update`.

View File

@ -19,6 +19,8 @@ class SmartChildDevice(SmartDevice):
This wraps the protocol communications and sets internal data for the child. This wraps the protocol communications and sets internal data for the child.
""" """
_parent: SmartDevice
def __init__( def __init__(
self, self,
parent: SmartDevice, parent: SmartDevice,
@ -34,12 +36,26 @@ class SmartChildDevice(SmartDevice):
self._id = info["device_id"] self._id = info["device_id"]
self.protocol = _ChildProtocolWrapper(self._id, parent.protocol) self.protocol = _ChildProtocolWrapper(self._id, parent.protocol)
async def update(self, update_children: bool = True): async def update(self, update_children: bool = True, update_parent: bool = True):
"""Update the device.
Calling update directly on a child device will update the parent
and only this child.
"""
if update_parent:
await self._parent._update(update_children=False, called_from_child=self)
else:
await self._update()
async def _update(self):
"""Update child module info. """Update child module info.
The parent updates our internal info so just update modules with The parent updates our internal info so just update modules with
their own queries. their own queries.
""" """
# Hubs attached devices only update via the parent hub
if self._parent.device_type == DeviceType.Hub:
return
req: dict[str, Any] = {} req: dict[str, Any] = {}
for module in self.modules.values(): for module in self.modules.values():
if mod_query := module.query(): if mod_query := module.query():

View File

@ -147,8 +147,14 @@ class SmartDevice(Device):
if "child_device" in self._components and not self.children: if "child_device" in self._components and not self.children:
await self._initialize_children() await self._initialize_children()
async def update(self, update_children: bool = False): async def update(self, update_children: bool = True, update_parent: bool = True):
"""Update the device.""" """Update the device."""
await self._update(update_children)
async def _update(
self, update_children: bool = True, called_from_child: SmartDevice | None = None
):
"""If called from a child device will only update that child."""
if self.credentials is None and self.credentials_hash is None: if self.credentials is None and self.credentials_hash is None:
raise AuthenticationError("Tapo plug requires authentication.") raise AuthenticationError("Tapo plug requires authentication.")
@ -167,11 +173,13 @@ class SmartDevice(Device):
self._info = self._try_get_response(resp, "get_device_info") self._info = self._try_get_response(resp, "get_device_info")
# Call child update which will only update module calls, info is updated # Call child update which will only update module calls, info is updated
# from get_child_device_list. update_children only affects hub devices, other # from get_child_device_list. If this method is being called by a child
# devices will always update children to prevent errors on module access. # it will only call update on that child
if update_children or self.device_type != DeviceType.Hub: if called_from_child:
await called_from_child._update()
elif update_children:
for child in self._children.values(): for child in self._children.values():
await child.update() await child._update()
if child_info := self._try_get_response(resp, "get_child_device_list", {}): if child_info := self._try_get_response(resp, "get_child_device_list", {}):
for info in child_info["child_device_list"]: for info in child_info["child_device_list"]:
self._children[info["device_id"]]._update_internal_state(info) self._children[info["device_id"]]._update_internal_state(info)

View File

@ -2,13 +2,21 @@ import inspect
import sys import sys
import pytest import pytest
from pytest_mock import MockerFixture
from kasa import Device
from kasa.device_type import DeviceType from kasa.device_type import DeviceType
from kasa.smart.smartchilddevice import SmartChildDevice from kasa.smart.smartchilddevice import SmartChildDevice
from kasa.smart.smartdevice import NON_HUB_PARENT_ONLY_MODULES from kasa.smart.smartdevice import NON_HUB_PARENT_ONLY_MODULES
from kasa.smartprotocol import _ChildProtocolWrapper from kasa.smartprotocol import _ChildProtocolWrapper
from .conftest import parametrize, parametrize_subtract, strip_smart from .conftest import (
parametrize,
parametrize_combine,
parametrize_subtract,
strip_iot,
strip_smart,
)
has_children_smart = parametrize( has_children_smart = parametrize(
"has children", component_filter="control_child", protocol_filter={"SMART"} "has children", component_filter="control_child", protocol_filter={"SMART"}
@ -18,6 +26,8 @@ hub_smart = parametrize(
) )
non_hub_parent_smart = parametrize_subtract(has_children_smart, hub_smart) non_hub_parent_smart = parametrize_subtract(has_children_smart, hub_smart)
has_children = parametrize_combine([has_children_smart, strip_iot])
@strip_smart @strip_smart
def test_childdevice_init(dev, dummy_protocol, mocker): def test_childdevice_init(dev, dummy_protocol, mocker):
@ -100,3 +110,57 @@ async def test_parent_only_modules(dev, dummy_protocol, mocker):
for child in dev.children: for child in dev.children:
for module in NON_HUB_PARENT_ONLY_MODULES: for module in NON_HUB_PARENT_ONLY_MODULES:
assert module not in [type(module) for module in child.modules.values()] assert module not in [type(module) for module in child.modules.values()]
@has_children
async def test_device_updates(dev: Device, mocker: MockerFixture):
if not dev.children and dev.device_type is Device.Type.Hub:
pytest.skip(f"Fixture for hub device {dev} does not have any children")
assert dev.children
parent_spy = mocker.spy(dev, "_update")
child_spies = {child: mocker.spy(child, "_update") for child in dev.children}
# update children
await dev.update(update_children=True)
parent_spy.assert_called_once()
for child_spy in child_spies.values():
child_spy.assert_called_once()
# do not update children
parent_spy.reset_mock()
for child_spy in child_spies.values():
child_spy.reset_mock()
await dev.update(update_children=False)
parent_spy.assert_called_once()
for child_spy in child_spies.values():
child_spy.assert_not_called()
# update parent
parent_spy.reset_mock()
for child_spy in child_spies.values():
child_spy.reset_mock()
child_to_update = dev.children[0]
await child_to_update.update(update_parent=True)
parent_spy.assert_called_once()
assert child_to_update
for child, child_spy in child_spies.items():
if child == child_to_update:
child_spy.assert_called_once()
else:
child_spy.assert_not_called()
# do not update parent
parent_spy.reset_mock()
for child_spy in child_spies.values():
child_spy.reset_mock()
await child_to_update.update(update_parent=False)
parent_spy.assert_not_called()
assert child_to_update
for child, child_spy in child_spies.items():
if child == child_to_update:
child_spy.assert_called_once()
else:
child_spy.assert_not_called()