Generate AES KeyPair lazily (#687)

* Generate AES KeyPair lazily

* Fix coverage

* Update post-review

* Fix pragma

* Make json dumps consistent between python and orjson

* Add comment

* Add comments re json parameter in HttpClient
This commit is contained in:
Steven B 2024-01-23 15:29:27 +00:00 committed by GitHub
parent 718983c401
commit e233e377ad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 58 additions and 24 deletions

View File

@ -8,7 +8,7 @@ import base64
import hashlib
import logging
import time
from typing import Dict, Optional, cast
from typing import TYPE_CHECKING, AsyncGenerator, Dict, Optional, cast
from cryptography.hazmat.primitives import padding, serialization
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
@ -55,6 +55,8 @@ class AesTransport(BaseTransport):
"requestByApp": "true",
"Accept": "application/json",
}
CONTENT_LENGTH = "Content-Length"
KEY_PAIR_CONTENT_LENGTH = 314
def __init__(
self,
@ -86,6 +88,8 @@ class AesTransport(BaseTransport):
self._login_token = None
self._key_pair: Optional[KeyPair] = None
_LOGGER.debug("Created AES transport for %s", self._host)
@property
@ -204,34 +208,44 @@ class AesTransport(BaseTransport):
self._handle_response_error_code(resp_dict, "Error logging in")
self._login_token = resp_dict["result"]["token"]
async def _generate_key_pair_payload(self) -> AsyncGenerator:
"""Generate the request body and return an ascyn_generator.
This prevents the key pair being generated unless a connection
can be made to the device.
"""
_LOGGER.debug("Generating keypair")
self._key_pair = KeyPair.create_key_pair()
pub_key = (
"-----BEGIN PUBLIC KEY-----\n"
+ self._key_pair.get_public_key() # type: ignore[union-attr]
+ "\n-----END PUBLIC KEY-----\n"
)
handshake_params = {"key": pub_key}
_LOGGER.debug(f"Handshake params: {handshake_params}")
request_body = {"method": "handshake", "params": handshake_params}
_LOGGER.debug(f"Request {request_body}")
yield json_dumps(request_body).encode()
async def perform_handshake(self):
"""Perform the handshake."""
_LOGGER.debug("Will perform handshaking...")
_LOGGER.debug("Generating keypair")
self._key_pair = None
self._handshake_done = False
self._session_expire_at = None
self._session_cookie = None
url = f"http://{self._host}/app"
key_pair = KeyPair.create_key_pair()
pub_key = (
"-----BEGIN PUBLIC KEY-----\n"
+ key_pair.get_public_key()
+ "\n-----END PUBLIC KEY-----\n"
)
handshake_params = {"key": pub_key}
_LOGGER.debug(f"Handshake params: {handshake_params}")
request_body = {"method": "handshake", "params": handshake_params}
_LOGGER.debug(f"Request {request_body}")
# Device needs the content length or it will response with 500
headers = {
**self.COMMON_HEADERS,
self.CONTENT_LENGTH: str(self.KEY_PAIR_CONTENT_LENGTH),
}
status_code, resp_dict = await self._http_client.post(
url,
json=request_body,
headers=self.COMMON_HEADERS,
json=self._generate_key_pair_payload(),
headers=headers,
cookies_dict=self._session_cookie,
)
@ -259,8 +273,10 @@ class AesTransport(BaseTransport):
self._session_cookie = {self.SESSION_COOKIE_NAME: cookie}
self._session_expire_at = time.time() + 86400
if TYPE_CHECKING:
assert self._key_pair is not None # pragma: no cover
self._encryption_session = AesEncyptionSession.create_from_keypair(
handshake_key, key_pair
handshake_key, self._key_pair
)
self._handshake_done = True

View File

@ -41,14 +41,25 @@ class HttpClient:
*,
params: Optional[Dict[str, Any]] = None,
data: Optional[bytes] = None,
json: Optional[Dict] = None,
json: Optional[Union[Dict, Any]] = None,
headers: Optional[Dict[str, str]] = None,
cookies_dict: Optional[Dict[str, str]] = None,
) -> Tuple[int, Optional[Union[Dict, bytes]]]:
"""Send an http post request to the device."""
"""Send an http post request to the device.
If the request is provided via the json parameter json will be returned.
"""
response_data = None
self._last_url = url
self.client.cookie_jar.clear()
return_json = bool(json)
# If json is not a dict send as data.
# This allows the json parameter to be used to pass other
# types of data such as async_generator and still have json
# returned.
if json and not isinstance(json, Dict):
data = json
json = None
try:
resp = await self.client.post(
url,
@ -62,7 +73,7 @@ class HttpClient:
async with resp:
if resp.status == 200:
response_data = await resp.read()
if json:
if return_json:
response_data = json_loads(response_data.decode())
except (aiohttp.ServerDisconnectedError, aiohttp.ClientOSError) as ex:

View File

@ -11,5 +11,9 @@ try:
except ImportError:
import json
dumps = json.dumps
def dumps(obj, *, default=None):
"""Dump JSON."""
# Separators specified for consistency with orjson
return json.dumps(obj, separators=(",", ":"))
loads = json.loads

View File

@ -225,7 +225,10 @@ class MockAesDevice:
else:
return self._inner_error_code
async def post(self, url, params=None, json=None, *_, **__):
async def post(self, url, params=None, json=None, data=None, *_, **__):
if data:
async for item in data:
json = json_loads(item.decode())
return await self._post(url, json)
async def _post(self, url, json):