Encapsulate http client dependency (#642)

* Encapsulate http client dependency

* Store cookie dict as variable

* Update post-review
This commit is contained in:
Steven B
2024-01-18 09:57:33 +00:00
committed by GitHub
parent 4623434eb4
commit 3b1b0a3c21
11 changed files with 194 additions and 156 deletions

View File

@@ -8,9 +8,8 @@ import base64
import hashlib
import logging
import time
from typing import Optional
from typing import Dict, Optional, cast
import httpx
from cryptography.hazmat.primitives import padding, serialization
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
from cryptography.hazmat.primitives.asymmetric import rsa
@@ -28,6 +27,7 @@ from .exceptions import (
SmartErrorCode,
TimeoutException,
)
from .httpclient import HttpClient
from .json import dumps as json_dumps
from .json import loads as json_loads
from .protocol import BaseTransport
@@ -75,14 +75,14 @@ class AesTransport(BaseTransport):
base64.b64decode(self._credentials_hash.encode()).decode() # type: ignore[union-attr]
)
self._default_http_client: Optional[httpx.AsyncClient] = None
self._http_client: HttpClient = HttpClient(config)
self._handshake_done = False
self._encryption_session: Optional[AesEncyptionSession] = None
self._session_expire_at: Optional[float] = None
self._session_cookie = None
self._session_cookie: Optional[Dict[str, str]] = None
self._login_token = None
@@ -98,14 +98,6 @@ class AesTransport(BaseTransport):
"""The hashed credentials used by the transport."""
return base64.b64encode(json_dumps(self._login_params).encode()).decode()
@property
def _http_client(self) -> httpx.AsyncClient:
if self._config.http_client:
return self._config.http_client
if not self._default_http_client:
self._default_http_client = httpx.AsyncClient()
return self._default_http_client
def _get_login_params(self):
"""Get the login parameters based on the login_version."""
un, pw = self.hash_credentials(self._login_version == 2)
@@ -128,28 +120,6 @@ class AesTransport(BaseTransport):
pw = base64.b64encode(self._credentials.password.encode()).decode()
return un, pw
async def client_post(self, url, params=None, data=None, json=None, headers=None):
"""Send an http post request to the device."""
response_data = None
cookies = None
if self._session_cookie:
cookies = httpx.Cookies()
cookies.set(self.SESSION_COOKIE_NAME, self._session_cookie)
self._http_client.cookies.clear()
resp = await self._http_client.post(
url,
params=params,
data=data,
json=json,
timeout=self._timeout,
cookies=cookies,
headers=self.COMMON_HEADERS,
)
if resp.status_code == 200:
response_data = resp.json()
return resp.status_code, response_data
def _handle_response_error_code(self, resp_dict: dict, msg: str):
error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
if error_code == SmartErrorCode.SUCCESS:
@@ -176,7 +146,12 @@ class AesTransport(BaseTransport):
"method": "securePassthrough",
"params": {"request": encrypted_payload.decode()},
}
status_code, resp_dict = await self.client_post(url, json=passthrough_request)
status_code, resp_dict = await self._http_client.post(
url,
json=passthrough_request,
headers=self.COMMON_HEADERS,
cookies_dict=self._session_cookie,
)
# _LOGGER.debug(f"secure_passthrough response is {status_code}: {resp_dict}")
if status_code != 200:
@@ -185,6 +160,7 @@ class AesTransport(BaseTransport):
+ f"status code {status_code} to passthrough"
)
resp_dict = cast(Dict, resp_dict)
self._handle_response_error_code(
resp_dict, "Error sending secure_passthrough message"
)
@@ -233,7 +209,12 @@ class AesTransport(BaseTransport):
_LOGGER.debug(f"Request {request_body}")
status_code, resp_dict = await self.client_post(url, json=request_body)
status_code, resp_dict = await self._http_client.post(
url,
json=request_body,
headers=self.COMMON_HEADERS,
cookies_dict=self._session_cookie,
)
_LOGGER.debug(f"Device responded with: {resp_dict}")
@@ -247,13 +228,16 @@ class AesTransport(BaseTransport):
handshake_key = resp_dict["result"]["key"]
self._session_cookie = self._http_client.cookies.get( # type: ignore
self.SESSION_COOKIE_NAME
)
if not self._session_cookie:
self._session_cookie = self._http_client.cookies.get( # type: ignore
if (
cookie := self._http_client.get_cookie( # type: ignore
self.SESSION_COOKIE_NAME
)
) or (
cookie := self._http_client.get_cookie( # type: ignore
"SESSIONID"
)
):
self._session_cookie = {self.SESSION_COOKIE_NAME: cookie}
self._session_expire_at = time.time() + 86400
self._encryption_session = AesEncyptionSession.create_from_keypair(
@@ -281,13 +265,10 @@ class AesTransport(BaseTransport):
return await self.send_secure_passthrough(request)
async def close(self) -> None:
"""Close the protocol."""
client = self._default_http_client
self._default_http_client = None
"""Close the transport."""
self._handshake_done = False
self._login_token = None
if client:
await client.aclose()
await self._http_client.close()
class AesEncyptionSession: