From 20ea6700a51c8476bb68b56d33558921b1d45890 Mon Sep 17 00:00:00 2001 From: sdb9696 <51370195+sdb9696@users.noreply.github.com> Date: Tue, 19 Dec 2023 14:11:59 +0000 Subject: [PATCH] Do login entirely within AesTransport (#580) * Do login entirely within AesTransport * Remove login and handshake attributes from BaseTransport * Add AesTransport tests * Synchronise transport and protocol __init__ signatures and rename internal variables * Update after review --- kasa/aestransport.py | 81 ++++++--------- kasa/device_factory.py | 19 +++- kasa/iotprotocol.py | 44 +++----- kasa/klaptransport.py | 76 +++++--------- kasa/protocol.py | 108 +++++++++++--------- kasa/smartdevice.py | 4 +- kasa/smartprotocol.py | 59 ++++------- kasa/tapo/tapodevice.py | 8 +- kasa/tests/newfakes.py | 3 + kasa/tests/test_aestransport.py | 174 ++++++++++++++++++++++++++++++++ kasa/tests/test_klapprotocol.py | 29 ++++-- kasa/tests/test_protocol.py | 98 ++++++++++++++++-- kasa/tests/test_smartdevice.py | 2 +- 13 files changed, 468 insertions(+), 237 deletions(-) create mode 100644 kasa/tests/test_aestransport.py diff --git a/kasa/aestransport.py b/kasa/aestransport.py index dc982b61..9db0db4f 100644 --- a/kasa/aestransport.py +++ b/kasa/aestransport.py @@ -8,7 +8,7 @@ import base64 import hashlib import logging import time -from typing import Optional, Union +from typing import Optional import httpx from cryptography.hazmat.primitives import padding, serialization @@ -47,6 +47,7 @@ class AesTransport(BaseTransport): protocol, sometimes used by newer firmware versions on kasa devices. """ + DEFAULT_PORT = 80 DEFAULT_TIMEOUT = 5 SESSION_COOKIE_NAME = "TP_SESSIONID" COMMON_HEADERS = { @@ -59,12 +60,16 @@ class AesTransport(BaseTransport): self, host: str, *, + port: Optional[int] = None, credentials: Optional[Credentials] = None, timeout: Optional[int] = None, ) -> None: - super().__init__(host=host) - - self._credentials = credentials or Credentials(username="", password="") + super().__init__( + host, + port=port or self.DEFAULT_PORT, + credentials=credentials, + timeout=timeout, + ) self._handshake_done = False @@ -77,7 +82,7 @@ class AesTransport(BaseTransport): self._http_client: httpx.AsyncClient = httpx.AsyncClient() self._login_token = None - _LOGGER.debug("Created AES object for %s", self.host) + _LOGGER.debug("Created AES transport for %s", self._host) def hash_credentials(self, login_v2): """Hash the credentials.""" @@ -123,7 +128,7 @@ class AesTransport(BaseTransport): if ( error_code := SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type] ) != SmartErrorCode.SUCCESS: - msg = f"{msg}: {self.host}: {error_code.name}({error_code.value})" + msg = f"{msg}: {self._host}: {error_code.name}({error_code.value})" if error_code in SMART_TIMEOUT_ERRORS: raise TimeoutException(msg) if error_code in SMART_RETRYABLE_ERRORS: @@ -136,7 +141,7 @@ class AesTransport(BaseTransport): async def send_secure_passthrough(self, request: str): """Send encrypted message as passthrough.""" - url = f"http://{self.host}/app" + url = f"http://{self._host}/app" if self._login_token: url += f"?token={self._login_token}" @@ -150,7 +155,7 @@ class AesTransport(BaseTransport): if status_code != 200: raise SmartDeviceException( - f"{self.host} responded with an unexpected " + f"{self._host} responded with an unexpected " + f"status code {status_code} to passthrough" ) @@ -164,49 +169,31 @@ class AesTransport(BaseTransport): resp_dict = json_loads(response) return resp_dict - async def perform_login(self, login_request: Union[str, dict], *, login_v2: bool): + async def _perform_login_for_version(self, *, login_version: int = 1): """Login to the device.""" self._login_token = None - - if isinstance(login_request, str): - login_request_dict: dict = json_loads(login_request) - else: - login_request_dict = login_request - - un, pw = self.hash_credentials(login_v2) - login_request_dict["params"] = {"password": pw, "username": un} - request = json_dumps(login_request_dict) + un, pw = self.hash_credentials(login_version == 2) + password_field_name = "password2" if login_version == 2 else "password" + login_request = { + "method": "login_device", + "params": {password_field_name: pw, "username": un}, + "request_time_milis": round(time.time() * 1000), + } + request = json_dumps(login_request) try: resp_dict = await self.send_secure_passthrough(request) except SmartDeviceException as ex: raise AuthenticationException(ex) from ex self._login_token = resp_dict["result"]["token"] - @property - def needs_login(self) -> bool: - """Return true if the transport needs to do a login.""" - return self._login_token is None - - async def login(self, request: str) -> None: + async def perform_login(self) -> None: """Login to the device.""" try: - if self.needs_handshake: - raise SmartDeviceException( - "Handshake must be complete before trying to login" - ) - await self.perform_login(request, login_v2=False) + await self._perform_login_for_version(login_version=2) except AuthenticationException: + _LOGGER.warning("Login version 2 failed, trying version 1") await self.perform_handshake() - await self.perform_login(request, login_v2=True) - - @property - def needs_handshake(self) -> bool: - """Return true if the transport needs to do a handshake.""" - return not self._handshake_done or self._handshake_session_expired() - - async def handshake(self) -> None: - """Perform the encryption handshake.""" - await self.perform_handshake() + await self._perform_login_for_version(login_version=1) async def perform_handshake(self): """Perform the handshake.""" @@ -217,7 +204,7 @@ class AesTransport(BaseTransport): self._session_expire_at = None self._session_cookie = None - url = f"http://{self.host}/app" + url = f"http://{self._host}/app" key_pair = KeyPair.create_key_pair() pub_key = ( @@ -238,7 +225,7 @@ class AesTransport(BaseTransport): if status_code != 200: raise SmartDeviceException( - f"{self.host} responded with an unexpected " + f"{self._host} responded with an unexpected " + f"status code {status_code} to handshake" ) @@ -261,7 +248,7 @@ class AesTransport(BaseTransport): self._handshake_done = True - _LOGGER.debug("Handshake with %s complete", self.host) + _LOGGER.debug("Handshake with %s complete", self._host) def _handshake_session_expired(self): """Return true if session has expired.""" @@ -272,12 +259,10 @@ class AesTransport(BaseTransport): async def send(self, request: str): """Send the request.""" - if self.needs_handshake: - raise SmartDeviceException( - "Handshake must be complete before trying to send" - ) - if self.needs_login: - raise SmartDeviceException("Login must be complete before trying to send") + if not self._handshake_done or self._handshake_session_expired(): + await self.perform_handshake() + if not self._login_token: + await self.perform_login() return await self.send_secure_passthrough(request) diff --git a/kasa/device_factory.py b/kasa/device_factory.py index 15896e06..d8a07bee 100755 --- a/kasa/device_factory.py +++ b/kasa/device_factory.py @@ -74,7 +74,12 @@ async def connect( host=host, port=port, credentials=credentials, timeout=timeout ) if protocol_class is not None: - dev.protocol = protocol_class(host, credentials=credentials) + dev.protocol = protocol_class( + host, + transport=AesTransport( + host, port=port, credentials=credentials, timeout=timeout + ), + ) await dev.update() if debug_enabled: end_time = time.perf_counter() @@ -90,7 +95,13 @@ async def connect( host=host, port=port, credentials=credentials, timeout=timeout ) if protocol_class is not None: - unknown_dev.protocol = protocol_class(host, credentials=credentials) + # TODO this will be replaced with connection params + unknown_dev.protocol = protocol_class( + host, + transport=AesTransport( + host, port=port, credentials=credentials, timeout=timeout + ), + ) await unknown_dev.update() device_class = get_device_class_from_sys_info(unknown_dev.internal_state) dev = device_class(host=host, port=port, credentials=credentials, timeout=timeout) @@ -163,7 +174,5 @@ def get_protocol_from_connection_name( protocol_class, transport_class = supported_device_protocols.get(connection_name) # type: ignore transport: BaseTransport = transport_class(host, credentials=credentials) - protocol: TPLinkProtocol = protocol_class( - host, credentials=credentials, transport=transport - ) + protocol: TPLinkProtocol = protocol_class(host, transport=transport) return protocol diff --git a/kasa/iotprotocol.py b/kasa/iotprotocol.py index 2b7f422d..d942d060 100755 --- a/kasa/iotprotocol.py +++ b/kasa/iotprotocol.py @@ -1,14 +1,12 @@ """Module for the IOT legacy IOT KASA protocol.""" import asyncio import logging -from typing import Dict, Optional, Union +from typing import Dict, Union import httpx -from .credentials import Credentials from .exceptions import AuthenticationException, SmartDeviceException from .json import dumps as json_dumps -from .klaptransport import KlapTransport from .protocol import BaseTransport, TPLinkProtocol _LOGGER = logging.getLogger(__name__) @@ -17,24 +15,14 @@ _LOGGER = logging.getLogger(__name__) class IotProtocol(TPLinkProtocol): """Class for the legacy TPLink IOT KASA Protocol.""" - DEFAULT_PORT = 80 - def __init__( self, host: str, *, - transport: Optional[BaseTransport] = None, - credentials: Optional[Credentials] = None, - timeout: Optional[int] = None, + transport: BaseTransport, ) -> None: - super().__init__(host=host, port=self.DEFAULT_PORT) - - self._credentials: Credentials = credentials or Credentials( - username="", password="" - ) - self._transport: BaseTransport = transport or KlapTransport( - host, credentials=self._credentials, timeout=timeout - ) + """Create a protocol object.""" + super().__init__(host, transport=transport) self._query_lock = asyncio.Lock() @@ -54,30 +42,32 @@ class IotProtocol(TPLinkProtocol): except httpx.CloseError as sdex: await self.close() if retry >= retry_count: - _LOGGER.debug("Giving up on %s after %s retries", self.host, retry) + _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise SmartDeviceException( - f"Unable to connect to the device: {self.host}: {sdex}" + f"Unable to connect to the device: {self._host}: {sdex}" ) from sdex continue except httpx.ConnectError as cex: await self.close() raise SmartDeviceException( - f"Unable to connect to the device: {self.host}: {cex}" + f"Unable to connect to the device: {self._host}: {cex}" ) from cex except TimeoutError as tex: await self.close() raise SmartDeviceException( - f"Unable to connect to the device, timed out: {self.host}: {tex}" + f"Unable to connect to the device, timed out: {self._host}: {tex}" ) from tex except AuthenticationException as auex: - _LOGGER.debug("Unable to authenticate with %s, not retrying", self.host) + _LOGGER.debug( + "Unable to authenticate with %s, not retrying", self._host + ) raise auex except Exception as ex: await self.close() if retry >= retry_count: - _LOGGER.debug("Giving up on %s after %s retries", self.host, retry) + _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise SmartDeviceException( - f"Unable to connect to the device: {self.host}: {ex}" + f"Unable to connect to the device: {self._host}: {ex}" ) from ex continue @@ -85,14 +75,6 @@ class IotProtocol(TPLinkProtocol): raise SmartDeviceException("Query reached somehow to unreachable") async def _execute_query(self, request: str, retry_count: int) -> Dict: - if self._transport.needs_handshake: - await self._transport.handshake() - - if self._transport.needs_login: # This shouln't happen - raise SmartDeviceException( - "IOT Protocol needs to login to transport but is not login aware" - ) - return await self._transport.send(request) async def close(self) -> None: diff --git a/kasa/klaptransport.py b/kasa/klaptransport.py index d578ef84..e7bb8ae6 100644 --- a/kasa/klaptransport.py +++ b/kasa/klaptransport.py @@ -82,7 +82,7 @@ class KlapTransport(BaseTransport): protocol, used by newer firmware versions. """ - DEFAULT_TIMEOUT = 5 + DEFAULT_PORT = 80 DISCOVERY_QUERY = {"system": {"get_sysinfo": None}} KASA_SETUP_EMAIL = "kasa@tp-link.net" KASA_SETUP_PASSWORD = "kasaSetup" # noqa: S105 @@ -92,12 +92,17 @@ class KlapTransport(BaseTransport): self, host: str, *, + port: Optional[int] = None, credentials: Optional[Credentials] = None, timeout: Optional[int] = None, ) -> None: - super().__init__(host=host) + super().__init__( + host, + port=port or self.DEFAULT_PORT, + credentials=credentials, + timeout=timeout, + ) - self._credentials = credentials or Credentials(username="", password="") self._local_seed: Optional[bytes] = None self._local_auth_hash = self.generate_auth_hash(self._credentials) self._local_auth_owner = self.generate_owner_hash(self._credentials).hex() @@ -110,11 +115,10 @@ class KlapTransport(BaseTransport): self._encryption_session: Optional[KlapEncryptionSession] = None self._session_expire_at: Optional[float] = None - self._timeout = timeout if timeout else self.DEFAULT_TIMEOUT self._session_cookie = None self._http_client: httpx.AsyncClient = httpx.AsyncClient() - _LOGGER.debug("Created KLAP object for %s", self.host) + _LOGGER.debug("Created KLAP transport for %s", self._host) async def client_post(self, url, params=None, data=None): """Send an http post request to the device.""" @@ -148,7 +152,7 @@ class KlapTransport(BaseTransport): payload = local_seed - url = f"http://{self.host}/app/handshake1" + url = f"http://{self._host}/app/handshake1" response_status, response_data = await self.client_post(url, data=payload) @@ -157,14 +161,14 @@ class KlapTransport(BaseTransport): "Handshake1 posted at %s. Host is %s, Response" + "status is %s, Request was %s", datetime.datetime.now(), - self.host, + self._host, response_status, payload.hex(), ) if response_status != 200: raise AuthenticationException( - f"Device {self.host} responded with {response_status} to handshake1" + f"Device {self._host} responded with {response_status} to handshake1" ) remote_seed: bytes = response_data[0:16] @@ -175,7 +179,7 @@ class KlapTransport(BaseTransport): "Handshake1 success at %s. Host is %s, " + "Server remote_seed is: %s, server hash is: %s", datetime.datetime.now(), - self.host, + self._host, remote_seed.hex(), server_hash.hex(), ) @@ -207,7 +211,7 @@ class KlapTransport(BaseTransport): _LOGGER.debug( "Server response doesn't match our expected hash on ip %s" + " but an authentication with kasa setup credentials matched", - self.host, + self._host, ) return local_seed, remote_seed, self._kasa_setup_auth_hash # type: ignore @@ -226,11 +230,11 @@ class KlapTransport(BaseTransport): _LOGGER.debug( "Server response doesn't match our expected hash on ip %s" + " but an authentication with blank credentials matched", - self.host, + self._host, ) return local_seed, remote_seed, self._blank_auth_hash # type: ignore - msg = f"Server response doesn't match our challenge on ip {self.host}" + msg = f"Server response doesn't match our challenge on ip {self._host}" _LOGGER.debug(msg) raise AuthenticationException(msg) @@ -241,7 +245,7 @@ class KlapTransport(BaseTransport): # Handshake 2 has the following payload: # sha256(serverBytes | authenticator) - url = f"http://{self.host}/app/handshake2" + url = f"http://{self._host}/app/handshake2" payload = self.handshake2_seed_auth_hash(local_seed, remote_seed, auth_hash) @@ -252,44 +256,24 @@ class KlapTransport(BaseTransport): "Handshake2 posted %s. Host is %s, Response status is %s, " + "Request was %s", datetime.datetime.now(), - self.host, + self._host, response_status, payload.hex(), ) if response_status != 200: raise AuthenticationException( - f"Device {self.host} responded with {response_status} to handshake2" + f"Device {self._host} responded with {response_status} to handshake2" ) return KlapEncryptionSession(local_seed, remote_seed, auth_hash) - @property - def needs_login(self) -> bool: - """Will return false as KLAP does not do a login.""" - return False - - async def login(self, request: str) -> None: - """Will raise and exception as KLAP does not do a login.""" - raise SmartDeviceException( - "KLAP does not perform logins and return needs_login == False" - ) - - @property - def needs_handshake(self) -> bool: - """Return true if the transport needs to do a handshake.""" - return not self._handshake_done or self._handshake_session_expired() - - async def handshake(self) -> None: - """Perform the encryption handshake.""" - await self.perform_handshake() - async def perform_handshake(self) -> Any: """Perform handshake1 and handshake2. Sets the encryption_session if successful. """ - _LOGGER.debug("Starting handshake with %s", self.host) + _LOGGER.debug("Starting handshake with %s", self._host) self._handshake_done = False self._session_expire_at = None self._session_cookie = None @@ -307,7 +291,7 @@ class KlapTransport(BaseTransport): ) self._handshake_done = True - _LOGGER.debug("Handshake with %s complete", self.host) + _LOGGER.debug("Handshake with %s complete", self._host) def _handshake_session_expired(self): """Return true if session has expired.""" @@ -318,18 +302,14 @@ class KlapTransport(BaseTransport): async def send(self, request: str): """Send the request.""" - if self.needs_handshake: - raise SmartDeviceException( - "Handshake must be complete before trying to send" - ) - if self.needs_login: - raise SmartDeviceException("Login must be complete before trying to send") + if not self._handshake_done or self._handshake_session_expired(): + await self.perform_handshake() # Check for mypy if self._encryption_session is not None: payload, seq = self._encryption_session.encrypt(request.encode()) - url = f"http://{self.host}/app/request" + url = f"http://{self._host}/app/request" response_status, response_data = await self.client_post( url, @@ -338,7 +318,7 @@ class KlapTransport(BaseTransport): ) msg = ( - f"at {datetime.datetime.now()}. Host is {self.host}, " + f"at {datetime.datetime.now()}. Host is {self._host}, " + f"Sequence is {seq}, " + f"Response status is {response_status}, Request was {request}" ) @@ -348,12 +328,12 @@ class KlapTransport(BaseTransport): if response_status == 403: self._handshake_done = False raise AuthenticationException( - f"Got a security error from {self.host} after handshake " + f"Got a security error from {self._host} after handshake " + "completed" ) else: raise SmartDeviceException( - f"Device {self.host} responded with {response_status} to" + f"Device {self._host} responded with {response_status} to" + f"request with seq {seq}" ) else: @@ -367,7 +347,7 @@ class KlapTransport(BaseTransport): _LOGGER.debug( "%s << %s", - self.host, + self._host, _LOGGER.isEnabledFor(logging.DEBUG) and pf(json_payload), ) diff --git a/kasa/protocol.py b/kasa/protocol.py index 62cd5fb6..f73260bf 100755 --- a/kasa/protocol.py +++ b/kasa/protocol.py @@ -44,35 +44,21 @@ def md5(payload: bytes) -> bytes: class BaseTransport(ABC): """Base class for all TP-Link protocol transports.""" + DEFAULT_TIMEOUT = 5 + def __init__( self, host: str, *, port: Optional[int] = None, credentials: Optional[Credentials] = None, + timeout: Optional[int] = None, ) -> None: """Create a protocol object.""" - self.host = host - self.port = port - self.credentials = credentials - - @property - @abstractmethod - def needs_handshake(self) -> bool: - """Return true if the transport needs to do a handshake.""" - - @property - @abstractmethod - def needs_login(self) -> bool: - """Return true if the transport needs to do a login.""" - - @abstractmethod - async def login(self, request: str) -> None: - """Login to the device.""" - - @abstractmethod - async def handshake(self) -> None: - """Perform the encryption handshake.""" + self._host = host + self._port = port + self._credentials = credentials or Credentials(username="", password="") + self._timeout = timeout or self.DEFAULT_TIMEOUT @abstractmethod async def send(self, request: str) -> Dict: @@ -90,14 +76,14 @@ class TPLinkProtocol(ABC): self, host: str, *, - port: Optional[int] = None, - credentials: Optional[Credentials] = None, - transport: Optional[BaseTransport] = None, + transport: BaseTransport, ) -> None: """Create a protocol object.""" - self.host = host - self.port = port - self.credentials = credentials + self._transport = transport + + @property + def _host(self): + return self._transport._host @abstractmethod async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: @@ -108,6 +94,40 @@ class TPLinkProtocol(ABC): """Close the protocol. Abstract method to be overriden.""" +class _XorTransport(BaseTransport): + """Implementation of the Xor encryption transport. + + WIP, currently only to ensure consistent __init__ method signatures + for protocol classes. Will eventually incorporate the logic from + TPLinkSmartHomeProtocol to simplify the API and re-use the IotProtocol + class. + """ + + DEFAULT_PORT = 9999 + + def __init__( + self, + host: str, + *, + port: Optional[int] = None, + credentials: Optional[Credentials] = None, + timeout: Optional[int] = None, + ) -> None: + super().__init__( + host, + port=port or self.DEFAULT_PORT, + credentials=credentials, + timeout=timeout, + ) + + async def send(self, request: str) -> Dict: + """Send a message to the device and return a response.""" + return {} + + async def close(self) -> None: + """Close the transport. Abstract method to be overriden.""" + + class TPLinkSmartHomeProtocol(TPLinkProtocol): """Implementation of the TP-Link Smart Home protocol.""" @@ -120,20 +140,18 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol): self, host: str, *, - port: Optional[int] = None, - timeout: Optional[int] = None, - credentials: Optional[Credentials] = None, + transport: BaseTransport, ) -> None: """Create a protocol object.""" - super().__init__( - host=host, port=port or self.DEFAULT_PORT, credentials=credentials - ) + super().__init__(host, transport=transport) self.reader: Optional[asyncio.StreamReader] = None self.writer: Optional[asyncio.StreamWriter] = None self.query_lock = asyncio.Lock() self.loop: Optional[asyncio.AbstractEventLoop] = None - self.timeout = timeout or TPLinkSmartHomeProtocol.DEFAULT_TIMEOUT + + self._timeout = self._transport._timeout + self._port = self._transport._port async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: """Request information from a TP-Link SmartHome Device. @@ -149,7 +167,7 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol): assert isinstance(request, str) # noqa: S101 async with self.query_lock: - return await self._query(request, retry_count, self.timeout) + return await self._query(request, retry_count, self._timeout) async def _connect(self, timeout: int) -> None: """Try to connect or reconnect to the device.""" @@ -157,7 +175,7 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol): return self.reader = self.writer = None - task = asyncio.open_connection(self.host, self.port) + task = asyncio.open_connection(self._host, self._port) async with asyncio_timeout(timeout): self.reader, self.writer = await task sock: socket.socket = self.writer.get_extra_info("socket") @@ -174,7 +192,7 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol): debug_log = _LOGGER.isEnabledFor(logging.DEBUG) if debug_log: - _LOGGER.debug("%s >> %s", self.host, request) + _LOGGER.debug("%s >> %s", self._host, request) self.writer.write(TPLinkSmartHomeProtocol.encrypt(request)) await self.writer.drain() @@ -185,7 +203,7 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol): response = TPLinkSmartHomeProtocol.decrypt(buffer) json_payload = json_loads(response) if debug_log: - _LOGGER.debug("%s << %s", self.host, pf(json_payload)) + _LOGGER.debug("%s << %s", self._host, pf(json_payload)) return json_payload @@ -219,23 +237,23 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol): except ConnectionRefusedError as ex: await self.close() raise SmartDeviceException( - f"Unable to connect to the device: {self.host}:{self.port}: {ex}" + f"Unable to connect to the device: {self._host}:{self._port}: {ex}" ) from ex except OSError as ex: await self.close() if ex.errno in _NO_RETRY_ERRORS or retry >= retry_count: raise SmartDeviceException( f"Unable to connect to the device:" - f" {self.host}:{self.port}: {ex}" + f" {self._host}:{self._port}: {ex}" ) from ex continue except Exception as ex: await self.close() if retry >= retry_count: - _LOGGER.debug("Giving up on %s after %s retries", self.host, retry) + _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise SmartDeviceException( f"Unable to connect to the device:" - f" {self.host}:{self.port}: {ex}" + f" {self._host}:{self._port}: {ex}" ) from ex continue @@ -247,13 +265,13 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol): except Exception as ex: await self.close() if retry >= retry_count: - _LOGGER.debug("Giving up on %s after %s retries", self.host, retry) + _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise SmartDeviceException( - f"Unable to query the device {self.host}:{self.port}: {ex}" + f"Unable to query the device {self._host}:{self._port}: {ex}" ) from ex _LOGGER.debug( - "Unable to query the device %s, retrying: %s", self.host, ex + "Unable to query the device %s, retrying: %s", self._host, ex ) # make mypy happy, this should never be reached.. diff --git a/kasa/smartdevice.py b/kasa/smartdevice.py index 342d1c4a..5ad94a9f 100755 --- a/kasa/smartdevice.py +++ b/kasa/smartdevice.py @@ -24,7 +24,7 @@ from .device_type import DeviceType from .emeterstatus import EmeterStatus from .exceptions import SmartDeviceException from .modules import Emeter, Module -from .protocol import TPLinkProtocol, TPLinkSmartHomeProtocol +from .protocol import TPLinkProtocol, TPLinkSmartHomeProtocol, _XorTransport _LOGGER = logging.getLogger(__name__) @@ -202,7 +202,7 @@ class SmartDevice: self.host = host self.port = port self.protocol: TPLinkProtocol = TPLinkSmartHomeProtocol( - host, port=port, timeout=timeout + host, transport=_XorTransport(host, port=port, timeout=timeout) ) self.credentials = credentials _LOGGER.debug("Initializing %s of type %s", self.host, type(self)) diff --git a/kasa/smartprotocol.py b/kasa/smartprotocol.py index eb661317..a344cf66 100644 --- a/kasa/smartprotocol.py +++ b/kasa/smartprotocol.py @@ -10,12 +10,10 @@ import logging import time import uuid from pprint import pformat as pf -from typing import Dict, Optional, Union +from typing import Dict, Union import httpx -from .aestransport import AesTransport -from .credentials import Credentials from .exceptions import ( SMART_AUTHENTICATION_ERRORS, SMART_RETRYABLE_ERRORS, @@ -36,26 +34,17 @@ logging.getLogger("httpx").propagate = False class SmartProtocol(TPLinkProtocol): """Class for the new TPLink SMART protocol.""" - DEFAULT_PORT = 80 SLEEP_SECONDS_AFTER_TIMEOUT = 1 def __init__( self, host: str, *, - transport: Optional[BaseTransport] = None, - credentials: Optional[Credentials] = None, - timeout: Optional[int] = None, + transport: BaseTransport, ) -> None: - super().__init__(host=host, port=self.DEFAULT_PORT) - - self._credentials: Credentials = credentials or Credentials( - username="", password="" - ) - self._transport: BaseTransport = transport or AesTransport( - host, credentials=self._credentials, timeout=timeout - ) - self._terminal_uuid: Optional[str] = None + """Create a protocol object.""" + super().__init__(host, transport=transport) + self._terminal_uuid: str = base64.b64encode(md5(uuid.uuid4().bytes)).decode() self._request_id_generator = SnowflakeId(1, 1) self._query_lock = asyncio.Lock() @@ -79,7 +68,7 @@ class SmartProtocol(TPLinkProtocol): error_code := SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type] ) != SmartErrorCode.SUCCESS: msg = ( - f"Error querying device: {self.host}: " + f"Error querying device: {self._host}: " + f"{error_code.name}({error_code.value})" ) if error_code in SMART_TIMEOUT_ERRORS: @@ -101,51 +90,53 @@ class SmartProtocol(TPLinkProtocol): except httpx.CloseError as sdex: await self.close() if retry >= retry_count: - _LOGGER.debug("Giving up on %s after %s retries", self.host, retry) + _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise SmartDeviceException( - f"Unable to connect to the device: {self.host}: {sdex}" + f"Unable to connect to the device: {self._host}: {sdex}" ) from sdex continue except httpx.ConnectError as cex: await self.close() raise SmartDeviceException( - f"Unable to connect to the device: {self.host}: {cex}" + f"Unable to connect to the device: {self._host}: {cex}" ) from cex except TimeoutError as tex: if retry >= retry_count: await self.close() raise SmartDeviceException( "Unable to connect to the device, " - + f"timed out: {self.host}: {tex}" + + f"timed out: {self._host}: {tex}" ) from tex await asyncio.sleep(self.SLEEP_SECONDS_AFTER_TIMEOUT) continue except AuthenticationException as auex: await self.close() - _LOGGER.debug("Unable to authenticate with %s, not retrying", self.host) + _LOGGER.debug( + "Unable to authenticate with %s, not retrying", self._host + ) raise auex except RetryableException as ex: if retry >= retry_count: await self.close() - _LOGGER.debug("Giving up on %s after %s retries", self.host, retry) + _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise ex continue except TimeoutException as ex: if retry >= retry_count: await self.close() - _LOGGER.debug("Giving up on %s after %s retries", self.host, retry) + _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise ex await asyncio.sleep(self.SLEEP_SECONDS_AFTER_TIMEOUT) continue except Exception as ex: if retry >= retry_count: await self.close() - _LOGGER.debug("Giving up on %s after %s retries", self.host, retry) + _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise SmartDeviceException( - f"Unable to query the device {self.host}:{self.port}: {ex}" + f"Unable to connect to the device: {self._host}: {ex}" ) from ex _LOGGER.debug( - "Unable to query the device %s, retrying: %s", self.host, ex + "Unable to query the device %s, retrying: %s", self._host, ex ) continue @@ -160,27 +151,17 @@ class SmartProtocol(TPLinkProtocol): smart_method = request smart_params = None - if self._transport.needs_handshake: - await self._transport.handshake() - - if self._transport.needs_login: - self._terminal_uuid = base64.b64encode(md5(uuid.uuid4().bytes)).decode( - "UTF-8" - ) - login_request = self.get_smart_request("login_device") - await self._transport.login(login_request) - smart_request = self.get_smart_request(smart_method, smart_params) _LOGGER.debug( "%s >> %s", - self.host, + self._host, _LOGGER.isEnabledFor(logging.DEBUG) and pf(smart_request), ) response_data = await self._transport.send(smart_request) _LOGGER.debug( "%s << %s", - self.host, + self._host, _LOGGER.isEnabledFor(logging.DEBUG) and pf(response_data), ) diff --git a/kasa/tapo/tapodevice.py b/kasa/tapo/tapodevice.py index 291e1744..e5d9effe 100644 --- a/kasa/tapo/tapodevice.py +++ b/kasa/tapo/tapodevice.py @@ -4,6 +4,7 @@ import logging from datetime import datetime, timedelta, timezone from typing import Any, Dict, Optional, Set, cast +from ..aestransport import AesTransport from ..credentials import Credentials from ..exceptions import AuthenticationException from ..smartdevice import SmartDevice @@ -27,7 +28,12 @@ class TapoDevice(SmartDevice): self._components: Optional[Dict[str, Any]] = None self._state_information: Dict[str, Any] = {} self._discovery_info: Optional[Dict[str, Any]] = None - self.protocol = SmartProtocol(host, credentials=credentials, timeout=timeout) + self.protocol = SmartProtocol( + host, + transport=AesTransport( + host, credentials=credentials, timeout=timeout, port=port + ), + ) async def update(self, update_children: bool = True): """Update the device.""" diff --git a/kasa/tests/newfakes.py b/kasa/tests/newfakes.py index 284f4e2b..c01c8ee3 100644 --- a/kasa/tests/newfakes.py +++ b/kasa/tests/newfakes.py @@ -301,6 +301,9 @@ class FakeSmartProtocol(SmartProtocol): class FakeSmartTransport(BaseTransport): def __init__(self, info): + super().__init__( + "127.0.0.123", + ) self.info = info @property diff --git a/kasa/tests/test_aestransport.py b/kasa/tests/test_aestransport.py new file mode 100644 index 00000000..b018b497 --- /dev/null +++ b/kasa/tests/test_aestransport.py @@ -0,0 +1,174 @@ +import base64 +import json +import time +from contextlib import nullcontext as does_not_raise +from json import dumps as json_dumps +from json import loads as json_loads + +import httpx +import pytest +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding + +from ..aestransport import AesEncyptionSession, AesTransport +from ..credentials import Credentials +from ..exceptions import SmartDeviceException + +DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}} + +key = b"8\x89\x02\xfa\xf5Xs\x1c\xa1 H\x9a\x82\xc7\xd9\t" +iv = b"9=\xf8\x1bS\xcd0\xb5\x89i\xba\xfd^9\x9f\xfa" +KEY_IV = key + iv + + +def test_encrypt(): + encryption_session = AesEncyptionSession(KEY_IV[:16], KEY_IV[16:]) + + d = json.dumps({"foo": 1, "bar": 2}) + encrypted = encryption_session.encrypt(d.encode()) + assert d == encryption_session.decrypt(encrypted) + + # test encrypt unicode + d = "{'snowman': '\u2603'}" + encrypted = encryption_session.encrypt(d.encode()) + assert d == encryption_session.decrypt(encrypted) + + +status_parameters = pytest.mark.parametrize( + "status_code, error_code, inner_error_code, expectation", + [ + (200, 0, 0, does_not_raise()), + (400, 0, 0, pytest.raises(SmartDeviceException)), + (200, -1, 0, pytest.raises(SmartDeviceException)), + ], + ids=("success", "status_code", "error_code"), +) + + +@status_parameters +async def test_handshake( + mocker, status_code, error_code, inner_error_code, expectation +): + host = "127.0.0.1" + mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code) + mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post) + + transport = AesTransport(host=host, credentials=Credentials("foo", "bar")) + + assert transport._encryption_session is None + assert transport._handshake_done is False + with expectation: + await transport.perform_handshake() + assert transport._encryption_session is not None + assert transport._handshake_done is True + + +@status_parameters +async def test_login(mocker, status_code, error_code, inner_error_code, expectation): + host = "127.0.0.1" + mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code) + mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post) + + transport = AesTransport(host=host, credentials=Credentials("foo", "bar")) + transport._handshake_done = True + transport._session_expire_at = time.time() + 86400 + transport._encryption_session = mock_aes_device.encryption_session + + assert transport._login_token is None + with expectation: + await transport.perform_login() + assert transport._login_token == mock_aes_device.token + + +@status_parameters +async def test_send(mocker, status_code, error_code, inner_error_code, expectation): + host = "127.0.0.1" + mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code) + mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post) + + transport = AesTransport(host=host, credentials=Credentials("foo", "bar")) + transport._handshake_done = True + transport._session_expire_at = time.time() + 86400 + transport._encryption_session = mock_aes_device.encryption_session + transport._login_token = mock_aes_device.token + + un, pw = transport.hash_credentials(True) + request = { + "method": "get_device_info", + "params": None, + "request_time_milis": round(time.time() * 1000), + "requestID": 1, + "terminal_uuid": "foobar", + } + with expectation: + res = await transport.send(json_dumps(request)) + assert "result" in res + + +class MockAesDevice: + class _mock_response: + def __init__(self, status_code, json: dict): + self.status_code = status_code + self._json = json + + def json(self): + return self._json + + encryption_session = AesEncyptionSession(KEY_IV[:16], KEY_IV[16:]) + token = "test_token" # noqa + + def __init__(self, host, status_code=200, error_code=0, inner_error_code=0): + self.host = host + self.status_code = status_code + self.error_code = error_code + self.inner_error_code = inner_error_code + + async def post(self, url, params=None, json=None, *_, **__): + return await self._post(url, json) + + async def _post(self, url, json): + if json["method"] == "handshake": + return await self._return_handshake_response(url, json) + elif json["method"] == "securePassthrough": + return await self._return_secure_passthrough_response(url, json) + elif json["method"] == "login_device": + return await self._return_login_response(url, json) + else: + assert url == f"http://{self.host}/app?token={self.token}" + return await self._return_send_response(url, json) + + async def _return_handshake_response(self, url, json): + start = len("-----BEGIN PUBLIC KEY-----\n") + end = len("\n-----END PUBLIC KEY-----\n") + client_pub_key = json["params"]["key"][start:-end] + + client_pub_key_data = base64.b64decode(client_pub_key.encode()) + client_pub_key = serialization.load_der_public_key(client_pub_key_data, None) + encrypted_key = client_pub_key.encrypt(KEY_IV, asymmetric_padding.PKCS1v15()) + key_64 = base64.b64encode(encrypted_key).decode() + return self._mock_response( + self.status_code, {"result": {"key": key_64}, "error_code": self.error_code} + ) + + async def _return_secure_passthrough_response(self, url, json): + encrypted_request = json["params"]["request"] + decrypted_request = self.encryption_session.decrypt(encrypted_request.encode()) + decrypted_request_dict = json_loads(decrypted_request) + decrypted_response = await self._post(url, decrypted_request_dict) + decrypted_response_dict = decrypted_response.json() + encrypted_response = self.encryption_session.encrypt( + json_dumps(decrypted_response_dict).encode() + ) + result = { + "result": {"response": encrypted_response.decode()}, + "error_code": self.error_code, + } + return self._mock_response(self.status_code, result) + + async def _return_login_response(self, url, json): + result = {"result": {"token": self.token}, "error_code": self.inner_error_code} + return self._mock_response(self.status_code, result) + + async def _return_send_response(self, url, json): + result = {"result": {"method": None}, "error_code": self.inner_error_code} + return self._mock_response(self.status_code, result) diff --git a/kasa/tests/test_klapprotocol.py b/kasa/tests/test_klapprotocol.py index 8ad46b6e..d29f4e30 100644 --- a/kasa/tests/test_klapprotocol.py +++ b/kasa/tests/test_klapprotocol.py @@ -96,10 +96,9 @@ async def test_protocol_reconnect(mocker, retry_count, protocol_class, transport return mock_response - mocker.patch.object( - transport_class, "needs_handshake", property(lambda self: False) - ) - mocker.patch.object(transport_class, "needs_login", property(lambda self: False)) + mocker.patch.object(transport_class, "perform_handshake") + if hasattr(transport_class, "perform_login"): + mocker.patch.object(transport_class, "perform_login") send_mock = mocker.patch.object( transport_class, @@ -128,7 +127,7 @@ async def test_protocol_logging(mocker, caplog, log_level): seed = secrets.token_bytes(16) auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar")) encryption_session = KlapEncryptionSession(seed, seed, auth_hash) - protocol = IotProtocol("127.0.0.1") + protocol = IotProtocol("127.0.0.1", transport=KlapTransport("127.0.0.1")) protocol._transport._handshake_done = True protocol._transport._session_expire_at = time.time() + 86400 @@ -206,7 +205,10 @@ async def test_handshake1(mocker, device_credentials, expectation): httpx.AsyncClient, "post", side_effect=_return_handshake1_response ) - protocol = IotProtocol("127.0.0.1", credentials=client_credentials) + protocol = IotProtocol( + "127.0.0.1", + transport=KlapTransport("127.0.0.1", credentials=client_credentials), + ) protocol._transport.http_client = httpx.AsyncClient() with expectation: @@ -243,7 +245,10 @@ async def test_handshake(mocker): httpx.AsyncClient, "post", side_effect=_return_handshake_response ) - protocol = IotProtocol("127.0.0.1", credentials=client_credentials) + protocol = IotProtocol( + "127.0.0.1", + transport=KlapTransport("127.0.0.1", credentials=client_credentials), + ) protocol._transport.http_client = httpx.AsyncClient() response_status = 200 @@ -289,7 +294,10 @@ async def test_query(mocker): mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response) - protocol = IotProtocol("127.0.0.1", credentials=client_credentials) + protocol = IotProtocol( + "127.0.0.1", + transport=KlapTransport("127.0.0.1", credentials=client_credentials), + ) for _ in range(10): resp = await protocol.query({}) @@ -333,7 +341,10 @@ async def test_authentication_failures(mocker, response_status, expectation): mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response) - protocol = IotProtocol("127.0.0.1", credentials=client_credentials) + protocol = IotProtocol( + "127.0.0.1", + transport=KlapTransport("127.0.0.1", credentials=client_credentials), + ) with expectation: await protocol.query({}) diff --git a/kasa/tests/test_protocol.py b/kasa/tests/test_protocol.py index b438f498..7bd6342b 100644 --- a/kasa/tests/test_protocol.py +++ b/kasa/tests/test_protocol.py @@ -1,13 +1,21 @@ import errno +import importlib +import inspect import json import logging +import pkgutil import struct import sys import pytest from ..exceptions import SmartDeviceException -from ..protocol import TPLinkSmartHomeProtocol +from ..protocol import ( + BaseTransport, + TPLinkProtocol, + TPLinkSmartHomeProtocol, + _XorTransport, +) @pytest.mark.parametrize("retry_count", [1, 3, 5]) @@ -24,7 +32,9 @@ async def test_protocol_retries(mocker, retry_count): conn = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) with pytest.raises(SmartDeviceException): - await TPLinkSmartHomeProtocol("127.0.0.1").query({}, retry_count=retry_count) + await TPLinkSmartHomeProtocol( + "127.0.0.1", transport=_XorTransport("127.0.0.1") + ).query({}, retry_count=retry_count) assert conn.call_count == retry_count + 1 @@ -35,7 +45,9 @@ async def test_protocol_no_retry_on_unreachable(mocker): side_effect=OSError(errno.EHOSTUNREACH, "No route to host"), ) with pytest.raises(SmartDeviceException): - await TPLinkSmartHomeProtocol("127.0.0.1").query({}, retry_count=5) + await TPLinkSmartHomeProtocol( + "127.0.0.1", transport=_XorTransport("127.0.0.1") + ).query({}, retry_count=5) assert conn.call_count == 1 @@ -46,7 +58,9 @@ async def test_protocol_no_retry_connection_refused(mocker): side_effect=ConnectionRefusedError, ) with pytest.raises(SmartDeviceException): - await TPLinkSmartHomeProtocol("127.0.0.1").query({}, retry_count=5) + await TPLinkSmartHomeProtocol( + "127.0.0.1", transport=_XorTransport("127.0.0.1") + ).query({}, retry_count=5) assert conn.call_count == 1 @@ -57,7 +71,9 @@ async def test_protocol_retry_recoverable_error(mocker): side_effect=OSError(errno.ECONNRESET, "Connection reset by peer"), ) with pytest.raises(SmartDeviceException): - await TPLinkSmartHomeProtocol("127.0.0.1").query({}, retry_count=5) + await TPLinkSmartHomeProtocol( + "127.0.0.1", transport=_XorTransport("127.0.0.1") + ).query({}, retry_count=5) assert conn.call_count == 6 @@ -91,7 +107,9 @@ async def test_protocol_reconnect(mocker, retry_count): mocker.patch.object(reader, "readexactly", _mock_read) return reader, writer - protocol = TPLinkSmartHomeProtocol("127.0.0.1") + protocol = TPLinkSmartHomeProtocol( + "127.0.0.1", transport=_XorTransport("127.0.0.1") + ) mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) response = await protocol.query({}, retry_count=retry_count) assert response == {"great": "success"} @@ -119,7 +137,9 @@ async def test_protocol_logging(mocker, caplog, log_level): mocker.patch.object(reader, "readexactly", _mock_read) return reader, writer - protocol = TPLinkSmartHomeProtocol("127.0.0.1") + protocol = TPLinkSmartHomeProtocol( + "127.0.0.1", transport=_XorTransport("127.0.0.1") + ) mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) response = await protocol.query({}) assert response == {"great": "success"} @@ -153,7 +173,9 @@ async def test_protocol_custom_port(mocker, custom_port): mocker.patch.object(reader, "readexactly", _mock_read) return reader, writer - protocol = TPLinkSmartHomeProtocol("127.0.0.1", port=custom_port) + protocol = TPLinkSmartHomeProtocol( + "127.0.0.1", transport=_XorTransport("127.0.0.1", port=custom_port) + ) mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) response = await protocol.query({}) assert response == {"great": "success"} @@ -227,3 +249,63 @@ def test_decrypt_unicode(): d = "{'snowman': '\u2603'}" assert d == TPLinkSmartHomeProtocol.decrypt(e) + + +def _get_subclasses(of_class): + import kasa + + package = sys.modules["kasa"] + subclasses = set() + for _, modname, _ in pkgutil.iter_modules(package.__path__): + importlib.import_module("." + modname, package="kasa") + module = sys.modules["kasa." + modname] + for name, obj in inspect.getmembers(module): + if inspect.isclass(obj) and issubclass(obj, of_class): + subclasses.add((name, obj)) + return subclasses + + +@pytest.mark.parametrize( + "class_name_obj", _get_subclasses(TPLinkProtocol), ids=lambda t: t[0] +) +def test_protocol_init_signature(class_name_obj): + params = list(inspect.signature(class_name_obj[1].__init__).parameters.values()) + + assert len(params) == 3 + assert ( + params[0].name == "self" + and params[0].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + ) + assert ( + params[1].name == "host" + and params[1].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + ) + assert ( + params[2].name == "transport" + and params[2].kind == inspect.Parameter.KEYWORD_ONLY + ) + + +@pytest.mark.parametrize( + "class_name_obj", _get_subclasses(BaseTransport), ids=lambda t: t[0] +) +def test_transport_init_signature(class_name_obj): + params = list(inspect.signature(class_name_obj[1].__init__).parameters.values()) + + assert len(params) == 5 + assert ( + params[0].name == "self" + and params[0].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + ) + assert ( + params[1].name == "host" + and params[1].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + ) + assert params[2].name == "port" and params[2].kind == inspect.Parameter.KEYWORD_ONLY + assert ( + params[3].name == "credentials" + and params[3].kind == inspect.Parameter.KEYWORD_ONLY + ) + assert ( + params[4].name == "timeout" and params[4].kind == inspect.Parameter.KEYWORD_ONLY + ) diff --git a/kasa/tests/test_smartdevice.py b/kasa/tests/test_smartdevice.py index 33c9f448..90eae16f 100644 --- a/kasa/tests/test_smartdevice.py +++ b/kasa/tests/test_smartdevice.py @@ -232,7 +232,7 @@ async def test_modules_preserved(dev: SmartDevice): async def test_create_smart_device_with_timeout(): """Make sure timeout is passed to the protocol.""" dev = SmartDevice(host="127.0.0.1", timeout=100) - assert dev.protocol.timeout == 100 + assert dev.protocol._transport._timeout == 100 async def test_create_thin_wrapper():