Add retries to protocol queries (#65)

* Add retries to query(), defaults to 3 + add tests

* Catch also json decoding errors for retries

* add missing exceptions file, fix old protocol tests
This commit is contained in:
Teemu R 2020-05-27 19:02:09 +02:00 committed by GitHub
parent 644a10a0d1
commit 9dc0cbaece
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 150 additions and 94 deletions

View File

@ -13,9 +13,10 @@ to be handled by the user of the library.
"""
from importlib_metadata import version # type: ignore
from kasa.discover import Discover
from kasa.exceptions import SmartDeviceException
from kasa.protocol import TPLinkSmartHomeProtocol
from kasa.smartbulb import SmartBulb
from kasa.smartdevice import DeviceType, EmeterStatus, SmartDevice, SmartDeviceException
from kasa.smartdevice import DeviceType, EmeterStatus, SmartDevice
from kasa.smartdimmer import SmartDimmer
from kasa.smartplug import SmartPlug
from kasa.smartstrip import SmartStrip

5
kasa/exceptions.py Normal file
View File

@ -0,0 +1,5 @@
"""python-kasa exceptions."""
class SmartDeviceException(Exception):
"""Base exception for device errors."""

View File

@ -16,6 +16,8 @@ import struct
from pprint import pformat as pf
from typing import Dict, Union
from .exceptions import SmartDeviceException
_LOGGER = logging.getLogger(__name__)
@ -27,12 +29,13 @@ class TPLinkSmartHomeProtocol:
DEFAULT_TIMEOUT = 5
@staticmethod
async def query(host: str, request: Union[str, Dict]) -> Dict:
async def query(host: str, request: Union[str, Dict], retry_count: int = 3) -> Dict:
"""Request information from a TP-Link SmartHome Device.
:param str host: host name or ip address of the device
:param request: command to send to the device (can be either dict or
json string)
:param retry_count: how many retries to do in case of failure
:return: response dict
"""
if isinstance(request, dict):
@ -40,8 +43,11 @@ class TPLinkSmartHomeProtocol:
timeout = TPLinkSmartHomeProtocol.DEFAULT_TIMEOUT
writer = None
for retry in range(retry_count + 1):
try:
task = asyncio.open_connection(host, TPLinkSmartHomeProtocol.DEFAULT_PORT)
task = asyncio.open_connection(
host, TPLinkSmartHomeProtocol.DEFAULT_PORT
)
reader, writer = await asyncio.wait_for(task, timeout=timeout)
_LOGGER.debug("> (%i) %s", len(request), request)
writer.write(TPLinkSmartHomeProtocol.encrypt(request))
@ -59,10 +65,6 @@ class TPLinkSmartHomeProtocol:
buffer += chunk
if (length > 0 and len(buffer) >= length + 4) or not chunk:
break
finally:
if writer:
writer.close()
await writer.wait_closed()
response = TPLinkSmartHomeProtocol.decrypt(buffer[4:])
json_payload = json.loads(response)
@ -70,6 +72,23 @@ class TPLinkSmartHomeProtocol:
return json_payload
except Exception as ex:
if retry >= retry_count:
_LOGGER.debug("Giving up after %s retries", retry)
raise SmartDeviceException(
"Unable to query the device: %s" % ex
) from ex
_LOGGER.debug("Unable to query the device, retrying: %s", ex)
finally:
if writer:
writer.close()
await writer.wait_closed()
# make mypy happy, this should never be reached..
raise SmartDeviceException("Query reached somehow to unreachable")
@staticmethod
def encrypt(request: str) -> bytes:
"""Encrypt a request for a TP-Link Smart Home Device.

View File

@ -19,7 +19,8 @@ from datetime import datetime, timedelta
from enum import Enum
from typing import Any, Dict, List, Optional
from kasa.protocol import TPLinkSmartHomeProtocol
from .exceptions import SmartDeviceException
from .protocol import TPLinkSmartHomeProtocol
_LOGGER = logging.getLogger(__name__)
@ -47,10 +48,6 @@ class WifiNetwork:
rssi: Optional[int] = None
class SmartDeviceException(Exception):
"""Base exception for device errors."""
class EmeterStatus(dict):
"""Container for converting different representations of emeter data.

View File

@ -3,6 +3,7 @@ import glob
import json
import os
from os.path import basename
from unittest.mock import MagicMock
import pytest # type: ignore # see https://github.com/pytest-dev/pytest/issues/3342
@ -151,3 +152,14 @@ def pytest_collection_modifyitems(config, items):
return
else:
print("Running against ip %s" % config.getoption("--ip"))
# allow mocks to be awaited
# https://stackoverflow.com/questions/51394411/python-object-magicmock-cant-be-used-in-await-expression/51399767#51399767
async def async_magic():
pass
MagicMock.__await__ = lambda x: async_magic().__await__()

View File

@ -1,18 +1,39 @@
import json
from unittest import TestCase
import pytest
from ..exceptions import SmartDeviceException
from ..protocol import TPLinkSmartHomeProtocol
class TestTPLinkSmartHomeProtocol(TestCase):
def test_encrypt(self):
@pytest.mark.parametrize("retry_count", [1, 3, 5])
async def test_protocol_retries(mocker, retry_count):
def aio_mock_writer(_, __):
reader = mocker.patch("asyncio.StreamReader")
writer = mocker.patch("asyncio.StreamWriter")
mocker.patch(
"asyncio.StreamWriter.write", side_effect=Exception("dummy exception")
)
return reader, writer
conn = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
with pytest.raises(SmartDeviceException):
await TPLinkSmartHomeProtocol.query("127.0.0.1", {}, retry_count=retry_count)
assert conn.call_count == retry_count + 1
def test_encrypt():
d = json.dumps({"foo": 1, "bar": 2})
encrypted = TPLinkSmartHomeProtocol.encrypt(d)
# encrypt adds a 4 byte header
encrypted = encrypted[4:]
self.assertEqual(d, TPLinkSmartHomeProtocol.decrypt(encrypted))
assert d == TPLinkSmartHomeProtocol.decrypt(encrypted)
def test_encrypt_unicode(self):
def test_encrypt_unicode():
d = "{'snowman': '\u2603'}"
e = bytes(
@ -42,9 +63,10 @@ class TestTPLinkSmartHomeProtocol(TestCase):
# encrypt adds a 4 byte header
encrypted = encrypted[4:]
self.assertEqual(e, encrypted)
assert e == encrypted
def test_decrypt_unicode(self):
def test_decrypt_unicode():
e = bytes(
[
208,
@ -70,4 +92,4 @@ class TestTPLinkSmartHomeProtocol(TestCase):
d = "{'snowman': '\u2603'}"
self.assertEqual(d, TPLinkSmartHomeProtocol.decrypt(e))
assert d == TPLinkSmartHomeProtocol.decrypt(e)