Ensure login token is only sent if aes state is ESTABLISHED (#702)

This commit is contained in:
J. Nick Koston 2024-01-24 09:43:42 -10:00 committed by GitHub
parent aecf0ecd8a
commit 3df837cc82
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 10 deletions

View File

@ -151,7 +151,7 @@ class AesTransport(BaseTransport):
async def send_secure_passthrough(self, request: str) -> Dict[str, Any]: async def send_secure_passthrough(self, request: str) -> Dict[str, Any]:
"""Send encrypted message as passthrough.""" """Send encrypted message as passthrough."""
url = f"http://{self._host}/app" url = f"http://{self._host}/app"
if self._login_token: if self._state is TransportState.ESTABLISHED and self._login_token:
url += f"?token={self._login_token}" url += f"?token={self._login_token}"
encrypted_payload = self._encryption_session.encrypt(request.encode()) # type: ignore encrypted_payload = self._encryption_session.encrypt(request.encode()) # type: ignore
@ -250,6 +250,7 @@ class AesTransport(BaseTransport):
_LOGGER.debug("Will perform handshaking...") _LOGGER.debug("Will perform handshaking...")
self._key_pair = None self._key_pair = None
self._login_token = None
self._session_expire_at = None self._session_expire_at = None
self._session_cookie = None self._session_cookie = None
@ -284,9 +285,7 @@ class AesTransport(BaseTransport):
handshake_key = resp_dict["result"]["key"] handshake_key = resp_dict["result"]["key"]
if ( if (
cookie := http_client.get_cookie( # type: ignore cookie := http_client.get_cookie(self.SESSION_COOKIE_NAME) # type: ignore
self.SESSION_COOKIE_NAME
)
) or ( ) or (
cookie := http_client.get_cookie("SESSIONID") # type: ignore cookie := http_client.get_cookie("SESSIONID") # type: ignore
): ):

View File

@ -1,9 +1,12 @@
import base64 import base64
import json import json
import random
import string
import time import time
from contextlib import nullcontext as does_not_raise from contextlib import nullcontext as does_not_raise
from json import dumps as json_dumps from json import dumps as json_dumps
from json import loads as json_loads from json import loads as json_loads
from typing import Any, Dict, Optional
import aiohttp import aiohttp
import pytest import pytest
@ -219,7 +222,6 @@ class MockAesDevice:
return json_dumps(self._json).encode() return json_dumps(self._json).encode()
encryption_session = AesEncyptionSession(KEY_IV[:16], KEY_IV[16:]) encryption_session = AesEncyptionSession(KEY_IV[:16], KEY_IV[16:])
token = "test_token" # noqa
def __init__(self, host, status_code=200, error_code=0, inner_error_code=0): def __init__(self, host, status_code=200, error_code=0, inner_error_code=0):
self.host = host self.host = host
@ -228,6 +230,7 @@ class MockAesDevice:
self._inner_error_code = inner_error_code self._inner_error_code = inner_error_code
self.http_client = HttpClient(DeviceConfig(self.host)) self.http_client = HttpClient(DeviceConfig(self.host))
self.inner_call_count = 0 self.inner_call_count = 0
self.token = "".join(random.choices(string.ascii_uppercase, k=32)) # noqa: S311
@property @property
def inner_error_code(self): def inner_error_code(self):
@ -242,7 +245,7 @@ class MockAesDevice:
json = json_loads(item.decode()) json = json_loads(item.decode())
return await self._post(url, json) return await self._post(url, json)
async def _post(self, url, json): async def _post(self, url: str, json: Dict[str, Any]):
if json["method"] == "handshake": if json["method"] == "handshake":
return await self._return_handshake_response(url, json) return await self._return_handshake_response(url, json)
elif json["method"] == "securePassthrough": elif json["method"] == "securePassthrough":
@ -253,7 +256,7 @@ class MockAesDevice:
assert url == f"http://{self.host}/app?token={self.token}" assert url == f"http://{self.host}/app?token={self.token}"
return await self._return_send_response(url, json) return await self._return_send_response(url, json)
async def _return_handshake_response(self, url, json): async def _return_handshake_response(self, url: str, json: Dict[str, Any]):
start = len("-----BEGIN PUBLIC KEY-----\n") start = len("-----BEGIN PUBLIC KEY-----\n")
end = len("\n-----END PUBLIC KEY-----\n") end = len("\n-----END PUBLIC KEY-----\n")
client_pub_key = json["params"]["key"][start:-end] client_pub_key = json["params"]["key"][start:-end]
@ -266,7 +269,7 @@ class MockAesDevice:
self.status_code, {"result": {"key": key_64}, "error_code": self.error_code} self.status_code, {"result": {"key": key_64}, "error_code": self.error_code}
) )
async def _return_secure_passthrough_response(self, url, json): async def _return_secure_passthrough_response(self, url: str, json: Dict[str, Any]):
encrypted_request = json["params"]["request"] encrypted_request = json["params"]["request"]
decrypted_request = self.encryption_session.decrypt(encrypted_request.encode()) decrypted_request = self.encryption_session.decrypt(encrypted_request.encode())
decrypted_request_dict = json_loads(decrypted_request) decrypted_request_dict = json_loads(decrypted_request)
@ -283,12 +286,15 @@ class MockAesDevice:
} }
return self._mock_response(self.status_code, result) return self._mock_response(self.status_code, result)
async def _return_login_response(self, url, json): async def _return_login_response(self, url: str, json: Dict[str, Any]):
if "token=" in url:
raise Exception("token should not be in url for a login request")
self.token = "".join(random.choices(string.ascii_uppercase, k=32)) # noqa: S311
result = {"result": {"token": self.token}, "error_code": self.inner_error_code} result = {"result": {"token": self.token}, "error_code": self.inner_error_code}
self.inner_call_count += 1 self.inner_call_count += 1
return self._mock_response(self.status_code, result) return self._mock_response(self.status_code, result)
async def _return_send_response(self, url, json): async def _return_send_response(self, url: str, json: Dict[str, Any]):
result = {"result": {"method": None}, "error_code": self.inner_error_code} result = {"result": {"method": None}, "error_code": self.inner_error_code}
self.inner_call_count += 1 self.inner_call_count += 1
return self._mock_response(self.status_code, result) return self._mock_response(self.status_code, result)