Generate AES KeyPair lazily (#687)

* Generate AES KeyPair lazily

* Fix coverage

* Update post-review

* Fix pragma

* Make json dumps consistent between python and orjson

* Add comment

* Add comments re json parameter in HttpClient
This commit is contained in:
Steven B
2024-01-23 15:29:27 +00:00
committed by GitHub
parent 718983c401
commit e233e377ad
4 changed files with 58 additions and 24 deletions

View File

@@ -8,7 +8,7 @@ import base64
import hashlib
import logging
import time
from typing import Dict, Optional, cast
from typing import TYPE_CHECKING, AsyncGenerator, Dict, Optional, cast
from cryptography.hazmat.primitives import padding, serialization
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
@@ -55,6 +55,8 @@ class AesTransport(BaseTransport):
"requestByApp": "true",
"Accept": "application/json",
}
CONTENT_LENGTH = "Content-Length"
KEY_PAIR_CONTENT_LENGTH = 314
def __init__(
self,
@@ -86,6 +88,8 @@ class AesTransport(BaseTransport):
self._login_token = None
self._key_pair: Optional[KeyPair] = None
_LOGGER.debug("Created AES transport for %s", self._host)
@property
@@ -204,34 +208,44 @@ class AesTransport(BaseTransport):
self._handle_response_error_code(resp_dict, "Error logging in")
self._login_token = resp_dict["result"]["token"]
async def _generate_key_pair_payload(self) -> AsyncGenerator:
"""Generate the request body and return an ascyn_generator.
This prevents the key pair being generated unless a connection
can be made to the device.
"""
_LOGGER.debug("Generating keypair")
self._key_pair = KeyPair.create_key_pair()
pub_key = (
"-----BEGIN PUBLIC KEY-----\n"
+ self._key_pair.get_public_key() # type: ignore[union-attr]
+ "\n-----END PUBLIC KEY-----\n"
)
handshake_params = {"key": pub_key}
_LOGGER.debug(f"Handshake params: {handshake_params}")
request_body = {"method": "handshake", "params": handshake_params}
_LOGGER.debug(f"Request {request_body}")
yield json_dumps(request_body).encode()
async def perform_handshake(self):
"""Perform the handshake."""
_LOGGER.debug("Will perform handshaking...")
_LOGGER.debug("Generating keypair")
self._key_pair = None
self._handshake_done = False
self._session_expire_at = None
self._session_cookie = None
url = f"http://{self._host}/app"
key_pair = KeyPair.create_key_pair()
pub_key = (
"-----BEGIN PUBLIC KEY-----\n"
+ key_pair.get_public_key()
+ "\n-----END PUBLIC KEY-----\n"
)
handshake_params = {"key": pub_key}
_LOGGER.debug(f"Handshake params: {handshake_params}")
request_body = {"method": "handshake", "params": handshake_params}
_LOGGER.debug(f"Request {request_body}")
# Device needs the content length or it will response with 500
headers = {
**self.COMMON_HEADERS,
self.CONTENT_LENGTH: str(self.KEY_PAIR_CONTENT_LENGTH),
}
status_code, resp_dict = await self._http_client.post(
url,
json=request_body,
headers=self.COMMON_HEADERS,
json=self._generate_key_pair_payload(),
headers=headers,
cookies_dict=self._session_cookie,
)
@@ -259,8 +273,10 @@ class AesTransport(BaseTransport):
self._session_cookie = {self.SESSION_COOKIE_NAME: cookie}
self._session_expire_at = time.time() + 86400
if TYPE_CHECKING:
assert self._key_pair is not None # pragma: no cover
self._encryption_session = AesEncyptionSession.create_from_keypair(
handshake_key, key_pair
handshake_key, self._key_pair
)
self._handshake_done = True