From aed67dad16fb59c111d14558b7786d72979979f6 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 13 Jan 2024 07:37:24 -1000 Subject: [PATCH] Fix connection indeterminate state on cancellation (#636) * Fix connection indeterminate state on cancellation If the task the query is running in it cancelled, we do know the state of the connection so we must close. Previously we would not close on BaseException which could result in reading the previous response if the previous query was cancelled after the request had been sent * add test for cancellation --- kasa/protocol.py | 31 ++++++++++++++-- kasa/tests/test_protocol.py | 74 +++++++++++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+), 3 deletions(-) diff --git a/kasa/protocol.py b/kasa/protocol.py index e86c07aa..74023e01 100755 --- a/kasa/protocol.py +++ b/kasa/protocol.py @@ -200,7 +200,6 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol): assert self.writer is not None # noqa: S101 assert self.reader is not None # noqa: S101 debug_log = _LOGGER.isEnabledFor(logging.DEBUG) - if debug_log: _LOGGER.debug("%s >> %s", self._host, request) self.writer.write(TPLinkSmartHomeProtocol.encrypt(request)) @@ -220,11 +219,17 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol): async def close(self) -> None: """Close the connection.""" writer = self.writer + self.close_without_wait() + if writer: + with contextlib.suppress(Exception): + await writer.wait_closed() + + def close_without_wait(self) -> None: + """Close the connection without waiting for the connection to close.""" + writer = self.writer self.reader = self.writer = None if writer: writer.close() - with contextlib.suppress(Exception): - await writer.wait_closed() def _reset(self) -> None: """Clear any varibles that should not survive between loops.""" @@ -266,6 +271,16 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol): f" {self._host}:{self._port}: {ex}" ) from ex continue + except BaseException as ex: + # Likely something cancelled the task so we need to close the connection + # as we are not in an indeterminate state + self.close_without_wait() + _LOGGER.debug( + "%s: BaseException during connect, closing connection: %s", + self._host, + ex, + ) + raise try: assert self.reader is not None # noqa: S101 @@ -283,6 +298,16 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol): _LOGGER.debug( "Unable to query the device %s, retrying: %s", self._host, ex ) + except BaseException as ex: + # Likely something cancelled the task so we need to close the connection + # as we are not in an indeterminate state + self.close_without_wait() + _LOGGER.debug( + "%s: BaseException during query, closing connection: %s", + self._host, + ex, + ) + raise # make mypy happy, this should never be reached.. await self.close() diff --git a/kasa/tests/test_protocol.py b/kasa/tests/test_protocol.py index 05ae40f3..563b8176 100644 --- a/kasa/tests/test_protocol.py +++ b/kasa/tests/test_protocol.py @@ -1,3 +1,4 @@ +import asyncio import errno import importlib import inspect @@ -122,6 +123,79 @@ async def test_protocol_reconnect(mocker, retry_count): assert response == {"great": "success"} +async def test_protocol_handles_cancellation_during_write(mocker): + attempts = 0 + encrypted = TPLinkSmartHomeProtocol.encrypt('{"great":"success"}')[ + TPLinkSmartHomeProtocol.BLOCK_SIZE : + ] + + def _cancel_first_attempt(*_): + nonlocal attempts + attempts += 1 + if attempts == 1: + raise asyncio.CancelledError("Simulated task cancel") + + async def _mock_read(byte_count): + nonlocal encrypted + if byte_count == TPLinkSmartHomeProtocol.BLOCK_SIZE: + return struct.pack(">I", len(encrypted)) + if byte_count == len(encrypted): + return encrypted + + raise ValueError(f"No mock for {byte_count}") + + def aio_mock_writer(_, __): + reader = mocker.patch("asyncio.StreamReader") + writer = mocker.patch("asyncio.StreamWriter") + mocker.patch.object(writer, "write", _cancel_first_attempt) + mocker.patch.object(reader, "readexactly", _mock_read) + return reader, writer + + config = DeviceConfig("127.0.0.1") + protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)) + mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) + with pytest.raises(asyncio.CancelledError): + await protocol.query({}) + assert protocol.writer is None + response = await protocol.query({}) + assert response == {"great": "success"} + + +async def test_protocol_handles_cancellation_during_connection(mocker): + attempts = 0 + encrypted = TPLinkSmartHomeProtocol.encrypt('{"great":"success"}')[ + TPLinkSmartHomeProtocol.BLOCK_SIZE : + ] + + async def _mock_read(byte_count): + nonlocal encrypted + if byte_count == TPLinkSmartHomeProtocol.BLOCK_SIZE: + return struct.pack(">I", len(encrypted)) + if byte_count == len(encrypted): + return encrypted + + raise ValueError(f"No mock for {byte_count}") + + def aio_mock_writer(_, __): + nonlocal attempts + attempts += 1 + if attempts == 1: + raise asyncio.CancelledError("Simulated task cancel") + reader = mocker.patch("asyncio.StreamReader") + writer = mocker.patch("asyncio.StreamWriter") + mocker.patch.object(reader, "readexactly", _mock_read) + return reader, writer + + config = DeviceConfig("127.0.0.1") + protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)) + mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) + with pytest.raises(asyncio.CancelledError): + await protocol.query({}) + assert protocol.writer is None + response = await protocol.query({}) + assert response == {"great": "success"} + + @pytest.mark.parametrize("log_level", [logging.WARNING, logging.DEBUG]) async def test_protocol_logging(mocker, caplog, log_level): caplog.set_level(log_level)