mirror of
https://github.com/python-kasa/python-kasa.git
synced 2024-12-22 19:23:34 +00:00
Improve smartprotocol error handling and retries (#578)
* Improve smartprotocol error handling and retries * Update after review * Enum to IntEnum and SLEEP_SECONDS_AFTER_TIMEOUT
This commit is contained in:
parent
a77af5fb3b
commit
2e6c41d039
@ -17,7 +17,16 @@ from cryptography.hazmat.primitives.asymmetric import rsa
|
|||||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||||
|
|
||||||
from .credentials import Credentials
|
from .credentials import Credentials
|
||||||
from .exceptions import AuthenticationException, SmartDeviceException
|
from .exceptions import (
|
||||||
|
SMART_AUTHENTICATION_ERRORS,
|
||||||
|
SMART_RETRYABLE_ERRORS,
|
||||||
|
SMART_TIMEOUT_ERRORS,
|
||||||
|
AuthenticationException,
|
||||||
|
RetryableException,
|
||||||
|
SmartDeviceException,
|
||||||
|
SmartErrorCode,
|
||||||
|
TimeoutException,
|
||||||
|
)
|
||||||
from .json import dumps as json_dumps
|
from .json import dumps as json_dumps
|
||||||
from .json import loads as json_loads
|
from .json import loads as json_loads
|
||||||
from .protocol import BaseTransport
|
from .protocol import BaseTransport
|
||||||
@ -110,6 +119,21 @@ class AesTransport(BaseTransport):
|
|||||||
|
|
||||||
return resp.status_code, response_data
|
return resp.status_code, response_data
|
||||||
|
|
||||||
|
def _handle_response_error_code(self, resp_dict: dict, msg: str):
|
||||||
|
if (
|
||||||
|
error_code := SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
|
||||||
|
) != SmartErrorCode.SUCCESS:
|
||||||
|
msg = f"{msg}: {self.host}: {error_code.name}({error_code.value})"
|
||||||
|
if error_code in SMART_TIMEOUT_ERRORS:
|
||||||
|
raise TimeoutException(msg)
|
||||||
|
if error_code in SMART_RETRYABLE_ERRORS:
|
||||||
|
raise RetryableException(msg)
|
||||||
|
if error_code in SMART_AUTHENTICATION_ERRORS:
|
||||||
|
self._handshake_done = False
|
||||||
|
self._login_token = None
|
||||||
|
raise AuthenticationException(msg)
|
||||||
|
raise SmartDeviceException(msg)
|
||||||
|
|
||||||
async def send_secure_passthrough(self, request: str):
|
async def send_secure_passthrough(self, request: str):
|
||||||
"""Send encrypted message as passthrough."""
|
"""Send encrypted message as passthrough."""
|
||||||
url = f"http://{self.host}/app"
|
url = f"http://{self.host}/app"
|
||||||
@ -123,17 +147,22 @@ class AesTransport(BaseTransport):
|
|||||||
}
|
}
|
||||||
status_code, resp_dict = await self.client_post(url, json=passthrough_request)
|
status_code, resp_dict = await self.client_post(url, json=passthrough_request)
|
||||||
# _LOGGER.debug(f"secure_passthrough response is {status_code}: {resp_dict}")
|
# _LOGGER.debug(f"secure_passthrough response is {status_code}: {resp_dict}")
|
||||||
if status_code == 200 and resp_dict["error_code"] == 0:
|
|
||||||
response = self._encryption_session.decrypt( # type: ignore
|
if status_code != 200:
|
||||||
resp_dict["result"]["response"].encode()
|
raise SmartDeviceException(
|
||||||
|
f"{self.host} responded with an unexpected "
|
||||||
|
+ f"status code {status_code} to passthrough"
|
||||||
)
|
)
|
||||||
_LOGGER.debug(f"decrypted secure_passthrough response is {response}")
|
|
||||||
resp_dict = json_loads(response)
|
self._handle_response_error_code(
|
||||||
return resp_dict
|
resp_dict, "Error sending secure_passthrough message"
|
||||||
else:
|
)
|
||||||
self._handshake_done = False
|
|
||||||
self._login_token = None
|
response = self._encryption_session.decrypt( # type: ignore
|
||||||
raise AuthenticationException("Could not complete send")
|
resp_dict["result"]["response"].encode()
|
||||||
|
)
|
||||||
|
resp_dict = json_loads(response)
|
||||||
|
return resp_dict
|
||||||
|
|
||||||
async def perform_login(self, login_request: Union[str, dict], *, login_v2: bool):
|
async def perform_login(self, login_request: Union[str, dict], *, login_v2: bool):
|
||||||
"""Login to the device."""
|
"""Login to the device."""
|
||||||
@ -207,29 +236,32 @@ class AesTransport(BaseTransport):
|
|||||||
|
|
||||||
_LOGGER.debug(f"Device responded with: {resp_dict}")
|
_LOGGER.debug(f"Device responded with: {resp_dict}")
|
||||||
|
|
||||||
if status_code == 200 and resp_dict["error_code"] == 0:
|
if status_code != 200:
|
||||||
_LOGGER.debug("Decoding handshake key...")
|
raise SmartDeviceException(
|
||||||
handshake_key = resp_dict["result"]["key"]
|
f"{self.host} responded with an unexpected "
|
||||||
|
+ f"status code {status_code} to handshake"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._handle_response_error_code(resp_dict, "Unable to complete handshake")
|
||||||
|
|
||||||
|
handshake_key = resp_dict["result"]["key"]
|
||||||
|
|
||||||
|
self._session_cookie = self._http_client.cookies.get( # type: ignore
|
||||||
|
self.SESSION_COOKIE_NAME
|
||||||
|
)
|
||||||
|
if not self._session_cookie:
|
||||||
self._session_cookie = self._http_client.cookies.get( # type: ignore
|
self._session_cookie = self._http_client.cookies.get( # type: ignore
|
||||||
self.SESSION_COOKIE_NAME
|
"SESSIONID"
|
||||||
)
|
|
||||||
if not self._session_cookie:
|
|
||||||
self._session_cookie = self._http_client.cookies.get( # type: ignore
|
|
||||||
"SESSIONID"
|
|
||||||
)
|
|
||||||
|
|
||||||
self._session_expire_at = time.time() + 86400
|
|
||||||
self._encryption_session = AesEncyptionSession.create_from_keypair(
|
|
||||||
handshake_key, key_pair
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self._handshake_done = True
|
self._session_expire_at = time.time() + 86400
|
||||||
|
self._encryption_session = AesEncyptionSession.create_from_keypair(
|
||||||
|
handshake_key, key_pair
|
||||||
|
)
|
||||||
|
|
||||||
_LOGGER.debug("Handshake with %s complete", self.host)
|
self._handshake_done = True
|
||||||
|
|
||||||
else:
|
_LOGGER.debug("Handshake with %s complete", self.host)
|
||||||
raise AuthenticationException("Could not complete handshake")
|
|
||||||
|
|
||||||
def _handshake_session_expired(self):
|
def _handshake_session_expired(self):
|
||||||
"""Return true if session has expired."""
|
"""Return true if session has expired."""
|
||||||
@ -247,19 +279,14 @@ class AesTransport(BaseTransport):
|
|||||||
if self.needs_login:
|
if self.needs_login:
|
||||||
raise SmartDeviceException("Login must be complete before trying to send")
|
raise SmartDeviceException("Login must be complete before trying to send")
|
||||||
|
|
||||||
resp_dict = await self.send_secure_passthrough(request)
|
return await self.send_secure_passthrough(request)
|
||||||
if resp_dict["error_code"] != 0:
|
|
||||||
self._handshake_done = False
|
|
||||||
self._login_token = None
|
|
||||||
raise SmartDeviceException(
|
|
||||||
f"Could not complete send, response was {resp_dict}",
|
|
||||||
)
|
|
||||||
return resp_dict
|
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
"""Close the protocol."""
|
"""Close the protocol."""
|
||||||
client = self._http_client
|
client = self._http_client
|
||||||
self._http_client = None
|
self._http_client = None
|
||||||
|
self._handshake_done = False
|
||||||
|
self._login_token = None
|
||||||
if client:
|
if client:
|
||||||
await client.aclose()
|
await client.aclose()
|
||||||
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
"""python-kasa exceptions."""
|
"""python-kasa exceptions."""
|
||||||
|
from enum import IntEnum
|
||||||
|
|
||||||
|
|
||||||
class SmartDeviceException(Exception):
|
class SmartDeviceException(Exception):
|
||||||
@ -11,3 +12,87 @@ class UnsupportedDeviceException(SmartDeviceException):
|
|||||||
|
|
||||||
class AuthenticationException(SmartDeviceException):
|
class AuthenticationException(SmartDeviceException):
|
||||||
"""Base exception for device authentication errors."""
|
"""Base exception for device authentication errors."""
|
||||||
|
|
||||||
|
|
||||||
|
class RetryableException(SmartDeviceException):
|
||||||
|
"""Retryable exception for device errors."""
|
||||||
|
|
||||||
|
|
||||||
|
class TimeoutException(SmartDeviceException):
|
||||||
|
"""Timeout exception for device errors."""
|
||||||
|
|
||||||
|
|
||||||
|
class SmartErrorCode(IntEnum):
|
||||||
|
"""Enum for SMART Error Codes."""
|
||||||
|
|
||||||
|
SUCCESS = 0
|
||||||
|
|
||||||
|
# Transport Errors
|
||||||
|
SESSION_TIMEOUT_ERROR = 9999
|
||||||
|
MULTI_REQUEST_FAILED_ERROR = 1200
|
||||||
|
HTTP_TRANSPORT_FAILED_ERROR = 1112
|
||||||
|
LOGIN_FAILED_ERROR = 1111
|
||||||
|
HAND_SHAKE_FAILED_ERROR = 1100
|
||||||
|
TRANSPORT_NOT_AVAILABLE_ERROR = 1002
|
||||||
|
CMD_COMMAND_CANCEL_ERROR = 1001
|
||||||
|
NULL_TRANSPORT_ERROR = 1000
|
||||||
|
|
||||||
|
# Common Method Errors
|
||||||
|
COMMON_FAILED_ERROR = -1
|
||||||
|
UNSPECIFIC_ERROR = -1001
|
||||||
|
UNKNOWN_METHOD_ERROR = -1002
|
||||||
|
JSON_DECODE_FAIL_ERROR = -1003
|
||||||
|
JSON_ENCODE_FAIL_ERROR = -1004
|
||||||
|
AES_DECODE_FAIL_ERROR = -1005
|
||||||
|
REQUEST_LEN_ERROR_ERROR = -1006
|
||||||
|
CLOUD_FAILED_ERROR = -1007
|
||||||
|
PARAMS_ERROR = -1008
|
||||||
|
INVALID_PUBLIC_KEY_ERROR = -1010 # Unverified
|
||||||
|
SESSION_PARAM_ERROR = -1101
|
||||||
|
|
||||||
|
# Method Specific Errors
|
||||||
|
QUICK_SETUP_ERROR = -1201
|
||||||
|
DEVICE_ERROR = -1301
|
||||||
|
DEVICE_NEXT_EVENT_ERROR = -1302
|
||||||
|
FIRMWARE_ERROR = -1401
|
||||||
|
FIRMWARE_VER_ERROR_ERROR = -1402
|
||||||
|
LOGIN_ERROR = -1501
|
||||||
|
TIME_ERROR = -1601
|
||||||
|
TIME_SYS_ERROR = -1602
|
||||||
|
TIME_SAVE_ERROR = -1603
|
||||||
|
WIRELESS_ERROR = -1701
|
||||||
|
WIRELESS_UNSUPPORTED_ERROR = -1702
|
||||||
|
SCHEDULE_ERROR = -1801
|
||||||
|
SCHEDULE_FULL_ERROR = -1802
|
||||||
|
SCHEDULE_CONFLICT_ERROR = -1803
|
||||||
|
SCHEDULE_SAVE_ERROR = -1804
|
||||||
|
SCHEDULE_INDEX_ERROR = -1805
|
||||||
|
COUNTDOWN_ERROR = -1901
|
||||||
|
COUNTDOWN_CONFLICT_ERROR = -1902
|
||||||
|
COUNTDOWN_SAVE_ERROR = -1903
|
||||||
|
ANTITHEFT_ERROR = -2001
|
||||||
|
ANTITHEFT_CONFLICT_ERROR = -2002
|
||||||
|
ANTITHEFT_SAVE_ERROR = -2003
|
||||||
|
ACCOUNT_ERROR = -2101
|
||||||
|
STAT_ERROR = -2201
|
||||||
|
STAT_SAVE_ERROR = -2202
|
||||||
|
DST_ERROR = -2301
|
||||||
|
DST_SAVE_ERROR = -2302
|
||||||
|
|
||||||
|
|
||||||
|
SMART_RETRYABLE_ERRORS = [
|
||||||
|
SmartErrorCode.TRANSPORT_NOT_AVAILABLE_ERROR,
|
||||||
|
SmartErrorCode.HTTP_TRANSPORT_FAILED_ERROR,
|
||||||
|
SmartErrorCode.UNSPECIFIC_ERROR,
|
||||||
|
]
|
||||||
|
|
||||||
|
SMART_AUTHENTICATION_ERRORS = [
|
||||||
|
SmartErrorCode.LOGIN_ERROR,
|
||||||
|
SmartErrorCode.LOGIN_FAILED_ERROR,
|
||||||
|
SmartErrorCode.AES_DECODE_FAIL_ERROR,
|
||||||
|
SmartErrorCode.HAND_SHAKE_FAILED_ERROR,
|
||||||
|
]
|
||||||
|
|
||||||
|
SMART_TIMEOUT_ERRORS = [
|
||||||
|
SmartErrorCode.SESSION_TIMEOUT_ERROR,
|
||||||
|
]
|
||||||
|
@ -377,6 +377,7 @@ class KlapTransport(BaseTransport):
|
|||||||
"""Close the transport."""
|
"""Close the transport."""
|
||||||
client = self._http_client
|
client = self._http_client
|
||||||
self._http_client = None
|
self._http_client = None
|
||||||
|
self._handshake_done = False
|
||||||
if client:
|
if client:
|
||||||
await client.aclose()
|
await client.aclose()
|
||||||
|
|
||||||
|
@ -16,7 +16,16 @@ import httpx
|
|||||||
|
|
||||||
from .aestransport import AesTransport
|
from .aestransport import AesTransport
|
||||||
from .credentials import Credentials
|
from .credentials import Credentials
|
||||||
from .exceptions import AuthenticationException, SmartDeviceException
|
from .exceptions import (
|
||||||
|
SMART_AUTHENTICATION_ERRORS,
|
||||||
|
SMART_RETRYABLE_ERRORS,
|
||||||
|
SMART_TIMEOUT_ERRORS,
|
||||||
|
AuthenticationException,
|
||||||
|
RetryableException,
|
||||||
|
SmartDeviceException,
|
||||||
|
SmartErrorCode,
|
||||||
|
TimeoutException,
|
||||||
|
)
|
||||||
from .json import dumps as json_dumps
|
from .json import dumps as json_dumps
|
||||||
from .protocol import BaseTransport, TPLinkProtocol, md5
|
from .protocol import BaseTransport, TPLinkProtocol, md5
|
||||||
|
|
||||||
@ -28,6 +37,7 @@ class SmartProtocol(TPLinkProtocol):
|
|||||||
"""Class for the new TPLink SMART protocol."""
|
"""Class for the new TPLink SMART protocol."""
|
||||||
|
|
||||||
DEFAULT_PORT = 80
|
DEFAULT_PORT = 80
|
||||||
|
SLEEP_SECONDS_AFTER_TIMEOUT = 1
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -64,6 +74,22 @@ class SmartProtocol(TPLinkProtocol):
|
|||||||
"""Query the device retrying for retry_count on failure."""
|
"""Query the device retrying for retry_count on failure."""
|
||||||
async with self._query_lock:
|
async with self._query_lock:
|
||||||
resp_dict = await self._query(request, retry_count)
|
resp_dict = await self._query(request, retry_count)
|
||||||
|
|
||||||
|
if (
|
||||||
|
error_code := SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
|
||||||
|
) != SmartErrorCode.SUCCESS:
|
||||||
|
msg = (
|
||||||
|
f"Error querying device: {self.host}: "
|
||||||
|
+ f"{error_code.name}({error_code.value})"
|
||||||
|
)
|
||||||
|
if error_code in SMART_TIMEOUT_ERRORS:
|
||||||
|
raise TimeoutException(msg)
|
||||||
|
if error_code in SMART_RETRYABLE_ERRORS:
|
||||||
|
raise RetryableException(msg)
|
||||||
|
if error_code in SMART_AUTHENTICATION_ERRORS:
|
||||||
|
raise AuthenticationException(msg)
|
||||||
|
raise SmartDeviceException(msg)
|
||||||
|
|
||||||
if "result" in resp_dict:
|
if "result" in resp_dict:
|
||||||
return resp_dict["result"]
|
return resp_dict["result"]
|
||||||
return {}
|
return {}
|
||||||
@ -86,20 +112,41 @@ class SmartProtocol(TPLinkProtocol):
|
|||||||
f"Unable to connect to the device: {self.host}: {cex}"
|
f"Unable to connect to the device: {self.host}: {cex}"
|
||||||
) from cex
|
) from cex
|
||||||
except TimeoutError as tex:
|
except TimeoutError as tex:
|
||||||
await self.close()
|
if retry >= retry_count:
|
||||||
raise SmartDeviceException(
|
await self.close()
|
||||||
f"Unable to connect to the device, timed out: {self.host}: {tex}"
|
raise SmartDeviceException(
|
||||||
) from tex
|
"Unable to connect to the device, "
|
||||||
|
+ f"timed out: {self.host}: {tex}"
|
||||||
|
) from tex
|
||||||
|
await asyncio.sleep(self.SLEEP_SECONDS_AFTER_TIMEOUT)
|
||||||
|
continue
|
||||||
except AuthenticationException as auex:
|
except AuthenticationException as auex:
|
||||||
|
await self.close()
|
||||||
_LOGGER.debug("Unable to authenticate with %s, not retrying", self.host)
|
_LOGGER.debug("Unable to authenticate with %s, not retrying", self.host)
|
||||||
raise auex
|
raise auex
|
||||||
except Exception as ex:
|
except RetryableException as ex:
|
||||||
await self.close()
|
|
||||||
if retry >= retry_count:
|
if retry >= retry_count:
|
||||||
|
await self.close()
|
||||||
|
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
||||||
|
raise ex
|
||||||
|
continue
|
||||||
|
except TimeoutException as ex:
|
||||||
|
if retry >= retry_count:
|
||||||
|
await self.close()
|
||||||
|
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
||||||
|
raise ex
|
||||||
|
await asyncio.sleep(self.SLEEP_SECONDS_AFTER_TIMEOUT)
|
||||||
|
continue
|
||||||
|
except Exception as ex:
|
||||||
|
if retry >= retry_count:
|
||||||
|
await self.close()
|
||||||
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
||||||
raise SmartDeviceException(
|
raise SmartDeviceException(
|
||||||
f"Unable to connect to the device: {self.host}: {ex}"
|
f"Unable to query the device {self.host}:{self.port}: {ex}"
|
||||||
) from ex
|
) from ex
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Unable to query the device %s, retrying: %s", self.host, ex
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# make mypy happy, this should never be reached..
|
# make mypy happy, this should never be reached..
|
||||||
|
@ -166,13 +166,7 @@ class TapoBulb(TapoDevice, SmartBulb):
|
|||||||
if value is not None:
|
if value is not None:
|
||||||
request_payload["brightness"] = value
|
request_payload["brightness"] = value
|
||||||
|
|
||||||
return await self.protocol.query(
|
return await self.protocol.query({"set_device_info": {**request_payload}})
|
||||||
{
|
|
||||||
"set_device_info": {
|
|
||||||
**request_payload
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def set_color_temp(
|
async def set_color_temp(
|
||||||
self, temp: int, *, brightness=None, transition: Optional[int] = None
|
self, temp: int, *, brightness=None, transition: Optional[int] = None
|
||||||
|
@ -315,11 +315,11 @@ class FakeSmartTransport(BaseTransport):
|
|||||||
method = request_dict["method"]
|
method = request_dict["method"]
|
||||||
params = request_dict["params"]
|
params = request_dict["params"]
|
||||||
if method == "component_nego" or method[:4] == "get_":
|
if method == "component_nego" or method[:4] == "get_":
|
||||||
return {"result": self.info[method]}
|
return {"result": self.info[method], "error_code": 0}
|
||||||
elif method[:4] == "set_":
|
elif method[:4] == "set_":
|
||||||
target_method = f"get_{method[4:]}"
|
target_method = f"get_{method[4:]}"
|
||||||
self.info[target_method].update(params)
|
self.info[target_method].update(params)
|
||||||
return {"result": ""}
|
return {"error_code": 0}
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
@ -86,7 +86,7 @@ async def test_protocol_retry_recoverable_error(
|
|||||||
async def test_protocol_reconnect(mocker, retry_count, protocol_class, transport_class):
|
async def test_protocol_reconnect(mocker, retry_count, protocol_class, transport_class):
|
||||||
host = "127.0.0.1"
|
host = "127.0.0.1"
|
||||||
remaining = retry_count
|
remaining = retry_count
|
||||||
mock_response = {"result": {"great": "success"}}
|
mock_response = {"result": {"great": "success"}, "error_code": 0}
|
||||||
|
|
||||||
def _fail_one_less_than_retry_count(*_, **__):
|
def _fail_one_less_than_retry_count(*_, **__):
|
||||||
nonlocal remaining
|
nonlocal remaining
|
||||||
|
Loading…
Reference in New Issue
Block a user