From 93ca3ad2e10194a13dcee843c6deab369930a672 Mon Sep 17 00:00:00 2001 From: "Steven B." <51370195+sdb9696@users.noreply.github.com> Date: Fri, 20 Dec 2024 14:55:15 +0000 Subject: [PATCH] Handle smartcam device blocked response (#1393) Devices that have failed authentication multiple times due to bad credentials go into a blocked state for 30 mins. Handle that as a different error type instead of treating it as a normal `AuthenticationError`. --- kasa/transports/sslaestransport.py | 31 +++++++++++++++++++++++- tests/transports/test_sslaestransport.py | 27 +++++++++++++++++++++ 2 files changed, 57 insertions(+), 1 deletion(-) diff --git a/kasa/transports/sslaestransport.py b/kasa/transports/sslaestransport.py index 2061d293..6e6ec0db 100644 --- a/kasa/transports/sslaestransport.py +++ b/kasa/transports/sslaestransport.py @@ -160,6 +160,19 @@ class SslAesTransport(BaseTransport): error_code = SmartErrorCode.INTERNAL_UNKNOWN_ERROR return error_code + def _get_response_inner_error(self, resp_dict: Any) -> SmartErrorCode | None: + error_code_raw = resp_dict.get("data", {}).get("code") + if error_code_raw is None: + return None + try: + error_code = SmartErrorCode.from_int(error_code_raw) + except ValueError: + _LOGGER.warning( + "Device %s received unknown error code: %s", self._host, error_code_raw + ) + error_code = SmartErrorCode.INTERNAL_UNKNOWN_ERROR + return error_code + def _handle_response_error_code(self, resp_dict: Any, msg: str) -> None: error_code = self._get_response_error(resp_dict) if error_code is SmartErrorCode.SUCCESS: @@ -383,13 +396,29 @@ class SslAesTransport(BaseTransport): error_code = default_error_code resp_dict = default_resp_dict + # If the default login worked it's ok not to provide credentials but if + # it didn't raise auth error here. if not self._username: raise AuthenticationError( f"Credentials must be supplied to connect to {self._host}" ) + + # Device responds with INVALID_NONCE and a "nonce" to indicate ready + # for secure login. Otherwise error. if error_code is not SmartErrorCode.INVALID_NONCE or ( - resp_dict and "nonce" not in resp_dict["result"].get("data", {}) + resp_dict and "nonce" not in resp_dict.get("result", {}).get("data", {}) ): + if ( + resp_dict + and self._get_response_inner_error(resp_dict) + is SmartErrorCode.DEVICE_BLOCKED + ): + sec_left = resp_dict.get("data", {}).get("sec_left") + msg = "Device blocked" + ( + f" for {sec_left} seconds" if sec_left else "" + ) + raise DeviceError(msg, error_code=SmartErrorCode.DEVICE_BLOCKED) + raise AuthenticationError(f"Error trying handshake1: {resp_dict}") if TYPE_CHECKING: diff --git a/tests/transports/test_sslaestransport.py b/tests/transports/test_sslaestransport.py index 6816fa35..00c54a54 100644 --- a/tests/transports/test_sslaestransport.py +++ b/tests/transports/test_sslaestransport.py @@ -15,6 +15,7 @@ from kasa.credentials import DEFAULT_CREDENTIALS, Credentials, get_default_crede from kasa.deviceconfig import DeviceConfig from kasa.exceptions import ( AuthenticationError, + DeviceError, KasaException, SmartErrorCode, ) @@ -200,6 +201,22 @@ async def test_unencrypted_response(mocker, caplog): ) +async def test_device_blocked_response(mocker): + host = "127.0.0.1" + mock_ssl_aes_device = MockSslAesDevice(host, device_blocked=True) + mocker.patch.object( + aiohttp.ClientSession, "post", side_effect=mock_ssl_aes_device.post + ) + + transport = SslAesTransport( + config=DeviceConfig(host, credentials=Credentials(MOCK_USER, MOCK_PWD)) + ) + msg = "Device blocked for 1685 seconds" + + with pytest.raises(DeviceError, match=msg): + await transport.perform_handshake() + + async def test_port_override(): """Test that port override sets the app_url.""" host = "127.0.0.1" @@ -235,6 +252,11 @@ class MockSslAesDevice: }, } + DEVICE_BLOCKED_RESP = { + "data": {"code": SmartErrorCode.DEVICE_BLOCKED.value, "sec_left": 1685}, + "error_code": SmartErrorCode.SESSION_EXPIRED.value, + } + class _mock_response: def __init__(self, status, request: dict): self.status = status @@ -263,6 +285,7 @@ class MockSslAesDevice: send_error_code=0, secure_passthrough_error_code=0, digest_password_fail=False, + device_blocked=False, ): self.host = host self.http_client = HttpClient(DeviceConfig(self.host)) @@ -277,6 +300,7 @@ class MockSslAesDevice: self.do_not_encrypt_response = do_not_encrypt_response self.want_default_username = want_default_username self.digest_password_fail = digest_password_fail + self.device_blocked = device_blocked async def post(self, url: URL, params=None, json=None, data=None, *_, **__): if data: @@ -303,6 +327,9 @@ class MockSslAesDevice: request_nonce = request["params"].get("cnonce") request_username = request["params"].get("username") + if self.device_blocked: + return self._mock_response(self.status_code, self.DEVICE_BLOCKED_RESP) + if (self.want_default_username and request_username != MOCK_ADMIN_USER) or ( not self.want_default_username and request_username != MOCK_USER ):