Follow main package structure for tests (#1317)

* Transport tests under tests/transports/
* Protocol tests under tests/protocols/
* IOT tests under iot/
* Plus some minor cleanups, most code changes are related to splitting
up smart & iot tests
This commit is contained in:
Teemu R.
2024-11-28 17:56:20 +01:00
committed by GitHub
parent 6adb2b5c28
commit fcb604e435
18 changed files with 393 additions and 392 deletions

View File

View File

@@ -0,0 +1,541 @@
from __future__ import annotations
import base64
import json
import logging
import random
import string
import time
from contextlib import nullcontext as does_not_raise
from json import dumps as json_dumps
from json import loads as json_loads
from typing import Any
import aiohttp
import pytest
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
from freezegun.api import FrozenDateTimeFactory
from yarl import URL
from kasa.credentials import Credentials
from kasa.deviceconfig import DeviceConfig
from kasa.exceptions import (
AuthenticationError,
KasaException,
SmartErrorCode,
_ConnectionError,
)
from kasa.httpclient import HttpClient
from kasa.transports.aestransport import (
AesEncyptionSession,
AesTransport,
TransportState,
)
pytestmark = [pytest.mark.requires_dummy]
DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}}
key = b"8\x89\x02\xfa\xf5Xs\x1c\xa1 H\x9a\x82\xc7\xd9\t"
iv = b"9=\xf8\x1bS\xcd0\xb5\x89i\xba\xfd^9\x9f\xfa"
KEY_IV = key + iv
def test_encrypt():
encryption_session = AesEncyptionSession(KEY_IV[:16], KEY_IV[16:])
d = json.dumps({"foo": 1, "bar": 2})
encrypted = encryption_session.encrypt(d.encode())
assert d == encryption_session.decrypt(encrypted)
# test encrypt unicode
d = "{'snowman': '\u2603'}"
encrypted = encryption_session.encrypt(d.encode())
assert d == encryption_session.decrypt(encrypted)
status_parameters = pytest.mark.parametrize(
"status_code, error_code, inner_error_code, expectation",
[
(200, 0, 0, does_not_raise()),
(400, 0, 0, pytest.raises(KasaException)),
(200, -1, 0, pytest.raises(KasaException)),
],
ids=("success", "status_code", "error_code"),
)
@status_parameters
async def test_handshake(
mocker, status_code, error_code, inner_error_code, expectation
):
host = "127.0.0.1"
mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code)
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=mock_aes_device.post)
transport = AesTransport(
config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
)
assert transport._encryption_session is None
assert transport._state is TransportState.HANDSHAKE_REQUIRED
with expectation:
await transport.perform_handshake()
assert transport._encryption_session is not None
assert transport._state is TransportState.LOGIN_REQUIRED
async def test_handshake_with_keys(mocker):
host = "127.0.0.1"
mock_aes_device = MockAesDevice(host)
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=mock_aes_device.post)
test_keys = {
"private": "MIICdQIBADANBgkqhkiG9w0BAQEFAASCAl8wggJbAgEAAoGBAMo/JQpXIbP2M3bLOKyfEVCURFCxHIXv4HDME8J58AL4BwGDXf0oQycgj9nV+T/MzgEd/4iVysYuYfLuIEKXADP7Lby6AfA/dbcinZZ7bLUNMNa7TaylIvVKtSfR0LV8AmG0jdQYkr4cTzLAEd+AEs/wG3nMQNEcoQRVY+svLPDjAgMBAAECgYBCsDOch0KbvrEVmMklUoY5Fcq4+M249HIDf6d8VwznTbWxsAmL8nzCKCCG6eF4QiYjhCrAdPQaCS1PF2oXywbLhngid/9W9gz4CKKDJChs1X8KvLi+TLg1jgJUXvq9yVNh1CB+lS2ho4gdDDCbVmiVOZR5TDfEf0xeJ+Zz3zlUEQJBAPkhuNdc3yRue8huFZbrWwikURQPYBxLOYfVTDsfV9mZGSkGoWS1FPDsxrqSXugTmcTRuw+lrXKDabJ72kqywA8CQQDP0oaGh5r7F12Xzcwb7X9JkTvyr+rO8YgVtKNBaNVOPabAzysNwOlvH/sNCVQcRj8rn5LNXitgLx6T+Q5uqa3tAkA7J0elUzbkhps7ju/vYri9x448zh3K+g2R9BJio2GPmCuCM0HVEK4FOqNBH4oLXsQPGKFq6LLTUuKg74l4XRL/AkBHBO6r8pNn0yhMxCtIL/UbsuIFoVBgv/F9WWmg5K5gOnlN0n4oCRC8xPUKE3IG54qW4cVNIS05hWCxuJ7R+nJRAkByt/+kX1nQxis2wIXj90fztXG3oSmoVaieYxaXPxlWvX3/Q5kslFF5UsGy9gcK0v2PXhqjTbhud3/X0Er6YP4v",
"public": "MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDKPyUKVyGz9jN2yzisnxFQlERQsRyF7+BwzBPCefAC+AcBg139KEMnII/Z1fk/zM4BHf+IlcrGLmHy7iBClwAz+y28ugHwP3W3Ip2We2y1DTDWu02spSL1SrUn0dC1fAJhtI3UGJK+HE8ywBHfgBLP8Bt5zEDRHKEEVWPrLyzw4wIDAQAB",
}
transport = AesTransport(
config=DeviceConfig(
host, credentials=Credentials("foo", "bar"), aes_keys=test_keys
)
)
assert transport._encryption_session is None
assert transport._state is TransportState.HANDSHAKE_REQUIRED
await transport.perform_handshake()
assert transport._key_pair.private_key_der_b64 == test_keys["private"]
assert transport._key_pair.public_key_der_b64 == test_keys["public"]
@status_parameters
async def test_login(mocker, status_code, error_code, inner_error_code, expectation):
host = "127.0.0.1"
mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code)
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=mock_aes_device.post)
transport = AesTransport(
config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
)
transport._state = TransportState.LOGIN_REQUIRED
transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session
assert transport._token_url is None
with expectation:
await transport.perform_login()
assert mock_aes_device.token in str(transport._token_url)
assert transport._config.aes_keys == transport._key_pair
@pytest.mark.parametrize(
("inner_error_codes", "expectation", "call_count"),
[
([SmartErrorCode.LOGIN_ERROR, 0, 0, 0], does_not_raise(), 4),
(
[SmartErrorCode.LOGIN_ERROR, SmartErrorCode.LOGIN_ERROR],
pytest.raises(AuthenticationError),
3,
),
(
[SmartErrorCode.LOGIN_FAILED_ERROR],
pytest.raises(AuthenticationError),
1,
),
(
[SmartErrorCode.LOGIN_ERROR, SmartErrorCode.SESSION_TIMEOUT_ERROR],
pytest.raises(KasaException),
3,
),
],
ids=(
"LOGIN_ERROR-success",
"LOGIN_ERROR-LOGIN_ERROR",
"LOGIN_FAILED_ERROR",
"LOGIN_ERROR-SESSION_TIMEOUT_ERROR",
),
)
async def test_login_errors(mocker, inner_error_codes, expectation, call_count):
host = "127.0.0.1"
mock_aes_device = MockAesDevice(host, 200, 0, inner_error_codes)
post_mock = mocker.patch.object(
aiohttp.ClientSession, "post", side_effect=mock_aes_device.post
)
transport = AesTransport(
config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
)
transport._state = TransportState.LOGIN_REQUIRED
transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session
mocker.patch.object(transport._http_client, "WAIT_BETWEEN_REQUESTS_ON_OSERROR", 0)
assert transport._token_url is None
request = {
"method": "get_device_info",
"params": None,
"request_time_milis": round(time.time() * 1000),
"requestID": 1,
"terminal_uuid": "foobar",
}
with expectation:
await transport.send(json_dumps(request))
assert mock_aes_device.token in str(transport._token_url)
assert post_mock.call_count == call_count # Login, Handshake, Login
await transport.close()
@status_parameters
async def test_send(mocker, status_code, error_code, inner_error_code, expectation):
host = "127.0.0.1"
mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code)
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=mock_aes_device.post)
transport = AesTransport(
config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
)
transport._handshake_done = True
transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session
transport._token_url = transport._app_url.with_query(
f"token={mock_aes_device.token}"
)
request = {
"method": "get_device_info",
"params": None,
"request_time_milis": round(time.time() * 1000),
"requestID": 1,
"terminal_uuid": "foobar",
}
with expectation:
res = await transport.send(json_dumps(request))
assert "result" in res
@pytest.mark.xdist_group(name="caplog")
async def test_unencrypted_response(mocker, caplog):
host = "127.0.0.1"
mock_aes_device = MockAesDevice(host, 200, 0, 0, do_not_encrypt_response=True)
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=mock_aes_device.post)
transport = AesTransport(
config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
)
transport._state = TransportState.ESTABLISHED
transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session
transport._token_url = transport._app_url.with_query(
f"token={mock_aes_device.token}"
)
request = {
"method": "get_device_info",
"params": None,
"request_time_milis": round(time.time() * 1000),
"requestID": 1,
"terminal_uuid": "foobar",
}
caplog.set_level(logging.DEBUG)
res = await transport.send(json_dumps(request))
assert "result" in res
assert (
"Received unencrypted response over secure passthrough from 127.0.0.1"
in caplog.text
)
async def test_unencrypted_response_invalid_json(mocker, caplog):
host = "127.0.0.1"
mock_aes_device = MockAesDevice(
host, 200, 0, 0, do_not_encrypt_response=True, send_response=b"Foobar"
)
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=mock_aes_device.post)
transport = AesTransport(
config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
)
transport._state = TransportState.ESTABLISHED
transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session
transport._token_url = transport._app_url.with_query(
f"token={mock_aes_device.token}"
)
request = {
"method": "get_device_info",
"params": None,
"request_time_milis": round(time.time() * 1000),
"requestID": 1,
"terminal_uuid": "foobar",
}
caplog.set_level(logging.DEBUG)
msg = f"Unable to decrypt response from {host}, error: Incorrect padding, response: Foobar"
with pytest.raises(KasaException, match=msg):
await transport.send(json_dumps(request))
ERRORS = [e for e in SmartErrorCode if e != 0]
@pytest.mark.parametrize("error_code", ERRORS, ids=lambda e: e.name)
async def test_passthrough_errors(mocker, error_code):
host = "127.0.0.1"
mock_aes_device = MockAesDevice(host, 200, error_code, 0)
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=mock_aes_device.post)
config = DeviceConfig(host, credentials=Credentials("foo", "bar"))
transport = AesTransport(config=config)
transport._handshake_done = True
transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session
transport._token_url = transport._app_url.with_query(
f"token={mock_aes_device.token}"
)
request = {
"method": "get_device_info",
"params": None,
"request_time_milis": round(time.time() * 1000),
"requestID": 1,
"terminal_uuid": "foobar",
}
with pytest.raises(KasaException):
await transport.send(json_dumps(request))
@pytest.mark.parametrize("error_code", [-13333, 13333])
async def test_unknown_errors(mocker, error_code):
host = "127.0.0.1"
mock_aes_device = MockAesDevice(host, 200, error_code, 0)
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=mock_aes_device.post)
config = DeviceConfig(host, credentials=Credentials("foo", "bar"))
transport = AesTransport(config=config)
transport._handshake_done = True
transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session
transport._token_url = transport._app_url.with_query(
f"token={mock_aes_device.token}"
)
request = {
"method": "get_device_info",
"params": None,
"request_time_milis": round(time.time() * 1000),
"requestID": 1,
"terminal_uuid": "foobar",
}
with pytest.raises(KasaException): # noqa: PT012
res = await transport.send(json_dumps(request))
assert res is SmartErrorCode.INTERNAL_UNKNOWN_ERROR
async def test_port_override():
"""Test that port override sets the app_url."""
host = "127.0.0.1"
config = DeviceConfig(
host, credentials=Credentials("foo", "bar"), port_override=12345
)
transport = AesTransport(config=config)
assert str(transport._app_url) == "http://127.0.0.1:12345/app"
@pytest.mark.parametrize(
("device_delay_required", "should_error", "should_succeed"),
[
pytest.param(0, False, True, id="No error"),
pytest.param(0.125, True, True, id="Error then succeed"),
pytest.param(0.3, True, True, id="Two errors then succeed"),
pytest.param(0.7, True, False, id="No succeed"),
],
)
async def test_device_closes_connection(
mocker,
freezer: FrozenDateTimeFactory,
device_delay_required,
should_error,
should_succeed,
):
"""Test the delay logic in http client to deal with devices that close connections after each request.
Currently only the P100 on older firmware.
"""
host = "127.0.0.1"
default_delay = HttpClient.WAIT_BETWEEN_REQUESTS_ON_OSERROR
mock_aes_device = MockAesDevice(
host, 200, 0, 0, sequential_request_delay=device_delay_required
)
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=mock_aes_device.post)
async def _asyncio_sleep_mock(delay, result=None):
freezer.tick(delay)
return result
mocker.patch("asyncio.sleep", side_effect=_asyncio_sleep_mock)
config = DeviceConfig(host, credentials=Credentials("foo", "bar"))
transport = AesTransport(config=config)
transport._http_client.WAIT_BETWEEN_REQUESTS_ON_OSERROR = default_delay
transport._state = TransportState.LOGIN_REQUIRED
transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session
transport._token_url = transport._app_url.with_query(
f"token={mock_aes_device.token}"
)
request = {
"method": "get_device_info",
"params": None,
"request_time_milis": round(time.time() * 1000),
"requestID": 1,
"terminal_uuid": "foobar",
}
error_count = 0
success = False
# If the device errors without a delay then it should error immedately ( + 1)
# and then the number of times the default delay passes within the request delay window
expected_error_count = (
0 if not should_error else int(device_delay_required / default_delay) + 1
)
for _ in range(3):
try:
await transport.send(json_dumps(request))
except _ConnectionError:
error_count += 1
else:
success = True
assert bool(transport._http_client._wait_between_requests) == should_error
assert bool(error_count) == should_error
assert error_count == expected_error_count
assert success == should_succeed
class MockAesDevice:
class _mock_response:
def __init__(self, status, json: dict):
self.status = status
self._json = json
async def __aenter__(self):
return self
async def __aexit__(self, exc_t, exc_v, exc_tb):
pass
async def read(self):
if isinstance(self._json, dict):
return json_dumps(self._json).encode()
return self._json
encryption_session = AesEncyptionSession(KEY_IV[:16], KEY_IV[16:])
def __init__(
self,
host,
status_code=200,
error_code=0,
inner_error_code=0,
*,
do_not_encrypt_response=False,
send_response=None,
sequential_request_delay=0,
):
self.host = host
self.status_code = status_code
self.error_code = error_code
self._inner_error_code = inner_error_code
self.do_not_encrypt_response = do_not_encrypt_response
self.send_response = send_response
self.http_client = HttpClient(DeviceConfig(self.host))
self.inner_call_count = 0
self.token = "".join(random.choices(string.ascii_uppercase, k=32)) # noqa: S311
self.sequential_request_delay = sequential_request_delay
self.last_request_time = None
self.sequential_error_raised = False
@property
def inner_error_code(self):
if isinstance(self._inner_error_code, list):
return self._inner_error_code[self.inner_call_count]
else:
return self._inner_error_code
async def post(self, url: URL, params=None, json=None, data=None, *_, **__):
if self.sequential_request_delay and self.last_request_time:
now = time.time()
print(now - self.last_request_time)
if (now - self.last_request_time) < self.sequential_request_delay:
self.sequential_error_raised = True
raise aiohttp.ClientOSError("Test connection closed")
if data:
async for item in data:
json = json_loads(item.decode())
res = await self._post(url, json)
if self.sequential_request_delay:
self.last_request_time = time.time()
return res
async def _post(self, url: URL, json: dict[str, Any]):
if json["method"] == "handshake":
return await self._return_handshake_response(url, json)
elif json["method"] == "securePassthrough":
return await self._return_secure_passthrough_response(url, json)
elif json["method"] == "login_device":
return await self._return_login_response(url, json)
else:
assert url == URL(f"http://{self.host}:80/app?token={self.token}")
return await self._return_send_response(url, json)
async def _return_handshake_response(self, url: URL, json: dict[str, Any]):
start = len("-----BEGIN PUBLIC KEY-----\n")
end = len("\n-----END PUBLIC KEY-----\n")
client_pub_key = json["params"]["key"][start:-end]
client_pub_key_data = base64.b64decode(client_pub_key.encode())
client_pub_key = serialization.load_der_public_key(client_pub_key_data, None)
encrypted_key = client_pub_key.encrypt(KEY_IV, asymmetric_padding.PKCS1v15())
key_64 = base64.b64encode(encrypted_key).decode()
return self._mock_response(
self.status_code, {"result": {"key": key_64}, "error_code": self.error_code}
)
async def _return_secure_passthrough_response(self, url: URL, json: dict[str, Any]):
encrypted_request = json["params"]["request"]
decrypted_request = self.encryption_session.decrypt(encrypted_request.encode())
decrypted_request_dict = json_loads(decrypted_request)
decrypted_response = await self._post(url, decrypted_request_dict)
async with decrypted_response:
decrypted_response_data = await decrypted_response.read()
encrypted_response = self.encryption_session.encrypt(decrypted_response_data)
response = (
decrypted_response_data
if self.do_not_encrypt_response
else encrypted_response
)
result = {
"result": {"response": response.decode()},
"error_code": self.error_code,
}
return self._mock_response(self.status_code, result)
async def _return_login_response(self, url: URL, json: dict[str, Any]):
if "token=" in str(url):
raise Exception("token should not be in url for a login request")
self.token = "".join(random.choices(string.ascii_uppercase, k=32)) # noqa: S311
result = {"result": {"token": self.token}, "error_code": self.inner_error_code}
self.inner_call_count += 1
return self._mock_response(self.status_code, result)
async def _return_send_response(self, url: URL, json: dict[str, Any]):
result = {"result": {"method": None}, "error_code": self.inner_error_code}
response = self.send_response if self.send_response else result
self.inner_call_count += 1
return self._mock_response(self.status_code, response)

View File

@@ -0,0 +1,565 @@
import json
import logging
import re
import secrets
import time
from contextlib import nullcontext as does_not_raise
import aiohttp
import pytest
from yarl import URL
from kasa.credentials import DEFAULT_CREDENTIALS, Credentials, get_default_credentials
from kasa.deviceconfig import DeviceConfig
from kasa.exceptions import (
AuthenticationError,
KasaException,
TimeoutError,
_ConnectionError,
_RetryableError,
)
from kasa.httpclient import HttpClient
from kasa.protocols import IotProtocol, SmartProtocol
from kasa.transports.aestransport import AesTransport
from kasa.transports.klaptransport import (
KlapEncryptionSession,
KlapTransport,
KlapTransportV2,
_sha256,
)
DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}}
# Transport tests are not designed for real devices
pytestmark = [pytest.mark.requires_dummy]
class _mock_response:
def __init__(self, status, content: bytes):
self.status = status
self.content = content
async def __aenter__(self):
return self
async def __aexit__(self, exc_t, exc_v, exc_tb):
pass
async def read(self):
return self.content
@pytest.mark.parametrize(
("error", "retry_expectation"),
[
(Exception("dummy exception"), False),
(aiohttp.ServerTimeoutError("dummy exception"), True),
(aiohttp.ServerDisconnectedError("dummy exception"), True),
(aiohttp.ClientOSError("dummy exception"), True),
],
ids=("Exception", "ServerTimeoutError", "ServerDisconnectedError", "ClientOSError"),
)
@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol])
@pytest.mark.parametrize("retry_count", [1, 3, 5])
async def test_protocol_retries_via_client_session(
mocker, retry_count, protocol_class, transport_class, error, retry_expectation
):
host = "127.0.0.1"
conn = mocker.patch.object(aiohttp.ClientSession, "post", side_effect=error)
config = DeviceConfig(host)
with pytest.raises(KasaException):
await protocol_class(transport=transport_class(config=config)).query(
DUMMY_QUERY, retry_count=retry_count
)
expected_count = retry_count + 1 if retry_expectation else 1
assert conn.call_count == expected_count
@pytest.mark.parametrize(
("error", "retry_expectation"),
[
(KasaException("dummy exception"), False),
(_RetryableError("dummy exception"), True),
(TimeoutError("dummy exception"), True),
],
ids=("KasaException", "_RetryableError", "TimeoutError"),
)
@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol])
@pytest.mark.parametrize("retry_count", [1, 3, 5])
async def test_protocol_retries_via_httpclient(
mocker, retry_count, protocol_class, transport_class, error, retry_expectation
):
host = "127.0.0.1"
conn = mocker.patch.object(HttpClient, "post", side_effect=error)
mocker.patch.object(protocol_class, "BACKOFF_SECONDS_AFTER_TIMEOUT", 0)
config = DeviceConfig(host)
with pytest.raises(KasaException):
await protocol_class(transport=transport_class(config=config)).query(
DUMMY_QUERY, retry_count=retry_count
)
expected_count = retry_count + 1 if retry_expectation else 1
assert conn.call_count == expected_count
@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol])
async def test_protocol_no_retry_on_connection_error(
mocker, protocol_class, transport_class
):
host = "127.0.0.1"
conn = mocker.patch.object(
aiohttp.ClientSession,
"post",
side_effect=AuthenticationError("foo"),
)
mocker.patch.object(protocol_class, "BACKOFF_SECONDS_AFTER_TIMEOUT", 0)
config = DeviceConfig(host)
with pytest.raises(KasaException):
await protocol_class(transport=transport_class(config=config)).query(
DUMMY_QUERY, retry_count=5
)
assert conn.call_count == 1
@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol])
async def test_protocol_retry_recoverable_error(
mocker, protocol_class, transport_class
):
host = "127.0.0.1"
conn = mocker.patch.object(
aiohttp.ClientSession,
"post",
side_effect=aiohttp.ClientOSError("foo"),
)
config = DeviceConfig(host)
with pytest.raises(KasaException):
await protocol_class(transport=transport_class(config=config)).query(
DUMMY_QUERY, retry_count=5
)
assert conn.call_count == 6
@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol])
@pytest.mark.parametrize("retry_count", [1, 3, 5])
async def test_protocol_reconnect(mocker, retry_count, protocol_class, transport_class):
host = "127.0.0.1"
remaining = retry_count
mock_response = {"result": {"great": "success"}, "error_code": 0}
def _fail_one_less_than_retry_count(*_, **__):
nonlocal remaining
remaining -= 1
if remaining:
raise _ConnectionError("Simulated connection failure")
return mock_response
mocker.patch.object(transport_class, "perform_handshake")
if hasattr(transport_class, "perform_login"):
mocker.patch.object(transport_class, "perform_login")
send_mock = mocker.patch.object(
transport_class,
"send",
side_effect=_fail_one_less_than_retry_count,
)
config = DeviceConfig(host)
response = await protocol_class(transport=transport_class(config=config)).query(
DUMMY_QUERY, retry_count=retry_count
)
assert "result" in response or "foobar" in response
assert send_mock.call_count == retry_count
@pytest.mark.parametrize("log_level", [logging.WARNING, logging.DEBUG])
@pytest.mark.xdist_group(name="caplog")
async def test_protocol_logging(mocker, caplog, log_level):
caplog.set_level(log_level)
logging.getLogger("kasa").setLevel(log_level)
def _return_encrypted(*_, **__):
nonlocal encryption_session
# Do the encrypt just before returning the value so the incrementing sequence number is correct
encrypted, seq = encryption_session.encrypt('{"great":"success"}')
return 200, encrypted
seed = secrets.token_bytes(16)
auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar"))
encryption_session = KlapEncryptionSession(seed, seed, auth_hash)
config = DeviceConfig("127.0.0.1")
protocol = IotProtocol(transport=KlapTransport(config=config))
protocol._transport._handshake_done = True
protocol._transport._session_expire_at = time.time() + 86400
protocol._transport._encryption_session = encryption_session
mocker.patch.object(HttpClient, "post", side_effect=_return_encrypted)
response = await protocol.query({})
assert response == {"great": "success"}
if log_level == logging.DEBUG:
assert "success" in caplog.text
else:
assert "success" not in caplog.text
def test_encrypt():
d = json.dumps({"foo": 1, "bar": 2})
seed = secrets.token_bytes(16)
auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar"))
encryption_session = KlapEncryptionSession(seed, seed, auth_hash)
encrypted, seq = encryption_session.encrypt(d)
assert d == encryption_session.decrypt(encrypted)
def test_encrypt_unicode():
d = "{'snowman': '\u2603'}"
seed = secrets.token_bytes(16)
auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar"))
encryption_session = KlapEncryptionSession(seed, seed, auth_hash)
encrypted, seq = encryption_session.encrypt(d)
decrypted = encryption_session.decrypt(encrypted)
assert d == decrypted
async def test_transport_decrypt(mocker):
"""Test transport decryption."""
d = {"great": "success"}
seed = secrets.token_bytes(16)
auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar"))
encryption_session = KlapEncryptionSession(seed, seed, auth_hash)
transport = KlapTransport(config=DeviceConfig(host="127.0.0.1"))
transport._handshake_done = True
transport._session_expire_at = time.monotonic() + 60
transport._encryption_session = encryption_session
async def _return_response(url: URL, params=None, data=None, *_, **__):
encryption_session = KlapEncryptionSession(
transport._encryption_session.local_seed,
transport._encryption_session.remote_seed,
transport._encryption_session.user_hash,
)
seq = params.get("seq")
encryption_session._seq = seq - 1
encrypted, seq = encryption_session.encrypt(json.dumps(d))
seq = seq
return 200, encrypted
mocker.patch.object(HttpClient, "post", side_effect=_return_response)
resp = await transport.send(json.dumps({}))
assert d == resp
async def test_transport_decrypt_error(mocker, caplog):
"""Test that a decryption error raises a kasa exception."""
d = {"great": "success"}
seed = secrets.token_bytes(16)
auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar"))
encryption_session = KlapEncryptionSession(seed, seed, auth_hash)
transport = KlapTransport(config=DeviceConfig(host="127.0.0.1"))
transport._handshake_done = True
transport._session_expire_at = time.monotonic() + 60
transport._encryption_session = encryption_session
async def _return_response(url: URL, params=None, data=None, *_, **__):
encryption_session = KlapEncryptionSession(
secrets.token_bytes(16),
transport._encryption_session.remote_seed,
transport._encryption_session.user_hash,
)
seq = params.get("seq")
encryption_session._seq = seq - 1
encrypted, seq = encryption_session.encrypt(json.dumps(d))
seq = seq
return 200, encrypted
mocker.patch.object(HttpClient, "post", side_effect=_return_response)
with pytest.raises(
KasaException,
match=re.escape("Error trying to decrypt device 127.0.0.1 response:"),
):
await transport.send(json.dumps({}))
@pytest.mark.parametrize(
("device_credentials", "expectation"),
[
(Credentials("foo", "bar"), does_not_raise()),
(Credentials(), does_not_raise()),
(
get_default_credentials(DEFAULT_CREDENTIALS["KASA"]),
does_not_raise(),
),
(
Credentials("shouldfail", "shouldfail"),
pytest.raises(AuthenticationError),
),
],
ids=("client", "blank", "kasa_setup", "shouldfail"),
)
@pytest.mark.parametrize(
("transport_class", "seed_auth_hash_calc"),
[
pytest.param(KlapTransport, lambda c, s, a: c + a, id="KLAP"),
pytest.param(KlapTransportV2, lambda c, s, a: c + s + a, id="KLAPV2"),
],
)
async def test_handshake1(
mocker, device_credentials, expectation, transport_class, seed_auth_hash_calc
):
async def _return_handshake1_response(url, params=None, data=None, *_, **__):
nonlocal client_seed, server_seed, device_auth_hash
client_seed = data
seed_auth_hash = _sha256(
seed_auth_hash_calc(client_seed, server_seed, device_auth_hash)
)
return _mock_response(200, server_seed + seed_auth_hash)
client_seed = None
server_seed = secrets.token_bytes(16)
client_credentials = Credentials("foo", "bar")
device_auth_hash = transport_class.generate_auth_hash(device_credentials)
mocker.patch.object(
aiohttp.ClientSession, "post", side_effect=_return_handshake1_response
)
config = DeviceConfig("127.0.0.1", credentials=client_credentials)
protocol = IotProtocol(transport=transport_class(config=config))
with expectation:
(
local_seed,
device_remote_seed,
auth_hash,
) = await protocol._transport.perform_handshake1()
assert local_seed == client_seed
assert device_remote_seed == server_seed
assert device_auth_hash == auth_hash
await protocol.close()
@pytest.mark.parametrize(
("transport_class", "seed_auth_hash_calc1", "seed_auth_hash_calc2"),
[
pytest.param(
KlapTransport, lambda c, s, a: c + a, lambda c, s, a: s + a, id="KLAP"
),
pytest.param(
KlapTransportV2,
lambda c, s, a: c + s + a,
lambda c, s, a: s + c + a,
id="KLAPV2",
),
],
)
async def test_handshake(
mocker, transport_class, seed_auth_hash_calc1, seed_auth_hash_calc2
):
client_seed = None
server_seed = secrets.token_bytes(16)
client_credentials = Credentials("foo", "bar")
device_auth_hash = transport_class.generate_auth_hash(client_credentials)
async def _return_handshake_response(url: URL, params=None, data=None, *_, **__):
nonlocal client_seed, server_seed, device_auth_hash
if url == URL("http://127.0.0.1:80/app/handshake1"):
client_seed = data
seed_auth_hash = _sha256(
seed_auth_hash_calc1(client_seed, server_seed, device_auth_hash)
)
return _mock_response(200, server_seed + seed_auth_hash)
elif url == URL("http://127.0.0.1:80/app/handshake2"):
seed_auth_hash = _sha256(
seed_auth_hash_calc2(client_seed, server_seed, device_auth_hash)
)
assert data == seed_auth_hash
return _mock_response(response_status, b"")
mocker.patch.object(
aiohttp.ClientSession, "post", side_effect=_return_handshake_response
)
config = DeviceConfig("127.0.0.1", credentials=client_credentials)
protocol = IotProtocol(transport=transport_class(config=config))
protocol._transport.http_client = aiohttp.ClientSession()
response_status = 200
await protocol._transport.perform_handshake()
assert protocol._transport._handshake_done is True
response_status = 403
with pytest.raises(KasaException):
await protocol._transport.perform_handshake()
assert protocol._transport._handshake_done is False
await protocol.close()
async def test_query(mocker):
client_seed = None
last_seq = None
seq = None
server_seed = secrets.token_bytes(16)
client_credentials = Credentials("foo", "bar")
device_auth_hash = KlapTransport.generate_auth_hash(client_credentials)
async def _return_response(url: URL, params=None, data=None, *_, **__):
nonlocal client_seed, server_seed, device_auth_hash, seq
if url == URL("http://127.0.0.1:80/app/handshake1"):
client_seed = data
client_seed_auth_hash = _sha256(data + device_auth_hash)
return _mock_response(200, server_seed + client_seed_auth_hash)
elif url == URL("http://127.0.0.1:80/app/handshake2"):
return _mock_response(200, b"")
elif url == URL("http://127.0.0.1:80/app/request"):
encryption_session = KlapEncryptionSession(
protocol._transport._encryption_session.local_seed,
protocol._transport._encryption_session.remote_seed,
protocol._transport._encryption_session.user_hash,
)
seq = params.get("seq")
encryption_session._seq = seq - 1
encrypted, seq = encryption_session.encrypt('{"great": "success"}')
seq = seq
return _mock_response(200, encrypted)
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=_return_response)
config = DeviceConfig("127.0.0.1", credentials=client_credentials)
protocol = IotProtocol(transport=KlapTransport(config=config))
for _ in range(10):
resp = await protocol.query({})
assert resp == {"great": "success"}
# Check the protocol is incrementing the sequence number
assert last_seq is None or last_seq + 1 == seq
last_seq = seq
@pytest.mark.parametrize(
("response_status", "credentials_match", "expectation"),
[
pytest.param(
(403, 403, 403),
True,
pytest.raises(KasaException),
id="handshake1-403-status",
),
pytest.param(
(200, 403, 403),
True,
pytest.raises(KasaException),
id="handshake2-403-status",
),
pytest.param(
(200, 200, 403),
True,
pytest.raises(_RetryableError),
id="request-403-status",
),
pytest.param(
(200, 200, 400),
True,
pytest.raises(KasaException),
id="request-400-status",
),
pytest.param(
(200, 200, 200),
False,
pytest.raises(AuthenticationError),
id="handshake1-wrong-auth",
),
pytest.param(
(200, 200, 200),
secrets.token_bytes(16),
pytest.raises(KasaException),
id="handshake1-bad-auth-length",
),
],
)
async def test_authentication_failures(
mocker, response_status, credentials_match, expectation
):
client_seed = None
server_seed = secrets.token_bytes(16)
client_credentials = Credentials("foo", "bar")
device_credentials = (
client_credentials if credentials_match else Credentials("bar", "foo")
)
device_auth_hash = KlapTransport.generate_auth_hash(device_credentials)
async def _return_response(url: URL, params=None, data=None, *_, **__):
nonlocal \
client_seed, \
server_seed, \
device_auth_hash, \
response_status, \
credentials_match
if url == URL("http://127.0.0.1:80/app/handshake1"):
client_seed = data
client_seed_auth_hash = _sha256(data + device_auth_hash)
if credentials_match is not False and credentials_match is not True:
client_seed_auth_hash += credentials_match
return _mock_response(
response_status[0], server_seed + client_seed_auth_hash
)
elif url == URL("http://127.0.0.1:80/app/handshake2"):
client_seed = data
client_seed_auth_hash = _sha256(data + device_auth_hash)
return _mock_response(
response_status[1], server_seed + client_seed_auth_hash
)
elif url == URL("http://127.0.0.1:80/app/request"):
return _mock_response(response_status[2], b"")
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=_return_response)
config = DeviceConfig("127.0.0.1", credentials=client_credentials)
protocol = IotProtocol(transport=KlapTransport(config=config))
with expectation:
await protocol.query({})
async def test_port_override():
"""Test that port override sets the app_url."""
host = "127.0.0.1"
config = DeviceConfig(
host, credentials=Credentials("foo", "bar"), port_override=12345
)
transport = KlapTransport(config=config)
assert str(transport._app_url) == "http://127.0.0.1:12345/app"

View File

@@ -0,0 +1,380 @@
from __future__ import annotations
import logging
import secrets
from contextlib import nullcontext as does_not_raise
from json import dumps as json_dumps
from json import loads as json_loads
from typing import Any
import aiohttp
import pytest
from yarl import URL
from kasa.credentials import DEFAULT_CREDENTIALS, Credentials, get_default_credentials
from kasa.deviceconfig import DeviceConfig
from kasa.exceptions import (
AuthenticationError,
KasaException,
SmartErrorCode,
)
from kasa.httpclient import HttpClient
from kasa.transports.aestransport import AesEncyptionSession
from kasa.transports.sslaestransport import (
SslAesTransport,
TransportState,
_sha256_hash,
)
# Transport tests are not designed for real devices
pytestmark = [pytest.mark.requires_dummy]
MOCK_ADMIN_USER = get_default_credentials(DEFAULT_CREDENTIALS["TAPOCAMERA"]).username
MOCK_PWD = "correct_pwd" # noqa: S105
MOCK_USER = "mock@example.com"
MOCK_STOCK = "abcdefghijklmnopqrstuvwxyz1234)("
@pytest.mark.parametrize(
(
"status_code",
"username",
"password",
"wants_default_user",
"digest_password_fail",
"expectation",
),
[
pytest.param(
200, MOCK_USER, MOCK_PWD, False, False, does_not_raise(), id="success"
),
pytest.param(
200,
MOCK_USER,
MOCK_PWD,
True,
False,
does_not_raise(),
id="success-default",
),
pytest.param(
400,
MOCK_USER,
MOCK_PWD,
False,
False,
pytest.raises(KasaException),
id="400 error",
),
pytest.param(
200,
"foobar",
MOCK_PWD,
False,
False,
pytest.raises(AuthenticationError),
id="bad-username",
),
pytest.param(
200,
MOCK_USER,
"barfoo",
False,
False,
pytest.raises(AuthenticationError),
id="bad-password",
),
pytest.param(
200,
MOCK_USER,
MOCK_PWD,
False,
True,
pytest.raises(AuthenticationError),
id="bad-password-digest",
),
],
)
async def test_handshake(
mocker,
status_code,
username,
password,
wants_default_user,
digest_password_fail,
expectation,
):
host = "127.0.0.1"
mock_ssl_aes_device = MockSslAesDevice(
host,
status_code=status_code,
want_default_username=wants_default_user,
digest_password_fail=digest_password_fail,
)
mocker.patch.object(
aiohttp.ClientSession, "post", side_effect=mock_ssl_aes_device.post
)
transport = SslAesTransport(
config=DeviceConfig(host, credentials=Credentials(username, password))
)
assert transport._encryption_session is None
assert transport._state is TransportState.HANDSHAKE_REQUIRED
with expectation:
await transport.perform_handshake()
assert transport._encryption_session is not None
assert transport._state is TransportState.ESTABLISHED
@pytest.mark.parametrize(
("wants_default_user"),
[pytest.param(False, id="username"), pytest.param(True, id="default")],
)
async def test_credentials_hash(mocker, wants_default_user):
host = "127.0.0.1"
mock_ssl_aes_device = MockSslAesDevice(
host, want_default_username=wants_default_user
)
mocker.patch.object(
aiohttp.ClientSession, "post", side_effect=mock_ssl_aes_device.post
)
creds = Credentials(MOCK_USER, MOCK_PWD)
creds_hash = SslAesTransport._create_b64_credentials(creds)
# Test with credentials input
transport = SslAesTransport(config=DeviceConfig(host, credentials=creds))
assert transport.credentials_hash == creds_hash
await transport.perform_handshake()
assert transport.credentials_hash == creds_hash
# Test with credentials_hash input
transport = SslAesTransport(config=DeviceConfig(host, credentials_hash=creds_hash))
mock_ssl_aes_device.handshake1_complete = False
assert transport.credentials_hash == creds_hash
await transport.perform_handshake()
assert transport.credentials_hash == creds_hash
async def test_send(mocker):
host = "127.0.0.1"
mock_ssl_aes_device = MockSslAesDevice(host, want_default_username=False)
mocker.patch.object(
aiohttp.ClientSession, "post", side_effect=mock_ssl_aes_device.post
)
transport = SslAesTransport(
config=DeviceConfig(host, credentials=Credentials(MOCK_USER, MOCK_PWD))
)
request = {
"method": "getDeviceInfo",
"params": None,
}
res = await transport.send(json_dumps(request))
assert "result" in res
@pytest.mark.xdist_group(name="caplog")
async def test_unencrypted_response(mocker, caplog):
host = "127.0.0.1"
mock_ssl_aes_device = MockSslAesDevice(host, do_not_encrypt_response=True)
mocker.patch.object(
aiohttp.ClientSession, "post", side_effect=mock_ssl_aes_device.post
)
transport = SslAesTransport(
config=DeviceConfig(host, credentials=Credentials(MOCK_USER, MOCK_PWD))
)
request = {
"method": "getDeviceInfo",
"params": None,
}
caplog.set_level(logging.DEBUG)
res = await transport.send(json_dumps(request))
assert "result" in res
assert (
"Received unencrypted response over secure passthrough from 127.0.0.1"
in caplog.text
)
async def test_port_override():
"""Test that port override sets the app_url."""
host = "127.0.0.1"
port_override = 12345
config = DeviceConfig(
host, credentials=Credentials("foo", "bar"), port_override=port_override
)
transport = SslAesTransport(config=config)
assert str(transport._app_url) == f"https://127.0.0.1:{port_override}"
class MockSslAesDevice:
BAD_USER_RESP = {
"error_code": SmartErrorCode.SESSION_EXPIRED.value,
"result": {
"data": {
"code": -60502,
}
},
}
BAD_PWD_RESP = {
"error_code": SmartErrorCode.INVALID_NONCE.value,
"result": {
"data": {
"code": SmartErrorCode.SESSION_EXPIRED.value,
"encrypt_type": ["3"],
"key": "Someb64keyWithUnknownPurpose",
"nonce": "1234567890ABCDEF", # Whatever the original nonce was
"device_confirm": "",
}
},
}
class _mock_response:
def __init__(self, status, request: dict):
self.status = status
self._json = request
async def __aenter__(self):
return self
async def __aexit__(self, exc_t, exc_v, exc_tb):
pass
async def read(self):
if isinstance(self._json, dict):
return json_dumps(self._json).encode()
return self._json
def __init__(
self,
host,
*,
status_code=200,
want_default_username: bool = False,
do_not_encrypt_response=False,
send_response=None,
sequential_request_delay=0,
send_error_code=0,
secure_passthrough_error_code=0,
digest_password_fail=False,
):
self.host = host
self.http_client = HttpClient(DeviceConfig(self.host))
self.encryption_session: AesEncyptionSession | None = None
self.server_nonce = secrets.token_bytes(8).hex().upper()
self.handshake1_complete = False
# test behaviour attributes
self.status_code = status_code
self.send_error_code = send_error_code
self.secure_passthrough_error_code = secure_passthrough_error_code
self.do_not_encrypt_response = do_not_encrypt_response
self.want_default_username = want_default_username
self.digest_password_fail = digest_password_fail
async def post(self, url: URL, params=None, json=None, data=None, *_, **__):
if data:
json = json_loads(data)
res = await self._post(url, json)
return res
async def _post(self, url: URL, json: dict[str, Any]):
method = json["method"]
if method == "login" and not self.handshake1_complete:
return await self._return_handshake1_response(url, json)
if method == "login" and self.handshake1_complete:
return await self._return_handshake2_response(url, json)
elif method == "securePassthrough":
assert url == URL(f"https://{self.host}/stok={MOCK_STOCK}/ds")
return await self._return_secure_passthrough_response(url, json)
else:
assert url == URL(f"https://{self.host}/stok={MOCK_STOCK}/ds")
return await self._return_send_response(url, json)
async def _return_handshake1_response(self, url: URL, request: dict[str, Any]):
request_nonce = request["params"].get("cnonce")
request_username = request["params"].get("username")
if (self.want_default_username and request_username != MOCK_ADMIN_USER) or (
not self.want_default_username and request_username != MOCK_USER
):
return self._mock_response(self.status_code, self.BAD_USER_RESP)
device_confirm = SslAesTransport.generate_confirm_hash(
request_nonce, self.server_nonce, _sha256_hash(MOCK_PWD.encode())
)
self.handshake1_complete = True
resp = {
"error_code": SmartErrorCode.INVALID_NONCE.value,
"result": {
"data": {
"code": SmartErrorCode.INVALID_NONCE.value,
"encrypt_type": ["3"],
"key": "Someb64keyWithUnknownPurpose",
"nonce": self.server_nonce,
"device_confirm": device_confirm,
}
},
}
return self._mock_response(self.status_code, resp)
async def _return_handshake2_response(self, url: URL, request: dict[str, Any]):
request_nonce = request["params"].get("cnonce")
request_username = request["params"].get("username")
if (self.want_default_username and request_username != MOCK_ADMIN_USER) or (
not self.want_default_username and request_username != MOCK_USER
):
return self._mock_response(self.status_code, self.BAD_USER_RESP)
request_password = request["params"].get("digest_passwd")
expected_pwd = SslAesTransport.generate_digest_password(
request_nonce, self.server_nonce, _sha256_hash(MOCK_PWD.encode())
)
if request_password != expected_pwd or self.digest_password_fail:
return self._mock_response(self.status_code, self.BAD_PWD_RESP)
lsk = SslAesTransport.generate_encryption_token(
"lsk", request_nonce, self.server_nonce, _sha256_hash(MOCK_PWD.encode())
)
ivb = SslAesTransport.generate_encryption_token(
"ivb", request_nonce, self.server_nonce, _sha256_hash(MOCK_PWD.encode())
)
self.encryption_session = AesEncyptionSession(lsk, ivb)
resp = {
"error_code": 0,
"result": {"stok": MOCK_STOCK, "user_group": "root", "start_seq": 100},
}
return self._mock_response(self.status_code, resp)
async def _return_secure_passthrough_response(self, url: URL, json: dict[str, Any]):
encrypted_request = json["params"]["request"]
assert self.encryption_session
decrypted_request = self.encryption_session.decrypt(encrypted_request.encode())
decrypted_request_dict = json_loads(decrypted_request)
decrypted_response = await self._post(url, decrypted_request_dict)
async with decrypted_response:
decrypted_response_data = await decrypted_response.read()
encrypted_response = self.encryption_session.encrypt(decrypted_response_data)
response = (
decrypted_response_data
if self.do_not_encrypt_response
else encrypted_response
)
result = {
"result": {"response": response.decode()},
"error_code": self.secure_passthrough_error_code,
}
return self._mock_response(self.status_code, result)
async def _return_send_response(self, url: URL, json: dict[str, Any]):
result = {"result": {"method": None}, "error_code": self.send_error_code}
return self._mock_response(self.status_code, result)