mirror of
https://github.com/python-kasa/python-kasa.git
synced 2024-12-22 19:23:34 +00:00
Revert changes moved into seperate PRs
This commit is contained in:
parent
7f8f823eac
commit
5d33a66341
@ -7,7 +7,6 @@ import base64
|
|||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
import secrets
|
import secrets
|
||||||
import socket
|
|
||||||
import ssl
|
import ssl
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from typing import TYPE_CHECKING, Any, cast
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
@ -108,7 +107,12 @@ class SslAesTransport(BaseTransport):
|
|||||||
self._app_url = URL(f"https://{self._host_port}")
|
self._app_url = URL(f"https://{self._host_port}")
|
||||||
self._token_url: URL | None = None
|
self._token_url: URL | None = None
|
||||||
self._ssl_context: ssl.SSLContext | None = None
|
self._ssl_context: ssl.SSLContext | None = None
|
||||||
self._headers: dict | None = None
|
ref = str(self._token_url) if self._token_url else str(self._app_url)
|
||||||
|
self._headers = {
|
||||||
|
**self.COMMON_HEADERS,
|
||||||
|
"Host": self._host_port,
|
||||||
|
"Referer": ref,
|
||||||
|
}
|
||||||
self._seq: int | None = None
|
self._seq: int | None = None
|
||||||
self._pwd_hash: str | None = None
|
self._pwd_hash: str | None = None
|
||||||
self._username: str | None = None
|
self._username: str | None = None
|
||||||
@ -157,19 +161,6 @@ class SslAesTransport(BaseTransport):
|
|||||||
error_code = SmartErrorCode.INTERNAL_UNKNOWN_ERROR
|
error_code = SmartErrorCode.INTERNAL_UNKNOWN_ERROR
|
||||||
return error_code
|
return error_code
|
||||||
|
|
||||||
def _get_response_inner_error(self, resp_dict: Any) -> SmartErrorCode | None:
|
|
||||||
error_code_raw = resp_dict.get("data", {}).get("error_code")
|
|
||||||
if error_code_raw is None:
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
error_code = SmartErrorCode.from_int(error_code_raw)
|
|
||||||
except ValueError:
|
|
||||||
_LOGGER.warning(
|
|
||||||
"Device %s received unknown error code: %s", self._host, error_code_raw
|
|
||||||
)
|
|
||||||
error_code = SmartErrorCode.INTERNAL_UNKNOWN_ERROR
|
|
||||||
return error_code
|
|
||||||
|
|
||||||
def _handle_response_error_code(self, resp_dict: Any, msg: str) -> None:
|
def _handle_response_error_code(self, resp_dict: Any, msg: str) -> None:
|
||||||
error_code = self._get_response_error(resp_dict)
|
error_code = self._get_response_error(resp_dict)
|
||||||
if error_code is SmartErrorCode.SUCCESS:
|
if error_code is SmartErrorCode.SUCCESS:
|
||||||
@ -197,34 +188,6 @@ class SslAesTransport(BaseTransport):
|
|||||||
)
|
)
|
||||||
return self._ssl_context
|
return self._ssl_context
|
||||||
|
|
||||||
async def _get_host_ip(self) -> str:
|
|
||||||
def get_ip() -> str:
|
|
||||||
# From https://stackoverflow.com/a/28950776
|
|
||||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
||||||
s.settimeout(0)
|
|
||||||
try:
|
|
||||||
# doesn't even have to be reachable
|
|
||||||
s.connect(("10.254.254.254", 1))
|
|
||||||
ip = s.getsockname()[0]
|
|
||||||
except Exception:
|
|
||||||
ip = "127.0.0.1"
|
|
||||||
finally:
|
|
||||||
s.close()
|
|
||||||
return ip
|
|
||||||
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
return await loop.run_in_executor(None, get_ip)
|
|
||||||
|
|
||||||
async def _get_headers(self) -> dict:
|
|
||||||
if not self._headers:
|
|
||||||
this_ip = await self._get_host_ip()
|
|
||||||
self._headers = {
|
|
||||||
**self.COMMON_HEADERS,
|
|
||||||
"Referer": f"https://{this_ip}",
|
|
||||||
"Host": self._host_port,
|
|
||||||
}
|
|
||||||
return self._headers
|
|
||||||
|
|
||||||
async def send_secure_passthrough(self, request: str) -> dict[str, Any]:
|
async def send_secure_passthrough(self, request: str) -> dict[str, Any]:
|
||||||
"""Send encrypted message as passthrough."""
|
"""Send encrypted message as passthrough."""
|
||||||
if self._state is TransportState.ESTABLISHED and self._token_url:
|
if self._state is TransportState.ESTABLISHED and self._token_url:
|
||||||
@ -249,7 +212,7 @@ class SslAesTransport(BaseTransport):
|
|||||||
tag = self.generate_tag(
|
tag = self.generate_tag(
|
||||||
passthrough_request_str, self._local_nonce, self._pwd_hash, self._seq
|
passthrough_request_str, self._local_nonce, self._pwd_hash, self._seq
|
||||||
)
|
)
|
||||||
headers = {**await self._get_headers(), "Seq": str(self._seq), "Tapo_tag": tag}
|
headers = {**self._headers, "Seq": str(self._seq), "Tapo_tag": tag}
|
||||||
self._seq += 1
|
self._seq += 1
|
||||||
status_code, resp_dict = await self._http_client.post(
|
status_code, resp_dict = await self._http_client.post(
|
||||||
url,
|
url,
|
||||||
@ -311,7 +274,7 @@ class SslAesTransport(BaseTransport):
|
|||||||
status_code, resp_dict = await self._http_client.post(
|
status_code, resp_dict = await self._http_client.post(
|
||||||
url,
|
url,
|
||||||
json=request,
|
json=request,
|
||||||
headers=await self._get_headers(),
|
headers=self._headers,
|
||||||
ssl=await self._get_ssl_context(),
|
ssl=await self._get_ssl_context(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -399,7 +362,7 @@ class SslAesTransport(BaseTransport):
|
|||||||
status_code, resp_dict = await http_client.post(
|
status_code, resp_dict = await http_client.post(
|
||||||
self._app_url,
|
self._app_url,
|
||||||
json=body,
|
json=body,
|
||||||
headers=await self._get_headers(),
|
headers=self._headers,
|
||||||
ssl=await self._get_ssl_context(),
|
ssl=await self._get_ssl_context(),
|
||||||
)
|
)
|
||||||
if status_code != 200:
|
if status_code != 200:
|
||||||
@ -443,7 +406,7 @@ class SslAesTransport(BaseTransport):
|
|||||||
status_code, resp_dict = await http_client.post(
|
status_code, resp_dict = await http_client.post(
|
||||||
self._app_url,
|
self._app_url,
|
||||||
json=body,
|
json=body,
|
||||||
headers=await self._get_headers(),
|
headers=self._headers,
|
||||||
ssl=await self._get_ssl_context(),
|
ssl=await self._get_ssl_context(),
|
||||||
)
|
)
|
||||||
if status_code != 200:
|
if status_code != 200:
|
||||||
@ -500,8 +463,8 @@ class SslAesTransport(BaseTransport):
|
|||||||
async def perform_handshake1(self) -> tuple[str, str, str] | None:
|
async def perform_handshake1(self) -> tuple[str, str, str] | None:
|
||||||
"""Perform the handshake1."""
|
"""Perform the handshake1."""
|
||||||
resp_dict = None
|
resp_dict = None
|
||||||
local_nonce = secrets.token_bytes(8).hex().upper()
|
|
||||||
if self._username:
|
if self._username:
|
||||||
|
local_nonce = secrets.token_bytes(8).hex().upper()
|
||||||
resp_dict = await self.try_send_handshake1(self._username, local_nonce)
|
resp_dict = await self.try_send_handshake1(self._username, local_nonce)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@ -519,7 +482,7 @@ class SslAesTransport(BaseTransport):
|
|||||||
or "nonce" not in resp_dict["result"].get("data", {})
|
or "nonce" not in resp_dict["result"].get("data", {})
|
||||||
):
|
):
|
||||||
_LOGGER.debug("Trying default credentials to %s", self._host)
|
_LOGGER.debug("Trying default credentials to %s", self._host)
|
||||||
# local_nonce = secrets.token_bytes(8).hex().upper()
|
local_nonce = secrets.token_bytes(8).hex().upper()
|
||||||
default_resp_dict = await self.try_send_handshake1(
|
default_resp_dict = await self.try_send_handshake1(
|
||||||
self._default_credentials.username, local_nonce
|
self._default_credentials.username, local_nonce
|
||||||
)
|
)
|
||||||
@ -545,18 +508,8 @@ class SslAesTransport(BaseTransport):
|
|||||||
f"Credentials must be supplied to connect to {self._host}"
|
f"Credentials must be supplied to connect to {self._host}"
|
||||||
)
|
)
|
||||||
if error_code is not SmartErrorCode.INVALID_NONCE or (
|
if error_code is not SmartErrorCode.INVALID_NONCE or (
|
||||||
resp_dict and "nonce" not in resp_dict.get("result", {}).get("data", {})
|
resp_dict and "nonce" not in resp_dict["result"].get("data", {})
|
||||||
):
|
):
|
||||||
if (
|
|
||||||
resp_dict
|
|
||||||
and self._get_response_inner_error(resp_dict)
|
|
||||||
is SmartErrorCode.DEVICE_BLOCKED
|
|
||||||
):
|
|
||||||
secs_left = resp_dict.get("data", {}).get("secs_left")
|
|
||||||
msg = "Device blocked" + (
|
|
||||||
f" for {secs_left} seconds" if secs_left else ""
|
|
||||||
)
|
|
||||||
raise DeviceError(msg)
|
|
||||||
raise AuthenticationError(
|
raise AuthenticationError(
|
||||||
f"Error trying handshake1 for {self._host}: {resp_dict}"
|
f"Error trying handshake1 for {self._host}: {resp_dict}"
|
||||||
)
|
)
|
||||||
@ -611,7 +564,7 @@ class SslAesTransport(BaseTransport):
|
|||||||
status_code, resp_dict = await http_client.post(
|
status_code, resp_dict = await http_client.post(
|
||||||
self._app_url,
|
self._app_url,
|
||||||
json=body,
|
json=body,
|
||||||
headers=await self._get_headers(),
|
headers=self._headers,
|
||||||
ssl=await self._get_ssl_context(),
|
ssl=await self._get_ssl_context(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user