mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-01-08 22:07:06 +00:00
Add retries to protocol queries (#65)
* Add retries to query(), defaults to 3 + add tests * Catch also json decoding errors for retries * add missing exceptions file, fix old protocol tests
This commit is contained in:
parent
644a10a0d1
commit
9dc0cbaece
@ -13,9 +13,10 @@ to be handled by the user of the library.
|
||||
"""
|
||||
from importlib_metadata import version # type: ignore
|
||||
from kasa.discover import Discover
|
||||
from kasa.exceptions import SmartDeviceException
|
||||
from kasa.protocol import TPLinkSmartHomeProtocol
|
||||
from kasa.smartbulb import SmartBulb
|
||||
from kasa.smartdevice import DeviceType, EmeterStatus, SmartDevice, SmartDeviceException
|
||||
from kasa.smartdevice import DeviceType, EmeterStatus, SmartDevice
|
||||
from kasa.smartdimmer import SmartDimmer
|
||||
from kasa.smartplug import SmartPlug
|
||||
from kasa.smartstrip import SmartStrip
|
||||
|
5
kasa/exceptions.py
Normal file
5
kasa/exceptions.py
Normal file
@ -0,0 +1,5 @@
|
||||
"""python-kasa exceptions."""
|
||||
|
||||
|
||||
class SmartDeviceException(Exception):
|
||||
"""Base exception for device errors."""
|
@ -16,6 +16,8 @@ import struct
|
||||
from pprint import pformat as pf
|
||||
from typing import Dict, Union
|
||||
|
||||
from .exceptions import SmartDeviceException
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -27,12 +29,13 @@ class TPLinkSmartHomeProtocol:
|
||||
DEFAULT_TIMEOUT = 5
|
||||
|
||||
@staticmethod
|
||||
async def query(host: str, request: Union[str, Dict]) -> Dict:
|
||||
async def query(host: str, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
||||
"""Request information from a TP-Link SmartHome Device.
|
||||
|
||||
:param str host: host name or ip address of the device
|
||||
:param request: command to send to the device (can be either dict or
|
||||
json string)
|
||||
:param retry_count: how many retries to do in case of failure
|
||||
:return: response dict
|
||||
"""
|
||||
if isinstance(request, dict):
|
||||
@ -40,8 +43,11 @@ class TPLinkSmartHomeProtocol:
|
||||
|
||||
timeout = TPLinkSmartHomeProtocol.DEFAULT_TIMEOUT
|
||||
writer = None
|
||||
for retry in range(retry_count + 1):
|
||||
try:
|
||||
task = asyncio.open_connection(host, TPLinkSmartHomeProtocol.DEFAULT_PORT)
|
||||
task = asyncio.open_connection(
|
||||
host, TPLinkSmartHomeProtocol.DEFAULT_PORT
|
||||
)
|
||||
reader, writer = await asyncio.wait_for(task, timeout=timeout)
|
||||
_LOGGER.debug("> (%i) %s", len(request), request)
|
||||
writer.write(TPLinkSmartHomeProtocol.encrypt(request))
|
||||
@ -59,10 +65,6 @@ class TPLinkSmartHomeProtocol:
|
||||
buffer += chunk
|
||||
if (length > 0 and len(buffer) >= length + 4) or not chunk:
|
||||
break
|
||||
finally:
|
||||
if writer:
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
|
||||
response = TPLinkSmartHomeProtocol.decrypt(buffer[4:])
|
||||
json_payload = json.loads(response)
|
||||
@ -70,6 +72,23 @@ class TPLinkSmartHomeProtocol:
|
||||
|
||||
return json_payload
|
||||
|
||||
except Exception as ex:
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up after %s retries", retry)
|
||||
raise SmartDeviceException(
|
||||
"Unable to query the device: %s" % ex
|
||||
) from ex
|
||||
|
||||
_LOGGER.debug("Unable to query the device, retrying: %s", ex)
|
||||
|
||||
finally:
|
||||
if writer:
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
|
||||
# make mypy happy, this should never be reached..
|
||||
raise SmartDeviceException("Query reached somehow to unreachable")
|
||||
|
||||
@staticmethod
|
||||
def encrypt(request: str) -> bytes:
|
||||
"""Encrypt a request for a TP-Link Smart Home Device.
|
||||
|
@ -19,7 +19,8 @@ from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from kasa.protocol import TPLinkSmartHomeProtocol
|
||||
from .exceptions import SmartDeviceException
|
||||
from .protocol import TPLinkSmartHomeProtocol
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@ -47,10 +48,6 @@ class WifiNetwork:
|
||||
rssi: Optional[int] = None
|
||||
|
||||
|
||||
class SmartDeviceException(Exception):
|
||||
"""Base exception for device errors."""
|
||||
|
||||
|
||||
class EmeterStatus(dict):
|
||||
"""Container for converting different representations of emeter data.
|
||||
|
||||
|
@ -3,6 +3,7 @@ import glob
|
||||
import json
|
||||
import os
|
||||
from os.path import basename
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest # type: ignore # see https://github.com/pytest-dev/pytest/issues/3342
|
||||
|
||||
@ -151,3 +152,14 @@ def pytest_collection_modifyitems(config, items):
|
||||
return
|
||||
else:
|
||||
print("Running against ip %s" % config.getoption("--ip"))
|
||||
|
||||
|
||||
# allow mocks to be awaited
|
||||
# https://stackoverflow.com/questions/51394411/python-object-magicmock-cant-be-used-in-await-expression/51399767#51399767
|
||||
|
||||
|
||||
async def async_magic():
|
||||
pass
|
||||
|
||||
|
||||
MagicMock.__await__ = lambda x: async_magic().__await__()
|
||||
|
@ -1,18 +1,39 @@
|
||||
import json
|
||||
from unittest import TestCase
|
||||
|
||||
import pytest
|
||||
|
||||
from ..exceptions import SmartDeviceException
|
||||
from ..protocol import TPLinkSmartHomeProtocol
|
||||
|
||||
|
||||
class TestTPLinkSmartHomeProtocol(TestCase):
|
||||
def test_encrypt(self):
|
||||
@pytest.mark.parametrize("retry_count", [1, 3, 5])
|
||||
async def test_protocol_retries(mocker, retry_count):
|
||||
def aio_mock_writer(_, __):
|
||||
reader = mocker.patch("asyncio.StreamReader")
|
||||
writer = mocker.patch("asyncio.StreamWriter")
|
||||
|
||||
mocker.patch(
|
||||
"asyncio.StreamWriter.write", side_effect=Exception("dummy exception")
|
||||
)
|
||||
|
||||
return reader, writer
|
||||
|
||||
conn = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await TPLinkSmartHomeProtocol.query("127.0.0.1", {}, retry_count=retry_count)
|
||||
|
||||
assert conn.call_count == retry_count + 1
|
||||
|
||||
|
||||
def test_encrypt():
|
||||
d = json.dumps({"foo": 1, "bar": 2})
|
||||
encrypted = TPLinkSmartHomeProtocol.encrypt(d)
|
||||
# encrypt adds a 4 byte header
|
||||
encrypted = encrypted[4:]
|
||||
self.assertEqual(d, TPLinkSmartHomeProtocol.decrypt(encrypted))
|
||||
assert d == TPLinkSmartHomeProtocol.decrypt(encrypted)
|
||||
|
||||
def test_encrypt_unicode(self):
|
||||
|
||||
def test_encrypt_unicode():
|
||||
d = "{'snowman': '\u2603'}"
|
||||
|
||||
e = bytes(
|
||||
@ -42,9 +63,10 @@ class TestTPLinkSmartHomeProtocol(TestCase):
|
||||
# encrypt adds a 4 byte header
|
||||
encrypted = encrypted[4:]
|
||||
|
||||
self.assertEqual(e, encrypted)
|
||||
assert e == encrypted
|
||||
|
||||
def test_decrypt_unicode(self):
|
||||
|
||||
def test_decrypt_unicode():
|
||||
e = bytes(
|
||||
[
|
||||
208,
|
||||
@ -70,4 +92,4 @@ class TestTPLinkSmartHomeProtocol(TestCase):
|
||||
|
||||
d = "{'snowman': '\u2603'}"
|
||||
|
||||
self.assertEqual(d, TPLinkSmartHomeProtocol.decrypt(e))
|
||||
assert d == TPLinkSmartHomeProtocol.decrypt(e)
|
||||
|
Loading…
Reference in New Issue
Block a user