Encapsulate http client dependency (#642)

* Encapsulate http client dependency

* Store cookie dict as variable

* Update post-review
This commit is contained in:
Steven B 2024-01-18 09:57:33 +00:00 committed by GitHub
parent 4623434eb4
commit 3b1b0a3c21
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 194 additions and 156 deletions

View File

@ -30,7 +30,6 @@ from dataclasses import asdict, dataclass
from typing import List, Optional, Union
_LOGGER = logging.getLogger(__name__)
logging.getLogger("httpx").propagate = False
class SmartRequest:

View File

@ -8,9 +8,8 @@ import base64
import hashlib
import logging
import time
from typing import Optional
from typing import Dict, Optional, cast
import httpx
from cryptography.hazmat.primitives import padding, serialization
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
from cryptography.hazmat.primitives.asymmetric import rsa
@ -28,6 +27,7 @@ from .exceptions import (
SmartErrorCode,
TimeoutException,
)
from .httpclient import HttpClient
from .json import dumps as json_dumps
from .json import loads as json_loads
from .protocol import BaseTransport
@ -75,14 +75,14 @@ class AesTransport(BaseTransport):
base64.b64decode(self._credentials_hash.encode()).decode() # type: ignore[union-attr]
)
self._default_http_client: Optional[httpx.AsyncClient] = None
self._http_client: HttpClient = HttpClient(config)
self._handshake_done = False
self._encryption_session: Optional[AesEncyptionSession] = None
self._session_expire_at: Optional[float] = None
self._session_cookie = None
self._session_cookie: Optional[Dict[str, str]] = None
self._login_token = None
@ -98,14 +98,6 @@ class AesTransport(BaseTransport):
"""The hashed credentials used by the transport."""
return base64.b64encode(json_dumps(self._login_params).encode()).decode()
@property
def _http_client(self) -> httpx.AsyncClient:
if self._config.http_client:
return self._config.http_client
if not self._default_http_client:
self._default_http_client = httpx.AsyncClient()
return self._default_http_client
def _get_login_params(self):
"""Get the login parameters based on the login_version."""
un, pw = self.hash_credentials(self._login_version == 2)
@ -128,28 +120,6 @@ class AesTransport(BaseTransport):
pw = base64.b64encode(self._credentials.password.encode()).decode()
return un, pw
async def client_post(self, url, params=None, data=None, json=None, headers=None):
"""Send an http post request to the device."""
response_data = None
cookies = None
if self._session_cookie:
cookies = httpx.Cookies()
cookies.set(self.SESSION_COOKIE_NAME, self._session_cookie)
self._http_client.cookies.clear()
resp = await self._http_client.post(
url,
params=params,
data=data,
json=json,
timeout=self._timeout,
cookies=cookies,
headers=self.COMMON_HEADERS,
)
if resp.status_code == 200:
response_data = resp.json()
return resp.status_code, response_data
def _handle_response_error_code(self, resp_dict: dict, msg: str):
error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
if error_code == SmartErrorCode.SUCCESS:
@ -176,7 +146,12 @@ class AesTransport(BaseTransport):
"method": "securePassthrough",
"params": {"request": encrypted_payload.decode()},
}
status_code, resp_dict = await self.client_post(url, json=passthrough_request)
status_code, resp_dict = await self._http_client.post(
url,
json=passthrough_request,
headers=self.COMMON_HEADERS,
cookies_dict=self._session_cookie,
)
# _LOGGER.debug(f"secure_passthrough response is {status_code}: {resp_dict}")
if status_code != 200:
@ -185,6 +160,7 @@ class AesTransport(BaseTransport):
+ f"status code {status_code} to passthrough"
)
resp_dict = cast(Dict, resp_dict)
self._handle_response_error_code(
resp_dict, "Error sending secure_passthrough message"
)
@ -233,7 +209,12 @@ class AesTransport(BaseTransport):
_LOGGER.debug(f"Request {request_body}")
status_code, resp_dict = await self.client_post(url, json=request_body)
status_code, resp_dict = await self._http_client.post(
url,
json=request_body,
headers=self.COMMON_HEADERS,
cookies_dict=self._session_cookie,
)
_LOGGER.debug(f"Device responded with: {resp_dict}")
@ -247,13 +228,16 @@ class AesTransport(BaseTransport):
handshake_key = resp_dict["result"]["key"]
self._session_cookie = self._http_client.cookies.get( # type: ignore
if (
cookie := self._http_client.get_cookie( # type: ignore
self.SESSION_COOKIE_NAME
)
if not self._session_cookie:
self._session_cookie = self._http_client.cookies.get( # type: ignore
) or (
cookie := self._http_client.get_cookie( # type: ignore
"SESSIONID"
)
):
self._session_cookie = {self.SESSION_COOKIE_NAME: cookie}
self._session_expire_at = time.time() + 86400
self._encryption_session = AesEncyptionSession.create_from_keypair(
@ -281,13 +265,10 @@ class AesTransport(BaseTransport):
return await self.send_secure_passthrough(request)
async def close(self) -> None:
"""Close the protocol."""
client = self._default_http_client
self._default_http_client = None
"""Close the transport."""
self._handshake_done = False
self._login_token = None
if client:
await client.aclose()
await self._http_client.close()
class AesEncyptionSession:

View File

@ -2,13 +2,14 @@
import logging
from dataclasses import asdict, dataclass, field, fields, is_dataclass
from enum import Enum
from typing import Dict, Optional, Union
import httpx
from typing import TYPE_CHECKING, Dict, Optional, Union
from .credentials import Credentials
from .exceptions import SmartDeviceException
if TYPE_CHECKING:
from httpx import AsyncClient
_LOGGER = logging.getLogger(__name__)
@ -150,7 +151,7 @@ class DeviceConfig:
# compare=False will be excluded from the serialization and object comparison.
#: Set a custom http_client for the device to use.
http_client: Optional[httpx.AsyncClient] = field(default=None, compare=False)
http_client: Optional["AsyncClient"] = field(default=None, compare=False)
def __post_init__(self):
if self.connection_type is None:

View File

@ -31,6 +31,10 @@ class TimeoutException(SmartDeviceException):
"""Timeout exception for device errors."""
class ConnectionException(SmartDeviceException):
"""Connection exception for device errors."""
class SmartErrorCode(IntEnum):
"""Enum for SMART Error Codes."""

89
kasa/httpclient.py Normal file
View File

@ -0,0 +1,89 @@
"""Module for HttpClientSession class."""
import logging
from typing import Any, Dict, Optional, Tuple, Type, Union
import httpx
from .deviceconfig import DeviceConfig
from .exceptions import ConnectionException, SmartDeviceException, TimeoutException
logging.getLogger("httpx").propagate = False
InnerHttpType = Type[httpx.AsyncClient]
class HttpClient:
"""HttpClient Class."""
def __init__(self, config: DeviceConfig) -> None:
self._config = config
self._client: httpx.AsyncClient = None
@property
def client(self) -> httpx.AsyncClient:
"""Return the underlying http client."""
if self._config.http_client and issubclass(
self._config.http_client.__class__, httpx.AsyncClient
):
return self._config.http_client
if not self._client:
self._client = httpx.AsyncClient()
return self._client
async def post(
self,
url: str,
*,
params: Optional[Dict[str, Any]] = None,
data: Optional[bytes] = None,
json: Optional[Dict] = None,
headers: Optional[Dict[str, str]] = None,
cookies_dict: Optional[Dict[str, str]] = None,
) -> Tuple[int, Optional[Union[Dict, bytes]]]:
"""Send an http post request to the device."""
response_data = None
cookies = None
if cookies_dict:
cookies = httpx.Cookies()
for name, value in cookies_dict.items():
cookies.set(name, value)
self.client.cookies.clear()
try:
resp = await self.client.post(
url,
params=params,
data=data,
json=json,
timeout=self._config.timeout,
cookies=cookies,
headers=headers,
)
except httpx.ConnectError as ex:
raise ConnectionException(
f"Unable to connect to the device: {self._config.host}: {ex}"
) from ex
except httpx.TimeoutException as ex:
raise TimeoutException(
"Unable to query the device, " + f"timed out: {self._config.host}: {ex}"
) from ex
except Exception as ex:
raise SmartDeviceException(
f"Unable to query the device: {self._config.host}: {ex}"
) from ex
if resp.status_code == 200:
response_data = resp.json() if json else resp.content
return resp.status_code, response_data
def get_cookie(self, cookie_name: str) -> str:
"""Return the cookie with cookie_name."""
return self._client.cookies.get(cookie_name)
async def close(self) -> None:
"""Close the protocol."""
client = self._client
self._client = None
if client:
await client.aclose()

View File

@ -3,9 +3,13 @@ import asyncio
import logging
from typing import Dict, Union
import httpx
from .exceptions import AuthenticationException, SmartDeviceException
from .exceptions import (
AuthenticationException,
ConnectionException,
RetryableException,
SmartDeviceException,
TimeoutException,
)
from .json import dumps as json_dumps
from .protocol import BaseTransport, TPLinkProtocol
@ -15,6 +19,8 @@ _LOGGER = logging.getLogger(__name__)
class IotProtocol(TPLinkProtocol):
"""Class for the legacy TPLink IOT KASA Protocol."""
BACKOFF_SECONDS_AFTER_TIMEOUT = 1
def __init__(
self,
*,
@ -38,40 +44,39 @@ class IotProtocol(TPLinkProtocol):
for retry in range(retry_count + 1):
try:
return await self._execute_query(request, retry)
except httpx.ConnectError as sdex:
except ConnectionException as sdex:
if retry >= retry_count:
await self.close()
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise SmartDeviceException(
f"Unable to connect to the device: {self._host}: {sdex}"
) from sdex
raise sdex
continue
except TimeoutError as tex:
await self.close()
raise SmartDeviceException(
f"Unable to connect to the device, timed out: {self._host}: {tex}"
) from tex
except AuthenticationException as auex:
await self.close()
_LOGGER.debug(
"Unable to authenticate with %s, not retrying", self._host
)
raise auex
except RetryableException as ex:
if retry >= retry_count:
await self.close()
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise ex
continue
except TimeoutException as ex:
if retry >= retry_count:
await self.close()
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise ex
await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT)
continue
except SmartDeviceException as ex:
await self.close()
_LOGGER.debug(
"Unable to connect to the device: %s, not retrying: %s",
"Unable to query the device: %s, not retrying: %s",
self._host,
ex,
)
raise ex
except Exception as ex:
if retry >= retry_count:
await self.close()
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise SmartDeviceException(
f"Unable to connect to the device: {self._host}: {ex}"
) from ex
continue
# make mypy happy, this should never be reached..
raise SmartDeviceException("Query reached somehow to unreachable")

View File

@ -48,20 +48,19 @@ import logging
import secrets
import time
from pprint import pformat as pf
from typing import Any, Optional, Tuple
from typing import Any, Dict, Optional, Tuple, cast
import httpx
from cryptography.hazmat.primitives import hashes, padding
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from .credentials import Credentials
from .deviceconfig import DeviceConfig
from .exceptions import AuthenticationException, SmartDeviceException
from .httpclient import HttpClient
from .json import loads as json_loads
from .protocol import BaseTransport, md5
_LOGGER = logging.getLogger(__name__)
logging.getLogger("httpx").propagate = False
def _sha256(payload: bytes) -> bytes:
@ -98,7 +97,7 @@ class KlapTransport(BaseTransport):
) -> None:
super().__init__(config=config)
self._default_http_client: Optional[httpx.AsyncClient] = None
self._http_client = HttpClient(config)
self._local_seed: Optional[bytes] = None
if (
not self._credentials or self._credentials.username is None
@ -118,7 +117,7 @@ class KlapTransport(BaseTransport):
self._encryption_session: Optional[KlapEncryptionSession] = None
self._session_expire_at: Optional[float] = None
self._session_cookie = None
self._session_cookie: Optional[Dict[str, Any]] = None
_LOGGER.debug("Created KLAP transport for %s", self._host)
@ -132,34 +131,6 @@ class KlapTransport(BaseTransport):
"""The hashed credentials used by the transport."""
return base64.b64encode(self._local_auth_hash).decode()
@property
def _http_client(self) -> httpx.AsyncClient:
if self._config.http_client:
return self._config.http_client
if not self._default_http_client:
self._default_http_client = httpx.AsyncClient()
return self._default_http_client
async def client_post(self, url, params=None, data=None):
"""Send an http post request to the device."""
response_data = None
cookies = None
if self._session_cookie:
cookies = httpx.Cookies()
cookies.set(self.SESSION_COOKIE_NAME, self._session_cookie)
self._http_client.cookies.clear()
resp = await self._http_client.post(
url,
params=params,
data=data,
timeout=self._timeout,
cookies=cookies,
)
if resp.status_code == 200:
response_data = resp.content
return resp.status_code, response_data
async def perform_handshake1(self) -> Tuple[bytes, bytes, bytes]:
"""Perform handshake1."""
local_seed: bytes = secrets.token_bytes(16)
@ -172,7 +143,7 @@ class KlapTransport(BaseTransport):
url = f"http://{self._host}/app/handshake1"
response_status, response_data = await self.client_post(url, data=payload)
response_status, response_data = await self._http_client.post(url, data=payload)
if _LOGGER.isEnabledFor(logging.DEBUG):
_LOGGER.debug(
@ -189,6 +160,7 @@ class KlapTransport(BaseTransport):
f"Device {self._host} responded with {response_status} to handshake1"
)
response_data = cast(bytes, response_data)
remote_seed: bytes = response_data[0:16]
server_hash = response_data[16:]
@ -268,7 +240,11 @@ class KlapTransport(BaseTransport):
payload = self.handshake2_seed_auth_hash(local_seed, remote_seed, auth_hash)
response_status, response_data = await self.client_post(url, data=payload)
response_status, _ = await self._http_client.post(
url,
data=payload,
cookies_dict=self._session_cookie,
)
if _LOGGER.isEnabledFor(logging.DEBUG):
_LOGGER.debug(
@ -298,9 +274,10 @@ class KlapTransport(BaseTransport):
self._session_cookie = None
local_seed, remote_seed, auth_hash = await self.perform_handshake1()
self._session_cookie = self._http_client.cookies.get( # type: ignore
if cookie := self._http_client.get_cookie( # type: ignore
self.SESSION_COOKIE_NAME
)
):
self._session_cookie = {self.SESSION_COOKIE_NAME: cookie}
# The device returns a TIMEOUT cookie on handshake1 which
# it doesn't like to get back so we store the one we want
@ -330,10 +307,11 @@ class KlapTransport(BaseTransport):
url = f"http://{self._host}/app/request"
response_status, response_data = await self.client_post(
response_status, response_data = await self._http_client.post(
url,
params={"seq": seq},
data=payload,
cookies_dict=self._session_cookie,
)
msg = (
@ -374,11 +352,8 @@ class KlapTransport(BaseTransport):
async def close(self) -> None:
"""Close the transport."""
client = self._default_http_client
self._default_http_client = None
self._handshake_done = False
if client:
await client.aclose()
await self._http_client.close()
@staticmethod
def generate_auth_hash(creds: Credentials):

View File

@ -12,13 +12,12 @@ import uuid
from pprint import pformat as pf
from typing import Dict, Union
import httpx
from .exceptions import (
SMART_AUTHENTICATION_ERRORS,
SMART_RETRYABLE_ERRORS,
SMART_TIMEOUT_ERRORS,
AuthenticationException,
ConnectionException,
RetryableException,
SmartDeviceException,
SmartErrorCode,
@ -28,13 +27,12 @@ from .json import dumps as json_dumps
from .protocol import BaseTransport, TPLinkProtocol, md5
_LOGGER = logging.getLogger(__name__)
logging.getLogger("httpx").propagate = False
class SmartProtocol(TPLinkProtocol):
"""Class for the new TPLink SMART protocol."""
SLEEP_SECONDS_AFTER_TIMEOUT = 1
BACKOFF_SECONDS_AFTER_TIMEOUT = 1
def __init__(
self,
@ -67,22 +65,11 @@ class SmartProtocol(TPLinkProtocol):
for retry in range(retry_count + 1):
try:
return await self._execute_query(request, retry)
except httpx.ConnectError as sdex:
except ConnectionException as sdex:
if retry >= retry_count:
await self.close()
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise SmartDeviceException(
f"Unable to connect to the device: {self._host}: {sdex}"
) from sdex
continue
except TimeoutError as tex:
if retry >= retry_count:
await self.close()
raise SmartDeviceException(
"Unable to connect to the device, "
+ f"timed out: {self._host}: {tex}"
) from tex
await asyncio.sleep(self.SLEEP_SECONDS_AFTER_TIMEOUT)
raise sdex
continue
except AuthenticationException as auex:
await self.close()
@ -101,24 +88,16 @@ class SmartProtocol(TPLinkProtocol):
await self.close()
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise ex
await asyncio.sleep(self.SLEEP_SECONDS_AFTER_TIMEOUT)
await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT)
continue
except SmartDeviceException as ex:
# Transport would have raised RetryableException if retry makes sense.
await self.close()
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise ex
except Exception as ex:
if retry >= retry_count:
await self.close()
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise SmartDeviceException(
f"Unable to connect to the device: {self._host}: {ex}"
) from ex
_LOGGER.debug(
"Unable to query the device %s, retrying: %s", self._host, ex
"Unable to query the device: %s, not retrying: %s",
self._host,
ex,
)
continue
raise ex
# make mypy happy, this should never be reached..
raise SmartDeviceException("Query reached somehow to unreachable")

View File

@ -145,7 +145,7 @@ async def test_connect_http_client(all_fixture_data, mocker):
)
dev = await connect(config=config)
if ctype.encryption_type != EncryptType.Xor:
assert dev.protocol._transport._http_client != http_client
assert dev.protocol._transport._http_client.client != http_client
config = DeviceConfig(
host=host,
@ -155,4 +155,4 @@ async def test_connect_http_client(all_fixture_data, mocker):
)
dev = await connect(config=config)
if ctype.encryption_type != EncryptType.Xor:
assert dev.protocol._transport._http_client == http_client
assert dev.protocol._transport._http_client.client == http_client

View File

@ -321,9 +321,9 @@ async def test_discover_single_http_client(discovery_mock, mocker):
assert x.config.uses_http == (discovery_mock.default_port == 80)
if discovery_mock.default_port == 80:
assert x.protocol._transport._http_client != http_client
assert x.protocol._transport._http_client.client != http_client
x.config.http_client = http_client
assert x.protocol._transport._http_client == http_client
assert x.protocol._transport._http_client.client == http_client
async def test_discover_http_client(discovery_mock, mocker):
@ -338,6 +338,6 @@ async def test_discover_http_client(discovery_mock, mocker):
assert x.config.uses_http == (discovery_mock.default_port == 80)
if discovery_mock.default_port == 80:
assert x.protocol._transport._http_client != http_client
assert x.protocol._transport._http_client.client != http_client
x.config.http_client = http_client
assert x.protocol._transport._http_client == http_client
assert x.protocol._transport._http_client.client == http_client

View File

@ -13,7 +13,12 @@ import pytest
from ..aestransport import AesTransport
from ..credentials import Credentials
from ..deviceconfig import DeviceConfig
from ..exceptions import AuthenticationException, SmartDeviceException
from ..exceptions import (
AuthenticationException,
ConnectionException,
SmartDeviceException,
)
from ..httpclient import HttpClient
from ..iotprotocol import IotProtocol
from ..klaptransport import (
KlapEncryptionSession,
@ -35,8 +40,8 @@ class _mock_response:
@pytest.mark.parametrize(
"error, retry_expectation",
[
(Exception("dummy exception"), True),
(SmartDeviceException("dummy exception"), False),
(Exception("dummy exception"), False),
(httpx.TimeoutException("dummy exception"), True),
(httpx.ConnectError("dummy exception"), True),
],
ids=("Exception", "SmartDeviceException", "httpx.ConnectError"),
@ -89,7 +94,7 @@ async def test_protocol_retry_recoverable_error(
conn = mocker.patch.object(
httpx.AsyncClient,
"post",
side_effect=httpx.CloseError("foo"),
side_effect=httpx.ConnectError("foo"),
)
config = DeviceConfig(host)
with pytest.raises(SmartDeviceException):
@ -112,7 +117,7 @@ async def test_protocol_reconnect(mocker, retry_count, protocol_class, transport
nonlocal remaining
remaining -= 1
if remaining:
raise Exception("Simulated post failure")
raise ConnectionException("Simulated connection failure")
return mock_response
@ -155,7 +160,7 @@ async def test_protocol_logging(mocker, caplog, log_level):
protocol._transport._handshake_done = True
protocol._transport._session_expire_at = time.time() + 86400
protocol._transport._encryption_session = encryption_session
mocker.patch.object(KlapTransport, "client_post", side_effect=_return_encrypted)
mocker.patch.object(HttpClient, "post", side_effect=_return_encrypted)
response = await protocol.query({})
assert response == {"great": "success"}