Backoff after xor timeout and improve error reporting (#1424)

This commit is contained in:
J. Nick Koston 2025-01-06 04:00:23 -10:00 committed by GitHub
parent 48a07a2970
commit 7d508b5092
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 232 additions and 1 deletions

View File

@ -98,12 +98,26 @@ class IotProtocol(BaseProtocol):
) )
raise auex raise auex
except _RetryableError as ex: except _RetryableError as ex:
if retry == 0:
_LOGGER.debug(
"Device %s got a retryable error, will retry %s times: %s",
self._host,
retry_count,
ex,
)
await self._transport.reset() await self._transport.reset()
if retry >= retry_count: if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry) _LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise ex raise ex
continue continue
except TimeoutError as ex: except TimeoutError as ex:
if retry == 0:
_LOGGER.debug(
"Device %s got a timeout error, will retry %s times: %s",
self._host,
retry_count,
ex,
)
await self._transport.reset() await self._transport.reset()
if retry >= retry_count: if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry) _LOGGER.debug("Giving up on %s after %s retries", self._host, retry)

View File

@ -23,6 +23,7 @@ from collections.abc import Generator
from kasa.deviceconfig import DeviceConfig from kasa.deviceconfig import DeviceConfig
from kasa.exceptions import KasaException, _RetryableError from kasa.exceptions import KasaException, _RetryableError
from kasa.exceptions import TimeoutError as KasaTimeoutError
from kasa.json import loads as json_loads from kasa.json import loads as json_loads
from .basetransport import BaseTransport from .basetransport import BaseTransport
@ -126,6 +127,12 @@ class XorTransport(BaseTransport):
# This is especially import when there are multiple tplink devices being polled. # This is especially import when there are multiple tplink devices being polled.
try: try:
await self._connect(self._timeout) await self._connect(self._timeout)
except TimeoutError as ex:
await self.reset()
raise KasaTimeoutError(
f"Timeout after {self._timeout} seconds connecting to the device:"
f" {self._host}:{self._port}: {ex}"
) from ex
except ConnectionRefusedError as ex: except ConnectionRefusedError as ex:
await self.reset() await self.reset()
raise KasaException( raise KasaException(
@ -159,6 +166,12 @@ class XorTransport(BaseTransport):
assert self.writer is not None # noqa: S101 assert self.writer is not None # noqa: S101
async with asyncio_timeout(self._timeout): async with asyncio_timeout(self._timeout):
return await self._execute_send(request) return await self._execute_send(request)
except TimeoutError as ex:
await self.reset()
raise KasaTimeoutError(
f"Timeout after {self._timeout} seconds sending request to the device"
f" {self._host}:{self._port}: {ex}"
) from ex
except Exception as ex: except Exception as ex:
await self.reset() await self.reset()
raise _RetryableError( raise _RetryableError(

View File

@ -16,7 +16,7 @@ import pytest
from kasa.credentials import Credentials from kasa.credentials import Credentials
from kasa.device import Device from kasa.device import Device
from kasa.deviceconfig import DeviceConfig from kasa.deviceconfig import DeviceConfig
from kasa.exceptions import KasaException from kasa.exceptions import KasaException, TimeoutError
from kasa.iot import IotDevice from kasa.iot import IotDevice
from kasa.protocols.iotprotocol import IotProtocol, _deprecated_TPLinkSmartHomeProtocol from kasa.protocols.iotprotocol import IotProtocol, _deprecated_TPLinkSmartHomeProtocol
from kasa.protocols.protocol import ( from kasa.protocols.protocol import (
@ -294,6 +294,210 @@ async def test_protocol_handles_cancellation_during_connection(
assert response == {"great": "success"} assert response == {"great": "success"}
@pytest.mark.parametrize(
("protocol_class", "transport_class", "encryption_class"),
[
(
_deprecated_TPLinkSmartHomeProtocol,
XorTransport,
_deprecated_TPLinkSmartHomeProtocol,
),
(IotProtocol, XorTransport, XorEncryption),
],
ids=("_deprecated_TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"),
)
async def test_protocol_handles_timeout_during_write(
mocker, protocol_class, transport_class, encryption_class
):
attempts = 0
encrypted = encryption_class.encrypt('{"great":"success"}')[
transport_class.BLOCK_SIZE :
]
def _timeout_first_attempt(*_):
nonlocal attempts
attempts += 1
if attempts == 1:
raise TimeoutError("Simulated timeout")
async def _mock_read(byte_count):
nonlocal encrypted
if byte_count == transport_class.BLOCK_SIZE:
return struct.pack(">I", len(encrypted))
if byte_count == len(encrypted):
return encrypted
raise ValueError(f"No mock for {byte_count}")
def aio_mock_writer(_, __):
reader = mocker.patch("asyncio.StreamReader")
writer = mocker.patch("asyncio.StreamWriter")
mocker.patch.object(writer, "write", _timeout_first_attempt)
mocker.patch.object(reader, "readexactly", _mock_read)
mocker.patch.object(writer, "drain", new_callable=AsyncMock)
return reader, writer
config = DeviceConfig("127.0.0.1")
protocol = protocol_class(transport=transport_class(config=config))
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
await protocol.query({})
writer_obj = protocol if hasattr(protocol, "writer") else protocol._transport
assert writer_obj.writer is not None
response = await protocol.query({})
assert response == {"great": "success"}
@pytest.mark.parametrize(
("protocol_class", "transport_class", "encryption_class"),
[
(
_deprecated_TPLinkSmartHomeProtocol,
XorTransport,
_deprecated_TPLinkSmartHomeProtocol,
),
(IotProtocol, XorTransport, XorEncryption),
],
ids=("_deprecated_TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"),
)
async def test_protocol_handles_timeout_during_connection(
mocker, protocol_class, transport_class, encryption_class
):
attempts = 0
encrypted = encryption_class.encrypt('{"great":"success"}')[
transport_class.BLOCK_SIZE :
]
async def _mock_read(byte_count):
nonlocal encrypted
if byte_count == transport_class.BLOCK_SIZE:
return struct.pack(">I", len(encrypted))
if byte_count == len(encrypted):
return encrypted
raise ValueError(f"No mock for {byte_count}")
def aio_mock_writer(_, __):
nonlocal attempts
attempts += 1
if attempts == 1:
raise TimeoutError("Simulated timeout")
reader = mocker.patch("asyncio.StreamReader")
writer = mocker.patch("asyncio.StreamWriter")
mocker.patch.object(reader, "readexactly", _mock_read)
mocker.patch.object(writer, "drain", new_callable=AsyncMock)
return reader, writer
config = DeviceConfig("127.0.0.1")
protocol = protocol_class(transport=transport_class(config=config))
writer_obj = protocol if hasattr(protocol, "writer") else protocol._transport
await writer_obj.close()
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
await protocol.query({"any": "thing"})
writer_obj = protocol if hasattr(protocol, "writer") else protocol._transport
assert writer_obj.writer is not None
response = await protocol.query({})
assert response == {"great": "success"}
@pytest.mark.parametrize(
("protocol_class", "transport_class", "encryption_class"),
[
(
_deprecated_TPLinkSmartHomeProtocol,
XorTransport,
_deprecated_TPLinkSmartHomeProtocol,
),
(IotProtocol, XorTransport, XorEncryption),
],
ids=("_deprecated_TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"),
)
async def test_protocol_handles_timeout_failure_during_write(
mocker, protocol_class, transport_class, encryption_class
):
encrypted = encryption_class.encrypt('{"great":"success"}')[
transport_class.BLOCK_SIZE :
]
def _timeout_all_attempts(*_):
raise TimeoutError("Simulated timeout")
async def _mock_read(byte_count):
nonlocal encrypted
if byte_count == transport_class.BLOCK_SIZE:
return struct.pack(">I", len(encrypted))
if byte_count == len(encrypted):
return encrypted
raise ValueError(f"No mock for {byte_count}")
def aio_mock_writer(_, __):
reader = mocker.patch("asyncio.StreamReader")
writer = mocker.patch("asyncio.StreamWriter")
mocker.patch.object(writer, "write", _timeout_all_attempts)
mocker.patch.object(reader, "readexactly", _mock_read)
mocker.patch.object(writer, "drain", new_callable=AsyncMock)
return reader, writer
config = DeviceConfig("127.0.0.1")
protocol = protocol_class(transport=transport_class(config=config))
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
with pytest.raises(
TimeoutError,
match="Timeout after 5 seconds sending request to the device 127.0.0.1:9999: Simulated timeout",
):
await protocol.query({})
writer_obj = protocol if hasattr(protocol, "writer") else protocol._transport
assert writer_obj.writer is None
@pytest.mark.parametrize(
("protocol_class", "transport_class", "encryption_class"),
[
(
_deprecated_TPLinkSmartHomeProtocol,
XorTransport,
_deprecated_TPLinkSmartHomeProtocol,
),
(IotProtocol, XorTransport, XorEncryption),
],
ids=("_deprecated_TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"),
)
async def test_protocol_handles_timeout_failure_during_connection(
mocker, protocol_class, transport_class, encryption_class
):
encrypted = encryption_class.encrypt('{"great":"success"}')[
transport_class.BLOCK_SIZE :
]
async def _mock_read(byte_count):
nonlocal encrypted
if byte_count == transport_class.BLOCK_SIZE:
return struct.pack(">I", len(encrypted))
if byte_count == len(encrypted):
return encrypted
raise ValueError(f"No mock for {byte_count}")
def aio_mock_writer(_, __):
raise TimeoutError("Simulated timeout")
config = DeviceConfig("127.0.0.1")
protocol = protocol_class(transport=transport_class(config=config))
writer_obj = protocol if hasattr(protocol, "writer") else protocol._transport
await writer_obj.close()
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
with pytest.raises(
TimeoutError,
match="Timeout after 5 seconds connecting to the device: 127.0.0.1:9999: Simulated timeout",
):
await protocol.query({})
writer_obj = protocol if hasattr(protocol, "writer") else protocol._transport
assert writer_obj.writer is None
@pytest.mark.parametrize( @pytest.mark.parametrize(
("protocol_class", "transport_class", "encryption_class"), ("protocol_class", "transport_class", "encryption_class"),
[ [