Revert changes moved into seperate PRs

This commit is contained in:
Steven B 2024-12-20 12:48:48 +00:00
parent 7f8f823eac
commit 5d33a66341
No known key found for this signature in database
GPG Key ID: 6D5B46B3679F2A43

View File

@ -7,7 +7,6 @@ import base64
import hashlib import hashlib
import logging import logging
import secrets import secrets
import socket
import ssl import ssl
from enum import Enum, auto from enum import Enum, auto
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any, cast
@ -108,7 +107,12 @@ class SslAesTransport(BaseTransport):
self._app_url = URL(f"https://{self._host_port}") self._app_url = URL(f"https://{self._host_port}")
self._token_url: URL | None = None self._token_url: URL | None = None
self._ssl_context: ssl.SSLContext | None = None self._ssl_context: ssl.SSLContext | None = None
self._headers: dict | None = None ref = str(self._token_url) if self._token_url else str(self._app_url)
self._headers = {
**self.COMMON_HEADERS,
"Host": self._host_port,
"Referer": ref,
}
self._seq: int | None = None self._seq: int | None = None
self._pwd_hash: str | None = None self._pwd_hash: str | None = None
self._username: str | None = None self._username: str | None = None
@ -157,19 +161,6 @@ class SslAesTransport(BaseTransport):
error_code = SmartErrorCode.INTERNAL_UNKNOWN_ERROR error_code = SmartErrorCode.INTERNAL_UNKNOWN_ERROR
return error_code return error_code
def _get_response_inner_error(self, resp_dict: Any) -> SmartErrorCode | None:
error_code_raw = resp_dict.get("data", {}).get("error_code")
if error_code_raw is None:
return None
try:
error_code = SmartErrorCode.from_int(error_code_raw)
except ValueError:
_LOGGER.warning(
"Device %s received unknown error code: %s", self._host, error_code_raw
)
error_code = SmartErrorCode.INTERNAL_UNKNOWN_ERROR
return error_code
def _handle_response_error_code(self, resp_dict: Any, msg: str) -> None: def _handle_response_error_code(self, resp_dict: Any, msg: str) -> None:
error_code = self._get_response_error(resp_dict) error_code = self._get_response_error(resp_dict)
if error_code is SmartErrorCode.SUCCESS: if error_code is SmartErrorCode.SUCCESS:
@ -197,34 +188,6 @@ class SslAesTransport(BaseTransport):
) )
return self._ssl_context return self._ssl_context
async def _get_host_ip(self) -> str:
def get_ip() -> str:
# From https://stackoverflow.com/a/28950776
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.settimeout(0)
try:
# doesn't even have to be reachable
s.connect(("10.254.254.254", 1))
ip = s.getsockname()[0]
except Exception:
ip = "127.0.0.1"
finally:
s.close()
return ip
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, get_ip)
async def _get_headers(self) -> dict:
if not self._headers:
this_ip = await self._get_host_ip()
self._headers = {
**self.COMMON_HEADERS,
"Referer": f"https://{this_ip}",
"Host": self._host_port,
}
return self._headers
async def send_secure_passthrough(self, request: str) -> dict[str, Any]: async def send_secure_passthrough(self, request: str) -> dict[str, Any]:
"""Send encrypted message as passthrough.""" """Send encrypted message as passthrough."""
if self._state is TransportState.ESTABLISHED and self._token_url: if self._state is TransportState.ESTABLISHED and self._token_url:
@ -249,7 +212,7 @@ class SslAesTransport(BaseTransport):
tag = self.generate_tag( tag = self.generate_tag(
passthrough_request_str, self._local_nonce, self._pwd_hash, self._seq passthrough_request_str, self._local_nonce, self._pwd_hash, self._seq
) )
headers = {**await self._get_headers(), "Seq": str(self._seq), "Tapo_tag": tag} headers = {**self._headers, "Seq": str(self._seq), "Tapo_tag": tag}
self._seq += 1 self._seq += 1
status_code, resp_dict = await self._http_client.post( status_code, resp_dict = await self._http_client.post(
url, url,
@ -311,7 +274,7 @@ class SslAesTransport(BaseTransport):
status_code, resp_dict = await self._http_client.post( status_code, resp_dict = await self._http_client.post(
url, url,
json=request, json=request,
headers=await self._get_headers(), headers=self._headers,
ssl=await self._get_ssl_context(), ssl=await self._get_ssl_context(),
) )
@ -399,7 +362,7 @@ class SslAesTransport(BaseTransport):
status_code, resp_dict = await http_client.post( status_code, resp_dict = await http_client.post(
self._app_url, self._app_url,
json=body, json=body,
headers=await self._get_headers(), headers=self._headers,
ssl=await self._get_ssl_context(), ssl=await self._get_ssl_context(),
) )
if status_code != 200: if status_code != 200:
@ -443,7 +406,7 @@ class SslAesTransport(BaseTransport):
status_code, resp_dict = await http_client.post( status_code, resp_dict = await http_client.post(
self._app_url, self._app_url,
json=body, json=body,
headers=await self._get_headers(), headers=self._headers,
ssl=await self._get_ssl_context(), ssl=await self._get_ssl_context(),
) )
if status_code != 200: if status_code != 200:
@ -500,8 +463,8 @@ class SslAesTransport(BaseTransport):
async def perform_handshake1(self) -> tuple[str, str, str] | None: async def perform_handshake1(self) -> tuple[str, str, str] | None:
"""Perform the handshake1.""" """Perform the handshake1."""
resp_dict = None resp_dict = None
local_nonce = secrets.token_bytes(8).hex().upper()
if self._username: if self._username:
local_nonce = secrets.token_bytes(8).hex().upper()
resp_dict = await self.try_send_handshake1(self._username, local_nonce) resp_dict = await self.try_send_handshake1(self._username, local_nonce)
if ( if (
@ -519,7 +482,7 @@ class SslAesTransport(BaseTransport):
or "nonce" not in resp_dict["result"].get("data", {}) or "nonce" not in resp_dict["result"].get("data", {})
): ):
_LOGGER.debug("Trying default credentials to %s", self._host) _LOGGER.debug("Trying default credentials to %s", self._host)
# local_nonce = secrets.token_bytes(8).hex().upper() local_nonce = secrets.token_bytes(8).hex().upper()
default_resp_dict = await self.try_send_handshake1( default_resp_dict = await self.try_send_handshake1(
self._default_credentials.username, local_nonce self._default_credentials.username, local_nonce
) )
@ -545,18 +508,8 @@ class SslAesTransport(BaseTransport):
f"Credentials must be supplied to connect to {self._host}" f"Credentials must be supplied to connect to {self._host}"
) )
if error_code is not SmartErrorCode.INVALID_NONCE or ( if error_code is not SmartErrorCode.INVALID_NONCE or (
resp_dict and "nonce" not in resp_dict.get("result", {}).get("data", {}) resp_dict and "nonce" not in resp_dict["result"].get("data", {})
): ):
if (
resp_dict
and self._get_response_inner_error(resp_dict)
is SmartErrorCode.DEVICE_BLOCKED
):
secs_left = resp_dict.get("data", {}).get("secs_left")
msg = "Device blocked" + (
f" for {secs_left} seconds" if secs_left else ""
)
raise DeviceError(msg)
raise AuthenticationError( raise AuthenticationError(
f"Error trying handshake1 for {self._host}: {resp_dict}" f"Error trying handshake1 for {self._host}: {resp_dict}"
) )
@ -611,7 +564,7 @@ class SslAesTransport(BaseTransport):
status_code, resp_dict = await http_client.post( status_code, resp_dict = await http_client.post(
self._app_url, self._app_url,
json=body, json=body,
headers=await self._get_headers(), headers=self._headers,
ssl=await self._get_ssl_context(), ssl=await self._get_ssl_context(),
) )