from __future__ import annotations import base64 import json import logging import random import string import time from contextlib import nullcontext as does_not_raise from json import dumps as json_dumps from json import loads as json_loads from typing import Any import aiohttp import pytest from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding from freezegun.api import FrozenDateTimeFactory from yarl import URL from kasa.credentials import Credentials from kasa.deviceconfig import DeviceConfig from kasa.exceptions import ( AuthenticationError, KasaException, SmartErrorCode, _ConnectionError, ) from kasa.httpclient import HttpClient from kasa.transports.aestransport import ( AesEncyptionSession, AesTransport, TransportState, ) pytestmark = [pytest.mark.requires_dummy] DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}} key = b"8\x89\x02\xfa\xf5Xs\x1c\xa1 H\x9a\x82\xc7\xd9\t" iv = b"9=\xf8\x1bS\xcd0\xb5\x89i\xba\xfd^9\x9f\xfa" KEY_IV = key + iv def test_encrypt(): encryption_session = AesEncyptionSession(KEY_IV[:16], KEY_IV[16:]) d = json.dumps({"foo": 1, "bar": 2}) encrypted = encryption_session.encrypt(d.encode()) assert d == encryption_session.decrypt(encrypted) # test encrypt unicode d = "{'snowman': '\u2603'}" encrypted = encryption_session.encrypt(d.encode()) assert d == encryption_session.decrypt(encrypted) status_parameters = pytest.mark.parametrize( "status_code, error_code, inner_error_code, expectation", [ (200, 0, 0, does_not_raise()), (400, 0, 0, pytest.raises(KasaException)), (200, -1, 0, pytest.raises(KasaException)), ], ids=("success", "status_code", "error_code"), ) @status_parameters async def test_handshake( mocker, status_code, error_code, inner_error_code, expectation ): host = "127.0.0.1" mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code) mocker.patch.object(aiohttp.ClientSession, "post", side_effect=mock_aes_device.post) transport = AesTransport( config=DeviceConfig(host, credentials=Credentials("foo", "bar")) ) assert transport._encryption_session is None assert transport._state is TransportState.HANDSHAKE_REQUIRED with expectation: await transport.perform_handshake() assert transport._encryption_session is not None assert transport._state is TransportState.LOGIN_REQUIRED async def test_handshake_with_keys(mocker): host = "127.0.0.1" mock_aes_device = MockAesDevice(host) mocker.patch.object(aiohttp.ClientSession, "post", side_effect=mock_aes_device.post) test_keys = { "private": "MIICdQIBADANBgkqhkiG9w0BAQEFAASCAl8wggJbAgEAAoGBAMo/JQpXIbP2M3bLOKyfEVCURFCxHIXv4HDME8J58AL4BwGDXf0oQycgj9nV+T/MzgEd/4iVysYuYfLuIEKXADP7Lby6AfA/dbcinZZ7bLUNMNa7TaylIvVKtSfR0LV8AmG0jdQYkr4cTzLAEd+AEs/wG3nMQNEcoQRVY+svLPDjAgMBAAECgYBCsDOch0KbvrEVmMklUoY5Fcq4+M249HIDf6d8VwznTbWxsAmL8nzCKCCG6eF4QiYjhCrAdPQaCS1PF2oXywbLhngid/9W9gz4CKKDJChs1X8KvLi+TLg1jgJUXvq9yVNh1CB+lS2ho4gdDDCbVmiVOZR5TDfEf0xeJ+Zz3zlUEQJBAPkhuNdc3yRue8huFZbrWwikURQPYBxLOYfVTDsfV9mZGSkGoWS1FPDsxrqSXugTmcTRuw+lrXKDabJ72kqywA8CQQDP0oaGh5r7F12Xzcwb7X9JkTvyr+rO8YgVtKNBaNVOPabAzysNwOlvH/sNCVQcRj8rn5LNXitgLx6T+Q5uqa3tAkA7J0elUzbkhps7ju/vYri9x448zh3K+g2R9BJio2GPmCuCM0HVEK4FOqNBH4oLXsQPGKFq6LLTUuKg74l4XRL/AkBHBO6r8pNn0yhMxCtIL/UbsuIFoVBgv/F9WWmg5K5gOnlN0n4oCRC8xPUKE3IG54qW4cVNIS05hWCxuJ7R+nJRAkByt/+kX1nQxis2wIXj90fztXG3oSmoVaieYxaXPxlWvX3/Q5kslFF5UsGy9gcK0v2PXhqjTbhud3/X0Er6YP4v", "public": "MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDKPyUKVyGz9jN2yzisnxFQlERQsRyF7+BwzBPCefAC+AcBg139KEMnII/Z1fk/zM4BHf+IlcrGLmHy7iBClwAz+y28ugHwP3W3Ip2We2y1DTDWu02spSL1SrUn0dC1fAJhtI3UGJK+HE8ywBHfgBLP8Bt5zEDRHKEEVWPrLyzw4wIDAQAB", } transport = AesTransport( config=DeviceConfig( host, credentials=Credentials("foo", "bar"), aes_keys=test_keys ) ) assert transport._encryption_session is None assert transport._state is TransportState.HANDSHAKE_REQUIRED await transport.perform_handshake() assert transport._key_pair.private_key_der_b64 == test_keys["private"] assert transport._key_pair.public_key_der_b64 == test_keys["public"] @status_parameters async def test_login(mocker, status_code, error_code, inner_error_code, expectation): host = "127.0.0.1" mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code) mocker.patch.object(aiohttp.ClientSession, "post", side_effect=mock_aes_device.post) transport = AesTransport( config=DeviceConfig(host, credentials=Credentials("foo", "bar")) ) transport._state = TransportState.LOGIN_REQUIRED transport._session_expire_at = time.time() + 86400 transport._encryption_session = mock_aes_device.encryption_session assert transport._token_url is None with expectation: await transport.perform_login() assert mock_aes_device.token in str(transport._token_url) assert transport._config.aes_keys == transport._key_pair @pytest.mark.parametrize( ("inner_error_codes", "expectation", "call_count"), [ ([SmartErrorCode.LOGIN_ERROR, 0, 0, 0], does_not_raise(), 4), ( [SmartErrorCode.LOGIN_ERROR, SmartErrorCode.LOGIN_ERROR], pytest.raises(AuthenticationError), 3, ), ( [SmartErrorCode.LOGIN_FAILED_ERROR], pytest.raises(AuthenticationError), 1, ), ( [SmartErrorCode.LOGIN_ERROR, SmartErrorCode.SESSION_TIMEOUT_ERROR], pytest.raises(KasaException), 3, ), ], ids=( "LOGIN_ERROR-success", "LOGIN_ERROR-LOGIN_ERROR", "LOGIN_FAILED_ERROR", "LOGIN_ERROR-SESSION_TIMEOUT_ERROR", ), ) async def test_login_errors(mocker, inner_error_codes, expectation, call_count): host = "127.0.0.1" mock_aes_device = MockAesDevice(host, 200, 0, inner_error_codes) post_mock = mocker.patch.object( aiohttp.ClientSession, "post", side_effect=mock_aes_device.post ) transport = AesTransport( config=DeviceConfig(host, credentials=Credentials("foo", "bar")) ) transport._state = TransportState.LOGIN_REQUIRED transport._session_expire_at = time.time() + 86400 transport._encryption_session = mock_aes_device.encryption_session mocker.patch.object(transport._http_client, "WAIT_BETWEEN_REQUESTS_ON_OSERROR", 0) assert transport._token_url is None request = { "method": "get_device_info", "params": None, "request_time_milis": round(time.time() * 1000), "requestID": 1, "terminal_uuid": "foobar", } with expectation: await transport.send(json_dumps(request)) assert mock_aes_device.token in str(transport._token_url) assert post_mock.call_count == call_count # Login, Handshake, Login await transport.close() @status_parameters async def test_send(mocker, status_code, error_code, inner_error_code, expectation): host = "127.0.0.1" mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code) mocker.patch.object(aiohttp.ClientSession, "post", side_effect=mock_aes_device.post) transport = AesTransport( config=DeviceConfig(host, credentials=Credentials("foo", "bar")) ) transport._handshake_done = True transport._session_expire_at = time.time() + 86400 transport._encryption_session = mock_aes_device.encryption_session transport._token_url = transport._app_url.with_query( f"token={mock_aes_device.token}" ) request = { "method": "get_device_info", "params": None, "request_time_milis": round(time.time() * 1000), "requestID": 1, "terminal_uuid": "foobar", } with expectation: res = await transport.send(json_dumps(request)) assert "result" in res @pytest.mark.xdist_group(name="caplog") async def test_unencrypted_response(mocker, caplog): host = "127.0.0.1" mock_aes_device = MockAesDevice(host, 200, 0, 0, do_not_encrypt_response=True) mocker.patch.object(aiohttp.ClientSession, "post", side_effect=mock_aes_device.post) transport = AesTransport( config=DeviceConfig(host, credentials=Credentials("foo", "bar")) ) transport._state = TransportState.ESTABLISHED transport._session_expire_at = time.time() + 86400 transport._encryption_session = mock_aes_device.encryption_session transport._token_url = transport._app_url.with_query( f"token={mock_aes_device.token}" ) request = { "method": "get_device_info", "params": None, "request_time_milis": round(time.time() * 1000), "requestID": 1, "terminal_uuid": "foobar", } caplog.set_level(logging.DEBUG) res = await transport.send(json_dumps(request)) assert "result" in res assert ( "Received unencrypted response over secure passthrough from 127.0.0.1" in caplog.text ) async def test_unencrypted_response_invalid_json(mocker, caplog): host = "127.0.0.1" mock_aes_device = MockAesDevice( host, 200, 0, 0, do_not_encrypt_response=True, send_response=b"Foobar" ) mocker.patch.object(aiohttp.ClientSession, "post", side_effect=mock_aes_device.post) transport = AesTransport( config=DeviceConfig(host, credentials=Credentials("foo", "bar")) ) transport._state = TransportState.ESTABLISHED transport._session_expire_at = time.time() + 86400 transport._encryption_session = mock_aes_device.encryption_session transport._token_url = transport._app_url.with_query( f"token={mock_aes_device.token}" ) request = { "method": "get_device_info", "params": None, "request_time_milis": round(time.time() * 1000), "requestID": 1, "terminal_uuid": "foobar", } caplog.set_level(logging.DEBUG) msg = f"Unable to decrypt response from {host}, error: Incorrect padding, response: Foobar" with pytest.raises(KasaException, match=msg): await transport.send(json_dumps(request)) ERRORS = [e for e in SmartErrorCode if e != 0] @pytest.mark.parametrize("error_code", ERRORS, ids=lambda e: e.name) async def test_passthrough_errors(mocker, error_code): host = "127.0.0.1" mock_aes_device = MockAesDevice(host, 200, error_code, 0) mocker.patch.object(aiohttp.ClientSession, "post", side_effect=mock_aes_device.post) config = DeviceConfig(host, credentials=Credentials("foo", "bar")) transport = AesTransport(config=config) transport._handshake_done = True transport._session_expire_at = time.time() + 86400 transport._encryption_session = mock_aes_device.encryption_session transport._token_url = transport._app_url.with_query( f"token={mock_aes_device.token}" ) request = { "method": "get_device_info", "params": None, "request_time_milis": round(time.time() * 1000), "requestID": 1, "terminal_uuid": "foobar", } with pytest.raises(KasaException): await transport.send(json_dumps(request)) @pytest.mark.parametrize("error_code", [-13333, 13333]) async def test_unknown_errors(mocker, error_code): host = "127.0.0.1" mock_aes_device = MockAesDevice(host, 200, error_code, 0) mocker.patch.object(aiohttp.ClientSession, "post", side_effect=mock_aes_device.post) config = DeviceConfig(host, credentials=Credentials("foo", "bar")) transport = AesTransport(config=config) transport._handshake_done = True transport._session_expire_at = time.time() + 86400 transport._encryption_session = mock_aes_device.encryption_session transport._token_url = transport._app_url.with_query( f"token={mock_aes_device.token}" ) request = { "method": "get_device_info", "params": None, "request_time_milis": round(time.time() * 1000), "requestID": 1, "terminal_uuid": "foobar", } with pytest.raises(KasaException): # noqa: PT012 res = await transport.send(json_dumps(request)) assert res is SmartErrorCode.INTERNAL_UNKNOWN_ERROR async def test_port_override(): """Test that port override sets the app_url.""" host = "127.0.0.1" config = DeviceConfig( host, credentials=Credentials("foo", "bar"), port_override=12345 ) transport = AesTransport(config=config) assert str(transport._app_url) == "http://127.0.0.1:12345/app" @pytest.mark.parametrize( ("device_delay_required", "should_error", "should_succeed"), [ pytest.param(0, False, True, id="No error"), pytest.param(0.125, True, True, id="Error then succeed"), pytest.param(0.3, True, True, id="Two errors then succeed"), pytest.param(0.7, True, False, id="No succeed"), ], ) async def test_device_closes_connection( mocker, freezer: FrozenDateTimeFactory, device_delay_required, should_error, should_succeed, ): """Test the delay logic in http client to deal with devices that close connections after each request. Currently only the P100 on older firmware. """ host = "127.0.0.1" default_delay = HttpClient.WAIT_BETWEEN_REQUESTS_ON_OSERROR mock_aes_device = MockAesDevice( host, 200, 0, 0, sequential_request_delay=device_delay_required ) mocker.patch.object(aiohttp.ClientSession, "post", side_effect=mock_aes_device.post) async def _asyncio_sleep_mock(delay, result=None): freezer.tick(delay) return result mocker.patch("asyncio.sleep", side_effect=_asyncio_sleep_mock) config = DeviceConfig(host, credentials=Credentials("foo", "bar")) transport = AesTransport(config=config) transport._http_client.WAIT_BETWEEN_REQUESTS_ON_OSERROR = default_delay transport._state = TransportState.LOGIN_REQUIRED transport._session_expire_at = time.time() + 86400 transport._encryption_session = mock_aes_device.encryption_session transport._token_url = transport._app_url.with_query( f"token={mock_aes_device.token}" ) request = { "method": "get_device_info", "params": None, "request_time_milis": round(time.time() * 1000), "requestID": 1, "terminal_uuid": "foobar", } error_count = 0 success = False # If the device errors without a delay then it should error immedately ( + 1) # and then the number of times the default delay passes within the request delay window expected_error_count = ( 0 if not should_error else int(device_delay_required / default_delay) + 1 ) for _ in range(3): try: await transport.send(json_dumps(request)) except _ConnectionError: error_count += 1 else: success = True assert bool(transport._http_client._wait_between_requests) == should_error assert bool(error_count) == should_error assert error_count == expected_error_count assert success == should_succeed class MockAesDevice: class _mock_response: def __init__(self, status, json: dict): self.status = status self._json = json async def __aenter__(self): return self async def __aexit__(self, exc_t, exc_v, exc_tb): pass async def read(self): if isinstance(self._json, dict): return json_dumps(self._json).encode() return self._json encryption_session = AesEncyptionSession(KEY_IV[:16], KEY_IV[16:]) def __init__( self, host, status_code=200, error_code=0, inner_error_code=0, *, do_not_encrypt_response=False, send_response=None, sequential_request_delay=0, ): self.host = host self.status_code = status_code self.error_code = error_code self._inner_error_code = inner_error_code self.do_not_encrypt_response = do_not_encrypt_response self.send_response = send_response self.http_client = HttpClient(DeviceConfig(self.host)) self.inner_call_count = 0 self.token = "".join(random.choices(string.ascii_uppercase, k=32)) # noqa: S311 self.sequential_request_delay = sequential_request_delay self.last_request_time = None self.sequential_error_raised = False @property def inner_error_code(self): if isinstance(self._inner_error_code, list): return self._inner_error_code[self.inner_call_count] else: return self._inner_error_code async def post(self, url: URL, params=None, json=None, data=None, *_, **__): if self.sequential_request_delay and self.last_request_time: now = time.time() print(now - self.last_request_time) if (now - self.last_request_time) < self.sequential_request_delay: self.sequential_error_raised = True raise aiohttp.ClientOSError("Test connection closed") if data: async for item in data: json = json_loads(item.decode()) res = await self._post(url, json) if self.sequential_request_delay: self.last_request_time = time.time() return res async def _post(self, url: URL, json: dict[str, Any]): if json["method"] == "handshake": return await self._return_handshake_response(url, json) elif json["method"] == "securePassthrough": return await self._return_secure_passthrough_response(url, json) elif json["method"] == "login_device": return await self._return_login_response(url, json) else: assert url == URL(f"http://{self.host}:80/app?token={self.token}") return await self._return_send_response(url, json) async def _return_handshake_response(self, url: URL, json: dict[str, Any]): start = len("-----BEGIN PUBLIC KEY-----\n") end = len("\n-----END PUBLIC KEY-----\n") client_pub_key = json["params"]["key"][start:-end] client_pub_key_data = base64.b64decode(client_pub_key.encode()) client_pub_key = serialization.load_der_public_key(client_pub_key_data, None) encrypted_key = client_pub_key.encrypt(KEY_IV, asymmetric_padding.PKCS1v15()) key_64 = base64.b64encode(encrypted_key).decode() return self._mock_response( self.status_code, {"result": {"key": key_64}, "error_code": self.error_code} ) async def _return_secure_passthrough_response(self, url: URL, json: dict[str, Any]): encrypted_request = json["params"]["request"] decrypted_request = self.encryption_session.decrypt(encrypted_request.encode()) decrypted_request_dict = json_loads(decrypted_request) decrypted_response = await self._post(url, decrypted_request_dict) async with decrypted_response: decrypted_response_data = await decrypted_response.read() encrypted_response = self.encryption_session.encrypt(decrypted_response_data) response = ( decrypted_response_data if self.do_not_encrypt_response else encrypted_response ) result = { "result": {"response": response.decode()}, "error_code": self.error_code, } return self._mock_response(self.status_code, result) async def _return_login_response(self, url: URL, json: dict[str, Any]): if "token=" in str(url): raise Exception("token should not be in url for a login request") self.token = "".join(random.choices(string.ascii_uppercase, k=32)) # noqa: S311 result = {"result": {"token": self.token}, "error_code": self.inner_error_code} self.inner_call_count += 1 return self._mock_response(self.status_code, result) async def _return_send_response(self, url: URL, json: dict[str, Any]): result = {"result": {"method": None}, "error_code": self.inner_error_code} response = self.send_response if self.send_response else result self.inner_call_count += 1 return self._mock_response(self.status_code, response)