From ec1082a2287190266b086ae1d6a9c789f9eca8eb Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Fri, 28 Jun 2024 19:25:39 +0100 Subject: [PATCH] Allow calling update directly on child devices and skipping updates on the parent --- kasa/device.py | 2 +- kasa/iot/iotstrip.py | 33 ++++++++++++++--- kasa/smart/smartchilddevice.py | 18 +++++++++- kasa/smart/smartdevice.py | 18 +++++++--- kasa/tests/test_childdevice.py | 66 +++++++++++++++++++++++++++++++++- 5 files changed, 124 insertions(+), 13 deletions(-) diff --git a/kasa/device.py b/kasa/device.py index 9bf0903e..c9be2ab3 100644 --- a/kasa/device.py +++ b/kasa/device.py @@ -233,7 +233,7 @@ class Device(ABC): return await connect(host=host, config=config) # type: ignore[arg-type] @abstractmethod - async def update(self, update_children: bool = True): + async def update(self, update_children: bool = True, update_parent: bool = True): """Update the device.""" async def disconnect(self): diff --git a/kasa/iot/iotstrip.py b/kasa/iot/iotstrip.py index e64ace05..da9765f8 100755 --- a/kasa/iot/iotstrip.py +++ b/kasa/iot/iotstrip.py @@ -121,7 +121,18 @@ class IotStrip(IotDevice): """Return if any of the outlets are on.""" return any(plug.is_on for plug in self.children) - async def update(self, update_children: bool = True): + async def update(self, update_children: bool = True, update_parent: bool = True): + """Update some of the attributes. + + Needed for methods that are decorated with `requires_update`. + """ + await self._update(update_children) + + async def _update( + self, + update_children: bool = True, + called_from_child: IotStripPlug | None = None, + ): """Update some of the attributes. Needed for methods that are decorated with `requires_update`. @@ -143,9 +154,11 @@ class IotStrip(IotDevice): for child in self._children.values(): await child._initialize_modules() - if update_children: - for plug in self.children: - await plug.update() + if called_from_child: + await called_from_child._update() + elif update_children: + for child in self._children.values(): + await child._update() if not self.features: await self._initialize_features() @@ -355,7 +368,17 @@ class IotStripPlug(IotPlug): for module_feat in module._module_features.values(): self._add_feature(module_feat) - async def update(self, update_children: bool = True): + async def update(self, update_children: bool = True, update_parent: bool = True): + """Query the device to update the data. + + Needed for properties that are decorated with `requires_update`. + """ + if update_parent: + await self.parent._update(update_children=False, called_from_child=self) + else: + await self._update() + + async def _update(self): """Query the device to update the data. Needed for properties that are decorated with `requires_update`. diff --git a/kasa/smart/smartchilddevice.py b/kasa/smart/smartchilddevice.py index c6596b96..6c364b0c 100644 --- a/kasa/smart/smartchilddevice.py +++ b/kasa/smart/smartchilddevice.py @@ -19,6 +19,8 @@ class SmartChildDevice(SmartDevice): This wraps the protocol communications and sets internal data for the child. """ + _parent: SmartDevice + def __init__( self, parent: SmartDevice, @@ -34,12 +36,26 @@ class SmartChildDevice(SmartDevice): self._id = info["device_id"] self.protocol = _ChildProtocolWrapper(self._id, parent.protocol) - async def update(self, update_children: bool = True): + async def update(self, update_children: bool = True, update_parent: bool = True): + """Update the device. + + Calling update directly on a child device will update the parent + and only this child. + """ + if update_parent: + await self._parent._update(update_children=False, called_from_child=self) + else: + await self._update() + + async def _update(self): """Update child module info. The parent updates our internal info so just update modules with their own queries. """ + # Hubs attached devices only update via the parent hub + if self._parent.device_type == DeviceType.Hub: + return req: dict[str, Any] = {} for module in self.modules.values(): if mod_query := module.query(): diff --git a/kasa/smart/smartdevice.py b/kasa/smart/smartdevice.py index a5b64e52..f4fa69c9 100644 --- a/kasa/smart/smartdevice.py +++ b/kasa/smart/smartdevice.py @@ -147,8 +147,14 @@ class SmartDevice(Device): if "child_device" in self._components and not self.children: await self._initialize_children() - async def update(self, update_children: bool = False): + async def update(self, update_children: bool = True, update_parent: bool = True): """Update the device.""" + await self._update(update_children) + + async def _update( + self, update_children: bool = True, called_from_child: SmartDevice | None = None + ): + """If called from a child device will only update that child.""" if self.credentials is None and self.credentials_hash is None: raise AuthenticationError("Tapo plug requires authentication.") @@ -167,11 +173,13 @@ class SmartDevice(Device): self._info = self._try_get_response(resp, "get_device_info") # Call child update which will only update module calls, info is updated - # from get_child_device_list. update_children only affects hub devices, other - # devices will always update children to prevent errors on module access. - if update_children or self.device_type != DeviceType.Hub: + # from get_child_device_list. If this method is being called by a child + # it will only call update on that child + if called_from_child: + await called_from_child._update() + elif update_children: for child in self._children.values(): - await child.update() + await child._update() if child_info := self._try_get_response(resp, "get_child_device_list", {}): for info in child_info["child_device_list"]: self._children[info["device_id"]]._update_internal_state(info) diff --git a/kasa/tests/test_childdevice.py b/kasa/tests/test_childdevice.py index 26568c24..3cdf774f 100644 --- a/kasa/tests/test_childdevice.py +++ b/kasa/tests/test_childdevice.py @@ -2,13 +2,21 @@ import inspect import sys import pytest +from pytest_mock import MockerFixture +from kasa import Device from kasa.device_type import DeviceType from kasa.smart.smartchilddevice import SmartChildDevice from kasa.smart.smartdevice import NON_HUB_PARENT_ONLY_MODULES from kasa.smartprotocol import _ChildProtocolWrapper -from .conftest import parametrize, parametrize_subtract, strip_smart +from .conftest import ( + parametrize, + parametrize_combine, + parametrize_subtract, + strip_iot, + strip_smart, +) has_children_smart = parametrize( "has children", component_filter="control_child", protocol_filter={"SMART"} @@ -18,6 +26,8 @@ hub_smart = parametrize( ) non_hub_parent_smart = parametrize_subtract(has_children_smart, hub_smart) +has_children = parametrize_combine([has_children_smart, strip_iot]) + @strip_smart def test_childdevice_init(dev, dummy_protocol, mocker): @@ -100,3 +110,57 @@ async def test_parent_only_modules(dev, dummy_protocol, mocker): for child in dev.children: for module in NON_HUB_PARENT_ONLY_MODULES: assert module not in [type(module) for module in child.modules.values()] + + +@has_children +async def test_device_updates(dev: Device, mocker: MockerFixture): + if not dev.children and dev.device_type is Device.Type.Hub: + pytest.skip(f"Fixture for hub device {dev} does not have any children") + assert dev.children + parent_spy = mocker.spy(dev, "_update") + child_spies = {child: mocker.spy(child, "_update") for child in dev.children} + + # update children + await dev.update(update_children=True) + parent_spy.assert_called_once() + for child_spy in child_spies.values(): + child_spy.assert_called_once() + + # do not update children + parent_spy.reset_mock() + for child_spy in child_spies.values(): + child_spy.reset_mock() + + await dev.update(update_children=False) + parent_spy.assert_called_once() + for child_spy in child_spies.values(): + child_spy.assert_not_called() + + # update parent + parent_spy.reset_mock() + for child_spy in child_spies.values(): + child_spy.reset_mock() + + child_to_update = dev.children[0] + await child_to_update.update(update_parent=True) + parent_spy.assert_called_once() + assert child_to_update + for child, child_spy in child_spies.items(): + if child == child_to_update: + child_spy.assert_called_once() + else: + child_spy.assert_not_called() + + # do not update parent + parent_spy.reset_mock() + for child_spy in child_spies.values(): + child_spy.reset_mock() + + await child_to_update.update(update_parent=False) + parent_spy.assert_not_called() + assert child_to_update + for child, child_spy in child_spies.items(): + if child == child_to_update: + child_spy.assert_called_once() + else: + child_spy.assert_not_called()