mirror of
				https://github.com/python-kasa/python-kasa.git
				synced 2025-11-03 22:22:06 +00:00 
			
		
		
		
	Allow calling update directly on child devices and skipping updates on the parent
This commit is contained in:
		@@ -233,7 +233,7 @@ class Device(ABC):
 | 
			
		||||
        return await connect(host=host, config=config)  # type: ignore[arg-type]
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    async def update(self, update_children: bool = True):
 | 
			
		||||
    async def update(self, update_children: bool = True, update_parent: bool = True):
 | 
			
		||||
        """Update the device."""
 | 
			
		||||
 | 
			
		||||
    async def disconnect(self):
 | 
			
		||||
 
 | 
			
		||||
@@ -121,7 +121,18 @@ class IotStrip(IotDevice):
 | 
			
		||||
        """Return if any of the outlets are on."""
 | 
			
		||||
        return any(plug.is_on for plug in self.children)
 | 
			
		||||
 | 
			
		||||
    async def update(self, update_children: bool = True):
 | 
			
		||||
    async def update(self, update_children: bool = True, update_parent: bool = True):
 | 
			
		||||
        """Update some of the attributes.
 | 
			
		||||
 | 
			
		||||
        Needed for methods that are decorated with `requires_update`.
 | 
			
		||||
        """
 | 
			
		||||
        await self._update(update_children)
 | 
			
		||||
 | 
			
		||||
    async def _update(
 | 
			
		||||
        self,
 | 
			
		||||
        update_children: bool = True,
 | 
			
		||||
        called_from_child: IotStripPlug | None = None,
 | 
			
		||||
    ):
 | 
			
		||||
        """Update some of the attributes.
 | 
			
		||||
 | 
			
		||||
        Needed for methods that are decorated with `requires_update`.
 | 
			
		||||
@@ -143,9 +154,11 @@ class IotStrip(IotDevice):
 | 
			
		||||
            for child in self._children.values():
 | 
			
		||||
                await child._initialize_modules()
 | 
			
		||||
 | 
			
		||||
        if update_children:
 | 
			
		||||
            for plug in self.children:
 | 
			
		||||
                await plug.update()
 | 
			
		||||
        if called_from_child:
 | 
			
		||||
            await called_from_child._update()
 | 
			
		||||
        elif update_children:
 | 
			
		||||
            for child in self._children.values():
 | 
			
		||||
                await child._update()
 | 
			
		||||
 | 
			
		||||
        if not self.features:
 | 
			
		||||
            await self._initialize_features()
 | 
			
		||||
@@ -355,7 +368,17 @@ class IotStripPlug(IotPlug):
 | 
			
		||||
            for module_feat in module._module_features.values():
 | 
			
		||||
                self._add_feature(module_feat)
 | 
			
		||||
 | 
			
		||||
    async def update(self, update_children: bool = True):
 | 
			
		||||
    async def update(self, update_children: bool = True, update_parent: bool = True):
 | 
			
		||||
        """Query the device to update the data.
 | 
			
		||||
 | 
			
		||||
        Needed for properties that are decorated with `requires_update`.
 | 
			
		||||
        """
 | 
			
		||||
        if update_parent:
 | 
			
		||||
            await self.parent._update(update_children=False, called_from_child=self)
 | 
			
		||||
        else:
 | 
			
		||||
            await self._update()
 | 
			
		||||
 | 
			
		||||
    async def _update(self):
 | 
			
		||||
        """Query the device to update the data.
 | 
			
		||||
 | 
			
		||||
        Needed for properties that are decorated with `requires_update`.
 | 
			
		||||
 
 | 
			
		||||
@@ -19,6 +19,8 @@ class SmartChildDevice(SmartDevice):
 | 
			
		||||
    This wraps the protocol communications and sets internal data for the child.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    _parent: SmartDevice
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        parent: SmartDevice,
 | 
			
		||||
@@ -34,12 +36,26 @@ class SmartChildDevice(SmartDevice):
 | 
			
		||||
        self._id = info["device_id"]
 | 
			
		||||
        self.protocol = _ChildProtocolWrapper(self._id, parent.protocol)
 | 
			
		||||
 | 
			
		||||
    async def update(self, update_children: bool = True):
 | 
			
		||||
    async def update(self, update_children: bool = True, update_parent: bool = True):
 | 
			
		||||
        """Update the device.
 | 
			
		||||
 | 
			
		||||
        Calling update directly on a child device will update the parent
 | 
			
		||||
        and only this child.
 | 
			
		||||
        """
 | 
			
		||||
        if update_parent:
 | 
			
		||||
            await self._parent._update(update_children=False, called_from_child=self)
 | 
			
		||||
        else:
 | 
			
		||||
            await self._update()
 | 
			
		||||
 | 
			
		||||
    async def _update(self):
 | 
			
		||||
        """Update child module info.
 | 
			
		||||
 | 
			
		||||
        The parent updates our internal info so just update modules with
 | 
			
		||||
        their own queries.
 | 
			
		||||
        """
 | 
			
		||||
        # Hubs attached devices only update via the parent hub
 | 
			
		||||
        if self._parent.device_type == DeviceType.Hub:
 | 
			
		||||
            return
 | 
			
		||||
        req: dict[str, Any] = {}
 | 
			
		||||
        for module in self.modules.values():
 | 
			
		||||
            if mod_query := module.query():
 | 
			
		||||
 
 | 
			
		||||
@@ -147,8 +147,14 @@ class SmartDevice(Device):
 | 
			
		||||
        if "child_device" in self._components and not self.children:
 | 
			
		||||
            await self._initialize_children()
 | 
			
		||||
 | 
			
		||||
    async def update(self, update_children: bool = False):
 | 
			
		||||
    async def update(self, update_children: bool = True, update_parent: bool = True):
 | 
			
		||||
        """Update the device."""
 | 
			
		||||
        await self._update(update_children)
 | 
			
		||||
 | 
			
		||||
    async def _update(
 | 
			
		||||
        self, update_children: bool = True, called_from_child: SmartDevice | None = None
 | 
			
		||||
    ):
 | 
			
		||||
        """If called from a child device will only update that child."""
 | 
			
		||||
        if self.credentials is None and self.credentials_hash is None:
 | 
			
		||||
            raise AuthenticationError("Tapo plug requires authentication.")
 | 
			
		||||
 | 
			
		||||
@@ -167,11 +173,13 @@ class SmartDevice(Device):
 | 
			
		||||
        self._info = self._try_get_response(resp, "get_device_info")
 | 
			
		||||
 | 
			
		||||
        # Call child update which will only update module calls, info is updated
 | 
			
		||||
        # from get_child_device_list. update_children only affects hub devices, other
 | 
			
		||||
        # devices will always update children to prevent errors on module access.
 | 
			
		||||
        if update_children or self.device_type != DeviceType.Hub:
 | 
			
		||||
        # from get_child_device_list. If this method is being called by a child
 | 
			
		||||
        # it will only call update on that child
 | 
			
		||||
        if called_from_child:
 | 
			
		||||
            await called_from_child._update()
 | 
			
		||||
        elif update_children:
 | 
			
		||||
            for child in self._children.values():
 | 
			
		||||
                await child.update()
 | 
			
		||||
                await child._update()
 | 
			
		||||
        if child_info := self._try_get_response(resp, "get_child_device_list", {}):
 | 
			
		||||
            for info in child_info["child_device_list"]:
 | 
			
		||||
                self._children[info["device_id"]]._update_internal_state(info)
 | 
			
		||||
 
 | 
			
		||||
@@ -2,13 +2,21 @@ import inspect
 | 
			
		||||
import sys
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
from pytest_mock import MockerFixture
 | 
			
		||||
 | 
			
		||||
from kasa import Device
 | 
			
		||||
from kasa.device_type import DeviceType
 | 
			
		||||
from kasa.smart.smartchilddevice import SmartChildDevice
 | 
			
		||||
from kasa.smart.smartdevice import NON_HUB_PARENT_ONLY_MODULES
 | 
			
		||||
from kasa.smartprotocol import _ChildProtocolWrapper
 | 
			
		||||
 | 
			
		||||
from .conftest import parametrize, parametrize_subtract, strip_smart
 | 
			
		||||
from .conftest import (
 | 
			
		||||
    parametrize,
 | 
			
		||||
    parametrize_combine,
 | 
			
		||||
    parametrize_subtract,
 | 
			
		||||
    strip_iot,
 | 
			
		||||
    strip_smart,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
has_children_smart = parametrize(
 | 
			
		||||
    "has children", component_filter="control_child", protocol_filter={"SMART"}
 | 
			
		||||
@@ -18,6 +26,8 @@ hub_smart = parametrize(
 | 
			
		||||
)
 | 
			
		||||
non_hub_parent_smart = parametrize_subtract(has_children_smart, hub_smart)
 | 
			
		||||
 | 
			
		||||
has_children = parametrize_combine([has_children_smart, strip_iot])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@strip_smart
 | 
			
		||||
def test_childdevice_init(dev, dummy_protocol, mocker):
 | 
			
		||||
@@ -100,3 +110,57 @@ async def test_parent_only_modules(dev, dummy_protocol, mocker):
 | 
			
		||||
    for child in dev.children:
 | 
			
		||||
        for module in NON_HUB_PARENT_ONLY_MODULES:
 | 
			
		||||
            assert module not in [type(module) for module in child.modules.values()]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@has_children
 | 
			
		||||
async def test_device_updates(dev: Device, mocker: MockerFixture):
 | 
			
		||||
    if not dev.children and dev.device_type is Device.Type.Hub:
 | 
			
		||||
        pytest.skip(f"Fixture for hub device {dev} does not have any children")
 | 
			
		||||
    assert dev.children
 | 
			
		||||
    parent_spy = mocker.spy(dev, "_update")
 | 
			
		||||
    child_spies = {child: mocker.spy(child, "_update") for child in dev.children}
 | 
			
		||||
 | 
			
		||||
    # update children
 | 
			
		||||
    await dev.update(update_children=True)
 | 
			
		||||
    parent_spy.assert_called_once()
 | 
			
		||||
    for child_spy in child_spies.values():
 | 
			
		||||
        child_spy.assert_called_once()
 | 
			
		||||
 | 
			
		||||
    # do not update children
 | 
			
		||||
    parent_spy.reset_mock()
 | 
			
		||||
    for child_spy in child_spies.values():
 | 
			
		||||
        child_spy.reset_mock()
 | 
			
		||||
 | 
			
		||||
    await dev.update(update_children=False)
 | 
			
		||||
    parent_spy.assert_called_once()
 | 
			
		||||
    for child_spy in child_spies.values():
 | 
			
		||||
        child_spy.assert_not_called()
 | 
			
		||||
 | 
			
		||||
    # update parent
 | 
			
		||||
    parent_spy.reset_mock()
 | 
			
		||||
    for child_spy in child_spies.values():
 | 
			
		||||
        child_spy.reset_mock()
 | 
			
		||||
 | 
			
		||||
    child_to_update = dev.children[0]
 | 
			
		||||
    await child_to_update.update(update_parent=True)
 | 
			
		||||
    parent_spy.assert_called_once()
 | 
			
		||||
    assert child_to_update
 | 
			
		||||
    for child, child_spy in child_spies.items():
 | 
			
		||||
        if child == child_to_update:
 | 
			
		||||
            child_spy.assert_called_once()
 | 
			
		||||
        else:
 | 
			
		||||
            child_spy.assert_not_called()
 | 
			
		||||
 | 
			
		||||
    # do not update parent
 | 
			
		||||
    parent_spy.reset_mock()
 | 
			
		||||
    for child_spy in child_spies.values():
 | 
			
		||||
        child_spy.reset_mock()
 | 
			
		||||
 | 
			
		||||
    await child_to_update.update(update_parent=False)
 | 
			
		||||
    parent_spy.assert_not_called()
 | 
			
		||||
    assert child_to_update
 | 
			
		||||
    for child, child_spy in child_spies.items():
 | 
			
		||||
        if child == child_to_update:
 | 
			
		||||
            child_spy.assert_called_once()
 | 
			
		||||
        else:
 | 
			
		||||
            child_spy.assert_not_called()
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user