diff --git a/kasa/aestransport.py b/kasa/aestransport.py index c03b6a11..dd61e720 100644 --- a/kasa/aestransport.py +++ b/kasa/aestransport.py @@ -30,7 +30,7 @@ from .exceptions import ( from .httpclient import HttpClient from .json import dumps as json_dumps from .json import loads as json_loads -from .protocol import DEFAULT_CREDENTIALS, BaseTransport, get_default_credentials +from .protocol import BaseTransport _LOGGER = logging.getLogger(__name__) @@ -71,12 +71,12 @@ class AesTransport(BaseTransport): ) and not self._credentials_hash: self._credentials = Credentials() if self._credentials: - self._login_params = self._get_login_params(self._credentials) + self._login_params = self._get_login_params() else: self._login_params = json_loads( base64.b64decode(self._credentials_hash.encode()).decode() # type: ignore[union-attr] ) - self._default_credentials: Optional[Credentials] = None + self._http_client: HttpClient = HttpClient(config) self._handshake_done = False @@ -102,27 +102,26 @@ class AesTransport(BaseTransport): """The hashed credentials used by the transport.""" return base64.b64encode(json_dumps(self._login_params).encode()).decode() - def _get_login_params(self, credentials): + def _get_login_params(self): """Get the login parameters based on the login_version.""" - un, pw = self.hash_credentials(self._login_version == 2, credentials) + un, pw = self.hash_credentials(self._login_version == 2) password_field_name = "password2" if self._login_version == 2 else "password" return {password_field_name: pw, "username": un} - @staticmethod - def hash_credentials(login_v2, credentials): + def hash_credentials(self, login_v2): """Hash the credentials.""" if login_v2: un = base64.b64encode( - _sha1(credentials.username.encode()).encode() + _sha1(self._credentials.username.encode()).encode() ).decode() pw = base64.b64encode( - _sha1(credentials.password.encode()).encode() + _sha1(self._credentials.password.encode()).encode() ).decode() else: un = base64.b64encode( - _sha1(credentials.username.encode()).encode() + _sha1(self._credentials.username.encode()).encode() ).decode() - pw = base64.b64encode(credentials.password.encode()).decode() + pw = base64.b64encode(self._credentials.password.encode()).decode() return un, pw def _handle_response_error_code(self, resp_dict: dict, msg: str): @@ -178,28 +177,10 @@ class AesTransport(BaseTransport): async def perform_login(self): """Login to the device.""" - try: - await self.try_login(self._login_params) - except AuthenticationException as ex: - if ex.error_code != SmartErrorCode.LOGIN_ERROR: - raise ex - if self._default_credentials is None: - self._default_credentials = get_default_credentials( - DEFAULT_CREDENTIALS["TAPO"] - ) - await self.perform_handshake() - await self.try_login(self._get_login_params(self._default_credentials)) - _LOGGER.debug( - "%s: logged in with default credentials", - self._host, - ) - - async def try_login(self, login_params): - """Try to login with supplied login_params.""" self._login_token = None login_request = { "method": "login_device", - "params": login_params, + "params": self._login_params, "request_time_milis": round(time.time() * 1000), } request = json_dumps(login_request) @@ -295,13 +276,7 @@ class AesTransport(BaseTransport): if not self._handshake_done or self._handshake_session_expired(): await self.perform_handshake() if not self._login_token: - try: - await self.perform_login() - # After a login failure handshake needs to - # be redone or a 9999 error is received. - except AuthenticationException as ex: - self._handshake_done = False - raise ex + await self.perform_login() return await self.send_secure_passthrough(request) diff --git a/kasa/klaptransport.py b/kasa/klaptransport.py index c678e448..cdc0da5e 100644 --- a/kasa/klaptransport.py +++ b/kasa/klaptransport.py @@ -58,7 +58,7 @@ from .deviceconfig import DeviceConfig from .exceptions import AuthenticationException, SmartDeviceException from .httpclient import HttpClient from .json import loads as json_loads -from .protocol import DEFAULT_CREDENTIALS, BaseTransport, get_default_credentials, md5 +from .protocol import BaseTransport, md5 _LOGGER = logging.getLogger(__name__) @@ -85,6 +85,9 @@ class KlapTransport(BaseTransport): DEFAULT_PORT: int = 80 DISCOVERY_QUERY = {"system": {"get_sysinfo": None}} + + KASA_SETUP_EMAIL = "kasa@tp-link.net" + KASA_SETUP_PASSWORD = "kasaSetup" # noqa: S105 SESSION_COOKIE_NAME = "TP_SESSIONID" def __init__( @@ -105,7 +108,7 @@ class KlapTransport(BaseTransport): self._local_auth_owner = self.generate_owner_hash(self._credentials).hex() else: self._local_auth_hash = base64.b64decode(self._credentials_hash.encode()) # type: ignore[union-attr] - self._default_credentials_auth_hash: Dict[str, bytes] = {} + self._kasa_setup_auth_hash = None self._blank_auth_hash = None self._handshake_lock = asyncio.Lock() self._query_lock = asyncio.Lock() @@ -180,27 +183,27 @@ class KlapTransport(BaseTransport): _LOGGER.debug("handshake1 hashes match with expected credentials") return local_seed, remote_seed, self._local_auth_hash # type: ignore - # Now check against the default setup credentials - for key, value in DEFAULT_CREDENTIALS.items(): - if key not in self._default_credentials_auth_hash: - default_credentials = get_default_credentials(value) - self._default_credentials_auth_hash[key] = self.generate_auth_hash( - default_credentials - ) - - default_credentials_seed_auth_hash = self.handshake1_seed_auth_hash( - local_seed, - remote_seed, - self._default_credentials_auth_hash[key], # type: ignore + # Now check against the default kasa setup credentials + if not self._kasa_setup_auth_hash: + kasa_setup_creds = Credentials( + username=self.KASA_SETUP_EMAIL, + password=self.KASA_SETUP_PASSWORD, ) + self._kasa_setup_auth_hash = self.generate_auth_hash(kasa_setup_creds) - if default_credentials_seed_auth_hash == server_hash: - _LOGGER.debug( - "Server response doesn't match our expected hash on ip %s" - + f" but an authentication with {key} default credentials matched", - self._host, - ) - return local_seed, remote_seed, self._default_credentials_auth_hash[key] # type: ignore + kasa_setup_seed_auth_hash = self.handshake1_seed_auth_hash( + local_seed, + remote_seed, + self._kasa_setup_auth_hash, # type: ignore + ) + + if kasa_setup_seed_auth_hash == server_hash: + _LOGGER.debug( + "Server response doesn't match our expected hash on ip %s" + + " but an authentication with kasa setup credentials matched", + self._host, + ) + return local_seed, remote_seed, self._kasa_setup_auth_hash # type: ignore # Finally check against blank credentials if not already blank blank_creds = Credentials() diff --git a/kasa/protocol.py b/kasa/protocol.py index ae8eb89b..ca2e3fb6 100755 --- a/kasa/protocol.py +++ b/kasa/protocol.py @@ -10,7 +10,6 @@ which are licensed under the Apache License, Version 2.0 http://www.apache.org/licenses/LICENSE-2.0 """ import asyncio -import base64 import contextlib import errno import logging @@ -18,14 +17,13 @@ import socket import struct from abc import ABC, abstractmethod from pprint import pformat as pf -from typing import Dict, Generator, Optional, Tuple, Union +from typing import Dict, Generator, Optional, Union # When support for cpython older than 3.11 is dropped # async_timeout can be replaced with asyncio.timeout from async_timeout import timeout as asyncio_timeout from cryptography.hazmat.primitives import hashes -from .credentials import Credentials from .deviceconfig import DeviceConfig from .exceptions import SmartDeviceException from .json import dumps as json_dumps @@ -369,18 +367,6 @@ class TPLinkSmartHomeProtocol(BaseProtocol): ).decode() -def get_default_credentials(tuple: Tuple[str, str]) -> Credentials: - """Return decoded default credentials.""" - un = base64.b64decode(tuple[0].encode()).decode() - pw = base64.b64decode(tuple[1].encode()).decode() - return Credentials(un, pw) - - -DEFAULT_CREDENTIALS = { - "KASA": ("a2FzYUB0cC1saW5rLm5ldA==", "a2FzYVNldHVw"), - "TAPO": ("dGVzdEB0cC1saW5rLm5ldA==", "dGVzdA=="), -} - # Try to load the kasa_crypt module and if it is available try: from kasa_crypt import decrypt, encrypt diff --git a/kasa/tests/test_aestransport.py b/kasa/tests/test_aestransport.py index 4694e363..0174c637 100644 --- a/kasa/tests/test_aestransport.py +++ b/kasa/tests/test_aestransport.py @@ -16,7 +16,6 @@ from ..deviceconfig import DeviceConfig from ..exceptions import ( SMART_RETRYABLE_ERRORS, SMART_TIMEOUT_ERRORS, - AuthenticationException, SmartDeviceException, SmartErrorCode, ) @@ -92,53 +91,6 @@ async def test_login(mocker, status_code, error_code, inner_error_code, expectat assert transport._login_token == mock_aes_device.token -@pytest.mark.parametrize( - "inner_error_codes, expectation, call_count", - [ - ([SmartErrorCode.LOGIN_ERROR, 0, 0, 0], does_not_raise(), 4), - ( - [SmartErrorCode.LOGIN_ERROR, SmartErrorCode.LOGIN_ERROR], - pytest.raises(AuthenticationException), - 3, - ), - ( - [SmartErrorCode.LOGIN_FAILED_ERROR], - pytest.raises(AuthenticationException), - 1, - ), - ], - ids=("LOGIN_ERROR-success", "LOGIN_ERROR-LOGIN_ERROR", "LOGIN_FAILED_ERROR"), -) -async def test_login_errors(mocker, inner_error_codes, expectation, call_count): - host = "127.0.0.1" - mock_aes_device = MockAesDevice(host, 200, 0, inner_error_codes) - post_mock = mocker.patch.object( - aiohttp.ClientSession, "post", side_effect=mock_aes_device.post - ) - - transport = AesTransport( - config=DeviceConfig(host, credentials=Credentials("foo", "bar")) - ) - transport._handshake_done = True - transport._session_expire_at = time.time() + 86400 - transport._encryption_session = mock_aes_device.encryption_session - - assert transport._login_token is None - - request = { - "method": "get_device_info", - "params": None, - "request_time_milis": round(time.time() * 1000), - "requestID": 1, - "terminal_uuid": "foobar", - } - - with expectation: - await transport.send(json_dumps(request)) - assert transport._login_token == mock_aes_device.token - assert post_mock.call_count == call_count # Login, Handshake, Login - - @status_parameters async def test_send(mocker, status_code, error_code, inner_error_code, expectation): host = "127.0.0.1" @@ -214,16 +166,8 @@ class MockAesDevice: self.host = host self.status_code = status_code self.error_code = error_code - self._inner_error_code = inner_error_code + self.inner_error_code = inner_error_code self.http_client = HttpClient(DeviceConfig(self.host)) - self.inner_call_count = 0 - - @property - def inner_error_code(self): - if isinstance(self._inner_error_code, list): - return self._inner_error_code[self.inner_call_count] - else: - return self._inner_error_code async def post(self, url, params=None, json=None, data=None, *_, **__): if data: @@ -274,10 +218,8 @@ class MockAesDevice: async def _return_login_response(self, url, json): result = {"result": {"token": self.token}, "error_code": self.inner_error_code} - self.inner_call_count += 1 return self._mock_response(self.status_code, result) async def _return_send_response(self, url, json): result = {"result": {"method": None}, "error_code": self.inner_error_code} - self.inner_call_count += 1 return self._mock_response(self.status_code, result) diff --git a/kasa/tests/test_klapprotocol.py b/kasa/tests/test_klapprotocol.py index 09ceccae..3aab46e8 100644 --- a/kasa/tests/test_klapprotocol.py +++ b/kasa/tests/test_klapprotocol.py @@ -28,7 +28,6 @@ from ..klaptransport import ( KlapTransportV2, _sha256, ) -from ..protocol import DEFAULT_CREDENTIALS, get_default_credentials from ..smartprotocol import SmartProtocol DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}} @@ -243,7 +242,10 @@ def test_encrypt_unicode(): (Credentials("foo", "bar"), does_not_raise()), (Credentials(), does_not_raise()), ( - get_default_credentials(DEFAULT_CREDENTIALS["KASA"]), + Credentials( + KlapTransport.KASA_SETUP_EMAIL, + KlapTransport.KASA_SETUP_PASSWORD, + ), does_not_raise(), ), (