python-kasa/kasa/protocol.py
J. Nick Koston afd54d11d3
Add optional kasa-crypt dependency for speedups (#464)
If installed, use the optimized protocol encryption procedures implemented as a C extension by kasa-crypt (https://pypi.org/project/kasa-crypt/
2023-06-30 02:43:01 +02:00

232 lines
8.5 KiB
Python
Executable File

"""Implementation of the TP-Link Smart Home Protocol.
Encryption/Decryption methods based on the works of
Lubomir Stroetmann and Tobias Esser
https://www.softscheck.com/en/reverse-engineering-tp-link-hs110/
https://github.com/softScheck/tplink-smartplug/
which are licensed under the Apache License, Version 2.0
http://www.apache.org/licenses/LICENSE-2.0
"""
import asyncio
import contextlib
import errno
import logging
import struct
from pprint import pformat as pf
from typing import Dict, Generator, Optional, Union
from .exceptions import SmartDeviceException
from .json import dumps as json_dumps
from .json import loads as json_loads
_LOGGER = logging.getLogger(__name__)
_NO_RETRY_ERRORS = {errno.EHOSTDOWN, errno.EHOSTUNREACH, errno.ECONNREFUSED}
class TPLinkSmartHomeProtocol:
"""Implementation of the TP-Link Smart Home protocol."""
INITIALIZATION_VECTOR = 171
DEFAULT_PORT = 9999
DEFAULT_TIMEOUT = 5
BLOCK_SIZE = 4
def __init__(self, host: str) -> None:
"""Create a protocol object."""
self.host = host
self.reader: Optional[asyncio.StreamReader] = None
self.writer: Optional[asyncio.StreamWriter] = None
self.query_lock: Optional[asyncio.Lock] = None
self.loop: Optional[asyncio.AbstractEventLoop] = None
def _detect_event_loop_change(self) -> None:
"""Check if this object has been reused betwen event loops."""
loop = asyncio.get_running_loop()
if not self.loop:
self.loop = loop
elif self.loop != loop:
_LOGGER.warning("Detected protocol reuse between different event loop")
self._reset()
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
"""Request information from a TP-Link SmartHome Device.
:param str host: host name or ip address of the device
:param request: command to send to the device (can be either dict or
json string)
:param retry_count: how many retries to do in case of failure
:return: response dict
"""
self._detect_event_loop_change()
if not self.query_lock:
self.query_lock = asyncio.Lock()
if isinstance(request, dict):
request = json_dumps(request)
assert isinstance(request, str)
timeout = TPLinkSmartHomeProtocol.DEFAULT_TIMEOUT
async with self.query_lock:
return await self._query(request, retry_count, timeout)
async def _connect(self, timeout: int) -> None:
"""Try to connect or reconnect to the device."""
if self.writer:
return
self.reader = self.writer = None
task = asyncio.open_connection(self.host, TPLinkSmartHomeProtocol.DEFAULT_PORT)
self.reader, self.writer = await asyncio.wait_for(task, timeout=timeout)
async def _execute_query(self, request: str) -> Dict:
"""Execute a query on the device and wait for the response."""
assert self.writer is not None
assert self.reader is not None
debug_log = _LOGGER.isEnabledFor(logging.DEBUG)
if debug_log:
_LOGGER.debug("%s >> %s", self.host, request)
self.writer.write(TPLinkSmartHomeProtocol.encrypt(request))
await self.writer.drain()
packed_block_size = await self.reader.readexactly(self.BLOCK_SIZE)
length = struct.unpack(">I", packed_block_size)[0]
buffer = await self.reader.readexactly(length)
response = TPLinkSmartHomeProtocol.decrypt(buffer)
json_payload = json_loads(response)
if debug_log:
_LOGGER.debug("%s << %s", self.host, pf(json_payload))
return json_payload
async def close(self) -> None:
"""Close the connection."""
writer = self.writer
self.reader = self.writer = None
if writer:
writer.close()
with contextlib.suppress(Exception):
await writer.wait_closed()
def _reset(self) -> None:
"""Clear any varibles that should not survive between loops."""
self.reader = self.writer = self.loop = self.query_lock = None
async def _query(self, request: str, retry_count: int, timeout: int) -> Dict:
"""Try to query a device."""
#
# Most of the time we will already be connected if the device is online
# and the connect call will do nothing and return right away
#
# However, if we get an unrecoverable error (_NO_RETRY_ERRORS and ConnectionRefusedError)
# we do not want to keep trying since many connection open/close operations
# in the same time frame can block the event loop. This is especially
# import when there are multiple tplink devices being polled.
#
for retry in range(retry_count + 1):
try:
await self._connect(timeout)
except ConnectionRefusedError as ex:
await self.close()
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {ex}"
)
except OSError as ex:
await self.close()
if ex.errno in _NO_RETRY_ERRORS or retry >= retry_count:
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {ex}"
)
continue
except Exception as ex:
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {ex}"
)
continue
try:
assert self.reader is not None
assert self.writer is not None
return await asyncio.wait_for(
self._execute_query(request), timeout=timeout
)
except Exception as ex:
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
raise SmartDeviceException(
f"Unable to query the device {self.host}: {ex}"
) from ex
_LOGGER.debug(
"Unable to query the device %s, retrying: %s", self.host, ex
)
# make mypy happy, this should never be reached..
await self.close()
raise SmartDeviceException("Query reached somehow to unreachable")
def __del__(self) -> None:
if self.writer and self.loop and self.loop.is_running():
# Since __del__ will be called when python does
# garbage collection is can happen in the event loop thread
# or in another thread so we need to make sure the call to
# close is called safely with call_soon_threadsafe
self.loop.call_soon_threadsafe(self.writer.close)
self._reset()
@staticmethod
def _xor_payload(unencrypted: bytes) -> Generator[int, None, None]:
key = TPLinkSmartHomeProtocol.INITIALIZATION_VECTOR
for unencryptedbyte in unencrypted:
key = key ^ unencryptedbyte
yield key
@staticmethod
def encrypt(request: str) -> bytes:
"""Encrypt a request for a TP-Link Smart Home Device.
:param request: plaintext request data
:return: ciphertext to be send over wire, in bytes
"""
plainbytes = request.encode()
return struct.pack(">I", len(plainbytes)) + bytes(
TPLinkSmartHomeProtocol._xor_payload(plainbytes)
)
@staticmethod
def _xor_encrypted_payload(ciphertext: bytes) -> Generator[int, None, None]:
key = TPLinkSmartHomeProtocol.INITIALIZATION_VECTOR
for cipherbyte in ciphertext:
plainbyte = key ^ cipherbyte
key = cipherbyte
yield plainbyte
@staticmethod
def decrypt(ciphertext: bytes) -> str:
"""Decrypt a response of a TP-Link Smart Home Device.
:param ciphertext: encrypted response data
:return: plaintext response
"""
return bytes(
TPLinkSmartHomeProtocol._xor_encrypted_payload(ciphertext)
).decode()
# Try to load the kasa_crypt module and if it is available
try:
from kasa_crypt import decrypt, encrypt
TPLinkSmartHomeProtocol.decrypt = decrypt # type: ignore[method-assign]
TPLinkSmartHomeProtocol.encrypt = encrypt # type: ignore[method-assign]
except ImportError:
pass