Handle module errors more robustly and add query params to light preset and transition (#1036)

Ensures that all modules try to access their data in `_post_update_hook` in a safe manner and disable themselves if there's an error.
Also adds parameters to get_preset_rules and get_on_off_gradually_info to fix issues with recent firmware updates.
[#1033](https://github.com/python-kasa/python-kasa/issues/1033)
This commit is contained in:
Steven B
2024-07-04 08:02:50 +01:00
committed by GitHub
parent 9cffbe9e48
commit 905a14895d
17 changed files with 206 additions and 30 deletions

View File

@@ -3,7 +3,7 @@
from __future__ import annotations
import logging
from typing import Any
from typing import Any, cast
from unittest.mock import patch
import pytest
@@ -132,6 +132,78 @@ async def test_update_module_queries(dev: SmartDevice, mocker: MockerFixture):
spies[device].assert_not_called()
@device_smart
async def test_update_module_errors(dev: SmartDevice, mocker: MockerFixture):
"""Test that modules that error are disabled / removed."""
# We need to have some modules initialized by now
assert dev._modules
critical_modules = {Module.DeviceModule, Module.ChildDevice}
not_disabling_modules = {Module.Firmware, Module.Cloud}
new_dev = SmartDevice("127.0.0.1", protocol=dev.protocol)
module_queries = {
modname: q
for modname, module in dev._modules.items()
if (q := module.query()) and modname not in critical_modules
}
child_module_queries = {
modname: q
for child in dev.children
for modname, module in child._modules.items()
if (q := module.query()) and modname not in critical_modules
}
all_queries_names = {
key for mod_query in module_queries.values() for key in mod_query
}
all_child_queries_names = {
key for mod_query in child_module_queries.values() for key in mod_query
}
async def _query(request, *args, **kwargs):
responses = await dev.protocol._query(request, *args, **kwargs)
for k in responses:
if k in all_queries_names:
responses[k] = SmartErrorCode.PARAMS_ERROR
return responses
async def _child_query(self, request, *args, **kwargs):
responses = await child_protocols[self._device_id]._query(
request, *args, **kwargs
)
for k in responses:
if k in all_child_queries_names:
responses[k] = SmartErrorCode.PARAMS_ERROR
return responses
mocker.patch.object(new_dev.protocol, "query", side_effect=_query)
from kasa.smartprotocol import _ChildProtocolWrapper
child_protocols = {
cast(_ChildProtocolWrapper, child.protocol)._device_id: child.protocol
for child in dev.children
}
# children not created yet so cannot patch.object
mocker.patch("kasa.smartprotocol._ChildProtocolWrapper.query", new=_child_query)
await new_dev.update()
for modname in module_queries:
no_disable = modname in not_disabling_modules
mod_present = modname in new_dev._modules
assert (
mod_present is no_disable
), f"{modname} present {mod_present} when no_disable {no_disable}"
for modname in child_module_queries:
no_disable = modname in not_disabling_modules
mod_present = any(modname in child._modules for child in new_dev.children)
assert (
mod_present is no_disable
), f"{modname} present {mod_present} when no_disable {no_disable}"
async def test_get_modules():
"""Test getting modules for child and parent modules."""
dummy_device = await get_device_for_fixture_protocol(
@@ -181,6 +253,9 @@ async def test_smartdevice_cloud_connection(dev: SmartDevice, mocker: MockerFixt
assert dev.is_cloud_connected == is_connected
last_update = dev._last_update
for child in dev.children:
mocker.patch.object(child.protocol, "query", return_value=child._last_update)
last_update["get_connect_cloud_state"] = {"status": 0}
with patch.object(dev.protocol, "query", return_value=last_update):
await dev.update()
@@ -207,21 +282,18 @@ async def test_smartdevice_cloud_connection(dev: SmartDevice, mocker: MockerFixt
"get_connect_cloud_state": last_update["get_connect_cloud_state"],
"get_device_info": last_update["get_device_info"],
}
# Child component list is not stored on the device
if "get_child_device_list" in last_update:
child_component_list = await dev.protocol.query(
"get_child_device_component_list"
)
last_update["get_child_device_component_list"] = child_component_list[
"get_child_device_component_list"
]
new_dev = SmartDevice("127.0.0.1", protocol=dev.protocol)
first_call = True
def side_effect_func(*_, **__):
async def side_effect_func(*args, **kwargs):
nonlocal first_call
resp = initial_response if first_call else last_update
resp = (
initial_response
if first_call
else await new_dev.protocol._query(*args, **kwargs)
)
first_call = False
return resp

View File

@@ -1,6 +1,7 @@
import logging
import pytest
import pytest_mock
from ..exceptions import (
SMART_RETRYABLE_ERRORS,
@@ -19,6 +20,21 @@ DUMMY_MULTIPLE_QUERY = {
ERRORS = [e for e in SmartErrorCode if e != 0]
async def test_smart_queries(dummy_protocol, mocker: pytest_mock.MockerFixture):
mock_response = {"result": {"great": "success"}, "error_code": 0}
mocker.patch.object(dummy_protocol._transport, "send", return_value=mock_response)
# test sending a method name as a string
resp = await dummy_protocol.query("foobar")
assert "foobar" in resp
assert resp["foobar"] == mock_response["result"]
# test sending a method name as a dict
resp = await dummy_protocol.query(DUMMY_QUERY)
assert "foobar" in resp
assert resp["foobar"] == mock_response["result"]
@pytest.mark.parametrize("error_code", ERRORS, ids=lambda e: e.name)
async def test_smart_device_errors(dummy_protocol, mocker, error_code):
mock_response = {"result": {"great": "success"}, "error_code": error_code.value}