From 91de5e20ba3c8bbf9f2ce41d21c15aef3dda22f6 Mon Sep 17 00:00:00 2001 From: Steven B <51370195+sdb9696@users.noreply.github.com> Date: Tue, 4 Jun 2024 20:49:01 +0300 Subject: [PATCH] Fix P100 errors on multi-requests (#930) Fixes an issue reported by @bdraco with the P100 not working in the latest branch: `[Errno None] Can not write request body for HOST_REDACTED, ClientOSError(None, 'Can not write request body for URL_REDACTED'))` Issue caused by the number of multi requests going above the default batch of 5 and the P100 not being able to handle the second multi request happening immediately as it closes the connection after each query (See https://github.com/python-kasa/python-kasa/pull/690 for similar issue). This introduces a small wait time on concurrent requests once the device has raised a ClientOSError. --- kasa/aestransport.py | 3 -- kasa/httpclient.py | 24 ++++++++++ kasa/tests/test_aestransport.py | 80 ++++++++++++++++++++++++++++++++- 3 files changed, 102 insertions(+), 5 deletions(-) diff --git a/kasa/aestransport.py b/kasa/aestransport.py index 85624abc..427801e1 100644 --- a/kasa/aestransport.py +++ b/kasa/aestransport.py @@ -6,7 +6,6 @@ under compatible GNU GPL3 license. from __future__ import annotations -import asyncio import base64 import hashlib import logging @@ -74,7 +73,6 @@ class AesTransport(BaseTransport): } CONTENT_LENGTH = "Content-Length" KEY_PAIR_CONTENT_LENGTH = 314 - BACKOFF_SECONDS_AFTER_LOGIN_ERROR = 1 def __init__( self, @@ -216,7 +214,6 @@ class AesTransport(BaseTransport): self._default_credentials = get_default_credentials( DEFAULT_CREDENTIALS["TAPO"] ) - await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_LOGIN_ERROR) await self.perform_handshake() await self.try_login(self._get_login_params(self._default_credentials)) _LOGGER.debug( diff --git a/kasa/httpclient.py b/kasa/httpclient.py index 55ac5a8e..d1f4936e 100644 --- a/kasa/httpclient.py +++ b/kasa/httpclient.py @@ -4,6 +4,7 @@ from __future__ import annotations import asyncio import logging +import time from typing import Any, Dict import aiohttp @@ -28,12 +29,20 @@ def get_cookie_jar() -> aiohttp.CookieJar: class HttpClient: """HttpClient Class.""" + # Some devices (only P100 so far) close the http connection after each request + # and aiohttp doesn't seem to handle it. If a Client OS error is received the + # http client will start ensuring that sequential requests have a wait delay. + WAIT_BETWEEN_REQUESTS_ON_OSERROR = 0.25 + def __init__(self, config: DeviceConfig) -> None: self._config = config self._client_session: aiohttp.ClientSession = None self._jar = aiohttp.CookieJar(unsafe=True, quote_cookie=False) self._last_url = URL(f"http://{self._config.host}/") + self._wait_between_requests = 0.0 + self._last_request_time = 0.0 + @property def client(self) -> aiohttp.ClientSession: """Return the underlying http client.""" @@ -60,6 +69,14 @@ class HttpClient: If the request is provided via the json parameter json will be returned. """ + # Once we know a device needs a wait between sequential queries always wait + # first rather than keep erroring then waiting. + if self._wait_between_requests: + now = time.time() + gap = now - self._last_request_time + if gap < self._wait_between_requests: + await asyncio.sleep(self._wait_between_requests - gap) + _LOGGER.debug("Posting to %s", url) response_data = None self._last_url = url @@ -89,6 +106,9 @@ class HttpClient: response_data = json_loads(response_data.decode()) except (aiohttp.ServerDisconnectedError, aiohttp.ClientOSError) as ex: + if isinstance(ex, aiohttp.ClientOSError): + self._wait_between_requests = self.WAIT_BETWEEN_REQUESTS_ON_OSERROR + self._last_request_time = time.time() raise _ConnectionError( f"Device connection error: {self._config.host}: {ex}", ex ) from ex @@ -103,6 +123,10 @@ class HttpClient: f"Unable to query the device: {self._config.host}: {ex}", ex ) from ex + # For performance only request system time if waiting is enabled + if self._wait_between_requests: + self._last_request_time = time.time() + return resp.status, response_data def get_cookie(self, cookie_name: str) -> str | None: diff --git a/kasa/tests/test_aestransport.py b/kasa/tests/test_aestransport.py index ffd32cb1..00bcb953 100644 --- a/kasa/tests/test_aestransport.py +++ b/kasa/tests/test_aestransport.py @@ -24,6 +24,7 @@ from ..exceptions import ( AuthenticationError, KasaException, SmartErrorCode, + _ConnectionError, ) from ..httpclient import HttpClient @@ -137,7 +138,7 @@ async def test_login_errors(mocker, inner_error_codes, expectation, call_count): transport._state = TransportState.LOGIN_REQUIRED transport._session_expire_at = time.time() + 86400 transport._encryption_session = mock_aes_device.encryption_session - mocker.patch.object(transport, "BACKOFF_SECONDS_AFTER_LOGIN_ERROR", 0) + mocker.patch.object(transport._http_client, "WAIT_BETWEEN_REQUESTS_ON_OSERROR", 0) assert transport._token_url is None @@ -285,6 +286,68 @@ async def test_port_override(): assert str(transport._app_url) == "http://127.0.0.1:12345/app" +@pytest.mark.parametrize( + "request_delay, should_error, should_succeed", + [(0, False, True), (0.125, True, True), (0.3, True, True), (0.7, True, False)], + ids=["No error", "Error then succeed", "Two errors then succeed", "No succeed"], +) +async def test_device_closes_connection( + mocker, request_delay, should_error, should_succeed +): + """Test the delay logic in http client to deal with devices that close connections after each request. + + Currently only the P100 on older firmware. + """ + host = "127.0.0.1" + + # Speed up the test by dividing all times by a factor. Doesn't seem to work on windows + # but leaving here as a TODO to manipulate system time for testing. + speed_up_factor = 1 + default_delay = HttpClient.WAIT_BETWEEN_REQUESTS_ON_OSERROR / speed_up_factor + request_delay = request_delay / speed_up_factor + mock_aes_device = MockAesDevice( + host, 200, 0, 0, sequential_request_delay=request_delay + ) + mocker.patch.object(aiohttp.ClientSession, "post", side_effect=mock_aes_device.post) + + config = DeviceConfig(host, credentials=Credentials("foo", "bar")) + transport = AesTransport(config=config) + transport._http_client.WAIT_BETWEEN_REQUESTS_ON_OSERROR = default_delay + transport._state = TransportState.LOGIN_REQUIRED + transport._session_expire_at = time.time() + 86400 + transport._encryption_session = mock_aes_device.encryption_session + transport._token_url = transport._app_url.with_query( + f"token={mock_aes_device.token}" + ) + request = { + "method": "get_device_info", + "params": None, + "request_time_milis": round(time.time() * 1000), + "requestID": 1, + "terminal_uuid": "foobar", + } + error_count = 0 + success = False + + # If the device errors without a delay then it should error immedately ( + 1) + # and then the number of times the default delay passes within the request delay window + expected_error_count = ( + 0 if not should_error else int(request_delay / default_delay) + 1 + ) + for _ in range(3): + try: + await transport.send(json_dumps(request)) + except _ConnectionError: + error_count += 1 + else: + success = True + + assert bool(transport._http_client._wait_between_requests) == should_error + assert bool(error_count) == should_error + assert error_count == expected_error_count + assert success == should_succeed + + class MockAesDevice: class _mock_response: def __init__(self, status, json: dict): @@ -313,6 +376,7 @@ class MockAesDevice: *, do_not_encrypt_response=False, send_response=None, + sequential_request_delay=0, ): self.host = host self.status_code = status_code @@ -323,6 +387,9 @@ class MockAesDevice: self.http_client = HttpClient(DeviceConfig(self.host)) self.inner_call_count = 0 self.token = "".join(random.choices(string.ascii_uppercase, k=32)) # noqa: S311 + self.sequential_request_delay = sequential_request_delay + self.last_request_time = None + self.sequential_error_raised = False @property def inner_error_code(self): @@ -332,10 +399,19 @@ class MockAesDevice: return self._inner_error_code async def post(self, url: URL, params=None, json=None, data=None, *_, **__): + if self.sequential_request_delay and self.last_request_time: + now = time.time() + print(now - self.last_request_time) + if (now - self.last_request_time) < self.sequential_request_delay: + self.sequential_error_raised = True + raise aiohttp.ClientOSError("Test connection closed") if data: async for item in data: json = json_loads(item.decode()) - return await self._post(url, json) + res = await self._post(url, json) + if self.sequential_request_delay: + self.last_request_time = time.time() + return res async def _post(self, url: URL, json: dict[str, Any]): if json["method"] == "handshake":