From e233e377ad8748dc5a4ec8d706ad5f70209825f5 Mon Sep 17 00:00:00 2001 From: Steven B <51370195+sdb9696@users.noreply.github.com> Date: Tue, 23 Jan 2024 15:29:27 +0000 Subject: [PATCH] Generate AES KeyPair lazily (#687) * Generate AES KeyPair lazily * Fix coverage * Update post-review * Fix pragma * Make json dumps consistent between python and orjson * Add comment * Add comments re json parameter in HttpClient --- kasa/aestransport.py | 54 +++++++++++++++++++++------------ kasa/httpclient.py | 17 +++++++++-- kasa/json.py | 6 +++- kasa/tests/test_aestransport.py | 5 ++- 4 files changed, 58 insertions(+), 24 deletions(-) diff --git a/kasa/aestransport.py b/kasa/aestransport.py index cd810b8f..14a9ee6a 100644 --- a/kasa/aestransport.py +++ b/kasa/aestransport.py @@ -8,7 +8,7 @@ import base64 import hashlib import logging import time -from typing import Dict, Optional, cast +from typing import TYPE_CHECKING, AsyncGenerator, Dict, Optional, cast from cryptography.hazmat.primitives import padding, serialization from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding @@ -55,6 +55,8 @@ class AesTransport(BaseTransport): "requestByApp": "true", "Accept": "application/json", } + CONTENT_LENGTH = "Content-Length" + KEY_PAIR_CONTENT_LENGTH = 314 def __init__( self, @@ -86,6 +88,8 @@ class AesTransport(BaseTransport): self._login_token = None + self._key_pair: Optional[KeyPair] = None + _LOGGER.debug("Created AES transport for %s", self._host) @property @@ -204,34 +208,44 @@ class AesTransport(BaseTransport): self._handle_response_error_code(resp_dict, "Error logging in") self._login_token = resp_dict["result"]["token"] + async def _generate_key_pair_payload(self) -> AsyncGenerator: + """Generate the request body and return an ascyn_generator. + + This prevents the key pair being generated unless a connection + can be made to the device. + """ + _LOGGER.debug("Generating keypair") + self._key_pair = KeyPair.create_key_pair() + pub_key = ( + "-----BEGIN PUBLIC KEY-----\n" + + self._key_pair.get_public_key() # type: ignore[union-attr] + + "\n-----END PUBLIC KEY-----\n" + ) + handshake_params = {"key": pub_key} + _LOGGER.debug(f"Handshake params: {handshake_params}") + request_body = {"method": "handshake", "params": handshake_params} + _LOGGER.debug(f"Request {request_body}") + yield json_dumps(request_body).encode() + async def perform_handshake(self): """Perform the handshake.""" _LOGGER.debug("Will perform handshaking...") - _LOGGER.debug("Generating keypair") + self._key_pair = None self._handshake_done = False self._session_expire_at = None self._session_cookie = None url = f"http://{self._host}/app" - key_pair = KeyPair.create_key_pair() - - pub_key = ( - "-----BEGIN PUBLIC KEY-----\n" - + key_pair.get_public_key() - + "\n-----END PUBLIC KEY-----\n" - ) - handshake_params = {"key": pub_key} - _LOGGER.debug(f"Handshake params: {handshake_params}") - - request_body = {"method": "handshake", "params": handshake_params} - - _LOGGER.debug(f"Request {request_body}") - + # Device needs the content length or it will response with 500 + headers = { + **self.COMMON_HEADERS, + self.CONTENT_LENGTH: str(self.KEY_PAIR_CONTENT_LENGTH), + } status_code, resp_dict = await self._http_client.post( url, - json=request_body, - headers=self.COMMON_HEADERS, + json=self._generate_key_pair_payload(), + headers=headers, cookies_dict=self._session_cookie, ) @@ -259,8 +273,10 @@ class AesTransport(BaseTransport): self._session_cookie = {self.SESSION_COOKIE_NAME: cookie} self._session_expire_at = time.time() + 86400 + if TYPE_CHECKING: + assert self._key_pair is not None # pragma: no cover self._encryption_session = AesEncyptionSession.create_from_keypair( - handshake_key, key_pair + handshake_key, self._key_pair ) self._handshake_done = True diff --git a/kasa/httpclient.py b/kasa/httpclient.py index a4bd84a3..28a19e8b 100644 --- a/kasa/httpclient.py +++ b/kasa/httpclient.py @@ -41,14 +41,25 @@ class HttpClient: *, params: Optional[Dict[str, Any]] = None, data: Optional[bytes] = None, - json: Optional[Dict] = None, + json: Optional[Union[Dict, Any]] = None, headers: Optional[Dict[str, str]] = None, cookies_dict: Optional[Dict[str, str]] = None, ) -> Tuple[int, Optional[Union[Dict, bytes]]]: - """Send an http post request to the device.""" + """Send an http post request to the device. + + If the request is provided via the json parameter json will be returned. + """ response_data = None self._last_url = url self.client.cookie_jar.clear() + return_json = bool(json) + # If json is not a dict send as data. + # This allows the json parameter to be used to pass other + # types of data such as async_generator and still have json + # returned. + if json and not isinstance(json, Dict): + data = json + json = None try: resp = await self.client.post( url, @@ -62,7 +73,7 @@ class HttpClient: async with resp: if resp.status == 200: response_data = await resp.read() - if json: + if return_json: response_data = json_loads(response_data.decode()) except (aiohttp.ServerDisconnectedError, aiohttp.ClientOSError) as ex: diff --git a/kasa/json.py b/kasa/json.py index 4acc865f..aed8cd56 100755 --- a/kasa/json.py +++ b/kasa/json.py @@ -11,5 +11,9 @@ try: except ImportError: import json - dumps = json.dumps + def dumps(obj, *, default=None): + """Dump JSON.""" + # Separators specified for consistency with orjson + return json.dumps(obj, separators=(",", ":")) + loads = json.loads diff --git a/kasa/tests/test_aestransport.py b/kasa/tests/test_aestransport.py index 748dae9a..4694e363 100644 --- a/kasa/tests/test_aestransport.py +++ b/kasa/tests/test_aestransport.py @@ -225,7 +225,10 @@ class MockAesDevice: else: return self._inner_error_code - async def post(self, url, params=None, json=None, *_, **__): + async def post(self, url, params=None, json=None, data=None, *_, **__): + if data: + async for item in data: + json = json_loads(item.decode()) return await self._post(url, json) async def _post(self, url, json):