mirror of
https://github.com/python-kasa/python-kasa.git
synced 2024-12-23 03:33:35 +00:00
Ensure login token is only sent if aes state is ESTABLISHED (#702)
This commit is contained in:
parent
aecf0ecd8a
commit
3df837cc82
@ -151,7 +151,7 @@ class AesTransport(BaseTransport):
|
|||||||
async def send_secure_passthrough(self, request: str) -> Dict[str, Any]:
|
async def send_secure_passthrough(self, request: str) -> Dict[str, Any]:
|
||||||
"""Send encrypted message as passthrough."""
|
"""Send encrypted message as passthrough."""
|
||||||
url = f"http://{self._host}/app"
|
url = f"http://{self._host}/app"
|
||||||
if self._login_token:
|
if self._state is TransportState.ESTABLISHED and self._login_token:
|
||||||
url += f"?token={self._login_token}"
|
url += f"?token={self._login_token}"
|
||||||
|
|
||||||
encrypted_payload = self._encryption_session.encrypt(request.encode()) # type: ignore
|
encrypted_payload = self._encryption_session.encrypt(request.encode()) # type: ignore
|
||||||
@ -250,6 +250,7 @@ class AesTransport(BaseTransport):
|
|||||||
_LOGGER.debug("Will perform handshaking...")
|
_LOGGER.debug("Will perform handshaking...")
|
||||||
|
|
||||||
self._key_pair = None
|
self._key_pair = None
|
||||||
|
self._login_token = None
|
||||||
self._session_expire_at = None
|
self._session_expire_at = None
|
||||||
self._session_cookie = None
|
self._session_cookie = None
|
||||||
|
|
||||||
@ -284,9 +285,7 @@ class AesTransport(BaseTransport):
|
|||||||
handshake_key = resp_dict["result"]["key"]
|
handshake_key = resp_dict["result"]["key"]
|
||||||
|
|
||||||
if (
|
if (
|
||||||
cookie := http_client.get_cookie( # type: ignore
|
cookie := http_client.get_cookie(self.SESSION_COOKIE_NAME) # type: ignore
|
||||||
self.SESSION_COOKIE_NAME
|
|
||||||
)
|
|
||||||
) or (
|
) or (
|
||||||
cookie := http_client.get_cookie("SESSIONID") # type: ignore
|
cookie := http_client.get_cookie("SESSIONID") # type: ignore
|
||||||
):
|
):
|
||||||
|
@ -1,9 +1,12 @@
|
|||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
|
import random
|
||||||
|
import string
|
||||||
import time
|
import time
|
||||||
from contextlib import nullcontext as does_not_raise
|
from contextlib import nullcontext as does_not_raise
|
||||||
from json import dumps as json_dumps
|
from json import dumps as json_dumps
|
||||||
from json import loads as json_loads
|
from json import loads as json_loads
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import pytest
|
import pytest
|
||||||
@ -219,7 +222,6 @@ class MockAesDevice:
|
|||||||
return json_dumps(self._json).encode()
|
return json_dumps(self._json).encode()
|
||||||
|
|
||||||
encryption_session = AesEncyptionSession(KEY_IV[:16], KEY_IV[16:])
|
encryption_session = AesEncyptionSession(KEY_IV[:16], KEY_IV[16:])
|
||||||
token = "test_token" # noqa
|
|
||||||
|
|
||||||
def __init__(self, host, status_code=200, error_code=0, inner_error_code=0):
|
def __init__(self, host, status_code=200, error_code=0, inner_error_code=0):
|
||||||
self.host = host
|
self.host = host
|
||||||
@ -228,6 +230,7 @@ class MockAesDevice:
|
|||||||
self._inner_error_code = inner_error_code
|
self._inner_error_code = inner_error_code
|
||||||
self.http_client = HttpClient(DeviceConfig(self.host))
|
self.http_client = HttpClient(DeviceConfig(self.host))
|
||||||
self.inner_call_count = 0
|
self.inner_call_count = 0
|
||||||
|
self.token = "".join(random.choices(string.ascii_uppercase, k=32)) # noqa: S311
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def inner_error_code(self):
|
def inner_error_code(self):
|
||||||
@ -242,7 +245,7 @@ class MockAesDevice:
|
|||||||
json = json_loads(item.decode())
|
json = json_loads(item.decode())
|
||||||
return await self._post(url, json)
|
return await self._post(url, json)
|
||||||
|
|
||||||
async def _post(self, url, json):
|
async def _post(self, url: str, json: Dict[str, Any]):
|
||||||
if json["method"] == "handshake":
|
if json["method"] == "handshake":
|
||||||
return await self._return_handshake_response(url, json)
|
return await self._return_handshake_response(url, json)
|
||||||
elif json["method"] == "securePassthrough":
|
elif json["method"] == "securePassthrough":
|
||||||
@ -253,7 +256,7 @@ class MockAesDevice:
|
|||||||
assert url == f"http://{self.host}/app?token={self.token}"
|
assert url == f"http://{self.host}/app?token={self.token}"
|
||||||
return await self._return_send_response(url, json)
|
return await self._return_send_response(url, json)
|
||||||
|
|
||||||
async def _return_handshake_response(self, url, json):
|
async def _return_handshake_response(self, url: str, json: Dict[str, Any]):
|
||||||
start = len("-----BEGIN PUBLIC KEY-----\n")
|
start = len("-----BEGIN PUBLIC KEY-----\n")
|
||||||
end = len("\n-----END PUBLIC KEY-----\n")
|
end = len("\n-----END PUBLIC KEY-----\n")
|
||||||
client_pub_key = json["params"]["key"][start:-end]
|
client_pub_key = json["params"]["key"][start:-end]
|
||||||
@ -266,7 +269,7 @@ class MockAesDevice:
|
|||||||
self.status_code, {"result": {"key": key_64}, "error_code": self.error_code}
|
self.status_code, {"result": {"key": key_64}, "error_code": self.error_code}
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _return_secure_passthrough_response(self, url, json):
|
async def _return_secure_passthrough_response(self, url: str, json: Dict[str, Any]):
|
||||||
encrypted_request = json["params"]["request"]
|
encrypted_request = json["params"]["request"]
|
||||||
decrypted_request = self.encryption_session.decrypt(encrypted_request.encode())
|
decrypted_request = self.encryption_session.decrypt(encrypted_request.encode())
|
||||||
decrypted_request_dict = json_loads(decrypted_request)
|
decrypted_request_dict = json_loads(decrypted_request)
|
||||||
@ -283,12 +286,15 @@ class MockAesDevice:
|
|||||||
}
|
}
|
||||||
return self._mock_response(self.status_code, result)
|
return self._mock_response(self.status_code, result)
|
||||||
|
|
||||||
async def _return_login_response(self, url, json):
|
async def _return_login_response(self, url: str, json: Dict[str, Any]):
|
||||||
|
if "token=" in url:
|
||||||
|
raise Exception("token should not be in url for a login request")
|
||||||
|
self.token = "".join(random.choices(string.ascii_uppercase, k=32)) # noqa: S311
|
||||||
result = {"result": {"token": self.token}, "error_code": self.inner_error_code}
|
result = {"result": {"token": self.token}, "error_code": self.inner_error_code}
|
||||||
self.inner_call_count += 1
|
self.inner_call_count += 1
|
||||||
return self._mock_response(self.status_code, result)
|
return self._mock_response(self.status_code, result)
|
||||||
|
|
||||||
async def _return_send_response(self, url, json):
|
async def _return_send_response(self, url: str, json: Dict[str, Any]):
|
||||||
result = {"result": {"method": None}, "error_code": self.inner_error_code}
|
result = {"result": {"method": None}, "error_code": self.inner_error_code}
|
||||||
self.inner_call_count += 1
|
self.inner_call_count += 1
|
||||||
return self._mock_response(self.status_code, result)
|
return self._mock_response(self.status_code, result)
|
||||||
|
Loading…
Reference in New Issue
Block a user