From 51611156217b2d1cb12e17455193adf0dd066e13 Mon Sep 17 00:00:00 2001 From: "Steven B." <51370195+sdb9696@users.noreply.github.com> Date: Sun, 27 Oct 2024 12:08:02 +0000 Subject: [PATCH] Update SMART test framework to use fake child protocols (#1199) --- kasa/tests/fakeprotocol_smart.py | 155 ++++++++++++++++++++++--- kasa/tests/fakeprotocol_smartcamera.py | 68 ++--------- kasa/tests/fixtureinfo.py | 11 +- kasa/tests/test_emeter.py | 11 ++ 4 files changed, 170 insertions(+), 75 deletions(-) diff --git a/kasa/tests/fakeprotocol_smart.py b/kasa/tests/fakeprotocol_smart.py index 6c9423ec..c3d8104e 100644 --- a/kasa/tests/fakeprotocol_smart.py +++ b/kasa/tests/fakeprotocol_smart.py @@ -1,17 +1,19 @@ import copy from json import loads as json_loads +from warnings import warn import pytest from kasa import Credentials, DeviceConfig, SmartProtocol from kasa.exceptions import SmartErrorCode from kasa.protocol import BaseTransport +from kasa.smart import SmartChildDevice class FakeSmartProtocol(SmartProtocol): - def __init__(self, info, fixture_name): + def __init__(self, info, fixture_name, *, is_child=False): super().__init__( - transport=FakeSmartTransport(info, fixture_name), + transport=FakeSmartTransport(info, fixture_name, is_child=is_child), ) async def query(self, request, retry_count: int = 3): @@ -30,6 +32,7 @@ class FakeSmartTransport(BaseTransport): component_nego_not_included=False, warn_fixture_missing_methods=True, fix_incomplete_fixture_lists=True, + is_child=False, ): super().__init__( config=DeviceConfig( @@ -41,7 +44,15 @@ class FakeSmartTransport(BaseTransport): ), ) self.fixture_name = fixture_name - self.info = copy.deepcopy(info) + # Don't copy the dict if the device is a child so that updates on the + # child are then still reflected on the parent's lis of child device in + if not is_child: + self.info = copy.deepcopy(info) + self.child_protocols = self._get_child_protocols( + self.info, self.fixture_name, "get_child_device_list" + ) + else: + self.info = info if not component_nego_not_included: self.components = { comp["id"]: comp["ver_code"] @@ -125,7 +136,7 @@ class FakeSmartTransport(BaseTransport): params = request_dict["params"] responses = [] for request in params["requests"]: - response = self._send_request(request) # type: ignore[arg-type] + response = await self._send_request(request) # type: ignore[arg-type] # Devices do not continue after error if response["error_code"] != 0: break @@ -133,11 +144,111 @@ class FakeSmartTransport(BaseTransport): responses.append(response) return {"result": {"responses": responses}, "error_code": 0} else: - return self._send_request(request_dict) + return await self._send_request(request_dict) - def _handle_control_child(self, params: dict): + @staticmethod + def _get_child_protocols( + parent_fixture_info, parent_fixture_name, child_devices_key + ): + child_infos = parent_fixture_info.get(child_devices_key, {}).get( + "child_device_list", [] + ) + if not child_infos: + return + found_child_fixture_infos = [] + child_protocols = {} + # imported here to avoid circular import + from .conftest import filter_fixtures + + def try_get_child_fixture_info(child_dev_info): + hw_version = child_dev_info["hw_ver"] + sw_version = child_dev_info["fw_ver"] + sw_version = sw_version.split(" ")[0] + model = child_dev_info["model"] + region = child_dev_info.get("specs", "XX") + child_fixture_name = f"{model}({region})_{hw_version}_{sw_version}" + child_fixtures = filter_fixtures( + "Child fixture", + protocol_filter={"SMART.CHILD"}, + model_filter={child_fixture_name}, + ) + if child_fixtures: + return next(iter(child_fixtures)) + return None + + for child_info in child_infos: + if ( # Is SMART protocol + (device_id := child_info.get("device_id")) + and (category := child_info.get("category")) + and category in SmartChildDevice.CHILD_DEVICE_TYPE_MAP + ): + if fixture_info_tuple := try_get_child_fixture_info(child_info): + child_fixture = copy.deepcopy(fixture_info_tuple.data) + child_fixture["get_device_info"]["device_id"] = device_id + found_child_fixture_infos.append(child_fixture["get_device_info"]) + child_protocols[device_id] = FakeSmartProtocol( + child_fixture, fixture_info_tuple.name, is_child=True + ) + # Look for fixture inline + elif (child_fixtures := parent_fixture_info.get("child_devices")) and ( + child_fixture := child_fixtures.get(device_id) + ): + found_child_fixture_infos.append(child_fixture["get_device_info"]) + child_protocols[device_id] = FakeSmartProtocol( + child_fixture, + f"{parent_fixture_name}-{device_id}", + is_child=True, + ) + else: + warn( + f"Could not find child SMART fixture for {child_info}", + stacklevel=1, + ) + else: + warn( + f"Child is a cameraprotocol which needs to be implemented {child_info}", + stacklevel=1, + ) + # Replace parent child infos with the infos from the child fixtures so + # that updates update both + if child_infos and found_child_fixture_infos: + parent_fixture_info[child_devices_key]["child_device_list"] = ( + found_child_fixture_infos + ) + return child_protocols + + async def _handle_control_child(self, params: dict): """Handle control_child command.""" device_id = params.get("device_id") + if device_id not in self.child_protocols: + warn( + f"Could not find child fixture {device_id} in {self.fixture_name}", + stacklevel=1, + ) + return self._handle_control_child_missing(params) + + child_protocol: SmartProtocol = self.child_protocols[device_id] + + request_data = params.get("requestData", {}) + + child_method = request_data.get("method") + child_params = request_data.get("params") # noqa: F841 + + resp = await child_protocol.query({child_method: child_params}) + resp["error_code"] = 0 + for val in resp.values(): + return { + "result": {"responseData": {"result": val, "error_code": 0}}, + "error_code": 0, + } + + def _handle_control_child_missing(self, params: dict): + """Handle control_child command. + + Used for older fixtures where child info wasn't stored in the fixture. + TODO: Should be removed somehow for future maintanability. + """ + device_id = params.get("device_id") request_data = params.get("requestData", {}) child_method = request_data.get("method") @@ -156,7 +267,7 @@ class FakeSmartTransport(BaseTransport): # Get the method calls made directly on the child devices child_device_calls = self.info["child_devices"].setdefault(device_id, {}) - # We only support get & set device info for now. + # We only support get & set device info in this method for missing. if child_method == "get_device_info": result = copy.deepcopy(info) return {"result": result, "error_code": 0} @@ -216,14 +327,17 @@ class FakeSmartTransport(BaseTransport): def _set_on_off_gradually_info(self, info, params): # Child devices can have the required properties directly in info + # the _handle_control_child_missing directly passes in get_device_info + sys_info = info.get("get_device_info", info) + if self.components["on_off_gradually"] == 1: info["get_on_off_gradually_info"] = {"enable": params["enable"]} elif on_state := params.get("on_state"): - if "fade_on_time" in info and "gradually_on_mode" in info: - info["gradually_on_mode"] = 1 if on_state["enable"] else 0 + if "fade_on_time" in sys_info and "gradually_on_mode" in sys_info: + sys_info["gradually_on_mode"] = 1 if on_state["enable"] else 0 if "duration" in on_state: - info["fade_on_time"] = on_state["duration"] - else: + sys_info["fade_on_time"] = on_state["duration"] + if "get_on_off_gradually_info" in info: info["get_on_off_gradually_info"]["on_state"]["enable"] = on_state[ "enable" ] @@ -232,11 +346,11 @@ class FakeSmartTransport(BaseTransport): on_state["duration"] ) elif off_state := params.get("off_state"): - if "fade_off_time" in info and "gradually_off_mode" in info: - info["gradually_off_mode"] = 1 if off_state["enable"] else 0 + if "fade_off_time" in sys_info and "gradually_off_mode" in sys_info: + sys_info["gradually_off_mode"] = 1 if off_state["enable"] else 0 if "duration" in off_state: - info["fade_off_time"] = off_state["duration"] - else: + sys_info["fade_off_time"] = off_state["duration"] + if "get_on_off_gradually_info" in info: info["get_on_off_gradually_info"]["off_state"]["enable"] = off_state[ "enable" ] @@ -290,6 +404,13 @@ class FakeSmartTransport(BaseTransport): if "brightness" not in info["get_preset_rules"]: return {"error_code": SmartErrorCode.PARAMS_ERROR} info["get_preset_rules"]["brightness"] = params["brightness"] + # So far the only child device with light preset (KS240) also has the + # data available to read in the device_info. + device_info = info["get_device_info"] + if "preset_state" in device_info: + device_info["preset_state"] = [ + {"brightness": b} for b in params["brightness"] + ] return {"error_code": 0} def _set_child_preset_rules(self, info, params): @@ -309,12 +430,12 @@ class FakeSmartTransport(BaseTransport): info["get_preset_rules"]["states"][params["index"]] = params["state"] return {"error_code": 0} - def _send_request(self, request_dict: dict): + async def _send_request(self, request_dict: dict): method = request_dict["method"] info = self.info if method == "control_child": - return self._handle_control_child(request_dict["params"]) + return await self._handle_control_child(request_dict["params"]) params = request_dict.get("params") if method == "component_nego" or method[:4] == "get_": diff --git a/kasa/tests/fakeprotocol_smartcamera.py b/kasa/tests/fakeprotocol_smartcamera.py index a8c49bd4..d7465489 100644 --- a/kasa/tests/fakeprotocol_smartcamera.py +++ b/kasa/tests/fakeprotocol_smartcamera.py @@ -2,20 +2,18 @@ from __future__ import annotations import copy from json import loads as json_loads -from warnings import warn from kasa import Credentials, DeviceConfig, SmartProtocol from kasa.experimental.smartcameraprotocol import SmartCameraProtocol from kasa.protocol import BaseTransport -from kasa.smart import SmartChildDevice -from .fakeprotocol_smart import FakeSmartProtocol +from .fakeprotocol_smart import FakeSmartTransport class FakeSmartCameraProtocol(SmartCameraProtocol): - def __init__(self, info, fixture_name): + def __init__(self, info, fixture_name, *, is_child=False): super().__init__( - transport=FakeSmartCameraTransport(info, fixture_name), + transport=FakeSmartCameraTransport(info, fixture_name, is_child=is_child), ) async def query(self, request, retry_count: int = 3): @@ -31,6 +29,7 @@ class FakeSmartCameraTransport(BaseTransport): fixture_name, *, list_return_size=10, + is_child=False, ): super().__init__( config=DeviceConfig( @@ -42,8 +41,14 @@ class FakeSmartCameraTransport(BaseTransport): ), ) self.fixture_name = fixture_name - self.info = copy.deepcopy(info) - self.child_protocols = self._get_child_protocols() + if not is_child: + self.info = copy.deepcopy(info) + self.child_protocols = FakeSmartTransport._get_child_protocols( + self.info, self.fixture_name, "getChildDeviceList" + ) + else: + self.info = info + # self.child_protocols = self._get_child_protocols() self.list_return_size = list_return_size @property @@ -74,55 +79,6 @@ class FakeSmartCameraTransport(BaseTransport): else: return await self._send_request(request_dict) - def _get_child_protocols(self): - child_infos = self.info.get("getChildDeviceList", {}).get( - "child_device_list", [] - ) - found_child_fixture_infos = [] - child_protocols = {} - # imported here to avoid circular import - from .conftest import filter_fixtures - - for child_info in child_infos: - if ( - (device_id := child_info.get("device_id")) - and (category := child_info.get("category")) - and category in SmartChildDevice.CHILD_DEVICE_TYPE_MAP - ): - hw_version = child_info["hw_ver"] - sw_version = child_info["fw_ver"] - sw_version = sw_version.split(" ")[0] - model = child_info["model"] - region = child_info["specs"] - child_fixture_name = f"{model}({region})_{hw_version}_{sw_version}" - child_fixtures = filter_fixtures( - "Child fixture", - protocol_filter={"SMART.CHILD"}, - model_filter=child_fixture_name, - ) - if child_fixtures: - fixture_info = next(iter(child_fixtures)) - found_child_fixture_infos.append(child_info) - child_protocols[device_id] = FakeSmartProtocol( - fixture_info.data, fixture_info.name - ) - else: - warn( - f"Could not find child fixture {child_fixture_name}", - stacklevel=1, - ) - else: - warn( - f"Child is a cameraprotocol which needs to be implemented {child_info}", - stacklevel=1, - ) - # Replace child infos with the infos that found child fixtures - if child_infos: - self.info["getChildDeviceList"]["child_device_list"] = ( - found_child_fixture_infos - ) - return child_protocols - async def _handle_control_child(self, params: dict): """Handle control_child command.""" device_id = params.get("device_id") diff --git a/kasa/tests/fixtureinfo.py b/kasa/tests/fixtureinfo.py index 8db96024..9f4d3952 100644 --- a/kasa/tests/fixtureinfo.py +++ b/kasa/tests/fixtureinfo.py @@ -118,10 +118,17 @@ def filter_fixtures( """ def _model_match(fixture_data: FixtureInfo, model_filter: set[str]): + if isinstance(model_filter, str): + model_filter = {model_filter} + assert isinstance(model_filter, set), "model filter must be a set" model_filter_list = [mf for mf in model_filter] - if len(model_filter_list) == 1 and model_filter_list[0].split("_") == 3: + if ( + len(model_filter_list) == 1 + and (model := model_filter_list[0]) + and len(model.split("_")) == 3 + ): # return exact match - return fixture_data.name == model_filter_list[0] + return fixture_data.name == f"{model}.json" file_model_region = fixture_data.name.split("_")[0] file_model = file_model_region.split("(")[0] return file_model in model_filter diff --git a/kasa/tests/test_emeter.py b/kasa/tests/test_emeter.py index 3cc69193..d5a35758 100644 --- a/kasa/tests/test_emeter.py +++ b/kasa/tests/test_emeter.py @@ -14,6 +14,8 @@ from kasa import Device, EmeterStatus, Module from kasa.interfaces.energy import Energy from kasa.iot import IotDevice, IotStrip from kasa.iot.modules.emeter import Emeter +from kasa.smart import SmartDevice +from kasa.smart.modules import Energy as SmartEnergyModule from .conftest import has_emeter, has_emeter_iot, no_emeter @@ -54,6 +56,11 @@ async def test_no_emeter(dev): @has_emeter async def test_get_emeter_realtime(dev): + if isinstance(dev, SmartDevice): + mod = SmartEnergyModule(dev, str(Module.Energy)) + if not await mod._check_supported(): + pytest.skip(f"Energy module not supported for {dev}.") + assert dev.has_emeter current_emeter = await dev.get_emeter_realtime() @@ -178,6 +185,10 @@ async def test_emeter_daily(): @has_emeter async def test_supported(dev: Device): + if isinstance(dev, SmartDevice): + mod = SmartEnergyModule(dev, str(Module.Energy)) + if not await mod._check_supported(): + pytest.skip(f"Energy module not supported for {dev}.") energy_module = dev.modules.get(Module.Energy) assert energy_module if isinstance(dev, IotDevice):