Update transport close/reset behaviour (#689)

Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
Steven B 2024-01-23 22:15:18 +00:00 committed by GitHub
parent e576fcdb46
commit 1788c50146
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 66 additions and 43 deletions

View File

@ -315,10 +315,12 @@ class AesTransport(BaseTransport):
return await self.send_secure_passthrough(request)
async def close(self) -> None:
"""Mark the handshake and login as not done.
"""Close the http client and reset internal state."""
await self.reset()
await self._http_client.close()
Since we likely lost the connection.
"""
async def reset(self) -> None:
"""Reset internal handshake and login state."""
self._handshake_done = False
self._login_token = None

View File

@ -5,7 +5,11 @@ from typing import Any, Dict, Optional, Tuple, Union
import aiohttp
from .deviceconfig import DeviceConfig
from .exceptions import ConnectionException, SmartDeviceException, TimeoutException
from .exceptions import (
ConnectionException,
SmartDeviceException,
TimeoutException,
)
from .json import loads as json_loads
@ -78,7 +82,7 @@ class HttpClient:
except (aiohttp.ServerDisconnectedError, aiohttp.ClientOSError) as ex:
raise ConnectionException(
f"Unable to connect to the device: {self._config.host}: {ex}", ex
f"Device connection error: {self._config.host}: {ex}", ex
) from ex
except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as ex:
raise TimeoutException(

View File

@ -45,32 +45,31 @@ class IotProtocol(BaseProtocol):
try:
return await self._execute_query(request, retry)
except ConnectionException as sdex:
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise sdex
continue
except AuthenticationException as auex:
await self.close()
await self._transport.reset()
_LOGGER.debug(
"Unable to authenticate with %s, not retrying", self._host
)
raise auex
except RetryableException as ex:
await self.close()
await self._transport.reset()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise ex
continue
except TimeoutException as ex:
await self.close()
await self._transport.reset()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise ex
await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT)
continue
except SmartDeviceException as ex:
await self.close()
await self._transport.reset()
_LOGGER.debug(
"Unable to query the device: %s, not retrying: %s",
self._host,
@ -85,10 +84,5 @@ class IotProtocol(BaseProtocol):
return await self._transport.send(request)
async def close(self) -> None:
"""Close the underlying transport.
Some transports may close the connection, and some may
use this as a hint that they need to reconnect, or
reauthenticate.
"""
"""Close the underlying transport."""
await self._transport.close()

View File

@ -348,7 +348,12 @@ class KlapTransport(BaseTransport):
return json_payload
async def close(self) -> None:
"""Mark the handshake as not done since we likely lost the connection."""
"""Close the http client and reset internal state."""
await self.reset()
await self._http_client.close()
async def reset(self) -> None:
"""Reset internal handshake state."""
self._handshake_done = False
@staticmethod

View File

@ -80,6 +80,10 @@ class BaseTransport(ABC):
async def close(self) -> None:
"""Close the transport. Abstract method to be overriden."""
@abstractmethod
async def reset(self) -> None:
"""Reset internal state."""
class BaseProtocol(ABC):
"""Base class for all TP-Link Smart Home communication."""
@ -139,7 +143,10 @@ class _XorTransport(BaseTransport):
return {}
async def close(self) -> None:
"""Close the transport. Abstract method to be overriden."""
"""Close the transport."""
async def reset(self) -> None:
"""Reset internal state.."""
class TPLinkSmartHomeProtocol(BaseProtocol):
@ -233,9 +240,9 @@ class TPLinkSmartHomeProtocol(BaseProtocol):
if writer:
writer.close()
def _reset(self) -> None:
"""Clear any varibles that should not survive between loops."""
self.reader = self.writer = None
async def reset(self) -> None:
"""Reset the transport."""
await self.close()
async def _query(self, request: str, retry_count: int, timeout: int) -> Dict:
"""Try to query a device."""
@ -252,12 +259,12 @@ class TPLinkSmartHomeProtocol(BaseProtocol):
try:
await self._connect(timeout)
except ConnectionRefusedError as ex:
await self.close()
await self.reset()
raise SmartDeviceException(
f"Unable to connect to the device: {self._host}:{self._port}: {ex}"
) from ex
except OSError as ex:
await self.close()
await self.reset()
if ex.errno in _NO_RETRY_ERRORS or retry >= retry_count:
raise SmartDeviceException(
f"Unable to connect to the device:"
@ -265,7 +272,7 @@ class TPLinkSmartHomeProtocol(BaseProtocol):
) from ex
continue
except Exception as ex:
await self.close()
await self.reset()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise SmartDeviceException(
@ -290,7 +297,7 @@ class TPLinkSmartHomeProtocol(BaseProtocol):
async with asyncio_timeout(timeout):
return await self._execute_query(request)
except Exception as ex:
await self.close()
await self.reset()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise SmartDeviceException(
@ -312,7 +319,7 @@ class TPLinkSmartHomeProtocol(BaseProtocol):
raise
# make mypy happy, this should never be reached..
await self.close()
await self.reset()
raise SmartDeviceException("Query reached somehow to unreachable")
def __del__(self) -> None:
@ -322,7 +329,6 @@ class TPLinkSmartHomeProtocol(BaseProtocol):
# or in another thread so we need to make sure the call to
# close is called safely with call_soon_threadsafe
self.loop.call_soon_threadsafe(self.writer.close)
self._reset()
@staticmethod
def _xor_payload(unencrypted: bytes) -> Generator[int, None, None]:

View File

@ -806,6 +806,10 @@ class SmartDevice:
"""Return the device configuration."""
return self.protocol.config
async def disconnect(self):
"""Disconnect and close any underlying connection resources."""
await self.protocol.close()
@staticmethod
async def connect(
*,

View File

@ -66,32 +66,31 @@ class SmartProtocol(BaseProtocol):
try:
return await self._execute_query(request, retry)
except ConnectionException as sdex:
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise sdex
continue
except AuthenticationException as auex:
await self.close()
await self._transport.reset()
_LOGGER.debug(
"Unable to authenticate with %s, not retrying", self._host
)
raise auex
except RetryableException as ex:
await self.close()
await self._transport.reset()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise ex
continue
except TimeoutException as ex:
await self.close()
await self._transport.reset()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise ex
await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT)
continue
except SmartDeviceException as ex:
await self.close()
await self._transport.reset()
_LOGGER.debug(
"Unable to query the device: %s, not retrying: %s",
self._host,
@ -167,12 +166,7 @@ class SmartProtocol(BaseProtocol):
raise SmartDeviceException(msg, error_code=error_code)
async def close(self) -> None:
"""Close the underlying transport.
Some transports may close the connection, and some may
use this as a hint that they need to reconnect, or
reauthenticate.
"""
"""Close the underlying transport."""
await self._transport.close()

View File

@ -15,6 +15,7 @@ from kasa import (
Credentials,
Discover,
SmartBulb,
SmartDevice,
SmartDimmer,
SmartLightStrip,
SmartPlug,
@ -416,9 +417,15 @@ async def dev(request):
IP_MODEL_CACHE[ip] = model = d.model
if model not in file:
pytest.skip(f"skipping file {file}")
return d if d else await _discover_update_and_close(ip, username, password)
dev: SmartDevice = (
d if d else await _discover_update_and_close(ip, username, password)
)
else:
dev: SmartDevice = await get_device_for_file(file, protocol)
return await get_device_for_file(file, protocol)
yield dev
await dev.disconnect()
@pytest.fixture

View File

@ -377,6 +377,9 @@ class FakeSmartTransport(BaseTransport):
async def close(self) -> None:
pass
async def reset(self) -> None:
pass
class FakeTransportProtocol(TPLinkSmartHomeProtocol):
def __init__(self, info):

View File

@ -147,6 +147,7 @@ async def test_login_errors(mocker, inner_error_codes, expectation, call_count):
await transport.send(json_dumps(request))
assert transport._login_token == mock_aes_device.token
assert post_mock.call_count == call_count # Login, Handshake, Login
await transport.close()
@status_parameters

View File

@ -69,6 +69,8 @@ async def test_connect(
assert dev.config == config
await dev.disconnect()
@pytest.mark.parametrize("custom_port", [123, None])
async def test_connect_custom_port(all_fixture_data: dict, mocker, custom_port):

View File

@ -19,12 +19,12 @@ from ..httpclient import HttpClient
(
aiohttp.ServerDisconnectedError(),
ConnectionException,
"Unable to connect to the device: ",
"Device connection error: ",
),
(
aiohttp.ClientOSError(),
ConnectionException,
"Unable to connect to the device: ",
"Device connection error: ",
),
(
aiohttp.ServerTimeoutError(),

View File

@ -54,9 +54,10 @@ class _mock_response:
[
(Exception("dummy exception"), False),
(aiohttp.ServerTimeoutError("dummy exception"), True),
(aiohttp.ServerDisconnectedError("dummy exception"), True),
(aiohttp.ClientOSError("dummy exception"), True),
],
ids=("Exception", "SmartDeviceException", "ConnectError"),
ids=("Exception", "ServerTimeoutError", "ServerDisconnectedError", "ClientOSError"),
)
@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol])