diff --git a/kasa/exceptions.py b/kasa/exceptions.py index 9172cfc3..b646e514 100644 --- a/kasa/exceptions.py +++ b/kasa/exceptions.py @@ -186,6 +186,7 @@ SMART_RETRYABLE_ERRORS = [ SmartErrorCode.UNSPECIFIC_ERROR, SmartErrorCode.SESSION_TIMEOUT_ERROR, SmartErrorCode.SESSION_EXPIRED, + SmartErrorCode.INVALID_NONCE, ] SMART_AUTHENTICATION_ERRORS = [ diff --git a/kasa/experimental/sslaestransport.py b/kasa/experimental/sslaestransport.py index 2a5d12e2..9f891263 100644 --- a/kasa/experimental/sslaestransport.py +++ b/kasa/experimental/sslaestransport.py @@ -8,7 +8,6 @@ import hashlib import logging import secrets import ssl -import time from enum import Enum, auto from typing import TYPE_CHECKING, Any, Dict, cast @@ -29,7 +28,7 @@ from ..exceptions import ( from ..httpclient import HttpClient from ..json import dumps as json_dumps from ..json import loads as json_loads -from ..protocol import BaseTransport +from ..protocol import DEFAULT_CREDENTIALS, BaseTransport, get_default_credentials _LOGGER = logging.getLogger(__name__) @@ -71,7 +70,6 @@ class SslAesTransport(BaseTransport): "Accept": "application/json", "Accept-Encoding": "gzip, deflate", "User-Agent": "Tapo CameraClient Android", - "Connection": "close", } CIPHERS = ":".join( [ @@ -96,7 +94,9 @@ class SslAesTransport(BaseTransport): not self._credentials or self._credentials.username is None ) and not self._credentials_hash: self._credentials = Credentials() - self._default_credentials: Credentials | None = None + self._default_credentials: Credentials = get_default_credentials( + DEFAULT_CREDENTIALS["TAPOCAMERA"] + ) if not config.timeout: config.timeout = self.DEFAULT_TIMEOUT @@ -149,7 +149,7 @@ class SslAesTransport(BaseTransport): return base64.b64encode(json_dumps(ch).encode()).decode() return None - def _handle_response_error_code(self, resp_dict: Any, msg: str) -> None: + def _get_response_error(self, resp_dict: Any) -> SmartErrorCode: error_code_raw = resp_dict.get("error_code") try: error_code = SmartErrorCode.from_int(error_code_raw) @@ -158,6 +158,10 @@ class SslAesTransport(BaseTransport): "Device %s received unknown error code: %s", self._host, error_code_raw ) error_code = SmartErrorCode.INTERNAL_UNKNOWN_ERROR + return error_code + + def _handle_response_error_code(self, resp_dict: Any, msg: str) -> None: + error_code = self._get_response_error(resp_dict) if error_code is SmartErrorCode.SUCCESS: return msg = f"{msg}: {self._host}: {error_code.name}({error_code.value})" @@ -325,6 +329,8 @@ class SslAesTransport(BaseTransport): + f"status code {status_code} to handshake2" ) resp_dict = cast(dict, resp_dict) + self._handle_response_error_code(resp_dict, "Error in handshake2") + self._seq = resp_dict["result"]["start_seq"] stok = resp_dict["result"]["stok"] self._token_url = URL(f"{str(self._app_url)}/stok={stok}/ds") @@ -337,13 +343,80 @@ class SslAesTransport(BaseTransport): _LOGGER.debug("Handshake2 complete ...") async def perform_handshake1(self) -> tuple[str, str, str]: - """Perform the handshake.""" - _LOGGER.debug("Will perform handshaking...") + """Perform the handshake1.""" + resp_dict = None + if self._username: + local_nonce = secrets.token_bytes(8).hex().upper() + resp_dict = await self.try_send_handshake1(self._username, local_nonce) + + # Try the default username. If it fails raise the original error_code + if ( + not resp_dict + or (error_code := self._get_response_error(resp_dict)) + is not SmartErrorCode.INVALID_NONCE + or "nonce" not in resp_dict["result"].get("data", {}) + ): + local_nonce = secrets.token_bytes(8).hex().upper() + default_resp_dict = await self.try_send_handshake1( + self._default_credentials.username, local_nonce + ) + if ( + default_error_code := self._get_response_error(default_resp_dict) + ) is SmartErrorCode.INVALID_NONCE and "nonce" in default_resp_dict[ + "result" + ].get("data", {}): + _LOGGER.debug("Connected to {self._host} with default username") + self._username = self._default_credentials.username + error_code = default_error_code + resp_dict = default_resp_dict if not self._username: - raise KasaException("Cannot connect to device with no credentials") - local_nonce = secrets.token_bytes(8).hex().upper() - # Device needs the content length or it will response with 500 + raise AuthenticationError( + "Credentials must be supplied to connect to {self._host}" + ) + if error_code is not SmartErrorCode.INVALID_NONCE or ( + resp_dict and "nonce" not in resp_dict["result"].get("data", {}) + ): + raise AuthenticationError("Error trying handshake1: {resp_dict}") + + if TYPE_CHECKING: + resp_dict = cast(Dict[str, Any], resp_dict) + + server_nonce = resp_dict["result"]["data"]["nonce"] + device_confirm = resp_dict["result"]["data"]["device_confirm"] + if self._credentials and self._credentials != Credentials(): + pwd_hash = _sha256_hash(self._credentials.password.encode()) + elif self._username and self._password: + pwd_hash = _sha256_hash(self._password.encode()) + else: + pwd_hash = _sha256_hash(self._default_credentials.password.encode()) + + expected_confirm_sha256 = self.generate_confirm_hash( + local_nonce, server_nonce, pwd_hash + ) + if device_confirm == expected_confirm_sha256: + _LOGGER.debug("Credentials match") + return local_nonce, server_nonce, pwd_hash + + if TYPE_CHECKING: + assert self._credentials + assert self._credentials.password + pwd_hash = _md5_hash(self._credentials.password.encode()) + expected_confirm_md5 = self.generate_confirm_hash( + local_nonce, server_nonce, pwd_hash + ) + if device_confirm == expected_confirm_md5: + _LOGGER.debug("Credentials match") + return local_nonce, server_nonce, pwd_hash + + msg = f"Server response doesn't match our challenge on ip {self._host}" + _LOGGER.debug(msg) + raise AuthenticationError(msg) + + async def try_send_handshake1(self, username: str, local_nonce: str) -> dict: + """Perform the handshake.""" + _LOGGER.debug("Will to send handshake1...") + body = { "method": "login", "params": { @@ -369,58 +442,11 @@ class SslAesTransport(BaseTransport): + f"status code {status_code} to handshake1" ) - resp_dict = cast(dict, resp_dict) - error_code = SmartErrorCode.from_int(resp_dict["error_code"]) - if error_code != SmartErrorCode.INVALID_NONCE: - self._handle_response_error_code(resp_dict, "Unable to complete handshake") - - if TYPE_CHECKING: - resp_dict = cast(Dict[str, Any], resp_dict) - - server_nonce = resp_dict["result"]["data"]["nonce"] - device_confirm = resp_dict["result"]["data"]["device_confirm"] - if self._credentials and self._credentials != Credentials(): - pwd_hash = _sha256_hash(self._credentials.password.encode()) - else: - if TYPE_CHECKING: - assert self._pwd_hash - pwd_hash = self._pwd_hash - - expected_confirm_sha256 = self.generate_confirm_hash( - local_nonce, server_nonce, pwd_hash - ) - if device_confirm == expected_confirm_sha256: - _LOGGER.debug("Credentials match") - return local_nonce, server_nonce, pwd_hash - - if TYPE_CHECKING: - assert self._credentials - assert self._credentials.password - pwd_hash = _md5_hash(self._credentials.password.encode()) - expected_confirm_md5 = self.generate_confirm_hash( - local_nonce, server_nonce, pwd_hash - ) - if device_confirm == expected_confirm_md5: - _LOGGER.debug("Credentials match") - return local_nonce, server_nonce, pwd_hash - - msg = f"Server response doesn't match our challenge on ip {self._host}" - _LOGGER.debug(msg) - raise AuthenticationError(msg) - - def _handshake_session_expired(self): - """Return true if session has expired.""" - return ( - self._session_expire_at is None - or self._session_expire_at - time.time() <= 0 - ) + return cast(dict, resp_dict) async def send(self, request: str) -> dict[str, Any]: """Send the request.""" - if ( - self._state is TransportState.HANDSHAKE_REQUIRED - or self._handshake_session_expired() - ): + if self._state is TransportState.HANDSHAKE_REQUIRED: await self.perform_handshake() return await self.send_secure_passthrough(request) diff --git a/kasa/protocol.py b/kasa/protocol.py index 9b5ffa3d..1107fa1d 100755 --- a/kasa/protocol.py +++ b/kasa/protocol.py @@ -155,4 +155,5 @@ def get_default_credentials(tuple: tuple[str, str]) -> Credentials: DEFAULT_CREDENTIALS = { "KASA": ("a2FzYUB0cC1saW5rLm5ldA==", "a2FzYVNldHVw"), "TAPO": ("dGVzdEB0cC1saW5rLm5ldA==", "dGVzdA=="), + "TAPOCAMERA": ("YWRtaW4=", "YWRtaW4="), }