Drop urllib3 dependency and create ssl context in executor thread (#1175)

This commit is contained in:
Steven B. 2024-10-18 10:40:17 +01:00 committed by GitHub
parent c6f2d89d44
commit 2dd621675a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import base64 import base64
import hashlib import hashlib
import logging import logging
@ -12,7 +13,6 @@ from enum import Enum, IntEnum, auto
from functools import cache from functools import cache
from typing import TYPE_CHECKING, Any, Dict, cast from typing import TYPE_CHECKING, Any, Dict, cast
from urllib3.util import create_urllib3_context
from yarl import URL from yarl import URL
from ..aestransport import AesEncyptionSession from ..aestransport import AesEncyptionSession
@ -108,11 +108,7 @@ class SslAesTransport(BaseTransport):
self._host_port = f"{self._host}:{self._port}" self._host_port = f"{self._host}:{self._port}"
self._app_url = URL(f"https://{self._host_port}") self._app_url = URL(f"https://{self._host_port}")
self._token_url: URL | None = None self._token_url: URL | None = None
self._ssl_context = create_urllib3_context( self._ssl_context: ssl.SSLContext | None = None
ciphers=self.CIPHERS,
cert_reqs=ssl.CERT_NONE,
options=0,
)
ref = str(self._token_url) if self._token_url else str(self._app_url) ref = str(self._token_url) if self._token_url else str(self._app_url)
self._headers = { self._headers = {
**self.COMMON_HEADERS, **self.COMMON_HEADERS,
@ -168,6 +164,21 @@ class SslAesTransport(BaseTransport):
raise AuthenticationError(msg, error_code=error_code) raise AuthenticationError(msg, error_code=error_code)
raise DeviceError(msg, error_code=error_code) raise DeviceError(msg, error_code=error_code)
def _create_ssl_context(self) -> ssl.SSLContext:
context = ssl.SSLContext()
context.set_ciphers(self.CIPHERS)
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE
return context
async def _get_ssl_context(self) -> ssl.SSLContext:
if not self._ssl_context:
loop = asyncio.get_running_loop()
self._ssl_context = await loop.run_in_executor(
None, self._create_ssl_context
)
return self._ssl_context
async def send_secure_passthrough(self, request: str) -> dict[str, Any]: async def send_secure_passthrough(self, request: str) -> dict[str, Any]:
"""Send encrypted message as passthrough.""" """Send encrypted message as passthrough."""
if self._state is TransportState.ESTABLISHED and self._token_url: if self._state is TransportState.ESTABLISHED and self._token_url:
@ -194,7 +205,7 @@ class SslAesTransport(BaseTransport):
url, url,
json=passthrough_request_str, json=passthrough_request_str,
headers=headers, headers=headers,
ssl=self._ssl_context, ssl=await self._get_ssl_context(),
) )
if status_code != 200: if status_code != 200:
@ -299,7 +310,10 @@ class SslAesTransport(BaseTransport):
} }
http_client = self._http_client http_client = self._http_client
status_code, resp_dict = await http_client.post( status_code, resp_dict = await http_client.post(
self._app_url, json=body, headers=self._headers, ssl=self._ssl_context self._app_url,
json=body,
headers=self._headers,
ssl=await self._get_ssl_context(),
) )
if status_code != 200: if status_code != 200:
raise KasaException( raise KasaException(
@ -337,7 +351,10 @@ class SslAesTransport(BaseTransport):
http_client = self._http_client http_client = self._http_client
status_code, resp_dict = await http_client.post( status_code, resp_dict = await http_client.post(
self._app_url, json=body, headers=self._headers, ssl=self._ssl_context self._app_url,
json=body,
headers=self._headers,
ssl=await self._get_ssl_context(),
) )
_LOGGER.debug("Device responded with: %s", resp_dict) _LOGGER.debug("Device responded with: %s", resp_dict)