python-kasa/kasa/aestransport.py
J. Nick Koston bab40d43e6
Renew the handshake session 20 minutes before we think it will expire (#697)
* Renew the KLAP handshake session 20 minutes before we think it will expire

Currently we assumed the clocks were perfectly aligned and the handshake
session lasted 20 hours.  We now add a 20 minute buffer

* use timeout cookie when available
2024-01-24 10:11:27 +01:00

421 lines
15 KiB
Python

"""Implementation of the TP-Link AES transport.
Based on the work of https://github.com/petretiandrea/plugp100
under compatible GNU GPL3 license.
"""
import base64
import hashlib
import logging
import time
from enum import Enum, auto
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Optional, Tuple, cast
from cryptography.hazmat.primitives import padding, serialization
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from .credentials import Credentials
from .deviceconfig import DeviceConfig
from .exceptions import (
SMART_AUTHENTICATION_ERRORS,
SMART_RETRYABLE_ERRORS,
SMART_TIMEOUT_ERRORS,
AuthenticationException,
RetryableException,
SmartDeviceException,
SmartErrorCode,
TimeoutException,
)
from .httpclient import HttpClient
from .json import dumps as json_dumps
from .json import loads as json_loads
from .protocol import DEFAULT_CREDENTIALS, BaseTransport, get_default_credentials
_LOGGER = logging.getLogger(__name__)
ONE_DAY_SECONDS = 86400
SESSION_EXPIRE_BUFFER_SECONDS = 60 * 20
def _sha1(payload: bytes) -> str:
sha1_algo = hashlib.sha1() # noqa: S324
sha1_algo.update(payload)
return sha1_algo.hexdigest()
class TransportState(Enum):
"""Enum for AES state."""
HANDSHAKE_REQUIRED = auto() # Handshake needed
LOGIN_REQUIRED = auto() # Login needed
ESTABLISHED = auto() # Ready to send requests
class AesTransport(BaseTransport):
"""Implementation of the AES encryption protocol.
AES is the name used in device discovery for TP-Link's TAPO encryption
protocol, sometimes used by newer firmware versions on kasa devices.
"""
DEFAULT_PORT: int = 80
SESSION_COOKIE_NAME = "TP_SESSIONID"
TIMEOUT_COOKIE_NAME = "TIMEOUT"
COMMON_HEADERS = {
"Content-Type": "application/json",
"requestByApp": "true",
"Accept": "application/json",
}
CONTENT_LENGTH = "Content-Length"
KEY_PAIR_CONTENT_LENGTH = 314
def __init__(
self,
*,
config: DeviceConfig,
) -> None:
super().__init__(config=config)
self._login_version = config.connection_type.login_version
if (
not self._credentials or self._credentials.username is None
) and not self._credentials_hash:
self._credentials = Credentials()
if self._credentials:
self._login_params = self._get_login_params(self._credentials)
else:
self._login_params = json_loads(
base64.b64decode(self._credentials_hash.encode()).decode() # type: ignore[union-attr]
)
self._default_credentials: Optional[Credentials] = None
self._http_client: HttpClient = HttpClient(config)
self._state = TransportState.HANDSHAKE_REQUIRED
self._encryption_session: Optional[AesEncyptionSession] = None
self._session_expire_at: Optional[float] = None
self._session_cookie: Optional[Dict[str, str]] = None
self._login_token: Optional[str] = None
self._key_pair: Optional[KeyPair] = None
_LOGGER.debug("Created AES transport for %s", self._host)
@property
def default_port(self) -> int:
"""Default port for the transport."""
return self.DEFAULT_PORT
@property
def credentials_hash(self) -> str:
"""The hashed credentials used by the transport."""
return base64.b64encode(json_dumps(self._login_params).encode()).decode()
def _get_login_params(self, credentials: Credentials) -> Dict[str, str]:
"""Get the login parameters based on the login_version."""
un, pw = self.hash_credentials(self._login_version == 2, credentials)
password_field_name = "password2" if self._login_version == 2 else "password"
return {password_field_name: pw, "username": un}
@staticmethod
def hash_credentials(login_v2: bool, credentials: Credentials) -> Tuple[str, str]:
"""Hash the credentials."""
un = base64.b64encode(_sha1(credentials.username.encode()).encode()).decode()
if login_v2:
pw = base64.b64encode(
_sha1(credentials.password.encode()).encode()
).decode()
else:
pw = base64.b64encode(credentials.password.encode()).decode()
return un, pw
def _handle_response_error_code(self, resp_dict: Any, msg: str) -> None:
error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
if error_code == SmartErrorCode.SUCCESS:
return
msg = f"{msg}: {self._host}: {error_code.name}({error_code.value})"
if error_code in SMART_TIMEOUT_ERRORS:
raise TimeoutException(msg, error_code=error_code)
if error_code in SMART_RETRYABLE_ERRORS:
raise RetryableException(msg, error_code=error_code)
if error_code in SMART_AUTHENTICATION_ERRORS:
self._state = TransportState.HANDSHAKE_REQUIRED
raise AuthenticationException(msg, error_code=error_code)
raise SmartDeviceException(msg, error_code=error_code)
async def send_secure_passthrough(self, request: str) -> Dict[str, Any]:
"""Send encrypted message as passthrough."""
url = f"http://{self._host}/app"
if self._login_token:
url += f"?token={self._login_token}"
encrypted_payload = self._encryption_session.encrypt(request.encode()) # type: ignore
passthrough_request = {
"method": "securePassthrough",
"params": {"request": encrypted_payload.decode()},
}
status_code, resp_dict = await self._http_client.post(
url,
json=passthrough_request,
headers=self.COMMON_HEADERS,
cookies_dict=self._session_cookie,
)
# _LOGGER.debug(f"secure_passthrough response is {status_code}: {resp_dict}")
if status_code != 200:
raise SmartDeviceException(
f"{self._host} responded with an unexpected "
+ f"status code {status_code} to passthrough"
)
self._handle_response_error_code(
resp_dict, "Error sending secure_passthrough message"
)
if TYPE_CHECKING:
resp_dict = cast(Dict[str, Any], resp_dict)
assert self._encryption_session is not None
raw_response: str = resp_dict["result"]["response"]
response = self._encryption_session.decrypt(raw_response.encode())
return json_loads(response) # type: ignore[return-value]
async def perform_login(self):
"""Login to the device."""
try:
await self.try_login(self._login_params)
except AuthenticationException as aex:
try:
if aex.error_code is not SmartErrorCode.LOGIN_ERROR:
raise aex
if self._default_credentials is None:
self._default_credentials = get_default_credentials(
DEFAULT_CREDENTIALS["TAPO"]
)
await self.perform_handshake()
await self.try_login(self._get_login_params(self._default_credentials))
_LOGGER.debug(
"%s: logged in with default credentials",
self._host,
)
except AuthenticationException:
raise
except Exception as ex:
raise AuthenticationException(
"Unable to login and trying default "
+ "login raised another exception: %s",
ex,
) from ex
async def try_login(self, login_params: Dict[str, Any]) -> None:
"""Try to login with supplied login_params."""
login_request = {
"method": "login_device",
"params": login_params,
"request_time_milis": round(time.time() * 1000),
}
request = json_dumps(login_request)
resp_dict = await self.send_secure_passthrough(request)
self._handle_response_error_code(resp_dict, "Error logging in")
self._login_token = resp_dict["result"]["token"]
self._state = TransportState.ESTABLISHED
async def _generate_key_pair_payload(self) -> AsyncGenerator:
"""Generate the request body and return an ascyn_generator.
This prevents the key pair being generated unless a connection
can be made to the device.
"""
_LOGGER.debug("Generating keypair")
self._key_pair = KeyPair.create_key_pair()
pub_key = (
"-----BEGIN PUBLIC KEY-----\n"
+ self._key_pair.get_public_key() # type: ignore[union-attr]
+ "\n-----END PUBLIC KEY-----\n"
)
handshake_params = {"key": pub_key}
_LOGGER.debug(f"Handshake params: {handshake_params}")
request_body = {"method": "handshake", "params": handshake_params}
_LOGGER.debug(f"Request {request_body}")
yield json_dumps(request_body).encode()
async def perform_handshake(self) -> None:
"""Perform the handshake."""
_LOGGER.debug("Will perform handshaking...")
self._key_pair = None
self._session_expire_at = None
self._session_cookie = None
url = f"http://{self._host}/app"
# Device needs the content length or it will response with 500
headers = {
**self.COMMON_HEADERS,
self.CONTENT_LENGTH: str(self.KEY_PAIR_CONTENT_LENGTH),
}
http_client = self._http_client
status_code, resp_dict = await http_client.post(
url,
json=self._generate_key_pair_payload(),
headers=headers,
cookies_dict=self._session_cookie,
)
_LOGGER.debug("Device responded with: %s", resp_dict)
if status_code != 200:
raise SmartDeviceException(
f"{self._host} responded with an unexpected "
+ f"status code {status_code} to handshake"
)
self._handle_response_error_code(resp_dict, "Unable to complete handshake")
if TYPE_CHECKING:
resp_dict = cast(Dict[str, Any], resp_dict)
handshake_key = resp_dict["result"]["key"]
if (
cookie := http_client.get_cookie( # type: ignore
self.SESSION_COOKIE_NAME
)
) or (
cookie := http_client.get_cookie("SESSIONID") # type: ignore
):
self._session_cookie = {self.SESSION_COOKIE_NAME: cookie}
timeout = int(
http_client.get_cookie(self.TIMEOUT_COOKIE_NAME) or ONE_DAY_SECONDS
)
# There is a 24 hour timeout on the session cookie
# but the clock on the device is not always accurate
# so we set the expiry to 24 hours from now minus a buffer
self._session_expire_at = time.time() + timeout - SESSION_EXPIRE_BUFFER_SECONDS
if TYPE_CHECKING:
assert self._key_pair is not None
self._encryption_session = AesEncyptionSession.create_from_keypair(
handshake_key, self._key_pair
)
self._state = TransportState.LOGIN_REQUIRED
_LOGGER.debug("Handshake with %s complete", self._host)
def _handshake_session_expired(self):
"""Return true if session has expired."""
return (
self._session_expire_at is None
or self._session_expire_at - time.time() <= 0
)
async def send(self, request: str) -> Dict[str, Any]:
"""Send the request."""
if (
self._state is TransportState.HANDSHAKE_REQUIRED
or self._handshake_session_expired()
):
await self.perform_handshake()
if self._state is not TransportState.ESTABLISHED:
try:
await self.perform_login()
# After a login failure handshake needs to
# be redone or a 9999 error is received.
except AuthenticationException as ex:
self._state = TransportState.HANDSHAKE_REQUIRED
raise ex
return await self.send_secure_passthrough(request)
async def close(self) -> None:
"""Close the http client and reset internal state."""
await self.reset()
await self._http_client.close()
async def reset(self) -> None:
"""Reset internal handshake and login state."""
self._state = TransportState.HANDSHAKE_REQUIRED
class AesEncyptionSession:
"""Class for an AES encryption session."""
@staticmethod
def create_from_keypair(handshake_key: str, keypair):
"""Create the encryption session."""
handshake_key_bytes: bytes = base64.b64decode(handshake_key.encode("UTF-8"))
private_key_data = base64.b64decode(keypair.get_private_key().encode("UTF-8"))
private_key = serialization.load_der_private_key(private_key_data, None, None)
key_and_iv = private_key.decrypt(
handshake_key_bytes, asymmetric_padding.PKCS1v15()
)
if key_and_iv is None:
raise ValueError("Decryption failed!")
return AesEncyptionSession(key_and_iv[:16], key_and_iv[16:])
def __init__(self, key, iv):
self.cipher = Cipher(algorithms.AES(key), modes.CBC(iv))
self.padding_strategy = padding.PKCS7(algorithms.AES.block_size)
def encrypt(self, data) -> bytes:
"""Encrypt the message."""
encryptor = self.cipher.encryptor()
padder = self.padding_strategy.padder()
padded_data = padder.update(data) + padder.finalize()
encrypted = encryptor.update(padded_data) + encryptor.finalize()
return base64.b64encode(encrypted)
def decrypt(self, data) -> str:
"""Decrypt the message."""
decryptor = self.cipher.decryptor()
unpadder = self.padding_strategy.unpadder()
decrypted = decryptor.update(base64.b64decode(data)) + decryptor.finalize()
unpadded_data = unpadder.update(decrypted) + unpadder.finalize()
return unpadded_data.decode()
class KeyPair:
"""Class for generating key pairs."""
@staticmethod
def create_key_pair(key_size: int = 1024):
"""Create a key pair."""
private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size)
public_key = private_key.public_key()
private_key_bytes = private_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
public_key_bytes = public_key.public_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
return KeyPair(
private_key=base64.b64encode(private_key_bytes).decode("UTF-8"),
public_key=base64.b64encode(public_key_bytes).decode("UTF-8"),
)
def __init__(self, private_key: str, public_key: str):
self.private_key = private_key
self.public_key = public_key
def get_private_key(self) -> str:
"""Get the private key."""
return self.private_key
def get_public_key(self) -> str:
"""Get the public key."""
return self.public_key