import json import logging import re import secrets import time from contextlib import nullcontext as does_not_raise import aiohttp import pytest from yarl import URL from ..aestransport import AesTransport from ..credentials import Credentials from ..deviceconfig import DeviceConfig from ..exceptions import ( AuthenticationError, KasaException, TimeoutError, _ConnectionError, _RetryableError, ) from ..httpclient import HttpClient from ..iotprotocol import IotProtocol from ..klaptransport import ( KlapEncryptionSession, KlapTransport, KlapTransportV2, _sha256, ) from ..protocol import DEFAULT_CREDENTIALS, get_default_credentials from ..smartprotocol import SmartProtocol DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}} class _mock_response: def __init__(self, status, content: bytes): self.status = status self.content = content async def __aenter__(self): return self async def __aexit__(self, exc_t, exc_v, exc_tb): pass async def read(self): return self.content @pytest.mark.parametrize( ("error", "retry_expectation"), [ (Exception("dummy exception"), False), (aiohttp.ServerTimeoutError("dummy exception"), True), (aiohttp.ServerDisconnectedError("dummy exception"), True), (aiohttp.ClientOSError("dummy exception"), True), ], ids=("Exception", "ServerTimeoutError", "ServerDisconnectedError", "ClientOSError"), ) @pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport]) @pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol]) @pytest.mark.parametrize("retry_count", [1, 3, 5]) async def test_protocol_retries_via_client_session( mocker, retry_count, protocol_class, transport_class, error, retry_expectation ): host = "127.0.0.1" conn = mocker.patch.object(aiohttp.ClientSession, "post", side_effect=error) mocker.patch.object(protocol_class, "BACKOFF_SECONDS_AFTER_TIMEOUT", 0) config = DeviceConfig(host) with pytest.raises(KasaException): await protocol_class(transport=transport_class(config=config)).query( DUMMY_QUERY, retry_count=retry_count ) expected_count = retry_count + 1 if retry_expectation else 1 assert conn.call_count == expected_count @pytest.mark.parametrize( ("error", "retry_expectation"), [ (KasaException("dummy exception"), False), (_RetryableError("dummy exception"), True), (TimeoutError("dummy exception"), True), ], ids=("KasaException", "_RetryableError", "TimeoutError"), ) @pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport]) @pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol]) @pytest.mark.parametrize("retry_count", [1, 3, 5]) async def test_protocol_retries_via_httpclient( mocker, retry_count, protocol_class, transport_class, error, retry_expectation ): host = "127.0.0.1" conn = mocker.patch.object(HttpClient, "post", side_effect=error) mocker.patch.object(protocol_class, "BACKOFF_SECONDS_AFTER_TIMEOUT", 0) config = DeviceConfig(host) with pytest.raises(KasaException): await protocol_class(transport=transport_class(config=config)).query( DUMMY_QUERY, retry_count=retry_count ) expected_count = retry_count + 1 if retry_expectation else 1 assert conn.call_count == expected_count @pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport]) @pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol]) async def test_protocol_no_retry_on_connection_error( mocker, protocol_class, transport_class ): host = "127.0.0.1" conn = mocker.patch.object( aiohttp.ClientSession, "post", side_effect=AuthenticationError("foo"), ) mocker.patch.object(protocol_class, "BACKOFF_SECONDS_AFTER_TIMEOUT", 0) config = DeviceConfig(host) with pytest.raises(KasaException): await protocol_class(transport=transport_class(config=config)).query( DUMMY_QUERY, retry_count=5 ) assert conn.call_count == 1 @pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport]) @pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol]) async def test_protocol_retry_recoverable_error( mocker, protocol_class, transport_class ): host = "127.0.0.1" conn = mocker.patch.object( aiohttp.ClientSession, "post", side_effect=aiohttp.ClientOSError("foo"), ) config = DeviceConfig(host) with pytest.raises(KasaException): await protocol_class(transport=transport_class(config=config)).query( DUMMY_QUERY, retry_count=5 ) assert conn.call_count == 6 @pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport]) @pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol]) @pytest.mark.parametrize("retry_count", [1, 3, 5]) async def test_protocol_reconnect(mocker, retry_count, protocol_class, transport_class): host = "127.0.0.1" remaining = retry_count mock_response = {"result": {"great": "success"}, "error_code": 0} def _fail_one_less_than_retry_count(*_, **__): nonlocal remaining remaining -= 1 if remaining: raise _ConnectionError("Simulated connection failure") return mock_response mocker.patch.object(transport_class, "perform_handshake") if hasattr(transport_class, "perform_login"): mocker.patch.object(transport_class, "perform_login") send_mock = mocker.patch.object( transport_class, "send", side_effect=_fail_one_less_than_retry_count, ) config = DeviceConfig(host) response = await protocol_class(transport=transport_class(config=config)).query( DUMMY_QUERY, retry_count=retry_count ) assert "result" in response or "foobar" in response assert send_mock.call_count == retry_count @pytest.mark.parametrize("log_level", [logging.WARNING, logging.DEBUG]) async def test_protocol_logging(mocker, caplog, log_level): caplog.set_level(log_level) logging.getLogger("kasa").setLevel(log_level) def _return_encrypted(*_, **__): nonlocal encryption_session # Do the encrypt just before returning the value so the incrementing sequence number is correct encrypted, seq = encryption_session.encrypt('{"great":"success"}') return 200, encrypted seed = secrets.token_bytes(16) auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar")) encryption_session = KlapEncryptionSession(seed, seed, auth_hash) config = DeviceConfig("127.0.0.1") protocol = IotProtocol(transport=KlapTransport(config=config)) protocol._transport._handshake_done = True protocol._transport._session_expire_at = time.time() + 86400 protocol._transport._encryption_session = encryption_session mocker.patch.object(HttpClient, "post", side_effect=_return_encrypted) response = await protocol.query({}) assert response == {"great": "success"} if log_level == logging.DEBUG: assert "success" in caplog.text else: assert "success" not in caplog.text def test_encrypt(): d = json.dumps({"foo": 1, "bar": 2}) seed = secrets.token_bytes(16) auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar")) encryption_session = KlapEncryptionSession(seed, seed, auth_hash) encrypted, seq = encryption_session.encrypt(d) assert d == encryption_session.decrypt(encrypted) def test_encrypt_unicode(): d = "{'snowman': '\u2603'}" seed = secrets.token_bytes(16) auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar")) encryption_session = KlapEncryptionSession(seed, seed, auth_hash) encrypted, seq = encryption_session.encrypt(d) decrypted = encryption_session.decrypt(encrypted) assert d == decrypted async def test_transport_decrypt(mocker): """Test transport decryption.""" d = {"great": "success"} seed = secrets.token_bytes(16) auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar")) encryption_session = KlapEncryptionSession(seed, seed, auth_hash) transport = KlapTransport(config=DeviceConfig(host="127.0.0.1")) transport._handshake_done = True transport._session_expire_at = time.monotonic() + 60 transport._encryption_session = encryption_session async def _return_response(url: URL, params=None, data=None, *_, **__): encryption_session = KlapEncryptionSession( transport._encryption_session.local_seed, transport._encryption_session.remote_seed, transport._encryption_session.user_hash, ) seq = params.get("seq") encryption_session._seq = seq - 1 encrypted, seq = encryption_session.encrypt(json.dumps(d)) seq = seq return 200, encrypted mocker.patch.object(HttpClient, "post", side_effect=_return_response) resp = await transport.send(json.dumps({})) assert d == resp async def test_transport_decrypt_error(mocker, caplog): """Test that a decryption error raises a kasa exception.""" d = {"great": "success"} seed = secrets.token_bytes(16) auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar")) encryption_session = KlapEncryptionSession(seed, seed, auth_hash) transport = KlapTransport(config=DeviceConfig(host="127.0.0.1")) transport._handshake_done = True transport._session_expire_at = time.monotonic() + 60 transport._encryption_session = encryption_session async def _return_response(url: URL, params=None, data=None, *_, **__): encryption_session = KlapEncryptionSession( secrets.token_bytes(16), transport._encryption_session.remote_seed, transport._encryption_session.user_hash, ) seq = params.get("seq") encryption_session._seq = seq - 1 encrypted, seq = encryption_session.encrypt(json.dumps(d)) seq = seq return 200, encrypted mocker.patch.object(HttpClient, "post", side_effect=_return_response) with pytest.raises( KasaException, match=re.escape("Error trying to decrypt device 127.0.0.1 response:"), ): await transport.send(json.dumps({})) @pytest.mark.parametrize( ("device_credentials", "expectation"), [ (Credentials("foo", "bar"), does_not_raise()), (Credentials(), does_not_raise()), ( get_default_credentials(DEFAULT_CREDENTIALS["KASA"]), does_not_raise(), ), ( Credentials("shouldfail", "shouldfail"), pytest.raises(AuthenticationError), ), ], ids=("client", "blank", "kasa_setup", "shouldfail"), ) @pytest.mark.parametrize( ("transport_class", "seed_auth_hash_calc"), [ pytest.param(KlapTransport, lambda c, s, a: c + a, id="KLAP"), pytest.param(KlapTransportV2, lambda c, s, a: c + s + a, id="KLAPV2"), ], ) async def test_handshake1( mocker, device_credentials, expectation, transport_class, seed_auth_hash_calc ): async def _return_handshake1_response(url, params=None, data=None, *_, **__): nonlocal client_seed, server_seed, device_auth_hash client_seed = data seed_auth_hash = _sha256( seed_auth_hash_calc(client_seed, server_seed, device_auth_hash) ) return _mock_response(200, server_seed + seed_auth_hash) client_seed = None server_seed = secrets.token_bytes(16) client_credentials = Credentials("foo", "bar") device_auth_hash = transport_class.generate_auth_hash(device_credentials) mocker.patch.object( aiohttp.ClientSession, "post", side_effect=_return_handshake1_response ) config = DeviceConfig("127.0.0.1", credentials=client_credentials) protocol = IotProtocol(transport=transport_class(config=config)) with expectation: ( local_seed, device_remote_seed, auth_hash, ) = await protocol._transport.perform_handshake1() assert local_seed == client_seed assert device_remote_seed == server_seed assert device_auth_hash == auth_hash await protocol.close() @pytest.mark.parametrize( ("transport_class", "seed_auth_hash_calc1", "seed_auth_hash_calc2"), [ pytest.param( KlapTransport, lambda c, s, a: c + a, lambda c, s, a: s + a, id="KLAP" ), pytest.param( KlapTransportV2, lambda c, s, a: c + s + a, lambda c, s, a: s + c + a, id="KLAPV2", ), ], ) async def test_handshake( mocker, transport_class, seed_auth_hash_calc1, seed_auth_hash_calc2 ): client_seed = None server_seed = secrets.token_bytes(16) client_credentials = Credentials("foo", "bar") device_auth_hash = transport_class.generate_auth_hash(client_credentials) async def _return_handshake_response(url: URL, params=None, data=None, *_, **__): nonlocal client_seed, server_seed, device_auth_hash if str(url) == "http://127.0.0.1:80/app/handshake1": client_seed = data seed_auth_hash = _sha256( seed_auth_hash_calc1(client_seed, server_seed, device_auth_hash) ) return _mock_response(200, server_seed + seed_auth_hash) elif str(url) == "http://127.0.0.1:80/app/handshake2": seed_auth_hash = _sha256( seed_auth_hash_calc2(client_seed, server_seed, device_auth_hash) ) assert data == seed_auth_hash return _mock_response(response_status, b"") mocker.patch.object( aiohttp.ClientSession, "post", side_effect=_return_handshake_response ) config = DeviceConfig("127.0.0.1", credentials=client_credentials) protocol = IotProtocol(transport=transport_class(config=config)) protocol._transport.http_client = aiohttp.ClientSession() response_status = 200 await protocol._transport.perform_handshake() assert protocol._transport._handshake_done is True response_status = 403 with pytest.raises(KasaException): await protocol._transport.perform_handshake() assert protocol._transport._handshake_done is False await protocol.close() async def test_query(mocker): client_seed = None last_seq = None seq = None server_seed = secrets.token_bytes(16) client_credentials = Credentials("foo", "bar") device_auth_hash = KlapTransport.generate_auth_hash(client_credentials) async def _return_response(url: URL, params=None, data=None, *_, **__): nonlocal client_seed, server_seed, device_auth_hash, seq if str(url) == "http://127.0.0.1:80/app/handshake1": client_seed = data client_seed_auth_hash = _sha256(data + device_auth_hash) return _mock_response(200, server_seed + client_seed_auth_hash) elif str(url) == "http://127.0.0.1:80/app/handshake2": return _mock_response(200, b"") elif str(url) == "http://127.0.0.1:80/app/request": encryption_session = KlapEncryptionSession( protocol._transport._encryption_session.local_seed, protocol._transport._encryption_session.remote_seed, protocol._transport._encryption_session.user_hash, ) seq = params.get("seq") encryption_session._seq = seq - 1 encrypted, seq = encryption_session.encrypt('{"great": "success"}') seq = seq return _mock_response(200, encrypted) mocker.patch.object(aiohttp.ClientSession, "post", side_effect=_return_response) config = DeviceConfig("127.0.0.1", credentials=client_credentials) protocol = IotProtocol(transport=KlapTransport(config=config)) for _ in range(10): resp = await protocol.query({}) assert resp == {"great": "success"} # Check the protocol is incrementing the sequence number assert last_seq is None or last_seq + 1 == seq last_seq = seq @pytest.mark.parametrize( ("response_status", "credentials_match", "expectation"), [ pytest.param( (403, 403, 403), True, pytest.raises(KasaException), id="handshake1-403-status", ), pytest.param( (200, 403, 403), True, pytest.raises(KasaException), id="handshake2-403-status", ), pytest.param( (200, 200, 403), True, pytest.raises(_RetryableError), id="request-403-status", ), pytest.param( (200, 200, 400), True, pytest.raises(KasaException), id="request-400-status", ), pytest.param( (200, 200, 200), False, pytest.raises(AuthenticationError), id="handshake1-wrong-auth", ), pytest.param( (200, 200, 200), secrets.token_bytes(16), pytest.raises(KasaException), id="handshake1-bad-auth-length", ), ], ) async def test_authentication_failures( mocker, response_status, credentials_match, expectation ): client_seed = None server_seed = secrets.token_bytes(16) client_credentials = Credentials("foo", "bar") device_credentials = ( client_credentials if credentials_match else Credentials("bar", "foo") ) device_auth_hash = KlapTransport.generate_auth_hash(device_credentials) async def _return_response(url: URL, params=None, data=None, *_, **__): nonlocal \ client_seed, \ server_seed, \ device_auth_hash, \ response_status, \ credentials_match if str(url) == "http://127.0.0.1:80/app/handshake1": client_seed = data client_seed_auth_hash = _sha256(data + device_auth_hash) if credentials_match is not False and credentials_match is not True: client_seed_auth_hash += credentials_match return _mock_response( response_status[0], server_seed + client_seed_auth_hash ) elif str(url) == "http://127.0.0.1:80/app/handshake2": client_seed = data client_seed_auth_hash = _sha256(data + device_auth_hash) return _mock_response( response_status[1], server_seed + client_seed_auth_hash ) elif str(url) == "http://127.0.0.1:80/app/request": return _mock_response(response_status[2], b"") mocker.patch.object(aiohttp.ClientSession, "post", side_effect=_return_response) config = DeviceConfig("127.0.0.1", credentials=client_credentials) protocol = IotProtocol(transport=KlapTransport(config=config)) with expectation: await protocol.query({}) async def test_port_override(): """Test that port override sets the app_url.""" host = "127.0.0.1" config = DeviceConfig( host, credentials=Credentials("foo", "bar"), port_override=12345 ) transport = KlapTransport(config=config) assert str(transport._app_url) == "http://127.0.0.1:12345/app"