mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-10-24 00:08:01 +00:00
Do login entirely within AesTransport (#580)
* Do login entirely within AesTransport * Remove login and handshake attributes from BaseTransport * Add AesTransport tests * Synchronise transport and protocol __init__ signatures and rename internal variables * Update after review
This commit is contained in:
@@ -8,7 +8,7 @@ import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
from cryptography.hazmat.primitives import padding, serialization
|
||||
@@ -47,6 +47,7 @@ class AesTransport(BaseTransport):
|
||||
protocol, sometimes used by newer firmware versions on kasa devices.
|
||||
"""
|
||||
|
||||
DEFAULT_PORT = 80
|
||||
DEFAULT_TIMEOUT = 5
|
||||
SESSION_COOKIE_NAME = "TP_SESSIONID"
|
||||
COMMON_HEADERS = {
|
||||
@@ -59,12 +60,16 @@ class AesTransport(BaseTransport):
|
||||
self,
|
||||
host: str,
|
||||
*,
|
||||
port: Optional[int] = None,
|
||||
credentials: Optional[Credentials] = None,
|
||||
timeout: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__(host=host)
|
||||
|
||||
self._credentials = credentials or Credentials(username="", password="")
|
||||
super().__init__(
|
||||
host,
|
||||
port=port or self.DEFAULT_PORT,
|
||||
credentials=credentials,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
self._handshake_done = False
|
||||
|
||||
@@ -77,7 +82,7 @@ class AesTransport(BaseTransport):
|
||||
self._http_client: httpx.AsyncClient = httpx.AsyncClient()
|
||||
self._login_token = None
|
||||
|
||||
_LOGGER.debug("Created AES object for %s", self.host)
|
||||
_LOGGER.debug("Created AES transport for %s", self._host)
|
||||
|
||||
def hash_credentials(self, login_v2):
|
||||
"""Hash the credentials."""
|
||||
@@ -123,7 +128,7 @@ class AesTransport(BaseTransport):
|
||||
if (
|
||||
error_code := SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
|
||||
) != SmartErrorCode.SUCCESS:
|
||||
msg = f"{msg}: {self.host}: {error_code.name}({error_code.value})"
|
||||
msg = f"{msg}: {self._host}: {error_code.name}({error_code.value})"
|
||||
if error_code in SMART_TIMEOUT_ERRORS:
|
||||
raise TimeoutException(msg)
|
||||
if error_code in SMART_RETRYABLE_ERRORS:
|
||||
@@ -136,7 +141,7 @@ class AesTransport(BaseTransport):
|
||||
|
||||
async def send_secure_passthrough(self, request: str):
|
||||
"""Send encrypted message as passthrough."""
|
||||
url = f"http://{self.host}/app"
|
||||
url = f"http://{self._host}/app"
|
||||
if self._login_token:
|
||||
url += f"?token={self._login_token}"
|
||||
|
||||
@@ -150,7 +155,7 @@ class AesTransport(BaseTransport):
|
||||
|
||||
if status_code != 200:
|
||||
raise SmartDeviceException(
|
||||
f"{self.host} responded with an unexpected "
|
||||
f"{self._host} responded with an unexpected "
|
||||
+ f"status code {status_code} to passthrough"
|
||||
)
|
||||
|
||||
@@ -164,49 +169,31 @@ class AesTransport(BaseTransport):
|
||||
resp_dict = json_loads(response)
|
||||
return resp_dict
|
||||
|
||||
async def perform_login(self, login_request: Union[str, dict], *, login_v2: bool):
|
||||
async def _perform_login_for_version(self, *, login_version: int = 1):
|
||||
"""Login to the device."""
|
||||
self._login_token = None
|
||||
|
||||
if isinstance(login_request, str):
|
||||
login_request_dict: dict = json_loads(login_request)
|
||||
else:
|
||||
login_request_dict = login_request
|
||||
|
||||
un, pw = self.hash_credentials(login_v2)
|
||||
login_request_dict["params"] = {"password": pw, "username": un}
|
||||
request = json_dumps(login_request_dict)
|
||||
un, pw = self.hash_credentials(login_version == 2)
|
||||
password_field_name = "password2" if login_version == 2 else "password"
|
||||
login_request = {
|
||||
"method": "login_device",
|
||||
"params": {password_field_name: pw, "username": un},
|
||||
"request_time_milis": round(time.time() * 1000),
|
||||
}
|
||||
request = json_dumps(login_request)
|
||||
try:
|
||||
resp_dict = await self.send_secure_passthrough(request)
|
||||
except SmartDeviceException as ex:
|
||||
raise AuthenticationException(ex) from ex
|
||||
self._login_token = resp_dict["result"]["token"]
|
||||
|
||||
@property
|
||||
def needs_login(self) -> bool:
|
||||
"""Return true if the transport needs to do a login."""
|
||||
return self._login_token is None
|
||||
|
||||
async def login(self, request: str) -> None:
|
||||
async def perform_login(self) -> None:
|
||||
"""Login to the device."""
|
||||
try:
|
||||
if self.needs_handshake:
|
||||
raise SmartDeviceException(
|
||||
"Handshake must be complete before trying to login"
|
||||
)
|
||||
await self.perform_login(request, login_v2=False)
|
||||
await self._perform_login_for_version(login_version=2)
|
||||
except AuthenticationException:
|
||||
_LOGGER.warning("Login version 2 failed, trying version 1")
|
||||
await self.perform_handshake()
|
||||
await self.perform_login(request, login_v2=True)
|
||||
|
||||
@property
|
||||
def needs_handshake(self) -> bool:
|
||||
"""Return true if the transport needs to do a handshake."""
|
||||
return not self._handshake_done or self._handshake_session_expired()
|
||||
|
||||
async def handshake(self) -> None:
|
||||
"""Perform the encryption handshake."""
|
||||
await self.perform_handshake()
|
||||
await self._perform_login_for_version(login_version=1)
|
||||
|
||||
async def perform_handshake(self):
|
||||
"""Perform the handshake."""
|
||||
@@ -217,7 +204,7 @@ class AesTransport(BaseTransport):
|
||||
self._session_expire_at = None
|
||||
self._session_cookie = None
|
||||
|
||||
url = f"http://{self.host}/app"
|
||||
url = f"http://{self._host}/app"
|
||||
key_pair = KeyPair.create_key_pair()
|
||||
|
||||
pub_key = (
|
||||
@@ -238,7 +225,7 @@ class AesTransport(BaseTransport):
|
||||
|
||||
if status_code != 200:
|
||||
raise SmartDeviceException(
|
||||
f"{self.host} responded with an unexpected "
|
||||
f"{self._host} responded with an unexpected "
|
||||
+ f"status code {status_code} to handshake"
|
||||
)
|
||||
|
||||
@@ -261,7 +248,7 @@ class AesTransport(BaseTransport):
|
||||
|
||||
self._handshake_done = True
|
||||
|
||||
_LOGGER.debug("Handshake with %s complete", self.host)
|
||||
_LOGGER.debug("Handshake with %s complete", self._host)
|
||||
|
||||
def _handshake_session_expired(self):
|
||||
"""Return true if session has expired."""
|
||||
@@ -272,12 +259,10 @@ class AesTransport(BaseTransport):
|
||||
|
||||
async def send(self, request: str):
|
||||
"""Send the request."""
|
||||
if self.needs_handshake:
|
||||
raise SmartDeviceException(
|
||||
"Handshake must be complete before trying to send"
|
||||
)
|
||||
if self.needs_login:
|
||||
raise SmartDeviceException("Login must be complete before trying to send")
|
||||
if not self._handshake_done or self._handshake_session_expired():
|
||||
await self.perform_handshake()
|
||||
if not self._login_token:
|
||||
await self.perform_login()
|
||||
|
||||
return await self.send_secure_passthrough(request)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user