Fix port-override for aes&klap transports (#734)

* Fix port-override for aes&klap transports

* Add tests for port override
This commit is contained in:
Teemu R 2024-02-03 15:28:20 +01:00 committed by GitHub
parent 414489ff18
commit fae071f0df
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 39 additions and 13 deletions

View File

@ -102,7 +102,7 @@ class AesTransport(BaseTransport):
self._session_cookie: Optional[Dict[str, str]] = None self._session_cookie: Optional[Dict[str, str]] = None
self._key_pair: Optional[KeyPair] = None self._key_pair: Optional[KeyPair] = None
self._app_url = URL(f"http://{self._host}/app") self._app_url = URL(f"http://{self._host}:{self._port}/app")
self._token_url: Optional[URL] = None self._token_url: Optional[URL] = None
_LOGGER.debug("Created AES transport for %s", self._host) _LOGGER.debug("Created AES transport for %s", self._host)
@ -257,7 +257,6 @@ class AesTransport(BaseTransport):
self._session_expire_at = None self._session_expire_at = None
self._session_cookie = None self._session_cookie = None
url = f"http://{self._host}/app"
# Device needs the content length or it will response with 500 # Device needs the content length or it will response with 500
headers = { headers = {
**self.COMMON_HEADERS, **self.COMMON_HEADERS,
@ -266,7 +265,7 @@ class AesTransport(BaseTransport):
http_client = self._http_client http_client = self._http_client
status_code, resp_dict = await http_client.post( status_code, resp_dict = await http_client.post(
url, self._app_url,
json=self._generate_key_pair_payload(), json=self._generate_key_pair_payload(),
headers=headers, headers=headers,
cookies_dict=self._session_cookie, cookies_dict=self._session_cookie,

View File

@ -334,6 +334,7 @@ async def cli(
) )
config = DeviceConfig( config = DeviceConfig(
host=host, host=host,
port_override=port,
credentials=credentials, credentials=credentials,
credentials_hash=credentials_hash, credentials_hash=credentials_hash,
timeout=timeout, timeout=timeout,

View File

@ -1,5 +1,6 @@
"""Module for HttpClientSession class.""" """Module for HttpClientSession class."""
import asyncio import asyncio
import logging
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union
import aiohttp import aiohttp
@ -13,6 +14,8 @@ from .exceptions import (
) )
from .json import loads as json_loads from .json import loads as json_loads
_LOGGER = logging.getLogger(__name__)
def get_cookie_jar() -> aiohttp.CookieJar: def get_cookie_jar() -> aiohttp.CookieJar:
"""Return a new cookie jar with the correct options for device communication.""" """Return a new cookie jar with the correct options for device communication."""
@ -54,6 +57,7 @@ class HttpClient:
If the request is provided via the json parameter json will be returned. If the request is provided via the json parameter json will be returned.
""" """
_LOGGER.debug("Posting to %s", url)
response_data = None response_data = None
self._last_url = url self._last_url = url
self.client.cookie_jar.clear() self.client.cookie_jar.clear()

View File

@ -121,7 +121,7 @@ class KlapTransport(BaseTransport):
self._session_cookie: Optional[Dict[str, Any]] = None self._session_cookie: Optional[Dict[str, Any]] = None
_LOGGER.debug("Created KLAP transport for %s", self._host) _LOGGER.debug("Created KLAP transport for %s", self._host)
self._app_url = URL(f"http://{self._host}/app") self._app_url = URL(f"http://{self._host}:{self._port}/app")
self._request_url = self._app_url / "request" self._request_url = self._app_url / "request"
@property @property

View File

@ -209,6 +209,17 @@ async def test_passthrough_errors(mocker, error_code):
await transport.send(json_dumps(request)) await transport.send(json_dumps(request))
async def test_port_override():
"""Test that port override sets the app_url."""
host = "127.0.0.1"
config = DeviceConfig(
host, credentials=Credentials("foo", "bar"), port_override=12345
)
transport = AesTransport(config=config)
assert str(transport._app_url) == "http://127.0.0.1:12345/app"
class MockAesDevice: class MockAesDevice:
class _mock_response: class _mock_response:
def __init__(self, status, json: dict): def __init__(self, status, json: dict):
@ -256,7 +267,7 @@ class MockAesDevice:
elif json["method"] == "login_device": elif json["method"] == "login_device":
return await self._return_login_response(url, json) return await self._return_login_response(url, json)
else: else:
assert str(url) == f"http://{self.host}/app?token={self.token}" assert str(url) == f"http://{self.host}:80/app?token={self.token}"
return await self._return_send_response(url, json) return await self._return_send_response(url, json)
async def _return_handshake_response(self, url: URL, json: Dict[str, Any]): async def _return_handshake_response(self, url: URL, json: Dict[str, Any]):

View File

@ -323,14 +323,14 @@ async def test_handshake(
async def _return_handshake_response(url: URL, params=None, data=None, *_, **__): async def _return_handshake_response(url: URL, params=None, data=None, *_, **__):
nonlocal client_seed, server_seed, device_auth_hash nonlocal client_seed, server_seed, device_auth_hash
if str(url) == "http://127.0.0.1/app/handshake1": if str(url) == "http://127.0.0.1:80/app/handshake1":
client_seed = data client_seed = data
seed_auth_hash = _sha256( seed_auth_hash = _sha256(
seed_auth_hash_calc1(client_seed, server_seed, device_auth_hash) seed_auth_hash_calc1(client_seed, server_seed, device_auth_hash)
) )
return _mock_response(200, server_seed + seed_auth_hash) return _mock_response(200, server_seed + seed_auth_hash)
elif str(url) == "http://127.0.0.1/app/handshake2": elif str(url) == "http://127.0.0.1:80/app/handshake2":
seed_auth_hash = _sha256( seed_auth_hash = _sha256(
seed_auth_hash_calc2(client_seed, server_seed, device_auth_hash) seed_auth_hash_calc2(client_seed, server_seed, device_auth_hash)
) )
@ -367,14 +367,14 @@ async def test_query(mocker):
async def _return_response(url: URL, params=None, data=None, *_, **__): async def _return_response(url: URL, params=None, data=None, *_, **__):
nonlocal client_seed, server_seed, device_auth_hash, seq nonlocal client_seed, server_seed, device_auth_hash, seq
if str(url) == "http://127.0.0.1/app/handshake1": if str(url) == "http://127.0.0.1:80/app/handshake1":
client_seed = data client_seed = data
client_seed_auth_hash = _sha256(data + device_auth_hash) client_seed_auth_hash = _sha256(data + device_auth_hash)
return _mock_response(200, server_seed + client_seed_auth_hash) return _mock_response(200, server_seed + client_seed_auth_hash)
elif str(url) == "http://127.0.0.1/app/handshake2": elif str(url) == "http://127.0.0.1:80/app/handshake2":
return _mock_response(200, b"") return _mock_response(200, b"")
elif str(url) == "http://127.0.0.1/app/request": elif str(url) == "http://127.0.0.1:80/app/request":
encryption_session = KlapEncryptionSession( encryption_session = KlapEncryptionSession(
protocol._transport._encryption_session.local_seed, protocol._transport._encryption_session.local_seed,
protocol._transport._encryption_session.remote_seed, protocol._transport._encryption_session.remote_seed,
@ -419,16 +419,16 @@ async def test_authentication_failures(mocker, response_status, expectation):
async def _return_response(url: URL, params=None, data=None, *_, **__): async def _return_response(url: URL, params=None, data=None, *_, **__):
nonlocal client_seed, server_seed, device_auth_hash, response_status nonlocal client_seed, server_seed, device_auth_hash, response_status
if str(url) == "http://127.0.0.1/app/handshake1": if str(url) == "http://127.0.0.1:80/app/handshake1":
client_seed = data client_seed = data
client_seed_auth_hash = _sha256(data + device_auth_hash) client_seed_auth_hash = _sha256(data + device_auth_hash)
return _mock_response( return _mock_response(
response_status[0], server_seed + client_seed_auth_hash response_status[0], server_seed + client_seed_auth_hash
) )
elif str(url) == "http://127.0.0.1/app/handshake2": elif str(url) == "http://127.0.0.1:80/app/handshake2":
return _mock_response(response_status[1], b"") return _mock_response(response_status[1], b"")
elif str(url) == "http://127.0.0.1/app/request": elif str(url) == "http://127.0.0.1:80/app/request":
return _mock_response(response_status[2], b"") return _mock_response(response_status[2], b"")
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=_return_response) mocker.patch.object(aiohttp.ClientSession, "post", side_effect=_return_response)
@ -438,3 +438,14 @@ async def test_authentication_failures(mocker, response_status, expectation):
with expectation: with expectation:
await protocol.query({}) await protocol.query({})
async def test_port_override():
"""Test that port override sets the app_url."""
host = "127.0.0.1"
config = DeviceConfig(
host, credentials=Credentials("foo", "bar"), port_override=12345
)
transport = KlapTransport(config=config)
assert str(transport._app_url) == "http://127.0.0.1:12345/app"