mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-10-23 15:58:02 +00:00
Improve smartprotocol error handling and retries (#578)
* Improve smartprotocol error handling and retries * Update after review * Enum to IntEnum and SLEEP_SECONDS_AFTER_TIMEOUT
This commit is contained in:
@@ -17,7 +17,16 @@ from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
|
||||
from .credentials import Credentials
|
||||
from .exceptions import AuthenticationException, SmartDeviceException
|
||||
from .exceptions import (
|
||||
SMART_AUTHENTICATION_ERRORS,
|
||||
SMART_RETRYABLE_ERRORS,
|
||||
SMART_TIMEOUT_ERRORS,
|
||||
AuthenticationException,
|
||||
RetryableException,
|
||||
SmartDeviceException,
|
||||
SmartErrorCode,
|
||||
TimeoutException,
|
||||
)
|
||||
from .json import dumps as json_dumps
|
||||
from .json import loads as json_loads
|
||||
from .protocol import BaseTransport
|
||||
@@ -110,6 +119,21 @@ class AesTransport(BaseTransport):
|
||||
|
||||
return resp.status_code, response_data
|
||||
|
||||
def _handle_response_error_code(self, resp_dict: dict, msg: str):
|
||||
if (
|
||||
error_code := SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
|
||||
) != SmartErrorCode.SUCCESS:
|
||||
msg = f"{msg}: {self.host}: {error_code.name}({error_code.value})"
|
||||
if error_code in SMART_TIMEOUT_ERRORS:
|
||||
raise TimeoutException(msg)
|
||||
if error_code in SMART_RETRYABLE_ERRORS:
|
||||
raise RetryableException(msg)
|
||||
if error_code in SMART_AUTHENTICATION_ERRORS:
|
||||
self._handshake_done = False
|
||||
self._login_token = None
|
||||
raise AuthenticationException(msg)
|
||||
raise SmartDeviceException(msg)
|
||||
|
||||
async def send_secure_passthrough(self, request: str):
|
||||
"""Send encrypted message as passthrough."""
|
||||
url = f"http://{self.host}/app"
|
||||
@@ -123,17 +147,22 @@ class AesTransport(BaseTransport):
|
||||
}
|
||||
status_code, resp_dict = await self.client_post(url, json=passthrough_request)
|
||||
# _LOGGER.debug(f"secure_passthrough response is {status_code}: {resp_dict}")
|
||||
if status_code == 200 and resp_dict["error_code"] == 0:
|
||||
response = self._encryption_session.decrypt( # type: ignore
|
||||
resp_dict["result"]["response"].encode()
|
||||
|
||||
if status_code != 200:
|
||||
raise SmartDeviceException(
|
||||
f"{self.host} responded with an unexpected "
|
||||
+ f"status code {status_code} to passthrough"
|
||||
)
|
||||
_LOGGER.debug(f"decrypted secure_passthrough response is {response}")
|
||||
resp_dict = json_loads(response)
|
||||
return resp_dict
|
||||
else:
|
||||
self._handshake_done = False
|
||||
self._login_token = None
|
||||
raise AuthenticationException("Could not complete send")
|
||||
|
||||
self._handle_response_error_code(
|
||||
resp_dict, "Error sending secure_passthrough message"
|
||||
)
|
||||
|
||||
response = self._encryption_session.decrypt( # type: ignore
|
||||
resp_dict["result"]["response"].encode()
|
||||
)
|
||||
resp_dict = json_loads(response)
|
||||
return resp_dict
|
||||
|
||||
async def perform_login(self, login_request: Union[str, dict], *, login_v2: bool):
|
||||
"""Login to the device."""
|
||||
@@ -207,29 +236,32 @@ class AesTransport(BaseTransport):
|
||||
|
||||
_LOGGER.debug(f"Device responded with: {resp_dict}")
|
||||
|
||||
if status_code == 200 and resp_dict["error_code"] == 0:
|
||||
_LOGGER.debug("Decoding handshake key...")
|
||||
handshake_key = resp_dict["result"]["key"]
|
||||
if status_code != 200:
|
||||
raise SmartDeviceException(
|
||||
f"{self.host} responded with an unexpected "
|
||||
+ f"status code {status_code} to handshake"
|
||||
)
|
||||
|
||||
self._handle_response_error_code(resp_dict, "Unable to complete handshake")
|
||||
|
||||
handshake_key = resp_dict["result"]["key"]
|
||||
|
||||
self._session_cookie = self._http_client.cookies.get( # type: ignore
|
||||
self.SESSION_COOKIE_NAME
|
||||
)
|
||||
if not self._session_cookie:
|
||||
self._session_cookie = self._http_client.cookies.get( # type: ignore
|
||||
self.SESSION_COOKIE_NAME
|
||||
)
|
||||
if not self._session_cookie:
|
||||
self._session_cookie = self._http_client.cookies.get( # type: ignore
|
||||
"SESSIONID"
|
||||
)
|
||||
|
||||
self._session_expire_at = time.time() + 86400
|
||||
self._encryption_session = AesEncyptionSession.create_from_keypair(
|
||||
handshake_key, key_pair
|
||||
"SESSIONID"
|
||||
)
|
||||
|
||||
self._handshake_done = True
|
||||
self._session_expire_at = time.time() + 86400
|
||||
self._encryption_session = AesEncyptionSession.create_from_keypair(
|
||||
handshake_key, key_pair
|
||||
)
|
||||
|
||||
_LOGGER.debug("Handshake with %s complete", self.host)
|
||||
self._handshake_done = True
|
||||
|
||||
else:
|
||||
raise AuthenticationException("Could not complete handshake")
|
||||
_LOGGER.debug("Handshake with %s complete", self.host)
|
||||
|
||||
def _handshake_session_expired(self):
|
||||
"""Return true if session has expired."""
|
||||
@@ -247,19 +279,14 @@ class AesTransport(BaseTransport):
|
||||
if self.needs_login:
|
||||
raise SmartDeviceException("Login must be complete before trying to send")
|
||||
|
||||
resp_dict = await self.send_secure_passthrough(request)
|
||||
if resp_dict["error_code"] != 0:
|
||||
self._handshake_done = False
|
||||
self._login_token = None
|
||||
raise SmartDeviceException(
|
||||
f"Could not complete send, response was {resp_dict}",
|
||||
)
|
||||
return resp_dict
|
||||
return await self.send_secure_passthrough(request)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the protocol."""
|
||||
client = self._http_client
|
||||
self._http_client = None
|
||||
self._handshake_done = False
|
||||
self._login_token = None
|
||||
if client:
|
||||
await client.aclose()
|
||||
|
||||
|
Reference in New Issue
Block a user