From f045696ebe7dbed55ab928559446506b8fe5aad9 Mon Sep 17 00:00:00 2001 From: Steven B <51370195+sdb9696@users.noreply.github.com> Date: Tue, 23 Jan 2024 21:51:07 +0000 Subject: [PATCH] Fix P100 error getting conn closed when trying default login after login failure (#690) --- kasa/aestransport.py | 33 +++++++++++++++++++++------------ kasa/tests/test_aestransport.py | 12 +++++++++++- 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/kasa/aestransport.py b/kasa/aestransport.py index 14a9ee6a..018176ad 100644 --- a/kasa/aestransport.py +++ b/kasa/aestransport.py @@ -180,19 +180,28 @@ class AesTransport(BaseTransport): """Login to the device.""" try: await self.try_login(self._login_params) - except AuthenticationException as ex: - if ex.error_code != SmartErrorCode.LOGIN_ERROR: - raise ex - if self._default_credentials is None: - self._default_credentials = get_default_credentials( - DEFAULT_CREDENTIALS["TAPO"] + except AuthenticationException as aex: + try: + if aex.error_code != SmartErrorCode.LOGIN_ERROR: + raise aex + if self._default_credentials is None: + self._default_credentials = get_default_credentials( + DEFAULT_CREDENTIALS["TAPO"] + ) + await self.perform_handshake() + await self.try_login(self._get_login_params(self._default_credentials)) + _LOGGER.debug( + "%s: logged in with default credentials", + self._host, ) - await self.perform_handshake() - await self.try_login(self._get_login_params(self._default_credentials)) - _LOGGER.debug( - "%s: logged in with default credentials", - self._host, - ) + except AuthenticationException: + raise + except Exception as ex: + raise AuthenticationException( + "Unable to login and trying default " + + "login raised another exception: %s", + ex, + ) from ex async def try_login(self, login_params): """Try to login with supplied login_params.""" diff --git a/kasa/tests/test_aestransport.py b/kasa/tests/test_aestransport.py index 4694e363..c58aad4e 100644 --- a/kasa/tests/test_aestransport.py +++ b/kasa/tests/test_aestransport.py @@ -106,8 +106,18 @@ async def test_login(mocker, status_code, error_code, inner_error_code, expectat pytest.raises(AuthenticationException), 1, ), + ( + [SmartErrorCode.LOGIN_ERROR, SmartErrorCode.SESSION_TIMEOUT_ERROR], + pytest.raises(SmartDeviceException), + 3, + ), ], - ids=("LOGIN_ERROR-success", "LOGIN_ERROR-LOGIN_ERROR", "LOGIN_FAILED_ERROR"), + ids=( + "LOGIN_ERROR-success", + "LOGIN_ERROR-LOGIN_ERROR", + "LOGIN_FAILED_ERROR", + "LOGIN_ERROR-SESSION_TIMEOUT_ERROR", + ), ) async def test_login_errors(mocker, inner_error_codes, expectation, call_count): host = "127.0.0.1"