Handle blocked session and try less secure login for default

This commit is contained in:
Steven B 2024-12-11 08:56:32 +00:00
parent 36a9823b63
commit 22e4f48efa
No known key found for this signature in database
GPG Key ID: 6D5B46B3679F2A43

View File

@ -43,10 +43,12 @@ def _sha256(payload: bytes) -> bytes:
def _md5_hash(payload: bytes) -> str: def _md5_hash(payload: bytes) -> str:
return hashlib.md5(payload).hexdigest().upper() # noqa: S324 return hashlib.md5(payload).hexdigest().upper() # noqa: S324
def _sha256_hash32(payload: bytes) -> str: def _sha256_hash32(payload: bytes) -> str:
digest = hashlib.sha256(payload).digest() # noqa: S324 digest = hashlib.sha256(payload).digest() # noqa: S324
return base64.b32hexencode(digest).decode().upper() return base64.b32hexencode(digest).decode().upper()
def _sha256_hash(payload: bytes) -> str: def _sha256_hash(payload: bytes) -> str:
return hashlib.sha256(payload).hexdigest().upper() # noqa: S324 return hashlib.sha256(payload).hexdigest().upper() # noqa: S324
@ -168,6 +170,19 @@ class SslAesTransport(BaseTransport):
error_code = SmartErrorCode.INTERNAL_UNKNOWN_ERROR error_code = SmartErrorCode.INTERNAL_UNKNOWN_ERROR
return error_code return error_code
def _get_response_inner_error(self, resp_dict: Any) -> SmartErrorCode | None:
error_code_raw = resp_dict.get("data", {}).get("error_code")
if error_code_raw is None:
return None
try:
error_code = SmartErrorCode.from_int(error_code_raw)
except ValueError:
_LOGGER.warning(
"Device %s received unknown error code: %s", self._host, error_code_raw
)
error_code = SmartErrorCode.INTERNAL_UNKNOWN_ERROR
return error_code
def _handle_response_error_code(self, resp_dict: Any, msg: str) -> None: def _handle_response_error_code(self, resp_dict: Any, msg: str) -> None:
error_code = self._get_response_error(resp_dict) error_code = self._get_response_error(resp_dict)
if error_code is SmartErrorCode.SUCCESS: if error_code is SmartErrorCode.SUCCESS:
@ -458,6 +473,21 @@ class SslAesTransport(BaseTransport):
return self._default_credentials.password return self._default_credentials.password
def _is_less_secure_login(self, resp_dict: dict[str, Any]) -> bool:
result = (
self._get_response_error(resp_dict) is SmartErrorCode.SESSION_EXPIRED
and (data := resp_dict.get("result", {}).get("data", {}))
and (encrypt_type := data.get("encrypt_type"))
and (encrypt_type != ["3"])
)
if result:
_LOGGER.debug(
"Received encrypt_type %s for %s, trying less secure login",
encrypt_type,
self._host,
)
return result
async def perform_handshake1(self) -> tuple[str, str, str] | None: async def perform_handshake1(self) -> tuple[str, str, str] | None:
"""Perform the handshake1.""" """Perform the handshake1."""
resp_dict = None resp_dict = None
@ -467,20 +497,11 @@ class SslAesTransport(BaseTransport):
if ( if (
resp_dict resp_dict
and (error_code := self._get_response_error(resp_dict)) and self._is_less_secure_login(resp_dict)
is SmartErrorCode.SESSION_EXPIRED and await self.try_perform_login(
and (data := resp_dict.get("result", {}).get("data", {})) resp_dict.get("data", {}).get("nonce"), local_nonce
and (
encrypt_type := data.get("encrypt_type")
) )
and (encrypt_type != ["3"])
): ):
_LOGGER.debug(
"Received encrypt_type %s for %s, trying less secure login",
encrypt_type,
self._host,
)
if await self.try_perform_login(data.get("nonce"), local_nonce):
return None return None
# Try the default username. If it fails raise the original error_code # Try the default username. If it fails raise the original error_code
@ -495,6 +516,7 @@ class SslAesTransport(BaseTransport):
default_resp_dict = await self.try_send_handshake1( default_resp_dict = await self.try_send_handshake1(
self._default_credentials.username, local_nonce self._default_credentials.username, local_nonce
) )
# INVALID_NONCE means device should perform secure login
if ( if (
default_error_code := self._get_response_error(default_resp_dict) default_error_code := self._get_response_error(default_resp_dict)
) is SmartErrorCode.INVALID_NONCE and "nonce" in default_resp_dict[ ) is SmartErrorCode.INVALID_NONCE and "nonce" in default_resp_dict[
@ -504,15 +526,34 @@ class SslAesTransport(BaseTransport):
self._username = self._default_credentials.username self._username = self._default_credentials.username
error_code = default_error_code error_code = default_error_code
resp_dict = default_resp_dict resp_dict = default_resp_dict
# Otherwise could be less secure login
elif self._is_less_secure_login(
default_resp_dict
) and await self.try_perform_login(
default_resp_dict.get("data", {}).get("nonce"), local_nonce
):
return None
if not self._username: if not self._username:
raise AuthenticationError( raise AuthenticationError(
f"Credentials must be supplied to connect to {self._host}" f"Credentials must be supplied to connect to {self._host}"
) )
if error_code is not SmartErrorCode.INVALID_NONCE or ( if error_code is not SmartErrorCode.INVALID_NONCE or (
resp_dict and "nonce" not in resp_dict["result"].get("data", {}) resp_dict and "nonce" not in resp_dict.get("result", {}).get("data", {})
): ):
raise AuthenticationError(f"Error trying handshake1: {resp_dict}") if (
resp_dict
and self._get_response_inner_error(resp_dict)
is SmartErrorCode.DEVICE_BLOCKED
):
secs_left = resp_dict.get("data", {}).get("secs_left")
msg = "Device blocked" + (
f" for {secs_left} seconds" if secs_left else ""
)
raise DeviceError(msg)
raise AuthenticationError(
f"Error trying handshake1 for {self._host}: {resp_dict}"
)
if TYPE_CHECKING: if TYPE_CHECKING:
resp_dict = cast(dict[str, Any], resp_dict) resp_dict = cast(dict[str, Any], resp_dict)
@ -544,8 +585,12 @@ class SslAesTransport(BaseTransport):
# For testing purposes only. # For testing purposes only.
from ..credentials import DEFAULT_CREDENTIALS, get_default_credentials from ..credentials import DEFAULT_CREDENTIALS, get_default_credentials
device_or_wifi_mac = "12:34:56:AB:CD:EF" device_or_wifi_mac = "12:34:56:AB:CD:EF"
default_passes = {get_default_credentials(cred).password for cred in DEFAULT_CREDENTIALS.values()} default_passes = {
get_default_credentials(cred).password
for cred in DEFAULT_CREDENTIALS.values()
}
vals = { vals = {
"admin", "admin",
"tpadmin", "tpadmin",
@ -557,7 +602,13 @@ class SslAesTransport(BaseTransport):
} }
vals.update(default_passes) vals.update(default_passes)
for val in vals: for val in vals:
for func in {_sha256_hash, _md5_hash, _sha1_hash, _sha256_hash32, lambda x: x.decode()}: for func in {
_sha256_hash,
_md5_hash,
_sha1_hash,
_sha256_hash32,
lambda x: x.decode(),
}:
pwd_hash = func(val.encode()) pwd_hash = func(val.encode())
ec = self.generate_confirm_hash(local_nonce, server_nonce, pwd_hash) ec = self.generate_confirm_hash(local_nonce, server_nonce, pwd_hash)
if device_confirm == ec: if device_confirm == ec:
@ -590,7 +641,7 @@ class SslAesTransport(BaseTransport):
ssl=await self._get_ssl_context(), ssl=await self._get_ssl_context(),
) )
_LOGGER.debug("Device responded with: %s", resp_dict) _LOGGER.debug("Device responded with status %s: %s", status_code, resp_dict)
if status_code != 200: if status_code != 200:
raise KasaException( raise KasaException(