Backoff after xor timeout and improve error reporting (#1424)

This commit is contained in:
J. Nick Koston 2025-01-06 04:00:23 -10:00 committed by GitHub
parent 48a07a2970
commit 7d508b5092
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 232 additions and 1 deletions

View File

@ -98,12 +98,26 @@ class IotProtocol(BaseProtocol):
)
raise auex
except _RetryableError as ex:
if retry == 0:
_LOGGER.debug(
"Device %s got a retryable error, will retry %s times: %s",
self._host,
retry_count,
ex,
)
await self._transport.reset()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise ex
continue
except TimeoutError as ex:
if retry == 0:
_LOGGER.debug(
"Device %s got a timeout error, will retry %s times: %s",
self._host,
retry_count,
ex,
)
await self._transport.reset()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)

View File

@ -23,6 +23,7 @@ from collections.abc import Generator
from kasa.deviceconfig import DeviceConfig
from kasa.exceptions import KasaException, _RetryableError
from kasa.exceptions import TimeoutError as KasaTimeoutError
from kasa.json import loads as json_loads
from .basetransport import BaseTransport
@ -126,6 +127,12 @@ class XorTransport(BaseTransport):
# This is especially import when there are multiple tplink devices being polled.
try:
await self._connect(self._timeout)
except TimeoutError as ex:
await self.reset()
raise KasaTimeoutError(
f"Timeout after {self._timeout} seconds connecting to the device:"
f" {self._host}:{self._port}: {ex}"
) from ex
except ConnectionRefusedError as ex:
await self.reset()
raise KasaException(
@ -159,6 +166,12 @@ class XorTransport(BaseTransport):
assert self.writer is not None # noqa: S101
async with asyncio_timeout(self._timeout):
return await self._execute_send(request)
except TimeoutError as ex:
await self.reset()
raise KasaTimeoutError(
f"Timeout after {self._timeout} seconds sending request to the device"
f" {self._host}:{self._port}: {ex}"
) from ex
except Exception as ex:
await self.reset()
raise _RetryableError(

View File

@ -16,7 +16,7 @@ import pytest
from kasa.credentials import Credentials
from kasa.device import Device
from kasa.deviceconfig import DeviceConfig
from kasa.exceptions import KasaException
from kasa.exceptions import KasaException, TimeoutError
from kasa.iot import IotDevice
from kasa.protocols.iotprotocol import IotProtocol, _deprecated_TPLinkSmartHomeProtocol
from kasa.protocols.protocol import (
@ -294,6 +294,210 @@ async def test_protocol_handles_cancellation_during_connection(
assert response == {"great": "success"}
@pytest.mark.parametrize(
("protocol_class", "transport_class", "encryption_class"),
[
(
_deprecated_TPLinkSmartHomeProtocol,
XorTransport,
_deprecated_TPLinkSmartHomeProtocol,
),
(IotProtocol, XorTransport, XorEncryption),
],
ids=("_deprecated_TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"),
)
async def test_protocol_handles_timeout_during_write(
mocker, protocol_class, transport_class, encryption_class
):
attempts = 0
encrypted = encryption_class.encrypt('{"great":"success"}')[
transport_class.BLOCK_SIZE :
]
def _timeout_first_attempt(*_):
nonlocal attempts
attempts += 1
if attempts == 1:
raise TimeoutError("Simulated timeout")
async def _mock_read(byte_count):
nonlocal encrypted
if byte_count == transport_class.BLOCK_SIZE:
return struct.pack(">I", len(encrypted))
if byte_count == len(encrypted):
return encrypted
raise ValueError(f"No mock for {byte_count}")
def aio_mock_writer(_, __):
reader = mocker.patch("asyncio.StreamReader")
writer = mocker.patch("asyncio.StreamWriter")
mocker.patch.object(writer, "write", _timeout_first_attempt)
mocker.patch.object(reader, "readexactly", _mock_read)
mocker.patch.object(writer, "drain", new_callable=AsyncMock)
return reader, writer
config = DeviceConfig("127.0.0.1")
protocol = protocol_class(transport=transport_class(config=config))
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
await protocol.query({})
writer_obj = protocol if hasattr(protocol, "writer") else protocol._transport
assert writer_obj.writer is not None
response = await protocol.query({})
assert response == {"great": "success"}
@pytest.mark.parametrize(
("protocol_class", "transport_class", "encryption_class"),
[
(
_deprecated_TPLinkSmartHomeProtocol,
XorTransport,
_deprecated_TPLinkSmartHomeProtocol,
),
(IotProtocol, XorTransport, XorEncryption),
],
ids=("_deprecated_TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"),
)
async def test_protocol_handles_timeout_during_connection(
mocker, protocol_class, transport_class, encryption_class
):
attempts = 0
encrypted = encryption_class.encrypt('{"great":"success"}')[
transport_class.BLOCK_SIZE :
]
async def _mock_read(byte_count):
nonlocal encrypted
if byte_count == transport_class.BLOCK_SIZE:
return struct.pack(">I", len(encrypted))
if byte_count == len(encrypted):
return encrypted
raise ValueError(f"No mock for {byte_count}")
def aio_mock_writer(_, __):
nonlocal attempts
attempts += 1
if attempts == 1:
raise TimeoutError("Simulated timeout")
reader = mocker.patch("asyncio.StreamReader")
writer = mocker.patch("asyncio.StreamWriter")
mocker.patch.object(reader, "readexactly", _mock_read)
mocker.patch.object(writer, "drain", new_callable=AsyncMock)
return reader, writer
config = DeviceConfig("127.0.0.1")
protocol = protocol_class(transport=transport_class(config=config))
writer_obj = protocol if hasattr(protocol, "writer") else protocol._transport
await writer_obj.close()
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
await protocol.query({"any": "thing"})
writer_obj = protocol if hasattr(protocol, "writer") else protocol._transport
assert writer_obj.writer is not None
response = await protocol.query({})
assert response == {"great": "success"}
@pytest.mark.parametrize(
("protocol_class", "transport_class", "encryption_class"),
[
(
_deprecated_TPLinkSmartHomeProtocol,
XorTransport,
_deprecated_TPLinkSmartHomeProtocol,
),
(IotProtocol, XorTransport, XorEncryption),
],
ids=("_deprecated_TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"),
)
async def test_protocol_handles_timeout_failure_during_write(
mocker, protocol_class, transport_class, encryption_class
):
encrypted = encryption_class.encrypt('{"great":"success"}')[
transport_class.BLOCK_SIZE :
]
def _timeout_all_attempts(*_):
raise TimeoutError("Simulated timeout")
async def _mock_read(byte_count):
nonlocal encrypted
if byte_count == transport_class.BLOCK_SIZE:
return struct.pack(">I", len(encrypted))
if byte_count == len(encrypted):
return encrypted
raise ValueError(f"No mock for {byte_count}")
def aio_mock_writer(_, __):
reader = mocker.patch("asyncio.StreamReader")
writer = mocker.patch("asyncio.StreamWriter")
mocker.patch.object(writer, "write", _timeout_all_attempts)
mocker.patch.object(reader, "readexactly", _mock_read)
mocker.patch.object(writer, "drain", new_callable=AsyncMock)
return reader, writer
config = DeviceConfig("127.0.0.1")
protocol = protocol_class(transport=transport_class(config=config))
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
with pytest.raises(
TimeoutError,
match="Timeout after 5 seconds sending request to the device 127.0.0.1:9999: Simulated timeout",
):
await protocol.query({})
writer_obj = protocol if hasattr(protocol, "writer") else protocol._transport
assert writer_obj.writer is None
@pytest.mark.parametrize(
("protocol_class", "transport_class", "encryption_class"),
[
(
_deprecated_TPLinkSmartHomeProtocol,
XorTransport,
_deprecated_TPLinkSmartHomeProtocol,
),
(IotProtocol, XorTransport, XorEncryption),
],
ids=("_deprecated_TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"),
)
async def test_protocol_handles_timeout_failure_during_connection(
mocker, protocol_class, transport_class, encryption_class
):
encrypted = encryption_class.encrypt('{"great":"success"}')[
transport_class.BLOCK_SIZE :
]
async def _mock_read(byte_count):
nonlocal encrypted
if byte_count == transport_class.BLOCK_SIZE:
return struct.pack(">I", len(encrypted))
if byte_count == len(encrypted):
return encrypted
raise ValueError(f"No mock for {byte_count}")
def aio_mock_writer(_, __):
raise TimeoutError("Simulated timeout")
config = DeviceConfig("127.0.0.1")
protocol = protocol_class(transport=transport_class(config=config))
writer_obj = protocol if hasattr(protocol, "writer") else protocol._transport
await writer_obj.close()
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
with pytest.raises(
TimeoutError,
match="Timeout after 5 seconds connecting to the device: 127.0.0.1:9999: Simulated timeout",
):
await protocol.query({})
writer_obj = protocol if hasattr(protocol, "writer") else protocol._transport
assert writer_obj.writer is None
@pytest.mark.parametrize(
("protocol_class", "transport_class", "encryption_class"),
[