From 2dd621675a99086e400103a7f78b810f64a0d426 Mon Sep 17 00:00:00 2001 From: "Steven B." <51370195+sdb9696@users.noreply.github.com> Date: Fri, 18 Oct 2024 10:40:17 +0100 Subject: [PATCH] Drop urllib3 dependency and create ssl context in executor thread (#1175) --- kasa/experimental/sslaestransport.py | 35 +++++++++++++++++++++------- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/kasa/experimental/sslaestransport.py b/kasa/experimental/sslaestransport.py index 8936db8d..151cd568 100644 --- a/kasa/experimental/sslaestransport.py +++ b/kasa/experimental/sslaestransport.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio import base64 import hashlib import logging @@ -12,7 +13,6 @@ from enum import Enum, IntEnum, auto from functools import cache from typing import TYPE_CHECKING, Any, Dict, cast -from urllib3.util import create_urllib3_context from yarl import URL from ..aestransport import AesEncyptionSession @@ -108,11 +108,7 @@ class SslAesTransport(BaseTransport): self._host_port = f"{self._host}:{self._port}" self._app_url = URL(f"https://{self._host_port}") self._token_url: URL | None = None - self._ssl_context = create_urllib3_context( - ciphers=self.CIPHERS, - cert_reqs=ssl.CERT_NONE, - options=0, - ) + self._ssl_context: ssl.SSLContext | None = None ref = str(self._token_url) if self._token_url else str(self._app_url) self._headers = { **self.COMMON_HEADERS, @@ -168,6 +164,21 @@ class SslAesTransport(BaseTransport): raise AuthenticationError(msg, error_code=error_code) raise DeviceError(msg, error_code=error_code) + def _create_ssl_context(self) -> ssl.SSLContext: + context = ssl.SSLContext() + context.set_ciphers(self.CIPHERS) + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + return context + + async def _get_ssl_context(self) -> ssl.SSLContext: + if not self._ssl_context: + loop = asyncio.get_running_loop() + self._ssl_context = await loop.run_in_executor( + None, self._create_ssl_context + ) + return self._ssl_context + async def send_secure_passthrough(self, request: str) -> dict[str, Any]: """Send encrypted message as passthrough.""" if self._state is TransportState.ESTABLISHED and self._token_url: @@ -194,7 +205,7 @@ class SslAesTransport(BaseTransport): url, json=passthrough_request_str, headers=headers, - ssl=self._ssl_context, + ssl=await self._get_ssl_context(), ) if status_code != 200: @@ -299,7 +310,10 @@ class SslAesTransport(BaseTransport): } http_client = self._http_client status_code, resp_dict = await http_client.post( - self._app_url, json=body, headers=self._headers, ssl=self._ssl_context + self._app_url, + json=body, + headers=self._headers, + ssl=await self._get_ssl_context(), ) if status_code != 200: raise KasaException( @@ -337,7 +351,10 @@ class SslAesTransport(BaseTransport): http_client = self._http_client status_code, resp_dict = await http_client.post( - self._app_url, json=body, headers=self._headers, ssl=self._ssl_context + self._app_url, + json=body, + headers=self._headers, + ssl=await self._get_ssl_context(), ) _LOGGER.debug("Device responded with: %s", resp_dict)