mirror of
https://github.com/python-kasa/python-kasa.git
synced 2024-12-22 19:23:34 +00:00
Encapsulate http client dependency (#642)
* Encapsulate http client dependency * Store cookie dict as variable * Update post-review
This commit is contained in:
parent
4623434eb4
commit
3b1b0a3c21
@ -30,7 +30,6 @@ from dataclasses import asdict, dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
logging.getLogger("httpx").propagate = False
|
||||
|
||||
|
||||
class SmartRequest:
|
||||
|
@ -8,9 +8,8 @@ import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
from typing import Dict, Optional, cast
|
||||
|
||||
import httpx
|
||||
from cryptography.hazmat.primitives import padding, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
@ -28,6 +27,7 @@ from .exceptions import (
|
||||
SmartErrorCode,
|
||||
TimeoutException,
|
||||
)
|
||||
from .httpclient import HttpClient
|
||||
from .json import dumps as json_dumps
|
||||
from .json import loads as json_loads
|
||||
from .protocol import BaseTransport
|
||||
@ -75,14 +75,14 @@ class AesTransport(BaseTransport):
|
||||
base64.b64decode(self._credentials_hash.encode()).decode() # type: ignore[union-attr]
|
||||
)
|
||||
|
||||
self._default_http_client: Optional[httpx.AsyncClient] = None
|
||||
self._http_client: HttpClient = HttpClient(config)
|
||||
|
||||
self._handshake_done = False
|
||||
|
||||
self._encryption_session: Optional[AesEncyptionSession] = None
|
||||
self._session_expire_at: Optional[float] = None
|
||||
|
||||
self._session_cookie = None
|
||||
self._session_cookie: Optional[Dict[str, str]] = None
|
||||
|
||||
self._login_token = None
|
||||
|
||||
@ -98,14 +98,6 @@ class AesTransport(BaseTransport):
|
||||
"""The hashed credentials used by the transport."""
|
||||
return base64.b64encode(json_dumps(self._login_params).encode()).decode()
|
||||
|
||||
@property
|
||||
def _http_client(self) -> httpx.AsyncClient:
|
||||
if self._config.http_client:
|
||||
return self._config.http_client
|
||||
if not self._default_http_client:
|
||||
self._default_http_client = httpx.AsyncClient()
|
||||
return self._default_http_client
|
||||
|
||||
def _get_login_params(self):
|
||||
"""Get the login parameters based on the login_version."""
|
||||
un, pw = self.hash_credentials(self._login_version == 2)
|
||||
@ -128,28 +120,6 @@ class AesTransport(BaseTransport):
|
||||
pw = base64.b64encode(self._credentials.password.encode()).decode()
|
||||
return un, pw
|
||||
|
||||
async def client_post(self, url, params=None, data=None, json=None, headers=None):
|
||||
"""Send an http post request to the device."""
|
||||
response_data = None
|
||||
cookies = None
|
||||
if self._session_cookie:
|
||||
cookies = httpx.Cookies()
|
||||
cookies.set(self.SESSION_COOKIE_NAME, self._session_cookie)
|
||||
self._http_client.cookies.clear()
|
||||
resp = await self._http_client.post(
|
||||
url,
|
||||
params=params,
|
||||
data=data,
|
||||
json=json,
|
||||
timeout=self._timeout,
|
||||
cookies=cookies,
|
||||
headers=self.COMMON_HEADERS,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
response_data = resp.json()
|
||||
|
||||
return resp.status_code, response_data
|
||||
|
||||
def _handle_response_error_code(self, resp_dict: dict, msg: str):
|
||||
error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
|
||||
if error_code == SmartErrorCode.SUCCESS:
|
||||
@ -176,7 +146,12 @@ class AesTransport(BaseTransport):
|
||||
"method": "securePassthrough",
|
||||
"params": {"request": encrypted_payload.decode()},
|
||||
}
|
||||
status_code, resp_dict = await self.client_post(url, json=passthrough_request)
|
||||
status_code, resp_dict = await self._http_client.post(
|
||||
url,
|
||||
json=passthrough_request,
|
||||
headers=self.COMMON_HEADERS,
|
||||
cookies_dict=self._session_cookie,
|
||||
)
|
||||
# _LOGGER.debug(f"secure_passthrough response is {status_code}: {resp_dict}")
|
||||
|
||||
if status_code != 200:
|
||||
@ -185,6 +160,7 @@ class AesTransport(BaseTransport):
|
||||
+ f"status code {status_code} to passthrough"
|
||||
)
|
||||
|
||||
resp_dict = cast(Dict, resp_dict)
|
||||
self._handle_response_error_code(
|
||||
resp_dict, "Error sending secure_passthrough message"
|
||||
)
|
||||
@ -233,7 +209,12 @@ class AesTransport(BaseTransport):
|
||||
|
||||
_LOGGER.debug(f"Request {request_body}")
|
||||
|
||||
status_code, resp_dict = await self.client_post(url, json=request_body)
|
||||
status_code, resp_dict = await self._http_client.post(
|
||||
url,
|
||||
json=request_body,
|
||||
headers=self.COMMON_HEADERS,
|
||||
cookies_dict=self._session_cookie,
|
||||
)
|
||||
|
||||
_LOGGER.debug(f"Device responded with: {resp_dict}")
|
||||
|
||||
@ -247,13 +228,16 @@ class AesTransport(BaseTransport):
|
||||
|
||||
handshake_key = resp_dict["result"]["key"]
|
||||
|
||||
self._session_cookie = self._http_client.cookies.get( # type: ignore
|
||||
self.SESSION_COOKIE_NAME
|
||||
)
|
||||
if not self._session_cookie:
|
||||
self._session_cookie = self._http_client.cookies.get( # type: ignore
|
||||
if (
|
||||
cookie := self._http_client.get_cookie( # type: ignore
|
||||
self.SESSION_COOKIE_NAME
|
||||
)
|
||||
) or (
|
||||
cookie := self._http_client.get_cookie( # type: ignore
|
||||
"SESSIONID"
|
||||
)
|
||||
):
|
||||
self._session_cookie = {self.SESSION_COOKIE_NAME: cookie}
|
||||
|
||||
self._session_expire_at = time.time() + 86400
|
||||
self._encryption_session = AesEncyptionSession.create_from_keypair(
|
||||
@ -281,13 +265,10 @@ class AesTransport(BaseTransport):
|
||||
return await self.send_secure_passthrough(request)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the protocol."""
|
||||
client = self._default_http_client
|
||||
self._default_http_client = None
|
||||
"""Close the transport."""
|
||||
self._handshake_done = False
|
||||
self._login_token = None
|
||||
if client:
|
||||
await client.aclose()
|
||||
await self._http_client.close()
|
||||
|
||||
|
||||
class AesEncyptionSession:
|
||||
|
@ -2,13 +2,14 @@
|
||||
import logging
|
||||
from dataclasses import asdict, dataclass, field, fields, is_dataclass
|
||||
from enum import Enum
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import httpx
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Union
|
||||
|
||||
from .credentials import Credentials
|
||||
from .exceptions import SmartDeviceException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from httpx import AsyncClient
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -150,7 +151,7 @@ class DeviceConfig:
|
||||
|
||||
# compare=False will be excluded from the serialization and object comparison.
|
||||
#: Set a custom http_client for the device to use.
|
||||
http_client: Optional[httpx.AsyncClient] = field(default=None, compare=False)
|
||||
http_client: Optional["AsyncClient"] = field(default=None, compare=False)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.connection_type is None:
|
||||
|
@ -31,6 +31,10 @@ class TimeoutException(SmartDeviceException):
|
||||
"""Timeout exception for device errors."""
|
||||
|
||||
|
||||
class ConnectionException(SmartDeviceException):
|
||||
"""Connection exception for device errors."""
|
||||
|
||||
|
||||
class SmartErrorCode(IntEnum):
|
||||
"""Enum for SMART Error Codes."""
|
||||
|
||||
|
89
kasa/httpclient.py
Normal file
89
kasa/httpclient.py
Normal file
@ -0,0 +1,89 @@
|
||||
"""Module for HttpClientSession class."""
|
||||
import logging
|
||||
from typing import Any, Dict, Optional, Tuple, Type, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from .deviceconfig import DeviceConfig
|
||||
from .exceptions import ConnectionException, SmartDeviceException, TimeoutException
|
||||
|
||||
logging.getLogger("httpx").propagate = False
|
||||
|
||||
InnerHttpType = Type[httpx.AsyncClient]
|
||||
|
||||
|
||||
class HttpClient:
|
||||
"""HttpClient Class."""
|
||||
|
||||
def __init__(self, config: DeviceConfig) -> None:
|
||||
self._config = config
|
||||
self._client: httpx.AsyncClient = None
|
||||
|
||||
@property
|
||||
def client(self) -> httpx.AsyncClient:
|
||||
"""Return the underlying http client."""
|
||||
if self._config.http_client and issubclass(
|
||||
self._config.http_client.__class__, httpx.AsyncClient
|
||||
):
|
||||
return self._config.http_client
|
||||
|
||||
if not self._client:
|
||||
self._client = httpx.AsyncClient()
|
||||
return self._client
|
||||
|
||||
async def post(
|
||||
self,
|
||||
url: str,
|
||||
*,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
data: Optional[bytes] = None,
|
||||
json: Optional[Dict] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
cookies_dict: Optional[Dict[str, str]] = None,
|
||||
) -> Tuple[int, Optional[Union[Dict, bytes]]]:
|
||||
"""Send an http post request to the device."""
|
||||
response_data = None
|
||||
cookies = None
|
||||
if cookies_dict:
|
||||
cookies = httpx.Cookies()
|
||||
for name, value in cookies_dict.items():
|
||||
cookies.set(name, value)
|
||||
self.client.cookies.clear()
|
||||
try:
|
||||
resp = await self.client.post(
|
||||
url,
|
||||
params=params,
|
||||
data=data,
|
||||
json=json,
|
||||
timeout=self._config.timeout,
|
||||
cookies=cookies,
|
||||
headers=headers,
|
||||
)
|
||||
except httpx.ConnectError as ex:
|
||||
raise ConnectionException(
|
||||
f"Unable to connect to the device: {self._config.host}: {ex}"
|
||||
) from ex
|
||||
except httpx.TimeoutException as ex:
|
||||
raise TimeoutException(
|
||||
"Unable to query the device, " + f"timed out: {self._config.host}: {ex}"
|
||||
) from ex
|
||||
except Exception as ex:
|
||||
raise SmartDeviceException(
|
||||
f"Unable to query the device: {self._config.host}: {ex}"
|
||||
) from ex
|
||||
|
||||
if resp.status_code == 200:
|
||||
response_data = resp.json() if json else resp.content
|
||||
|
||||
return resp.status_code, response_data
|
||||
|
||||
def get_cookie(self, cookie_name: str) -> str:
|
||||
"""Return the cookie with cookie_name."""
|
||||
return self._client.cookies.get(cookie_name)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the protocol."""
|
||||
client = self._client
|
||||
self._client = None
|
||||
if client:
|
||||
await client.aclose()
|
@ -3,9 +3,13 @@ import asyncio
|
||||
import logging
|
||||
from typing import Dict, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from .exceptions import AuthenticationException, SmartDeviceException
|
||||
from .exceptions import (
|
||||
AuthenticationException,
|
||||
ConnectionException,
|
||||
RetryableException,
|
||||
SmartDeviceException,
|
||||
TimeoutException,
|
||||
)
|
||||
from .json import dumps as json_dumps
|
||||
from .protocol import BaseTransport, TPLinkProtocol
|
||||
|
||||
@ -15,6 +19,8 @@ _LOGGER = logging.getLogger(__name__)
|
||||
class IotProtocol(TPLinkProtocol):
|
||||
"""Class for the legacy TPLink IOT KASA Protocol."""
|
||||
|
||||
BACKOFF_SECONDS_AFTER_TIMEOUT = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@ -38,40 +44,39 @@ class IotProtocol(TPLinkProtocol):
|
||||
for retry in range(retry_count + 1):
|
||||
try:
|
||||
return await self._execute_query(request, retry)
|
||||
except httpx.ConnectError as sdex:
|
||||
except ConnectionException as sdex:
|
||||
if retry >= retry_count:
|
||||
await self.close()
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device: {self._host}: {sdex}"
|
||||
) from sdex
|
||||
raise sdex
|
||||
continue
|
||||
except TimeoutError as tex:
|
||||
await self.close()
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device, timed out: {self._host}: {tex}"
|
||||
) from tex
|
||||
except AuthenticationException as auex:
|
||||
await self.close()
|
||||
_LOGGER.debug(
|
||||
"Unable to authenticate with %s, not retrying", self._host
|
||||
)
|
||||
raise auex
|
||||
except RetryableException as ex:
|
||||
if retry >= retry_count:
|
||||
await self.close()
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise ex
|
||||
continue
|
||||
except TimeoutException as ex:
|
||||
if retry >= retry_count:
|
||||
await self.close()
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise ex
|
||||
await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT)
|
||||
continue
|
||||
except SmartDeviceException as ex:
|
||||
await self.close()
|
||||
_LOGGER.debug(
|
||||
"Unable to connect to the device: %s, not retrying: %s",
|
||||
"Unable to query the device: %s, not retrying: %s",
|
||||
self._host,
|
||||
ex,
|
||||
)
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
if retry >= retry_count:
|
||||
await self.close()
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device: {self._host}: {ex}"
|
||||
) from ex
|
||||
continue
|
||||
|
||||
# make mypy happy, this should never be reached..
|
||||
raise SmartDeviceException("Query reached somehow to unreachable")
|
||||
|
@ -48,20 +48,19 @@ import logging
|
||||
import secrets
|
||||
import time
|
||||
from pprint import pformat as pf
|
||||
from typing import Any, Optional, Tuple
|
||||
from typing import Any, Dict, Optional, Tuple, cast
|
||||
|
||||
import httpx
|
||||
from cryptography.hazmat.primitives import hashes, padding
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
|
||||
from .credentials import Credentials
|
||||
from .deviceconfig import DeviceConfig
|
||||
from .exceptions import AuthenticationException, SmartDeviceException
|
||||
from .httpclient import HttpClient
|
||||
from .json import loads as json_loads
|
||||
from .protocol import BaseTransport, md5
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
logging.getLogger("httpx").propagate = False
|
||||
|
||||
|
||||
def _sha256(payload: bytes) -> bytes:
|
||||
@ -98,7 +97,7 @@ class KlapTransport(BaseTransport):
|
||||
) -> None:
|
||||
super().__init__(config=config)
|
||||
|
||||
self._default_http_client: Optional[httpx.AsyncClient] = None
|
||||
self._http_client = HttpClient(config)
|
||||
self._local_seed: Optional[bytes] = None
|
||||
if (
|
||||
not self._credentials or self._credentials.username is None
|
||||
@ -118,7 +117,7 @@ class KlapTransport(BaseTransport):
|
||||
self._encryption_session: Optional[KlapEncryptionSession] = None
|
||||
self._session_expire_at: Optional[float] = None
|
||||
|
||||
self._session_cookie = None
|
||||
self._session_cookie: Optional[Dict[str, Any]] = None
|
||||
|
||||
_LOGGER.debug("Created KLAP transport for %s", self._host)
|
||||
|
||||
@ -132,34 +131,6 @@ class KlapTransport(BaseTransport):
|
||||
"""The hashed credentials used by the transport."""
|
||||
return base64.b64encode(self._local_auth_hash).decode()
|
||||
|
||||
@property
|
||||
def _http_client(self) -> httpx.AsyncClient:
|
||||
if self._config.http_client:
|
||||
return self._config.http_client
|
||||
if not self._default_http_client:
|
||||
self._default_http_client = httpx.AsyncClient()
|
||||
return self._default_http_client
|
||||
|
||||
async def client_post(self, url, params=None, data=None):
|
||||
"""Send an http post request to the device."""
|
||||
response_data = None
|
||||
cookies = None
|
||||
if self._session_cookie:
|
||||
cookies = httpx.Cookies()
|
||||
cookies.set(self.SESSION_COOKIE_NAME, self._session_cookie)
|
||||
self._http_client.cookies.clear()
|
||||
resp = await self._http_client.post(
|
||||
url,
|
||||
params=params,
|
||||
data=data,
|
||||
timeout=self._timeout,
|
||||
cookies=cookies,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
response_data = resp.content
|
||||
|
||||
return resp.status_code, response_data
|
||||
|
||||
async def perform_handshake1(self) -> Tuple[bytes, bytes, bytes]:
|
||||
"""Perform handshake1."""
|
||||
local_seed: bytes = secrets.token_bytes(16)
|
||||
@ -172,7 +143,7 @@ class KlapTransport(BaseTransport):
|
||||
|
||||
url = f"http://{self._host}/app/handshake1"
|
||||
|
||||
response_status, response_data = await self.client_post(url, data=payload)
|
||||
response_status, response_data = await self._http_client.post(url, data=payload)
|
||||
|
||||
if _LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_LOGGER.debug(
|
||||
@ -189,6 +160,7 @@ class KlapTransport(BaseTransport):
|
||||
f"Device {self._host} responded with {response_status} to handshake1"
|
||||
)
|
||||
|
||||
response_data = cast(bytes, response_data)
|
||||
remote_seed: bytes = response_data[0:16]
|
||||
server_hash = response_data[16:]
|
||||
|
||||
@ -268,7 +240,11 @@ class KlapTransport(BaseTransport):
|
||||
|
||||
payload = self.handshake2_seed_auth_hash(local_seed, remote_seed, auth_hash)
|
||||
|
||||
response_status, response_data = await self.client_post(url, data=payload)
|
||||
response_status, _ = await self._http_client.post(
|
||||
url,
|
||||
data=payload,
|
||||
cookies_dict=self._session_cookie,
|
||||
)
|
||||
|
||||
if _LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_LOGGER.debug(
|
||||
@ -298,9 +274,10 @@ class KlapTransport(BaseTransport):
|
||||
self._session_cookie = None
|
||||
|
||||
local_seed, remote_seed, auth_hash = await self.perform_handshake1()
|
||||
self._session_cookie = self._http_client.cookies.get( # type: ignore
|
||||
if cookie := self._http_client.get_cookie( # type: ignore
|
||||
self.SESSION_COOKIE_NAME
|
||||
)
|
||||
):
|
||||
self._session_cookie = {self.SESSION_COOKIE_NAME: cookie}
|
||||
# The device returns a TIMEOUT cookie on handshake1 which
|
||||
# it doesn't like to get back so we store the one we want
|
||||
|
||||
@ -330,10 +307,11 @@ class KlapTransport(BaseTransport):
|
||||
|
||||
url = f"http://{self._host}/app/request"
|
||||
|
||||
response_status, response_data = await self.client_post(
|
||||
response_status, response_data = await self._http_client.post(
|
||||
url,
|
||||
params={"seq": seq},
|
||||
data=payload,
|
||||
cookies_dict=self._session_cookie,
|
||||
)
|
||||
|
||||
msg = (
|
||||
@ -374,11 +352,8 @@ class KlapTransport(BaseTransport):
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the transport."""
|
||||
client = self._default_http_client
|
||||
self._default_http_client = None
|
||||
self._handshake_done = False
|
||||
if client:
|
||||
await client.aclose()
|
||||
await self._http_client.close()
|
||||
|
||||
@staticmethod
|
||||
def generate_auth_hash(creds: Credentials):
|
||||
|
@ -12,13 +12,12 @@ import uuid
|
||||
from pprint import pformat as pf
|
||||
from typing import Dict, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from .exceptions import (
|
||||
SMART_AUTHENTICATION_ERRORS,
|
||||
SMART_RETRYABLE_ERRORS,
|
||||
SMART_TIMEOUT_ERRORS,
|
||||
AuthenticationException,
|
||||
ConnectionException,
|
||||
RetryableException,
|
||||
SmartDeviceException,
|
||||
SmartErrorCode,
|
||||
@ -28,13 +27,12 @@ from .json import dumps as json_dumps
|
||||
from .protocol import BaseTransport, TPLinkProtocol, md5
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
logging.getLogger("httpx").propagate = False
|
||||
|
||||
|
||||
class SmartProtocol(TPLinkProtocol):
|
||||
"""Class for the new TPLink SMART protocol."""
|
||||
|
||||
SLEEP_SECONDS_AFTER_TIMEOUT = 1
|
||||
BACKOFF_SECONDS_AFTER_TIMEOUT = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -67,22 +65,11 @@ class SmartProtocol(TPLinkProtocol):
|
||||
for retry in range(retry_count + 1):
|
||||
try:
|
||||
return await self._execute_query(request, retry)
|
||||
except httpx.ConnectError as sdex:
|
||||
except ConnectionException as sdex:
|
||||
if retry >= retry_count:
|
||||
await self.close()
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device: {self._host}: {sdex}"
|
||||
) from sdex
|
||||
continue
|
||||
except TimeoutError as tex:
|
||||
if retry >= retry_count:
|
||||
await self.close()
|
||||
raise SmartDeviceException(
|
||||
"Unable to connect to the device, "
|
||||
+ f"timed out: {self._host}: {tex}"
|
||||
) from tex
|
||||
await asyncio.sleep(self.SLEEP_SECONDS_AFTER_TIMEOUT)
|
||||
raise sdex
|
||||
continue
|
||||
except AuthenticationException as auex:
|
||||
await self.close()
|
||||
@ -101,24 +88,16 @@ class SmartProtocol(TPLinkProtocol):
|
||||
await self.close()
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise ex
|
||||
await asyncio.sleep(self.SLEEP_SECONDS_AFTER_TIMEOUT)
|
||||
await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT)
|
||||
continue
|
||||
except SmartDeviceException as ex:
|
||||
# Transport would have raised RetryableException if retry makes sense.
|
||||
await self.close()
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
if retry >= retry_count:
|
||||
await self.close()
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device: {self._host}: {ex}"
|
||||
) from ex
|
||||
_LOGGER.debug(
|
||||
"Unable to query the device %s, retrying: %s", self._host, ex
|
||||
"Unable to query the device: %s, not retrying: %s",
|
||||
self._host,
|
||||
ex,
|
||||
)
|
||||
continue
|
||||
raise ex
|
||||
|
||||
# make mypy happy, this should never be reached..
|
||||
raise SmartDeviceException("Query reached somehow to unreachable")
|
||||
|
@ -145,7 +145,7 @@ async def test_connect_http_client(all_fixture_data, mocker):
|
||||
)
|
||||
dev = await connect(config=config)
|
||||
if ctype.encryption_type != EncryptType.Xor:
|
||||
assert dev.protocol._transport._http_client != http_client
|
||||
assert dev.protocol._transport._http_client.client != http_client
|
||||
|
||||
config = DeviceConfig(
|
||||
host=host,
|
||||
@ -155,4 +155,4 @@ async def test_connect_http_client(all_fixture_data, mocker):
|
||||
)
|
||||
dev = await connect(config=config)
|
||||
if ctype.encryption_type != EncryptType.Xor:
|
||||
assert dev.protocol._transport._http_client == http_client
|
||||
assert dev.protocol._transport._http_client.client == http_client
|
||||
|
@ -321,9 +321,9 @@ async def test_discover_single_http_client(discovery_mock, mocker):
|
||||
assert x.config.uses_http == (discovery_mock.default_port == 80)
|
||||
|
||||
if discovery_mock.default_port == 80:
|
||||
assert x.protocol._transport._http_client != http_client
|
||||
assert x.protocol._transport._http_client.client != http_client
|
||||
x.config.http_client = http_client
|
||||
assert x.protocol._transport._http_client == http_client
|
||||
assert x.protocol._transport._http_client.client == http_client
|
||||
|
||||
|
||||
async def test_discover_http_client(discovery_mock, mocker):
|
||||
@ -338,6 +338,6 @@ async def test_discover_http_client(discovery_mock, mocker):
|
||||
assert x.config.uses_http == (discovery_mock.default_port == 80)
|
||||
|
||||
if discovery_mock.default_port == 80:
|
||||
assert x.protocol._transport._http_client != http_client
|
||||
assert x.protocol._transport._http_client.client != http_client
|
||||
x.config.http_client = http_client
|
||||
assert x.protocol._transport._http_client == http_client
|
||||
assert x.protocol._transport._http_client.client == http_client
|
||||
|
@ -13,7 +13,12 @@ import pytest
|
||||
from ..aestransport import AesTransport
|
||||
from ..credentials import Credentials
|
||||
from ..deviceconfig import DeviceConfig
|
||||
from ..exceptions import AuthenticationException, SmartDeviceException
|
||||
from ..exceptions import (
|
||||
AuthenticationException,
|
||||
ConnectionException,
|
||||
SmartDeviceException,
|
||||
)
|
||||
from ..httpclient import HttpClient
|
||||
from ..iotprotocol import IotProtocol
|
||||
from ..klaptransport import (
|
||||
KlapEncryptionSession,
|
||||
@ -35,8 +40,8 @@ class _mock_response:
|
||||
@pytest.mark.parametrize(
|
||||
"error, retry_expectation",
|
||||
[
|
||||
(Exception("dummy exception"), True),
|
||||
(SmartDeviceException("dummy exception"), False),
|
||||
(Exception("dummy exception"), False),
|
||||
(httpx.TimeoutException("dummy exception"), True),
|
||||
(httpx.ConnectError("dummy exception"), True),
|
||||
],
|
||||
ids=("Exception", "SmartDeviceException", "httpx.ConnectError"),
|
||||
@ -89,7 +94,7 @@ async def test_protocol_retry_recoverable_error(
|
||||
conn = mocker.patch.object(
|
||||
httpx.AsyncClient,
|
||||
"post",
|
||||
side_effect=httpx.CloseError("foo"),
|
||||
side_effect=httpx.ConnectError("foo"),
|
||||
)
|
||||
config = DeviceConfig(host)
|
||||
with pytest.raises(SmartDeviceException):
|
||||
@ -112,7 +117,7 @@ async def test_protocol_reconnect(mocker, retry_count, protocol_class, transport
|
||||
nonlocal remaining
|
||||
remaining -= 1
|
||||
if remaining:
|
||||
raise Exception("Simulated post failure")
|
||||
raise ConnectionException("Simulated connection failure")
|
||||
|
||||
return mock_response
|
||||
|
||||
@ -155,7 +160,7 @@ async def test_protocol_logging(mocker, caplog, log_level):
|
||||
protocol._transport._handshake_done = True
|
||||
protocol._transport._session_expire_at = time.time() + 86400
|
||||
protocol._transport._encryption_session = encryption_session
|
||||
mocker.patch.object(KlapTransport, "client_post", side_effect=_return_encrypted)
|
||||
mocker.patch.object(HttpClient, "post", side_effect=_return_encrypted)
|
||||
|
||||
response = await protocol.query({})
|
||||
assert response == {"great": "success"}
|
||||
|
Loading…
Reference in New Issue
Block a user