diff --git a/kasa/tests/device_fixtures.py b/kasa/tests/device_fixtures.py index 5843639e..085bab8e 100644 --- a/kasa/tests/device_fixtures.py +++ b/kasa/tests/device_fixtures.py @@ -1,10 +1,11 @@ -from typing import Dict, Set +from typing import Dict, List, Set import pytest from kasa import ( Credentials, Device, + DeviceType, Discover, ) from kasa.iot import IotBulb, IotDimmer, IotLightStrip, IotPlug, IotStrip, IotWallSwitch @@ -127,6 +128,21 @@ ALL_DEVICES = ALL_DEVICES_IOT.union(ALL_DEVICES_SMART) IP_MODEL_CACHE: Dict[str, str] = {} +def parametrize_combine(parametrized: List[pytest.MarkDecorator]): + """Combine multiple pytest parametrize dev marks into one set of fixtures.""" + fixtures = set() + for param in parametrized: + if param.args[0] != "dev": + raise Exception(f"Supplied mark is not for dev fixture: {param.args[0]}") + fixtures.update(param.args[1]) + return pytest.mark.parametrize( + "dev", + sorted(list(fixtures)), + indirect=True, + ids=idgenerator, + ) + + def parametrize( desc, *, @@ -134,6 +150,7 @@ def parametrize( protocol_filter=None, component_filter=None, data_root_filter=None, + device_type_filter=None, ids=None, ): if ids is None: @@ -146,6 +163,7 @@ def parametrize( protocol_filter=protocol_filter, component_filter=component_filter, data_root_filter=data_root_filter, + device_type_filter=device_type_filter, ), indirect=True, ids=ids, @@ -169,7 +187,6 @@ no_emeter_iot = parametrize( protocol_filter={"IOT"}, ) -bulb = parametrize("bulbs", model_filter=BULBS, protocol_filter={"SMART", "IOT"}) plug = parametrize("plugs", model_filter=PLUGS, protocol_filter={"IOT", "SMART"}) plug_iot = parametrize("plugs iot", model_filter=PLUGS, protocol_filter={"IOT"}) wallswitch = parametrize( @@ -216,9 +233,16 @@ variable_temp_iot = parametrize( model_filter=BULBS_IOT_VARIABLE_TEMP, protocol_filter={"IOT"}, ) + +bulb_smart = parametrize( + "bulb devices smart", + device_type_filter=[DeviceType.Bulb, DeviceType.LightStrip], + protocol_filter={"SMART"}, +) bulb_iot = parametrize( "bulb devices iot", model_filter=BULBS_IOT, protocol_filter={"IOT"} ) +bulb = parametrize_combine([bulb_smart, bulb_iot]) strip_iot = parametrize( "strip devices iot", model_filter=STRIPS_IOT, protocol_filter={"IOT"} @@ -233,9 +257,6 @@ plug_smart = parametrize( switch_smart = parametrize( "switch devices smart", model_filter=SWITCHES_SMART, protocol_filter={"SMART"} ) -bulb_smart = parametrize( - "bulb devices smart", model_filter=BULBS_SMART, protocol_filter={"SMART"} -) dimmers_smart = parametrize( "dimmer devices smart", model_filter=DIMMERS_SMART, protocol_filter={"SMART"} ) diff --git a/kasa/tests/fixtureinfo.py b/kasa/tests/fixtureinfo.py index dc6e5307..70d385f6 100644 --- a/kasa/tests/fixtureinfo.py +++ b/kasa/tests/fixtureinfo.py @@ -4,6 +4,10 @@ import os from pathlib import Path from typing import Dict, List, NamedTuple, Optional, Set +from kasa.device_factory import _get_device_type_from_sys_info +from kasa.device_type import DeviceType +from kasa.smart.smartdevice import SmartDevice + class FixtureInfo(NamedTuple): name: str @@ -83,6 +87,7 @@ def filter_fixtures( protocol_filter: Optional[Set[str]] = None, model_filter: Optional[Set[str]] = None, component_filter: Optional[str] = None, + device_type_filter: Optional[List[DeviceType]] = None, ): """Filter the fixtures based on supplied parameters. @@ -108,6 +113,19 @@ def filter_fixtures( } return component_filter in components + def _device_type_match(fixture_data: FixtureInfo, device_type): + if (component_nego := fixture_data.data.get("component_nego")) is None: + return _get_device_type_from_sys_info(fixture_data.data) in device_type + components = [component["id"] for component in component_nego["component_list"]] + if (info := fixture_data.data.get("get_device_info")) and ( + type_ := info.get("type") + ): + return ( + SmartDevice._get_device_type_from_components(components, type_) + in device_type + ) + return False + filtered = [] if protocol_filter is None: protocol_filter = {"IOT", "SMART"} @@ -120,6 +138,10 @@ def filter_fixtures( continue if component_filter and not _component_match(fixture_data, component_filter): continue + if device_type_filter and not _device_type_match( + fixture_data, device_type_filter + ): + continue filtered.append(fixture_data)