mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-04-26 08:36:22 +00:00
Generate AES KeyPair lazily (#687)
* Generate AES KeyPair lazily * Fix coverage * Update post-review * Fix pragma * Make json dumps consistent between python and orjson * Add comment * Add comments re json parameter in HttpClient
This commit is contained in:
parent
718983c401
commit
e233e377ad
@ -8,7 +8,7 @@ import base64
|
|||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Dict, Optional, cast
|
from typing import TYPE_CHECKING, AsyncGenerator, Dict, Optional, cast
|
||||||
|
|
||||||
from cryptography.hazmat.primitives import padding, serialization
|
from cryptography.hazmat.primitives import padding, serialization
|
||||||
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
|
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
|
||||||
@ -55,6 +55,8 @@ class AesTransport(BaseTransport):
|
|||||||
"requestByApp": "true",
|
"requestByApp": "true",
|
||||||
"Accept": "application/json",
|
"Accept": "application/json",
|
||||||
}
|
}
|
||||||
|
CONTENT_LENGTH = "Content-Length"
|
||||||
|
KEY_PAIR_CONTENT_LENGTH = 314
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -86,6 +88,8 @@ class AesTransport(BaseTransport):
|
|||||||
|
|
||||||
self._login_token = None
|
self._login_token = None
|
||||||
|
|
||||||
|
self._key_pair: Optional[KeyPair] = None
|
||||||
|
|
||||||
_LOGGER.debug("Created AES transport for %s", self._host)
|
_LOGGER.debug("Created AES transport for %s", self._host)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -204,34 +208,44 @@ class AesTransport(BaseTransport):
|
|||||||
self._handle_response_error_code(resp_dict, "Error logging in")
|
self._handle_response_error_code(resp_dict, "Error logging in")
|
||||||
self._login_token = resp_dict["result"]["token"]
|
self._login_token = resp_dict["result"]["token"]
|
||||||
|
|
||||||
|
async def _generate_key_pair_payload(self) -> AsyncGenerator:
|
||||||
|
"""Generate the request body and return an ascyn_generator.
|
||||||
|
|
||||||
|
This prevents the key pair being generated unless a connection
|
||||||
|
can be made to the device.
|
||||||
|
"""
|
||||||
|
_LOGGER.debug("Generating keypair")
|
||||||
|
self._key_pair = KeyPair.create_key_pair()
|
||||||
|
pub_key = (
|
||||||
|
"-----BEGIN PUBLIC KEY-----\n"
|
||||||
|
+ self._key_pair.get_public_key() # type: ignore[union-attr]
|
||||||
|
+ "\n-----END PUBLIC KEY-----\n"
|
||||||
|
)
|
||||||
|
handshake_params = {"key": pub_key}
|
||||||
|
_LOGGER.debug(f"Handshake params: {handshake_params}")
|
||||||
|
request_body = {"method": "handshake", "params": handshake_params}
|
||||||
|
_LOGGER.debug(f"Request {request_body}")
|
||||||
|
yield json_dumps(request_body).encode()
|
||||||
|
|
||||||
async def perform_handshake(self):
|
async def perform_handshake(self):
|
||||||
"""Perform the handshake."""
|
"""Perform the handshake."""
|
||||||
_LOGGER.debug("Will perform handshaking...")
|
_LOGGER.debug("Will perform handshaking...")
|
||||||
_LOGGER.debug("Generating keypair")
|
|
||||||
|
|
||||||
|
self._key_pair = None
|
||||||
self._handshake_done = False
|
self._handshake_done = False
|
||||||
self._session_expire_at = None
|
self._session_expire_at = None
|
||||||
self._session_cookie = None
|
self._session_cookie = None
|
||||||
|
|
||||||
url = f"http://{self._host}/app"
|
url = f"http://{self._host}/app"
|
||||||
key_pair = KeyPair.create_key_pair()
|
# Device needs the content length or it will response with 500
|
||||||
|
headers = {
|
||||||
pub_key = (
|
**self.COMMON_HEADERS,
|
||||||
"-----BEGIN PUBLIC KEY-----\n"
|
self.CONTENT_LENGTH: str(self.KEY_PAIR_CONTENT_LENGTH),
|
||||||
+ key_pair.get_public_key()
|
}
|
||||||
+ "\n-----END PUBLIC KEY-----\n"
|
|
||||||
)
|
|
||||||
handshake_params = {"key": pub_key}
|
|
||||||
_LOGGER.debug(f"Handshake params: {handshake_params}")
|
|
||||||
|
|
||||||
request_body = {"method": "handshake", "params": handshake_params}
|
|
||||||
|
|
||||||
_LOGGER.debug(f"Request {request_body}")
|
|
||||||
|
|
||||||
status_code, resp_dict = await self._http_client.post(
|
status_code, resp_dict = await self._http_client.post(
|
||||||
url,
|
url,
|
||||||
json=request_body,
|
json=self._generate_key_pair_payload(),
|
||||||
headers=self.COMMON_HEADERS,
|
headers=headers,
|
||||||
cookies_dict=self._session_cookie,
|
cookies_dict=self._session_cookie,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -259,8 +273,10 @@ class AesTransport(BaseTransport):
|
|||||||
self._session_cookie = {self.SESSION_COOKIE_NAME: cookie}
|
self._session_cookie = {self.SESSION_COOKIE_NAME: cookie}
|
||||||
|
|
||||||
self._session_expire_at = time.time() + 86400
|
self._session_expire_at = time.time() + 86400
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
assert self._key_pair is not None # pragma: no cover
|
||||||
self._encryption_session = AesEncyptionSession.create_from_keypair(
|
self._encryption_session = AesEncyptionSession.create_from_keypair(
|
||||||
handshake_key, key_pair
|
handshake_key, self._key_pair
|
||||||
)
|
)
|
||||||
|
|
||||||
self._handshake_done = True
|
self._handshake_done = True
|
||||||
|
@ -41,14 +41,25 @@ class HttpClient:
|
|||||||
*,
|
*,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
data: Optional[bytes] = None,
|
data: Optional[bytes] = None,
|
||||||
json: Optional[Dict] = None,
|
json: Optional[Union[Dict, Any]] = None,
|
||||||
headers: Optional[Dict[str, str]] = None,
|
headers: Optional[Dict[str, str]] = None,
|
||||||
cookies_dict: Optional[Dict[str, str]] = None,
|
cookies_dict: Optional[Dict[str, str]] = None,
|
||||||
) -> Tuple[int, Optional[Union[Dict, bytes]]]:
|
) -> Tuple[int, Optional[Union[Dict, bytes]]]:
|
||||||
"""Send an http post request to the device."""
|
"""Send an http post request to the device.
|
||||||
|
|
||||||
|
If the request is provided via the json parameter json will be returned.
|
||||||
|
"""
|
||||||
response_data = None
|
response_data = None
|
||||||
self._last_url = url
|
self._last_url = url
|
||||||
self.client.cookie_jar.clear()
|
self.client.cookie_jar.clear()
|
||||||
|
return_json = bool(json)
|
||||||
|
# If json is not a dict send as data.
|
||||||
|
# This allows the json parameter to be used to pass other
|
||||||
|
# types of data such as async_generator and still have json
|
||||||
|
# returned.
|
||||||
|
if json and not isinstance(json, Dict):
|
||||||
|
data = json
|
||||||
|
json = None
|
||||||
try:
|
try:
|
||||||
resp = await self.client.post(
|
resp = await self.client.post(
|
||||||
url,
|
url,
|
||||||
@ -62,7 +73,7 @@ class HttpClient:
|
|||||||
async with resp:
|
async with resp:
|
||||||
if resp.status == 200:
|
if resp.status == 200:
|
||||||
response_data = await resp.read()
|
response_data = await resp.read()
|
||||||
if json:
|
if return_json:
|
||||||
response_data = json_loads(response_data.decode())
|
response_data = json_loads(response_data.decode())
|
||||||
|
|
||||||
except (aiohttp.ServerDisconnectedError, aiohttp.ClientOSError) as ex:
|
except (aiohttp.ServerDisconnectedError, aiohttp.ClientOSError) as ex:
|
||||||
|
@ -11,5 +11,9 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
import json
|
import json
|
||||||
|
|
||||||
dumps = json.dumps
|
def dumps(obj, *, default=None):
|
||||||
|
"""Dump JSON."""
|
||||||
|
# Separators specified for consistency with orjson
|
||||||
|
return json.dumps(obj, separators=(",", ":"))
|
||||||
|
|
||||||
loads = json.loads
|
loads = json.loads
|
||||||
|
@ -225,7 +225,10 @@ class MockAesDevice:
|
|||||||
else:
|
else:
|
||||||
return self._inner_error_code
|
return self._inner_error_code
|
||||||
|
|
||||||
async def post(self, url, params=None, json=None, *_, **__):
|
async def post(self, url, params=None, json=None, data=None, *_, **__):
|
||||||
|
if data:
|
||||||
|
async for item in data:
|
||||||
|
json = json_loads(item.decode())
|
||||||
return await self._post(url, json)
|
return await self._post(url, json)
|
||||||
|
|
||||||
async def _post(self, url, json):
|
async def _post(self, url, json):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user