Allow serializing and passing of credentials_hashes in DeviceConfig (#607)

* Allow passing of credentials_hashes in DeviceConfig

* Update following review
This commit is contained in:
sdb9696
2024-01-03 21:46:08 +00:00
committed by GitHub
parent 3692e4812f
commit e9bf9f58ee
13 changed files with 183 additions and 34 deletions

View File

@@ -16,6 +16,7 @@ from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padd
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from .credentials import Credentials
from .deviceconfig import DeviceConfig
from .exceptions import (
SMART_AUTHENTICATION_ERRORS,
@@ -62,6 +63,16 @@ class AesTransport(BaseTransport):
) -> None:
super().__init__(config=config)
self._login_version = config.connection_type.login_version
if not self._credentials and not self._credentials_hash:
self._credentials = Credentials()
if self._credentials:
self._login_params = self._get_login_params()
else:
self._login_params = json_loads(
base64.b64decode(self._credentials_hash.encode()).decode() # type: ignore[union-attr]
)
self._default_http_client: Optional[httpx.AsyncClient] = None
self._handshake_done = False
@@ -80,6 +91,11 @@ class AesTransport(BaseTransport):
"""Default port for the transport."""
return self.DEFAULT_PORT
@property
def credentials_hash(self) -> str:
"""The hashed credentials used by the transport."""
return base64.b64encode(json_dumps(self._login_params).encode()).decode()
@property
def _http_client(self) -> httpx.AsyncClient:
if self._config.http_client:
@@ -88,6 +104,12 @@ class AesTransport(BaseTransport):
self._default_http_client = httpx.AsyncClient()
return self._default_http_client
def _get_login_params(self):
"""Get the login parameters based on the login_version."""
un, pw = self.hash_credentials(self._login_version == 2)
password_field_name = "password2" if self._login_version == 2 else "password"
return {password_field_name: pw, "username": un}
def hash_credentials(self, login_v2):
"""Hash the credentials."""
if login_v2:
@@ -171,14 +193,12 @@ class AesTransport(BaseTransport):
resp_dict = json_loads(response)
return resp_dict
async def _perform_login_for_version(self, *, login_version: int = 1):
async def perform_login(self):
"""Login to the device."""
self._login_token = None
un, pw = self.hash_credentials(login_version == 2)
password_field_name = "password2" if login_version == 2 else "password"
login_request = {
"method": "login_device",
"params": {password_field_name: pw, "username": un},
"params": self._login_params,
"request_time_milis": round(time.time() * 1000),
}
request = json_dumps(login_request)
@@ -187,15 +207,6 @@ class AesTransport(BaseTransport):
self._handle_response_error_code(resp_dict, "Error logging in")
self._login_token = resp_dict["result"]["token"]
async def perform_login(self) -> None:
"""Login to the device."""
try:
await self._perform_login_for_version(login_version=2)
except AuthenticationException:
_LOGGER.warning("Login version 2 failed, trying version 1")
await self.perform_handshake()
await self._perform_login_for_version(login_version=1)
async def perform_handshake(self):
"""Perform the handshake."""
_LOGGER.debug("Will perform handshaking...")