Try default tapo credentials for klap and aes (#685)

* Try default tapo credentials for klap and aes

* Add tests
This commit is contained in:
Steven B 2024-01-23 14:44:32 +00:00 committed by GitHub
parent c8ac3a29c7
commit 718983c401
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 134 additions and 42 deletions

View File

@ -30,7 +30,7 @@ from .exceptions import (
from .httpclient import HttpClient
from .json import dumps as json_dumps
from .json import loads as json_loads
from .protocol import BaseTransport
from .protocol import DEFAULT_CREDENTIALS, BaseTransport, get_default_credentials
_LOGGER = logging.getLogger(__name__)
@ -69,12 +69,12 @@ class AesTransport(BaseTransport):
) and not self._credentials_hash:
self._credentials = Credentials()
if self._credentials:
self._login_params = self._get_login_params()
self._login_params = self._get_login_params(self._credentials)
else:
self._login_params = json_loads(
base64.b64decode(self._credentials_hash.encode()).decode() # type: ignore[union-attr]
)
self._default_credentials: Optional[Credentials] = None
self._http_client: HttpClient = HttpClient(config)
self._handshake_done = False
@ -98,26 +98,27 @@ class AesTransport(BaseTransport):
"""The hashed credentials used by the transport."""
return base64.b64encode(json_dumps(self._login_params).encode()).decode()
def _get_login_params(self):
def _get_login_params(self, credentials):
"""Get the login parameters based on the login_version."""
un, pw = self.hash_credentials(self._login_version == 2)
un, pw = self.hash_credentials(self._login_version == 2, credentials)
password_field_name = "password2" if self._login_version == 2 else "password"
return {password_field_name: pw, "username": un}
def hash_credentials(self, login_v2):
@staticmethod
def hash_credentials(login_v2, credentials):
"""Hash the credentials."""
if login_v2:
un = base64.b64encode(
_sha1(self._credentials.username.encode()).encode()
_sha1(credentials.username.encode()).encode()
).decode()
pw = base64.b64encode(
_sha1(self._credentials.password.encode()).encode()
_sha1(credentials.password.encode()).encode()
).decode()
else:
un = base64.b64encode(
_sha1(self._credentials.username.encode()).encode()
_sha1(credentials.username.encode()).encode()
).decode()
pw = base64.b64encode(self._credentials.password.encode()).decode()
pw = base64.b64encode(credentials.password.encode()).decode()
return un, pw
def _handle_response_error_code(self, resp_dict: dict, msg: str):
@ -173,10 +174,28 @@ class AesTransport(BaseTransport):
async def perform_login(self):
"""Login to the device."""
try:
await self.try_login(self._login_params)
except AuthenticationException as ex:
if ex.error_code != SmartErrorCode.LOGIN_ERROR:
raise ex
if self._default_credentials is None:
self._default_credentials = get_default_credentials(
DEFAULT_CREDENTIALS["TAPO"]
)
await self.perform_handshake()
await self.try_login(self._get_login_params(self._default_credentials))
_LOGGER.debug(
"%s: logged in with default credentials",
self._host,
)
async def try_login(self, login_params):
"""Try to login with supplied login_params."""
self._login_token = None
login_request = {
"method": "login_device",
"params": self._login_params,
"params": login_params,
"request_time_milis": round(time.time() * 1000),
}
request = json_dumps(login_request)
@ -260,7 +279,13 @@ class AesTransport(BaseTransport):
if not self._handshake_done or self._handshake_session_expired():
await self.perform_handshake()
if not self._login_token:
try:
await self.perform_login()
# After a login failure handshake needs to
# be redone or a 9999 error is received.
except AuthenticationException as ex:
self._handshake_done = False
raise ex
return await self.send_secure_passthrough(request)

View File

@ -58,7 +58,7 @@ from .deviceconfig import DeviceConfig
from .exceptions import AuthenticationException, SmartDeviceException
from .httpclient import HttpClient
from .json import loads as json_loads
from .protocol import BaseTransport, md5
from .protocol import DEFAULT_CREDENTIALS, BaseTransport, get_default_credentials, md5
_LOGGER = logging.getLogger(__name__)
@ -85,9 +85,6 @@ class KlapTransport(BaseTransport):
DEFAULT_PORT: int = 80
DISCOVERY_QUERY = {"system": {"get_sysinfo": None}}
KASA_SETUP_EMAIL = "kasa@tp-link.net"
KASA_SETUP_PASSWORD = "kasaSetup" # noqa: S105
SESSION_COOKIE_NAME = "TP_SESSIONID"
def __init__(
@ -108,7 +105,7 @@ class KlapTransport(BaseTransport):
self._local_auth_owner = self.generate_owner_hash(self._credentials).hex()
else:
self._local_auth_hash = base64.b64decode(self._credentials_hash.encode()) # type: ignore[union-attr]
self._kasa_setup_auth_hash = None
self._default_credentials_auth_hash: Dict[str, bytes] = {}
self._blank_auth_hash = None
self._handshake_lock = asyncio.Lock()
self._query_lock = asyncio.Lock()
@ -183,27 +180,27 @@ class KlapTransport(BaseTransport):
_LOGGER.debug("handshake1 hashes match with expected credentials")
return local_seed, remote_seed, self._local_auth_hash # type: ignore
# Now check against the default kasa setup credentials
if not self._kasa_setup_auth_hash:
kasa_setup_creds = Credentials(
username=self.KASA_SETUP_EMAIL,
password=self.KASA_SETUP_PASSWORD,
# Now check against the default setup credentials
for key, value in DEFAULT_CREDENTIALS.items():
if key not in self._default_credentials_auth_hash:
default_credentials = get_default_credentials(value)
self._default_credentials_auth_hash[key] = self.generate_auth_hash(
default_credentials
)
self._kasa_setup_auth_hash = self.generate_auth_hash(kasa_setup_creds)
kasa_setup_seed_auth_hash = self.handshake1_seed_auth_hash(
default_credentials_seed_auth_hash = self.handshake1_seed_auth_hash(
local_seed,
remote_seed,
self._kasa_setup_auth_hash, # type: ignore
self._default_credentials_auth_hash[key], # type: ignore
)
if kasa_setup_seed_auth_hash == server_hash:
if default_credentials_seed_auth_hash == server_hash:
_LOGGER.debug(
"Server response doesn't match our expected hash on ip %s"
+ " but an authentication with kasa setup credentials matched",
+ f" but an authentication with {key} default credentials matched",
self._host,
)
return local_seed, remote_seed, self._kasa_setup_auth_hash # type: ignore
return local_seed, remote_seed, self._default_credentials_auth_hash[key] # type: ignore
# Finally check against blank credentials if not already blank
blank_creds = Credentials()

View File

@ -10,6 +10,7 @@ which are licensed under the Apache License, Version 2.0
http://www.apache.org/licenses/LICENSE-2.0
"""
import asyncio
import base64
import contextlib
import errno
import logging
@ -17,13 +18,14 @@ import socket
import struct
from abc import ABC, abstractmethod
from pprint import pformat as pf
from typing import Dict, Generator, Optional, Union
from typing import Dict, Generator, Optional, Tuple, Union
# When support for cpython older than 3.11 is dropped
# async_timeout can be replaced with asyncio.timeout
from async_timeout import timeout as asyncio_timeout
from cryptography.hazmat.primitives import hashes
from .credentials import Credentials
from .deviceconfig import DeviceConfig
from .exceptions import SmartDeviceException
from .json import dumps as json_dumps
@ -361,6 +363,18 @@ class TPLinkSmartHomeProtocol(BaseProtocol):
).decode()
def get_default_credentials(tuple: Tuple[str, str]) -> Credentials:
"""Return decoded default credentials."""
un = base64.b64decode(tuple[0].encode()).decode()
pw = base64.b64decode(tuple[1].encode()).decode()
return Credentials(un, pw)
DEFAULT_CREDENTIALS = {
"KASA": ("a2FzYUB0cC1saW5rLm5ldA==", "a2FzYVNldHVw"),
"TAPO": ("dGVzdEB0cC1saW5rLm5ldA==", "dGVzdA=="),
}
# Try to load the kasa_crypt module and if it is available
try:
from kasa_crypt import decrypt, encrypt

View File

@ -16,6 +16,7 @@ from ..deviceconfig import DeviceConfig
from ..exceptions import (
SMART_RETRYABLE_ERRORS,
SMART_TIMEOUT_ERRORS,
AuthenticationException,
SmartDeviceException,
SmartErrorCode,
)
@ -91,6 +92,53 @@ async def test_login(mocker, status_code, error_code, inner_error_code, expectat
assert transport._login_token == mock_aes_device.token
@pytest.mark.parametrize(
"inner_error_codes, expectation, call_count",
[
([SmartErrorCode.LOGIN_ERROR, 0, 0, 0], does_not_raise(), 4),
(
[SmartErrorCode.LOGIN_ERROR, SmartErrorCode.LOGIN_ERROR],
pytest.raises(AuthenticationException),
3,
),
(
[SmartErrorCode.LOGIN_FAILED_ERROR],
pytest.raises(AuthenticationException),
1,
),
],
ids=("LOGIN_ERROR-success", "LOGIN_ERROR-LOGIN_ERROR", "LOGIN_FAILED_ERROR"),
)
async def test_login_errors(mocker, inner_error_codes, expectation, call_count):
host = "127.0.0.1"
mock_aes_device = MockAesDevice(host, 200, 0, inner_error_codes)
post_mock = mocker.patch.object(
aiohttp.ClientSession, "post", side_effect=mock_aes_device.post
)
transport = AesTransport(
config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
)
transport._handshake_done = True
transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session
assert transport._login_token is None
request = {
"method": "get_device_info",
"params": None,
"request_time_milis": round(time.time() * 1000),
"requestID": 1,
"terminal_uuid": "foobar",
}
with expectation:
await transport.send(json_dumps(request))
assert transport._login_token == mock_aes_device.token
assert post_mock.call_count == call_count # Login, Handshake, Login
@status_parameters
async def test_send(mocker, status_code, error_code, inner_error_code, expectation):
host = "127.0.0.1"
@ -166,8 +214,16 @@ class MockAesDevice:
self.host = host
self.status_code = status_code
self.error_code = error_code
self.inner_error_code = inner_error_code
self._inner_error_code = inner_error_code
self.http_client = HttpClient(DeviceConfig(self.host))
self.inner_call_count = 0
@property
def inner_error_code(self):
if isinstance(self._inner_error_code, list):
return self._inner_error_code[self.inner_call_count]
else:
return self._inner_error_code
async def post(self, url, params=None, json=None, *_, **__):
return await self._post(url, json)
@ -215,8 +271,10 @@ class MockAesDevice:
async def _return_login_response(self, url, json):
result = {"result": {"token": self.token}, "error_code": self.inner_error_code}
self.inner_call_count += 1
return self._mock_response(self.status_code, result)
async def _return_send_response(self, url, json):
result = {"result": {"method": None}, "error_code": self.inner_error_code}
self.inner_call_count += 1
return self._mock_response(self.status_code, result)

View File

@ -28,6 +28,7 @@ from ..klaptransport import (
KlapTransportV2,
_sha256,
)
from ..protocol import DEFAULT_CREDENTIALS, get_default_credentials
from ..smartprotocol import SmartProtocol
DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}}
@ -241,10 +242,7 @@ def test_encrypt_unicode():
(Credentials("foo", "bar"), does_not_raise()),
(Credentials(), does_not_raise()),
(
Credentials(
KlapTransport.KASA_SETUP_EMAIL,
KlapTransport.KASA_SETUP_PASSWORD,
),
get_default_credentials(DEFAULT_CREDENTIALS["KASA"]),
does_not_raise(),
),
(