Fix connection indeterminate state on cancellation (#636)

* Fix connection indeterminate state on cancellation

If the task the query is running in it cancelled, we do
know the state of the connection so we must close. Previously
we would not close on BaseException which could result
in reading the previous response if the previous query was
cancelled after the request had been sent

* add test for cancellation
This commit is contained in:
J. Nick Koston 2024-01-13 07:37:24 -10:00 committed by GitHub
parent 816053fc6e
commit aed67dad16
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 102 additions and 3 deletions

View File

@ -200,7 +200,6 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
assert self.writer is not None # noqa: S101
assert self.reader is not None # noqa: S101
debug_log = _LOGGER.isEnabledFor(logging.DEBUG)
if debug_log:
_LOGGER.debug("%s >> %s", self._host, request)
self.writer.write(TPLinkSmartHomeProtocol.encrypt(request))
@ -220,11 +219,17 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
async def close(self) -> None:
"""Close the connection."""
writer = self.writer
self.close_without_wait()
if writer:
with contextlib.suppress(Exception):
await writer.wait_closed()
def close_without_wait(self) -> None:
"""Close the connection without waiting for the connection to close."""
writer = self.writer
self.reader = self.writer = None
if writer:
writer.close()
with contextlib.suppress(Exception):
await writer.wait_closed()
def _reset(self) -> None:
"""Clear any varibles that should not survive between loops."""
@ -266,6 +271,16 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
f" {self._host}:{self._port}: {ex}"
) from ex
continue
except BaseException as ex:
# Likely something cancelled the task so we need to close the connection
# as we are not in an indeterminate state
self.close_without_wait()
_LOGGER.debug(
"%s: BaseException during connect, closing connection: %s",
self._host,
ex,
)
raise
try:
assert self.reader is not None # noqa: S101
@ -283,6 +298,16 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
_LOGGER.debug(
"Unable to query the device %s, retrying: %s", self._host, ex
)
except BaseException as ex:
# Likely something cancelled the task so we need to close the connection
# as we are not in an indeterminate state
self.close_without_wait()
_LOGGER.debug(
"%s: BaseException during query, closing connection: %s",
self._host,
ex,
)
raise
# make mypy happy, this should never be reached..
await self.close()

View File

@ -1,3 +1,4 @@
import asyncio
import errno
import importlib
import inspect
@ -122,6 +123,79 @@ async def test_protocol_reconnect(mocker, retry_count):
assert response == {"great": "success"}
async def test_protocol_handles_cancellation_during_write(mocker):
attempts = 0
encrypted = TPLinkSmartHomeProtocol.encrypt('{"great":"success"}')[
TPLinkSmartHomeProtocol.BLOCK_SIZE :
]
def _cancel_first_attempt(*_):
nonlocal attempts
attempts += 1
if attempts == 1:
raise asyncio.CancelledError("Simulated task cancel")
async def _mock_read(byte_count):
nonlocal encrypted
if byte_count == TPLinkSmartHomeProtocol.BLOCK_SIZE:
return struct.pack(">I", len(encrypted))
if byte_count == len(encrypted):
return encrypted
raise ValueError(f"No mock for {byte_count}")
def aio_mock_writer(_, __):
reader = mocker.patch("asyncio.StreamReader")
writer = mocker.patch("asyncio.StreamWriter")
mocker.patch.object(writer, "write", _cancel_first_attempt)
mocker.patch.object(reader, "readexactly", _mock_read)
return reader, writer
config = DeviceConfig("127.0.0.1")
protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config))
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
with pytest.raises(asyncio.CancelledError):
await protocol.query({})
assert protocol.writer is None
response = await protocol.query({})
assert response == {"great": "success"}
async def test_protocol_handles_cancellation_during_connection(mocker):
attempts = 0
encrypted = TPLinkSmartHomeProtocol.encrypt('{"great":"success"}')[
TPLinkSmartHomeProtocol.BLOCK_SIZE :
]
async def _mock_read(byte_count):
nonlocal encrypted
if byte_count == TPLinkSmartHomeProtocol.BLOCK_SIZE:
return struct.pack(">I", len(encrypted))
if byte_count == len(encrypted):
return encrypted
raise ValueError(f"No mock for {byte_count}")
def aio_mock_writer(_, __):
nonlocal attempts
attempts += 1
if attempts == 1:
raise asyncio.CancelledError("Simulated task cancel")
reader = mocker.patch("asyncio.StreamReader")
writer = mocker.patch("asyncio.StreamWriter")
mocker.patch.object(reader, "readexactly", _mock_read)
return reader, writer
config = DeviceConfig("127.0.0.1")
protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config))
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
with pytest.raises(asyncio.CancelledError):
await protocol.query({})
assert protocol.writer is None
response = await protocol.query({})
assert response == {"great": "success"}
@pytest.mark.parametrize("log_level", [logging.WARNING, logging.DEBUG])
async def test_protocol_logging(mocker, caplog, log_level):
caplog.set_level(log_level)