diff --git a/kasa/aestransport.py b/kasa/aestransport.py index 412dbbf2..4e1ccb7d 100644 --- a/kasa/aestransport.py +++ b/kasa/aestransport.py @@ -151,7 +151,7 @@ class AesTransport(BaseTransport): async def send_secure_passthrough(self, request: str) -> Dict[str, Any]: """Send encrypted message as passthrough.""" url = f"http://{self._host}/app" - if self._login_token: + if self._state is TransportState.ESTABLISHED and self._login_token: url += f"?token={self._login_token}" encrypted_payload = self._encryption_session.encrypt(request.encode()) # type: ignore @@ -250,6 +250,7 @@ class AesTransport(BaseTransport): _LOGGER.debug("Will perform handshaking...") self._key_pair = None + self._login_token = None self._session_expire_at = None self._session_cookie = None @@ -284,9 +285,7 @@ class AesTransport(BaseTransport): handshake_key = resp_dict["result"]["key"] if ( - cookie := http_client.get_cookie( # type: ignore - self.SESSION_COOKIE_NAME - ) + cookie := http_client.get_cookie(self.SESSION_COOKIE_NAME) # type: ignore ) or ( cookie := http_client.get_cookie("SESSIONID") # type: ignore ): diff --git a/kasa/tests/test_aestransport.py b/kasa/tests/test_aestransport.py index 086f6ea6..9fe5cabd 100644 --- a/kasa/tests/test_aestransport.py +++ b/kasa/tests/test_aestransport.py @@ -1,9 +1,12 @@ import base64 import json +import random +import string import time from contextlib import nullcontext as does_not_raise from json import dumps as json_dumps from json import loads as json_loads +from typing import Any, Dict, Optional import aiohttp import pytest @@ -219,7 +222,6 @@ class MockAesDevice: return json_dumps(self._json).encode() encryption_session = AesEncyptionSession(KEY_IV[:16], KEY_IV[16:]) - token = "test_token" # noqa def __init__(self, host, status_code=200, error_code=0, inner_error_code=0): self.host = host @@ -228,6 +230,7 @@ class MockAesDevice: self._inner_error_code = inner_error_code self.http_client = HttpClient(DeviceConfig(self.host)) self.inner_call_count = 0 + self.token = "".join(random.choices(string.ascii_uppercase, k=32)) # noqa: S311 @property def inner_error_code(self): @@ -242,7 +245,7 @@ class MockAesDevice: json = json_loads(item.decode()) return await self._post(url, json) - async def _post(self, url, json): + async def _post(self, url: str, json: Dict[str, Any]): if json["method"] == "handshake": return await self._return_handshake_response(url, json) elif json["method"] == "securePassthrough": @@ -253,7 +256,7 @@ class MockAesDevice: assert url == f"http://{self.host}/app?token={self.token}" return await self._return_send_response(url, json) - async def _return_handshake_response(self, url, json): + async def _return_handshake_response(self, url: str, json: Dict[str, Any]): start = len("-----BEGIN PUBLIC KEY-----\n") end = len("\n-----END PUBLIC KEY-----\n") client_pub_key = json["params"]["key"][start:-end] @@ -266,7 +269,7 @@ class MockAesDevice: self.status_code, {"result": {"key": key_64}, "error_code": self.error_code} ) - async def _return_secure_passthrough_response(self, url, json): + async def _return_secure_passthrough_response(self, url: str, json: Dict[str, Any]): encrypted_request = json["params"]["request"] decrypted_request = self.encryption_session.decrypt(encrypted_request.encode()) decrypted_request_dict = json_loads(decrypted_request) @@ -283,12 +286,15 @@ class MockAesDevice: } return self._mock_response(self.status_code, result) - async def _return_login_response(self, url, json): + async def _return_login_response(self, url: str, json: Dict[str, Any]): + if "token=" in url: + raise Exception("token should not be in url for a login request") + self.token = "".join(random.choices(string.ascii_uppercase, k=32)) # noqa: S311 result = {"result": {"token": self.token}, "error_code": self.inner_error_code} self.inner_call_count += 1 return self._mock_response(self.status_code, result) - async def _return_send_response(self, url, json): + async def _return_send_response(self, url: str, json: Dict[str, Any]): result = {"result": {"method": None}, "error_code": self.inner_error_code} self.inner_call_count += 1 return self._mock_response(self.status_code, result)