more cover

This commit is contained in:
J. Nick Koston 2025-01-05 13:56:35 -10:00
parent 12f7f33880
commit 9fd2f28420
No known key found for this signature in database

View File

@ -16,7 +16,7 @@ import pytest
from kasa.credentials import Credentials from kasa.credentials import Credentials
from kasa.device import Device from kasa.device import Device
from kasa.deviceconfig import DeviceConfig from kasa.deviceconfig import DeviceConfig
from kasa.exceptions import KasaException from kasa.exceptions import KasaException, _RetryableError
from kasa.iot import IotDevice from kasa.iot import IotDevice
from kasa.protocols.iotprotocol import IotProtocol, _deprecated_TPLinkSmartHomeProtocol from kasa.protocols.iotprotocol import IotProtocol, _deprecated_TPLinkSmartHomeProtocol
from kasa.protocols.protocol import ( from kasa.protocols.protocol import (
@ -314,7 +314,7 @@ async def test_protocol_handles_timeout_during_write(
transport_class.BLOCK_SIZE : transport_class.BLOCK_SIZE :
] ]
def _cancel_first_attempt(*_): def _timeout_first_attempt(*_):
nonlocal attempts nonlocal attempts
attempts += 1 attempts += 1
if attempts == 1: if attempts == 1:
@ -332,7 +332,7 @@ async def test_protocol_handles_timeout_during_write(
def aio_mock_writer(_, __): def aio_mock_writer(_, __):
reader = mocker.patch("asyncio.StreamReader") reader = mocker.patch("asyncio.StreamReader")
writer = mocker.patch("asyncio.StreamWriter") writer = mocker.patch("asyncio.StreamWriter")
mocker.patch.object(writer, "write", _cancel_first_attempt) mocker.patch.object(writer, "write", _timeout_first_attempt)
mocker.patch.object(reader, "readexactly", _mock_read) mocker.patch.object(reader, "readexactly", _mock_read)
mocker.patch.object(writer, "drain", new_callable=AsyncMock) mocker.patch.object(writer, "drain", new_callable=AsyncMock)
return reader, writer return reader, writer
@ -401,6 +401,103 @@ async def test_protocol_handles_timeout_during_connection(
assert response == {"great": "success"} assert response == {"great": "success"}
@pytest.mark.parametrize(
("protocol_class", "transport_class", "encryption_class"),
[
(
_deprecated_TPLinkSmartHomeProtocol,
XorTransport,
_deprecated_TPLinkSmartHomeProtocol,
),
(IotProtocol, XorTransport, XorEncryption),
],
ids=("_deprecated_TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"),
)
async def test_protocol_handles_timeout_failure_during_write(
mocker, protocol_class, transport_class, encryption_class
):
encrypted = encryption_class.encrypt('{"great":"success"}')[
transport_class.BLOCK_SIZE :
]
def _timeout_all_attempts(*_):
raise TimeoutError("Simulated timeout")
async def _mock_read(byte_count):
nonlocal encrypted
if byte_count == transport_class.BLOCK_SIZE:
return struct.pack(">I", len(encrypted))
if byte_count == len(encrypted):
return encrypted
raise ValueError(f"No mock for {byte_count}")
def aio_mock_writer(_, __):
reader = mocker.patch("asyncio.StreamReader")
writer = mocker.patch("asyncio.StreamWriter")
mocker.patch.object(writer, "write", _timeout_all_attempts)
mocker.patch.object(reader, "readexactly", _mock_read)
mocker.patch.object(writer, "drain", new_callable=AsyncMock)
return reader, writer
config = DeviceConfig("127.0.0.1")
protocol = protocol_class(transport=transport_class(config=config))
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
with pytest.raises(
_RetryableError,
match="Timeout after 5 seconds sending request to the device 127.0.0.1:9999: Simulated timeout",
):
await protocol.query({})
writer_obj = protocol if hasattr(protocol, "writer") else protocol._transport
assert writer_obj.writer is None
@pytest.mark.parametrize(
("protocol_class", "transport_class", "encryption_class"),
[
(
_deprecated_TPLinkSmartHomeProtocol,
XorTransport,
_deprecated_TPLinkSmartHomeProtocol,
),
(IotProtocol, XorTransport, XorEncryption),
],
ids=("_deprecated_TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"),
)
async def test_protocol_handles_timeout_failure_during_connection(
mocker, protocol_class, transport_class, encryption_class
):
encrypted = encryption_class.encrypt('{"great":"success"}')[
transport_class.BLOCK_SIZE :
]
async def _mock_read(byte_count):
nonlocal encrypted
if byte_count == transport_class.BLOCK_SIZE:
return struct.pack(">I", len(encrypted))
if byte_count == len(encrypted):
return encrypted
raise ValueError(f"No mock for {byte_count}")
def aio_mock_writer(_, __):
raise TimeoutError("Simulated timeout")
config = DeviceConfig("127.0.0.1")
protocol = protocol_class(transport=transport_class(config=config))
writer_obj = protocol if hasattr(protocol, "writer") else protocol._transport
await writer_obj.close()
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
with pytest.raises(
_RetryableError,
match="Timeout after 5 seconds connecting to the device: 127.0.0.1:9999: Simulated timeout",
):
await protocol.query({})
writer_obj = protocol if hasattr(protocol, "writer") else protocol._transport
assert writer_obj.writer is None
@pytest.mark.parametrize( @pytest.mark.parametrize(
("protocol_class", "transport_class", "encryption_class"), ("protocol_class", "transport_class", "encryption_class"),
[ [