diff --git a/kasa/cli.py b/kasa/cli.py index 2f4c2145..1a8ca963 100755 --- a/kasa/cli.py +++ b/kasa/cli.py @@ -1279,12 +1279,11 @@ async def child_list(dev): @pass_dev async def child_pair(dev, timeout): """Pair new device.""" - if "ChildSetupModule" not in dev.modules: - echo("%s is not a hub.") + if (cs := dev.modules.get(Module.ChildSetup)) is None: + echo("%s does not support pairing." % dev) return - echo("Finding new devices for %s" % timeout) - cs = dev.modules["ChildSetupModule"] + echo("Finding new devices for %s seconds" % timeout) return await cs.pair(timeout=timeout) @@ -1293,11 +1292,10 @@ async def child_pair(dev, timeout): @pass_dev async def child_unpair(dev, device_id: str): """Unpair given device.""" - if "ChildSetupModule" not in dev.modules: - echo("%s is not a hub.") + if (cs := dev.modules.get(Module.ChildSetup)) is None: + echo("%s does not support pairing." % dev) return - cs = dev.modules["ChildSetupModule"] res = await cs.unpair(device_id=device_id) echo("Unpaired %s (if it was paired)" % device_id) return res diff --git a/kasa/module.py b/kasa/module.py index a2a9c931..be277dbd 100644 --- a/kasa/module.py +++ b/kasa/module.py @@ -85,6 +85,7 @@ class Module(ABC): WaterleakSensor: Final[ModuleName[smart.WaterleakSensor]] = ModuleName( "WaterleakSensor" ) + ChildSetup: Final[ModuleName[smart.ChildSetup]] = ModuleName("ChildSetup") def __init__(self, device: Device, module: str): self._device = device diff --git a/kasa/smart/modules/__init__.py b/kasa/smart/modules/__init__.py index 4bece213..e431e1c6 100644 --- a/kasa/smart/modules/__init__.py +++ b/kasa/smart/modules/__init__.py @@ -5,8 +5,8 @@ from .autooff import AutoOff from .batterysensor import BatterySensor from .brightness import Brightness from .childdevice import ChildDevice +from .childsetup import ChildSetup from .cloud import Cloud -from .childsetup import ChildSetupModule from .color import Color from .colortemperature import ColorTemperature from .contactsensor import ContactSensor @@ -34,7 +34,7 @@ __all__ = [ "Energy", "DeviceModule", "ChildDevice", - "ChildSetupModule", + "ChildSetup", "BatterySensor", "HumiditySensor", "TemperatureSensor", diff --git a/kasa/smart/modules/childsetup.py b/kasa/smart/modules/childsetup.py index c3659826..209712c9 100644 --- a/kasa/smart/modules/childsetup.py +++ b/kasa/smart/modules/childsetup.py @@ -22,7 +22,7 @@ if TYPE_CHECKING: _LOGGER = logging.getLogger(__name__) -class ChildSetupModule(SmartModule): +class ChildSetup(SmartModule): """Implementation for child device setup.""" REQUIRED_COMPONENT = "child_quick_setup" @@ -33,6 +33,7 @@ class ChildSetupModule(SmartModule): self._add_feature( Feature( device, + id="pair", name="Pair", container=self, attribute_setter="pair", @@ -77,13 +78,13 @@ class ChildSetupModule(SmartModule): """Remove device from the hub.""" payload = {"child_device_list": [{"device_id": device_id}]} res = await self._device._query_helper("remove_child_device_list", payload) - await self._device._initialize_children() + self._device.request_renegotiation() return res async def add_devices(self, devices: dict): """Add devices.""" res = await self._device._query_helper("add_child_device_list", devices) - await self._device._initialize_children() + self._device.request_renegotiation() return res async def get_detected_devices(self) -> dict: diff --git a/kasa/smart/smartdevice.py b/kasa/smart/smartdevice.py index f4e3eb58..6350a00c 100644 --- a/kasa/smart/smartdevice.py +++ b/kasa/smart/smartdevice.py @@ -55,7 +55,6 @@ class SmartDevice(Device): self.protocol: SmartProtocol self._components_raw: dict[str, Any] | None = None self._components: dict[str, int] = {} - self._state_information: dict[str, Any] = {} self._modules: dict[str | ModuleName[Module], SmartModule] = {} self._parent: SmartDevice | None = None self._children: Mapping[str, SmartDevice] = {} @@ -149,6 +148,21 @@ class SmartDevice(Device): if "child_device" in self._components and not self.children: await self._initialize_children() + def request_renegotiation(self) -> None: + """Request renegotiation on the next update. + + This is used by childsetup to inform about new or removed children. + """ + self._modules.clear() + self._features.clear() + self._last_update.clear() + self._components.clear() + if self._components_raw is not None: + self._components_raw.clear() + self._components_raw = None + # we cannot use clear here, as mapping doesn't have it... + self._children = {} + async def update(self, update_children: bool = False): """Update the device.""" if self.credentials is None and self.credentials_hash is None: diff --git a/kasa/tests/fakeprotocol_smart.py b/kasa/tests/fakeprotocol_smart.py index 533cd648..36326e26 100644 --- a/kasa/tests/fakeprotocol_smart.py +++ b/kasa/tests/fakeprotocol_smart.py @@ -115,6 +115,16 @@ class FakeSmartTransport(BaseTransport): }, ), "get_device_usage": ("device", {}), + # child setup + "get_support_child_device_category": ( + "child_quick_setup", + {"device_category_list": [{"category": "subg.trv"}]}, + ), + # no devices found + "get_scan_child_device_list": ( + "child_quick_setup", + {"child_device_list": [{"dummy": "response"}], "scan_status": "idle"}, + ), } async def send(self, request: str): @@ -324,6 +334,13 @@ class FakeSmartTransport(BaseTransport): return self._set_preset_rules(info, params) elif method == "edit_preset_rules": return self._edit_preset_rules(info, params) + # childsetup methods + if method in [ + "begin_scanning_child_device", + "add_child_device_list", + "remove_child_device_list", + ]: + return {"error_code": 0} elif method[:4] == "set_": target_method = f"get_{method[4:]}" info[target_method].update(params) diff --git a/kasa/tests/smart/modules/test_childsetup.py b/kasa/tests/smart/modules/test_childsetup.py new file mode 100644 index 00000000..7673c50f --- /dev/null +++ b/kasa/tests/smart/modules/test_childsetup.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +import logging + +import pytest +from pytest_mock import MockerFixture + +from kasa import Feature, Module +from kasa.smart import SmartDevice +from kasa.tests.device_fixtures import parametrize + +childsetup = parametrize( + "supports pairing", component_filter="child_quick_setup", protocol_filter={"SMART"} +) + + +@childsetup +async def test_childsetup_features(dev: SmartDevice): + """Test the exposed features.""" + cs = dev.modules.get(Module.ChildSetup) + assert cs + + assert "pair" in cs._module_features + pair = cs._module_features["pair"] + assert pair.type == Feature.Type.Action + + +@childsetup +async def test_childsetup_pair( + dev: SmartDevice, mocker: MockerFixture, caplog: pytest.LogCaptureFixture +): + """Test device pairing.""" + caplog.set_level(logging.INFO) + mock_query_helper = mocker.spy(dev, "_query_helper") + mocker.patch("asyncio.sleep") + + cs = dev.modules.get(Module.ChildSetup) + assert cs + + await cs.pair() + + mock_query_helper.assert_has_awaits( + [ + mocker.call("begin_scanning_child_device", None), + mocker.call("get_scan_child_device_list", params=mocker.ANY), + mocker.call("add_child_device_list", params=mocker.ANY), + ] + ) + assert "Discovery done" in caplog.text + + +@childsetup +async def test_childsetup_unpair( + dev: SmartDevice, mocker: MockerFixture, caplog: pytest.LogCaptureFixture +): + """Test unpair.""" + mock_query_helper = mocker.spy(dev, "_query_helper") + DUMMY_ID = "dummy_id" + + cs = dev.modules.get(Module.ChildSetup) + assert cs + + await cs.unpair(DUMMY_ID) + + mock_query_helper.assert_awaited_with( + "remove_child_device_list", + params={"child_device_list": [{"device_id": DUMMY_ID}]}, + ) diff --git a/kasa/tests/test_cli.py b/kasa/tests/test_cli.py index 2104de05..bfd79bce 100644 --- a/kasa/tests/test_cli.py +++ b/kasa/tests/test_cli.py @@ -5,6 +5,7 @@ import re import asyncclick as click import pytest from asyncclick.testing import CliRunner +from pytest_mock import MockerFixture from kasa import ( AuthenticationError, @@ -20,6 +21,7 @@ from kasa.cli import ( TYPE_TO_CLASS, alias, brightness, + child, cli, cmd_command, effect, @@ -244,6 +246,46 @@ async def test_wifi_join_exception(dev, mocker, runner): assert isinstance(res.exception, KasaException) +@device_smart +async def test_child_pair(dev, mocker: MockerFixture, runner, caplog): + """Test that pair calls the expected methods.""" + cs = dev.modules.get(Module.ChildSetup) + # Patch if the device supports the module + if cs is not None: + mock_pair = mocker.patch.object(cs, "pair") + + res = await runner.invoke(child, ["pair"], obj=dev, catch_exceptions=False) + if cs is None: + assert "does not support pairing" in res.output.replace("\n", "") + return + + mock_pair.assert_awaited() + assert "Finding new devices for 10 seconds" in res.output + assert res.exit_code == 0 + + +@device_smart +async def test_child_unpair(dev, mocker: MockerFixture, runner): + """Test that unpair calls the expected method.""" + DUMMY_ID = "dummy_id" + cs = dev.modules.get(Module.ChildSetup) + # Patch if the device supports the module + if cs is not None: + mock_unpair = mocker.patch.object(cs, "unpair") + + res = await runner.invoke( + child, ["unpair", DUMMY_ID], obj=dev, catch_exceptions=False + ) + + if cs is None: + assert "does not support pairing" in res.output.replace("\n", "") + return + + mock_unpair.assert_awaited() + assert f"Unpaired {DUMMY_ID} (if it was paired)" in res.output + assert res.exit_code == 0 + + @device_smart async def test_update_credentials(dev, runner): res = await runner.invoke( diff --git a/kasa/tests/test_feature.py b/kasa/tests/test_feature.py index 0fb7156d..7dcf9672 100644 --- a/kasa/tests/test_feature.py +++ b/kasa/tests/test_feature.py @@ -181,7 +181,8 @@ async def test_feature_setters(dev: Device, mocker: MockerFixture): async def _test_features(dev): exceptions = [] - for feat in dev.features.values(): + feats = dev.features.copy() + for feat in feats.values(): try: with patch.object(feat.device.protocol, "query") as query: await _test_feature(feat, query) @@ -194,6 +195,9 @@ async def test_feature_setters(dev: Device, mocker: MockerFixture): return exceptions + # We mock the device state reset + mocker.patch.object(dev, "request_renegotiation") + exceptions = await _test_features(dev) for child in dev.children: