From 0e874a35f1ff6d4e2bfa2f13fa7333085b8acfdb Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Tue, 23 Jan 2024 17:11:23 +0000 Subject: [PATCH] Update transport close/reset behaviour --- kasa/aestransport.py | 8 +++++--- kasa/exceptions.py | 4 ++++ kasa/httpclient.py | 13 +++++++++++-- kasa/iotprotocol.py | 23 ++++++++++++----------- kasa/klaptransport.py | 7 ++++++- kasa/protocol.py | 26 ++++++++++++++++---------- kasa/smartdevice.py | 4 ++++ kasa/smartprotocol.py | 23 ++++++++++++----------- kasa/tests/conftest.py | 11 +++++++++-- kasa/tests/newfakes.py | 3 +++ kasa/tests/test_httpclient.py | 5 +++-- kasa/tests/test_klapprotocol.py | 3 ++- 12 files changed, 87 insertions(+), 43 deletions(-) diff --git a/kasa/aestransport.py b/kasa/aestransport.py index 14a9ee6a..c03b6a11 100644 --- a/kasa/aestransport.py +++ b/kasa/aestransport.py @@ -306,10 +306,12 @@ class AesTransport(BaseTransport): return await self.send_secure_passthrough(request) async def close(self) -> None: - """Mark the handshake and login as not done. + """Close the http client and reset internal state.""" + await self.reset() + await self._http_client.close() - Since we likely lost the connection. - """ + async def reset(self) -> None: + """Reset internal handshake and login state.""" self._handshake_done = False self._login_token = None diff --git a/kasa/exceptions.py b/kasa/exceptions.py index c0ef23b6..8720d97b 100644 --- a/kasa/exceptions.py +++ b/kasa/exceptions.py @@ -42,6 +42,10 @@ class ConnectionException(SmartDeviceException): """Connection exception for device errors.""" +class DisconnectedException(SmartDeviceException): + """Disconnected exception for device errors.""" + + class SmartErrorCode(IntEnum): """Enum for SMART Error Codes.""" diff --git a/kasa/httpclient.py b/kasa/httpclient.py index 28a19e8b..73c91fa4 100644 --- a/kasa/httpclient.py +++ b/kasa/httpclient.py @@ -5,7 +5,12 @@ from typing import Any, Dict, Optional, Tuple, Union import aiohttp from .deviceconfig import DeviceConfig -from .exceptions import ConnectionException, SmartDeviceException, TimeoutException +from .exceptions import ( + ConnectionException, + DisconnectedException, + SmartDeviceException, + TimeoutException, +) from .json import loads as json_loads @@ -76,7 +81,11 @@ class HttpClient: if return_json: response_data = json_loads(response_data.decode()) - except (aiohttp.ServerDisconnectedError, aiohttp.ClientOSError) as ex: + except aiohttp.ServerDisconnectedError as ex: + raise DisconnectedException( + f"Disconnected from the device: {self._config.host}: {ex}", ex + ) from ex + except aiohttp.ClientOSError as ex: raise ConnectionException( f"Unable to connect to the device: {self._config.host}: {ex}", ex ) from ex diff --git a/kasa/iotprotocol.py b/kasa/iotprotocol.py index c58cc880..aac21f10 100755 --- a/kasa/iotprotocol.py +++ b/kasa/iotprotocol.py @@ -6,6 +6,7 @@ from typing import Dict, Union from .exceptions import ( AuthenticationException, ConnectionException, + DisconnectedException, RetryableException, SmartDeviceException, TimeoutException, @@ -44,33 +45,38 @@ class IotProtocol(BaseProtocol): for retry in range(retry_count + 1): try: return await self._execute_query(request, retry) + except DisconnectedException as sdex: + if retry >= retry_count: + _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) + raise sdex + continue except ConnectionException as sdex: - await self.close() + await self._transport.reset() if retry >= retry_count: _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise sdex continue except AuthenticationException as auex: - await self.close() + await self._transport.reset() _LOGGER.debug( "Unable to authenticate with %s, not retrying", self._host ) raise auex except RetryableException as ex: - await self.close() + await self._transport.reset() if retry >= retry_count: _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise ex continue except TimeoutException as ex: - await self.close() + await self._transport.reset() if retry >= retry_count: _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise ex await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT) continue except SmartDeviceException as ex: - await self.close() + await self._transport.reset() _LOGGER.debug( "Unable to query the device: %s, not retrying: %s", self._host, @@ -85,10 +91,5 @@ class IotProtocol(BaseProtocol): return await self._transport.send(request) async def close(self) -> None: - """Close the underlying transport. - - Some transports may close the connection, and some may - use this as a hint that they need to reconnect, or - reauthenticate. - """ + """Close the underlying transport.""" await self._transport.close() diff --git a/kasa/klaptransport.py b/kasa/klaptransport.py index 5411314a..c678e448 100644 --- a/kasa/klaptransport.py +++ b/kasa/klaptransport.py @@ -348,7 +348,12 @@ class KlapTransport(BaseTransport): return json_payload async def close(self) -> None: - """Mark the handshake as not done since we likely lost the connection.""" + """Close the http client and reset internal state.""" + await self.reset() + await self._http_client.close() + + async def reset(self) -> None: + """Reset internal handshake state.""" self._handshake_done = False @staticmethod diff --git a/kasa/protocol.py b/kasa/protocol.py index 59fea4a8..ae8eb89b 100755 --- a/kasa/protocol.py +++ b/kasa/protocol.py @@ -80,6 +80,10 @@ class BaseTransport(ABC): async def close(self) -> None: """Close the transport. Abstract method to be overriden.""" + @abstractmethod + async def reset(self) -> None: + """Reset internal state.""" + class BaseProtocol(ABC): """Base class for all TP-Link Smart Home communication.""" @@ -139,7 +143,10 @@ class _XorTransport(BaseTransport): return {} async def close(self) -> None: - """Close the transport. Abstract method to be overriden.""" + """Close the transport.""" + + async def reset(self) -> None: + """Reset internal state..""" class TPLinkSmartHomeProtocol(BaseProtocol): @@ -233,9 +240,9 @@ class TPLinkSmartHomeProtocol(BaseProtocol): if writer: writer.close() - def _reset(self) -> None: - """Clear any varibles that should not survive between loops.""" - self.reader = self.writer = None + async def reset(self) -> None: + """Reset the transport.""" + await self.close() async def _query(self, request: str, retry_count: int, timeout: int) -> Dict: """Try to query a device.""" @@ -252,12 +259,12 @@ class TPLinkSmartHomeProtocol(BaseProtocol): try: await self._connect(timeout) except ConnectionRefusedError as ex: - await self.close() + await self.reset() raise SmartDeviceException( f"Unable to connect to the device: {self._host}:{self._port}: {ex}" ) from ex except OSError as ex: - await self.close() + await self.reset() if ex.errno in _NO_RETRY_ERRORS or retry >= retry_count: raise SmartDeviceException( f"Unable to connect to the device:" @@ -265,7 +272,7 @@ class TPLinkSmartHomeProtocol(BaseProtocol): ) from ex continue except Exception as ex: - await self.close() + await self.reset() if retry >= retry_count: _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise SmartDeviceException( @@ -290,7 +297,7 @@ class TPLinkSmartHomeProtocol(BaseProtocol): async with asyncio_timeout(timeout): return await self._execute_query(request) except Exception as ex: - await self.close() + await self.reset() if retry >= retry_count: _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise SmartDeviceException( @@ -312,7 +319,7 @@ class TPLinkSmartHomeProtocol(BaseProtocol): raise # make mypy happy, this should never be reached.. - await self.close() + await self.reset() raise SmartDeviceException("Query reached somehow to unreachable") def __del__(self) -> None: @@ -322,7 +329,6 @@ class TPLinkSmartHomeProtocol(BaseProtocol): # or in another thread so we need to make sure the call to # close is called safely with call_soon_threadsafe self.loop.call_soon_threadsafe(self.writer.close) - self._reset() @staticmethod def _xor_payload(unencrypted: bytes) -> Generator[int, None, None]: diff --git a/kasa/smartdevice.py b/kasa/smartdevice.py index 08a6bfb6..31418afc 100755 --- a/kasa/smartdevice.py +++ b/kasa/smartdevice.py @@ -806,6 +806,10 @@ class SmartDevice: """Return the device configuration.""" return self.protocol.config + async def disconnect(self): + """Disconnect and close any underlying connection resources.""" + await self.protocol.close() + @staticmethod async def connect( *, diff --git a/kasa/smartprotocol.py b/kasa/smartprotocol.py index c28db948..8b876c14 100644 --- a/kasa/smartprotocol.py +++ b/kasa/smartprotocol.py @@ -18,6 +18,7 @@ from .exceptions import ( SMART_TIMEOUT_ERRORS, AuthenticationException, ConnectionException, + DisconnectedException, RetryableException, SmartDeviceException, SmartErrorCode, @@ -65,33 +66,38 @@ class SmartProtocol(BaseProtocol): for retry in range(retry_count + 1): try: return await self._execute_query(request, retry) + except DisconnectedException as sdex: + if retry >= retry_count: + _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) + raise sdex + continue except ConnectionException as sdex: - await self.close() + await self._transport.reset() if retry >= retry_count: _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise sdex continue except AuthenticationException as auex: - await self.close() + await self._transport.reset() _LOGGER.debug( "Unable to authenticate with %s, not retrying", self._host ) raise auex except RetryableException as ex: - await self.close() + await self._transport.reset() if retry >= retry_count: _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise ex continue except TimeoutException as ex: - await self.close() + await self._transport.reset() if retry >= retry_count: _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise ex await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT) continue except SmartDeviceException as ex: - await self.close() + await self._transport.reset() _LOGGER.debug( "Unable to query the device: %s, not retrying: %s", self._host, @@ -167,12 +173,7 @@ class SmartProtocol(BaseProtocol): raise SmartDeviceException(msg, error_code=error_code) async def close(self) -> None: - """Close the underlying transport. - - Some transports may close the connection, and some may - use this as a hint that they need to reconnect, or - reauthenticate. - """ + """Close the underlying transport.""" await self._transport.close() diff --git a/kasa/tests/conftest.py b/kasa/tests/conftest.py index 12f9c276..7addbe72 100644 --- a/kasa/tests/conftest.py +++ b/kasa/tests/conftest.py @@ -15,6 +15,7 @@ from kasa import ( Credentials, Discover, SmartBulb, + SmartDevice, SmartDimmer, SmartLightStrip, SmartPlug, @@ -416,9 +417,15 @@ async def dev(request): IP_MODEL_CACHE[ip] = model = d.model if model not in file: pytest.skip(f"skipping file {file}") - return d if d else await _discover_update_and_close(ip, username, password) + dev: SmartDevice = ( + d if d else await _discover_update_and_close(ip, username, password) + ) + else: + dev: SmartDevice = await get_device_for_file(file, protocol) - return await get_device_for_file(file, protocol) + yield dev + + await dev.disconnect() @pytest.fixture diff --git a/kasa/tests/newfakes.py b/kasa/tests/newfakes.py index 78bea334..625a4994 100644 --- a/kasa/tests/newfakes.py +++ b/kasa/tests/newfakes.py @@ -377,6 +377,9 @@ class FakeSmartTransport(BaseTransport): async def close(self) -> None: pass + async def reset(self) -> None: + pass + class FakeTransportProtocol(TPLinkSmartHomeProtocol): def __init__(self, info): diff --git a/kasa/tests/test_httpclient.py b/kasa/tests/test_httpclient.py index 0a6c2beb..bcf48df2 100644 --- a/kasa/tests/test_httpclient.py +++ b/kasa/tests/test_httpclient.py @@ -7,6 +7,7 @@ import pytest from ..deviceconfig import DeviceConfig from ..exceptions import ( ConnectionException, + DisconnectedException, SmartDeviceException, TimeoutException, ) @@ -18,8 +19,8 @@ from ..httpclient import HttpClient [ ( aiohttp.ServerDisconnectedError(), - ConnectionException, - "Unable to connect to the device: ", + DisconnectedException, + "Disconnected from the device: ", ), ( aiohttp.ClientOSError(), diff --git a/kasa/tests/test_klapprotocol.py b/kasa/tests/test_klapprotocol.py index 54f4a4be..6bd142ca 100644 --- a/kasa/tests/test_klapprotocol.py +++ b/kasa/tests/test_klapprotocol.py @@ -54,9 +54,10 @@ class _mock_response: [ (Exception("dummy exception"), False), (aiohttp.ServerTimeoutError("dummy exception"), True), + (aiohttp.ServerDisconnectedError("dummy exception"), True), (aiohttp.ClientOSError("dummy exception"), True), ], - ids=("Exception", "SmartDeviceException", "ConnectError"), + ids=("Exception", "SmartDeviceException", "DisconnectError", "ConnectError"), ) @pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport]) @pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol])