mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-04-26 16:46:23 +00:00
Fix port-override for aes&klap transports (#734)
* Fix port-override for aes&klap transports * Add tests for port override
This commit is contained in:
parent
414489ff18
commit
fae071f0df
@ -102,7 +102,7 @@ class AesTransport(BaseTransport):
|
|||||||
self._session_cookie: Optional[Dict[str, str]] = None
|
self._session_cookie: Optional[Dict[str, str]] = None
|
||||||
|
|
||||||
self._key_pair: Optional[KeyPair] = None
|
self._key_pair: Optional[KeyPair] = None
|
||||||
self._app_url = URL(f"http://{self._host}/app")
|
self._app_url = URL(f"http://{self._host}:{self._port}/app")
|
||||||
self._token_url: Optional[URL] = None
|
self._token_url: Optional[URL] = None
|
||||||
|
|
||||||
_LOGGER.debug("Created AES transport for %s", self._host)
|
_LOGGER.debug("Created AES transport for %s", self._host)
|
||||||
@ -257,7 +257,6 @@ class AesTransport(BaseTransport):
|
|||||||
self._session_expire_at = None
|
self._session_expire_at = None
|
||||||
self._session_cookie = None
|
self._session_cookie = None
|
||||||
|
|
||||||
url = f"http://{self._host}/app"
|
|
||||||
# Device needs the content length or it will response with 500
|
# Device needs the content length or it will response with 500
|
||||||
headers = {
|
headers = {
|
||||||
**self.COMMON_HEADERS,
|
**self.COMMON_HEADERS,
|
||||||
@ -266,7 +265,7 @@ class AesTransport(BaseTransport):
|
|||||||
http_client = self._http_client
|
http_client = self._http_client
|
||||||
|
|
||||||
status_code, resp_dict = await http_client.post(
|
status_code, resp_dict = await http_client.post(
|
||||||
url,
|
self._app_url,
|
||||||
json=self._generate_key_pair_payload(),
|
json=self._generate_key_pair_payload(),
|
||||||
headers=headers,
|
headers=headers,
|
||||||
cookies_dict=self._session_cookie,
|
cookies_dict=self._session_cookie,
|
||||||
|
@ -334,6 +334,7 @@ async def cli(
|
|||||||
)
|
)
|
||||||
config = DeviceConfig(
|
config = DeviceConfig(
|
||||||
host=host,
|
host=host,
|
||||||
|
port_override=port,
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
credentials_hash=credentials_hash,
|
credentials_hash=credentials_hash,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
"""Module for HttpClientSession class."""
|
"""Module for HttpClientSession class."""
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
@ -13,6 +14,8 @@ from .exceptions import (
|
|||||||
)
|
)
|
||||||
from .json import loads as json_loads
|
from .json import loads as json_loads
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_cookie_jar() -> aiohttp.CookieJar:
|
def get_cookie_jar() -> aiohttp.CookieJar:
|
||||||
"""Return a new cookie jar with the correct options for device communication."""
|
"""Return a new cookie jar with the correct options for device communication."""
|
||||||
@ -54,6 +57,7 @@ class HttpClient:
|
|||||||
|
|
||||||
If the request is provided via the json parameter json will be returned.
|
If the request is provided via the json parameter json will be returned.
|
||||||
"""
|
"""
|
||||||
|
_LOGGER.debug("Posting to %s", url)
|
||||||
response_data = None
|
response_data = None
|
||||||
self._last_url = url
|
self._last_url = url
|
||||||
self.client.cookie_jar.clear()
|
self.client.cookie_jar.clear()
|
||||||
|
@ -121,7 +121,7 @@ class KlapTransport(BaseTransport):
|
|||||||
self._session_cookie: Optional[Dict[str, Any]] = None
|
self._session_cookie: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
_LOGGER.debug("Created KLAP transport for %s", self._host)
|
_LOGGER.debug("Created KLAP transport for %s", self._host)
|
||||||
self._app_url = URL(f"http://{self._host}/app")
|
self._app_url = URL(f"http://{self._host}:{self._port}/app")
|
||||||
self._request_url = self._app_url / "request"
|
self._request_url = self._app_url / "request"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -209,6 +209,17 @@ async def test_passthrough_errors(mocker, error_code):
|
|||||||
await transport.send(json_dumps(request))
|
await transport.send(json_dumps(request))
|
||||||
|
|
||||||
|
|
||||||
|
async def test_port_override():
|
||||||
|
"""Test that port override sets the app_url."""
|
||||||
|
host = "127.0.0.1"
|
||||||
|
config = DeviceConfig(
|
||||||
|
host, credentials=Credentials("foo", "bar"), port_override=12345
|
||||||
|
)
|
||||||
|
transport = AesTransport(config=config)
|
||||||
|
|
||||||
|
assert str(transport._app_url) == "http://127.0.0.1:12345/app"
|
||||||
|
|
||||||
|
|
||||||
class MockAesDevice:
|
class MockAesDevice:
|
||||||
class _mock_response:
|
class _mock_response:
|
||||||
def __init__(self, status, json: dict):
|
def __init__(self, status, json: dict):
|
||||||
@ -256,7 +267,7 @@ class MockAesDevice:
|
|||||||
elif json["method"] == "login_device":
|
elif json["method"] == "login_device":
|
||||||
return await self._return_login_response(url, json)
|
return await self._return_login_response(url, json)
|
||||||
else:
|
else:
|
||||||
assert str(url) == f"http://{self.host}/app?token={self.token}"
|
assert str(url) == f"http://{self.host}:80/app?token={self.token}"
|
||||||
return await self._return_send_response(url, json)
|
return await self._return_send_response(url, json)
|
||||||
|
|
||||||
async def _return_handshake_response(self, url: URL, json: Dict[str, Any]):
|
async def _return_handshake_response(self, url: URL, json: Dict[str, Any]):
|
||||||
|
@ -323,14 +323,14 @@ async def test_handshake(
|
|||||||
async def _return_handshake_response(url: URL, params=None, data=None, *_, **__):
|
async def _return_handshake_response(url: URL, params=None, data=None, *_, **__):
|
||||||
nonlocal client_seed, server_seed, device_auth_hash
|
nonlocal client_seed, server_seed, device_auth_hash
|
||||||
|
|
||||||
if str(url) == "http://127.0.0.1/app/handshake1":
|
if str(url) == "http://127.0.0.1:80/app/handshake1":
|
||||||
client_seed = data
|
client_seed = data
|
||||||
seed_auth_hash = _sha256(
|
seed_auth_hash = _sha256(
|
||||||
seed_auth_hash_calc1(client_seed, server_seed, device_auth_hash)
|
seed_auth_hash_calc1(client_seed, server_seed, device_auth_hash)
|
||||||
)
|
)
|
||||||
|
|
||||||
return _mock_response(200, server_seed + seed_auth_hash)
|
return _mock_response(200, server_seed + seed_auth_hash)
|
||||||
elif str(url) == "http://127.0.0.1/app/handshake2":
|
elif str(url) == "http://127.0.0.1:80/app/handshake2":
|
||||||
seed_auth_hash = _sha256(
|
seed_auth_hash = _sha256(
|
||||||
seed_auth_hash_calc2(client_seed, server_seed, device_auth_hash)
|
seed_auth_hash_calc2(client_seed, server_seed, device_auth_hash)
|
||||||
)
|
)
|
||||||
@ -367,14 +367,14 @@ async def test_query(mocker):
|
|||||||
async def _return_response(url: URL, params=None, data=None, *_, **__):
|
async def _return_response(url: URL, params=None, data=None, *_, **__):
|
||||||
nonlocal client_seed, server_seed, device_auth_hash, seq
|
nonlocal client_seed, server_seed, device_auth_hash, seq
|
||||||
|
|
||||||
if str(url) == "http://127.0.0.1/app/handshake1":
|
if str(url) == "http://127.0.0.1:80/app/handshake1":
|
||||||
client_seed = data
|
client_seed = data
|
||||||
client_seed_auth_hash = _sha256(data + device_auth_hash)
|
client_seed_auth_hash = _sha256(data + device_auth_hash)
|
||||||
|
|
||||||
return _mock_response(200, server_seed + client_seed_auth_hash)
|
return _mock_response(200, server_seed + client_seed_auth_hash)
|
||||||
elif str(url) == "http://127.0.0.1/app/handshake2":
|
elif str(url) == "http://127.0.0.1:80/app/handshake2":
|
||||||
return _mock_response(200, b"")
|
return _mock_response(200, b"")
|
||||||
elif str(url) == "http://127.0.0.1/app/request":
|
elif str(url) == "http://127.0.0.1:80/app/request":
|
||||||
encryption_session = KlapEncryptionSession(
|
encryption_session = KlapEncryptionSession(
|
||||||
protocol._transport._encryption_session.local_seed,
|
protocol._transport._encryption_session.local_seed,
|
||||||
protocol._transport._encryption_session.remote_seed,
|
protocol._transport._encryption_session.remote_seed,
|
||||||
@ -419,16 +419,16 @@ async def test_authentication_failures(mocker, response_status, expectation):
|
|||||||
async def _return_response(url: URL, params=None, data=None, *_, **__):
|
async def _return_response(url: URL, params=None, data=None, *_, **__):
|
||||||
nonlocal client_seed, server_seed, device_auth_hash, response_status
|
nonlocal client_seed, server_seed, device_auth_hash, response_status
|
||||||
|
|
||||||
if str(url) == "http://127.0.0.1/app/handshake1":
|
if str(url) == "http://127.0.0.1:80/app/handshake1":
|
||||||
client_seed = data
|
client_seed = data
|
||||||
client_seed_auth_hash = _sha256(data + device_auth_hash)
|
client_seed_auth_hash = _sha256(data + device_auth_hash)
|
||||||
|
|
||||||
return _mock_response(
|
return _mock_response(
|
||||||
response_status[0], server_seed + client_seed_auth_hash
|
response_status[0], server_seed + client_seed_auth_hash
|
||||||
)
|
)
|
||||||
elif str(url) == "http://127.0.0.1/app/handshake2":
|
elif str(url) == "http://127.0.0.1:80/app/handshake2":
|
||||||
return _mock_response(response_status[1], b"")
|
return _mock_response(response_status[1], b"")
|
||||||
elif str(url) == "http://127.0.0.1/app/request":
|
elif str(url) == "http://127.0.0.1:80/app/request":
|
||||||
return _mock_response(response_status[2], b"")
|
return _mock_response(response_status[2], b"")
|
||||||
|
|
||||||
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=_return_response)
|
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=_return_response)
|
||||||
@ -438,3 +438,14 @@ async def test_authentication_failures(mocker, response_status, expectation):
|
|||||||
|
|
||||||
with expectation:
|
with expectation:
|
||||||
await protocol.query({})
|
await protocol.query({})
|
||||||
|
|
||||||
|
|
||||||
|
async def test_port_override():
|
||||||
|
"""Test that port override sets the app_url."""
|
||||||
|
host = "127.0.0.1"
|
||||||
|
config = DeviceConfig(
|
||||||
|
host, credentials=Credentials("foo", "bar"), port_override=12345
|
||||||
|
)
|
||||||
|
transport = KlapTransport(config=config)
|
||||||
|
|
||||||
|
assert str(transport._app_url) == "http://127.0.0.1:12345/app"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user