mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-10-20 06:18:01 +00:00
Do login entirely within AesTransport (#580)
* Do login entirely within AesTransport * Remove login and handshake attributes from BaseTransport * Add AesTransport tests * Synchronise transport and protocol __init__ signatures and rename internal variables * Update after review
This commit is contained in:
@@ -301,6 +301,9 @@ class FakeSmartProtocol(SmartProtocol):
|
||||
|
||||
class FakeSmartTransport(BaseTransport):
|
||||
def __init__(self, info):
|
||||
super().__init__(
|
||||
"127.0.0.123",
|
||||
)
|
||||
self.info = info
|
||||
|
||||
@property
|
||||
|
174
kasa/tests/test_aestransport.py
Normal file
174
kasa/tests/test_aestransport.py
Normal file
@@ -0,0 +1,174 @@
|
||||
import base64
|
||||
import json
|
||||
import time
|
||||
from contextlib import nullcontext as does_not_raise
|
||||
from json import dumps as json_dumps
|
||||
from json import loads as json_loads
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
|
||||
|
||||
from ..aestransport import AesEncyptionSession, AesTransport
|
||||
from ..credentials import Credentials
|
||||
from ..exceptions import SmartDeviceException
|
||||
|
||||
DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}}
|
||||
|
||||
key = b"8\x89\x02\xfa\xf5Xs\x1c\xa1 H\x9a\x82\xc7\xd9\t"
|
||||
iv = b"9=\xf8\x1bS\xcd0\xb5\x89i\xba\xfd^9\x9f\xfa"
|
||||
KEY_IV = key + iv
|
||||
|
||||
|
||||
def test_encrypt():
|
||||
encryption_session = AesEncyptionSession(KEY_IV[:16], KEY_IV[16:])
|
||||
|
||||
d = json.dumps({"foo": 1, "bar": 2})
|
||||
encrypted = encryption_session.encrypt(d.encode())
|
||||
assert d == encryption_session.decrypt(encrypted)
|
||||
|
||||
# test encrypt unicode
|
||||
d = "{'snowman': '\u2603'}"
|
||||
encrypted = encryption_session.encrypt(d.encode())
|
||||
assert d == encryption_session.decrypt(encrypted)
|
||||
|
||||
|
||||
status_parameters = pytest.mark.parametrize(
|
||||
"status_code, error_code, inner_error_code, expectation",
|
||||
[
|
||||
(200, 0, 0, does_not_raise()),
|
||||
(400, 0, 0, pytest.raises(SmartDeviceException)),
|
||||
(200, -1, 0, pytest.raises(SmartDeviceException)),
|
||||
],
|
||||
ids=("success", "status_code", "error_code"),
|
||||
)
|
||||
|
||||
|
||||
@status_parameters
|
||||
async def test_handshake(
|
||||
mocker, status_code, error_code, inner_error_code, expectation
|
||||
):
|
||||
host = "127.0.0.1"
|
||||
mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code)
|
||||
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
|
||||
|
||||
transport = AesTransport(host=host, credentials=Credentials("foo", "bar"))
|
||||
|
||||
assert transport._encryption_session is None
|
||||
assert transport._handshake_done is False
|
||||
with expectation:
|
||||
await transport.perform_handshake()
|
||||
assert transport._encryption_session is not None
|
||||
assert transport._handshake_done is True
|
||||
|
||||
|
||||
@status_parameters
|
||||
async def test_login(mocker, status_code, error_code, inner_error_code, expectation):
|
||||
host = "127.0.0.1"
|
||||
mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code)
|
||||
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
|
||||
|
||||
transport = AesTransport(host=host, credentials=Credentials("foo", "bar"))
|
||||
transport._handshake_done = True
|
||||
transport._session_expire_at = time.time() + 86400
|
||||
transport._encryption_session = mock_aes_device.encryption_session
|
||||
|
||||
assert transport._login_token is None
|
||||
with expectation:
|
||||
await transport.perform_login()
|
||||
assert transport._login_token == mock_aes_device.token
|
||||
|
||||
|
||||
@status_parameters
|
||||
async def test_send(mocker, status_code, error_code, inner_error_code, expectation):
|
||||
host = "127.0.0.1"
|
||||
mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code)
|
||||
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
|
||||
|
||||
transport = AesTransport(host=host, credentials=Credentials("foo", "bar"))
|
||||
transport._handshake_done = True
|
||||
transport._session_expire_at = time.time() + 86400
|
||||
transport._encryption_session = mock_aes_device.encryption_session
|
||||
transport._login_token = mock_aes_device.token
|
||||
|
||||
un, pw = transport.hash_credentials(True)
|
||||
request = {
|
||||
"method": "get_device_info",
|
||||
"params": None,
|
||||
"request_time_milis": round(time.time() * 1000),
|
||||
"requestID": 1,
|
||||
"terminal_uuid": "foobar",
|
||||
}
|
||||
with expectation:
|
||||
res = await transport.send(json_dumps(request))
|
||||
assert "result" in res
|
||||
|
||||
|
||||
class MockAesDevice:
|
||||
class _mock_response:
|
||||
def __init__(self, status_code, json: dict):
|
||||
self.status_code = status_code
|
||||
self._json = json
|
||||
|
||||
def json(self):
|
||||
return self._json
|
||||
|
||||
encryption_session = AesEncyptionSession(KEY_IV[:16], KEY_IV[16:])
|
||||
token = "test_token" # noqa
|
||||
|
||||
def __init__(self, host, status_code=200, error_code=0, inner_error_code=0):
|
||||
self.host = host
|
||||
self.status_code = status_code
|
||||
self.error_code = error_code
|
||||
self.inner_error_code = inner_error_code
|
||||
|
||||
async def post(self, url, params=None, json=None, *_, **__):
|
||||
return await self._post(url, json)
|
||||
|
||||
async def _post(self, url, json):
|
||||
if json["method"] == "handshake":
|
||||
return await self._return_handshake_response(url, json)
|
||||
elif json["method"] == "securePassthrough":
|
||||
return await self._return_secure_passthrough_response(url, json)
|
||||
elif json["method"] == "login_device":
|
||||
return await self._return_login_response(url, json)
|
||||
else:
|
||||
assert url == f"http://{self.host}/app?token={self.token}"
|
||||
return await self._return_send_response(url, json)
|
||||
|
||||
async def _return_handshake_response(self, url, json):
|
||||
start = len("-----BEGIN PUBLIC KEY-----\n")
|
||||
end = len("\n-----END PUBLIC KEY-----\n")
|
||||
client_pub_key = json["params"]["key"][start:-end]
|
||||
|
||||
client_pub_key_data = base64.b64decode(client_pub_key.encode())
|
||||
client_pub_key = serialization.load_der_public_key(client_pub_key_data, None)
|
||||
encrypted_key = client_pub_key.encrypt(KEY_IV, asymmetric_padding.PKCS1v15())
|
||||
key_64 = base64.b64encode(encrypted_key).decode()
|
||||
return self._mock_response(
|
||||
self.status_code, {"result": {"key": key_64}, "error_code": self.error_code}
|
||||
)
|
||||
|
||||
async def _return_secure_passthrough_response(self, url, json):
|
||||
encrypted_request = json["params"]["request"]
|
||||
decrypted_request = self.encryption_session.decrypt(encrypted_request.encode())
|
||||
decrypted_request_dict = json_loads(decrypted_request)
|
||||
decrypted_response = await self._post(url, decrypted_request_dict)
|
||||
decrypted_response_dict = decrypted_response.json()
|
||||
encrypted_response = self.encryption_session.encrypt(
|
||||
json_dumps(decrypted_response_dict).encode()
|
||||
)
|
||||
result = {
|
||||
"result": {"response": encrypted_response.decode()},
|
||||
"error_code": self.error_code,
|
||||
}
|
||||
return self._mock_response(self.status_code, result)
|
||||
|
||||
async def _return_login_response(self, url, json):
|
||||
result = {"result": {"token": self.token}, "error_code": self.inner_error_code}
|
||||
return self._mock_response(self.status_code, result)
|
||||
|
||||
async def _return_send_response(self, url, json):
|
||||
result = {"result": {"method": None}, "error_code": self.inner_error_code}
|
||||
return self._mock_response(self.status_code, result)
|
@@ -96,10 +96,9 @@ async def test_protocol_reconnect(mocker, retry_count, protocol_class, transport
|
||||
|
||||
return mock_response
|
||||
|
||||
mocker.patch.object(
|
||||
transport_class, "needs_handshake", property(lambda self: False)
|
||||
)
|
||||
mocker.patch.object(transport_class, "needs_login", property(lambda self: False))
|
||||
mocker.patch.object(transport_class, "perform_handshake")
|
||||
if hasattr(transport_class, "perform_login"):
|
||||
mocker.patch.object(transport_class, "perform_login")
|
||||
|
||||
send_mock = mocker.patch.object(
|
||||
transport_class,
|
||||
@@ -128,7 +127,7 @@ async def test_protocol_logging(mocker, caplog, log_level):
|
||||
seed = secrets.token_bytes(16)
|
||||
auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar"))
|
||||
encryption_session = KlapEncryptionSession(seed, seed, auth_hash)
|
||||
protocol = IotProtocol("127.0.0.1")
|
||||
protocol = IotProtocol("127.0.0.1", transport=KlapTransport("127.0.0.1"))
|
||||
|
||||
protocol._transport._handshake_done = True
|
||||
protocol._transport._session_expire_at = time.time() + 86400
|
||||
@@ -206,7 +205,10 @@ async def test_handshake1(mocker, device_credentials, expectation):
|
||||
httpx.AsyncClient, "post", side_effect=_return_handshake1_response
|
||||
)
|
||||
|
||||
protocol = IotProtocol("127.0.0.1", credentials=client_credentials)
|
||||
protocol = IotProtocol(
|
||||
"127.0.0.1",
|
||||
transport=KlapTransport("127.0.0.1", credentials=client_credentials),
|
||||
)
|
||||
|
||||
protocol._transport.http_client = httpx.AsyncClient()
|
||||
with expectation:
|
||||
@@ -243,7 +245,10 @@ async def test_handshake(mocker):
|
||||
httpx.AsyncClient, "post", side_effect=_return_handshake_response
|
||||
)
|
||||
|
||||
protocol = IotProtocol("127.0.0.1", credentials=client_credentials)
|
||||
protocol = IotProtocol(
|
||||
"127.0.0.1",
|
||||
transport=KlapTransport("127.0.0.1", credentials=client_credentials),
|
||||
)
|
||||
protocol._transport.http_client = httpx.AsyncClient()
|
||||
|
||||
response_status = 200
|
||||
@@ -289,7 +294,10 @@ async def test_query(mocker):
|
||||
|
||||
mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response)
|
||||
|
||||
protocol = IotProtocol("127.0.0.1", credentials=client_credentials)
|
||||
protocol = IotProtocol(
|
||||
"127.0.0.1",
|
||||
transport=KlapTransport("127.0.0.1", credentials=client_credentials),
|
||||
)
|
||||
|
||||
for _ in range(10):
|
||||
resp = await protocol.query({})
|
||||
@@ -333,7 +341,10 @@ async def test_authentication_failures(mocker, response_status, expectation):
|
||||
|
||||
mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response)
|
||||
|
||||
protocol = IotProtocol("127.0.0.1", credentials=client_credentials)
|
||||
protocol = IotProtocol(
|
||||
"127.0.0.1",
|
||||
transport=KlapTransport("127.0.0.1", credentials=client_credentials),
|
||||
)
|
||||
|
||||
with expectation:
|
||||
await protocol.query({})
|
||||
|
@@ -1,13 +1,21 @@
|
||||
import errno
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import pkgutil
|
||||
import struct
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from ..exceptions import SmartDeviceException
|
||||
from ..protocol import TPLinkSmartHomeProtocol
|
||||
from ..protocol import (
|
||||
BaseTransport,
|
||||
TPLinkProtocol,
|
||||
TPLinkSmartHomeProtocol,
|
||||
_XorTransport,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("retry_count", [1, 3, 5])
|
||||
@@ -24,7 +32,9 @@ async def test_protocol_retries(mocker, retry_count):
|
||||
|
||||
conn = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await TPLinkSmartHomeProtocol("127.0.0.1").query({}, retry_count=retry_count)
|
||||
await TPLinkSmartHomeProtocol(
|
||||
"127.0.0.1", transport=_XorTransport("127.0.0.1")
|
||||
).query({}, retry_count=retry_count)
|
||||
|
||||
assert conn.call_count == retry_count + 1
|
||||
|
||||
@@ -35,7 +45,9 @@ async def test_protocol_no_retry_on_unreachable(mocker):
|
||||
side_effect=OSError(errno.EHOSTUNREACH, "No route to host"),
|
||||
)
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await TPLinkSmartHomeProtocol("127.0.0.1").query({}, retry_count=5)
|
||||
await TPLinkSmartHomeProtocol(
|
||||
"127.0.0.1", transport=_XorTransport("127.0.0.1")
|
||||
).query({}, retry_count=5)
|
||||
|
||||
assert conn.call_count == 1
|
||||
|
||||
@@ -46,7 +58,9 @@ async def test_protocol_no_retry_connection_refused(mocker):
|
||||
side_effect=ConnectionRefusedError,
|
||||
)
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await TPLinkSmartHomeProtocol("127.0.0.1").query({}, retry_count=5)
|
||||
await TPLinkSmartHomeProtocol(
|
||||
"127.0.0.1", transport=_XorTransport("127.0.0.1")
|
||||
).query({}, retry_count=5)
|
||||
|
||||
assert conn.call_count == 1
|
||||
|
||||
@@ -57,7 +71,9 @@ async def test_protocol_retry_recoverable_error(mocker):
|
||||
side_effect=OSError(errno.ECONNRESET, "Connection reset by peer"),
|
||||
)
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await TPLinkSmartHomeProtocol("127.0.0.1").query({}, retry_count=5)
|
||||
await TPLinkSmartHomeProtocol(
|
||||
"127.0.0.1", transport=_XorTransport("127.0.0.1")
|
||||
).query({}, retry_count=5)
|
||||
|
||||
assert conn.call_count == 6
|
||||
|
||||
@@ -91,7 +107,9 @@ async def test_protocol_reconnect(mocker, retry_count):
|
||||
mocker.patch.object(reader, "readexactly", _mock_read)
|
||||
return reader, writer
|
||||
|
||||
protocol = TPLinkSmartHomeProtocol("127.0.0.1")
|
||||
protocol = TPLinkSmartHomeProtocol(
|
||||
"127.0.0.1", transport=_XorTransport("127.0.0.1")
|
||||
)
|
||||
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
|
||||
response = await protocol.query({}, retry_count=retry_count)
|
||||
assert response == {"great": "success"}
|
||||
@@ -119,7 +137,9 @@ async def test_protocol_logging(mocker, caplog, log_level):
|
||||
mocker.patch.object(reader, "readexactly", _mock_read)
|
||||
return reader, writer
|
||||
|
||||
protocol = TPLinkSmartHomeProtocol("127.0.0.1")
|
||||
protocol = TPLinkSmartHomeProtocol(
|
||||
"127.0.0.1", transport=_XorTransport("127.0.0.1")
|
||||
)
|
||||
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
|
||||
response = await protocol.query({})
|
||||
assert response == {"great": "success"}
|
||||
@@ -153,7 +173,9 @@ async def test_protocol_custom_port(mocker, custom_port):
|
||||
mocker.patch.object(reader, "readexactly", _mock_read)
|
||||
return reader, writer
|
||||
|
||||
protocol = TPLinkSmartHomeProtocol("127.0.0.1", port=custom_port)
|
||||
protocol = TPLinkSmartHomeProtocol(
|
||||
"127.0.0.1", transport=_XorTransport("127.0.0.1", port=custom_port)
|
||||
)
|
||||
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
|
||||
response = await protocol.query({})
|
||||
assert response == {"great": "success"}
|
||||
@@ -227,3 +249,63 @@ def test_decrypt_unicode():
|
||||
d = "{'snowman': '\u2603'}"
|
||||
|
||||
assert d == TPLinkSmartHomeProtocol.decrypt(e)
|
||||
|
||||
|
||||
def _get_subclasses(of_class):
|
||||
import kasa
|
||||
|
||||
package = sys.modules["kasa"]
|
||||
subclasses = set()
|
||||
for _, modname, _ in pkgutil.iter_modules(package.__path__):
|
||||
importlib.import_module("." + modname, package="kasa")
|
||||
module = sys.modules["kasa." + modname]
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if inspect.isclass(obj) and issubclass(obj, of_class):
|
||||
subclasses.add((name, obj))
|
||||
return subclasses
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"class_name_obj", _get_subclasses(TPLinkProtocol), ids=lambda t: t[0]
|
||||
)
|
||||
def test_protocol_init_signature(class_name_obj):
|
||||
params = list(inspect.signature(class_name_obj[1].__init__).parameters.values())
|
||||
|
||||
assert len(params) == 3
|
||||
assert (
|
||||
params[0].name == "self"
|
||||
and params[0].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||||
)
|
||||
assert (
|
||||
params[1].name == "host"
|
||||
and params[1].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||||
)
|
||||
assert (
|
||||
params[2].name == "transport"
|
||||
and params[2].kind == inspect.Parameter.KEYWORD_ONLY
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"class_name_obj", _get_subclasses(BaseTransport), ids=lambda t: t[0]
|
||||
)
|
||||
def test_transport_init_signature(class_name_obj):
|
||||
params = list(inspect.signature(class_name_obj[1].__init__).parameters.values())
|
||||
|
||||
assert len(params) == 5
|
||||
assert (
|
||||
params[0].name == "self"
|
||||
and params[0].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||||
)
|
||||
assert (
|
||||
params[1].name == "host"
|
||||
and params[1].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||||
)
|
||||
assert params[2].name == "port" and params[2].kind == inspect.Parameter.KEYWORD_ONLY
|
||||
assert (
|
||||
params[3].name == "credentials"
|
||||
and params[3].kind == inspect.Parameter.KEYWORD_ONLY
|
||||
)
|
||||
assert (
|
||||
params[4].name == "timeout" and params[4].kind == inspect.Parameter.KEYWORD_ONLY
|
||||
)
|
||||
|
@@ -232,7 +232,7 @@ async def test_modules_preserved(dev: SmartDevice):
|
||||
async def test_create_smart_device_with_timeout():
|
||||
"""Make sure timeout is passed to the protocol."""
|
||||
dev = SmartDevice(host="127.0.0.1", timeout=100)
|
||||
assert dev.protocol.timeout == 100
|
||||
assert dev.protocol._transport._timeout == 100
|
||||
|
||||
|
||||
async def test_create_thin_wrapper():
|
||||
|
Reference in New Issue
Block a user