Add retries to protocol queries (#65)

* Add retries to query(), defaults to 3 + add tests

* Catch also json decoding errors for retries

* add missing exceptions file, fix old protocol tests
This commit is contained in:
Teemu R 2020-05-27 19:02:09 +02:00 committed by GitHub
parent 644a10a0d1
commit 9dc0cbaece
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 150 additions and 94 deletions

View File

@ -13,9 +13,10 @@ to be handled by the user of the library.
""" """
from importlib_metadata import version # type: ignore from importlib_metadata import version # type: ignore
from kasa.discover import Discover from kasa.discover import Discover
from kasa.exceptions import SmartDeviceException
from kasa.protocol import TPLinkSmartHomeProtocol from kasa.protocol import TPLinkSmartHomeProtocol
from kasa.smartbulb import SmartBulb from kasa.smartbulb import SmartBulb
from kasa.smartdevice import DeviceType, EmeterStatus, SmartDevice, SmartDeviceException from kasa.smartdevice import DeviceType, EmeterStatus, SmartDevice
from kasa.smartdimmer import SmartDimmer from kasa.smartdimmer import SmartDimmer
from kasa.smartplug import SmartPlug from kasa.smartplug import SmartPlug
from kasa.smartstrip import SmartStrip from kasa.smartstrip import SmartStrip

5
kasa/exceptions.py Normal file
View File

@ -0,0 +1,5 @@
"""python-kasa exceptions."""
class SmartDeviceException(Exception):
"""Base exception for device errors."""

View File

@ -16,6 +16,8 @@ import struct
from pprint import pformat as pf from pprint import pformat as pf
from typing import Dict, Union from typing import Dict, Union
from .exceptions import SmartDeviceException
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -27,12 +29,13 @@ class TPLinkSmartHomeProtocol:
DEFAULT_TIMEOUT = 5 DEFAULT_TIMEOUT = 5
@staticmethod @staticmethod
async def query(host: str, request: Union[str, Dict]) -> Dict: async def query(host: str, request: Union[str, Dict], retry_count: int = 3) -> Dict:
"""Request information from a TP-Link SmartHome Device. """Request information from a TP-Link SmartHome Device.
:param str host: host name or ip address of the device :param str host: host name or ip address of the device
:param request: command to send to the device (can be either dict or :param request: command to send to the device (can be either dict or
json string) json string)
:param retry_count: how many retries to do in case of failure
:return: response dict :return: response dict
""" """
if isinstance(request, dict): if isinstance(request, dict):
@ -40,35 +43,51 @@ class TPLinkSmartHomeProtocol:
timeout = TPLinkSmartHomeProtocol.DEFAULT_TIMEOUT timeout = TPLinkSmartHomeProtocol.DEFAULT_TIMEOUT
writer = None writer = None
try: for retry in range(retry_count + 1):
task = asyncio.open_connection(host, TPLinkSmartHomeProtocol.DEFAULT_PORT) try:
reader, writer = await asyncio.wait_for(task, timeout=timeout) task = asyncio.open_connection(
_LOGGER.debug("> (%i) %s", len(request), request) host, TPLinkSmartHomeProtocol.DEFAULT_PORT
writer.write(TPLinkSmartHomeProtocol.encrypt(request)) )
await writer.drain() reader, writer = await asyncio.wait_for(task, timeout=timeout)
_LOGGER.debug("> (%i) %s", len(request), request)
writer.write(TPLinkSmartHomeProtocol.encrypt(request))
await writer.drain()
buffer = bytes() buffer = bytes()
# Some devices send responses with a length header of 0 and # Some devices send responses with a length header of 0 and
# terminate with a zero size chunk. Others send the length and # terminate with a zero size chunk. Others send the length and
# will hang if we attempt to read more data. # will hang if we attempt to read more data.
length = -1 length = -1
while True: while True:
chunk = await reader.read(4096) chunk = await reader.read(4096)
if length == -1: if length == -1:
length = struct.unpack(">I", chunk[0:4])[0] length = struct.unpack(">I", chunk[0:4])[0]
buffer += chunk buffer += chunk
if (length > 0 and len(buffer) >= length + 4) or not chunk: if (length > 0 and len(buffer) >= length + 4) or not chunk:
break break
finally:
if writer:
writer.close()
await writer.wait_closed()
response = TPLinkSmartHomeProtocol.decrypt(buffer[4:]) response = TPLinkSmartHomeProtocol.decrypt(buffer[4:])
json_payload = json.loads(response) json_payload = json.loads(response)
_LOGGER.debug("< (%i) %s", len(response), pf(json_payload)) _LOGGER.debug("< (%i) %s", len(response), pf(json_payload))
return json_payload return json_payload
except Exception as ex:
if retry >= retry_count:
_LOGGER.debug("Giving up after %s retries", retry)
raise SmartDeviceException(
"Unable to query the device: %s" % ex
) from ex
_LOGGER.debug("Unable to query the device, retrying: %s", ex)
finally:
if writer:
writer.close()
await writer.wait_closed()
# make mypy happy, this should never be reached..
raise SmartDeviceException("Query reached somehow to unreachable")
@staticmethod @staticmethod
def encrypt(request: str) -> bytes: def encrypt(request: str) -> bytes:

View File

@ -19,7 +19,8 @@ from datetime import datetime, timedelta
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from kasa.protocol import TPLinkSmartHomeProtocol from .exceptions import SmartDeviceException
from .protocol import TPLinkSmartHomeProtocol
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -47,10 +48,6 @@ class WifiNetwork:
rssi: Optional[int] = None rssi: Optional[int] = None
class SmartDeviceException(Exception):
"""Base exception for device errors."""
class EmeterStatus(dict): class EmeterStatus(dict):
"""Container for converting different representations of emeter data. """Container for converting different representations of emeter data.

View File

@ -3,6 +3,7 @@ import glob
import json import json
import os import os
from os.path import basename from os.path import basename
from unittest.mock import MagicMock
import pytest # type: ignore # see https://github.com/pytest-dev/pytest/issues/3342 import pytest # type: ignore # see https://github.com/pytest-dev/pytest/issues/3342
@ -151,3 +152,14 @@ def pytest_collection_modifyitems(config, items):
return return
else: else:
print("Running against ip %s" % config.getoption("--ip")) print("Running against ip %s" % config.getoption("--ip"))
# allow mocks to be awaited
# https://stackoverflow.com/questions/51394411/python-object-magicmock-cant-be-used-in-await-expression/51399767#51399767
async def async_magic():
pass
MagicMock.__await__ = lambda x: async_magic().__await__()

View File

@ -1,73 +1,95 @@
import json import json
from unittest import TestCase
import pytest
from ..exceptions import SmartDeviceException
from ..protocol import TPLinkSmartHomeProtocol from ..protocol import TPLinkSmartHomeProtocol
class TestTPLinkSmartHomeProtocol(TestCase): @pytest.mark.parametrize("retry_count", [1, 3, 5])
def test_encrypt(self): async def test_protocol_retries(mocker, retry_count):
d = json.dumps({"foo": 1, "bar": 2}) def aio_mock_writer(_, __):
encrypted = TPLinkSmartHomeProtocol.encrypt(d) reader = mocker.patch("asyncio.StreamReader")
# encrypt adds a 4 byte header writer = mocker.patch("asyncio.StreamWriter")
encrypted = encrypted[4:]
self.assertEqual(d, TPLinkSmartHomeProtocol.decrypt(encrypted))
def test_encrypt_unicode(self): mocker.patch(
d = "{'snowman': '\u2603'}" "asyncio.StreamWriter.write", side_effect=Exception("dummy exception")
e = bytes(
[
208,
247,
132,
234,
133,
242,
159,
254,
144,
183,
141,
173,
138,
104,
240,
115,
84,
41,
]
) )
encrypted = TPLinkSmartHomeProtocol.encrypt(d) return reader, writer
# encrypt adds a 4 byte header
encrypted = encrypted[4:]
self.assertEqual(e, encrypted) conn = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
with pytest.raises(SmartDeviceException):
await TPLinkSmartHomeProtocol.query("127.0.0.1", {}, retry_count=retry_count)
def test_decrypt_unicode(self): assert conn.call_count == retry_count + 1
e = bytes(
[
208,
247,
132,
234,
133,
242,
159,
254,
144,
183,
141,
173,
138,
104,
240,
115,
84,
41,
]
)
d = "{'snowman': '\u2603'}"
self.assertEqual(d, TPLinkSmartHomeProtocol.decrypt(e)) def test_encrypt():
d = json.dumps({"foo": 1, "bar": 2})
encrypted = TPLinkSmartHomeProtocol.encrypt(d)
# encrypt adds a 4 byte header
encrypted = encrypted[4:]
assert d == TPLinkSmartHomeProtocol.decrypt(encrypted)
def test_encrypt_unicode():
d = "{'snowman': '\u2603'}"
e = bytes(
[
208,
247,
132,
234,
133,
242,
159,
254,
144,
183,
141,
173,
138,
104,
240,
115,
84,
41,
]
)
encrypted = TPLinkSmartHomeProtocol.encrypt(d)
# encrypt adds a 4 byte header
encrypted = encrypted[4:]
assert e == encrypted
def test_decrypt_unicode():
e = bytes(
[
208,
247,
132,
234,
133,
242,
159,
254,
144,
183,
141,
173,
138,
104,
240,
115,
84,
41,
]
)
d = "{'snowman': '\u2603'}"
assert d == TPLinkSmartHomeProtocol.decrypt(e)