From 0f3e4fc67520cfca29f09ebbe499882583d0da96 Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Thu, 25 Jan 2024 18:55:13 +0000 Subject: [PATCH] Switch from TPLinkSmartHomeProtocol to IotProtocol/XorTransport --- devtools/parse_pcap.py | 4 +- kasa/__init__.py | 13 +- kasa/device_factory.py | 9 +- kasa/discover.py | 6 +- kasa/iotprotocol.py | 44 ++++- kasa/protocol.py | 275 +----------------------------- kasa/smartdevice.py | 8 +- kasa/tests/conftest.py | 5 +- kasa/tests/newfakes.py | 8 +- kasa/tests/test_cli.py | 1 - kasa/tests/test_device_factory.py | 5 - kasa/tests/test_discovery.py | 14 +- kasa/tests/test_protocol.py | 84 +++++---- kasa/xortransport.py | 20 ++- 14 files changed, 148 insertions(+), 348 deletions(-) diff --git a/devtools/parse_pcap.py b/devtools/parse_pcap.py index 5e741623..7a55bf54 100644 --- a/devtools/parse_pcap.py +++ b/devtools/parse_pcap.py @@ -9,7 +9,7 @@ import dpkt from dpkt.ethernet import ETH_TYPE_IP, Ethernet from kasa.cli import echo -from kasa.protocol import TPLinkSmartHomeProtocol +from kasa.xortransport import XorEncryption def read_payloads_from_file(file): @@ -34,7 +34,7 @@ def read_payloads_from_file(file): data = transport.data try: - decrypted = TPLinkSmartHomeProtocol.decrypt(data[4:]) + decrypted = XorEncryption.decrypt(data[4:]) except Exception as ex: echo(f"[red]Unable to decrypt the data, ignoring: {ex}[/red]") continue diff --git a/kasa/__init__.py b/kasa/__init__.py index a8101ae3..a4bd5546 100755 --- a/kasa/__init__.py +++ b/kasa/__init__.py @@ -12,6 +12,7 @@ Module-specific errors are raised as `SmartDeviceException` and are expected to be handled by the user of the library. """ from importlib.metadata import version +from warnings import warn from kasa.credentials import Credentials from kasa.deviceconfig import ( @@ -29,7 +30,7 @@ from kasa.exceptions import ( UnsupportedDeviceException, ) from kasa.iotprotocol import IotProtocol -from kasa.protocol import BaseProtocol, TPLinkSmartHomeProtocol +from kasa.protocol import BaseProtocol from kasa.smartbulb import SmartBulb, SmartBulbPreset, TurnOnBehavior, TurnOnBehaviors from kasa.smartdevice import DeviceType, SmartDevice from kasa.smartdimmer import SmartDimmer @@ -43,7 +44,6 @@ __version__ = version("python-kasa") __all__ = [ "Discover", - "TPLinkSmartHomeProtocol", "BaseProtocol", "IotProtocol", "SmartProtocol", @@ -68,3 +68,12 @@ __all__ = [ "EncryptType", "DeviceFamilyType", ] + +deprecated_names = ["TPLinkSmartHomeProtocol"] + + +def __getattr__(name): + if name in deprecated_names: + warn(f"{name} is deprecated", DeprecationWarning, stacklevel=1) + return globals()[f"_deprecated_{name}"] + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/kasa/device_factory.py b/kasa/device_factory.py index d216e0ef..fdb5b1b4 100755 --- a/kasa/device_factory.py +++ b/kasa/device_factory.py @@ -11,8 +11,6 @@ from .klaptransport import KlapTransport, KlapTransportV2 from .protocol import ( BaseProtocol, BaseTransport, - TPLinkSmartHomeProtocol, - _XorTransport, ) from .smartbulb import SmartBulb from .smartdevice import SmartDevice @@ -22,6 +20,7 @@ from .smartplug import SmartPlug from .smartprotocol import SmartProtocol from .smartstrip import SmartStrip from .tapo import TapoBulb, TapoPlug +from .xortransport import XorTransport _LOGGER = logging.getLogger(__name__) @@ -76,7 +75,9 @@ async def connect(*, host: Optional[str] = None, config: DeviceConfig) -> "Smart device_class: Optional[Type[SmartDevice]] - if isinstance(protocol, TPLinkSmartHomeProtocol): + if isinstance(protocol, IotProtocol) and isinstance( + protocol._transport, XorTransport + ): info = await protocol.query(GET_SYSINFO_QUERY) _perf_log(True, "get_sysinfo") device_class = get_device_class_from_sys_info(info) @@ -151,7 +152,7 @@ def get_protocol( supported_device_protocols: Dict[ str, Tuple[Type[BaseProtocol], Type[BaseTransport]] ] = { - "IOT.XOR": (TPLinkSmartHomeProtocol, _XorTransport), + "IOT.XOR": (IotProtocol, XorTransport), "IOT.KLAP": (IotProtocol, KlapTransport), "SMART.AES": (SmartProtocol, AesTransport), "SMART.KLAP": (SmartProtocol, KlapTransportV2), diff --git a/kasa/discover.py b/kasa/discover.py index 8b58d4bd..8286387a 100755 --- a/kasa/discover.py +++ b/kasa/discover.py @@ -25,8 +25,8 @@ from kasa.deviceconfig import ConnectionType, DeviceConfig, EncryptType from kasa.exceptions import TimeoutException, UnsupportedDeviceException from kasa.json import dumps as json_dumps from kasa.json import loads as json_loads -from kasa.protocol import TPLinkSmartHomeProtocol from kasa.smartdevice import SmartDevice, SmartDeviceException +from kasa.xortransport import XorEncryption _LOGGER = logging.getLogger(__name__) @@ -103,7 +103,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): """Send number of discovery datagrams.""" req = json_dumps(Discover.DISCOVERY_QUERY) _LOGGER.debug("[DISCOVERY] %s >> %s", self.target, Discover.DISCOVERY_QUERY) - encrypted_req = TPLinkSmartHomeProtocol.encrypt(req) + encrypted_req = XorEncryption.encrypt(req) sleep_between_packets = self.discovery_timeout / self.discovery_packets for i in range(self.discovery_packets): if self.target in self.seen_hosts: # Stop sending for discover_single @@ -400,7 +400,7 @@ class Discover: def _get_device_instance_legacy(data: bytes, config: DeviceConfig) -> SmartDevice: """Get SmartDevice from legacy 9999 response.""" try: - info = json_loads(TPLinkSmartHomeProtocol.decrypt(data)) + info = json_loads(XorEncryption.decrypt(data)) except Exception as ex: raise SmartDeviceException( f"Unable to read response from device: {config.host}: {ex}" diff --git a/kasa/iotprotocol.py b/kasa/iotprotocol.py index ed926101..f74e56f4 100755 --- a/kasa/iotprotocol.py +++ b/kasa/iotprotocol.py @@ -1,8 +1,9 @@ """Module for the IOT legacy IOT KASA protocol.""" import asyncio import logging -from typing import Dict, Union +from typing import Dict, Optional, Union +from .deviceconfig import DeviceConfig from .exceptions import ( AuthenticationException, ConnectionException, @@ -12,6 +13,7 @@ from .exceptions import ( ) from .json import dumps as json_dumps from .protocol import BaseProtocol, BaseTransport +from .xortransport import XorEncryption, XorTransport _LOGGER = logging.getLogger(__name__) @@ -86,3 +88,43 @@ class IotProtocol(BaseProtocol): async def close(self) -> None: """Close the underlying transport.""" await self._transport.close() + + +class _deprecated_TPLinkSmartHomeProtocol(IotProtocol): + def __init__( + self, + host: Optional[str] = None, + *, + port: Optional[int] = None, + timeout: Optional[int] = None, + transport: Optional[BaseTransport] = None, + ) -> None: + """Create a protocol object.""" + if not host and not transport: + raise SmartDeviceException("host or transport must be supplied") + if not transport: + config = DeviceConfig( + host=host, # type: ignore[arg-type] + port_override=port, + timeout=timeout or XorTransport.DEFAULT_TIMEOUT, + ) + transport = XorTransport(config=config) + super().__init__(transport=transport) + + @staticmethod + def encrypt(request: str) -> bytes: + """Encrypt a request for a TP-Link Smart Home Device. + + :param request: plaintext request data + :return: ciphertext to be send over wire, in bytes + """ + return XorEncryption.encrypt(request) + + @staticmethod + def decrypt(ciphertext: bytes) -> str: + """Decrypt a response of a TP-Link Smart Home Device. + + :param ciphertext: encrypted response data + :return: plaintext response + """ + return XorEncryption.decrypt(ciphertext) diff --git a/kasa/protocol.py b/kasa/protocol.py index b7ef3dea..60b3d7ca 100755 --- a/kasa/protocol.py +++ b/kasa/protocol.py @@ -9,27 +9,19 @@ https://github.com/softScheck/tplink-smartplug/ which are licensed under the Apache License, Version 2.0 http://www.apache.org/licenses/LICENSE-2.0 """ -import asyncio import base64 -import contextlib import errno import logging -import socket import struct from abc import ABC, abstractmethod -from pprint import pformat as pf -from typing import Dict, Generator, Optional, Tuple, Union +from typing import Dict, Tuple, Union # When support for cpython older than 3.11 is dropped # async_timeout can be replaced with asyncio.timeout -from async_timeout import timeout as asyncio_timeout from cryptography.hazmat.primitives import hashes from .credentials import Credentials from .deviceconfig import DeviceConfig -from .exceptions import SmartDeviceException -from .json import dumps as json_dumps -from .json import loads as json_loads _LOGGER = logging.getLogger(__name__) _NO_RETRY_ERRORS = {errno.EHOSTDOWN, errno.EHOSTUNREACH, errno.ECONNREFUSED} @@ -114,262 +106,6 @@ class BaseProtocol(ABC): """Close the protocol. Abstract method to be overriden.""" -class _XorTransport(BaseTransport): - """Implementation of the Xor encryption transport. - - WIP, currently only to ensure consistent __init__ method signatures - for protocol classes. Will eventually incorporate the logic from - TPLinkSmartHomeProtocol to simplify the API and re-use the IotProtocol - class. - """ - - DEFAULT_PORT: int = 9999 - BLOCK_SIZE = 4 - - def __init__(self, *, config: DeviceConfig) -> None: - super().__init__(config=config) - - @property - def default_port(self): - """Default port for the transport.""" - return self.DEFAULT_PORT - - @property - def credentials_hash(self) -> str: - """The hashed credentials used by the transport.""" - return "" - - async def send(self, request: str) -> Dict: - """Send a message to the device and return a response.""" - return {} - - async def close(self) -> None: - """Close the transport.""" - - async def reset(self) -> None: - """Reset internal state..""" - - -class TPLinkSmartHomeProtocol(BaseProtocol): - """Implementation of the TP-Link Smart Home protocol.""" - - INITIALIZATION_VECTOR = 171 - DEFAULT_PORT = 9999 - BLOCK_SIZE = 4 - - def __init__( - self, - *, - transport: BaseTransport, - ) -> None: - """Create a protocol object.""" - super().__init__(transport=transport) - - self.reader: Optional[asyncio.StreamReader] = None - self.writer: Optional[asyncio.StreamWriter] = None - self.query_lock = asyncio.Lock() - self.loop: Optional[asyncio.AbstractEventLoop] = None - - self._timeout = self._transport._timeout - self._port = self._transport._port - - async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: - """Request information from a TP-Link SmartHome Device. - - :param str host: host name or ip address of the device - :param request: command to send to the device (can be either dict or - json string) - :param retry_count: how many retries to do in case of failure - :return: response dict - """ - if isinstance(request, dict): - request = json_dumps(request) - assert isinstance(request, str) # noqa: S101 - - async with self.query_lock: - return await self._query(request, retry_count, self._timeout) # type: ignore[arg-type] - - async def _connect(self, timeout: int) -> None: - """Try to connect or reconnect to the device.""" - if self.writer: - return - self.reader = self.writer = None - - task = asyncio.open_connection(self._host, self._port) - async with asyncio_timeout(timeout): - self.reader, self.writer = await task - sock: socket.socket = self.writer.get_extra_info("socket") - # Ensure our packets get sent without delay as we do all - # our writes in a single go and we do not want any buffering - # which would needlessly delay the request or risk overloading - # the buffer on the device - sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - - async def _execute_query(self, request: str) -> Dict: - """Execute a query on the device and wait for the response.""" - assert self.writer is not None # noqa: S101 - assert self.reader is not None # noqa: S101 - debug_log = _LOGGER.isEnabledFor(logging.DEBUG) - if debug_log: - _LOGGER.debug("%s >> %s", self._host, request) - self.writer.write(TPLinkSmartHomeProtocol.encrypt(request)) - await self.writer.drain() - - packed_block_size = await self.reader.readexactly(self.BLOCK_SIZE) - length = _UNSIGNED_INT_NETWORK_ORDER.unpack(packed_block_size)[0] - - buffer = await self.reader.readexactly(length) - response = TPLinkSmartHomeProtocol.decrypt(buffer) - json_payload = json_loads(response) - if debug_log: - _LOGGER.debug("%s << %s", self._host, pf(json_payload)) - - return json_payload - - async def close(self) -> None: - """Close the connection.""" - writer = self.writer - self.close_without_wait() - if writer: - with contextlib.suppress(Exception): - await writer.wait_closed() - - def close_without_wait(self) -> None: - """Close the connection without waiting for the connection to close.""" - writer = self.writer - self.reader = self.writer = None - if writer: - writer.close() - - async def reset(self) -> None: - """Reset the transport.""" - await self.close() - - async def _query(self, request: str, retry_count: int, timeout: int) -> Dict: - """Try to query a device.""" - # - # Most of the time we will already be connected if the device is online - # and the connect call will do nothing and return right away - # - # However, if we get an unrecoverable error (_NO_RETRY_ERRORS and - # ConnectionRefusedError) we do not want to keep trying since many - # connection open/close operations in the same time frame can block - # the event loop. - # This is especially import when there are multiple tplink devices being polled. - for retry in range(retry_count + 1): - try: - await self._connect(timeout) - except ConnectionRefusedError as ex: - await self.reset() - raise SmartDeviceException( - f"Unable to connect to the device: {self._host}:{self._port}: {ex}" - ) from ex - except OSError as ex: - await self.reset() - if ex.errno in _NO_RETRY_ERRORS or retry >= retry_count: - raise SmartDeviceException( - f"Unable to connect to the device:" - f" {self._host}:{self._port}: {ex}" - ) from ex - continue - except Exception as ex: - await self.reset() - if retry >= retry_count: - _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) - raise SmartDeviceException( - f"Unable to connect to the device:" - f" {self._host}:{self._port}: {ex}" - ) from ex - continue - except BaseException as ex: - # Likely something cancelled the task so we need to close the connection - # as we are not in an indeterminate state - self.close_without_wait() - _LOGGER.debug( - "%s: BaseException during connect, closing connection: %s", - self._host, - ex, - ) - raise - - try: - assert self.reader is not None # noqa: S101 - assert self.writer is not None # noqa: S101 - async with asyncio_timeout(timeout): - return await self._execute_query(request) - except Exception as ex: - await self.reset() - if retry >= retry_count: - _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) - raise SmartDeviceException( - f"Unable to query the device {self._host}:{self._port}: {ex}" - ) from ex - - _LOGGER.debug( - "Unable to query the device %s, retrying: %s", self._host, ex - ) - except BaseException as ex: - # Likely something cancelled the task so we need to close the connection - # as we are not in an indeterminate state - self.close_without_wait() - _LOGGER.debug( - "%s: BaseException during query, closing connection: %s", - self._host, - ex, - ) - raise - - # make mypy happy, this should never be reached.. - await self.reset() - raise SmartDeviceException("Query reached somehow to unreachable") - - def __del__(self) -> None: - if self.writer and self.loop and self.loop.is_running(): - # Since __del__ will be called when python does - # garbage collection is can happen in the event loop thread - # or in another thread so we need to make sure the call to - # close is called safely with call_soon_threadsafe - self.loop.call_soon_threadsafe(self.writer.close) - - @staticmethod - def _xor_payload(unencrypted: bytes) -> Generator[int, None, None]: - key = TPLinkSmartHomeProtocol.INITIALIZATION_VECTOR - for unencryptedbyte in unencrypted: - key = key ^ unencryptedbyte - yield key - - @staticmethod - def encrypt(request: str) -> bytes: - """Encrypt a request for a TP-Link Smart Home Device. - - :param request: plaintext request data - :return: ciphertext to be send over wire, in bytes - """ - plainbytes = request.encode() - return _UNSIGNED_INT_NETWORK_ORDER.pack(len(plainbytes)) + bytes( - TPLinkSmartHomeProtocol._xor_payload(plainbytes) - ) - - @staticmethod - def _xor_encrypted_payload(ciphertext: bytes) -> Generator[int, None, None]: - key = TPLinkSmartHomeProtocol.INITIALIZATION_VECTOR - for cipherbyte in ciphertext: - plainbyte = key ^ cipherbyte - key = cipherbyte - yield plainbyte - - @staticmethod - def decrypt(ciphertext: bytes) -> str: - """Decrypt a response of a TP-Link Smart Home Device. - - :param ciphertext: encrypted response data - :return: plaintext response - """ - return bytes( - TPLinkSmartHomeProtocol._xor_encrypted_payload(ciphertext) - ).decode() - - def get_default_credentials(tuple: Tuple[str, str]) -> Credentials: """Return decoded default credentials.""" un = base64.b64decode(tuple[0].encode()).decode() @@ -381,12 +117,3 @@ DEFAULT_CREDENTIALS = { "KASA": ("a2FzYUB0cC1saW5rLm5ldA==", "a2FzYVNldHVw"), "TAPO": ("dGVzdEB0cC1saW5rLm5ldA==", "dGVzdA=="), } - -# Try to load the kasa_crypt module and if it is available -try: - from kasa_crypt import decrypt, encrypt - - TPLinkSmartHomeProtocol.decrypt = decrypt # type: ignore[method-assign] - TPLinkSmartHomeProtocol.encrypt = encrypt # type: ignore[method-assign] -except ImportError: - pass diff --git a/kasa/smartdevice.py b/kasa/smartdevice.py index 31418afc..01ca382d 100755 --- a/kasa/smartdevice.py +++ b/kasa/smartdevice.py @@ -24,8 +24,10 @@ from .device_type import DeviceType from .deviceconfig import DeviceConfig from .emeterstatus import EmeterStatus from .exceptions import SmartDeviceException +from .iotprotocol import IotProtocol from .modules import Emeter, Module -from .protocol import BaseProtocol, TPLinkSmartHomeProtocol, _XorTransport +from .protocol import BaseProtocol +from .xortransport import XorTransport _LOGGER = logging.getLogger(__name__) @@ -204,8 +206,8 @@ class SmartDevice: """ if config and protocol: protocol._transport._config = config - self.protocol: BaseProtocol = protocol or TPLinkSmartHomeProtocol( - transport=_XorTransport(config=config or DeviceConfig(host=host)), + self.protocol: BaseProtocol = protocol or IotProtocol( + transport=XorTransport(config=config or DeviceConfig(host=host)), ) _LOGGER.debug("Initializing %s of type %s", self.host, type(self)) self._device_type = DeviceType.Unknown diff --git a/kasa/tests/conftest.py b/kasa/tests/conftest.py index 9b573186..24bc3372 100644 --- a/kasa/tests/conftest.py +++ b/kasa/tests/conftest.py @@ -20,9 +20,9 @@ from kasa import ( SmartLightStrip, SmartPlug, SmartStrip, - TPLinkSmartHomeProtocol, ) from kasa.tapo import TapoBulb, TapoDevice, TapoPlug +from kasa.xortransport import XorEncryption from .newfakes import FakeSmartProtocol, FakeTransportProtocol @@ -478,7 +478,7 @@ def discovery_mock(all_fixture_data, mocker): device_type = sys_info.get("mic_type") or sys_info.get("type") encrypt_type = "XOR" login_version = None - datagram = TPLinkSmartHomeProtocol.encrypt(json_dumps(discovery_data))[4:] + datagram = XorEncryption.encrypt(json_dumps(discovery_data))[4:] dm = _DiscoveryMock( "127.0.0.123", 9999, @@ -517,7 +517,6 @@ def discovery_mock(all_fixture_data, mocker): mocker.patch("kasa.IotProtocol.query", side_effect=_query) mocker.patch("kasa.SmartProtocol.query", side_effect=_query) - mocker.patch("kasa.TPLinkSmartHomeProtocol.query", side_effect=_query) yield dm diff --git a/kasa/tests/newfakes.py b/kasa/tests/newfakes.py index 625a4994..aa3d42be 100644 --- a/kasa/tests/newfakes.py +++ b/kasa/tests/newfakes.py @@ -19,8 +19,10 @@ from voluptuous import ( from ..credentials import Credentials from ..deviceconfig import DeviceConfig from ..exceptions import SmartDeviceException -from ..protocol import BaseTransport, TPLinkSmartHomeProtocol, _XorTransport +from ..iotprotocol import IotProtocol +from ..protocol import BaseTransport from ..smartprotocol import SmartProtocol +from ..xortransport import XorTransport _LOGGER = logging.getLogger(__name__) @@ -381,10 +383,10 @@ class FakeSmartTransport(BaseTransport): pass -class FakeTransportProtocol(TPLinkSmartHomeProtocol): +class FakeTransportProtocol(IotProtocol): def __init__(self, info): super().__init__( - transport=_XorTransport( + transport=XorTransport( config=DeviceConfig("127.0.0.123"), ) ) diff --git a/kasa/tests/test_cli.py b/kasa/tests/test_cli.py index fa2d5c69..14dbb4bd 100644 --- a/kasa/tests/test_cli.py +++ b/kasa/tests/test_cli.py @@ -11,7 +11,6 @@ from kasa import ( EmeterStatus, SmartDevice, SmartDeviceException, - TPLinkSmartHomeProtocol, UnsupportedDeviceException, ) from kasa.cli import ( diff --git a/kasa/tests/test_device_factory.py b/kasa/tests/test_device_factory.py index 8e3e2ed6..9a068cd9 100644 --- a/kasa/tests/test_device_factory.py +++ b/kasa/tests/test_device_factory.py @@ -54,7 +54,6 @@ async def test_connect( mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data) mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data) - mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data) config = DeviceConfig( host=host, credentials=Credentials("foor", "bar"), connection_type=ctype @@ -87,7 +86,6 @@ async def test_connect_custom_port(all_fixture_data: dict, mocker, custom_port): default_port = 80 if "discovery_result" in all_fixture_data else 9999 ctype, _ = _get_connection_type_device_class(all_fixture_data) - mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data) mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data) mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data) dev = await connect(config=config) @@ -102,7 +100,6 @@ async def test_connect_logs_connect_time( ctype, _ = _get_connection_type_device_class(all_fixture_data) mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data) mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data) - mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data) host = "127.0.0.1" config = DeviceConfig( @@ -118,7 +115,6 @@ async def test_connect_logs_connect_time( async def test_connect_query_fails(all_fixture_data: dict, mocker): """Make sure that connect fails when query fails.""" host = "127.0.0.1" - mocker.patch("kasa.TPLinkSmartHomeProtocol.query", side_effect=SmartDeviceException) mocker.patch("kasa.IotProtocol.query", side_effect=SmartDeviceException) mocker.patch("kasa.SmartProtocol.query", side_effect=SmartDeviceException) @@ -138,7 +134,6 @@ async def test_connect_http_client(all_fixture_data, mocker): mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data) mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data) - mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data) http_client = aiohttp.ClientSession() diff --git a/kasa/tests/test_discovery.py b/kasa/tests/test_discovery.py index 2916e60a..db4d8fc1 100644 --- a/kasa/tests/test_discovery.py +++ b/kasa/tests/test_discovery.py @@ -15,7 +15,6 @@ from kasa import ( Discover, SmartDevice, SmartDeviceException, - TPLinkSmartHomeProtocol, protocol, ) from kasa.deviceconfig import ( @@ -26,6 +25,7 @@ from kasa.deviceconfig import ( ) from kasa.discover import DiscoveryResult, _DiscoverProtocol, json_dumps from kasa.exceptions import AuthenticationException, UnsupportedDeviceException +from kasa.xortransport import XorEncryption from .conftest import bulb, bulb_iot, dimmer, lightstrip, new_discovery, plug, strip @@ -189,7 +189,7 @@ async def test_discover_invalid_info(msg, data, mocker): def mock_discover(self): self.datagram_received( - protocol.TPLinkSmartHomeProtocol.encrypt(json_dumps(data))[4:], (host, 9999) + XorEncryption.encrypt(json_dumps(data))[4:], (host, 9999) ) mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover) @@ -212,7 +212,7 @@ async def test_discover_datagram_received(mocker, discovery_data): """Verify that datagram received fills discovered_devices.""" proto = _DiscoverProtocol() - mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "decrypt") + mocker.patch.object(XorEncryption, "decrypt") addr = "127.0.0.1" port = 20002 if "result" in discovery_data else 9999 @@ -238,8 +238,8 @@ async def test_discover_invalid_responses(msg, data, mocker): """Verify that we don't crash whole discovery if some devices in the network are sending unexpected data.""" proto = _DiscoverProtocol() mocker.patch("kasa.discover.json_loads", return_value=data) - mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "encrypt") - mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "decrypt") + mocker.patch.object(XorEncryption, "encrypt") + mocker.patch.object(XorEncryption, "decrypt") proto.datagram_received(data, ("127.0.0.1", 9999)) assert len(proto.discovered_devices) == 0 @@ -375,9 +375,7 @@ class FakeDatagramTransport(asyncio.DatagramTransport): self.do_not_reply_count = do_not_reply_count self.send_count = 0 if port == 9999: - self.datagram = TPLinkSmartHomeProtocol.encrypt( - json_dumps(LEGACY_DISCOVER_DATA) - )[4:] + self.datagram = XorEncryption.encrypt(json_dumps(LEGACY_DISCOVER_DATA))[4:] elif port == 20002: discovery_data = UNSUPPORTED if unsupported else AUTHENTICATION_DATA_KLAP self.datagram = ( diff --git a/kasa/tests/test_protocol.py b/kasa/tests/test_protocol.py index 34f2507e..78a8ea1e 100644 --- a/kasa/tests/test_protocol.py +++ b/kasa/tests/test_protocol.py @@ -15,13 +15,11 @@ from ..aestransport import AesTransport from ..credentials import Credentials from ..deviceconfig import DeviceConfig from ..exceptions import SmartDeviceException -from ..iotprotocol import IotProtocol +from ..iotprotocol import IotProtocol, _deprecated_TPLinkSmartHomeProtocol from ..klaptransport import KlapTransport, KlapTransportV2 from ..protocol import ( BaseProtocol, BaseTransport, - TPLinkSmartHomeProtocol, - _XorTransport, ) from ..xortransport import XorEncryption, XorTransport @@ -29,10 +27,10 @@ from ..xortransport import XorEncryption, XorTransport @pytest.mark.parametrize( "protocol_class, transport_class", [ - (TPLinkSmartHomeProtocol, _XorTransport), + (_deprecated_TPLinkSmartHomeProtocol, XorTransport), (IotProtocol, XorTransport), ], - ids=("TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"), + ids=("_deprecated_TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"), ) @pytest.mark.parametrize("retry_count", [1, 3, 5]) async def test_protocol_retries(mocker, retry_count, protocol_class, transport_class): @@ -59,10 +57,10 @@ async def test_protocol_retries(mocker, retry_count, protocol_class, transport_c @pytest.mark.parametrize( "protocol_class, transport_class", [ - (TPLinkSmartHomeProtocol, _XorTransport), + (_deprecated_TPLinkSmartHomeProtocol, XorTransport), (IotProtocol, XorTransport), ], - ids=("TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"), + ids=("_deprecated_TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"), ) async def test_protocol_no_retry_on_unreachable( mocker, protocol_class, transport_class @@ -83,10 +81,10 @@ async def test_protocol_no_retry_on_unreachable( @pytest.mark.parametrize( "protocol_class, transport_class", [ - (TPLinkSmartHomeProtocol, _XorTransport), + (_deprecated_TPLinkSmartHomeProtocol, XorTransport), (IotProtocol, XorTransport), ], - ids=("TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"), + ids=("_deprecated_TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"), ) async def test_protocol_no_retry_connection_refused( mocker, protocol_class, transport_class @@ -107,10 +105,10 @@ async def test_protocol_no_retry_connection_refused( @pytest.mark.parametrize( "protocol_class, transport_class", [ - (TPLinkSmartHomeProtocol, _XorTransport), + (_deprecated_TPLinkSmartHomeProtocol, XorTransport), (IotProtocol, XorTransport), ], - ids=("TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"), + ids=("_deprecated_TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"), ) async def test_protocol_retry_recoverable_error( mocker, protocol_class, transport_class @@ -131,10 +129,14 @@ async def test_protocol_retry_recoverable_error( @pytest.mark.parametrize( "protocol_class, transport_class, encryption_class", [ - (TPLinkSmartHomeProtocol, _XorTransport, TPLinkSmartHomeProtocol), + ( + _deprecated_TPLinkSmartHomeProtocol, + XorTransport, + _deprecated_TPLinkSmartHomeProtocol, + ), (IotProtocol, XorTransport, XorEncryption), ], - ids=("TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"), + ids=("_deprecated_TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"), ) @pytest.mark.parametrize("retry_count", [1, 3, 5]) async def test_protocol_reconnect( @@ -177,10 +179,14 @@ async def test_protocol_reconnect( @pytest.mark.parametrize( "protocol_class, transport_class, encryption_class", [ - (TPLinkSmartHomeProtocol, _XorTransport, TPLinkSmartHomeProtocol), + ( + _deprecated_TPLinkSmartHomeProtocol, + XorTransport, + _deprecated_TPLinkSmartHomeProtocol, + ), (IotProtocol, XorTransport, XorEncryption), ], - ids=("TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"), + ids=("_deprecated_TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"), ) async def test_protocol_handles_cancellation_during_write( mocker, protocol_class, transport_class, encryption_class @@ -227,10 +233,14 @@ async def test_protocol_handles_cancellation_during_write( @pytest.mark.parametrize( "protocol_class, transport_class, encryption_class", [ - (TPLinkSmartHomeProtocol, _XorTransport, TPLinkSmartHomeProtocol), + ( + _deprecated_TPLinkSmartHomeProtocol, + XorTransport, + _deprecated_TPLinkSmartHomeProtocol, + ), (IotProtocol, XorTransport, XorEncryption), ], - ids=("TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"), + ids=("_deprecated_TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"), ) async def test_protocol_handles_cancellation_during_connection( mocker, protocol_class, transport_class, encryption_class @@ -275,10 +285,14 @@ async def test_protocol_handles_cancellation_during_connection( @pytest.mark.parametrize( "protocol_class, transport_class, encryption_class", [ - (TPLinkSmartHomeProtocol, _XorTransport, TPLinkSmartHomeProtocol), + ( + _deprecated_TPLinkSmartHomeProtocol, + XorTransport, + _deprecated_TPLinkSmartHomeProtocol, + ), (IotProtocol, XorTransport, XorEncryption), ], - ids=("TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"), + ids=("_deprecated_TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"), ) @pytest.mark.parametrize("log_level", [logging.WARNING, logging.DEBUG]) async def test_protocol_logging( @@ -318,10 +332,14 @@ async def test_protocol_logging( @pytest.mark.parametrize( "protocol_class, transport_class, encryption_class", [ - (TPLinkSmartHomeProtocol, _XorTransport, TPLinkSmartHomeProtocol), + ( + _deprecated_TPLinkSmartHomeProtocol, + XorTransport, + _deprecated_TPLinkSmartHomeProtocol, + ), (IotProtocol, XorTransport, XorEncryption), ], - ids=("TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"), + ids=("_deprecated_TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"), ) @pytest.mark.parametrize("custom_port", [123, None]) async def test_protocol_custom_port( @@ -358,11 +376,11 @@ async def test_protocol_custom_port( @pytest.mark.parametrize( "encrypt_class", - [TPLinkSmartHomeProtocol, XorEncryption], + [_deprecated_TPLinkSmartHomeProtocol, XorEncryption], ) @pytest.mark.parametrize( "decrypt_class", - [TPLinkSmartHomeProtocol, XorEncryption], + [_deprecated_TPLinkSmartHomeProtocol, XorEncryption], ) def test_encrypt(encrypt_class, decrypt_class): d = json.dumps({"foo": 1, "bar": 2}) @@ -374,7 +392,7 @@ def test_encrypt(encrypt_class, decrypt_class): @pytest.mark.parametrize( "encrypt_class", - [TPLinkSmartHomeProtocol, XorEncryption], + [_deprecated_TPLinkSmartHomeProtocol, XorEncryption], ) def test_encrypt_unicode(encrypt_class): d = "{'snowman': '\u2603'}" @@ -411,7 +429,7 @@ def test_encrypt_unicode(encrypt_class): @pytest.mark.parametrize( "decrypt_class", - [TPLinkSmartHomeProtocol, XorEncryption], + [_deprecated_TPLinkSmartHomeProtocol, XorEncryption], ) def test_decrypt_unicode(decrypt_class): e = bytes( @@ -451,7 +469,11 @@ def _get_subclasses(of_class): importlib.import_module("." + modname, package="kasa") module = sys.modules["kasa." + modname] for name, obj in inspect.getmembers(module): - if inspect.isclass(obj) and issubclass(obj, of_class): + if ( + inspect.isclass(obj) + and issubclass(obj, of_class) + and name != "_deprecated_TPLinkSmartHomeProtocol" + ): subclasses.add((name, obj)) return subclasses @@ -491,7 +513,7 @@ def test_transport_init_signature(class_name_obj): @pytest.mark.parametrize( "transport_class", - [AesTransport, KlapTransport, KlapTransportV2, _XorTransport, XorTransport], + [AesTransport, KlapTransport, KlapTransportV2, XorTransport, XorTransport], ) async def test_transport_credentials_hash(mocker, transport_class): host = "127.0.0.1" @@ -519,10 +541,10 @@ async def test_transport_credentials_hash(mocker, transport_class): @pytest.mark.parametrize( "protocol_class, transport_class", [ - (TPLinkSmartHomeProtocol, _XorTransport), + (_deprecated_TPLinkSmartHomeProtocol, XorTransport), (IotProtocol, XorTransport), ], - ids=("TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"), + ids=("_deprecated_TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"), ) async def test_protocol_will_retry_on_connect( mocker, protocol_class, transport_class, error, retry_expectation @@ -551,10 +573,10 @@ async def test_protocol_will_retry_on_connect( @pytest.mark.parametrize( "protocol_class, transport_class", [ - (TPLinkSmartHomeProtocol, _XorTransport), + (_deprecated_TPLinkSmartHomeProtocol, XorTransport), (IotProtocol, XorTransport), ], - ids=("TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"), + ids=("_deprecated_TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"), ) async def test_protocol_will_retry_on_write( mocker, protocol_class, transport_class, error, retry_expectation diff --git a/kasa/xortransport.py b/kasa/xortransport.py index bed62ea8..95e78c20 100644 --- a/kasa/xortransport.py +++ b/kasa/xortransport.py @@ -1,4 +1,14 @@ -"""Module for the XorTransport.""" +"""Implementation of the legacy TP-Link Smart Home Protocol. + +Encryption/Decryption methods based on the works of +Lubomir Stroetmann and Tobias Esser + +https://www.softscheck.com/en/reverse-engineering-tp-link-hs110/ +https://github.com/softScheck/tplink-smartplug/ + +which are licensed under the Apache License, Version 2.0 +http://www.apache.org/licenses/LICENSE-2.0 +""" import asyncio import contextlib import errno @@ -23,13 +33,7 @@ _UNSIGNED_INT_NETWORK_ORDER = struct.Struct(">I") class XorTransport(BaseTransport): - """Implementation of the Xor encryption transport. - - WIP, currently only to ensure consistent __init__ method signatures - for protocol classes. Will eventually incorporate the logic from - TPLinkSmartHomeProtocol to simplify the API and re-use the IotProtocol - class. - """ + """XorTransport class.""" DEFAULT_PORT: int = 9999 BLOCK_SIZE = 4