Move transports into their own package (#1247)

This moves all transport implementations into a new `transports` package
for cleaner main package & easier to understand project structure.
This commit is contained in:
Teemu R.
2024-11-12 14:40:44 +01:00
committed by GitHub
parent 71ae06fa83
commit 668ba748c5
27 changed files with 159 additions and 102 deletions

View File

@@ -0,0 +1,16 @@
"""Package containing all supported transports."""
from .aestransport import AesEncyptionSession, AesTransport
from .basetransport import BaseTransport
from .klaptransport import KlapTransport, KlapTransportV2
from .xortransport import XorEncryption, XorTransport
__all__ = [
"AesTransport",
"AesEncyptionSession",
"BaseTransport",
"KlapTransport",
"KlapTransportV2",
"XorTransport",
"XorEncryption",
]

View File

@@ -0,0 +1,499 @@
"""Implementation of the TP-Link AES transport.
Based on the work of https://github.com/petretiandrea/plugp100
under compatible GNU GPL3 license.
"""
from __future__ import annotations
import base64
import hashlib
import logging
import time
from collections.abc import AsyncGenerator
from enum import Enum, auto
from typing import TYPE_CHECKING, Any, Dict, cast
from cryptography.hazmat.primitives import hashes, padding, serialization
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from yarl import URL
from kasa.credentials import Credentials
from kasa.deviceconfig import DeviceConfig
from kasa.exceptions import (
SMART_AUTHENTICATION_ERRORS,
SMART_RETRYABLE_ERRORS,
AuthenticationError,
DeviceError,
KasaException,
SmartErrorCode,
TimeoutError,
_ConnectionError,
_RetryableError,
)
from kasa.httpclient import HttpClient
from kasa.json import dumps as json_dumps
from kasa.json import loads as json_loads
from kasa.protocol import DEFAULT_CREDENTIALS, get_default_credentials
from .basetransport import BaseTransport
_LOGGER = logging.getLogger(__name__)
ONE_DAY_SECONDS = 86400
SESSION_EXPIRE_BUFFER_SECONDS = 60 * 20
def _sha1(payload: bytes) -> str:
sha1_algo = hashlib.sha1() # noqa: S324
sha1_algo.update(payload)
return sha1_algo.hexdigest()
class TransportState(Enum):
"""Enum for AES state."""
HANDSHAKE_REQUIRED = auto() # Handshake needed
LOGIN_REQUIRED = auto() # Login needed
ESTABLISHED = auto() # Ready to send requests
class AesTransport(BaseTransport):
"""Implementation of the AES encryption protocol.
AES is the name used in device discovery for TP-Link's TAPO encryption
protocol, sometimes used by newer firmware versions on kasa devices.
"""
DEFAULT_PORT: int = 80
SESSION_COOKIE_NAME = "TP_SESSIONID"
TIMEOUT_COOKIE_NAME = "TIMEOUT"
COMMON_HEADERS = {
"Content-Type": "application/json",
"requestByApp": "true",
"Accept": "application/json",
}
CONTENT_LENGTH = "Content-Length"
KEY_PAIR_CONTENT_LENGTH = 314
def __init__(
self,
*,
config: DeviceConfig,
) -> None:
super().__init__(config=config)
self._login_version = config.connection_type.login_version
if (
not self._credentials or self._credentials.username is None
) and not self._credentials_hash:
self._credentials = Credentials()
if self._credentials:
self._login_params = self._get_login_params(self._credentials)
else:
self._login_params = json_loads(
base64.b64decode(self._credentials_hash.encode()).decode() # type: ignore[union-attr]
)
self._default_credentials: Credentials | None = None
self._http_client: HttpClient = HttpClient(config)
self._state = TransportState.HANDSHAKE_REQUIRED
self._encryption_session: AesEncyptionSession | None = None
self._session_expire_at: float | None = None
self._session_cookie: dict[str, str] | None = None
self._key_pair: KeyPair | None = None
if config.aes_keys:
aes_keys = config.aes_keys
self._key_pair = KeyPair.create_from_der_keys(
aes_keys["private"], aes_keys["public"]
)
self._app_url = URL(f"http://{self._host}:{self._port}/app")
self._token_url: URL | None = None
_LOGGER.debug("Created AES transport for %s", self._host)
@property
def default_port(self) -> int:
"""Default port for the transport."""
return self.DEFAULT_PORT
@property
def credentials_hash(self) -> str | None:
"""The hashed credentials used by the transport."""
if self._credentials == Credentials():
return None
return base64.b64encode(json_dumps(self._login_params).encode()).decode()
def _get_login_params(self, credentials: Credentials) -> dict[str, str]:
"""Get the login parameters based on the login_version."""
un, pw = self.hash_credentials(self._login_version == 2, credentials)
password_field_name = "password2" if self._login_version == 2 else "password"
return {password_field_name: pw, "username": un}
@staticmethod
def hash_credentials(login_v2: bool, credentials: Credentials) -> tuple[str, str]:
"""Hash the credentials."""
un = base64.b64encode(_sha1(credentials.username.encode()).encode()).decode()
if login_v2:
pw = base64.b64encode(
_sha1(credentials.password.encode()).encode()
).decode()
else:
pw = base64.b64encode(credentials.password.encode()).decode()
return un, pw
def _handle_response_error_code(self, resp_dict: dict, msg: str) -> None:
error_code_raw = resp_dict.get("error_code")
try:
error_code = SmartErrorCode.from_int(error_code_raw)
except ValueError:
_LOGGER.warning(
"Device %s received unknown error code: %s", self._host, error_code_raw
)
error_code = SmartErrorCode.INTERNAL_UNKNOWN_ERROR
if error_code is SmartErrorCode.SUCCESS:
return
msg = f"{msg}: {self._host}: {error_code.name}({error_code.value})"
if error_code in SMART_RETRYABLE_ERRORS:
raise _RetryableError(msg, error_code=error_code)
if error_code in SMART_AUTHENTICATION_ERRORS:
self._state = TransportState.HANDSHAKE_REQUIRED
raise AuthenticationError(msg, error_code=error_code)
raise DeviceError(msg, error_code=error_code)
async def send_secure_passthrough(self, request: str) -> dict[str, Any]:
"""Send encrypted message as passthrough."""
if self._state is TransportState.ESTABLISHED and self._token_url:
url = self._token_url
else:
url = self._app_url
encrypted_payload = self._encryption_session.encrypt(request.encode()) # type: ignore
passthrough_request = {
"method": "securePassthrough",
"params": {"request": encrypted_payload.decode()},
}
status_code, resp_dict = await self._http_client.post(
url,
json=passthrough_request,
headers=self.COMMON_HEADERS,
cookies_dict=self._session_cookie,
)
# _LOGGER.debug(f"secure_passthrough response is {status_code}: {resp_dict}")
if status_code != 200:
raise KasaException(
f"{self._host} responded with an unexpected "
+ f"status code {status_code} to passthrough"
)
if TYPE_CHECKING:
resp_dict = cast(Dict[str, Any], resp_dict)
assert self._encryption_session is not None
self._handle_response_error_code(
resp_dict, "Error sending secure_passthrough message"
)
raw_response: str = resp_dict["result"]["response"]
try:
response = self._encryption_session.decrypt(raw_response.encode())
ret_val = json_loads(response)
except Exception as ex:
try:
ret_val = json_loads(raw_response)
_LOGGER.debug(
"Received unencrypted response over secure passthrough from %s",
self._host,
)
except Exception:
raise KasaException(
f"Unable to decrypt response from {self._host}, "
+ f"error: {ex}, response: {raw_response}",
ex,
) from ex
return ret_val # type: ignore[return-value]
async def perform_login(self) -> None:
"""Login to the device."""
try:
await self.try_login(self._login_params)
_LOGGER.debug(
"%s: logged in with provided credentials",
self._host,
)
except AuthenticationError as aex:
try:
if aex.error_code is not SmartErrorCode.LOGIN_ERROR:
raise aex
_LOGGER.debug(
"%s: trying login with default TAPO credentials",
self._host,
)
if self._default_credentials is None:
self._default_credentials = get_default_credentials(
DEFAULT_CREDENTIALS["TAPO"]
)
await self.perform_handshake()
await self.try_login(self._get_login_params(self._default_credentials))
_LOGGER.debug(
"%s: logged in with default TAPO credentials",
self._host,
)
except (AuthenticationError, _ConnectionError, TimeoutError):
raise
except Exception as ex:
raise KasaException(
"Unable to login and trying default "
+ f"login raised another exception: {ex}",
ex,
) from ex
async def try_login(self, login_params: dict[str, Any]) -> None:
"""Try to login with supplied login_params."""
login_request = {
"method": "login_device",
"params": login_params,
"request_time_milis": round(time.time() * 1000),
}
request = json_dumps(login_request)
resp_dict = await self.send_secure_passthrough(request)
self._handle_response_error_code(resp_dict, "Error logging in")
login_token = resp_dict["result"]["token"]
self._token_url = self._app_url.with_query(f"token={login_token}")
self._state = TransportState.ESTABLISHED
async def _generate_key_pair_payload(self) -> AsyncGenerator:
"""Generate the request body and return an ascyn_generator.
This prevents the key pair being generated unless a connection
can be made to the device.
"""
_LOGGER.debug("Generating keypair")
if not self._key_pair:
kp = KeyPair.create_key_pair()
self._config.aes_keys = {
"private": kp.private_key_der_b64,
"public": kp.public_key_der_b64,
}
self._key_pair = kp
pub_key = (
"-----BEGIN PUBLIC KEY-----\n"
+ self._key_pair.public_key_der_b64 # type: ignore[union-attr]
+ "\n-----END PUBLIC KEY-----\n"
)
handshake_params = {"key": pub_key}
request_body = {"method": "handshake", "params": handshake_params}
_LOGGER.debug("Handshake request: %s", request_body)
yield json_dumps(request_body).encode()
async def perform_handshake(self) -> None:
"""Perform the handshake."""
_LOGGER.debug("Will perform handshaking...")
self._token_url = None
self._session_expire_at = None
self._session_cookie = None
# Device needs the content length or it will response with 500
headers = {
**self.COMMON_HEADERS,
self.CONTENT_LENGTH: str(self.KEY_PAIR_CONTENT_LENGTH),
}
http_client = self._http_client
status_code, resp_dict = await http_client.post(
self._app_url,
json=self._generate_key_pair_payload(),
headers=headers,
cookies_dict=self._session_cookie,
)
_LOGGER.debug("Device responded with: %s", resp_dict)
if status_code != 200:
raise KasaException(
f"{self._host} responded with an unexpected "
+ f"status code {status_code} to handshake"
)
if TYPE_CHECKING:
resp_dict = cast(Dict[str, Any], resp_dict)
self._handle_response_error_code(resp_dict, "Unable to complete handshake")
handshake_key = resp_dict["result"]["key"]
if (
cookie := http_client.get_cookie(self.SESSION_COOKIE_NAME) # type: ignore
) or (
cookie := http_client.get_cookie("SESSIONID") # type: ignore
):
self._session_cookie = {self.SESSION_COOKIE_NAME: cookie}
timeout = int(
http_client.get_cookie(self.TIMEOUT_COOKIE_NAME) or ONE_DAY_SECONDS
)
# There is a 24 hour timeout on the session cookie
# but the clock on the device is not always accurate
# so we set the expiry to 24 hours from now minus a buffer
self._session_expire_at = time.time() + timeout - SESSION_EXPIRE_BUFFER_SECONDS
if TYPE_CHECKING:
assert self._key_pair is not None
self._encryption_session = AesEncyptionSession.create_from_keypair(
handshake_key, self._key_pair
)
self._state = TransportState.LOGIN_REQUIRED
_LOGGER.debug("Handshake with %s complete", self._host)
def _handshake_session_expired(self) -> bool:
"""Return true if session has expired."""
return (
self._session_expire_at is None
or self._session_expire_at - time.time() <= 0
)
async def send(self, request: str) -> dict[str, Any]:
"""Send the request."""
if (
self._state is TransportState.HANDSHAKE_REQUIRED
or self._handshake_session_expired()
):
await self.perform_handshake()
if self._state is not TransportState.ESTABLISHED:
try:
await self.perform_login()
# After a login failure handshake needs to
# be redone or a 9999 error is received.
except AuthenticationError as ex:
self._state = TransportState.HANDSHAKE_REQUIRED
raise ex
return await self.send_secure_passthrough(request)
async def close(self) -> None:
"""Close the http client and reset internal state."""
await self.reset()
await self._http_client.close()
async def reset(self) -> None:
"""Reset internal handshake and login state."""
self._state = TransportState.HANDSHAKE_REQUIRED
class AesEncyptionSession:
"""Class for an AES encryption session."""
@staticmethod
def create_from_keypair(
handshake_key: str, keypair: KeyPair
) -> AesEncyptionSession:
"""Create the encryption session."""
handshake_key_bytes: bytes = base64.b64decode(handshake_key.encode())
key_and_iv = keypair.decrypt_handshake_key(handshake_key_bytes)
if key_and_iv is None:
raise ValueError("Decryption failed!")
return AesEncyptionSession(key_and_iv[:16], key_and_iv[16:])
def __init__(self, key: bytes, iv: bytes) -> None:
self.cipher = Cipher(algorithms.AES(key), modes.CBC(iv))
self.padding_strategy = padding.PKCS7(algorithms.AES.block_size)
def encrypt(self, data: bytes) -> bytes:
"""Encrypt the message."""
encryptor = self.cipher.encryptor()
padder = self.padding_strategy.padder()
padded_data = padder.update(data) + padder.finalize()
encrypted = encryptor.update(padded_data) + encryptor.finalize()
return base64.b64encode(encrypted)
def decrypt(self, data: str | bytes) -> str:
"""Decrypt the message."""
decryptor = self.cipher.decryptor()
unpadder = self.padding_strategy.unpadder()
decrypted = decryptor.update(base64.b64decode(data)) + decryptor.finalize()
unpadded_data = unpadder.update(decrypted) + unpadder.finalize()
return unpadded_data.decode()
class KeyPair:
"""Class for generating key pairs."""
@staticmethod
def create_key_pair(key_size: int = 1024) -> KeyPair:
"""Create a key pair."""
private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size)
public_key = private_key.public_key()
return KeyPair(private_key, public_key)
@staticmethod
def create_from_der_keys(
private_key_der_b64: str, public_key_der_b64: str
) -> KeyPair:
"""Create a key pair."""
key_bytes = base64.b64decode(private_key_der_b64.encode())
private_key = cast(
rsa.RSAPrivateKey, serialization.load_der_private_key(key_bytes, None)
)
key_bytes = base64.b64decode(public_key_der_b64.encode())
public_key = cast(
rsa.RSAPublicKey, serialization.load_der_public_key(key_bytes, None)
)
return KeyPair(private_key, public_key)
def __init__(
self, private_key: rsa.RSAPrivateKey, public_key: rsa.RSAPublicKey
) -> None:
self.private_key = private_key
self.public_key = public_key
self.private_key_der_bytes = self.private_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
self.public_key_der_bytes = self.public_key.public_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
self.private_key_der_b64 = base64.b64encode(self.private_key_der_bytes).decode()
self.public_key_der_b64 = base64.b64encode(self.public_key_der_bytes).decode()
def get_public_pem(self) -> bytes:
"""Get public key in PEM encoding."""
return self.public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
def decrypt_handshake_key(self, encrypted_key: bytes) -> bytes:
"""Decrypt an aes handshake key."""
decrypted = self.private_key.decrypt(
encrypted_key, asymmetric_padding.PKCS1v15()
)
return decrypted
def decrypt_discovery_key(self, encrypted_key: bytes) -> bytes:
"""Decrypt an aes discovery key."""
decrypted = self.private_key.decrypt(
encrypted_key,
asymmetric_padding.OAEP(
mgf=asymmetric_padding.MGF1(algorithm=hashes.SHA1()), # noqa: S303
algorithm=hashes.SHA1(), # noqa: S303
label=None,
),
)
return decrypted

View File

@@ -0,0 +1,55 @@
"""Base class for all transport implementations.
All transport classes must derive from this to implement the common interface.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from kasa import DeviceConfig
class BaseTransport(ABC):
"""Base class for all TP-Link protocol transports."""
DEFAULT_TIMEOUT = 5
def __init__(
self,
*,
config: DeviceConfig,
) -> None:
"""Create a protocol object."""
self._config = config
self._host = config.host
self._port = config.port_override or self.default_port
self._credentials = config.credentials
self._credentials_hash = config.credentials_hash
if not config.timeout:
config.timeout = self.DEFAULT_TIMEOUT
self._timeout = config.timeout
@property
@abstractmethod
def default_port(self) -> int:
"""The default port for the transport."""
@property
@abstractmethod
def credentials_hash(self) -> str | None:
"""The hashed credentials used by the transport."""
@abstractmethod
async def send(self, request: str) -> dict:
"""Send a message to the device and return a response."""
@abstractmethod
async def close(self) -> None:
"""Close the transport. Abstract method to be overriden."""
@abstractmethod
async def reset(self) -> None:
"""Reset internal state."""

View File

@@ -0,0 +1,512 @@
"""Implementation of the TP-Link Klap Home Protocol.
Encryption/Decryption methods based on the works of
Simon Wilkinson and Chris Weeldon
Klap devices that have never been connected to the kasa
cloud should work with blank credentials.
Devices that have been connected to the kasa cloud will
switch intermittently between the users cloud credentials
and default kasa credentials that are hardcoded.
This appears to be an issue with the devices.
The protocol works by doing a two stage handshake to obtain
and encryption key and session id cookie.
Authentication uses an auth_hash which is
md5(md5(username),md5(password))
handshake1: client sends a random 16 byte local_seed to the
device and receives a random 16 bytes remote_seed, followed
by sha256(local_seed + auth_hash). It also returns a
TP_SESSIONID in the cookie header. This implementation
then checks this value against the possible auth_hashes
described above (user cloud, kasa hardcoded, blank). If it
finds a match it moves onto handshake2
handshake2: client sends sha25(remote_seed + auth_hash) to
the device along with the TP_SESSIONID. Device responds with
200 if successful. It generally will be because this
implementation checks the auth_hash it received during handshake1
encryption: local_seed, remote_seed and auth_hash are now used
for encryption. The last 4 bytes of the initialization vector
are used as a sequence number that increments every time the
client calls encrypt and this sequence number is sent as a
url parameter to the device along with the encrypted payload
https://gist.github.com/chriswheeldon/3b17d974db3817613c69191c0480fe55
https://github.com/python-kasa/python-kasa/pull/117
"""
from __future__ import annotations
import asyncio
import base64
import datetime
import hashlib
import logging
import secrets
import struct
import time
from asyncio import Future
from typing import TYPE_CHECKING, Any, Generator, cast
from cryptography.hazmat.primitives import padding
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from yarl import URL
from kasa.credentials import Credentials
from kasa.deviceconfig import DeviceConfig
from kasa.exceptions import AuthenticationError, KasaException, _RetryableError
from kasa.httpclient import HttpClient
from kasa.json import loads as json_loads
from kasa.protocol import (
DEFAULT_CREDENTIALS,
get_default_credentials,
md5,
)
from .basetransport import BaseTransport
_LOGGER = logging.getLogger(__name__)
ONE_DAY_SECONDS = 86400
SESSION_EXPIRE_BUFFER_SECONDS = 60 * 20
PACK_SIGNED_LONG = struct.Struct(">l").pack
def _sha256(payload: bytes) -> bytes:
return hashlib.sha256(payload).digest() # noqa: S324
def _sha1(payload: bytes) -> bytes:
return hashlib.sha1(payload).digest() # noqa: S324
class KlapTransport(BaseTransport):
"""Implementation of the KLAP encryption protocol.
KLAP is the name used in device discovery for TP-Link's new encryption
protocol, used by newer firmware versions.
"""
DEFAULT_PORT: int = 80
SESSION_COOKIE_NAME = "TP_SESSIONID"
TIMEOUT_COOKIE_NAME = "TIMEOUT"
def __init__(
self,
*,
config: DeviceConfig,
) -> None:
super().__init__(config=config)
self._http_client = HttpClient(config)
self._local_seed: bytes | None = None
if (
not self._credentials or self._credentials.username is None
) and not self._credentials_hash:
self._credentials = Credentials()
if self._credentials:
self._local_auth_hash = self.generate_auth_hash(self._credentials)
self._local_auth_owner = self.generate_owner_hash(self._credentials).hex()
else:
self._local_auth_hash = base64.b64decode(self._credentials_hash.encode()) # type: ignore[union-attr]
self._default_credentials_auth_hash: dict[str, bytes] = {}
self._blank_auth_hash: bytes | None = None
self._handshake_lock = asyncio.Lock()
self._query_lock = asyncio.Lock()
self._handshake_done: bool = False
self._encryption_session: KlapEncryptionSession | None = None
self._session_expire_at: float | None = None
self._session_cookie: dict[str, Any] | None = None
_LOGGER.debug("Created KLAP transport for %s", self._host)
self._app_url = URL(f"http://{self._host}:{self._port}/app")
self._request_url = self._app_url / "request"
@property
def default_port(self) -> int:
"""Default port for the transport."""
return self.DEFAULT_PORT
@property
def credentials_hash(self) -> str | None:
"""The hashed credentials used by the transport."""
if self._credentials == Credentials():
return None
return base64.b64encode(self._local_auth_hash).decode()
async def perform_handshake1(self) -> tuple[bytes, bytes, bytes]:
"""Perform handshake1."""
local_seed: bytes = secrets.token_bytes(16)
# Handshake 1 has a payload of local_seed
# and a response of 16 bytes, followed by
# sha256(remote_seed | auth_hash)
payload = local_seed
url = self._app_url / "handshake1"
response_status, response_data = await self._http_client.post(url, data=payload)
if _LOGGER.isEnabledFor(logging.DEBUG):
_LOGGER.debug(
"Handshake1 posted at %s. Host is %s, "
"Response status is %s, Request was %s",
datetime.datetime.now(),
self._host,
response_status,
payload.hex(),
)
if response_status != 200:
raise KasaException(
f"Device {self._host} responded with {response_status} to handshake1"
)
response_data = cast(bytes, response_data)
remote_seed: bytes = response_data[0:16]
server_hash = response_data[16:]
if len(server_hash) != 32:
raise KasaException(
f"Device {self._host} responded with unexpected klap response "
+ f"{response_data!r} to handshake1"
)
if _LOGGER.isEnabledFor(logging.DEBUG):
_LOGGER.debug(
"Handshake1 success at %s. Host is %s, "
"Server remote_seed is: %s, server hash is: %s",
datetime.datetime.now(),
self._host,
remote_seed.hex(),
server_hash.hex(),
)
local_seed_auth_hash = self.handshake1_seed_auth_hash(
local_seed, remote_seed, self._local_auth_hash
) # type: ignore
# Check the response from the device with local credentials
if local_seed_auth_hash == server_hash:
_LOGGER.debug("handshake1 hashes match with expected credentials")
return local_seed, remote_seed, self._local_auth_hash # type: ignore
# Now check against the default setup credentials
for key, value in DEFAULT_CREDENTIALS.items():
if key not in self._default_credentials_auth_hash:
default_credentials = get_default_credentials(value)
self._default_credentials_auth_hash[key] = self.generate_auth_hash(
default_credentials
)
default_credentials_seed_auth_hash = self.handshake1_seed_auth_hash(
local_seed,
remote_seed,
self._default_credentials_auth_hash[key], # type: ignore
)
if default_credentials_seed_auth_hash == server_hash:
_LOGGER.debug(
"Server response doesn't match our expected hash on ip %s, "
"but an authentication with %s default credentials matched",
self._host,
key,
)
return local_seed, remote_seed, self._default_credentials_auth_hash[key] # type: ignore
# Finally check against blank credentials if not already blank
blank_creds = Credentials()
if self._credentials != blank_creds:
if not self._blank_auth_hash:
self._blank_auth_hash = self.generate_auth_hash(blank_creds)
blank_seed_auth_hash = self.handshake1_seed_auth_hash(
local_seed,
remote_seed,
self._blank_auth_hash, # type: ignore
)
if blank_seed_auth_hash == server_hash:
_LOGGER.debug(
"Server response doesn't match our expected hash on ip %s, "
"but an authentication with blank credentials matched",
self._host,
)
return local_seed, remote_seed, self._blank_auth_hash # type: ignore
msg = f"Server response doesn't match our challenge on ip {self._host}"
_LOGGER.debug(msg)
raise AuthenticationError(msg)
async def perform_handshake2(
self, local_seed: bytes, remote_seed: bytes, auth_hash: bytes
) -> KlapEncryptionSession:
"""Perform handshake2."""
# Handshake 2 has the following payload:
# sha256(serverBytes | authenticator)
url = self._app_url / "handshake2"
payload = self.handshake2_seed_auth_hash(local_seed, remote_seed, auth_hash)
response_status, _ = await self._http_client.post(
url,
data=payload,
cookies_dict=self._session_cookie,
)
if _LOGGER.isEnabledFor(logging.DEBUG):
_LOGGER.debug(
"Handshake2 posted %s. Host is %s, "
"Response status is %s, Request was %s",
datetime.datetime.now(),
self._host,
response_status,
payload.hex(),
)
if response_status != 200:
# This shouldn't be caused by incorrect
# credentials so don't raise AuthenticationError
raise KasaException(
f"Device {self._host} responded with {response_status} to handshake2"
)
return KlapEncryptionSession(local_seed, remote_seed, auth_hash)
async def perform_handshake(self) -> None:
"""Perform handshake1 and handshake2.
Sets the encryption_session if successful.
"""
_LOGGER.debug("Starting handshake with %s", self._host)
self._handshake_done = False
self._session_expire_at = None
self._session_cookie = None
local_seed, remote_seed, auth_hash = await self.perform_handshake1()
http_client = self._http_client
if cookie := http_client.get_cookie(self.SESSION_COOKIE_NAME): # type: ignore
self._session_cookie = {self.SESSION_COOKIE_NAME: cookie}
# The device returns a TIMEOUT cookie on handshake1 which
# it doesn't like to get back so we store the one we want
timeout = int(
http_client.get_cookie(self.TIMEOUT_COOKIE_NAME) or ONE_DAY_SECONDS
)
# There is a 24 hour timeout on the session cookie
# but the clock on the device is not always accurate
# so we set the expiry to 24 hours from now minus a buffer
self._session_expire_at = (
time.monotonic() + timeout - SESSION_EXPIRE_BUFFER_SECONDS
)
self._encryption_session = await self.perform_handshake2(
local_seed, remote_seed, auth_hash
)
self._handshake_done = True
_LOGGER.debug("Handshake with %s complete", self._host)
def _handshake_session_expired(self) -> bool:
"""Return true if session has expired."""
return (
self._session_expire_at is None
or self._session_expire_at - time.monotonic() <= 0
)
async def send(self, request: str) -> Generator[Future, None, dict[str, str]]: # type: ignore[override]
"""Send the request."""
if not self._handshake_done or self._handshake_session_expired():
await self.perform_handshake()
# Check for mypy
if self._encryption_session is not None:
payload, seq = self._encryption_session.encrypt(request.encode())
response_status, response_data = await self._http_client.post(
self._request_url,
params={"seq": seq},
data=payload,
cookies_dict=self._session_cookie,
)
msg = (
f"Host is {self._host}, "
+ f"Sequence is {seq}, "
+ f"Response status is {response_status}, Request was {request}"
)
if response_status != 200:
_LOGGER.error("Query failed after successful authentication: %s", msg)
# If we failed with a security error, force a new handshake next time.
if response_status == 403:
self._handshake_done = False
raise _RetryableError(
"Got a security error from %s after handshake completed", self._host
)
else:
raise KasaException(
f"Device {self._host} responded with {response_status} to "
f"request with seq {seq}"
)
else:
_LOGGER.debug("Device %s query posted %s", self._host, msg)
if TYPE_CHECKING:
assert self._encryption_session
assert isinstance(response_data, bytes)
try:
decrypted_response = self._encryption_session.decrypt(response_data)
except Exception as ex:
raise KasaException(
f"Error trying to decrypt device {self._host} response: {ex}"
) from ex
json_payload = json_loads(decrypted_response)
_LOGGER.debug("Device %s query response received", self._host)
return json_payload
async def close(self) -> None:
"""Close the http client and reset internal state."""
await self.reset()
await self._http_client.close()
async def reset(self) -> None:
"""Reset internal handshake state."""
self._handshake_done = False
@staticmethod
def generate_auth_hash(creds: Credentials) -> bytes:
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
un = creds.username
pw = creds.password
return md5(md5(un.encode()) + md5(pw.encode()))
@staticmethod
def handshake1_seed_auth_hash(
local_seed: bytes, remote_seed: bytes, auth_hash: bytes
) -> bytes:
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
return _sha256(local_seed + auth_hash)
@staticmethod
def handshake2_seed_auth_hash(
local_seed: bytes, remote_seed: bytes, auth_hash: bytes
) -> bytes:
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
return _sha256(remote_seed + auth_hash)
@staticmethod
def generate_owner_hash(creds: Credentials) -> bytes:
"""Return the MD5 hash of the username in this object."""
un = creds.username
return md5(un.encode())
class KlapTransportV2(KlapTransport):
"""Implementation of the KLAP encryption protocol with v2 hanshake hashes."""
@staticmethod
def generate_auth_hash(creds: Credentials) -> bytes:
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
un = creds.username
pw = creds.password
return _sha256(_sha1(un.encode()) + _sha1(pw.encode()))
@staticmethod
def handshake1_seed_auth_hash(
local_seed: bytes, remote_seed: bytes, auth_hash: bytes
) -> bytes:
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
return _sha256(local_seed + remote_seed + auth_hash)
@staticmethod
def handshake2_seed_auth_hash(
local_seed: bytes, remote_seed: bytes, auth_hash: bytes
) -> bytes:
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
return _sha256(remote_seed + local_seed + auth_hash)
class KlapEncryptionSession:
"""Class to represent an encryption session and it's internal state.
i.e. sequence number which the device expects to increment.
"""
_cipher: Cipher
def __init__(self, local_seed: bytes, remote_seed: bytes, user_hash: bytes) -> None:
self.local_seed = local_seed
self.remote_seed = remote_seed
self.user_hash = user_hash
self._key = self._key_derive(local_seed, remote_seed, user_hash)
(self._iv, self._seq) = self._iv_derive(local_seed, remote_seed, user_hash)
self._aes = algorithms.AES(self._key)
self._sig = self._sig_derive(local_seed, remote_seed, user_hash)
def _key_derive(
self, local_seed: bytes, remote_seed: bytes, user_hash: bytes
) -> bytes:
payload = b"lsk" + local_seed + remote_seed + user_hash
return hashlib.sha256(payload).digest()[:16]
def _iv_derive(
self, local_seed: bytes, remote_seed: bytes, user_hash: bytes
) -> tuple[bytes, int]:
# iv is first 16 bytes of sha256, where the last 4 bytes forms the
# sequence number used in requests and is incremented on each request
payload = b"iv" + local_seed + remote_seed + user_hash
fulliv = hashlib.sha256(payload).digest()
seq = int.from_bytes(fulliv[-4:], "big", signed=True)
return (fulliv[:12], seq)
def _sig_derive(
self, local_seed: bytes, remote_seed: bytes, user_hash: bytes
) -> bytes:
# used to create a hash with which to prefix each request
payload = b"ldk" + local_seed + remote_seed + user_hash
return hashlib.sha256(payload).digest()[:28]
def _generate_cipher(self) -> None:
iv_seq = self._iv + PACK_SIGNED_LONG(self._seq)
cbc = modes.CBC(iv_seq)
self._cipher = Cipher(self._aes, cbc)
def encrypt(self, msg: bytes | str) -> tuple[bytes, int]:
"""Encrypt the data and increment the sequence number."""
self._seq += 1
self._generate_cipher()
if isinstance(msg, str):
msg = msg.encode("utf-8")
encryptor = self._cipher.encryptor()
padder = padding.PKCS7(128).padder()
padded_data = padder.update(msg) + padder.finalize()
ciphertext = encryptor.update(padded_data) + encryptor.finalize()
signature = hashlib.sha256(
self._sig + PACK_SIGNED_LONG(self._seq) + ciphertext
).digest()
return (signature + ciphertext, self._seq)
def decrypt(self, msg: bytes) -> str:
"""Decrypt the data."""
decryptor = self._cipher.decryptor()
dp = decryptor.update(msg[32:]) + decryptor.finalize()
unpadder = padding.PKCS7(128).unpadder()
plaintextbytes = unpadder.update(dp) + unpadder.finalize()
return plaintextbytes.decode()

View File

@@ -0,0 +1,234 @@
"""Implementation of the legacy TP-Link Smart Home Protocol.
Encryption/Decryption methods based on the works of
Lubomir Stroetmann and Tobias Esser
https://www.softscheck.com/en/reverse-engineering-tp-link-hs110/
https://github.com/softScheck/tplink-smartplug/
which are licensed under the Apache License, Version 2.0
http://www.apache.org/licenses/LICENSE-2.0
"""
from __future__ import annotations
import asyncio
import contextlib
import errno
import logging
import socket
import struct
from collections.abc import Generator
# When support for cpython older than 3.11 is dropped
# async_timeout can be replaced with asyncio.timeout
from async_timeout import timeout as asyncio_timeout
from kasa.deviceconfig import DeviceConfig
from kasa.exceptions import KasaException, _RetryableError
from kasa.json import loads as json_loads
from .basetransport import BaseTransport
_LOGGER = logging.getLogger(__name__)
_NO_RETRY_ERRORS = {errno.EHOSTDOWN, errno.EHOSTUNREACH, errno.ECONNREFUSED}
_UNSIGNED_INT_NETWORK_ORDER = struct.Struct(">I")
class XorTransport(BaseTransport):
"""XorTransport class."""
DEFAULT_PORT: int = 9999
BLOCK_SIZE = 4
def __init__(self, *, config: DeviceConfig) -> None:
super().__init__(config=config)
self.reader: asyncio.StreamReader | None = None
self.writer: asyncio.StreamWriter | None = None
self.query_lock = asyncio.Lock()
self.loop: asyncio.AbstractEventLoop | None = None
@property
def default_port(self) -> int:
"""Default port for the transport."""
return self.DEFAULT_PORT
@property
def credentials_hash(self) -> str | None:
"""The hashed credentials used by the transport."""
return None
async def _connect(self, timeout: int) -> None:
"""Try to connect or reconnect to the device."""
if self.writer:
return
self.reader = self.writer = None
task = asyncio.open_connection(self._host, self._port)
async with asyncio_timeout(timeout):
self.reader, self.writer = await task
sock: socket.socket = self.writer.get_extra_info("socket")
# Ensure our packets get sent without delay as we do all
# our writes in a single go and we do not want any buffering
# which would needlessly delay the request or risk overloading
# the buffer on the device
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
async def _execute_send(self, request: str) -> dict:
"""Execute a query on the device and wait for the response."""
assert self.writer is not None # noqa: S101
assert self.reader is not None # noqa: S101
_LOGGER.debug("Device %s sending query %s", self._host, request)
self.writer.write(XorEncryption.encrypt(request))
await self.writer.drain()
packed_block_size = await self.reader.readexactly(self.BLOCK_SIZE)
length = _UNSIGNED_INT_NETWORK_ORDER.unpack(packed_block_size)[0]
buffer = await self.reader.readexactly(length)
response = XorEncryption.decrypt(buffer)
json_payload = json_loads(response)
_LOGGER.debug("Device %s query response received", self._host)
return json_payload
async def close(self) -> None:
"""Close the connection."""
writer = self.writer
self.close_without_wait()
if writer:
with contextlib.suppress(Exception):
await writer.wait_closed()
def close_without_wait(self) -> None:
"""Close the connection without waiting for the connection to close."""
writer = self.writer
self.reader = self.writer = None
if writer:
writer.close()
async def reset(self) -> None:
"""Reset the transport.
The transport cannot be reset so we must close instead.
"""
await self.close()
async def send(self, request: str) -> dict:
"""Send a message to the device and return a response."""
#
# Most of the time we will already be connected if the device is online
# and the connect call will do nothing and return right away
#
# However, if we get an unrecoverable error (_NO_RETRY_ERRORS and
# ConnectionRefusedError) we do not want to keep trying since many
# connection open/close operations in the same time frame can block
# the event loop.
# This is especially import when there are multiple tplink devices being polled.
try:
await self._connect(self._timeout)
except ConnectionRefusedError as ex:
await self.reset()
raise KasaException(
f"Unable to connect to the device: {self._host}:{self._port}: {ex}"
) from ex
except OSError as ex:
await self.reset()
if ex.errno in _NO_RETRY_ERRORS:
raise KasaException(
f"Unable to connect to the device:"
f" {self._host}:{self._port}: {ex}"
) from ex
else:
raise _RetryableError(
f"Unable to connect to the device:"
f" {self._host}:{self._port}: {ex}"
) from ex
except Exception as ex:
await self.reset()
raise _RetryableError(
f"Unable to connect to the device:" f" {self._host}:{self._port}: {ex}"
) from ex
except BaseException:
# Likely something cancelled the task so we need to close the connection
# as we are not in an indeterminate state
self.close_without_wait()
raise
try:
assert self.reader is not None # noqa: S101
assert self.writer is not None # noqa: S101
async with asyncio_timeout(self._timeout):
return await self._execute_send(request)
except Exception as ex:
await self.reset()
raise _RetryableError(
f"Unable to query the device {self._host}:{self._port}: {ex}"
) from ex
except BaseException:
# Likely something cancelled the task so we need to close the connection
# as we are not in an indeterminate state
self.close_without_wait()
raise
def __del__(self) -> None:
if self.writer and self.loop and self.loop.is_running():
# Since __del__ will be called when python does
# garbage collection is can happen in the event loop thread
# or in another thread so we need to make sure the call to
# close is called safely with call_soon_threadsafe
self.loop.call_soon_threadsafe(self.writer.close)
class XorEncryption:
"""XorEncryption class."""
INITIALIZATION_VECTOR = 171
@staticmethod
def _xor_payload(unencrypted: bytes) -> Generator[int, None, None]:
key = XorEncryption.INITIALIZATION_VECTOR
for unencryptedbyte in unencrypted:
key = key ^ unencryptedbyte
yield key
@staticmethod
def encrypt(request: str) -> bytes:
"""Encrypt a request for a TP-Link Smart Home Device.
:param request: plaintext request data
:return: ciphertext to be send over wire, in bytes
"""
plainbytes = request.encode()
return _UNSIGNED_INT_NETWORK_ORDER.pack(len(plainbytes)) + bytes(
XorEncryption._xor_payload(plainbytes)
)
@staticmethod
def _xor_encrypted_payload(ciphertext: bytes) -> Generator[int, None, None]:
key = XorEncryption.INITIALIZATION_VECTOR
for cipherbyte in ciphertext:
plainbyte = key ^ cipherbyte
key = cipherbyte
yield plainbyte
@staticmethod
def decrypt(ciphertext: bytes) -> str:
"""Decrypt a response of a TP-Link Smart Home Device.
:param ciphertext: encrypted response data
:return: plaintext response
"""
return bytes(XorEncryption._xor_encrypted_payload(ciphertext)).decode()
# Try to load the kasa_crypt module and if it is available
try:
from kasa_crypt import decrypt, encrypt
XorEncryption.decrypt = decrypt # type: ignore[assignment]
XorEncryption.encrypt = encrypt # type: ignore[assignment]
except ImportError:
pass