Fix connection indeterminate state on cancellation (#636)

* Fix connection indeterminate state on cancellation

If the task the query is running in it cancelled, we do
know the state of the connection so we must close. Previously
we would not close on BaseException which could result
in reading the previous response if the previous query was
cancelled after the request had been sent

* add test for cancellation
This commit is contained in:
J. Nick Koston
2024-01-13 07:37:24 -10:00
committed by GitHub
parent 816053fc6e
commit aed67dad16
2 changed files with 102 additions and 3 deletions

View File

@@ -1,3 +1,4 @@
import asyncio
import errno
import importlib
import inspect
@@ -122,6 +123,79 @@ async def test_protocol_reconnect(mocker, retry_count):
assert response == {"great": "success"}
async def test_protocol_handles_cancellation_during_write(mocker):
attempts = 0
encrypted = TPLinkSmartHomeProtocol.encrypt('{"great":"success"}')[
TPLinkSmartHomeProtocol.BLOCK_SIZE :
]
def _cancel_first_attempt(*_):
nonlocal attempts
attempts += 1
if attempts == 1:
raise asyncio.CancelledError("Simulated task cancel")
async def _mock_read(byte_count):
nonlocal encrypted
if byte_count == TPLinkSmartHomeProtocol.BLOCK_SIZE:
return struct.pack(">I", len(encrypted))
if byte_count == len(encrypted):
return encrypted
raise ValueError(f"No mock for {byte_count}")
def aio_mock_writer(_, __):
reader = mocker.patch("asyncio.StreamReader")
writer = mocker.patch("asyncio.StreamWriter")
mocker.patch.object(writer, "write", _cancel_first_attempt)
mocker.patch.object(reader, "readexactly", _mock_read)
return reader, writer
config = DeviceConfig("127.0.0.1")
protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config))
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
with pytest.raises(asyncio.CancelledError):
await protocol.query({})
assert protocol.writer is None
response = await protocol.query({})
assert response == {"great": "success"}
async def test_protocol_handles_cancellation_during_connection(mocker):
attempts = 0
encrypted = TPLinkSmartHomeProtocol.encrypt('{"great":"success"}')[
TPLinkSmartHomeProtocol.BLOCK_SIZE :
]
async def _mock_read(byte_count):
nonlocal encrypted
if byte_count == TPLinkSmartHomeProtocol.BLOCK_SIZE:
return struct.pack(">I", len(encrypted))
if byte_count == len(encrypted):
return encrypted
raise ValueError(f"No mock for {byte_count}")
def aio_mock_writer(_, __):
nonlocal attempts
attempts += 1
if attempts == 1:
raise asyncio.CancelledError("Simulated task cancel")
reader = mocker.patch("asyncio.StreamReader")
writer = mocker.patch("asyncio.StreamWriter")
mocker.patch.object(reader, "readexactly", _mock_read)
return reader, writer
config = DeviceConfig("127.0.0.1")
protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config))
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
with pytest.raises(asyncio.CancelledError):
await protocol.query({})
assert protocol.writer is None
response = await protocol.query({})
assert response == {"great": "success"}
@pytest.mark.parametrize("log_level", [logging.WARNING, logging.DEBUG])
async def test_protocol_logging(mocker, caplog, log_level):
caplog.set_level(log_level)