mirror of
https://github.com/python-kasa/python-kasa.git
synced 2024-12-22 19:23:34 +00:00
Update transport close/reset behaviour (#689)
Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
parent
e576fcdb46
commit
1788c50146
@ -315,10 +315,12 @@ class AesTransport(BaseTransport):
|
||||
return await self.send_secure_passthrough(request)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Mark the handshake and login as not done.
|
||||
"""Close the http client and reset internal state."""
|
||||
await self.reset()
|
||||
await self._http_client.close()
|
||||
|
||||
Since we likely lost the connection.
|
||||
"""
|
||||
async def reset(self) -> None:
|
||||
"""Reset internal handshake and login state."""
|
||||
self._handshake_done = False
|
||||
self._login_token = None
|
||||
|
||||
|
@ -5,7 +5,11 @@ from typing import Any, Dict, Optional, Tuple, Union
|
||||
import aiohttp
|
||||
|
||||
from .deviceconfig import DeviceConfig
|
||||
from .exceptions import ConnectionException, SmartDeviceException, TimeoutException
|
||||
from .exceptions import (
|
||||
ConnectionException,
|
||||
SmartDeviceException,
|
||||
TimeoutException,
|
||||
)
|
||||
from .json import loads as json_loads
|
||||
|
||||
|
||||
@ -78,7 +82,7 @@ class HttpClient:
|
||||
|
||||
except (aiohttp.ServerDisconnectedError, aiohttp.ClientOSError) as ex:
|
||||
raise ConnectionException(
|
||||
f"Unable to connect to the device: {self._config.host}: {ex}", ex
|
||||
f"Device connection error: {self._config.host}: {ex}", ex
|
||||
) from ex
|
||||
except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as ex:
|
||||
raise TimeoutException(
|
||||
|
@ -45,32 +45,31 @@ class IotProtocol(BaseProtocol):
|
||||
try:
|
||||
return await self._execute_query(request, retry)
|
||||
except ConnectionException as sdex:
|
||||
await self.close()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise sdex
|
||||
continue
|
||||
except AuthenticationException as auex:
|
||||
await self.close()
|
||||
await self._transport.reset()
|
||||
_LOGGER.debug(
|
||||
"Unable to authenticate with %s, not retrying", self._host
|
||||
)
|
||||
raise auex
|
||||
except RetryableException as ex:
|
||||
await self.close()
|
||||
await self._transport.reset()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise ex
|
||||
continue
|
||||
except TimeoutException as ex:
|
||||
await self.close()
|
||||
await self._transport.reset()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise ex
|
||||
await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT)
|
||||
continue
|
||||
except SmartDeviceException as ex:
|
||||
await self.close()
|
||||
await self._transport.reset()
|
||||
_LOGGER.debug(
|
||||
"Unable to query the device: %s, not retrying: %s",
|
||||
self._host,
|
||||
@ -85,10 +84,5 @@ class IotProtocol(BaseProtocol):
|
||||
return await self._transport.send(request)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the underlying transport.
|
||||
|
||||
Some transports may close the connection, and some may
|
||||
use this as a hint that they need to reconnect, or
|
||||
reauthenticate.
|
||||
"""
|
||||
"""Close the underlying transport."""
|
||||
await self._transport.close()
|
||||
|
@ -348,7 +348,12 @@ class KlapTransport(BaseTransport):
|
||||
return json_payload
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Mark the handshake as not done since we likely lost the connection."""
|
||||
"""Close the http client and reset internal state."""
|
||||
await self.reset()
|
||||
await self._http_client.close()
|
||||
|
||||
async def reset(self) -> None:
|
||||
"""Reset internal handshake state."""
|
||||
self._handshake_done = False
|
||||
|
||||
@staticmethod
|
||||
|
@ -80,6 +80,10 @@ class BaseTransport(ABC):
|
||||
async def close(self) -> None:
|
||||
"""Close the transport. Abstract method to be overriden."""
|
||||
|
||||
@abstractmethod
|
||||
async def reset(self) -> None:
|
||||
"""Reset internal state."""
|
||||
|
||||
|
||||
class BaseProtocol(ABC):
|
||||
"""Base class for all TP-Link Smart Home communication."""
|
||||
@ -139,7 +143,10 @@ class _XorTransport(BaseTransport):
|
||||
return {}
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the transport. Abstract method to be overriden."""
|
||||
"""Close the transport."""
|
||||
|
||||
async def reset(self) -> None:
|
||||
"""Reset internal state.."""
|
||||
|
||||
|
||||
class TPLinkSmartHomeProtocol(BaseProtocol):
|
||||
@ -233,9 +240,9 @@ class TPLinkSmartHomeProtocol(BaseProtocol):
|
||||
if writer:
|
||||
writer.close()
|
||||
|
||||
def _reset(self) -> None:
|
||||
"""Clear any varibles that should not survive between loops."""
|
||||
self.reader = self.writer = None
|
||||
async def reset(self) -> None:
|
||||
"""Reset the transport."""
|
||||
await self.close()
|
||||
|
||||
async def _query(self, request: str, retry_count: int, timeout: int) -> Dict:
|
||||
"""Try to query a device."""
|
||||
@ -252,12 +259,12 @@ class TPLinkSmartHomeProtocol(BaseProtocol):
|
||||
try:
|
||||
await self._connect(timeout)
|
||||
except ConnectionRefusedError as ex:
|
||||
await self.close()
|
||||
await self.reset()
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device: {self._host}:{self._port}: {ex}"
|
||||
) from ex
|
||||
except OSError as ex:
|
||||
await self.close()
|
||||
await self.reset()
|
||||
if ex.errno in _NO_RETRY_ERRORS or retry >= retry_count:
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device:"
|
||||
@ -265,7 +272,7 @@ class TPLinkSmartHomeProtocol(BaseProtocol):
|
||||
) from ex
|
||||
continue
|
||||
except Exception as ex:
|
||||
await self.close()
|
||||
await self.reset()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise SmartDeviceException(
|
||||
@ -290,7 +297,7 @@ class TPLinkSmartHomeProtocol(BaseProtocol):
|
||||
async with asyncio_timeout(timeout):
|
||||
return await self._execute_query(request)
|
||||
except Exception as ex:
|
||||
await self.close()
|
||||
await self.reset()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise SmartDeviceException(
|
||||
@ -312,7 +319,7 @@ class TPLinkSmartHomeProtocol(BaseProtocol):
|
||||
raise
|
||||
|
||||
# make mypy happy, this should never be reached..
|
||||
await self.close()
|
||||
await self.reset()
|
||||
raise SmartDeviceException("Query reached somehow to unreachable")
|
||||
|
||||
def __del__(self) -> None:
|
||||
@ -322,7 +329,6 @@ class TPLinkSmartHomeProtocol(BaseProtocol):
|
||||
# or in another thread so we need to make sure the call to
|
||||
# close is called safely with call_soon_threadsafe
|
||||
self.loop.call_soon_threadsafe(self.writer.close)
|
||||
self._reset()
|
||||
|
||||
@staticmethod
|
||||
def _xor_payload(unencrypted: bytes) -> Generator[int, None, None]:
|
||||
|
@ -806,6 +806,10 @@ class SmartDevice:
|
||||
"""Return the device configuration."""
|
||||
return self.protocol.config
|
||||
|
||||
async def disconnect(self):
|
||||
"""Disconnect and close any underlying connection resources."""
|
||||
await self.protocol.close()
|
||||
|
||||
@staticmethod
|
||||
async def connect(
|
||||
*,
|
||||
|
@ -66,32 +66,31 @@ class SmartProtocol(BaseProtocol):
|
||||
try:
|
||||
return await self._execute_query(request, retry)
|
||||
except ConnectionException as sdex:
|
||||
await self.close()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise sdex
|
||||
continue
|
||||
except AuthenticationException as auex:
|
||||
await self.close()
|
||||
await self._transport.reset()
|
||||
_LOGGER.debug(
|
||||
"Unable to authenticate with %s, not retrying", self._host
|
||||
)
|
||||
raise auex
|
||||
except RetryableException as ex:
|
||||
await self.close()
|
||||
await self._transport.reset()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise ex
|
||||
continue
|
||||
except TimeoutException as ex:
|
||||
await self.close()
|
||||
await self._transport.reset()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise ex
|
||||
await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT)
|
||||
continue
|
||||
except SmartDeviceException as ex:
|
||||
await self.close()
|
||||
await self._transport.reset()
|
||||
_LOGGER.debug(
|
||||
"Unable to query the device: %s, not retrying: %s",
|
||||
self._host,
|
||||
@ -167,12 +166,7 @@ class SmartProtocol(BaseProtocol):
|
||||
raise SmartDeviceException(msg, error_code=error_code)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the underlying transport.
|
||||
|
||||
Some transports may close the connection, and some may
|
||||
use this as a hint that they need to reconnect, or
|
||||
reauthenticate.
|
||||
"""
|
||||
"""Close the underlying transport."""
|
||||
await self._transport.close()
|
||||
|
||||
|
||||
|
@ -15,6 +15,7 @@ from kasa import (
|
||||
Credentials,
|
||||
Discover,
|
||||
SmartBulb,
|
||||
SmartDevice,
|
||||
SmartDimmer,
|
||||
SmartLightStrip,
|
||||
SmartPlug,
|
||||
@ -416,9 +417,15 @@ async def dev(request):
|
||||
IP_MODEL_CACHE[ip] = model = d.model
|
||||
if model not in file:
|
||||
pytest.skip(f"skipping file {file}")
|
||||
return d if d else await _discover_update_and_close(ip, username, password)
|
||||
dev: SmartDevice = (
|
||||
d if d else await _discover_update_and_close(ip, username, password)
|
||||
)
|
||||
else:
|
||||
dev: SmartDevice = await get_device_for_file(file, protocol)
|
||||
|
||||
return await get_device_for_file(file, protocol)
|
||||
yield dev
|
||||
|
||||
await dev.disconnect()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -377,6 +377,9 @@ class FakeSmartTransport(BaseTransport):
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
async def reset(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class FakeTransportProtocol(TPLinkSmartHomeProtocol):
|
||||
def __init__(self, info):
|
||||
|
@ -147,6 +147,7 @@ async def test_login_errors(mocker, inner_error_codes, expectation, call_count):
|
||||
await transport.send(json_dumps(request))
|
||||
assert transport._login_token == mock_aes_device.token
|
||||
assert post_mock.call_count == call_count # Login, Handshake, Login
|
||||
await transport.close()
|
||||
|
||||
|
||||
@status_parameters
|
||||
|
@ -69,6 +69,8 @@ async def test_connect(
|
||||
|
||||
assert dev.config == config
|
||||
|
||||
await dev.disconnect()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("custom_port", [123, None])
|
||||
async def test_connect_custom_port(all_fixture_data: dict, mocker, custom_port):
|
||||
|
@ -19,12 +19,12 @@ from ..httpclient import HttpClient
|
||||
(
|
||||
aiohttp.ServerDisconnectedError(),
|
||||
ConnectionException,
|
||||
"Unable to connect to the device: ",
|
||||
"Device connection error: ",
|
||||
),
|
||||
(
|
||||
aiohttp.ClientOSError(),
|
||||
ConnectionException,
|
||||
"Unable to connect to the device: ",
|
||||
"Device connection error: ",
|
||||
),
|
||||
(
|
||||
aiohttp.ServerTimeoutError(),
|
||||
|
@ -54,9 +54,10 @@ class _mock_response:
|
||||
[
|
||||
(Exception("dummy exception"), False),
|
||||
(aiohttp.ServerTimeoutError("dummy exception"), True),
|
||||
(aiohttp.ServerDisconnectedError("dummy exception"), True),
|
||||
(aiohttp.ClientOSError("dummy exception"), True),
|
||||
],
|
||||
ids=("Exception", "SmartDeviceException", "ConnectError"),
|
||||
ids=("Exception", "ServerTimeoutError", "ServerDisconnectedError", "ClientOSError"),
|
||||
)
|
||||
@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
|
||||
@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol])
|
||||
|
Loading…
Reference in New Issue
Block a user