mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-08-09 20:24:02 +00:00
Fix connection indeterminate state on cancellation (#636)
* Fix connection indeterminate state on cancellation If the task the query is running in it cancelled, we do know the state of the connection so we must close. Previously we would not close on BaseException which could result in reading the previous response if the previous query was cancelled after the request had been sent * add test for cancellation
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import errno
|
||||
import importlib
|
||||
import inspect
|
||||
@@ -122,6 +123,79 @@ async def test_protocol_reconnect(mocker, retry_count):
|
||||
assert response == {"great": "success"}
|
||||
|
||||
|
||||
async def test_protocol_handles_cancellation_during_write(mocker):
|
||||
attempts = 0
|
||||
encrypted = TPLinkSmartHomeProtocol.encrypt('{"great":"success"}')[
|
||||
TPLinkSmartHomeProtocol.BLOCK_SIZE :
|
||||
]
|
||||
|
||||
def _cancel_first_attempt(*_):
|
||||
nonlocal attempts
|
||||
attempts += 1
|
||||
if attempts == 1:
|
||||
raise asyncio.CancelledError("Simulated task cancel")
|
||||
|
||||
async def _mock_read(byte_count):
|
||||
nonlocal encrypted
|
||||
if byte_count == TPLinkSmartHomeProtocol.BLOCK_SIZE:
|
||||
return struct.pack(">I", len(encrypted))
|
||||
if byte_count == len(encrypted):
|
||||
return encrypted
|
||||
|
||||
raise ValueError(f"No mock for {byte_count}")
|
||||
|
||||
def aio_mock_writer(_, __):
|
||||
reader = mocker.patch("asyncio.StreamReader")
|
||||
writer = mocker.patch("asyncio.StreamWriter")
|
||||
mocker.patch.object(writer, "write", _cancel_first_attempt)
|
||||
mocker.patch.object(reader, "readexactly", _mock_read)
|
||||
return reader, writer
|
||||
|
||||
config = DeviceConfig("127.0.0.1")
|
||||
protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config))
|
||||
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await protocol.query({})
|
||||
assert protocol.writer is None
|
||||
response = await protocol.query({})
|
||||
assert response == {"great": "success"}
|
||||
|
||||
|
||||
async def test_protocol_handles_cancellation_during_connection(mocker):
|
||||
attempts = 0
|
||||
encrypted = TPLinkSmartHomeProtocol.encrypt('{"great":"success"}')[
|
||||
TPLinkSmartHomeProtocol.BLOCK_SIZE :
|
||||
]
|
||||
|
||||
async def _mock_read(byte_count):
|
||||
nonlocal encrypted
|
||||
if byte_count == TPLinkSmartHomeProtocol.BLOCK_SIZE:
|
||||
return struct.pack(">I", len(encrypted))
|
||||
if byte_count == len(encrypted):
|
||||
return encrypted
|
||||
|
||||
raise ValueError(f"No mock for {byte_count}")
|
||||
|
||||
def aio_mock_writer(_, __):
|
||||
nonlocal attempts
|
||||
attempts += 1
|
||||
if attempts == 1:
|
||||
raise asyncio.CancelledError("Simulated task cancel")
|
||||
reader = mocker.patch("asyncio.StreamReader")
|
||||
writer = mocker.patch("asyncio.StreamWriter")
|
||||
mocker.patch.object(reader, "readexactly", _mock_read)
|
||||
return reader, writer
|
||||
|
||||
config = DeviceConfig("127.0.0.1")
|
||||
protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config))
|
||||
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await protocol.query({})
|
||||
assert protocol.writer is None
|
||||
response = await protocol.query({})
|
||||
assert response == {"great": "success"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("log_level", [logging.WARNING, logging.DEBUG])
|
||||
async def test_protocol_logging(mocker, caplog, log_level):
|
||||
caplog.set_level(log_level)
|
||||
|
Reference in New Issue
Block a user