Make close behaviour consistent across new protocols and transports (#660)

This commit is contained in:
Steven B 2024-01-20 12:35:05 +00:00 committed by GitHub
parent e94cd118a4
commit 49cfef087c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 46 additions and 9 deletions

View File

@ -265,10 +265,12 @@ class AesTransport(BaseTransport):
return await self.send_secure_passthrough(request) return await self.send_secure_passthrough(request)
async def close(self) -> None: async def close(self) -> None:
"""Close the transport.""" """Mark the handshake and login as not done.
Since we likely lost the connection.
"""
self._handshake_done = False self._handshake_done = False
self._login_token = None self._login_token = None
await self._http_client.close()
class AesEncyptionSession: class AesEncyptionSession:

View File

@ -45,8 +45,8 @@ class IotProtocol(TPLinkProtocol):
try: try:
return await self._execute_query(request, retry) return await self._execute_query(request, retry)
except ConnectionException as sdex: except ConnectionException as sdex:
if retry >= retry_count:
await self.close() await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry) _LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise sdex raise sdex
continue continue
@ -57,14 +57,14 @@ class IotProtocol(TPLinkProtocol):
) )
raise auex raise auex
except RetryableException as ex: except RetryableException as ex:
if retry >= retry_count:
await self.close() await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry) _LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise ex raise ex
continue continue
except TimeoutException as ex: except TimeoutException as ex:
if retry >= retry_count:
await self.close() await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry) _LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise ex raise ex
await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT) await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT)
@ -85,5 +85,10 @@ class IotProtocol(TPLinkProtocol):
return await self._transport.send(request) return await self._transport.send(request)
async def close(self) -> None: async def close(self) -> None:
"""Close the protocol.""" """Close the underlying transport.
Some transports may close the connection, and some may
use this as a hint that they need to reconnect, or
reauthenticate.
"""
await self._transport.close() await self._transport.close()

View File

@ -66,8 +66,8 @@ class SmartProtocol(TPLinkProtocol):
try: try:
return await self._execute_query(request, retry) return await self._execute_query(request, retry)
except ConnectionException as sdex: except ConnectionException as sdex:
if retry >= retry_count:
await self.close() await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry) _LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise sdex raise sdex
continue continue
@ -78,8 +78,8 @@ class SmartProtocol(TPLinkProtocol):
) )
raise auex raise auex
except RetryableException as ex: except RetryableException as ex:
if retry >= retry_count:
await self.close() await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry) _LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise ex raise ex
continue continue

View File

@ -16,7 +16,9 @@ from ..deviceconfig import DeviceConfig
from ..exceptions import ( from ..exceptions import (
AuthenticationException, AuthenticationException,
ConnectionException, ConnectionException,
RetryableException,
SmartDeviceException, SmartDeviceException,
TimeoutException,
) )
from ..httpclient import HttpClient from ..httpclient import HttpClient
from ..iotprotocol import IotProtocol from ..iotprotocol import IotProtocol
@ -58,7 +60,7 @@ class _mock_response:
@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport]) @pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol]) @pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol])
@pytest.mark.parametrize("retry_count", [1, 3, 5]) @pytest.mark.parametrize("retry_count", [1, 3, 5])
async def test_protocol_retries( async def test_protocol_retries_via_client_session(
mocker, retry_count, protocol_class, transport_class, error, retry_expectation mocker, retry_count, protocol_class, transport_class, error, retry_expectation
): ):
host = "127.0.0.1" host = "127.0.0.1"
@ -74,6 +76,34 @@ async def test_protocol_retries(
assert conn.call_count == expected_count assert conn.call_count == expected_count
@pytest.mark.parametrize(
"error, retry_expectation",
[
(SmartDeviceException("dummy exception"), False),
(RetryableException("dummy exception"), True),
(TimeoutException("dummy exception"), True),
],
ids=("SmartDeviceException", "RetryableException", "TimeoutException"),
)
@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol])
@pytest.mark.parametrize("retry_count", [1, 3, 5])
async def test_protocol_retries_via_httpclient(
mocker, retry_count, protocol_class, transport_class, error, retry_expectation
):
host = "127.0.0.1"
conn = mocker.patch.object(HttpClient, "post", side_effect=error)
config = DeviceConfig(host)
with pytest.raises(SmartDeviceException):
await protocol_class(transport=transport_class(config=config)).query(
DUMMY_QUERY, retry_count=retry_count
)
expected_count = retry_count + 1 if retry_expectation else 1
assert conn.call_count == expected_count
@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport]) @pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol]) @pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol])
async def test_protocol_no_retry_on_connection_error( async def test_protocol_no_retry_on_connection_error(