Keep connection open and lock to prevent duplicate requests (#213)

* Keep connection open and lock to prevent duplicate requests

* option to not update children

* tweaks

* typing

* tweaks

* run tests in the same event loop

* memorize model

* Update kasa/protocol.py

Co-authored-by: Teemu R. <tpr@iki.fi>

* Update kasa/protocol.py

Co-authored-by: Teemu R. <tpr@iki.fi>

* Update kasa/protocol.py

Co-authored-by: Teemu R. <tpr@iki.fi>

* Update kasa/protocol.py

Co-authored-by: Teemu R. <tpr@iki.fi>

* dry

* tweaks

* warn when the event loop gets switched out from under us

* raise on unable to connect multiple times

* fix patch target

* tweaks

* isrot

* reconnect test

* prune

* fix mocking

* fix mocking

* fix test under python 3.7

* fix test under python 3.7

* less patching

* isort

* use mocker to patch

* disable on old python since mocking doesnt work

* avoid disconnect/reconnect cycles

* isort

* Fix hue validation

* Fix latitude_i/longitude_i units

Co-authored-by: Teemu R. <tpr@iki.fi>
This commit is contained in:
J. Nick Koston
2021-09-24 16:25:43 -05:00
committed by GitHub
parent f1b28e79b9
commit e31cc6662c
11 changed files with 241 additions and 96 deletions

View File

@@ -10,11 +10,12 @@ which are licensed under the Apache License, Version 2.0
http://www.apache.org/licenses/LICENSE-2.0
"""
import asyncio
import contextlib
import json
import logging
import struct
from pprint import pformat as pf
from typing import Dict, Union
from typing import Dict, Optional, Union
from .exceptions import SmartDeviceException
@@ -28,8 +29,26 @@ class TPLinkSmartHomeProtocol:
DEFAULT_PORT = 9999
DEFAULT_TIMEOUT = 5
@staticmethod
async def query(host: str, request: Union[str, Dict], retry_count: int = 3) -> Dict:
BLOCK_SIZE = 4
def __init__(self, host: str) -> None:
"""Create a protocol object."""
self.host = host
self.reader: Optional[asyncio.StreamReader] = None
self.writer: Optional[asyncio.StreamWriter] = None
self.query_lock: Optional[asyncio.Lock] = None
self.loop: Optional[asyncio.AbstractEventLoop] = None
def _detect_event_loop_change(self) -> None:
"""Check if this object has been reused betwen event loops."""
loop = asyncio.get_running_loop()
if not self.loop:
self.loop = loop
elif self.loop != loop:
_LOGGER.warning("Detected protocol reuse between different event loop")
self._reset()
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
"""Request information from a TP-Link SmartHome Device.
:param str host: host name or ip address of the device
@@ -38,57 +57,106 @@ class TPLinkSmartHomeProtocol:
:param retry_count: how many retries to do in case of failure
:return: response dict
"""
self._detect_event_loop_change()
if not self.query_lock:
self.query_lock = asyncio.Lock()
if isinstance(request, dict):
request = json.dumps(request)
assert isinstance(request, str)
timeout = TPLinkSmartHomeProtocol.DEFAULT_TIMEOUT
writer = None
async with self.query_lock:
return await self._query(request, retry_count, timeout)
async def _connect(self, timeout: int) -> bool:
"""Try to connect or reconnect to the device."""
if self.writer:
return True
with contextlib.suppress(Exception):
self.reader = self.writer = None
task = asyncio.open_connection(
self.host, TPLinkSmartHomeProtocol.DEFAULT_PORT
)
self.reader, self.writer = await asyncio.wait_for(task, timeout=timeout)
return True
return False
async def _execute_query(self, request: str) -> Dict:
"""Execute a query on the device and wait for the response."""
assert self.writer is not None
assert self.reader is not None
_LOGGER.debug("> (%i) %s", len(request), request)
self.writer.write(TPLinkSmartHomeProtocol.encrypt(request))
await self.writer.drain()
packed_block_size = await self.reader.readexactly(self.BLOCK_SIZE)
length = struct.unpack(">I", packed_block_size)[0]
buffer = await self.reader.readexactly(length)
response = TPLinkSmartHomeProtocol.decrypt(buffer)
json_payload = json.loads(response)
_LOGGER.debug("< (%i) %s", len(response), pf(json_payload))
return json_payload
async def close(self):
"""Close the connection."""
writer = self.writer
self._reset()
if writer:
writer.close()
with contextlib.suppress(Exception):
await writer.wait_closed()
def _reset(self):
"""Clear any varibles that should not survive between loops."""
self.writer = None
self.reader = None
self.query_lock = None
self.loop = None
async def _query(self, request: str, retry_count: int, timeout: int) -> Dict:
"""Try to query a device."""
for retry in range(retry_count + 1):
try:
task = asyncio.open_connection(
host, TPLinkSmartHomeProtocol.DEFAULT_PORT
)
reader, writer = await asyncio.wait_for(task, timeout=timeout)
_LOGGER.debug("> (%i) %s", len(request), request)
writer.write(TPLinkSmartHomeProtocol.encrypt(request))
await writer.drain()
buffer = bytes()
# Some devices send responses with a length header of 0 and
# terminate with a zero size chunk. Others send the length and
# will hang if we attempt to read more data.
length = -1
while True:
chunk = await reader.read(4096)
if length == -1:
length = struct.unpack(">I", chunk[0:4])[0]
buffer += chunk
if (length > 0 and len(buffer) >= length + 4) or not chunk:
break
response = TPLinkSmartHomeProtocol.decrypt(buffer[4:])
json_payload = json.loads(response)
_LOGGER.debug("< (%i) %s", len(response), pf(json_payload))
return json_payload
except Exception as ex:
if not await self._connect(timeout):
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up after %s retries", retry)
raise SmartDeviceException(
"Unable to query the device: %s" % ex
f"Unable to connect to the device: {self.host}"
)
continue
try:
assert self.reader is not None
assert self.writer is not None
return await asyncio.wait_for(
self._execute_query(request), timeout=timeout
)
except Exception as ex:
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up after %s retries", retry)
raise SmartDeviceException(
f"Unable to query the device: {ex}"
) from ex
_LOGGER.debug("Unable to query the device, retrying: %s", ex)
finally:
if writer:
writer.close()
await writer.wait_closed()
# make mypy happy, this should never be reached..
await self.close()
raise SmartDeviceException("Query reached somehow to unreachable")
def __del__(self):
if self.writer and self.loop and self.loop.is_running():
self.writer.close()
self._reset()
@staticmethod
def _xor_payload(unencrypted):
key = TPLinkSmartHomeProtocol.INITIALIZATION_VECTOR