Update transport close/reset behaviour

This commit is contained in:
sdb9696 2024-01-23 17:11:23 +00:00
parent e233e377ad
commit 0e874a35f1
12 changed files with 87 additions and 43 deletions

View File

@ -306,10 +306,12 @@ class AesTransport(BaseTransport):
return await self.send_secure_passthrough(request)
async def close(self) -> None:
"""Mark the handshake and login as not done.
"""Close the http client and reset internal state."""
await self.reset()
await self._http_client.close()
Since we likely lost the connection.
"""
async def reset(self) -> None:
"""Reset internal handshake and login state."""
self._handshake_done = False
self._login_token = None

View File

@ -42,6 +42,10 @@ class ConnectionException(SmartDeviceException):
"""Connection exception for device errors."""
class DisconnectedException(SmartDeviceException):
"""Disconnected exception for device errors."""
class SmartErrorCode(IntEnum):
"""Enum for SMART Error Codes."""

View File

@ -5,7 +5,12 @@ from typing import Any, Dict, Optional, Tuple, Union
import aiohttp
from .deviceconfig import DeviceConfig
from .exceptions import ConnectionException, SmartDeviceException, TimeoutException
from .exceptions import (
ConnectionException,
DisconnectedException,
SmartDeviceException,
TimeoutException,
)
from .json import loads as json_loads
@ -76,7 +81,11 @@ class HttpClient:
if return_json:
response_data = json_loads(response_data.decode())
except (aiohttp.ServerDisconnectedError, aiohttp.ClientOSError) as ex:
except aiohttp.ServerDisconnectedError as ex:
raise DisconnectedException(
f"Disconnected from the device: {self._config.host}: {ex}", ex
) from ex
except aiohttp.ClientOSError as ex:
raise ConnectionException(
f"Unable to connect to the device: {self._config.host}: {ex}", ex
) from ex

View File

@ -6,6 +6,7 @@ from typing import Dict, Union
from .exceptions import (
AuthenticationException,
ConnectionException,
DisconnectedException,
RetryableException,
SmartDeviceException,
TimeoutException,
@ -44,33 +45,38 @@ class IotProtocol(BaseProtocol):
for retry in range(retry_count + 1):
try:
return await self._execute_query(request, retry)
except DisconnectedException as sdex:
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise sdex
continue
except ConnectionException as sdex:
await self.close()
await self._transport.reset()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise sdex
continue
except AuthenticationException as auex:
await self.close()
await self._transport.reset()
_LOGGER.debug(
"Unable to authenticate with %s, not retrying", self._host
)
raise auex
except RetryableException as ex:
await self.close()
await self._transport.reset()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise ex
continue
except TimeoutException as ex:
await self.close()
await self._transport.reset()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise ex
await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT)
continue
except SmartDeviceException as ex:
await self.close()
await self._transport.reset()
_LOGGER.debug(
"Unable to query the device: %s, not retrying: %s",
self._host,
@ -85,10 +91,5 @@ class IotProtocol(BaseProtocol):
return await self._transport.send(request)
async def close(self) -> None:
"""Close the underlying transport.
Some transports may close the connection, and some may
use this as a hint that they need to reconnect, or
reauthenticate.
"""
"""Close the underlying transport."""
await self._transport.close()

View File

@ -348,7 +348,12 @@ class KlapTransport(BaseTransport):
return json_payload
async def close(self) -> None:
"""Mark the handshake as not done since we likely lost the connection."""
"""Close the http client and reset internal state."""
await self.reset()
await self._http_client.close()
async def reset(self) -> None:
"""Reset internal handshake state."""
self._handshake_done = False
@staticmethod

View File

@ -80,6 +80,10 @@ class BaseTransport(ABC):
async def close(self) -> None:
"""Close the transport. Abstract method to be overriden."""
@abstractmethod
async def reset(self) -> None:
"""Reset internal state."""
class BaseProtocol(ABC):
"""Base class for all TP-Link Smart Home communication."""
@ -139,7 +143,10 @@ class _XorTransport(BaseTransport):
return {}
async def close(self) -> None:
"""Close the transport. Abstract method to be overriden."""
"""Close the transport."""
async def reset(self) -> None:
"""Reset internal state.."""
class TPLinkSmartHomeProtocol(BaseProtocol):
@ -233,9 +240,9 @@ class TPLinkSmartHomeProtocol(BaseProtocol):
if writer:
writer.close()
def _reset(self) -> None:
"""Clear any varibles that should not survive between loops."""
self.reader = self.writer = None
async def reset(self) -> None:
"""Reset the transport."""
await self.close()
async def _query(self, request: str, retry_count: int, timeout: int) -> Dict:
"""Try to query a device."""
@ -252,12 +259,12 @@ class TPLinkSmartHomeProtocol(BaseProtocol):
try:
await self._connect(timeout)
except ConnectionRefusedError as ex:
await self.close()
await self.reset()
raise SmartDeviceException(
f"Unable to connect to the device: {self._host}:{self._port}: {ex}"
) from ex
except OSError as ex:
await self.close()
await self.reset()
if ex.errno in _NO_RETRY_ERRORS or retry >= retry_count:
raise SmartDeviceException(
f"Unable to connect to the device:"
@ -265,7 +272,7 @@ class TPLinkSmartHomeProtocol(BaseProtocol):
) from ex
continue
except Exception as ex:
await self.close()
await self.reset()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise SmartDeviceException(
@ -290,7 +297,7 @@ class TPLinkSmartHomeProtocol(BaseProtocol):
async with asyncio_timeout(timeout):
return await self._execute_query(request)
except Exception as ex:
await self.close()
await self.reset()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise SmartDeviceException(
@ -312,7 +319,7 @@ class TPLinkSmartHomeProtocol(BaseProtocol):
raise
# make mypy happy, this should never be reached..
await self.close()
await self.reset()
raise SmartDeviceException("Query reached somehow to unreachable")
def __del__(self) -> None:
@ -322,7 +329,6 @@ class TPLinkSmartHomeProtocol(BaseProtocol):
# or in another thread so we need to make sure the call to
# close is called safely with call_soon_threadsafe
self.loop.call_soon_threadsafe(self.writer.close)
self._reset()
@staticmethod
def _xor_payload(unencrypted: bytes) -> Generator[int, None, None]:

View File

@ -806,6 +806,10 @@ class SmartDevice:
"""Return the device configuration."""
return self.protocol.config
async def disconnect(self):
"""Disconnect and close any underlying connection resources."""
await self.protocol.close()
@staticmethod
async def connect(
*,

View File

@ -18,6 +18,7 @@ from .exceptions import (
SMART_TIMEOUT_ERRORS,
AuthenticationException,
ConnectionException,
DisconnectedException,
RetryableException,
SmartDeviceException,
SmartErrorCode,
@ -65,33 +66,38 @@ class SmartProtocol(BaseProtocol):
for retry in range(retry_count + 1):
try:
return await self._execute_query(request, retry)
except DisconnectedException as sdex:
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise sdex
continue
except ConnectionException as sdex:
await self.close()
await self._transport.reset()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise sdex
continue
except AuthenticationException as auex:
await self.close()
await self._transport.reset()
_LOGGER.debug(
"Unable to authenticate with %s, not retrying", self._host
)
raise auex
except RetryableException as ex:
await self.close()
await self._transport.reset()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise ex
continue
except TimeoutException as ex:
await self.close()
await self._transport.reset()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise ex
await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT)
continue
except SmartDeviceException as ex:
await self.close()
await self._transport.reset()
_LOGGER.debug(
"Unable to query the device: %s, not retrying: %s",
self._host,
@ -167,12 +173,7 @@ class SmartProtocol(BaseProtocol):
raise SmartDeviceException(msg, error_code=error_code)
async def close(self) -> None:
"""Close the underlying transport.
Some transports may close the connection, and some may
use this as a hint that they need to reconnect, or
reauthenticate.
"""
"""Close the underlying transport."""
await self._transport.close()

View File

@ -15,6 +15,7 @@ from kasa import (
Credentials,
Discover,
SmartBulb,
SmartDevice,
SmartDimmer,
SmartLightStrip,
SmartPlug,
@ -416,9 +417,15 @@ async def dev(request):
IP_MODEL_CACHE[ip] = model = d.model
if model not in file:
pytest.skip(f"skipping file {file}")
return d if d else await _discover_update_and_close(ip, username, password)
dev: SmartDevice = (
d if d else await _discover_update_and_close(ip, username, password)
)
else:
dev: SmartDevice = await get_device_for_file(file, protocol)
return await get_device_for_file(file, protocol)
yield dev
await dev.disconnect()
@pytest.fixture

View File

@ -377,6 +377,9 @@ class FakeSmartTransport(BaseTransport):
async def close(self) -> None:
pass
async def reset(self) -> None:
pass
class FakeTransportProtocol(TPLinkSmartHomeProtocol):
def __init__(self, info):

View File

@ -7,6 +7,7 @@ import pytest
from ..deviceconfig import DeviceConfig
from ..exceptions import (
ConnectionException,
DisconnectedException,
SmartDeviceException,
TimeoutException,
)
@ -18,8 +19,8 @@ from ..httpclient import HttpClient
[
(
aiohttp.ServerDisconnectedError(),
ConnectionException,
"Unable to connect to the device: ",
DisconnectedException,
"Disconnected from the device: ",
),
(
aiohttp.ClientOSError(),

View File

@ -54,9 +54,10 @@ class _mock_response:
[
(Exception("dummy exception"), False),
(aiohttp.ServerTimeoutError("dummy exception"), True),
(aiohttp.ServerDisconnectedError("dummy exception"), True),
(aiohttp.ClientOSError("dummy exception"), True),
],
ids=("Exception", "SmartDeviceException", "ConnectError"),
ids=("Exception", "SmartDeviceException", "DisconnectError", "ConnectError"),
)
@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol])