more cover

This commit is contained in:
J. Nick Koston 2025-01-05 13:56:35 -10:00
parent 12f7f33880
commit 9fd2f28420
No known key found for this signature in database

View File

@ -16,7 +16,7 @@ import pytest
from kasa.credentials import Credentials
from kasa.device import Device
from kasa.deviceconfig import DeviceConfig
from kasa.exceptions import KasaException
from kasa.exceptions import KasaException, _RetryableError
from kasa.iot import IotDevice
from kasa.protocols.iotprotocol import IotProtocol, _deprecated_TPLinkSmartHomeProtocol
from kasa.protocols.protocol import (
@ -314,7 +314,7 @@ async def test_protocol_handles_timeout_during_write(
transport_class.BLOCK_SIZE :
]
def _cancel_first_attempt(*_):
def _timeout_first_attempt(*_):
nonlocal attempts
attempts += 1
if attempts == 1:
@ -332,7 +332,7 @@ async def test_protocol_handles_timeout_during_write(
def aio_mock_writer(_, __):
reader = mocker.patch("asyncio.StreamReader")
writer = mocker.patch("asyncio.StreamWriter")
mocker.patch.object(writer, "write", _cancel_first_attempt)
mocker.patch.object(writer, "write", _timeout_first_attempt)
mocker.patch.object(reader, "readexactly", _mock_read)
mocker.patch.object(writer, "drain", new_callable=AsyncMock)
return reader, writer
@ -401,6 +401,103 @@ async def test_protocol_handles_timeout_during_connection(
assert response == {"great": "success"}
@pytest.mark.parametrize(
("protocol_class", "transport_class", "encryption_class"),
[
(
_deprecated_TPLinkSmartHomeProtocol,
XorTransport,
_deprecated_TPLinkSmartHomeProtocol,
),
(IotProtocol, XorTransport, XorEncryption),
],
ids=("_deprecated_TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"),
)
async def test_protocol_handles_timeout_failure_during_write(
mocker, protocol_class, transport_class, encryption_class
):
encrypted = encryption_class.encrypt('{"great":"success"}')[
transport_class.BLOCK_SIZE :
]
def _timeout_all_attempts(*_):
raise TimeoutError("Simulated timeout")
async def _mock_read(byte_count):
nonlocal encrypted
if byte_count == transport_class.BLOCK_SIZE:
return struct.pack(">I", len(encrypted))
if byte_count == len(encrypted):
return encrypted
raise ValueError(f"No mock for {byte_count}")
def aio_mock_writer(_, __):
reader = mocker.patch("asyncio.StreamReader")
writer = mocker.patch("asyncio.StreamWriter")
mocker.patch.object(writer, "write", _timeout_all_attempts)
mocker.patch.object(reader, "readexactly", _mock_read)
mocker.patch.object(writer, "drain", new_callable=AsyncMock)
return reader, writer
config = DeviceConfig("127.0.0.1")
protocol = protocol_class(transport=transport_class(config=config))
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
with pytest.raises(
_RetryableError,
match="Timeout after 5 seconds sending request to the device 127.0.0.1:9999: Simulated timeout",
):
await protocol.query({})
writer_obj = protocol if hasattr(protocol, "writer") else protocol._transport
assert writer_obj.writer is None
@pytest.mark.parametrize(
("protocol_class", "transport_class", "encryption_class"),
[
(
_deprecated_TPLinkSmartHomeProtocol,
XorTransport,
_deprecated_TPLinkSmartHomeProtocol,
),
(IotProtocol, XorTransport, XorEncryption),
],
ids=("_deprecated_TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"),
)
async def test_protocol_handles_timeout_failure_during_connection(
mocker, protocol_class, transport_class, encryption_class
):
encrypted = encryption_class.encrypt('{"great":"success"}')[
transport_class.BLOCK_SIZE :
]
async def _mock_read(byte_count):
nonlocal encrypted
if byte_count == transport_class.BLOCK_SIZE:
return struct.pack(">I", len(encrypted))
if byte_count == len(encrypted):
return encrypted
raise ValueError(f"No mock for {byte_count}")
def aio_mock_writer(_, __):
raise TimeoutError("Simulated timeout")
config = DeviceConfig("127.0.0.1")
protocol = protocol_class(transport=transport_class(config=config))
writer_obj = protocol if hasattr(protocol, "writer") else protocol._transport
await writer_obj.close()
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
with pytest.raises(
_RetryableError,
match="Timeout after 5 seconds connecting to the device: 127.0.0.1:9999: Simulated timeout",
):
await protocol.query({})
writer_obj = protocol if hasattr(protocol, "writer") else protocol._transport
assert writer_obj.writer is None
@pytest.mark.parametrize(
("protocol_class", "transport_class", "encryption_class"),
[