mirror of
https://github.com/python-kasa/python-kasa.git
synced 2024-12-22 19:23:34 +00:00
Add retries to protocol queries (#65)
* Add retries to query(), defaults to 3 + add tests * Catch also json decoding errors for retries * add missing exceptions file, fix old protocol tests
This commit is contained in:
parent
644a10a0d1
commit
9dc0cbaece
@ -13,9 +13,10 @@ to be handled by the user of the library.
|
|||||||
"""
|
"""
|
||||||
from importlib_metadata import version # type: ignore
|
from importlib_metadata import version # type: ignore
|
||||||
from kasa.discover import Discover
|
from kasa.discover import Discover
|
||||||
|
from kasa.exceptions import SmartDeviceException
|
||||||
from kasa.protocol import TPLinkSmartHomeProtocol
|
from kasa.protocol import TPLinkSmartHomeProtocol
|
||||||
from kasa.smartbulb import SmartBulb
|
from kasa.smartbulb import SmartBulb
|
||||||
from kasa.smartdevice import DeviceType, EmeterStatus, SmartDevice, SmartDeviceException
|
from kasa.smartdevice import DeviceType, EmeterStatus, SmartDevice
|
||||||
from kasa.smartdimmer import SmartDimmer
|
from kasa.smartdimmer import SmartDimmer
|
||||||
from kasa.smartplug import SmartPlug
|
from kasa.smartplug import SmartPlug
|
||||||
from kasa.smartstrip import SmartStrip
|
from kasa.smartstrip import SmartStrip
|
||||||
|
5
kasa/exceptions.py
Normal file
5
kasa/exceptions.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
"""python-kasa exceptions."""
|
||||||
|
|
||||||
|
|
||||||
|
class SmartDeviceException(Exception):
|
||||||
|
"""Base exception for device errors."""
|
@ -16,6 +16,8 @@ import struct
|
|||||||
from pprint import pformat as pf
|
from pprint import pformat as pf
|
||||||
from typing import Dict, Union
|
from typing import Dict, Union
|
||||||
|
|
||||||
|
from .exceptions import SmartDeviceException
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -27,12 +29,13 @@ class TPLinkSmartHomeProtocol:
|
|||||||
DEFAULT_TIMEOUT = 5
|
DEFAULT_TIMEOUT = 5
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def query(host: str, request: Union[str, Dict]) -> Dict:
|
async def query(host: str, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
||||||
"""Request information from a TP-Link SmartHome Device.
|
"""Request information from a TP-Link SmartHome Device.
|
||||||
|
|
||||||
:param str host: host name or ip address of the device
|
:param str host: host name or ip address of the device
|
||||||
:param request: command to send to the device (can be either dict or
|
:param request: command to send to the device (can be either dict or
|
||||||
json string)
|
json string)
|
||||||
|
:param retry_count: how many retries to do in case of failure
|
||||||
:return: response dict
|
:return: response dict
|
||||||
"""
|
"""
|
||||||
if isinstance(request, dict):
|
if isinstance(request, dict):
|
||||||
@ -40,35 +43,51 @@ class TPLinkSmartHomeProtocol:
|
|||||||
|
|
||||||
timeout = TPLinkSmartHomeProtocol.DEFAULT_TIMEOUT
|
timeout = TPLinkSmartHomeProtocol.DEFAULT_TIMEOUT
|
||||||
writer = None
|
writer = None
|
||||||
try:
|
for retry in range(retry_count + 1):
|
||||||
task = asyncio.open_connection(host, TPLinkSmartHomeProtocol.DEFAULT_PORT)
|
try:
|
||||||
reader, writer = await asyncio.wait_for(task, timeout=timeout)
|
task = asyncio.open_connection(
|
||||||
_LOGGER.debug("> (%i) %s", len(request), request)
|
host, TPLinkSmartHomeProtocol.DEFAULT_PORT
|
||||||
writer.write(TPLinkSmartHomeProtocol.encrypt(request))
|
)
|
||||||
await writer.drain()
|
reader, writer = await asyncio.wait_for(task, timeout=timeout)
|
||||||
|
_LOGGER.debug("> (%i) %s", len(request), request)
|
||||||
|
writer.write(TPLinkSmartHomeProtocol.encrypt(request))
|
||||||
|
await writer.drain()
|
||||||
|
|
||||||
buffer = bytes()
|
buffer = bytes()
|
||||||
# Some devices send responses with a length header of 0 and
|
# Some devices send responses with a length header of 0 and
|
||||||
# terminate with a zero size chunk. Others send the length and
|
# terminate with a zero size chunk. Others send the length and
|
||||||
# will hang if we attempt to read more data.
|
# will hang if we attempt to read more data.
|
||||||
length = -1
|
length = -1
|
||||||
while True:
|
while True:
|
||||||
chunk = await reader.read(4096)
|
chunk = await reader.read(4096)
|
||||||
if length == -1:
|
if length == -1:
|
||||||
length = struct.unpack(">I", chunk[0:4])[0]
|
length = struct.unpack(">I", chunk[0:4])[0]
|
||||||
buffer += chunk
|
buffer += chunk
|
||||||
if (length > 0 and len(buffer) >= length + 4) or not chunk:
|
if (length > 0 and len(buffer) >= length + 4) or not chunk:
|
||||||
break
|
break
|
||||||
finally:
|
|
||||||
if writer:
|
|
||||||
writer.close()
|
|
||||||
await writer.wait_closed()
|
|
||||||
|
|
||||||
response = TPLinkSmartHomeProtocol.decrypt(buffer[4:])
|
response = TPLinkSmartHomeProtocol.decrypt(buffer[4:])
|
||||||
json_payload = json.loads(response)
|
json_payload = json.loads(response)
|
||||||
_LOGGER.debug("< (%i) %s", len(response), pf(json_payload))
|
_LOGGER.debug("< (%i) %s", len(response), pf(json_payload))
|
||||||
|
|
||||||
return json_payload
|
return json_payload
|
||||||
|
|
||||||
|
except Exception as ex:
|
||||||
|
if retry >= retry_count:
|
||||||
|
_LOGGER.debug("Giving up after %s retries", retry)
|
||||||
|
raise SmartDeviceException(
|
||||||
|
"Unable to query the device: %s" % ex
|
||||||
|
) from ex
|
||||||
|
|
||||||
|
_LOGGER.debug("Unable to query the device, retrying: %s", ex)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if writer:
|
||||||
|
writer.close()
|
||||||
|
await writer.wait_closed()
|
||||||
|
|
||||||
|
# make mypy happy, this should never be reached..
|
||||||
|
raise SmartDeviceException("Query reached somehow to unreachable")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def encrypt(request: str) -> bytes:
|
def encrypt(request: str) -> bytes:
|
||||||
|
@ -19,7 +19,8 @@ from datetime import datetime, timedelta
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from kasa.protocol import TPLinkSmartHomeProtocol
|
from .exceptions import SmartDeviceException
|
||||||
|
from .protocol import TPLinkSmartHomeProtocol
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -47,10 +48,6 @@ class WifiNetwork:
|
|||||||
rssi: Optional[int] = None
|
rssi: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
class SmartDeviceException(Exception):
|
|
||||||
"""Base exception for device errors."""
|
|
||||||
|
|
||||||
|
|
||||||
class EmeterStatus(dict):
|
class EmeterStatus(dict):
|
||||||
"""Container for converting different representations of emeter data.
|
"""Container for converting different representations of emeter data.
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@ import glob
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from os.path import basename
|
from os.path import basename
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest # type: ignore # see https://github.com/pytest-dev/pytest/issues/3342
|
import pytest # type: ignore # see https://github.com/pytest-dev/pytest/issues/3342
|
||||||
|
|
||||||
@ -151,3 +152,14 @@ def pytest_collection_modifyitems(config, items):
|
|||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
print("Running against ip %s" % config.getoption("--ip"))
|
print("Running against ip %s" % config.getoption("--ip"))
|
||||||
|
|
||||||
|
|
||||||
|
# allow mocks to be awaited
|
||||||
|
# https://stackoverflow.com/questions/51394411/python-object-magicmock-cant-be-used-in-await-expression/51399767#51399767
|
||||||
|
|
||||||
|
|
||||||
|
async def async_magic():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
MagicMock.__await__ = lambda x: async_magic().__await__()
|
||||||
|
@ -1,73 +1,95 @@
|
|||||||
import json
|
import json
|
||||||
from unittest import TestCase
|
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from ..exceptions import SmartDeviceException
|
||||||
from ..protocol import TPLinkSmartHomeProtocol
|
from ..protocol import TPLinkSmartHomeProtocol
|
||||||
|
|
||||||
|
|
||||||
class TestTPLinkSmartHomeProtocol(TestCase):
|
@pytest.mark.parametrize("retry_count", [1, 3, 5])
|
||||||
def test_encrypt(self):
|
async def test_protocol_retries(mocker, retry_count):
|
||||||
d = json.dumps({"foo": 1, "bar": 2})
|
def aio_mock_writer(_, __):
|
||||||
encrypted = TPLinkSmartHomeProtocol.encrypt(d)
|
reader = mocker.patch("asyncio.StreamReader")
|
||||||
# encrypt adds a 4 byte header
|
writer = mocker.patch("asyncio.StreamWriter")
|
||||||
encrypted = encrypted[4:]
|
|
||||||
self.assertEqual(d, TPLinkSmartHomeProtocol.decrypt(encrypted))
|
|
||||||
|
|
||||||
def test_encrypt_unicode(self):
|
mocker.patch(
|
||||||
d = "{'snowman': '\u2603'}"
|
"asyncio.StreamWriter.write", side_effect=Exception("dummy exception")
|
||||||
|
|
||||||
e = bytes(
|
|
||||||
[
|
|
||||||
208,
|
|
||||||
247,
|
|
||||||
132,
|
|
||||||
234,
|
|
||||||
133,
|
|
||||||
242,
|
|
||||||
159,
|
|
||||||
254,
|
|
||||||
144,
|
|
||||||
183,
|
|
||||||
141,
|
|
||||||
173,
|
|
||||||
138,
|
|
||||||
104,
|
|
||||||
240,
|
|
||||||
115,
|
|
||||||
84,
|
|
||||||
41,
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
encrypted = TPLinkSmartHomeProtocol.encrypt(d)
|
return reader, writer
|
||||||
# encrypt adds a 4 byte header
|
|
||||||
encrypted = encrypted[4:]
|
|
||||||
|
|
||||||
self.assertEqual(e, encrypted)
|
conn = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
|
||||||
|
with pytest.raises(SmartDeviceException):
|
||||||
|
await TPLinkSmartHomeProtocol.query("127.0.0.1", {}, retry_count=retry_count)
|
||||||
|
|
||||||
def test_decrypt_unicode(self):
|
assert conn.call_count == retry_count + 1
|
||||||
e = bytes(
|
|
||||||
[
|
|
||||||
208,
|
|
||||||
247,
|
|
||||||
132,
|
|
||||||
234,
|
|
||||||
133,
|
|
||||||
242,
|
|
||||||
159,
|
|
||||||
254,
|
|
||||||
144,
|
|
||||||
183,
|
|
||||||
141,
|
|
||||||
173,
|
|
||||||
138,
|
|
||||||
104,
|
|
||||||
240,
|
|
||||||
115,
|
|
||||||
84,
|
|
||||||
41,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
d = "{'snowman': '\u2603'}"
|
|
||||||
|
|
||||||
self.assertEqual(d, TPLinkSmartHomeProtocol.decrypt(e))
|
def test_encrypt():
|
||||||
|
d = json.dumps({"foo": 1, "bar": 2})
|
||||||
|
encrypted = TPLinkSmartHomeProtocol.encrypt(d)
|
||||||
|
# encrypt adds a 4 byte header
|
||||||
|
encrypted = encrypted[4:]
|
||||||
|
assert d == TPLinkSmartHomeProtocol.decrypt(encrypted)
|
||||||
|
|
||||||
|
|
||||||
|
def test_encrypt_unicode():
|
||||||
|
d = "{'snowman': '\u2603'}"
|
||||||
|
|
||||||
|
e = bytes(
|
||||||
|
[
|
||||||
|
208,
|
||||||
|
247,
|
||||||
|
132,
|
||||||
|
234,
|
||||||
|
133,
|
||||||
|
242,
|
||||||
|
159,
|
||||||
|
254,
|
||||||
|
144,
|
||||||
|
183,
|
||||||
|
141,
|
||||||
|
173,
|
||||||
|
138,
|
||||||
|
104,
|
||||||
|
240,
|
||||||
|
115,
|
||||||
|
84,
|
||||||
|
41,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
encrypted = TPLinkSmartHomeProtocol.encrypt(d)
|
||||||
|
# encrypt adds a 4 byte header
|
||||||
|
encrypted = encrypted[4:]
|
||||||
|
|
||||||
|
assert e == encrypted
|
||||||
|
|
||||||
|
|
||||||
|
def test_decrypt_unicode():
|
||||||
|
e = bytes(
|
||||||
|
[
|
||||||
|
208,
|
||||||
|
247,
|
||||||
|
132,
|
||||||
|
234,
|
||||||
|
133,
|
||||||
|
242,
|
||||||
|
159,
|
||||||
|
254,
|
||||||
|
144,
|
||||||
|
183,
|
||||||
|
141,
|
||||||
|
173,
|
||||||
|
138,
|
||||||
|
104,
|
||||||
|
240,
|
||||||
|
115,
|
||||||
|
84,
|
||||||
|
41,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
d = "{'snowman': '\u2603'}"
|
||||||
|
|
||||||
|
assert d == TPLinkSmartHomeProtocol.decrypt(e)
|
||||||
|
Loading…
Reference in New Issue
Block a user