"""Tests for SMART devices.""" from __future__ import annotations import copy import logging import time from typing import TYPE_CHECKING, Any, cast from unittest.mock import patch import pytest from freezegun.api import FrozenDateTimeFactory from pytest_mock import MockerFixture from kasa import Device, DeviceType, KasaException, Module from kasa.exceptions import DeviceError, SmartErrorCode from kasa.smart import SmartDevice from kasa.smart.modules.energy import Energy from kasa.smart.smartmodule import SmartModule from tests.conftest import ( DISCOVERY_MOCK_IP, device_smart, get_device_for_fixture_protocol, get_parent_and_child_modules, smart_discovery, ) from tests.device_fixtures import ( hub_smartcam, hubs_smart, parametrize_combine, variable_temp_smart, ) DUMMY_CHILD_REQUEST_PREFIX = "get_dummy_" hub_all = parametrize_combine([hubs_smart, hub_smartcam]) @device_smart @pytest.mark.requires_dummy @pytest.mark.xdist_group(name="caplog") async def test_try_get_response(dev: SmartDevice, caplog): mock_response: dict = { "get_device_info": SmartErrorCode.PARAMS_ERROR, } caplog.set_level(logging.DEBUG) dev._try_get_response(mock_response, "get_device_info", {}) msg = "Error PARAMS_ERROR(-1008) getting request get_device_info for device 127.0.0.123" assert msg in caplog.text @device_smart @pytest.mark.requires_dummy async def test_update_no_device_info(dev: SmartDevice, mocker: MockerFixture): mock_response: dict = { "get_device_usage": {}, "get_device_time": {}, } msg = f"get_device_info not found in {mock_response} for device 127.0.0.123" mocker.patch.object(dev.protocol, "query", return_value=mock_response) with pytest.raises(KasaException, match=msg): await dev.update() @smart_discovery async def test_device_type_no_update(discovery_mock, caplog: pytest.LogCaptureFixture): """Test device type and repr when device not updated.""" dev = SmartDevice(DISCOVERY_MOCK_IP) assert dev.device_type is DeviceType.Unknown assert repr(dev) == f"" discovery_result = copy.deepcopy(discovery_mock.discovery_data["result"]) disco_model = discovery_result["device_model"] short_model, _, _ = disco_model.partition("(") dev.update_from_discover_info(discovery_result) assert dev.device_type is DeviceType.Unknown assert ( repr(dev) == f"" ) discovery_result["device_type"] = "SMART.FOOBAR" dev.update_from_discover_info(discovery_result) dev._components = {"dummy": 1} assert dev.device_type is DeviceType.Plug assert ( repr(dev) == f"" ) assert "Unknown device type, falling back to plug" in caplog.text @device_smart async def test_initial_update(dev: SmartDevice, mocker: MockerFixture): """Test the initial update cycle.""" # As the fixture data is already initialized, we reset the state for testing dev._components_raw = None dev._components = {} dev._modules = {} dev._features = {} dev._children = {} dev._last_update = {} dev._last_update_time = None negotiate = mocker.spy(dev, "_negotiate") initialize_modules = mocker.spy(dev, "_initialize_modules") initialize_features = mocker.spy(dev, "_initialize_features") # Perform two updates and verify that initialization is only done once await dev.update() await dev.update() negotiate.assert_called_once() assert dev._components_raw is not None initialize_modules.assert_called_once() assert dev.modules initialize_features.assert_called_once() assert dev.features @device_smart async def test_negotiate(dev: SmartDevice, mocker: MockerFixture): """Test that the initial negotiation performs expected steps.""" # As the fixture data is already initialized, we reset the state for testing dev._components_raw = None dev._children = {} query = mocker.spy(dev.protocol, "query") initialize_children = mocker.spy(dev, "_initialize_children") await dev._negotiate() # Check that we got the initial negotiation call query.assert_any_call( { "component_nego": None, "get_device_info": None, "get_connect_cloud_state": None, } ) assert dev._components_raw # Check the children are created, if device supports them if "child_device" in dev._components: initialize_children.assert_called_once() query.assert_any_call( { "get_child_device_component_list": None, "get_child_device_list": None, } ) assert len(dev._children) == dev.internal_state["get_child_device_list"]["sum"] @device_smart async def test_update_module_queries(dev: SmartDevice, mocker: MockerFixture): """Test that the regular update uses queries from all supported modules.""" # We need to have some modules initialized by now assert dev._modules # Reset last update so all modules will query for mod in dev._modules.values(): mod._last_update_time = None device_queries: dict[SmartDevice, dict[str, Any]] = {} for mod in dev._modules.values(): device_queries.setdefault(mod._device, {}).update(mod.query()) # Hubs do not query child modules by default. if dev.device_type != Device.Type.Hub: for child in dev.children: for mod in child.modules.values(): device_queries.setdefault(mod._device, {}).update(mod.query()) spies = {} for device in device_queries: spies[device] = mocker.spy(device.protocol, "query") await dev.update() for device in device_queries: if device_queries[device]: # Need assert any here because the child device updates use the parent's protocol spies[device].assert_any_call(device_queries[device]) else: spies[device].assert_not_called() @device_smart @pytest.mark.xdist_group(name="caplog") async def test_update_module_update_delays( dev: SmartDevice, mocker: MockerFixture, caplog: pytest.LogCaptureFixture, freezer: FrozenDateTimeFactory, ): """Test that modules with minimum delays delay.""" # We need to have some modules initialized by now assert dev._modules new_dev = SmartDevice("127.0.0.1", protocol=dev.protocol) await new_dev.update() first_update_time = time.monotonic() assert new_dev._last_update_time == first_update_time for module in new_dev.modules.values(): if module.query(): assert module._last_update_time == first_update_time seconds = 0 tick = 30 while seconds <= 180: seconds += tick freezer.tick(tick) now = time.monotonic() await new_dev.update() for module in new_dev.modules.values(): mod_delay = module.MINIMUM_UPDATE_INTERVAL_SECS if module.query(): expected_update_time = ( now if mod_delay == 0 else now - (seconds % mod_delay) ) assert ( module._last_update_time == expected_update_time ), f"Expected update time {expected_update_time} after {seconds} seconds for {module.name} with delay {mod_delay} got {module._last_update_time}" async def _get_child_responses(child_requests: list[dict[str, Any]], child_protocol): """Get dummy responses for testing all child modules. Even if they don't return really return query. """ child_req = {item["method"]: item.get("params") for item in child_requests} child_resp = {k: v for k, v in child_req.items() if k.startswith("get_dummy")} child_req = { k: v for k, v in child_req.items() if k.startswith("get_dummy") is False } resp = await child_protocol._query(child_req) resp = {**child_resp, **resp} return [ {"method": k, "error_code": 0, "result": v or {"dummy": "dummy"}} for k, v in resp.items() ] @hub_all @pytest.mark.xdist_group(name="caplog") async def test_hub_children_update_delays( dev: SmartDevice, mocker: MockerFixture, caplog: pytest.LogCaptureFixture, freezer: FrozenDateTimeFactory, ): """Test that hub children use the correct delay.""" if not dev.children: pytest.skip(f"Device {dev.model} does not have children.") # We need to have some modules initialized by now assert dev._modules new_dev = type(dev)("127.0.0.1", protocol=dev.protocol) module_queries: dict[str, dict[str, dict]] = {} # children should always update on first update await new_dev.update(update_children=False) if TYPE_CHECKING: from ..fakeprotocol_smart import FakeSmartTransport assert isinstance(dev.protocol._transport, FakeSmartTransport) if dev.protocol._transport.child_protocols: for child in new_dev.children: for modname, module in child._modules.items(): if ( not (q := module.query()) and modname not in {"DeviceModule", "Light"} and not module.SYSINFO_LOOKUP_KEYS ): q = {f"get_dummy_{modname}": {}} mocker.patch.object(module, "query", return_value=q) if q: queries = module_queries.setdefault(child.device_id, {}) queries[cast(str, modname)] = q module._last_update_time = None module_queries[""] = { cast(str, modname): q for modname, module in dev._modules.items() if (q := module.query()) } async def _query(request, *args, **kwargs): # If this is a child multipleRequest query return the error wrapped child_id = None # smart hub if ( (cc := request.get("control_child")) and (child_id := cc.get("device_id")) and (requestData := cc["requestData"]) and requestData["method"] == "multipleRequest" and (child_requests := requestData["params"]["requests"]) ): child_protocol = dev.protocol._transport.child_protocols[child_id] resp = await _get_child_responses(child_requests, child_protocol) return {"control_child": {"responseData": {"result": {"responses": resp}}}} # smartcam hub if ( (mr := request.get("multipleRequest")) and (requests := mr.get("requests")) # assumes all requests for the same child and ( child_id := next(iter(requests)) .get("params", {}) .get("childControl", {}) .get("device_id") ) and ( child_requests := [ cc["request_data"] for req in requests if (cc := req["params"].get("childControl")) ] ) ): child_protocol = dev.protocol._transport.child_protocols[child_id] resp = await _get_child_responses(child_requests, child_protocol) resp = [{"result": {"response_data": resp}} for resp in resp] return {"multipleRequest": {"responses": resp}} if child_id: # child single query child_protocol = dev.protocol._transport.child_protocols[child_id] resp_list = await _get_child_responses([requestData], child_protocol) resp = {"control_child": {"responseData": resp_list[0]}} else: resp = await dev.protocol._query(request, *args, **kwargs) return resp mocker.patch.object(new_dev.protocol, "query", side_effect=_query) first_update_time = time.monotonic() assert new_dev._last_update_time == first_update_time await new_dev.update() for dev_id, modqueries in module_queries.items(): check_dev = new_dev._children[dev_id] if dev_id else new_dev for modname in modqueries: mod = cast(SmartModule, check_dev.modules[modname]) assert mod._last_update_time == first_update_time for mod in new_dev.modules.values(): mod.MINIMUM_UPDATE_INTERVAL_SECS = 5 freezer.tick(180) now = time.monotonic() await new_dev.update() child_tick = max( module.MINIMUM_HUB_CHILD_UPDATE_INTERVAL_SECS for child in new_dev.children for module in child.modules.values() ) for dev_id, modqueries in module_queries.items(): check_dev = new_dev._children[dev_id] if dev_id else new_dev for modname in modqueries: if modname in {"Firmware"}: continue mod = cast(SmartModule, check_dev.modules[modname]) expected_update_time = first_update_time if dev_id else now assert mod._last_update_time == expected_update_time freezer.tick(child_tick) now = time.monotonic() await new_dev.update() for dev_id, modqueries in module_queries.items(): check_dev = new_dev._children[dev_id] if dev_id else new_dev for modname in modqueries: if modname in {"Firmware"}: continue mod = cast(SmartModule, check_dev.modules[modname]) assert mod._last_update_time == now @pytest.mark.parametrize( ("first_update"), [ pytest.param(True, id="First update true"), pytest.param(False, id="First update false"), ], ) @pytest.mark.parametrize( ("error_type"), [ pytest.param(SmartErrorCode.PARAMS_ERROR, id="Device error"), pytest.param(TimeoutError("Dummy timeout"), id="Query error"), ], ) @pytest.mark.parametrize( ("recover"), [ pytest.param(True, id="recover"), pytest.param(False, id="no recover"), ], ) @device_smart @pytest.mark.xdist_group(name="caplog") async def test_update_module_query_errors( dev: SmartDevice, mocker: MockerFixture, caplog: pytest.LogCaptureFixture, freezer: FrozenDateTimeFactory, first_update, error_type, recover, ): """Test that modules that disabled / removed on query failures. i.e. the whole query times out rather than device returns an error. """ # We need to have some modules initialized by now assert dev._modules SmartModule.DISABLE_AFTER_ERROR_COUNT = 2 first_update_queries = {"get_device_info", "get_connect_cloud_state"} critical_modules = {Module.DeviceModule, Module.ChildDevice} new_dev = SmartDevice("127.0.0.1", protocol=dev.protocol) if not first_update: await new_dev.update() freezer.tick(max(module.update_interval for module in dev._modules.values())) module_queries: dict[str, dict[str, dict]] = {} if TYPE_CHECKING: from ..fakeprotocol_smart import FakeSmartTransport assert isinstance(dev.protocol._transport, FakeSmartTransport) if dev.protocol._transport.child_protocols: for child in new_dev.children: for modname, module in child._modules.items(): if ( not (q := module.query()) and modname not in {"DeviceModule", "Light"} and not module.SYSINFO_LOOKUP_KEYS ): q = {f"get_dummy_{modname}": {}} mocker.patch.object(module, "query", return_value=q) if q: queries = module_queries.setdefault(child.device_id, {}) queries[cast(str, modname)] = q module_queries[""] = { cast(str, modname): q for modname, module in dev._modules.items() if (q := module.query()) and modname not in critical_modules } raise_error = True async def _query(request, *args, **kwargs): pass # If this is a childmultipleRequest query return the error wrapped child_id = None if ( (cc := request.get("control_child")) and (child_id := cc.get("device_id")) and (requestData := cc["requestData"]) and requestData["method"] == "multipleRequest" and (child_requests := requestData["params"]["requests"]) ): if raise_error: if not isinstance(error_type, SmartErrorCode): raise TimeoutError() if len(child_requests) > 1: raise TimeoutError() if raise_error: resp = { "method": child_requests[0]["method"], "error_code": error_type.value, } else: child_protocol = dev.protocol._transport.child_protocols[child_id] resp = await _get_child_responses(child_requests, child_protocol) return {"control_child": {"responseData": {"result": {"responses": resp}}}} if ( not raise_error or "component_nego" in request or "get_child_device_component_list" in request ): if child_id: # child single query child_protocol = dev.protocol._transport.child_protocols[child_id] resp_list = await _get_child_responses([requestData], child_protocol) resp = {"control_child": {"responseData": resp_list[0]}} else: resp = await dev.protocol._query(request, *args, **kwargs) if raise_error: resp["get_connect_cloud_state"] = SmartErrorCode.CLOUD_FAILED_ERROR return resp # Don't test for errors on get_device_info as that is likely terminal if len(request) == 1 and "get_device_info" in request: return await dev.protocol._query(request, *args, **kwargs) if isinstance(error_type, SmartErrorCode): if len(request) == 1: raise DeviceError("Dummy device error", error_code=error_type) raise TimeoutError("Dummy timeout") raise error_type mocker.patch.object(new_dev.protocol, "query", side_effect=_query) await new_dev.update() msg = f"Error querying {new_dev.host} for modules" assert msg in caplog.text for dev_id, modqueries in module_queries.items(): check_dev = new_dev._children[dev_id] if dev_id else new_dev for modname in modqueries: mod = cast(SmartModule, check_dev.modules[modname]) if modname in {"DeviceModule"} or ( hasattr(mod, "_state_in_sysinfo") and mod._state_in_sysinfo is True ): continue assert mod.disabled is False, f"{modname} disabled" assert mod.update_interval == mod.UPDATE_INTERVAL_AFTER_ERROR_SECS for mod_query in modqueries[modname]: if not first_update or mod_query not in first_update_queries: msg = f"Error querying {new_dev.host} individually for module query '{mod_query}" assert msg in caplog.text # Query again should not run for the modules caplog.clear() await new_dev.update() for dev_id, modqueries in module_queries.items(): check_dev = new_dev._children[dev_id] if dev_id else new_dev for modname in modqueries: mod = cast(SmartModule, check_dev.modules[modname]) assert mod.disabled is False, f"{modname} disabled" freezer.tick(SmartModule.UPDATE_INTERVAL_AFTER_ERROR_SECS) caplog.clear() if recover: raise_error = False await new_dev.update() msg = f"Error querying {new_dev.host} for modules" if not recover: assert msg in caplog.text for dev_id, modqueries in module_queries.items(): check_dev = new_dev._children[dev_id] if dev_id else new_dev for modname in modqueries: mod = cast(SmartModule, check_dev.modules[modname]) if modname in {"DeviceModule"} or ( hasattr(mod, "_state_in_sysinfo") and mod._state_in_sysinfo is True ): continue if not recover: assert mod.disabled is True, f"{modname} not disabled" assert mod._error_count == 2 assert mod._last_update_error for mod_query in modqueries[modname]: if not first_update or mod_query not in first_update_queries: msg = f"Error querying {new_dev.host} individually for module query '{mod_query}" assert msg in caplog.text # Test one of the raise_if_update_error if mod.name == "Energy": emod = cast(Energy, mod) with pytest.raises(KasaException, match="Module update error"): assert emod.status is not None else: assert mod.disabled is False assert mod._error_count == 0 assert mod._last_update_error is None # Test one of the raise_if_update_error doesn't raise if mod.name == "Energy": emod = cast(Energy, mod) assert emod.status is not None async def test_get_modules(): """Test getting modules for child and parent modules.""" dummy_device = await get_device_for_fixture_protocol( "KS240(US)_1.0_1.0.5.json", "SMART" ) from kasa.smart.modules import Cloud # Modules on device module = dummy_device.modules.get("Cloud") assert module assert module._device == dummy_device assert isinstance(module, Cloud) module = dummy_device.modules.get(Module.Cloud) assert module assert module._device == dummy_device assert isinstance(module, Cloud) # Modules on child module = dummy_device.modules.get("Fan") assert module is None module = next(get_parent_and_child_modules(dummy_device, "Fan")) assert module assert module._device != dummy_device assert module._device._parent == dummy_device # Invalid modules module = dummy_device.modules.get("DummyModule") assert module is None module = dummy_device.modules.get(Module.IotAmbientLight) assert module is None @device_smart async def test_smartdevice_cloud_connection(dev: SmartDevice, mocker: MockerFixture): """Test is_cloud_connected property.""" assert isinstance(dev, SmartDevice) assert "cloud_connect" in dev._components is_connected = ( (cc := dev._last_update.get("get_connect_cloud_state")) and not isinstance(cc, SmartErrorCode) and cc["status"] == 0 ) assert dev.is_cloud_connected == is_connected last_update = dev._last_update for child in dev.children: mocker.patch.object(child.protocol, "query", return_value=child._last_update) last_update["get_connect_cloud_state"] = {"status": 0} with patch.object(dev.protocol, "query", return_value=last_update): await dev.update() assert dev.is_cloud_connected is True last_update["get_connect_cloud_state"] = {"status": 1} with patch.object(dev.protocol, "query", return_value=last_update): await dev.update() assert dev.is_cloud_connected is False last_update["get_connect_cloud_state"] = SmartErrorCode.UNKNOWN_METHOD_ERROR with patch.object(dev.protocol, "query", return_value=last_update): await dev.update() assert dev.is_cloud_connected is False # Test for no cloud_connect component during device initialisation component_list = [ val for val in dev._components_raw["component_list"] if val["id"] not in {"cloud_connect"} ] initial_response = { "component_nego": {"component_list": component_list}, "get_connect_cloud_state": last_update["get_connect_cloud_state"], "get_device_info": last_update["get_device_info"], } new_dev = SmartDevice("127.0.0.1", protocol=dev.protocol) first_call = True async def side_effect_func(*args, **kwargs): nonlocal first_call resp = ( initial_response if first_call else await new_dev.protocol._query(*args, **kwargs) ) first_call = False return resp with patch.object( new_dev.protocol, "query", side_effect=side_effect_func, ): await new_dev.update() assert new_dev.is_cloud_connected is False @variable_temp_smart async def test_smart_temp_range(dev: Device): light = dev.modules.get(Module.Light) assert light color_temp_feat = light.get_feature("color_temp") assert color_temp_feat assert color_temp_feat.range @device_smart async def test_initialize_modules_sysinfo_lookup_keys( dev: SmartDevice, mocker: MockerFixture ): """Test that matching modules using SYSINFO_LOOKUP_KEYS are initialized correctly.""" class AvailableKey(SmartModule): SYSINFO_LOOKUP_KEYS = ["device_id"] class NonExistingKey(SmartModule): SYSINFO_LOOKUP_KEYS = ["this_does_not_exist"] # The __init_subclass__ hook in smartmodule checks the path, # so we have to manually add these for testing. mocker.patch.dict( "kasa.smart.smartmodule.SmartModule.REGISTERED_MODULES", { AvailableKey._module_name(): AvailableKey, NonExistingKey._module_name(): NonExistingKey, }, ) # We have an already initialized device, so we try to initialize the modules again await dev._initialize_modules() assert "AvailableKey" in dev.modules assert "NonExistingKey" not in dev.modules @device_smart async def test_initialize_modules_required_component( dev: SmartDevice, mocker: MockerFixture ): """Test that matching modules using REQUIRED_COMPONENT are initialized correctly.""" class AvailableComponent(SmartModule): REQUIRED_COMPONENT = "device" class NonExistingComponent(SmartModule): REQUIRED_COMPONENT = "this_does_not_exist" # The __init_subclass__ hook in smartmodule checks the path, # so we have to manually add these for testing. mocker.patch.dict( "kasa.smart.smartmodule.SmartModule.REGISTERED_MODULES", { AvailableComponent._module_name(): AvailableComponent, NonExistingComponent._module_name(): NonExistingComponent, }, ) # We have an already initialized device, so we try to initialize the modules again await dev._initialize_modules() assert "AvailableComponent" in dev.modules assert "NonExistingComponent" not in dev.modules async def test_smartmodule_query(): """Test that a module that doesn't set QUERY_GETTER_NAME has empty query.""" class DummyModule(SmartModule): pass dummy_device = await get_device_for_fixture_protocol( "KS240(US)_1.0_1.0.5.json", "SMART" ) mod = DummyModule(dummy_device, "dummy") assert mod.query() == {}