diff --git a/kasa/protocol.py b/kasa/protocol.py index e86c07aa..74023e01 100755 --- a/kasa/protocol.py +++ b/kasa/protocol.py @@ -200,7 +200,6 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol): assert self.writer is not None # noqa: S101 assert self.reader is not None # noqa: S101 debug_log = _LOGGER.isEnabledFor(logging.DEBUG) - if debug_log: _LOGGER.debug("%s >> %s", self._host, request) self.writer.write(TPLinkSmartHomeProtocol.encrypt(request)) @@ -220,11 +219,17 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol): async def close(self) -> None: """Close the connection.""" writer = self.writer + self.close_without_wait() + if writer: + with contextlib.suppress(Exception): + await writer.wait_closed() + + def close_without_wait(self) -> None: + """Close the connection without waiting for the connection to close.""" + writer = self.writer self.reader = self.writer = None if writer: writer.close() - with contextlib.suppress(Exception): - await writer.wait_closed() def _reset(self) -> None: """Clear any varibles that should not survive between loops.""" @@ -266,6 +271,16 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol): f" {self._host}:{self._port}: {ex}" ) from ex continue + except BaseException as ex: + # Likely something cancelled the task so we need to close the connection + # as we are not in an indeterminate state + self.close_without_wait() + _LOGGER.debug( + "%s: BaseException during connect, closing connection: %s", + self._host, + ex, + ) + raise try: assert self.reader is not None # noqa: S101 @@ -283,6 +298,16 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol): _LOGGER.debug( "Unable to query the device %s, retrying: %s", self._host, ex ) + except BaseException as ex: + # Likely something cancelled the task so we need to close the connection + # as we are not in an indeterminate state + self.close_without_wait() + _LOGGER.debug( + "%s: BaseException during query, closing connection: %s", + self._host, + ex, + ) + raise # make mypy happy, this should never be reached.. await self.close() diff --git a/kasa/tests/test_protocol.py b/kasa/tests/test_protocol.py index 05ae40f3..563b8176 100644 --- a/kasa/tests/test_protocol.py +++ b/kasa/tests/test_protocol.py @@ -1,3 +1,4 @@ +import asyncio import errno import importlib import inspect @@ -122,6 +123,79 @@ async def test_protocol_reconnect(mocker, retry_count): assert response == {"great": "success"} +async def test_protocol_handles_cancellation_during_write(mocker): + attempts = 0 + encrypted = TPLinkSmartHomeProtocol.encrypt('{"great":"success"}')[ + TPLinkSmartHomeProtocol.BLOCK_SIZE : + ] + + def _cancel_first_attempt(*_): + nonlocal attempts + attempts += 1 + if attempts == 1: + raise asyncio.CancelledError("Simulated task cancel") + + async def _mock_read(byte_count): + nonlocal encrypted + if byte_count == TPLinkSmartHomeProtocol.BLOCK_SIZE: + return struct.pack(">I", len(encrypted)) + if byte_count == len(encrypted): + return encrypted + + raise ValueError(f"No mock for {byte_count}") + + def aio_mock_writer(_, __): + reader = mocker.patch("asyncio.StreamReader") + writer = mocker.patch("asyncio.StreamWriter") + mocker.patch.object(writer, "write", _cancel_first_attempt) + mocker.patch.object(reader, "readexactly", _mock_read) + return reader, writer + + config = DeviceConfig("127.0.0.1") + protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)) + mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) + with pytest.raises(asyncio.CancelledError): + await protocol.query({}) + assert protocol.writer is None + response = await protocol.query({}) + assert response == {"great": "success"} + + +async def test_protocol_handles_cancellation_during_connection(mocker): + attempts = 0 + encrypted = TPLinkSmartHomeProtocol.encrypt('{"great":"success"}')[ + TPLinkSmartHomeProtocol.BLOCK_SIZE : + ] + + async def _mock_read(byte_count): + nonlocal encrypted + if byte_count == TPLinkSmartHomeProtocol.BLOCK_SIZE: + return struct.pack(">I", len(encrypted)) + if byte_count == len(encrypted): + return encrypted + + raise ValueError(f"No mock for {byte_count}") + + def aio_mock_writer(_, __): + nonlocal attempts + attempts += 1 + if attempts == 1: + raise asyncio.CancelledError("Simulated task cancel") + reader = mocker.patch("asyncio.StreamReader") + writer = mocker.patch("asyncio.StreamWriter") + mocker.patch.object(reader, "readexactly", _mock_read) + return reader, writer + + config = DeviceConfig("127.0.0.1") + protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)) + mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) + with pytest.raises(asyncio.CancelledError): + await protocol.query({}) + assert protocol.writer is None + response = await protocol.query({}) + assert response == {"great": "success"} + + @pytest.mark.parametrize("log_level", [logging.WARNING, logging.DEBUG]) async def test_protocol_logging(mocker, caplog, log_level): caplog.set_level(log_level)