From c318303255289f6929565dcde18a6853a50cdec3 Mon Sep 17 00:00:00 2001 From: Steven B <51370195+sdb9696@users.noreply.github.com> Date: Thu, 25 Jan 2024 17:37:19 +0000 Subject: [PATCH] Add concrete XorTransport class with full implementation (#646) * Add concrete XorTransport class * Update xortransport reset() docstring --- kasa/protocol.py | 3 +- kasa/tests/test_protocol.py | 279 ++++++++++++++++++++++++++++++------ kasa/xortransport.py | 228 +++++++++++++++++++++++++++++ 3 files changed, 464 insertions(+), 46 deletions(-) create mode 100644 kasa/xortransport.py diff --git a/kasa/protocol.py b/kasa/protocol.py index ae8eb89b..b7ef3dea 100755 --- a/kasa/protocol.py +++ b/kasa/protocol.py @@ -60,7 +60,7 @@ class BaseTransport(ABC): self._port = config.port_override or self.default_port self._credentials = config.credentials self._credentials_hash = config.credentials_hash - self._timeout = config.timeout + self._timeout = config.timeout or self.DEFAULT_TIMEOUT @property @abstractmethod @@ -124,6 +124,7 @@ class _XorTransport(BaseTransport): """ DEFAULT_PORT: int = 9999 + BLOCK_SIZE = 4 def __init__(self, *, config: DeviceConfig) -> None: super().__init__(config=config) diff --git a/kasa/tests/test_protocol.py b/kasa/tests/test_protocol.py index f623b597..34f2507e 100644 --- a/kasa/tests/test_protocol.py +++ b/kasa/tests/test_protocol.py @@ -4,6 +4,7 @@ import importlib import inspect import json import logging +import os import pkgutil import struct import sys @@ -14,6 +15,7 @@ from ..aestransport import AesTransport from ..credentials import Credentials from ..deviceconfig import DeviceConfig from ..exceptions import SmartDeviceException +from ..iotprotocol import IotProtocol from ..klaptransport import KlapTransport, KlapTransportV2 from ..protocol import ( BaseProtocol, @@ -21,10 +23,19 @@ from ..protocol import ( TPLinkSmartHomeProtocol, _XorTransport, ) +from ..xortransport import XorEncryption, XorTransport +@pytest.mark.parametrize( + "protocol_class, transport_class", + [ + (TPLinkSmartHomeProtocol, _XorTransport), + (IotProtocol, XorTransport), + ], + ids=("TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"), +) @pytest.mark.parametrize("retry_count", [1, 3, 5]) -async def test_protocol_retries(mocker, retry_count): +async def test_protocol_retries(mocker, retry_count, protocol_class, transport_class): def aio_mock_writer(_, __): reader = mocker.patch("asyncio.StreamReader") writer = mocker.patch("asyncio.StreamWriter") @@ -38,60 +49,100 @@ async def test_protocol_retries(mocker, retry_count): conn = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) config = DeviceConfig("127.0.0.1") with pytest.raises(SmartDeviceException): - await TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)).query( + await protocol_class(transport=transport_class(config=config)).query( {}, retry_count=retry_count ) assert conn.call_count == retry_count + 1 -async def test_protocol_no_retry_on_unreachable(mocker): +@pytest.mark.parametrize( + "protocol_class, transport_class", + [ + (TPLinkSmartHomeProtocol, _XorTransport), + (IotProtocol, XorTransport), + ], + ids=("TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"), +) +async def test_protocol_no_retry_on_unreachable( + mocker, protocol_class, transport_class +): conn = mocker.patch( "asyncio.open_connection", side_effect=OSError(errno.EHOSTUNREACH, "No route to host"), ) config = DeviceConfig("127.0.0.1") with pytest.raises(SmartDeviceException): - await TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)).query( + await protocol_class(transport=transport_class(config=config)).query( {}, retry_count=5 ) assert conn.call_count == 1 -async def test_protocol_no_retry_connection_refused(mocker): +@pytest.mark.parametrize( + "protocol_class, transport_class", + [ + (TPLinkSmartHomeProtocol, _XorTransport), + (IotProtocol, XorTransport), + ], + ids=("TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"), +) +async def test_protocol_no_retry_connection_refused( + mocker, protocol_class, transport_class +): conn = mocker.patch( "asyncio.open_connection", side_effect=ConnectionRefusedError, ) config = DeviceConfig("127.0.0.1") with pytest.raises(SmartDeviceException): - await TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)).query( + await protocol_class(transport=transport_class(config=config)).query( {}, retry_count=5 ) assert conn.call_count == 1 -async def test_protocol_retry_recoverable_error(mocker): +@pytest.mark.parametrize( + "protocol_class, transport_class", + [ + (TPLinkSmartHomeProtocol, _XorTransport), + (IotProtocol, XorTransport), + ], + ids=("TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"), +) +async def test_protocol_retry_recoverable_error( + mocker, protocol_class, transport_class +): conn = mocker.patch( "asyncio.open_connection", side_effect=OSError(errno.ECONNRESET, "Connection reset by peer"), ) config = DeviceConfig("127.0.0.1") with pytest.raises(SmartDeviceException): - await TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)).query( + await protocol_class(transport=transport_class(config=config)).query( {}, retry_count=5 ) assert conn.call_count == 6 +@pytest.mark.parametrize( + "protocol_class, transport_class, encryption_class", + [ + (TPLinkSmartHomeProtocol, _XorTransport, TPLinkSmartHomeProtocol), + (IotProtocol, XorTransport, XorEncryption), + ], + ids=("TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"), +) @pytest.mark.parametrize("retry_count", [1, 3, 5]) -async def test_protocol_reconnect(mocker, retry_count): +async def test_protocol_reconnect( + mocker, retry_count, protocol_class, transport_class, encryption_class +): remaining = retry_count - encrypted = TPLinkSmartHomeProtocol.encrypt('{"great":"success"}')[ - TPLinkSmartHomeProtocol.BLOCK_SIZE : + encrypted = encryption_class.encrypt('{"great":"success"}')[ + transport_class.BLOCK_SIZE : ] def _fail_one_less_than_retry_count(*_): @@ -102,7 +153,7 @@ async def test_protocol_reconnect(mocker, retry_count): async def _mock_read(byte_count): nonlocal encrypted - if byte_count == TPLinkSmartHomeProtocol.BLOCK_SIZE: + if byte_count == transport_class.BLOCK_SIZE: return struct.pack(">I", len(encrypted)) if byte_count == len(encrypted): return encrypted @@ -117,16 +168,26 @@ async def test_protocol_reconnect(mocker, retry_count): return reader, writer config = DeviceConfig("127.0.0.1") - protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)) + protocol = protocol_class(transport=transport_class(config=config)) mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) response = await protocol.query({}, retry_count=retry_count) assert response == {"great": "success"} -async def test_protocol_handles_cancellation_during_write(mocker): +@pytest.mark.parametrize( + "protocol_class, transport_class, encryption_class", + [ + (TPLinkSmartHomeProtocol, _XorTransport, TPLinkSmartHomeProtocol), + (IotProtocol, XorTransport, XorEncryption), + ], + ids=("TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"), +) +async def test_protocol_handles_cancellation_during_write( + mocker, protocol_class, transport_class, encryption_class +): attempts = 0 - encrypted = TPLinkSmartHomeProtocol.encrypt('{"great":"success"}')[ - TPLinkSmartHomeProtocol.BLOCK_SIZE : + encrypted = encryption_class.encrypt('{"great":"success"}')[ + transport_class.BLOCK_SIZE : ] def _cancel_first_attempt(*_): @@ -137,7 +198,7 @@ async def test_protocol_handles_cancellation_during_write(mocker): async def _mock_read(byte_count): nonlocal encrypted - if byte_count == TPLinkSmartHomeProtocol.BLOCK_SIZE: + if byte_count == transport_class.BLOCK_SIZE: return struct.pack(">I", len(encrypted)) if byte_count == len(encrypted): return encrypted @@ -152,24 +213,36 @@ async def test_protocol_handles_cancellation_during_write(mocker): return reader, writer config = DeviceConfig("127.0.0.1") - protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)) - mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) + protocol = protocol_class(transport=transport_class(config=config)) + conn_mock = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) with pytest.raises(asyncio.CancelledError): await protocol.query({}) - assert protocol.writer is None + writer_obj = protocol if hasattr(protocol, "writer") else protocol._transport + assert writer_obj.writer is None + conn_mock.assert_awaited_once() response = await protocol.query({}) assert response == {"great": "success"} -async def test_protocol_handles_cancellation_during_connection(mocker): +@pytest.mark.parametrize( + "protocol_class, transport_class, encryption_class", + [ + (TPLinkSmartHomeProtocol, _XorTransport, TPLinkSmartHomeProtocol), + (IotProtocol, XorTransport, XorEncryption), + ], + ids=("TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"), +) +async def test_protocol_handles_cancellation_during_connection( + mocker, protocol_class, transport_class, encryption_class +): attempts = 0 - encrypted = TPLinkSmartHomeProtocol.encrypt('{"great":"success"}')[ - TPLinkSmartHomeProtocol.BLOCK_SIZE : + encrypted = encryption_class.encrypt('{"great":"success"}')[ + transport_class.BLOCK_SIZE : ] async def _mock_read(byte_count): nonlocal encrypted - if byte_count == TPLinkSmartHomeProtocol.BLOCK_SIZE: + if byte_count == transport_class.BLOCK_SIZE: return struct.pack(">I", len(encrypted)) if byte_count == len(encrypted): return encrypted @@ -187,26 +260,39 @@ async def test_protocol_handles_cancellation_during_connection(mocker): return reader, writer config = DeviceConfig("127.0.0.1") - protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)) - mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) + protocol = protocol_class(transport=transport_class(config=config)) + conn_mock = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) with pytest.raises(asyncio.CancelledError): await protocol.query({}) - assert protocol.writer is None + + writer_obj = protocol if hasattr(protocol, "writer") else protocol._transport + assert writer_obj.writer is None + conn_mock.assert_awaited_once() response = await protocol.query({}) assert response == {"great": "success"} +@pytest.mark.parametrize( + "protocol_class, transport_class, encryption_class", + [ + (TPLinkSmartHomeProtocol, _XorTransport, TPLinkSmartHomeProtocol), + (IotProtocol, XorTransport, XorEncryption), + ], + ids=("TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"), +) @pytest.mark.parametrize("log_level", [logging.WARNING, logging.DEBUG]) -async def test_protocol_logging(mocker, caplog, log_level): +async def test_protocol_logging( + mocker, caplog, log_level, protocol_class, transport_class, encryption_class +): caplog.set_level(log_level) logging.getLogger("kasa").setLevel(log_level) - encrypted = TPLinkSmartHomeProtocol.encrypt('{"great":"success"}')[ - TPLinkSmartHomeProtocol.BLOCK_SIZE : + encrypted = encryption_class.encrypt('{"great":"success"}')[ + transport_class.BLOCK_SIZE : ] async def _mock_read(byte_count): nonlocal encrypted - if byte_count == TPLinkSmartHomeProtocol.BLOCK_SIZE: + if byte_count == transport_class.BLOCK_SIZE: return struct.pack(">I", len(encrypted)) if byte_count == len(encrypted): return encrypted @@ -219,7 +305,7 @@ async def test_protocol_logging(mocker, caplog, log_level): return reader, writer config = DeviceConfig("127.0.0.1") - protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)) + protocol = protocol_class(transport=transport_class(config=config)) mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) response = await protocol.query({}) assert response == {"great": "success"} @@ -229,15 +315,25 @@ async def test_protocol_logging(mocker, caplog, log_level): assert "success" not in caplog.text +@pytest.mark.parametrize( + "protocol_class, transport_class, encryption_class", + [ + (TPLinkSmartHomeProtocol, _XorTransport, TPLinkSmartHomeProtocol), + (IotProtocol, XorTransport, XorEncryption), + ], + ids=("TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"), +) @pytest.mark.parametrize("custom_port", [123, None]) -async def test_protocol_custom_port(mocker, custom_port): - encrypted = TPLinkSmartHomeProtocol.encrypt('{"great":"success"}')[ - TPLinkSmartHomeProtocol.BLOCK_SIZE : +async def test_protocol_custom_port( + mocker, custom_port, protocol_class, transport_class, encryption_class +): + encrypted = encryption_class.encrypt('{"great":"success"}')[ + transport_class.BLOCK_SIZE : ] async def _mock_read(byte_count): nonlocal encrypted - if byte_count == TPLinkSmartHomeProtocol.BLOCK_SIZE: + if byte_count == transport_class.BLOCK_SIZE: return struct.pack(">I", len(encrypted)) if byte_count == len(encrypted): return encrypted @@ -254,21 +350,33 @@ async def test_protocol_custom_port(mocker, custom_port): return reader, writer config = DeviceConfig("127.0.0.1", port_override=custom_port) - protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)) + protocol = protocol_class(transport=transport_class(config=config)) mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) response = await protocol.query({}) assert response == {"great": "success"} -def test_encrypt(): +@pytest.mark.parametrize( + "encrypt_class", + [TPLinkSmartHomeProtocol, XorEncryption], +) +@pytest.mark.parametrize( + "decrypt_class", + [TPLinkSmartHomeProtocol, XorEncryption], +) +def test_encrypt(encrypt_class, decrypt_class): d = json.dumps({"foo": 1, "bar": 2}) - encrypted = TPLinkSmartHomeProtocol.encrypt(d) + encrypted = encrypt_class.encrypt(d) # encrypt adds a 4 byte header encrypted = encrypted[4:] - assert d == TPLinkSmartHomeProtocol.decrypt(encrypted) + assert d == decrypt_class.decrypt(encrypted) -def test_encrypt_unicode(): +@pytest.mark.parametrize( + "encrypt_class", + [TPLinkSmartHomeProtocol, XorEncryption], +) +def test_encrypt_unicode(encrypt_class): d = "{'snowman': '\u2603'}" e = bytes( @@ -294,14 +402,18 @@ def test_encrypt_unicode(): ] ) - encrypted = TPLinkSmartHomeProtocol.encrypt(d) + encrypted = encrypt_class.encrypt(d) # encrypt adds a 4 byte header encrypted = encrypted[4:] assert e == encrypted -def test_decrypt_unicode(): +@pytest.mark.parametrize( + "decrypt_class", + [TPLinkSmartHomeProtocol, XorEncryption], +) +def test_decrypt_unicode(decrypt_class): e = bytes( [ 208, @@ -327,7 +439,7 @@ def test_decrypt_unicode(): d = "{'snowman': '\u2603'}" - assert d == TPLinkSmartHomeProtocol.decrypt(e) + assert d == decrypt_class.decrypt(e) def _get_subclasses(of_class): @@ -378,7 +490,8 @@ def test_transport_init_signature(class_name_obj): @pytest.mark.parametrize( - "transport_class", [AesTransport, KlapTransport, KlapTransportV2, _XorTransport] + "transport_class", + [AesTransport, KlapTransport, KlapTransportV2, _XorTransport, XorTransport], ) async def test_transport_credentials_hash(mocker, transport_class): host = "127.0.0.1" @@ -391,3 +504,79 @@ async def test_transport_credentials_hash(mocker, transport_class): transport = transport_class(config=config) assert transport.credentials_hash == credentials_hash + + +@pytest.mark.parametrize( + "error, retry_expectation", + [ + (ConnectionRefusedError("dummy exception"), False), + (OSError(errno.EHOSTDOWN, os.strerror(errno.EHOSTDOWN)), False), + (OSError(errno.ECONNRESET, os.strerror(errno.ECONNRESET)), True), + (Exception("dummy exception"), True), + ], + ids=("ConnectionRefusedError", "OSErrorNoRetry", "OSErrorRetry", "Exception"), +) +@pytest.mark.parametrize( + "protocol_class, transport_class", + [ + (TPLinkSmartHomeProtocol, _XorTransport), + (IotProtocol, XorTransport), + ], + ids=("TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"), +) +async def test_protocol_will_retry_on_connect( + mocker, protocol_class, transport_class, error, retry_expectation +): + retry_count = 2 + conn = mocker.patch("asyncio.open_connection", side_effect=error) + config = DeviceConfig("127.0.0.1") + with pytest.raises(SmartDeviceException): + await protocol_class(transport=transport_class(config=config)).query( + {}, retry_count=retry_count + ) + + assert conn.call_count == (retry_count + 1 if retry_expectation else 1) + + +@pytest.mark.parametrize( + "error, retry_expectation", + [ + (ConnectionRefusedError("dummy exception"), True), + (OSError(errno.EHOSTDOWN, os.strerror(errno.EHOSTDOWN)), True), + (OSError(errno.ECONNRESET, os.strerror(errno.ECONNRESET)), True), + (Exception("dummy exception"), True), + ], + ids=("ConnectionRefusedError", "OSErrorNoRetry", "OSErrorRetry", "Exception"), +) +@pytest.mark.parametrize( + "protocol_class, transport_class", + [ + (TPLinkSmartHomeProtocol, _XorTransport), + (IotProtocol, XorTransport), + ], + ids=("TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"), +) +async def test_protocol_will_retry_on_write( + mocker, protocol_class, transport_class, error, retry_expectation +): + retry_count = 2 + writer = mocker.patch("asyncio.StreamWriter") + write_mock = mocker.patch.object(writer, "write", side_effect=error) + + def aio_mock_writer(_, __): + nonlocal writer + reader = mocker.patch("asyncio.StreamReader") + + return reader, writer + + conn = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) + write_mock = mocker.patch("asyncio.StreamWriter.write", side_effect=error) + config = DeviceConfig("127.0.0.1") + with pytest.raises(SmartDeviceException): + await protocol_class(transport=transport_class(config=config)).query( + {}, retry_count=retry_count + ) + + expected_call_count = retry_count + 1 if retry_expectation else 1 + assert conn.call_count == expected_call_count + assert write_mock.call_count == expected_call_count diff --git a/kasa/xortransport.py b/kasa/xortransport.py new file mode 100644 index 00000000..bed62ea8 --- /dev/null +++ b/kasa/xortransport.py @@ -0,0 +1,228 @@ +"""Module for the XorTransport.""" +import asyncio +import contextlib +import errno +import logging +import socket +import struct +from pprint import pformat as pf +from typing import Dict, Generator, Optional + +# When support for cpython older than 3.11 is dropped +# async_timeout can be replaced with asyncio.timeout +from async_timeout import timeout as asyncio_timeout + +from .deviceconfig import DeviceConfig +from .exceptions import RetryableException, SmartDeviceException +from .json import loads as json_loads +from .protocol import BaseTransport + +_LOGGER = logging.getLogger(__name__) +_NO_RETRY_ERRORS = {errno.EHOSTDOWN, errno.EHOSTUNREACH, errno.ECONNREFUSED} +_UNSIGNED_INT_NETWORK_ORDER = struct.Struct(">I") + + +class XorTransport(BaseTransport): + """Implementation of the Xor encryption transport. + + WIP, currently only to ensure consistent __init__ method signatures + for protocol classes. Will eventually incorporate the logic from + TPLinkSmartHomeProtocol to simplify the API and re-use the IotProtocol + class. + """ + + DEFAULT_PORT: int = 9999 + BLOCK_SIZE = 4 + + def __init__(self, *, config: DeviceConfig) -> None: + super().__init__(config=config) + self.reader: Optional[asyncio.StreamReader] = None + self.writer: Optional[asyncio.StreamWriter] = None + self.query_lock = asyncio.Lock() + self.loop: Optional[asyncio.AbstractEventLoop] = None + + @property + def default_port(self): + """Default port for the transport.""" + return self.DEFAULT_PORT + + @property + def credentials_hash(self) -> str: + """The hashed credentials used by the transport.""" + return "" + + async def _connect(self, timeout: int) -> None: + """Try to connect or reconnect to the device.""" + if self.writer: + return + self.reader = self.writer = None + + task = asyncio.open_connection(self._host, self._port) + async with asyncio_timeout(timeout): + self.reader, self.writer = await task + sock: socket.socket = self.writer.get_extra_info("socket") + # Ensure our packets get sent without delay as we do all + # our writes in a single go and we do not want any buffering + # which would needlessly delay the request or risk overloading + # the buffer on the device + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + async def _execute_send(self, request: str) -> Dict: + """Execute a query on the device and wait for the response.""" + assert self.writer is not None # noqa: S101 + assert self.reader is not None # noqa: S101 + debug_log = _LOGGER.isEnabledFor(logging.DEBUG) + if debug_log: + _LOGGER.debug("%s >> %s", self._host, request) + self.writer.write(XorEncryption.encrypt(request)) + await self.writer.drain() + + packed_block_size = await self.reader.readexactly(self.BLOCK_SIZE) + length = _UNSIGNED_INT_NETWORK_ORDER.unpack(packed_block_size)[0] + + buffer = await self.reader.readexactly(length) + response = XorEncryption.decrypt(buffer) + json_payload = json_loads(response) + if debug_log: + _LOGGER.debug("%s << %s", self._host, pf(json_payload)) + + return json_payload + + async def close(self) -> None: + """Close the connection.""" + writer = self.writer + self.close_without_wait() + if writer: + with contextlib.suppress(Exception): + await writer.wait_closed() + + def close_without_wait(self) -> None: + """Close the connection without waiting for the connection to close.""" + writer = self.writer + self.reader = self.writer = None + if writer: + writer.close() + + async def reset(self) -> None: + """Reset the transport. + + The transport cannot be reset so we must close instead. + """ + await self.close() + + async def send(self, request: str) -> Dict: + """Send a message to the device and return a response.""" + # + # Most of the time we will already be connected if the device is online + # and the connect call will do nothing and return right away + # + # However, if we get an unrecoverable error (_NO_RETRY_ERRORS and + # ConnectionRefusedError) we do not want to keep trying since many + # connection open/close operations in the same time frame can block + # the event loop. + # This is especially import when there are multiple tplink devices being polled. + try: + await self._connect(self._timeout) + except ConnectionRefusedError as ex: + await self.reset() + raise SmartDeviceException( + f"Unable to connect to the device: {self._host}:{self._port}: {ex}" + ) from ex + except OSError as ex: + await self.reset() + if ex.errno in _NO_RETRY_ERRORS: + raise SmartDeviceException( + f"Unable to connect to the device:" + f" {self._host}:{self._port}: {ex}" + ) from ex + else: + raise RetryableException( + f"Unable to connect to the device:" + f" {self._host}:{self._port}: {ex}" + ) from ex + except Exception as ex: + await self.reset() + raise RetryableException( + f"Unable to connect to the device:" f" {self._host}:{self._port}: {ex}" + ) from ex + except BaseException: + # Likely something cancelled the task so we need to close the connection + # as we are not in an indeterminate state + self.close_without_wait() + raise + + try: + assert self.reader is not None # noqa: S101 + assert self.writer is not None # noqa: S101 + async with asyncio_timeout(self._timeout): + return await self._execute_send(request) + except Exception as ex: + await self.reset() + raise RetryableException( + f"Unable to query the device {self._host}:{self._port}: {ex}" + ) from ex + except BaseException: + # Likely something cancelled the task so we need to close the connection + # as we are not in an indeterminate state + self.close_without_wait() + raise + + def __del__(self) -> None: + if self.writer and self.loop and self.loop.is_running(): + # Since __del__ will be called when python does + # garbage collection is can happen in the event loop thread + # or in another thread so we need to make sure the call to + # close is called safely with call_soon_threadsafe + self.loop.call_soon_threadsafe(self.writer.close) + + +class XorEncryption: + """XorEncryption class.""" + + INITIALIZATION_VECTOR = 171 + + @staticmethod + def _xor_payload(unencrypted: bytes) -> Generator[int, None, None]: + key = XorEncryption.INITIALIZATION_VECTOR + for unencryptedbyte in unencrypted: + key = key ^ unencryptedbyte + yield key + + @staticmethod + def encrypt(request: str) -> bytes: + """Encrypt a request for a TP-Link Smart Home Device. + + :param request: plaintext request data + :return: ciphertext to be send over wire, in bytes + """ + plainbytes = request.encode() + return _UNSIGNED_INT_NETWORK_ORDER.pack(len(plainbytes)) + bytes( + XorEncryption._xor_payload(plainbytes) + ) + + @staticmethod + def _xor_encrypted_payload(ciphertext: bytes) -> Generator[int, None, None]: + key = XorEncryption.INITIALIZATION_VECTOR + for cipherbyte in ciphertext: + plainbyte = key ^ cipherbyte + key = cipherbyte + yield plainbyte + + @staticmethod + def decrypt(ciphertext: bytes) -> str: + """Decrypt a response of a TP-Link Smart Home Device. + + :param ciphertext: encrypted response data + :return: plaintext response + """ + return bytes(XorEncryption._xor_encrypted_payload(ciphertext)).decode() + + +# Try to load the kasa_crypt module and if it is available +try: + from kasa_crypt import decrypt, encrypt + + XorEncryption.decrypt = decrypt # type: ignore[method-assign] + XorEncryption.encrypt = encrypt # type: ignore[method-assign] +except ImportError: + pass