Backoff after xor timeout and improve error reporting (#1424)

This commit is contained in:
J. Nick Koston
2025-01-06 04:00:23 -10:00
committed by GitHub
parent 48a07a2970
commit 7d508b5092
3 changed files with 232 additions and 1 deletions

View File

@@ -16,7 +16,7 @@ import pytest
from kasa.credentials import Credentials
from kasa.device import Device
from kasa.deviceconfig import DeviceConfig
from kasa.exceptions import KasaException
from kasa.exceptions import KasaException, TimeoutError
from kasa.iot import IotDevice
from kasa.protocols.iotprotocol import IotProtocol, _deprecated_TPLinkSmartHomeProtocol
from kasa.protocols.protocol import (
@@ -294,6 +294,210 @@ async def test_protocol_handles_cancellation_during_connection(
assert response == {"great": "success"}
@pytest.mark.parametrize(
("protocol_class", "transport_class", "encryption_class"),
[
(
_deprecated_TPLinkSmartHomeProtocol,
XorTransport,
_deprecated_TPLinkSmartHomeProtocol,
),
(IotProtocol, XorTransport, XorEncryption),
],
ids=("_deprecated_TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"),
)
async def test_protocol_handles_timeout_during_write(
mocker, protocol_class, transport_class, encryption_class
):
attempts = 0
encrypted = encryption_class.encrypt('{"great":"success"}')[
transport_class.BLOCK_SIZE :
]
def _timeout_first_attempt(*_):
nonlocal attempts
attempts += 1
if attempts == 1:
raise TimeoutError("Simulated timeout")
async def _mock_read(byte_count):
nonlocal encrypted
if byte_count == transport_class.BLOCK_SIZE:
return struct.pack(">I", len(encrypted))
if byte_count == len(encrypted):
return encrypted
raise ValueError(f"No mock for {byte_count}")
def aio_mock_writer(_, __):
reader = mocker.patch("asyncio.StreamReader")
writer = mocker.patch("asyncio.StreamWriter")
mocker.patch.object(writer, "write", _timeout_first_attempt)
mocker.patch.object(reader, "readexactly", _mock_read)
mocker.patch.object(writer, "drain", new_callable=AsyncMock)
return reader, writer
config = DeviceConfig("127.0.0.1")
protocol = protocol_class(transport=transport_class(config=config))
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
await protocol.query({})
writer_obj = protocol if hasattr(protocol, "writer") else protocol._transport
assert writer_obj.writer is not None
response = await protocol.query({})
assert response == {"great": "success"}
@pytest.mark.parametrize(
("protocol_class", "transport_class", "encryption_class"),
[
(
_deprecated_TPLinkSmartHomeProtocol,
XorTransport,
_deprecated_TPLinkSmartHomeProtocol,
),
(IotProtocol, XorTransport, XorEncryption),
],
ids=("_deprecated_TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"),
)
async def test_protocol_handles_timeout_during_connection(
mocker, protocol_class, transport_class, encryption_class
):
attempts = 0
encrypted = encryption_class.encrypt('{"great":"success"}')[
transport_class.BLOCK_SIZE :
]
async def _mock_read(byte_count):
nonlocal encrypted
if byte_count == transport_class.BLOCK_SIZE:
return struct.pack(">I", len(encrypted))
if byte_count == len(encrypted):
return encrypted
raise ValueError(f"No mock for {byte_count}")
def aio_mock_writer(_, __):
nonlocal attempts
attempts += 1
if attempts == 1:
raise TimeoutError("Simulated timeout")
reader = mocker.patch("asyncio.StreamReader")
writer = mocker.patch("asyncio.StreamWriter")
mocker.patch.object(reader, "readexactly", _mock_read)
mocker.patch.object(writer, "drain", new_callable=AsyncMock)
return reader, writer
config = DeviceConfig("127.0.0.1")
protocol = protocol_class(transport=transport_class(config=config))
writer_obj = protocol if hasattr(protocol, "writer") else protocol._transport
await writer_obj.close()
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
await protocol.query({"any": "thing"})
writer_obj = protocol if hasattr(protocol, "writer") else protocol._transport
assert writer_obj.writer is not None
response = await protocol.query({})
assert response == {"great": "success"}
@pytest.mark.parametrize(
("protocol_class", "transport_class", "encryption_class"),
[
(
_deprecated_TPLinkSmartHomeProtocol,
XorTransport,
_deprecated_TPLinkSmartHomeProtocol,
),
(IotProtocol, XorTransport, XorEncryption),
],
ids=("_deprecated_TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"),
)
async def test_protocol_handles_timeout_failure_during_write(
mocker, protocol_class, transport_class, encryption_class
):
encrypted = encryption_class.encrypt('{"great":"success"}')[
transport_class.BLOCK_SIZE :
]
def _timeout_all_attempts(*_):
raise TimeoutError("Simulated timeout")
async def _mock_read(byte_count):
nonlocal encrypted
if byte_count == transport_class.BLOCK_SIZE:
return struct.pack(">I", len(encrypted))
if byte_count == len(encrypted):
return encrypted
raise ValueError(f"No mock for {byte_count}")
def aio_mock_writer(_, __):
reader = mocker.patch("asyncio.StreamReader")
writer = mocker.patch("asyncio.StreamWriter")
mocker.patch.object(writer, "write", _timeout_all_attempts)
mocker.patch.object(reader, "readexactly", _mock_read)
mocker.patch.object(writer, "drain", new_callable=AsyncMock)
return reader, writer
config = DeviceConfig("127.0.0.1")
protocol = protocol_class(transport=transport_class(config=config))
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
with pytest.raises(
TimeoutError,
match="Timeout after 5 seconds sending request to the device 127.0.0.1:9999: Simulated timeout",
):
await protocol.query({})
writer_obj = protocol if hasattr(protocol, "writer") else protocol._transport
assert writer_obj.writer is None
@pytest.mark.parametrize(
("protocol_class", "transport_class", "encryption_class"),
[
(
_deprecated_TPLinkSmartHomeProtocol,
XorTransport,
_deprecated_TPLinkSmartHomeProtocol,
),
(IotProtocol, XorTransport, XorEncryption),
],
ids=("_deprecated_TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"),
)
async def test_protocol_handles_timeout_failure_during_connection(
mocker, protocol_class, transport_class, encryption_class
):
encrypted = encryption_class.encrypt('{"great":"success"}')[
transport_class.BLOCK_SIZE :
]
async def _mock_read(byte_count):
nonlocal encrypted
if byte_count == transport_class.BLOCK_SIZE:
return struct.pack(">I", len(encrypted))
if byte_count == len(encrypted):
return encrypted
raise ValueError(f"No mock for {byte_count}")
def aio_mock_writer(_, __):
raise TimeoutError("Simulated timeout")
config = DeviceConfig("127.0.0.1")
protocol = protocol_class(transport=transport_class(config=config))
writer_obj = protocol if hasattr(protocol, "writer") else protocol._transport
await writer_obj.close()
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
with pytest.raises(
TimeoutError,
match="Timeout after 5 seconds connecting to the device: 127.0.0.1:9999: Simulated timeout",
):
await protocol.query({})
writer_obj = protocol if hasattr(protocol, "writer") else protocol._transport
assert writer_obj.writer is None
@pytest.mark.parametrize(
("protocol_class", "transport_class", "encryption_class"),
[