mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-08-06 10:44:04 +00:00
Migrate http client to use aiohttp instead of httpx (#643)
This commit is contained in:
@@ -8,7 +8,7 @@ from .credentials import Credentials
|
||||
from .exceptions import SmartDeviceException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from httpx import AsyncClient
|
||||
from aiohttp import ClientSession
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -151,7 +151,7 @@ class DeviceConfig:
|
||||
|
||||
# compare=False will be excluded from the serialization and object comparison.
|
||||
#: Set a custom http_client for the device to use.
|
||||
http_client: Optional["AsyncClient"] = field(default=None, compare=False)
|
||||
http_client: Optional["ClientSession"] = field(default=None, compare=False)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.connection_type is None:
|
||||
|
@@ -1,15 +1,16 @@
|
||||
"""Module for HttpClientSession class."""
|
||||
import logging
|
||||
from typing import Any, Dict, Optional, Tuple, Type, Union
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
import aiohttp
|
||||
|
||||
from .deviceconfig import DeviceConfig
|
||||
from .exceptions import ConnectionException, SmartDeviceException, TimeoutException
|
||||
from .json import loads as json_loads
|
||||
|
||||
logging.getLogger("httpx").propagate = False
|
||||
|
||||
InnerHttpType = Type[httpx.AsyncClient]
|
||||
def get_cookie_jar() -> aiohttp.CookieJar:
|
||||
"""Return a new cookie jar with the correct options for device communication."""
|
||||
return aiohttp.CookieJar(unsafe=True, quote_cookie=False)
|
||||
|
||||
|
||||
class HttpClient:
|
||||
@@ -17,18 +18,20 @@ class HttpClient:
|
||||
|
||||
def __init__(self, config: DeviceConfig) -> None:
|
||||
self._config = config
|
||||
self._client: httpx.AsyncClient = None
|
||||
self._client: aiohttp.ClientSession = None
|
||||
self._jar = aiohttp.CookieJar(unsafe=True, quote_cookie=False)
|
||||
self._last_url = f"http://{self._config.host}/"
|
||||
|
||||
@property
|
||||
def client(self) -> httpx.AsyncClient:
|
||||
def client(self) -> aiohttp.ClientSession:
|
||||
"""Return the underlying http client."""
|
||||
if self._config.http_client and issubclass(
|
||||
self._config.http_client.__class__, httpx.AsyncClient
|
||||
self._config.http_client.__class__, aiohttp.ClientSession
|
||||
):
|
||||
return self._config.http_client
|
||||
|
||||
if not self._client:
|
||||
self._client = httpx.AsyncClient()
|
||||
self._client = aiohttp.ClientSession(cookie_jar=get_cookie_jar())
|
||||
return self._client
|
||||
|
||||
async def post(
|
||||
@@ -43,12 +46,8 @@ class HttpClient:
|
||||
) -> Tuple[int, Optional[Union[Dict, bytes]]]:
|
||||
"""Send an http post request to the device."""
|
||||
response_data = None
|
||||
cookies = None
|
||||
if cookies_dict:
|
||||
cookies = httpx.Cookies()
|
||||
for name, value in cookies_dict.items():
|
||||
cookies.set(name, value)
|
||||
self.client.cookies.clear()
|
||||
self._last_url = url
|
||||
self.client.cookie_jar.clear()
|
||||
try:
|
||||
resp = await self.client.post(
|
||||
url,
|
||||
@@ -56,14 +55,14 @@ class HttpClient:
|
||||
data=data,
|
||||
json=json,
|
||||
timeout=self._config.timeout,
|
||||
cookies=cookies,
|
||||
cookies=cookies_dict,
|
||||
headers=headers,
|
||||
)
|
||||
except httpx.ConnectError as ex:
|
||||
except (aiohttp.ServerDisconnectedError, aiohttp.ClientOSError) as ex:
|
||||
raise ConnectionException(
|
||||
f"Unable to connect to the device: {self._config.host}: {ex}"
|
||||
) from ex
|
||||
except httpx.TimeoutException as ex:
|
||||
except aiohttp.ServerTimeoutError as ex:
|
||||
raise TimeoutException(
|
||||
"Unable to query the device, " + f"timed out: {self._config.host}: {ex}"
|
||||
) from ex
|
||||
@@ -72,18 +71,25 @@ class HttpClient:
|
||||
f"Unable to query the device: {self._config.host}: {ex}"
|
||||
) from ex
|
||||
|
||||
if resp.status_code == 200:
|
||||
response_data = resp.json() if json else resp.content
|
||||
async with resp:
|
||||
if resp.status == 200:
|
||||
response_data = await resp.read()
|
||||
if json:
|
||||
response_data = json_loads(response_data.decode())
|
||||
|
||||
return resp.status_code, response_data
|
||||
return resp.status, response_data
|
||||
|
||||
def get_cookie(self, cookie_name: str) -> str:
|
||||
def get_cookie(self, cookie_name: str) -> Optional[str]:
|
||||
"""Return the cookie with cookie_name."""
|
||||
return self._client.cookies.get(cookie_name)
|
||||
if cookie := self.client.cookie_jar.filter_cookies(self._last_url).get(
|
||||
cookie_name
|
||||
):
|
||||
return cookie.value
|
||||
return None
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the protocol."""
|
||||
"""Close the client."""
|
||||
client = self._client
|
||||
self._client = None
|
||||
if client:
|
||||
await client.aclose()
|
||||
await client.close()
|
||||
|
@@ -5,7 +5,7 @@ from contextlib import nullcontext as does_not_raise
|
||||
from json import dumps as json_dumps
|
||||
from json import loads as json_loads
|
||||
|
||||
import httpx
|
||||
import aiohttp
|
||||
import pytest
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
|
||||
@@ -19,6 +19,7 @@ from ..exceptions import (
|
||||
SmartDeviceException,
|
||||
SmartErrorCode,
|
||||
)
|
||||
from ..httpclient import HttpClient
|
||||
|
||||
DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}}
|
||||
|
||||
@@ -57,7 +58,7 @@ async def test_handshake(
|
||||
):
|
||||
host = "127.0.0.1"
|
||||
mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code)
|
||||
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
|
||||
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=mock_aes_device.post)
|
||||
|
||||
transport = AesTransport(
|
||||
config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
|
||||
@@ -75,7 +76,7 @@ async def test_handshake(
|
||||
async def test_login(mocker, status_code, error_code, inner_error_code, expectation):
|
||||
host = "127.0.0.1"
|
||||
mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code)
|
||||
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
|
||||
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=mock_aes_device.post)
|
||||
|
||||
transport = AesTransport(
|
||||
config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
|
||||
@@ -94,7 +95,7 @@ async def test_login(mocker, status_code, error_code, inner_error_code, expectat
|
||||
async def test_send(mocker, status_code, error_code, inner_error_code, expectation):
|
||||
host = "127.0.0.1"
|
||||
mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code)
|
||||
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
|
||||
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=mock_aes_device.post)
|
||||
|
||||
transport = AesTransport(
|
||||
config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
|
||||
@@ -123,7 +124,7 @@ ERRORS = [e for e in SmartErrorCode if e != 0]
|
||||
async def test_passthrough_errors(mocker, error_code):
|
||||
host = "127.0.0.1"
|
||||
mock_aes_device = MockAesDevice(host, 200, error_code, 0)
|
||||
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
|
||||
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=mock_aes_device.post)
|
||||
|
||||
config = DeviceConfig(host, credentials=Credentials("foo", "bar"))
|
||||
transport = AesTransport(config=config)
|
||||
@@ -145,12 +146,18 @@ async def test_passthrough_errors(mocker, error_code):
|
||||
|
||||
class MockAesDevice:
|
||||
class _mock_response:
|
||||
def __init__(self, status_code, json: dict):
|
||||
self.status_code = status_code
|
||||
def __init__(self, status, json: dict):
|
||||
self.status = status
|
||||
self._json = json
|
||||
|
||||
def json(self):
|
||||
return self._json
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_t, exc_v, exc_tb):
|
||||
pass
|
||||
|
||||
async def read(self):
|
||||
return json_dumps(self._json).encode()
|
||||
|
||||
encryption_session = AesEncyptionSession(KEY_IV[:16], KEY_IV[16:])
|
||||
token = "test_token" # noqa
|
||||
@@ -160,6 +167,7 @@ class MockAesDevice:
|
||||
self.status_code = status_code
|
||||
self.error_code = error_code
|
||||
self.inner_error_code = inner_error_code
|
||||
self.http_client = HttpClient(DeviceConfig(self.host))
|
||||
|
||||
async def post(self, url, params=None, json=None, *_, **__):
|
||||
return await self._post(url, json)
|
||||
@@ -193,7 +201,9 @@ class MockAesDevice:
|
||||
decrypted_request = self.encryption_session.decrypt(encrypted_request.encode())
|
||||
decrypted_request_dict = json_loads(decrypted_request)
|
||||
decrypted_response = await self._post(url, decrypted_request_dict)
|
||||
decrypted_response_dict = decrypted_response.json()
|
||||
async with decrypted_response:
|
||||
response_data = await decrypted_response.read()
|
||||
decrypted_response_dict = json_loads(response_data.decode())
|
||||
encrypted_response = self.encryption_session.encrypt(
|
||||
json_dumps(decrypted_response_dict).encode()
|
||||
)
|
||||
|
@@ -2,7 +2,7 @@
|
||||
import logging
|
||||
from typing import Type
|
||||
|
||||
import httpx
|
||||
import aiohttp
|
||||
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
|
||||
|
||||
from kasa import (
|
||||
@@ -138,7 +138,7 @@ async def test_connect_http_client(all_fixture_data, mocker):
|
||||
mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data)
|
||||
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data)
|
||||
|
||||
http_client = httpx.AsyncClient()
|
||||
http_client = aiohttp.ClientSession()
|
||||
|
||||
config = DeviceConfig(
|
||||
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
|
||||
|
@@ -1,7 +1,7 @@
|
||||
from json import dumps as json_dumps
|
||||
from json import loads as json_loads
|
||||
|
||||
import httpx
|
||||
import aiohttp
|
||||
|
||||
from kasa.credentials import Credentials
|
||||
from kasa.deviceconfig import (
|
||||
@@ -12,8 +12,8 @@ from kasa.deviceconfig import (
|
||||
)
|
||||
|
||||
|
||||
def test_serialization():
|
||||
config = DeviceConfig(host="Foo", http_client=httpx.AsyncClient())
|
||||
async def test_serialization():
|
||||
config = DeviceConfig(host="Foo", http_client=aiohttp.ClientSession())
|
||||
config_dict = config.to_dict()
|
||||
config_json = json_dumps(config_dict)
|
||||
config2_dict = json_loads(config_json)
|
||||
@@ -21,10 +21,10 @@ def test_serialization():
|
||||
assert config == config2
|
||||
|
||||
|
||||
def test_credentials_hash():
|
||||
async def test_credentials_hash():
|
||||
config = DeviceConfig(
|
||||
host="Foo",
|
||||
http_client=httpx.AsyncClient(),
|
||||
http_client=aiohttp.ClientSession(),
|
||||
credentials=Credentials("foo", "bar"),
|
||||
)
|
||||
config_dict = config.to_dict(credentials_hash="credhash")
|
||||
@@ -35,10 +35,10 @@ def test_credentials_hash():
|
||||
assert config2.credentials is None
|
||||
|
||||
|
||||
def test_blank_credentials_hash():
|
||||
async def test_blank_credentials_hash():
|
||||
config = DeviceConfig(
|
||||
host="Foo",
|
||||
http_client=httpx.AsyncClient(),
|
||||
http_client=aiohttp.ClientSession(),
|
||||
credentials=Credentials("foo", "bar"),
|
||||
)
|
||||
config_dict = config.to_dict(credentials_hash="")
|
||||
@@ -49,10 +49,10 @@ def test_blank_credentials_hash():
|
||||
assert config2.credentials is None
|
||||
|
||||
|
||||
def test_exclude_credentials():
|
||||
async def test_exclude_credentials():
|
||||
config = DeviceConfig(
|
||||
host="Foo",
|
||||
http_client=httpx.AsyncClient(),
|
||||
http_client=aiohttp.ClientSession(),
|
||||
credentials=Credentials("foo", "bar"),
|
||||
)
|
||||
config_dict = config.to_dict(exclude_credentials=True)
|
||||
|
@@ -3,7 +3,7 @@ import logging
|
||||
import re
|
||||
import socket
|
||||
|
||||
import httpx
|
||||
import aiohttp
|
||||
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
|
||||
|
||||
from kasa import (
|
||||
@@ -314,7 +314,7 @@ async def test_discover_single_http_client(discovery_mock, mocker):
|
||||
host = "127.0.0.1"
|
||||
discovery_mock.ip = host
|
||||
|
||||
http_client = httpx.AsyncClient()
|
||||
http_client = aiohttp.ClientSession()
|
||||
|
||||
x: SmartDevice = await Discover.discover_single(host)
|
||||
|
||||
@@ -331,7 +331,7 @@ async def test_discover_http_client(discovery_mock, mocker):
|
||||
host = "127.0.0.1"
|
||||
discovery_mock.ip = host
|
||||
|
||||
http_client = httpx.AsyncClient()
|
||||
http_client = aiohttp.ClientSession()
|
||||
|
||||
devices = await Discover.discover(discovery_timeout=0)
|
||||
x: SmartDevice = devices[host]
|
||||
|
@@ -7,7 +7,7 @@ import sys
|
||||
import time
|
||||
from contextlib import nullcontext as does_not_raise
|
||||
|
||||
import httpx
|
||||
import aiohttp
|
||||
import pytest
|
||||
|
||||
from ..aestransport import AesTransport
|
||||
@@ -32,19 +32,28 @@ DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}}
|
||||
|
||||
|
||||
class _mock_response:
|
||||
def __init__(self, status_code, content: bytes):
|
||||
self.status_code = status_code
|
||||
def __init__(self, status, content: bytes):
|
||||
self.status = status
|
||||
self.content = content
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_t, exc_v, exc_tb):
|
||||
pass
|
||||
|
||||
async def read(self):
|
||||
return self.content
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"error, retry_expectation",
|
||||
[
|
||||
(Exception("dummy exception"), False),
|
||||
(httpx.TimeoutException("dummy exception"), True),
|
||||
(httpx.ConnectError("dummy exception"), True),
|
||||
(aiohttp.ServerTimeoutError("dummy exception"), True),
|
||||
(aiohttp.ClientOSError("dummy exception"), True),
|
||||
],
|
||||
ids=("Exception", "SmartDeviceException", "httpx.ConnectError"),
|
||||
ids=("Exception", "SmartDeviceException", "ConnectError"),
|
||||
)
|
||||
@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
|
||||
@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol])
|
||||
@@ -53,7 +62,7 @@ async def test_protocol_retries(
|
||||
mocker, retry_count, protocol_class, transport_class, error, retry_expectation
|
||||
):
|
||||
host = "127.0.0.1"
|
||||
conn = mocker.patch.object(httpx.AsyncClient, "post", side_effect=error)
|
||||
conn = mocker.patch.object(aiohttp.ClientSession, "post", side_effect=error)
|
||||
|
||||
config = DeviceConfig(host)
|
||||
with pytest.raises(SmartDeviceException):
|
||||
@@ -72,7 +81,7 @@ async def test_protocol_no_retry_on_connection_error(
|
||||
):
|
||||
host = "127.0.0.1"
|
||||
conn = mocker.patch.object(
|
||||
httpx.AsyncClient,
|
||||
aiohttp.ClientSession,
|
||||
"post",
|
||||
side_effect=AuthenticationException("foo"),
|
||||
)
|
||||
@@ -92,9 +101,9 @@ async def test_protocol_retry_recoverable_error(
|
||||
):
|
||||
host = "127.0.0.1"
|
||||
conn = mocker.patch.object(
|
||||
httpx.AsyncClient,
|
||||
aiohttp.ClientSession,
|
||||
"post",
|
||||
side_effect=httpx.ConnectError("foo"),
|
||||
side_effect=aiohttp.ClientOSError("foo"),
|
||||
)
|
||||
config = DeviceConfig(host)
|
||||
with pytest.raises(SmartDeviceException):
|
||||
@@ -240,7 +249,7 @@ async def test_handshake1(
|
||||
device_auth_hash = transport_class.generate_auth_hash(device_credentials)
|
||||
|
||||
mocker.patch.object(
|
||||
httpx.AsyncClient, "post", side_effect=_return_handshake1_response
|
||||
aiohttp.ClientSession, "post", side_effect=_return_handshake1_response
|
||||
)
|
||||
|
||||
config = DeviceConfig("127.0.0.1", credentials=client_credentials)
|
||||
@@ -299,12 +308,12 @@ async def test_handshake(
|
||||
device_auth_hash = transport_class.generate_auth_hash(client_credentials)
|
||||
|
||||
mocker.patch.object(
|
||||
httpx.AsyncClient, "post", side_effect=_return_handshake_response
|
||||
aiohttp.ClientSession, "post", side_effect=_return_handshake_response
|
||||
)
|
||||
|
||||
config = DeviceConfig("127.0.0.1", credentials=client_credentials)
|
||||
protocol = IotProtocol(transport=transport_class(config=config))
|
||||
protocol._transport.http_client = httpx.AsyncClient()
|
||||
protocol._transport.http_client = aiohttp.ClientSession()
|
||||
|
||||
response_status = 200
|
||||
await protocol._transport.perform_handshake()
|
||||
@@ -347,7 +356,7 @@ async def test_query(mocker):
|
||||
client_credentials = Credentials("foo", "bar")
|
||||
device_auth_hash = KlapTransport.generate_auth_hash(client_credentials)
|
||||
|
||||
mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response)
|
||||
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=_return_response)
|
||||
|
||||
config = DeviceConfig("127.0.0.1", credentials=client_credentials)
|
||||
protocol = IotProtocol(transport=KlapTransport(config=config))
|
||||
@@ -392,7 +401,7 @@ async def test_authentication_failures(mocker, response_status, expectation):
|
||||
client_credentials = Credentials("foo", "bar")
|
||||
device_auth_hash = KlapTransport.generate_auth_hash(client_credentials)
|
||||
|
||||
mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response)
|
||||
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=_return_response)
|
||||
|
||||
config = DeviceConfig("127.0.0.1", credentials=client_credentials)
|
||||
protocol = IotProtocol(transport=KlapTransport(config=config))
|
||||
|
@@ -8,7 +8,6 @@ import time
|
||||
from contextlib import nullcontext as does_not_raise
|
||||
from itertools import chain
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from ..aestransport import AesTransport
|
||||
|
Reference in New Issue
Block a user