mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-01-08 22:07:06 +00:00
Avoid rebuilding urls for every request (#715)
* Avoid rebuilding urls for every request * more fixes * more fixes * make mypy happy * reduce * tweak * fix tests * fix tests * tweak * tweak * lint * fix type
This commit is contained in:
parent
69dcc0d8bb
commit
b479b6d84d
@ -15,6 +15,7 @@ from cryptography.hazmat.primitives import padding, serialization
|
|||||||
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
|
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
|
||||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||||
|
from yarl import URL
|
||||||
|
|
||||||
from .credentials import Credentials
|
from .credentials import Credentials
|
||||||
from .deviceconfig import DeviceConfig
|
from .deviceconfig import DeviceConfig
|
||||||
@ -100,9 +101,9 @@ class AesTransport(BaseTransport):
|
|||||||
|
|
||||||
self._session_cookie: Optional[Dict[str, str]] = None
|
self._session_cookie: Optional[Dict[str, str]] = None
|
||||||
|
|
||||||
self._login_token: Optional[str] = None
|
|
||||||
|
|
||||||
self._key_pair: Optional[KeyPair] = None
|
self._key_pair: Optional[KeyPair] = None
|
||||||
|
self._app_url = URL(f"http://{self._host}/app")
|
||||||
|
self._token_url: Optional[URL] = None
|
||||||
|
|
||||||
_LOGGER.debug("Created AES transport for %s", self._host)
|
_LOGGER.debug("Created AES transport for %s", self._host)
|
||||||
|
|
||||||
@ -150,9 +151,10 @@ class AesTransport(BaseTransport):
|
|||||||
|
|
||||||
async def send_secure_passthrough(self, request: str) -> Dict[str, Any]:
|
async def send_secure_passthrough(self, request: str) -> Dict[str, Any]:
|
||||||
"""Send encrypted message as passthrough."""
|
"""Send encrypted message as passthrough."""
|
||||||
url = f"http://{self._host}/app"
|
if self._state is TransportState.ESTABLISHED and self._token_url:
|
||||||
if self._state is TransportState.ESTABLISHED and self._login_token:
|
url = self._token_url
|
||||||
url += f"?token={self._login_token}"
|
else:
|
||||||
|
url = self._app_url
|
||||||
|
|
||||||
encrypted_payload = self._encryption_session.encrypt(request.encode()) # type: ignore
|
encrypted_payload = self._encryption_session.encrypt(request.encode()) # type: ignore
|
||||||
passthrough_request = {
|
passthrough_request = {
|
||||||
@ -223,7 +225,8 @@ class AesTransport(BaseTransport):
|
|||||||
|
|
||||||
resp_dict = await self.send_secure_passthrough(request)
|
resp_dict = await self.send_secure_passthrough(request)
|
||||||
self._handle_response_error_code(resp_dict, "Error logging in")
|
self._handle_response_error_code(resp_dict, "Error logging in")
|
||||||
self._login_token = resp_dict["result"]["token"]
|
login_token = resp_dict["result"]["token"]
|
||||||
|
self._token_url = self._app_url.with_query(f"token={login_token}")
|
||||||
self._state = TransportState.ESTABLISHED
|
self._state = TransportState.ESTABLISHED
|
||||||
|
|
||||||
async def _generate_key_pair_payload(self) -> AsyncGenerator:
|
async def _generate_key_pair_payload(self) -> AsyncGenerator:
|
||||||
@ -250,7 +253,7 @@ class AesTransport(BaseTransport):
|
|||||||
_LOGGER.debug("Will perform handshaking...")
|
_LOGGER.debug("Will perform handshaking...")
|
||||||
|
|
||||||
self._key_pair = None
|
self._key_pair = None
|
||||||
self._login_token = None
|
self._token_url = None
|
||||||
self._session_expire_at = None
|
self._session_expire_at = None
|
||||||
self._session_cookie = None
|
self._session_cookie = None
|
||||||
|
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
"""python-kasa exceptions."""
|
"""python-kasa exceptions."""
|
||||||
from asyncio import TimeoutError
|
from asyncio import TimeoutError
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
|
||||||
class SmartDeviceException(Exception):
|
class SmartDeviceException(Exception):
|
||||||
"""Base exception for device errors."""
|
"""Base exception for device errors."""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
self.error_code: Optional["SmartErrorCode"] = kwargs.get("error_code", None)
|
self.error_code: Optional["SmartErrorCode"] = kwargs.get("error_code", None)
|
||||||
super().__init__(*args)
|
super().__init__(*args)
|
||||||
|
|
||||||
@ -15,7 +15,7 @@ class SmartDeviceException(Exception):
|
|||||||
class UnsupportedDeviceException(SmartDeviceException):
|
class UnsupportedDeviceException(SmartDeviceException):
|
||||||
"""Exception for trying to connect to unsupported devices."""
|
"""Exception for trying to connect to unsupported devices."""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
self.discovery_result = kwargs.get("discovery_result")
|
self.discovery_result = kwargs.get("discovery_result")
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@ import asyncio
|
|||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
from yarl import URL
|
||||||
|
|
||||||
from .deviceconfig import DeviceConfig
|
from .deviceconfig import DeviceConfig
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
@ -25,7 +26,7 @@ class HttpClient:
|
|||||||
self._config = config
|
self._config = config
|
||||||
self._client_session: aiohttp.ClientSession = None
|
self._client_session: aiohttp.ClientSession = None
|
||||||
self._jar = aiohttp.CookieJar(unsafe=True, quote_cookie=False)
|
self._jar = aiohttp.CookieJar(unsafe=True, quote_cookie=False)
|
||||||
self._last_url = f"http://{self._config.host}/"
|
self._last_url = URL(f"http://{self._config.host}/")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def client(self) -> aiohttp.ClientSession:
|
def client(self) -> aiohttp.ClientSession:
|
||||||
@ -41,7 +42,7 @@ class HttpClient:
|
|||||||
|
|
||||||
async def post(
|
async def post(
|
||||||
self,
|
self,
|
||||||
url: str,
|
url: URL,
|
||||||
*,
|
*,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
data: Optional[bytes] = None,
|
data: Optional[bytes] = None,
|
||||||
|
@ -53,6 +53,7 @@ from typing import Any, Dict, Optional, Tuple, cast
|
|||||||
|
|
||||||
from cryptography.hazmat.primitives import padding
|
from cryptography.hazmat.primitives import padding
|
||||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||||
|
from yarl import URL
|
||||||
|
|
||||||
from .credentials import Credentials
|
from .credentials import Credentials
|
||||||
from .deviceconfig import DeviceConfig
|
from .deviceconfig import DeviceConfig
|
||||||
@ -120,6 +121,8 @@ class KlapTransport(BaseTransport):
|
|||||||
self._session_cookie: Optional[Dict[str, Any]] = None
|
self._session_cookie: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
_LOGGER.debug("Created KLAP transport for %s", self._host)
|
_LOGGER.debug("Created KLAP transport for %s", self._host)
|
||||||
|
self._app_url = URL(f"http://{self._host}/app")
|
||||||
|
self._request_url = self._app_url / "request"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def default_port(self):
|
def default_port(self):
|
||||||
@ -141,7 +144,7 @@ class KlapTransport(BaseTransport):
|
|||||||
|
|
||||||
payload = local_seed
|
payload = local_seed
|
||||||
|
|
||||||
url = f"http://{self._host}/app/handshake1"
|
url = self._app_url / "handshake1"
|
||||||
|
|
||||||
response_status, response_data = await self._http_client.post(url, data=payload)
|
response_status, response_data = await self._http_client.post(url, data=payload)
|
||||||
|
|
||||||
@ -236,7 +239,7 @@ class KlapTransport(BaseTransport):
|
|||||||
# Handshake 2 has the following payload:
|
# Handshake 2 has the following payload:
|
||||||
# sha256(serverBytes | authenticator)
|
# sha256(serverBytes | authenticator)
|
||||||
|
|
||||||
url = f"http://{self._host}/app/handshake2"
|
url = self._app_url / "handshake2"
|
||||||
|
|
||||||
payload = self.handshake2_seed_auth_hash(local_seed, remote_seed, auth_hash)
|
payload = self.handshake2_seed_auth_hash(local_seed, remote_seed, auth_hash)
|
||||||
|
|
||||||
@ -309,10 +312,8 @@ class KlapTransport(BaseTransport):
|
|||||||
if self._encryption_session is not None:
|
if self._encryption_session is not None:
|
||||||
payload, seq = self._encryption_session.encrypt(request.encode())
|
payload, seq = self._encryption_session.encrypt(request.encode())
|
||||||
|
|
||||||
url = f"http://{self._host}/app/request"
|
|
||||||
|
|
||||||
response_status, response_data = await self._http_client.post(
|
response_status, response_data = await self._http_client.post(
|
||||||
url,
|
self._request_url,
|
||||||
params={"seq": seq},
|
params={"seq": seq},
|
||||||
data=payload,
|
data=payload,
|
||||||
cookies_dict=self._session_cookie,
|
cookies_dict=self._session_cookie,
|
||||||
|
@ -12,6 +12,7 @@ import aiohttp
|
|||||||
import pytest
|
import pytest
|
||||||
from cryptography.hazmat.primitives import serialization
|
from cryptography.hazmat.primitives import serialization
|
||||||
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
|
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
|
||||||
|
from yarl import URL
|
||||||
|
|
||||||
from ..aestransport import AesEncyptionSession, AesTransport, TransportState
|
from ..aestransport import AesEncyptionSession, AesTransport, TransportState
|
||||||
from ..credentials import Credentials
|
from ..credentials import Credentials
|
||||||
@ -89,10 +90,10 @@ async def test_login(mocker, status_code, error_code, inner_error_code, expectat
|
|||||||
transport._session_expire_at = time.time() + 86400
|
transport._session_expire_at = time.time() + 86400
|
||||||
transport._encryption_session = mock_aes_device.encryption_session
|
transport._encryption_session = mock_aes_device.encryption_session
|
||||||
|
|
||||||
assert transport._login_token is None
|
assert transport._token_url is None
|
||||||
with expectation:
|
with expectation:
|
||||||
await transport.perform_login()
|
await transport.perform_login()
|
||||||
assert transport._login_token == mock_aes_device.token
|
assert mock_aes_device.token in str(transport._token_url)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -136,7 +137,7 @@ async def test_login_errors(mocker, inner_error_codes, expectation, call_count):
|
|||||||
transport._session_expire_at = time.time() + 86400
|
transport._session_expire_at = time.time() + 86400
|
||||||
transport._encryption_session = mock_aes_device.encryption_session
|
transport._encryption_session = mock_aes_device.encryption_session
|
||||||
|
|
||||||
assert transport._login_token is None
|
assert transport._token_url is None
|
||||||
|
|
||||||
request = {
|
request = {
|
||||||
"method": "get_device_info",
|
"method": "get_device_info",
|
||||||
@ -148,7 +149,7 @@ async def test_login_errors(mocker, inner_error_codes, expectation, call_count):
|
|||||||
|
|
||||||
with expectation:
|
with expectation:
|
||||||
await transport.send(json_dumps(request))
|
await transport.send(json_dumps(request))
|
||||||
assert transport._login_token == mock_aes_device.token
|
assert mock_aes_device.token in str(transport._token_url)
|
||||||
assert post_mock.call_count == call_count # Login, Handshake, Login
|
assert post_mock.call_count == call_count # Login, Handshake, Login
|
||||||
await transport.close()
|
await transport.close()
|
||||||
|
|
||||||
@ -165,7 +166,9 @@ async def test_send(mocker, status_code, error_code, inner_error_code, expectati
|
|||||||
transport._handshake_done = True
|
transport._handshake_done = True
|
||||||
transport._session_expire_at = time.time() + 86400
|
transport._session_expire_at = time.time() + 86400
|
||||||
transport._encryption_session = mock_aes_device.encryption_session
|
transport._encryption_session = mock_aes_device.encryption_session
|
||||||
transport._login_token = mock_aes_device.token
|
transport._token_url = transport._app_url.with_query(
|
||||||
|
f"token={mock_aes_device.token}"
|
||||||
|
)
|
||||||
|
|
||||||
request = {
|
request = {
|
||||||
"method": "get_device_info",
|
"method": "get_device_info",
|
||||||
@ -193,7 +196,9 @@ async def test_passthrough_errors(mocker, error_code):
|
|||||||
transport._handshake_done = True
|
transport._handshake_done = True
|
||||||
transport._session_expire_at = time.time() + 86400
|
transport._session_expire_at = time.time() + 86400
|
||||||
transport._encryption_session = mock_aes_device.encryption_session
|
transport._encryption_session = mock_aes_device.encryption_session
|
||||||
transport._login_token = mock_aes_device.token
|
transport._token_url = transport._app_url.with_query(
|
||||||
|
f"token={mock_aes_device.token}"
|
||||||
|
)
|
||||||
|
|
||||||
request = {
|
request = {
|
||||||
"method": "get_device_info",
|
"method": "get_device_info",
|
||||||
@ -239,13 +244,13 @@ class MockAesDevice:
|
|||||||
else:
|
else:
|
||||||
return self._inner_error_code
|
return self._inner_error_code
|
||||||
|
|
||||||
async def post(self, url, params=None, json=None, data=None, *_, **__):
|
async def post(self, url: URL, params=None, json=None, data=None, *_, **__):
|
||||||
if data:
|
if data:
|
||||||
async for item in data:
|
async for item in data:
|
||||||
json = json_loads(item.decode())
|
json = json_loads(item.decode())
|
||||||
return await self._post(url, json)
|
return await self._post(url, json)
|
||||||
|
|
||||||
async def _post(self, url: str, json: Dict[str, Any]):
|
async def _post(self, url: URL, json: Dict[str, Any]):
|
||||||
if json["method"] == "handshake":
|
if json["method"] == "handshake":
|
||||||
return await self._return_handshake_response(url, json)
|
return await self._return_handshake_response(url, json)
|
||||||
elif json["method"] == "securePassthrough":
|
elif json["method"] == "securePassthrough":
|
||||||
@ -253,10 +258,10 @@ class MockAesDevice:
|
|||||||
elif json["method"] == "login_device":
|
elif json["method"] == "login_device":
|
||||||
return await self._return_login_response(url, json)
|
return await self._return_login_response(url, json)
|
||||||
else:
|
else:
|
||||||
assert url == f"http://{self.host}/app?token={self.token}"
|
assert str(url) == f"http://{self.host}/app?token={self.token}"
|
||||||
return await self._return_send_response(url, json)
|
return await self._return_send_response(url, json)
|
||||||
|
|
||||||
async def _return_handshake_response(self, url: str, json: Dict[str, Any]):
|
async def _return_handshake_response(self, url: URL, json: Dict[str, Any]):
|
||||||
start = len("-----BEGIN PUBLIC KEY-----\n")
|
start = len("-----BEGIN PUBLIC KEY-----\n")
|
||||||
end = len("\n-----END PUBLIC KEY-----\n")
|
end = len("\n-----END PUBLIC KEY-----\n")
|
||||||
client_pub_key = json["params"]["key"][start:-end]
|
client_pub_key = json["params"]["key"][start:-end]
|
||||||
@ -269,7 +274,7 @@ class MockAesDevice:
|
|||||||
self.status_code, {"result": {"key": key_64}, "error_code": self.error_code}
|
self.status_code, {"result": {"key": key_64}, "error_code": self.error_code}
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _return_secure_passthrough_response(self, url: str, json: Dict[str, Any]):
|
async def _return_secure_passthrough_response(self, url: URL, json: Dict[str, Any]):
|
||||||
encrypted_request = json["params"]["request"]
|
encrypted_request = json["params"]["request"]
|
||||||
decrypted_request = self.encryption_session.decrypt(encrypted_request.encode())
|
decrypted_request = self.encryption_session.decrypt(encrypted_request.encode())
|
||||||
decrypted_request_dict = json_loads(decrypted_request)
|
decrypted_request_dict = json_loads(decrypted_request)
|
||||||
@ -286,15 +291,15 @@ class MockAesDevice:
|
|||||||
}
|
}
|
||||||
return self._mock_response(self.status_code, result)
|
return self._mock_response(self.status_code, result)
|
||||||
|
|
||||||
async def _return_login_response(self, url: str, json: Dict[str, Any]):
|
async def _return_login_response(self, url: URL, json: Dict[str, Any]):
|
||||||
if "token=" in url:
|
if "token=" in str(url):
|
||||||
raise Exception("token should not be in url for a login request")
|
raise Exception("token should not be in url for a login request")
|
||||||
self.token = "".join(random.choices(string.ascii_uppercase, k=32)) # noqa: S311
|
self.token = "".join(random.choices(string.ascii_uppercase, k=32)) # noqa: S311
|
||||||
result = {"result": {"token": self.token}, "error_code": self.inner_error_code}
|
result = {"result": {"token": self.token}, "error_code": self.inner_error_code}
|
||||||
self.inner_call_count += 1
|
self.inner_call_count += 1
|
||||||
return self._mock_response(self.status_code, result)
|
return self._mock_response(self.status_code, result)
|
||||||
|
|
||||||
async def _return_send_response(self, url: str, json: Dict[str, Any]):
|
async def _return_send_response(self, url: URL, json: Dict[str, Any]):
|
||||||
result = {"result": {"method": None}, "error_code": self.inner_error_code}
|
result = {"result": {"method": None}, "error_code": self.inner_error_code}
|
||||||
self.inner_call_count += 1
|
self.inner_call_count += 1
|
||||||
return self._mock_response(self.status_code, result)
|
return self._mock_response(self.status_code, result)
|
||||||
|
@ -84,7 +84,7 @@ async def test_httpclient_errors(mocker, error, error_raises, error_message, moc
|
|||||||
client = HttpClient(DeviceConfig(host))
|
client = HttpClient(DeviceConfig(host))
|
||||||
# Exceptions with parameters print with double quotes, without use single quotes
|
# Exceptions with parameters print with double quotes, without use single quotes
|
||||||
full_msg = (
|
full_msg = (
|
||||||
"\("
|
"\(" # type: ignore
|
||||||
+ "['\"]"
|
+ "['\"]"
|
||||||
+ re.escape(f"{error_message}{host}: {error}")
|
+ re.escape(f"{error_message}{host}: {error}")
|
||||||
+ "['\"]"
|
+ "['\"]"
|
||||||
|
@ -10,6 +10,7 @@ from unittest.mock import PropertyMock
|
|||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import pytest
|
import pytest
|
||||||
|
from yarl import URL
|
||||||
|
|
||||||
from ..aestransport import AesTransport
|
from ..aestransport import AesTransport
|
||||||
from ..credentials import Credentials
|
from ..credentials import Credentials
|
||||||
@ -318,28 +319,28 @@ async def test_handshake1(
|
|||||||
async def test_handshake(
|
async def test_handshake(
|
||||||
mocker, transport_class, seed_auth_hash_calc1, seed_auth_hash_calc2
|
mocker, transport_class, seed_auth_hash_calc1, seed_auth_hash_calc2
|
||||||
):
|
):
|
||||||
async def _return_handshake_response(url, params=None, data=None, *_, **__):
|
client_seed = None
|
||||||
|
server_seed = secrets.token_bytes(16)
|
||||||
|
client_credentials = Credentials("foo", "bar")
|
||||||
|
device_auth_hash = transport_class.generate_auth_hash(client_credentials)
|
||||||
|
|
||||||
|
async def _return_handshake_response(url: URL, params=None, data=None, *_, **__):
|
||||||
nonlocal client_seed, server_seed, device_auth_hash
|
nonlocal client_seed, server_seed, device_auth_hash
|
||||||
|
|
||||||
if url == "http://127.0.0.1/app/handshake1":
|
if str(url) == "http://127.0.0.1/app/handshake1":
|
||||||
client_seed = data
|
client_seed = data
|
||||||
seed_auth_hash = _sha256(
|
seed_auth_hash = _sha256(
|
||||||
seed_auth_hash_calc1(client_seed, server_seed, device_auth_hash)
|
seed_auth_hash_calc1(client_seed, server_seed, device_auth_hash)
|
||||||
)
|
)
|
||||||
|
|
||||||
return _mock_response(200, server_seed + seed_auth_hash)
|
return _mock_response(200, server_seed + seed_auth_hash)
|
||||||
elif url == "http://127.0.0.1/app/handshake2":
|
elif str(url) == "http://127.0.0.1/app/handshake2":
|
||||||
seed_auth_hash = _sha256(
|
seed_auth_hash = _sha256(
|
||||||
seed_auth_hash_calc2(client_seed, server_seed, device_auth_hash)
|
seed_auth_hash_calc2(client_seed, server_seed, device_auth_hash)
|
||||||
)
|
)
|
||||||
assert data == seed_auth_hash
|
assert data == seed_auth_hash
|
||||||
return _mock_response(response_status, b"")
|
return _mock_response(response_status, b"")
|
||||||
|
|
||||||
client_seed = None
|
|
||||||
server_seed = secrets.token_bytes(16)
|
|
||||||
client_credentials = Credentials("foo", "bar")
|
|
||||||
device_auth_hash = transport_class.generate_auth_hash(client_credentials)
|
|
||||||
|
|
||||||
mocker.patch.object(
|
mocker.patch.object(
|
||||||
aiohttp.ClientSession, "post", side_effect=_return_handshake_response
|
aiohttp.ClientSession, "post", side_effect=_return_handshake_response
|
||||||
)
|
)
|
||||||
@ -360,17 +361,24 @@ async def test_handshake(
|
|||||||
|
|
||||||
|
|
||||||
async def test_query(mocker):
|
async def test_query(mocker):
|
||||||
async def _return_response(url, params=None, data=None, *_, **__):
|
client_seed = None
|
||||||
|
last_seq = None
|
||||||
|
seq = None
|
||||||
|
server_seed = secrets.token_bytes(16)
|
||||||
|
client_credentials = Credentials("foo", "bar")
|
||||||
|
device_auth_hash = KlapTransport.generate_auth_hash(client_credentials)
|
||||||
|
|
||||||
|
async def _return_response(url: URL, params=None, data=None, *_, **__):
|
||||||
nonlocal client_seed, server_seed, device_auth_hash, seq
|
nonlocal client_seed, server_seed, device_auth_hash, seq
|
||||||
|
|
||||||
if url == "http://127.0.0.1/app/handshake1":
|
if str(url) == "http://127.0.0.1/app/handshake1":
|
||||||
client_seed = data
|
client_seed = data
|
||||||
client_seed_auth_hash = _sha256(data + device_auth_hash)
|
client_seed_auth_hash = _sha256(data + device_auth_hash)
|
||||||
|
|
||||||
return _mock_response(200, server_seed + client_seed_auth_hash)
|
return _mock_response(200, server_seed + client_seed_auth_hash)
|
||||||
elif url == "http://127.0.0.1/app/handshake2":
|
elif str(url) == "http://127.0.0.1/app/handshake2":
|
||||||
return _mock_response(200, b"")
|
return _mock_response(200, b"")
|
||||||
elif url == "http://127.0.0.1/app/request":
|
elif str(url) == "http://127.0.0.1/app/request":
|
||||||
encryption_session = KlapEncryptionSession(
|
encryption_session = KlapEncryptionSession(
|
||||||
protocol._transport._encryption_session.local_seed,
|
protocol._transport._encryption_session.local_seed,
|
||||||
protocol._transport._encryption_session.remote_seed,
|
protocol._transport._encryption_session.remote_seed,
|
||||||
@ -382,13 +390,6 @@ async def test_query(mocker):
|
|||||||
seq = seq
|
seq = seq
|
||||||
return _mock_response(200, encrypted)
|
return _mock_response(200, encrypted)
|
||||||
|
|
||||||
client_seed = None
|
|
||||||
last_seq = None
|
|
||||||
seq = None
|
|
||||||
server_seed = secrets.token_bytes(16)
|
|
||||||
client_credentials = Credentials("foo", "bar")
|
|
||||||
device_auth_hash = KlapTransport.generate_auth_hash(client_credentials)
|
|
||||||
|
|
||||||
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=_return_response)
|
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=_return_response)
|
||||||
|
|
||||||
config = DeviceConfig("127.0.0.1", credentials=client_credentials)
|
config = DeviceConfig("127.0.0.1", credentials=client_credentials)
|
||||||
@ -413,26 +414,26 @@ async def test_query(mocker):
|
|||||||
ids=("handshake1", "handshake2", "request", "non_auth_error"),
|
ids=("handshake1", "handshake2", "request", "non_auth_error"),
|
||||||
)
|
)
|
||||||
async def test_authentication_failures(mocker, response_status, expectation):
|
async def test_authentication_failures(mocker, response_status, expectation):
|
||||||
async def _return_response(url, params=None, data=None, *_, **__):
|
client_seed = None
|
||||||
|
|
||||||
|
server_seed = secrets.token_bytes(16)
|
||||||
|
client_credentials = Credentials("foo", "bar")
|
||||||
|
device_auth_hash = KlapTransport.generate_auth_hash(client_credentials)
|
||||||
|
|
||||||
|
async def _return_response(url: URL, params=None, data=None, *_, **__):
|
||||||
nonlocal client_seed, server_seed, device_auth_hash, response_status
|
nonlocal client_seed, server_seed, device_auth_hash, response_status
|
||||||
|
|
||||||
if url == "http://127.0.0.1/app/handshake1":
|
if str(url) == "http://127.0.0.1/app/handshake1":
|
||||||
client_seed = data
|
client_seed = data
|
||||||
client_seed_auth_hash = _sha256(data + device_auth_hash)
|
client_seed_auth_hash = _sha256(data + device_auth_hash)
|
||||||
|
|
||||||
return _mock_response(
|
return _mock_response(
|
||||||
response_status[0], server_seed + client_seed_auth_hash
|
response_status[0], server_seed + client_seed_auth_hash
|
||||||
)
|
)
|
||||||
elif url == "http://127.0.0.1/app/handshake2":
|
elif str(url) == "http://127.0.0.1/app/handshake2":
|
||||||
return _mock_response(response_status[1], b"")
|
return _mock_response(response_status[1], b"")
|
||||||
elif url == "http://127.0.0.1/app/request":
|
elif str(url) == "http://127.0.0.1/app/request":
|
||||||
return _mock_response(response_status[2], None)
|
return _mock_response(response_status[2], b"")
|
||||||
|
|
||||||
client_seed = None
|
|
||||||
|
|
||||||
server_seed = secrets.token_bytes(16)
|
|
||||||
client_credentials = Credentials("foo", "bar")
|
|
||||||
device_auth_hash = KlapTransport.generate_auth_hash(client_credentials)
|
|
||||||
|
|
||||||
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=_return_response)
|
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=_return_response)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user