mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-04-26 16:46:23 +00:00
Fix connection indeterminate state on cancellation (#636)
* Fix connection indeterminate state on cancellation If the task the query is running in it cancelled, we do know the state of the connection so we must close. Previously we would not close on BaseException which could result in reading the previous response if the previous query was cancelled after the request had been sent * add test for cancellation
This commit is contained in:
parent
816053fc6e
commit
aed67dad16
@ -200,7 +200,6 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
|
||||
assert self.writer is not None # noqa: S101
|
||||
assert self.reader is not None # noqa: S101
|
||||
debug_log = _LOGGER.isEnabledFor(logging.DEBUG)
|
||||
|
||||
if debug_log:
|
||||
_LOGGER.debug("%s >> %s", self._host, request)
|
||||
self.writer.write(TPLinkSmartHomeProtocol.encrypt(request))
|
||||
@ -220,11 +219,17 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
|
||||
async def close(self) -> None:
|
||||
"""Close the connection."""
|
||||
writer = self.writer
|
||||
self.close_without_wait()
|
||||
if writer:
|
||||
with contextlib.suppress(Exception):
|
||||
await writer.wait_closed()
|
||||
|
||||
def close_without_wait(self) -> None:
|
||||
"""Close the connection without waiting for the connection to close."""
|
||||
writer = self.writer
|
||||
self.reader = self.writer = None
|
||||
if writer:
|
||||
writer.close()
|
||||
with contextlib.suppress(Exception):
|
||||
await writer.wait_closed()
|
||||
|
||||
def _reset(self) -> None:
|
||||
"""Clear any varibles that should not survive between loops."""
|
||||
@ -266,6 +271,16 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
|
||||
f" {self._host}:{self._port}: {ex}"
|
||||
) from ex
|
||||
continue
|
||||
except BaseException as ex:
|
||||
# Likely something cancelled the task so we need to close the connection
|
||||
# as we are not in an indeterminate state
|
||||
self.close_without_wait()
|
||||
_LOGGER.debug(
|
||||
"%s: BaseException during connect, closing connection: %s",
|
||||
self._host,
|
||||
ex,
|
||||
)
|
||||
raise
|
||||
|
||||
try:
|
||||
assert self.reader is not None # noqa: S101
|
||||
@ -283,6 +298,16 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
|
||||
_LOGGER.debug(
|
||||
"Unable to query the device %s, retrying: %s", self._host, ex
|
||||
)
|
||||
except BaseException as ex:
|
||||
# Likely something cancelled the task so we need to close the connection
|
||||
# as we are not in an indeterminate state
|
||||
self.close_without_wait()
|
||||
_LOGGER.debug(
|
||||
"%s: BaseException during query, closing connection: %s",
|
||||
self._host,
|
||||
ex,
|
||||
)
|
||||
raise
|
||||
|
||||
# make mypy happy, this should never be reached..
|
||||
await self.close()
|
||||
|
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import errno
|
||||
import importlib
|
||||
import inspect
|
||||
@ -122,6 +123,79 @@ async def test_protocol_reconnect(mocker, retry_count):
|
||||
assert response == {"great": "success"}
|
||||
|
||||
|
||||
async def test_protocol_handles_cancellation_during_write(mocker):
|
||||
attempts = 0
|
||||
encrypted = TPLinkSmartHomeProtocol.encrypt('{"great":"success"}')[
|
||||
TPLinkSmartHomeProtocol.BLOCK_SIZE :
|
||||
]
|
||||
|
||||
def _cancel_first_attempt(*_):
|
||||
nonlocal attempts
|
||||
attempts += 1
|
||||
if attempts == 1:
|
||||
raise asyncio.CancelledError("Simulated task cancel")
|
||||
|
||||
async def _mock_read(byte_count):
|
||||
nonlocal encrypted
|
||||
if byte_count == TPLinkSmartHomeProtocol.BLOCK_SIZE:
|
||||
return struct.pack(">I", len(encrypted))
|
||||
if byte_count == len(encrypted):
|
||||
return encrypted
|
||||
|
||||
raise ValueError(f"No mock for {byte_count}")
|
||||
|
||||
def aio_mock_writer(_, __):
|
||||
reader = mocker.patch("asyncio.StreamReader")
|
||||
writer = mocker.patch("asyncio.StreamWriter")
|
||||
mocker.patch.object(writer, "write", _cancel_first_attempt)
|
||||
mocker.patch.object(reader, "readexactly", _mock_read)
|
||||
return reader, writer
|
||||
|
||||
config = DeviceConfig("127.0.0.1")
|
||||
protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config))
|
||||
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await protocol.query({})
|
||||
assert protocol.writer is None
|
||||
response = await protocol.query({})
|
||||
assert response == {"great": "success"}
|
||||
|
||||
|
||||
async def test_protocol_handles_cancellation_during_connection(mocker):
|
||||
attempts = 0
|
||||
encrypted = TPLinkSmartHomeProtocol.encrypt('{"great":"success"}')[
|
||||
TPLinkSmartHomeProtocol.BLOCK_SIZE :
|
||||
]
|
||||
|
||||
async def _mock_read(byte_count):
|
||||
nonlocal encrypted
|
||||
if byte_count == TPLinkSmartHomeProtocol.BLOCK_SIZE:
|
||||
return struct.pack(">I", len(encrypted))
|
||||
if byte_count == len(encrypted):
|
||||
return encrypted
|
||||
|
||||
raise ValueError(f"No mock for {byte_count}")
|
||||
|
||||
def aio_mock_writer(_, __):
|
||||
nonlocal attempts
|
||||
attempts += 1
|
||||
if attempts == 1:
|
||||
raise asyncio.CancelledError("Simulated task cancel")
|
||||
reader = mocker.patch("asyncio.StreamReader")
|
||||
writer = mocker.patch("asyncio.StreamWriter")
|
||||
mocker.patch.object(reader, "readexactly", _mock_read)
|
||||
return reader, writer
|
||||
|
||||
config = DeviceConfig("127.0.0.1")
|
||||
protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config))
|
||||
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await protocol.query({})
|
||||
assert protocol.writer is None
|
||||
response = await protocol.query({})
|
||||
assert response == {"great": "success"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("log_level", [logging.WARNING, logging.DEBUG])
|
||||
async def test_protocol_logging(mocker, caplog, log_level):
|
||||
caplog.set_level(log_level)
|
||||
|
Loading…
x
Reference in New Issue
Block a user