Encapsulate http client dependency (#642)

* Encapsulate http client dependency

* Store cookie dict as variable

* Update post-review
This commit is contained in:
Steven B 2024-01-18 09:57:33 +00:00 committed by GitHub
parent 4623434eb4
commit 3b1b0a3c21
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 194 additions and 156 deletions

View File

@ -30,7 +30,6 @@ from dataclasses import asdict, dataclass
from typing import List, Optional, Union from typing import List, Optional, Union
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
logging.getLogger("httpx").propagate = False
class SmartRequest: class SmartRequest:

View File

@ -8,9 +8,8 @@ import base64
import hashlib import hashlib
import logging import logging
import time import time
from typing import Optional from typing import Dict, Optional, cast
import httpx
from cryptography.hazmat.primitives import padding, serialization from cryptography.hazmat.primitives import padding, serialization
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.asymmetric import rsa
@ -28,6 +27,7 @@ from .exceptions import (
SmartErrorCode, SmartErrorCode,
TimeoutException, TimeoutException,
) )
from .httpclient import HttpClient
from .json import dumps as json_dumps from .json import dumps as json_dumps
from .json import loads as json_loads from .json import loads as json_loads
from .protocol import BaseTransport from .protocol import BaseTransport
@ -75,14 +75,14 @@ class AesTransport(BaseTransport):
base64.b64decode(self._credentials_hash.encode()).decode() # type: ignore[union-attr] base64.b64decode(self._credentials_hash.encode()).decode() # type: ignore[union-attr]
) )
self._default_http_client: Optional[httpx.AsyncClient] = None self._http_client: HttpClient = HttpClient(config)
self._handshake_done = False self._handshake_done = False
self._encryption_session: Optional[AesEncyptionSession] = None self._encryption_session: Optional[AesEncyptionSession] = None
self._session_expire_at: Optional[float] = None self._session_expire_at: Optional[float] = None
self._session_cookie = None self._session_cookie: Optional[Dict[str, str]] = None
self._login_token = None self._login_token = None
@ -98,14 +98,6 @@ class AesTransport(BaseTransport):
"""The hashed credentials used by the transport.""" """The hashed credentials used by the transport."""
return base64.b64encode(json_dumps(self._login_params).encode()).decode() return base64.b64encode(json_dumps(self._login_params).encode()).decode()
@property
def _http_client(self) -> httpx.AsyncClient:
if self._config.http_client:
return self._config.http_client
if not self._default_http_client:
self._default_http_client = httpx.AsyncClient()
return self._default_http_client
def _get_login_params(self): def _get_login_params(self):
"""Get the login parameters based on the login_version.""" """Get the login parameters based on the login_version."""
un, pw = self.hash_credentials(self._login_version == 2) un, pw = self.hash_credentials(self._login_version == 2)
@ -128,28 +120,6 @@ class AesTransport(BaseTransport):
pw = base64.b64encode(self._credentials.password.encode()).decode() pw = base64.b64encode(self._credentials.password.encode()).decode()
return un, pw return un, pw
async def client_post(self, url, params=None, data=None, json=None, headers=None):
"""Send an http post request to the device."""
response_data = None
cookies = None
if self._session_cookie:
cookies = httpx.Cookies()
cookies.set(self.SESSION_COOKIE_NAME, self._session_cookie)
self._http_client.cookies.clear()
resp = await self._http_client.post(
url,
params=params,
data=data,
json=json,
timeout=self._timeout,
cookies=cookies,
headers=self.COMMON_HEADERS,
)
if resp.status_code == 200:
response_data = resp.json()
return resp.status_code, response_data
def _handle_response_error_code(self, resp_dict: dict, msg: str): def _handle_response_error_code(self, resp_dict: dict, msg: str):
error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type] error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
if error_code == SmartErrorCode.SUCCESS: if error_code == SmartErrorCode.SUCCESS:
@ -176,7 +146,12 @@ class AesTransport(BaseTransport):
"method": "securePassthrough", "method": "securePassthrough",
"params": {"request": encrypted_payload.decode()}, "params": {"request": encrypted_payload.decode()},
} }
status_code, resp_dict = await self.client_post(url, json=passthrough_request) status_code, resp_dict = await self._http_client.post(
url,
json=passthrough_request,
headers=self.COMMON_HEADERS,
cookies_dict=self._session_cookie,
)
# _LOGGER.debug(f"secure_passthrough response is {status_code}: {resp_dict}") # _LOGGER.debug(f"secure_passthrough response is {status_code}: {resp_dict}")
if status_code != 200: if status_code != 200:
@ -185,6 +160,7 @@ class AesTransport(BaseTransport):
+ f"status code {status_code} to passthrough" + f"status code {status_code} to passthrough"
) )
resp_dict = cast(Dict, resp_dict)
self._handle_response_error_code( self._handle_response_error_code(
resp_dict, "Error sending secure_passthrough message" resp_dict, "Error sending secure_passthrough message"
) )
@ -233,7 +209,12 @@ class AesTransport(BaseTransport):
_LOGGER.debug(f"Request {request_body}") _LOGGER.debug(f"Request {request_body}")
status_code, resp_dict = await self.client_post(url, json=request_body) status_code, resp_dict = await self._http_client.post(
url,
json=request_body,
headers=self.COMMON_HEADERS,
cookies_dict=self._session_cookie,
)
_LOGGER.debug(f"Device responded with: {resp_dict}") _LOGGER.debug(f"Device responded with: {resp_dict}")
@ -247,13 +228,16 @@ class AesTransport(BaseTransport):
handshake_key = resp_dict["result"]["key"] handshake_key = resp_dict["result"]["key"]
self._session_cookie = self._http_client.cookies.get( # type: ignore if (
self.SESSION_COOKIE_NAME cookie := self._http_client.get_cookie( # type: ignore
) self.SESSION_COOKIE_NAME
if not self._session_cookie: )
self._session_cookie = self._http_client.cookies.get( # type: ignore ) or (
cookie := self._http_client.get_cookie( # type: ignore
"SESSIONID" "SESSIONID"
) )
):
self._session_cookie = {self.SESSION_COOKIE_NAME: cookie}
self._session_expire_at = time.time() + 86400 self._session_expire_at = time.time() + 86400
self._encryption_session = AesEncyptionSession.create_from_keypair( self._encryption_session = AesEncyptionSession.create_from_keypair(
@ -281,13 +265,10 @@ class AesTransport(BaseTransport):
return await self.send_secure_passthrough(request) return await self.send_secure_passthrough(request)
async def close(self) -> None: async def close(self) -> None:
"""Close the protocol.""" """Close the transport."""
client = self._default_http_client
self._default_http_client = None
self._handshake_done = False self._handshake_done = False
self._login_token = None self._login_token = None
if client: await self._http_client.close()
await client.aclose()
class AesEncyptionSession: class AesEncyptionSession:

View File

@ -2,13 +2,14 @@
import logging import logging
from dataclasses import asdict, dataclass, field, fields, is_dataclass from dataclasses import asdict, dataclass, field, fields, is_dataclass
from enum import Enum from enum import Enum
from typing import Dict, Optional, Union from typing import TYPE_CHECKING, Dict, Optional, Union
import httpx
from .credentials import Credentials from .credentials import Credentials
from .exceptions import SmartDeviceException from .exceptions import SmartDeviceException
if TYPE_CHECKING:
from httpx import AsyncClient
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -150,7 +151,7 @@ class DeviceConfig:
# compare=False will be excluded from the serialization and object comparison. # compare=False will be excluded from the serialization and object comparison.
#: Set a custom http_client for the device to use. #: Set a custom http_client for the device to use.
http_client: Optional[httpx.AsyncClient] = field(default=None, compare=False) http_client: Optional["AsyncClient"] = field(default=None, compare=False)
def __post_init__(self): def __post_init__(self):
if self.connection_type is None: if self.connection_type is None:

View File

@ -31,6 +31,10 @@ class TimeoutException(SmartDeviceException):
"""Timeout exception for device errors.""" """Timeout exception for device errors."""
class ConnectionException(SmartDeviceException):
"""Connection exception for device errors."""
class SmartErrorCode(IntEnum): class SmartErrorCode(IntEnum):
"""Enum for SMART Error Codes.""" """Enum for SMART Error Codes."""

89
kasa/httpclient.py Normal file
View File

@ -0,0 +1,89 @@
"""Module for HttpClientSession class."""
import logging
from typing import Any, Dict, Optional, Tuple, Type, Union
import httpx
from .deviceconfig import DeviceConfig
from .exceptions import ConnectionException, SmartDeviceException, TimeoutException
logging.getLogger("httpx").propagate = False
InnerHttpType = Type[httpx.AsyncClient]
class HttpClient:
"""HttpClient Class."""
def __init__(self, config: DeviceConfig) -> None:
self._config = config
self._client: httpx.AsyncClient = None
@property
def client(self) -> httpx.AsyncClient:
"""Return the underlying http client."""
if self._config.http_client and issubclass(
self._config.http_client.__class__, httpx.AsyncClient
):
return self._config.http_client
if not self._client:
self._client = httpx.AsyncClient()
return self._client
async def post(
self,
url: str,
*,
params: Optional[Dict[str, Any]] = None,
data: Optional[bytes] = None,
json: Optional[Dict] = None,
headers: Optional[Dict[str, str]] = None,
cookies_dict: Optional[Dict[str, str]] = None,
) -> Tuple[int, Optional[Union[Dict, bytes]]]:
"""Send an http post request to the device."""
response_data = None
cookies = None
if cookies_dict:
cookies = httpx.Cookies()
for name, value in cookies_dict.items():
cookies.set(name, value)
self.client.cookies.clear()
try:
resp = await self.client.post(
url,
params=params,
data=data,
json=json,
timeout=self._config.timeout,
cookies=cookies,
headers=headers,
)
except httpx.ConnectError as ex:
raise ConnectionException(
f"Unable to connect to the device: {self._config.host}: {ex}"
) from ex
except httpx.TimeoutException as ex:
raise TimeoutException(
"Unable to query the device, " + f"timed out: {self._config.host}: {ex}"
) from ex
except Exception as ex:
raise SmartDeviceException(
f"Unable to query the device: {self._config.host}: {ex}"
) from ex
if resp.status_code == 200:
response_data = resp.json() if json else resp.content
return resp.status_code, response_data
def get_cookie(self, cookie_name: str) -> str:
"""Return the cookie with cookie_name."""
return self._client.cookies.get(cookie_name)
async def close(self) -> None:
"""Close the protocol."""
client = self._client
self._client = None
if client:
await client.aclose()

View File

@ -3,9 +3,13 @@ import asyncio
import logging import logging
from typing import Dict, Union from typing import Dict, Union
import httpx from .exceptions import (
AuthenticationException,
from .exceptions import AuthenticationException, SmartDeviceException ConnectionException,
RetryableException,
SmartDeviceException,
TimeoutException,
)
from .json import dumps as json_dumps from .json import dumps as json_dumps
from .protocol import BaseTransport, TPLinkProtocol from .protocol import BaseTransport, TPLinkProtocol
@ -15,6 +19,8 @@ _LOGGER = logging.getLogger(__name__)
class IotProtocol(TPLinkProtocol): class IotProtocol(TPLinkProtocol):
"""Class for the legacy TPLink IOT KASA Protocol.""" """Class for the legacy TPLink IOT KASA Protocol."""
BACKOFF_SECONDS_AFTER_TIMEOUT = 1
def __init__( def __init__(
self, self,
*, *,
@ -38,40 +44,39 @@ class IotProtocol(TPLinkProtocol):
for retry in range(retry_count + 1): for retry in range(retry_count + 1):
try: try:
return await self._execute_query(request, retry) return await self._execute_query(request, retry)
except httpx.ConnectError as sdex: except ConnectionException as sdex:
if retry >= retry_count: if retry >= retry_count:
await self.close() await self.close()
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry) _LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise SmartDeviceException( raise sdex
f"Unable to connect to the device: {self._host}: {sdex}"
) from sdex
continue continue
except TimeoutError as tex:
await self.close()
raise SmartDeviceException(
f"Unable to connect to the device, timed out: {self._host}: {tex}"
) from tex
except AuthenticationException as auex: except AuthenticationException as auex:
await self.close() await self.close()
_LOGGER.debug( _LOGGER.debug(
"Unable to authenticate with %s, not retrying", self._host "Unable to authenticate with %s, not retrying", self._host
) )
raise auex raise auex
except RetryableException as ex:
if retry >= retry_count:
await self.close()
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise ex
continue
except TimeoutException as ex:
if retry >= retry_count:
await self.close()
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise ex
await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT)
continue
except SmartDeviceException as ex: except SmartDeviceException as ex:
await self.close()
_LOGGER.debug( _LOGGER.debug(
"Unable to connect to the device: %s, not retrying: %s", "Unable to query the device: %s, not retrying: %s",
self._host, self._host,
ex, ex,
) )
raise ex raise ex
except Exception as ex:
if retry >= retry_count:
await self.close()
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise SmartDeviceException(
f"Unable to connect to the device: {self._host}: {ex}"
) from ex
continue
# make mypy happy, this should never be reached.. # make mypy happy, this should never be reached..
raise SmartDeviceException("Query reached somehow to unreachable") raise SmartDeviceException("Query reached somehow to unreachable")

View File

@ -48,20 +48,19 @@ import logging
import secrets import secrets
import time import time
from pprint import pformat as pf from pprint import pformat as pf
from typing import Any, Optional, Tuple from typing import Any, Dict, Optional, Tuple, cast
import httpx
from cryptography.hazmat.primitives import hashes, padding from cryptography.hazmat.primitives import hashes, padding
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from .credentials import Credentials from .credentials import Credentials
from .deviceconfig import DeviceConfig from .deviceconfig import DeviceConfig
from .exceptions import AuthenticationException, SmartDeviceException from .exceptions import AuthenticationException, SmartDeviceException
from .httpclient import HttpClient
from .json import loads as json_loads from .json import loads as json_loads
from .protocol import BaseTransport, md5 from .protocol import BaseTransport, md5
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
logging.getLogger("httpx").propagate = False
def _sha256(payload: bytes) -> bytes: def _sha256(payload: bytes) -> bytes:
@ -98,7 +97,7 @@ class KlapTransport(BaseTransport):
) -> None: ) -> None:
super().__init__(config=config) super().__init__(config=config)
self._default_http_client: Optional[httpx.AsyncClient] = None self._http_client = HttpClient(config)
self._local_seed: Optional[bytes] = None self._local_seed: Optional[bytes] = None
if ( if (
not self._credentials or self._credentials.username is None not self._credentials or self._credentials.username is None
@ -118,7 +117,7 @@ class KlapTransport(BaseTransport):
self._encryption_session: Optional[KlapEncryptionSession] = None self._encryption_session: Optional[KlapEncryptionSession] = None
self._session_expire_at: Optional[float] = None self._session_expire_at: Optional[float] = None
self._session_cookie = None self._session_cookie: Optional[Dict[str, Any]] = None
_LOGGER.debug("Created KLAP transport for %s", self._host) _LOGGER.debug("Created KLAP transport for %s", self._host)
@ -132,34 +131,6 @@ class KlapTransport(BaseTransport):
"""The hashed credentials used by the transport.""" """The hashed credentials used by the transport."""
return base64.b64encode(self._local_auth_hash).decode() return base64.b64encode(self._local_auth_hash).decode()
@property
def _http_client(self) -> httpx.AsyncClient:
if self._config.http_client:
return self._config.http_client
if not self._default_http_client:
self._default_http_client = httpx.AsyncClient()
return self._default_http_client
async def client_post(self, url, params=None, data=None):
"""Send an http post request to the device."""
response_data = None
cookies = None
if self._session_cookie:
cookies = httpx.Cookies()
cookies.set(self.SESSION_COOKIE_NAME, self._session_cookie)
self._http_client.cookies.clear()
resp = await self._http_client.post(
url,
params=params,
data=data,
timeout=self._timeout,
cookies=cookies,
)
if resp.status_code == 200:
response_data = resp.content
return resp.status_code, response_data
async def perform_handshake1(self) -> Tuple[bytes, bytes, bytes]: async def perform_handshake1(self) -> Tuple[bytes, bytes, bytes]:
"""Perform handshake1.""" """Perform handshake1."""
local_seed: bytes = secrets.token_bytes(16) local_seed: bytes = secrets.token_bytes(16)
@ -172,7 +143,7 @@ class KlapTransport(BaseTransport):
url = f"http://{self._host}/app/handshake1" url = f"http://{self._host}/app/handshake1"
response_status, response_data = await self.client_post(url, data=payload) response_status, response_data = await self._http_client.post(url, data=payload)
if _LOGGER.isEnabledFor(logging.DEBUG): if _LOGGER.isEnabledFor(logging.DEBUG):
_LOGGER.debug( _LOGGER.debug(
@ -189,6 +160,7 @@ class KlapTransport(BaseTransport):
f"Device {self._host} responded with {response_status} to handshake1" f"Device {self._host} responded with {response_status} to handshake1"
) )
response_data = cast(bytes, response_data)
remote_seed: bytes = response_data[0:16] remote_seed: bytes = response_data[0:16]
server_hash = response_data[16:] server_hash = response_data[16:]
@ -268,7 +240,11 @@ class KlapTransport(BaseTransport):
payload = self.handshake2_seed_auth_hash(local_seed, remote_seed, auth_hash) payload = self.handshake2_seed_auth_hash(local_seed, remote_seed, auth_hash)
response_status, response_data = await self.client_post(url, data=payload) response_status, _ = await self._http_client.post(
url,
data=payload,
cookies_dict=self._session_cookie,
)
if _LOGGER.isEnabledFor(logging.DEBUG): if _LOGGER.isEnabledFor(logging.DEBUG):
_LOGGER.debug( _LOGGER.debug(
@ -298,9 +274,10 @@ class KlapTransport(BaseTransport):
self._session_cookie = None self._session_cookie = None
local_seed, remote_seed, auth_hash = await self.perform_handshake1() local_seed, remote_seed, auth_hash = await self.perform_handshake1()
self._session_cookie = self._http_client.cookies.get( # type: ignore if cookie := self._http_client.get_cookie( # type: ignore
self.SESSION_COOKIE_NAME self.SESSION_COOKIE_NAME
) ):
self._session_cookie = {self.SESSION_COOKIE_NAME: cookie}
# The device returns a TIMEOUT cookie on handshake1 which # The device returns a TIMEOUT cookie on handshake1 which
# it doesn't like to get back so we store the one we want # it doesn't like to get back so we store the one we want
@ -330,10 +307,11 @@ class KlapTransport(BaseTransport):
url = f"http://{self._host}/app/request" url = f"http://{self._host}/app/request"
response_status, response_data = await self.client_post( response_status, response_data = await self._http_client.post(
url, url,
params={"seq": seq}, params={"seq": seq},
data=payload, data=payload,
cookies_dict=self._session_cookie,
) )
msg = ( msg = (
@ -374,11 +352,8 @@ class KlapTransport(BaseTransport):
async def close(self) -> None: async def close(self) -> None:
"""Close the transport.""" """Close the transport."""
client = self._default_http_client
self._default_http_client = None
self._handshake_done = False self._handshake_done = False
if client: await self._http_client.close()
await client.aclose()
@staticmethod @staticmethod
def generate_auth_hash(creds: Credentials): def generate_auth_hash(creds: Credentials):

View File

@ -12,13 +12,12 @@ import uuid
from pprint import pformat as pf from pprint import pformat as pf
from typing import Dict, Union from typing import Dict, Union
import httpx
from .exceptions import ( from .exceptions import (
SMART_AUTHENTICATION_ERRORS, SMART_AUTHENTICATION_ERRORS,
SMART_RETRYABLE_ERRORS, SMART_RETRYABLE_ERRORS,
SMART_TIMEOUT_ERRORS, SMART_TIMEOUT_ERRORS,
AuthenticationException, AuthenticationException,
ConnectionException,
RetryableException, RetryableException,
SmartDeviceException, SmartDeviceException,
SmartErrorCode, SmartErrorCode,
@ -28,13 +27,12 @@ from .json import dumps as json_dumps
from .protocol import BaseTransport, TPLinkProtocol, md5 from .protocol import BaseTransport, TPLinkProtocol, md5
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
logging.getLogger("httpx").propagate = False
class SmartProtocol(TPLinkProtocol): class SmartProtocol(TPLinkProtocol):
"""Class for the new TPLink SMART protocol.""" """Class for the new TPLink SMART protocol."""
SLEEP_SECONDS_AFTER_TIMEOUT = 1 BACKOFF_SECONDS_AFTER_TIMEOUT = 1
def __init__( def __init__(
self, self,
@ -67,22 +65,11 @@ class SmartProtocol(TPLinkProtocol):
for retry in range(retry_count + 1): for retry in range(retry_count + 1):
try: try:
return await self._execute_query(request, retry) return await self._execute_query(request, retry)
except httpx.ConnectError as sdex: except ConnectionException as sdex:
if retry >= retry_count: if retry >= retry_count:
await self.close() await self.close()
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry) _LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise SmartDeviceException( raise sdex
f"Unable to connect to the device: {self._host}: {sdex}"
) from sdex
continue
except TimeoutError as tex:
if retry >= retry_count:
await self.close()
raise SmartDeviceException(
"Unable to connect to the device, "
+ f"timed out: {self._host}: {tex}"
) from tex
await asyncio.sleep(self.SLEEP_SECONDS_AFTER_TIMEOUT)
continue continue
except AuthenticationException as auex: except AuthenticationException as auex:
await self.close() await self.close()
@ -101,24 +88,16 @@ class SmartProtocol(TPLinkProtocol):
await self.close() await self.close()
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry) _LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise ex raise ex
await asyncio.sleep(self.SLEEP_SECONDS_AFTER_TIMEOUT) await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT)
continue continue
except SmartDeviceException as ex: except SmartDeviceException as ex:
# Transport would have raised RetryableException if retry makes sense.
await self.close() await self.close()
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise ex
except Exception as ex:
if retry >= retry_count:
await self.close()
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise SmartDeviceException(
f"Unable to connect to the device: {self._host}: {ex}"
) from ex
_LOGGER.debug( _LOGGER.debug(
"Unable to query the device %s, retrying: %s", self._host, ex "Unable to query the device: %s, not retrying: %s",
self._host,
ex,
) )
continue raise ex
# make mypy happy, this should never be reached.. # make mypy happy, this should never be reached..
raise SmartDeviceException("Query reached somehow to unreachable") raise SmartDeviceException("Query reached somehow to unreachable")

View File

@ -145,7 +145,7 @@ async def test_connect_http_client(all_fixture_data, mocker):
) )
dev = await connect(config=config) dev = await connect(config=config)
if ctype.encryption_type != EncryptType.Xor: if ctype.encryption_type != EncryptType.Xor:
assert dev.protocol._transport._http_client != http_client assert dev.protocol._transport._http_client.client != http_client
config = DeviceConfig( config = DeviceConfig(
host=host, host=host,
@ -155,4 +155,4 @@ async def test_connect_http_client(all_fixture_data, mocker):
) )
dev = await connect(config=config) dev = await connect(config=config)
if ctype.encryption_type != EncryptType.Xor: if ctype.encryption_type != EncryptType.Xor:
assert dev.protocol._transport._http_client == http_client assert dev.protocol._transport._http_client.client == http_client

View File

@ -321,9 +321,9 @@ async def test_discover_single_http_client(discovery_mock, mocker):
assert x.config.uses_http == (discovery_mock.default_port == 80) assert x.config.uses_http == (discovery_mock.default_port == 80)
if discovery_mock.default_port == 80: if discovery_mock.default_port == 80:
assert x.protocol._transport._http_client != http_client assert x.protocol._transport._http_client.client != http_client
x.config.http_client = http_client x.config.http_client = http_client
assert x.protocol._transport._http_client == http_client assert x.protocol._transport._http_client.client == http_client
async def test_discover_http_client(discovery_mock, mocker): async def test_discover_http_client(discovery_mock, mocker):
@ -338,6 +338,6 @@ async def test_discover_http_client(discovery_mock, mocker):
assert x.config.uses_http == (discovery_mock.default_port == 80) assert x.config.uses_http == (discovery_mock.default_port == 80)
if discovery_mock.default_port == 80: if discovery_mock.default_port == 80:
assert x.protocol._transport._http_client != http_client assert x.protocol._transport._http_client.client != http_client
x.config.http_client = http_client x.config.http_client = http_client
assert x.protocol._transport._http_client == http_client assert x.protocol._transport._http_client.client == http_client

View File

@ -13,7 +13,12 @@ import pytest
from ..aestransport import AesTransport from ..aestransport import AesTransport
from ..credentials import Credentials from ..credentials import Credentials
from ..deviceconfig import DeviceConfig from ..deviceconfig import DeviceConfig
from ..exceptions import AuthenticationException, SmartDeviceException from ..exceptions import (
AuthenticationException,
ConnectionException,
SmartDeviceException,
)
from ..httpclient import HttpClient
from ..iotprotocol import IotProtocol from ..iotprotocol import IotProtocol
from ..klaptransport import ( from ..klaptransport import (
KlapEncryptionSession, KlapEncryptionSession,
@ -35,8 +40,8 @@ class _mock_response:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"error, retry_expectation", "error, retry_expectation",
[ [
(Exception("dummy exception"), True), (Exception("dummy exception"), False),
(SmartDeviceException("dummy exception"), False), (httpx.TimeoutException("dummy exception"), True),
(httpx.ConnectError("dummy exception"), True), (httpx.ConnectError("dummy exception"), True),
], ],
ids=("Exception", "SmartDeviceException", "httpx.ConnectError"), ids=("Exception", "SmartDeviceException", "httpx.ConnectError"),
@ -89,7 +94,7 @@ async def test_protocol_retry_recoverable_error(
conn = mocker.patch.object( conn = mocker.patch.object(
httpx.AsyncClient, httpx.AsyncClient,
"post", "post",
side_effect=httpx.CloseError("foo"), side_effect=httpx.ConnectError("foo"),
) )
config = DeviceConfig(host) config = DeviceConfig(host)
with pytest.raises(SmartDeviceException): with pytest.raises(SmartDeviceException):
@ -112,7 +117,7 @@ async def test_protocol_reconnect(mocker, retry_count, protocol_class, transport
nonlocal remaining nonlocal remaining
remaining -= 1 remaining -= 1
if remaining: if remaining:
raise Exception("Simulated post failure") raise ConnectionException("Simulated connection failure")
return mock_response return mock_response
@ -155,7 +160,7 @@ async def test_protocol_logging(mocker, caplog, log_level):
protocol._transport._handshake_done = True protocol._transport._handshake_done = True
protocol._transport._session_expire_at = time.time() + 86400 protocol._transport._session_expire_at = time.time() + 86400
protocol._transport._encryption_session = encryption_session protocol._transport._encryption_session = encryption_session
mocker.patch.object(KlapTransport, "client_post", side_effect=_return_encrypted) mocker.patch.object(HttpClient, "post", side_effect=_return_encrypted)
response = await protocol.query({}) response = await protocol.query({})
assert response == {"great": "success"} assert response == {"great": "success"}