Avoid rebuilding urls for every request (#715)

* Avoid rebuilding urls for every request

* more fixes

* more fixes

* make mypy happy

* reduce

* tweak

* fix tests

* fix tests

* tweak

* tweak

* lint

* fix type
This commit is contained in:
J. Nick Koston
2024-01-29 05:26:00 -10:00
committed by GitHub
parent 69dcc0d8bb
commit b479b6d84d
7 changed files with 73 additions and 62 deletions

View File

@@ -12,6 +12,7 @@ import aiohttp
import pytest
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
from yarl import URL
from ..aestransport import AesEncyptionSession, AesTransport, TransportState
from ..credentials import Credentials
@@ -89,10 +90,10 @@ async def test_login(mocker, status_code, error_code, inner_error_code, expectat
transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session
assert transport._login_token is None
assert transport._token_url is None
with expectation:
await transport.perform_login()
assert transport._login_token == mock_aes_device.token
assert mock_aes_device.token in str(transport._token_url)
@pytest.mark.parametrize(
@@ -136,7 +137,7 @@ async def test_login_errors(mocker, inner_error_codes, expectation, call_count):
transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session
assert transport._login_token is None
assert transport._token_url is None
request = {
"method": "get_device_info",
@@ -148,7 +149,7 @@ async def test_login_errors(mocker, inner_error_codes, expectation, call_count):
with expectation:
await transport.send(json_dumps(request))
assert transport._login_token == mock_aes_device.token
assert mock_aes_device.token in str(transport._token_url)
assert post_mock.call_count == call_count # Login, Handshake, Login
await transport.close()
@@ -165,7 +166,9 @@ async def test_send(mocker, status_code, error_code, inner_error_code, expectati
transport._handshake_done = True
transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session
transport._login_token = mock_aes_device.token
transport._token_url = transport._app_url.with_query(
f"token={mock_aes_device.token}"
)
request = {
"method": "get_device_info",
@@ -193,7 +196,9 @@ async def test_passthrough_errors(mocker, error_code):
transport._handshake_done = True
transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session
transport._login_token = mock_aes_device.token
transport._token_url = transport._app_url.with_query(
f"token={mock_aes_device.token}"
)
request = {
"method": "get_device_info",
@@ -239,13 +244,13 @@ class MockAesDevice:
else:
return self._inner_error_code
async def post(self, url, params=None, json=None, data=None, *_, **__):
async def post(self, url: URL, params=None, json=None, data=None, *_, **__):
if data:
async for item in data:
json = json_loads(item.decode())
return await self._post(url, json)
async def _post(self, url: str, json: Dict[str, Any]):
async def _post(self, url: URL, json: Dict[str, Any]):
if json["method"] == "handshake":
return await self._return_handshake_response(url, json)
elif json["method"] == "securePassthrough":
@@ -253,10 +258,10 @@ class MockAesDevice:
elif json["method"] == "login_device":
return await self._return_login_response(url, json)
else:
assert url == f"http://{self.host}/app?token={self.token}"
assert str(url) == f"http://{self.host}/app?token={self.token}"
return await self._return_send_response(url, json)
async def _return_handshake_response(self, url: str, json: Dict[str, Any]):
async def _return_handshake_response(self, url: URL, json: Dict[str, Any]):
start = len("-----BEGIN PUBLIC KEY-----\n")
end = len("\n-----END PUBLIC KEY-----\n")
client_pub_key = json["params"]["key"][start:-end]
@@ -269,7 +274,7 @@ class MockAesDevice:
self.status_code, {"result": {"key": key_64}, "error_code": self.error_code}
)
async def _return_secure_passthrough_response(self, url: str, json: Dict[str, Any]):
async def _return_secure_passthrough_response(self, url: URL, json: Dict[str, Any]):
encrypted_request = json["params"]["request"]
decrypted_request = self.encryption_session.decrypt(encrypted_request.encode())
decrypted_request_dict = json_loads(decrypted_request)
@@ -286,15 +291,15 @@ class MockAesDevice:
}
return self._mock_response(self.status_code, result)
async def _return_login_response(self, url: str, json: Dict[str, Any]):
if "token=" in url:
async def _return_login_response(self, url: URL, json: Dict[str, Any]):
if "token=" in str(url):
raise Exception("token should not be in url for a login request")
self.token = "".join(random.choices(string.ascii_uppercase, k=32)) # noqa: S311
result = {"result": {"token": self.token}, "error_code": self.inner_error_code}
self.inner_call_count += 1
return self._mock_response(self.status_code, result)
async def _return_send_response(self, url: str, json: Dict[str, Any]):
async def _return_send_response(self, url: URL, json: Dict[str, Any]):
result = {"result": {"method": None}, "error_code": self.inner_error_code}
self.inner_call_count += 1
return self._mock_response(self.status_code, result)