mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-04-26 16:46:23 +00:00
Update transport close/reset behaviour
This commit is contained in:
parent
e233e377ad
commit
0e874a35f1
@ -306,10 +306,12 @@ class AesTransport(BaseTransport):
|
||||
return await self.send_secure_passthrough(request)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Mark the handshake and login as not done.
|
||||
"""Close the http client and reset internal state."""
|
||||
await self.reset()
|
||||
await self._http_client.close()
|
||||
|
||||
Since we likely lost the connection.
|
||||
"""
|
||||
async def reset(self) -> None:
|
||||
"""Reset internal handshake and login state."""
|
||||
self._handshake_done = False
|
||||
self._login_token = None
|
||||
|
||||
|
@ -42,6 +42,10 @@ class ConnectionException(SmartDeviceException):
|
||||
"""Connection exception for device errors."""
|
||||
|
||||
|
||||
class DisconnectedException(SmartDeviceException):
|
||||
"""Disconnected exception for device errors."""
|
||||
|
||||
|
||||
class SmartErrorCode(IntEnum):
|
||||
"""Enum for SMART Error Codes."""
|
||||
|
||||
|
@ -5,7 +5,12 @@ from typing import Any, Dict, Optional, Tuple, Union
|
||||
import aiohttp
|
||||
|
||||
from .deviceconfig import DeviceConfig
|
||||
from .exceptions import ConnectionException, SmartDeviceException, TimeoutException
|
||||
from .exceptions import (
|
||||
ConnectionException,
|
||||
DisconnectedException,
|
||||
SmartDeviceException,
|
||||
TimeoutException,
|
||||
)
|
||||
from .json import loads as json_loads
|
||||
|
||||
|
||||
@ -76,7 +81,11 @@ class HttpClient:
|
||||
if return_json:
|
||||
response_data = json_loads(response_data.decode())
|
||||
|
||||
except (aiohttp.ServerDisconnectedError, aiohttp.ClientOSError) as ex:
|
||||
except aiohttp.ServerDisconnectedError as ex:
|
||||
raise DisconnectedException(
|
||||
f"Disconnected from the device: {self._config.host}: {ex}", ex
|
||||
) from ex
|
||||
except aiohttp.ClientOSError as ex:
|
||||
raise ConnectionException(
|
||||
f"Unable to connect to the device: {self._config.host}: {ex}", ex
|
||||
) from ex
|
||||
|
@ -6,6 +6,7 @@ from typing import Dict, Union
|
||||
from .exceptions import (
|
||||
AuthenticationException,
|
||||
ConnectionException,
|
||||
DisconnectedException,
|
||||
RetryableException,
|
||||
SmartDeviceException,
|
||||
TimeoutException,
|
||||
@ -44,33 +45,38 @@ class IotProtocol(BaseProtocol):
|
||||
for retry in range(retry_count + 1):
|
||||
try:
|
||||
return await self._execute_query(request, retry)
|
||||
except DisconnectedException as sdex:
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise sdex
|
||||
continue
|
||||
except ConnectionException as sdex:
|
||||
await self.close()
|
||||
await self._transport.reset()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise sdex
|
||||
continue
|
||||
except AuthenticationException as auex:
|
||||
await self.close()
|
||||
await self._transport.reset()
|
||||
_LOGGER.debug(
|
||||
"Unable to authenticate with %s, not retrying", self._host
|
||||
)
|
||||
raise auex
|
||||
except RetryableException as ex:
|
||||
await self.close()
|
||||
await self._transport.reset()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise ex
|
||||
continue
|
||||
except TimeoutException as ex:
|
||||
await self.close()
|
||||
await self._transport.reset()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise ex
|
||||
await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT)
|
||||
continue
|
||||
except SmartDeviceException as ex:
|
||||
await self.close()
|
||||
await self._transport.reset()
|
||||
_LOGGER.debug(
|
||||
"Unable to query the device: %s, not retrying: %s",
|
||||
self._host,
|
||||
@ -85,10 +91,5 @@ class IotProtocol(BaseProtocol):
|
||||
return await self._transport.send(request)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the underlying transport.
|
||||
|
||||
Some transports may close the connection, and some may
|
||||
use this as a hint that they need to reconnect, or
|
||||
reauthenticate.
|
||||
"""
|
||||
"""Close the underlying transport."""
|
||||
await self._transport.close()
|
||||
|
@ -348,7 +348,12 @@ class KlapTransport(BaseTransport):
|
||||
return json_payload
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Mark the handshake as not done since we likely lost the connection."""
|
||||
"""Close the http client and reset internal state."""
|
||||
await self.reset()
|
||||
await self._http_client.close()
|
||||
|
||||
async def reset(self) -> None:
|
||||
"""Reset internal handshake state."""
|
||||
self._handshake_done = False
|
||||
|
||||
@staticmethod
|
||||
|
@ -80,6 +80,10 @@ class BaseTransport(ABC):
|
||||
async def close(self) -> None:
|
||||
"""Close the transport. Abstract method to be overriden."""
|
||||
|
||||
@abstractmethod
|
||||
async def reset(self) -> None:
|
||||
"""Reset internal state."""
|
||||
|
||||
|
||||
class BaseProtocol(ABC):
|
||||
"""Base class for all TP-Link Smart Home communication."""
|
||||
@ -139,7 +143,10 @@ class _XorTransport(BaseTransport):
|
||||
return {}
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the transport. Abstract method to be overriden."""
|
||||
"""Close the transport."""
|
||||
|
||||
async def reset(self) -> None:
|
||||
"""Reset internal state.."""
|
||||
|
||||
|
||||
class TPLinkSmartHomeProtocol(BaseProtocol):
|
||||
@ -233,9 +240,9 @@ class TPLinkSmartHomeProtocol(BaseProtocol):
|
||||
if writer:
|
||||
writer.close()
|
||||
|
||||
def _reset(self) -> None:
|
||||
"""Clear any varibles that should not survive between loops."""
|
||||
self.reader = self.writer = None
|
||||
async def reset(self) -> None:
|
||||
"""Reset the transport."""
|
||||
await self.close()
|
||||
|
||||
async def _query(self, request: str, retry_count: int, timeout: int) -> Dict:
|
||||
"""Try to query a device."""
|
||||
@ -252,12 +259,12 @@ class TPLinkSmartHomeProtocol(BaseProtocol):
|
||||
try:
|
||||
await self._connect(timeout)
|
||||
except ConnectionRefusedError as ex:
|
||||
await self.close()
|
||||
await self.reset()
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device: {self._host}:{self._port}: {ex}"
|
||||
) from ex
|
||||
except OSError as ex:
|
||||
await self.close()
|
||||
await self.reset()
|
||||
if ex.errno in _NO_RETRY_ERRORS or retry >= retry_count:
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device:"
|
||||
@ -265,7 +272,7 @@ class TPLinkSmartHomeProtocol(BaseProtocol):
|
||||
) from ex
|
||||
continue
|
||||
except Exception as ex:
|
||||
await self.close()
|
||||
await self.reset()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise SmartDeviceException(
|
||||
@ -290,7 +297,7 @@ class TPLinkSmartHomeProtocol(BaseProtocol):
|
||||
async with asyncio_timeout(timeout):
|
||||
return await self._execute_query(request)
|
||||
except Exception as ex:
|
||||
await self.close()
|
||||
await self.reset()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise SmartDeviceException(
|
||||
@ -312,7 +319,7 @@ class TPLinkSmartHomeProtocol(BaseProtocol):
|
||||
raise
|
||||
|
||||
# make mypy happy, this should never be reached..
|
||||
await self.close()
|
||||
await self.reset()
|
||||
raise SmartDeviceException("Query reached somehow to unreachable")
|
||||
|
||||
def __del__(self) -> None:
|
||||
@ -322,7 +329,6 @@ class TPLinkSmartHomeProtocol(BaseProtocol):
|
||||
# or in another thread so we need to make sure the call to
|
||||
# close is called safely with call_soon_threadsafe
|
||||
self.loop.call_soon_threadsafe(self.writer.close)
|
||||
self._reset()
|
||||
|
||||
@staticmethod
|
||||
def _xor_payload(unencrypted: bytes) -> Generator[int, None, None]:
|
||||
|
@ -806,6 +806,10 @@ class SmartDevice:
|
||||
"""Return the device configuration."""
|
||||
return self.protocol.config
|
||||
|
||||
async def disconnect(self):
|
||||
"""Disconnect and close any underlying connection resources."""
|
||||
await self.protocol.close()
|
||||
|
||||
@staticmethod
|
||||
async def connect(
|
||||
*,
|
||||
|
@ -18,6 +18,7 @@ from .exceptions import (
|
||||
SMART_TIMEOUT_ERRORS,
|
||||
AuthenticationException,
|
||||
ConnectionException,
|
||||
DisconnectedException,
|
||||
RetryableException,
|
||||
SmartDeviceException,
|
||||
SmartErrorCode,
|
||||
@ -65,33 +66,38 @@ class SmartProtocol(BaseProtocol):
|
||||
for retry in range(retry_count + 1):
|
||||
try:
|
||||
return await self._execute_query(request, retry)
|
||||
except DisconnectedException as sdex:
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise sdex
|
||||
continue
|
||||
except ConnectionException as sdex:
|
||||
await self.close()
|
||||
await self._transport.reset()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise sdex
|
||||
continue
|
||||
except AuthenticationException as auex:
|
||||
await self.close()
|
||||
await self._transport.reset()
|
||||
_LOGGER.debug(
|
||||
"Unable to authenticate with %s, not retrying", self._host
|
||||
)
|
||||
raise auex
|
||||
except RetryableException as ex:
|
||||
await self.close()
|
||||
await self._transport.reset()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise ex
|
||||
continue
|
||||
except TimeoutException as ex:
|
||||
await self.close()
|
||||
await self._transport.reset()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise ex
|
||||
await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT)
|
||||
continue
|
||||
except SmartDeviceException as ex:
|
||||
await self.close()
|
||||
await self._transport.reset()
|
||||
_LOGGER.debug(
|
||||
"Unable to query the device: %s, not retrying: %s",
|
||||
self._host,
|
||||
@ -167,12 +173,7 @@ class SmartProtocol(BaseProtocol):
|
||||
raise SmartDeviceException(msg, error_code=error_code)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the underlying transport.
|
||||
|
||||
Some transports may close the connection, and some may
|
||||
use this as a hint that they need to reconnect, or
|
||||
reauthenticate.
|
||||
"""
|
||||
"""Close the underlying transport."""
|
||||
await self._transport.close()
|
||||
|
||||
|
||||
|
@ -15,6 +15,7 @@ from kasa import (
|
||||
Credentials,
|
||||
Discover,
|
||||
SmartBulb,
|
||||
SmartDevice,
|
||||
SmartDimmer,
|
||||
SmartLightStrip,
|
||||
SmartPlug,
|
||||
@ -416,9 +417,15 @@ async def dev(request):
|
||||
IP_MODEL_CACHE[ip] = model = d.model
|
||||
if model not in file:
|
||||
pytest.skip(f"skipping file {file}")
|
||||
return d if d else await _discover_update_and_close(ip, username, password)
|
||||
dev: SmartDevice = (
|
||||
d if d else await _discover_update_and_close(ip, username, password)
|
||||
)
|
||||
else:
|
||||
dev: SmartDevice = await get_device_for_file(file, protocol)
|
||||
|
||||
return await get_device_for_file(file, protocol)
|
||||
yield dev
|
||||
|
||||
await dev.disconnect()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -377,6 +377,9 @@ class FakeSmartTransport(BaseTransport):
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
async def reset(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class FakeTransportProtocol(TPLinkSmartHomeProtocol):
|
||||
def __init__(self, info):
|
||||
|
@ -7,6 +7,7 @@ import pytest
|
||||
from ..deviceconfig import DeviceConfig
|
||||
from ..exceptions import (
|
||||
ConnectionException,
|
||||
DisconnectedException,
|
||||
SmartDeviceException,
|
||||
TimeoutException,
|
||||
)
|
||||
@ -18,8 +19,8 @@ from ..httpclient import HttpClient
|
||||
[
|
||||
(
|
||||
aiohttp.ServerDisconnectedError(),
|
||||
ConnectionException,
|
||||
"Unable to connect to the device: ",
|
||||
DisconnectedException,
|
||||
"Disconnected from the device: ",
|
||||
),
|
||||
(
|
||||
aiohttp.ClientOSError(),
|
||||
|
@ -54,9 +54,10 @@ class _mock_response:
|
||||
[
|
||||
(Exception("dummy exception"), False),
|
||||
(aiohttp.ServerTimeoutError("dummy exception"), True),
|
||||
(aiohttp.ServerDisconnectedError("dummy exception"), True),
|
||||
(aiohttp.ClientOSError("dummy exception"), True),
|
||||
],
|
||||
ids=("Exception", "SmartDeviceException", "ConnectError"),
|
||||
ids=("Exception", "SmartDeviceException", "DisconnectError", "ConnectError"),
|
||||
)
|
||||
@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
|
||||
@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol])
|
||||
|
Loading…
x
Reference in New Issue
Block a user