mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-08-06 10:44:04 +00:00
Avoid rebuilding urls for every request (#715)
* Avoid rebuilding urls for every request * more fixes * more fixes * make mypy happy * reduce * tweak * fix tests * fix tests * tweak * tweak * lint * fix type
This commit is contained in:
@@ -12,6 +12,7 @@ import aiohttp
|
||||
import pytest
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
|
||||
from yarl import URL
|
||||
|
||||
from ..aestransport import AesEncyptionSession, AesTransport, TransportState
|
||||
from ..credentials import Credentials
|
||||
@@ -89,10 +90,10 @@ async def test_login(mocker, status_code, error_code, inner_error_code, expectat
|
||||
transport._session_expire_at = time.time() + 86400
|
||||
transport._encryption_session = mock_aes_device.encryption_session
|
||||
|
||||
assert transport._login_token is None
|
||||
assert transport._token_url is None
|
||||
with expectation:
|
||||
await transport.perform_login()
|
||||
assert transport._login_token == mock_aes_device.token
|
||||
assert mock_aes_device.token in str(transport._token_url)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -136,7 +137,7 @@ async def test_login_errors(mocker, inner_error_codes, expectation, call_count):
|
||||
transport._session_expire_at = time.time() + 86400
|
||||
transport._encryption_session = mock_aes_device.encryption_session
|
||||
|
||||
assert transport._login_token is None
|
||||
assert transport._token_url is None
|
||||
|
||||
request = {
|
||||
"method": "get_device_info",
|
||||
@@ -148,7 +149,7 @@ async def test_login_errors(mocker, inner_error_codes, expectation, call_count):
|
||||
|
||||
with expectation:
|
||||
await transport.send(json_dumps(request))
|
||||
assert transport._login_token == mock_aes_device.token
|
||||
assert mock_aes_device.token in str(transport._token_url)
|
||||
assert post_mock.call_count == call_count # Login, Handshake, Login
|
||||
await transport.close()
|
||||
|
||||
@@ -165,7 +166,9 @@ async def test_send(mocker, status_code, error_code, inner_error_code, expectati
|
||||
transport._handshake_done = True
|
||||
transport._session_expire_at = time.time() + 86400
|
||||
transport._encryption_session = mock_aes_device.encryption_session
|
||||
transport._login_token = mock_aes_device.token
|
||||
transport._token_url = transport._app_url.with_query(
|
||||
f"token={mock_aes_device.token}"
|
||||
)
|
||||
|
||||
request = {
|
||||
"method": "get_device_info",
|
||||
@@ -193,7 +196,9 @@ async def test_passthrough_errors(mocker, error_code):
|
||||
transport._handshake_done = True
|
||||
transport._session_expire_at = time.time() + 86400
|
||||
transport._encryption_session = mock_aes_device.encryption_session
|
||||
transport._login_token = mock_aes_device.token
|
||||
transport._token_url = transport._app_url.with_query(
|
||||
f"token={mock_aes_device.token}"
|
||||
)
|
||||
|
||||
request = {
|
||||
"method": "get_device_info",
|
||||
@@ -239,13 +244,13 @@ class MockAesDevice:
|
||||
else:
|
||||
return self._inner_error_code
|
||||
|
||||
async def post(self, url, params=None, json=None, data=None, *_, **__):
|
||||
async def post(self, url: URL, params=None, json=None, data=None, *_, **__):
|
||||
if data:
|
||||
async for item in data:
|
||||
json = json_loads(item.decode())
|
||||
return await self._post(url, json)
|
||||
|
||||
async def _post(self, url: str, json: Dict[str, Any]):
|
||||
async def _post(self, url: URL, json: Dict[str, Any]):
|
||||
if json["method"] == "handshake":
|
||||
return await self._return_handshake_response(url, json)
|
||||
elif json["method"] == "securePassthrough":
|
||||
@@ -253,10 +258,10 @@ class MockAesDevice:
|
||||
elif json["method"] == "login_device":
|
||||
return await self._return_login_response(url, json)
|
||||
else:
|
||||
assert url == f"http://{self.host}/app?token={self.token}"
|
||||
assert str(url) == f"http://{self.host}/app?token={self.token}"
|
||||
return await self._return_send_response(url, json)
|
||||
|
||||
async def _return_handshake_response(self, url: str, json: Dict[str, Any]):
|
||||
async def _return_handshake_response(self, url: URL, json: Dict[str, Any]):
|
||||
start = len("-----BEGIN PUBLIC KEY-----\n")
|
||||
end = len("\n-----END PUBLIC KEY-----\n")
|
||||
client_pub_key = json["params"]["key"][start:-end]
|
||||
@@ -269,7 +274,7 @@ class MockAesDevice:
|
||||
self.status_code, {"result": {"key": key_64}, "error_code": self.error_code}
|
||||
)
|
||||
|
||||
async def _return_secure_passthrough_response(self, url: str, json: Dict[str, Any]):
|
||||
async def _return_secure_passthrough_response(self, url: URL, json: Dict[str, Any]):
|
||||
encrypted_request = json["params"]["request"]
|
||||
decrypted_request = self.encryption_session.decrypt(encrypted_request.encode())
|
||||
decrypted_request_dict = json_loads(decrypted_request)
|
||||
@@ -286,15 +291,15 @@ class MockAesDevice:
|
||||
}
|
||||
return self._mock_response(self.status_code, result)
|
||||
|
||||
async def _return_login_response(self, url: str, json: Dict[str, Any]):
|
||||
if "token=" in url:
|
||||
async def _return_login_response(self, url: URL, json: Dict[str, Any]):
|
||||
if "token=" in str(url):
|
||||
raise Exception("token should not be in url for a login request")
|
||||
self.token = "".join(random.choices(string.ascii_uppercase, k=32)) # noqa: S311
|
||||
result = {"result": {"token": self.token}, "error_code": self.inner_error_code}
|
||||
self.inner_call_count += 1
|
||||
return self._mock_response(self.status_code, result)
|
||||
|
||||
async def _return_send_response(self, url: str, json: Dict[str, Any]):
|
||||
async def _return_send_response(self, url: URL, json: Dict[str, Any]):
|
||||
result = {"result": {"method": None}, "error_code": self.inner_error_code}
|
||||
self.inner_call_count += 1
|
||||
return self._mock_response(self.status_code, result)
|
||||
|
@@ -84,7 +84,7 @@ async def test_httpclient_errors(mocker, error, error_raises, error_message, moc
|
||||
client = HttpClient(DeviceConfig(host))
|
||||
# Exceptions with parameters print with double quotes, without use single quotes
|
||||
full_msg = (
|
||||
"\("
|
||||
"\(" # type: ignore
|
||||
+ "['\"]"
|
||||
+ re.escape(f"{error_message}{host}: {error}")
|
||||
+ "['\"]"
|
||||
|
@@ -10,6 +10,7 @@ from unittest.mock import PropertyMock
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
from yarl import URL
|
||||
|
||||
from ..aestransport import AesTransport
|
||||
from ..credentials import Credentials
|
||||
@@ -318,28 +319,28 @@ async def test_handshake1(
|
||||
async def test_handshake(
|
||||
mocker, transport_class, seed_auth_hash_calc1, seed_auth_hash_calc2
|
||||
):
|
||||
async def _return_handshake_response(url, params=None, data=None, *_, **__):
|
||||
client_seed = None
|
||||
server_seed = secrets.token_bytes(16)
|
||||
client_credentials = Credentials("foo", "bar")
|
||||
device_auth_hash = transport_class.generate_auth_hash(client_credentials)
|
||||
|
||||
async def _return_handshake_response(url: URL, params=None, data=None, *_, **__):
|
||||
nonlocal client_seed, server_seed, device_auth_hash
|
||||
|
||||
if url == "http://127.0.0.1/app/handshake1":
|
||||
if str(url) == "http://127.0.0.1/app/handshake1":
|
||||
client_seed = data
|
||||
seed_auth_hash = _sha256(
|
||||
seed_auth_hash_calc1(client_seed, server_seed, device_auth_hash)
|
||||
)
|
||||
|
||||
return _mock_response(200, server_seed + seed_auth_hash)
|
||||
elif url == "http://127.0.0.1/app/handshake2":
|
||||
elif str(url) == "http://127.0.0.1/app/handshake2":
|
||||
seed_auth_hash = _sha256(
|
||||
seed_auth_hash_calc2(client_seed, server_seed, device_auth_hash)
|
||||
)
|
||||
assert data == seed_auth_hash
|
||||
return _mock_response(response_status, b"")
|
||||
|
||||
client_seed = None
|
||||
server_seed = secrets.token_bytes(16)
|
||||
client_credentials = Credentials("foo", "bar")
|
||||
device_auth_hash = transport_class.generate_auth_hash(client_credentials)
|
||||
|
||||
mocker.patch.object(
|
||||
aiohttp.ClientSession, "post", side_effect=_return_handshake_response
|
||||
)
|
||||
@@ -360,17 +361,24 @@ async def test_handshake(
|
||||
|
||||
|
||||
async def test_query(mocker):
|
||||
async def _return_response(url, params=None, data=None, *_, **__):
|
||||
client_seed = None
|
||||
last_seq = None
|
||||
seq = None
|
||||
server_seed = secrets.token_bytes(16)
|
||||
client_credentials = Credentials("foo", "bar")
|
||||
device_auth_hash = KlapTransport.generate_auth_hash(client_credentials)
|
||||
|
||||
async def _return_response(url: URL, params=None, data=None, *_, **__):
|
||||
nonlocal client_seed, server_seed, device_auth_hash, seq
|
||||
|
||||
if url == "http://127.0.0.1/app/handshake1":
|
||||
if str(url) == "http://127.0.0.1/app/handshake1":
|
||||
client_seed = data
|
||||
client_seed_auth_hash = _sha256(data + device_auth_hash)
|
||||
|
||||
return _mock_response(200, server_seed + client_seed_auth_hash)
|
||||
elif url == "http://127.0.0.1/app/handshake2":
|
||||
elif str(url) == "http://127.0.0.1/app/handshake2":
|
||||
return _mock_response(200, b"")
|
||||
elif url == "http://127.0.0.1/app/request":
|
||||
elif str(url) == "http://127.0.0.1/app/request":
|
||||
encryption_session = KlapEncryptionSession(
|
||||
protocol._transport._encryption_session.local_seed,
|
||||
protocol._transport._encryption_session.remote_seed,
|
||||
@@ -382,13 +390,6 @@ async def test_query(mocker):
|
||||
seq = seq
|
||||
return _mock_response(200, encrypted)
|
||||
|
||||
client_seed = None
|
||||
last_seq = None
|
||||
seq = None
|
||||
server_seed = secrets.token_bytes(16)
|
||||
client_credentials = Credentials("foo", "bar")
|
||||
device_auth_hash = KlapTransport.generate_auth_hash(client_credentials)
|
||||
|
||||
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=_return_response)
|
||||
|
||||
config = DeviceConfig("127.0.0.1", credentials=client_credentials)
|
||||
@@ -413,26 +414,26 @@ async def test_query(mocker):
|
||||
ids=("handshake1", "handshake2", "request", "non_auth_error"),
|
||||
)
|
||||
async def test_authentication_failures(mocker, response_status, expectation):
|
||||
async def _return_response(url, params=None, data=None, *_, **__):
|
||||
client_seed = None
|
||||
|
||||
server_seed = secrets.token_bytes(16)
|
||||
client_credentials = Credentials("foo", "bar")
|
||||
device_auth_hash = KlapTransport.generate_auth_hash(client_credentials)
|
||||
|
||||
async def _return_response(url: URL, params=None, data=None, *_, **__):
|
||||
nonlocal client_seed, server_seed, device_auth_hash, response_status
|
||||
|
||||
if url == "http://127.0.0.1/app/handshake1":
|
||||
if str(url) == "http://127.0.0.1/app/handshake1":
|
||||
client_seed = data
|
||||
client_seed_auth_hash = _sha256(data + device_auth_hash)
|
||||
|
||||
return _mock_response(
|
||||
response_status[0], server_seed + client_seed_auth_hash
|
||||
)
|
||||
elif url == "http://127.0.0.1/app/handshake2":
|
||||
elif str(url) == "http://127.0.0.1/app/handshake2":
|
||||
return _mock_response(response_status[1], b"")
|
||||
elif url == "http://127.0.0.1/app/request":
|
||||
return _mock_response(response_status[2], None)
|
||||
|
||||
client_seed = None
|
||||
|
||||
server_seed = secrets.token_bytes(16)
|
||||
client_credentials = Credentials("foo", "bar")
|
||||
device_auth_hash = KlapTransport.generate_auth_hash(client_credentials)
|
||||
elif str(url) == "http://127.0.0.1/app/request":
|
||||
return _mock_response(response_status[2], b"")
|
||||
|
||||
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=_return_response)
|
||||
|
||||
|
Reference in New Issue
Block a user