Add tests and refactor

This commit is contained in:
Teemu Rytilahti 2024-06-15 02:43:48 +02:00
parent 1a758a6a53
commit 803e88d508
9 changed files with 159 additions and 14 deletions

View File

@ -1279,12 +1279,11 @@ async def child_list(dev):
@pass_dev
async def child_pair(dev, timeout):
"""Pair new device."""
if "ChildSetupModule" not in dev.modules:
echo("%s is not a hub.")
if (cs := dev.modules.get(Module.ChildSetup)) is None:
echo("%s does not support pairing." % dev)
return
echo("Finding new devices for %s" % timeout)
cs = dev.modules["ChildSetupModule"]
echo("Finding new devices for %s seconds" % timeout)
return await cs.pair(timeout=timeout)
@ -1293,11 +1292,10 @@ async def child_pair(dev, timeout):
@pass_dev
async def child_unpair(dev, device_id: str):
"""Unpair given device."""
if "ChildSetupModule" not in dev.modules:
echo("%s is not a hub.")
if (cs := dev.modules.get(Module.ChildSetup)) is None:
echo("%s does not support pairing." % dev)
return
cs = dev.modules["ChildSetupModule"]
res = await cs.unpair(device_id=device_id)
echo("Unpaired %s (if it was paired)" % device_id)
return res

View File

@ -85,6 +85,7 @@ class Module(ABC):
WaterleakSensor: Final[ModuleName[smart.WaterleakSensor]] = ModuleName(
"WaterleakSensor"
)
ChildSetup: Final[ModuleName[smart.ChildSetup]] = ModuleName("ChildSetup")
def __init__(self, device: Device, module: str):
self._device = device

View File

@ -5,8 +5,8 @@ from .autooff import AutoOff
from .batterysensor import BatterySensor
from .brightness import Brightness
from .childdevice import ChildDevice
from .childsetup import ChildSetup
from .cloud import Cloud
from .childsetup import ChildSetupModule
from .color import Color
from .colortemperature import ColorTemperature
from .contactsensor import ContactSensor
@ -34,7 +34,7 @@ __all__ = [
"Energy",
"DeviceModule",
"ChildDevice",
"ChildSetupModule",
"ChildSetup",
"BatterySensor",
"HumiditySensor",
"TemperatureSensor",

View File

@ -22,7 +22,7 @@ if TYPE_CHECKING:
_LOGGER = logging.getLogger(__name__)
class ChildSetupModule(SmartModule):
class ChildSetup(SmartModule):
"""Implementation for child device setup."""
REQUIRED_COMPONENT = "child_quick_setup"
@ -33,6 +33,7 @@ class ChildSetupModule(SmartModule):
self._add_feature(
Feature(
device,
id="pair",
name="Pair",
container=self,
attribute_setter="pair",
@ -77,13 +78,13 @@ class ChildSetupModule(SmartModule):
"""Remove device from the hub."""
payload = {"child_device_list": [{"device_id": device_id}]}
res = await self._device._query_helper("remove_child_device_list", payload)
await self._device._initialize_children()
self._device.request_renegotiation()
return res
async def add_devices(self, devices: dict):
"""Add devices."""
res = await self._device._query_helper("add_child_device_list", devices)
await self._device._initialize_children()
self._device.request_renegotiation()
return res
async def get_detected_devices(self) -> dict:

View File

@ -55,7 +55,6 @@ class SmartDevice(Device):
self.protocol: SmartProtocol
self._components_raw: dict[str, Any] | None = None
self._components: dict[str, int] = {}
self._state_information: dict[str, Any] = {}
self._modules: dict[str | ModuleName[Module], SmartModule] = {}
self._parent: SmartDevice | None = None
self._children: Mapping[str, SmartDevice] = {}
@ -149,6 +148,21 @@ class SmartDevice(Device):
if "child_device" in self._components and not self.children:
await self._initialize_children()
def request_renegotiation(self) -> None:
"""Request renegotiation on the next update.
This is used by childsetup to inform about new or removed children.
"""
self._modules.clear()
self._features.clear()
self._last_update.clear()
self._components.clear()
if self._components_raw is not None:
self._components_raw.clear()
self._components_raw = None
# we cannot use clear here, as mapping doesn't have it...
self._children = {}
async def update(self, update_children: bool = False):
"""Update the device."""
if self.credentials is None and self.credentials_hash is None:

View File

@ -115,6 +115,16 @@ class FakeSmartTransport(BaseTransport):
},
),
"get_device_usage": ("device", {}),
# child setup
"get_support_child_device_category": (
"child_quick_setup",
{"device_category_list": [{"category": "subg.trv"}]},
),
# no devices found
"get_scan_child_device_list": (
"child_quick_setup",
{"child_device_list": [{"dummy": "response"}], "scan_status": "idle"},
),
}
async def send(self, request: str):
@ -324,6 +334,13 @@ class FakeSmartTransport(BaseTransport):
return self._set_preset_rules(info, params)
elif method == "edit_preset_rules":
return self._edit_preset_rules(info, params)
# childsetup methods
if method in [
"begin_scanning_child_device",
"add_child_device_list",
"remove_child_device_list",
]:
return {"error_code": 0}
elif method[:4] == "set_":
target_method = f"get_{method[4:]}"
info[target_method].update(params)

View File

@ -0,0 +1,68 @@
from __future__ import annotations
import logging
import pytest
from pytest_mock import MockerFixture
from kasa import Feature, Module
from kasa.smart import SmartDevice
from kasa.tests.device_fixtures import parametrize
childsetup = parametrize(
"supports pairing", component_filter="child_quick_setup", protocol_filter={"SMART"}
)
@childsetup
async def test_childsetup_features(dev: SmartDevice):
"""Test the exposed features."""
cs = dev.modules.get(Module.ChildSetup)
assert cs
assert "pair" in cs._module_features
pair = cs._module_features["pair"]
assert pair.type == Feature.Type.Action
@childsetup
async def test_childsetup_pair(
dev: SmartDevice, mocker: MockerFixture, caplog: pytest.LogCaptureFixture
):
"""Test device pairing."""
caplog.set_level(logging.INFO)
mock_query_helper = mocker.spy(dev, "_query_helper")
mocker.patch("asyncio.sleep")
cs = dev.modules.get(Module.ChildSetup)
assert cs
await cs.pair()
mock_query_helper.assert_has_awaits(
[
mocker.call("begin_scanning_child_device", None),
mocker.call("get_scan_child_device_list", params=mocker.ANY),
mocker.call("add_child_device_list", params=mocker.ANY),
]
)
assert "Discovery done" in caplog.text
@childsetup
async def test_childsetup_unpair(
dev: SmartDevice, mocker: MockerFixture, caplog: pytest.LogCaptureFixture
):
"""Test unpair."""
mock_query_helper = mocker.spy(dev, "_query_helper")
DUMMY_ID = "dummy_id"
cs = dev.modules.get(Module.ChildSetup)
assert cs
await cs.unpair(DUMMY_ID)
mock_query_helper.assert_awaited_with(
"remove_child_device_list",
params={"child_device_list": [{"device_id": DUMMY_ID}]},
)

View File

@ -5,6 +5,7 @@ import re
import asyncclick as click
import pytest
from asyncclick.testing import CliRunner
from pytest_mock import MockerFixture
from kasa import (
AuthenticationError,
@ -20,6 +21,7 @@ from kasa.cli import (
TYPE_TO_CLASS,
alias,
brightness,
child,
cli,
cmd_command,
effect,
@ -244,6 +246,46 @@ async def test_wifi_join_exception(dev, mocker, runner):
assert isinstance(res.exception, KasaException)
@device_smart
async def test_child_pair(dev, mocker: MockerFixture, runner, caplog):
"""Test that pair calls the expected methods."""
cs = dev.modules.get(Module.ChildSetup)
# Patch if the device supports the module
if cs is not None:
mock_pair = mocker.patch.object(cs, "pair")
res = await runner.invoke(child, ["pair"], obj=dev, catch_exceptions=False)
if cs is None:
assert "does not support pairing" in res.output.replace("\n", "")
return
mock_pair.assert_awaited()
assert "Finding new devices for 10 seconds" in res.output
assert res.exit_code == 0
@device_smart
async def test_child_unpair(dev, mocker: MockerFixture, runner):
"""Test that unpair calls the expected method."""
DUMMY_ID = "dummy_id"
cs = dev.modules.get(Module.ChildSetup)
# Patch if the device supports the module
if cs is not None:
mock_unpair = mocker.patch.object(cs, "unpair")
res = await runner.invoke(
child, ["unpair", DUMMY_ID], obj=dev, catch_exceptions=False
)
if cs is None:
assert "does not support pairing" in res.output.replace("\n", "")
return
mock_unpair.assert_awaited()
assert f"Unpaired {DUMMY_ID} (if it was paired)" in res.output
assert res.exit_code == 0
@device_smart
async def test_update_credentials(dev, runner):
res = await runner.invoke(

View File

@ -181,7 +181,8 @@ async def test_feature_setters(dev: Device, mocker: MockerFixture):
async def _test_features(dev):
exceptions = []
for feat in dev.features.values():
feats = dev.features.copy()
for feat in feats.values():
try:
with patch.object(feat.device.protocol, "query") as query:
await _test_feature(feat, query)
@ -194,6 +195,9 @@ async def test_feature_setters(dev: Device, mocker: MockerFixture):
return exceptions
# We mock the device state reset
mocker.patch.object(dev, "request_renegotiation")
exceptions = await _test_features(dev)
for child in dev.children: