mirror of
				https://github.com/python-kasa/python-kasa.git
				synced 2025-10-30 20:21:54 +00:00 
			
		
		
		
	Add retries to protocol queries (#65)
* Add retries to query(), defaults to 3 + add tests * Catch also json decoding errors for retries * add missing exceptions file, fix old protocol tests
This commit is contained in:
		| @@ -13,9 +13,10 @@ to be handled by the user of the library. | ||||
| """ | ||||
| from importlib_metadata import version  # type: ignore | ||||
| from kasa.discover import Discover | ||||
| from kasa.exceptions import SmartDeviceException | ||||
| from kasa.protocol import TPLinkSmartHomeProtocol | ||||
| from kasa.smartbulb import SmartBulb | ||||
| from kasa.smartdevice import DeviceType, EmeterStatus, SmartDevice, SmartDeviceException | ||||
| from kasa.smartdevice import DeviceType, EmeterStatus, SmartDevice | ||||
| from kasa.smartdimmer import SmartDimmer | ||||
| from kasa.smartplug import SmartPlug | ||||
| from kasa.smartstrip import SmartStrip | ||||
|   | ||||
							
								
								
									
										5
									
								
								kasa/exceptions.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								kasa/exceptions.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| """python-kasa exceptions.""" | ||||
|  | ||||
|  | ||||
| class SmartDeviceException(Exception): | ||||
|     """Base exception for device errors.""" | ||||
| @@ -16,6 +16,8 @@ import struct | ||||
| from pprint import pformat as pf | ||||
| from typing import Dict, Union | ||||
|  | ||||
| from .exceptions import SmartDeviceException | ||||
|  | ||||
| _LOGGER = logging.getLogger(__name__) | ||||
|  | ||||
|  | ||||
| @@ -27,12 +29,13 @@ class TPLinkSmartHomeProtocol: | ||||
|     DEFAULT_TIMEOUT = 5 | ||||
|  | ||||
|     @staticmethod | ||||
|     async def query(host: str, request: Union[str, Dict]) -> Dict: | ||||
|     async def query(host: str, request: Union[str, Dict], retry_count: int = 3) -> Dict: | ||||
|         """Request information from a TP-Link SmartHome Device. | ||||
|  | ||||
|         :param str host: host name or ip address of the device | ||||
|         :param request: command to send to the device (can be either dict or | ||||
|         json string) | ||||
|         :param retry_count: how many retries to do in case of failure | ||||
|         :return: response dict | ||||
|         """ | ||||
|         if isinstance(request, dict): | ||||
| @@ -40,35 +43,51 @@ class TPLinkSmartHomeProtocol: | ||||
|  | ||||
|         timeout = TPLinkSmartHomeProtocol.DEFAULT_TIMEOUT | ||||
|         writer = None | ||||
|         try: | ||||
|             task = asyncio.open_connection(host, TPLinkSmartHomeProtocol.DEFAULT_PORT) | ||||
|             reader, writer = await asyncio.wait_for(task, timeout=timeout) | ||||
|             _LOGGER.debug("> (%i) %s", len(request), request) | ||||
|             writer.write(TPLinkSmartHomeProtocol.encrypt(request)) | ||||
|             await writer.drain() | ||||
|         for retry in range(retry_count + 1): | ||||
|             try: | ||||
|                 task = asyncio.open_connection( | ||||
|                     host, TPLinkSmartHomeProtocol.DEFAULT_PORT | ||||
|                 ) | ||||
|                 reader, writer = await asyncio.wait_for(task, timeout=timeout) | ||||
|                 _LOGGER.debug("> (%i) %s", len(request), request) | ||||
|                 writer.write(TPLinkSmartHomeProtocol.encrypt(request)) | ||||
|                 await writer.drain() | ||||
|  | ||||
|             buffer = bytes() | ||||
|             # Some devices send responses with a length header of 0 and | ||||
|             # terminate with a zero size chunk. Others send the length and | ||||
|             # will hang if we attempt to read more data. | ||||
|             length = -1 | ||||
|             while True: | ||||
|                 chunk = await reader.read(4096) | ||||
|                 if length == -1: | ||||
|                     length = struct.unpack(">I", chunk[0:4])[0] | ||||
|                 buffer += chunk | ||||
|                 if (length > 0 and len(buffer) >= length + 4) or not chunk: | ||||
|                     break | ||||
|         finally: | ||||
|             if writer: | ||||
|                 writer.close() | ||||
|                 await writer.wait_closed() | ||||
|                 buffer = bytes() | ||||
|                 # Some devices send responses with a length header of 0 and | ||||
|                 # terminate with a zero size chunk. Others send the length and | ||||
|                 # will hang if we attempt to read more data. | ||||
|                 length = -1 | ||||
|                 while True: | ||||
|                     chunk = await reader.read(4096) | ||||
|                     if length == -1: | ||||
|                         length = struct.unpack(">I", chunk[0:4])[0] | ||||
|                     buffer += chunk | ||||
|                     if (length > 0 and len(buffer) >= length + 4) or not chunk: | ||||
|                         break | ||||
|  | ||||
|         response = TPLinkSmartHomeProtocol.decrypt(buffer[4:]) | ||||
|         json_payload = json.loads(response) | ||||
|         _LOGGER.debug("< (%i) %s", len(response), pf(json_payload)) | ||||
|                 response = TPLinkSmartHomeProtocol.decrypt(buffer[4:]) | ||||
|                 json_payload = json.loads(response) | ||||
|                 _LOGGER.debug("< (%i) %s", len(response), pf(json_payload)) | ||||
|  | ||||
|         return json_payload | ||||
|                 return json_payload | ||||
|  | ||||
|             except Exception as ex: | ||||
|                 if retry >= retry_count: | ||||
|                     _LOGGER.debug("Giving up after %s retries", retry) | ||||
|                     raise SmartDeviceException( | ||||
|                         "Unable to query the device: %s" % ex | ||||
|                     ) from ex | ||||
|  | ||||
|                 _LOGGER.debug("Unable to query the device, retrying: %s", ex) | ||||
|  | ||||
|             finally: | ||||
|                 if writer: | ||||
|                     writer.close() | ||||
|                     await writer.wait_closed() | ||||
|  | ||||
|         # make mypy happy, this should never be reached.. | ||||
|         raise SmartDeviceException("Query reached somehow to unreachable") | ||||
|  | ||||
|     @staticmethod | ||||
|     def encrypt(request: str) -> bytes: | ||||
|   | ||||
| @@ -19,7 +19,8 @@ from datetime import datetime, timedelta | ||||
| from enum import Enum | ||||
| from typing import Any, Dict, List, Optional | ||||
|  | ||||
| from kasa.protocol import TPLinkSmartHomeProtocol | ||||
| from .exceptions import SmartDeviceException | ||||
| from .protocol import TPLinkSmartHomeProtocol | ||||
|  | ||||
| _LOGGER = logging.getLogger(__name__) | ||||
|  | ||||
| @@ -47,10 +48,6 @@ class WifiNetwork: | ||||
|     rssi: Optional[int] = None | ||||
|  | ||||
|  | ||||
| class SmartDeviceException(Exception): | ||||
|     """Base exception for device errors.""" | ||||
|  | ||||
|  | ||||
| class EmeterStatus(dict): | ||||
|     """Container for converting different representations of emeter data. | ||||
|  | ||||
|   | ||||
| @@ -3,6 +3,7 @@ import glob | ||||
| import json | ||||
| import os | ||||
| from os.path import basename | ||||
| from unittest.mock import MagicMock | ||||
|  | ||||
| import pytest  # type: ignore # see https://github.com/pytest-dev/pytest/issues/3342 | ||||
|  | ||||
| @@ -151,3 +152,14 @@ def pytest_collection_modifyitems(config, items): | ||||
|         return | ||||
|     else: | ||||
|         print("Running against ip %s" % config.getoption("--ip")) | ||||
|  | ||||
|  | ||||
| # allow mocks to be awaited | ||||
| # https://stackoverflow.com/questions/51394411/python-object-magicmock-cant-be-used-in-await-expression/51399767#51399767 | ||||
|  | ||||
|  | ||||
| async def async_magic(): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| MagicMock.__await__ = lambda x: async_magic().__await__() | ||||
|   | ||||
| @@ -1,73 +1,95 @@ | ||||
| import json | ||||
| from unittest import TestCase | ||||
|  | ||||
| import pytest | ||||
|  | ||||
| from ..exceptions import SmartDeviceException | ||||
| from ..protocol import TPLinkSmartHomeProtocol | ||||
|  | ||||
|  | ||||
| class TestTPLinkSmartHomeProtocol(TestCase): | ||||
|     def test_encrypt(self): | ||||
|         d = json.dumps({"foo": 1, "bar": 2}) | ||||
|         encrypted = TPLinkSmartHomeProtocol.encrypt(d) | ||||
|         # encrypt adds a 4 byte header | ||||
|         encrypted = encrypted[4:] | ||||
|         self.assertEqual(d, TPLinkSmartHomeProtocol.decrypt(encrypted)) | ||||
| @pytest.mark.parametrize("retry_count", [1, 3, 5]) | ||||
| async def test_protocol_retries(mocker, retry_count): | ||||
|     def aio_mock_writer(_, __): | ||||
|         reader = mocker.patch("asyncio.StreamReader") | ||||
|         writer = mocker.patch("asyncio.StreamWriter") | ||||
|  | ||||
|     def test_encrypt_unicode(self): | ||||
|         d = "{'snowman': '\u2603'}" | ||||
|  | ||||
|         e = bytes( | ||||
|             [ | ||||
|                 208, | ||||
|                 247, | ||||
|                 132, | ||||
|                 234, | ||||
|                 133, | ||||
|                 242, | ||||
|                 159, | ||||
|                 254, | ||||
|                 144, | ||||
|                 183, | ||||
|                 141, | ||||
|                 173, | ||||
|                 138, | ||||
|                 104, | ||||
|                 240, | ||||
|                 115, | ||||
|                 84, | ||||
|                 41, | ||||
|             ] | ||||
|         mocker.patch( | ||||
|             "asyncio.StreamWriter.write", side_effect=Exception("dummy exception") | ||||
|         ) | ||||
|  | ||||
|         encrypted = TPLinkSmartHomeProtocol.encrypt(d) | ||||
|         # encrypt adds a 4 byte header | ||||
|         encrypted = encrypted[4:] | ||||
|         return reader, writer | ||||
|  | ||||
|         self.assertEqual(e, encrypted) | ||||
|     conn = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) | ||||
|     with pytest.raises(SmartDeviceException): | ||||
|         await TPLinkSmartHomeProtocol.query("127.0.0.1", {}, retry_count=retry_count) | ||||
|  | ||||
|     def test_decrypt_unicode(self): | ||||
|         e = bytes( | ||||
|             [ | ||||
|                 208, | ||||
|                 247, | ||||
|                 132, | ||||
|                 234, | ||||
|                 133, | ||||
|                 242, | ||||
|                 159, | ||||
|                 254, | ||||
|                 144, | ||||
|                 183, | ||||
|                 141, | ||||
|                 173, | ||||
|                 138, | ||||
|                 104, | ||||
|                 240, | ||||
|                 115, | ||||
|                 84, | ||||
|                 41, | ||||
|             ] | ||||
|         ) | ||||
|     assert conn.call_count == retry_count + 1 | ||||
|  | ||||
|         d = "{'snowman': '\u2603'}" | ||||
|  | ||||
|         self.assertEqual(d, TPLinkSmartHomeProtocol.decrypt(e)) | ||||
| def test_encrypt(): | ||||
|     d = json.dumps({"foo": 1, "bar": 2}) | ||||
|     encrypted = TPLinkSmartHomeProtocol.encrypt(d) | ||||
|     # encrypt adds a 4 byte header | ||||
|     encrypted = encrypted[4:] | ||||
|     assert d == TPLinkSmartHomeProtocol.decrypt(encrypted) | ||||
|  | ||||
|  | ||||
| def test_encrypt_unicode(): | ||||
|     d = "{'snowman': '\u2603'}" | ||||
|  | ||||
|     e = bytes( | ||||
|         [ | ||||
|             208, | ||||
|             247, | ||||
|             132, | ||||
|             234, | ||||
|             133, | ||||
|             242, | ||||
|             159, | ||||
|             254, | ||||
|             144, | ||||
|             183, | ||||
|             141, | ||||
|             173, | ||||
|             138, | ||||
|             104, | ||||
|             240, | ||||
|             115, | ||||
|             84, | ||||
|             41, | ||||
|         ] | ||||
|     ) | ||||
|  | ||||
|     encrypted = TPLinkSmartHomeProtocol.encrypt(d) | ||||
|     # encrypt adds a 4 byte header | ||||
|     encrypted = encrypted[4:] | ||||
|  | ||||
|     assert e == encrypted | ||||
|  | ||||
|  | ||||
| def test_decrypt_unicode(): | ||||
|     e = bytes( | ||||
|         [ | ||||
|             208, | ||||
|             247, | ||||
|             132, | ||||
|             234, | ||||
|             133, | ||||
|             242, | ||||
|             159, | ||||
|             254, | ||||
|             144, | ||||
|             183, | ||||
|             141, | ||||
|             173, | ||||
|             138, | ||||
|             104, | ||||
|             240, | ||||
|             115, | ||||
|             84, | ||||
|             41, | ||||
|         ] | ||||
|     ) | ||||
|  | ||||
|     d = "{'snowman': '\u2603'}" | ||||
|  | ||||
|     assert d == TPLinkSmartHomeProtocol.decrypt(e) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Teemu R
					Teemu R