diff --git a/kasa/klaptransport.py b/kasa/klaptransport.py index b7976101..97b23145 100644 --- a/kasa/klaptransport.py +++ b/kasa/klaptransport.py @@ -50,7 +50,7 @@ import logging import secrets import struct import time -from typing import Any, cast +from typing import TYPE_CHECKING, Any, cast from cryptography.hazmat.primitives import padding from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes @@ -354,9 +354,14 @@ class KlapTransport(BaseTransport): else: _LOGGER.debug("Device %s query posted %s", self._host, msg) - # Check for mypy - if self._encryption_session is not None: + if TYPE_CHECKING: + assert self._encryption_session + try: decrypted_response = self._encryption_session.decrypt(response_data) + except Exception as ex: + raise KasaException( + f"Error trying to decrypt device {self._host} response: {ex}" + ) from ex json_payload = json_loads(decrypted_response) diff --git a/kasa/tests/test_klapprotocol.py b/kasa/tests/test_klapprotocol.py index b71ea460..0565683a 100644 --- a/kasa/tests/test_klapprotocol.py +++ b/kasa/tests/test_klapprotocol.py @@ -238,6 +238,71 @@ def test_encrypt_unicode(): assert d == decrypted +async def test_transport_decrypt(mocker): + """Test transport decryption.""" + d = {"great": "success"} + + seed = secrets.token_bytes(16) + auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar")) + encryption_session = KlapEncryptionSession(seed, seed, auth_hash) + + transport = KlapTransport(config=DeviceConfig(host="127.0.0.1")) + transport._handshake_done = True + transport._session_expire_at = time.monotonic() + 60 + transport._encryption_session = encryption_session + + async def _return_response(url: URL, params=None, data=None, *_, **__): + encryption_session = KlapEncryptionSession( + transport._encryption_session.local_seed, + transport._encryption_session.remote_seed, + transport._encryption_session.user_hash, + ) + seq = params.get("seq") + encryption_session._seq = seq - 1 + encrypted, seq = encryption_session.encrypt(json.dumps(d)) + seq = seq + return 200, encrypted + + mocker.patch.object(HttpClient, "post", side_effect=_return_response) + + resp = await transport.send(json.dumps({})) + assert d == resp + + +async def test_transport_decrypt_error(mocker, caplog): + """Test that a decryption error raises a kasa exception.""" + d = {"great": "success"} + + seed = secrets.token_bytes(16) + auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar")) + encryption_session = KlapEncryptionSession(seed, seed, auth_hash) + + transport = KlapTransport(config=DeviceConfig(host="127.0.0.1")) + transport._handshake_done = True + transport._session_expire_at = time.monotonic() + 60 + transport._encryption_session = encryption_session + + async def _return_response(url: URL, params=None, data=None, *_, **__): + encryption_session = KlapEncryptionSession( + secrets.token_bytes(16), + transport._encryption_session.remote_seed, + transport._encryption_session.user_hash, + ) + seq = params.get("seq") + encryption_session._seq = seq - 1 + encrypted, seq = encryption_session.encrypt(json.dumps(d)) + seq = seq + return 200, encrypted + + mocker.patch.object(HttpClient, "post", side_effect=_return_response) + + with pytest.raises( + KasaException, + match="Error trying to decrypt device 127.0.0.1 response: Invalid padding bytes.", + ): + await transport.send(json.dumps({})) + + @pytest.mark.parametrize( "device_credentials, expectation", [