Improve smartprotocol error handling and retries (#578)

* Improve smartprotocol error handling and retries

* Update after review

* Enum to IntEnum and SLEEP_SECONDS_AFTER_TIMEOUT
This commit is contained in:
sdb9696 2023-12-10 15:41:53 +00:00 committed by GitHub
parent a77af5fb3b
commit 2e6c41d039
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 208 additions and 54 deletions

View File

@ -17,7 +17,16 @@ from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from .credentials import Credentials from .credentials import Credentials
from .exceptions import AuthenticationException, SmartDeviceException from .exceptions import (
SMART_AUTHENTICATION_ERRORS,
SMART_RETRYABLE_ERRORS,
SMART_TIMEOUT_ERRORS,
AuthenticationException,
RetryableException,
SmartDeviceException,
SmartErrorCode,
TimeoutException,
)
from .json import dumps as json_dumps from .json import dumps as json_dumps
from .json import loads as json_loads from .json import loads as json_loads
from .protocol import BaseTransport from .protocol import BaseTransport
@ -110,6 +119,21 @@ class AesTransport(BaseTransport):
return resp.status_code, response_data return resp.status_code, response_data
def _handle_response_error_code(self, resp_dict: dict, msg: str):
if (
error_code := SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
) != SmartErrorCode.SUCCESS:
msg = f"{msg}: {self.host}: {error_code.name}({error_code.value})"
if error_code in SMART_TIMEOUT_ERRORS:
raise TimeoutException(msg)
if error_code in SMART_RETRYABLE_ERRORS:
raise RetryableException(msg)
if error_code in SMART_AUTHENTICATION_ERRORS:
self._handshake_done = False
self._login_token = None
raise AuthenticationException(msg)
raise SmartDeviceException(msg)
async def send_secure_passthrough(self, request: str): async def send_secure_passthrough(self, request: str):
"""Send encrypted message as passthrough.""" """Send encrypted message as passthrough."""
url = f"http://{self.host}/app" url = f"http://{self.host}/app"
@ -123,17 +147,22 @@ class AesTransport(BaseTransport):
} }
status_code, resp_dict = await self.client_post(url, json=passthrough_request) status_code, resp_dict = await self.client_post(url, json=passthrough_request)
# _LOGGER.debug(f"secure_passthrough response is {status_code}: {resp_dict}") # _LOGGER.debug(f"secure_passthrough response is {status_code}: {resp_dict}")
if status_code == 200 and resp_dict["error_code"] == 0:
response = self._encryption_session.decrypt( # type: ignore if status_code != 200:
resp_dict["result"]["response"].encode() raise SmartDeviceException(
f"{self.host} responded with an unexpected "
+ f"status code {status_code} to passthrough"
) )
_LOGGER.debug(f"decrypted secure_passthrough response is {response}")
resp_dict = json_loads(response) self._handle_response_error_code(
return resp_dict resp_dict, "Error sending secure_passthrough message"
else: )
self._handshake_done = False
self._login_token = None response = self._encryption_session.decrypt( # type: ignore
raise AuthenticationException("Could not complete send") resp_dict["result"]["response"].encode()
)
resp_dict = json_loads(response)
return resp_dict
async def perform_login(self, login_request: Union[str, dict], *, login_v2: bool): async def perform_login(self, login_request: Union[str, dict], *, login_v2: bool):
"""Login to the device.""" """Login to the device."""
@ -207,29 +236,32 @@ class AesTransport(BaseTransport):
_LOGGER.debug(f"Device responded with: {resp_dict}") _LOGGER.debug(f"Device responded with: {resp_dict}")
if status_code == 200 and resp_dict["error_code"] == 0: if status_code != 200:
_LOGGER.debug("Decoding handshake key...") raise SmartDeviceException(
handshake_key = resp_dict["result"]["key"] f"{self.host} responded with an unexpected "
+ f"status code {status_code} to handshake"
)
self._handle_response_error_code(resp_dict, "Unable to complete handshake")
handshake_key = resp_dict["result"]["key"]
self._session_cookie = self._http_client.cookies.get( # type: ignore
self.SESSION_COOKIE_NAME
)
if not self._session_cookie:
self._session_cookie = self._http_client.cookies.get( # type: ignore self._session_cookie = self._http_client.cookies.get( # type: ignore
self.SESSION_COOKIE_NAME "SESSIONID"
)
if not self._session_cookie:
self._session_cookie = self._http_client.cookies.get( # type: ignore
"SESSIONID"
)
self._session_expire_at = time.time() + 86400
self._encryption_session = AesEncyptionSession.create_from_keypair(
handshake_key, key_pair
) )
self._handshake_done = True self._session_expire_at = time.time() + 86400
self._encryption_session = AesEncyptionSession.create_from_keypair(
handshake_key, key_pair
)
_LOGGER.debug("Handshake with %s complete", self.host) self._handshake_done = True
else: _LOGGER.debug("Handshake with %s complete", self.host)
raise AuthenticationException("Could not complete handshake")
def _handshake_session_expired(self): def _handshake_session_expired(self):
"""Return true if session has expired.""" """Return true if session has expired."""
@ -247,19 +279,14 @@ class AesTransport(BaseTransport):
if self.needs_login: if self.needs_login:
raise SmartDeviceException("Login must be complete before trying to send") raise SmartDeviceException("Login must be complete before trying to send")
resp_dict = await self.send_secure_passthrough(request) return await self.send_secure_passthrough(request)
if resp_dict["error_code"] != 0:
self._handshake_done = False
self._login_token = None
raise SmartDeviceException(
f"Could not complete send, response was {resp_dict}",
)
return resp_dict
async def close(self) -> None: async def close(self) -> None:
"""Close the protocol.""" """Close the protocol."""
client = self._http_client client = self._http_client
self._http_client = None self._http_client = None
self._handshake_done = False
self._login_token = None
if client: if client:
await client.aclose() await client.aclose()

View File

@ -1,4 +1,5 @@
"""python-kasa exceptions.""" """python-kasa exceptions."""
from enum import IntEnum
class SmartDeviceException(Exception): class SmartDeviceException(Exception):
@ -11,3 +12,87 @@ class UnsupportedDeviceException(SmartDeviceException):
class AuthenticationException(SmartDeviceException): class AuthenticationException(SmartDeviceException):
"""Base exception for device authentication errors.""" """Base exception for device authentication errors."""
class RetryableException(SmartDeviceException):
"""Retryable exception for device errors."""
class TimeoutException(SmartDeviceException):
"""Timeout exception for device errors."""
class SmartErrorCode(IntEnum):
"""Enum for SMART Error Codes."""
SUCCESS = 0
# Transport Errors
SESSION_TIMEOUT_ERROR = 9999
MULTI_REQUEST_FAILED_ERROR = 1200
HTTP_TRANSPORT_FAILED_ERROR = 1112
LOGIN_FAILED_ERROR = 1111
HAND_SHAKE_FAILED_ERROR = 1100
TRANSPORT_NOT_AVAILABLE_ERROR = 1002
CMD_COMMAND_CANCEL_ERROR = 1001
NULL_TRANSPORT_ERROR = 1000
# Common Method Errors
COMMON_FAILED_ERROR = -1
UNSPECIFIC_ERROR = -1001
UNKNOWN_METHOD_ERROR = -1002
JSON_DECODE_FAIL_ERROR = -1003
JSON_ENCODE_FAIL_ERROR = -1004
AES_DECODE_FAIL_ERROR = -1005
REQUEST_LEN_ERROR_ERROR = -1006
CLOUD_FAILED_ERROR = -1007
PARAMS_ERROR = -1008
INVALID_PUBLIC_KEY_ERROR = -1010 # Unverified
SESSION_PARAM_ERROR = -1101
# Method Specific Errors
QUICK_SETUP_ERROR = -1201
DEVICE_ERROR = -1301
DEVICE_NEXT_EVENT_ERROR = -1302
FIRMWARE_ERROR = -1401
FIRMWARE_VER_ERROR_ERROR = -1402
LOGIN_ERROR = -1501
TIME_ERROR = -1601
TIME_SYS_ERROR = -1602
TIME_SAVE_ERROR = -1603
WIRELESS_ERROR = -1701
WIRELESS_UNSUPPORTED_ERROR = -1702
SCHEDULE_ERROR = -1801
SCHEDULE_FULL_ERROR = -1802
SCHEDULE_CONFLICT_ERROR = -1803
SCHEDULE_SAVE_ERROR = -1804
SCHEDULE_INDEX_ERROR = -1805
COUNTDOWN_ERROR = -1901
COUNTDOWN_CONFLICT_ERROR = -1902
COUNTDOWN_SAVE_ERROR = -1903
ANTITHEFT_ERROR = -2001
ANTITHEFT_CONFLICT_ERROR = -2002
ANTITHEFT_SAVE_ERROR = -2003
ACCOUNT_ERROR = -2101
STAT_ERROR = -2201
STAT_SAVE_ERROR = -2202
DST_ERROR = -2301
DST_SAVE_ERROR = -2302
SMART_RETRYABLE_ERRORS = [
SmartErrorCode.TRANSPORT_NOT_AVAILABLE_ERROR,
SmartErrorCode.HTTP_TRANSPORT_FAILED_ERROR,
SmartErrorCode.UNSPECIFIC_ERROR,
]
SMART_AUTHENTICATION_ERRORS = [
SmartErrorCode.LOGIN_ERROR,
SmartErrorCode.LOGIN_FAILED_ERROR,
SmartErrorCode.AES_DECODE_FAIL_ERROR,
SmartErrorCode.HAND_SHAKE_FAILED_ERROR,
]
SMART_TIMEOUT_ERRORS = [
SmartErrorCode.SESSION_TIMEOUT_ERROR,
]

View File

@ -377,6 +377,7 @@ class KlapTransport(BaseTransport):
"""Close the transport.""" """Close the transport."""
client = self._http_client client = self._http_client
self._http_client = None self._http_client = None
self._handshake_done = False
if client: if client:
await client.aclose() await client.aclose()

View File

@ -16,7 +16,16 @@ import httpx
from .aestransport import AesTransport from .aestransport import AesTransport
from .credentials import Credentials from .credentials import Credentials
from .exceptions import AuthenticationException, SmartDeviceException from .exceptions import (
SMART_AUTHENTICATION_ERRORS,
SMART_RETRYABLE_ERRORS,
SMART_TIMEOUT_ERRORS,
AuthenticationException,
RetryableException,
SmartDeviceException,
SmartErrorCode,
TimeoutException,
)
from .json import dumps as json_dumps from .json import dumps as json_dumps
from .protocol import BaseTransport, TPLinkProtocol, md5 from .protocol import BaseTransport, TPLinkProtocol, md5
@ -28,6 +37,7 @@ class SmartProtocol(TPLinkProtocol):
"""Class for the new TPLink SMART protocol.""" """Class for the new TPLink SMART protocol."""
DEFAULT_PORT = 80 DEFAULT_PORT = 80
SLEEP_SECONDS_AFTER_TIMEOUT = 1
def __init__( def __init__(
self, self,
@ -64,6 +74,22 @@ class SmartProtocol(TPLinkProtocol):
"""Query the device retrying for retry_count on failure.""" """Query the device retrying for retry_count on failure."""
async with self._query_lock: async with self._query_lock:
resp_dict = await self._query(request, retry_count) resp_dict = await self._query(request, retry_count)
if (
error_code := SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
) != SmartErrorCode.SUCCESS:
msg = (
f"Error querying device: {self.host}: "
+ f"{error_code.name}({error_code.value})"
)
if error_code in SMART_TIMEOUT_ERRORS:
raise TimeoutException(msg)
if error_code in SMART_RETRYABLE_ERRORS:
raise RetryableException(msg)
if error_code in SMART_AUTHENTICATION_ERRORS:
raise AuthenticationException(msg)
raise SmartDeviceException(msg)
if "result" in resp_dict: if "result" in resp_dict:
return resp_dict["result"] return resp_dict["result"]
return {} return {}
@ -86,20 +112,41 @@ class SmartProtocol(TPLinkProtocol):
f"Unable to connect to the device: {self.host}: {cex}" f"Unable to connect to the device: {self.host}: {cex}"
) from cex ) from cex
except TimeoutError as tex: except TimeoutError as tex:
await self.close() if retry >= retry_count:
raise SmartDeviceException( await self.close()
f"Unable to connect to the device, timed out: {self.host}: {tex}" raise SmartDeviceException(
) from tex "Unable to connect to the device, "
+ f"timed out: {self.host}: {tex}"
) from tex
await asyncio.sleep(self.SLEEP_SECONDS_AFTER_TIMEOUT)
continue
except AuthenticationException as auex: except AuthenticationException as auex:
await self.close()
_LOGGER.debug("Unable to authenticate with %s, not retrying", self.host) _LOGGER.debug("Unable to authenticate with %s, not retrying", self.host)
raise auex raise auex
except Exception as ex: except RetryableException as ex:
await self.close()
if retry >= retry_count: if retry >= retry_count:
await self.close()
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
raise ex
continue
except TimeoutException as ex:
if retry >= retry_count:
await self.close()
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
raise ex
await asyncio.sleep(self.SLEEP_SECONDS_AFTER_TIMEOUT)
continue
except Exception as ex:
if retry >= retry_count:
await self.close()
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry) _LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
raise SmartDeviceException( raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {ex}" f"Unable to query the device {self.host}:{self.port}: {ex}"
) from ex ) from ex
_LOGGER.debug(
"Unable to query the device %s, retrying: %s", self.host, ex
)
continue continue
# make mypy happy, this should never be reached.. # make mypy happy, this should never be reached..

View File

@ -166,13 +166,7 @@ class TapoBulb(TapoDevice, SmartBulb):
if value is not None: if value is not None:
request_payload["brightness"] = value request_payload["brightness"] = value
return await self.protocol.query( return await self.protocol.query({"set_device_info": {**request_payload}})
{
"set_device_info": {
**request_payload
}
}
)
async def set_color_temp( async def set_color_temp(
self, temp: int, *, brightness=None, transition: Optional[int] = None self, temp: int, *, brightness=None, transition: Optional[int] = None

View File

@ -315,11 +315,11 @@ class FakeSmartTransport(BaseTransport):
method = request_dict["method"] method = request_dict["method"]
params = request_dict["params"] params = request_dict["params"]
if method == "component_nego" or method[:4] == "get_": if method == "component_nego" or method[:4] == "get_":
return {"result": self.info[method]} return {"result": self.info[method], "error_code": 0}
elif method[:4] == "set_": elif method[:4] == "set_":
target_method = f"get_{method[4:]}" target_method = f"get_{method[4:]}"
self.info[target_method].update(params) self.info[target_method].update(params)
return {"result": ""} return {"error_code": 0}
async def close(self) -> None: async def close(self) -> None:
pass pass

View File

@ -86,7 +86,7 @@ async def test_protocol_retry_recoverable_error(
async def test_protocol_reconnect(mocker, retry_count, protocol_class, transport_class): async def test_protocol_reconnect(mocker, retry_count, protocol_class, transport_class):
host = "127.0.0.1" host = "127.0.0.1"
remaining = retry_count remaining = retry_count
mock_response = {"result": {"great": "success"}} mock_response = {"result": {"great": "success"}, "error_code": 0}
def _fail_one_less_than_retry_count(*_, **__): def _fail_one_less_than_retry_count(*_, **__):
nonlocal remaining nonlocal remaining