Try default tapo credentials for klap and aes (#685)

* Try default tapo credentials for klap and aes

* Add tests
This commit is contained in:
Steven B 2024-01-23 14:44:32 +00:00 committed by GitHub
parent c8ac3a29c7
commit 718983c401
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 134 additions and 42 deletions

View File

@ -30,7 +30,7 @@ from .exceptions import (
from .httpclient import HttpClient from .httpclient import HttpClient
from .json import dumps as json_dumps from .json import dumps as json_dumps
from .json import loads as json_loads from .json import loads as json_loads
from .protocol import BaseTransport from .protocol import DEFAULT_CREDENTIALS, BaseTransport, get_default_credentials
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -69,12 +69,12 @@ class AesTransport(BaseTransport):
) and not self._credentials_hash: ) and not self._credentials_hash:
self._credentials = Credentials() self._credentials = Credentials()
if self._credentials: if self._credentials:
self._login_params = self._get_login_params() self._login_params = self._get_login_params(self._credentials)
else: else:
self._login_params = json_loads( self._login_params = json_loads(
base64.b64decode(self._credentials_hash.encode()).decode() # type: ignore[union-attr] base64.b64decode(self._credentials_hash.encode()).decode() # type: ignore[union-attr]
) )
self._default_credentials: Optional[Credentials] = None
self._http_client: HttpClient = HttpClient(config) self._http_client: HttpClient = HttpClient(config)
self._handshake_done = False self._handshake_done = False
@ -98,26 +98,27 @@ class AesTransport(BaseTransport):
"""The hashed credentials used by the transport.""" """The hashed credentials used by the transport."""
return base64.b64encode(json_dumps(self._login_params).encode()).decode() return base64.b64encode(json_dumps(self._login_params).encode()).decode()
def _get_login_params(self): def _get_login_params(self, credentials):
"""Get the login parameters based on the login_version.""" """Get the login parameters based on the login_version."""
un, pw = self.hash_credentials(self._login_version == 2) un, pw = self.hash_credentials(self._login_version == 2, credentials)
password_field_name = "password2" if self._login_version == 2 else "password" password_field_name = "password2" if self._login_version == 2 else "password"
return {password_field_name: pw, "username": un} return {password_field_name: pw, "username": un}
def hash_credentials(self, login_v2): @staticmethod
def hash_credentials(login_v2, credentials):
"""Hash the credentials.""" """Hash the credentials."""
if login_v2: if login_v2:
un = base64.b64encode( un = base64.b64encode(
_sha1(self._credentials.username.encode()).encode() _sha1(credentials.username.encode()).encode()
).decode() ).decode()
pw = base64.b64encode( pw = base64.b64encode(
_sha1(self._credentials.password.encode()).encode() _sha1(credentials.password.encode()).encode()
).decode() ).decode()
else: else:
un = base64.b64encode( un = base64.b64encode(
_sha1(self._credentials.username.encode()).encode() _sha1(credentials.username.encode()).encode()
).decode() ).decode()
pw = base64.b64encode(self._credentials.password.encode()).decode() pw = base64.b64encode(credentials.password.encode()).decode()
return un, pw return un, pw
def _handle_response_error_code(self, resp_dict: dict, msg: str): def _handle_response_error_code(self, resp_dict: dict, msg: str):
@ -173,10 +174,28 @@ class AesTransport(BaseTransport):
async def perform_login(self): async def perform_login(self):
"""Login to the device.""" """Login to the device."""
try:
await self.try_login(self._login_params)
except AuthenticationException as ex:
if ex.error_code != SmartErrorCode.LOGIN_ERROR:
raise ex
if self._default_credentials is None:
self._default_credentials = get_default_credentials(
DEFAULT_CREDENTIALS["TAPO"]
)
await self.perform_handshake()
await self.try_login(self._get_login_params(self._default_credentials))
_LOGGER.debug(
"%s: logged in with default credentials",
self._host,
)
async def try_login(self, login_params):
"""Try to login with supplied login_params."""
self._login_token = None self._login_token = None
login_request = { login_request = {
"method": "login_device", "method": "login_device",
"params": self._login_params, "params": login_params,
"request_time_milis": round(time.time() * 1000), "request_time_milis": round(time.time() * 1000),
} }
request = json_dumps(login_request) request = json_dumps(login_request)
@ -260,7 +279,13 @@ class AesTransport(BaseTransport):
if not self._handshake_done or self._handshake_session_expired(): if not self._handshake_done or self._handshake_session_expired():
await self.perform_handshake() await self.perform_handshake()
if not self._login_token: if not self._login_token:
await self.perform_login() try:
await self.perform_login()
# After a login failure handshake needs to
# be redone or a 9999 error is received.
except AuthenticationException as ex:
self._handshake_done = False
raise ex
return await self.send_secure_passthrough(request) return await self.send_secure_passthrough(request)

View File

@ -58,7 +58,7 @@ from .deviceconfig import DeviceConfig
from .exceptions import AuthenticationException, SmartDeviceException from .exceptions import AuthenticationException, SmartDeviceException
from .httpclient import HttpClient from .httpclient import HttpClient
from .json import loads as json_loads from .json import loads as json_loads
from .protocol import BaseTransport, md5 from .protocol import DEFAULT_CREDENTIALS, BaseTransport, get_default_credentials, md5
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -85,9 +85,6 @@ class KlapTransport(BaseTransport):
DEFAULT_PORT: int = 80 DEFAULT_PORT: int = 80
DISCOVERY_QUERY = {"system": {"get_sysinfo": None}} DISCOVERY_QUERY = {"system": {"get_sysinfo": None}}
KASA_SETUP_EMAIL = "kasa@tp-link.net"
KASA_SETUP_PASSWORD = "kasaSetup" # noqa: S105
SESSION_COOKIE_NAME = "TP_SESSIONID" SESSION_COOKIE_NAME = "TP_SESSIONID"
def __init__( def __init__(
@ -108,7 +105,7 @@ class KlapTransport(BaseTransport):
self._local_auth_owner = self.generate_owner_hash(self._credentials).hex() self._local_auth_owner = self.generate_owner_hash(self._credentials).hex()
else: else:
self._local_auth_hash = base64.b64decode(self._credentials_hash.encode()) # type: ignore[union-attr] self._local_auth_hash = base64.b64decode(self._credentials_hash.encode()) # type: ignore[union-attr]
self._kasa_setup_auth_hash = None self._default_credentials_auth_hash: Dict[str, bytes] = {}
self._blank_auth_hash = None self._blank_auth_hash = None
self._handshake_lock = asyncio.Lock() self._handshake_lock = asyncio.Lock()
self._query_lock = asyncio.Lock() self._query_lock = asyncio.Lock()
@ -183,27 +180,27 @@ class KlapTransport(BaseTransport):
_LOGGER.debug("handshake1 hashes match with expected credentials") _LOGGER.debug("handshake1 hashes match with expected credentials")
return local_seed, remote_seed, self._local_auth_hash # type: ignore return local_seed, remote_seed, self._local_auth_hash # type: ignore
# Now check against the default kasa setup credentials # Now check against the default setup credentials
if not self._kasa_setup_auth_hash: for key, value in DEFAULT_CREDENTIALS.items():
kasa_setup_creds = Credentials( if key not in self._default_credentials_auth_hash:
username=self.KASA_SETUP_EMAIL, default_credentials = get_default_credentials(value)
password=self.KASA_SETUP_PASSWORD, self._default_credentials_auth_hash[key] = self.generate_auth_hash(
) default_credentials
self._kasa_setup_auth_hash = self.generate_auth_hash(kasa_setup_creds) )
kasa_setup_seed_auth_hash = self.handshake1_seed_auth_hash( default_credentials_seed_auth_hash = self.handshake1_seed_auth_hash(
local_seed, local_seed,
remote_seed, remote_seed,
self._kasa_setup_auth_hash, # type: ignore self._default_credentials_auth_hash[key], # type: ignore
)
if kasa_setup_seed_auth_hash == server_hash:
_LOGGER.debug(
"Server response doesn't match our expected hash on ip %s"
+ " but an authentication with kasa setup credentials matched",
self._host,
) )
return local_seed, remote_seed, self._kasa_setup_auth_hash # type: ignore
if default_credentials_seed_auth_hash == server_hash:
_LOGGER.debug(
"Server response doesn't match our expected hash on ip %s"
+ f" but an authentication with {key} default credentials matched",
self._host,
)
return local_seed, remote_seed, self._default_credentials_auth_hash[key] # type: ignore
# Finally check against blank credentials if not already blank # Finally check against blank credentials if not already blank
blank_creds = Credentials() blank_creds = Credentials()

View File

@ -10,6 +10,7 @@ which are licensed under the Apache License, Version 2.0
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
""" """
import asyncio import asyncio
import base64
import contextlib import contextlib
import errno import errno
import logging import logging
@ -17,13 +18,14 @@ import socket
import struct import struct
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pprint import pformat as pf from pprint import pformat as pf
from typing import Dict, Generator, Optional, Union from typing import Dict, Generator, Optional, Tuple, Union
# When support for cpython older than 3.11 is dropped # When support for cpython older than 3.11 is dropped
# async_timeout can be replaced with asyncio.timeout # async_timeout can be replaced with asyncio.timeout
from async_timeout import timeout as asyncio_timeout from async_timeout import timeout as asyncio_timeout
from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import hashes
from .credentials import Credentials
from .deviceconfig import DeviceConfig from .deviceconfig import DeviceConfig
from .exceptions import SmartDeviceException from .exceptions import SmartDeviceException
from .json import dumps as json_dumps from .json import dumps as json_dumps
@ -361,6 +363,18 @@ class TPLinkSmartHomeProtocol(BaseProtocol):
).decode() ).decode()
def get_default_credentials(tuple: Tuple[str, str]) -> Credentials:
"""Return decoded default credentials."""
un = base64.b64decode(tuple[0].encode()).decode()
pw = base64.b64decode(tuple[1].encode()).decode()
return Credentials(un, pw)
DEFAULT_CREDENTIALS = {
"KASA": ("a2FzYUB0cC1saW5rLm5ldA==", "a2FzYVNldHVw"),
"TAPO": ("dGVzdEB0cC1saW5rLm5ldA==", "dGVzdA=="),
}
# Try to load the kasa_crypt module and if it is available # Try to load the kasa_crypt module and if it is available
try: try:
from kasa_crypt import decrypt, encrypt from kasa_crypt import decrypt, encrypt

View File

@ -16,6 +16,7 @@ from ..deviceconfig import DeviceConfig
from ..exceptions import ( from ..exceptions import (
SMART_RETRYABLE_ERRORS, SMART_RETRYABLE_ERRORS,
SMART_TIMEOUT_ERRORS, SMART_TIMEOUT_ERRORS,
AuthenticationException,
SmartDeviceException, SmartDeviceException,
SmartErrorCode, SmartErrorCode,
) )
@ -91,6 +92,53 @@ async def test_login(mocker, status_code, error_code, inner_error_code, expectat
assert transport._login_token == mock_aes_device.token assert transport._login_token == mock_aes_device.token
@pytest.mark.parametrize(
"inner_error_codes, expectation, call_count",
[
([SmartErrorCode.LOGIN_ERROR, 0, 0, 0], does_not_raise(), 4),
(
[SmartErrorCode.LOGIN_ERROR, SmartErrorCode.LOGIN_ERROR],
pytest.raises(AuthenticationException),
3,
),
(
[SmartErrorCode.LOGIN_FAILED_ERROR],
pytest.raises(AuthenticationException),
1,
),
],
ids=("LOGIN_ERROR-success", "LOGIN_ERROR-LOGIN_ERROR", "LOGIN_FAILED_ERROR"),
)
async def test_login_errors(mocker, inner_error_codes, expectation, call_count):
host = "127.0.0.1"
mock_aes_device = MockAesDevice(host, 200, 0, inner_error_codes)
post_mock = mocker.patch.object(
aiohttp.ClientSession, "post", side_effect=mock_aes_device.post
)
transport = AesTransport(
config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
)
transport._handshake_done = True
transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session
assert transport._login_token is None
request = {
"method": "get_device_info",
"params": None,
"request_time_milis": round(time.time() * 1000),
"requestID": 1,
"terminal_uuid": "foobar",
}
with expectation:
await transport.send(json_dumps(request))
assert transport._login_token == mock_aes_device.token
assert post_mock.call_count == call_count # Login, Handshake, Login
@status_parameters @status_parameters
async def test_send(mocker, status_code, error_code, inner_error_code, expectation): async def test_send(mocker, status_code, error_code, inner_error_code, expectation):
host = "127.0.0.1" host = "127.0.0.1"
@ -166,8 +214,16 @@ class MockAesDevice:
self.host = host self.host = host
self.status_code = status_code self.status_code = status_code
self.error_code = error_code self.error_code = error_code
self.inner_error_code = inner_error_code self._inner_error_code = inner_error_code
self.http_client = HttpClient(DeviceConfig(self.host)) self.http_client = HttpClient(DeviceConfig(self.host))
self.inner_call_count = 0
@property
def inner_error_code(self):
if isinstance(self._inner_error_code, list):
return self._inner_error_code[self.inner_call_count]
else:
return self._inner_error_code
async def post(self, url, params=None, json=None, *_, **__): async def post(self, url, params=None, json=None, *_, **__):
return await self._post(url, json) return await self._post(url, json)
@ -215,8 +271,10 @@ class MockAesDevice:
async def _return_login_response(self, url, json): async def _return_login_response(self, url, json):
result = {"result": {"token": self.token}, "error_code": self.inner_error_code} result = {"result": {"token": self.token}, "error_code": self.inner_error_code}
self.inner_call_count += 1
return self._mock_response(self.status_code, result) return self._mock_response(self.status_code, result)
async def _return_send_response(self, url, json): async def _return_send_response(self, url, json):
result = {"result": {"method": None}, "error_code": self.inner_error_code} result = {"result": {"method": None}, "error_code": self.inner_error_code}
self.inner_call_count += 1
return self._mock_response(self.status_code, result) return self._mock_response(self.status_code, result)

View File

@ -28,6 +28,7 @@ from ..klaptransport import (
KlapTransportV2, KlapTransportV2,
_sha256, _sha256,
) )
from ..protocol import DEFAULT_CREDENTIALS, get_default_credentials
from ..smartprotocol import SmartProtocol from ..smartprotocol import SmartProtocol
DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}} DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}}
@ -241,10 +242,7 @@ def test_encrypt_unicode():
(Credentials("foo", "bar"), does_not_raise()), (Credentials("foo", "bar"), does_not_raise()),
(Credentials(), does_not_raise()), (Credentials(), does_not_raise()),
( (
Credentials( get_default_credentials(DEFAULT_CREDENTIALS["KASA"]),
KlapTransport.KASA_SETUP_EMAIL,
KlapTransport.KASA_SETUP_PASSWORD,
),
does_not_raise(), does_not_raise(),
), ),
( (