mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-05-30 21:51:24 +00:00
more cover
This commit is contained in:
parent
12f7f33880
commit
9fd2f28420
@ -16,7 +16,7 @@ import pytest
|
|||||||
from kasa.credentials import Credentials
|
from kasa.credentials import Credentials
|
||||||
from kasa.device import Device
|
from kasa.device import Device
|
||||||
from kasa.deviceconfig import DeviceConfig
|
from kasa.deviceconfig import DeviceConfig
|
||||||
from kasa.exceptions import KasaException
|
from kasa.exceptions import KasaException, _RetryableError
|
||||||
from kasa.iot import IotDevice
|
from kasa.iot import IotDevice
|
||||||
from kasa.protocols.iotprotocol import IotProtocol, _deprecated_TPLinkSmartHomeProtocol
|
from kasa.protocols.iotprotocol import IotProtocol, _deprecated_TPLinkSmartHomeProtocol
|
||||||
from kasa.protocols.protocol import (
|
from kasa.protocols.protocol import (
|
||||||
@ -314,7 +314,7 @@ async def test_protocol_handles_timeout_during_write(
|
|||||||
transport_class.BLOCK_SIZE :
|
transport_class.BLOCK_SIZE :
|
||||||
]
|
]
|
||||||
|
|
||||||
def _cancel_first_attempt(*_):
|
def _timeout_first_attempt(*_):
|
||||||
nonlocal attempts
|
nonlocal attempts
|
||||||
attempts += 1
|
attempts += 1
|
||||||
if attempts == 1:
|
if attempts == 1:
|
||||||
@ -332,7 +332,7 @@ async def test_protocol_handles_timeout_during_write(
|
|||||||
def aio_mock_writer(_, __):
|
def aio_mock_writer(_, __):
|
||||||
reader = mocker.patch("asyncio.StreamReader")
|
reader = mocker.patch("asyncio.StreamReader")
|
||||||
writer = mocker.patch("asyncio.StreamWriter")
|
writer = mocker.patch("asyncio.StreamWriter")
|
||||||
mocker.patch.object(writer, "write", _cancel_first_attempt)
|
mocker.patch.object(writer, "write", _timeout_first_attempt)
|
||||||
mocker.patch.object(reader, "readexactly", _mock_read)
|
mocker.patch.object(reader, "readexactly", _mock_read)
|
||||||
mocker.patch.object(writer, "drain", new_callable=AsyncMock)
|
mocker.patch.object(writer, "drain", new_callable=AsyncMock)
|
||||||
return reader, writer
|
return reader, writer
|
||||||
@ -401,6 +401,103 @@ async def test_protocol_handles_timeout_during_connection(
|
|||||||
assert response == {"great": "success"}
|
assert response == {"great": "success"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("protocol_class", "transport_class", "encryption_class"),
|
||||||
|
[
|
||||||
|
(
|
||||||
|
_deprecated_TPLinkSmartHomeProtocol,
|
||||||
|
XorTransport,
|
||||||
|
_deprecated_TPLinkSmartHomeProtocol,
|
||||||
|
),
|
||||||
|
(IotProtocol, XorTransport, XorEncryption),
|
||||||
|
],
|
||||||
|
ids=("_deprecated_TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"),
|
||||||
|
)
|
||||||
|
async def test_protocol_handles_timeout_failure_during_write(
|
||||||
|
mocker, protocol_class, transport_class, encryption_class
|
||||||
|
):
|
||||||
|
encrypted = encryption_class.encrypt('{"great":"success"}')[
|
||||||
|
transport_class.BLOCK_SIZE :
|
||||||
|
]
|
||||||
|
|
||||||
|
def _timeout_all_attempts(*_):
|
||||||
|
raise TimeoutError("Simulated timeout")
|
||||||
|
|
||||||
|
async def _mock_read(byte_count):
|
||||||
|
nonlocal encrypted
|
||||||
|
if byte_count == transport_class.BLOCK_SIZE:
|
||||||
|
return struct.pack(">I", len(encrypted))
|
||||||
|
if byte_count == len(encrypted):
|
||||||
|
return encrypted
|
||||||
|
|
||||||
|
raise ValueError(f"No mock for {byte_count}")
|
||||||
|
|
||||||
|
def aio_mock_writer(_, __):
|
||||||
|
reader = mocker.patch("asyncio.StreamReader")
|
||||||
|
writer = mocker.patch("asyncio.StreamWriter")
|
||||||
|
mocker.patch.object(writer, "write", _timeout_all_attempts)
|
||||||
|
mocker.patch.object(reader, "readexactly", _mock_read)
|
||||||
|
mocker.patch.object(writer, "drain", new_callable=AsyncMock)
|
||||||
|
return reader, writer
|
||||||
|
|
||||||
|
config = DeviceConfig("127.0.0.1")
|
||||||
|
protocol = protocol_class(transport=transport_class(config=config))
|
||||||
|
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
|
||||||
|
with pytest.raises(
|
||||||
|
_RetryableError,
|
||||||
|
match="Timeout after 5 seconds sending request to the device 127.0.0.1:9999: Simulated timeout",
|
||||||
|
):
|
||||||
|
await protocol.query({})
|
||||||
|
writer_obj = protocol if hasattr(protocol, "writer") else protocol._transport
|
||||||
|
assert writer_obj.writer is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("protocol_class", "transport_class", "encryption_class"),
|
||||||
|
[
|
||||||
|
(
|
||||||
|
_deprecated_TPLinkSmartHomeProtocol,
|
||||||
|
XorTransport,
|
||||||
|
_deprecated_TPLinkSmartHomeProtocol,
|
||||||
|
),
|
||||||
|
(IotProtocol, XorTransport, XorEncryption),
|
||||||
|
],
|
||||||
|
ids=("_deprecated_TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"),
|
||||||
|
)
|
||||||
|
async def test_protocol_handles_timeout_failure_during_connection(
|
||||||
|
mocker, protocol_class, transport_class, encryption_class
|
||||||
|
):
|
||||||
|
encrypted = encryption_class.encrypt('{"great":"success"}')[
|
||||||
|
transport_class.BLOCK_SIZE :
|
||||||
|
]
|
||||||
|
|
||||||
|
async def _mock_read(byte_count):
|
||||||
|
nonlocal encrypted
|
||||||
|
if byte_count == transport_class.BLOCK_SIZE:
|
||||||
|
return struct.pack(">I", len(encrypted))
|
||||||
|
if byte_count == len(encrypted):
|
||||||
|
return encrypted
|
||||||
|
|
||||||
|
raise ValueError(f"No mock for {byte_count}")
|
||||||
|
|
||||||
|
def aio_mock_writer(_, __):
|
||||||
|
raise TimeoutError("Simulated timeout")
|
||||||
|
|
||||||
|
config = DeviceConfig("127.0.0.1")
|
||||||
|
protocol = protocol_class(transport=transport_class(config=config))
|
||||||
|
writer_obj = protocol if hasattr(protocol, "writer") else protocol._transport
|
||||||
|
await writer_obj.close()
|
||||||
|
|
||||||
|
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
|
||||||
|
with pytest.raises(
|
||||||
|
_RetryableError,
|
||||||
|
match="Timeout after 5 seconds connecting to the device: 127.0.0.1:9999: Simulated timeout",
|
||||||
|
):
|
||||||
|
await protocol.query({})
|
||||||
|
writer_obj = protocol if hasattr(protocol, "writer") else protocol._transport
|
||||||
|
assert writer_obj.writer is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("protocol_class", "transport_class", "encryption_class"),
|
("protocol_class", "transport_class", "encryption_class"),
|
||||||
[
|
[
|
||||||
|
Loading…
x
Reference in New Issue
Block a user