diff --git a/kasa/transports/sslaestransport.py b/kasa/transports/sslaestransport.py index 15c2bc01..bbbb841d 100644 --- a/kasa/transports/sslaestransport.py +++ b/kasa/transports/sslaestransport.py @@ -7,7 +7,6 @@ import base64 import hashlib import logging import secrets -import socket import ssl from enum import Enum, auto from typing import TYPE_CHECKING, Any, cast @@ -108,7 +107,12 @@ class SslAesTransport(BaseTransport): self._app_url = URL(f"https://{self._host_port}") self._token_url: URL | None = None self._ssl_context: ssl.SSLContext | None = None - self._headers: dict | None = None + ref = str(self._token_url) if self._token_url else str(self._app_url) + self._headers = { + **self.COMMON_HEADERS, + "Host": self._host_port, + "Referer": ref, + } self._seq: int | None = None self._pwd_hash: str | None = None self._username: str | None = None @@ -157,19 +161,6 @@ class SslAesTransport(BaseTransport): error_code = SmartErrorCode.INTERNAL_UNKNOWN_ERROR return error_code - def _get_response_inner_error(self, resp_dict: Any) -> SmartErrorCode | None: - error_code_raw = resp_dict.get("data", {}).get("error_code") - if error_code_raw is None: - return None - try: - error_code = SmartErrorCode.from_int(error_code_raw) - except ValueError: - _LOGGER.warning( - "Device %s received unknown error code: %s", self._host, error_code_raw - ) - error_code = SmartErrorCode.INTERNAL_UNKNOWN_ERROR - return error_code - def _handle_response_error_code(self, resp_dict: Any, msg: str) -> None: error_code = self._get_response_error(resp_dict) if error_code is SmartErrorCode.SUCCESS: @@ -197,34 +188,6 @@ class SslAesTransport(BaseTransport): ) return self._ssl_context - async def _get_host_ip(self) -> str: - def get_ip() -> str: - # From https://stackoverflow.com/a/28950776 - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - s.settimeout(0) - try: - # doesn't even have to be reachable - s.connect(("10.254.254.254", 1)) - ip = s.getsockname()[0] - except Exception: - ip = "127.0.0.1" - finally: - s.close() - return ip - - loop = asyncio.get_running_loop() - return await loop.run_in_executor(None, get_ip) - - async def _get_headers(self) -> dict: - if not self._headers: - this_ip = await self._get_host_ip() - self._headers = { - **self.COMMON_HEADERS, - "Referer": f"https://{this_ip}", - "Host": self._host_port, - } - return self._headers - async def send_secure_passthrough(self, request: str) -> dict[str, Any]: """Send encrypted message as passthrough.""" if self._state is TransportState.ESTABLISHED and self._token_url: @@ -249,7 +212,7 @@ class SslAesTransport(BaseTransport): tag = self.generate_tag( passthrough_request_str, self._local_nonce, self._pwd_hash, self._seq ) - headers = {**await self._get_headers(), "Seq": str(self._seq), "Tapo_tag": tag} + headers = {**self._headers, "Seq": str(self._seq), "Tapo_tag": tag} self._seq += 1 status_code, resp_dict = await self._http_client.post( url, @@ -311,7 +274,7 @@ class SslAesTransport(BaseTransport): status_code, resp_dict = await self._http_client.post( url, json=request, - headers=await self._get_headers(), + headers=self._headers, ssl=await self._get_ssl_context(), ) @@ -399,7 +362,7 @@ class SslAesTransport(BaseTransport): status_code, resp_dict = await http_client.post( self._app_url, json=body, - headers=await self._get_headers(), + headers=self._headers, ssl=await self._get_ssl_context(), ) if status_code != 200: @@ -443,7 +406,7 @@ class SslAesTransport(BaseTransport): status_code, resp_dict = await http_client.post( self._app_url, json=body, - headers=await self._get_headers(), + headers=self._headers, ssl=await self._get_ssl_context(), ) if status_code != 200: @@ -500,8 +463,8 @@ class SslAesTransport(BaseTransport): async def perform_handshake1(self) -> tuple[str, str, str] | None: """Perform the handshake1.""" resp_dict = None - local_nonce = secrets.token_bytes(8).hex().upper() if self._username: + local_nonce = secrets.token_bytes(8).hex().upper() resp_dict = await self.try_send_handshake1(self._username, local_nonce) if ( @@ -519,7 +482,7 @@ class SslAesTransport(BaseTransport): or "nonce" not in resp_dict["result"].get("data", {}) ): _LOGGER.debug("Trying default credentials to %s", self._host) - # local_nonce = secrets.token_bytes(8).hex().upper() + local_nonce = secrets.token_bytes(8).hex().upper() default_resp_dict = await self.try_send_handshake1( self._default_credentials.username, local_nonce ) @@ -545,18 +508,8 @@ class SslAesTransport(BaseTransport): f"Credentials must be supplied to connect to {self._host}" ) if error_code is not SmartErrorCode.INVALID_NONCE or ( - resp_dict and "nonce" not in resp_dict.get("result", {}).get("data", {}) + resp_dict and "nonce" not in resp_dict["result"].get("data", {}) ): - if ( - resp_dict - and self._get_response_inner_error(resp_dict) - is SmartErrorCode.DEVICE_BLOCKED - ): - secs_left = resp_dict.get("data", {}).get("secs_left") - msg = "Device blocked" + ( - f" for {secs_left} seconds" if secs_left else "" - ) - raise DeviceError(msg) raise AuthenticationError( f"Error trying handshake1 for {self._host}: {resp_dict}" ) @@ -611,7 +564,7 @@ class SslAesTransport(BaseTransport): status_code, resp_dict = await http_client.post( self._app_url, json=body, - headers=await self._get_headers(), + headers=self._headers, ssl=await self._get_ssl_context(), )