Handle module errors more robustly and add query params to light preset and transition (#1036)

Ensures that all modules try to access their data in `_post_update_hook` in a safe manner and disable themselves if there's an error.
Also adds parameters to get_preset_rules and get_on_off_gradually_info to fix issues with recent firmware updates.
[#1033](https://github.com/python-kasa/python-kasa/issues/1033)
This commit is contained in:
Steven B 2024-07-04 08:02:50 +01:00 committed by GitHub
parent 9cffbe9e48
commit 905a14895d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 206 additions and 30 deletions

View File

@ -284,6 +284,15 @@ class SmartRequest:
"""Get preset rules."""
return SmartRequest("get_preset_rules", params or SmartRequest.GetRulesParams())
@staticmethod
def get_on_off_gradually_info(
params: SmartRequestParams | None = None,
) -> SmartRequest:
"""Get preset rules."""
return SmartRequest(
"get_on_off_gradually_info", params or SmartRequest.SmartRequestParams()
)
@staticmethod
def get_auto_light_info() -> SmartRequest:
"""Get auto light info."""
@ -382,7 +391,7 @@ COMPONENT_REQUESTS = {
"auto_light": [SmartRequest.get_auto_light_info()],
"light_effect": [SmartRequest.get_dynamic_light_effect_rules()],
"bulb_quick_control": [],
"on_off_gradually": [SmartRequest.get_raw_request("get_on_off_gradually_info")],
"on_off_gradually": [SmartRequest.get_on_off_gradually_info()],
"light_strip": [],
"light_strip_lighting_effect": [
SmartRequest.get_raw_request("get_lighting_effect")

View File

@ -19,12 +19,6 @@ class AutoOff(SmartModule):
def _initialize_features(self):
"""Initialize features after the initial update."""
if not isinstance(self.data, dict):
_LOGGER.warning(
"No data available for module, skipping %s: %s", self, self.data
)
return
self._add_feature(
Feature(
self._device,

View File

@ -43,6 +43,10 @@ class BatterySensor(SmartModule):
)
)
def query(self) -> dict:
"""Query to execute during the update cycle."""
return {}
@property
def battery(self):
"""Return battery level."""

View File

@ -4,7 +4,6 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from ...exceptions import SmartErrorCode
from ...feature import Feature
from ..smartmodule import SmartModule
@ -18,6 +17,13 @@ class Cloud(SmartModule):
QUERY_GETTER_NAME = "get_connect_cloud_state"
REQUIRED_COMPONENT = "cloud_connect"
def _post_update_hook(self):
"""Perform actions after a device update.
Overrides the default behaviour to disable a module if the query returns
an error because the logic here is to treat that as not connected.
"""
def __init__(self, device: SmartDevice, module: str):
super().__init__(device, module)
@ -37,6 +43,6 @@ class Cloud(SmartModule):
@property
def is_connected(self):
"""Return True if device is connected to the cloud."""
if isinstance(self.data, SmartErrorCode):
if self._has_data_error():
return False
return self.data["status"] == 0

View File

@ -10,6 +10,13 @@ class DeviceModule(SmartModule):
REQUIRED_COMPONENT = "device"
def _post_update_hook(self):
"""Perform actions after a device update.
Overrides the default behaviour to disable a module if the query returns
an error because this module is critical.
"""
def query(self) -> dict:
"""Query to execute during the update cycle."""
query = {

View File

@ -13,7 +13,6 @@ from typing import TYPE_CHECKING, Any, Callable, Optional
from async_timeout import timeout as asyncio_timeout
from pydantic.v1 import BaseModel, Field, validator
from ...exceptions import SmartErrorCode
from ...feature import Feature
from ..smartmodule import SmartModule
@ -123,6 +122,13 @@ class Firmware(SmartModule):
req["get_auto_update_info"] = None
return req
def _post_update_hook(self):
"""Perform actions after a device update.
Overrides the default behaviour to disable a module if the query returns
an error because some of the module still functions.
"""
@property
def current_firmware(self) -> str:
"""Return the current firmware version."""
@ -136,11 +142,11 @@ class Firmware(SmartModule):
@property
def firmware_update_info(self):
"""Return latest firmware information."""
fw = self.data.get("get_latest_fw") or self.data
if not self._device.is_cloud_connected or isinstance(fw, SmartErrorCode):
if not self._device.is_cloud_connected or self._has_data_error():
# Error in response, probably disconnected from the cloud.
return UpdateInfo(type=0, need_to_upgrade=False)
fw = self.data.get("get_latest_fw") or self.data
return UpdateInfo.parse_obj(fw)
@property

View File

@ -14,6 +14,10 @@ class FrostProtection(SmartModule):
REQUIRED_COMPONENT = "frost_protection"
QUERY_GETTER_NAME = "get_frost_protection"
def query(self) -> dict:
"""Query to execute during the update cycle."""
return {}
@property
def enabled(self) -> bool:
"""Return True if frost protection is on."""

View File

@ -45,6 +45,10 @@ class HumiditySensor(SmartModule):
)
)
def query(self) -> dict:
"""Query to execute during the update cycle."""
return {}
@property
def humidity(self):
"""Return current humidity in percentage."""

View File

@ -140,7 +140,7 @@ class LightPreset(SmartModule, LightPresetInterface):
"""Query to execute during the update cycle."""
if self._state_in_sysinfo: # Child lights can have states in the child info
return {}
return {self.QUERY_GETTER_NAME: None}
return {self.QUERY_GETTER_NAME: {"start_index": 0}}
async def _check_supported(self):
"""Additional check to see if the module is supported by the device.

View File

@ -230,7 +230,7 @@ class LightTransition(SmartModule):
if self._state_in_sysinfo:
return {}
else:
return {self.QUERY_GETTER_NAME: None}
return {self.QUERY_GETTER_NAME: {}}
async def _check_supported(self):
"""Additional check to see if the module is supported by the device."""

View File

@ -32,6 +32,10 @@ class ReportMode(SmartModule):
)
)
def query(self) -> dict:
"""Query to execute during the update cycle."""
return {}
@property
def report_interval(self):
"""Reporting interval of a sensor device."""

View File

@ -58,6 +58,10 @@ class TemperatureSensor(SmartModule):
)
)
def query(self) -> dict:
"""Query to execute during the update cycle."""
return {}
@property
def temperature(self):
"""Return current humidity in percentage."""

View File

@ -177,11 +177,20 @@ class SmartDevice(Device):
self._children[info["device_id"]]._update_internal_state(info)
# Call handle update for modules that want to update internal data
for module in self._modules.values():
module._post_update_hook()
errors = []
for module_name, module in self._modules.items():
if not self._handle_module_post_update_hook(module):
errors.append(module_name)
for error in errors:
self._modules.pop(error)
for child in self._children.values():
for child_module in child._modules.values():
child_module._post_update_hook()
errors = []
for child_module_name, child_module in child._modules.items():
if not self._handle_module_post_update_hook(child_module):
errors.append(child_module_name)
for error in errors:
child._modules.pop(error)
# We can first initialize the features after the first update.
# We make here an assumption that every device has at least a single feature.
@ -190,6 +199,19 @@ class SmartDevice(Device):
_LOGGER.debug("Got an update: %s", self._last_update)
def _handle_module_post_update_hook(self, module: SmartModule) -> bool:
try:
module._post_update_hook()
return True
except Exception as ex:
_LOGGER.error(
"Error processing %s for device %s, module will be unavailable: %s",
module.name,
self.host,
ex,
)
return False
async def _initialize_modules(self):
"""Initialize modules based on component negotiation response."""
from .smartmodule import SmartModule

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import logging
from typing import TYPE_CHECKING
from ..exceptions import KasaException
from ..exceptions import DeviceError, KasaException, SmartErrorCode
from ..module import Module
if TYPE_CHECKING:
@ -41,6 +41,14 @@ class SmartModule(Module):
"""Name of the module."""
return getattr(self, "NAME", self.__class__.__name__)
def _post_update_hook(self): # noqa: B027
"""Perform actions after a device update.
Any modules overriding this should ensure that self.data is
accessed unless the module should remain active despite errors.
"""
assert self.data # noqa: S101
def query(self) -> dict:
"""Query to execute during the update cycle.
@ -87,6 +95,11 @@ class SmartModule(Module):
filtered_data = {k: v for k, v in dev._last_update.items() if k in q_keys}
for data_item in filtered_data:
if isinstance(filtered_data[data_item], SmartErrorCode):
raise DeviceError(
f"{data_item} for {self.name}", error_code=filtered_data[data_item]
)
if len(filtered_data) == 1:
return next(iter(filtered_data.values()))
@ -110,3 +123,10 @@ class SmartModule(Module):
color_temp_range but only supports one value.
"""
return True
def _has_data_error(self) -> bool:
try:
assert self.data # noqa: S101
return False
except DeviceError:
return True

View File

@ -416,6 +416,10 @@ class _ChildProtocolWrapper(SmartProtocol):
return smart_method, smart_params
async def query(self, request: str | dict, retry_count: int = 3) -> dict:
"""Wrap request inside control_child envelope."""
return await self._query(request, retry_count)
async def _query(self, request: str | dict, retry_count: int = 3) -> dict:
"""Wrap request inside control_child envelope."""
method, params = self._get_method_and_params_for_request(request)
request_data = {

View File

@ -3,7 +3,7 @@
from __future__ import annotations
import logging
from typing import Any
from typing import Any, cast
from unittest.mock import patch
import pytest
@ -132,6 +132,78 @@ async def test_update_module_queries(dev: SmartDevice, mocker: MockerFixture):
spies[device].assert_not_called()
@device_smart
async def test_update_module_errors(dev: SmartDevice, mocker: MockerFixture):
"""Test that modules that error are disabled / removed."""
# We need to have some modules initialized by now
assert dev._modules
critical_modules = {Module.DeviceModule, Module.ChildDevice}
not_disabling_modules = {Module.Firmware, Module.Cloud}
new_dev = SmartDevice("127.0.0.1", protocol=dev.protocol)
module_queries = {
modname: q
for modname, module in dev._modules.items()
if (q := module.query()) and modname not in critical_modules
}
child_module_queries = {
modname: q
for child in dev.children
for modname, module in child._modules.items()
if (q := module.query()) and modname not in critical_modules
}
all_queries_names = {
key for mod_query in module_queries.values() for key in mod_query
}
all_child_queries_names = {
key for mod_query in child_module_queries.values() for key in mod_query
}
async def _query(request, *args, **kwargs):
responses = await dev.protocol._query(request, *args, **kwargs)
for k in responses:
if k in all_queries_names:
responses[k] = SmartErrorCode.PARAMS_ERROR
return responses
async def _child_query(self, request, *args, **kwargs):
responses = await child_protocols[self._device_id]._query(
request, *args, **kwargs
)
for k in responses:
if k in all_child_queries_names:
responses[k] = SmartErrorCode.PARAMS_ERROR
return responses
mocker.patch.object(new_dev.protocol, "query", side_effect=_query)
from kasa.smartprotocol import _ChildProtocolWrapper
child_protocols = {
cast(_ChildProtocolWrapper, child.protocol)._device_id: child.protocol
for child in dev.children
}
# children not created yet so cannot patch.object
mocker.patch("kasa.smartprotocol._ChildProtocolWrapper.query", new=_child_query)
await new_dev.update()
for modname in module_queries:
no_disable = modname in not_disabling_modules
mod_present = modname in new_dev._modules
assert (
mod_present is no_disable
), f"{modname} present {mod_present} when no_disable {no_disable}"
for modname in child_module_queries:
no_disable = modname in not_disabling_modules
mod_present = any(modname in child._modules for child in new_dev.children)
assert (
mod_present is no_disable
), f"{modname} present {mod_present} when no_disable {no_disable}"
async def test_get_modules():
"""Test getting modules for child and parent modules."""
dummy_device = await get_device_for_fixture_protocol(
@ -181,6 +253,9 @@ async def test_smartdevice_cloud_connection(dev: SmartDevice, mocker: MockerFixt
assert dev.is_cloud_connected == is_connected
last_update = dev._last_update
for child in dev.children:
mocker.patch.object(child.protocol, "query", return_value=child._last_update)
last_update["get_connect_cloud_state"] = {"status": 0}
with patch.object(dev.protocol, "query", return_value=last_update):
await dev.update()
@ -207,21 +282,18 @@ async def test_smartdevice_cloud_connection(dev: SmartDevice, mocker: MockerFixt
"get_connect_cloud_state": last_update["get_connect_cloud_state"],
"get_device_info": last_update["get_device_info"],
}
# Child component list is not stored on the device
if "get_child_device_list" in last_update:
child_component_list = await dev.protocol.query(
"get_child_device_component_list"
)
last_update["get_child_device_component_list"] = child_component_list[
"get_child_device_component_list"
]
new_dev = SmartDevice("127.0.0.1", protocol=dev.protocol)
first_call = True
def side_effect_func(*_, **__):
async def side_effect_func(*args, **kwargs):
nonlocal first_call
resp = initial_response if first_call else last_update
resp = (
initial_response
if first_call
else await new_dev.protocol._query(*args, **kwargs)
)
first_call = False
return resp

View File

@ -1,6 +1,7 @@
import logging
import pytest
import pytest_mock
from ..exceptions import (
SMART_RETRYABLE_ERRORS,
@ -19,6 +20,21 @@ DUMMY_MULTIPLE_QUERY = {
ERRORS = [e for e in SmartErrorCode if e != 0]
async def test_smart_queries(dummy_protocol, mocker: pytest_mock.MockerFixture):
mock_response = {"result": {"great": "success"}, "error_code": 0}
mocker.patch.object(dummy_protocol._transport, "send", return_value=mock_response)
# test sending a method name as a string
resp = await dummy_protocol.query("foobar")
assert "foobar" in resp
assert resp["foobar"] == mock_response["result"]
# test sending a method name as a dict
resp = await dummy_protocol.query(DUMMY_QUERY)
assert "foobar" in resp
assert resp["foobar"] == mock_response["result"]
@pytest.mark.parametrize("error_code", ERRORS, ids=lambda e: e.name)
async def test_smart_device_errors(dummy_protocol, mocker, error_code):
mock_response = {"result": {"great": "success"}, "error_code": error_code.value}