mirror of
				https://github.com/python-kasa/python-kasa.git
				synced 2025-10-21 23:08:01 +00:00 
			
		
		
		
	Update transport close/reset behaviour (#689)
Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
		| @@ -315,10 +315,12 @@ class AesTransport(BaseTransport): | ||||
|         return await self.send_secure_passthrough(request) | ||||
|  | ||||
|     async def close(self) -> None: | ||||
|         """Mark the handshake and login as not done. | ||||
|         """Close the http client and reset internal state.""" | ||||
|         await self.reset() | ||||
|         await self._http_client.close() | ||||
|  | ||||
|         Since we likely lost the connection. | ||||
|         """ | ||||
|     async def reset(self) -> None: | ||||
|         """Reset internal handshake and login state.""" | ||||
|         self._handshake_done = False | ||||
|         self._login_token = None | ||||
|  | ||||
|   | ||||
| @@ -5,7 +5,11 @@ from typing import Any, Dict, Optional, Tuple, Union | ||||
| import aiohttp | ||||
|  | ||||
| from .deviceconfig import DeviceConfig | ||||
| from .exceptions import ConnectionException, SmartDeviceException, TimeoutException | ||||
| from .exceptions import ( | ||||
|     ConnectionException, | ||||
|     SmartDeviceException, | ||||
|     TimeoutException, | ||||
| ) | ||||
| from .json import loads as json_loads | ||||
|  | ||||
|  | ||||
| @@ -78,7 +82,7 @@ class HttpClient: | ||||
|  | ||||
|         except (aiohttp.ServerDisconnectedError, aiohttp.ClientOSError) as ex: | ||||
|             raise ConnectionException( | ||||
|                 f"Unable to connect to the device: {self._config.host}: {ex}", ex | ||||
|                 f"Device connection error: {self._config.host}: {ex}", ex | ||||
|             ) from ex | ||||
|         except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as ex: | ||||
|             raise TimeoutException( | ||||
|   | ||||
| @@ -45,32 +45,31 @@ class IotProtocol(BaseProtocol): | ||||
|             try: | ||||
|                 return await self._execute_query(request, retry) | ||||
|             except ConnectionException as sdex: | ||||
|                 await self.close() | ||||
|                 if retry >= retry_count: | ||||
|                     _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) | ||||
|                     raise sdex | ||||
|                 continue | ||||
|             except AuthenticationException as auex: | ||||
|                 await self.close() | ||||
|                 await self._transport.reset() | ||||
|                 _LOGGER.debug( | ||||
|                     "Unable to authenticate with %s, not retrying", self._host | ||||
|                 ) | ||||
|                 raise auex | ||||
|             except RetryableException as ex: | ||||
|                 await self.close() | ||||
|                 await self._transport.reset() | ||||
|                 if retry >= retry_count: | ||||
|                     _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) | ||||
|                     raise ex | ||||
|                 continue | ||||
|             except TimeoutException as ex: | ||||
|                 await self.close() | ||||
|                 await self._transport.reset() | ||||
|                 if retry >= retry_count: | ||||
|                     _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) | ||||
|                     raise ex | ||||
|                 await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT) | ||||
|                 continue | ||||
|             except SmartDeviceException as ex: | ||||
|                 await self.close() | ||||
|                 await self._transport.reset() | ||||
|                 _LOGGER.debug( | ||||
|                     "Unable to query the device: %s, not retrying: %s", | ||||
|                     self._host, | ||||
| @@ -85,10 +84,5 @@ class IotProtocol(BaseProtocol): | ||||
|         return await self._transport.send(request) | ||||
|  | ||||
|     async def close(self) -> None: | ||||
|         """Close the underlying transport. | ||||
|  | ||||
|         Some transports may close the connection, and some may | ||||
|         use this as a hint that they need to reconnect, or | ||||
|         reauthenticate. | ||||
|         """ | ||||
|         """Close the underlying transport.""" | ||||
|         await self._transport.close() | ||||
|   | ||||
| @@ -348,7 +348,12 @@ class KlapTransport(BaseTransport): | ||||
|             return json_payload | ||||
|  | ||||
|     async def close(self) -> None: | ||||
|         """Mark the handshake as not done since we likely lost the connection.""" | ||||
|         """Close the http client and reset internal state.""" | ||||
|         await self.reset() | ||||
|         await self._http_client.close() | ||||
|  | ||||
|     async def reset(self) -> None: | ||||
|         """Reset internal handshake state.""" | ||||
|         self._handshake_done = False | ||||
|  | ||||
|     @staticmethod | ||||
|   | ||||
| @@ -80,6 +80,10 @@ class BaseTransport(ABC): | ||||
|     async def close(self) -> None: | ||||
|         """Close the transport.  Abstract method to be overriden.""" | ||||
|  | ||||
|     @abstractmethod | ||||
|     async def reset(self) -> None: | ||||
|         """Reset internal state.""" | ||||
|  | ||||
|  | ||||
| class BaseProtocol(ABC): | ||||
|     """Base class for all TP-Link Smart Home communication.""" | ||||
| @@ -139,7 +143,10 @@ class _XorTransport(BaseTransport): | ||||
|         return {} | ||||
|  | ||||
|     async def close(self) -> None: | ||||
|         """Close the transport.  Abstract method to be overriden.""" | ||||
|         """Close the transport.""" | ||||
|  | ||||
|     async def reset(self) -> None: | ||||
|         """Reset internal state..""" | ||||
|  | ||||
|  | ||||
| class TPLinkSmartHomeProtocol(BaseProtocol): | ||||
| @@ -233,9 +240,9 @@ class TPLinkSmartHomeProtocol(BaseProtocol): | ||||
|         if writer: | ||||
|             writer.close() | ||||
|  | ||||
|     def _reset(self) -> None: | ||||
|         """Clear any varibles that should not survive between loops.""" | ||||
|         self.reader = self.writer = None | ||||
|     async def reset(self) -> None: | ||||
|         """Reset the transport.""" | ||||
|         await self.close() | ||||
|  | ||||
|     async def _query(self, request: str, retry_count: int, timeout: int) -> Dict: | ||||
|         """Try to query a device.""" | ||||
| @@ -252,12 +259,12 @@ class TPLinkSmartHomeProtocol(BaseProtocol): | ||||
|             try: | ||||
|                 await self._connect(timeout) | ||||
|             except ConnectionRefusedError as ex: | ||||
|                 await self.close() | ||||
|                 await self.reset() | ||||
|                 raise SmartDeviceException( | ||||
|                     f"Unable to connect to the device: {self._host}:{self._port}: {ex}" | ||||
|                 ) from ex | ||||
|             except OSError as ex: | ||||
|                 await self.close() | ||||
|                 await self.reset() | ||||
|                 if ex.errno in _NO_RETRY_ERRORS or retry >= retry_count: | ||||
|                     raise SmartDeviceException( | ||||
|                         f"Unable to connect to the device:" | ||||
| @@ -265,7 +272,7 @@ class TPLinkSmartHomeProtocol(BaseProtocol): | ||||
|                     ) from ex | ||||
|                 continue | ||||
|             except Exception as ex: | ||||
|                 await self.close() | ||||
|                 await self.reset() | ||||
|                 if retry >= retry_count: | ||||
|                     _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) | ||||
|                     raise SmartDeviceException( | ||||
| @@ -290,7 +297,7 @@ class TPLinkSmartHomeProtocol(BaseProtocol): | ||||
|                 async with asyncio_timeout(timeout): | ||||
|                     return await self._execute_query(request) | ||||
|             except Exception as ex: | ||||
|                 await self.close() | ||||
|                 await self.reset() | ||||
|                 if retry >= retry_count: | ||||
|                     _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) | ||||
|                     raise SmartDeviceException( | ||||
| @@ -312,7 +319,7 @@ class TPLinkSmartHomeProtocol(BaseProtocol): | ||||
|                 raise | ||||
|  | ||||
|         # make mypy happy, this should never be reached.. | ||||
|         await self.close() | ||||
|         await self.reset() | ||||
|         raise SmartDeviceException("Query reached somehow to unreachable") | ||||
|  | ||||
|     def __del__(self) -> None: | ||||
| @@ -322,7 +329,6 @@ class TPLinkSmartHomeProtocol(BaseProtocol): | ||||
|             # or in another thread so we need to make sure the call to | ||||
|             # close is called safely with call_soon_threadsafe | ||||
|             self.loop.call_soon_threadsafe(self.writer.close) | ||||
|         self._reset() | ||||
|  | ||||
|     @staticmethod | ||||
|     def _xor_payload(unencrypted: bytes) -> Generator[int, None, None]: | ||||
|   | ||||
| @@ -806,6 +806,10 @@ class SmartDevice: | ||||
|         """Return the device configuration.""" | ||||
|         return self.protocol.config | ||||
|  | ||||
|     async def disconnect(self): | ||||
|         """Disconnect and close any underlying connection resources.""" | ||||
|         await self.protocol.close() | ||||
|  | ||||
|     @staticmethod | ||||
|     async def connect( | ||||
|         *, | ||||
|   | ||||
| @@ -66,32 +66,31 @@ class SmartProtocol(BaseProtocol): | ||||
|             try: | ||||
|                 return await self._execute_query(request, retry) | ||||
|             except ConnectionException as sdex: | ||||
|                 await self.close() | ||||
|                 if retry >= retry_count: | ||||
|                     _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) | ||||
|                     raise sdex | ||||
|                 continue | ||||
|             except AuthenticationException as auex: | ||||
|                 await self.close() | ||||
|                 await self._transport.reset() | ||||
|                 _LOGGER.debug( | ||||
|                     "Unable to authenticate with %s, not retrying", self._host | ||||
|                 ) | ||||
|                 raise auex | ||||
|             except RetryableException as ex: | ||||
|                 await self.close() | ||||
|                 await self._transport.reset() | ||||
|                 if retry >= retry_count: | ||||
|                     _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) | ||||
|                     raise ex | ||||
|                 continue | ||||
|             except TimeoutException as ex: | ||||
|                 await self.close() | ||||
|                 await self._transport.reset() | ||||
|                 if retry >= retry_count: | ||||
|                     _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) | ||||
|                     raise ex | ||||
|                 await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT) | ||||
|                 continue | ||||
|             except SmartDeviceException as ex: | ||||
|                 await self.close() | ||||
|                 await self._transport.reset() | ||||
|                 _LOGGER.debug( | ||||
|                     "Unable to query the device: %s, not retrying: %s", | ||||
|                     self._host, | ||||
| @@ -167,12 +166,7 @@ class SmartProtocol(BaseProtocol): | ||||
|         raise SmartDeviceException(msg, error_code=error_code) | ||||
|  | ||||
|     async def close(self) -> None: | ||||
|         """Close the underlying transport. | ||||
|  | ||||
|         Some transports may close the connection, and some may | ||||
|         use this as a hint that they need to reconnect, or | ||||
|         reauthenticate. | ||||
|         """ | ||||
|         """Close the underlying transport.""" | ||||
|         await self._transport.close() | ||||
|  | ||||
|  | ||||
|   | ||||
| @@ -15,6 +15,7 @@ from kasa import ( | ||||
|     Credentials, | ||||
|     Discover, | ||||
|     SmartBulb, | ||||
|     SmartDevice, | ||||
|     SmartDimmer, | ||||
|     SmartLightStrip, | ||||
|     SmartPlug, | ||||
| @@ -416,9 +417,15 @@ async def dev(request): | ||||
|             IP_MODEL_CACHE[ip] = model = d.model | ||||
|         if model not in file: | ||||
|             pytest.skip(f"skipping file {file}") | ||||
|         return d if d else await _discover_update_and_close(ip, username, password) | ||||
|         dev: SmartDevice = ( | ||||
|             d if d else await _discover_update_and_close(ip, username, password) | ||||
|         ) | ||||
|     else: | ||||
|         dev: SmartDevice = await get_device_for_file(file, protocol) | ||||
|  | ||||
|     return await get_device_for_file(file, protocol) | ||||
|     yield dev | ||||
|  | ||||
|     await dev.disconnect() | ||||
|  | ||||
|  | ||||
| @pytest.fixture | ||||
|   | ||||
| @@ -377,6 +377,9 @@ class FakeSmartTransport(BaseTransport): | ||||
|     async def close(self) -> None: | ||||
|         pass | ||||
|  | ||||
|     async def reset(self) -> None: | ||||
|         pass | ||||
|  | ||||
|  | ||||
| class FakeTransportProtocol(TPLinkSmartHomeProtocol): | ||||
|     def __init__(self, info): | ||||
|   | ||||
| @@ -147,6 +147,7 @@ async def test_login_errors(mocker, inner_error_codes, expectation, call_count): | ||||
|         await transport.send(json_dumps(request)) | ||||
|         assert transport._login_token == mock_aes_device.token | ||||
|         assert post_mock.call_count == call_count  # Login, Handshake, Login | ||||
|         await transport.close() | ||||
|  | ||||
|  | ||||
| @status_parameters | ||||
|   | ||||
| @@ -69,6 +69,8 @@ async def test_connect( | ||||
|  | ||||
|     assert dev.config == config | ||||
|  | ||||
|     await dev.disconnect() | ||||
|  | ||||
|  | ||||
| @pytest.mark.parametrize("custom_port", [123, None]) | ||||
| async def test_connect_custom_port(all_fixture_data: dict, mocker, custom_port): | ||||
|   | ||||
| @@ -19,12 +19,12 @@ from ..httpclient import HttpClient | ||||
|         ( | ||||
|             aiohttp.ServerDisconnectedError(), | ||||
|             ConnectionException, | ||||
|             "Unable to connect to the device: ", | ||||
|             "Device connection error: ", | ||||
|         ), | ||||
|         ( | ||||
|             aiohttp.ClientOSError(), | ||||
|             ConnectionException, | ||||
|             "Unable to connect to the device: ", | ||||
|             "Device connection error: ", | ||||
|         ), | ||||
|         ( | ||||
|             aiohttp.ServerTimeoutError(), | ||||
|   | ||||
| @@ -54,9 +54,10 @@ class _mock_response: | ||||
|     [ | ||||
|         (Exception("dummy exception"), False), | ||||
|         (aiohttp.ServerTimeoutError("dummy exception"), True), | ||||
|         (aiohttp.ServerDisconnectedError("dummy exception"), True), | ||||
|         (aiohttp.ClientOSError("dummy exception"), True), | ||||
|     ], | ||||
|     ids=("Exception", "SmartDeviceException", "ConnectError"), | ||||
|     ids=("Exception", "ServerTimeoutError", "ServerDisconnectedError", "ClientOSError"), | ||||
| ) | ||||
| @pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport]) | ||||
| @pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol]) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Steven B
					Steven B