mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-10-23 07:48:01 +00:00
Encapsulate http client dependency (#642)
* Encapsulate http client dependency * Store cookie dict as variable * Update post-review
This commit is contained in:
@@ -8,9 +8,8 @@ import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
from typing import Dict, Optional, cast
|
||||
|
||||
import httpx
|
||||
from cryptography.hazmat.primitives import padding, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
@@ -28,6 +27,7 @@ from .exceptions import (
|
||||
SmartErrorCode,
|
||||
TimeoutException,
|
||||
)
|
||||
from .httpclient import HttpClient
|
||||
from .json import dumps as json_dumps
|
||||
from .json import loads as json_loads
|
||||
from .protocol import BaseTransport
|
||||
@@ -75,14 +75,14 @@ class AesTransport(BaseTransport):
|
||||
base64.b64decode(self._credentials_hash.encode()).decode() # type: ignore[union-attr]
|
||||
)
|
||||
|
||||
self._default_http_client: Optional[httpx.AsyncClient] = None
|
||||
self._http_client: HttpClient = HttpClient(config)
|
||||
|
||||
self._handshake_done = False
|
||||
|
||||
self._encryption_session: Optional[AesEncyptionSession] = None
|
||||
self._session_expire_at: Optional[float] = None
|
||||
|
||||
self._session_cookie = None
|
||||
self._session_cookie: Optional[Dict[str, str]] = None
|
||||
|
||||
self._login_token = None
|
||||
|
||||
@@ -98,14 +98,6 @@ class AesTransport(BaseTransport):
|
||||
"""The hashed credentials used by the transport."""
|
||||
return base64.b64encode(json_dumps(self._login_params).encode()).decode()
|
||||
|
||||
@property
|
||||
def _http_client(self) -> httpx.AsyncClient:
|
||||
if self._config.http_client:
|
||||
return self._config.http_client
|
||||
if not self._default_http_client:
|
||||
self._default_http_client = httpx.AsyncClient()
|
||||
return self._default_http_client
|
||||
|
||||
def _get_login_params(self):
|
||||
"""Get the login parameters based on the login_version."""
|
||||
un, pw = self.hash_credentials(self._login_version == 2)
|
||||
@@ -128,28 +120,6 @@ class AesTransport(BaseTransport):
|
||||
pw = base64.b64encode(self._credentials.password.encode()).decode()
|
||||
return un, pw
|
||||
|
||||
async def client_post(self, url, params=None, data=None, json=None, headers=None):
|
||||
"""Send an http post request to the device."""
|
||||
response_data = None
|
||||
cookies = None
|
||||
if self._session_cookie:
|
||||
cookies = httpx.Cookies()
|
||||
cookies.set(self.SESSION_COOKIE_NAME, self._session_cookie)
|
||||
self._http_client.cookies.clear()
|
||||
resp = await self._http_client.post(
|
||||
url,
|
||||
params=params,
|
||||
data=data,
|
||||
json=json,
|
||||
timeout=self._timeout,
|
||||
cookies=cookies,
|
||||
headers=self.COMMON_HEADERS,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
response_data = resp.json()
|
||||
|
||||
return resp.status_code, response_data
|
||||
|
||||
def _handle_response_error_code(self, resp_dict: dict, msg: str):
|
||||
error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
|
||||
if error_code == SmartErrorCode.SUCCESS:
|
||||
@@ -176,7 +146,12 @@ class AesTransport(BaseTransport):
|
||||
"method": "securePassthrough",
|
||||
"params": {"request": encrypted_payload.decode()},
|
||||
}
|
||||
status_code, resp_dict = await self.client_post(url, json=passthrough_request)
|
||||
status_code, resp_dict = await self._http_client.post(
|
||||
url,
|
||||
json=passthrough_request,
|
||||
headers=self.COMMON_HEADERS,
|
||||
cookies_dict=self._session_cookie,
|
||||
)
|
||||
# _LOGGER.debug(f"secure_passthrough response is {status_code}: {resp_dict}")
|
||||
|
||||
if status_code != 200:
|
||||
@@ -185,6 +160,7 @@ class AesTransport(BaseTransport):
|
||||
+ f"status code {status_code} to passthrough"
|
||||
)
|
||||
|
||||
resp_dict = cast(Dict, resp_dict)
|
||||
self._handle_response_error_code(
|
||||
resp_dict, "Error sending secure_passthrough message"
|
||||
)
|
||||
@@ -233,7 +209,12 @@ class AesTransport(BaseTransport):
|
||||
|
||||
_LOGGER.debug(f"Request {request_body}")
|
||||
|
||||
status_code, resp_dict = await self.client_post(url, json=request_body)
|
||||
status_code, resp_dict = await self._http_client.post(
|
||||
url,
|
||||
json=request_body,
|
||||
headers=self.COMMON_HEADERS,
|
||||
cookies_dict=self._session_cookie,
|
||||
)
|
||||
|
||||
_LOGGER.debug(f"Device responded with: {resp_dict}")
|
||||
|
||||
@@ -247,13 +228,16 @@ class AesTransport(BaseTransport):
|
||||
|
||||
handshake_key = resp_dict["result"]["key"]
|
||||
|
||||
self._session_cookie = self._http_client.cookies.get( # type: ignore
|
||||
self.SESSION_COOKIE_NAME
|
||||
)
|
||||
if not self._session_cookie:
|
||||
self._session_cookie = self._http_client.cookies.get( # type: ignore
|
||||
if (
|
||||
cookie := self._http_client.get_cookie( # type: ignore
|
||||
self.SESSION_COOKIE_NAME
|
||||
)
|
||||
) or (
|
||||
cookie := self._http_client.get_cookie( # type: ignore
|
||||
"SESSIONID"
|
||||
)
|
||||
):
|
||||
self._session_cookie = {self.SESSION_COOKIE_NAME: cookie}
|
||||
|
||||
self._session_expire_at = time.time() + 86400
|
||||
self._encryption_session = AesEncyptionSession.create_from_keypair(
|
||||
@@ -281,13 +265,10 @@ class AesTransport(BaseTransport):
|
||||
return await self.send_secure_passthrough(request)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the protocol."""
|
||||
client = self._default_http_client
|
||||
self._default_http_client = None
|
||||
"""Close the transport."""
|
||||
self._handshake_done = False
|
||||
self._login_token = None
|
||||
if client:
|
||||
await client.aclose()
|
||||
await self._http_client.close()
|
||||
|
||||
|
||||
class AesEncyptionSession:
|
||||
|
Reference in New Issue
Block a user