mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-10-23 15:58:02 +00:00
Generate AES KeyPair lazily (#687)
* Generate AES KeyPair lazily * Fix coverage * Update post-review * Fix pragma * Make json dumps consistent between python and orjson * Add comment * Add comments re json parameter in HttpClient
This commit is contained in:
@@ -8,7 +8,7 @@ import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, Optional, cast
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, Dict, Optional, cast
|
||||
|
||||
from cryptography.hazmat.primitives import padding, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
|
||||
@@ -55,6 +55,8 @@ class AesTransport(BaseTransport):
|
||||
"requestByApp": "true",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
CONTENT_LENGTH = "Content-Length"
|
||||
KEY_PAIR_CONTENT_LENGTH = 314
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -86,6 +88,8 @@ class AesTransport(BaseTransport):
|
||||
|
||||
self._login_token = None
|
||||
|
||||
self._key_pair: Optional[KeyPair] = None
|
||||
|
||||
_LOGGER.debug("Created AES transport for %s", self._host)
|
||||
|
||||
@property
|
||||
@@ -204,34 +208,44 @@ class AesTransport(BaseTransport):
|
||||
self._handle_response_error_code(resp_dict, "Error logging in")
|
||||
self._login_token = resp_dict["result"]["token"]
|
||||
|
||||
async def _generate_key_pair_payload(self) -> AsyncGenerator:
|
||||
"""Generate the request body and return an ascyn_generator.
|
||||
|
||||
This prevents the key pair being generated unless a connection
|
||||
can be made to the device.
|
||||
"""
|
||||
_LOGGER.debug("Generating keypair")
|
||||
self._key_pair = KeyPair.create_key_pair()
|
||||
pub_key = (
|
||||
"-----BEGIN PUBLIC KEY-----\n"
|
||||
+ self._key_pair.get_public_key() # type: ignore[union-attr]
|
||||
+ "\n-----END PUBLIC KEY-----\n"
|
||||
)
|
||||
handshake_params = {"key": pub_key}
|
||||
_LOGGER.debug(f"Handshake params: {handshake_params}")
|
||||
request_body = {"method": "handshake", "params": handshake_params}
|
||||
_LOGGER.debug(f"Request {request_body}")
|
||||
yield json_dumps(request_body).encode()
|
||||
|
||||
async def perform_handshake(self):
|
||||
"""Perform the handshake."""
|
||||
_LOGGER.debug("Will perform handshaking...")
|
||||
_LOGGER.debug("Generating keypair")
|
||||
|
||||
self._key_pair = None
|
||||
self._handshake_done = False
|
||||
self._session_expire_at = None
|
||||
self._session_cookie = None
|
||||
|
||||
url = f"http://{self._host}/app"
|
||||
key_pair = KeyPair.create_key_pair()
|
||||
|
||||
pub_key = (
|
||||
"-----BEGIN PUBLIC KEY-----\n"
|
||||
+ key_pair.get_public_key()
|
||||
+ "\n-----END PUBLIC KEY-----\n"
|
||||
)
|
||||
handshake_params = {"key": pub_key}
|
||||
_LOGGER.debug(f"Handshake params: {handshake_params}")
|
||||
|
||||
request_body = {"method": "handshake", "params": handshake_params}
|
||||
|
||||
_LOGGER.debug(f"Request {request_body}")
|
||||
|
||||
# Device needs the content length or it will response with 500
|
||||
headers = {
|
||||
**self.COMMON_HEADERS,
|
||||
self.CONTENT_LENGTH: str(self.KEY_PAIR_CONTENT_LENGTH),
|
||||
}
|
||||
status_code, resp_dict = await self._http_client.post(
|
||||
url,
|
||||
json=request_body,
|
||||
headers=self.COMMON_HEADERS,
|
||||
json=self._generate_key_pair_payload(),
|
||||
headers=headers,
|
||||
cookies_dict=self._session_cookie,
|
||||
)
|
||||
|
||||
@@ -259,8 +273,10 @@ class AesTransport(BaseTransport):
|
||||
self._session_cookie = {self.SESSION_COOKIE_NAME: cookie}
|
||||
|
||||
self._session_expire_at = time.time() + 86400
|
||||
if TYPE_CHECKING:
|
||||
assert self._key_pair is not None # pragma: no cover
|
||||
self._encryption_session = AesEncyptionSession.create_from_keypair(
|
||||
handshake_key, key_pair
|
||||
handshake_key, self._key_pair
|
||||
)
|
||||
|
||||
self._handshake_done = True
|
||||
|
Reference in New Issue
Block a user