mirror of
https://github.com/python-kasa/python-kasa.git
synced 2024-12-22 19:23:34 +00:00
Fix SslAesTransport default login and add tests (#1202)
Co-authored-by: Teemu R. <tpr@iki.fi>
This commit is contained in:
parent
0287606235
commit
440b2d153b
@ -137,6 +137,11 @@ class SslAesTransport(BaseTransport):
|
|||||||
"""Default port for the transport."""
|
"""Default port for the transport."""
|
||||||
return self.DEFAULT_PORT
|
return self.DEFAULT_PORT
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _create_b64_credentials(credentials: Credentials) -> str:
|
||||||
|
ch = {"un": credentials.username, "pwd": credentials.password}
|
||||||
|
return base64.b64encode(json_dumps(ch).encode()).decode()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def credentials_hash(self) -> str | None:
|
def credentials_hash(self) -> str | None:
|
||||||
"""The hashed credentials used by the transport."""
|
"""The hashed credentials used by the transport."""
|
||||||
@ -145,8 +150,7 @@ class SslAesTransport(BaseTransport):
|
|||||||
if not self._credentials and self._credentials_hash:
|
if not self._credentials and self._credentials_hash:
|
||||||
return self._credentials_hash
|
return self._credentials_hash
|
||||||
if (cred := self._credentials) and cred.password and cred.username:
|
if (cred := self._credentials) and cred.password and cred.username:
|
||||||
ch = {"un": cred.username, "pwd": cred.password}
|
return self._create_b64_credentials(cred)
|
||||||
return base64.b64encode(json_dumps(ch).encode()).decode()
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _get_response_error(self, resp_dict: Any) -> SmartErrorCode:
|
def _get_response_error(self, resp_dict: Any) -> SmartErrorCode:
|
||||||
@ -329,6 +333,13 @@ class SslAesTransport(BaseTransport):
|
|||||||
+ f"status code {status_code} to handshake2"
|
+ f"status code {status_code} to handshake2"
|
||||||
)
|
)
|
||||||
resp_dict = cast(dict, resp_dict)
|
resp_dict = cast(dict, resp_dict)
|
||||||
|
if (
|
||||||
|
error_code := self._get_response_error(resp_dict)
|
||||||
|
) and error_code is SmartErrorCode.INVALID_NONCE:
|
||||||
|
raise AuthenticationError(
|
||||||
|
f"Invalid password hash in handshake2 for {self._host}"
|
||||||
|
)
|
||||||
|
|
||||||
self._handle_response_error_code(resp_dict, "Error in handshake2")
|
self._handle_response_error_code(resp_dict, "Error in handshake2")
|
||||||
|
|
||||||
self._seq = resp_dict["result"]["start_seq"]
|
self._seq = resp_dict["result"]["start_seq"]
|
||||||
@ -372,12 +383,12 @@ class SslAesTransport(BaseTransport):
|
|||||||
|
|
||||||
if not self._username:
|
if not self._username:
|
||||||
raise AuthenticationError(
|
raise AuthenticationError(
|
||||||
"Credentials must be supplied to connect to {self._host}"
|
f"Credentials must be supplied to connect to {self._host}"
|
||||||
)
|
)
|
||||||
if error_code is not SmartErrorCode.INVALID_NONCE or (
|
if error_code is not SmartErrorCode.INVALID_NONCE or (
|
||||||
resp_dict and "nonce" not in resp_dict["result"].get("data", {})
|
resp_dict and "nonce" not in resp_dict["result"].get("data", {})
|
||||||
):
|
):
|
||||||
raise AuthenticationError("Error trying handshake1: {resp_dict}")
|
raise AuthenticationError(f"Error trying handshake1: {resp_dict}")
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
resp_dict = cast(Dict[str, Any], resp_dict)
|
resp_dict = cast(Dict[str, Any], resp_dict)
|
||||||
@ -422,7 +433,7 @@ class SslAesTransport(BaseTransport):
|
|||||||
"params": {
|
"params": {
|
||||||
"cnonce": local_nonce,
|
"cnonce": local_nonce,
|
||||||
"encrypt_type": "3",
|
"encrypt_type": "3",
|
||||||
"username": self._username,
|
"username": username,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
http_client = self._http_client
|
http_client = self._http_client
|
||||||
|
374
kasa/tests/test_sslaestransport.py
Normal file
374
kasa/tests/test_sslaestransport.py
Normal file
@ -0,0 +1,374 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import secrets
|
||||||
|
from contextlib import nullcontext as does_not_raise
|
||||||
|
from json import dumps as json_dumps
|
||||||
|
from json import loads as json_loads
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import pytest
|
||||||
|
from yarl import URL
|
||||||
|
|
||||||
|
from kasa.protocol import DEFAULT_CREDENTIALS, get_default_credentials
|
||||||
|
|
||||||
|
from ..aestransport import AesEncyptionSession
|
||||||
|
from ..credentials import Credentials
|
||||||
|
from ..deviceconfig import DeviceConfig
|
||||||
|
from ..exceptions import (
|
||||||
|
AuthenticationError,
|
||||||
|
KasaException,
|
||||||
|
SmartErrorCode,
|
||||||
|
)
|
||||||
|
from ..experimental.sslaestransport import SslAesTransport, TransportState, _sha256_hash
|
||||||
|
from ..httpclient import HttpClient
|
||||||
|
|
||||||
|
MOCK_ADMIN_USER = get_default_credentials(DEFAULT_CREDENTIALS["TAPOCAMERA"]).username
|
||||||
|
MOCK_PWD = "correct_pwd" # noqa: S105
|
||||||
|
MOCK_USER = "mock@example.com"
|
||||||
|
MOCK_STOCK = "abcdefghijklmnopqrstuvwxyz1234)("
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
(
|
||||||
|
"status_code",
|
||||||
|
"username",
|
||||||
|
"password",
|
||||||
|
"wants_default_user",
|
||||||
|
"digest_password_fail",
|
||||||
|
"expectation",
|
||||||
|
),
|
||||||
|
[
|
||||||
|
pytest.param(
|
||||||
|
200, MOCK_USER, MOCK_PWD, False, False, does_not_raise(), id="success"
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
200,
|
||||||
|
MOCK_USER,
|
||||||
|
MOCK_PWD,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
does_not_raise(),
|
||||||
|
id="success-default",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
400,
|
||||||
|
MOCK_USER,
|
||||||
|
MOCK_PWD,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
pytest.raises(KasaException),
|
||||||
|
id="400 error",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
200,
|
||||||
|
"foobar",
|
||||||
|
MOCK_PWD,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
pytest.raises(AuthenticationError),
|
||||||
|
id="bad-username",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
200,
|
||||||
|
MOCK_USER,
|
||||||
|
"barfoo",
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
pytest.raises(AuthenticationError),
|
||||||
|
id="bad-password",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
200,
|
||||||
|
MOCK_USER,
|
||||||
|
MOCK_PWD,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
pytest.raises(AuthenticationError),
|
||||||
|
id="bad-password-digest",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_handshake(
|
||||||
|
mocker,
|
||||||
|
status_code,
|
||||||
|
username,
|
||||||
|
password,
|
||||||
|
wants_default_user,
|
||||||
|
digest_password_fail,
|
||||||
|
expectation,
|
||||||
|
):
|
||||||
|
host = "127.0.0.1"
|
||||||
|
mock_ssl_aes_device = MockSslAesDevice(
|
||||||
|
host,
|
||||||
|
status_code=status_code,
|
||||||
|
want_default_username=wants_default_user,
|
||||||
|
digest_password_fail=digest_password_fail,
|
||||||
|
)
|
||||||
|
mocker.patch.object(
|
||||||
|
aiohttp.ClientSession, "post", side_effect=mock_ssl_aes_device.post
|
||||||
|
)
|
||||||
|
|
||||||
|
transport = SslAesTransport(
|
||||||
|
config=DeviceConfig(host, credentials=Credentials(username, password))
|
||||||
|
)
|
||||||
|
|
||||||
|
assert transport._encryption_session is None
|
||||||
|
assert transport._state is TransportState.HANDSHAKE_REQUIRED
|
||||||
|
with expectation:
|
||||||
|
await transport.perform_handshake()
|
||||||
|
assert transport._encryption_session is not None
|
||||||
|
assert transport._state is TransportState.ESTABLISHED
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("wants_default_user"),
|
||||||
|
[pytest.param(False, id="username"), pytest.param(True, id="default")],
|
||||||
|
)
|
||||||
|
async def test_credentials_hash(mocker, wants_default_user):
|
||||||
|
host = "127.0.0.1"
|
||||||
|
mock_ssl_aes_device = MockSslAesDevice(
|
||||||
|
host, want_default_username=wants_default_user
|
||||||
|
)
|
||||||
|
mocker.patch.object(
|
||||||
|
aiohttp.ClientSession, "post", side_effect=mock_ssl_aes_device.post
|
||||||
|
)
|
||||||
|
creds = Credentials(MOCK_USER, MOCK_PWD)
|
||||||
|
creds_hash = SslAesTransport._create_b64_credentials(creds)
|
||||||
|
|
||||||
|
# Test with credentials input
|
||||||
|
transport = SslAesTransport(config=DeviceConfig(host, credentials=creds))
|
||||||
|
assert transport.credentials_hash == creds_hash
|
||||||
|
await transport.perform_handshake()
|
||||||
|
assert transport.credentials_hash == creds_hash
|
||||||
|
|
||||||
|
# Test with credentials_hash input
|
||||||
|
transport = SslAesTransport(config=DeviceConfig(host, credentials_hash=creds_hash))
|
||||||
|
mock_ssl_aes_device.handshake1_complete = False
|
||||||
|
assert transport.credentials_hash == creds_hash
|
||||||
|
await transport.perform_handshake()
|
||||||
|
assert transport.credentials_hash == creds_hash
|
||||||
|
|
||||||
|
|
||||||
|
async def test_send(mocker):
|
||||||
|
host = "127.0.0.1"
|
||||||
|
mock_ssl_aes_device = MockSslAesDevice(host, want_default_username=False)
|
||||||
|
mocker.patch.object(
|
||||||
|
aiohttp.ClientSession, "post", side_effect=mock_ssl_aes_device.post
|
||||||
|
)
|
||||||
|
|
||||||
|
transport = SslAesTransport(
|
||||||
|
config=DeviceConfig(host, credentials=Credentials(MOCK_USER, MOCK_PWD))
|
||||||
|
)
|
||||||
|
request = {
|
||||||
|
"method": "getDeviceInfo",
|
||||||
|
"params": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
res = await transport.send(json_dumps(request))
|
||||||
|
assert "result" in res
|
||||||
|
|
||||||
|
|
||||||
|
async def test_unencrypted_response(mocker, caplog):
|
||||||
|
host = "127.0.0.1"
|
||||||
|
mock_ssl_aes_device = MockSslAesDevice(host, do_not_encrypt_response=True)
|
||||||
|
mocker.patch.object(
|
||||||
|
aiohttp.ClientSession, "post", side_effect=mock_ssl_aes_device.post
|
||||||
|
)
|
||||||
|
|
||||||
|
transport = SslAesTransport(
|
||||||
|
config=DeviceConfig(host, credentials=Credentials(MOCK_USER, MOCK_PWD))
|
||||||
|
)
|
||||||
|
|
||||||
|
request = {
|
||||||
|
"method": "getDeviceInfo",
|
||||||
|
"params": None,
|
||||||
|
}
|
||||||
|
caplog.set_level(logging.DEBUG)
|
||||||
|
res = await transport.send(json_dumps(request))
|
||||||
|
assert "result" in res
|
||||||
|
assert (
|
||||||
|
"Received unencrypted response over secure passthrough from 127.0.0.1"
|
||||||
|
in caplog.text
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_port_override():
|
||||||
|
"""Test that port override sets the app_url."""
|
||||||
|
host = "127.0.0.1"
|
||||||
|
port_override = 12345
|
||||||
|
config = DeviceConfig(
|
||||||
|
host, credentials=Credentials("foo", "bar"), port_override=port_override
|
||||||
|
)
|
||||||
|
transport = SslAesTransport(config=config)
|
||||||
|
|
||||||
|
assert str(transport._app_url) == f"https://127.0.0.1:{port_override}"
|
||||||
|
|
||||||
|
|
||||||
|
class MockSslAesDevice:
|
||||||
|
BAD_USER_RESP = {
|
||||||
|
"error_code": SmartErrorCode.SESSION_EXPIRED.value,
|
||||||
|
"result": {
|
||||||
|
"data": {
|
||||||
|
"code": -60502,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
BAD_PWD_RESP = {
|
||||||
|
"error_code": SmartErrorCode.INVALID_NONCE.value,
|
||||||
|
"result": {
|
||||||
|
"data": {
|
||||||
|
"code": SmartErrorCode.SESSION_EXPIRED.value,
|
||||||
|
"encrypt_type": ["3"],
|
||||||
|
"key": "Someb64keyWithUnknownPurpose",
|
||||||
|
"nonce": "1234567890ABCDEF", # Whatever the original nonce was
|
||||||
|
"device_confirm": "",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
class _mock_response:
|
||||||
|
def __init__(self, status, request: dict):
|
||||||
|
self.status = status
|
||||||
|
self._json = request
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_t, exc_v, exc_tb):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def read(self):
|
||||||
|
if isinstance(self._json, dict):
|
||||||
|
return json_dumps(self._json).encode()
|
||||||
|
return self._json
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
host,
|
||||||
|
*,
|
||||||
|
status_code=200,
|
||||||
|
want_default_username: bool = False,
|
||||||
|
do_not_encrypt_response=False,
|
||||||
|
send_response=None,
|
||||||
|
sequential_request_delay=0,
|
||||||
|
send_error_code=0,
|
||||||
|
secure_passthrough_error_code=0,
|
||||||
|
digest_password_fail=False,
|
||||||
|
):
|
||||||
|
self.host = host
|
||||||
|
self.http_client = HttpClient(DeviceConfig(self.host))
|
||||||
|
self.encryption_session: AesEncyptionSession | None = None
|
||||||
|
self.server_nonce = secrets.token_bytes(8).hex().upper()
|
||||||
|
self.handshake1_complete = False
|
||||||
|
|
||||||
|
# test behaviour attributes
|
||||||
|
self.status_code = status_code
|
||||||
|
self.send_error_code = send_error_code
|
||||||
|
self.secure_passthrough_error_code = secure_passthrough_error_code
|
||||||
|
self.do_not_encrypt_response = do_not_encrypt_response
|
||||||
|
self.want_default_username = want_default_username
|
||||||
|
self.digest_password_fail = digest_password_fail
|
||||||
|
|
||||||
|
async def post(self, url: URL, params=None, json=None, data=None, *_, **__):
|
||||||
|
if data:
|
||||||
|
json = json_loads(data)
|
||||||
|
res = await self._post(url, json)
|
||||||
|
return res
|
||||||
|
|
||||||
|
async def _post(self, url: URL, json: dict[str, Any]):
|
||||||
|
method = json["method"]
|
||||||
|
|
||||||
|
if method == "login" and not self.handshake1_complete:
|
||||||
|
return await self._return_handshake1_response(url, json)
|
||||||
|
|
||||||
|
if method == "login" and self.handshake1_complete:
|
||||||
|
return await self._return_handshake2_response(url, json)
|
||||||
|
elif method == "securePassthrough":
|
||||||
|
assert url == URL(f"https://{self.host}/stok={MOCK_STOCK}/ds")
|
||||||
|
return await self._return_secure_passthrough_response(url, json)
|
||||||
|
else:
|
||||||
|
assert url == URL(f"https://{self.host}/stok={MOCK_STOCK}/ds")
|
||||||
|
return await self._return_send_response(url, json)
|
||||||
|
|
||||||
|
async def _return_handshake1_response(self, url: URL, request: dict[str, Any]):
|
||||||
|
request_nonce = request["params"].get("cnonce")
|
||||||
|
request_username = request["params"].get("username")
|
||||||
|
|
||||||
|
if (self.want_default_username and request_username != MOCK_ADMIN_USER) or (
|
||||||
|
not self.want_default_username and request_username != MOCK_USER
|
||||||
|
):
|
||||||
|
return self._mock_response(self.status_code, self.BAD_USER_RESP)
|
||||||
|
|
||||||
|
device_confirm = SslAesTransport.generate_confirm_hash(
|
||||||
|
request_nonce, self.server_nonce, _sha256_hash(MOCK_PWD.encode())
|
||||||
|
)
|
||||||
|
self.handshake1_complete = True
|
||||||
|
resp = {
|
||||||
|
"error_code": SmartErrorCode.INVALID_NONCE.value,
|
||||||
|
"result": {
|
||||||
|
"data": {
|
||||||
|
"code": SmartErrorCode.INVALID_NONCE.value,
|
||||||
|
"encrypt_type": ["3"],
|
||||||
|
"key": "Someb64keyWithUnknownPurpose",
|
||||||
|
"nonce": self.server_nonce,
|
||||||
|
"device_confirm": device_confirm,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return self._mock_response(self.status_code, resp)
|
||||||
|
|
||||||
|
async def _return_handshake2_response(self, url: URL, request: dict[str, Any]):
|
||||||
|
request_nonce = request["params"].get("cnonce")
|
||||||
|
request_username = request["params"].get("username")
|
||||||
|
if (self.want_default_username and request_username != MOCK_ADMIN_USER) or (
|
||||||
|
not self.want_default_username and request_username != MOCK_USER
|
||||||
|
):
|
||||||
|
return self._mock_response(self.status_code, self.BAD_USER_RESP)
|
||||||
|
|
||||||
|
request_password = request["params"].get("digest_passwd")
|
||||||
|
expected_pwd = SslAesTransport.generate_digest_password(
|
||||||
|
request_nonce, self.server_nonce, _sha256_hash(MOCK_PWD.encode())
|
||||||
|
)
|
||||||
|
if request_password != expected_pwd or self.digest_password_fail:
|
||||||
|
return self._mock_response(self.status_code, self.BAD_PWD_RESP)
|
||||||
|
|
||||||
|
lsk = SslAesTransport.generate_encryption_token(
|
||||||
|
"lsk", request_nonce, self.server_nonce, _sha256_hash(MOCK_PWD.encode())
|
||||||
|
)
|
||||||
|
ivb = SslAesTransport.generate_encryption_token(
|
||||||
|
"ivb", request_nonce, self.server_nonce, _sha256_hash(MOCK_PWD.encode())
|
||||||
|
)
|
||||||
|
self.encryption_session = AesEncyptionSession(lsk, ivb)
|
||||||
|
resp = {
|
||||||
|
"error_code": 0,
|
||||||
|
"result": {"stok": MOCK_STOCK, "user_group": "root", "start_seq": 100},
|
||||||
|
}
|
||||||
|
return self._mock_response(self.status_code, resp)
|
||||||
|
|
||||||
|
async def _return_secure_passthrough_response(self, url: URL, json: dict[str, Any]):
|
||||||
|
encrypted_request = json["params"]["request"]
|
||||||
|
assert self.encryption_session
|
||||||
|
decrypted_request = self.encryption_session.decrypt(encrypted_request.encode())
|
||||||
|
decrypted_request_dict = json_loads(decrypted_request)
|
||||||
|
decrypted_response = await self._post(url, decrypted_request_dict)
|
||||||
|
async with decrypted_response:
|
||||||
|
decrypted_response_data = await decrypted_response.read()
|
||||||
|
|
||||||
|
encrypted_response = self.encryption_session.encrypt(decrypted_response_data)
|
||||||
|
response = (
|
||||||
|
decrypted_response_data
|
||||||
|
if self.do_not_encrypt_response
|
||||||
|
else encrypted_response
|
||||||
|
)
|
||||||
|
result = {
|
||||||
|
"result": {"response": response.decode()},
|
||||||
|
"error_code": self.secure_passthrough_error_code,
|
||||||
|
}
|
||||||
|
return self._mock_response(self.status_code, result)
|
||||||
|
|
||||||
|
async def _return_send_response(self, url: URL, json: dict[str, Any]):
|
||||||
|
result = {"result": {"method": None}, "error_code": self.send_error_code}
|
||||||
|
return self._mock_response(self.status_code, result)
|
Loading…
Reference in New Issue
Block a user