diff --git a/kasa/aestransport.py b/kasa/aestransport.py index 018176ad..73d02b0e 100644 --- a/kasa/aestransport.py +++ b/kasa/aestransport.py @@ -315,10 +315,12 @@ class AesTransport(BaseTransport): return await self.send_secure_passthrough(request) async def close(self) -> None: - """Mark the handshake and login as not done. + """Close the http client and reset internal state.""" + await self.reset() + await self._http_client.close() - Since we likely lost the connection. - """ + async def reset(self) -> None: + """Reset internal handshake and login state.""" self._handshake_done = False self._login_token = None diff --git a/kasa/httpclient.py b/kasa/httpclient.py index 28a19e8b..7fe0b2c3 100644 --- a/kasa/httpclient.py +++ b/kasa/httpclient.py @@ -5,7 +5,11 @@ from typing import Any, Dict, Optional, Tuple, Union import aiohttp from .deviceconfig import DeviceConfig -from .exceptions import ConnectionException, SmartDeviceException, TimeoutException +from .exceptions import ( + ConnectionException, + SmartDeviceException, + TimeoutException, +) from .json import loads as json_loads @@ -78,7 +82,7 @@ class HttpClient: except (aiohttp.ServerDisconnectedError, aiohttp.ClientOSError) as ex: raise ConnectionException( - f"Unable to connect to the device: {self._config.host}: {ex}", ex + f"Device connection error: {self._config.host}: {ex}", ex ) from ex except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as ex: raise TimeoutException( diff --git a/kasa/iotprotocol.py b/kasa/iotprotocol.py index c58cc880..ed926101 100755 --- a/kasa/iotprotocol.py +++ b/kasa/iotprotocol.py @@ -45,32 +45,31 @@ class IotProtocol(BaseProtocol): try: return await self._execute_query(request, retry) except ConnectionException as sdex: - await self.close() if retry >= retry_count: _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise sdex continue except AuthenticationException as auex: - await self.close() + await self._transport.reset() _LOGGER.debug( "Unable to authenticate with %s, not retrying", self._host ) raise auex except RetryableException as ex: - await self.close() + await self._transport.reset() if retry >= retry_count: _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise ex continue except TimeoutException as ex: - await self.close() + await self._transport.reset() if retry >= retry_count: _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise ex await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT) continue except SmartDeviceException as ex: - await self.close() + await self._transport.reset() _LOGGER.debug( "Unable to query the device: %s, not retrying: %s", self._host, @@ -85,10 +84,5 @@ class IotProtocol(BaseProtocol): return await self._transport.send(request) async def close(self) -> None: - """Close the underlying transport. - - Some transports may close the connection, and some may - use this as a hint that they need to reconnect, or - reauthenticate. - """ + """Close the underlying transport.""" await self._transport.close() diff --git a/kasa/klaptransport.py b/kasa/klaptransport.py index 5411314a..c678e448 100644 --- a/kasa/klaptransport.py +++ b/kasa/klaptransport.py @@ -348,7 +348,12 @@ class KlapTransport(BaseTransport): return json_payload async def close(self) -> None: - """Mark the handshake as not done since we likely lost the connection.""" + """Close the http client and reset internal state.""" + await self.reset() + await self._http_client.close() + + async def reset(self) -> None: + """Reset internal handshake state.""" self._handshake_done = False @staticmethod diff --git a/kasa/protocol.py b/kasa/protocol.py index 59fea4a8..ae8eb89b 100755 --- a/kasa/protocol.py +++ b/kasa/protocol.py @@ -80,6 +80,10 @@ class BaseTransport(ABC): async def close(self) -> None: """Close the transport. Abstract method to be overriden.""" + @abstractmethod + async def reset(self) -> None: + """Reset internal state.""" + class BaseProtocol(ABC): """Base class for all TP-Link Smart Home communication.""" @@ -139,7 +143,10 @@ class _XorTransport(BaseTransport): return {} async def close(self) -> None: - """Close the transport. Abstract method to be overriden.""" + """Close the transport.""" + + async def reset(self) -> None: + """Reset internal state..""" class TPLinkSmartHomeProtocol(BaseProtocol): @@ -233,9 +240,9 @@ class TPLinkSmartHomeProtocol(BaseProtocol): if writer: writer.close() - def _reset(self) -> None: - """Clear any varibles that should not survive between loops.""" - self.reader = self.writer = None + async def reset(self) -> None: + """Reset the transport.""" + await self.close() async def _query(self, request: str, retry_count: int, timeout: int) -> Dict: """Try to query a device.""" @@ -252,12 +259,12 @@ class TPLinkSmartHomeProtocol(BaseProtocol): try: await self._connect(timeout) except ConnectionRefusedError as ex: - await self.close() + await self.reset() raise SmartDeviceException( f"Unable to connect to the device: {self._host}:{self._port}: {ex}" ) from ex except OSError as ex: - await self.close() + await self.reset() if ex.errno in _NO_RETRY_ERRORS or retry >= retry_count: raise SmartDeviceException( f"Unable to connect to the device:" @@ -265,7 +272,7 @@ class TPLinkSmartHomeProtocol(BaseProtocol): ) from ex continue except Exception as ex: - await self.close() + await self.reset() if retry >= retry_count: _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise SmartDeviceException( @@ -290,7 +297,7 @@ class TPLinkSmartHomeProtocol(BaseProtocol): async with asyncio_timeout(timeout): return await self._execute_query(request) except Exception as ex: - await self.close() + await self.reset() if retry >= retry_count: _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise SmartDeviceException( @@ -312,7 +319,7 @@ class TPLinkSmartHomeProtocol(BaseProtocol): raise # make mypy happy, this should never be reached.. - await self.close() + await self.reset() raise SmartDeviceException("Query reached somehow to unreachable") def __del__(self) -> None: @@ -322,7 +329,6 @@ class TPLinkSmartHomeProtocol(BaseProtocol): # or in another thread so we need to make sure the call to # close is called safely with call_soon_threadsafe self.loop.call_soon_threadsafe(self.writer.close) - self._reset() @staticmethod def _xor_payload(unencrypted: bytes) -> Generator[int, None, None]: diff --git a/kasa/smartdevice.py b/kasa/smartdevice.py index 08a6bfb6..31418afc 100755 --- a/kasa/smartdevice.py +++ b/kasa/smartdevice.py @@ -806,6 +806,10 @@ class SmartDevice: """Return the device configuration.""" return self.protocol.config + async def disconnect(self): + """Disconnect and close any underlying connection resources.""" + await self.protocol.close() + @staticmethod async def connect( *, diff --git a/kasa/smartprotocol.py b/kasa/smartprotocol.py index c28db948..6f0648ea 100644 --- a/kasa/smartprotocol.py +++ b/kasa/smartprotocol.py @@ -66,32 +66,31 @@ class SmartProtocol(BaseProtocol): try: return await self._execute_query(request, retry) except ConnectionException as sdex: - await self.close() if retry >= retry_count: _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise sdex continue except AuthenticationException as auex: - await self.close() + await self._transport.reset() _LOGGER.debug( "Unable to authenticate with %s, not retrying", self._host ) raise auex except RetryableException as ex: - await self.close() + await self._transport.reset() if retry >= retry_count: _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise ex continue except TimeoutException as ex: - await self.close() + await self._transport.reset() if retry >= retry_count: _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise ex await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT) continue except SmartDeviceException as ex: - await self.close() + await self._transport.reset() _LOGGER.debug( "Unable to query the device: %s, not retrying: %s", self._host, @@ -167,12 +166,7 @@ class SmartProtocol(BaseProtocol): raise SmartDeviceException(msg, error_code=error_code) async def close(self) -> None: - """Close the underlying transport. - - Some transports may close the connection, and some may - use this as a hint that they need to reconnect, or - reauthenticate. - """ + """Close the underlying transport.""" await self._transport.close() diff --git a/kasa/tests/conftest.py b/kasa/tests/conftest.py index 12f9c276..7addbe72 100644 --- a/kasa/tests/conftest.py +++ b/kasa/tests/conftest.py @@ -15,6 +15,7 @@ from kasa import ( Credentials, Discover, SmartBulb, + SmartDevice, SmartDimmer, SmartLightStrip, SmartPlug, @@ -416,9 +417,15 @@ async def dev(request): IP_MODEL_CACHE[ip] = model = d.model if model not in file: pytest.skip(f"skipping file {file}") - return d if d else await _discover_update_and_close(ip, username, password) + dev: SmartDevice = ( + d if d else await _discover_update_and_close(ip, username, password) + ) + else: + dev: SmartDevice = await get_device_for_file(file, protocol) - return await get_device_for_file(file, protocol) + yield dev + + await dev.disconnect() @pytest.fixture diff --git a/kasa/tests/newfakes.py b/kasa/tests/newfakes.py index 78bea334..625a4994 100644 --- a/kasa/tests/newfakes.py +++ b/kasa/tests/newfakes.py @@ -377,6 +377,9 @@ class FakeSmartTransport(BaseTransport): async def close(self) -> None: pass + async def reset(self) -> None: + pass + class FakeTransportProtocol(TPLinkSmartHomeProtocol): def __init__(self, info): diff --git a/kasa/tests/test_aestransport.py b/kasa/tests/test_aestransport.py index c58aad4e..cfd29284 100644 --- a/kasa/tests/test_aestransport.py +++ b/kasa/tests/test_aestransport.py @@ -147,6 +147,7 @@ async def test_login_errors(mocker, inner_error_codes, expectation, call_count): await transport.send(json_dumps(request)) assert transport._login_token == mock_aes_device.token assert post_mock.call_count == call_count # Login, Handshake, Login + await transport.close() @status_parameters diff --git a/kasa/tests/test_device_factory.py b/kasa/tests/test_device_factory.py index 25a13aea..8e3e2ed6 100644 --- a/kasa/tests/test_device_factory.py +++ b/kasa/tests/test_device_factory.py @@ -69,6 +69,8 @@ async def test_connect( assert dev.config == config + await dev.disconnect() + @pytest.mark.parametrize("custom_port", [123, None]) async def test_connect_custom_port(all_fixture_data: dict, mocker, custom_port): diff --git a/kasa/tests/test_httpclient.py b/kasa/tests/test_httpclient.py index 0a6c2beb..e178b818 100644 --- a/kasa/tests/test_httpclient.py +++ b/kasa/tests/test_httpclient.py @@ -19,12 +19,12 @@ from ..httpclient import HttpClient ( aiohttp.ServerDisconnectedError(), ConnectionException, - "Unable to connect to the device: ", + "Device connection error: ", ), ( aiohttp.ClientOSError(), ConnectionException, - "Unable to connect to the device: ", + "Device connection error: ", ), ( aiohttp.ServerTimeoutError(), diff --git a/kasa/tests/test_klapprotocol.py b/kasa/tests/test_klapprotocol.py index 54f4a4be..09ceccae 100644 --- a/kasa/tests/test_klapprotocol.py +++ b/kasa/tests/test_klapprotocol.py @@ -54,9 +54,10 @@ class _mock_response: [ (Exception("dummy exception"), False), (aiohttp.ServerTimeoutError("dummy exception"), True), + (aiohttp.ServerDisconnectedError("dummy exception"), True), (aiohttp.ClientOSError("dummy exception"), True), ], - ids=("Exception", "SmartDeviceException", "ConnectError"), + ids=("Exception", "ServerTimeoutError", "ServerDisconnectedError", "ClientOSError"), ) @pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport]) @pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol])