Add concrete XorTransport class with full implementation (#646)

* Add concrete XorTransport class

* Update xortransport reset() docstring
This commit is contained in:
Steven B 2024-01-25 17:37:19 +00:00 committed by GitHub
parent c01c3c679c
commit c318303255
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 464 additions and 46 deletions

View File

@ -60,7 +60,7 @@ class BaseTransport(ABC):
self._port = config.port_override or self.default_port self._port = config.port_override or self.default_port
self._credentials = config.credentials self._credentials = config.credentials
self._credentials_hash = config.credentials_hash self._credentials_hash = config.credentials_hash
self._timeout = config.timeout self._timeout = config.timeout or self.DEFAULT_TIMEOUT
@property @property
@abstractmethod @abstractmethod
@ -124,6 +124,7 @@ class _XorTransport(BaseTransport):
""" """
DEFAULT_PORT: int = 9999 DEFAULT_PORT: int = 9999
BLOCK_SIZE = 4
def __init__(self, *, config: DeviceConfig) -> None: def __init__(self, *, config: DeviceConfig) -> None:
super().__init__(config=config) super().__init__(config=config)

View File

@ -4,6 +4,7 @@ import importlib
import inspect import inspect
import json import json
import logging import logging
import os
import pkgutil import pkgutil
import struct import struct
import sys import sys
@ -14,6 +15,7 @@ from ..aestransport import AesTransport
from ..credentials import Credentials from ..credentials import Credentials
from ..deviceconfig import DeviceConfig from ..deviceconfig import DeviceConfig
from ..exceptions import SmartDeviceException from ..exceptions import SmartDeviceException
from ..iotprotocol import IotProtocol
from ..klaptransport import KlapTransport, KlapTransportV2 from ..klaptransport import KlapTransport, KlapTransportV2
from ..protocol import ( from ..protocol import (
BaseProtocol, BaseProtocol,
@ -21,10 +23,19 @@ from ..protocol import (
TPLinkSmartHomeProtocol, TPLinkSmartHomeProtocol,
_XorTransport, _XorTransport,
) )
from ..xortransport import XorEncryption, XorTransport
@pytest.mark.parametrize(
"protocol_class, transport_class",
[
(TPLinkSmartHomeProtocol, _XorTransport),
(IotProtocol, XorTransport),
],
ids=("TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"),
)
@pytest.mark.parametrize("retry_count", [1, 3, 5]) @pytest.mark.parametrize("retry_count", [1, 3, 5])
async def test_protocol_retries(mocker, retry_count): async def test_protocol_retries(mocker, retry_count, protocol_class, transport_class):
def aio_mock_writer(_, __): def aio_mock_writer(_, __):
reader = mocker.patch("asyncio.StreamReader") reader = mocker.patch("asyncio.StreamReader")
writer = mocker.patch("asyncio.StreamWriter") writer = mocker.patch("asyncio.StreamWriter")
@ -38,60 +49,100 @@ async def test_protocol_retries(mocker, retry_count):
conn = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) conn = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
config = DeviceConfig("127.0.0.1") config = DeviceConfig("127.0.0.1")
with pytest.raises(SmartDeviceException): with pytest.raises(SmartDeviceException):
await TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)).query( await protocol_class(transport=transport_class(config=config)).query(
{}, retry_count=retry_count {}, retry_count=retry_count
) )
assert conn.call_count == retry_count + 1 assert conn.call_count == retry_count + 1
async def test_protocol_no_retry_on_unreachable(mocker): @pytest.mark.parametrize(
"protocol_class, transport_class",
[
(TPLinkSmartHomeProtocol, _XorTransport),
(IotProtocol, XorTransport),
],
ids=("TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"),
)
async def test_protocol_no_retry_on_unreachable(
mocker, protocol_class, transport_class
):
conn = mocker.patch( conn = mocker.patch(
"asyncio.open_connection", "asyncio.open_connection",
side_effect=OSError(errno.EHOSTUNREACH, "No route to host"), side_effect=OSError(errno.EHOSTUNREACH, "No route to host"),
) )
config = DeviceConfig("127.0.0.1") config = DeviceConfig("127.0.0.1")
with pytest.raises(SmartDeviceException): with pytest.raises(SmartDeviceException):
await TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)).query( await protocol_class(transport=transport_class(config=config)).query(
{}, retry_count=5 {}, retry_count=5
) )
assert conn.call_count == 1 assert conn.call_count == 1
async def test_protocol_no_retry_connection_refused(mocker): @pytest.mark.parametrize(
"protocol_class, transport_class",
[
(TPLinkSmartHomeProtocol, _XorTransport),
(IotProtocol, XorTransport),
],
ids=("TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"),
)
async def test_protocol_no_retry_connection_refused(
mocker, protocol_class, transport_class
):
conn = mocker.patch( conn = mocker.patch(
"asyncio.open_connection", "asyncio.open_connection",
side_effect=ConnectionRefusedError, side_effect=ConnectionRefusedError,
) )
config = DeviceConfig("127.0.0.1") config = DeviceConfig("127.0.0.1")
with pytest.raises(SmartDeviceException): with pytest.raises(SmartDeviceException):
await TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)).query( await protocol_class(transport=transport_class(config=config)).query(
{}, retry_count=5 {}, retry_count=5
) )
assert conn.call_count == 1 assert conn.call_count == 1
async def test_protocol_retry_recoverable_error(mocker): @pytest.mark.parametrize(
"protocol_class, transport_class",
[
(TPLinkSmartHomeProtocol, _XorTransport),
(IotProtocol, XorTransport),
],
ids=("TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"),
)
async def test_protocol_retry_recoverable_error(
mocker, protocol_class, transport_class
):
conn = mocker.patch( conn = mocker.patch(
"asyncio.open_connection", "asyncio.open_connection",
side_effect=OSError(errno.ECONNRESET, "Connection reset by peer"), side_effect=OSError(errno.ECONNRESET, "Connection reset by peer"),
) )
config = DeviceConfig("127.0.0.1") config = DeviceConfig("127.0.0.1")
with pytest.raises(SmartDeviceException): with pytest.raises(SmartDeviceException):
await TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)).query( await protocol_class(transport=transport_class(config=config)).query(
{}, retry_count=5 {}, retry_count=5
) )
assert conn.call_count == 6 assert conn.call_count == 6
@pytest.mark.parametrize(
"protocol_class, transport_class, encryption_class",
[
(TPLinkSmartHomeProtocol, _XorTransport, TPLinkSmartHomeProtocol),
(IotProtocol, XorTransport, XorEncryption),
],
ids=("TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"),
)
@pytest.mark.parametrize("retry_count", [1, 3, 5]) @pytest.mark.parametrize("retry_count", [1, 3, 5])
async def test_protocol_reconnect(mocker, retry_count): async def test_protocol_reconnect(
mocker, retry_count, protocol_class, transport_class, encryption_class
):
remaining = retry_count remaining = retry_count
encrypted = TPLinkSmartHomeProtocol.encrypt('{"great":"success"}')[ encrypted = encryption_class.encrypt('{"great":"success"}')[
TPLinkSmartHomeProtocol.BLOCK_SIZE : transport_class.BLOCK_SIZE :
] ]
def _fail_one_less_than_retry_count(*_): def _fail_one_less_than_retry_count(*_):
@ -102,7 +153,7 @@ async def test_protocol_reconnect(mocker, retry_count):
async def _mock_read(byte_count): async def _mock_read(byte_count):
nonlocal encrypted nonlocal encrypted
if byte_count == TPLinkSmartHomeProtocol.BLOCK_SIZE: if byte_count == transport_class.BLOCK_SIZE:
return struct.pack(">I", len(encrypted)) return struct.pack(">I", len(encrypted))
if byte_count == len(encrypted): if byte_count == len(encrypted):
return encrypted return encrypted
@ -117,16 +168,26 @@ async def test_protocol_reconnect(mocker, retry_count):
return reader, writer return reader, writer
config = DeviceConfig("127.0.0.1") config = DeviceConfig("127.0.0.1")
protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)) protocol = protocol_class(transport=transport_class(config=config))
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
response = await protocol.query({}, retry_count=retry_count) response = await protocol.query({}, retry_count=retry_count)
assert response == {"great": "success"} assert response == {"great": "success"}
async def test_protocol_handles_cancellation_during_write(mocker): @pytest.mark.parametrize(
"protocol_class, transport_class, encryption_class",
[
(TPLinkSmartHomeProtocol, _XorTransport, TPLinkSmartHomeProtocol),
(IotProtocol, XorTransport, XorEncryption),
],
ids=("TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"),
)
async def test_protocol_handles_cancellation_during_write(
mocker, protocol_class, transport_class, encryption_class
):
attempts = 0 attempts = 0
encrypted = TPLinkSmartHomeProtocol.encrypt('{"great":"success"}')[ encrypted = encryption_class.encrypt('{"great":"success"}')[
TPLinkSmartHomeProtocol.BLOCK_SIZE : transport_class.BLOCK_SIZE :
] ]
def _cancel_first_attempt(*_): def _cancel_first_attempt(*_):
@ -137,7 +198,7 @@ async def test_protocol_handles_cancellation_during_write(mocker):
async def _mock_read(byte_count): async def _mock_read(byte_count):
nonlocal encrypted nonlocal encrypted
if byte_count == TPLinkSmartHomeProtocol.BLOCK_SIZE: if byte_count == transport_class.BLOCK_SIZE:
return struct.pack(">I", len(encrypted)) return struct.pack(">I", len(encrypted))
if byte_count == len(encrypted): if byte_count == len(encrypted):
return encrypted return encrypted
@ -152,24 +213,36 @@ async def test_protocol_handles_cancellation_during_write(mocker):
return reader, writer return reader, writer
config = DeviceConfig("127.0.0.1") config = DeviceConfig("127.0.0.1")
protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)) protocol = protocol_class(transport=transport_class(config=config))
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) conn_mock = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
with pytest.raises(asyncio.CancelledError): with pytest.raises(asyncio.CancelledError):
await protocol.query({}) await protocol.query({})
assert protocol.writer is None writer_obj = protocol if hasattr(protocol, "writer") else protocol._transport
assert writer_obj.writer is None
conn_mock.assert_awaited_once()
response = await protocol.query({}) response = await protocol.query({})
assert response == {"great": "success"} assert response == {"great": "success"}
async def test_protocol_handles_cancellation_during_connection(mocker): @pytest.mark.parametrize(
"protocol_class, transport_class, encryption_class",
[
(TPLinkSmartHomeProtocol, _XorTransport, TPLinkSmartHomeProtocol),
(IotProtocol, XorTransport, XorEncryption),
],
ids=("TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"),
)
async def test_protocol_handles_cancellation_during_connection(
mocker, protocol_class, transport_class, encryption_class
):
attempts = 0 attempts = 0
encrypted = TPLinkSmartHomeProtocol.encrypt('{"great":"success"}')[ encrypted = encryption_class.encrypt('{"great":"success"}')[
TPLinkSmartHomeProtocol.BLOCK_SIZE : transport_class.BLOCK_SIZE :
] ]
async def _mock_read(byte_count): async def _mock_read(byte_count):
nonlocal encrypted nonlocal encrypted
if byte_count == TPLinkSmartHomeProtocol.BLOCK_SIZE: if byte_count == transport_class.BLOCK_SIZE:
return struct.pack(">I", len(encrypted)) return struct.pack(">I", len(encrypted))
if byte_count == len(encrypted): if byte_count == len(encrypted):
return encrypted return encrypted
@ -187,26 +260,39 @@ async def test_protocol_handles_cancellation_during_connection(mocker):
return reader, writer return reader, writer
config = DeviceConfig("127.0.0.1") config = DeviceConfig("127.0.0.1")
protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)) protocol = protocol_class(transport=transport_class(config=config))
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) conn_mock = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
with pytest.raises(asyncio.CancelledError): with pytest.raises(asyncio.CancelledError):
await protocol.query({}) await protocol.query({})
assert protocol.writer is None
writer_obj = protocol if hasattr(protocol, "writer") else protocol._transport
assert writer_obj.writer is None
conn_mock.assert_awaited_once()
response = await protocol.query({}) response = await protocol.query({})
assert response == {"great": "success"} assert response == {"great": "success"}
@pytest.mark.parametrize(
"protocol_class, transport_class, encryption_class",
[
(TPLinkSmartHomeProtocol, _XorTransport, TPLinkSmartHomeProtocol),
(IotProtocol, XorTransport, XorEncryption),
],
ids=("TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"),
)
@pytest.mark.parametrize("log_level", [logging.WARNING, logging.DEBUG]) @pytest.mark.parametrize("log_level", [logging.WARNING, logging.DEBUG])
async def test_protocol_logging(mocker, caplog, log_level): async def test_protocol_logging(
mocker, caplog, log_level, protocol_class, transport_class, encryption_class
):
caplog.set_level(log_level) caplog.set_level(log_level)
logging.getLogger("kasa").setLevel(log_level) logging.getLogger("kasa").setLevel(log_level)
encrypted = TPLinkSmartHomeProtocol.encrypt('{"great":"success"}')[ encrypted = encryption_class.encrypt('{"great":"success"}')[
TPLinkSmartHomeProtocol.BLOCK_SIZE : transport_class.BLOCK_SIZE :
] ]
async def _mock_read(byte_count): async def _mock_read(byte_count):
nonlocal encrypted nonlocal encrypted
if byte_count == TPLinkSmartHomeProtocol.BLOCK_SIZE: if byte_count == transport_class.BLOCK_SIZE:
return struct.pack(">I", len(encrypted)) return struct.pack(">I", len(encrypted))
if byte_count == len(encrypted): if byte_count == len(encrypted):
return encrypted return encrypted
@ -219,7 +305,7 @@ async def test_protocol_logging(mocker, caplog, log_level):
return reader, writer return reader, writer
config = DeviceConfig("127.0.0.1") config = DeviceConfig("127.0.0.1")
protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)) protocol = protocol_class(transport=transport_class(config=config))
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
response = await protocol.query({}) response = await protocol.query({})
assert response == {"great": "success"} assert response == {"great": "success"}
@ -229,15 +315,25 @@ async def test_protocol_logging(mocker, caplog, log_level):
assert "success" not in caplog.text assert "success" not in caplog.text
@pytest.mark.parametrize(
"protocol_class, transport_class, encryption_class",
[
(TPLinkSmartHomeProtocol, _XorTransport, TPLinkSmartHomeProtocol),
(IotProtocol, XorTransport, XorEncryption),
],
ids=("TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"),
)
@pytest.mark.parametrize("custom_port", [123, None]) @pytest.mark.parametrize("custom_port", [123, None])
async def test_protocol_custom_port(mocker, custom_port): async def test_protocol_custom_port(
encrypted = TPLinkSmartHomeProtocol.encrypt('{"great":"success"}')[ mocker, custom_port, protocol_class, transport_class, encryption_class
TPLinkSmartHomeProtocol.BLOCK_SIZE : ):
encrypted = encryption_class.encrypt('{"great":"success"}')[
transport_class.BLOCK_SIZE :
] ]
async def _mock_read(byte_count): async def _mock_read(byte_count):
nonlocal encrypted nonlocal encrypted
if byte_count == TPLinkSmartHomeProtocol.BLOCK_SIZE: if byte_count == transport_class.BLOCK_SIZE:
return struct.pack(">I", len(encrypted)) return struct.pack(">I", len(encrypted))
if byte_count == len(encrypted): if byte_count == len(encrypted):
return encrypted return encrypted
@ -254,21 +350,33 @@ async def test_protocol_custom_port(mocker, custom_port):
return reader, writer return reader, writer
config = DeviceConfig("127.0.0.1", port_override=custom_port) config = DeviceConfig("127.0.0.1", port_override=custom_port)
protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)) protocol = protocol_class(transport=transport_class(config=config))
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
response = await protocol.query({}) response = await protocol.query({})
assert response == {"great": "success"} assert response == {"great": "success"}
def test_encrypt(): @pytest.mark.parametrize(
"encrypt_class",
[TPLinkSmartHomeProtocol, XorEncryption],
)
@pytest.mark.parametrize(
"decrypt_class",
[TPLinkSmartHomeProtocol, XorEncryption],
)
def test_encrypt(encrypt_class, decrypt_class):
d = json.dumps({"foo": 1, "bar": 2}) d = json.dumps({"foo": 1, "bar": 2})
encrypted = TPLinkSmartHomeProtocol.encrypt(d) encrypted = encrypt_class.encrypt(d)
# encrypt adds a 4 byte header # encrypt adds a 4 byte header
encrypted = encrypted[4:] encrypted = encrypted[4:]
assert d == TPLinkSmartHomeProtocol.decrypt(encrypted) assert d == decrypt_class.decrypt(encrypted)
def test_encrypt_unicode(): @pytest.mark.parametrize(
"encrypt_class",
[TPLinkSmartHomeProtocol, XorEncryption],
)
def test_encrypt_unicode(encrypt_class):
d = "{'snowman': '\u2603'}" d = "{'snowman': '\u2603'}"
e = bytes( e = bytes(
@ -294,14 +402,18 @@ def test_encrypt_unicode():
] ]
) )
encrypted = TPLinkSmartHomeProtocol.encrypt(d) encrypted = encrypt_class.encrypt(d)
# encrypt adds a 4 byte header # encrypt adds a 4 byte header
encrypted = encrypted[4:] encrypted = encrypted[4:]
assert e == encrypted assert e == encrypted
def test_decrypt_unicode(): @pytest.mark.parametrize(
"decrypt_class",
[TPLinkSmartHomeProtocol, XorEncryption],
)
def test_decrypt_unicode(decrypt_class):
e = bytes( e = bytes(
[ [
208, 208,
@ -327,7 +439,7 @@ def test_decrypt_unicode():
d = "{'snowman': '\u2603'}" d = "{'snowman': '\u2603'}"
assert d == TPLinkSmartHomeProtocol.decrypt(e) assert d == decrypt_class.decrypt(e)
def _get_subclasses(of_class): def _get_subclasses(of_class):
@ -378,7 +490,8 @@ def test_transport_init_signature(class_name_obj):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"transport_class", [AesTransport, KlapTransport, KlapTransportV2, _XorTransport] "transport_class",
[AesTransport, KlapTransport, KlapTransportV2, _XorTransport, XorTransport],
) )
async def test_transport_credentials_hash(mocker, transport_class): async def test_transport_credentials_hash(mocker, transport_class):
host = "127.0.0.1" host = "127.0.0.1"
@ -391,3 +504,79 @@ async def test_transport_credentials_hash(mocker, transport_class):
transport = transport_class(config=config) transport = transport_class(config=config)
assert transport.credentials_hash == credentials_hash assert transport.credentials_hash == credentials_hash
@pytest.mark.parametrize(
"error, retry_expectation",
[
(ConnectionRefusedError("dummy exception"), False),
(OSError(errno.EHOSTDOWN, os.strerror(errno.EHOSTDOWN)), False),
(OSError(errno.ECONNRESET, os.strerror(errno.ECONNRESET)), True),
(Exception("dummy exception"), True),
],
ids=("ConnectionRefusedError", "OSErrorNoRetry", "OSErrorRetry", "Exception"),
)
@pytest.mark.parametrize(
"protocol_class, transport_class",
[
(TPLinkSmartHomeProtocol, _XorTransport),
(IotProtocol, XorTransport),
],
ids=("TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"),
)
async def test_protocol_will_retry_on_connect(
mocker, protocol_class, transport_class, error, retry_expectation
):
retry_count = 2
conn = mocker.patch("asyncio.open_connection", side_effect=error)
config = DeviceConfig("127.0.0.1")
with pytest.raises(SmartDeviceException):
await protocol_class(transport=transport_class(config=config)).query(
{}, retry_count=retry_count
)
assert conn.call_count == (retry_count + 1 if retry_expectation else 1)
@pytest.mark.parametrize(
"error, retry_expectation",
[
(ConnectionRefusedError("dummy exception"), True),
(OSError(errno.EHOSTDOWN, os.strerror(errno.EHOSTDOWN)), True),
(OSError(errno.ECONNRESET, os.strerror(errno.ECONNRESET)), True),
(Exception("dummy exception"), True),
],
ids=("ConnectionRefusedError", "OSErrorNoRetry", "OSErrorRetry", "Exception"),
)
@pytest.mark.parametrize(
"protocol_class, transport_class",
[
(TPLinkSmartHomeProtocol, _XorTransport),
(IotProtocol, XorTransport),
],
ids=("TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"),
)
async def test_protocol_will_retry_on_write(
mocker, protocol_class, transport_class, error, retry_expectation
):
retry_count = 2
writer = mocker.patch("asyncio.StreamWriter")
write_mock = mocker.patch.object(writer, "write", side_effect=error)
def aio_mock_writer(_, __):
nonlocal writer
reader = mocker.patch("asyncio.StreamReader")
return reader, writer
conn = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
write_mock = mocker.patch("asyncio.StreamWriter.write", side_effect=error)
config = DeviceConfig("127.0.0.1")
with pytest.raises(SmartDeviceException):
await protocol_class(transport=transport_class(config=config)).query(
{}, retry_count=retry_count
)
expected_call_count = retry_count + 1 if retry_expectation else 1
assert conn.call_count == expected_call_count
assert write_mock.call_count == expected_call_count

228
kasa/xortransport.py Normal file
View File

@ -0,0 +1,228 @@
"""Module for the XorTransport."""
import asyncio
import contextlib
import errno
import logging
import socket
import struct
from pprint import pformat as pf
from typing import Dict, Generator, Optional
# When support for cpython older than 3.11 is dropped
# async_timeout can be replaced with asyncio.timeout
from async_timeout import timeout as asyncio_timeout
from .deviceconfig import DeviceConfig
from .exceptions import RetryableException, SmartDeviceException
from .json import loads as json_loads
from .protocol import BaseTransport
_LOGGER = logging.getLogger(__name__)
_NO_RETRY_ERRORS = {errno.EHOSTDOWN, errno.EHOSTUNREACH, errno.ECONNREFUSED}
_UNSIGNED_INT_NETWORK_ORDER = struct.Struct(">I")
class XorTransport(BaseTransport):
"""Implementation of the Xor encryption transport.
WIP, currently only to ensure consistent __init__ method signatures
for protocol classes. Will eventually incorporate the logic from
TPLinkSmartHomeProtocol to simplify the API and re-use the IotProtocol
class.
"""
DEFAULT_PORT: int = 9999
BLOCK_SIZE = 4
def __init__(self, *, config: DeviceConfig) -> None:
super().__init__(config=config)
self.reader: Optional[asyncio.StreamReader] = None
self.writer: Optional[asyncio.StreamWriter] = None
self.query_lock = asyncio.Lock()
self.loop: Optional[asyncio.AbstractEventLoop] = None
@property
def default_port(self):
"""Default port for the transport."""
return self.DEFAULT_PORT
@property
def credentials_hash(self) -> str:
"""The hashed credentials used by the transport."""
return ""
async def _connect(self, timeout: int) -> None:
"""Try to connect or reconnect to the device."""
if self.writer:
return
self.reader = self.writer = None
task = asyncio.open_connection(self._host, self._port)
async with asyncio_timeout(timeout):
self.reader, self.writer = await task
sock: socket.socket = self.writer.get_extra_info("socket")
# Ensure our packets get sent without delay as we do all
# our writes in a single go and we do not want any buffering
# which would needlessly delay the request or risk overloading
# the buffer on the device
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
async def _execute_send(self, request: str) -> Dict:
"""Execute a query on the device and wait for the response."""
assert self.writer is not None # noqa: S101
assert self.reader is not None # noqa: S101
debug_log = _LOGGER.isEnabledFor(logging.DEBUG)
if debug_log:
_LOGGER.debug("%s >> %s", self._host, request)
self.writer.write(XorEncryption.encrypt(request))
await self.writer.drain()
packed_block_size = await self.reader.readexactly(self.BLOCK_SIZE)
length = _UNSIGNED_INT_NETWORK_ORDER.unpack(packed_block_size)[0]
buffer = await self.reader.readexactly(length)
response = XorEncryption.decrypt(buffer)
json_payload = json_loads(response)
if debug_log:
_LOGGER.debug("%s << %s", self._host, pf(json_payload))
return json_payload
async def close(self) -> None:
"""Close the connection."""
writer = self.writer
self.close_without_wait()
if writer:
with contextlib.suppress(Exception):
await writer.wait_closed()
def close_without_wait(self) -> None:
"""Close the connection without waiting for the connection to close."""
writer = self.writer
self.reader = self.writer = None
if writer:
writer.close()
async def reset(self) -> None:
"""Reset the transport.
The transport cannot be reset so we must close instead.
"""
await self.close()
async def send(self, request: str) -> Dict:
"""Send a message to the device and return a response."""
#
# Most of the time we will already be connected if the device is online
# and the connect call will do nothing and return right away
#
# However, if we get an unrecoverable error (_NO_RETRY_ERRORS and
# ConnectionRefusedError) we do not want to keep trying since many
# connection open/close operations in the same time frame can block
# the event loop.
# This is especially import when there are multiple tplink devices being polled.
try:
await self._connect(self._timeout)
except ConnectionRefusedError as ex:
await self.reset()
raise SmartDeviceException(
f"Unable to connect to the device: {self._host}:{self._port}: {ex}"
) from ex
except OSError as ex:
await self.reset()
if ex.errno in _NO_RETRY_ERRORS:
raise SmartDeviceException(
f"Unable to connect to the device:"
f" {self._host}:{self._port}: {ex}"
) from ex
else:
raise RetryableException(
f"Unable to connect to the device:"
f" {self._host}:{self._port}: {ex}"
) from ex
except Exception as ex:
await self.reset()
raise RetryableException(
f"Unable to connect to the device:" f" {self._host}:{self._port}: {ex}"
) from ex
except BaseException:
# Likely something cancelled the task so we need to close the connection
# as we are not in an indeterminate state
self.close_without_wait()
raise
try:
assert self.reader is not None # noqa: S101
assert self.writer is not None # noqa: S101
async with asyncio_timeout(self._timeout):
return await self._execute_send(request)
except Exception as ex:
await self.reset()
raise RetryableException(
f"Unable to query the device {self._host}:{self._port}: {ex}"
) from ex
except BaseException:
# Likely something cancelled the task so we need to close the connection
# as we are not in an indeterminate state
self.close_without_wait()
raise
def __del__(self) -> None:
if self.writer and self.loop and self.loop.is_running():
# Since __del__ will be called when python does
# garbage collection is can happen in the event loop thread
# or in another thread so we need to make sure the call to
# close is called safely with call_soon_threadsafe
self.loop.call_soon_threadsafe(self.writer.close)
class XorEncryption:
"""XorEncryption class."""
INITIALIZATION_VECTOR = 171
@staticmethod
def _xor_payload(unencrypted: bytes) -> Generator[int, None, None]:
key = XorEncryption.INITIALIZATION_VECTOR
for unencryptedbyte in unencrypted:
key = key ^ unencryptedbyte
yield key
@staticmethod
def encrypt(request: str) -> bytes:
"""Encrypt a request for a TP-Link Smart Home Device.
:param request: plaintext request data
:return: ciphertext to be send over wire, in bytes
"""
plainbytes = request.encode()
return _UNSIGNED_INT_NETWORK_ORDER.pack(len(plainbytes)) + bytes(
XorEncryption._xor_payload(plainbytes)
)
@staticmethod
def _xor_encrypted_payload(ciphertext: bytes) -> Generator[int, None, None]:
key = XorEncryption.INITIALIZATION_VECTOR
for cipherbyte in ciphertext:
plainbyte = key ^ cipherbyte
key = cipherbyte
yield plainbyte
@staticmethod
def decrypt(ciphertext: bytes) -> str:
"""Decrypt a response of a TP-Link Smart Home Device.
:param ciphertext: encrypted response data
:return: plaintext response
"""
return bytes(XorEncryption._xor_encrypted_payload(ciphertext)).decode()
# Try to load the kasa_crypt module and if it is available
try:
from kasa_crypt import decrypt, encrypt
XorEncryption.decrypt = decrypt # type: ignore[method-assign]
XorEncryption.encrypt = encrypt # type: ignore[method-assign]
except ImportError:
pass