Enable and convert to future annotations (#838)

This commit is contained in:
Steven B
2024-04-17 14:39:24 +01:00
committed by GitHub
parent 82d92aeea5
commit 203bd79253
59 changed files with 562 additions and 462 deletions

View File

@@ -4,13 +4,15 @@ Based on the work of https://github.com/petretiandrea/plugp100
under compatible GNU GPL3 license.
"""
from __future__ import annotations
import asyncio
import base64
import hashlib
import logging
import time
from enum import Enum, auto
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Optional, Tuple, cast
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, cast
from cryptography.hazmat.primitives import padding, serialization
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
@@ -92,19 +94,19 @@ class AesTransport(BaseTransport):
self._login_params = json_loads(
base64.b64decode(self._credentials_hash.encode()).decode() # type: ignore[union-attr]
)
self._default_credentials: Optional[Credentials] = None
self._default_credentials: Credentials | None = None
self._http_client: HttpClient = HttpClient(config)
self._state = TransportState.HANDSHAKE_REQUIRED
self._encryption_session: Optional[AesEncyptionSession] = None
self._session_expire_at: Optional[float] = None
self._encryption_session: AesEncyptionSession | None = None
self._session_expire_at: float | None = None
self._session_cookie: Optional[Dict[str, str]] = None
self._session_cookie: dict[str, str] | None = None
self._key_pair: Optional[KeyPair] = None
self._key_pair: KeyPair | None = None
self._app_url = URL(f"http://{self._host}:{self._port}/app")
self._token_url: Optional[URL] = None
self._token_url: URL | None = None
_LOGGER.debug("Created AES transport for %s", self._host)
@@ -118,14 +120,14 @@ class AesTransport(BaseTransport):
"""The hashed credentials used by the transport."""
return base64.b64encode(json_dumps(self._login_params).encode()).decode()
def _get_login_params(self, credentials: Credentials) -> Dict[str, str]:
def _get_login_params(self, credentials: Credentials) -> dict[str, str]:
"""Get the login parameters based on the login_version."""
un, pw = self.hash_credentials(self._login_version == 2, credentials)
password_field_name = "password2" if self._login_version == 2 else "password"
return {password_field_name: pw, "username": un}
@staticmethod
def hash_credentials(login_v2: bool, credentials: Credentials) -> Tuple[str, str]:
def hash_credentials(login_v2: bool, credentials: Credentials) -> tuple[str, str]:
"""Hash the credentials."""
un = base64.b64encode(_sha1(credentials.username.encode()).encode()).decode()
if login_v2:
@@ -148,7 +150,7 @@ class AesTransport(BaseTransport):
raise AuthenticationError(msg, error_code=error_code)
raise DeviceError(msg, error_code=error_code)
async def send_secure_passthrough(self, request: str) -> Dict[str, Any]:
async def send_secure_passthrough(self, request: str) -> dict[str, Any]:
"""Send encrypted message as passthrough."""
if self._state is TransportState.ESTABLISHED and self._token_url:
url = self._token_url
@@ -230,7 +232,7 @@ class AesTransport(BaseTransport):
ex,
) from ex
async def try_login(self, login_params: Dict[str, Any]) -> None:
async def try_login(self, login_params: dict[str, Any]) -> None:
"""Try to login with supplied login_params."""
login_request = {
"method": "login_device",
@@ -333,7 +335,7 @@ class AesTransport(BaseTransport):
or self._session_expire_at - time.time() <= 0
)
async def send(self, request: str) -> Dict[str, Any]:
async def send(self, request: str) -> dict[str, Any]:
"""Send the request."""
if (
self._state is TransportState.HANDSHAKE_REQUIRED