Add tests and refactor

This commit is contained in:
Teemu Rytilahti 2024-06-15 02:43:48 +02:00
parent 1a758a6a53
commit 803e88d508
9 changed files with 159 additions and 14 deletions

View File

@ -1279,12 +1279,11 @@ async def child_list(dev):
@pass_dev @pass_dev
async def child_pair(dev, timeout): async def child_pair(dev, timeout):
"""Pair new device.""" """Pair new device."""
if "ChildSetupModule" not in dev.modules: if (cs := dev.modules.get(Module.ChildSetup)) is None:
echo("%s is not a hub.") echo("%s does not support pairing." % dev)
return return
echo("Finding new devices for %s" % timeout) echo("Finding new devices for %s seconds" % timeout)
cs = dev.modules["ChildSetupModule"]
return await cs.pair(timeout=timeout) return await cs.pair(timeout=timeout)
@ -1293,11 +1292,10 @@ async def child_pair(dev, timeout):
@pass_dev @pass_dev
async def child_unpair(dev, device_id: str): async def child_unpair(dev, device_id: str):
"""Unpair given device.""" """Unpair given device."""
if "ChildSetupModule" not in dev.modules: if (cs := dev.modules.get(Module.ChildSetup)) is None:
echo("%s is not a hub.") echo("%s does not support pairing." % dev)
return return
cs = dev.modules["ChildSetupModule"]
res = await cs.unpair(device_id=device_id) res = await cs.unpair(device_id=device_id)
echo("Unpaired %s (if it was paired)" % device_id) echo("Unpaired %s (if it was paired)" % device_id)
return res return res

View File

@ -85,6 +85,7 @@ class Module(ABC):
WaterleakSensor: Final[ModuleName[smart.WaterleakSensor]] = ModuleName( WaterleakSensor: Final[ModuleName[smart.WaterleakSensor]] = ModuleName(
"WaterleakSensor" "WaterleakSensor"
) )
ChildSetup: Final[ModuleName[smart.ChildSetup]] = ModuleName("ChildSetup")
def __init__(self, device: Device, module: str): def __init__(self, device: Device, module: str):
self._device = device self._device = device

View File

@ -5,8 +5,8 @@ from .autooff import AutoOff
from .batterysensor import BatterySensor from .batterysensor import BatterySensor
from .brightness import Brightness from .brightness import Brightness
from .childdevice import ChildDevice from .childdevice import ChildDevice
from .childsetup import ChildSetup
from .cloud import Cloud from .cloud import Cloud
from .childsetup import ChildSetupModule
from .color import Color from .color import Color
from .colortemperature import ColorTemperature from .colortemperature import ColorTemperature
from .contactsensor import ContactSensor from .contactsensor import ContactSensor
@ -34,7 +34,7 @@ __all__ = [
"Energy", "Energy",
"DeviceModule", "DeviceModule",
"ChildDevice", "ChildDevice",
"ChildSetupModule", "ChildSetup",
"BatterySensor", "BatterySensor",
"HumiditySensor", "HumiditySensor",
"TemperatureSensor", "TemperatureSensor",

View File

@ -22,7 +22,7 @@ if TYPE_CHECKING:
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
class ChildSetupModule(SmartModule): class ChildSetup(SmartModule):
"""Implementation for child device setup.""" """Implementation for child device setup."""
REQUIRED_COMPONENT = "child_quick_setup" REQUIRED_COMPONENT = "child_quick_setup"
@ -33,6 +33,7 @@ class ChildSetupModule(SmartModule):
self._add_feature( self._add_feature(
Feature( Feature(
device, device,
id="pair",
name="Pair", name="Pair",
container=self, container=self,
attribute_setter="pair", attribute_setter="pair",
@ -77,13 +78,13 @@ class ChildSetupModule(SmartModule):
"""Remove device from the hub.""" """Remove device from the hub."""
payload = {"child_device_list": [{"device_id": device_id}]} payload = {"child_device_list": [{"device_id": device_id}]}
res = await self._device._query_helper("remove_child_device_list", payload) res = await self._device._query_helper("remove_child_device_list", payload)
await self._device._initialize_children() self._device.request_renegotiation()
return res return res
async def add_devices(self, devices: dict): async def add_devices(self, devices: dict):
"""Add devices.""" """Add devices."""
res = await self._device._query_helper("add_child_device_list", devices) res = await self._device._query_helper("add_child_device_list", devices)
await self._device._initialize_children() self._device.request_renegotiation()
return res return res
async def get_detected_devices(self) -> dict: async def get_detected_devices(self) -> dict:

View File

@ -55,7 +55,6 @@ class SmartDevice(Device):
self.protocol: SmartProtocol self.protocol: SmartProtocol
self._components_raw: dict[str, Any] | None = None self._components_raw: dict[str, Any] | None = None
self._components: dict[str, int] = {} self._components: dict[str, int] = {}
self._state_information: dict[str, Any] = {}
self._modules: dict[str | ModuleName[Module], SmartModule] = {} self._modules: dict[str | ModuleName[Module], SmartModule] = {}
self._parent: SmartDevice | None = None self._parent: SmartDevice | None = None
self._children: Mapping[str, SmartDevice] = {} self._children: Mapping[str, SmartDevice] = {}
@ -149,6 +148,21 @@ class SmartDevice(Device):
if "child_device" in self._components and not self.children: if "child_device" in self._components and not self.children:
await self._initialize_children() await self._initialize_children()
def request_renegotiation(self) -> None:
"""Request renegotiation on the next update.
This is used by childsetup to inform about new or removed children.
"""
self._modules.clear()
self._features.clear()
self._last_update.clear()
self._components.clear()
if self._components_raw is not None:
self._components_raw.clear()
self._components_raw = None
# we cannot use clear here, as mapping doesn't have it...
self._children = {}
async def update(self, update_children: bool = False): async def update(self, update_children: bool = False):
"""Update the device.""" """Update the device."""
if self.credentials is None and self.credentials_hash is None: if self.credentials is None and self.credentials_hash is None:

View File

@ -115,6 +115,16 @@ class FakeSmartTransport(BaseTransport):
}, },
), ),
"get_device_usage": ("device", {}), "get_device_usage": ("device", {}),
# child setup
"get_support_child_device_category": (
"child_quick_setup",
{"device_category_list": [{"category": "subg.trv"}]},
),
# no devices found
"get_scan_child_device_list": (
"child_quick_setup",
{"child_device_list": [{"dummy": "response"}], "scan_status": "idle"},
),
} }
async def send(self, request: str): async def send(self, request: str):
@ -324,6 +334,13 @@ class FakeSmartTransport(BaseTransport):
return self._set_preset_rules(info, params) return self._set_preset_rules(info, params)
elif method == "edit_preset_rules": elif method == "edit_preset_rules":
return self._edit_preset_rules(info, params) return self._edit_preset_rules(info, params)
# childsetup methods
if method in [
"begin_scanning_child_device",
"add_child_device_list",
"remove_child_device_list",
]:
return {"error_code": 0}
elif method[:4] == "set_": elif method[:4] == "set_":
target_method = f"get_{method[4:]}" target_method = f"get_{method[4:]}"
info[target_method].update(params) info[target_method].update(params)

View File

@ -0,0 +1,68 @@
from __future__ import annotations
import logging
import pytest
from pytest_mock import MockerFixture
from kasa import Feature, Module
from kasa.smart import SmartDevice
from kasa.tests.device_fixtures import parametrize
childsetup = parametrize(
"supports pairing", component_filter="child_quick_setup", protocol_filter={"SMART"}
)
@childsetup
async def test_childsetup_features(dev: SmartDevice):
"""Test the exposed features."""
cs = dev.modules.get(Module.ChildSetup)
assert cs
assert "pair" in cs._module_features
pair = cs._module_features["pair"]
assert pair.type == Feature.Type.Action
@childsetup
async def test_childsetup_pair(
dev: SmartDevice, mocker: MockerFixture, caplog: pytest.LogCaptureFixture
):
"""Test device pairing."""
caplog.set_level(logging.INFO)
mock_query_helper = mocker.spy(dev, "_query_helper")
mocker.patch("asyncio.sleep")
cs = dev.modules.get(Module.ChildSetup)
assert cs
await cs.pair()
mock_query_helper.assert_has_awaits(
[
mocker.call("begin_scanning_child_device", None),
mocker.call("get_scan_child_device_list", params=mocker.ANY),
mocker.call("add_child_device_list", params=mocker.ANY),
]
)
assert "Discovery done" in caplog.text
@childsetup
async def test_childsetup_unpair(
dev: SmartDevice, mocker: MockerFixture, caplog: pytest.LogCaptureFixture
):
"""Test unpair."""
mock_query_helper = mocker.spy(dev, "_query_helper")
DUMMY_ID = "dummy_id"
cs = dev.modules.get(Module.ChildSetup)
assert cs
await cs.unpair(DUMMY_ID)
mock_query_helper.assert_awaited_with(
"remove_child_device_list",
params={"child_device_list": [{"device_id": DUMMY_ID}]},
)

View File

@ -5,6 +5,7 @@ import re
import asyncclick as click import asyncclick as click
import pytest import pytest
from asyncclick.testing import CliRunner from asyncclick.testing import CliRunner
from pytest_mock import MockerFixture
from kasa import ( from kasa import (
AuthenticationError, AuthenticationError,
@ -20,6 +21,7 @@ from kasa.cli import (
TYPE_TO_CLASS, TYPE_TO_CLASS,
alias, alias,
brightness, brightness,
child,
cli, cli,
cmd_command, cmd_command,
effect, effect,
@ -244,6 +246,46 @@ async def test_wifi_join_exception(dev, mocker, runner):
assert isinstance(res.exception, KasaException) assert isinstance(res.exception, KasaException)
@device_smart
async def test_child_pair(dev, mocker: MockerFixture, runner, caplog):
"""Test that pair calls the expected methods."""
cs = dev.modules.get(Module.ChildSetup)
# Patch if the device supports the module
if cs is not None:
mock_pair = mocker.patch.object(cs, "pair")
res = await runner.invoke(child, ["pair"], obj=dev, catch_exceptions=False)
if cs is None:
assert "does not support pairing" in res.output.replace("\n", "")
return
mock_pair.assert_awaited()
assert "Finding new devices for 10 seconds" in res.output
assert res.exit_code == 0
@device_smart
async def test_child_unpair(dev, mocker: MockerFixture, runner):
"""Test that unpair calls the expected method."""
DUMMY_ID = "dummy_id"
cs = dev.modules.get(Module.ChildSetup)
# Patch if the device supports the module
if cs is not None:
mock_unpair = mocker.patch.object(cs, "unpair")
res = await runner.invoke(
child, ["unpair", DUMMY_ID], obj=dev, catch_exceptions=False
)
if cs is None:
assert "does not support pairing" in res.output.replace("\n", "")
return
mock_unpair.assert_awaited()
assert f"Unpaired {DUMMY_ID} (if it was paired)" in res.output
assert res.exit_code == 0
@device_smart @device_smart
async def test_update_credentials(dev, runner): async def test_update_credentials(dev, runner):
res = await runner.invoke( res = await runner.invoke(

View File

@ -181,7 +181,8 @@ async def test_feature_setters(dev: Device, mocker: MockerFixture):
async def _test_features(dev): async def _test_features(dev):
exceptions = [] exceptions = []
for feat in dev.features.values(): feats = dev.features.copy()
for feat in feats.values():
try: try:
with patch.object(feat.device.protocol, "query") as query: with patch.object(feat.device.protocol, "query") as query:
await _test_feature(feat, query) await _test_feature(feat, query)
@ -194,6 +195,9 @@ async def test_feature_setters(dev: Device, mocker: MockerFixture):
return exceptions return exceptions
# We mock the device state reset
mocker.patch.object(dev, "request_renegotiation")
exceptions = await _test_features(dev) exceptions = await _test_features(dev)
for child in dev.children: for child in dev.children: