diff --git a/kasa/protocols/iotprotocol.py b/kasa/protocols/iotprotocol.py index b58e57ae..1af4ae59 100755 --- a/kasa/protocols/iotprotocol.py +++ b/kasa/protocols/iotprotocol.py @@ -98,12 +98,26 @@ class IotProtocol(BaseProtocol): ) raise auex except _RetryableError as ex: + if retry == 0: + _LOGGER.debug( + "Device %s got a retryable error, will retry %s times: %s", + self._host, + retry_count, + ex, + ) await self._transport.reset() if retry >= retry_count: _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise ex continue except TimeoutError as ex: + if retry == 0: + _LOGGER.debug( + "Device %s got a timeout error, will retry %s times: %s", + self._host, + retry_count, + ex, + ) await self._transport.reset() if retry >= retry_count: _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) diff --git a/kasa/transports/xortransport.py b/kasa/transports/xortransport.py index 77a232f0..8cce6eb5 100644 --- a/kasa/transports/xortransport.py +++ b/kasa/transports/xortransport.py @@ -23,6 +23,7 @@ from collections.abc import Generator from kasa.deviceconfig import DeviceConfig from kasa.exceptions import KasaException, _RetryableError +from kasa.exceptions import TimeoutError as KasaTimeoutError from kasa.json import loads as json_loads from .basetransport import BaseTransport @@ -126,6 +127,12 @@ class XorTransport(BaseTransport): # This is especially import when there are multiple tplink devices being polled. try: await self._connect(self._timeout) + except TimeoutError as ex: + await self.reset() + raise KasaTimeoutError( + f"Timeout after {self._timeout} seconds connecting to the device:" + f" {self._host}:{self._port}: {ex}" + ) from ex except ConnectionRefusedError as ex: await self.reset() raise KasaException( @@ -159,6 +166,12 @@ class XorTransport(BaseTransport): assert self.writer is not None # noqa: S101 async with asyncio_timeout(self._timeout): return await self._execute_send(request) + except TimeoutError as ex: + await self.reset() + raise KasaTimeoutError( + f"Timeout after {self._timeout} seconds sending request to the device" + f" {self._host}:{self._port}: {ex}" + ) from ex except Exception as ex: await self.reset() raise _RetryableError( diff --git a/tests/protocols/test_iotprotocol.py b/tests/protocols/test_iotprotocol.py index a2feaae3..fd8facc9 100644 --- a/tests/protocols/test_iotprotocol.py +++ b/tests/protocols/test_iotprotocol.py @@ -16,7 +16,7 @@ import pytest from kasa.credentials import Credentials from kasa.device import Device from kasa.deviceconfig import DeviceConfig -from kasa.exceptions import KasaException +from kasa.exceptions import KasaException, TimeoutError from kasa.iot import IotDevice from kasa.protocols.iotprotocol import IotProtocol, _deprecated_TPLinkSmartHomeProtocol from kasa.protocols.protocol import ( @@ -294,6 +294,210 @@ async def test_protocol_handles_cancellation_during_connection( assert response == {"great": "success"} +@pytest.mark.parametrize( + ("protocol_class", "transport_class", "encryption_class"), + [ + ( + _deprecated_TPLinkSmartHomeProtocol, + XorTransport, + _deprecated_TPLinkSmartHomeProtocol, + ), + (IotProtocol, XorTransport, XorEncryption), + ], + ids=("_deprecated_TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"), +) +async def test_protocol_handles_timeout_during_write( + mocker, protocol_class, transport_class, encryption_class +): + attempts = 0 + encrypted = encryption_class.encrypt('{"great":"success"}')[ + transport_class.BLOCK_SIZE : + ] + + def _timeout_first_attempt(*_): + nonlocal attempts + attempts += 1 + if attempts == 1: + raise TimeoutError("Simulated timeout") + + async def _mock_read(byte_count): + nonlocal encrypted + if byte_count == transport_class.BLOCK_SIZE: + return struct.pack(">I", len(encrypted)) + if byte_count == len(encrypted): + return encrypted + + raise ValueError(f"No mock for {byte_count}") + + def aio_mock_writer(_, __): + reader = mocker.patch("asyncio.StreamReader") + writer = mocker.patch("asyncio.StreamWriter") + mocker.patch.object(writer, "write", _timeout_first_attempt) + mocker.patch.object(reader, "readexactly", _mock_read) + mocker.patch.object(writer, "drain", new_callable=AsyncMock) + return reader, writer + + config = DeviceConfig("127.0.0.1") + protocol = protocol_class(transport=transport_class(config=config)) + mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) + await protocol.query({}) + writer_obj = protocol if hasattr(protocol, "writer") else protocol._transport + assert writer_obj.writer is not None + response = await protocol.query({}) + assert response == {"great": "success"} + + +@pytest.mark.parametrize( + ("protocol_class", "transport_class", "encryption_class"), + [ + ( + _deprecated_TPLinkSmartHomeProtocol, + XorTransport, + _deprecated_TPLinkSmartHomeProtocol, + ), + (IotProtocol, XorTransport, XorEncryption), + ], + ids=("_deprecated_TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"), +) +async def test_protocol_handles_timeout_during_connection( + mocker, protocol_class, transport_class, encryption_class +): + attempts = 0 + encrypted = encryption_class.encrypt('{"great":"success"}')[ + transport_class.BLOCK_SIZE : + ] + + async def _mock_read(byte_count): + nonlocal encrypted + if byte_count == transport_class.BLOCK_SIZE: + return struct.pack(">I", len(encrypted)) + if byte_count == len(encrypted): + return encrypted + + raise ValueError(f"No mock for {byte_count}") + + def aio_mock_writer(_, __): + nonlocal attempts + attempts += 1 + if attempts == 1: + raise TimeoutError("Simulated timeout") + reader = mocker.patch("asyncio.StreamReader") + writer = mocker.patch("asyncio.StreamWriter") + mocker.patch.object(reader, "readexactly", _mock_read) + mocker.patch.object(writer, "drain", new_callable=AsyncMock) + return reader, writer + + config = DeviceConfig("127.0.0.1") + protocol = protocol_class(transport=transport_class(config=config)) + writer_obj = protocol if hasattr(protocol, "writer") else protocol._transport + await writer_obj.close() + + mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) + await protocol.query({"any": "thing"}) + + writer_obj = protocol if hasattr(protocol, "writer") else protocol._transport + assert writer_obj.writer is not None + response = await protocol.query({}) + assert response == {"great": "success"} + + +@pytest.mark.parametrize( + ("protocol_class", "transport_class", "encryption_class"), + [ + ( + _deprecated_TPLinkSmartHomeProtocol, + XorTransport, + _deprecated_TPLinkSmartHomeProtocol, + ), + (IotProtocol, XorTransport, XorEncryption), + ], + ids=("_deprecated_TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"), +) +async def test_protocol_handles_timeout_failure_during_write( + mocker, protocol_class, transport_class, encryption_class +): + encrypted = encryption_class.encrypt('{"great":"success"}')[ + transport_class.BLOCK_SIZE : + ] + + def _timeout_all_attempts(*_): + raise TimeoutError("Simulated timeout") + + async def _mock_read(byte_count): + nonlocal encrypted + if byte_count == transport_class.BLOCK_SIZE: + return struct.pack(">I", len(encrypted)) + if byte_count == len(encrypted): + return encrypted + + raise ValueError(f"No mock for {byte_count}") + + def aio_mock_writer(_, __): + reader = mocker.patch("asyncio.StreamReader") + writer = mocker.patch("asyncio.StreamWriter") + mocker.patch.object(writer, "write", _timeout_all_attempts) + mocker.patch.object(reader, "readexactly", _mock_read) + mocker.patch.object(writer, "drain", new_callable=AsyncMock) + return reader, writer + + config = DeviceConfig("127.0.0.1") + protocol = protocol_class(transport=transport_class(config=config)) + mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) + with pytest.raises( + TimeoutError, + match="Timeout after 5 seconds sending request to the device 127.0.0.1:9999: Simulated timeout", + ): + await protocol.query({}) + writer_obj = protocol if hasattr(protocol, "writer") else protocol._transport + assert writer_obj.writer is None + + +@pytest.mark.parametrize( + ("protocol_class", "transport_class", "encryption_class"), + [ + ( + _deprecated_TPLinkSmartHomeProtocol, + XorTransport, + _deprecated_TPLinkSmartHomeProtocol, + ), + (IotProtocol, XorTransport, XorEncryption), + ], + ids=("_deprecated_TPLinkSmartHomeProtocol", "IotProtocol-XorTransport"), +) +async def test_protocol_handles_timeout_failure_during_connection( + mocker, protocol_class, transport_class, encryption_class +): + encrypted = encryption_class.encrypt('{"great":"success"}')[ + transport_class.BLOCK_SIZE : + ] + + async def _mock_read(byte_count): + nonlocal encrypted + if byte_count == transport_class.BLOCK_SIZE: + return struct.pack(">I", len(encrypted)) + if byte_count == len(encrypted): + return encrypted + + raise ValueError(f"No mock for {byte_count}") + + def aio_mock_writer(_, __): + raise TimeoutError("Simulated timeout") + + config = DeviceConfig("127.0.0.1") + protocol = protocol_class(transport=transport_class(config=config)) + writer_obj = protocol if hasattr(protocol, "writer") else protocol._transport + await writer_obj.close() + + mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) + with pytest.raises( + TimeoutError, + match="Timeout after 5 seconds connecting to the device: 127.0.0.1:9999: Simulated timeout", + ): + await protocol.query({}) + writer_obj = protocol if hasattr(protocol, "writer") else protocol._transport + assert writer_obj.writer is None + + @pytest.mark.parametrize( ("protocol_class", "transport_class", "encryption_class"), [