mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-04-26 16:46:23 +00:00
Fix P100 errors on multi-requests (#930)
Fixes an issue reported by @bdraco with the P100 not working in the latest branch: `[Errno None] Can not write request body for HOST_REDACTED, ClientOSError(None, 'Can not write request body for URL_REDACTED'))` Issue caused by the number of multi requests going above the default batch of 5 and the P100 not being able to handle the second multi request happening immediately as it closes the connection after each query (See https://github.com/python-kasa/python-kasa/pull/690 for similar issue). This introduces a small wait time on concurrent requests once the device has raised a ClientOSError.
This commit is contained in:
parent
40f2263770
commit
91de5e20ba
@ -6,7 +6,6 @@ under compatible GNU GPL3 license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
@ -74,7 +73,6 @@ class AesTransport(BaseTransport):
|
||||
}
|
||||
CONTENT_LENGTH = "Content-Length"
|
||||
KEY_PAIR_CONTENT_LENGTH = 314
|
||||
BACKOFF_SECONDS_AFTER_LOGIN_ERROR = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -216,7 +214,6 @@ class AesTransport(BaseTransport):
|
||||
self._default_credentials = get_default_credentials(
|
||||
DEFAULT_CREDENTIALS["TAPO"]
|
||||
)
|
||||
await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_LOGIN_ERROR)
|
||||
await self.perform_handshake()
|
||||
await self.try_login(self._get_login_params(self._default_credentials))
|
||||
_LOGGER.debug(
|
||||
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
|
||||
import aiohttp
|
||||
@ -28,12 +29,20 @@ def get_cookie_jar() -> aiohttp.CookieJar:
|
||||
class HttpClient:
|
||||
"""HttpClient Class."""
|
||||
|
||||
# Some devices (only P100 so far) close the http connection after each request
|
||||
# and aiohttp doesn't seem to handle it. If a Client OS error is received the
|
||||
# http client will start ensuring that sequential requests have a wait delay.
|
||||
WAIT_BETWEEN_REQUESTS_ON_OSERROR = 0.25
|
||||
|
||||
def __init__(self, config: DeviceConfig) -> None:
|
||||
self._config = config
|
||||
self._client_session: aiohttp.ClientSession = None
|
||||
self._jar = aiohttp.CookieJar(unsafe=True, quote_cookie=False)
|
||||
self._last_url = URL(f"http://{self._config.host}/")
|
||||
|
||||
self._wait_between_requests = 0.0
|
||||
self._last_request_time = 0.0
|
||||
|
||||
@property
|
||||
def client(self) -> aiohttp.ClientSession:
|
||||
"""Return the underlying http client."""
|
||||
@ -60,6 +69,14 @@ class HttpClient:
|
||||
|
||||
If the request is provided via the json parameter json will be returned.
|
||||
"""
|
||||
# Once we know a device needs a wait between sequential queries always wait
|
||||
# first rather than keep erroring then waiting.
|
||||
if self._wait_between_requests:
|
||||
now = time.time()
|
||||
gap = now - self._last_request_time
|
||||
if gap < self._wait_between_requests:
|
||||
await asyncio.sleep(self._wait_between_requests - gap)
|
||||
|
||||
_LOGGER.debug("Posting to %s", url)
|
||||
response_data = None
|
||||
self._last_url = url
|
||||
@ -89,6 +106,9 @@ class HttpClient:
|
||||
response_data = json_loads(response_data.decode())
|
||||
|
||||
except (aiohttp.ServerDisconnectedError, aiohttp.ClientOSError) as ex:
|
||||
if isinstance(ex, aiohttp.ClientOSError):
|
||||
self._wait_between_requests = self.WAIT_BETWEEN_REQUESTS_ON_OSERROR
|
||||
self._last_request_time = time.time()
|
||||
raise _ConnectionError(
|
||||
f"Device connection error: {self._config.host}: {ex}", ex
|
||||
) from ex
|
||||
@ -103,6 +123,10 @@ class HttpClient:
|
||||
f"Unable to query the device: {self._config.host}: {ex}", ex
|
||||
) from ex
|
||||
|
||||
# For performance only request system time if waiting is enabled
|
||||
if self._wait_between_requests:
|
||||
self._last_request_time = time.time()
|
||||
|
||||
return resp.status, response_data
|
||||
|
||||
def get_cookie(self, cookie_name: str) -> str | None:
|
||||
|
@ -24,6 +24,7 @@ from ..exceptions import (
|
||||
AuthenticationError,
|
||||
KasaException,
|
||||
SmartErrorCode,
|
||||
_ConnectionError,
|
||||
)
|
||||
from ..httpclient import HttpClient
|
||||
|
||||
@ -137,7 +138,7 @@ async def test_login_errors(mocker, inner_error_codes, expectation, call_count):
|
||||
transport._state = TransportState.LOGIN_REQUIRED
|
||||
transport._session_expire_at = time.time() + 86400
|
||||
transport._encryption_session = mock_aes_device.encryption_session
|
||||
mocker.patch.object(transport, "BACKOFF_SECONDS_AFTER_LOGIN_ERROR", 0)
|
||||
mocker.patch.object(transport._http_client, "WAIT_BETWEEN_REQUESTS_ON_OSERROR", 0)
|
||||
|
||||
assert transport._token_url is None
|
||||
|
||||
@ -285,6 +286,68 @@ async def test_port_override():
|
||||
assert str(transport._app_url) == "http://127.0.0.1:12345/app"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"request_delay, should_error, should_succeed",
|
||||
[(0, False, True), (0.125, True, True), (0.3, True, True), (0.7, True, False)],
|
||||
ids=["No error", "Error then succeed", "Two errors then succeed", "No succeed"],
|
||||
)
|
||||
async def test_device_closes_connection(
|
||||
mocker, request_delay, should_error, should_succeed
|
||||
):
|
||||
"""Test the delay logic in http client to deal with devices that close connections after each request.
|
||||
|
||||
Currently only the P100 on older firmware.
|
||||
"""
|
||||
host = "127.0.0.1"
|
||||
|
||||
# Speed up the test by dividing all times by a factor. Doesn't seem to work on windows
|
||||
# but leaving here as a TODO to manipulate system time for testing.
|
||||
speed_up_factor = 1
|
||||
default_delay = HttpClient.WAIT_BETWEEN_REQUESTS_ON_OSERROR / speed_up_factor
|
||||
request_delay = request_delay / speed_up_factor
|
||||
mock_aes_device = MockAesDevice(
|
||||
host, 200, 0, 0, sequential_request_delay=request_delay
|
||||
)
|
||||
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=mock_aes_device.post)
|
||||
|
||||
config = DeviceConfig(host, credentials=Credentials("foo", "bar"))
|
||||
transport = AesTransport(config=config)
|
||||
transport._http_client.WAIT_BETWEEN_REQUESTS_ON_OSERROR = default_delay
|
||||
transport._state = TransportState.LOGIN_REQUIRED
|
||||
transport._session_expire_at = time.time() + 86400
|
||||
transport._encryption_session = mock_aes_device.encryption_session
|
||||
transport._token_url = transport._app_url.with_query(
|
||||
f"token={mock_aes_device.token}"
|
||||
)
|
||||
request = {
|
||||
"method": "get_device_info",
|
||||
"params": None,
|
||||
"request_time_milis": round(time.time() * 1000),
|
||||
"requestID": 1,
|
||||
"terminal_uuid": "foobar",
|
||||
}
|
||||
error_count = 0
|
||||
success = False
|
||||
|
||||
# If the device errors without a delay then it should error immedately ( + 1)
|
||||
# and then the number of times the default delay passes within the request delay window
|
||||
expected_error_count = (
|
||||
0 if not should_error else int(request_delay / default_delay) + 1
|
||||
)
|
||||
for _ in range(3):
|
||||
try:
|
||||
await transport.send(json_dumps(request))
|
||||
except _ConnectionError:
|
||||
error_count += 1
|
||||
else:
|
||||
success = True
|
||||
|
||||
assert bool(transport._http_client._wait_between_requests) == should_error
|
||||
assert bool(error_count) == should_error
|
||||
assert error_count == expected_error_count
|
||||
assert success == should_succeed
|
||||
|
||||
|
||||
class MockAesDevice:
|
||||
class _mock_response:
|
||||
def __init__(self, status, json: dict):
|
||||
@ -313,6 +376,7 @@ class MockAesDevice:
|
||||
*,
|
||||
do_not_encrypt_response=False,
|
||||
send_response=None,
|
||||
sequential_request_delay=0,
|
||||
):
|
||||
self.host = host
|
||||
self.status_code = status_code
|
||||
@ -323,6 +387,9 @@ class MockAesDevice:
|
||||
self.http_client = HttpClient(DeviceConfig(self.host))
|
||||
self.inner_call_count = 0
|
||||
self.token = "".join(random.choices(string.ascii_uppercase, k=32)) # noqa: S311
|
||||
self.sequential_request_delay = sequential_request_delay
|
||||
self.last_request_time = None
|
||||
self.sequential_error_raised = False
|
||||
|
||||
@property
|
||||
def inner_error_code(self):
|
||||
@ -332,10 +399,19 @@ class MockAesDevice:
|
||||
return self._inner_error_code
|
||||
|
||||
async def post(self, url: URL, params=None, json=None, data=None, *_, **__):
|
||||
if self.sequential_request_delay and self.last_request_time:
|
||||
now = time.time()
|
||||
print(now - self.last_request_time)
|
||||
if (now - self.last_request_time) < self.sequential_request_delay:
|
||||
self.sequential_error_raised = True
|
||||
raise aiohttp.ClientOSError("Test connection closed")
|
||||
if data:
|
||||
async for item in data:
|
||||
json = json_loads(item.decode())
|
||||
return await self._post(url, json)
|
||||
res = await self._post(url, json)
|
||||
if self.sequential_request_delay:
|
||||
self.last_request_time = time.time()
|
||||
return res
|
||||
|
||||
async def _post(self, url: URL, json: dict[str, Any]):
|
||||
if json["method"] == "handshake":
|
||||
|
Loading…
x
Reference in New Issue
Block a user