Fix P100 errors on multi-requests (#930)

Fixes an issue reported by @bdraco with the P100 not working in the
latest branch:

`[Errno None] Can not write request body for HOST_REDACTED,
ClientOSError(None, 'Can not write request body for
URL_REDACTED'))`

Issue caused by the number of multi requests going above the default
batch of 5 and the P100 not being able to handle the second multi
request happening immediately as it closes the connection after each
query (See https://github.com/python-kasa/python-kasa/pull/690 for
similar issue). This introduces a small wait time on concurrent requests
once the device has raised a ClientOSError.
This commit is contained in:
Steven B 2024-06-04 20:49:01 +03:00 committed by GitHub
parent 40f2263770
commit 91de5e20ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 102 additions and 5 deletions

View File

@ -6,7 +6,6 @@ under compatible GNU GPL3 license.
from __future__ import annotations
import asyncio
import base64
import hashlib
import logging
@ -74,7 +73,6 @@ class AesTransport(BaseTransport):
}
CONTENT_LENGTH = "Content-Length"
KEY_PAIR_CONTENT_LENGTH = 314
BACKOFF_SECONDS_AFTER_LOGIN_ERROR = 1
def __init__(
self,
@ -216,7 +214,6 @@ class AesTransport(BaseTransport):
self._default_credentials = get_default_credentials(
DEFAULT_CREDENTIALS["TAPO"]
)
await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_LOGIN_ERROR)
await self.perform_handshake()
await self.try_login(self._get_login_params(self._default_credentials))
_LOGGER.debug(

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio
import logging
import time
from typing import Any, Dict
import aiohttp
@ -28,12 +29,20 @@ def get_cookie_jar() -> aiohttp.CookieJar:
class HttpClient:
"""HttpClient Class."""
# Some devices (only P100 so far) close the http connection after each request
# and aiohttp doesn't seem to handle it. If a Client OS error is received the
# http client will start ensuring that sequential requests have a wait delay.
WAIT_BETWEEN_REQUESTS_ON_OSERROR = 0.25
def __init__(self, config: DeviceConfig) -> None:
self._config = config
self._client_session: aiohttp.ClientSession = None
self._jar = aiohttp.CookieJar(unsafe=True, quote_cookie=False)
self._last_url = URL(f"http://{self._config.host}/")
self._wait_between_requests = 0.0
self._last_request_time = 0.0
@property
def client(self) -> aiohttp.ClientSession:
"""Return the underlying http client."""
@ -60,6 +69,14 @@ class HttpClient:
If the request is provided via the json parameter json will be returned.
"""
# Once we know a device needs a wait between sequential queries always wait
# first rather than keep erroring then waiting.
if self._wait_between_requests:
now = time.time()
gap = now - self._last_request_time
if gap < self._wait_between_requests:
await asyncio.sleep(self._wait_between_requests - gap)
_LOGGER.debug("Posting to %s", url)
response_data = None
self._last_url = url
@ -89,6 +106,9 @@ class HttpClient:
response_data = json_loads(response_data.decode())
except (aiohttp.ServerDisconnectedError, aiohttp.ClientOSError) as ex:
if isinstance(ex, aiohttp.ClientOSError):
self._wait_between_requests = self.WAIT_BETWEEN_REQUESTS_ON_OSERROR
self._last_request_time = time.time()
raise _ConnectionError(
f"Device connection error: {self._config.host}: {ex}", ex
) from ex
@ -103,6 +123,10 @@ class HttpClient:
f"Unable to query the device: {self._config.host}: {ex}", ex
) from ex
# For performance only request system time if waiting is enabled
if self._wait_between_requests:
self._last_request_time = time.time()
return resp.status, response_data
def get_cookie(self, cookie_name: str) -> str | None:

View File

@ -24,6 +24,7 @@ from ..exceptions import (
AuthenticationError,
KasaException,
SmartErrorCode,
_ConnectionError,
)
from ..httpclient import HttpClient
@ -137,7 +138,7 @@ async def test_login_errors(mocker, inner_error_codes, expectation, call_count):
transport._state = TransportState.LOGIN_REQUIRED
transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session
mocker.patch.object(transport, "BACKOFF_SECONDS_AFTER_LOGIN_ERROR", 0)
mocker.patch.object(transport._http_client, "WAIT_BETWEEN_REQUESTS_ON_OSERROR", 0)
assert transport._token_url is None
@ -285,6 +286,68 @@ async def test_port_override():
assert str(transport._app_url) == "http://127.0.0.1:12345/app"
@pytest.mark.parametrize(
"request_delay, should_error, should_succeed",
[(0, False, True), (0.125, True, True), (0.3, True, True), (0.7, True, False)],
ids=["No error", "Error then succeed", "Two errors then succeed", "No succeed"],
)
async def test_device_closes_connection(
mocker, request_delay, should_error, should_succeed
):
"""Test the delay logic in http client to deal with devices that close connections after each request.
Currently only the P100 on older firmware.
"""
host = "127.0.0.1"
# Speed up the test by dividing all times by a factor. Doesn't seem to work on windows
# but leaving here as a TODO to manipulate system time for testing.
speed_up_factor = 1
default_delay = HttpClient.WAIT_BETWEEN_REQUESTS_ON_OSERROR / speed_up_factor
request_delay = request_delay / speed_up_factor
mock_aes_device = MockAesDevice(
host, 200, 0, 0, sequential_request_delay=request_delay
)
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=mock_aes_device.post)
config = DeviceConfig(host, credentials=Credentials("foo", "bar"))
transport = AesTransport(config=config)
transport._http_client.WAIT_BETWEEN_REQUESTS_ON_OSERROR = default_delay
transport._state = TransportState.LOGIN_REQUIRED
transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session
transport._token_url = transport._app_url.with_query(
f"token={mock_aes_device.token}"
)
request = {
"method": "get_device_info",
"params": None,
"request_time_milis": round(time.time() * 1000),
"requestID": 1,
"terminal_uuid": "foobar",
}
error_count = 0
success = False
# If the device errors without a delay then it should error immedately ( + 1)
# and then the number of times the default delay passes within the request delay window
expected_error_count = (
0 if not should_error else int(request_delay / default_delay) + 1
)
for _ in range(3):
try:
await transport.send(json_dumps(request))
except _ConnectionError:
error_count += 1
else:
success = True
assert bool(transport._http_client._wait_between_requests) == should_error
assert bool(error_count) == should_error
assert error_count == expected_error_count
assert success == should_succeed
class MockAesDevice:
class _mock_response:
def __init__(self, status, json: dict):
@ -313,6 +376,7 @@ class MockAesDevice:
*,
do_not_encrypt_response=False,
send_response=None,
sequential_request_delay=0,
):
self.host = host
self.status_code = status_code
@ -323,6 +387,9 @@ class MockAesDevice:
self.http_client = HttpClient(DeviceConfig(self.host))
self.inner_call_count = 0
self.token = "".join(random.choices(string.ascii_uppercase, k=32)) # noqa: S311
self.sequential_request_delay = sequential_request_delay
self.last_request_time = None
self.sequential_error_raised = False
@property
def inner_error_code(self):
@ -332,10 +399,19 @@ class MockAesDevice:
return self._inner_error_code
async def post(self, url: URL, params=None, json=None, data=None, *_, **__):
if self.sequential_request_delay and self.last_request_time:
now = time.time()
print(now - self.last_request_time)
if (now - self.last_request_time) < self.sequential_request_delay:
self.sequential_error_raised = True
raise aiohttp.ClientOSError("Test connection closed")
if data:
async for item in data:
json = json_loads(item.decode())
return await self._post(url, json)
res = await self._post(url, json)
if self.sequential_request_delay:
self.last_request_time = time.time()
return res
async def _post(self, url: URL, json: dict[str, Any]):
if json["method"] == "handshake":