Do login entirely within AesTransport (#580)

* Do login entirely within AesTransport

* Remove login and handshake attributes from BaseTransport

* Add AesTransport tests

* Synchronise transport and protocol __init__ signatures and rename internal variables

* Update after review
This commit is contained in:
sdb9696
2023-12-19 14:11:59 +00:00
committed by GitHub
parent 209391c422
commit 20ea6700a5
13 changed files with 468 additions and 237 deletions

View File

@@ -8,7 +8,7 @@ import base64
import hashlib
import logging
import time
from typing import Optional, Union
from typing import Optional
import httpx
from cryptography.hazmat.primitives import padding, serialization
@@ -47,6 +47,7 @@ class AesTransport(BaseTransport):
protocol, sometimes used by newer firmware versions on kasa devices.
"""
DEFAULT_PORT = 80
DEFAULT_TIMEOUT = 5
SESSION_COOKIE_NAME = "TP_SESSIONID"
COMMON_HEADERS = {
@@ -59,12 +60,16 @@ class AesTransport(BaseTransport):
self,
host: str,
*,
port: Optional[int] = None,
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
) -> None:
super().__init__(host=host)
self._credentials = credentials or Credentials(username="", password="")
super().__init__(
host,
port=port or self.DEFAULT_PORT,
credentials=credentials,
timeout=timeout,
)
self._handshake_done = False
@@ -77,7 +82,7 @@ class AesTransport(BaseTransport):
self._http_client: httpx.AsyncClient = httpx.AsyncClient()
self._login_token = None
_LOGGER.debug("Created AES object for %s", self.host)
_LOGGER.debug("Created AES transport for %s", self._host)
def hash_credentials(self, login_v2):
"""Hash the credentials."""
@@ -123,7 +128,7 @@ class AesTransport(BaseTransport):
if (
error_code := SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
) != SmartErrorCode.SUCCESS:
msg = f"{msg}: {self.host}: {error_code.name}({error_code.value})"
msg = f"{msg}: {self._host}: {error_code.name}({error_code.value})"
if error_code in SMART_TIMEOUT_ERRORS:
raise TimeoutException(msg)
if error_code in SMART_RETRYABLE_ERRORS:
@@ -136,7 +141,7 @@ class AesTransport(BaseTransport):
async def send_secure_passthrough(self, request: str):
"""Send encrypted message as passthrough."""
url = f"http://{self.host}/app"
url = f"http://{self._host}/app"
if self._login_token:
url += f"?token={self._login_token}"
@@ -150,7 +155,7 @@ class AesTransport(BaseTransport):
if status_code != 200:
raise SmartDeviceException(
f"{self.host} responded with an unexpected "
f"{self._host} responded with an unexpected "
+ f"status code {status_code} to passthrough"
)
@@ -164,49 +169,31 @@ class AesTransport(BaseTransport):
resp_dict = json_loads(response)
return resp_dict
async def perform_login(self, login_request: Union[str, dict], *, login_v2: bool):
async def _perform_login_for_version(self, *, login_version: int = 1):
"""Login to the device."""
self._login_token = None
if isinstance(login_request, str):
login_request_dict: dict = json_loads(login_request)
else:
login_request_dict = login_request
un, pw = self.hash_credentials(login_v2)
login_request_dict["params"] = {"password": pw, "username": un}
request = json_dumps(login_request_dict)
un, pw = self.hash_credentials(login_version == 2)
password_field_name = "password2" if login_version == 2 else "password"
login_request = {
"method": "login_device",
"params": {password_field_name: pw, "username": un},
"request_time_milis": round(time.time() * 1000),
}
request = json_dumps(login_request)
try:
resp_dict = await self.send_secure_passthrough(request)
except SmartDeviceException as ex:
raise AuthenticationException(ex) from ex
self._login_token = resp_dict["result"]["token"]
@property
def needs_login(self) -> bool:
"""Return true if the transport needs to do a login."""
return self._login_token is None
async def login(self, request: str) -> None:
async def perform_login(self) -> None:
"""Login to the device."""
try:
if self.needs_handshake:
raise SmartDeviceException(
"Handshake must be complete before trying to login"
)
await self.perform_login(request, login_v2=False)
await self._perform_login_for_version(login_version=2)
except AuthenticationException:
_LOGGER.warning("Login version 2 failed, trying version 1")
await self.perform_handshake()
await self.perform_login(request, login_v2=True)
@property
def needs_handshake(self) -> bool:
"""Return true if the transport needs to do a handshake."""
return not self._handshake_done or self._handshake_session_expired()
async def handshake(self) -> None:
"""Perform the encryption handshake."""
await self.perform_handshake()
await self._perform_login_for_version(login_version=1)
async def perform_handshake(self):
"""Perform the handshake."""
@@ -217,7 +204,7 @@ class AesTransport(BaseTransport):
self._session_expire_at = None
self._session_cookie = None
url = f"http://{self.host}/app"
url = f"http://{self._host}/app"
key_pair = KeyPair.create_key_pair()
pub_key = (
@@ -238,7 +225,7 @@ class AesTransport(BaseTransport):
if status_code != 200:
raise SmartDeviceException(
f"{self.host} responded with an unexpected "
f"{self._host} responded with an unexpected "
+ f"status code {status_code} to handshake"
)
@@ -261,7 +248,7 @@ class AesTransport(BaseTransport):
self._handshake_done = True
_LOGGER.debug("Handshake with %s complete", self.host)
_LOGGER.debug("Handshake with %s complete", self._host)
def _handshake_session_expired(self):
"""Return true if session has expired."""
@@ -272,12 +259,10 @@ class AesTransport(BaseTransport):
async def send(self, request: str):
"""Send the request."""
if self.needs_handshake:
raise SmartDeviceException(
"Handshake must be complete before trying to send"
)
if self.needs_login:
raise SmartDeviceException("Login must be complete before trying to send")
if not self._handshake_done or self._handshake_session_expired():
await self.perform_handshake()
if not self._login_token:
await self.perform_login()
return await self.send_secure_passthrough(request)