Enable dynamic hub child creation and deletion on update (#1454)

This commit is contained in:
Steven B. 2025-01-15 19:10:32 +00:00 committed by GitHub
parent 17356c10f1
commit b23019e748
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 445 additions and 115 deletions

View File

@ -38,6 +38,7 @@ Plug 3: False
True
"""
from ...device_type import DeviceType
from ..smartmodule import SmartModule
@ -46,3 +47,10 @@ class ChildDevice(SmartModule):
REQUIRED_COMPONENT = "child_device"
QUERY_GETTER_NAME = "get_child_device_list"
def query(self) -> dict:
"""Query to execute during the update cycle."""
q = super().query()
if self._device.device_type is DeviceType.Hub:
q["get_child_device_component_list"] = None
return q

View File

@ -109,6 +109,11 @@ class SmartChildDevice(SmartDevice):
)
self._last_update_time = now
# We can first initialize the features after the first update.
# We make here an assumption that every device has at least a single feature.
if not self._features:
await self._initialize_features()
@classmethod
async def create(
cls,

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import base64
import logging
import time
from collections.abc import Mapping, Sequence
from collections.abc import Sequence
from datetime import UTC, datetime, timedelta, tzinfo
from typing import TYPE_CHECKING, Any, TypeAlias, cast
@ -68,10 +68,11 @@ class SmartDevice(Device):
self._state_information: dict[str, Any] = {}
self._modules: dict[str | ModuleName[Module], SmartModule] = {}
self._parent: SmartDevice | None = None
self._children: Mapping[str, SmartDevice] = {}
self._children: dict[str, SmartDevice] = {}
self._last_update_time: float | None = None
self._on_since: datetime | None = None
self._info: dict[str, Any] = {}
self._logged_missing_child_ids: set[str] = set()
async def _initialize_children(self) -> None:
"""Initialize children for power strips."""
@ -82,23 +83,86 @@ class SmartDevice(Device):
resp = await self.protocol.query(child_info_query)
self.internal_state.update(resp)
children = self.internal_state["get_child_device_list"]["child_device_list"]
children_components_raw = {
child["device_id"]: child
for child in self.internal_state["get_child_device_component_list"][
"child_component_list"
]
}
async def _try_create_child(
self, info: dict, child_components: dict
) -> SmartDevice | None:
from .smartchilddevice import SmartChildDevice
self._children = {
child_info["device_id"]: await SmartChildDevice.create(
parent=self,
child_info=child_info,
child_components_raw=children_components_raw[child_info["device_id"]],
)
for child_info in children
return await SmartChildDevice.create(
parent=self,
child_info=info,
child_components_raw=child_components,
)
async def _create_delete_children(
self,
child_device_resp: dict[str, list],
child_device_components_resp: dict[str, list],
) -> bool:
"""Create and delete children. Return True if children changed.
Adds newly found children and deletes children that are no longer
reported by the device. It will only log once per child_id that
can't be created to avoid spamming the logs on every update.
"""
changed = False
smart_children_components = {
child["device_id"]: child
for child in child_device_components_resp["child_component_list"]
}
children = self._children
child_ids: set[str] = set()
existing_child_ids = set(self._children.keys())
for info in child_device_resp["child_device_list"]:
if (child_id := info.get("device_id")) and (
child_components := smart_children_components.get(child_id)
):
child_ids.add(child_id)
if child_id in existing_child_ids:
continue
child = await self._try_create_child(info, child_components)
if child:
_LOGGER.debug("Created child device %s for %s", child, self.host)
changed = True
children[child_id] = child
continue
if child_id not in self._logged_missing_child_ids:
self._logged_missing_child_ids.add(child_id)
_LOGGER.debug("Child device type not supported: %s", info)
continue
if child_id:
if child_id not in self._logged_missing_child_ids:
self._logged_missing_child_ids.add(child_id)
_LOGGER.debug(
"Could not find child components for device %s, "
"child_id %s, components: %s: ",
self.host,
child_id,
smart_children_components,
)
continue
# If we couldn't get a child device id we still only want to
# log once to avoid spamming the logs on every update cycle
# so store it under an empty string
if "" not in self._logged_missing_child_ids:
self._logged_missing_child_ids.add("")
_LOGGER.debug(
"Could not find child id for device %s, info: %s", self.host, info
)
removed_ids = existing_child_ids - child_ids
for removed_id in removed_ids:
changed = True
removed = children.pop(removed_id)
_LOGGER.debug("Removed child device %s from %s", removed, self.host)
return changed
@property
def children(self) -> Sequence[SmartDevice]:
@ -164,21 +228,29 @@ class SmartDevice(Device):
if "child_device" in self._components and not self.children:
await self._initialize_children()
def _update_children_info(self) -> None:
"""Update the internal child device info from the parent info."""
async def _update_children_info(self) -> bool:
"""Update the internal child device info from the parent info.
Return true if children added or deleted.
"""
changed = False
if child_info := self._try_get_response(
self._last_update, "get_child_device_list", {}
):
changed = await self._create_delete_children(
child_info, self._last_update["get_child_device_component_list"]
)
for info in child_info["child_device_list"]:
child_id = info["device_id"]
child_id = info.get("device_id")
if child_id not in self._children:
_LOGGER.debug(
"Skipping child update for %s, probably unsupported device",
child_id,
)
# _create_delete_children has already logged a message
continue
self._children[child_id]._update_internal_state(info)
return changed
def _update_internal_info(self, info_resp: dict) -> None:
"""Update the internal device info."""
self._info = self._try_get_response(info_resp, "get_device_info")
@ -201,13 +273,13 @@ class SmartDevice(Device):
resp = await self._modular_update(first_update, now)
self._update_children_info()
children_changed = await self._update_children_info()
# Call child update which will only update module calls, info is updated
# from get_child_device_list. update_children only affects hub devices, other
# devices will always update children to prevent errors on module access.
# This needs to go after updating the internal state of the children so that
# child modules have access to their sysinfo.
if first_update or update_children or self.device_type != DeviceType.Hub:
if children_changed or update_children or self.device_type != DeviceType.Hub:
for child in self._children.values():
if TYPE_CHECKING:
assert isinstance(child, SmartChildDevice)
@ -469,8 +541,6 @@ class SmartDevice(Device):
module._initialize_features()
for feat in module._module_features.values():
self._add_feature(feat)
for child in self._children.values():
await child._initialize_features()
@property
def _is_hub_child(self) -> bool:

View File

@ -19,7 +19,10 @@ class ChildDevice(SmartCamModule):
Default implementation uses the raw query getter w/o parameters.
"""
return {self.QUERY_GETTER_NAME: {"childControl": {"start_index": 0}}}
q = {self.QUERY_GETTER_NAME: {"childControl": {"start_index": 0}}}
if self._device.device_type is DeviceType.Hub:
q["getChildDeviceComponentList"] = {"childControl": {"start_index": 0}}
return q
async def _check_supported(self) -> bool:
"""Additional check to see if the module is supported by the device."""

View File

@ -70,21 +70,29 @@ class SmartCamDevice(SmartDevice):
"""
self._info = self._map_info(info)
def _update_children_info(self) -> None:
"""Update the internal child device info from the parent info."""
async def _update_children_info(self) -> bool:
"""Update the internal child device info from the parent info.
Return true if children added or deleted.
"""
changed = False
if child_info := self._try_get_response(
self._last_update, "getChildDeviceList", {}
):
changed = await self._create_delete_children(
child_info, self._last_update["getChildDeviceComponentList"]
)
for info in child_info["child_device_list"]:
child_id = info["device_id"]
child_id = info.get("device_id")
if child_id not in self._children:
_LOGGER.debug(
"Skipping child update for %s, probably unsupported device",
child_id,
)
# _create_delete_children has already logged a message
continue
self._children[child_id]._update_internal_state(info)
return changed
async def _initialize_smart_child(
self, info: dict, child_components_raw: ComponentsRaw
) -> SmartDevice:
@ -113,7 +121,6 @@ class SmartCamDevice(SmartDevice):
child_id = info["device_id"]
child_protocol = _ChildCameraProtocolWrapper(child_id, self.protocol)
last_update = {"getDeviceInfo": {"device_info": {"basic_info": info}}}
app_component_list = {
"app_component_list": child_components_raw["component_list"]
}
@ -124,7 +131,6 @@ class SmartCamDevice(SmartDevice):
child_info=info,
child_components_raw=app_component_list,
protocol=child_protocol,
last_update=last_update,
)
async def _initialize_children(self) -> None:
@ -136,35 +142,22 @@ class SmartCamDevice(SmartDevice):
resp = await self.protocol.query(child_info_query)
self.internal_state.update(resp)
smart_children_components = {
child["device_id"]: child
for child in resp["getChildDeviceComponentList"]["child_component_list"]
}
children = {}
async def _try_create_child(
self, info: dict, child_components: dict
) -> SmartDevice | None:
if not (category := info.get("category")):
return None
# Smart
if category in SmartChildDevice.CHILD_DEVICE_TYPE_MAP:
return await self._initialize_smart_child(info, child_components)
# Smartcam
from .smartcamchild import SmartCamChild
for info in resp["getChildDeviceList"]["child_device_list"]:
if (
(category := info.get("category"))
and (child_id := info.get("device_id"))
and (child_components := smart_children_components.get(child_id))
):
# Smart
if category in SmartChildDevice.CHILD_DEVICE_TYPE_MAP:
children[child_id] = await self._initialize_smart_child(
info, child_components
)
continue
# Smartcam
if category in SmartCamChild.CHILD_DEVICE_TYPE_MAP:
children[child_id] = await self._initialize_smartcam_child(
info, child_components
)
continue
if category in SmartCamChild.CHILD_DEVICE_TYPE_MAP:
return await self._initialize_smartcam_child(info, child_components)
_LOGGER.debug("Child device type not supported: %s", info)
self._children = children
return None
async def _initialize_modules(self) -> None:
"""Initialize modules based on component negotiation response."""
@ -190,9 +183,6 @@ class SmartCamDevice(SmartDevice):
for feat in module._module_features.values():
self._add_feature(feat)
for child in self._children.values():
await child._initialize_features()
async def _query_setter_helper(
self, method: str, module: str, section: str, params: dict | None = None
) -> dict:

View File

@ -548,6 +548,37 @@ class FakeSmartTransport(BaseTransport):
return {"error_code": 0}
def get_child_device_queries(self, method, params):
return self._get_method_from_info(method, params)
def _get_method_from_info(self, method, params):
result = copy.deepcopy(self.info[method])
if result and "start_index" in result and "sum" in result:
list_key = next(
iter([key for key in result if isinstance(result[key], list)])
)
start_index = (
start_index
if (params and (start_index := params.get("start_index")))
else 0
)
# Fixtures generated before _handle_response_lists was implemented
# could have incomplete lists.
if (
len(result[list_key]) < result["sum"]
and self.fix_incomplete_fixture_lists
):
result["sum"] = len(result[list_key])
if self.warn_fixture_missing_methods:
pytest.fixtures_missing_methods.setdefault( # type: ignore[attr-defined]
self.fixture_name, set()
).add(f"{method} (incomplete '{list_key}' list)")
result[list_key] = result[list_key][
start_index : start_index + self.list_return_size
]
return {"result": result, "error_code": 0}
async def _send_request(self, request_dict: dict):
method = request_dict["method"]
@ -557,33 +588,16 @@ class FakeSmartTransport(BaseTransport):
params = request_dict.get("params", {})
if method in {"component_nego", "qs_component_nego"} or method[:3] == "get":
if method in info:
result = copy.deepcopy(info[method])
if result and "start_index" in result and "sum" in result:
list_key = next(
iter([key for key in result if isinstance(result[key], list)])
)
start_index = (
start_index
if (params and (start_index := params.get("start_index")))
else 0
)
# Fixtures generated before _handle_response_lists was implemented
# could have incomplete lists.
if (
len(result[list_key]) < result["sum"]
and self.fix_incomplete_fixture_lists
):
result["sum"] = len(result[list_key])
if self.warn_fixture_missing_methods:
pytest.fixtures_missing_methods.setdefault( # type: ignore[attr-defined]
self.fixture_name, set()
).add(f"{method} (incomplete '{list_key}' list)")
# These methods are handled in get_child_device_query so it can be
# patched for tests to simulate dynamic devices.
if (
method in ("get_child_device_list", "get_child_device_component_list")
and method in info
):
return self.get_child_device_queries(method, params)
result[list_key] = result[list_key][
start_index : start_index + self.list_return_size
]
return {"result": result, "error_code": 0}
if method in info:
return self._get_method_from_info(method, params)
if self.verbatim:
return {

View File

@ -188,6 +188,33 @@ class FakeSmartCamTransport(BaseTransport):
next(it, None)
return next(it)
def get_child_device_queries(self, method, params):
return self._get_method_from_info(method, params)
def _get_method_from_info(self, method, params):
result = copy.deepcopy(self.info[method])
if "start_index" in result and "sum" in result:
list_key = next(
iter([key for key in result if isinstance(result[key], list)])
)
assert isinstance(params, dict)
module_name = next(iter(params))
start_index = (
start_index
if (
params
and module_name
and (start_index := params[module_name].get("start_index"))
)
else 0
)
result[list_key] = result[list_key][
start_index : start_index + self.list_return_size
]
return {"result": result, "error_code": 0}
async def _send_request(self, request_dict: dict):
method = request_dict["method"]
@ -257,30 +284,18 @@ class FakeSmartCamTransport(BaseTransport):
result = {"device_info": {"basic_info": mapped}}
return {"result": result, "error_code": 0}
# These methods are handled in get_child_device_query so it can be
# patched for tests to simulate dynamic devices.
if (
method in ("getChildDeviceList", "getChildDeviceComponentList")
and method in info
):
params = request_dict.get("params")
return self.get_child_device_queries(method, params)
if method in info:
params = request_dict.get("params")
result = copy.deepcopy(info[method])
if "start_index" in result and "sum" in result:
list_key = next(
iter([key for key in result if isinstance(result[key], list)])
)
assert isinstance(params, dict)
module_name = next(iter(params))
start_index = (
start_index
if (
params
and module_name
and (start_index := params[module_name].get("start_index"))
)
else 0
)
result[list_key] = result[list_key][
start_index : start_index + self.list_return_size
]
return {"result": result, "error_code": 0}
return self._get_method_from_info(method, params)
if self.verbatim:
return {"error_code": -1}

View File

@ -17,6 +17,7 @@ from kasa.exceptions import DeviceError, SmartErrorCode
from kasa.smart import SmartDevice
from kasa.smart.modules.energy import Energy
from kasa.smart.smartmodule import SmartModule
from kasa.smartcam import SmartCamDevice
from tests.conftest import (
DISCOVERY_MOCK_IP,
device_smart,
@ -31,6 +32,9 @@ from tests.device_fixtures import (
variable_temp_smart,
)
from ..fakeprotocol_smart import FakeSmartTransport
from ..fakeprotocol_smartcam import FakeSmartCamTransport
DUMMY_CHILD_REQUEST_PREFIX = "get_dummy_"
hub_all = parametrize_combine([hubs_smart, hub_smartcam])
@ -148,6 +152,7 @@ async def test_negotiate(dev: SmartDevice, mocker: MockerFixture):
"get_child_device_list": None,
}
)
await dev.update()
assert len(dev._children) == dev.internal_state["get_child_device_list"]["sum"]
@ -488,7 +493,12 @@ async def test_update_module_query_errors(
if (
not raise_error
or "component_nego" in request
or "get_child_device_component_list" in request
# allow the initial child device query
or (
"get_child_device_component_list" in request
and "get_child_device_list" in request
and len(request) == 2
)
):
if child_id: # child single query
child_protocol = dev.protocol._transport.child_protocols[child_id]
@ -763,3 +773,218 @@ async def test_smartmodule_query():
)
mod = DummyModule(dummy_device, "dummy")
assert mod.query() == {}
@hub_all
@pytest.mark.xdist_group(name="caplog")
@pytest.mark.requires_dummy
async def test_dynamic_devices(dev: Device, caplog: pytest.LogCaptureFixture):
"""Test dynamic child devices."""
if not dev.children:
pytest.skip(f"Device {dev.model} does not have children.")
transport = dev.protocol._transport
assert isinstance(transport, FakeSmartCamTransport | FakeSmartTransport)
lu = dev._last_update
assert lu
child_device_info = lu.get("getChildDeviceList", lu.get("get_child_device_list"))
assert child_device_info
child_device_components = lu.get(
"getChildDeviceComponentList", lu.get("get_child_device_component_list")
)
assert child_device_components
mock_child_device_info = copy.deepcopy(child_device_info)
mock_child_device_components = copy.deepcopy(child_device_components)
first_child = child_device_info["child_device_list"][0]
first_child_device_id = first_child["device_id"]
first_child_components = next(
iter(
[
cc
for cc in child_device_components["child_component_list"]
if cc["device_id"] == first_child_device_id
]
)
)
first_child_fake_transport = transport.child_protocols[first_child_device_id]
# Test adding devices
start_child_count = len(dev.children)
added_ids = []
for i in range(1, 3):
new_child = copy.deepcopy(first_child)
new_child_components = copy.deepcopy(first_child_components)
mock_device_id = f"mock_child_device_id_{i}"
transport.child_protocols[mock_device_id] = first_child_fake_transport
new_child["device_id"] = mock_device_id
new_child_components["device_id"] = mock_device_id
added_ids.append(mock_device_id)
mock_child_device_info["child_device_list"].append(new_child)
mock_child_device_components["child_component_list"].append(
new_child_components
)
def mock_get_child_device_queries(method, params):
if method in {"getChildDeviceList", "get_child_device_list"}:
result = mock_child_device_info
if method in {"getChildDeviceComponentList", "get_child_device_component_list"}:
result = mock_child_device_components
return {"result": result, "error_code": 0}
with patch.object(
transport, "get_child_device_queries", side_effect=mock_get_child_device_queries
):
await dev.update()
for added_id in added_ids:
assert added_id in dev._children
expected_new_length = start_child_count + len(added_ids)
assert len(dev.children) == expected_new_length
# Test removing devices
mock_child_device_info["child_device_list"] = [
info
for info in mock_child_device_info["child_device_list"]
if info["device_id"] != first_child_device_id
]
mock_child_device_components["child_component_list"] = [
cc
for cc in mock_child_device_components["child_component_list"]
if cc["device_id"] != first_child_device_id
]
with patch.object(
transport, "get_child_device_queries", side_effect=mock_get_child_device_queries
):
await dev.update()
expected_new_length -= 1
assert len(dev.children) == expected_new_length
# Test no child devices
mock_child_device_info["child_device_list"] = []
mock_child_device_components["child_component_list"] = []
mock_child_device_info["sum"] = 0
mock_child_device_components["sum"] = 0
with patch.object(
transport, "get_child_device_queries", side_effect=mock_get_child_device_queries
):
await dev.update()
assert len(dev.children) == 0
# Logging tests are only for smartcam hubs as smart hubs do not test categories
if not isinstance(dev, SmartCamDevice):
return
# setup
mock_child = copy.deepcopy(first_child)
mock_components = copy.deepcopy(first_child_components)
mock_child_device_info["child_device_list"] = [mock_child]
mock_child_device_components["child_component_list"] = [mock_components]
mock_child_device_info["sum"] = 1
mock_child_device_components["sum"] = 1
# Test can't find matching components
mock_child["device_id"] = "no_comps_1"
mock_components["device_id"] = "no_comps_2"
caplog.set_level("DEBUG")
caplog.clear()
with patch.object(
transport, "get_child_device_queries", side_effect=mock_get_child_device_queries
):
await dev.update()
assert "Could not find child components for device" in caplog.text
caplog.clear()
# Test doesn't log multiple
with patch.object(
transport, "get_child_device_queries", side_effect=mock_get_child_device_queries
):
await dev.update()
assert "Could not find child components for device" not in caplog.text
# Test invalid category
mock_child["device_id"] = "invalid_cat"
mock_components["device_id"] = "invalid_cat"
mock_child["category"] = "foobar"
with patch.object(
transport, "get_child_device_queries", side_effect=mock_get_child_device_queries
):
await dev.update()
assert "Child device type not supported" in caplog.text
caplog.clear()
# Test doesn't log multiple
with patch.object(
transport, "get_child_device_queries", side_effect=mock_get_child_device_queries
):
await dev.update()
assert "Child device type not supported" not in caplog.text
# Test no category
mock_child["device_id"] = "no_cat"
mock_components["device_id"] = "no_cat"
mock_child.pop("category")
with patch.object(
transport, "get_child_device_queries", side_effect=mock_get_child_device_queries
):
await dev.update()
assert "Child device type not supported" in caplog.text
# Test only log once
caplog.clear()
with patch.object(
transport, "get_child_device_queries", side_effect=mock_get_child_device_queries
):
await dev.update()
assert "Child device type not supported" not in caplog.text
# Test no device_id
mock_child.pop("device_id")
caplog.clear()
with patch.object(
transport, "get_child_device_queries", side_effect=mock_get_child_device_queries
):
await dev.update()
assert "Could not find child id for device" in caplog.text
# Test only log once
caplog.clear()
with patch.object(
transport, "get_child_device_queries", side_effect=mock_get_child_device_queries
):
await dev.update()
assert "Could not find child id for device" not in caplog.text