mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-10-14 11:28:02 +00:00
Keep connection open and lock to prevent duplicate requests (#213)
* Keep connection open and lock to prevent duplicate requests * option to not update children * tweaks * typing * tweaks * run tests in the same event loop * memorize model * Update kasa/protocol.py Co-authored-by: Teemu R. <tpr@iki.fi> * Update kasa/protocol.py Co-authored-by: Teemu R. <tpr@iki.fi> * Update kasa/protocol.py Co-authored-by: Teemu R. <tpr@iki.fi> * Update kasa/protocol.py Co-authored-by: Teemu R. <tpr@iki.fi> * dry * tweaks * warn when the event loop gets switched out from under us * raise on unable to connect multiple times * fix patch target * tweaks * isrot * reconnect test * prune * fix mocking * fix mocking * fix test under python 3.7 * fix test under python 3.7 * less patching * isort * use mocker to patch * disable on old python since mocking doesnt work * avoid disconnect/reconnect cycles * isort * Fix hue validation * Fix latitude_i/longitude_i units Co-authored-by: Teemu R. <tpr@iki.fi>
This commit is contained in:
146
kasa/protocol.py
146
kasa/protocol.py
@@ -10,11 +10,12 @@ which are licensed under the Apache License, Version 2.0
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
"""
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
import struct
|
||||
from pprint import pformat as pf
|
||||
from typing import Dict, Union
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from .exceptions import SmartDeviceException
|
||||
|
||||
@@ -28,8 +29,26 @@ class TPLinkSmartHomeProtocol:
|
||||
DEFAULT_PORT = 9999
|
||||
DEFAULT_TIMEOUT = 5
|
||||
|
||||
@staticmethod
|
||||
async def query(host: str, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
||||
BLOCK_SIZE = 4
|
||||
|
||||
def __init__(self, host: str) -> None:
|
||||
"""Create a protocol object."""
|
||||
self.host = host
|
||||
self.reader: Optional[asyncio.StreamReader] = None
|
||||
self.writer: Optional[asyncio.StreamWriter] = None
|
||||
self.query_lock: Optional[asyncio.Lock] = None
|
||||
self.loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
|
||||
def _detect_event_loop_change(self) -> None:
|
||||
"""Check if this object has been reused betwen event loops."""
|
||||
loop = asyncio.get_running_loop()
|
||||
if not self.loop:
|
||||
self.loop = loop
|
||||
elif self.loop != loop:
|
||||
_LOGGER.warning("Detected protocol reuse between different event loop")
|
||||
self._reset()
|
||||
|
||||
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
||||
"""Request information from a TP-Link SmartHome Device.
|
||||
|
||||
:param str host: host name or ip address of the device
|
||||
@@ -38,57 +57,106 @@ class TPLinkSmartHomeProtocol:
|
||||
:param retry_count: how many retries to do in case of failure
|
||||
:return: response dict
|
||||
"""
|
||||
self._detect_event_loop_change()
|
||||
|
||||
if not self.query_lock:
|
||||
self.query_lock = asyncio.Lock()
|
||||
|
||||
if isinstance(request, dict):
|
||||
request = json.dumps(request)
|
||||
assert isinstance(request, str)
|
||||
|
||||
timeout = TPLinkSmartHomeProtocol.DEFAULT_TIMEOUT
|
||||
writer = None
|
||||
|
||||
async with self.query_lock:
|
||||
return await self._query(request, retry_count, timeout)
|
||||
|
||||
async def _connect(self, timeout: int) -> bool:
|
||||
"""Try to connect or reconnect to the device."""
|
||||
if self.writer:
|
||||
return True
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
self.reader = self.writer = None
|
||||
task = asyncio.open_connection(
|
||||
self.host, TPLinkSmartHomeProtocol.DEFAULT_PORT
|
||||
)
|
||||
self.reader, self.writer = await asyncio.wait_for(task, timeout=timeout)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _execute_query(self, request: str) -> Dict:
|
||||
"""Execute a query on the device and wait for the response."""
|
||||
assert self.writer is not None
|
||||
assert self.reader is not None
|
||||
|
||||
_LOGGER.debug("> (%i) %s", len(request), request)
|
||||
self.writer.write(TPLinkSmartHomeProtocol.encrypt(request))
|
||||
await self.writer.drain()
|
||||
|
||||
packed_block_size = await self.reader.readexactly(self.BLOCK_SIZE)
|
||||
length = struct.unpack(">I", packed_block_size)[0]
|
||||
|
||||
buffer = await self.reader.readexactly(length)
|
||||
response = TPLinkSmartHomeProtocol.decrypt(buffer)
|
||||
json_payload = json.loads(response)
|
||||
_LOGGER.debug("< (%i) %s", len(response), pf(json_payload))
|
||||
return json_payload
|
||||
|
||||
async def close(self):
|
||||
"""Close the connection."""
|
||||
writer = self.writer
|
||||
self._reset()
|
||||
if writer:
|
||||
writer.close()
|
||||
with contextlib.suppress(Exception):
|
||||
await writer.wait_closed()
|
||||
|
||||
def _reset(self):
|
||||
"""Clear any varibles that should not survive between loops."""
|
||||
self.writer = None
|
||||
self.reader = None
|
||||
self.query_lock = None
|
||||
self.loop = None
|
||||
|
||||
async def _query(self, request: str, retry_count: int, timeout: int) -> Dict:
|
||||
"""Try to query a device."""
|
||||
for retry in range(retry_count + 1):
|
||||
try:
|
||||
task = asyncio.open_connection(
|
||||
host, TPLinkSmartHomeProtocol.DEFAULT_PORT
|
||||
)
|
||||
reader, writer = await asyncio.wait_for(task, timeout=timeout)
|
||||
_LOGGER.debug("> (%i) %s", len(request), request)
|
||||
writer.write(TPLinkSmartHomeProtocol.encrypt(request))
|
||||
await writer.drain()
|
||||
|
||||
buffer = bytes()
|
||||
# Some devices send responses with a length header of 0 and
|
||||
# terminate with a zero size chunk. Others send the length and
|
||||
# will hang if we attempt to read more data.
|
||||
length = -1
|
||||
while True:
|
||||
chunk = await reader.read(4096)
|
||||
if length == -1:
|
||||
length = struct.unpack(">I", chunk[0:4])[0]
|
||||
buffer += chunk
|
||||
if (length > 0 and len(buffer) >= length + 4) or not chunk:
|
||||
break
|
||||
|
||||
response = TPLinkSmartHomeProtocol.decrypt(buffer[4:])
|
||||
json_payload = json.loads(response)
|
||||
_LOGGER.debug("< (%i) %s", len(response), pf(json_payload))
|
||||
|
||||
return json_payload
|
||||
|
||||
except Exception as ex:
|
||||
if not await self._connect(timeout):
|
||||
await self.close()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up after %s retries", retry)
|
||||
raise SmartDeviceException(
|
||||
"Unable to query the device: %s" % ex
|
||||
f"Unable to connect to the device: {self.host}"
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
assert self.reader is not None
|
||||
assert self.writer is not None
|
||||
return await asyncio.wait_for(
|
||||
self._execute_query(request), timeout=timeout
|
||||
)
|
||||
except Exception as ex:
|
||||
await self.close()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up after %s retries", retry)
|
||||
raise SmartDeviceException(
|
||||
f"Unable to query the device: {ex}"
|
||||
) from ex
|
||||
|
||||
_LOGGER.debug("Unable to query the device, retrying: %s", ex)
|
||||
|
||||
finally:
|
||||
if writer:
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
|
||||
# make mypy happy, this should never be reached..
|
||||
await self.close()
|
||||
raise SmartDeviceException("Query reached somehow to unreachable")
|
||||
|
||||
def __del__(self):
|
||||
if self.writer and self.loop and self.loop.is_running():
|
||||
self.writer.close()
|
||||
self._reset()
|
||||
|
||||
@staticmethod
|
||||
def _xor_payload(unencrypted):
|
||||
key = TPLinkSmartHomeProtocol.INITIALIZATION_VECTOR
|
||||
|
Reference in New Issue
Block a user