Migrate http client to use aiohttp instead of httpx (#643)

This commit is contained in:
Steven B
2024-01-18 17:32:26 +00:00
committed by GitHub
parent 3b1b0a3c21
commit 642e9a1f5b
10 changed files with 488 additions and 119 deletions

View File

@@ -8,7 +8,7 @@ from .credentials import Credentials
from .exceptions import SmartDeviceException
if TYPE_CHECKING:
from httpx import AsyncClient
from aiohttp import ClientSession
_LOGGER = logging.getLogger(__name__)
@@ -151,7 +151,7 @@ class DeviceConfig:
# compare=False will be excluded from the serialization and object comparison.
#: Set a custom http_client for the device to use.
http_client: Optional["AsyncClient"] = field(default=None, compare=False)
http_client: Optional["ClientSession"] = field(default=None, compare=False)
def __post_init__(self):
if self.connection_type is None:

View File

@@ -1,15 +1,16 @@
"""Module for HttpClientSession class."""
import logging
from typing import Any, Dict, Optional, Tuple, Type, Union
from typing import Any, Dict, Optional, Tuple, Union
import httpx
import aiohttp
from .deviceconfig import DeviceConfig
from .exceptions import ConnectionException, SmartDeviceException, TimeoutException
from .json import loads as json_loads
logging.getLogger("httpx").propagate = False
InnerHttpType = Type[httpx.AsyncClient]
def get_cookie_jar() -> aiohttp.CookieJar:
"""Return a new cookie jar with the correct options for device communication."""
return aiohttp.CookieJar(unsafe=True, quote_cookie=False)
class HttpClient:
@@ -17,18 +18,20 @@ class HttpClient:
def __init__(self, config: DeviceConfig) -> None:
self._config = config
self._client: httpx.AsyncClient = None
self._client: aiohttp.ClientSession = None
self._jar = aiohttp.CookieJar(unsafe=True, quote_cookie=False)
self._last_url = f"http://{self._config.host}/"
@property
def client(self) -> httpx.AsyncClient:
def client(self) -> aiohttp.ClientSession:
"""Return the underlying http client."""
if self._config.http_client and issubclass(
self._config.http_client.__class__, httpx.AsyncClient
self._config.http_client.__class__, aiohttp.ClientSession
):
return self._config.http_client
if not self._client:
self._client = httpx.AsyncClient()
self._client = aiohttp.ClientSession(cookie_jar=get_cookie_jar())
return self._client
async def post(
@@ -43,12 +46,8 @@ class HttpClient:
) -> Tuple[int, Optional[Union[Dict, bytes]]]:
"""Send an http post request to the device."""
response_data = None
cookies = None
if cookies_dict:
cookies = httpx.Cookies()
for name, value in cookies_dict.items():
cookies.set(name, value)
self.client.cookies.clear()
self._last_url = url
self.client.cookie_jar.clear()
try:
resp = await self.client.post(
url,
@@ -56,14 +55,14 @@ class HttpClient:
data=data,
json=json,
timeout=self._config.timeout,
cookies=cookies,
cookies=cookies_dict,
headers=headers,
)
except httpx.ConnectError as ex:
except (aiohttp.ServerDisconnectedError, aiohttp.ClientOSError) as ex:
raise ConnectionException(
f"Unable to connect to the device: {self._config.host}: {ex}"
) from ex
except httpx.TimeoutException as ex:
except aiohttp.ServerTimeoutError as ex:
raise TimeoutException(
"Unable to query the device, " + f"timed out: {self._config.host}: {ex}"
) from ex
@@ -72,18 +71,25 @@ class HttpClient:
f"Unable to query the device: {self._config.host}: {ex}"
) from ex
if resp.status_code == 200:
response_data = resp.json() if json else resp.content
async with resp:
if resp.status == 200:
response_data = await resp.read()
if json:
response_data = json_loads(response_data.decode())
return resp.status_code, response_data
return resp.status, response_data
def get_cookie(self, cookie_name: str) -> str:
def get_cookie(self, cookie_name: str) -> Optional[str]:
"""Return the cookie with cookie_name."""
return self._client.cookies.get(cookie_name)
if cookie := self.client.cookie_jar.filter_cookies(self._last_url).get(
cookie_name
):
return cookie.value
return None
async def close(self) -> None:
"""Close the protocol."""
"""Close the client."""
client = self._client
self._client = None
if client:
await client.aclose()
await client.close()

View File

@@ -5,7 +5,7 @@ from contextlib import nullcontext as does_not_raise
from json import dumps as json_dumps
from json import loads as json_loads
import httpx
import aiohttp
import pytest
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
@@ -19,6 +19,7 @@ from ..exceptions import (
SmartDeviceException,
SmartErrorCode,
)
from ..httpclient import HttpClient
DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}}
@@ -57,7 +58,7 @@ async def test_handshake(
):
host = "127.0.0.1"
mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code)
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=mock_aes_device.post)
transport = AesTransport(
config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
@@ -75,7 +76,7 @@ async def test_handshake(
async def test_login(mocker, status_code, error_code, inner_error_code, expectation):
host = "127.0.0.1"
mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code)
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=mock_aes_device.post)
transport = AesTransport(
config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
@@ -94,7 +95,7 @@ async def test_login(mocker, status_code, error_code, inner_error_code, expectat
async def test_send(mocker, status_code, error_code, inner_error_code, expectation):
host = "127.0.0.1"
mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code)
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=mock_aes_device.post)
transport = AesTransport(
config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
@@ -123,7 +124,7 @@ ERRORS = [e for e in SmartErrorCode if e != 0]
async def test_passthrough_errors(mocker, error_code):
host = "127.0.0.1"
mock_aes_device = MockAesDevice(host, 200, error_code, 0)
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=mock_aes_device.post)
config = DeviceConfig(host, credentials=Credentials("foo", "bar"))
transport = AesTransport(config=config)
@@ -145,12 +146,18 @@ async def test_passthrough_errors(mocker, error_code):
class MockAesDevice:
class _mock_response:
def __init__(self, status_code, json: dict):
self.status_code = status_code
def __init__(self, status, json: dict):
self.status = status
self._json = json
def json(self):
return self._json
async def __aenter__(self):
return self
async def __aexit__(self, exc_t, exc_v, exc_tb):
pass
async def read(self):
return json_dumps(self._json).encode()
encryption_session = AesEncyptionSession(KEY_IV[:16], KEY_IV[16:])
token = "test_token" # noqa
@@ -160,6 +167,7 @@ class MockAesDevice:
self.status_code = status_code
self.error_code = error_code
self.inner_error_code = inner_error_code
self.http_client = HttpClient(DeviceConfig(self.host))
async def post(self, url, params=None, json=None, *_, **__):
return await self._post(url, json)
@@ -193,7 +201,9 @@ class MockAesDevice:
decrypted_request = self.encryption_session.decrypt(encrypted_request.encode())
decrypted_request_dict = json_loads(decrypted_request)
decrypted_response = await self._post(url, decrypted_request_dict)
decrypted_response_dict = decrypted_response.json()
async with decrypted_response:
response_data = await decrypted_response.read()
decrypted_response_dict = json_loads(response_data.decode())
encrypted_response = self.encryption_session.encrypt(
json_dumps(decrypted_response_dict).encode()
)

View File

@@ -2,7 +2,7 @@
import logging
from typing import Type
import httpx
import aiohttp
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
from kasa import (
@@ -138,7 +138,7 @@ async def test_connect_http_client(all_fixture_data, mocker):
mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data)
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data)
http_client = httpx.AsyncClient()
http_client = aiohttp.ClientSession()
config = DeviceConfig(
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype

View File

@@ -1,7 +1,7 @@
from json import dumps as json_dumps
from json import loads as json_loads
import httpx
import aiohttp
from kasa.credentials import Credentials
from kasa.deviceconfig import (
@@ -12,8 +12,8 @@ from kasa.deviceconfig import (
)
def test_serialization():
config = DeviceConfig(host="Foo", http_client=httpx.AsyncClient())
async def test_serialization():
config = DeviceConfig(host="Foo", http_client=aiohttp.ClientSession())
config_dict = config.to_dict()
config_json = json_dumps(config_dict)
config2_dict = json_loads(config_json)
@@ -21,10 +21,10 @@ def test_serialization():
assert config == config2
def test_credentials_hash():
async def test_credentials_hash():
config = DeviceConfig(
host="Foo",
http_client=httpx.AsyncClient(),
http_client=aiohttp.ClientSession(),
credentials=Credentials("foo", "bar"),
)
config_dict = config.to_dict(credentials_hash="credhash")
@@ -35,10 +35,10 @@ def test_credentials_hash():
assert config2.credentials is None
def test_blank_credentials_hash():
async def test_blank_credentials_hash():
config = DeviceConfig(
host="Foo",
http_client=httpx.AsyncClient(),
http_client=aiohttp.ClientSession(),
credentials=Credentials("foo", "bar"),
)
config_dict = config.to_dict(credentials_hash="")
@@ -49,10 +49,10 @@ def test_blank_credentials_hash():
assert config2.credentials is None
def test_exclude_credentials():
async def test_exclude_credentials():
config = DeviceConfig(
host="Foo",
http_client=httpx.AsyncClient(),
http_client=aiohttp.ClientSession(),
credentials=Credentials("foo", "bar"),
)
config_dict = config.to_dict(exclude_credentials=True)

View File

@@ -3,7 +3,7 @@ import logging
import re
import socket
import httpx
import aiohttp
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
from kasa import (
@@ -314,7 +314,7 @@ async def test_discover_single_http_client(discovery_mock, mocker):
host = "127.0.0.1"
discovery_mock.ip = host
http_client = httpx.AsyncClient()
http_client = aiohttp.ClientSession()
x: SmartDevice = await Discover.discover_single(host)
@@ -331,7 +331,7 @@ async def test_discover_http_client(discovery_mock, mocker):
host = "127.0.0.1"
discovery_mock.ip = host
http_client = httpx.AsyncClient()
http_client = aiohttp.ClientSession()
devices = await Discover.discover(discovery_timeout=0)
x: SmartDevice = devices[host]

View File

@@ -7,7 +7,7 @@ import sys
import time
from contextlib import nullcontext as does_not_raise
import httpx
import aiohttp
import pytest
from ..aestransport import AesTransport
@@ -32,19 +32,28 @@ DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}}
class _mock_response:
def __init__(self, status_code, content: bytes):
self.status_code = status_code
def __init__(self, status, content: bytes):
self.status = status
self.content = content
async def __aenter__(self):
return self
async def __aexit__(self, exc_t, exc_v, exc_tb):
pass
async def read(self):
return self.content
@pytest.mark.parametrize(
"error, retry_expectation",
[
(Exception("dummy exception"), False),
(httpx.TimeoutException("dummy exception"), True),
(httpx.ConnectError("dummy exception"), True),
(aiohttp.ServerTimeoutError("dummy exception"), True),
(aiohttp.ClientOSError("dummy exception"), True),
],
ids=("Exception", "SmartDeviceException", "httpx.ConnectError"),
ids=("Exception", "SmartDeviceException", "ConnectError"),
)
@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol])
@@ -53,7 +62,7 @@ async def test_protocol_retries(
mocker, retry_count, protocol_class, transport_class, error, retry_expectation
):
host = "127.0.0.1"
conn = mocker.patch.object(httpx.AsyncClient, "post", side_effect=error)
conn = mocker.patch.object(aiohttp.ClientSession, "post", side_effect=error)
config = DeviceConfig(host)
with pytest.raises(SmartDeviceException):
@@ -72,7 +81,7 @@ async def test_protocol_no_retry_on_connection_error(
):
host = "127.0.0.1"
conn = mocker.patch.object(
httpx.AsyncClient,
aiohttp.ClientSession,
"post",
side_effect=AuthenticationException("foo"),
)
@@ -92,9 +101,9 @@ async def test_protocol_retry_recoverable_error(
):
host = "127.0.0.1"
conn = mocker.patch.object(
httpx.AsyncClient,
aiohttp.ClientSession,
"post",
side_effect=httpx.ConnectError("foo"),
side_effect=aiohttp.ClientOSError("foo"),
)
config = DeviceConfig(host)
with pytest.raises(SmartDeviceException):
@@ -240,7 +249,7 @@ async def test_handshake1(
device_auth_hash = transport_class.generate_auth_hash(device_credentials)
mocker.patch.object(
httpx.AsyncClient, "post", side_effect=_return_handshake1_response
aiohttp.ClientSession, "post", side_effect=_return_handshake1_response
)
config = DeviceConfig("127.0.0.1", credentials=client_credentials)
@@ -299,12 +308,12 @@ async def test_handshake(
device_auth_hash = transport_class.generate_auth_hash(client_credentials)
mocker.patch.object(
httpx.AsyncClient, "post", side_effect=_return_handshake_response
aiohttp.ClientSession, "post", side_effect=_return_handshake_response
)
config = DeviceConfig("127.0.0.1", credentials=client_credentials)
protocol = IotProtocol(transport=transport_class(config=config))
protocol._transport.http_client = httpx.AsyncClient()
protocol._transport.http_client = aiohttp.ClientSession()
response_status = 200
await protocol._transport.perform_handshake()
@@ -347,7 +356,7 @@ async def test_query(mocker):
client_credentials = Credentials("foo", "bar")
device_auth_hash = KlapTransport.generate_auth_hash(client_credentials)
mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response)
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=_return_response)
config = DeviceConfig("127.0.0.1", credentials=client_credentials)
protocol = IotProtocol(transport=KlapTransport(config=config))
@@ -392,7 +401,7 @@ async def test_authentication_failures(mocker, response_status, expectation):
client_credentials = Credentials("foo", "bar")
device_auth_hash = KlapTransport.generate_auth_hash(client_credentials)
mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response)
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=_return_response)
config = DeviceConfig("127.0.0.1", credentials=client_credentials)
protocol = IotProtocol(transport=KlapTransport(config=config))

View File

@@ -8,7 +8,6 @@ import time
from contextlib import nullcontext as does_not_raise
from itertools import chain
import httpx
import pytest
from ..aestransport import AesTransport