From 4a0019950661116ccf36d45c5c43b29a6f918daa Mon Sep 17 00:00:00 2001 From: sdb9696 <51370195+sdb9696@users.noreply.github.com> Date: Mon, 4 Dec 2023 18:50:05 +0000 Subject: [PATCH] Add klap support for TAPO protocol by splitting out Transports and Protocols (#557) * Add support for TAPO/SMART KLAP and seperate transports from protocols * Add tests and some review changes * Update following review * Updates following review --- kasa/__init__.py | 6 +- kasa/aesprotocol.py | 498 ------------------ kasa/aestransport.py | 338 ++++++++++++ kasa/device_factory.py | 44 +- kasa/discover.py | 44 +- kasa/iotprotocol.py | 100 ++++ kasa/{klapprotocol.py => klaptransport.py} | 299 ++++++----- kasa/protocol.py | 52 ++ kasa/smartdevice.py | 2 +- kasa/smartprotocol.py | 219 ++++++++ kasa/tapo/tapodevice.py | 4 +- kasa/tests/conftest.py | 251 +++++++-- kasa/tests/fixtures/smart/P110_1.0_1.3.0.json | 180 +++++++ kasa/tests/newfakes.py | 39 +- kasa/tests/test_cli.py | 35 +- kasa/tests/test_device_factory.py | 99 ++-- kasa/tests/test_discovery.py | 94 ++-- kasa/tests/test_klapprotocol.py | 137 +++-- kasa/tests/test_plug.py | 13 +- kasa/tests/test_readme_examples.py | 14 +- kasa/tests/test_smartdevice.py | 23 +- 21 files changed, 1604 insertions(+), 887 deletions(-) delete mode 100644 kasa/aesprotocol.py create mode 100644 kasa/aestransport.py create mode 100755 kasa/iotprotocol.py rename kasa/{klapprotocol.py => klaptransport.py} (66%) mode change 100755 => 100644 create mode 100644 kasa/smartprotocol.py create mode 100644 kasa/tests/fixtures/smart/P110_1.0_1.3.0.json diff --git a/kasa/__init__.py b/kasa/__init__.py index 989e507f..7de394c1 100755 --- a/kasa/__init__.py +++ b/kasa/__init__.py @@ -21,13 +21,14 @@ from kasa.exceptions import ( SmartDeviceException, UnsupportedDeviceException, ) -from kasa.klapprotocol import TPLinkKlap +from kasa.iotprotocol import IotProtocol from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol from kasa.smartbulb import SmartBulb, SmartBulbPreset, TurnOnBehavior, TurnOnBehaviors from kasa.smartdevice import DeviceType, SmartDevice from kasa.smartdimmer import SmartDimmer from kasa.smartlightstrip import SmartLightStrip from kasa.smartplug import SmartPlug +from kasa.smartprotocol import SmartProtocol from kasa.smartstrip import SmartStrip __version__ = version("python-kasa") @@ -37,7 +38,8 @@ __all__ = [ "Discover", "TPLinkSmartHomeProtocol", "TPLinkProtocol", - "TPLinkKlap", + "IotProtocol", + "SmartProtocol", "SmartBulb", "SmartBulbPreset", "TurnOnBehaviors", diff --git a/kasa/aesprotocol.py b/kasa/aesprotocol.py deleted file mode 100644 index 98776ce2..00000000 --- a/kasa/aesprotocol.py +++ /dev/null @@ -1,498 +0,0 @@ -"""Implementation of the TP-Link AES Protocol. - -Based on the work of https://github.com/petretiandrea/plugp100 -under compatible GNU GPL3 license. -""" - -import asyncio -import base64 -import hashlib -import logging -import time -import uuid -from pprint import pformat as pf -from typing import Dict, Optional, Union - -import httpx -from cryptography.hazmat.primitives import hashes, padding, serialization -from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding -from cryptography.hazmat.primitives.asymmetric import rsa -from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes - -from .credentials import Credentials -from .exceptions import AuthenticationException, SmartDeviceException -from .json import dumps as json_dumps -from .json import loads as json_loads -from .protocol import TPLinkProtocol - -_LOGGER = logging.getLogger(__name__) -logging.getLogger("httpx").propagate = False - - -def _md5(payload: bytes) -> bytes: - digest = hashes.Hash(hashes.MD5()) # noqa: S303 - digest.update(payload) - hash = digest.finalize() - return hash - - -def _sha1(payload: bytes) -> str: - sha1_algo = hashlib.sha1() # noqa: S324 - sha1_algo.update(payload) - return sha1_algo.hexdigest() - - -class TPLinkAes(TPLinkProtocol): - """Implementation of the AES encryption protocol. - - AES is the name used in device discovery for TP-Link's TAPO encryption - protocol, sometimes used by newer firmware versions on kasa devices. - """ - - DEFAULT_PORT = 80 - DEFAULT_TIMEOUT = 5 - SESSION_COOKIE_NAME = "TP_SESSIONID" - COMMON_HEADERS = { - "Content-Type": "application/json", - "requestByApp": "true", - "Accept": "application/json", - } - - def __init__( - self, - host: str, - *, - credentials: Optional[Credentials] = None, - timeout: Optional[int] = None, - ) -> None: - super().__init__(host=host, port=self.DEFAULT_PORT) - - self.credentials = ( - credentials - if credentials and credentials.username and credentials.password - else Credentials(username="", password="") - ) - - self._local_seed: Optional[bytes] = None - self.local_auth_hash = self.generate_auth_hash(self.credentials) - self.local_auth_owner = self.generate_owner_hash(self.credentials).hex() - self.kasa_setup_auth_hash = None - self.blank_auth_hash = None - self.handshake_lock = asyncio.Lock() - self.query_lock = asyncio.Lock() - self.handshake_done = False - - self.encryption_session: Optional[AesEncyptionSession] = None - self.session_expire_at: Optional[float] = None - - self.timeout = timeout if timeout else self.DEFAULT_TIMEOUT - self.session_cookie = None - self.terminal_uuid = None - self.http_client: Optional[httpx.AsyncClient] = None - self.request_id_generator = SnowflakeId(1, 1) - self.login_token = None - - _LOGGER.debug("Created AES object for %s", self.host) - - def hash_credentials(self, credentials, try_login_version2): - """Hash the credentials.""" - if try_login_version2: - un = base64.b64encode( - _sha1(credentials.username.encode()).encode() - ).decode() - pw = base64.b64encode( - _sha1(credentials.password.encode()).encode() - ).decode() - else: - un = base64.b64encode( - _sha1(credentials.username.encode()).encode() - ).decode() - pw = base64.b64encode(credentials.password.encode()).decode() - return un, pw - - async def client_post(self, url, params=None, data=None, json=None, headers=None): - """Send an http post request to the device.""" - response_data = None - cookies = None - if self.session_cookie: - cookies = httpx.Cookies() - cookies.set(self.SESSION_COOKIE_NAME, self.session_cookie) - self.http_client.cookies.clear() - resp = await self.http_client.post( - url, - params=params, - data=data, - json=json, - timeout=self.timeout, - cookies=cookies, - headers=self.COMMON_HEADERS, - ) - if resp.status_code == 200: - response_data = resp.json() - - return resp.status_code, response_data - - async def send_secure_passthrough(self, request): - """Send encrypted message as passthrough.""" - url = f"http://{self.host}/app" - if self.login_token: - url += f"?token={self.login_token}" - raw_request = json_dumps(request) - encrypted_payload = self.encryption_session.encrypt(raw_request.encode()) - passthrough_request = { - "method": "securePassthrough", - "params": {"request": encrypted_payload.decode()}, - } - status_code, resp_dict = await self.client_post(url, json=passthrough_request) - if status_code == 200 and resp_dict["error_code"] == 0: - response = self.encryption_session.decrypt( - resp_dict["result"]["response"].encode() - ) - resp_dict = json_loads(response) - if resp_dict["error_code"] != 0: - raise SmartDeviceException( - f"Could not complete send, response was {resp_dict}", - ) - if "result" in resp_dict: - return resp_dict["result"] - else: - raise AuthenticationException("Could not complete send") - - def get_aes_request(self, method, params=None): - """Get a request message.""" - request = { - "method": method, - "params": params, - "requestID": self.request_id_generator.generate_id(), - "request_time_milis": round(time.time() * 1000), - "terminal_uuid": self.terminal_uuid, - } - return request - - async def perform_login(self, login_v2): - """Login to the device.""" - self.login_token = None - - un, pw = self.hash_credentials(self.credentials, login_v2) - params = {"password": pw, "username": un} - request = self.get_aes_request("login_device", params) - try: - result = await self.send_secure_passthrough(request) - except SmartDeviceException as ex: - raise AuthenticationException(ex) from ex - self.login_token = result["token"] - - async def perform_handshake(self): - """Perform the handshake.""" - _LOGGER.debug("Will perform handshaking...") - _LOGGER.debug("Generating keypair") - - self.handshake_done = False - self.session_expire_at = None - self.session_cookie = None - - url = f"http://{self.host}/app" - key_pair = KeyPair.create_key_pair() - - pub_key = ( - "-----BEGIN PUBLIC KEY-----\n" - + key_pair.get_public_key() - + "\n-----END PUBLIC KEY-----\n" - ) - handshake_params = {"key": pub_key} - _LOGGER.debug(f"Handshake params: {handshake_params}") - - request_body = {"method": "handshake", "params": handshake_params} - - _LOGGER.debug(f"Request {request_body}") - - status_code, resp_dict = await self.client_post(url, json=request_body) - - _LOGGER.debug(f"Device responded with: {resp_dict}") - - if status_code == 200 and resp_dict["error_code"] == 0: - _LOGGER.debug("Decoding handshake key...") - handshake_key = resp_dict["result"]["key"] - - self.session_cookie = self.http_client.cookies.get( # type: ignore - self.SESSION_COOKIE_NAME - ) - if not self.session_cookie: - self.session_cookie = self.http_client.cookies.get( # type: ignore - "SESSIONID" - ) - - self.session_expire_at = time.time() + 86400 - self.encryption_session = AesEncyptionSession.create_from_keypair( - handshake_key, key_pair - ) - - self.terminal_uuid = base64.b64encode(_md5(uuid.uuid4().bytes)).decode( - "UTF-8" - ) - self.handshake_done = True - - _LOGGER.debug("Handshake with %s complete", self.host) - - else: - raise AuthenticationException("Could not complete handshake") - - def handshake_session_expired(self): - """Return true if session has expired.""" - return ( - self.session_expire_at is None or self.session_expire_at - time.time() <= 0 - ) - - @staticmethod - def generate_auth_hash(creds: Credentials): - """Generate an md5 auth hash for the protocol on the supplied credentials.""" - un = creds.username or "" - pw = creds.password or "" - return _md5(_md5(un.encode()) + _md5(pw.encode())) - - @staticmethod - def generate_owner_hash(creds: Credentials): - """Return the MD5 hash of the username in this object.""" - un = creds.username or "" - return _md5(un.encode()) - - async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: - """Query the device retrying for retry_count on failure.""" - async with self.query_lock: - return await self._query(request, retry_count) - - async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: - for retry in range(retry_count + 1): - try: - return await self._execute_query(request, retry) - except httpx.CloseError as sdex: - await self.close() - if retry >= retry_count: - _LOGGER.debug("Giving up on %s after %s retries", self.host, retry) - raise SmartDeviceException( - f"Unable to connect to the device: {self.host}: {sdex}" - ) from sdex - continue - except httpx.ConnectError as cex: - await self.close() - raise SmartDeviceException( - f"Unable to connect to the device: {self.host}: {cex}" - ) from cex - except TimeoutError as tex: - await self.close() - raise SmartDeviceException( - f"Unable to connect to the device, timed out: {self.host}: {tex}" - ) from tex - except AuthenticationException as auex: - _LOGGER.debug("Unable to authenticate with %s, not retrying", self.host) - raise auex - except Exception as ex: - await self.close() - if retry >= retry_count: - _LOGGER.debug("Giving up on %s after %s retries", self.host, retry) - raise SmartDeviceException( - f"Unable to connect to the device: {self.host}: {ex}" - ) from ex - continue - - # make mypy happy, this should never be reached.. - raise SmartDeviceException("Query reached somehow to unreachable") - - async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> Dict: - _LOGGER.debug( - "%s >> %s", - self.host, - _LOGGER.isEnabledFor(logging.DEBUG) and pf(request), - ) - - if not self.http_client: - self.http_client = httpx.AsyncClient() - - if not self.handshake_done or self.handshake_session_expired(): - try: - await self.perform_handshake() - await self.perform_login(False) - except AuthenticationException: - await self.perform_handshake() - await self.perform_login(True) - - if isinstance(request, dict): - aes_method = next(iter(request)) - aes_params = request[aes_method] - else: - aes_method = request - aes_params = None - - aes_request = self.get_aes_request(aes_method, aes_params) - response_data = await self.send_secure_passthrough(aes_request) - - _LOGGER.debug( - "%s << %s", - self.host, - _LOGGER.isEnabledFor(logging.DEBUG) and pf(response_data), - ) - - return response_data - - async def close(self) -> None: - """Close the protocol.""" - client = self.http_client - self.http_client = None - if client: - await client.aclose() - - -class AesEncyptionSession: - """Class for an AES encryption session.""" - - @staticmethod - def create_from_keypair(handshake_key: str, keypair): - """Create the encryption session.""" - handshake_key_bytes: bytes = base64.b64decode(handshake_key.encode("UTF-8")) - private_key_data = base64.b64decode(keypair.get_private_key().encode("UTF-8")) - - private_key = serialization.load_der_private_key(private_key_data, None, None) - key_and_iv = private_key.decrypt( - handshake_key_bytes, asymmetric_padding.PKCS1v15() - ) - if key_and_iv is None: - raise ValueError("Decryption failed!") - - return AesEncyptionSession(key_and_iv[:16], key_and_iv[16:]) - - def __init__(self, key, iv): - self.cipher = Cipher(algorithms.AES(key), modes.CBC(iv)) - self.padding_strategy = padding.PKCS7(algorithms.AES.block_size) - - def encrypt(self, data) -> bytes: - """Encrypt the message.""" - encryptor = self.cipher.encryptor() - padder = self.padding_strategy.padder() - padded_data = padder.update(data) + padder.finalize() - encrypted = encryptor.update(padded_data) + encryptor.finalize() - return base64.b64encode(encrypted) - - def decrypt(self, data) -> str: - """Decrypt the message.""" - decryptor = self.cipher.decryptor() - unpadder = self.padding_strategy.unpadder() - decrypted = decryptor.update(base64.b64decode(data)) + decryptor.finalize() - unpadded_data = unpadder.update(decrypted) + unpadder.finalize() - return unpadded_data.decode() - - -class KeyPair: - """Class for generating key pairs.""" - - @staticmethod - def create_key_pair(key_size: int = 1024): - """Create a key pair.""" - private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size) - public_key = private_key.public_key() - - private_key_bytes = private_key.private_bytes( - encoding=serialization.Encoding.DER, - format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption(), - ) - public_key_bytes = public_key.public_bytes( - encoding=serialization.Encoding.DER, - format=serialization.PublicFormat.SubjectPublicKeyInfo, - ) - - return KeyPair( - private_key=base64.b64encode(private_key_bytes).decode("UTF-8"), - public_key=base64.b64encode(public_key_bytes).decode("UTF-8"), - ) - - def __init__(self, private_key: str, public_key: str): - self.private_key = private_key - self.public_key = public_key - - def get_private_key(self) -> str: - """Get the private key.""" - return self.private_key - - def get_public_key(self) -> str: - """Get the public key.""" - return self.public_key - - -class SnowflakeId: - """Class for generating snowflake ids.""" - - EPOCH = 1420041600000 # Custom epoch (in milliseconds) - WORKER_ID_BITS = 5 - DATA_CENTER_ID_BITS = 5 - SEQUENCE_BITS = 12 - - MAX_WORKER_ID = (1 << WORKER_ID_BITS) - 1 - MAX_DATA_CENTER_ID = (1 << DATA_CENTER_ID_BITS) - 1 - - SEQUENCE_MASK = (1 << SEQUENCE_BITS) - 1 - - def __init__(self, worker_id, data_center_id): - if worker_id > SnowflakeId.MAX_WORKER_ID or worker_id < 0: - raise ValueError( - "Worker ID can't be greater than " - + str(SnowflakeId.MAX_WORKER_ID) - + " or less than 0" - ) - if data_center_id > SnowflakeId.MAX_DATA_CENTER_ID or data_center_id < 0: - raise ValueError( - "Data center ID can't be greater than " - + str(SnowflakeId.MAX_DATA_CENTER_ID) - + " or less than 0" - ) - - self.worker_id = worker_id - self.data_center_id = data_center_id - self.sequence = 0 - self.last_timestamp = -1 - - def generate_id(self): - """Generate a snowflake id.""" - timestamp = self._current_millis() - - if timestamp < self.last_timestamp: - raise ValueError("Clock moved backwards. Refusing to generate ID.") - - if timestamp == self.last_timestamp: - # Within the same millisecond, increment the sequence number - self.sequence = (self.sequence + 1) & SnowflakeId.SEQUENCE_MASK - if self.sequence == 0: - # Sequence exceeds its bit range, wait until the next millisecond - timestamp = self._wait_next_millis(self.last_timestamp) - else: - # New millisecond, reset the sequence number - self.sequence = 0 - - # Update the last timestamp - self.last_timestamp = timestamp - - # Generate and return the final ID - return ( - ( - (timestamp - SnowflakeId.EPOCH) - << ( - SnowflakeId.WORKER_ID_BITS - + SnowflakeId.SEQUENCE_BITS - + SnowflakeId.DATA_CENTER_ID_BITS - ) - ) - | ( - self.data_center_id - << (SnowflakeId.SEQUENCE_BITS + SnowflakeId.WORKER_ID_BITS) - ) - | (self.worker_id << SnowflakeId.SEQUENCE_BITS) - | self.sequence - ) - - def _current_millis(self): - return round(time.time() * 1000) - - def _wait_next_millis(self, last_timestamp): - timestamp = self._current_millis() - while timestamp <= last_timestamp: - timestamp = self._current_millis() - return timestamp diff --git a/kasa/aestransport.py b/kasa/aestransport.py new file mode 100644 index 00000000..6757013d --- /dev/null +++ b/kasa/aestransport.py @@ -0,0 +1,338 @@ +"""Implementation of the TP-Link AES transport. + +Based on the work of https://github.com/petretiandrea/plugp100 +under compatible GNU GPL3 license. +""" + +import base64 +import hashlib +import logging +import time +from typing import Optional, Union + +import httpx +from cryptography.hazmat.primitives import padding, serialization +from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes + +from .credentials import Credentials +from .exceptions import AuthenticationException, SmartDeviceException +from .json import dumps as json_dumps +from .json import loads as json_loads +from .protocol import BaseTransport + +_LOGGER = logging.getLogger(__name__) + + +def _sha1(payload: bytes) -> str: + sha1_algo = hashlib.sha1() # noqa: S324 + sha1_algo.update(payload) + return sha1_algo.hexdigest() + + +class AesTransport(BaseTransport): + """Implementation of the AES encryption protocol. + + AES is the name used in device discovery for TP-Link's TAPO encryption + protocol, sometimes used by newer firmware versions on kasa devices. + """ + + DEFAULT_TIMEOUT = 5 + SESSION_COOKIE_NAME = "TP_SESSIONID" + COMMON_HEADERS = { + "Content-Type": "application/json", + "requestByApp": "true", + "Accept": "application/json", + } + + def __init__( + self, + host: str, + *, + credentials: Optional[Credentials] = None, + timeout: Optional[int] = None, + ) -> None: + super().__init__(host=host) + + self._credentials = credentials or Credentials(username="", password="") + + self._handshake_done = False + + self._encryption_session: Optional[AesEncyptionSession] = None + self._session_expire_at: Optional[float] = None + + self._timeout = timeout if timeout else self.DEFAULT_TIMEOUT + self._session_cookie = None + + self._http_client: httpx.AsyncClient = httpx.AsyncClient() + self._login_token = None + + _LOGGER.debug("Created AES object for %s", self.host) + + def hash_credentials(self, login_v2): + """Hash the credentials.""" + if login_v2: + un = base64.b64encode( + _sha1(self._credentials.username.encode()).encode() + ).decode() + pw = base64.b64encode( + _sha1(self._credentials.password.encode()).encode() + ).decode() + else: + un = base64.b64encode( + _sha1(self._credentials.username.encode()).encode() + ).decode() + pw = base64.b64encode(self._credentials.password.encode()).decode() + return un, pw + + async def client_post(self, url, params=None, data=None, json=None, headers=None): + """Send an http post request to the device.""" + response_data = None + cookies = None + if self._session_cookie: + cookies = httpx.Cookies() + cookies.set(self.SESSION_COOKIE_NAME, self._session_cookie) + self._http_client.cookies.clear() + resp = await self._http_client.post( + url, + params=params, + data=data, + json=json, + timeout=self._timeout, + cookies=cookies, + headers=self.COMMON_HEADERS, + ) + if resp.status_code == 200: + response_data = resp.json() + + return resp.status_code, response_data + + async def send_secure_passthrough(self, request: str): + """Send encrypted message as passthrough.""" + url = f"http://{self.host}/app" + if self._login_token: + url += f"?token={self._login_token}" + + encrypted_payload = self._encryption_session.encrypt(request.encode()) # type: ignore + passthrough_request = { + "method": "securePassthrough", + "params": {"request": encrypted_payload.decode()}, + } + status_code, resp_dict = await self.client_post(url, json=passthrough_request) + _LOGGER.debug(f"secure_passthrough response is {status_code}: {resp_dict}") + if status_code == 200 and resp_dict["error_code"] == 0: + response = self._encryption_session.decrypt( # type: ignore + resp_dict["result"]["response"].encode() + ) + _LOGGER.debug(f"decrypted secure_passthrough response is {response}") + resp_dict = json_loads(response) + return resp_dict + else: + self._handshake_done = False + self._login_token = None + raise AuthenticationException("Could not complete send") + + async def perform_login(self, login_request: Union[str, dict], *, login_v2: bool): + """Login to the device.""" + self._login_token = None + + if isinstance(login_request, str): + login_request_dict: dict = json_loads(login_request) + else: + login_request_dict = login_request + + un, pw = self.hash_credentials(login_v2) + login_request_dict["params"] = {"password": pw, "username": un} + request = json_dumps(login_request_dict) + try: + resp_dict = await self.send_secure_passthrough(request) + except SmartDeviceException as ex: + raise AuthenticationException(ex) from ex + self._login_token = resp_dict["result"]["token"] + + @property + def needs_login(self) -> bool: + """Return true if the transport needs to do a login.""" + return self._login_token is None + + async def login(self, request: str) -> None: + """Login to the device.""" + try: + if self.needs_handshake: + raise SmartDeviceException( + "Handshake must be complete before trying to login" + ) + await self.perform_login(request, login_v2=False) + except AuthenticationException: + await self.perform_handshake() + await self.perform_login(request, login_v2=True) + + @property + def needs_handshake(self) -> bool: + """Return true if the transport needs to do a handshake.""" + return not self._handshake_done or self._handshake_session_expired() + + async def handshake(self) -> None: + """Perform the encryption handshake.""" + await self.perform_handshake() + + async def perform_handshake(self): + """Perform the handshake.""" + _LOGGER.debug("Will perform handshaking...") + _LOGGER.debug("Generating keypair") + + self._handshake_done = False + self._session_expire_at = None + self._session_cookie = None + + url = f"http://{self.host}/app" + key_pair = KeyPair.create_key_pair() + + pub_key = ( + "-----BEGIN PUBLIC KEY-----\n" + + key_pair.get_public_key() + + "\n-----END PUBLIC KEY-----\n" + ) + handshake_params = {"key": pub_key} + _LOGGER.debug(f"Handshake params: {handshake_params}") + + request_body = {"method": "handshake", "params": handshake_params} + + _LOGGER.debug(f"Request {request_body}") + + status_code, resp_dict = await self.client_post(url, json=request_body) + + _LOGGER.debug(f"Device responded with: {resp_dict}") + + if status_code == 200 and resp_dict["error_code"] == 0: + _LOGGER.debug("Decoding handshake key...") + handshake_key = resp_dict["result"]["key"] + + self._session_cookie = self._http_client.cookies.get( # type: ignore + self.SESSION_COOKIE_NAME + ) + if not self._session_cookie: + self._session_cookie = self._http_client.cookies.get( # type: ignore + "SESSIONID" + ) + + self._session_expire_at = time.time() + 86400 + self._encryption_session = AesEncyptionSession.create_from_keypair( + handshake_key, key_pair + ) + + self._handshake_done = True + + _LOGGER.debug("Handshake with %s complete", self.host) + + else: + raise AuthenticationException("Could not complete handshake") + + def _handshake_session_expired(self): + """Return true if session has expired.""" + return ( + self._session_expire_at is None + or self._session_expire_at - time.time() <= 0 + ) + + async def send(self, request: str): + """Send the request.""" + if self.needs_handshake: + raise SmartDeviceException( + "Handshake must be complete before trying to send" + ) + if self.needs_login: + raise SmartDeviceException("Login must be complete before trying to send") + + resp_dict = await self.send_secure_passthrough(request) + if resp_dict["error_code"] != 0: + self._handshake_done = False + self._login_token = None + raise SmartDeviceException( + f"Could not complete send, response was {resp_dict}", + ) + return resp_dict + + async def close(self) -> None: + """Close the protocol.""" + client = self._http_client + self._http_client = None + if client: + await client.aclose() + + +class AesEncyptionSession: + """Class for an AES encryption session.""" + + @staticmethod + def create_from_keypair(handshake_key: str, keypair): + """Create the encryption session.""" + handshake_key_bytes: bytes = base64.b64decode(handshake_key.encode("UTF-8")) + private_key_data = base64.b64decode(keypair.get_private_key().encode("UTF-8")) + + private_key = serialization.load_der_private_key(private_key_data, None, None) + key_and_iv = private_key.decrypt( + handshake_key_bytes, asymmetric_padding.PKCS1v15() + ) + if key_and_iv is None: + raise ValueError("Decryption failed!") + + return AesEncyptionSession(key_and_iv[:16], key_and_iv[16:]) + + def __init__(self, key, iv): + self.cipher = Cipher(algorithms.AES(key), modes.CBC(iv)) + self.padding_strategy = padding.PKCS7(algorithms.AES.block_size) + + def encrypt(self, data) -> bytes: + """Encrypt the message.""" + encryptor = self.cipher.encryptor() + padder = self.padding_strategy.padder() + padded_data = padder.update(data) + padder.finalize() + encrypted = encryptor.update(padded_data) + encryptor.finalize() + return base64.b64encode(encrypted) + + def decrypt(self, data) -> str: + """Decrypt the message.""" + decryptor = self.cipher.decryptor() + unpadder = self.padding_strategy.unpadder() + decrypted = decryptor.update(base64.b64decode(data)) + decryptor.finalize() + unpadded_data = unpadder.update(decrypted) + unpadder.finalize() + return unpadded_data.decode() + + +class KeyPair: + """Class for generating key pairs.""" + + @staticmethod + def create_key_pair(key_size: int = 1024): + """Create a key pair.""" + private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size) + public_key = private_key.public_key() + + private_key_bytes = private_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + public_key_bytes = public_key.public_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + return KeyPair( + private_key=base64.b64encode(private_key_bytes).decode("UTF-8"), + public_key=base64.b64encode(public_key_bytes).decode("UTF-8"), + ) + + def __init__(self, private_key: str, public_key: str): + self.private_key = private_key + self.public_key = public_key + + def get_private_key(self) -> str: + """Get the private key.""" + return self.private_key + + def get_public_key(self) -> str: + """Get the public key.""" + return self.public_key diff --git a/kasa/device_factory.py b/kasa/device_factory.py index 9122003c..be293ee2 100755 --- a/kasa/device_factory.py +++ b/kasa/device_factory.py @@ -2,17 +2,21 @@ import logging import time -from typing import Any, Dict, Optional, Type +from typing import Any, Dict, Optional, Tuple, Type +from .aestransport import AesTransport from .credentials import Credentials from .device_type import DeviceType from .exceptions import UnsupportedDeviceException -from .protocol import TPLinkProtocol +from .iotprotocol import IotProtocol +from .klaptransport import KlapTransport, TPlinkKlapTransportV2 +from .protocol import BaseTransport, TPLinkProtocol from .smartbulb import SmartBulb from .smartdevice import SmartDevice, SmartDeviceException from .smartdimmer import SmartDimmer from .smartlightstrip import SmartLightStrip from .smartplug import SmartPlug +from .smartprotocol import SmartProtocol from .smartstrip import SmartStrip from .tapo.tapoplug import TapoPlug @@ -87,7 +91,7 @@ async def connect( if protocol_class is not None: unknown_dev.protocol = protocol_class(host, credentials=credentials) await unknown_dev.update() - device_class = get_device_class_from_info(unknown_dev.internal_state) + device_class = get_device_class_from_sys_info(unknown_dev.internal_state) dev = device_class(host=host, port=port, credentials=credentials, timeout=timeout) # Reuse the connection from the unknown device # so we don't have to reconnect @@ -104,7 +108,7 @@ async def connect( return dev -def get_device_class_from_info(info: Dict[str, Any]) -> Type[SmartDevice]: +def get_device_class_from_sys_info(info: Dict[str, Any]) -> Type[SmartDevice]: """Find SmartDevice subclass for device described by passed data.""" if "system" not in info or "get_sysinfo" not in info["system"]: raise SmartDeviceException("No 'system' or 'get_sysinfo' in response") @@ -129,3 +133,35 @@ def get_device_class_from_info(info: Dict[str, Any]) -> Type[SmartDevice]: return SmartBulb raise UnsupportedDeviceException("Unknown device type: %s" % type_) + + +def get_device_class_from_type_name(device_type: str) -> Optional[Type[SmartDevice]]: + """Return the device class from the type name.""" + supported_device_types: dict[str, Type[SmartDevice]] = { + "SMART.TAPOPLUG": TapoPlug, + "SMART.KASAPLUG": TapoPlug, + "IOT.SMARTPLUGSWITCH": SmartPlug, + } + return supported_device_types.get(device_type) + + +def get_protocol_from_connection_name( + connection_name: str, host: str, credentials: Optional[Credentials] = None +) -> Optional[TPLinkProtocol]: + """Return the protocol from the connection name.""" + supported_device_protocols: dict[ + str, Tuple[Type[TPLinkProtocol], Type[BaseTransport]] + ] = { + "IOT.KLAP": (IotProtocol, KlapTransport), + "SMART.AES": (SmartProtocol, AesTransport), + "SMART.KLAP": (SmartProtocol, TPlinkKlapTransportV2), + } + if connection_name not in supported_device_protocols: + return None + + protocol_class, transport_class = supported_device_protocols.get(connection_name) # type: ignore + transport: BaseTransport = transport_class(host, credentials=credentials) + protocol: TPLinkProtocol = protocol_class( + host, credentials=credentials, transport=transport + ) + return protocol diff --git a/kasa/discover.py b/kasa/discover.py index 59849bc0..2038369b 100755 --- a/kasa/discover.py +++ b/kasa/discover.py @@ -15,18 +15,18 @@ try: except ImportError: from pydantic import BaseModel, Field -from kasa.aesprotocol import TPLinkAes from kasa.credentials import Credentials from kasa.exceptions import UnsupportedDeviceException from kasa.json import dumps as json_dumps from kasa.json import loads as json_loads -from kasa.klapprotocol import TPLinkKlap -from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol +from kasa.protocol import TPLinkSmartHomeProtocol from kasa.smartdevice import SmartDevice, SmartDeviceException -from kasa.smartplug import SmartPlug -from kasa.tapo.tapoplug import TapoPlug -from .device_factory import get_device_class_from_info +from .device_factory import ( + get_device_class_from_sys_info, + get_device_class_from_type_name, + get_protocol_from_connection_name, +) _LOGGER = logging.getLogger(__name__) @@ -348,7 +348,16 @@ class Discover: @staticmethod def _get_device_class(info: dict) -> Type[SmartDevice]: """Find SmartDevice subclass for device described by passed data.""" - return get_device_class_from_info(info) + if "result" in info: + discovery_result = DiscoveryResult(**info["result"]) + dev_class = get_device_class_from_type_name(discovery_result.device_type) + if not dev_class: + raise UnsupportedDeviceException( + "Unknown device type: %s" % discovery_result.device_type + ) + return dev_class + else: + return get_device_class_from_sys_info(info) @staticmethod def _get_device_instance_legacy(data: bytes, ip: str, port: int) -> SmartDevice: @@ -384,24 +393,17 @@ class Discover: encrypt_type_ = ( f"{type_.split('.')[0]}.{discovery_result.mgt_encrypt_schm.encrypt_type}" ) - device_class = None - supported_device_types: dict[str, Type[SmartDevice]] = { - "SMART.TAPOPLUG": TapoPlug, - "SMART.KASAPLUG": TapoPlug, - "IOT.SMARTPLUGSWITCH": SmartPlug, - } - supported_device_protocols: dict[str, Type[TPLinkProtocol]] = { - "IOT.KLAP": TPLinkKlap, - "SMART.AES": TPLinkAes, - } - - if (device_class := supported_device_types.get(type_)) is None: + if (device_class := get_device_class_from_type_name(type_)) is None: _LOGGER.warning("Got unsupported device type: %s", type_) raise UnsupportedDeviceException( f"Unsupported device {ip} of type {type_}: {info}" ) - if (protocol_class := supported_device_protocols.get(encrypt_type_)) is None: + if ( + protocol := get_protocol_from_connection_name( + encrypt_type_, ip, credentials=credentials + ) + ) is None: _LOGGER.warning("Got unsupported device type: %s", encrypt_type_) raise UnsupportedDeviceException( f"Unsupported encryption scheme {ip} of type {encrypt_type_}: {info}" @@ -409,7 +411,7 @@ class Discover: _LOGGER.debug("[DISCOVERY] %s << %s", ip, info) device = device_class(ip, port=port, credentials=credentials) - device.protocol = protocol_class(ip, credentials=credentials) + device.protocol = protocol device.update_from_discover_info(discovery_result.get_dict()) return device diff --git a/kasa/iotprotocol.py b/kasa/iotprotocol.py new file mode 100755 index 00000000..2b7f422d --- /dev/null +++ b/kasa/iotprotocol.py @@ -0,0 +1,100 @@ +"""Module for the IOT legacy IOT KASA protocol.""" +import asyncio +import logging +from typing import Dict, Optional, Union + +import httpx + +from .credentials import Credentials +from .exceptions import AuthenticationException, SmartDeviceException +from .json import dumps as json_dumps +from .klaptransport import KlapTransport +from .protocol import BaseTransport, TPLinkProtocol + +_LOGGER = logging.getLogger(__name__) + + +class IotProtocol(TPLinkProtocol): + """Class for the legacy TPLink IOT KASA Protocol.""" + + DEFAULT_PORT = 80 + + def __init__( + self, + host: str, + *, + transport: Optional[BaseTransport] = None, + credentials: Optional[Credentials] = None, + timeout: Optional[int] = None, + ) -> None: + super().__init__(host=host, port=self.DEFAULT_PORT) + + self._credentials: Credentials = credentials or Credentials( + username="", password="" + ) + self._transport: BaseTransport = transport or KlapTransport( + host, credentials=self._credentials, timeout=timeout + ) + + self._query_lock = asyncio.Lock() + + async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: + """Query the device retrying for retry_count on failure.""" + if isinstance(request, dict): + request = json_dumps(request) + assert isinstance(request, str) # noqa: S101 + + async with self._query_lock: + return await self._query(request, retry_count) + + async def _query(self, request: str, retry_count: int = 3) -> Dict: + for retry in range(retry_count + 1): + try: + return await self._execute_query(request, retry) + except httpx.CloseError as sdex: + await self.close() + if retry >= retry_count: + _LOGGER.debug("Giving up on %s after %s retries", self.host, retry) + raise SmartDeviceException( + f"Unable to connect to the device: {self.host}: {sdex}" + ) from sdex + continue + except httpx.ConnectError as cex: + await self.close() + raise SmartDeviceException( + f"Unable to connect to the device: {self.host}: {cex}" + ) from cex + except TimeoutError as tex: + await self.close() + raise SmartDeviceException( + f"Unable to connect to the device, timed out: {self.host}: {tex}" + ) from tex + except AuthenticationException as auex: + _LOGGER.debug("Unable to authenticate with %s, not retrying", self.host) + raise auex + except Exception as ex: + await self.close() + if retry >= retry_count: + _LOGGER.debug("Giving up on %s after %s retries", self.host, retry) + raise SmartDeviceException( + f"Unable to connect to the device: {self.host}: {ex}" + ) from ex + continue + + # make mypy happy, this should never be reached.. + raise SmartDeviceException("Query reached somehow to unreachable") + + async def _execute_query(self, request: str, retry_count: int) -> Dict: + if self._transport.needs_handshake: + await self._transport.handshake() + + if self._transport.needs_login: # This shouln't happen + raise SmartDeviceException( + "IOT Protocol needs to login to transport but is not login aware" + ) + + return await self._transport.send(request) + + async def close(self) -> None: + """Close the protocol.""" + await self._transport.close() diff --git a/kasa/klapprotocol.py b/kasa/klaptransport.py old mode 100755 new mode 100644 similarity index 66% rename from kasa/klapprotocol.py rename to kasa/klaptransport.py index 36a42c58..c28cb035 --- a/kasa/klapprotocol.py +++ b/kasa/klaptransport.py @@ -47,7 +47,7 @@ import logging import secrets import time from pprint import pformat as pf -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Optional, Tuple import httpx from cryptography.hazmat.primitives import hashes, padding @@ -55,33 +55,33 @@ from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from .credentials import Credentials from .exceptions import AuthenticationException, SmartDeviceException -from .json import dumps as json_dumps from .json import loads as json_loads -from .protocol import TPLinkProtocol +from .protocol import BaseTransport, md5 _LOGGER = logging.getLogger(__name__) logging.getLogger("httpx").propagate = False def _sha256(payload: bytes) -> bytes: - return hashlib.sha256(payload).digest() - - -def _md5(payload: bytes) -> bytes: - digest = hashes.Hash(hashes.MD5()) # noqa: S303 + digest = hashes.Hash(hashes.SHA256()) # noqa: S303 digest.update(payload) hash = digest.finalize() return hash -class TPLinkKlap(TPLinkProtocol): +def _sha1(payload: bytes) -> bytes: + digest = hashes.Hash(hashes.SHA1()) # noqa: S303 + digest.update(payload) + return digest.finalize() + + +class KlapTransport(BaseTransport): """Implementation of the KLAP encryption protocol. KLAP is the name used in device discovery for TP-Link's new encryption protocol, used by newer firmware versions. """ - DEFAULT_PORT = 80 DEFAULT_TIMEOUT = 5 DISCOVERY_QUERY = {"system": {"get_sysinfo": None}} KASA_SETUP_EMAIL = "kasa@tp-link.net" @@ -95,29 +95,24 @@ class TPLinkKlap(TPLinkProtocol): credentials: Optional[Credentials] = None, timeout: Optional[int] = None, ) -> None: - super().__init__(host=host, port=self.DEFAULT_PORT) - - self.credentials = ( - credentials - if credentials and credentials.username and credentials.password - else Credentials(username="", password="") - ) + super().__init__(host=host) + self._credentials = credentials or Credentials(username="", password="") self._local_seed: Optional[bytes] = None - self.local_auth_hash = self.generate_auth_hash(self.credentials) - self.local_auth_owner = self.generate_owner_hash(self.credentials).hex() - self.kasa_setup_auth_hash = None - self.blank_auth_hash = None - self.handshake_lock = asyncio.Lock() - self.query_lock = asyncio.Lock() - self.handshake_done = False + self._local_auth_hash = self.generate_auth_hash(self._credentials) + self._local_auth_owner = self.generate_owner_hash(self._credentials).hex() + self._kasa_setup_auth_hash = None + self._blank_auth_hash = None + self._handshake_lock = asyncio.Lock() + self._query_lock = asyncio.Lock() + self._handshake_done = False - self.encryption_session: Optional[KlapEncryptionSession] = None - self.session_expire_at: Optional[float] = None + self._encryption_session: Optional[KlapEncryptionSession] = None + self._session_expire_at: Optional[float] = None - self.timeout = timeout if timeout else self.DEFAULT_TIMEOUT - self.session_cookie = None - self.http_client: Optional[httpx.AsyncClient] = None + self._timeout = timeout if timeout else self.DEFAULT_TIMEOUT + self._session_cookie = None + self._http_client: httpx.AsyncClient = httpx.AsyncClient() _LOGGER.debug("Created KLAP object for %s", self.host) @@ -125,15 +120,15 @@ class TPLinkKlap(TPLinkProtocol): """Send an http post request to the device.""" response_data = None cookies = None - if self.session_cookie: + if self._session_cookie: cookies = httpx.Cookies() - cookies.set(self.SESSION_COOKIE_NAME, self.session_cookie) - self.http_client.cookies.clear() - resp = await self.http_client.post( + cookies.set(self.SESSION_COOKIE_NAME, self._session_cookie) + self._http_client.cookies.clear() + resp = await self._http_client.post( url, params=params, data=data, - timeout=self.timeout, + timeout=self._timeout, cookies=cookies, ) if resp.status_code == 200: @@ -183,44 +178,55 @@ class TPLinkKlap(TPLinkProtocol): server_hash.hex(), ) - local_seed_auth_hash = _sha256(local_seed + self.local_auth_hash) + local_seed_auth_hash = self.handshake1_seed_auth_hash( + local_seed, remote_seed, self._local_auth_hash + ) # type: ignore # Check the response from the device with local credentials if local_seed_auth_hash == server_hash: _LOGGER.debug("handshake1 hashes match with expected credentials") - return local_seed, remote_seed, self.local_auth_hash # type: ignore + return local_seed, remote_seed, self._local_auth_hash # type: ignore # Now check against the default kasa setup credentials - if not self.kasa_setup_auth_hash: + if not self._kasa_setup_auth_hash: kasa_setup_creds = Credentials( - username=TPLinkKlap.KASA_SETUP_EMAIL, - password=TPLinkKlap.KASA_SETUP_PASSWORD, + username=self.KASA_SETUP_EMAIL, + password=self.KASA_SETUP_PASSWORD, ) - self.kasa_setup_auth_hash = TPLinkKlap.generate_auth_hash(kasa_setup_creds) + self._kasa_setup_auth_hash = self.generate_auth_hash(kasa_setup_creds) - kasa_setup_seed_auth_hash = _sha256( - local_seed + self.kasa_setup_auth_hash # type: ignore + kasa_setup_seed_auth_hash = self.handshake1_seed_auth_hash( + local_seed, + remote_seed, + self._kasa_setup_auth_hash, # type: ignore ) + if kasa_setup_seed_auth_hash == server_hash: _LOGGER.debug( "Server response doesn't match our expected hash on ip %s" + " but an authentication with kasa setup credentials matched", self.host, ) - return local_seed, remote_seed, self.kasa_setup_auth_hash # type: ignore + return local_seed, remote_seed, self._kasa_setup_auth_hash # type: ignore # Finally check against blank credentials if not already blank - if self.credentials != (blank_creds := Credentials(username="", password="")): - if not self.blank_auth_hash: - self.blank_auth_hash = TPLinkKlap.generate_auth_hash(blank_creds) - blank_seed_auth_hash = _sha256(local_seed + self.blank_auth_hash) # type: ignore + if self._credentials != (blank_creds := Credentials(username="", password="")): + if not self._blank_auth_hash: + self._blank_auth_hash = self.generate_auth_hash(blank_creds) + + blank_seed_auth_hash = self.handshake1_seed_auth_hash( + local_seed, + remote_seed, + self._blank_auth_hash, # type: ignore + ) + if blank_seed_auth_hash == server_hash: _LOGGER.debug( "Server response doesn't match our expected hash on ip %s" + " but an authentication with blank credentials matched", self.host, ) - return local_seed, remote_seed, self.blank_auth_hash # type: ignore + return local_seed, remote_seed, self._blank_auth_hash # type: ignore msg = f"Server response doesn't match our challenge on ip {self.host}" _LOGGER.debug(msg) @@ -235,7 +241,7 @@ class TPLinkKlap(TPLinkProtocol): url = f"http://{self.host}/app/handshake2" - payload = _sha256(remote_seed + auth_hash) + payload = self.handshake2_seed_auth_hash(local_seed, remote_seed, auth_hash) response_status, response_data = await self.client_post(url, data=payload) @@ -256,115 +262,70 @@ class TPLinkKlap(TPLinkProtocol): return KlapEncryptionSession(local_seed, remote_seed, auth_hash) + @property + def needs_login(self) -> bool: + """Will return false as KLAP does not do a login.""" + return False + + async def login(self, request: str) -> None: + """Will raise and exception as KLAP does not do a login.""" + raise SmartDeviceException( + "KLAP does not perform logins and return needs_login == False" + ) + + @property + def needs_handshake(self) -> bool: + """Return true if the transport needs to do a handshake.""" + return not self._handshake_done or self._handshake_session_expired() + + async def handshake(self) -> None: + """Perform the encryption handshake.""" + await self.perform_handshake() + async def perform_handshake(self) -> Any: """Perform handshake1 and handshake2. Sets the encryption_session if successful. """ _LOGGER.debug("Starting handshake with %s", self.host) - self.handshake_done = False - self.session_expire_at = None - self.session_cookie = None + self._handshake_done = False + self._session_expire_at = None + self._session_cookie = None local_seed, remote_seed, auth_hash = await self.perform_handshake1() - self.session_cookie = self.http_client.cookies.get( # type: ignore - TPLinkKlap.SESSION_COOKIE_NAME + self._session_cookie = self._http_client.cookies.get( # type: ignore + self.SESSION_COOKIE_NAME ) # The device returns a TIMEOUT cookie on handshake1 which # it doesn't like to get back so we store the one we want - self.session_expire_at = time.time() + 86400 - self.encryption_session = await self.perform_handshake2( + self._session_expire_at = time.time() + 86400 + self._encryption_session = await self.perform_handshake2( local_seed, remote_seed, auth_hash ) - self.handshake_done = True + self._handshake_done = True _LOGGER.debug("Handshake with %s complete", self.host) - def handshake_session_expired(self): + def _handshake_session_expired(self): """Return true if session has expired.""" return ( - self.session_expire_at is None or self.session_expire_at - time.time() <= 0 + self._session_expire_at is None + or self._session_expire_at - time.time() <= 0 ) - @staticmethod - def generate_auth_hash(creds: Credentials): - """Generate an md5 auth hash for the protocol on the supplied credentials.""" - un = creds.username or "" - pw = creds.password or "" - return _md5(_md5(un.encode()) + _md5(pw.encode())) - - @staticmethod - def generate_owner_hash(creds: Credentials): - """Return the MD5 hash of the username in this object.""" - un = creds.username or "" - return _md5(un.encode()) - - async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: - """Query the device retrying for retry_count on failure.""" - if isinstance(request, dict): - request = json_dumps(request) - assert isinstance(request, str) # noqa: S101 - - async with self.query_lock: - return await self._query(request, retry_count) - - async def _query(self, request: str, retry_count: int = 3) -> Dict: - for retry in range(retry_count + 1): - try: - return await self._execute_query(request, retry) - except httpx.CloseError as sdex: - await self.close() - if retry >= retry_count: - _LOGGER.debug("Giving up on %s after %s retries", self.host, retry) - raise SmartDeviceException( - f"Unable to connect to the device: {self.host}: {sdex}" - ) from sdex - continue - except httpx.ConnectError as cex: - await self.close() - raise SmartDeviceException( - f"Unable to connect to the device: {self.host}: {cex}" - ) from cex - except TimeoutError as tex: - await self.close() - raise SmartDeviceException( - f"Unable to connect to the device, timed out: {self.host}: {tex}" - ) from tex - except AuthenticationException as auex: - _LOGGER.debug("Unable to authenticate with %s, not retrying", self.host) - raise auex - except Exception as ex: - await self.close() - if retry >= retry_count: - _LOGGER.debug("Giving up on %s after %s retries", self.host, retry) - raise SmartDeviceException( - f"Unable to connect to the device: {self.host}: {ex}" - ) from ex - continue - - # make mypy happy, this should never be reached.. - raise SmartDeviceException("Query reached somehow to unreachable") - - async def _execute_query(self, request: str, retry_count: int) -> Dict: - if not self.http_client: - self.http_client = httpx.AsyncClient() - - if not self.handshake_done or self.handshake_session_expired(): - try: - await self.perform_handshake() - - except AuthenticationException as auex: - _LOGGER.debug( - "Unable to complete handshake for device %s, " - + "authentication failed", - self.host, - ) - raise auex + async def send(self, request: str): + """Send the request.""" + if self.needs_handshake: + raise SmartDeviceException( + "Handshake must be complete before trying to send" + ) + if self.needs_login: + raise SmartDeviceException("Login must be complete before trying to send") # Check for mypy - if self.encryption_session is not None: - payload, seq = self.encryption_session.encrypt(request.encode()) + if self._encryption_session is not None: + payload, seq = self._encryption_session.encrypt(request.encode()) url = f"http://{self.host}/app/request" @@ -376,14 +337,14 @@ class TPLinkKlap(TPLinkProtocol): msg = ( f"at {datetime.datetime.now()}. Host is {self.host}, " - + f"Retry count is {retry_count}, Sequence is {seq}, " + + f"Sequence is {seq}, " + f"Response status is {response_status}, Request was {request}" ) if response_status != 200: _LOGGER.error("Query failed after succesful authentication " + msg) # If we failed with a security error, force a new handshake next time. if response_status == 403: - self.handshake_done = False + self._handshake_done = False raise AuthenticationException( f"Got a security error from {self.host} after handshake " + "completed" @@ -397,8 +358,8 @@ class TPLinkKlap(TPLinkProtocol): _LOGGER.debug("Query posted " + msg) # Check for mypy - if self.encryption_session is not None: - decrypted_response = self.encryption_session.decrypt(response_data) + if self._encryption_session is not None: + decrypted_response = self._encryption_session.decrypt(response_data) json_payload = json_loads(decrypted_response) @@ -411,12 +372,66 @@ class TPLinkKlap(TPLinkProtocol): return json_payload async def close(self) -> None: - """Close the protocol.""" - client = self.http_client - self.http_client = None + """Close the transport.""" + client = self._http_client + self._http_client = None if client: await client.aclose() + @staticmethod + def generate_auth_hash(creds: Credentials): + """Generate an md5 auth hash for the protocol on the supplied credentials.""" + un = creds.username or "" + pw = creds.password or "" + + return md5(md5(un.encode()) + md5(pw.encode())) + + @staticmethod + def handshake1_seed_auth_hash( + local_seed: bytes, remote_seed: bytes, auth_hash: bytes + ): + """Generate an md5 auth hash for the protocol on the supplied credentials.""" + return _sha256(local_seed + auth_hash) + + @staticmethod + def handshake2_seed_auth_hash( + local_seed: bytes, remote_seed: bytes, auth_hash: bytes + ): + """Generate an md5 auth hash for the protocol on the supplied credentials.""" + return _sha256(remote_seed + auth_hash) + + @staticmethod + def generate_owner_hash(creds: Credentials): + """Return the MD5 hash of the username in this object.""" + un = creds.username or "" + return md5(un.encode()) + + +class TPlinkKlapTransportV2(KlapTransport): + """Implementation of the KLAP encryption protocol with v2 hanshake hashes.""" + + @staticmethod + def generate_auth_hash(creds: Credentials): + """Generate an md5 auth hash for the protocol on the supplied credentials.""" + un = creds.username or "" + pw = creds.password or "" + + return _sha256(_sha1(un.encode()) + _sha1(pw.encode())) + + @staticmethod + def handshake1_seed_auth_hash( + local_seed: bytes, remote_seed: bytes, auth_hash: bytes + ): + """Generate an md5 auth hash for the protocol on the supplied credentials.""" + return _sha256(local_seed + remote_seed + auth_hash) + + @staticmethod + def handshake2_seed_auth_hash( + local_seed: bytes, remote_seed: bytes, auth_hash: bytes + ): + """Generate an md5 auth hash for the protocol on the supplied credentials.""" + return _sha256(remote_seed + local_seed + auth_hash) + class KlapEncryptionSession: """Class to represent an encryption session and it's internal state. diff --git a/kasa/protocol.py b/kasa/protocol.py index 6413ba5d..62cd5fb6 100755 --- a/kasa/protocol.py +++ b/kasa/protocol.py @@ -22,6 +22,7 @@ from typing import Dict, Generator, Optional, Union # When support for cpython older than 3.11 is dropped # async_timeout can be replaced with asyncio.timeout from async_timeout import timeout as asyncio_timeout +from cryptography.hazmat.primitives import hashes from .credentials import Credentials from .exceptions import SmartDeviceException @@ -32,6 +33,56 @@ _LOGGER = logging.getLogger(__name__) _NO_RETRY_ERRORS = {errno.EHOSTDOWN, errno.EHOSTUNREACH, errno.ECONNREFUSED} +def md5(payload: bytes) -> bytes: + """Return an md5 hash of the payload.""" + digest = hashes.Hash(hashes.MD5()) # noqa: S303 + digest.update(payload) + hash = digest.finalize() + return hash + + +class BaseTransport(ABC): + """Base class for all TP-Link protocol transports.""" + + def __init__( + self, + host: str, + *, + port: Optional[int] = None, + credentials: Optional[Credentials] = None, + ) -> None: + """Create a protocol object.""" + self.host = host + self.port = port + self.credentials = credentials + + @property + @abstractmethod + def needs_handshake(self) -> bool: + """Return true if the transport needs to do a handshake.""" + + @property + @abstractmethod + def needs_login(self) -> bool: + """Return true if the transport needs to do a login.""" + + @abstractmethod + async def login(self, request: str) -> None: + """Login to the device.""" + + @abstractmethod + async def handshake(self) -> None: + """Perform the encryption handshake.""" + + @abstractmethod + async def send(self, request: str) -> Dict: + """Send a message to the device and return a response.""" + + @abstractmethod + async def close(self) -> None: + """Close the transport. Abstract method to be overriden.""" + + class TPLinkProtocol(ABC): """Base class for all TP-Link Smart Home communication.""" @@ -41,6 +92,7 @@ class TPLinkProtocol(ABC): *, port: Optional[int] = None, credentials: Optional[Credentials] = None, + transport: Optional[BaseTransport] = None, ) -> None: """Create a protocol object.""" self.host = host diff --git a/kasa/smartdevice.py b/kasa/smartdevice.py index af6a2c7f..342d1c4a 100755 --- a/kasa/smartdevice.py +++ b/kasa/smartdevice.py @@ -365,6 +365,7 @@ class SmartDevice: def update_from_discover_info(self, info: Dict[str, Any]) -> None: """Update state from info from the discover call.""" + self._discovery_info = info if "system" in info and (sys_info := info["system"].get("get_sysinfo")): self._last_update = info self._set_sys_info(sys_info) @@ -372,7 +373,6 @@ class SmartDevice: # This allows setting of some info properties directly # from partial discovery info that will then be found # by the requires_update decorator - self._discovery_info = info self._set_sys_info(info) def _set_sys_info(self, sys_info: Dict[str, Any]) -> None: diff --git a/kasa/smartprotocol.py b/kasa/smartprotocol.py new file mode 100644 index 00000000..98d1a86d --- /dev/null +++ b/kasa/smartprotocol.py @@ -0,0 +1,219 @@ +"""Implementation of the TP-Link AES Protocol. + +Based on the work of https://github.com/petretiandrea/plugp100 +under compatible GNU GPL3 license. +""" + +import asyncio +import base64 +import logging +import time +import uuid +from pprint import pformat as pf +from typing import Dict, Optional, Union + +import httpx + +from .aestransport import AesTransport +from .credentials import Credentials +from .exceptions import AuthenticationException, SmartDeviceException +from .json import dumps as json_dumps +from .protocol import BaseTransport, TPLinkProtocol, md5 + +_LOGGER = logging.getLogger(__name__) +logging.getLogger("httpx").propagate = False + + +class SmartProtocol(TPLinkProtocol): + """Class for the new TPLink SMART protocol.""" + + DEFAULT_PORT = 80 + + def __init__( + self, + host: str, + *, + transport: Optional[BaseTransport] = None, + credentials: Optional[Credentials] = None, + timeout: Optional[int] = None, + ) -> None: + super().__init__(host=host, port=self.DEFAULT_PORT) + + self._credentials: Credentials = credentials or Credentials( + username="", password="" + ) + self._transport: BaseTransport = transport or AesTransport( + host, credentials=self._credentials, timeout=timeout + ) + self._terminal_uuid: Optional[str] = None + self._request_id_generator = SnowflakeId(1, 1) + self._query_lock = asyncio.Lock() + + def get_smart_request(self, method, params=None) -> str: + """Get a request message as a string.""" + request = { + "method": method, + "params": params, + "requestID": self._request_id_generator.generate_id(), + "request_time_milis": round(time.time() * 1000), + "terminal_uuid": self._terminal_uuid, + } + return json_dumps(request) + + async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: + """Query the device retrying for retry_count on failure.""" + async with self._query_lock: + resp_dict = await self._query(request, retry_count) + if "result" in resp_dict: + return resp_dict["result"] + return {} + + async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: + for retry in range(retry_count + 1): + try: + return await self._execute_query(request, retry) + except httpx.CloseError as sdex: + await self.close() + if retry >= retry_count: + _LOGGER.debug("Giving up on %s after %s retries", self.host, retry) + raise SmartDeviceException( + f"Unable to connect to the device: {self.host}: {sdex}" + ) from sdex + continue + except httpx.ConnectError as cex: + await self.close() + raise SmartDeviceException( + f"Unable to connect to the device: {self.host}: {cex}" + ) from cex + except TimeoutError as tex: + await self.close() + raise SmartDeviceException( + f"Unable to connect to the device, timed out: {self.host}: {tex}" + ) from tex + except AuthenticationException as auex: + _LOGGER.debug("Unable to authenticate with %s, not retrying", self.host) + raise auex + except Exception as ex: + await self.close() + if retry >= retry_count: + _LOGGER.debug("Giving up on %s after %s retries", self.host, retry) + raise SmartDeviceException( + f"Unable to connect to the device: {self.host}: {ex}" + ) from ex + continue + + # make mypy happy, this should never be reached.. + raise SmartDeviceException("Query reached somehow to unreachable") + + async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> Dict: + if isinstance(request, dict): + smart_method = next(iter(request)) + smart_params = request[smart_method] + else: + smart_method = request + smart_params = None + + if self._transport.needs_handshake: + await self._transport.handshake() + + if self._transport.needs_login: + self._terminal_uuid = base64.b64encode(md5(uuid.uuid4().bytes)).decode( + "UTF-8" + ) + login_request = self.get_smart_request("login_device") + await self._transport.login(login_request) + + smart_request = self.get_smart_request(smart_method, smart_params) + response_data = await self._transport.send(smart_request) + + _LOGGER.debug( + "%s << %s", + self.host, + _LOGGER.isEnabledFor(logging.DEBUG) and pf(response_data), + ) + + return response_data + + async def close(self) -> None: + """Close the protocol.""" + await self._transport.close() + + +class SnowflakeId: + """Class for generating snowflake ids.""" + + EPOCH = 1420041600000 # Custom epoch (in milliseconds) + WORKER_ID_BITS = 5 + DATA_CENTER_ID_BITS = 5 + SEQUENCE_BITS = 12 + + MAX_WORKER_ID = (1 << WORKER_ID_BITS) - 1 + MAX_DATA_CENTER_ID = (1 << DATA_CENTER_ID_BITS) - 1 + + SEQUENCE_MASK = (1 << SEQUENCE_BITS) - 1 + + def __init__(self, worker_id, data_center_id): + if worker_id > SnowflakeId.MAX_WORKER_ID or worker_id < 0: + raise ValueError( + "Worker ID can't be greater than " + + str(SnowflakeId.MAX_WORKER_ID) + + " or less than 0" + ) + if data_center_id > SnowflakeId.MAX_DATA_CENTER_ID or data_center_id < 0: + raise ValueError( + "Data center ID can't be greater than " + + str(SnowflakeId.MAX_DATA_CENTER_ID) + + " or less than 0" + ) + + self.worker_id = worker_id + self.data_center_id = data_center_id + self.sequence = 0 + self.last_timestamp = -1 + + def generate_id(self): + """Generate a snowflake id.""" + timestamp = self._current_millis() + + if timestamp < self.last_timestamp: + raise ValueError("Clock moved backwards. Refusing to generate ID.") + + if timestamp == self.last_timestamp: + # Within the same millisecond, increment the sequence number + self.sequence = (self.sequence + 1) & SnowflakeId.SEQUENCE_MASK + if self.sequence == 0: + # Sequence exceeds its bit range, wait until the next millisecond + timestamp = self._wait_next_millis(self.last_timestamp) + else: + # New millisecond, reset the sequence number + self.sequence = 0 + + # Update the last timestamp + self.last_timestamp = timestamp + + # Generate and return the final ID + return ( + ( + (timestamp - SnowflakeId.EPOCH) + << ( + SnowflakeId.WORKER_ID_BITS + + SnowflakeId.SEQUENCE_BITS + + SnowflakeId.DATA_CENTER_ID_BITS + ) + ) + | ( + self.data_center_id + << (SnowflakeId.SEQUENCE_BITS + SnowflakeId.WORKER_ID_BITS) + ) + | (self.worker_id << SnowflakeId.SEQUENCE_BITS) + | self.sequence + ) + + def _current_millis(self): + return round(time.time() * 1000) + + def _wait_next_millis(self, last_timestamp): + timestamp = self._current_millis() + while timestamp <= last_timestamp: + timestamp = self._current_millis() + return timestamp diff --git a/kasa/tapo/tapodevice.py b/kasa/tapo/tapodevice.py index 2ba03956..6c643a6a 100644 --- a/kasa/tapo/tapodevice.py +++ b/kasa/tapo/tapodevice.py @@ -4,10 +4,10 @@ import logging from datetime import datetime, timedelta, timezone from typing import Any, Dict, Optional, Set, cast -from ..aesprotocol import TPLinkAes from ..credentials import Credentials from ..exceptions import AuthenticationException from ..smartdevice import SmartDevice +from ..smartprotocol import SmartProtocol _LOGGER = logging.getLogger(__name__) @@ -26,7 +26,7 @@ class TapoDevice(SmartDevice): super().__init__(host, port=port, credentials=credentials, timeout=timeout) self._state_information: Dict[str, Any] = {} self._discovery_info: Optional[Dict[str, Any]] = None - self.protocol = TPLinkAes(host, credentials=credentials, timeout=timeout) + self.protocol = SmartProtocol(host, credentials=credentials, timeout=timeout) async def update(self, update_children: bool = True): """Update the device.""" diff --git a/kasa/tests/conftest.py b/kasa/tests/conftest.py index 2b2adc7d..50d2f0de 100644 --- a/kasa/tests/conftest.py +++ b/kasa/tests/conftest.py @@ -2,27 +2,45 @@ import asyncio import glob import json import os +from dataclasses import dataclass +from json import dumps as json_dumps from os.path import basename from pathlib import Path, PurePath -from typing import Dict +from typing import Dict, Optional from unittest.mock import MagicMock import pytest # type: ignore # see https://github.com/pytest-dev/pytest/issues/3342 from kasa import ( + Credentials, Discover, SmartBulb, SmartDimmer, SmartLightStrip, SmartPlug, SmartStrip, + TPLinkSmartHomeProtocol, ) +from kasa.tapo import TapoDevice, TapoPlug -from .newfakes import FakeTransportProtocol +from .newfakes import FakeSmartProtocol, FakeTransportProtocol -SUPPORTED_DEVICES = glob.glob( - os.path.dirname(os.path.abspath(__file__)) + "/fixtures/*.json" -) +SUPPORTED_IOT_DEVICES = [ + (device, "IOT") + for device in glob.glob( + os.path.dirname(os.path.abspath(__file__)) + "/fixtures/*.json" + ) +] + +SUPPORTED_SMART_DEVICES = [ + (device, "SMART") + for device in glob.glob( + os.path.dirname(os.path.abspath(__file__)) + "/fixtures/smart/*.json" + ) +] + + +SUPPORTED_DEVICES = SUPPORTED_IOT_DEVICES + SUPPORTED_SMART_DEVICES LIGHT_STRIPS = {"KL400", "KL430", "KL420"} @@ -55,43 +73,59 @@ PLUGS = { "KP401", "KS200M", } + STRIPS = {"HS107", "HS300", "KP303", "KP200", "KP400", "EP40"} DIMMERS = {"ES20M", "HS220", "KS220M", "KS230", "KP405"} DIMMABLE = {*BULBS, *DIMMERS} WITH_EMETER = {"HS110", "HS300", "KP115", "KP125", *BULBS} -ALL_DEVICES = BULBS.union(PLUGS).union(STRIPS).union(DIMMERS) +ALL_DEVICES_IOT = BULBS.union(PLUGS).union(STRIPS).union(DIMMERS) + +PLUGS_SMART = {"P110"} +ALL_DEVICES_SMART = PLUGS_SMART + +ALL_DEVICES = ALL_DEVICES_IOT.union(ALL_DEVICES_SMART) IP_MODEL_CACHE: Dict[str, str] = {} -def filter_model(desc, filter): - filtered = list() - for dev in SUPPORTED_DEVICES: - for filt in filter: - if filt in basename(dev): - filtered.append(dev) +def idgenerator(paramtuple): + return basename(paramtuple[0]) + ( + "" if paramtuple[1] == "IOT" else "-" + paramtuple[1] + ) - filtered_basenames = [basename(f) for f in filtered] + +def filter_model(desc, model_filter, protocol_filter=None): + if not protocol_filter: + protocol_filter = {"IOT"} + filtered = list() + for file, protocol in SUPPORTED_DEVICES: + if protocol in protocol_filter: + file_model = basename(file).split("_")[0] + for model in model_filter: + if model in file_model: + filtered.append((file, protocol)) + + filtered_basenames = [basename(f) + "-" + p for f, p in filtered] print(f"{desc}: {filtered_basenames}") return filtered -def parametrize(desc, devices, ids=None): +def parametrize(desc, devices, protocol_filter=None, ids=None): return pytest.mark.parametrize( - "dev", filter_model(desc, devices), indirect=True, ids=ids + "dev", filter_model(desc, devices, protocol_filter), indirect=True, ids=ids ) has_emeter = parametrize("has emeter", WITH_EMETER) -no_emeter = parametrize("no emeter", ALL_DEVICES - WITH_EMETER) +no_emeter = parametrize("no emeter", ALL_DEVICES_IOT - WITH_EMETER) -bulb = parametrize("bulbs", BULBS, ids=basename) -plug = parametrize("plugs", PLUGS, ids=basename) -strip = parametrize("strips", STRIPS, ids=basename) -dimmer = parametrize("dimmers", DIMMERS, ids=basename) -lightstrip = parametrize("lightstrips", LIGHT_STRIPS, ids=basename) +bulb = parametrize("bulbs", BULBS, ids=idgenerator) +plug = parametrize("plugs", PLUGS, ids=idgenerator) +strip = parametrize("strips", STRIPS, ids=idgenerator) +dimmer = parametrize("dimmers", DIMMERS, ids=idgenerator) +lightstrip = parametrize("lightstrips", LIGHT_STRIPS, ids=idgenerator) # bulb types dimmable = parametrize("dimmable", DIMMABLE) @@ -101,6 +135,58 @@ non_variable_temp = parametrize("non-variable color temp", BULBS - VARIABLE_TEMP color_bulb = parametrize("color bulbs", COLOR_BULBS) non_color_bulb = parametrize("non-color bulbs", BULBS - COLOR_BULBS) +plug_smart = parametrize( + "plug devices smart", PLUGS_SMART, protocol_filter={"SMART"}, ids=idgenerator +) +device_smart = parametrize( + "devices smart", ALL_DEVICES_SMART, protocol_filter={"SMART"}, ids=idgenerator +) +device_iot = parametrize( + "devices iot", ALL_DEVICES_IOT, protocol_filter={"IOT"}, ids=idgenerator +) + + +def get_fixture_data(): + """Return raw discovery file contents as JSON. Used for discovery tests.""" + fixture_data = {} + for file, protocol in SUPPORTED_DEVICES: + p = Path(file) + if not p.is_absolute(): + folder = Path(__file__).parent / "fixtures" + if protocol == "SMART": + folder = folder / "smart" + p = folder / file + + with open(p) as f: + fixture_data[basename(p)] = json.load(f) + return fixture_data + + +FIXTURE_DATA = get_fixture_data() + + +def filter_fixtures(desc, root_filter): + filtered = {} + for key, val in FIXTURE_DATA.items(): + if root_filter in val: + filtered[key] = val + + print(f"{desc}: {filtered.keys()}") + return filtered + + +def parametrize_discovery(desc, root_key): + filtered_fixtures = filter_fixtures(desc, root_key) + return pytest.mark.parametrize( + "discovery_data", + filtered_fixtures.values(), + indirect=True, + ids=filtered_fixtures.keys(), + ) + + +new_discovery = parametrize_discovery("new discovery", "discovery_result") + def check_categories(): """Check that every fixture file is categorized.""" @@ -110,15 +196,15 @@ def check_categories(): + plug.args[1] + bulb.args[1] + lightstrip.args[1] + + plug_smart.args[1] ) diff = set(SUPPORTED_DEVICES) - set(categorized_fixtures) if diff: - for file in diff: + for file, protocol in diff: print( - "No category for file %s, add to the corresponding set (BULBS, PLUGS, ..)" - % file + f"No category for file {file} protocol {protocol}, add to the corresponding set (BULBS, PLUGS, ..)" ) - raise Exception("Missing category for %s" % diff) + raise Exception(f"Missing category for {diff}") check_categories() @@ -134,27 +220,32 @@ async def handle_turn_on(dev, turn_on): await dev.turn_off() -def device_for_file(model): - for d in STRIPS: - if d in model: - return SmartStrip +def device_for_file(model, protocol): + if protocol == "SMART": + for d in PLUGS_SMART: + if d in model: + return TapoPlug + else: + for d in STRIPS: + if d in model: + return SmartStrip - for d in PLUGS: - if d in model: - return SmartPlug + for d in PLUGS: + if d in model: + return SmartPlug - # Light strips are recognized also as bulbs, so this has to go first - for d in LIGHT_STRIPS: - if d in model: - return SmartLightStrip + # Light strips are recognized also as bulbs, so this has to go first + for d in LIGHT_STRIPS: + if d in model: + return SmartLightStrip - for d in BULBS: - if d in model: - return SmartBulb + for d in BULBS: + if d in model: + return SmartBulb - for d in DIMMERS: - if d in model: - return SmartDimmer + for d in DIMMERS: + if d in model: + return SmartDimmer raise Exception("Unable to find type for %s", model) @@ -170,11 +261,14 @@ async def _discover_update_and_close(ip): return await _update_and_close(d) -async def get_device_for_file(file): +async def get_device_for_file(file, protocol): # if the wanted file is not an absolute path, prepend the fixtures directory p = Path(file) if not p.is_absolute(): - p = Path(__file__).parent / "fixtures" / file + folder = Path(__file__).parent / "fixtures" + if protocol == "SMART": + folder = folder / "smart" + p = folder / file def load_file(): with open(p) as f: @@ -184,8 +278,12 @@ async def get_device_for_file(file): sysinfo = await loop.run_in_executor(None, load_file) model = basename(file) - d = device_for_file(model)(host="127.0.0.123") - d.protocol = FakeTransportProtocol(sysinfo) + d = device_for_file(model, protocol)(host="127.0.0.123") + if protocol == "SMART": + d.protocol = FakeSmartProtocol(sysinfo) + d.credentials = Credentials("", "") + else: + d.protocol = FakeTransportProtocol(sysinfo) await _update_and_close(d) return d @@ -197,7 +295,7 @@ async def dev(request): Provides a device (given --ip) or parametrized fixture for the supported devices. The initial update is called automatically before returning the device. """ - file = request.param + file, protocol = request.param ip = request.config.getoption("--ip") if ip: @@ -210,19 +308,62 @@ async def dev(request): pytest.skip(f"skipping file {file}") return d if d else await _discover_update_and_close(ip) - return await get_device_for_file(file) + return await get_device_for_file(file, protocol) -@pytest.fixture(params=SUPPORTED_DEVICES, scope="session") +@pytest.fixture +def discovery_mock(discovery_data, mocker): + @dataclass + class _DiscoveryMock: + ip: str + default_port: int + discovery_data: dict + port_override: Optional[int] = None + + if "result" in discovery_data: + datagram = ( + b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8" + + json_dumps(discovery_data).encode() + ) + dm = _DiscoveryMock("127.0.0.123", 20002, discovery_data) + else: + datagram = TPLinkSmartHomeProtocol.encrypt(json_dumps(discovery_data))[4:] + dm = _DiscoveryMock("127.0.0.123", 9999, discovery_data) + + def mock_discover(self): + port = ( + dm.port_override + if dm.port_override and dm.default_port != 20002 + else dm.default_port + ) + self.datagram_received( + datagram, + (dm.ip, port), + ) + + mocker.patch("kasa.discover._DiscoverProtocol.do_discover", mock_discover) + mocker.patch( + "socket.getaddrinfo", + side_effect=lambda *_, **__: [(None, None, None, None, (dm.ip, 0))], + ) + yield dm + + +@pytest.fixture(params=FIXTURE_DATA.values(), ids=FIXTURE_DATA.keys(), scope="session") def discovery_data(request): """Return raw discovery file contents as JSON. Used for discovery tests.""" - file = request.param - p = Path(file) - if not p.is_absolute(): - p = Path(__file__).parent / "fixtures" / file + fixture_data = request.param + if "discovery_result" in fixture_data: + return {"result": fixture_data["discovery_result"]} + else: + return {"system": {"get_sysinfo": fixture_data["system"]["get_sysinfo"]}} - with open(p) as f: - return json.load(f) + +@pytest.fixture(params=FIXTURE_DATA.values(), ids=FIXTURE_DATA.keys(), scope="session") +def all_fixture_data(request): + """Return raw fixture file contents as JSON. Used for discovery tests.""" + fixture_data = request.param + return fixture_data def pytest_addoption(parser): diff --git a/kasa/tests/fixtures/smart/P110_1.0_1.3.0.json b/kasa/tests/fixtures/smart/P110_1.0_1.3.0.json new file mode 100644 index 00000000..99fd3f13 --- /dev/null +++ b/kasa/tests/fixtures/smart/P110_1.0_1.3.0.json @@ -0,0 +1,180 @@ +{ + "component_nego": { + "component_list": [ + { + "id": "device", + "ver_code": 2 + }, + { + "id": "firmware", + "ver_code": 2 + }, + { + "id": "quick_setup", + "ver_code": 3 + }, + { + "id": "time", + "ver_code": 1 + }, + { + "id": "wireless", + "ver_code": 1 + }, + { + "id": "schedule", + "ver_code": 2 + }, + { + "id": "countdown", + "ver_code": 2 + }, + { + "id": "antitheft", + "ver_code": 1 + }, + { + "id": "account", + "ver_code": 1 + }, + { + "id": "synchronize", + "ver_code": 1 + }, + { + "id": "sunrise_sunset", + "ver_code": 1 + }, + { + "id": "led", + "ver_code": 1 + }, + { + "id": "cloud_connect", + "ver_code": 1 + }, + { + "id": "iot_cloud", + "ver_code": 1 + }, + { + "id": "device_local_time", + "ver_code": 1 + }, + { + "id": "default_states", + "ver_code": 1 + }, + { + "id": "auto_off", + "ver_code": 2 + }, + { + "id": "localSmart", + "ver_code": 1 + }, + { + "id": "energy_monitoring", + "ver_code": 2 + }, + { + "id": "power_protection", + "ver_code": 1 + }, + { + "id": "current_protection", + "ver_code": 1 + } + ] + }, + "discovery_result": { + "device_id": "00000000000000000000000000000000", + "device_model": "P110(UK)", + "device_type": "SMART.TAPOPLUG", + "factory_default": false, + "ip": "127.0.0.123", + "is_support_iot_cloud": true, + "mac": "00-00-00-00-00-00", + "mgt_encrypt_schm": { + "encrypt_type": "KLAP", + "http_port": 80, + "is_support_https": false, + "lv": 2 + }, + "obd_src": "tplink", + "owner": "00000000000000000000000000000000" + }, + "get_current_power": { + "current_power": 0 + }, + "get_device_info": { + "auto_off_remain_time": 0, + "auto_off_status": "off", + "avatar": "plug", + "default_states": { + "state": {}, + "type": "last_states" + }, + "device_id": "0000000000000000000000000000000000000000", + "device_on": true, + "fw_id": "00000000000000000000000000000000", + "fw_ver": "1.3.0 Build 230905 Rel.152200", + "has_set_location_info": true, + "hw_id": "00000000000000000000000000000000", + "hw_ver": "1.0", + "ip": "127.0.0.123", + "lang": "en_US", + "latitude": 0, + "longitude": 0, + "mac": "00-00-00-00-00-00", + "model": "P110", + "nickname": "VGFwaSBTbWFydCBQbHVnIDE=", + "oem_id": "00000000000000000000000000000000", + "on_time": 119335, + "overcurrent_status": "normal", + "overheated": false, + "power_protection_status": "normal", + "region": "Europe/London", + "rssi": -57, + "signal_level": 2, + "specs": "", + "ssid": "IyNNQVNLRUROQU1FIyM=", + "time_diff": 0, + "type": "SMART.TAPOPLUG" + }, + "get_device_time": { + "region": "Europe/London", + "time_diff": 0, + "timestamp": 1701370224 + }, + "get_device_usage": { + "power_usage": { + "past30": 75, + "past7": 69, + "today": 0 + }, + "saved_power": { + "past30": 2029, + "past7": 1964, + "today": 1130 + }, + "time_usage": { + "past30": 2104, + "past7": 2033, + "today": 1130 + } + }, + "get_energy_usage": { + "current_power": 0, + "electricity_charge": [ + 0, + 0, + 0 + ], + "local_time": "2023-11-30 18:50:24", + "month_energy": 75, + "month_runtime": 2104, + "today_energy": 0, + "today_runtime": 1130 + } +} diff --git a/kasa/tests/newfakes.py b/kasa/tests/newfakes.py index ee679cae..c5bf238f 100644 --- a/kasa/tests/newfakes.py +++ b/kasa/tests/newfakes.py @@ -1,6 +1,7 @@ import copy import logging import re +from json import loads as json_loads from voluptuous import ( REMOVE_EXTRA, @@ -13,7 +14,8 @@ from voluptuous import ( Schema, ) -from ..protocol import TPLinkSmartHomeProtocol +from ..protocol import BaseTransport, TPLinkSmartHomeProtocol +from ..smartprotocol import SmartProtocol _LOGGER = logging.getLogger(__name__) @@ -285,6 +287,41 @@ TIME_MODULE = { } +class FakeSmartProtocol(SmartProtocol): + def __init__(self, info): + super().__init__("127.0.0.123", transport=FakeSmartTransport(info)) + + +class FakeSmartTransport(BaseTransport): + def __init__(self, info): + self.info = info + + @property + def needs_handshake(self) -> bool: + return False + + @property + def needs_login(self) -> bool: + return False + + async def login(self, request: str) -> None: + pass + + async def handshake(self) -> None: + pass + + async def send(self, request: str): + request_dict = json_loads(request) + method = request_dict["method"] + if method == "component_nego" or method[:4] == "get_": + return self.info[method] + elif method[:4] == "set_": + _LOGGER.debug("Call %s not implemented, doing nothing", method) + + async def close(self) -> None: + pass + + class FakeTransportProtocol(TPLinkSmartHomeProtocol): def __init__(self, info): self.discovery_data = info diff --git a/kasa/tests/test_cli.py b/kasa/tests/test_cli.py index f590808f..55e3977a 100644 --- a/kasa/tests/test_cli.py +++ b/kasa/tests/test_cli.py @@ -6,12 +6,15 @@ from asyncclick.testing import CliRunner from kasa import SmartDevice, TPLinkSmartHomeProtocol from kasa.cli import alias, brightness, cli, emeter, raw_command, state, sysinfo, toggle +from kasa.device_factory import DEVICE_TYPE_TO_CLASS from kasa.discover import Discover +from kasa.smartprotocol import SmartProtocol -from .conftest import handle_turn_on, turn_on -from .newfakes import FakeTransportProtocol +from .conftest import device_iot, handle_turn_on, new_discovery, turn_on +from .newfakes import FakeSmartProtocol, FakeTransportProtocol +@device_iot async def test_sysinfo(dev): runner = CliRunner() res = await runner.invoke(sysinfo, obj=dev) @@ -19,6 +22,7 @@ async def test_sysinfo(dev): assert dev.alias in res.output +@device_iot @turn_on async def test_state(dev, turn_on): await handle_turn_on(dev, turn_on) @@ -32,6 +36,7 @@ async def test_state(dev, turn_on): assert "Device state: False" in res.output +@device_iot @turn_on async def test_toggle(dev, turn_on, mocker): await handle_turn_on(dev, turn_on) @@ -44,6 +49,7 @@ async def test_toggle(dev, turn_on, mocker): assert dev.is_on +@device_iot async def test_alias(dev): runner = CliRunner() @@ -62,6 +68,7 @@ async def test_alias(dev): await dev.set_alias(old_alias) +@device_iot async def test_raw_command(dev): runner = CliRunner() res = await runner.invoke(raw_command, ["system", "get_sysinfo"], obj=dev) @@ -74,6 +81,7 @@ async def test_raw_command(dev): assert "Usage" in res.output +@device_iot async def test_emeter(dev: SmartDevice, mocker): runner = CliRunner() @@ -99,6 +107,7 @@ async def test_emeter(dev: SmartDevice, mocker): daily.assert_called_with(year=1900, month=12) +@device_iot async def test_brightness(dev): runner = CliRunner() res = await runner.invoke(brightness, obj=dev) @@ -116,6 +125,7 @@ async def test_brightness(dev): assert "Brightness: 12" in res.output +@device_iot async def test_json_output(dev: SmartDevice, mocker): """Test that the json output produces correct output.""" mocker.patch("kasa.Discover.discover", return_value=[dev]) @@ -125,13 +135,9 @@ async def test_json_output(dev: SmartDevice, mocker): assert json.loads(res.output) == dev.internal_state -async def test_credentials(discovery_data: dict, mocker): +@new_discovery +async def test_credentials(discovery_mock, mocker): """Test credentials are passed correctly from cli to device.""" - # As this is testing the device constructor need to explicitly wire in - # the FakeTransportProtocol - ftp = FakeTransportProtocol(discovery_data) - mocker.patch.object(TPLinkSmartHomeProtocol, "query", ftp.query) - # Patch state to echo username and password pass_dev = click.make_pass_decorator(SmartDevice) @@ -143,18 +149,15 @@ async def test_credentials(discovery_data: dict, mocker): ) mocker.patch("kasa.cli.state", new=_state) - cli_device_type = Discover._get_device_class(discovery_data)( - "any" - ).device_type.value + for subclass in DEVICE_TYPE_TO_CLASS.values(): + mocker.patch.object(subclass, "update") runner = CliRunner() res = await runner.invoke( cli, [ "--host", - "127.0.0.1", - "--type", - cli_device_type, + "127.0.0.123", "--username", "foo", "--password", @@ -162,9 +165,11 @@ async def test_credentials(discovery_data: dict, mocker): ], ) assert res.exit_code == 0 - assert res.output == "Username:foo Password:bar\n" + + assert "Username:foo Password:bar\n" in res.output +@device_iot async def test_without_device_type(discovery_data: dict, dev, mocker): """Test connecting without the device type.""" runner = CliRunner() diff --git a/kasa/tests/test_device_factory.py b/kasa/tests/test_device_factory.py index aca38e19..eb12b3b0 100644 --- a/kasa/tests/test_device_factory.py +++ b/kasa/tests/test_device_factory.py @@ -5,7 +5,9 @@ from typing import Type import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 from kasa import ( + Credentials, DeviceType, + Discover, SmartBulb, SmartDevice, SmartDeviceException, @@ -13,8 +15,13 @@ from kasa import ( SmartLightStrip, SmartPlug, ) -from kasa.device_factory import connect -from kasa.klapprotocol import TPLinkKlap +from kasa.device_factory import ( + DEVICE_TYPE_TO_CLASS, + connect, + get_protocol_from_connection_name, +) +from kasa.discover import DiscoveryResult +from kasa.iotprotocol import IotProtocol from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol @@ -22,11 +29,15 @@ from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol async def test_connect(discovery_data: dict, mocker, custom_port): """Make sure that connect returns an initialized SmartDevice instance.""" host = "127.0.0.1" - mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data) - dev = await connect(host, port=custom_port) - assert issubclass(dev.__class__, SmartDevice) - assert dev.port == custom_port or dev.port == 9999 + if "result" in discovery_data: + with pytest.raises(SmartDeviceException): + dev = await connect(host, port=custom_port) + else: + mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data) + dev = await connect(host, port=custom_port) + assert issubclass(dev.__class__, SmartDevice) + assert dev.port == custom_port or dev.port == 9999 @pytest.mark.parametrize("custom_port", [123, None]) @@ -49,11 +60,15 @@ async def test_connect_passed_device_type( ): """Make sure that connect with a passed device type.""" host = "127.0.0.1" - mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data) - dev = await connect(host, port=custom_port, device_type=device_type) - assert isinstance(dev, klass) - assert dev.port == custom_port or dev.port == 9999 + if "result" in discovery_data: + with pytest.raises(SmartDeviceException): + dev = await connect(host, port=custom_port) + else: + mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data) + dev = await connect(host, port=custom_port, device_type=device_type) + assert isinstance(dev, klass) + assert dev.port == custom_port or dev.port == 9999 async def test_connect_query_fails(discovery_data: dict, mocker): @@ -70,32 +85,52 @@ async def test_connect_logs_connect_time( ): """Test that the connect time is logged when debug logging is enabled.""" host = "127.0.0.1" - mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data) - logging.getLogger("kasa").setLevel(logging.DEBUG) - await connect(host) - assert "seconds to connect" in caplog.text + if "result" in discovery_data: + with pytest.raises(SmartDeviceException): + await connect(host) + else: + mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data) + logging.getLogger("kasa").setLevel(logging.DEBUG) + await connect(host) + assert "seconds to connect" in caplog.text -@pytest.mark.parametrize("device_type", [DeviceType.Plug, None]) -@pytest.mark.parametrize( - ("protocol_in", "protocol_result"), - ( - (None, TPLinkSmartHomeProtocol), - (TPLinkKlap, TPLinkKlap), - (TPLinkSmartHomeProtocol, TPLinkSmartHomeProtocol), - ), -) async def test_connect_pass_protocol( - discovery_data: dict, + all_fixture_data: dict, mocker, - device_type: DeviceType, - protocol_in: Type[TPLinkProtocol], - protocol_result: Type[TPLinkProtocol], ): """Test that if the protocol is passed in it's gets set correctly.""" - host = "127.0.0.1" - mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data) - mocker.patch("kasa.TPLinkKlap.query", return_value=discovery_data) + if "discovery_result" in all_fixture_data: + discovery_info = {"result": all_fixture_data["discovery_result"]} + device_class = Discover._get_device_class(discovery_info) + else: + device_class = Discover._get_device_class(all_fixture_data) - dev = await connect(host, device_type=device_type, protocol_class=protocol_in) - assert isinstance(dev.protocol, protocol_result) + device_type = list(DEVICE_TYPE_TO_CLASS.keys())[ + list(DEVICE_TYPE_TO_CLASS.values()).index(device_class) + ] + host = "127.0.0.1" + if "discovery_result" in all_fixture_data: + mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data) + mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data) + + dr = DiscoveryResult(**discovery_info["result"]) + connection_name = ( + dr.device_type.split(".")[0] + "." + dr.mgt_encrypt_schm.encrypt_type + ) + protocol_class = get_protocol_from_connection_name( + connection_name, host + ).__class__ + else: + mocker.patch( + "kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data + ) + protocol_class = TPLinkSmartHomeProtocol + + dev = await connect( + host, + device_type=device_type, + protocol_class=protocol_class, + credentials=Credentials("", ""), + ) + assert isinstance(dev.protocol, protocol_class) diff --git a/kasa/tests/test_discovery.py b/kasa/tests/test_discovery.py index 626afd18..ea97d94a 100644 --- a/kasa/tests/test_discovery.py +++ b/kasa/tests/test_discovery.py @@ -17,6 +17,27 @@ from kasa.exceptions import AuthenticationException, UnsupportedDeviceException from .conftest import bulb, dimmer, lightstrip, plug, strip +UNSUPPORTED = { + "result": { + "device_id": "xx", + "owner": "xx", + "device_type": "SMART.TAPOXMASTREE", + "device_model": "P110(EU)", + "ip": "127.0.0.1", + "mac": "48-22xxx", + "is_support_iot_cloud": True, + "obd_src": "tplink", + "factory_default": False, + "mgt_encrypt_schm": { + "is_support_https": False, + "encrypt_type": "AES", + "http_port": 80, + "lv": 2, + }, + }, + "error_code": 0, +} + @plug async def test_type_detection_plug(dev: SmartDevice): @@ -62,76 +83,40 @@ async def test_type_unknown(): @pytest.mark.parametrize("custom_port", [123, None]) -async def test_discover_single(discovery_data: dict, mocker, custom_port): +# @pytest.mark.parametrize("discovery_mock", [("127.0.0.1",123), ("127.0.0.1",None)], indirect=True) +async def test_discover_single(discovery_mock, custom_port, mocker): """Make sure that discover_single returns an initialized SmartDevice instance.""" host = "127.0.0.1" - info = {"system": {"get_sysinfo": discovery_data["system"]["get_sysinfo"]}} - query_mock = mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=info) - - def mock_discover(self): - self.datagram_received( - protocol.TPLinkSmartHomeProtocol.encrypt(json_dumps(info))[4:], - (host, custom_port or 9999), - ) - - mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover) + discovery_mock.ip = host + discovery_mock.port_override = custom_port + update_mock = mocker.patch.object(SmartStrip, "update") x = await Discover.discover_single(host, port=custom_port) assert issubclass(x.__class__, SmartDevice) - assert x._sys_info is not None - assert x.port == custom_port or x.port == 9999 - assert (query_mock.call_count > 0) == isinstance(x, SmartStrip) + assert x._discovery_info is not None + assert x.port == custom_port or x.port == discovery_mock.default_port + assert (update_mock.call_count > 0) == isinstance(x, SmartStrip) -async def test_discover_single_hostname(discovery_data: dict, mocker): +async def test_discover_single_hostname(discovery_mock, mocker): """Make sure that discover_single returns an initialized SmartDevice instance.""" host = "foobar" ip = "127.0.0.1" - info = {"system": {"get_sysinfo": discovery_data["system"]["get_sysinfo"]}} - query_mock = mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=info) - def mock_discover(self): - self.datagram_received( - protocol.TPLinkSmartHomeProtocol.encrypt(json_dumps(info))[4:], - (ip, 9999), - ) - - mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover) - mocker.patch("socket.getaddrinfo", return_value=[(None, None, None, None, (ip, 0))]) + discovery_mock.ip = ip + update_mock = mocker.patch.object(SmartStrip, "update") x = await Discover.discover_single(host) assert issubclass(x.__class__, SmartDevice) - assert x._sys_info is not None + assert x._discovery_info is not None assert x.host == host - assert (query_mock.call_count > 0) == isinstance(x, SmartStrip) + assert (update_mock.call_count > 0) == isinstance(x, SmartStrip) mocker.patch("socket.getaddrinfo", side_effect=socket.gaierror()) with pytest.raises(SmartDeviceException): x = await Discover.discover_single(host) -UNSUPPORTED = { - "result": { - "device_id": "xx", - "owner": "xx", - "device_type": "SMART.TAPOXMASTREE", - "device_model": "P110(EU)", - "ip": "127.0.0.1", - "mac": "48-22xxx", - "is_support_iot_cloud": True, - "obd_src": "tplink", - "factory_default": False, - "mgt_encrypt_schm": { - "is_support_https": False, - "encrypt_type": "AES", - "http_port": 80, - "lv": 2, - }, - }, - "error_code": 0, -} - - async def test_discover_single_unsupported(mocker): """Make sure that discover_single handles unsupported devices correctly.""" host = "127.0.0.1" @@ -201,14 +186,17 @@ async def test_discover_send(mocker): async def test_discover_datagram_received(mocker, discovery_data): """Verify that datagram received fills discovered_devices.""" proto = _DiscoverProtocol() - info = {"system": {"get_sysinfo": discovery_data["system"]["get_sysinfo"]}} - mocker.patch("kasa.discover.json_loads", return_value=info) - mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "encrypt") + mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "decrypt") addr = "127.0.0.1" - proto.datagram_received("", (addr, 9999)) + port = 20002 if "result" in discovery_data else 9999 + + mocker.patch("kasa.discover.json_loads", return_value=discovery_data) + proto.datagram_received("", (addr, port)) + addr2 = "127.0.0.2" + mocker.patch("kasa.discover.json_loads", return_value=UNSUPPORTED) proto.datagram_received("", (addr2, 20002)) # Check that device in discovered_devices is initialized correctly diff --git a/kasa/tests/test_klapprotocol.py b/kasa/tests/test_klapprotocol.py index 991dbe6f..fe4d1a6c 100644 --- a/kasa/tests/test_klapprotocol.py +++ b/kasa/tests/test_klapprotocol.py @@ -10,9 +10,14 @@ from contextlib import nullcontext as does_not_raise import httpx import pytest +from ..aestransport import AesTransport from ..credentials import Credentials from ..exceptions import AuthenticationException, SmartDeviceException -from ..klapprotocol import KlapEncryptionSession, TPLinkKlap, _sha256 +from ..iotprotocol import IotProtocol +from ..klaptransport import KlapEncryptionSession, KlapTransport, _sha256 +from ..smartprotocol import SmartProtocol + +DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}} class _mock_response: @@ -21,67 +26,92 @@ class _mock_response: self.content = content +@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport]) +@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol]) @pytest.mark.parametrize("retry_count", [1, 3, 5]) -async def test_protocol_retries(mocker, retry_count): +async def test_protocol_retries(mocker, retry_count, protocol_class, transport_class): + host = "127.0.0.1" conn = mocker.patch.object( - TPLinkKlap, "client_post", side_effect=Exception("dummy exception") + transport_class, "client_post", side_effect=Exception("dummy exception") ) with pytest.raises(SmartDeviceException): - await TPLinkKlap("127.0.0.1").query({}, retry_count=retry_count) + await protocol_class(host, transport=transport_class(host)).query( + DUMMY_QUERY, retry_count=retry_count + ) assert conn.call_count == retry_count + 1 -async def test_protocol_no_retry_on_connection_error(mocker): +@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport]) +@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol]) +async def test_protocol_no_retry_on_connection_error( + mocker, protocol_class, transport_class +): + host = "127.0.0.1" conn = mocker.patch.object( - TPLinkKlap, + transport_class, "client_post", side_effect=httpx.ConnectError("foo"), ) with pytest.raises(SmartDeviceException): - await TPLinkKlap("127.0.0.1").query({}, retry_count=5) + await protocol_class(host, transport=transport_class(host)).query( + DUMMY_QUERY, retry_count=5 + ) assert conn.call_count == 1 -async def test_protocol_retry_recoverable_error(mocker): +@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport]) +@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol]) +async def test_protocol_retry_recoverable_error( + mocker, protocol_class, transport_class +): + host = "127.0.0.1" conn = mocker.patch.object( - TPLinkKlap, + transport_class, "client_post", side_effect=httpx.CloseError("foo"), ) with pytest.raises(SmartDeviceException): - await TPLinkKlap("127.0.0.1").query({}, retry_count=5) + await protocol_class(host, transport=transport_class(host)).query( + DUMMY_QUERY, retry_count=5 + ) assert conn.call_count == 6 +@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport]) +@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol]) @pytest.mark.parametrize("retry_count", [1, 3, 5]) -async def test_protocol_reconnect(mocker, retry_count): +async def test_protocol_reconnect(mocker, retry_count, protocol_class, transport_class): + host = "127.0.0.1" remaining = retry_count + mock_response = {"result": {"great": "success"}} def _fail_one_less_than_retry_count(*_, **__): - nonlocal remaining, encryption_session + nonlocal remaining remaining -= 1 if remaining: raise Exception("Simulated post failure") - # Do the encrypt just before returning the value so the incrementing sequence number is correct - encrypted, seq = encryption_session.encrypt('{"great":"success"}') - return 200, encrypted - seed = secrets.token_bytes(16) - auth_hash = TPLinkKlap.generate_auth_hash(Credentials("foo", "bar")) - encryption_session = KlapEncryptionSession(seed, seed, auth_hash) - protocol = TPLinkKlap("127.0.0.1") - protocol.handshake_done = True - protocol.session_expire_at = time.time() + 86400 - protocol.encryption_session = encryption_session + return mock_response + mocker.patch.object( - TPLinkKlap, "client_post", side_effect=_fail_one_less_than_retry_count + transport_class, "needs_handshake", property(lambda self: False) + ) + mocker.patch.object(transport_class, "needs_login", property(lambda self: False)) + + send_mock = mocker.patch.object( + transport_class, + "send", + side_effect=_fail_one_less_than_retry_count, ) - response = await protocol.query({}, retry_count=retry_count) - assert response == {"great": "success"} + response = await protocol_class(host, transport=transport_class(host)).query( + DUMMY_QUERY, retry_count=retry_count + ) + assert "result" in response or "great" in response + assert send_mock.call_count == retry_count @pytest.mark.parametrize("log_level", [logging.WARNING, logging.DEBUG]) @@ -96,14 +126,14 @@ async def test_protocol_logging(mocker, caplog, log_level): return 200, encrypted seed = secrets.token_bytes(16) - auth_hash = TPLinkKlap.generate_auth_hash(Credentials("foo", "bar")) + auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar")) encryption_session = KlapEncryptionSession(seed, seed, auth_hash) - protocol = TPLinkKlap("127.0.0.1") + protocol = IotProtocol("127.0.0.1") - protocol.handshake_done = True - protocol.session_expire_at = time.time() + 86400 - protocol.encryption_session = encryption_session - mocker.patch.object(TPLinkKlap, "client_post", side_effect=_return_encrypted) + protocol._transport._handshake_done = True + protocol._transport._session_expire_at = time.time() + 86400 + protocol._transport._encryption_session = encryption_session + mocker.patch.object(KlapTransport, "client_post", side_effect=_return_encrypted) response = await protocol.query({}) assert response == {"great": "success"} @@ -117,7 +147,7 @@ def test_encrypt(): d = json.dumps({"foo": 1, "bar": 2}) seed = secrets.token_bytes(16) - auth_hash = TPLinkKlap.generate_auth_hash(Credentials("foo", "bar")) + auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar")) encryption_session = KlapEncryptionSession(seed, seed, auth_hash) encrypted, seq = encryption_session.encrypt(d) @@ -129,7 +159,7 @@ def test_encrypt_unicode(): d = "{'snowman': '\u2603'}" seed = secrets.token_bytes(16) - auth_hash = TPLinkKlap.generate_auth_hash(Credentials("foo", "bar")) + auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar")) encryption_session = KlapEncryptionSession(seed, seed, auth_hash) encrypted, seq = encryption_session.encrypt(d) @@ -145,7 +175,10 @@ def test_encrypt_unicode(): (Credentials("foo", "bar"), does_not_raise()), (Credentials("", ""), does_not_raise()), ( - Credentials(TPLinkKlap.KASA_SETUP_EMAIL, TPLinkKlap.KASA_SETUP_PASSWORD), + Credentials( + KlapTransport.KASA_SETUP_EMAIL, + KlapTransport.KASA_SETUP_PASSWORD, + ), does_not_raise(), ), ( @@ -167,21 +200,21 @@ async def test_handshake1(mocker, device_credentials, expectation): client_seed = None server_seed = secrets.token_bytes(16) client_credentials = Credentials("foo", "bar") - device_auth_hash = TPLinkKlap.generate_auth_hash(device_credentials) + device_auth_hash = KlapTransport.generate_auth_hash(device_credentials) mocker.patch.object( httpx.AsyncClient, "post", side_effect=_return_handshake1_response ) - protocol = TPLinkKlap("127.0.0.1", credentials=client_credentials) + protocol = IotProtocol("127.0.0.1", credentials=client_credentials) - protocol.http_client = httpx.AsyncClient() + protocol._transport.http_client = httpx.AsyncClient() with expectation: ( local_seed, device_remote_seed, auth_hash, - ) = await protocol.perform_handshake1() + ) = await protocol._transport.perform_handshake1() assert local_seed == client_seed assert device_remote_seed == server_seed @@ -204,23 +237,23 @@ async def test_handshake(mocker): client_seed = None server_seed = secrets.token_bytes(16) client_credentials = Credentials("foo", "bar") - device_auth_hash = TPLinkKlap.generate_auth_hash(client_credentials) + device_auth_hash = KlapTransport.generate_auth_hash(client_credentials) mocker.patch.object( httpx.AsyncClient, "post", side_effect=_return_handshake_response ) - protocol = TPLinkKlap("127.0.0.1", credentials=client_credentials) - protocol.http_client = httpx.AsyncClient() + protocol = IotProtocol("127.0.0.1", credentials=client_credentials) + protocol._transport.http_client = httpx.AsyncClient() response_status = 200 - await protocol.perform_handshake() - assert protocol.handshake_done is True + await protocol._transport.perform_handshake() + assert protocol._transport._handshake_done is True response_status = 403 with pytest.raises(AuthenticationException): - await protocol.perform_handshake() - assert protocol.handshake_done is False + await protocol._transport.perform_handshake() + assert protocol._transport._handshake_done is False await protocol.close() @@ -237,9 +270,9 @@ async def test_query(mocker): return _mock_response(200, b"") elif url == "http://127.0.0.1/app/request": encryption_session = KlapEncryptionSession( - protocol.encryption_session.local_seed, - protocol.encryption_session.remote_seed, - protocol.encryption_session.user_hash, + protocol._transport._encryption_session.local_seed, + protocol._transport._encryption_session.remote_seed, + protocol._transport._encryption_session.user_hash, ) seq = params.get("seq") encryption_session._seq = seq - 1 @@ -252,11 +285,11 @@ async def test_query(mocker): seq = None server_seed = secrets.token_bytes(16) client_credentials = Credentials("foo", "bar") - device_auth_hash = TPLinkKlap.generate_auth_hash(client_credentials) + device_auth_hash = KlapTransport.generate_auth_hash(client_credentials) mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response) - protocol = TPLinkKlap("127.0.0.1", credentials=client_credentials) + protocol = IotProtocol("127.0.0.1", credentials=client_credentials) for _ in range(10): resp = await protocol.query({}) @@ -296,11 +329,11 @@ async def test_authentication_failures(mocker, response_status, expectation): server_seed = secrets.token_bytes(16) client_credentials = Credentials("foo", "bar") - device_auth_hash = TPLinkKlap.generate_auth_hash(client_credentials) + device_auth_hash = KlapTransport.generate_auth_hash(client_credentials) mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response) - protocol = TPLinkKlap("127.0.0.1", credentials=client_credentials) + protocol = IotProtocol("127.0.0.1", credentials=client_credentials) with expectation: await protocol.query({}) diff --git a/kasa/tests/test_plug.py b/kasa/tests/test_plug.py index e9704310..e9e1592f 100644 --- a/kasa/tests/test_plug.py +++ b/kasa/tests/test_plug.py @@ -1,6 +1,6 @@ from kasa import DeviceType -from .conftest import plug +from .conftest import plug, plug_smart from .newfakes import PLUG_SCHEMA @@ -28,3 +28,14 @@ async def test_led(dev): assert dev.led await dev.set_led(original) + + +@plug_smart +async def test_plug_device_info(dev): + assert dev._info is not None + # PLUG_SCHEMA(dev.sys_info) + + assert dev.model is not None + + assert dev.device_type == DeviceType.Plug or dev.device_type == DeviceType.Strip + # assert dev.is_plug or dev.is_strip diff --git a/kasa/tests/test_readme_examples.py b/kasa/tests/test_readme_examples.py index 13c6e994..5772ba42 100644 --- a/kasa/tests/test_readme_examples.py +++ b/kasa/tests/test_readme_examples.py @@ -9,7 +9,7 @@ from kasa.tests.conftest import get_device_for_file def test_bulb_examples(mocker): """Use KL130 (bulb with all features) to test the doctests.""" - p = asyncio.run(get_device_for_file("KL130(US)_1.0_1.8.11.json")) + p = asyncio.run(get_device_for_file("KL130(US)_1.0_1.8.11.json", "IOT")) mocker.patch("kasa.smartbulb.SmartBulb", return_value=p) mocker.patch("kasa.smartbulb.SmartBulb.update") res = xdoctest.doctest_module("kasa.smartbulb", "all") @@ -18,7 +18,7 @@ def test_bulb_examples(mocker): def test_smartdevice_examples(mocker): """Use HS110 for emeter examples.""" - p = asyncio.run(get_device_for_file("HS110(EU)_1.0_1.2.5.json")) + p = asyncio.run(get_device_for_file("HS110(EU)_1.0_1.2.5.json", "IOT")) mocker.patch("kasa.smartdevice.SmartDevice", return_value=p) mocker.patch("kasa.smartdevice.SmartDevice.update") res = xdoctest.doctest_module("kasa.smartdevice", "all") @@ -27,7 +27,7 @@ def test_smartdevice_examples(mocker): def test_plug_examples(mocker): """Test plug examples.""" - p = asyncio.run(get_device_for_file("HS110(EU)_1.0_1.2.5.json")) + p = asyncio.run(get_device_for_file("HS110(EU)_1.0_1.2.5.json", "IOT")) mocker.patch("kasa.smartplug.SmartPlug", return_value=p) mocker.patch("kasa.smartplug.SmartPlug.update") res = xdoctest.doctest_module("kasa.smartplug", "all") @@ -36,7 +36,7 @@ def test_plug_examples(mocker): def test_strip_examples(mocker): """Test strip examples.""" - p = asyncio.run(get_device_for_file("KP303(UK)_1.0_1.0.3.json")) + p = asyncio.run(get_device_for_file("KP303(UK)_1.0_1.0.3.json", "IOT")) mocker.patch("kasa.smartstrip.SmartStrip", return_value=p) mocker.patch("kasa.smartstrip.SmartStrip.update") res = xdoctest.doctest_module("kasa.smartstrip", "all") @@ -45,7 +45,7 @@ def test_strip_examples(mocker): def test_dimmer_examples(mocker): """Test dimmer examples.""" - p = asyncio.run(get_device_for_file("HS220(US)_1.0_1.5.7.json")) + p = asyncio.run(get_device_for_file("HS220(US)_1.0_1.5.7.json", "IOT")) mocker.patch("kasa.smartdimmer.SmartDimmer", return_value=p) mocker.patch("kasa.smartdimmer.SmartDimmer.update") res = xdoctest.doctest_module("kasa.smartdimmer", "all") @@ -54,7 +54,7 @@ def test_dimmer_examples(mocker): def test_lightstrip_examples(mocker): """Test lightstrip examples.""" - p = asyncio.run(get_device_for_file("KL430(US)_1.0_1.0.10.json")) + p = asyncio.run(get_device_for_file("KL430(US)_1.0_1.0.10.json", "IOT")) mocker.patch("kasa.smartlightstrip.SmartLightStrip", return_value=p) mocker.patch("kasa.smartlightstrip.SmartLightStrip.update") res = xdoctest.doctest_module("kasa.smartlightstrip", "all") @@ -63,7 +63,7 @@ def test_lightstrip_examples(mocker): def test_discovery_examples(mocker): """Test discovery examples.""" - p = asyncio.run(get_device_for_file("KP303(UK)_1.0_1.0.3.json")) + p = asyncio.run(get_device_for_file("KP303(UK)_1.0_1.0.3.json", "IOT")) mocker.patch("kasa.discover.Discover.discover", return_value=[p]) res = xdoctest.doctest_module("kasa.discover", "all") diff --git a/kasa/tests/test_smartdevice.py b/kasa/tests/test_smartdevice.py index 85dc358d..33c9f448 100644 --- a/kasa/tests/test_smartdevice.py +++ b/kasa/tests/test_smartdevice.py @@ -8,7 +8,7 @@ import kasa from kasa import Credentials, SmartDevice, SmartDeviceException from kasa.smartdevice import DeviceType -from .conftest import handle_turn_on, has_emeter, no_emeter, turn_on +from .conftest import device_iot, handle_turn_on, has_emeter, no_emeter, turn_on from .newfakes import PLUG_SCHEMA, TZ_SCHEMA, FakeTransportProtocol # List of all SmartXXX classes including the SmartDevice base class @@ -22,11 +22,13 @@ smart_device_classes = [ ] +@device_iot async def test_state_info(dev): assert isinstance(dev.state_information, dict) @pytest.mark.requires_dummy +@device_iot async def test_invalid_connection(dev): with patch.object( FakeTransportProtocol, "query", side_effect=SmartDeviceException @@ -58,12 +60,14 @@ async def test_initial_update_no_emeter(dev, mocker): assert spy.call_count == 2 +@device_iot async def test_query_helper(dev): with pytest.raises(SmartDeviceException): await dev._query_helper("test", "testcmd", {}) # TODO check for unwrapping? +@device_iot @turn_on async def test_state(dev, turn_on): await handle_turn_on(dev, turn_on) @@ -90,6 +94,7 @@ async def test_state(dev, turn_on): assert dev.is_off +@device_iot async def test_alias(dev): test_alias = "TEST1234" original = dev.alias @@ -104,6 +109,7 @@ async def test_alias(dev): assert dev.alias == original +@device_iot @turn_on async def test_on_since(dev, turn_on): await handle_turn_on(dev, turn_on) @@ -116,30 +122,37 @@ async def test_on_since(dev, turn_on): assert dev.on_since is None +@device_iot async def test_time(dev): assert isinstance(await dev.get_time(), datetime) +@device_iot async def test_timezone(dev): TZ_SCHEMA(await dev.get_timezone()) +@device_iot async def test_hw_info(dev): PLUG_SCHEMA(dev.hw_info) +@device_iot async def test_location(dev): PLUG_SCHEMA(dev.location) +@device_iot async def test_rssi(dev): PLUG_SCHEMA({"rssi": dev.rssi}) # wrapping for vol +@device_iot async def test_mac(dev): PLUG_SCHEMA({"mac": dev.mac}) # wrapping for val +@device_iot async def test_representation(dev): import re @@ -147,6 +160,7 @@ async def test_representation(dev): assert pattern.match(str(dev)) +@device_iot async def test_childrens(dev): """Make sure that children property is exposed by every device.""" if dev.is_strip: @@ -155,6 +169,7 @@ async def test_childrens(dev): assert len(dev.children) == 0 +@device_iot async def test_children(dev): """Make sure that children property is exposed by every device.""" if dev.is_strip: @@ -165,11 +180,13 @@ async def test_children(dev): assert dev.has_children is False +@device_iot async def test_internal_state(dev): """Make sure the internal state returns the last update results.""" assert dev.internal_state == dev._last_update +@device_iot async def test_features(dev): """Make sure features is always accessible.""" sysinfo = dev._last_update["system"]["get_sysinfo"] @@ -179,11 +196,13 @@ async def test_features(dev): assert dev.features == set() +@device_iot async def test_max_device_response_size(dev): """Make sure every device return has a set max response size.""" assert dev.max_device_response_size > 0 +@device_iot async def test_estimated_response_sizes(dev): """Make sure every module has an estimated response size set.""" for mod in dev.modules.values(): @@ -202,6 +221,7 @@ def test_device_class_ctors(device_class): assert dev.credentials == credentials +@device_iot async def test_modules_preserved(dev: SmartDevice): """Make modules that are not being updated are preserved between updates.""" dev._last_update["some_module_not_being_updated"] = "should_be_kept" @@ -237,6 +257,7 @@ async def test_create_thin_wrapper(): ) +@device_iot async def test_modules_not_supported(dev: SmartDevice): """Test that unsupported modules do not break the device.""" for module in dev.modules.values():