Fix SslAesTransport default login and add tests (#1202)

Co-authored-by: Teemu R. <tpr@iki.fi>
This commit is contained in:
Steven B.
2024-10-28 16:36:34 +00:00
committed by GitHub
parent 0287606235
commit 440b2d153b
2 changed files with 390 additions and 5 deletions

View File

@@ -137,6 +137,11 @@ class SslAesTransport(BaseTransport):
"""Default port for the transport."""
return self.DEFAULT_PORT
@staticmethod
def _create_b64_credentials(credentials: Credentials) -> str:
ch = {"un": credentials.username, "pwd": credentials.password}
return base64.b64encode(json_dumps(ch).encode()).decode()
@property
def credentials_hash(self) -> str | None:
"""The hashed credentials used by the transport."""
@@ -145,8 +150,7 @@ class SslAesTransport(BaseTransport):
if not self._credentials and self._credentials_hash:
return self._credentials_hash
if (cred := self._credentials) and cred.password and cred.username:
ch = {"un": cred.username, "pwd": cred.password}
return base64.b64encode(json_dumps(ch).encode()).decode()
return self._create_b64_credentials(cred)
return None
def _get_response_error(self, resp_dict: Any) -> SmartErrorCode:
@@ -329,6 +333,13 @@ class SslAesTransport(BaseTransport):
+ f"status code {status_code} to handshake2"
)
resp_dict = cast(dict, resp_dict)
if (
error_code := self._get_response_error(resp_dict)
) and error_code is SmartErrorCode.INVALID_NONCE:
raise AuthenticationError(
f"Invalid password hash in handshake2 for {self._host}"
)
self._handle_response_error_code(resp_dict, "Error in handshake2")
self._seq = resp_dict["result"]["start_seq"]
@@ -372,12 +383,12 @@ class SslAesTransport(BaseTransport):
if not self._username:
raise AuthenticationError(
"Credentials must be supplied to connect to {self._host}"
f"Credentials must be supplied to connect to {self._host}"
)
if error_code is not SmartErrorCode.INVALID_NONCE or (
resp_dict and "nonce" not in resp_dict["result"].get("data", {})
):
raise AuthenticationError("Error trying handshake1: {resp_dict}")
raise AuthenticationError(f"Error trying handshake1: {resp_dict}")
if TYPE_CHECKING:
resp_dict = cast(Dict[str, Any], resp_dict)
@@ -422,7 +433,7 @@ class SslAesTransport(BaseTransport):
"params": {
"cnonce": local_nonce,
"encrypt_type": "3",
"username": self._username,
"username": username,
},
}
http_client = self._http_client