mirror of
				https://github.com/python-kasa/python-kasa.git
				synced 2025-10-31 04:31:54 +00:00 
			
		
		
		
	Do login entirely within AesTransport (#580)
* Do login entirely within AesTransport * Remove login and handshake attributes from BaseTransport * Add AesTransport tests * Synchronise transport and protocol __init__ signatures and rename internal variables * Update after review
This commit is contained in:
		| @@ -8,7 +8,7 @@ import base64 | ||||
| import hashlib | ||||
| import logging | ||||
| import time | ||||
| from typing import Optional, Union | ||||
| from typing import Optional | ||||
|  | ||||
| import httpx | ||||
| from cryptography.hazmat.primitives import padding, serialization | ||||
| @@ -47,6 +47,7 @@ class AesTransport(BaseTransport): | ||||
|     protocol, sometimes used by newer firmware versions on kasa devices. | ||||
|     """ | ||||
|  | ||||
|     DEFAULT_PORT = 80 | ||||
|     DEFAULT_TIMEOUT = 5 | ||||
|     SESSION_COOKIE_NAME = "TP_SESSIONID" | ||||
|     COMMON_HEADERS = { | ||||
| @@ -59,12 +60,16 @@ class AesTransport(BaseTransport): | ||||
|         self, | ||||
|         host: str, | ||||
|         *, | ||||
|         port: Optional[int] = None, | ||||
|         credentials: Optional[Credentials] = None, | ||||
|         timeout: Optional[int] = None, | ||||
|     ) -> None: | ||||
|         super().__init__(host=host) | ||||
|  | ||||
|         self._credentials = credentials or Credentials(username="", password="") | ||||
|         super().__init__( | ||||
|             host, | ||||
|             port=port or self.DEFAULT_PORT, | ||||
|             credentials=credentials, | ||||
|             timeout=timeout, | ||||
|         ) | ||||
|  | ||||
|         self._handshake_done = False | ||||
|  | ||||
| @@ -77,7 +82,7 @@ class AesTransport(BaseTransport): | ||||
|         self._http_client: httpx.AsyncClient = httpx.AsyncClient() | ||||
|         self._login_token = None | ||||
|  | ||||
|         _LOGGER.debug("Created AES object for %s", self.host) | ||||
|         _LOGGER.debug("Created AES transport for %s", self._host) | ||||
|  | ||||
|     def hash_credentials(self, login_v2): | ||||
|         """Hash the credentials.""" | ||||
| @@ -123,7 +128,7 @@ class AesTransport(BaseTransport): | ||||
|         if ( | ||||
|             error_code := SmartErrorCode(resp_dict.get("error_code"))  # type: ignore[arg-type] | ||||
|         ) != SmartErrorCode.SUCCESS: | ||||
|             msg = f"{msg}: {self.host}: {error_code.name}({error_code.value})" | ||||
|             msg = f"{msg}: {self._host}: {error_code.name}({error_code.value})" | ||||
|             if error_code in SMART_TIMEOUT_ERRORS: | ||||
|                 raise TimeoutException(msg) | ||||
|             if error_code in SMART_RETRYABLE_ERRORS: | ||||
| @@ -136,7 +141,7 @@ class AesTransport(BaseTransport): | ||||
|  | ||||
|     async def send_secure_passthrough(self, request: str): | ||||
|         """Send encrypted message as passthrough.""" | ||||
|         url = f"http://{self.host}/app" | ||||
|         url = f"http://{self._host}/app" | ||||
|         if self._login_token: | ||||
|             url += f"?token={self._login_token}" | ||||
|  | ||||
| @@ -150,7 +155,7 @@ class AesTransport(BaseTransport): | ||||
|  | ||||
|         if status_code != 200: | ||||
|             raise SmartDeviceException( | ||||
|                 f"{self.host} responded with an unexpected " | ||||
|                 f"{self._host} responded with an unexpected " | ||||
|                 + f"status code {status_code} to passthrough" | ||||
|             ) | ||||
|  | ||||
| @@ -164,49 +169,31 @@ class AesTransport(BaseTransport): | ||||
|         resp_dict = json_loads(response) | ||||
|         return resp_dict | ||||
|  | ||||
|     async def perform_login(self, login_request: Union[str, dict], *, login_v2: bool): | ||||
|     async def _perform_login_for_version(self, *, login_version: int = 1): | ||||
|         """Login to the device.""" | ||||
|         self._login_token = None | ||||
|  | ||||
|         if isinstance(login_request, str): | ||||
|             login_request_dict: dict = json_loads(login_request) | ||||
|         else: | ||||
|             login_request_dict = login_request | ||||
|  | ||||
|         un, pw = self.hash_credentials(login_v2) | ||||
|         login_request_dict["params"] = {"password": pw, "username": un} | ||||
|         request = json_dumps(login_request_dict) | ||||
|         un, pw = self.hash_credentials(login_version == 2) | ||||
|         password_field_name = "password2" if login_version == 2 else "password" | ||||
|         login_request = { | ||||
|             "method": "login_device", | ||||
|             "params": {password_field_name: pw, "username": un}, | ||||
|             "request_time_milis": round(time.time() * 1000), | ||||
|         } | ||||
|         request = json_dumps(login_request) | ||||
|         try: | ||||
|             resp_dict = await self.send_secure_passthrough(request) | ||||
|         except SmartDeviceException as ex: | ||||
|             raise AuthenticationException(ex) from ex | ||||
|         self._login_token = resp_dict["result"]["token"] | ||||
|  | ||||
|     @property | ||||
|     def needs_login(self) -> bool: | ||||
|         """Return true if the transport needs to do a login.""" | ||||
|         return self._login_token is None | ||||
|  | ||||
|     async def login(self, request: str) -> None: | ||||
|     async def perform_login(self) -> None: | ||||
|         """Login to the device.""" | ||||
|         try: | ||||
|             if self.needs_handshake: | ||||
|                 raise SmartDeviceException( | ||||
|                     "Handshake must be complete before trying to login" | ||||
|                 ) | ||||
|             await self.perform_login(request, login_v2=False) | ||||
|             await self._perform_login_for_version(login_version=2) | ||||
|         except AuthenticationException: | ||||
|             _LOGGER.warning("Login version 2 failed, trying version 1") | ||||
|             await self.perform_handshake() | ||||
|             await self.perform_login(request, login_v2=True) | ||||
|  | ||||
|     @property | ||||
|     def needs_handshake(self) -> bool: | ||||
|         """Return true if the transport needs to do a handshake.""" | ||||
|         return not self._handshake_done or self._handshake_session_expired() | ||||
|  | ||||
|     async def handshake(self) -> None: | ||||
|         """Perform the encryption handshake.""" | ||||
|         await self.perform_handshake() | ||||
|             await self._perform_login_for_version(login_version=1) | ||||
|  | ||||
|     async def perform_handshake(self): | ||||
|         """Perform the handshake.""" | ||||
| @@ -217,7 +204,7 @@ class AesTransport(BaseTransport): | ||||
|         self._session_expire_at = None | ||||
|         self._session_cookie = None | ||||
|  | ||||
|         url = f"http://{self.host}/app" | ||||
|         url = f"http://{self._host}/app" | ||||
|         key_pair = KeyPair.create_key_pair() | ||||
|  | ||||
|         pub_key = ( | ||||
| @@ -238,7 +225,7 @@ class AesTransport(BaseTransport): | ||||
|  | ||||
|         if status_code != 200: | ||||
|             raise SmartDeviceException( | ||||
|                 f"{self.host} responded with an unexpected " | ||||
|                 f"{self._host} responded with an unexpected " | ||||
|                 + f"status code {status_code} to handshake" | ||||
|             ) | ||||
|  | ||||
| @@ -261,7 +248,7 @@ class AesTransport(BaseTransport): | ||||
|  | ||||
|         self._handshake_done = True | ||||
|  | ||||
|         _LOGGER.debug("Handshake with %s complete", self.host) | ||||
|         _LOGGER.debug("Handshake with %s complete", self._host) | ||||
|  | ||||
|     def _handshake_session_expired(self): | ||||
|         """Return true if session has expired.""" | ||||
| @@ -272,12 +259,10 @@ class AesTransport(BaseTransport): | ||||
|  | ||||
|     async def send(self, request: str): | ||||
|         """Send the request.""" | ||||
|         if self.needs_handshake: | ||||
|             raise SmartDeviceException( | ||||
|                 "Handshake must be complete before trying to send" | ||||
|             ) | ||||
|         if self.needs_login: | ||||
|             raise SmartDeviceException("Login must be complete before trying to send") | ||||
|         if not self._handshake_done or self._handshake_session_expired(): | ||||
|             await self.perform_handshake() | ||||
|         if not self._login_token: | ||||
|             await self.perform_login() | ||||
|  | ||||
|         return await self.send_secure_passthrough(request) | ||||
|  | ||||
|   | ||||
| @@ -74,7 +74,12 @@ async def connect( | ||||
|             host=host, port=port, credentials=credentials, timeout=timeout | ||||
|         ) | ||||
|         if protocol_class is not None: | ||||
|             dev.protocol = protocol_class(host, credentials=credentials) | ||||
|             dev.protocol = protocol_class( | ||||
|                 host, | ||||
|                 transport=AesTransport( | ||||
|                     host, port=port, credentials=credentials, timeout=timeout | ||||
|                 ), | ||||
|             ) | ||||
|         await dev.update() | ||||
|         if debug_enabled: | ||||
|             end_time = time.perf_counter() | ||||
| @@ -90,7 +95,13 @@ async def connect( | ||||
|         host=host, port=port, credentials=credentials, timeout=timeout | ||||
|     ) | ||||
|     if protocol_class is not None: | ||||
|         unknown_dev.protocol = protocol_class(host, credentials=credentials) | ||||
|         # TODO this will be replaced with connection params | ||||
|         unknown_dev.protocol = protocol_class( | ||||
|             host, | ||||
|             transport=AesTransport( | ||||
|                 host, port=port, credentials=credentials, timeout=timeout | ||||
|             ), | ||||
|         ) | ||||
|     await unknown_dev.update() | ||||
|     device_class = get_device_class_from_sys_info(unknown_dev.internal_state) | ||||
|     dev = device_class(host=host, port=port, credentials=credentials, timeout=timeout) | ||||
| @@ -163,7 +174,5 @@ def get_protocol_from_connection_name( | ||||
|  | ||||
|     protocol_class, transport_class = supported_device_protocols.get(connection_name)  # type: ignore | ||||
|     transport: BaseTransport = transport_class(host, credentials=credentials) | ||||
|     protocol: TPLinkProtocol = protocol_class( | ||||
|         host, credentials=credentials, transport=transport | ||||
|     ) | ||||
|     protocol: TPLinkProtocol = protocol_class(host, transport=transport) | ||||
|     return protocol | ||||
|   | ||||
| @@ -1,14 +1,12 @@ | ||||
| """Module for the IOT legacy IOT KASA protocol.""" | ||||
| import asyncio | ||||
| import logging | ||||
| from typing import Dict, Optional, Union | ||||
| from typing import Dict, Union | ||||
|  | ||||
| import httpx | ||||
|  | ||||
| from .credentials import Credentials | ||||
| from .exceptions import AuthenticationException, SmartDeviceException | ||||
| from .json import dumps as json_dumps | ||||
| from .klaptransport import KlapTransport | ||||
| from .protocol import BaseTransport, TPLinkProtocol | ||||
|  | ||||
| _LOGGER = logging.getLogger(__name__) | ||||
| @@ -17,24 +15,14 @@ _LOGGER = logging.getLogger(__name__) | ||||
| class IotProtocol(TPLinkProtocol): | ||||
|     """Class for the legacy TPLink IOT KASA Protocol.""" | ||||
|  | ||||
|     DEFAULT_PORT = 80 | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         host: str, | ||||
|         *, | ||||
|         transport: Optional[BaseTransport] = None, | ||||
|         credentials: Optional[Credentials] = None, | ||||
|         timeout: Optional[int] = None, | ||||
|         transport: BaseTransport, | ||||
|     ) -> None: | ||||
|         super().__init__(host=host, port=self.DEFAULT_PORT) | ||||
|  | ||||
|         self._credentials: Credentials = credentials or Credentials( | ||||
|             username="", password="" | ||||
|         ) | ||||
|         self._transport: BaseTransport = transport or KlapTransport( | ||||
|             host, credentials=self._credentials, timeout=timeout | ||||
|         ) | ||||
|         """Create a protocol object.""" | ||||
|         super().__init__(host, transport=transport) | ||||
|  | ||||
|         self._query_lock = asyncio.Lock() | ||||
|  | ||||
| @@ -54,30 +42,32 @@ class IotProtocol(TPLinkProtocol): | ||||
|             except httpx.CloseError as sdex: | ||||
|                 await self.close() | ||||
|                 if retry >= retry_count: | ||||
|                     _LOGGER.debug("Giving up on %s after %s retries", self.host, retry) | ||||
|                     _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) | ||||
|                     raise SmartDeviceException( | ||||
|                         f"Unable to connect to the device: {self.host}: {sdex}" | ||||
|                         f"Unable to connect to the device: {self._host}: {sdex}" | ||||
|                     ) from sdex | ||||
|                 continue | ||||
|             except httpx.ConnectError as cex: | ||||
|                 await self.close() | ||||
|                 raise SmartDeviceException( | ||||
|                     f"Unable to connect to the device: {self.host}: {cex}" | ||||
|                     f"Unable to connect to the device: {self._host}: {cex}" | ||||
|                 ) from cex | ||||
|             except TimeoutError as tex: | ||||
|                 await self.close() | ||||
|                 raise SmartDeviceException( | ||||
|                     f"Unable to connect to the device, timed out: {self.host}: {tex}" | ||||
|                     f"Unable to connect to the device, timed out: {self._host}: {tex}" | ||||
|                 ) from tex | ||||
|             except AuthenticationException as auex: | ||||
|                 _LOGGER.debug("Unable to authenticate with %s, not retrying", self.host) | ||||
|                 _LOGGER.debug( | ||||
|                     "Unable to authenticate with %s, not retrying", self._host | ||||
|                 ) | ||||
|                 raise auex | ||||
|             except Exception as ex: | ||||
|                 await self.close() | ||||
|                 if retry >= retry_count: | ||||
|                     _LOGGER.debug("Giving up on %s after %s retries", self.host, retry) | ||||
|                     _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) | ||||
|                     raise SmartDeviceException( | ||||
|                         f"Unable to connect to the device: {self.host}: {ex}" | ||||
|                         f"Unable to connect to the device: {self._host}: {ex}" | ||||
|                     ) from ex | ||||
|                 continue | ||||
|  | ||||
| @@ -85,14 +75,6 @@ class IotProtocol(TPLinkProtocol): | ||||
|         raise SmartDeviceException("Query reached somehow to unreachable") | ||||
|  | ||||
|     async def _execute_query(self, request: str, retry_count: int) -> Dict: | ||||
|         if self._transport.needs_handshake: | ||||
|             await self._transport.handshake() | ||||
|  | ||||
|         if self._transport.needs_login:  # This shouln't happen | ||||
|             raise SmartDeviceException( | ||||
|                 "IOT Protocol needs to login to transport but is not login aware" | ||||
|             ) | ||||
|  | ||||
|         return await self._transport.send(request) | ||||
|  | ||||
|     async def close(self) -> None: | ||||
|   | ||||
| @@ -82,7 +82,7 @@ class KlapTransport(BaseTransport): | ||||
|     protocol, used by newer firmware versions. | ||||
|     """ | ||||
|  | ||||
|     DEFAULT_TIMEOUT = 5 | ||||
|     DEFAULT_PORT = 80 | ||||
|     DISCOVERY_QUERY = {"system": {"get_sysinfo": None}} | ||||
|     KASA_SETUP_EMAIL = "kasa@tp-link.net" | ||||
|     KASA_SETUP_PASSWORD = "kasaSetup"  # noqa: S105 | ||||
| @@ -92,12 +92,17 @@ class KlapTransport(BaseTransport): | ||||
|         self, | ||||
|         host: str, | ||||
|         *, | ||||
|         port: Optional[int] = None, | ||||
|         credentials: Optional[Credentials] = None, | ||||
|         timeout: Optional[int] = None, | ||||
|     ) -> None: | ||||
|         super().__init__(host=host) | ||||
|         super().__init__( | ||||
|             host, | ||||
|             port=port or self.DEFAULT_PORT, | ||||
|             credentials=credentials, | ||||
|             timeout=timeout, | ||||
|         ) | ||||
|  | ||||
|         self._credentials = credentials or Credentials(username="", password="") | ||||
|         self._local_seed: Optional[bytes] = None | ||||
|         self._local_auth_hash = self.generate_auth_hash(self._credentials) | ||||
|         self._local_auth_owner = self.generate_owner_hash(self._credentials).hex() | ||||
| @@ -110,11 +115,10 @@ class KlapTransport(BaseTransport): | ||||
|         self._encryption_session: Optional[KlapEncryptionSession] = None | ||||
|         self._session_expire_at: Optional[float] = None | ||||
|  | ||||
|         self._timeout = timeout if timeout else self.DEFAULT_TIMEOUT | ||||
|         self._session_cookie = None | ||||
|         self._http_client: httpx.AsyncClient = httpx.AsyncClient() | ||||
|  | ||||
|         _LOGGER.debug("Created KLAP object for %s", self.host) | ||||
|         _LOGGER.debug("Created KLAP transport for %s", self._host) | ||||
|  | ||||
|     async def client_post(self, url, params=None, data=None): | ||||
|         """Send an http post request to the device.""" | ||||
| @@ -148,7 +152,7 @@ class KlapTransport(BaseTransport): | ||||
|  | ||||
|         payload = local_seed | ||||
|  | ||||
|         url = f"http://{self.host}/app/handshake1" | ||||
|         url = f"http://{self._host}/app/handshake1" | ||||
|  | ||||
|         response_status, response_data = await self.client_post(url, data=payload) | ||||
|  | ||||
| @@ -157,14 +161,14 @@ class KlapTransport(BaseTransport): | ||||
|                 "Handshake1 posted at %s. Host is %s, Response" | ||||
|                 + "status is %s, Request was %s", | ||||
|                 datetime.datetime.now(), | ||||
|                 self.host, | ||||
|                 self._host, | ||||
|                 response_status, | ||||
|                 payload.hex(), | ||||
|             ) | ||||
|  | ||||
|         if response_status != 200: | ||||
|             raise AuthenticationException( | ||||
|                 f"Device {self.host} responded with {response_status} to handshake1" | ||||
|                 f"Device {self._host} responded with {response_status} to handshake1" | ||||
|             ) | ||||
|  | ||||
|         remote_seed: bytes = response_data[0:16] | ||||
| @@ -175,7 +179,7 @@ class KlapTransport(BaseTransport): | ||||
|                 "Handshake1 success at %s. Host is %s, " | ||||
|                 + "Server remote_seed is: %s, server hash is: %s", | ||||
|                 datetime.datetime.now(), | ||||
|                 self.host, | ||||
|                 self._host, | ||||
|                 remote_seed.hex(), | ||||
|                 server_hash.hex(), | ||||
|             ) | ||||
| @@ -207,7 +211,7 @@ class KlapTransport(BaseTransport): | ||||
|             _LOGGER.debug( | ||||
|                 "Server response doesn't match our expected hash on ip %s" | ||||
|                 + " but an authentication with kasa setup credentials matched", | ||||
|                 self.host, | ||||
|                 self._host, | ||||
|             ) | ||||
|             return local_seed, remote_seed, self._kasa_setup_auth_hash  # type: ignore | ||||
|  | ||||
| @@ -226,11 +230,11 @@ class KlapTransport(BaseTransport): | ||||
|                 _LOGGER.debug( | ||||
|                     "Server response doesn't match our expected hash on ip %s" | ||||
|                     + " but an authentication with blank credentials matched", | ||||
|                     self.host, | ||||
|                     self._host, | ||||
|                 ) | ||||
|                 return local_seed, remote_seed, self._blank_auth_hash  # type: ignore | ||||
|  | ||||
|         msg = f"Server response doesn't match our challenge on ip {self.host}" | ||||
|         msg = f"Server response doesn't match our challenge on ip {self._host}" | ||||
|         _LOGGER.debug(msg) | ||||
|         raise AuthenticationException(msg) | ||||
|  | ||||
| @@ -241,7 +245,7 @@ class KlapTransport(BaseTransport): | ||||
|         # Handshake 2 has the following payload: | ||||
|         #    sha256(serverBytes | authenticator) | ||||
|  | ||||
|         url = f"http://{self.host}/app/handshake2" | ||||
|         url = f"http://{self._host}/app/handshake2" | ||||
|  | ||||
|         payload = self.handshake2_seed_auth_hash(local_seed, remote_seed, auth_hash) | ||||
|  | ||||
| @@ -252,44 +256,24 @@ class KlapTransport(BaseTransport): | ||||
|                 "Handshake2 posted %s.  Host is %s, Response status is %s, " | ||||
|                 + "Request was %s", | ||||
|                 datetime.datetime.now(), | ||||
|                 self.host, | ||||
|                 self._host, | ||||
|                 response_status, | ||||
|                 payload.hex(), | ||||
|             ) | ||||
|  | ||||
|         if response_status != 200: | ||||
|             raise AuthenticationException( | ||||
|                 f"Device {self.host} responded with {response_status} to handshake2" | ||||
|                 f"Device {self._host} responded with {response_status} to handshake2" | ||||
|             ) | ||||
|  | ||||
|         return KlapEncryptionSession(local_seed, remote_seed, auth_hash) | ||||
|  | ||||
|     @property | ||||
|     def needs_login(self) -> bool: | ||||
|         """Will return false as KLAP does not do a login.""" | ||||
|         return False | ||||
|  | ||||
|     async def login(self, request: str) -> None: | ||||
|         """Will raise and exception as KLAP does not do a login.""" | ||||
|         raise SmartDeviceException( | ||||
|             "KLAP does not perform logins and return needs_login == False" | ||||
|         ) | ||||
|  | ||||
|     @property | ||||
|     def needs_handshake(self) -> bool: | ||||
|         """Return true if the transport needs to do a handshake.""" | ||||
|         return not self._handshake_done or self._handshake_session_expired() | ||||
|  | ||||
|     async def handshake(self) -> None: | ||||
|         """Perform the encryption handshake.""" | ||||
|         await self.perform_handshake() | ||||
|  | ||||
|     async def perform_handshake(self) -> Any: | ||||
|         """Perform handshake1 and handshake2. | ||||
|  | ||||
|         Sets the encryption_session if successful. | ||||
|         """ | ||||
|         _LOGGER.debug("Starting handshake with %s", self.host) | ||||
|         _LOGGER.debug("Starting handshake with %s", self._host) | ||||
|         self._handshake_done = False | ||||
|         self._session_expire_at = None | ||||
|         self._session_cookie = None | ||||
| @@ -307,7 +291,7 @@ class KlapTransport(BaseTransport): | ||||
|         ) | ||||
|         self._handshake_done = True | ||||
|  | ||||
|         _LOGGER.debug("Handshake with %s complete", self.host) | ||||
|         _LOGGER.debug("Handshake with %s complete", self._host) | ||||
|  | ||||
|     def _handshake_session_expired(self): | ||||
|         """Return true if session has expired.""" | ||||
| @@ -318,18 +302,14 @@ class KlapTransport(BaseTransport): | ||||
|  | ||||
|     async def send(self, request: str): | ||||
|         """Send the request.""" | ||||
|         if self.needs_handshake: | ||||
|             raise SmartDeviceException( | ||||
|                 "Handshake must be complete before trying to send" | ||||
|             ) | ||||
|         if self.needs_login: | ||||
|             raise SmartDeviceException("Login must be complete before trying to send") | ||||
|         if not self._handshake_done or self._handshake_session_expired(): | ||||
|             await self.perform_handshake() | ||||
|  | ||||
|         # Check for mypy | ||||
|         if self._encryption_session is not None: | ||||
|             payload, seq = self._encryption_session.encrypt(request.encode()) | ||||
|  | ||||
|         url = f"http://{self.host}/app/request" | ||||
|         url = f"http://{self._host}/app/request" | ||||
|  | ||||
|         response_status, response_data = await self.client_post( | ||||
|             url, | ||||
| @@ -338,7 +318,7 @@ class KlapTransport(BaseTransport): | ||||
|         ) | ||||
|  | ||||
|         msg = ( | ||||
|             f"at {datetime.datetime.now()}.  Host is {self.host}, " | ||||
|             f"at {datetime.datetime.now()}.  Host is {self._host}, " | ||||
|             + f"Sequence is {seq}, " | ||||
|             + f"Response status is {response_status}, Request was {request}" | ||||
|         ) | ||||
| @@ -348,12 +328,12 @@ class KlapTransport(BaseTransport): | ||||
|             if response_status == 403: | ||||
|                 self._handshake_done = False | ||||
|                 raise AuthenticationException( | ||||
|                     f"Got a security error from {self.host} after handshake " | ||||
|                     f"Got a security error from {self._host} after handshake " | ||||
|                     + "completed" | ||||
|                 ) | ||||
|             else: | ||||
|                 raise SmartDeviceException( | ||||
|                     f"Device {self.host} responded with {response_status} to" | ||||
|                     f"Device {self._host} responded with {response_status} to" | ||||
|                     + f"request with seq {seq}" | ||||
|                 ) | ||||
|         else: | ||||
| @@ -367,7 +347,7 @@ class KlapTransport(BaseTransport): | ||||
|  | ||||
|             _LOGGER.debug( | ||||
|                 "%s << %s", | ||||
|                 self.host, | ||||
|                 self._host, | ||||
|                 _LOGGER.isEnabledFor(logging.DEBUG) and pf(json_payload), | ||||
|             ) | ||||
|  | ||||
|   | ||||
							
								
								
									
										108
									
								
								kasa/protocol.py
									
									
									
									
									
								
							
							
						
						
									
										108
									
								
								kasa/protocol.py
									
									
									
									
									
								
							| @@ -44,35 +44,21 @@ def md5(payload: bytes) -> bytes: | ||||
| class BaseTransport(ABC): | ||||
|     """Base class for all TP-Link protocol transports.""" | ||||
|  | ||||
|     DEFAULT_TIMEOUT = 5 | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         host: str, | ||||
|         *, | ||||
|         port: Optional[int] = None, | ||||
|         credentials: Optional[Credentials] = None, | ||||
|         timeout: Optional[int] = None, | ||||
|     ) -> None: | ||||
|         """Create a protocol object.""" | ||||
|         self.host = host | ||||
|         self.port = port | ||||
|         self.credentials = credentials | ||||
|  | ||||
|     @property | ||||
|     @abstractmethod | ||||
|     def needs_handshake(self) -> bool: | ||||
|         """Return true if the transport needs to do a handshake.""" | ||||
|  | ||||
|     @property | ||||
|     @abstractmethod | ||||
|     def needs_login(self) -> bool: | ||||
|         """Return true if the transport needs to do a login.""" | ||||
|  | ||||
|     @abstractmethod | ||||
|     async def login(self, request: str) -> None: | ||||
|         """Login to the device.""" | ||||
|  | ||||
|     @abstractmethod | ||||
|     async def handshake(self) -> None: | ||||
|         """Perform the encryption handshake.""" | ||||
|         self._host = host | ||||
|         self._port = port | ||||
|         self._credentials = credentials or Credentials(username="", password="") | ||||
|         self._timeout = timeout or self.DEFAULT_TIMEOUT | ||||
|  | ||||
|     @abstractmethod | ||||
|     async def send(self, request: str) -> Dict: | ||||
| @@ -90,14 +76,14 @@ class TPLinkProtocol(ABC): | ||||
|         self, | ||||
|         host: str, | ||||
|         *, | ||||
|         port: Optional[int] = None, | ||||
|         credentials: Optional[Credentials] = None, | ||||
|         transport: Optional[BaseTransport] = None, | ||||
|         transport: BaseTransport, | ||||
|     ) -> None: | ||||
|         """Create a protocol object.""" | ||||
|         self.host = host | ||||
|         self.port = port | ||||
|         self.credentials = credentials | ||||
|         self._transport = transport | ||||
|  | ||||
|     @property | ||||
|     def _host(self): | ||||
|         return self._transport._host | ||||
|  | ||||
|     @abstractmethod | ||||
|     async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: | ||||
| @@ -108,6 +94,40 @@ class TPLinkProtocol(ABC): | ||||
|         """Close the protocol.  Abstract method to be overriden.""" | ||||
|  | ||||
|  | ||||
| class _XorTransport(BaseTransport): | ||||
|     """Implementation of the Xor encryption transport. | ||||
|  | ||||
|     WIP, currently only to ensure consistent __init__ method signatures | ||||
|     for protocol classes.  Will eventually incorporate the logic from | ||||
|     TPLinkSmartHomeProtocol to simplify the API and re-use the IotProtocol | ||||
|     class. | ||||
|     """ | ||||
|  | ||||
|     DEFAULT_PORT = 9999 | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         host: str, | ||||
|         *, | ||||
|         port: Optional[int] = None, | ||||
|         credentials: Optional[Credentials] = None, | ||||
|         timeout: Optional[int] = None, | ||||
|     ) -> None: | ||||
|         super().__init__( | ||||
|             host, | ||||
|             port=port or self.DEFAULT_PORT, | ||||
|             credentials=credentials, | ||||
|             timeout=timeout, | ||||
|         ) | ||||
|  | ||||
|     async def send(self, request: str) -> Dict: | ||||
|         """Send a message to the device and return a response.""" | ||||
|         return {} | ||||
|  | ||||
|     async def close(self) -> None: | ||||
|         """Close the transport.  Abstract method to be overriden.""" | ||||
|  | ||||
|  | ||||
| class TPLinkSmartHomeProtocol(TPLinkProtocol): | ||||
|     """Implementation of the TP-Link Smart Home protocol.""" | ||||
|  | ||||
| @@ -120,20 +140,18 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol): | ||||
|         self, | ||||
|         host: str, | ||||
|         *, | ||||
|         port: Optional[int] = None, | ||||
|         timeout: Optional[int] = None, | ||||
|         credentials: Optional[Credentials] = None, | ||||
|         transport: BaseTransport, | ||||
|     ) -> None: | ||||
|         """Create a protocol object.""" | ||||
|         super().__init__( | ||||
|             host=host, port=port or self.DEFAULT_PORT, credentials=credentials | ||||
|         ) | ||||
|         super().__init__(host, transport=transport) | ||||
|  | ||||
|         self.reader: Optional[asyncio.StreamReader] = None | ||||
|         self.writer: Optional[asyncio.StreamWriter] = None | ||||
|         self.query_lock = asyncio.Lock() | ||||
|         self.loop: Optional[asyncio.AbstractEventLoop] = None | ||||
|         self.timeout = timeout or TPLinkSmartHomeProtocol.DEFAULT_TIMEOUT | ||||
|  | ||||
|         self._timeout = self._transport._timeout | ||||
|         self._port = self._transport._port | ||||
|  | ||||
|     async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: | ||||
|         """Request information from a TP-Link SmartHome Device. | ||||
| @@ -149,7 +167,7 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol): | ||||
|             assert isinstance(request, str)  # noqa: S101 | ||||
|  | ||||
|         async with self.query_lock: | ||||
|             return await self._query(request, retry_count, self.timeout) | ||||
|             return await self._query(request, retry_count, self._timeout) | ||||
|  | ||||
|     async def _connect(self, timeout: int) -> None: | ||||
|         """Try to connect or reconnect to the device.""" | ||||
| @@ -157,7 +175,7 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol): | ||||
|             return | ||||
|         self.reader = self.writer = None | ||||
|  | ||||
|         task = asyncio.open_connection(self.host, self.port) | ||||
|         task = asyncio.open_connection(self._host, self._port) | ||||
|         async with asyncio_timeout(timeout): | ||||
|             self.reader, self.writer = await task | ||||
|             sock: socket.socket = self.writer.get_extra_info("socket") | ||||
| @@ -174,7 +192,7 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol): | ||||
|         debug_log = _LOGGER.isEnabledFor(logging.DEBUG) | ||||
|  | ||||
|         if debug_log: | ||||
|             _LOGGER.debug("%s >> %s", self.host, request) | ||||
|             _LOGGER.debug("%s >> %s", self._host, request) | ||||
|         self.writer.write(TPLinkSmartHomeProtocol.encrypt(request)) | ||||
|         await self.writer.drain() | ||||
|  | ||||
| @@ -185,7 +203,7 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol): | ||||
|         response = TPLinkSmartHomeProtocol.decrypt(buffer) | ||||
|         json_payload = json_loads(response) | ||||
|         if debug_log: | ||||
|             _LOGGER.debug("%s << %s", self.host, pf(json_payload)) | ||||
|             _LOGGER.debug("%s << %s", self._host, pf(json_payload)) | ||||
|  | ||||
|         return json_payload | ||||
|  | ||||
| @@ -219,23 +237,23 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol): | ||||
|             except ConnectionRefusedError as ex: | ||||
|                 await self.close() | ||||
|                 raise SmartDeviceException( | ||||
|                     f"Unable to connect to the device: {self.host}:{self.port}: {ex}" | ||||
|                     f"Unable to connect to the device: {self._host}:{self._port}: {ex}" | ||||
|                 ) from ex | ||||
|             except OSError as ex: | ||||
|                 await self.close() | ||||
|                 if ex.errno in _NO_RETRY_ERRORS or retry >= retry_count: | ||||
|                     raise SmartDeviceException( | ||||
|                         f"Unable to connect to the device:" | ||||
|                         f" {self.host}:{self.port}: {ex}" | ||||
|                         f" {self._host}:{self._port}: {ex}" | ||||
|                     ) from ex | ||||
|                 continue | ||||
|             except Exception as ex: | ||||
|                 await self.close() | ||||
|                 if retry >= retry_count: | ||||
|                     _LOGGER.debug("Giving up on %s after %s retries", self.host, retry) | ||||
|                     _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) | ||||
|                     raise SmartDeviceException( | ||||
|                         f"Unable to connect to the device:" | ||||
|                         f" {self.host}:{self.port}: {ex}" | ||||
|                         f" {self._host}:{self._port}: {ex}" | ||||
|                     ) from ex | ||||
|                 continue | ||||
|  | ||||
| @@ -247,13 +265,13 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol): | ||||
|             except Exception as ex: | ||||
|                 await self.close() | ||||
|                 if retry >= retry_count: | ||||
|                     _LOGGER.debug("Giving up on %s after %s retries", self.host, retry) | ||||
|                     _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) | ||||
|                     raise SmartDeviceException( | ||||
|                         f"Unable to query the device {self.host}:{self.port}: {ex}" | ||||
|                         f"Unable to query the device {self._host}:{self._port}: {ex}" | ||||
|                     ) from ex | ||||
|  | ||||
|                 _LOGGER.debug( | ||||
|                     "Unable to query the device %s, retrying: %s", self.host, ex | ||||
|                     "Unable to query the device %s, retrying: %s", self._host, ex | ||||
|                 ) | ||||
|  | ||||
|         # make mypy happy, this should never be reached.. | ||||
|   | ||||
| @@ -24,7 +24,7 @@ from .device_type import DeviceType | ||||
| from .emeterstatus import EmeterStatus | ||||
| from .exceptions import SmartDeviceException | ||||
| from .modules import Emeter, Module | ||||
| from .protocol import TPLinkProtocol, TPLinkSmartHomeProtocol | ||||
| from .protocol import TPLinkProtocol, TPLinkSmartHomeProtocol, _XorTransport | ||||
|  | ||||
| _LOGGER = logging.getLogger(__name__) | ||||
|  | ||||
| @@ -202,7 +202,7 @@ class SmartDevice: | ||||
|         self.host = host | ||||
|         self.port = port | ||||
|         self.protocol: TPLinkProtocol = TPLinkSmartHomeProtocol( | ||||
|             host, port=port, timeout=timeout | ||||
|             host, transport=_XorTransport(host, port=port, timeout=timeout) | ||||
|         ) | ||||
|         self.credentials = credentials | ||||
|         _LOGGER.debug("Initializing %s of type %s", self.host, type(self)) | ||||
|   | ||||
| @@ -10,12 +10,10 @@ import logging | ||||
| import time | ||||
| import uuid | ||||
| from pprint import pformat as pf | ||||
| from typing import Dict, Optional, Union | ||||
| from typing import Dict, Union | ||||
|  | ||||
| import httpx | ||||
|  | ||||
| from .aestransport import AesTransport | ||||
| from .credentials import Credentials | ||||
| from .exceptions import ( | ||||
|     SMART_AUTHENTICATION_ERRORS, | ||||
|     SMART_RETRYABLE_ERRORS, | ||||
| @@ -36,26 +34,17 @@ logging.getLogger("httpx").propagate = False | ||||
| class SmartProtocol(TPLinkProtocol): | ||||
|     """Class for the new TPLink SMART protocol.""" | ||||
|  | ||||
|     DEFAULT_PORT = 80 | ||||
|     SLEEP_SECONDS_AFTER_TIMEOUT = 1 | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         host: str, | ||||
|         *, | ||||
|         transport: Optional[BaseTransport] = None, | ||||
|         credentials: Optional[Credentials] = None, | ||||
|         timeout: Optional[int] = None, | ||||
|         transport: BaseTransport, | ||||
|     ) -> None: | ||||
|         super().__init__(host=host, port=self.DEFAULT_PORT) | ||||
|  | ||||
|         self._credentials: Credentials = credentials or Credentials( | ||||
|             username="", password="" | ||||
|         ) | ||||
|         self._transport: BaseTransport = transport or AesTransport( | ||||
|             host, credentials=self._credentials, timeout=timeout | ||||
|         ) | ||||
|         self._terminal_uuid: Optional[str] = None | ||||
|         """Create a protocol object.""" | ||||
|         super().__init__(host, transport=transport) | ||||
|         self._terminal_uuid: str = base64.b64encode(md5(uuid.uuid4().bytes)).decode() | ||||
|         self._request_id_generator = SnowflakeId(1, 1) | ||||
|         self._query_lock = asyncio.Lock() | ||||
|  | ||||
| @@ -79,7 +68,7 @@ class SmartProtocol(TPLinkProtocol): | ||||
|                 error_code := SmartErrorCode(resp_dict.get("error_code"))  # type: ignore[arg-type] | ||||
|             ) != SmartErrorCode.SUCCESS: | ||||
|                 msg = ( | ||||
|                     f"Error querying device: {self.host}: " | ||||
|                     f"Error querying device: {self._host}: " | ||||
|                     + f"{error_code.name}({error_code.value})" | ||||
|                 ) | ||||
|                 if error_code in SMART_TIMEOUT_ERRORS: | ||||
| @@ -101,51 +90,53 @@ class SmartProtocol(TPLinkProtocol): | ||||
|             except httpx.CloseError as sdex: | ||||
|                 await self.close() | ||||
|                 if retry >= retry_count: | ||||
|                     _LOGGER.debug("Giving up on %s after %s retries", self.host, retry) | ||||
|                     _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) | ||||
|                     raise SmartDeviceException( | ||||
|                         f"Unable to connect to the device: {self.host}: {sdex}" | ||||
|                         f"Unable to connect to the device: {self._host}: {sdex}" | ||||
|                     ) from sdex | ||||
|                 continue | ||||
|             except httpx.ConnectError as cex: | ||||
|                 await self.close() | ||||
|                 raise SmartDeviceException( | ||||
|                     f"Unable to connect to the device: {self.host}: {cex}" | ||||
|                     f"Unable to connect to the device: {self._host}: {cex}" | ||||
|                 ) from cex | ||||
|             except TimeoutError as tex: | ||||
|                 if retry >= retry_count: | ||||
|                     await self.close() | ||||
|                     raise SmartDeviceException( | ||||
|                         "Unable to connect to the device, " | ||||
|                         + f"timed out: {self.host}: {tex}" | ||||
|                         + f"timed out: {self._host}: {tex}" | ||||
|                     ) from tex | ||||
|                 await asyncio.sleep(self.SLEEP_SECONDS_AFTER_TIMEOUT) | ||||
|                 continue | ||||
|             except AuthenticationException as auex: | ||||
|                 await self.close() | ||||
|                 _LOGGER.debug("Unable to authenticate with %s, not retrying", self.host) | ||||
|                 _LOGGER.debug( | ||||
|                     "Unable to authenticate with %s, not retrying", self._host | ||||
|                 ) | ||||
|                 raise auex | ||||
|             except RetryableException as ex: | ||||
|                 if retry >= retry_count: | ||||
|                     await self.close() | ||||
|                     _LOGGER.debug("Giving up on %s after %s retries", self.host, retry) | ||||
|                     _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) | ||||
|                     raise ex | ||||
|                 continue | ||||
|             except TimeoutException as ex: | ||||
|                 if retry >= retry_count: | ||||
|                     await self.close() | ||||
|                     _LOGGER.debug("Giving up on %s after %s retries", self.host, retry) | ||||
|                     _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) | ||||
|                     raise ex | ||||
|                 await asyncio.sleep(self.SLEEP_SECONDS_AFTER_TIMEOUT) | ||||
|                 continue | ||||
|             except Exception as ex: | ||||
|                 if retry >= retry_count: | ||||
|                     await self.close() | ||||
|                     _LOGGER.debug("Giving up on %s after %s retries", self.host, retry) | ||||
|                     _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) | ||||
|                     raise SmartDeviceException( | ||||
|                         f"Unable to query the device {self.host}:{self.port}: {ex}" | ||||
|                         f"Unable to connect to the device: {self._host}: {ex}" | ||||
|                     ) from ex | ||||
|                 _LOGGER.debug( | ||||
|                     "Unable to query the device %s, retrying: %s", self.host, ex | ||||
|                     "Unable to query the device %s, retrying: %s", self._host, ex | ||||
|                 ) | ||||
|                 continue | ||||
|  | ||||
| @@ -160,27 +151,17 @@ class SmartProtocol(TPLinkProtocol): | ||||
|             smart_method = request | ||||
|             smart_params = None | ||||
|  | ||||
|         if self._transport.needs_handshake: | ||||
|             await self._transport.handshake() | ||||
|  | ||||
|         if self._transport.needs_login: | ||||
|             self._terminal_uuid = base64.b64encode(md5(uuid.uuid4().bytes)).decode( | ||||
|                 "UTF-8" | ||||
|             ) | ||||
|             login_request = self.get_smart_request("login_device") | ||||
|             await self._transport.login(login_request) | ||||
|  | ||||
|         smart_request = self.get_smart_request(smart_method, smart_params) | ||||
|         _LOGGER.debug( | ||||
|             "%s >> %s", | ||||
|             self.host, | ||||
|             self._host, | ||||
|             _LOGGER.isEnabledFor(logging.DEBUG) and pf(smart_request), | ||||
|         ) | ||||
|         response_data = await self._transport.send(smart_request) | ||||
|  | ||||
|         _LOGGER.debug( | ||||
|             "%s << %s", | ||||
|             self.host, | ||||
|             self._host, | ||||
|             _LOGGER.isEnabledFor(logging.DEBUG) and pf(response_data), | ||||
|         ) | ||||
|  | ||||
|   | ||||
| @@ -4,6 +4,7 @@ import logging | ||||
| from datetime import datetime, timedelta, timezone | ||||
| from typing import Any, Dict, Optional, Set, cast | ||||
|  | ||||
| from ..aestransport import AesTransport | ||||
| from ..credentials import Credentials | ||||
| from ..exceptions import AuthenticationException | ||||
| from ..smartdevice import SmartDevice | ||||
| @@ -27,7 +28,12 @@ class TapoDevice(SmartDevice): | ||||
|         self._components: Optional[Dict[str, Any]] = None | ||||
|         self._state_information: Dict[str, Any] = {} | ||||
|         self._discovery_info: Optional[Dict[str, Any]] = None | ||||
|         self.protocol = SmartProtocol(host, credentials=credentials, timeout=timeout) | ||||
|         self.protocol = SmartProtocol( | ||||
|             host, | ||||
|             transport=AesTransport( | ||||
|                 host, credentials=credentials, timeout=timeout, port=port | ||||
|             ), | ||||
|         ) | ||||
|  | ||||
|     async def update(self, update_children: bool = True): | ||||
|         """Update the device.""" | ||||
|   | ||||
| @@ -301,6 +301,9 @@ class FakeSmartProtocol(SmartProtocol): | ||||
|  | ||||
| class FakeSmartTransport(BaseTransport): | ||||
|     def __init__(self, info): | ||||
|         super().__init__( | ||||
|             "127.0.0.123", | ||||
|         ) | ||||
|         self.info = info | ||||
|  | ||||
|     @property | ||||
|   | ||||
							
								
								
									
										174
									
								
								kasa/tests/test_aestransport.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										174
									
								
								kasa/tests/test_aestransport.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,174 @@ | ||||
| import base64 | ||||
| import json | ||||
| import time | ||||
| from contextlib import nullcontext as does_not_raise | ||||
| from json import dumps as json_dumps | ||||
| from json import loads as json_loads | ||||
|  | ||||
| import httpx | ||||
| import pytest | ||||
| from cryptography.hazmat.primitives import serialization | ||||
| from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding | ||||
|  | ||||
| from ..aestransport import AesEncyptionSession, AesTransport | ||||
| from ..credentials import Credentials | ||||
| from ..exceptions import SmartDeviceException | ||||
|  | ||||
| DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}} | ||||
|  | ||||
| key = b"8\x89\x02\xfa\xf5Xs\x1c\xa1 H\x9a\x82\xc7\xd9\t" | ||||
| iv = b"9=\xf8\x1bS\xcd0\xb5\x89i\xba\xfd^9\x9f\xfa" | ||||
| KEY_IV = key + iv | ||||
|  | ||||
|  | ||||
| def test_encrypt(): | ||||
|     encryption_session = AesEncyptionSession(KEY_IV[:16], KEY_IV[16:]) | ||||
|  | ||||
|     d = json.dumps({"foo": 1, "bar": 2}) | ||||
|     encrypted = encryption_session.encrypt(d.encode()) | ||||
|     assert d == encryption_session.decrypt(encrypted) | ||||
|  | ||||
|     # test encrypt unicode | ||||
|     d = "{'snowman': '\u2603'}" | ||||
|     encrypted = encryption_session.encrypt(d.encode()) | ||||
|     assert d == encryption_session.decrypt(encrypted) | ||||
|  | ||||
|  | ||||
| status_parameters = pytest.mark.parametrize( | ||||
|     "status_code, error_code, inner_error_code, expectation", | ||||
|     [ | ||||
|         (200, 0, 0, does_not_raise()), | ||||
|         (400, 0, 0, pytest.raises(SmartDeviceException)), | ||||
|         (200, -1, 0, pytest.raises(SmartDeviceException)), | ||||
|     ], | ||||
|     ids=("success", "status_code", "error_code"), | ||||
| ) | ||||
|  | ||||
|  | ||||
| @status_parameters | ||||
| async def test_handshake( | ||||
|     mocker, status_code, error_code, inner_error_code, expectation | ||||
| ): | ||||
|     host = "127.0.0.1" | ||||
|     mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code) | ||||
|     mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post) | ||||
|  | ||||
|     transport = AesTransport(host=host, credentials=Credentials("foo", "bar")) | ||||
|  | ||||
|     assert transport._encryption_session is None | ||||
|     assert transport._handshake_done is False | ||||
|     with expectation: | ||||
|         await transport.perform_handshake() | ||||
|         assert transport._encryption_session is not None | ||||
|         assert transport._handshake_done is True | ||||
|  | ||||
|  | ||||
| @status_parameters | ||||
| async def test_login(mocker, status_code, error_code, inner_error_code, expectation): | ||||
|     host = "127.0.0.1" | ||||
|     mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code) | ||||
|     mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post) | ||||
|  | ||||
|     transport = AesTransport(host=host, credentials=Credentials("foo", "bar")) | ||||
|     transport._handshake_done = True | ||||
|     transport._session_expire_at = time.time() + 86400 | ||||
|     transport._encryption_session = mock_aes_device.encryption_session | ||||
|  | ||||
|     assert transport._login_token is None | ||||
|     with expectation: | ||||
|         await transport.perform_login() | ||||
|         assert transport._login_token == mock_aes_device.token | ||||
|  | ||||
|  | ||||
| @status_parameters | ||||
| async def test_send(mocker, status_code, error_code, inner_error_code, expectation): | ||||
|     host = "127.0.0.1" | ||||
|     mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code) | ||||
|     mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post) | ||||
|  | ||||
|     transport = AesTransport(host=host, credentials=Credentials("foo", "bar")) | ||||
|     transport._handshake_done = True | ||||
|     transport._session_expire_at = time.time() + 86400 | ||||
|     transport._encryption_session = mock_aes_device.encryption_session | ||||
|     transport._login_token = mock_aes_device.token | ||||
|  | ||||
|     un, pw = transport.hash_credentials(True) | ||||
|     request = { | ||||
|         "method": "get_device_info", | ||||
|         "params": None, | ||||
|         "request_time_milis": round(time.time() * 1000), | ||||
|         "requestID": 1, | ||||
|         "terminal_uuid": "foobar", | ||||
|     } | ||||
|     with expectation: | ||||
|         res = await transport.send(json_dumps(request)) | ||||
|         assert "result" in res | ||||
|  | ||||
|  | ||||
| class MockAesDevice: | ||||
|     class _mock_response: | ||||
|         def __init__(self, status_code, json: dict): | ||||
|             self.status_code = status_code | ||||
|             self._json = json | ||||
|  | ||||
|         def json(self): | ||||
|             return self._json | ||||
|  | ||||
|     encryption_session = AesEncyptionSession(KEY_IV[:16], KEY_IV[16:]) | ||||
|     token = "test_token"  # noqa | ||||
|  | ||||
|     def __init__(self, host, status_code=200, error_code=0, inner_error_code=0): | ||||
|         self.host = host | ||||
|         self.status_code = status_code | ||||
|         self.error_code = error_code | ||||
|         self.inner_error_code = inner_error_code | ||||
|  | ||||
|     async def post(self, url, params=None, json=None, *_, **__): | ||||
|         return await self._post(url, json) | ||||
|  | ||||
|     async def _post(self, url, json): | ||||
|         if json["method"] == "handshake": | ||||
|             return await self._return_handshake_response(url, json) | ||||
|         elif json["method"] == "securePassthrough": | ||||
|             return await self._return_secure_passthrough_response(url, json) | ||||
|         elif json["method"] == "login_device": | ||||
|             return await self._return_login_response(url, json) | ||||
|         else: | ||||
|             assert url == f"http://{self.host}/app?token={self.token}" | ||||
|             return await self._return_send_response(url, json) | ||||
|  | ||||
|     async def _return_handshake_response(self, url, json): | ||||
|         start = len("-----BEGIN PUBLIC KEY-----\n") | ||||
|         end = len("\n-----END PUBLIC KEY-----\n") | ||||
|         client_pub_key = json["params"]["key"][start:-end] | ||||
|  | ||||
|         client_pub_key_data = base64.b64decode(client_pub_key.encode()) | ||||
|         client_pub_key = serialization.load_der_public_key(client_pub_key_data, None) | ||||
|         encrypted_key = client_pub_key.encrypt(KEY_IV, asymmetric_padding.PKCS1v15()) | ||||
|         key_64 = base64.b64encode(encrypted_key).decode() | ||||
|         return self._mock_response( | ||||
|             self.status_code, {"result": {"key": key_64}, "error_code": self.error_code} | ||||
|         ) | ||||
|  | ||||
|     async def _return_secure_passthrough_response(self, url, json): | ||||
|         encrypted_request = json["params"]["request"] | ||||
|         decrypted_request = self.encryption_session.decrypt(encrypted_request.encode()) | ||||
|         decrypted_request_dict = json_loads(decrypted_request) | ||||
|         decrypted_response = await self._post(url, decrypted_request_dict) | ||||
|         decrypted_response_dict = decrypted_response.json() | ||||
|         encrypted_response = self.encryption_session.encrypt( | ||||
|             json_dumps(decrypted_response_dict).encode() | ||||
|         ) | ||||
|         result = { | ||||
|             "result": {"response": encrypted_response.decode()}, | ||||
|             "error_code": self.error_code, | ||||
|         } | ||||
|         return self._mock_response(self.status_code, result) | ||||
|  | ||||
|     async def _return_login_response(self, url, json): | ||||
|         result = {"result": {"token": self.token}, "error_code": self.inner_error_code} | ||||
|         return self._mock_response(self.status_code, result) | ||||
|  | ||||
|     async def _return_send_response(self, url, json): | ||||
|         result = {"result": {"method": None}, "error_code": self.inner_error_code} | ||||
|         return self._mock_response(self.status_code, result) | ||||
| @@ -96,10 +96,9 @@ async def test_protocol_reconnect(mocker, retry_count, protocol_class, transport | ||||
|  | ||||
|         return mock_response | ||||
|  | ||||
|     mocker.patch.object( | ||||
|         transport_class, "needs_handshake", property(lambda self: False) | ||||
|     ) | ||||
|     mocker.patch.object(transport_class, "needs_login", property(lambda self: False)) | ||||
|     mocker.patch.object(transport_class, "perform_handshake") | ||||
|     if hasattr(transport_class, "perform_login"): | ||||
|         mocker.patch.object(transport_class, "perform_login") | ||||
|  | ||||
|     send_mock = mocker.patch.object( | ||||
|         transport_class, | ||||
| @@ -128,7 +127,7 @@ async def test_protocol_logging(mocker, caplog, log_level): | ||||
|     seed = secrets.token_bytes(16) | ||||
|     auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar")) | ||||
|     encryption_session = KlapEncryptionSession(seed, seed, auth_hash) | ||||
|     protocol = IotProtocol("127.0.0.1") | ||||
|     protocol = IotProtocol("127.0.0.1", transport=KlapTransport("127.0.0.1")) | ||||
|  | ||||
|     protocol._transport._handshake_done = True | ||||
|     protocol._transport._session_expire_at = time.time() + 86400 | ||||
| @@ -206,7 +205,10 @@ async def test_handshake1(mocker, device_credentials, expectation): | ||||
|         httpx.AsyncClient, "post", side_effect=_return_handshake1_response | ||||
|     ) | ||||
|  | ||||
|     protocol = IotProtocol("127.0.0.1", credentials=client_credentials) | ||||
|     protocol = IotProtocol( | ||||
|         "127.0.0.1", | ||||
|         transport=KlapTransport("127.0.0.1", credentials=client_credentials), | ||||
|     ) | ||||
|  | ||||
|     protocol._transport.http_client = httpx.AsyncClient() | ||||
|     with expectation: | ||||
| @@ -243,7 +245,10 @@ async def test_handshake(mocker): | ||||
|         httpx.AsyncClient, "post", side_effect=_return_handshake_response | ||||
|     ) | ||||
|  | ||||
|     protocol = IotProtocol("127.0.0.1", credentials=client_credentials) | ||||
|     protocol = IotProtocol( | ||||
|         "127.0.0.1", | ||||
|         transport=KlapTransport("127.0.0.1", credentials=client_credentials), | ||||
|     ) | ||||
|     protocol._transport.http_client = httpx.AsyncClient() | ||||
|  | ||||
|     response_status = 200 | ||||
| @@ -289,7 +294,10 @@ async def test_query(mocker): | ||||
|  | ||||
|     mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response) | ||||
|  | ||||
|     protocol = IotProtocol("127.0.0.1", credentials=client_credentials) | ||||
|     protocol = IotProtocol( | ||||
|         "127.0.0.1", | ||||
|         transport=KlapTransport("127.0.0.1", credentials=client_credentials), | ||||
|     ) | ||||
|  | ||||
|     for _ in range(10): | ||||
|         resp = await protocol.query({}) | ||||
| @@ -333,7 +341,10 @@ async def test_authentication_failures(mocker, response_status, expectation): | ||||
|  | ||||
|     mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response) | ||||
|  | ||||
|     protocol = IotProtocol("127.0.0.1", credentials=client_credentials) | ||||
|     protocol = IotProtocol( | ||||
|         "127.0.0.1", | ||||
|         transport=KlapTransport("127.0.0.1", credentials=client_credentials), | ||||
|     ) | ||||
|  | ||||
|     with expectation: | ||||
|         await protocol.query({}) | ||||
|   | ||||
| @@ -1,13 +1,21 @@ | ||||
| import errno | ||||
| import importlib | ||||
| import inspect | ||||
| import json | ||||
| import logging | ||||
| import pkgutil | ||||
| import struct | ||||
| import sys | ||||
|  | ||||
| import pytest | ||||
|  | ||||
| from ..exceptions import SmartDeviceException | ||||
| from ..protocol import TPLinkSmartHomeProtocol | ||||
| from ..protocol import ( | ||||
|     BaseTransport, | ||||
|     TPLinkProtocol, | ||||
|     TPLinkSmartHomeProtocol, | ||||
|     _XorTransport, | ||||
| ) | ||||
|  | ||||
|  | ||||
| @pytest.mark.parametrize("retry_count", [1, 3, 5]) | ||||
| @@ -24,7 +32,9 @@ async def test_protocol_retries(mocker, retry_count): | ||||
|  | ||||
|     conn = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) | ||||
|     with pytest.raises(SmartDeviceException): | ||||
|         await TPLinkSmartHomeProtocol("127.0.0.1").query({}, retry_count=retry_count) | ||||
|         await TPLinkSmartHomeProtocol( | ||||
|             "127.0.0.1", transport=_XorTransport("127.0.0.1") | ||||
|         ).query({}, retry_count=retry_count) | ||||
|  | ||||
|     assert conn.call_count == retry_count + 1 | ||||
|  | ||||
| @@ -35,7 +45,9 @@ async def test_protocol_no_retry_on_unreachable(mocker): | ||||
|         side_effect=OSError(errno.EHOSTUNREACH, "No route to host"), | ||||
|     ) | ||||
|     with pytest.raises(SmartDeviceException): | ||||
|         await TPLinkSmartHomeProtocol("127.0.0.1").query({}, retry_count=5) | ||||
|         await TPLinkSmartHomeProtocol( | ||||
|             "127.0.0.1", transport=_XorTransport("127.0.0.1") | ||||
|         ).query({}, retry_count=5) | ||||
|  | ||||
|     assert conn.call_count == 1 | ||||
|  | ||||
| @@ -46,7 +58,9 @@ async def test_protocol_no_retry_connection_refused(mocker): | ||||
|         side_effect=ConnectionRefusedError, | ||||
|     ) | ||||
|     with pytest.raises(SmartDeviceException): | ||||
|         await TPLinkSmartHomeProtocol("127.0.0.1").query({}, retry_count=5) | ||||
|         await TPLinkSmartHomeProtocol( | ||||
|             "127.0.0.1", transport=_XorTransport("127.0.0.1") | ||||
|         ).query({}, retry_count=5) | ||||
|  | ||||
|     assert conn.call_count == 1 | ||||
|  | ||||
| @@ -57,7 +71,9 @@ async def test_protocol_retry_recoverable_error(mocker): | ||||
|         side_effect=OSError(errno.ECONNRESET, "Connection reset by peer"), | ||||
|     ) | ||||
|     with pytest.raises(SmartDeviceException): | ||||
|         await TPLinkSmartHomeProtocol("127.0.0.1").query({}, retry_count=5) | ||||
|         await TPLinkSmartHomeProtocol( | ||||
|             "127.0.0.1", transport=_XorTransport("127.0.0.1") | ||||
|         ).query({}, retry_count=5) | ||||
|  | ||||
|     assert conn.call_count == 6 | ||||
|  | ||||
| @@ -91,7 +107,9 @@ async def test_protocol_reconnect(mocker, retry_count): | ||||
|         mocker.patch.object(reader, "readexactly", _mock_read) | ||||
|         return reader, writer | ||||
|  | ||||
|     protocol = TPLinkSmartHomeProtocol("127.0.0.1") | ||||
|     protocol = TPLinkSmartHomeProtocol( | ||||
|         "127.0.0.1", transport=_XorTransport("127.0.0.1") | ||||
|     ) | ||||
|     mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) | ||||
|     response = await protocol.query({}, retry_count=retry_count) | ||||
|     assert response == {"great": "success"} | ||||
| @@ -119,7 +137,9 @@ async def test_protocol_logging(mocker, caplog, log_level): | ||||
|         mocker.patch.object(reader, "readexactly", _mock_read) | ||||
|         return reader, writer | ||||
|  | ||||
|     protocol = TPLinkSmartHomeProtocol("127.0.0.1") | ||||
|     protocol = TPLinkSmartHomeProtocol( | ||||
|         "127.0.0.1", transport=_XorTransport("127.0.0.1") | ||||
|     ) | ||||
|     mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) | ||||
|     response = await protocol.query({}) | ||||
|     assert response == {"great": "success"} | ||||
| @@ -153,7 +173,9 @@ async def test_protocol_custom_port(mocker, custom_port): | ||||
|         mocker.patch.object(reader, "readexactly", _mock_read) | ||||
|         return reader, writer | ||||
|  | ||||
|     protocol = TPLinkSmartHomeProtocol("127.0.0.1", port=custom_port) | ||||
|     protocol = TPLinkSmartHomeProtocol( | ||||
|         "127.0.0.1", transport=_XorTransport("127.0.0.1", port=custom_port) | ||||
|     ) | ||||
|     mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) | ||||
|     response = await protocol.query({}) | ||||
|     assert response == {"great": "success"} | ||||
| @@ -227,3 +249,63 @@ def test_decrypt_unicode(): | ||||
|     d = "{'snowman': '\u2603'}" | ||||
|  | ||||
|     assert d == TPLinkSmartHomeProtocol.decrypt(e) | ||||
|  | ||||
|  | ||||
| def _get_subclasses(of_class): | ||||
|     import kasa | ||||
|  | ||||
|     package = sys.modules["kasa"] | ||||
|     subclasses = set() | ||||
|     for _, modname, _ in pkgutil.iter_modules(package.__path__): | ||||
|         importlib.import_module("." + modname, package="kasa") | ||||
|         module = sys.modules["kasa." + modname] | ||||
|         for name, obj in inspect.getmembers(module): | ||||
|             if inspect.isclass(obj) and issubclass(obj, of_class): | ||||
|                 subclasses.add((name, obj)) | ||||
|     return subclasses | ||||
|  | ||||
|  | ||||
| @pytest.mark.parametrize( | ||||
|     "class_name_obj", _get_subclasses(TPLinkProtocol), ids=lambda t: t[0] | ||||
| ) | ||||
| def test_protocol_init_signature(class_name_obj): | ||||
|     params = list(inspect.signature(class_name_obj[1].__init__).parameters.values()) | ||||
|  | ||||
|     assert len(params) == 3 | ||||
|     assert ( | ||||
|         params[0].name == "self" | ||||
|         and params[0].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD | ||||
|     ) | ||||
|     assert ( | ||||
|         params[1].name == "host" | ||||
|         and params[1].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD | ||||
|     ) | ||||
|     assert ( | ||||
|         params[2].name == "transport" | ||||
|         and params[2].kind == inspect.Parameter.KEYWORD_ONLY | ||||
|     ) | ||||
|  | ||||
|  | ||||
| @pytest.mark.parametrize( | ||||
|     "class_name_obj", _get_subclasses(BaseTransport), ids=lambda t: t[0] | ||||
| ) | ||||
| def test_transport_init_signature(class_name_obj): | ||||
|     params = list(inspect.signature(class_name_obj[1].__init__).parameters.values()) | ||||
|  | ||||
|     assert len(params) == 5 | ||||
|     assert ( | ||||
|         params[0].name == "self" | ||||
|         and params[0].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD | ||||
|     ) | ||||
|     assert ( | ||||
|         params[1].name == "host" | ||||
|         and params[1].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD | ||||
|     ) | ||||
|     assert params[2].name == "port" and params[2].kind == inspect.Parameter.KEYWORD_ONLY | ||||
|     assert ( | ||||
|         params[3].name == "credentials" | ||||
|         and params[3].kind == inspect.Parameter.KEYWORD_ONLY | ||||
|     ) | ||||
|     assert ( | ||||
|         params[4].name == "timeout" and params[4].kind == inspect.Parameter.KEYWORD_ONLY | ||||
|     ) | ||||
|   | ||||
| @@ -232,7 +232,7 @@ async def test_modules_preserved(dev: SmartDevice): | ||||
| async def test_create_smart_device_with_timeout(): | ||||
|     """Make sure timeout is passed to the protocol.""" | ||||
|     dev = SmartDevice(host="127.0.0.1", timeout=100) | ||||
|     assert dev.protocol.timeout == 100 | ||||
|     assert dev.protocol._transport._timeout == 100 | ||||
|  | ||||
|  | ||||
| async def test_create_thin_wrapper(): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 sdb9696
					sdb9696