mirror of
				https://github.com/python-kasa/python-kasa.git
				synced 2025-10-31 04:31:54 +00:00 
			
		
		
		
	Enable multiple requests in smartprotocol (#584)
* Enable multiple requests in smartprotocol * Update following review * Remove error_code parameter in exceptions
This commit is contained in:
		| @@ -125,19 +125,19 @@ class AesTransport(BaseTransport): | ||||
|         return resp.status_code, response_data | ||||
|  | ||||
|     def _handle_response_error_code(self, resp_dict: dict, msg: str): | ||||
|         if ( | ||||
|             error_code := SmartErrorCode(resp_dict.get("error_code"))  # type: ignore[arg-type] | ||||
|         ) != SmartErrorCode.SUCCESS: | ||||
|             msg = f"{msg}: {self._host}: {error_code.name}({error_code.value})" | ||||
|             if error_code in SMART_TIMEOUT_ERRORS: | ||||
|                 raise TimeoutException(msg) | ||||
|             if error_code in SMART_RETRYABLE_ERRORS: | ||||
|                 raise RetryableException(msg) | ||||
|             if error_code in SMART_AUTHENTICATION_ERRORS: | ||||
|                 self._handshake_done = False | ||||
|                 self._login_token = None | ||||
|                 raise AuthenticationException(msg) | ||||
|             raise SmartDeviceException(msg) | ||||
|         error_code = SmartErrorCode(resp_dict.get("error_code"))  # type: ignore[arg-type] | ||||
|         if error_code == SmartErrorCode.SUCCESS: | ||||
|             return | ||||
|         msg = f"{msg}: {self._host}: {error_code.name}({error_code.value})" | ||||
|         if error_code in SMART_TIMEOUT_ERRORS: | ||||
|             raise TimeoutException(msg) | ||||
|         if error_code in SMART_RETRYABLE_ERRORS: | ||||
|             raise RetryableException(msg) | ||||
|         if error_code in SMART_AUTHENTICATION_ERRORS: | ||||
|             self._handshake_done = False | ||||
|             self._login_token = None | ||||
|             raise AuthenticationException(msg) | ||||
|         raise SmartDeviceException(msg) | ||||
|  | ||||
|     async def send_secure_passthrough(self, request: str): | ||||
|         """Send encrypted message as passthrough.""" | ||||
|   | ||||
| @@ -62,6 +62,13 @@ class IotProtocol(TPLinkProtocol): | ||||
|                     "Unable to authenticate with %s, not retrying", self._host | ||||
|                 ) | ||||
|                 raise auex | ||||
|             except SmartDeviceException as ex: | ||||
|                 _LOGGER.debug( | ||||
|                     "Unable to connect to the device: %s, not retrying: %s", | ||||
|                     self._host, | ||||
|                     ex, | ||||
|                 ) | ||||
|                 raise ex | ||||
|             except Exception as ex: | ||||
|                 await self.close() | ||||
|                 if retry >= retry_count: | ||||
|   | ||||
| @@ -62,26 +62,7 @@ class SmartProtocol(TPLinkProtocol): | ||||
|     async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: | ||||
|         """Query the device retrying for retry_count on failure.""" | ||||
|         async with self._query_lock: | ||||
|             resp_dict = await self._query(request, retry_count) | ||||
|  | ||||
|             if ( | ||||
|                 error_code := SmartErrorCode(resp_dict.get("error_code"))  # type: ignore[arg-type] | ||||
|             ) != SmartErrorCode.SUCCESS: | ||||
|                 msg = ( | ||||
|                     f"Error querying device: {self._host}: " | ||||
|                     + f"{error_code.name}({error_code.value})" | ||||
|                 ) | ||||
|                 if error_code in SMART_TIMEOUT_ERRORS: | ||||
|                     raise TimeoutException(msg) | ||||
|                 if error_code in SMART_RETRYABLE_ERRORS: | ||||
|                     raise RetryableException(msg) | ||||
|                 if error_code in SMART_AUTHENTICATION_ERRORS: | ||||
|                     raise AuthenticationException(msg) | ||||
|                 raise SmartDeviceException(msg) | ||||
|  | ||||
|             if "result" in resp_dict: | ||||
|                 return resp_dict["result"] | ||||
|             return {} | ||||
|             return await self._query(request, retry_count) | ||||
|  | ||||
|     async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: | ||||
|         for retry in range(retry_count + 1): | ||||
| @@ -128,6 +109,11 @@ class SmartProtocol(TPLinkProtocol): | ||||
|                     raise ex | ||||
|                 await asyncio.sleep(self.SLEEP_SECONDS_AFTER_TIMEOUT) | ||||
|                 continue | ||||
|             except SmartDeviceException as ex: | ||||
|                 # Transport would have raised RetryableException if retry makes sense. | ||||
|                 await self.close() | ||||
|                 _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) | ||||
|                 raise ex | ||||
|             except Exception as ex: | ||||
|                 if retry >= retry_count: | ||||
|                     await self.close() | ||||
| @@ -145,8 +131,15 @@ class SmartProtocol(TPLinkProtocol): | ||||
|  | ||||
|     async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> Dict: | ||||
|         if isinstance(request, dict): | ||||
|             smart_method = next(iter(request)) | ||||
|             smart_params = request[smart_method] | ||||
|             if len(request) == 1: | ||||
|                 smart_method = next(iter(request)) | ||||
|                 smart_params = request[smart_method] | ||||
|             else: | ||||
|                 requests = [] | ||||
|                 for method, params in request.items(): | ||||
|                     requests.append({"method": method, "params": params}) | ||||
|                 smart_method = "multipleRequest" | ||||
|                 smart_params = {"requests": requests} | ||||
|         else: | ||||
|             smart_method = request | ||||
|             smart_params = None | ||||
| @@ -165,7 +158,40 @@ class SmartProtocol(TPLinkProtocol): | ||||
|             _LOGGER.isEnabledFor(logging.DEBUG) and pf(response_data), | ||||
|         ) | ||||
|  | ||||
|         return response_data | ||||
|         self._handle_response_error_code(response_data) | ||||
|  | ||||
|         if (result := response_data.get("result")) is None: | ||||
|             # Single set_ requests do not return a result | ||||
|             return {smart_method: None} | ||||
|  | ||||
|         if (responses := result.get("responses")) is None: | ||||
|             return {smart_method: result} | ||||
|  | ||||
|         # responses is returned for multipleRequest | ||||
|         multi_result = {} | ||||
|         for response in responses: | ||||
|             self._handle_response_error_code(response) | ||||
|             result = response.get("result", None) | ||||
|             multi_result[response["method"]] = result | ||||
|         return multi_result | ||||
|  | ||||
|     def _handle_response_error_code(self, resp_dict: dict): | ||||
|         error_code = SmartErrorCode(resp_dict.get("error_code"))  # type: ignore[arg-type] | ||||
|         if error_code == SmartErrorCode.SUCCESS: | ||||
|             return | ||||
|         msg = ( | ||||
|             f"Error querying device: {self._host}: " | ||||
|             + f"{error_code.name}({error_code.value})" | ||||
|         ) | ||||
|         if method := resp_dict.get("method"): | ||||
|             msg += f" for method: {method}" | ||||
|         if error_code in SMART_TIMEOUT_ERRORS: | ||||
|             raise TimeoutException(msg) | ||||
|         if error_code in SMART_RETRYABLE_ERRORS: | ||||
|             raise RetryableException(msg) | ||||
|         if error_code in SMART_AUTHENTICATION_ERRORS: | ||||
|             raise AuthenticationException(msg) | ||||
|         raise SmartDeviceException(msg) | ||||
|  | ||||
|     async def close(self) -> None: | ||||
|         """Close the protocol.""" | ||||
|   | ||||
| @@ -41,11 +41,18 @@ class TapoDevice(SmartDevice): | ||||
|             raise AuthenticationException("Tapo plug requires authentication.") | ||||
|  | ||||
|         if self._components is None: | ||||
|             self._components = await self.protocol.query("component_nego") | ||||
|             resp = await self.protocol.query("component_nego") | ||||
|             self._components = resp["component_nego"] | ||||
|  | ||||
|         self._info = await self.protocol.query("get_device_info") | ||||
|         self._usage = await self.protocol.query("get_device_usage") | ||||
|         self._time = await self.protocol.query("get_device_time") | ||||
|         req = { | ||||
|             "get_device_info": None, | ||||
|             "get_device_usage": None, | ||||
|             "get_device_time": None, | ||||
|         } | ||||
|         resp = await self.protocol.query(req) | ||||
|         self._info = resp["get_device_info"] | ||||
|         self._usage = resp["get_device_usage"] | ||||
|         self._time = resp["get_device_time"] | ||||
|  | ||||
|         self._last_update = self._data = { | ||||
|             "components": self._components, | ||||
|   | ||||
| @@ -39,8 +39,13 @@ class TapoPlug(TapoDevice): | ||||
|         """Call the device endpoint and update the device data.""" | ||||
|         await super().update(update_children) | ||||
|  | ||||
|         self._energy = await self.protocol.query("get_energy_usage") | ||||
|         self._emeter = await self.protocol.query("get_current_power") | ||||
|         req = { | ||||
|             "get_energy_usage": None, | ||||
|             "get_current_power": None, | ||||
|         } | ||||
|         resp = await self.protocol.query(req) | ||||
|         self._energy = resp["get_energy_usage"] | ||||
|         self._emeter = resp["get_current_power"] | ||||
|  | ||||
|         self._data["energy"] = self._energy | ||||
|         self._data["emeter"] = self._emeter | ||||
| @@ -71,6 +76,13 @@ class TapoPlug(TapoDevice): | ||||
|             } | ||||
|         ) | ||||
|  | ||||
|     async def get_emeter_realtime(self) -> EmeterStatus: | ||||
|         """Retrieve current energy readings.""" | ||||
|         self._verify_emeter() | ||||
|         resp = await self.protocol.query("get_energy_usage") | ||||
|         self._energy = resp["get_energy_usage"] | ||||
|         return self.emeter_realtime | ||||
|  | ||||
|     @property | ||||
|     def emeter_today(self) -> Optional[float]: | ||||
|         """Get the emeter value for today.""" | ||||
|   | ||||
| @@ -196,9 +196,13 @@ def parametrize(desc, devices, protocol_filter=None, ids=None): | ||||
|     ) | ||||
|  | ||||
|  | ||||
| has_emeter = parametrize("has emeter", WITH_EMETER_IOT, protocol_filter={"IOT"}) | ||||
| has_emeter = parametrize("has emeter", WITH_EMETER, protocol_filter={"SMART", "IOT"}) | ||||
| no_emeter = parametrize( | ||||
|     "no emeter", ALL_DEVICES_IOT - WITH_EMETER_IOT, protocol_filter={"SMART", "IOT"} | ||||
|     "no emeter", ALL_DEVICES - WITH_EMETER, protocol_filter={"SMART", "IOT"} | ||||
| ) | ||||
| has_emeter_iot = parametrize("has emeter iot", WITH_EMETER_IOT, protocol_filter={"IOT"}) | ||||
| no_emeter_iot = parametrize( | ||||
|     "no emeter iot", ALL_DEVICES_IOT - WITH_EMETER_IOT, protocol_filter={"IOT"} | ||||
| ) | ||||
|  | ||||
| bulb = parametrize("bulbs", BULBS, protocol_filter={"SMART", "IOT"}) | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| import copy | ||||
| import logging | ||||
| import re | ||||
| import warnings | ||||
| from json import loads as json_loads | ||||
|  | ||||
| from voluptuous import ( | ||||
| @@ -294,9 +295,7 @@ class FakeSmartProtocol(SmartProtocol): | ||||
|     async def query(self, request, retry_count: int = 3): | ||||
|         """Implement query here so can still patch SmartProtocol.query.""" | ||||
|         resp_dict = await self._query(request, retry_count) | ||||
|         if "result" in resp_dict: | ||||
|             return resp_dict["result"] | ||||
|         return {} | ||||
|         return resp_dict | ||||
|  | ||||
|  | ||||
| class FakeSmartTransport(BaseTransport): | ||||
| @@ -306,26 +305,34 @@ class FakeSmartTransport(BaseTransport): | ||||
|         ) | ||||
|         self.info = info | ||||
|  | ||||
|     @property | ||||
|     def needs_handshake(self) -> bool: | ||||
|         return False | ||||
|  | ||||
|     @property | ||||
|     def needs_login(self) -> bool: | ||||
|         return False | ||||
|  | ||||
|     async def login(self, request: str) -> None: | ||||
|         pass | ||||
|  | ||||
|     async def handshake(self) -> None: | ||||
|         pass | ||||
|  | ||||
|     async def send(self, request: str): | ||||
|         request_dict = json_loads(request) | ||||
|         method = request_dict["method"] | ||||
|         params = request_dict["params"] | ||||
|         if method == "multipleRequest": | ||||
|             responses = [] | ||||
|             for request in params["requests"]: | ||||
|                 response = self._send_request(request)  # type: ignore[arg-type] | ||||
|                 response["method"] = request["method"]  # type: ignore[index] | ||||
|                 responses.append(response) | ||||
|             return {"result": {"responses": responses}, "error_code": 0} | ||||
|         else: | ||||
|             return self._send_request(request_dict) | ||||
|  | ||||
|     def _send_request(self, request_dict: dict): | ||||
|         method = request_dict["method"] | ||||
|         params = request_dict["params"] | ||||
|         if method == "component_nego" or method[:4] == "get_": | ||||
|             return {"result": self.info[method], "error_code": 0} | ||||
|             if method in self.info: | ||||
|                 return {"result": self.info[method], "error_code": 0} | ||||
|             else: | ||||
|                 warnings.warn( | ||||
|                     UserWarning( | ||||
|                         f"Fixture missing expected method {method}, try to regenerate" | ||||
|                     ), | ||||
|                     stacklevel=1, | ||||
|                 ) | ||||
|                 return {"result": {}, "error_code": 0} | ||||
|         elif method[:4] == "set_": | ||||
|             target_method = f"get_{method[4:]}" | ||||
|             self.info[target_method].update(params) | ||||
|   | ||||
| @@ -12,7 +12,12 @@ from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padd | ||||
|  | ||||
| from ..aestransport import AesEncyptionSession, AesTransport | ||||
| from ..credentials import Credentials | ||||
| from ..exceptions import SmartDeviceException | ||||
| from ..exceptions import ( | ||||
|     SMART_RETRYABLE_ERRORS, | ||||
|     SMART_TIMEOUT_ERRORS, | ||||
|     SmartDeviceException, | ||||
|     SmartErrorCode, | ||||
| ) | ||||
|  | ||||
| DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}} | ||||
|  | ||||
| @@ -105,6 +110,32 @@ async def test_send(mocker, status_code, error_code, inner_error_code, expectati | ||||
|         assert "result" in res | ||||
|  | ||||
|  | ||||
| ERRORS = [e for e in SmartErrorCode if e != 0] | ||||
|  | ||||
|  | ||||
| @pytest.mark.parametrize("error_code", ERRORS, ids=lambda e: e.name) | ||||
| async def test_passthrough_errors(mocker, error_code): | ||||
|     host = "127.0.0.1" | ||||
|     mock_aes_device = MockAesDevice(host, 200, error_code, 0) | ||||
|     mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post) | ||||
|  | ||||
|     transport = AesTransport(host=host, credentials=Credentials("foo", "bar")) | ||||
|     transport._handshake_done = True | ||||
|     transport._session_expire_at = time.time() + 86400 | ||||
|     transport._encryption_session = mock_aes_device.encryption_session | ||||
|     transport._login_token = mock_aes_device.token | ||||
|  | ||||
|     request = { | ||||
|         "method": "get_device_info", | ||||
|         "params": None, | ||||
|         "request_time_milis": round(time.time() * 1000), | ||||
|         "requestID": 1, | ||||
|         "terminal_uuid": "foobar", | ||||
|     } | ||||
|     with pytest.raises(SmartDeviceException): | ||||
|         await transport.send(json_dumps(request)) | ||||
|  | ||||
|  | ||||
| class MockAesDevice: | ||||
|     class _mock_response: | ||||
|         def __init__(self, status_code, json: dict): | ||||
|   | ||||
| @@ -2,7 +2,7 @@ import pytest | ||||
|  | ||||
| from kasa import EmeterStatus, SmartDeviceException | ||||
|  | ||||
| from .conftest import has_emeter, no_emeter | ||||
| from .conftest import has_emeter, has_emeter_iot, no_emeter | ||||
| from .newfakes import CURRENT_CONSUMPTION_SCHEMA | ||||
|  | ||||
|  | ||||
| @@ -20,7 +20,7 @@ async def test_no_emeter(dev): | ||||
|         await dev.erase_emeter_stats() | ||||
|  | ||||
|  | ||||
| @has_emeter | ||||
| @has_emeter_iot | ||||
| async def test_get_emeter_realtime(dev): | ||||
|     assert dev.has_emeter | ||||
|  | ||||
| @@ -28,7 +28,7 @@ async def test_get_emeter_realtime(dev): | ||||
|     CURRENT_CONSUMPTION_SCHEMA(current_emeter) | ||||
|  | ||||
|  | ||||
| @has_emeter | ||||
| @has_emeter_iot | ||||
| @pytest.mark.requires_dummy | ||||
| async def test_get_emeter_daily(dev): | ||||
|     assert dev.has_emeter | ||||
| @@ -48,7 +48,7 @@ async def test_get_emeter_daily(dev): | ||||
|     assert v * 1000 == v2 | ||||
|  | ||||
|  | ||||
| @has_emeter | ||||
| @has_emeter_iot | ||||
| @pytest.mark.requires_dummy | ||||
| async def test_get_emeter_monthly(dev): | ||||
|     assert dev.has_emeter | ||||
| @@ -68,7 +68,7 @@ async def test_get_emeter_monthly(dev): | ||||
|     assert v * 1000 == v2 | ||||
|  | ||||
|  | ||||
| @has_emeter | ||||
| @has_emeter_iot | ||||
| async def test_emeter_status(dev): | ||||
|     assert dev.has_emeter | ||||
|  | ||||
|   | ||||
| @@ -26,20 +26,29 @@ class _mock_response: | ||||
|         self.content = content | ||||
|  | ||||
|  | ||||
| @pytest.mark.parametrize( | ||||
|     "error, retry_expectation", | ||||
|     [ | ||||
|         (Exception("dummy exception"), True), | ||||
|         (SmartDeviceException("dummy exception"), False), | ||||
|     ], | ||||
|     ids=("Exception", "SmartDeviceException"), | ||||
| ) | ||||
| @pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport]) | ||||
| @pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol]) | ||||
| @pytest.mark.parametrize("retry_count", [1, 3, 5]) | ||||
| async def test_protocol_retries(mocker, retry_count, protocol_class, transport_class): | ||||
| async def test_protocol_retries( | ||||
|     mocker, retry_count, protocol_class, transport_class, error, retry_expectation | ||||
| ): | ||||
|     host = "127.0.0.1" | ||||
|     conn = mocker.patch.object( | ||||
|         httpx.AsyncClient, "post", side_effect=Exception("dummy exception") | ||||
|     ) | ||||
|     conn = mocker.patch.object(httpx.AsyncClient, "post", side_effect=error) | ||||
|     with pytest.raises(SmartDeviceException): | ||||
|         await protocol_class(host, transport=transport_class(host)).query( | ||||
|             DUMMY_QUERY, retry_count=retry_count | ||||
|         ) | ||||
|  | ||||
|     assert conn.call_count == retry_count + 1 | ||||
|     expected_count = retry_count + 1 if retry_expectation else 1 | ||||
|     assert conn.call_count == expected_count | ||||
|  | ||||
|  | ||||
| @pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport]) | ||||
| @@ -109,7 +118,7 @@ async def test_protocol_reconnect(mocker, retry_count, protocol_class, transport | ||||
|     response = await protocol_class(host, transport=transport_class(host)).query( | ||||
|         DUMMY_QUERY, retry_count=retry_count | ||||
|     ) | ||||
|     assert "result" in response or "great" in response | ||||
|     assert "result" in response or "foobar" in response | ||||
|     assert send_mock.call_count == retry_count | ||||
|  | ||||
|  | ||||
|   | ||||
| @@ -8,7 +8,7 @@ import kasa | ||||
| from kasa import Credentials, SmartDevice, SmartDeviceException | ||||
| from kasa.smartdevice import DeviceType | ||||
|  | ||||
| from .conftest import device_iot, handle_turn_on, has_emeter, no_emeter, turn_on | ||||
| from .conftest import device_iot, handle_turn_on, has_emeter, no_emeter_iot, turn_on | ||||
| from .newfakes import PLUG_SCHEMA, TZ_SCHEMA, FakeTransportProtocol | ||||
|  | ||||
| # List of all SmartXXX classes including the SmartDevice base class | ||||
| @@ -48,7 +48,7 @@ async def test_initial_update_emeter(dev, mocker): | ||||
|     assert spy.call_count == expected_queries + len(dev.children) | ||||
|  | ||||
|  | ||||
| @no_emeter | ||||
| @no_emeter_iot | ||||
| async def test_initial_update_no_emeter(dev, mocker): | ||||
|     """Test that the initial update performs second query if emeter is available.""" | ||||
|     dev._last_update = None | ||||
|   | ||||
							
								
								
									
										81
									
								
								kasa/tests/test_smartprotocol.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										81
									
								
								kasa/tests/test_smartprotocol.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,81 @@ | ||||
| import errno | ||||
| import json | ||||
| import logging | ||||
| import secrets | ||||
| import struct | ||||
| import sys | ||||
| import time | ||||
| from contextlib import nullcontext as does_not_raise | ||||
| from itertools import chain | ||||
|  | ||||
| import httpx | ||||
| import pytest | ||||
|  | ||||
| from ..aestransport import AesTransport | ||||
| from ..credentials import Credentials | ||||
| from ..exceptions import ( | ||||
|     SMART_RETRYABLE_ERRORS, | ||||
|     SMART_TIMEOUT_ERRORS, | ||||
|     SmartDeviceException, | ||||
|     SmartErrorCode, | ||||
| ) | ||||
| from ..iotprotocol import IotProtocol | ||||
| from ..klaptransport import KlapEncryptionSession, KlapTransport, _sha256 | ||||
| from ..smartprotocol import SmartProtocol | ||||
|  | ||||
| DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}} | ||||
| ERRORS = [e for e in SmartErrorCode if e != 0] | ||||
|  | ||||
|  | ||||
| @pytest.mark.parametrize("error_code", ERRORS, ids=lambda e: e.name) | ||||
| async def test_smart_device_errors(mocker, error_code): | ||||
|     host = "127.0.0.1" | ||||
|     mock_response = {"result": {"great": "success"}, "error_code": error_code.value} | ||||
|  | ||||
|     mocker.patch.object(AesTransport, "perform_handshake") | ||||
|     mocker.patch.object(AesTransport, "perform_login") | ||||
|  | ||||
|     send_mock = mocker.patch.object(AesTransport, "send", return_value=mock_response) | ||||
|  | ||||
|     protocol = SmartProtocol(host, transport=AesTransport(host)) | ||||
|     with pytest.raises(SmartDeviceException): | ||||
|         await protocol.query(DUMMY_QUERY, retry_count=2) | ||||
|  | ||||
|     if error_code in chain(SMART_TIMEOUT_ERRORS, SMART_RETRYABLE_ERRORS): | ||||
|         expected_calls = 3 | ||||
|     else: | ||||
|         expected_calls = 1 | ||||
|     assert send_mock.call_count == expected_calls | ||||
|  | ||||
|  | ||||
| @pytest.mark.parametrize("error_code", ERRORS, ids=lambda e: e.name) | ||||
| async def test_smart_device_errors_in_multiple_request(mocker, error_code): | ||||
|     host = "127.0.0.1" | ||||
|     mock_response = { | ||||
|         "result": { | ||||
|             "responses": [ | ||||
|                 {"method": "foobar1", "result": {"great": "success"}, "error_code": 0}, | ||||
|                 { | ||||
|                     "method": "foobar2", | ||||
|                     "result": {"great": "success"}, | ||||
|                     "error_code": error_code.value, | ||||
|                 }, | ||||
|                 {"method": "foobar3", "result": {"great": "success"}, "error_code": 0}, | ||||
|             ] | ||||
|         }, | ||||
|         "error_code": 0, | ||||
|     } | ||||
|  | ||||
|     mocker.patch.object(AesTransport, "perform_handshake") | ||||
|     mocker.patch.object(AesTransport, "perform_login") | ||||
|  | ||||
|     send_mock = mocker.patch.object(AesTransport, "send", return_value=mock_response) | ||||
|  | ||||
|     protocol = SmartProtocol(host, transport=AesTransport(host)) | ||||
|     with pytest.raises(SmartDeviceException): | ||||
|         await protocol.query(DUMMY_QUERY, retry_count=2) | ||||
|     if error_code in chain(SMART_TIMEOUT_ERRORS, SMART_RETRYABLE_ERRORS): | ||||
|         expected_calls = 3 | ||||
|     else: | ||||
|         expected_calls = 1 | ||||
|     assert send_mock.call_count == expected_calls | ||||
		Reference in New Issue
	
	Block a user
	 sdb9696
					sdb9696