diff --git a/kasa/__init__.py b/kasa/__init__.py index b6c42059..e77aa7dd 100755 --- a/kasa/__init__.py +++ b/kasa/__init__.py @@ -13,9 +13,10 @@ to be handled by the user of the library. """ from importlib_metadata import version # type: ignore from kasa.discover import Discover +from kasa.exceptions import SmartDeviceException from kasa.protocol import TPLinkSmartHomeProtocol from kasa.smartbulb import SmartBulb -from kasa.smartdevice import DeviceType, EmeterStatus, SmartDevice, SmartDeviceException +from kasa.smartdevice import DeviceType, EmeterStatus, SmartDevice from kasa.smartdimmer import SmartDimmer from kasa.smartplug import SmartPlug from kasa.smartstrip import SmartStrip diff --git a/kasa/exceptions.py b/kasa/exceptions.py new file mode 100644 index 00000000..90d36c9a --- /dev/null +++ b/kasa/exceptions.py @@ -0,0 +1,5 @@ +"""python-kasa exceptions.""" + + +class SmartDeviceException(Exception): + """Base exception for device errors.""" diff --git a/kasa/protocol.py b/kasa/protocol.py index 443a428e..6ee6f72d 100755 --- a/kasa/protocol.py +++ b/kasa/protocol.py @@ -16,6 +16,8 @@ import struct from pprint import pformat as pf from typing import Dict, Union +from .exceptions import SmartDeviceException + _LOGGER = logging.getLogger(__name__) @@ -27,12 +29,13 @@ class TPLinkSmartHomeProtocol: DEFAULT_TIMEOUT = 5 @staticmethod - async def query(host: str, request: Union[str, Dict]) -> Dict: + async def query(host: str, request: Union[str, Dict], retry_count: int = 3) -> Dict: """Request information from a TP-Link SmartHome Device. :param str host: host name or ip address of the device :param request: command to send to the device (can be either dict or json string) + :param retry_count: how many retries to do in case of failure :return: response dict """ if isinstance(request, dict): @@ -40,35 +43,51 @@ class TPLinkSmartHomeProtocol: timeout = TPLinkSmartHomeProtocol.DEFAULT_TIMEOUT writer = None - try: - task = asyncio.open_connection(host, TPLinkSmartHomeProtocol.DEFAULT_PORT) - reader, writer = await asyncio.wait_for(task, timeout=timeout) - _LOGGER.debug("> (%i) %s", len(request), request) - writer.write(TPLinkSmartHomeProtocol.encrypt(request)) - await writer.drain() + for retry in range(retry_count + 1): + try: + task = asyncio.open_connection( + host, TPLinkSmartHomeProtocol.DEFAULT_PORT + ) + reader, writer = await asyncio.wait_for(task, timeout=timeout) + _LOGGER.debug("> (%i) %s", len(request), request) + writer.write(TPLinkSmartHomeProtocol.encrypt(request)) + await writer.drain() - buffer = bytes() - # Some devices send responses with a length header of 0 and - # terminate with a zero size chunk. Others send the length and - # will hang if we attempt to read more data. - length = -1 - while True: - chunk = await reader.read(4096) - if length == -1: - length = struct.unpack(">I", chunk[0:4])[0] - buffer += chunk - if (length > 0 and len(buffer) >= length + 4) or not chunk: - break - finally: - if writer: - writer.close() - await writer.wait_closed() + buffer = bytes() + # Some devices send responses with a length header of 0 and + # terminate with a zero size chunk. Others send the length and + # will hang if we attempt to read more data. + length = -1 + while True: + chunk = await reader.read(4096) + if length == -1: + length = struct.unpack(">I", chunk[0:4])[0] + buffer += chunk + if (length > 0 and len(buffer) >= length + 4) or not chunk: + break - response = TPLinkSmartHomeProtocol.decrypt(buffer[4:]) - json_payload = json.loads(response) - _LOGGER.debug("< (%i) %s", len(response), pf(json_payload)) + response = TPLinkSmartHomeProtocol.decrypt(buffer[4:]) + json_payload = json.loads(response) + _LOGGER.debug("< (%i) %s", len(response), pf(json_payload)) - return json_payload + return json_payload + + except Exception as ex: + if retry >= retry_count: + _LOGGER.debug("Giving up after %s retries", retry) + raise SmartDeviceException( + "Unable to query the device: %s" % ex + ) from ex + + _LOGGER.debug("Unable to query the device, retrying: %s", ex) + + finally: + if writer: + writer.close() + await writer.wait_closed() + + # make mypy happy, this should never be reached.. + raise SmartDeviceException("Query reached somehow to unreachable") @staticmethod def encrypt(request: str) -> bytes: diff --git a/kasa/smartdevice.py b/kasa/smartdevice.py index 49dc6c4a..cd2e8f5f 100755 --- a/kasa/smartdevice.py +++ b/kasa/smartdevice.py @@ -19,7 +19,8 @@ from datetime import datetime, timedelta from enum import Enum from typing import Any, Dict, List, Optional -from kasa.protocol import TPLinkSmartHomeProtocol +from .exceptions import SmartDeviceException +from .protocol import TPLinkSmartHomeProtocol _LOGGER = logging.getLogger(__name__) @@ -47,10 +48,6 @@ class WifiNetwork: rssi: Optional[int] = None -class SmartDeviceException(Exception): - """Base exception for device errors.""" - - class EmeterStatus(dict): """Container for converting different representations of emeter data. diff --git a/kasa/tests/conftest.py b/kasa/tests/conftest.py index 30e798be..f2b4c178 100644 --- a/kasa/tests/conftest.py +++ b/kasa/tests/conftest.py @@ -3,6 +3,7 @@ import glob import json import os from os.path import basename +from unittest.mock import MagicMock import pytest # type: ignore # see https://github.com/pytest-dev/pytest/issues/3342 @@ -151,3 +152,14 @@ def pytest_collection_modifyitems(config, items): return else: print("Running against ip %s" % config.getoption("--ip")) + + +# allow mocks to be awaited +# https://stackoverflow.com/questions/51394411/python-object-magicmock-cant-be-used-in-await-expression/51399767#51399767 + + +async def async_magic(): + pass + + +MagicMock.__await__ = lambda x: async_magic().__await__() diff --git a/kasa/tests/test_protocol.py b/kasa/tests/test_protocol.py index 313fd69d..0a8291e1 100644 --- a/kasa/tests/test_protocol.py +++ b/kasa/tests/test_protocol.py @@ -1,73 +1,95 @@ import json -from unittest import TestCase +import pytest + +from ..exceptions import SmartDeviceException from ..protocol import TPLinkSmartHomeProtocol -class TestTPLinkSmartHomeProtocol(TestCase): - def test_encrypt(self): - d = json.dumps({"foo": 1, "bar": 2}) - encrypted = TPLinkSmartHomeProtocol.encrypt(d) - # encrypt adds a 4 byte header - encrypted = encrypted[4:] - self.assertEqual(d, TPLinkSmartHomeProtocol.decrypt(encrypted)) +@pytest.mark.parametrize("retry_count", [1, 3, 5]) +async def test_protocol_retries(mocker, retry_count): + def aio_mock_writer(_, __): + reader = mocker.patch("asyncio.StreamReader") + writer = mocker.patch("asyncio.StreamWriter") - def test_encrypt_unicode(self): - d = "{'snowman': '\u2603'}" - - e = bytes( - [ - 208, - 247, - 132, - 234, - 133, - 242, - 159, - 254, - 144, - 183, - 141, - 173, - 138, - 104, - 240, - 115, - 84, - 41, - ] + mocker.patch( + "asyncio.StreamWriter.write", side_effect=Exception("dummy exception") ) - encrypted = TPLinkSmartHomeProtocol.encrypt(d) - # encrypt adds a 4 byte header - encrypted = encrypted[4:] + return reader, writer - self.assertEqual(e, encrypted) + conn = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) + with pytest.raises(SmartDeviceException): + await TPLinkSmartHomeProtocol.query("127.0.0.1", {}, retry_count=retry_count) - def test_decrypt_unicode(self): - e = bytes( - [ - 208, - 247, - 132, - 234, - 133, - 242, - 159, - 254, - 144, - 183, - 141, - 173, - 138, - 104, - 240, - 115, - 84, - 41, - ] - ) + assert conn.call_count == retry_count + 1 - d = "{'snowman': '\u2603'}" - self.assertEqual(d, TPLinkSmartHomeProtocol.decrypt(e)) +def test_encrypt(): + d = json.dumps({"foo": 1, "bar": 2}) + encrypted = TPLinkSmartHomeProtocol.encrypt(d) + # encrypt adds a 4 byte header + encrypted = encrypted[4:] + assert d == TPLinkSmartHomeProtocol.decrypt(encrypted) + + +def test_encrypt_unicode(): + d = "{'snowman': '\u2603'}" + + e = bytes( + [ + 208, + 247, + 132, + 234, + 133, + 242, + 159, + 254, + 144, + 183, + 141, + 173, + 138, + 104, + 240, + 115, + 84, + 41, + ] + ) + + encrypted = TPLinkSmartHomeProtocol.encrypt(d) + # encrypt adds a 4 byte header + encrypted = encrypted[4:] + + assert e == encrypted + + +def test_decrypt_unicode(): + e = bytes( + [ + 208, + 247, + 132, + 234, + 133, + 242, + 159, + 254, + 144, + 183, + 141, + 173, + 138, + 104, + 240, + 115, + 84, + 41, + ] + ) + + d = "{'snowman': '\u2603'}" + + assert d == TPLinkSmartHomeProtocol.decrypt(e)