Avoid rebuilding urls for every request (#715)

* Avoid rebuilding urls for every request

* more fixes

* more fixes

* make mypy happy

* reduce

* tweak

* fix tests

* fix tests

* tweak

* tweak

* lint

* fix type
This commit is contained in:
J. Nick Koston 2024-01-29 05:26:00 -10:00 committed by GitHub
parent 69dcc0d8bb
commit b479b6d84d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 73 additions and 62 deletions

View File

@ -15,6 +15,7 @@ from cryptography.hazmat.primitives import padding, serialization
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from yarl import URL
from .credentials import Credentials from .credentials import Credentials
from .deviceconfig import DeviceConfig from .deviceconfig import DeviceConfig
@ -100,9 +101,9 @@ class AesTransport(BaseTransport):
self._session_cookie: Optional[Dict[str, str]] = None self._session_cookie: Optional[Dict[str, str]] = None
self._login_token: Optional[str] = None
self._key_pair: Optional[KeyPair] = None self._key_pair: Optional[KeyPair] = None
self._app_url = URL(f"http://{self._host}/app")
self._token_url: Optional[URL] = None
_LOGGER.debug("Created AES transport for %s", self._host) _LOGGER.debug("Created AES transport for %s", self._host)
@ -150,9 +151,10 @@ class AesTransport(BaseTransport):
async def send_secure_passthrough(self, request: str) -> Dict[str, Any]: async def send_secure_passthrough(self, request: str) -> Dict[str, Any]:
"""Send encrypted message as passthrough.""" """Send encrypted message as passthrough."""
url = f"http://{self._host}/app" if self._state is TransportState.ESTABLISHED and self._token_url:
if self._state is TransportState.ESTABLISHED and self._login_token: url = self._token_url
url += f"?token={self._login_token}" else:
url = self._app_url
encrypted_payload = self._encryption_session.encrypt(request.encode()) # type: ignore encrypted_payload = self._encryption_session.encrypt(request.encode()) # type: ignore
passthrough_request = { passthrough_request = {
@ -223,7 +225,8 @@ class AesTransport(BaseTransport):
resp_dict = await self.send_secure_passthrough(request) resp_dict = await self.send_secure_passthrough(request)
self._handle_response_error_code(resp_dict, "Error logging in") self._handle_response_error_code(resp_dict, "Error logging in")
self._login_token = resp_dict["result"]["token"] login_token = resp_dict["result"]["token"]
self._token_url = self._app_url.with_query(f"token={login_token}")
self._state = TransportState.ESTABLISHED self._state = TransportState.ESTABLISHED
async def _generate_key_pair_payload(self) -> AsyncGenerator: async def _generate_key_pair_payload(self) -> AsyncGenerator:
@ -250,7 +253,7 @@ class AesTransport(BaseTransport):
_LOGGER.debug("Will perform handshaking...") _LOGGER.debug("Will perform handshaking...")
self._key_pair = None self._key_pair = None
self._login_token = None self._token_url = None
self._session_expire_at = None self._session_expire_at = None
self._session_cookie = None self._session_cookie = None

View File

@ -1,13 +1,13 @@
"""python-kasa exceptions.""" """python-kasa exceptions."""
from asyncio import TimeoutError from asyncio import TimeoutError
from enum import IntEnum from enum import IntEnum
from typing import Optional from typing import Any, Optional
class SmartDeviceException(Exception): class SmartDeviceException(Exception):
"""Base exception for device errors.""" """Base exception for device errors."""
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any) -> None:
self.error_code: Optional["SmartErrorCode"] = kwargs.get("error_code", None) self.error_code: Optional["SmartErrorCode"] = kwargs.get("error_code", None)
super().__init__(*args) super().__init__(*args)
@ -15,7 +15,7 @@ class SmartDeviceException(Exception):
class UnsupportedDeviceException(SmartDeviceException): class UnsupportedDeviceException(SmartDeviceException):
"""Exception for trying to connect to unsupported devices.""" """Exception for trying to connect to unsupported devices."""
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any) -> None:
self.discovery_result = kwargs.get("discovery_result") self.discovery_result = kwargs.get("discovery_result")
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)

View File

@ -3,6 +3,7 @@ import asyncio
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union
import aiohttp import aiohttp
from yarl import URL
from .deviceconfig import DeviceConfig from .deviceconfig import DeviceConfig
from .exceptions import ( from .exceptions import (
@ -25,7 +26,7 @@ class HttpClient:
self._config = config self._config = config
self._client_session: aiohttp.ClientSession = None self._client_session: aiohttp.ClientSession = None
self._jar = aiohttp.CookieJar(unsafe=True, quote_cookie=False) self._jar = aiohttp.CookieJar(unsafe=True, quote_cookie=False)
self._last_url = f"http://{self._config.host}/" self._last_url = URL(f"http://{self._config.host}/")
@property @property
def client(self) -> aiohttp.ClientSession: def client(self) -> aiohttp.ClientSession:
@ -41,7 +42,7 @@ class HttpClient:
async def post( async def post(
self, self,
url: str, url: URL,
*, *,
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
data: Optional[bytes] = None, data: Optional[bytes] = None,

View File

@ -53,6 +53,7 @@ from typing import Any, Dict, Optional, Tuple, cast
from cryptography.hazmat.primitives import padding from cryptography.hazmat.primitives import padding
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from yarl import URL
from .credentials import Credentials from .credentials import Credentials
from .deviceconfig import DeviceConfig from .deviceconfig import DeviceConfig
@ -120,6 +121,8 @@ class KlapTransport(BaseTransport):
self._session_cookie: Optional[Dict[str, Any]] = None self._session_cookie: Optional[Dict[str, Any]] = None
_LOGGER.debug("Created KLAP transport for %s", self._host) _LOGGER.debug("Created KLAP transport for %s", self._host)
self._app_url = URL(f"http://{self._host}/app")
self._request_url = self._app_url / "request"
@property @property
def default_port(self): def default_port(self):
@ -141,7 +144,7 @@ class KlapTransport(BaseTransport):
payload = local_seed payload = local_seed
url = f"http://{self._host}/app/handshake1" url = self._app_url / "handshake1"
response_status, response_data = await self._http_client.post(url, data=payload) response_status, response_data = await self._http_client.post(url, data=payload)
@ -236,7 +239,7 @@ class KlapTransport(BaseTransport):
# Handshake 2 has the following payload: # Handshake 2 has the following payload:
# sha256(serverBytes | authenticator) # sha256(serverBytes | authenticator)
url = f"http://{self._host}/app/handshake2" url = self._app_url / "handshake2"
payload = self.handshake2_seed_auth_hash(local_seed, remote_seed, auth_hash) payload = self.handshake2_seed_auth_hash(local_seed, remote_seed, auth_hash)
@ -309,10 +312,8 @@ class KlapTransport(BaseTransport):
if self._encryption_session is not None: if self._encryption_session is not None:
payload, seq = self._encryption_session.encrypt(request.encode()) payload, seq = self._encryption_session.encrypt(request.encode())
url = f"http://{self._host}/app/request"
response_status, response_data = await self._http_client.post( response_status, response_data = await self._http_client.post(
url, self._request_url,
params={"seq": seq}, params={"seq": seq},
data=payload, data=payload,
cookies_dict=self._session_cookie, cookies_dict=self._session_cookie,

View File

@ -12,6 +12,7 @@ import aiohttp
import pytest import pytest
from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
from yarl import URL
from ..aestransport import AesEncyptionSession, AesTransport, TransportState from ..aestransport import AesEncyptionSession, AesTransport, TransportState
from ..credentials import Credentials from ..credentials import Credentials
@ -89,10 +90,10 @@ async def test_login(mocker, status_code, error_code, inner_error_code, expectat
transport._session_expire_at = time.time() + 86400 transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session transport._encryption_session = mock_aes_device.encryption_session
assert transport._login_token is None assert transport._token_url is None
with expectation: with expectation:
await transport.perform_login() await transport.perform_login()
assert transport._login_token == mock_aes_device.token assert mock_aes_device.token in str(transport._token_url)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -136,7 +137,7 @@ async def test_login_errors(mocker, inner_error_codes, expectation, call_count):
transport._session_expire_at = time.time() + 86400 transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session transport._encryption_session = mock_aes_device.encryption_session
assert transport._login_token is None assert transport._token_url is None
request = { request = {
"method": "get_device_info", "method": "get_device_info",
@ -148,7 +149,7 @@ async def test_login_errors(mocker, inner_error_codes, expectation, call_count):
with expectation: with expectation:
await transport.send(json_dumps(request)) await transport.send(json_dumps(request))
assert transport._login_token == mock_aes_device.token assert mock_aes_device.token in str(transport._token_url)
assert post_mock.call_count == call_count # Login, Handshake, Login assert post_mock.call_count == call_count # Login, Handshake, Login
await transport.close() await transport.close()
@ -165,7 +166,9 @@ async def test_send(mocker, status_code, error_code, inner_error_code, expectati
transport._handshake_done = True transport._handshake_done = True
transport._session_expire_at = time.time() + 86400 transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session transport._encryption_session = mock_aes_device.encryption_session
transport._login_token = mock_aes_device.token transport._token_url = transport._app_url.with_query(
f"token={mock_aes_device.token}"
)
request = { request = {
"method": "get_device_info", "method": "get_device_info",
@ -193,7 +196,9 @@ async def test_passthrough_errors(mocker, error_code):
transport._handshake_done = True transport._handshake_done = True
transport._session_expire_at = time.time() + 86400 transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session transport._encryption_session = mock_aes_device.encryption_session
transport._login_token = mock_aes_device.token transport._token_url = transport._app_url.with_query(
f"token={mock_aes_device.token}"
)
request = { request = {
"method": "get_device_info", "method": "get_device_info",
@ -239,13 +244,13 @@ class MockAesDevice:
else: else:
return self._inner_error_code return self._inner_error_code
async def post(self, url, params=None, json=None, data=None, *_, **__): async def post(self, url: URL, params=None, json=None, data=None, *_, **__):
if data: if data:
async for item in data: async for item in data:
json = json_loads(item.decode()) json = json_loads(item.decode())
return await self._post(url, json) return await self._post(url, json)
async def _post(self, url: str, json: Dict[str, Any]): async def _post(self, url: URL, json: Dict[str, Any]):
if json["method"] == "handshake": if json["method"] == "handshake":
return await self._return_handshake_response(url, json) return await self._return_handshake_response(url, json)
elif json["method"] == "securePassthrough": elif json["method"] == "securePassthrough":
@ -253,10 +258,10 @@ class MockAesDevice:
elif json["method"] == "login_device": elif json["method"] == "login_device":
return await self._return_login_response(url, json) return await self._return_login_response(url, json)
else: else:
assert url == f"http://{self.host}/app?token={self.token}" assert str(url) == f"http://{self.host}/app?token={self.token}"
return await self._return_send_response(url, json) return await self._return_send_response(url, json)
async def _return_handshake_response(self, url: str, json: Dict[str, Any]): async def _return_handshake_response(self, url: URL, json: Dict[str, Any]):
start = len("-----BEGIN PUBLIC KEY-----\n") start = len("-----BEGIN PUBLIC KEY-----\n")
end = len("\n-----END PUBLIC KEY-----\n") end = len("\n-----END PUBLIC KEY-----\n")
client_pub_key = json["params"]["key"][start:-end] client_pub_key = json["params"]["key"][start:-end]
@ -269,7 +274,7 @@ class MockAesDevice:
self.status_code, {"result": {"key": key_64}, "error_code": self.error_code} self.status_code, {"result": {"key": key_64}, "error_code": self.error_code}
) )
async def _return_secure_passthrough_response(self, url: str, json: Dict[str, Any]): async def _return_secure_passthrough_response(self, url: URL, json: Dict[str, Any]):
encrypted_request = json["params"]["request"] encrypted_request = json["params"]["request"]
decrypted_request = self.encryption_session.decrypt(encrypted_request.encode()) decrypted_request = self.encryption_session.decrypt(encrypted_request.encode())
decrypted_request_dict = json_loads(decrypted_request) decrypted_request_dict = json_loads(decrypted_request)
@ -286,15 +291,15 @@ class MockAesDevice:
} }
return self._mock_response(self.status_code, result) return self._mock_response(self.status_code, result)
async def _return_login_response(self, url: str, json: Dict[str, Any]): async def _return_login_response(self, url: URL, json: Dict[str, Any]):
if "token=" in url: if "token=" in str(url):
raise Exception("token should not be in url for a login request") raise Exception("token should not be in url for a login request")
self.token = "".join(random.choices(string.ascii_uppercase, k=32)) # noqa: S311 self.token = "".join(random.choices(string.ascii_uppercase, k=32)) # noqa: S311
result = {"result": {"token": self.token}, "error_code": self.inner_error_code} result = {"result": {"token": self.token}, "error_code": self.inner_error_code}
self.inner_call_count += 1 self.inner_call_count += 1
return self._mock_response(self.status_code, result) return self._mock_response(self.status_code, result)
async def _return_send_response(self, url: str, json: Dict[str, Any]): async def _return_send_response(self, url: URL, json: Dict[str, Any]):
result = {"result": {"method": None}, "error_code": self.inner_error_code} result = {"result": {"method": None}, "error_code": self.inner_error_code}
self.inner_call_count += 1 self.inner_call_count += 1
return self._mock_response(self.status_code, result) return self._mock_response(self.status_code, result)

View File

@ -84,7 +84,7 @@ async def test_httpclient_errors(mocker, error, error_raises, error_message, moc
client = HttpClient(DeviceConfig(host)) client = HttpClient(DeviceConfig(host))
# Exceptions with parameters print with double quotes, without use single quotes # Exceptions with parameters print with double quotes, without use single quotes
full_msg = ( full_msg = (
"\(" "\(" # type: ignore
+ "['\"]" + "['\"]"
+ re.escape(f"{error_message}{host}: {error}") + re.escape(f"{error_message}{host}: {error}")
+ "['\"]" + "['\"]"

View File

@ -10,6 +10,7 @@ from unittest.mock import PropertyMock
import aiohttp import aiohttp
import pytest import pytest
from yarl import URL
from ..aestransport import AesTransport from ..aestransport import AesTransport
from ..credentials import Credentials from ..credentials import Credentials
@ -318,28 +319,28 @@ async def test_handshake1(
async def test_handshake( async def test_handshake(
mocker, transport_class, seed_auth_hash_calc1, seed_auth_hash_calc2 mocker, transport_class, seed_auth_hash_calc1, seed_auth_hash_calc2
): ):
async def _return_handshake_response(url, params=None, data=None, *_, **__): client_seed = None
server_seed = secrets.token_bytes(16)
client_credentials = Credentials("foo", "bar")
device_auth_hash = transport_class.generate_auth_hash(client_credentials)
async def _return_handshake_response(url: URL, params=None, data=None, *_, **__):
nonlocal client_seed, server_seed, device_auth_hash nonlocal client_seed, server_seed, device_auth_hash
if url == "http://127.0.0.1/app/handshake1": if str(url) == "http://127.0.0.1/app/handshake1":
client_seed = data client_seed = data
seed_auth_hash = _sha256( seed_auth_hash = _sha256(
seed_auth_hash_calc1(client_seed, server_seed, device_auth_hash) seed_auth_hash_calc1(client_seed, server_seed, device_auth_hash)
) )
return _mock_response(200, server_seed + seed_auth_hash) return _mock_response(200, server_seed + seed_auth_hash)
elif url == "http://127.0.0.1/app/handshake2": elif str(url) == "http://127.0.0.1/app/handshake2":
seed_auth_hash = _sha256( seed_auth_hash = _sha256(
seed_auth_hash_calc2(client_seed, server_seed, device_auth_hash) seed_auth_hash_calc2(client_seed, server_seed, device_auth_hash)
) )
assert data == seed_auth_hash assert data == seed_auth_hash
return _mock_response(response_status, b"") return _mock_response(response_status, b"")
client_seed = None
server_seed = secrets.token_bytes(16)
client_credentials = Credentials("foo", "bar")
device_auth_hash = transport_class.generate_auth_hash(client_credentials)
mocker.patch.object( mocker.patch.object(
aiohttp.ClientSession, "post", side_effect=_return_handshake_response aiohttp.ClientSession, "post", side_effect=_return_handshake_response
) )
@ -360,17 +361,24 @@ async def test_handshake(
async def test_query(mocker): async def test_query(mocker):
async def _return_response(url, params=None, data=None, *_, **__): client_seed = None
last_seq = None
seq = None
server_seed = secrets.token_bytes(16)
client_credentials = Credentials("foo", "bar")
device_auth_hash = KlapTransport.generate_auth_hash(client_credentials)
async def _return_response(url: URL, params=None, data=None, *_, **__):
nonlocal client_seed, server_seed, device_auth_hash, seq nonlocal client_seed, server_seed, device_auth_hash, seq
if url == "http://127.0.0.1/app/handshake1": if str(url) == "http://127.0.0.1/app/handshake1":
client_seed = data client_seed = data
client_seed_auth_hash = _sha256(data + device_auth_hash) client_seed_auth_hash = _sha256(data + device_auth_hash)
return _mock_response(200, server_seed + client_seed_auth_hash) return _mock_response(200, server_seed + client_seed_auth_hash)
elif url == "http://127.0.0.1/app/handshake2": elif str(url) == "http://127.0.0.1/app/handshake2":
return _mock_response(200, b"") return _mock_response(200, b"")
elif url == "http://127.0.0.1/app/request": elif str(url) == "http://127.0.0.1/app/request":
encryption_session = KlapEncryptionSession( encryption_session = KlapEncryptionSession(
protocol._transport._encryption_session.local_seed, protocol._transport._encryption_session.local_seed,
protocol._transport._encryption_session.remote_seed, protocol._transport._encryption_session.remote_seed,
@ -382,13 +390,6 @@ async def test_query(mocker):
seq = seq seq = seq
return _mock_response(200, encrypted) return _mock_response(200, encrypted)
client_seed = None
last_seq = None
seq = None
server_seed = secrets.token_bytes(16)
client_credentials = Credentials("foo", "bar")
device_auth_hash = KlapTransport.generate_auth_hash(client_credentials)
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=_return_response) mocker.patch.object(aiohttp.ClientSession, "post", side_effect=_return_response)
config = DeviceConfig("127.0.0.1", credentials=client_credentials) config = DeviceConfig("127.0.0.1", credentials=client_credentials)
@ -413,26 +414,26 @@ async def test_query(mocker):
ids=("handshake1", "handshake2", "request", "non_auth_error"), ids=("handshake1", "handshake2", "request", "non_auth_error"),
) )
async def test_authentication_failures(mocker, response_status, expectation): async def test_authentication_failures(mocker, response_status, expectation):
async def _return_response(url, params=None, data=None, *_, **__): client_seed = None
server_seed = secrets.token_bytes(16)
client_credentials = Credentials("foo", "bar")
device_auth_hash = KlapTransport.generate_auth_hash(client_credentials)
async def _return_response(url: URL, params=None, data=None, *_, **__):
nonlocal client_seed, server_seed, device_auth_hash, response_status nonlocal client_seed, server_seed, device_auth_hash, response_status
if url == "http://127.0.0.1/app/handshake1": if str(url) == "http://127.0.0.1/app/handshake1":
client_seed = data client_seed = data
client_seed_auth_hash = _sha256(data + device_auth_hash) client_seed_auth_hash = _sha256(data + device_auth_hash)
return _mock_response( return _mock_response(
response_status[0], server_seed + client_seed_auth_hash response_status[0], server_seed + client_seed_auth_hash
) )
elif url == "http://127.0.0.1/app/handshake2": elif str(url) == "http://127.0.0.1/app/handshake2":
return _mock_response(response_status[1], b"") return _mock_response(response_status[1], b"")
elif url == "http://127.0.0.1/app/request": elif str(url) == "http://127.0.0.1/app/request":
return _mock_response(response_status[2], None) return _mock_response(response_status[2], b"")
client_seed = None
server_seed = secrets.token_bytes(16)
client_credentials = Credentials("foo", "bar")
device_auth_hash = KlapTransport.generate_auth_hash(client_credentials)
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=_return_response) mocker.patch.object(aiohttp.ClientSession, "post", side_effect=_return_response)