mirror of
				https://github.com/python-kasa/python-kasa.git
				synced 2025-11-04 06:32:07 +00:00 
			
		
		
		
	Generate AES KeyPair lazily (#687)
* Generate AES KeyPair lazily * Fix coverage * Update post-review * Fix pragma * Make json dumps consistent between python and orjson * Add comment * Add comments re json parameter in HttpClient
This commit is contained in:
		@@ -8,7 +8,7 @@ import base64
 | 
			
		||||
import hashlib
 | 
			
		||||
import logging
 | 
			
		||||
import time
 | 
			
		||||
from typing import Dict, Optional, cast
 | 
			
		||||
from typing import TYPE_CHECKING, AsyncGenerator, Dict, Optional, cast
 | 
			
		||||
 | 
			
		||||
from cryptography.hazmat.primitives import padding, serialization
 | 
			
		||||
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
 | 
			
		||||
@@ -55,6 +55,8 @@ class AesTransport(BaseTransport):
 | 
			
		||||
        "requestByApp": "true",
 | 
			
		||||
        "Accept": "application/json",
 | 
			
		||||
    }
 | 
			
		||||
    CONTENT_LENGTH = "Content-Length"
 | 
			
		||||
    KEY_PAIR_CONTENT_LENGTH = 314
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
@@ -86,6 +88,8 @@ class AesTransport(BaseTransport):
 | 
			
		||||
 | 
			
		||||
        self._login_token = None
 | 
			
		||||
 | 
			
		||||
        self._key_pair: Optional[KeyPair] = None
 | 
			
		||||
 | 
			
		||||
        _LOGGER.debug("Created AES transport for %s", self._host)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
@@ -204,34 +208,44 @@ class AesTransport(BaseTransport):
 | 
			
		||||
        self._handle_response_error_code(resp_dict, "Error logging in")
 | 
			
		||||
        self._login_token = resp_dict["result"]["token"]
 | 
			
		||||
 | 
			
		||||
    async def _generate_key_pair_payload(self) -> AsyncGenerator:
 | 
			
		||||
        """Generate the request body and return an ascyn_generator.
 | 
			
		||||
 | 
			
		||||
        This prevents the key pair being generated unless a connection
 | 
			
		||||
        can be made to the device.
 | 
			
		||||
        """
 | 
			
		||||
        _LOGGER.debug("Generating keypair")
 | 
			
		||||
        self._key_pair = KeyPair.create_key_pair()
 | 
			
		||||
        pub_key = (
 | 
			
		||||
            "-----BEGIN PUBLIC KEY-----\n"
 | 
			
		||||
            + self._key_pair.get_public_key()  # type: ignore[union-attr]
 | 
			
		||||
            + "\n-----END PUBLIC KEY-----\n"
 | 
			
		||||
        )
 | 
			
		||||
        handshake_params = {"key": pub_key}
 | 
			
		||||
        _LOGGER.debug(f"Handshake params: {handshake_params}")
 | 
			
		||||
        request_body = {"method": "handshake", "params": handshake_params}
 | 
			
		||||
        _LOGGER.debug(f"Request {request_body}")
 | 
			
		||||
        yield json_dumps(request_body).encode()
 | 
			
		||||
 | 
			
		||||
    async def perform_handshake(self):
 | 
			
		||||
        """Perform the handshake."""
 | 
			
		||||
        _LOGGER.debug("Will perform handshaking...")
 | 
			
		||||
        _LOGGER.debug("Generating keypair")
 | 
			
		||||
 | 
			
		||||
        self._key_pair = None
 | 
			
		||||
        self._handshake_done = False
 | 
			
		||||
        self._session_expire_at = None
 | 
			
		||||
        self._session_cookie = None
 | 
			
		||||
 | 
			
		||||
        url = f"http://{self._host}/app"
 | 
			
		||||
        key_pair = KeyPair.create_key_pair()
 | 
			
		||||
 | 
			
		||||
        pub_key = (
 | 
			
		||||
            "-----BEGIN PUBLIC KEY-----\n"
 | 
			
		||||
            + key_pair.get_public_key()
 | 
			
		||||
            + "\n-----END PUBLIC KEY-----\n"
 | 
			
		||||
        )
 | 
			
		||||
        handshake_params = {"key": pub_key}
 | 
			
		||||
        _LOGGER.debug(f"Handshake params: {handshake_params}")
 | 
			
		||||
 | 
			
		||||
        request_body = {"method": "handshake", "params": handshake_params}
 | 
			
		||||
 | 
			
		||||
        _LOGGER.debug(f"Request {request_body}")
 | 
			
		||||
 | 
			
		||||
        # Device needs the content length or it will response with 500
 | 
			
		||||
        headers = {
 | 
			
		||||
            **self.COMMON_HEADERS,
 | 
			
		||||
            self.CONTENT_LENGTH: str(self.KEY_PAIR_CONTENT_LENGTH),
 | 
			
		||||
        }
 | 
			
		||||
        status_code, resp_dict = await self._http_client.post(
 | 
			
		||||
            url,
 | 
			
		||||
            json=request_body,
 | 
			
		||||
            headers=self.COMMON_HEADERS,
 | 
			
		||||
            json=self._generate_key_pair_payload(),
 | 
			
		||||
            headers=headers,
 | 
			
		||||
            cookies_dict=self._session_cookie,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
@@ -259,8 +273,10 @@ class AesTransport(BaseTransport):
 | 
			
		||||
            self._session_cookie = {self.SESSION_COOKIE_NAME: cookie}
 | 
			
		||||
 | 
			
		||||
        self._session_expire_at = time.time() + 86400
 | 
			
		||||
        if TYPE_CHECKING:
 | 
			
		||||
            assert self._key_pair is not None  # pragma: no cover
 | 
			
		||||
        self._encryption_session = AesEncyptionSession.create_from_keypair(
 | 
			
		||||
            handshake_key, key_pair
 | 
			
		||||
            handshake_key, self._key_pair
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self._handshake_done = True
 | 
			
		||||
 
 | 
			
		||||
@@ -41,14 +41,25 @@ class HttpClient:
 | 
			
		||||
        *,
 | 
			
		||||
        params: Optional[Dict[str, Any]] = None,
 | 
			
		||||
        data: Optional[bytes] = None,
 | 
			
		||||
        json: Optional[Dict] = None,
 | 
			
		||||
        json: Optional[Union[Dict, Any]] = None,
 | 
			
		||||
        headers: Optional[Dict[str, str]] = None,
 | 
			
		||||
        cookies_dict: Optional[Dict[str, str]] = None,
 | 
			
		||||
    ) -> Tuple[int, Optional[Union[Dict, bytes]]]:
 | 
			
		||||
        """Send an http post request to the device."""
 | 
			
		||||
        """Send an http post request to the device.
 | 
			
		||||
 | 
			
		||||
        If the request is provided via the json parameter json will be returned.
 | 
			
		||||
        """
 | 
			
		||||
        response_data = None
 | 
			
		||||
        self._last_url = url
 | 
			
		||||
        self.client.cookie_jar.clear()
 | 
			
		||||
        return_json = bool(json)
 | 
			
		||||
        # If json is not a dict send as data.
 | 
			
		||||
        # This allows the json parameter to be used to pass other
 | 
			
		||||
        # types of data such as async_generator and still have json
 | 
			
		||||
        # returned.
 | 
			
		||||
        if json and not isinstance(json, Dict):
 | 
			
		||||
            data = json
 | 
			
		||||
            json = None
 | 
			
		||||
        try:
 | 
			
		||||
            resp = await self.client.post(
 | 
			
		||||
                url,
 | 
			
		||||
@@ -62,7 +73,7 @@ class HttpClient:
 | 
			
		||||
            async with resp:
 | 
			
		||||
                if resp.status == 200:
 | 
			
		||||
                    response_data = await resp.read()
 | 
			
		||||
                    if json:
 | 
			
		||||
                    if return_json:
 | 
			
		||||
                        response_data = json_loads(response_data.decode())
 | 
			
		||||
 | 
			
		||||
        except (aiohttp.ServerDisconnectedError, aiohttp.ClientOSError) as ex:
 | 
			
		||||
 
 | 
			
		||||
@@ -11,5 +11,9 @@ try:
 | 
			
		||||
except ImportError:
 | 
			
		||||
    import json
 | 
			
		||||
 | 
			
		||||
    dumps = json.dumps
 | 
			
		||||
    def dumps(obj, *, default=None):
 | 
			
		||||
        """Dump JSON."""
 | 
			
		||||
        # Separators specified for consistency with orjson
 | 
			
		||||
        return json.dumps(obj, separators=(",", ":"))
 | 
			
		||||
 | 
			
		||||
    loads = json.loads
 | 
			
		||||
 
 | 
			
		||||
@@ -225,7 +225,10 @@ class MockAesDevice:
 | 
			
		||||
        else:
 | 
			
		||||
            return self._inner_error_code
 | 
			
		||||
 | 
			
		||||
    async def post(self, url, params=None, json=None, *_, **__):
 | 
			
		||||
    async def post(self, url, params=None, json=None, data=None, *_, **__):
 | 
			
		||||
        if data:
 | 
			
		||||
            async for item in data:
 | 
			
		||||
                json = json_loads(item.decode())
 | 
			
		||||
        return await self._post(url, json)
 | 
			
		||||
 | 
			
		||||
    async def _post(self, url, json):
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user