mirror of
				https://github.com/python-kasa/python-kasa.git
				synced 2025-11-03 22:22:06 +00:00 
			
		
		
		
	Let caller handle SMART errors on multi-requests (#754)
* Fix for missing get_device_usage * Fix coverage and add methods to exceptions * Remove unused caplog fixture
This commit is contained in:
		@@ -45,6 +45,9 @@ class ConnectionException(SmartDeviceException):
 | 
			
		||||
class SmartErrorCode(IntEnum):
 | 
			
		||||
    """Enum for SMART Error Codes."""
 | 
			
		||||
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
        return f"{self.name}({self.value})"
 | 
			
		||||
 | 
			
		||||
    SUCCESS = 0
 | 
			
		||||
 | 
			
		||||
    # Transport Errors
 | 
			
		||||
 
 | 
			
		||||
@@ -9,7 +9,7 @@ from ..device import Device, WifiNetwork
 | 
			
		||||
from ..device_type import DeviceType
 | 
			
		||||
from ..deviceconfig import DeviceConfig
 | 
			
		||||
from ..emeterstatus import EmeterStatus
 | 
			
		||||
from ..exceptions import AuthenticationException, SmartDeviceException
 | 
			
		||||
from ..exceptions import AuthenticationException, SmartDeviceException, SmartErrorCode
 | 
			
		||||
from ..feature import Feature, FeatureType
 | 
			
		||||
from ..smartprotocol import SmartProtocol
 | 
			
		||||
 | 
			
		||||
@@ -61,6 +61,24 @@ class SmartDevice(Device):
 | 
			
		||||
        """Return list of children."""
 | 
			
		||||
        return list(self._children.values())
 | 
			
		||||
 | 
			
		||||
    def _try_get_response(self, responses: dict, request: str, default=None) -> dict:
 | 
			
		||||
        response = responses.get(request)
 | 
			
		||||
        if isinstance(response, SmartErrorCode):
 | 
			
		||||
            _LOGGER.debug(
 | 
			
		||||
                "Error %s getting request %s for device %s",
 | 
			
		||||
                response,
 | 
			
		||||
                request,
 | 
			
		||||
                self.host,
 | 
			
		||||
            )
 | 
			
		||||
            response = None
 | 
			
		||||
        if response is not None:
 | 
			
		||||
            return response
 | 
			
		||||
        if default is not None:
 | 
			
		||||
            return default
 | 
			
		||||
        raise SmartDeviceException(
 | 
			
		||||
            f"{request} not found in {responses} for device {self.host}"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    async def update(self, update_children: bool = True):
 | 
			
		||||
        """Update the device."""
 | 
			
		||||
        if self.credentials is None and self.credentials_hash is None:
 | 
			
		||||
@@ -87,7 +105,7 @@ class SmartDevice(Device):
 | 
			
		||||
                "get_current_power": None,
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        if self._components["device"] >= 2:
 | 
			
		||||
        if self._components.get("device", 0) >= 2:
 | 
			
		||||
            extra_reqs = {
 | 
			
		||||
                **extra_reqs,
 | 
			
		||||
                "get_device_usage": None,
 | 
			
		||||
@@ -101,13 +119,13 @@ class SmartDevice(Device):
 | 
			
		||||
 | 
			
		||||
        resp = await self.protocol.query(req)
 | 
			
		||||
 | 
			
		||||
        self._info = resp["get_device_info"]
 | 
			
		||||
        self._time = resp["get_device_time"]
 | 
			
		||||
        self._info = self._try_get_response(resp, "get_device_info")
 | 
			
		||||
        self._time = self._try_get_response(resp, "get_device_time", {})
 | 
			
		||||
        # Device usage is not available on older firmware versions
 | 
			
		||||
        self._usage = resp.get("get_device_usage", {})
 | 
			
		||||
        self._usage = self._try_get_response(resp, "get_device_usage", {})
 | 
			
		||||
        # Emeter is not always available, but we set them still for now.
 | 
			
		||||
        self._energy = resp.get("get_energy_usage", {})
 | 
			
		||||
        self._emeter = resp.get("get_current_power", {})
 | 
			
		||||
        self._energy = self._try_get_response(resp, "get_energy_usage", {})
 | 
			
		||||
        self._emeter = self._try_get_response(resp, "get_current_power", {})
 | 
			
		||||
 | 
			
		||||
        self._last_update = {
 | 
			
		||||
            "components": self._components_raw,
 | 
			
		||||
@@ -116,7 +134,7 @@ class SmartDevice(Device):
 | 
			
		||||
            "time": self._time,
 | 
			
		||||
            "energy": self._energy,
 | 
			
		||||
            "emeter": self._emeter,
 | 
			
		||||
            "child_info": resp.get("get_child_device_list", {}),
 | 
			
		||||
            "child_info": self._try_get_response(resp, "get_child_device_list", {}),
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if child_info := self._last_update.get("child_info"):
 | 
			
		||||
 
 | 
			
		||||
@@ -129,19 +129,21 @@ class SmartProtocol(BaseProtocol):
 | 
			
		||||
                    pf(smart_request),
 | 
			
		||||
                )
 | 
			
		||||
            response_step = await self._transport.send(smart_request)
 | 
			
		||||
            batch_name = f"multi-request-batch-{i+1}"
 | 
			
		||||
            if debug_enabled:
 | 
			
		||||
                _LOGGER.debug(
 | 
			
		||||
                    "%s multi-request-batch-%s << %s",
 | 
			
		||||
                    "%s %s << %s",
 | 
			
		||||
                    self._host,
 | 
			
		||||
                    i + 1,
 | 
			
		||||
                    batch_name,
 | 
			
		||||
                    pf(response_step),
 | 
			
		||||
                )
 | 
			
		||||
            self._handle_response_error_code(response_step)
 | 
			
		||||
            self._handle_response_error_code(response_step, batch_name)
 | 
			
		||||
            responses = response_step["result"]["responses"]
 | 
			
		||||
            for response in responses:
 | 
			
		||||
                self._handle_response_error_code(response)
 | 
			
		||||
                method = response["method"]
 | 
			
		||||
                self._handle_response_error_code(response, method, raise_on_error=False)
 | 
			
		||||
                result = response.get("result", None)
 | 
			
		||||
                multi_result[response["method"]] = result
 | 
			
		||||
                multi_result[method] = result
 | 
			
		||||
        return multi_result
 | 
			
		||||
 | 
			
		||||
    async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> Dict:
 | 
			
		||||
@@ -173,22 +175,24 @@ class SmartProtocol(BaseProtocol):
 | 
			
		||||
                pf(response_data),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        self._handle_response_error_code(response_data)
 | 
			
		||||
        self._handle_response_error_code(response_data, smart_method)
 | 
			
		||||
 | 
			
		||||
        # Single set_ requests do not return a result
 | 
			
		||||
        result = response_data.get("result")
 | 
			
		||||
        return {smart_method: result}
 | 
			
		||||
 | 
			
		||||
    def _handle_response_error_code(self, resp_dict: dict):
 | 
			
		||||
    def _handle_response_error_code(self, resp_dict: dict, method, raise_on_error=True):
 | 
			
		||||
        error_code = SmartErrorCode(resp_dict.get("error_code"))  # type: ignore[arg-type]
 | 
			
		||||
        if error_code == SmartErrorCode.SUCCESS:
 | 
			
		||||
            return
 | 
			
		||||
        if not raise_on_error:
 | 
			
		||||
            resp_dict["result"] = error_code
 | 
			
		||||
            return
 | 
			
		||||
        msg = (
 | 
			
		||||
            f"Error querying device: {self._host}: "
 | 
			
		||||
            + f"{error_code.name}({error_code.value})"
 | 
			
		||||
            + f" for method: {method}"
 | 
			
		||||
        )
 | 
			
		||||
        if method := resp_dict.get("method"):
 | 
			
		||||
            msg += f" for method: {method}"
 | 
			
		||||
        if error_code in SMART_TIMEOUT_ERRORS:
 | 
			
		||||
            raise TimeoutException(msg, error_code=error_code)
 | 
			
		||||
        if error_code in SMART_RETRYABLE_ERRORS:
 | 
			
		||||
@@ -338,7 +342,7 @@ class _ChildProtocolWrapper(SmartProtocol):
 | 
			
		||||
        result = response.get("control_child")
 | 
			
		||||
        # Unwrap responseData for control_child
 | 
			
		||||
        if result and (response_data := result.get("responseData")):
 | 
			
		||||
            self._handle_response_error_code(response_data)
 | 
			
		||||
            self._handle_response_error_code(response_data, "control_child")
 | 
			
		||||
            result = response_data.get("result")
 | 
			
		||||
 | 
			
		||||
        # TODO: handle multipleRequest unwrapping
 | 
			
		||||
 
 | 
			
		||||
@@ -1,5 +1,6 @@
 | 
			
		||||
import importlib
 | 
			
		||||
import inspect
 | 
			
		||||
import logging
 | 
			
		||||
import pkgutil
 | 
			
		||||
import re
 | 
			
		||||
import sys
 | 
			
		||||
@@ -21,10 +22,18 @@ from voluptuous import (
 | 
			
		||||
 | 
			
		||||
import kasa
 | 
			
		||||
from kasa import Credentials, Device, DeviceConfig, SmartDeviceException
 | 
			
		||||
from kasa.exceptions import SmartErrorCode
 | 
			
		||||
from kasa.iot import IotDevice
 | 
			
		||||
from kasa.smart import SmartChildDevice, SmartDevice
 | 
			
		||||
 | 
			
		||||
from .conftest import device_iot, handle_turn_on, has_emeter_iot, no_emeter_iot, turn_on
 | 
			
		||||
from .conftest import (
 | 
			
		||||
    device_iot,
 | 
			
		||||
    device_smart,
 | 
			
		||||
    handle_turn_on,
 | 
			
		||||
    has_emeter_iot,
 | 
			
		||||
    no_emeter_iot,
 | 
			
		||||
    turn_on,
 | 
			
		||||
)
 | 
			
		||||
from .fakeprotocol_iot import FakeIotProtocol
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -300,6 +309,33 @@ async def test_modules_not_supported(dev: IotDevice):
 | 
			
		||||
        assert module.is_supported is not None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@device_smart
 | 
			
		||||
async def test_update_sub_errors(dev: SmartDevice, caplog):
 | 
			
		||||
    mock_response: dict = {
 | 
			
		||||
        "get_device_info": {},
 | 
			
		||||
        "get_device_usage": SmartErrorCode.PARAMS_ERROR,
 | 
			
		||||
        "get_device_time": {},
 | 
			
		||||
    }
 | 
			
		||||
    caplog.set_level(logging.DEBUG)
 | 
			
		||||
    with patch.object(dev.protocol, "query", return_value=mock_response):
 | 
			
		||||
        await dev.update()
 | 
			
		||||
    msg = "Error PARAMS_ERROR(-1008) getting request get_device_usage for device 127.0.0.123"
 | 
			
		||||
    assert msg in caplog.text
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@device_smart
 | 
			
		||||
async def test_update_no_device_info(dev: SmartDevice):
 | 
			
		||||
    mock_response: dict = {
 | 
			
		||||
        "get_device_usage": {},
 | 
			
		||||
        "get_device_time": {},
 | 
			
		||||
    }
 | 
			
		||||
    msg = f"get_device_info not found in {mock_response} for device 127.0.0.123"
 | 
			
		||||
    with patch.object(dev.protocol, "query", return_value=mock_response), pytest.raises(
 | 
			
		||||
        SmartDeviceException, match=msg
 | 
			
		||||
    ):
 | 
			
		||||
        await dev.update()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize(
 | 
			
		||||
    "device_class, use_class", kasa.deprecated_smart_devices.items()
 | 
			
		||||
)
 | 
			
		||||
 
 | 
			
		||||
@@ -60,13 +60,10 @@ async def test_smart_device_errors_in_multiple_request(
 | 
			
		||||
    send_mock = mocker.patch.object(
 | 
			
		||||
        dummy_protocol._transport, "send", return_value=mock_response
 | 
			
		||||
    )
 | 
			
		||||
    with pytest.raises(SmartDeviceException):
 | 
			
		||||
        await dummy_protocol.query(DUMMY_MULTIPLE_QUERY, retry_count=2)
 | 
			
		||||
    if error_code in chain(SMART_TIMEOUT_ERRORS, SMART_RETRYABLE_ERRORS):
 | 
			
		||||
        expected_calls = 3
 | 
			
		||||
    else:
 | 
			
		||||
        expected_calls = 1
 | 
			
		||||
    assert send_mock.call_count == expected_calls
 | 
			
		||||
 | 
			
		||||
    resp_dict = await dummy_protocol.query(DUMMY_MULTIPLE_QUERY, retry_count=2)
 | 
			
		||||
    assert resp_dict["foobar2"] == error_code
 | 
			
		||||
    assert send_mock.call_count == 1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize("request_size", [1, 3, 5, 10])
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user