From b479b6d84da51c5ab3e0894d53a9e3b60131a3df Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 29 Jan 2024 05:26:00 -1000 Subject: [PATCH] Avoid rebuilding urls for every request (#715) * Avoid rebuilding urls for every request * more fixes * more fixes * make mypy happy * reduce * tweak * fix tests * fix tests * tweak * tweak * lint * fix type --- kasa/aestransport.py | 17 +++++---- kasa/exceptions.py | 6 ++-- kasa/httpclient.py | 5 +-- kasa/klaptransport.py | 11 +++--- kasa/tests/test_aestransport.py | 33 ++++++++++-------- kasa/tests/test_httpclient.py | 2 +- kasa/tests/test_klapprotocol.py | 61 +++++++++++++++++---------------- 7 files changed, 73 insertions(+), 62 deletions(-) diff --git a/kasa/aestransport.py b/kasa/aestransport.py index 4e1ccb7d..f784390b 100644 --- a/kasa/aestransport.py +++ b/kasa/aestransport.py @@ -15,6 +15,7 @@ from cryptography.hazmat.primitives import padding, serialization from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes +from yarl import URL from .credentials import Credentials from .deviceconfig import DeviceConfig @@ -100,9 +101,9 @@ class AesTransport(BaseTransport): self._session_cookie: Optional[Dict[str, str]] = None - self._login_token: Optional[str] = None - self._key_pair: Optional[KeyPair] = None + self._app_url = URL(f"http://{self._host}/app") + self._token_url: Optional[URL] = None _LOGGER.debug("Created AES transport for %s", self._host) @@ -150,9 +151,10 @@ class AesTransport(BaseTransport): async def send_secure_passthrough(self, request: str) -> Dict[str, Any]: """Send encrypted message as passthrough.""" - url = f"http://{self._host}/app" - if self._state is TransportState.ESTABLISHED and self._login_token: - url += f"?token={self._login_token}" + if self._state is TransportState.ESTABLISHED and self._token_url: + url = self._token_url + else: + url = self._app_url encrypted_payload = self._encryption_session.encrypt(request.encode()) # type: ignore passthrough_request = { @@ -223,7 +225,8 @@ class AesTransport(BaseTransport): resp_dict = await self.send_secure_passthrough(request) self._handle_response_error_code(resp_dict, "Error logging in") - self._login_token = resp_dict["result"]["token"] + login_token = resp_dict["result"]["token"] + self._token_url = self._app_url.with_query(f"token={login_token}") self._state = TransportState.ESTABLISHED async def _generate_key_pair_payload(self) -> AsyncGenerator: @@ -250,7 +253,7 @@ class AesTransport(BaseTransport): _LOGGER.debug("Will perform handshaking...") self._key_pair = None - self._login_token = None + self._token_url = None self._session_expire_at = None self._session_cookie = None diff --git a/kasa/exceptions.py b/kasa/exceptions.py index fb86ef14..75f09169 100644 --- a/kasa/exceptions.py +++ b/kasa/exceptions.py @@ -1,13 +1,13 @@ """python-kasa exceptions.""" from asyncio import TimeoutError from enum import IntEnum -from typing import Optional +from typing import Any, Optional class SmartDeviceException(Exception): """Base exception for device errors.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: self.error_code: Optional["SmartErrorCode"] = kwargs.get("error_code", None) super().__init__(*args) @@ -15,7 +15,7 @@ class SmartDeviceException(Exception): class UnsupportedDeviceException(SmartDeviceException): """Exception for trying to connect to unsupported devices.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: self.discovery_result = kwargs.get("discovery_result") super().__init__(*args, **kwargs) diff --git a/kasa/httpclient.py b/kasa/httpclient.py index 7fe0b2c3..659ebdcf 100644 --- a/kasa/httpclient.py +++ b/kasa/httpclient.py @@ -3,6 +3,7 @@ import asyncio from typing import Any, Dict, Optional, Tuple, Union import aiohttp +from yarl import URL from .deviceconfig import DeviceConfig from .exceptions import ( @@ -25,7 +26,7 @@ class HttpClient: self._config = config self._client_session: aiohttp.ClientSession = None self._jar = aiohttp.CookieJar(unsafe=True, quote_cookie=False) - self._last_url = f"http://{self._config.host}/" + self._last_url = URL(f"http://{self._config.host}/") @property def client(self) -> aiohttp.ClientSession: @@ -41,7 +42,7 @@ class HttpClient: async def post( self, - url: str, + url: URL, *, params: Optional[Dict[str, Any]] = None, data: Optional[bytes] = None, diff --git a/kasa/klaptransport.py b/kasa/klaptransport.py index 898444c2..0e585f2c 100644 --- a/kasa/klaptransport.py +++ b/kasa/klaptransport.py @@ -53,6 +53,7 @@ from typing import Any, Dict, Optional, Tuple, cast from cryptography.hazmat.primitives import padding from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes +from yarl import URL from .credentials import Credentials from .deviceconfig import DeviceConfig @@ -120,6 +121,8 @@ class KlapTransport(BaseTransport): self._session_cookie: Optional[Dict[str, Any]] = None _LOGGER.debug("Created KLAP transport for %s", self._host) + self._app_url = URL(f"http://{self._host}/app") + self._request_url = self._app_url / "request" @property def default_port(self): @@ -141,7 +144,7 @@ class KlapTransport(BaseTransport): payload = local_seed - url = f"http://{self._host}/app/handshake1" + url = self._app_url / "handshake1" response_status, response_data = await self._http_client.post(url, data=payload) @@ -236,7 +239,7 @@ class KlapTransport(BaseTransport): # Handshake 2 has the following payload: # sha256(serverBytes | authenticator) - url = f"http://{self._host}/app/handshake2" + url = self._app_url / "handshake2" payload = self.handshake2_seed_auth_hash(local_seed, remote_seed, auth_hash) @@ -309,10 +312,8 @@ class KlapTransport(BaseTransport): if self._encryption_session is not None: payload, seq = self._encryption_session.encrypt(request.encode()) - url = f"http://{self._host}/app/request" - response_status, response_data = await self._http_client.post( - url, + self._request_url, params={"seq": seq}, data=payload, cookies_dict=self._session_cookie, diff --git a/kasa/tests/test_aestransport.py b/kasa/tests/test_aestransport.py index 9fe5cabd..151952bd 100644 --- a/kasa/tests/test_aestransport.py +++ b/kasa/tests/test_aestransport.py @@ -12,6 +12,7 @@ import aiohttp import pytest from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding +from yarl import URL from ..aestransport import AesEncyptionSession, AesTransport, TransportState from ..credentials import Credentials @@ -89,10 +90,10 @@ async def test_login(mocker, status_code, error_code, inner_error_code, expectat transport._session_expire_at = time.time() + 86400 transport._encryption_session = mock_aes_device.encryption_session - assert transport._login_token is None + assert transport._token_url is None with expectation: await transport.perform_login() - assert transport._login_token == mock_aes_device.token + assert mock_aes_device.token in str(transport._token_url) @pytest.mark.parametrize( @@ -136,7 +137,7 @@ async def test_login_errors(mocker, inner_error_codes, expectation, call_count): transport._session_expire_at = time.time() + 86400 transport._encryption_session = mock_aes_device.encryption_session - assert transport._login_token is None + assert transport._token_url is None request = { "method": "get_device_info", @@ -148,7 +149,7 @@ async def test_login_errors(mocker, inner_error_codes, expectation, call_count): with expectation: await transport.send(json_dumps(request)) - assert transport._login_token == mock_aes_device.token + assert mock_aes_device.token in str(transport._token_url) assert post_mock.call_count == call_count # Login, Handshake, Login await transport.close() @@ -165,7 +166,9 @@ async def test_send(mocker, status_code, error_code, inner_error_code, expectati transport._handshake_done = True transport._session_expire_at = time.time() + 86400 transport._encryption_session = mock_aes_device.encryption_session - transport._login_token = mock_aes_device.token + transport._token_url = transport._app_url.with_query( + f"token={mock_aes_device.token}" + ) request = { "method": "get_device_info", @@ -193,7 +196,9 @@ async def test_passthrough_errors(mocker, error_code): transport._handshake_done = True transport._session_expire_at = time.time() + 86400 transport._encryption_session = mock_aes_device.encryption_session - transport._login_token = mock_aes_device.token + transport._token_url = transport._app_url.with_query( + f"token={mock_aes_device.token}" + ) request = { "method": "get_device_info", @@ -239,13 +244,13 @@ class MockAesDevice: else: return self._inner_error_code - async def post(self, url, params=None, json=None, data=None, *_, **__): + async def post(self, url: URL, params=None, json=None, data=None, *_, **__): if data: async for item in data: json = json_loads(item.decode()) return await self._post(url, json) - async def _post(self, url: str, json: Dict[str, Any]): + async def _post(self, url: URL, json: Dict[str, Any]): if json["method"] == "handshake": return await self._return_handshake_response(url, json) elif json["method"] == "securePassthrough": @@ -253,10 +258,10 @@ class MockAesDevice: elif json["method"] == "login_device": return await self._return_login_response(url, json) else: - assert url == f"http://{self.host}/app?token={self.token}" + assert str(url) == f"http://{self.host}/app?token={self.token}" return await self._return_send_response(url, json) - async def _return_handshake_response(self, url: str, json: Dict[str, Any]): + async def _return_handshake_response(self, url: URL, json: Dict[str, Any]): start = len("-----BEGIN PUBLIC KEY-----\n") end = len("\n-----END PUBLIC KEY-----\n") client_pub_key = json["params"]["key"][start:-end] @@ -269,7 +274,7 @@ class MockAesDevice: self.status_code, {"result": {"key": key_64}, "error_code": self.error_code} ) - async def _return_secure_passthrough_response(self, url: str, json: Dict[str, Any]): + async def _return_secure_passthrough_response(self, url: URL, json: Dict[str, Any]): encrypted_request = json["params"]["request"] decrypted_request = self.encryption_session.decrypt(encrypted_request.encode()) decrypted_request_dict = json_loads(decrypted_request) @@ -286,15 +291,15 @@ class MockAesDevice: } return self._mock_response(self.status_code, result) - async def _return_login_response(self, url: str, json: Dict[str, Any]): - if "token=" in url: + async def _return_login_response(self, url: URL, json: Dict[str, Any]): + if "token=" in str(url): raise Exception("token should not be in url for a login request") self.token = "".join(random.choices(string.ascii_uppercase, k=32)) # noqa: S311 result = {"result": {"token": self.token}, "error_code": self.inner_error_code} self.inner_call_count += 1 return self._mock_response(self.status_code, result) - async def _return_send_response(self, url: str, json: Dict[str, Any]): + async def _return_send_response(self, url: URL, json: Dict[str, Any]): result = {"result": {"method": None}, "error_code": self.inner_error_code} self.inner_call_count += 1 return self._mock_response(self.status_code, result) diff --git a/kasa/tests/test_httpclient.py b/kasa/tests/test_httpclient.py index e178b818..2afabba0 100644 --- a/kasa/tests/test_httpclient.py +++ b/kasa/tests/test_httpclient.py @@ -84,7 +84,7 @@ async def test_httpclient_errors(mocker, error, error_raises, error_message, moc client = HttpClient(DeviceConfig(host)) # Exceptions with parameters print with double quotes, without use single quotes full_msg = ( - "\(" + "\(" # type: ignore + "['\"]" + re.escape(f"{error_message}{host}: {error}") + "['\"]" diff --git a/kasa/tests/test_klapprotocol.py b/kasa/tests/test_klapprotocol.py index 4d711f03..b69d5070 100644 --- a/kasa/tests/test_klapprotocol.py +++ b/kasa/tests/test_klapprotocol.py @@ -10,6 +10,7 @@ from unittest.mock import PropertyMock import aiohttp import pytest +from yarl import URL from ..aestransport import AesTransport from ..credentials import Credentials @@ -318,28 +319,28 @@ async def test_handshake1( async def test_handshake( mocker, transport_class, seed_auth_hash_calc1, seed_auth_hash_calc2 ): - async def _return_handshake_response(url, params=None, data=None, *_, **__): + client_seed = None + server_seed = secrets.token_bytes(16) + client_credentials = Credentials("foo", "bar") + device_auth_hash = transport_class.generate_auth_hash(client_credentials) + + async def _return_handshake_response(url: URL, params=None, data=None, *_, **__): nonlocal client_seed, server_seed, device_auth_hash - if url == "http://127.0.0.1/app/handshake1": + if str(url) == "http://127.0.0.1/app/handshake1": client_seed = data seed_auth_hash = _sha256( seed_auth_hash_calc1(client_seed, server_seed, device_auth_hash) ) return _mock_response(200, server_seed + seed_auth_hash) - elif url == "http://127.0.0.1/app/handshake2": + elif str(url) == "http://127.0.0.1/app/handshake2": seed_auth_hash = _sha256( seed_auth_hash_calc2(client_seed, server_seed, device_auth_hash) ) assert data == seed_auth_hash return _mock_response(response_status, b"") - client_seed = None - server_seed = secrets.token_bytes(16) - client_credentials = Credentials("foo", "bar") - device_auth_hash = transport_class.generate_auth_hash(client_credentials) - mocker.patch.object( aiohttp.ClientSession, "post", side_effect=_return_handshake_response ) @@ -360,17 +361,24 @@ async def test_handshake( async def test_query(mocker): - async def _return_response(url, params=None, data=None, *_, **__): + client_seed = None + last_seq = None + seq = None + server_seed = secrets.token_bytes(16) + client_credentials = Credentials("foo", "bar") + device_auth_hash = KlapTransport.generate_auth_hash(client_credentials) + + async def _return_response(url: URL, params=None, data=None, *_, **__): nonlocal client_seed, server_seed, device_auth_hash, seq - if url == "http://127.0.0.1/app/handshake1": + if str(url) == "http://127.0.0.1/app/handshake1": client_seed = data client_seed_auth_hash = _sha256(data + device_auth_hash) return _mock_response(200, server_seed + client_seed_auth_hash) - elif url == "http://127.0.0.1/app/handshake2": + elif str(url) == "http://127.0.0.1/app/handshake2": return _mock_response(200, b"") - elif url == "http://127.0.0.1/app/request": + elif str(url) == "http://127.0.0.1/app/request": encryption_session = KlapEncryptionSession( protocol._transport._encryption_session.local_seed, protocol._transport._encryption_session.remote_seed, @@ -382,13 +390,6 @@ async def test_query(mocker): seq = seq return _mock_response(200, encrypted) - client_seed = None - last_seq = None - seq = None - server_seed = secrets.token_bytes(16) - client_credentials = Credentials("foo", "bar") - device_auth_hash = KlapTransport.generate_auth_hash(client_credentials) - mocker.patch.object(aiohttp.ClientSession, "post", side_effect=_return_response) config = DeviceConfig("127.0.0.1", credentials=client_credentials) @@ -413,26 +414,26 @@ async def test_query(mocker): ids=("handshake1", "handshake2", "request", "non_auth_error"), ) async def test_authentication_failures(mocker, response_status, expectation): - async def _return_response(url, params=None, data=None, *_, **__): + client_seed = None + + server_seed = secrets.token_bytes(16) + client_credentials = Credentials("foo", "bar") + device_auth_hash = KlapTransport.generate_auth_hash(client_credentials) + + async def _return_response(url: URL, params=None, data=None, *_, **__): nonlocal client_seed, server_seed, device_auth_hash, response_status - if url == "http://127.0.0.1/app/handshake1": + if str(url) == "http://127.0.0.1/app/handshake1": client_seed = data client_seed_auth_hash = _sha256(data + device_auth_hash) return _mock_response( response_status[0], server_seed + client_seed_auth_hash ) - elif url == "http://127.0.0.1/app/handshake2": + elif str(url) == "http://127.0.0.1/app/handshake2": return _mock_response(response_status[1], b"") - elif url == "http://127.0.0.1/app/request": - return _mock_response(response_status[2], None) - - client_seed = None - - server_seed = secrets.token_bytes(16) - client_credentials = Credentials("foo", "bar") - device_auth_hash = KlapTransport.generate_auth_hash(client_credentials) + elif str(url) == "http://127.0.0.1/app/request": + return _mock_response(response_status[2], b"") mocker.patch.object(aiohttp.ClientSession, "post", side_effect=_return_response)