diff --git a/kasa/exceptions.py b/kasa/exceptions.py index a0ecbf8f..f23602a5 100644 --- a/kasa/exceptions.py +++ b/kasa/exceptions.py @@ -132,6 +132,7 @@ class SmartErrorCode(IntEnum): # Camera error codes SESSION_EXPIRED = -40401 + BAD_USERNAME = -40411 # determined from testing HOMEKIT_LOGIN_FAIL = -40412 DEVICE_BLOCKED = -40404 DEVICE_FACTORY = -40405 diff --git a/kasa/transports/sslaestransport.py b/kasa/transports/sslaestransport.py index 500d9422..3ea33145 100644 --- a/kasa/transports/sslaestransport.py +++ b/kasa/transports/sslaestransport.py @@ -126,6 +126,7 @@ class SslAesTransport(BaseTransport): self._password = ch["pwd"] self._username = ch["un"] self._local_nonce: str | None = None + self._send_secure = True _LOGGER.debug("Created AES transport for %s", self._host) @@ -162,7 +163,13 @@ class SslAesTransport(BaseTransport): return error_code def _get_response_inner_error(self, resp_dict: Any) -> SmartErrorCode | None: + # Device blocked errors have 'data' element at the root level, other inner + # errors are inside 'result' error_code_raw = resp_dict.get("data", {}).get("code") + + if error_code_raw is None: + error_code_raw = resp_dict.get("result", {}).get("data", {}).get("code") + if error_code_raw is None: return None try: @@ -208,6 +215,10 @@ class SslAesTransport(BaseTransport): else: url = self._app_url + _LOGGER.debug( + "Sending secure passthrough from %s", + self._host, + ) encrypted_payload = self._encryption_session.encrypt(request.encode()) # type: ignore passthrough_request = { "method": "securePassthrough", @@ -292,6 +303,34 @@ class SslAesTransport(BaseTransport): ) from ex return ret_val # type: ignore[return-value] + async def send_unencrypted(self, request: str) -> dict[str, Any]: + """Send encrypted message as passthrough.""" + url = cast(URL, self._token_url) + + _LOGGER.debug( + "Sending unencrypted to %s", + self._host, + ) + + status_code, resp_dict = await self._http_client.post( + url, + json=request, + headers=self._headers, + ssl=await self._get_ssl_context(), + ) + + if status_code != 200: + raise KasaException( + f"{self._host} responded with an unexpected " + + f"status code {status_code} to unencrypted send" + ) + + self._handle_response_error_code(resp_dict, "Error sending message") + + if TYPE_CHECKING: + resp_dict = cast(dict[str, Any], resp_dict) + return resp_dict + @staticmethod def generate_confirm_hash( local_nonce: str, server_nonce: str, pwd_hash: str @@ -340,8 +379,50 @@ class SslAesTransport(BaseTransport): async def perform_handshake(self) -> None: """Perform the handshake.""" - local_nonce, server_nonce, pwd_hash = await self.perform_handshake1() - await self.perform_handshake2(local_nonce, server_nonce, pwd_hash) + result = await self.perform_handshake1() + if result: + local_nonce, server_nonce, pwd_hash = result + await self.perform_handshake2(local_nonce, server_nonce, pwd_hash) + + async def try_perform_less_secure_login(self, username: str, password: str) -> bool: + """Perform the md5 login.""" + _LOGGER.debug("Performing less secure login...") + + pwd_hash = _md5_hash(password.encode()) + body = { + "method": "login", + "params": { + "hashed": True, + "password": pwd_hash, + "username": username, + }, + } + + status_code, resp_dict = await self._http_client.post( + self._app_url, + json=body, + headers=self._headers, + ssl=await self._get_ssl_context(), + ) + if status_code != 200: + raise KasaException( + f"{self._host} responded with an unexpected " + + f"status code {status_code} to login" + ) + resp_dict = cast(dict, resp_dict) + if resp_dict.get("error_code") == 0 and ( + stok := resp_dict.get("result", {}).get("stok") + ): + _LOGGER.debug( + "Succesfully logged in to %s with less secure passthrough", self._host + ) + self._send_secure = False + self._token_url = URL(f"{str(self._app_url)}/stok={stok}/ds") + self._pwd_hash = pwd_hash + return True + + _LOGGER.debug("Unable to log in to %s with less secure login", self._host) + return False async def perform_handshake2( self, local_nonce: str, server_nonce: str, pwd_hash: str @@ -393,13 +474,50 @@ class SslAesTransport(BaseTransport): self._state = TransportState.ESTABLISHED _LOGGER.debug("Handshake2 complete ...") - async def perform_handshake1(self) -> tuple[str, str, str]: + def _pwd_to_hash(self) -> str: + """Return the password to hash.""" + if self._credentials and self._credentials != Credentials(): + return self._credentials.password + + if self._username and self._password: + return self._password + + return self._default_credentials.password + + def _is_less_secure_login(self, resp_dict: dict[str, Any]) -> bool: + result = ( + self._get_response_error(resp_dict) is SmartErrorCode.SESSION_EXPIRED + and (data := resp_dict.get("result", {}).get("data", {})) + and (encrypt_type := data.get("encrypt_type")) + and (encrypt_type != ["3"]) + ) + if result: + _LOGGER.debug( + "Received encrypt_type %s for %s, trying less secure login", + encrypt_type, + self._host, + ) + return result + + async def perform_handshake1(self) -> tuple[str, str, str] | None: """Perform the handshake1.""" resp_dict = None if self._username: local_nonce = secrets.token_bytes(8).hex().upper() resp_dict = await self.try_send_handshake1(self._username, local_nonce) + if ( + resp_dict + and self._is_less_secure_login(resp_dict) + and self._get_response_inner_error(resp_dict) + is not SmartErrorCode.BAD_USERNAME + and await self.try_perform_less_secure_login( + cast(str, self._username), self._pwd_to_hash() + ) + ): + self._state = TransportState.ESTABLISHED + return None + # Try the default username. If it fails raise the original error_code if ( not resp_dict @@ -407,19 +525,30 @@ class SslAesTransport(BaseTransport): is not SmartErrorCode.INVALID_NONCE or "nonce" not in resp_dict["result"].get("data", {}) ): + _LOGGER.debug("Trying default credentials to %s", self._host) local_nonce = secrets.token_bytes(8).hex().upper() default_resp_dict = await self.try_send_handshake1( self._default_credentials.username, local_nonce ) + # INVALID_NONCE means device should perform secure login if ( default_error_code := self._get_response_error(default_resp_dict) ) is SmartErrorCode.INVALID_NONCE and "nonce" in default_resp_dict[ "result" ].get("data", {}): - _LOGGER.debug("Connected to {self._host} with default username") + _LOGGER.debug("Connected to %s with default username", self._host) self._username = self._default_credentials.username error_code = default_error_code resp_dict = default_resp_dict + # Otherwise could be less secure login + elif self._is_less_secure_login( + default_resp_dict + ) and await self.try_perform_less_secure_login( + self._default_credentials.username, self._pwd_to_hash() + ): + self._username = self._default_credentials.username + self._state = TransportState.ESTABLISHED + return None # If the default login worked it's ok not to provide credentials but if # it didn't raise auth error here. @@ -451,12 +580,8 @@ class SslAesTransport(BaseTransport): server_nonce = resp_dict["result"]["data"]["nonce"] device_confirm = resp_dict["result"]["data"]["device_confirm"] - if self._credentials and self._credentials != Credentials(): - pwd_hash = _sha256_hash(self._credentials.password.encode()) - elif self._username and self._password: - pwd_hash = _sha256_hash(self._password.encode()) - else: - pwd_hash = _sha256_hash(self._default_credentials.password.encode()) + + pwd_hash = _sha256_hash(self._pwd_to_hash().encode()) expected_confirm_sha256 = self.generate_confirm_hash( local_nonce, server_nonce, pwd_hash @@ -468,7 +593,9 @@ class SslAesTransport(BaseTransport): if TYPE_CHECKING: assert self._credentials assert self._credentials.password - pwd_hash = _md5_hash(self._credentials.password.encode()) + + pwd_hash = _md5_hash(self._pwd_to_hash().encode()) + expected_confirm_md5 = self.generate_confirm_hash( local_nonce, server_nonce, pwd_hash ) @@ -478,11 +605,12 @@ class SslAesTransport(BaseTransport): msg = f"Server response doesn't match our challenge on ip {self._host}" _LOGGER.debug(msg) + raise AuthenticationError(msg) async def try_send_handshake1(self, username: str, local_nonce: str) -> dict: """Perform the handshake.""" - _LOGGER.debug("Will to send handshake1...") + _LOGGER.debug("Sending handshake1...") body = { "method": "login", @@ -501,7 +629,7 @@ class SslAesTransport(BaseTransport): ssl=await self._get_ssl_context(), ) - _LOGGER.debug("Device responded with: %s", resp_dict) + _LOGGER.debug("Device responded with status %s: %s", status_code, resp_dict) if status_code != 200: raise KasaException( @@ -516,7 +644,10 @@ class SslAesTransport(BaseTransport): if self._state is TransportState.HANDSHAKE_REQUIRED: await self.perform_handshake() - return await self.send_secure_passthrough(request) + if self._send_secure: + return await self.send_secure_passthrough(request) + + return await self.send_unencrypted(request) async def close(self) -> None: """Close the http client and reset internal state.""" diff --git a/tests/transports/test_sslaestransport.py b/tests/transports/test_sslaestransport.py index 39469967..e8ff9e52 100644 --- a/tests/transports/test_sslaestransport.py +++ b/tests/transports/test_sslaestransport.py @@ -25,16 +25,19 @@ from kasa.transports.aestransport import AesEncyptionSession from kasa.transports.sslaestransport import ( SslAesTransport, TransportState, + _md5_hash, _sha256_hash, ) # Transport tests are not designed for real devices -pytestmark = [pytest.mark.requires_dummy] +# SslAesTransport use a socket to get it's own ip address +pytestmark = [pytest.mark.requires_dummy, pytest.mark.enable_socket] MOCK_ADMIN_USER = get_default_credentials(DEFAULT_CREDENTIALS["TAPOCAMERA"]).username MOCK_PWD = "correct_pwd" # noqa: S105 MOCK_USER = "mock@example.com" MOCK_STOCK = "abcdefghijklmnopqrstuvwxyz1234)(" +MOCK_UNENCRYPTED_PASSTHROUGH_STOK = "32charLowerCaseHexStok" @pytest.mark.parametrize( @@ -202,6 +205,124 @@ async def test_unencrypted_response(mocker, caplog): ) +@pytest.mark.parametrize(("want_default"), [True, False]) +@pytest.mark.xdist_group(name="caplog") +async def test_unencrypted_passthrough(mocker, caplog, want_default): + host = "127.0.0.1" + mock_ssl_aes_device = MockSslAesDevice( + host, unencrypted_passthrough=True, want_default_username=want_default + ) + mocker.patch.object( + aiohttp.ClientSession, "post", side_effect=mock_ssl_aes_device.post + ) + + transport = SslAesTransport( + config=DeviceConfig(host, credentials=Credentials(MOCK_USER, MOCK_PWD)) + ) + + request = { + "method": "getDeviceInfo", + "params": None, + } + caplog.set_level(logging.DEBUG) + res = await transport.send(json_dumps(request)) + assert "result" in res + assert ( + f"Succesfully logged in to {host} with less secure passthrough" in caplog.text + ) + + +@pytest.mark.parametrize(("want_default"), [True, False]) +@pytest.mark.xdist_group(name="caplog") +async def test_unencrypted_passthrough_errors(mocker, caplog, want_default): + host = "127.0.0.1" + request = { + "method": "getDeviceInfo", + "params": None, + } + transport = SslAesTransport( + config=DeviceConfig(host, credentials=Credentials(MOCK_USER, MOCK_PWD)) + ) + caplog.set_level(logging.DEBUG) + + # Test bad password + mock_ssl_aes_device = MockSslAesDevice( + host, + unencrypted_passthrough=True, + want_default_username=want_default, + digest_password_fail=True, + ) + mocker.patch.object( + aiohttp.ClientSession, "post", side_effect=mock_ssl_aes_device.post + ) + + msg = f"Unable to log in to {host} with less secure login" + with pytest.raises(AuthenticationError): + await transport.send(json_dumps(request)) + + assert msg in caplog.text + + # Test bad status code in handshake + mock_ssl_aes_device = MockSslAesDevice( + host, + unencrypted_passthrough=True, + want_default_username=want_default, + status_code=401, + ) + mocker.patch.object( + aiohttp.ClientSession, "post", side_effect=mock_ssl_aes_device.post + ) + + msg = f"{host} responded with an unexpected " f"status code 401 to handshake1" + with pytest.raises(KasaException, match=msg): + await transport.send(json_dumps(request)) + + # Test bad status code in login + mock_ssl_aes_device = MockSslAesDevice( + host, + unencrypted_passthrough=True, + want_default_username=want_default, + status_code_list=[200, 401], + ) + mocker.patch.object( + aiohttp.ClientSession, "post", side_effect=mock_ssl_aes_device.post + ) + + msg = f"{host} responded with an unexpected " f"status code 401 to login" + with pytest.raises(KasaException, match=msg): + await transport.send(json_dumps(request)) + + # Test bad status code in send + mock_ssl_aes_device = MockSslAesDevice( + host, + unencrypted_passthrough=True, + want_default_username=want_default, + status_code_list=[200, 200, 401], + ) + mocker.patch.object( + aiohttp.ClientSession, "post", side_effect=mock_ssl_aes_device.post + ) + + msg = f"{host} responded with an unexpected " f"status code 401 to unencrypted send" + with pytest.raises(KasaException, match=msg): + await transport.send(json_dumps(request)) + + # Test error code in send response + mock_ssl_aes_device = MockSslAesDevice( + host, + unencrypted_passthrough=True, + want_default_username=want_default, + send_error_code=SmartErrorCode.BAD_USERNAME.value, + ) + mocker.patch.object( + aiohttp.ClientSession, "post", side_effect=mock_ssl_aes_device.post + ) + + msg = f"Error sending message: {host}:" + with pytest.raises(KasaException, match=msg): + await transport.send(json_dumps(request)) + + async def test_device_blocked_response(mocker): host = "127.0.0.1" mock_ssl_aes_device = MockSslAesDevice(host, device_blocked=True) @@ -300,6 +421,38 @@ class MockSslAesDevice: "error_code": SmartErrorCode.SESSION_EXPIRED.value, } + UNENCRYPTED_PASSTHROUGH_BAD_USER_RESP = { + "error_code": SmartErrorCode.SESSION_EXPIRED.value, + "result": { + "data": { + "code": SmartErrorCode.BAD_USERNAME.value, + "encrypt_type": ["1", "2"], + "key": "Someb64keyWithUnknownPurpose", + "nonce": "MixedCaseAlphaNumericWithUnknownPurpose", + } + }, + } + + UNENCRYPTED_PASSTHROUGH_HANDSHAKE_RESP = { + "error_code": SmartErrorCode.SESSION_EXPIRED.value, + "result": { + "data": { + "code": SmartErrorCode.SESSION_EXPIRED.value, + "time": 9, + "max_time": 10, + "sec_left": 0, + "encrypt_type": ["1", "2"], + "key": "Someb64keyWithUnknownPurpose", + "nonce": "MixedCaseAlphaNumericWithUnknownPurpose", + } + }, + } + + UNENCRYPTED_PASSTHROUGH_GOOD_LOGIN_RESPONSE = { + "error_code": 0, + "result": {"stok": MOCK_UNENCRYPTED_PASSTHROUGH_STOK, "user_group": "root"}, + } + class _mock_response: def __init__(self, status, request: dict): self.status = status @@ -321,6 +474,7 @@ class MockSslAesDevice: host, *, status_code=200, + status_code_list=None, want_default_username: bool = False, do_not_encrypt_response=False, send_response=None, @@ -329,6 +483,7 @@ class MockSslAesDevice: secure_passthrough_error_code=0, digest_password_fail=False, device_blocked=False, + unencrypted_passthrough=False, ): self.host = host self.http_client = HttpClient(DeviceConfig(self.host)) @@ -338,15 +493,22 @@ class MockSslAesDevice: # test behaviour attributes self.status_code = status_code + self.status_code_list = status_code_list if status_code_list else [] self.send_error_code = send_error_code self.secure_passthrough_error_code = secure_passthrough_error_code self.do_not_encrypt_response = do_not_encrypt_response self.want_default_username = want_default_username self.digest_password_fail = digest_password_fail self.device_blocked = device_blocked + self.unencrypted_passthrough = unencrypted_passthrough self._next_responses: list[dict | bytes] = [] + def _get_status_code(self): + if self.status_code_list: + return self.status_code_list.pop(0) + return self.status_code + async def post(self, url: URL, params=None, json=None, data=None, *_, **__): if data: json = json_loads(data) @@ -360,12 +522,25 @@ class MockSslAesDevice: return await self._return_handshake1_response(url, json) if method == "login" and self.handshake1_complete: + if self.unencrypted_passthrough: + return await self._return_unencrypted_passthrough_login_response( + url, json + ) + return await self._return_handshake2_response(url, json) elif method == "securePassthrough": assert url == URL(f"https://{self.host}/stok={MOCK_STOCK}/ds") return await self._return_secure_passthrough_response(url, json) else: - assert url == URL(f"https://{self.host}/stok={MOCK_STOCK}/ds") + # The unencrypted passthrough with have actual query method names. + # This path is also used by the mock class to return unencrypted + # responses to single 'get' queries which the secure fw returns as unencrypted + stok = ( + MOCK_UNENCRYPTED_PASSTHROUGH_STOK + if self.unencrypted_passthrough + else MOCK_STOCK + ) + assert url == URL(f"https://{self.host}/stok={stok}/ds") return await self._return_send_response(url, json) async def _return_handshake1_response(self, url: URL, request: dict[str, Any]): @@ -378,12 +553,23 @@ class MockSslAesDevice: if (self.want_default_username and request_username != MOCK_ADMIN_USER) or ( not self.want_default_username and request_username != MOCK_USER ): - return self._mock_response(self.status_code, self.BAD_USER_RESP) + resp = ( + self.UNENCRYPTED_PASSTHROUGH_BAD_USER_RESP + if self.unencrypted_passthrough + else self.BAD_USER_RESP + ) + return self._mock_response(self.status_code, resp) device_confirm = SslAesTransport.generate_confirm_hash( request_nonce, self.server_nonce, _sha256_hash(MOCK_PWD.encode()) ) self.handshake1_complete = True + + if self.unencrypted_passthrough: + return self._mock_response( + self._get_status_code(), self.UNENCRYPTED_PASSTHROUGH_HANDSHAKE_RESP + ) + resp = { "error_code": SmartErrorCode.INVALID_NONCE.value, "result": { @@ -396,7 +582,29 @@ class MockSslAesDevice: } }, } - return self._mock_response(self.status_code, resp) + return self._mock_response(self._get_status_code(), resp) + + async def _return_unencrypted_passthrough_login_response( + self, url: URL, request: dict[str, Any] + ): + request_username = request["params"].get("username") + request_password = request["params"].get("password") + if (self.want_default_username and request_username != MOCK_ADMIN_USER) or ( + not self.want_default_username and request_username != MOCK_USER + ): + return self._mock_response( + self._get_status_code(), self.UNENCRYPTED_PASSTHROUGH_BAD_USER_RESP + ) + + expected_pwd = _md5_hash(MOCK_PWD.encode()) + if request_password != expected_pwd or self.digest_password_fail: + return self._mock_response( + self._get_status_code(), self.UNENCRYPTED_PASSTHROUGH_HANDSHAKE_RESP + ) + + return self._mock_response( + self._get_status_code(), self.UNENCRYPTED_PASSTHROUGH_GOOD_LOGIN_RESPONSE + ) async def _return_handshake2_response(self, url: URL, request: dict[str, Any]): request_nonce = request["params"].get("cnonce") @@ -404,14 +612,14 @@ class MockSslAesDevice: if (self.want_default_username and request_username != MOCK_ADMIN_USER) or ( not self.want_default_username and request_username != MOCK_USER ): - return self._mock_response(self.status_code, self.BAD_USER_RESP) + return self._mock_response(self._get_status_code(), self.BAD_USER_RESP) request_password = request["params"].get("digest_passwd") expected_pwd = SslAesTransport.generate_digest_password( request_nonce, self.server_nonce, _sha256_hash(MOCK_PWD.encode()) ) if request_password != expected_pwd or self.digest_password_fail: - return self._mock_response(self.status_code, self.BAD_PWD_RESP) + return self._mock_response(self._get_status_code(), self.BAD_PWD_RESP) lsk = SslAesTransport.generate_encryption_token( "lsk", request_nonce, self.server_nonce, _sha256_hash(MOCK_PWD.encode()) @@ -424,7 +632,7 @@ class MockSslAesDevice: "error_code": 0, "result": {"stok": MOCK_STOCK, "user_group": "root", "start_seq": 100}, } - return self._mock_response(self.status_code, resp) + return self._mock_response(self._get_status_code(), resp) async def _return_secure_passthrough_response(self, url: URL, json: dict[str, Any]): encrypted_request = json["params"]["request"] @@ -458,11 +666,11 @@ class MockSslAesDevice: "result": {"response": response.decode()}, "error_code": self.secure_passthrough_error_code, } - return self._mock_response(self.status_code, result) + return self._mock_response(self._get_status_code(), result) async def _return_send_response(self, url: URL, json: dict[str, Any]): result = {"result": {"method": None}, "error_code": self.send_error_code} - return self._mock_response(self.status_code, result) + return self._mock_response(self._get_status_code(), result) def put_next_response(self, request: dict | bytes) -> None: self._next_responses.append(request)