From 8418ba3eefdfb167c5ecdb2204516b1533b54a89 Mon Sep 17 00:00:00 2001 From: "Steven B." <51370195+sdb9696@users.noreply.github.com> Date: Fri, 20 Dec 2024 19:23:18 +0000 Subject: [PATCH] Treat smartcam 500 errors after handshake as retryable (#1395) `smartcam` devices can respond with 500 if another session is created from the same host --- kasa/httpclient.py | 19 +++++-- kasa/transports/sslaestransport.py | 27 +++++++++- tests/transports/test_sslaestransport.py | 69 ++++++++++++++++++++++-- 3 files changed, 107 insertions(+), 8 deletions(-) diff --git a/kasa/httpclient.py b/kasa/httpclient.py index 87e3626a..31d8dfbb 100644 --- a/kasa/httpclient.py +++ b/kasa/httpclient.py @@ -113,10 +113,23 @@ class HttpClient: ssl=ssl, ) async with resp: - if resp.status == 200: - response_data = await resp.read() - if return_json: + response_data = await resp.read() + + if resp.status == 200: + if return_json: + response_data = json_loads(response_data.decode()) + else: + _LOGGER.debug( + "Device %s received status code %s with response %s", + self._config.host, + resp.status, + str(response_data), + ) + if response_data and return_json: + try: response_data = json_loads(response_data.decode()) + except Exception: + _LOGGER.debug("Device %s response could not be parsed as json") except (aiohttp.ServerDisconnectedError, aiohttp.ClientOSError) as ex: if not self._wait_between_requests: diff --git a/kasa/transports/sslaestransport.py b/kasa/transports/sslaestransport.py index 6e6ec0db..500d9422 100644 --- a/kasa/transports/sslaestransport.py +++ b/kasa/transports/sslaestransport.py @@ -8,6 +8,7 @@ import hashlib import logging import secrets import ssl +from contextlib import suppress from enum import Enum, auto from typing import TYPE_CHECKING, Any, cast @@ -229,6 +230,31 @@ class SslAesTransport(BaseTransport): ssl=await self._get_ssl_context(), ) + if TYPE_CHECKING: + assert self._encryption_session is not None + + # Devices can respond with 500 if another session is created from + # the same host. Decryption may not succeed after that + if status_code == 500: + msg = ( + f"Device {self._host} replied with status 500 after handshake, " + f"response: " + ) + decrypted = None + if isinstance(resp_dict, dict) and ( + response := resp_dict.get("result", {}).get("response") + ): + with suppress(Exception): + decrypted = self._encryption_session.decrypt(response.encode()) + + if decrypted: + msg += decrypted + else: + msg += str(resp_dict) + + _LOGGER.debug(msg) + raise _RetryableError(msg) + if status_code != 200: raise KasaException( f"{self._host} responded with an unexpected " @@ -241,7 +267,6 @@ class SslAesTransport(BaseTransport): if TYPE_CHECKING: resp_dict = cast(dict[str, Any], resp_dict) - assert self._encryption_session is not None if "result" in resp_dict and "response" in resp_dict["result"]: raw_response: str = resp_dict["result"]["response"] diff --git a/tests/transports/test_sslaestransport.py b/tests/transports/test_sslaestransport.py index 00c54a54..39469967 100644 --- a/tests/transports/test_sslaestransport.py +++ b/tests/transports/test_sslaestransport.py @@ -18,6 +18,7 @@ from kasa.exceptions import ( DeviceError, KasaException, SmartErrorCode, + _RetryableError, ) from kasa.httpclient import HttpClient from kasa.transports.aestransport import AesEncyptionSession @@ -217,6 +218,48 @@ async def test_device_blocked_response(mocker): await transport.perform_handshake() +@pytest.mark.parametrize( + ("response", "expected_msg"), + [ + pytest.param( + {"error_code": -1, "msg": "Check tapo tag failed"}, + '{"error_code": -1, "msg": "Check tapo tag failed"}', + id="can-decrypt", + ), + pytest.param( + b"12345678", + str({"result": {"response": "12345678"}, "error_code": 0}), + id="cannot-decrypt", + ), + ], +) +async def test_device_500_error(mocker, response, expected_msg): + """Test 500 error raises retryable exception.""" + host = "127.0.0.1" + mock_ssl_aes_device = MockSslAesDevice(host) + mocker.patch.object( + aiohttp.ClientSession, "post", side_effect=mock_ssl_aes_device.post + ) + + transport = SslAesTransport( + config=DeviceConfig(host, credentials=Credentials(MOCK_USER, MOCK_PWD)) + ) + + request = { + "method": "getDeviceInfo", + "params": None, + } + + await transport.perform_handshake() + + mock_ssl_aes_device.put_next_response(response) + mock_ssl_aes_device.status_code = 500 + + msg = f"Device 127.0.0.1 replied with status 500 after handshake, response: {expected_msg}" + with pytest.raises(_RetryableError, match=msg): + await transport.send(json_dumps(request)) + + async def test_port_override(): """Test that port override sets the app_url.""" host = "127.0.0.1" @@ -302,6 +345,8 @@ class MockSslAesDevice: self.digest_password_fail = digest_password_fail self.device_blocked = device_blocked + self._next_responses: list[dict | bytes] = [] + async def post(self, url: URL, params=None, json=None, data=None, *_, **__): if data: json = json_loads(data) @@ -386,11 +431,24 @@ class MockSslAesDevice: assert self.encryption_session decrypted_request = self.encryption_session.decrypt(encrypted_request.encode()) decrypted_request_dict = json_loads(decrypted_request) - decrypted_response = await self._post(url, decrypted_request_dict) - async with decrypted_response: - decrypted_response_data = await decrypted_response.read() - encrypted_response = self.encryption_session.encrypt(decrypted_response_data) + if self._next_responses: + next_response = self._next_responses.pop(0) + if isinstance(next_response, dict): + decrypted_response_data = json_dumps(next_response).encode() + encrypted_response = self.encryption_session.encrypt( + decrypted_response_data + ) + else: + encrypted_response = next_response + else: + decrypted_response = await self._post(url, decrypted_request_dict) + async with decrypted_response: + decrypted_response_data = await decrypted_response.read() + encrypted_response = self.encryption_session.encrypt( + decrypted_response_data + ) + response = ( decrypted_response_data if self.do_not_encrypt_response @@ -405,3 +463,6 @@ class MockSslAesDevice: async def _return_send_response(self, url: URL, json: dict[str, Any]): result = {"result": {"method": None}, "error_code": self.send_error_code} return self._mock_response(self.status_code, result) + + def put_next_response(self, request: dict | bytes) -> None: + self._next_responses.append(request)