Generate AES KeyPair lazily (#687)

* Generate AES KeyPair lazily

* Fix coverage

* Update post-review

* Fix pragma

* Make json dumps consistent between python and orjson

* Add comment

* Add comments re json parameter in HttpClient
This commit is contained in:
Steven B 2024-01-23 15:29:27 +00:00 committed by GitHub
parent 718983c401
commit e233e377ad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 58 additions and 24 deletions

View File

@ -8,7 +8,7 @@ import base64
import hashlib import hashlib
import logging import logging
import time import time
from typing import Dict, Optional, cast from typing import TYPE_CHECKING, AsyncGenerator, Dict, Optional, cast
from cryptography.hazmat.primitives import padding, serialization from cryptography.hazmat.primitives import padding, serialization
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
@ -55,6 +55,8 @@ class AesTransport(BaseTransport):
"requestByApp": "true", "requestByApp": "true",
"Accept": "application/json", "Accept": "application/json",
} }
CONTENT_LENGTH = "Content-Length"
KEY_PAIR_CONTENT_LENGTH = 314
def __init__( def __init__(
self, self,
@ -86,6 +88,8 @@ class AesTransport(BaseTransport):
self._login_token = None self._login_token = None
self._key_pair: Optional[KeyPair] = None
_LOGGER.debug("Created AES transport for %s", self._host) _LOGGER.debug("Created AES transport for %s", self._host)
@property @property
@ -204,34 +208,44 @@ class AesTransport(BaseTransport):
self._handle_response_error_code(resp_dict, "Error logging in") self._handle_response_error_code(resp_dict, "Error logging in")
self._login_token = resp_dict["result"]["token"] self._login_token = resp_dict["result"]["token"]
async def _generate_key_pair_payload(self) -> AsyncGenerator:
"""Generate the request body and return an ascyn_generator.
This prevents the key pair being generated unless a connection
can be made to the device.
"""
_LOGGER.debug("Generating keypair")
self._key_pair = KeyPair.create_key_pair()
pub_key = (
"-----BEGIN PUBLIC KEY-----\n"
+ self._key_pair.get_public_key() # type: ignore[union-attr]
+ "\n-----END PUBLIC KEY-----\n"
)
handshake_params = {"key": pub_key}
_LOGGER.debug(f"Handshake params: {handshake_params}")
request_body = {"method": "handshake", "params": handshake_params}
_LOGGER.debug(f"Request {request_body}")
yield json_dumps(request_body).encode()
async def perform_handshake(self): async def perform_handshake(self):
"""Perform the handshake.""" """Perform the handshake."""
_LOGGER.debug("Will perform handshaking...") _LOGGER.debug("Will perform handshaking...")
_LOGGER.debug("Generating keypair")
self._key_pair = None
self._handshake_done = False self._handshake_done = False
self._session_expire_at = None self._session_expire_at = None
self._session_cookie = None self._session_cookie = None
url = f"http://{self._host}/app" url = f"http://{self._host}/app"
key_pair = KeyPair.create_key_pair() # Device needs the content length or it will response with 500
headers = {
pub_key = ( **self.COMMON_HEADERS,
"-----BEGIN PUBLIC KEY-----\n" self.CONTENT_LENGTH: str(self.KEY_PAIR_CONTENT_LENGTH),
+ key_pair.get_public_key() }
+ "\n-----END PUBLIC KEY-----\n"
)
handshake_params = {"key": pub_key}
_LOGGER.debug(f"Handshake params: {handshake_params}")
request_body = {"method": "handshake", "params": handshake_params}
_LOGGER.debug(f"Request {request_body}")
status_code, resp_dict = await self._http_client.post( status_code, resp_dict = await self._http_client.post(
url, url,
json=request_body, json=self._generate_key_pair_payload(),
headers=self.COMMON_HEADERS, headers=headers,
cookies_dict=self._session_cookie, cookies_dict=self._session_cookie,
) )
@ -259,8 +273,10 @@ class AesTransport(BaseTransport):
self._session_cookie = {self.SESSION_COOKIE_NAME: cookie} self._session_cookie = {self.SESSION_COOKIE_NAME: cookie}
self._session_expire_at = time.time() + 86400 self._session_expire_at = time.time() + 86400
if TYPE_CHECKING:
assert self._key_pair is not None # pragma: no cover
self._encryption_session = AesEncyptionSession.create_from_keypair( self._encryption_session = AesEncyptionSession.create_from_keypair(
handshake_key, key_pair handshake_key, self._key_pair
) )
self._handshake_done = True self._handshake_done = True

View File

@ -41,14 +41,25 @@ class HttpClient:
*, *,
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
data: Optional[bytes] = None, data: Optional[bytes] = None,
json: Optional[Dict] = None, json: Optional[Union[Dict, Any]] = None,
headers: Optional[Dict[str, str]] = None, headers: Optional[Dict[str, str]] = None,
cookies_dict: Optional[Dict[str, str]] = None, cookies_dict: Optional[Dict[str, str]] = None,
) -> Tuple[int, Optional[Union[Dict, bytes]]]: ) -> Tuple[int, Optional[Union[Dict, bytes]]]:
"""Send an http post request to the device.""" """Send an http post request to the device.
If the request is provided via the json parameter json will be returned.
"""
response_data = None response_data = None
self._last_url = url self._last_url = url
self.client.cookie_jar.clear() self.client.cookie_jar.clear()
return_json = bool(json)
# If json is not a dict send as data.
# This allows the json parameter to be used to pass other
# types of data such as async_generator and still have json
# returned.
if json and not isinstance(json, Dict):
data = json
json = None
try: try:
resp = await self.client.post( resp = await self.client.post(
url, url,
@ -62,7 +73,7 @@ class HttpClient:
async with resp: async with resp:
if resp.status == 200: if resp.status == 200:
response_data = await resp.read() response_data = await resp.read()
if json: if return_json:
response_data = json_loads(response_data.decode()) response_data = json_loads(response_data.decode())
except (aiohttp.ServerDisconnectedError, aiohttp.ClientOSError) as ex: except (aiohttp.ServerDisconnectedError, aiohttp.ClientOSError) as ex:

View File

@ -11,5 +11,9 @@ try:
except ImportError: except ImportError:
import json import json
dumps = json.dumps def dumps(obj, *, default=None):
"""Dump JSON."""
# Separators specified for consistency with orjson
return json.dumps(obj, separators=(",", ":"))
loads = json.loads loads = json.loads

View File

@ -225,7 +225,10 @@ class MockAesDevice:
else: else:
return self._inner_error_code return self._inner_error_code
async def post(self, url, params=None, json=None, *_, **__): async def post(self, url, params=None, json=None, data=None, *_, **__):
if data:
async for item in data:
json = json_loads(item.decode())
return await self._post(url, json) return await self._post(url, json)
async def _post(self, url, json): async def _post(self, url, json):