Enable dynamic hub child creation and deletion on update (#1454)

This commit is contained in:
Steven B.
2025-01-15 19:10:32 +00:00
committed by GitHub
parent 17356c10f1
commit b23019e748
8 changed files with 445 additions and 115 deletions

View File

@@ -548,6 +548,37 @@ class FakeSmartTransport(BaseTransport):
return {"error_code": 0}
def get_child_device_queries(self, method, params):
return self._get_method_from_info(method, params)
def _get_method_from_info(self, method, params):
result = copy.deepcopy(self.info[method])
if result and "start_index" in result and "sum" in result:
list_key = next(
iter([key for key in result if isinstance(result[key], list)])
)
start_index = (
start_index
if (params and (start_index := params.get("start_index")))
else 0
)
# Fixtures generated before _handle_response_lists was implemented
# could have incomplete lists.
if (
len(result[list_key]) < result["sum"]
and self.fix_incomplete_fixture_lists
):
result["sum"] = len(result[list_key])
if self.warn_fixture_missing_methods:
pytest.fixtures_missing_methods.setdefault( # type: ignore[attr-defined]
self.fixture_name, set()
).add(f"{method} (incomplete '{list_key}' list)")
result[list_key] = result[list_key][
start_index : start_index + self.list_return_size
]
return {"result": result, "error_code": 0}
async def _send_request(self, request_dict: dict):
method = request_dict["method"]
@@ -557,33 +588,16 @@ class FakeSmartTransport(BaseTransport):
params = request_dict.get("params", {})
if method in {"component_nego", "qs_component_nego"} or method[:3] == "get":
if method in info:
result = copy.deepcopy(info[method])
if result and "start_index" in result and "sum" in result:
list_key = next(
iter([key for key in result if isinstance(result[key], list)])
)
start_index = (
start_index
if (params and (start_index := params.get("start_index")))
else 0
)
# Fixtures generated before _handle_response_lists was implemented
# could have incomplete lists.
if (
len(result[list_key]) < result["sum"]
and self.fix_incomplete_fixture_lists
):
result["sum"] = len(result[list_key])
if self.warn_fixture_missing_methods:
pytest.fixtures_missing_methods.setdefault( # type: ignore[attr-defined]
self.fixture_name, set()
).add(f"{method} (incomplete '{list_key}' list)")
# These methods are handled in get_child_device_query so it can be
# patched for tests to simulate dynamic devices.
if (
method in ("get_child_device_list", "get_child_device_component_list")
and method in info
):
return self.get_child_device_queries(method, params)
result[list_key] = result[list_key][
start_index : start_index + self.list_return_size
]
return {"result": result, "error_code": 0}
if method in info:
return self._get_method_from_info(method, params)
if self.verbatim:
return {

View File

@@ -188,6 +188,33 @@ class FakeSmartCamTransport(BaseTransport):
next(it, None)
return next(it)
def get_child_device_queries(self, method, params):
return self._get_method_from_info(method, params)
def _get_method_from_info(self, method, params):
result = copy.deepcopy(self.info[method])
if "start_index" in result and "sum" in result:
list_key = next(
iter([key for key in result if isinstance(result[key], list)])
)
assert isinstance(params, dict)
module_name = next(iter(params))
start_index = (
start_index
if (
params
and module_name
and (start_index := params[module_name].get("start_index"))
)
else 0
)
result[list_key] = result[list_key][
start_index : start_index + self.list_return_size
]
return {"result": result, "error_code": 0}
async def _send_request(self, request_dict: dict):
method = request_dict["method"]
@@ -257,30 +284,18 @@ class FakeSmartCamTransport(BaseTransport):
result = {"device_info": {"basic_info": mapped}}
return {"result": result, "error_code": 0}
# These methods are handled in get_child_device_query so it can be
# patched for tests to simulate dynamic devices.
if (
method in ("getChildDeviceList", "getChildDeviceComponentList")
and method in info
):
params = request_dict.get("params")
return self.get_child_device_queries(method, params)
if method in info:
params = request_dict.get("params")
result = copy.deepcopy(info[method])
if "start_index" in result and "sum" in result:
list_key = next(
iter([key for key in result if isinstance(result[key], list)])
)
assert isinstance(params, dict)
module_name = next(iter(params))
start_index = (
start_index
if (
params
and module_name
and (start_index := params[module_name].get("start_index"))
)
else 0
)
result[list_key] = result[list_key][
start_index : start_index + self.list_return_size
]
return {"result": result, "error_code": 0}
return self._get_method_from_info(method, params)
if self.verbatim:
return {"error_code": -1}

View File

@@ -17,6 +17,7 @@ from kasa.exceptions import DeviceError, SmartErrorCode
from kasa.smart import SmartDevice
from kasa.smart.modules.energy import Energy
from kasa.smart.smartmodule import SmartModule
from kasa.smartcam import SmartCamDevice
from tests.conftest import (
DISCOVERY_MOCK_IP,
device_smart,
@@ -31,6 +32,9 @@ from tests.device_fixtures import (
variable_temp_smart,
)
from ..fakeprotocol_smart import FakeSmartTransport
from ..fakeprotocol_smartcam import FakeSmartCamTransport
DUMMY_CHILD_REQUEST_PREFIX = "get_dummy_"
hub_all = parametrize_combine([hubs_smart, hub_smartcam])
@@ -148,6 +152,7 @@ async def test_negotiate(dev: SmartDevice, mocker: MockerFixture):
"get_child_device_list": None,
}
)
await dev.update()
assert len(dev._children) == dev.internal_state["get_child_device_list"]["sum"]
@@ -488,7 +493,12 @@ async def test_update_module_query_errors(
if (
not raise_error
or "component_nego" in request
or "get_child_device_component_list" in request
# allow the initial child device query
or (
"get_child_device_component_list" in request
and "get_child_device_list" in request
and len(request) == 2
)
):
if child_id: # child single query
child_protocol = dev.protocol._transport.child_protocols[child_id]
@@ -763,3 +773,218 @@ async def test_smartmodule_query():
)
mod = DummyModule(dummy_device, "dummy")
assert mod.query() == {}
@hub_all
@pytest.mark.xdist_group(name="caplog")
@pytest.mark.requires_dummy
async def test_dynamic_devices(dev: Device, caplog: pytest.LogCaptureFixture):
"""Test dynamic child devices."""
if not dev.children:
pytest.skip(f"Device {dev.model} does not have children.")
transport = dev.protocol._transport
assert isinstance(transport, FakeSmartCamTransport | FakeSmartTransport)
lu = dev._last_update
assert lu
child_device_info = lu.get("getChildDeviceList", lu.get("get_child_device_list"))
assert child_device_info
child_device_components = lu.get(
"getChildDeviceComponentList", lu.get("get_child_device_component_list")
)
assert child_device_components
mock_child_device_info = copy.deepcopy(child_device_info)
mock_child_device_components = copy.deepcopy(child_device_components)
first_child = child_device_info["child_device_list"][0]
first_child_device_id = first_child["device_id"]
first_child_components = next(
iter(
[
cc
for cc in child_device_components["child_component_list"]
if cc["device_id"] == first_child_device_id
]
)
)
first_child_fake_transport = transport.child_protocols[first_child_device_id]
# Test adding devices
start_child_count = len(dev.children)
added_ids = []
for i in range(1, 3):
new_child = copy.deepcopy(first_child)
new_child_components = copy.deepcopy(first_child_components)
mock_device_id = f"mock_child_device_id_{i}"
transport.child_protocols[mock_device_id] = first_child_fake_transport
new_child["device_id"] = mock_device_id
new_child_components["device_id"] = mock_device_id
added_ids.append(mock_device_id)
mock_child_device_info["child_device_list"].append(new_child)
mock_child_device_components["child_component_list"].append(
new_child_components
)
def mock_get_child_device_queries(method, params):
if method in {"getChildDeviceList", "get_child_device_list"}:
result = mock_child_device_info
if method in {"getChildDeviceComponentList", "get_child_device_component_list"}:
result = mock_child_device_components
return {"result": result, "error_code": 0}
with patch.object(
transport, "get_child_device_queries", side_effect=mock_get_child_device_queries
):
await dev.update()
for added_id in added_ids:
assert added_id in dev._children
expected_new_length = start_child_count + len(added_ids)
assert len(dev.children) == expected_new_length
# Test removing devices
mock_child_device_info["child_device_list"] = [
info
for info in mock_child_device_info["child_device_list"]
if info["device_id"] != first_child_device_id
]
mock_child_device_components["child_component_list"] = [
cc
for cc in mock_child_device_components["child_component_list"]
if cc["device_id"] != first_child_device_id
]
with patch.object(
transport, "get_child_device_queries", side_effect=mock_get_child_device_queries
):
await dev.update()
expected_new_length -= 1
assert len(dev.children) == expected_new_length
# Test no child devices
mock_child_device_info["child_device_list"] = []
mock_child_device_components["child_component_list"] = []
mock_child_device_info["sum"] = 0
mock_child_device_components["sum"] = 0
with patch.object(
transport, "get_child_device_queries", side_effect=mock_get_child_device_queries
):
await dev.update()
assert len(dev.children) == 0
# Logging tests are only for smartcam hubs as smart hubs do not test categories
if not isinstance(dev, SmartCamDevice):
return
# setup
mock_child = copy.deepcopy(first_child)
mock_components = copy.deepcopy(first_child_components)
mock_child_device_info["child_device_list"] = [mock_child]
mock_child_device_components["child_component_list"] = [mock_components]
mock_child_device_info["sum"] = 1
mock_child_device_components["sum"] = 1
# Test can't find matching components
mock_child["device_id"] = "no_comps_1"
mock_components["device_id"] = "no_comps_2"
caplog.set_level("DEBUG")
caplog.clear()
with patch.object(
transport, "get_child_device_queries", side_effect=mock_get_child_device_queries
):
await dev.update()
assert "Could not find child components for device" in caplog.text
caplog.clear()
# Test doesn't log multiple
with patch.object(
transport, "get_child_device_queries", side_effect=mock_get_child_device_queries
):
await dev.update()
assert "Could not find child components for device" not in caplog.text
# Test invalid category
mock_child["device_id"] = "invalid_cat"
mock_components["device_id"] = "invalid_cat"
mock_child["category"] = "foobar"
with patch.object(
transport, "get_child_device_queries", side_effect=mock_get_child_device_queries
):
await dev.update()
assert "Child device type not supported" in caplog.text
caplog.clear()
# Test doesn't log multiple
with patch.object(
transport, "get_child_device_queries", side_effect=mock_get_child_device_queries
):
await dev.update()
assert "Child device type not supported" not in caplog.text
# Test no category
mock_child["device_id"] = "no_cat"
mock_components["device_id"] = "no_cat"
mock_child.pop("category")
with patch.object(
transport, "get_child_device_queries", side_effect=mock_get_child_device_queries
):
await dev.update()
assert "Child device type not supported" in caplog.text
# Test only log once
caplog.clear()
with patch.object(
transport, "get_child_device_queries", side_effect=mock_get_child_device_queries
):
await dev.update()
assert "Child device type not supported" not in caplog.text
# Test no device_id
mock_child.pop("device_id")
caplog.clear()
with patch.object(
transport, "get_child_device_queries", side_effect=mock_get_child_device_queries
):
await dev.update()
assert "Could not find child id for device" in caplog.text
# Test only log once
caplog.clear()
with patch.object(
transport, "get_child_device_queries", side_effect=mock_get_child_device_queries
):
await dev.update()
assert "Could not find child id for device" not in caplog.text