diff --git a/kasa/aestransport.py b/kasa/aestransport.py index 65b0045d..cd810b8f 100644 --- a/kasa/aestransport.py +++ b/kasa/aestransport.py @@ -30,7 +30,7 @@ from .exceptions import ( from .httpclient import HttpClient from .json import dumps as json_dumps from .json import loads as json_loads -from .protocol import BaseTransport +from .protocol import DEFAULT_CREDENTIALS, BaseTransport, get_default_credentials _LOGGER = logging.getLogger(__name__) @@ -69,12 +69,12 @@ class AesTransport(BaseTransport): ) and not self._credentials_hash: self._credentials = Credentials() if self._credentials: - self._login_params = self._get_login_params() + self._login_params = self._get_login_params(self._credentials) else: self._login_params = json_loads( base64.b64decode(self._credentials_hash.encode()).decode() # type: ignore[union-attr] ) - + self._default_credentials: Optional[Credentials] = None self._http_client: HttpClient = HttpClient(config) self._handshake_done = False @@ -98,26 +98,27 @@ class AesTransport(BaseTransport): """The hashed credentials used by the transport.""" return base64.b64encode(json_dumps(self._login_params).encode()).decode() - def _get_login_params(self): + def _get_login_params(self, credentials): """Get the login parameters based on the login_version.""" - un, pw = self.hash_credentials(self._login_version == 2) + un, pw = self.hash_credentials(self._login_version == 2, credentials) password_field_name = "password2" if self._login_version == 2 else "password" return {password_field_name: pw, "username": un} - def hash_credentials(self, login_v2): + @staticmethod + def hash_credentials(login_v2, credentials): """Hash the credentials.""" if login_v2: un = base64.b64encode( - _sha1(self._credentials.username.encode()).encode() + _sha1(credentials.username.encode()).encode() ).decode() pw = base64.b64encode( - _sha1(self._credentials.password.encode()).encode() + _sha1(credentials.password.encode()).encode() ).decode() else: un = base64.b64encode( - _sha1(self._credentials.username.encode()).encode() + _sha1(credentials.username.encode()).encode() ).decode() - pw = base64.b64encode(self._credentials.password.encode()).decode() + pw = base64.b64encode(credentials.password.encode()).decode() return un, pw def _handle_response_error_code(self, resp_dict: dict, msg: str): @@ -173,10 +174,28 @@ class AesTransport(BaseTransport): async def perform_login(self): """Login to the device.""" + try: + await self.try_login(self._login_params) + except AuthenticationException as ex: + if ex.error_code != SmartErrorCode.LOGIN_ERROR: + raise ex + if self._default_credentials is None: + self._default_credentials = get_default_credentials( + DEFAULT_CREDENTIALS["TAPO"] + ) + await self.perform_handshake() + await self.try_login(self._get_login_params(self._default_credentials)) + _LOGGER.debug( + "%s: logged in with default credentials", + self._host, + ) + + async def try_login(self, login_params): + """Try to login with supplied login_params.""" self._login_token = None login_request = { "method": "login_device", - "params": self._login_params, + "params": login_params, "request_time_milis": round(time.time() * 1000), } request = json_dumps(login_request) @@ -260,7 +279,13 @@ class AesTransport(BaseTransport): if not self._handshake_done or self._handshake_session_expired(): await self.perform_handshake() if not self._login_token: - await self.perform_login() + try: + await self.perform_login() + # After a login failure handshake needs to + # be redone or a 9999 error is received. + except AuthenticationException as ex: + self._handshake_done = False + raise ex return await self.send_secure_passthrough(request) diff --git a/kasa/klaptransport.py b/kasa/klaptransport.py index 92d6fd2b..5411314a 100644 --- a/kasa/klaptransport.py +++ b/kasa/klaptransport.py @@ -58,7 +58,7 @@ from .deviceconfig import DeviceConfig from .exceptions import AuthenticationException, SmartDeviceException from .httpclient import HttpClient from .json import loads as json_loads -from .protocol import BaseTransport, md5 +from .protocol import DEFAULT_CREDENTIALS, BaseTransport, get_default_credentials, md5 _LOGGER = logging.getLogger(__name__) @@ -85,9 +85,6 @@ class KlapTransport(BaseTransport): DEFAULT_PORT: int = 80 DISCOVERY_QUERY = {"system": {"get_sysinfo": None}} - - KASA_SETUP_EMAIL = "kasa@tp-link.net" - KASA_SETUP_PASSWORD = "kasaSetup" # noqa: S105 SESSION_COOKIE_NAME = "TP_SESSIONID" def __init__( @@ -108,7 +105,7 @@ class KlapTransport(BaseTransport): self._local_auth_owner = self.generate_owner_hash(self._credentials).hex() else: self._local_auth_hash = base64.b64decode(self._credentials_hash.encode()) # type: ignore[union-attr] - self._kasa_setup_auth_hash = None + self._default_credentials_auth_hash: Dict[str, bytes] = {} self._blank_auth_hash = None self._handshake_lock = asyncio.Lock() self._query_lock = asyncio.Lock() @@ -183,27 +180,27 @@ class KlapTransport(BaseTransport): _LOGGER.debug("handshake1 hashes match with expected credentials") return local_seed, remote_seed, self._local_auth_hash # type: ignore - # Now check against the default kasa setup credentials - if not self._kasa_setup_auth_hash: - kasa_setup_creds = Credentials( - username=self.KASA_SETUP_EMAIL, - password=self.KASA_SETUP_PASSWORD, - ) - self._kasa_setup_auth_hash = self.generate_auth_hash(kasa_setup_creds) + # Now check against the default setup credentials + for key, value in DEFAULT_CREDENTIALS.items(): + if key not in self._default_credentials_auth_hash: + default_credentials = get_default_credentials(value) + self._default_credentials_auth_hash[key] = self.generate_auth_hash( + default_credentials + ) - kasa_setup_seed_auth_hash = self.handshake1_seed_auth_hash( - local_seed, - remote_seed, - self._kasa_setup_auth_hash, # type: ignore - ) - - if kasa_setup_seed_auth_hash == server_hash: - _LOGGER.debug( - "Server response doesn't match our expected hash on ip %s" - + " but an authentication with kasa setup credentials matched", - self._host, + default_credentials_seed_auth_hash = self.handshake1_seed_auth_hash( + local_seed, + remote_seed, + self._default_credentials_auth_hash[key], # type: ignore ) - return local_seed, remote_seed, self._kasa_setup_auth_hash # type: ignore + + if default_credentials_seed_auth_hash == server_hash: + _LOGGER.debug( + "Server response doesn't match our expected hash on ip %s" + + f" but an authentication with {key} default credentials matched", + self._host, + ) + return local_seed, remote_seed, self._default_credentials_auth_hash[key] # type: ignore # Finally check against blank credentials if not already blank blank_creds = Credentials() diff --git a/kasa/protocol.py b/kasa/protocol.py index bbdd81fd..59fea4a8 100755 --- a/kasa/protocol.py +++ b/kasa/protocol.py @@ -10,6 +10,7 @@ which are licensed under the Apache License, Version 2.0 http://www.apache.org/licenses/LICENSE-2.0 """ import asyncio +import base64 import contextlib import errno import logging @@ -17,13 +18,14 @@ import socket import struct from abc import ABC, abstractmethod from pprint import pformat as pf -from typing import Dict, Generator, Optional, Union +from typing import Dict, Generator, Optional, Tuple, Union # When support for cpython older than 3.11 is dropped # async_timeout can be replaced with asyncio.timeout from async_timeout import timeout as asyncio_timeout from cryptography.hazmat.primitives import hashes +from .credentials import Credentials from .deviceconfig import DeviceConfig from .exceptions import SmartDeviceException from .json import dumps as json_dumps @@ -361,6 +363,18 @@ class TPLinkSmartHomeProtocol(BaseProtocol): ).decode() +def get_default_credentials(tuple: Tuple[str, str]) -> Credentials: + """Return decoded default credentials.""" + un = base64.b64decode(tuple[0].encode()).decode() + pw = base64.b64decode(tuple[1].encode()).decode() + return Credentials(un, pw) + + +DEFAULT_CREDENTIALS = { + "KASA": ("a2FzYUB0cC1saW5rLm5ldA==", "a2FzYVNldHVw"), + "TAPO": ("dGVzdEB0cC1saW5rLm5ldA==", "dGVzdA=="), +} + # Try to load the kasa_crypt module and if it is available try: from kasa_crypt import decrypt, encrypt diff --git a/kasa/tests/test_aestransport.py b/kasa/tests/test_aestransport.py index 774aaf94..748dae9a 100644 --- a/kasa/tests/test_aestransport.py +++ b/kasa/tests/test_aestransport.py @@ -16,6 +16,7 @@ from ..deviceconfig import DeviceConfig from ..exceptions import ( SMART_RETRYABLE_ERRORS, SMART_TIMEOUT_ERRORS, + AuthenticationException, SmartDeviceException, SmartErrorCode, ) @@ -91,6 +92,53 @@ async def test_login(mocker, status_code, error_code, inner_error_code, expectat assert transport._login_token == mock_aes_device.token +@pytest.mark.parametrize( + "inner_error_codes, expectation, call_count", + [ + ([SmartErrorCode.LOGIN_ERROR, 0, 0, 0], does_not_raise(), 4), + ( + [SmartErrorCode.LOGIN_ERROR, SmartErrorCode.LOGIN_ERROR], + pytest.raises(AuthenticationException), + 3, + ), + ( + [SmartErrorCode.LOGIN_FAILED_ERROR], + pytest.raises(AuthenticationException), + 1, + ), + ], + ids=("LOGIN_ERROR-success", "LOGIN_ERROR-LOGIN_ERROR", "LOGIN_FAILED_ERROR"), +) +async def test_login_errors(mocker, inner_error_codes, expectation, call_count): + host = "127.0.0.1" + mock_aes_device = MockAesDevice(host, 200, 0, inner_error_codes) + post_mock = mocker.patch.object( + aiohttp.ClientSession, "post", side_effect=mock_aes_device.post + ) + + transport = AesTransport( + config=DeviceConfig(host, credentials=Credentials("foo", "bar")) + ) + transport._handshake_done = True + transport._session_expire_at = time.time() + 86400 + transport._encryption_session = mock_aes_device.encryption_session + + assert transport._login_token is None + + request = { + "method": "get_device_info", + "params": None, + "request_time_milis": round(time.time() * 1000), + "requestID": 1, + "terminal_uuid": "foobar", + } + + with expectation: + await transport.send(json_dumps(request)) + assert transport._login_token == mock_aes_device.token + assert post_mock.call_count == call_count # Login, Handshake, Login + + @status_parameters async def test_send(mocker, status_code, error_code, inner_error_code, expectation): host = "127.0.0.1" @@ -166,8 +214,16 @@ class MockAesDevice: self.host = host self.status_code = status_code self.error_code = error_code - self.inner_error_code = inner_error_code + self._inner_error_code = inner_error_code self.http_client = HttpClient(DeviceConfig(self.host)) + self.inner_call_count = 0 + + @property + def inner_error_code(self): + if isinstance(self._inner_error_code, list): + return self._inner_error_code[self.inner_call_count] + else: + return self._inner_error_code async def post(self, url, params=None, json=None, *_, **__): return await self._post(url, json) @@ -215,8 +271,10 @@ class MockAesDevice: async def _return_login_response(self, url, json): result = {"result": {"token": self.token}, "error_code": self.inner_error_code} + self.inner_call_count += 1 return self._mock_response(self.status_code, result) async def _return_send_response(self, url, json): result = {"result": {"method": None}, "error_code": self.inner_error_code} + self.inner_call_count += 1 return self._mock_response(self.status_code, result) diff --git a/kasa/tests/test_klapprotocol.py b/kasa/tests/test_klapprotocol.py index 8ae32e3f..54f4a4be 100644 --- a/kasa/tests/test_klapprotocol.py +++ b/kasa/tests/test_klapprotocol.py @@ -28,6 +28,7 @@ from ..klaptransport import ( KlapTransportV2, _sha256, ) +from ..protocol import DEFAULT_CREDENTIALS, get_default_credentials from ..smartprotocol import SmartProtocol DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}} @@ -241,10 +242,7 @@ def test_encrypt_unicode(): (Credentials("foo", "bar"), does_not_raise()), (Credentials(), does_not_raise()), ( - Credentials( - KlapTransport.KASA_SETUP_EMAIL, - KlapTransport.KASA_SETUP_PASSWORD, - ), + get_default_credentials(DEFAULT_CREDENTIALS["KASA"]), does_not_raise(), ), (