mirror of
				https://github.com/python-kasa/python-kasa.git
				synced 2025-10-31 20:51:54 +00:00 
			
		
		
		
	Fix SslAesTransport default login and add tests (#1202)
Co-authored-by: Teemu R. <tpr@iki.fi>
This commit is contained in:
		| @@ -137,6 +137,11 @@ class SslAesTransport(BaseTransport): | ||||
|         """Default port for the transport.""" | ||||
|         return self.DEFAULT_PORT | ||||
|  | ||||
|     @staticmethod | ||||
|     def _create_b64_credentials(credentials: Credentials) -> str: | ||||
|         ch = {"un": credentials.username, "pwd": credentials.password} | ||||
|         return base64.b64encode(json_dumps(ch).encode()).decode() | ||||
|  | ||||
|     @property | ||||
|     def credentials_hash(self) -> str | None: | ||||
|         """The hashed credentials used by the transport.""" | ||||
| @@ -145,8 +150,7 @@ class SslAesTransport(BaseTransport): | ||||
|         if not self._credentials and self._credentials_hash: | ||||
|             return self._credentials_hash | ||||
|         if (cred := self._credentials) and cred.password and cred.username: | ||||
|             ch = {"un": cred.username, "pwd": cred.password} | ||||
|             return base64.b64encode(json_dumps(ch).encode()).decode() | ||||
|             return self._create_b64_credentials(cred) | ||||
|         return None | ||||
|  | ||||
|     def _get_response_error(self, resp_dict: Any) -> SmartErrorCode: | ||||
| @@ -329,6 +333,13 @@ class SslAesTransport(BaseTransport): | ||||
|                 + f"status code {status_code} to handshake2" | ||||
|             ) | ||||
|         resp_dict = cast(dict, resp_dict) | ||||
|         if ( | ||||
|             error_code := self._get_response_error(resp_dict) | ||||
|         ) and error_code is SmartErrorCode.INVALID_NONCE: | ||||
|             raise AuthenticationError( | ||||
|                 f"Invalid password hash in handshake2 for {self._host}" | ||||
|             ) | ||||
|  | ||||
|         self._handle_response_error_code(resp_dict, "Error in handshake2") | ||||
|  | ||||
|         self._seq = resp_dict["result"]["start_seq"] | ||||
| @@ -372,12 +383,12 @@ class SslAesTransport(BaseTransport): | ||||
|  | ||||
|         if not self._username: | ||||
|             raise AuthenticationError( | ||||
|                 "Credentials must be supplied to connect to {self._host}" | ||||
|                 f"Credentials must be supplied to connect to {self._host}" | ||||
|             ) | ||||
|         if error_code is not SmartErrorCode.INVALID_NONCE or ( | ||||
|             resp_dict and "nonce" not in resp_dict["result"].get("data", {}) | ||||
|         ): | ||||
|             raise AuthenticationError("Error trying handshake1: {resp_dict}") | ||||
|             raise AuthenticationError(f"Error trying handshake1: {resp_dict}") | ||||
|  | ||||
|         if TYPE_CHECKING: | ||||
|             resp_dict = cast(Dict[str, Any], resp_dict) | ||||
| @@ -422,7 +433,7 @@ class SslAesTransport(BaseTransport): | ||||
|             "params": { | ||||
|                 "cnonce": local_nonce, | ||||
|                 "encrypt_type": "3", | ||||
|                 "username": self._username, | ||||
|                 "username": username, | ||||
|             }, | ||||
|         } | ||||
|         http_client = self._http_client | ||||
|   | ||||
							
								
								
									
										374
									
								
								kasa/tests/test_sslaestransport.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										374
									
								
								kasa/tests/test_sslaestransport.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,374 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| import logging | ||||
| import secrets | ||||
| from contextlib import nullcontext as does_not_raise | ||||
| from json import dumps as json_dumps | ||||
| from json import loads as json_loads | ||||
| from typing import Any | ||||
|  | ||||
| import aiohttp | ||||
| import pytest | ||||
| from yarl import URL | ||||
|  | ||||
| from kasa.protocol import DEFAULT_CREDENTIALS, get_default_credentials | ||||
|  | ||||
| from ..aestransport import AesEncyptionSession | ||||
| from ..credentials import Credentials | ||||
| from ..deviceconfig import DeviceConfig | ||||
| from ..exceptions import ( | ||||
|     AuthenticationError, | ||||
|     KasaException, | ||||
|     SmartErrorCode, | ||||
| ) | ||||
| from ..experimental.sslaestransport import SslAesTransport, TransportState, _sha256_hash | ||||
| from ..httpclient import HttpClient | ||||
|  | ||||
| MOCK_ADMIN_USER = get_default_credentials(DEFAULT_CREDENTIALS["TAPOCAMERA"]).username | ||||
| MOCK_PWD = "correct_pwd"  # noqa: S105 | ||||
| MOCK_USER = "mock@example.com" | ||||
| MOCK_STOCK = "abcdefghijklmnopqrstuvwxyz1234)(" | ||||
|  | ||||
|  | ||||
| @pytest.mark.parametrize( | ||||
|     ( | ||||
|         "status_code", | ||||
|         "username", | ||||
|         "password", | ||||
|         "wants_default_user", | ||||
|         "digest_password_fail", | ||||
|         "expectation", | ||||
|     ), | ||||
|     [ | ||||
|         pytest.param( | ||||
|             200, MOCK_USER, MOCK_PWD, False, False, does_not_raise(), id="success" | ||||
|         ), | ||||
|         pytest.param( | ||||
|             200, | ||||
|             MOCK_USER, | ||||
|             MOCK_PWD, | ||||
|             True, | ||||
|             False, | ||||
|             does_not_raise(), | ||||
|             id="success-default", | ||||
|         ), | ||||
|         pytest.param( | ||||
|             400, | ||||
|             MOCK_USER, | ||||
|             MOCK_PWD, | ||||
|             False, | ||||
|             False, | ||||
|             pytest.raises(KasaException), | ||||
|             id="400 error", | ||||
|         ), | ||||
|         pytest.param( | ||||
|             200, | ||||
|             "foobar", | ||||
|             MOCK_PWD, | ||||
|             False, | ||||
|             False, | ||||
|             pytest.raises(AuthenticationError), | ||||
|             id="bad-username", | ||||
|         ), | ||||
|         pytest.param( | ||||
|             200, | ||||
|             MOCK_USER, | ||||
|             "barfoo", | ||||
|             False, | ||||
|             False, | ||||
|             pytest.raises(AuthenticationError), | ||||
|             id="bad-password", | ||||
|         ), | ||||
|         pytest.param( | ||||
|             200, | ||||
|             MOCK_USER, | ||||
|             MOCK_PWD, | ||||
|             False, | ||||
|             True, | ||||
|             pytest.raises(AuthenticationError), | ||||
|             id="bad-password-digest", | ||||
|         ), | ||||
|     ], | ||||
| ) | ||||
| async def test_handshake( | ||||
|     mocker, | ||||
|     status_code, | ||||
|     username, | ||||
|     password, | ||||
|     wants_default_user, | ||||
|     digest_password_fail, | ||||
|     expectation, | ||||
| ): | ||||
|     host = "127.0.0.1" | ||||
|     mock_ssl_aes_device = MockSslAesDevice( | ||||
|         host, | ||||
|         status_code=status_code, | ||||
|         want_default_username=wants_default_user, | ||||
|         digest_password_fail=digest_password_fail, | ||||
|     ) | ||||
|     mocker.patch.object( | ||||
|         aiohttp.ClientSession, "post", side_effect=mock_ssl_aes_device.post | ||||
|     ) | ||||
|  | ||||
|     transport = SslAesTransport( | ||||
|         config=DeviceConfig(host, credentials=Credentials(username, password)) | ||||
|     ) | ||||
|  | ||||
|     assert transport._encryption_session is None | ||||
|     assert transport._state is TransportState.HANDSHAKE_REQUIRED | ||||
|     with expectation: | ||||
|         await transport.perform_handshake() | ||||
|         assert transport._encryption_session is not None | ||||
|         assert transport._state is TransportState.ESTABLISHED | ||||
|  | ||||
|  | ||||
| @pytest.mark.parametrize( | ||||
|     ("wants_default_user"), | ||||
|     [pytest.param(False, id="username"), pytest.param(True, id="default")], | ||||
| ) | ||||
| async def test_credentials_hash(mocker, wants_default_user): | ||||
|     host = "127.0.0.1" | ||||
|     mock_ssl_aes_device = MockSslAesDevice( | ||||
|         host, want_default_username=wants_default_user | ||||
|     ) | ||||
|     mocker.patch.object( | ||||
|         aiohttp.ClientSession, "post", side_effect=mock_ssl_aes_device.post | ||||
|     ) | ||||
|     creds = Credentials(MOCK_USER, MOCK_PWD) | ||||
|     creds_hash = SslAesTransport._create_b64_credentials(creds) | ||||
|  | ||||
|     # Test with credentials input | ||||
|     transport = SslAesTransport(config=DeviceConfig(host, credentials=creds)) | ||||
|     assert transport.credentials_hash == creds_hash | ||||
|     await transport.perform_handshake() | ||||
|     assert transport.credentials_hash == creds_hash | ||||
|  | ||||
|     # Test with credentials_hash input | ||||
|     transport = SslAesTransport(config=DeviceConfig(host, credentials_hash=creds_hash)) | ||||
|     mock_ssl_aes_device.handshake1_complete = False | ||||
|     assert transport.credentials_hash == creds_hash | ||||
|     await transport.perform_handshake() | ||||
|     assert transport.credentials_hash == creds_hash | ||||
|  | ||||
|  | ||||
| async def test_send(mocker): | ||||
|     host = "127.0.0.1" | ||||
|     mock_ssl_aes_device = MockSslAesDevice(host, want_default_username=False) | ||||
|     mocker.patch.object( | ||||
|         aiohttp.ClientSession, "post", side_effect=mock_ssl_aes_device.post | ||||
|     ) | ||||
|  | ||||
|     transport = SslAesTransport( | ||||
|         config=DeviceConfig(host, credentials=Credentials(MOCK_USER, MOCK_PWD)) | ||||
|     ) | ||||
|     request = { | ||||
|         "method": "getDeviceInfo", | ||||
|         "params": None, | ||||
|     } | ||||
|  | ||||
|     res = await transport.send(json_dumps(request)) | ||||
|     assert "result" in res | ||||
|  | ||||
|  | ||||
| async def test_unencrypted_response(mocker, caplog): | ||||
|     host = "127.0.0.1" | ||||
|     mock_ssl_aes_device = MockSslAesDevice(host, do_not_encrypt_response=True) | ||||
|     mocker.patch.object( | ||||
|         aiohttp.ClientSession, "post", side_effect=mock_ssl_aes_device.post | ||||
|     ) | ||||
|  | ||||
|     transport = SslAesTransport( | ||||
|         config=DeviceConfig(host, credentials=Credentials(MOCK_USER, MOCK_PWD)) | ||||
|     ) | ||||
|  | ||||
|     request = { | ||||
|         "method": "getDeviceInfo", | ||||
|         "params": None, | ||||
|     } | ||||
|     caplog.set_level(logging.DEBUG) | ||||
|     res = await transport.send(json_dumps(request)) | ||||
|     assert "result" in res | ||||
|     assert ( | ||||
|         "Received unencrypted response over secure passthrough from 127.0.0.1" | ||||
|         in caplog.text | ||||
|     ) | ||||
|  | ||||
|  | ||||
| async def test_port_override(): | ||||
|     """Test that port override sets the app_url.""" | ||||
|     host = "127.0.0.1" | ||||
|     port_override = 12345 | ||||
|     config = DeviceConfig( | ||||
|         host, credentials=Credentials("foo", "bar"), port_override=port_override | ||||
|     ) | ||||
|     transport = SslAesTransport(config=config) | ||||
|  | ||||
|     assert str(transport._app_url) == f"https://127.0.0.1:{port_override}" | ||||
|  | ||||
|  | ||||
| class MockSslAesDevice: | ||||
|     BAD_USER_RESP = { | ||||
|         "error_code": SmartErrorCode.SESSION_EXPIRED.value, | ||||
|         "result": { | ||||
|             "data": { | ||||
|                 "code": -60502, | ||||
|             } | ||||
|         }, | ||||
|     } | ||||
|  | ||||
|     BAD_PWD_RESP = { | ||||
|         "error_code": SmartErrorCode.INVALID_NONCE.value, | ||||
|         "result": { | ||||
|             "data": { | ||||
|                 "code": SmartErrorCode.SESSION_EXPIRED.value, | ||||
|                 "encrypt_type": ["3"], | ||||
|                 "key": "Someb64keyWithUnknownPurpose", | ||||
|                 "nonce": "1234567890ABCDEF",  # Whatever the original nonce was | ||||
|                 "device_confirm": "", | ||||
|             } | ||||
|         }, | ||||
|     } | ||||
|  | ||||
|     class _mock_response: | ||||
|         def __init__(self, status, request: dict): | ||||
|             self.status = status | ||||
|             self._json = request | ||||
|  | ||||
|         async def __aenter__(self): | ||||
|             return self | ||||
|  | ||||
|         async def __aexit__(self, exc_t, exc_v, exc_tb): | ||||
|             pass | ||||
|  | ||||
|         async def read(self): | ||||
|             if isinstance(self._json, dict): | ||||
|                 return json_dumps(self._json).encode() | ||||
|             return self._json | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         host, | ||||
|         *, | ||||
|         status_code=200, | ||||
|         want_default_username: bool = False, | ||||
|         do_not_encrypt_response=False, | ||||
|         send_response=None, | ||||
|         sequential_request_delay=0, | ||||
|         send_error_code=0, | ||||
|         secure_passthrough_error_code=0, | ||||
|         digest_password_fail=False, | ||||
|     ): | ||||
|         self.host = host | ||||
|         self.http_client = HttpClient(DeviceConfig(self.host)) | ||||
|         self.encryption_session: AesEncyptionSession | None = None | ||||
|         self.server_nonce = secrets.token_bytes(8).hex().upper() | ||||
|         self.handshake1_complete = False | ||||
|  | ||||
|         # test behaviour attributes | ||||
|         self.status_code = status_code | ||||
|         self.send_error_code = send_error_code | ||||
|         self.secure_passthrough_error_code = secure_passthrough_error_code | ||||
|         self.do_not_encrypt_response = do_not_encrypt_response | ||||
|         self.want_default_username = want_default_username | ||||
|         self.digest_password_fail = digest_password_fail | ||||
|  | ||||
|     async def post(self, url: URL, params=None, json=None, data=None, *_, **__): | ||||
|         if data: | ||||
|             json = json_loads(data) | ||||
|         res = await self._post(url, json) | ||||
|         return res | ||||
|  | ||||
|     async def _post(self, url: URL, json: dict[str, Any]): | ||||
|         method = json["method"] | ||||
|  | ||||
|         if method == "login" and not self.handshake1_complete: | ||||
|             return await self._return_handshake1_response(url, json) | ||||
|  | ||||
|         if method == "login" and self.handshake1_complete: | ||||
|             return await self._return_handshake2_response(url, json) | ||||
|         elif method == "securePassthrough": | ||||
|             assert url == URL(f"https://{self.host}/stok={MOCK_STOCK}/ds") | ||||
|             return await self._return_secure_passthrough_response(url, json) | ||||
|         else: | ||||
|             assert url == URL(f"https://{self.host}/stok={MOCK_STOCK}/ds") | ||||
|             return await self._return_send_response(url, json) | ||||
|  | ||||
|     async def _return_handshake1_response(self, url: URL, request: dict[str, Any]): | ||||
|         request_nonce = request["params"].get("cnonce") | ||||
|         request_username = request["params"].get("username") | ||||
|  | ||||
|         if (self.want_default_username and request_username != MOCK_ADMIN_USER) or ( | ||||
|             not self.want_default_username and request_username != MOCK_USER | ||||
|         ): | ||||
|             return self._mock_response(self.status_code, self.BAD_USER_RESP) | ||||
|  | ||||
|         device_confirm = SslAesTransport.generate_confirm_hash( | ||||
|             request_nonce, self.server_nonce, _sha256_hash(MOCK_PWD.encode()) | ||||
|         ) | ||||
|         self.handshake1_complete = True | ||||
|         resp = { | ||||
|             "error_code": SmartErrorCode.INVALID_NONCE.value, | ||||
|             "result": { | ||||
|                 "data": { | ||||
|                     "code": SmartErrorCode.INVALID_NONCE.value, | ||||
|                     "encrypt_type": ["3"], | ||||
|                     "key": "Someb64keyWithUnknownPurpose", | ||||
|                     "nonce": self.server_nonce, | ||||
|                     "device_confirm": device_confirm, | ||||
|                 } | ||||
|             }, | ||||
|         } | ||||
|         return self._mock_response(self.status_code, resp) | ||||
|  | ||||
|     async def _return_handshake2_response(self, url: URL, request: dict[str, Any]): | ||||
|         request_nonce = request["params"].get("cnonce") | ||||
|         request_username = request["params"].get("username") | ||||
|         if (self.want_default_username and request_username != MOCK_ADMIN_USER) or ( | ||||
|             not self.want_default_username and request_username != MOCK_USER | ||||
|         ): | ||||
|             return self._mock_response(self.status_code, self.BAD_USER_RESP) | ||||
|  | ||||
|         request_password = request["params"].get("digest_passwd") | ||||
|         expected_pwd = SslAesTransport.generate_digest_password( | ||||
|             request_nonce, self.server_nonce, _sha256_hash(MOCK_PWD.encode()) | ||||
|         ) | ||||
|         if request_password != expected_pwd or self.digest_password_fail: | ||||
|             return self._mock_response(self.status_code, self.BAD_PWD_RESP) | ||||
|  | ||||
|         lsk = SslAesTransport.generate_encryption_token( | ||||
|             "lsk", request_nonce, self.server_nonce, _sha256_hash(MOCK_PWD.encode()) | ||||
|         ) | ||||
|         ivb = SslAesTransport.generate_encryption_token( | ||||
|             "ivb", request_nonce, self.server_nonce, _sha256_hash(MOCK_PWD.encode()) | ||||
|         ) | ||||
|         self.encryption_session = AesEncyptionSession(lsk, ivb) | ||||
|         resp = { | ||||
|             "error_code": 0, | ||||
|             "result": {"stok": MOCK_STOCK, "user_group": "root", "start_seq": 100}, | ||||
|         } | ||||
|         return self._mock_response(self.status_code, resp) | ||||
|  | ||||
|     async def _return_secure_passthrough_response(self, url: URL, json: dict[str, Any]): | ||||
|         encrypted_request = json["params"]["request"] | ||||
|         assert self.encryption_session | ||||
|         decrypted_request = self.encryption_session.decrypt(encrypted_request.encode()) | ||||
|         decrypted_request_dict = json_loads(decrypted_request) | ||||
|         decrypted_response = await self._post(url, decrypted_request_dict) | ||||
|         async with decrypted_response: | ||||
|             decrypted_response_data = await decrypted_response.read() | ||||
|  | ||||
|         encrypted_response = self.encryption_session.encrypt(decrypted_response_data) | ||||
|         response = ( | ||||
|             decrypted_response_data | ||||
|             if self.do_not_encrypt_response | ||||
|             else encrypted_response | ||||
|         ) | ||||
|         result = { | ||||
|             "result": {"response": response.decode()}, | ||||
|             "error_code": self.secure_passthrough_error_code, | ||||
|         } | ||||
|         return self._mock_response(self.status_code, result) | ||||
|  | ||||
|     async def _return_send_response(self, url: URL, json: dict[str, Any]): | ||||
|         result = {"result": {"method": None}, "error_code": self.send_error_code} | ||||
|         return self._mock_response(self.status_code, result) | ||||
		Reference in New Issue
	
	Block a user
	 Steven B.
					Steven B.