diff --git a/tests/protocols/test_iotprotocol.py b/tests/protocols/test_iotprotocol.py index 600a6b51..a185c70a 100644 --- a/tests/protocols/test_iotprotocol.py +++ b/tests/protocols/test_iotprotocol.py @@ -16,7 +16,7 @@ import pytest from kasa.credentials import Credentials from kasa.device import Device from kasa.deviceconfig import DeviceConfig -from kasa.exceptions import KasaException +from kasa.exceptions import KasaException, _RetryableError from kasa.iot import IotDevice from kasa.protocols.iotprotocol import IotProtocol, _deprecated_TPLinkSmartHomeProtocol from kasa.protocols.protocol import ( @@ -314,7 +314,7 @@ async def test_protocol_handles_timeout_during_write( transport_class.BLOCK_SIZE : ] - def _cancel_first_attempt(*_): + def _timeout_first_attempt(*_): nonlocal attempts attempts += 1 if attempts == 1: @@ -332,7 +332,7 @@ async def test_protocol_handles_timeout_during_write( def aio_mock_writer(_, __): reader = mocker.patch("asyncio.StreamReader") writer = mocker.patch("asyncio.StreamWriter") - mocker.patch.object(writer, "write", _cancel_first_attempt) + mocker.patch.object(writer, "write", _timeout_first_attempt) mocker.patch.object(reader, "readexactly", _mock_read) mocker.patch.object(writer, "drain", new_callable=AsyncMock) return reader, writer @@ -401,6 +401,103 @@ async def test_protocol_handles_timeout_during_connection( assert response == {"great": "success"} +@pytest.mark.parametrize( + ("protocol_class", "transport_class", "encryption_class"), + [ + ( + _deprecated_TPLinkSmartHomeProtocol, + XorTransport, + _deprecated_TPLinkSmartHomeProtocol, + ), + (IotProtocol, XorTransport, XorEncryption), + ], + ids=("_deprecated_TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"), +) +async def test_protocol_handles_timeout_failure_during_write( + mocker, protocol_class, transport_class, encryption_class +): + encrypted = encryption_class.encrypt('{"great":"success"}')[ + transport_class.BLOCK_SIZE : + ] + + def _timeout_all_attempts(*_): + raise TimeoutError("Simulated timeout") + + async def _mock_read(byte_count): + nonlocal encrypted + if byte_count == transport_class.BLOCK_SIZE: + return struct.pack(">I", len(encrypted)) + if byte_count == len(encrypted): + return encrypted + + raise ValueError(f"No mock for {byte_count}") + + def aio_mock_writer(_, __): + reader = mocker.patch("asyncio.StreamReader") + writer = mocker.patch("asyncio.StreamWriter") + mocker.patch.object(writer, "write", _timeout_all_attempts) + mocker.patch.object(reader, "readexactly", _mock_read) + mocker.patch.object(writer, "drain", new_callable=AsyncMock) + return reader, writer + + config = DeviceConfig("127.0.0.1") + protocol = protocol_class(transport=transport_class(config=config)) + mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) + with pytest.raises( + _RetryableError, + match="Timeout after 5 seconds sending request to the device 127.0.0.1:9999: Simulated timeout", + ): + await protocol.query({}) + writer_obj = protocol if hasattr(protocol, "writer") else protocol._transport + assert writer_obj.writer is None + + +@pytest.mark.parametrize( + ("protocol_class", "transport_class", "encryption_class"), + [ + ( + _deprecated_TPLinkSmartHomeProtocol, + XorTransport, + _deprecated_TPLinkSmartHomeProtocol, + ), + (IotProtocol, XorTransport, XorEncryption), + ], + ids=("_deprecated_TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"), +) +async def test_protocol_handles_timeout_failure_during_connection( + mocker, protocol_class, transport_class, encryption_class +): + encrypted = encryption_class.encrypt('{"great":"success"}')[ + transport_class.BLOCK_SIZE : + ] + + async def _mock_read(byte_count): + nonlocal encrypted + if byte_count == transport_class.BLOCK_SIZE: + return struct.pack(">I", len(encrypted)) + if byte_count == len(encrypted): + return encrypted + + raise ValueError(f"No mock for {byte_count}") + + def aio_mock_writer(_, __): + raise TimeoutError("Simulated timeout") + + config = DeviceConfig("127.0.0.1") + protocol = protocol_class(transport=transport_class(config=config)) + writer_obj = protocol if hasattr(protocol, "writer") else protocol._transport + await writer_obj.close() + + mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) + with pytest.raises( + _RetryableError, + match="Timeout after 5 seconds connecting to the device: 127.0.0.1:9999: Simulated timeout", + ): + await protocol.query({}) + writer_obj = protocol if hasattr(protocol, "writer") else protocol._transport + assert writer_obj.writer is None + + @pytest.mark.parametrize( ("protocol_class", "transport_class", "encryption_class"), [