From 49cfef087c7cca214a73c52921bdabd899fce927 Mon Sep 17 00:00:00 2001 From: Steven B <51370195+sdb9696@users.noreply.github.com> Date: Sat, 20 Jan 2024 12:35:05 +0000 Subject: [PATCH] Make close behaviour consistent across new protocols and transports (#660) --- kasa/aestransport.py | 6 ++++-- kasa/iotprotocol.py | 13 +++++++++---- kasa/smartprotocol.py | 4 ++-- kasa/tests/test_klapprotocol.py | 32 +++++++++++++++++++++++++++++++- 4 files changed, 46 insertions(+), 9 deletions(-) diff --git a/kasa/aestransport.py b/kasa/aestransport.py index c19ead5b..65b0045d 100644 --- a/kasa/aestransport.py +++ b/kasa/aestransport.py @@ -265,10 +265,12 @@ class AesTransport(BaseTransport): return await self.send_secure_passthrough(request) async def close(self) -> None: - """Close the transport.""" + """Mark the handshake and login as not done. + + Since we likely lost the connection. + """ self._handshake_done = False self._login_token = None - await self._http_client.close() class AesEncyptionSession: diff --git a/kasa/iotprotocol.py b/kasa/iotprotocol.py index a9001525..9f72bbc0 100755 --- a/kasa/iotprotocol.py +++ b/kasa/iotprotocol.py @@ -45,8 +45,8 @@ class IotProtocol(TPLinkProtocol): try: return await self._execute_query(request, retry) except ConnectionException as sdex: + await self.close() if retry >= retry_count: - await self.close() _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise sdex continue @@ -57,14 +57,14 @@ class IotProtocol(TPLinkProtocol): ) raise auex except RetryableException as ex: + await self.close() if retry >= retry_count: - await self.close() _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise ex continue except TimeoutException as ex: + await self.close() if retry >= retry_count: - await self.close() _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise ex await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT) @@ -85,5 +85,10 @@ class IotProtocol(TPLinkProtocol): return await self._transport.send(request) async def close(self) -> None: - """Close the protocol.""" + """Close the underlying transport. + + Some transports may close the connection, and some may + use this as a hint that they need to reconnect, or + reauthenticate. + """ await self._transport.close() diff --git a/kasa/smartprotocol.py b/kasa/smartprotocol.py index e7143d2e..c50c511f 100644 --- a/kasa/smartprotocol.py +++ b/kasa/smartprotocol.py @@ -66,8 +66,8 @@ class SmartProtocol(TPLinkProtocol): try: return await self._execute_query(request, retry) except ConnectionException as sdex: + await self.close() if retry >= retry_count: - await self.close() _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise sdex continue @@ -78,8 +78,8 @@ class SmartProtocol(TPLinkProtocol): ) raise auex except RetryableException as ex: + await self.close() if retry >= retry_count: - await self.close() _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise ex continue diff --git a/kasa/tests/test_klapprotocol.py b/kasa/tests/test_klapprotocol.py index a4b12e2c..8ae32e3f 100644 --- a/kasa/tests/test_klapprotocol.py +++ b/kasa/tests/test_klapprotocol.py @@ -16,7 +16,9 @@ from ..deviceconfig import DeviceConfig from ..exceptions import ( AuthenticationException, ConnectionException, + RetryableException, SmartDeviceException, + TimeoutException, ) from ..httpclient import HttpClient from ..iotprotocol import IotProtocol @@ -58,7 +60,7 @@ class _mock_response: @pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport]) @pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol]) @pytest.mark.parametrize("retry_count", [1, 3, 5]) -async def test_protocol_retries( +async def test_protocol_retries_via_client_session( mocker, retry_count, protocol_class, transport_class, error, retry_expectation ): host = "127.0.0.1" @@ -74,6 +76,34 @@ async def test_protocol_retries( assert conn.call_count == expected_count +@pytest.mark.parametrize( + "error, retry_expectation", + [ + (SmartDeviceException("dummy exception"), False), + (RetryableException("dummy exception"), True), + (TimeoutException("dummy exception"), True), + ], + ids=("SmartDeviceException", "RetryableException", "TimeoutException"), +) +@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport]) +@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol]) +@pytest.mark.parametrize("retry_count", [1, 3, 5]) +async def test_protocol_retries_via_httpclient( + mocker, retry_count, protocol_class, transport_class, error, retry_expectation +): + host = "127.0.0.1" + conn = mocker.patch.object(HttpClient, "post", side_effect=error) + + config = DeviceConfig(host) + with pytest.raises(SmartDeviceException): + await protocol_class(transport=transport_class(config=config)).query( + DUMMY_QUERY, retry_count=retry_count + ) + + expected_count = retry_count + 1 if retry_expectation else 1 + assert conn.call_count == expected_count + + @pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport]) @pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol]) async def test_protocol_no_retry_on_connection_error(