mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-08-06 10:44:04 +00:00
Move transports into their own package (#1247)
This moves all transport implementations into a new `transports` package for cleaner main package & easier to understand project structure.
This commit is contained in:
16
kasa/transports/__init__.py
Normal file
16
kasa/transports/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Package containing all supported transports."""
|
||||
|
||||
from .aestransport import AesEncyptionSession, AesTransport
|
||||
from .basetransport import BaseTransport
|
||||
from .klaptransport import KlapTransport, KlapTransportV2
|
||||
from .xortransport import XorEncryption, XorTransport
|
||||
|
||||
__all__ = [
|
||||
"AesTransport",
|
||||
"AesEncyptionSession",
|
||||
"BaseTransport",
|
||||
"KlapTransport",
|
||||
"KlapTransportV2",
|
||||
"XorTransport",
|
||||
"XorEncryption",
|
||||
]
|
499
kasa/transports/aestransport.py
Normal file
499
kasa/transports/aestransport.py
Normal file
@@ -0,0 +1,499 @@
|
||||
"""Implementation of the TP-Link AES transport.
|
||||
|
||||
Based on the work of https://github.com/petretiandrea/plugp100
|
||||
under compatible GNU GPL3 license.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from enum import Enum, auto
|
||||
from typing import TYPE_CHECKING, Any, Dict, cast
|
||||
|
||||
from cryptography.hazmat.primitives import hashes, padding, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
from yarl import URL
|
||||
|
||||
from kasa.credentials import Credentials
|
||||
from kasa.deviceconfig import DeviceConfig
|
||||
from kasa.exceptions import (
|
||||
SMART_AUTHENTICATION_ERRORS,
|
||||
SMART_RETRYABLE_ERRORS,
|
||||
AuthenticationError,
|
||||
DeviceError,
|
||||
KasaException,
|
||||
SmartErrorCode,
|
||||
TimeoutError,
|
||||
_ConnectionError,
|
||||
_RetryableError,
|
||||
)
|
||||
from kasa.httpclient import HttpClient
|
||||
from kasa.json import dumps as json_dumps
|
||||
from kasa.json import loads as json_loads
|
||||
from kasa.protocol import DEFAULT_CREDENTIALS, get_default_credentials
|
||||
|
||||
from .basetransport import BaseTransport
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
ONE_DAY_SECONDS = 86400
|
||||
SESSION_EXPIRE_BUFFER_SECONDS = 60 * 20
|
||||
|
||||
|
||||
def _sha1(payload: bytes) -> str:
|
||||
sha1_algo = hashlib.sha1() # noqa: S324
|
||||
sha1_algo.update(payload)
|
||||
return sha1_algo.hexdigest()
|
||||
|
||||
|
||||
class TransportState(Enum):
|
||||
"""Enum for AES state."""
|
||||
|
||||
HANDSHAKE_REQUIRED = auto() # Handshake needed
|
||||
LOGIN_REQUIRED = auto() # Login needed
|
||||
ESTABLISHED = auto() # Ready to send requests
|
||||
|
||||
|
||||
class AesTransport(BaseTransport):
|
||||
"""Implementation of the AES encryption protocol.
|
||||
|
||||
AES is the name used in device discovery for TP-Link's TAPO encryption
|
||||
protocol, sometimes used by newer firmware versions on kasa devices.
|
||||
"""
|
||||
|
||||
DEFAULT_PORT: int = 80
|
||||
SESSION_COOKIE_NAME = "TP_SESSIONID"
|
||||
TIMEOUT_COOKIE_NAME = "TIMEOUT"
|
||||
COMMON_HEADERS = {
|
||||
"Content-Type": "application/json",
|
||||
"requestByApp": "true",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
CONTENT_LENGTH = "Content-Length"
|
||||
KEY_PAIR_CONTENT_LENGTH = 314
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
config: DeviceConfig,
|
||||
) -> None:
|
||||
super().__init__(config=config)
|
||||
|
||||
self._login_version = config.connection_type.login_version
|
||||
if (
|
||||
not self._credentials or self._credentials.username is None
|
||||
) and not self._credentials_hash:
|
||||
self._credentials = Credentials()
|
||||
if self._credentials:
|
||||
self._login_params = self._get_login_params(self._credentials)
|
||||
else:
|
||||
self._login_params = json_loads(
|
||||
base64.b64decode(self._credentials_hash.encode()).decode() # type: ignore[union-attr]
|
||||
)
|
||||
self._default_credentials: Credentials | None = None
|
||||
self._http_client: HttpClient = HttpClient(config)
|
||||
|
||||
self._state = TransportState.HANDSHAKE_REQUIRED
|
||||
|
||||
self._encryption_session: AesEncyptionSession | None = None
|
||||
self._session_expire_at: float | None = None
|
||||
|
||||
self._session_cookie: dict[str, str] | None = None
|
||||
|
||||
self._key_pair: KeyPair | None = None
|
||||
if config.aes_keys:
|
||||
aes_keys = config.aes_keys
|
||||
self._key_pair = KeyPair.create_from_der_keys(
|
||||
aes_keys["private"], aes_keys["public"]
|
||||
)
|
||||
self._app_url = URL(f"http://{self._host}:{self._port}/app")
|
||||
self._token_url: URL | None = None
|
||||
|
||||
_LOGGER.debug("Created AES transport for %s", self._host)
|
||||
|
||||
@property
|
||||
def default_port(self) -> int:
|
||||
"""Default port for the transport."""
|
||||
return self.DEFAULT_PORT
|
||||
|
||||
@property
|
||||
def credentials_hash(self) -> str | None:
|
||||
"""The hashed credentials used by the transport."""
|
||||
if self._credentials == Credentials():
|
||||
return None
|
||||
return base64.b64encode(json_dumps(self._login_params).encode()).decode()
|
||||
|
||||
def _get_login_params(self, credentials: Credentials) -> dict[str, str]:
|
||||
"""Get the login parameters based on the login_version."""
|
||||
un, pw = self.hash_credentials(self._login_version == 2, credentials)
|
||||
password_field_name = "password2" if self._login_version == 2 else "password"
|
||||
return {password_field_name: pw, "username": un}
|
||||
|
||||
@staticmethod
|
||||
def hash_credentials(login_v2: bool, credentials: Credentials) -> tuple[str, str]:
|
||||
"""Hash the credentials."""
|
||||
un = base64.b64encode(_sha1(credentials.username.encode()).encode()).decode()
|
||||
if login_v2:
|
||||
pw = base64.b64encode(
|
||||
_sha1(credentials.password.encode()).encode()
|
||||
).decode()
|
||||
else:
|
||||
pw = base64.b64encode(credentials.password.encode()).decode()
|
||||
return un, pw
|
||||
|
||||
def _handle_response_error_code(self, resp_dict: dict, msg: str) -> None:
|
||||
error_code_raw = resp_dict.get("error_code")
|
||||
try:
|
||||
error_code = SmartErrorCode.from_int(error_code_raw)
|
||||
except ValueError:
|
||||
_LOGGER.warning(
|
||||
"Device %s received unknown error code: %s", self._host, error_code_raw
|
||||
)
|
||||
error_code = SmartErrorCode.INTERNAL_UNKNOWN_ERROR
|
||||
if error_code is SmartErrorCode.SUCCESS:
|
||||
return
|
||||
msg = f"{msg}: {self._host}: {error_code.name}({error_code.value})"
|
||||
if error_code in SMART_RETRYABLE_ERRORS:
|
||||
raise _RetryableError(msg, error_code=error_code)
|
||||
if error_code in SMART_AUTHENTICATION_ERRORS:
|
||||
self._state = TransportState.HANDSHAKE_REQUIRED
|
||||
raise AuthenticationError(msg, error_code=error_code)
|
||||
raise DeviceError(msg, error_code=error_code)
|
||||
|
||||
async def send_secure_passthrough(self, request: str) -> dict[str, Any]:
|
||||
"""Send encrypted message as passthrough."""
|
||||
if self._state is TransportState.ESTABLISHED and self._token_url:
|
||||
url = self._token_url
|
||||
else:
|
||||
url = self._app_url
|
||||
|
||||
encrypted_payload = self._encryption_session.encrypt(request.encode()) # type: ignore
|
||||
passthrough_request = {
|
||||
"method": "securePassthrough",
|
||||
"params": {"request": encrypted_payload.decode()},
|
||||
}
|
||||
status_code, resp_dict = await self._http_client.post(
|
||||
url,
|
||||
json=passthrough_request,
|
||||
headers=self.COMMON_HEADERS,
|
||||
cookies_dict=self._session_cookie,
|
||||
)
|
||||
# _LOGGER.debug(f"secure_passthrough response is {status_code}: {resp_dict}")
|
||||
|
||||
if status_code != 200:
|
||||
raise KasaException(
|
||||
f"{self._host} responded with an unexpected "
|
||||
+ f"status code {status_code} to passthrough"
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
resp_dict = cast(Dict[str, Any], resp_dict)
|
||||
assert self._encryption_session is not None
|
||||
|
||||
self._handle_response_error_code(
|
||||
resp_dict, "Error sending secure_passthrough message"
|
||||
)
|
||||
|
||||
raw_response: str = resp_dict["result"]["response"]
|
||||
|
||||
try:
|
||||
response = self._encryption_session.decrypt(raw_response.encode())
|
||||
ret_val = json_loads(response)
|
||||
except Exception as ex:
|
||||
try:
|
||||
ret_val = json_loads(raw_response)
|
||||
_LOGGER.debug(
|
||||
"Received unencrypted response over secure passthrough from %s",
|
||||
self._host,
|
||||
)
|
||||
except Exception:
|
||||
raise KasaException(
|
||||
f"Unable to decrypt response from {self._host}, "
|
||||
+ f"error: {ex}, response: {raw_response}",
|
||||
ex,
|
||||
) from ex
|
||||
return ret_val # type: ignore[return-value]
|
||||
|
||||
async def perform_login(self) -> None:
|
||||
"""Login to the device."""
|
||||
try:
|
||||
await self.try_login(self._login_params)
|
||||
_LOGGER.debug(
|
||||
"%s: logged in with provided credentials",
|
||||
self._host,
|
||||
)
|
||||
except AuthenticationError as aex:
|
||||
try:
|
||||
if aex.error_code is not SmartErrorCode.LOGIN_ERROR:
|
||||
raise aex
|
||||
_LOGGER.debug(
|
||||
"%s: trying login with default TAPO credentials",
|
||||
self._host,
|
||||
)
|
||||
if self._default_credentials is None:
|
||||
self._default_credentials = get_default_credentials(
|
||||
DEFAULT_CREDENTIALS["TAPO"]
|
||||
)
|
||||
await self.perform_handshake()
|
||||
await self.try_login(self._get_login_params(self._default_credentials))
|
||||
_LOGGER.debug(
|
||||
"%s: logged in with default TAPO credentials",
|
||||
self._host,
|
||||
)
|
||||
except (AuthenticationError, _ConnectionError, TimeoutError):
|
||||
raise
|
||||
except Exception as ex:
|
||||
raise KasaException(
|
||||
"Unable to login and trying default "
|
||||
+ f"login raised another exception: {ex}",
|
||||
ex,
|
||||
) from ex
|
||||
|
||||
async def try_login(self, login_params: dict[str, Any]) -> None:
|
||||
"""Try to login with supplied login_params."""
|
||||
login_request = {
|
||||
"method": "login_device",
|
||||
"params": login_params,
|
||||
"request_time_milis": round(time.time() * 1000),
|
||||
}
|
||||
request = json_dumps(login_request)
|
||||
|
||||
resp_dict = await self.send_secure_passthrough(request)
|
||||
self._handle_response_error_code(resp_dict, "Error logging in")
|
||||
login_token = resp_dict["result"]["token"]
|
||||
self._token_url = self._app_url.with_query(f"token={login_token}")
|
||||
self._state = TransportState.ESTABLISHED
|
||||
|
||||
async def _generate_key_pair_payload(self) -> AsyncGenerator:
|
||||
"""Generate the request body and return an ascyn_generator.
|
||||
|
||||
This prevents the key pair being generated unless a connection
|
||||
can be made to the device.
|
||||
"""
|
||||
_LOGGER.debug("Generating keypair")
|
||||
if not self._key_pair:
|
||||
kp = KeyPair.create_key_pair()
|
||||
self._config.aes_keys = {
|
||||
"private": kp.private_key_der_b64,
|
||||
"public": kp.public_key_der_b64,
|
||||
}
|
||||
self._key_pair = kp
|
||||
|
||||
pub_key = (
|
||||
"-----BEGIN PUBLIC KEY-----\n"
|
||||
+ self._key_pair.public_key_der_b64 # type: ignore[union-attr]
|
||||
+ "\n-----END PUBLIC KEY-----\n"
|
||||
)
|
||||
handshake_params = {"key": pub_key}
|
||||
request_body = {"method": "handshake", "params": handshake_params}
|
||||
_LOGGER.debug("Handshake request: %s", request_body)
|
||||
yield json_dumps(request_body).encode()
|
||||
|
||||
async def perform_handshake(self) -> None:
|
||||
"""Perform the handshake."""
|
||||
_LOGGER.debug("Will perform handshaking...")
|
||||
|
||||
self._token_url = None
|
||||
self._session_expire_at = None
|
||||
self._session_cookie = None
|
||||
|
||||
# Device needs the content length or it will response with 500
|
||||
headers = {
|
||||
**self.COMMON_HEADERS,
|
||||
self.CONTENT_LENGTH: str(self.KEY_PAIR_CONTENT_LENGTH),
|
||||
}
|
||||
http_client = self._http_client
|
||||
|
||||
status_code, resp_dict = await http_client.post(
|
||||
self._app_url,
|
||||
json=self._generate_key_pair_payload(),
|
||||
headers=headers,
|
||||
cookies_dict=self._session_cookie,
|
||||
)
|
||||
|
||||
_LOGGER.debug("Device responded with: %s", resp_dict)
|
||||
|
||||
if status_code != 200:
|
||||
raise KasaException(
|
||||
f"{self._host} responded with an unexpected "
|
||||
+ f"status code {status_code} to handshake"
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
resp_dict = cast(Dict[str, Any], resp_dict)
|
||||
|
||||
self._handle_response_error_code(resp_dict, "Unable to complete handshake")
|
||||
|
||||
handshake_key = resp_dict["result"]["key"]
|
||||
|
||||
if (
|
||||
cookie := http_client.get_cookie(self.SESSION_COOKIE_NAME) # type: ignore
|
||||
) or (
|
||||
cookie := http_client.get_cookie("SESSIONID") # type: ignore
|
||||
):
|
||||
self._session_cookie = {self.SESSION_COOKIE_NAME: cookie}
|
||||
|
||||
timeout = int(
|
||||
http_client.get_cookie(self.TIMEOUT_COOKIE_NAME) or ONE_DAY_SECONDS
|
||||
)
|
||||
# There is a 24 hour timeout on the session cookie
|
||||
# but the clock on the device is not always accurate
|
||||
# so we set the expiry to 24 hours from now minus a buffer
|
||||
self._session_expire_at = time.time() + timeout - SESSION_EXPIRE_BUFFER_SECONDS
|
||||
if TYPE_CHECKING:
|
||||
assert self._key_pair is not None
|
||||
self._encryption_session = AesEncyptionSession.create_from_keypair(
|
||||
handshake_key, self._key_pair
|
||||
)
|
||||
|
||||
self._state = TransportState.LOGIN_REQUIRED
|
||||
|
||||
_LOGGER.debug("Handshake with %s complete", self._host)
|
||||
|
||||
def _handshake_session_expired(self) -> bool:
|
||||
"""Return true if session has expired."""
|
||||
return (
|
||||
self._session_expire_at is None
|
||||
or self._session_expire_at - time.time() <= 0
|
||||
)
|
||||
|
||||
async def send(self, request: str) -> dict[str, Any]:
|
||||
"""Send the request."""
|
||||
if (
|
||||
self._state is TransportState.HANDSHAKE_REQUIRED
|
||||
or self._handshake_session_expired()
|
||||
):
|
||||
await self.perform_handshake()
|
||||
if self._state is not TransportState.ESTABLISHED:
|
||||
try:
|
||||
await self.perform_login()
|
||||
# After a login failure handshake needs to
|
||||
# be redone or a 9999 error is received.
|
||||
except AuthenticationError as ex:
|
||||
self._state = TransportState.HANDSHAKE_REQUIRED
|
||||
raise ex
|
||||
|
||||
return await self.send_secure_passthrough(request)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the http client and reset internal state."""
|
||||
await self.reset()
|
||||
await self._http_client.close()
|
||||
|
||||
async def reset(self) -> None:
|
||||
"""Reset internal handshake and login state."""
|
||||
self._state = TransportState.HANDSHAKE_REQUIRED
|
||||
|
||||
|
||||
class AesEncyptionSession:
|
||||
"""Class for an AES encryption session."""
|
||||
|
||||
@staticmethod
|
||||
def create_from_keypair(
|
||||
handshake_key: str, keypair: KeyPair
|
||||
) -> AesEncyptionSession:
|
||||
"""Create the encryption session."""
|
||||
handshake_key_bytes: bytes = base64.b64decode(handshake_key.encode())
|
||||
|
||||
key_and_iv = keypair.decrypt_handshake_key(handshake_key_bytes)
|
||||
if key_and_iv is None:
|
||||
raise ValueError("Decryption failed!")
|
||||
|
||||
return AesEncyptionSession(key_and_iv[:16], key_and_iv[16:])
|
||||
|
||||
def __init__(self, key: bytes, iv: bytes) -> None:
|
||||
self.cipher = Cipher(algorithms.AES(key), modes.CBC(iv))
|
||||
self.padding_strategy = padding.PKCS7(algorithms.AES.block_size)
|
||||
|
||||
def encrypt(self, data: bytes) -> bytes:
|
||||
"""Encrypt the message."""
|
||||
encryptor = self.cipher.encryptor()
|
||||
padder = self.padding_strategy.padder()
|
||||
padded_data = padder.update(data) + padder.finalize()
|
||||
encrypted = encryptor.update(padded_data) + encryptor.finalize()
|
||||
return base64.b64encode(encrypted)
|
||||
|
||||
def decrypt(self, data: str | bytes) -> str:
|
||||
"""Decrypt the message."""
|
||||
decryptor = self.cipher.decryptor()
|
||||
unpadder = self.padding_strategy.unpadder()
|
||||
decrypted = decryptor.update(base64.b64decode(data)) + decryptor.finalize()
|
||||
unpadded_data = unpadder.update(decrypted) + unpadder.finalize()
|
||||
return unpadded_data.decode()
|
||||
|
||||
|
||||
class KeyPair:
|
||||
"""Class for generating key pairs."""
|
||||
|
||||
@staticmethod
|
||||
def create_key_pair(key_size: int = 1024) -> KeyPair:
|
||||
"""Create a key pair."""
|
||||
private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size)
|
||||
public_key = private_key.public_key()
|
||||
return KeyPair(private_key, public_key)
|
||||
|
||||
@staticmethod
|
||||
def create_from_der_keys(
|
||||
private_key_der_b64: str, public_key_der_b64: str
|
||||
) -> KeyPair:
|
||||
"""Create a key pair."""
|
||||
key_bytes = base64.b64decode(private_key_der_b64.encode())
|
||||
private_key = cast(
|
||||
rsa.RSAPrivateKey, serialization.load_der_private_key(key_bytes, None)
|
||||
)
|
||||
key_bytes = base64.b64decode(public_key_der_b64.encode())
|
||||
public_key = cast(
|
||||
rsa.RSAPublicKey, serialization.load_der_public_key(key_bytes, None)
|
||||
)
|
||||
|
||||
return KeyPair(private_key, public_key)
|
||||
|
||||
def __init__(
|
||||
self, private_key: rsa.RSAPrivateKey, public_key: rsa.RSAPublicKey
|
||||
) -> None:
|
||||
self.private_key = private_key
|
||||
self.public_key = public_key
|
||||
self.private_key_der_bytes = self.private_key.private_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
self.public_key_der_bytes = self.public_key.public_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
self.private_key_der_b64 = base64.b64encode(self.private_key_der_bytes).decode()
|
||||
self.public_key_der_b64 = base64.b64encode(self.public_key_der_bytes).decode()
|
||||
|
||||
def get_public_pem(self) -> bytes:
|
||||
"""Get public key in PEM encoding."""
|
||||
return self.public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
def decrypt_handshake_key(self, encrypted_key: bytes) -> bytes:
|
||||
"""Decrypt an aes handshake key."""
|
||||
decrypted = self.private_key.decrypt(
|
||||
encrypted_key, asymmetric_padding.PKCS1v15()
|
||||
)
|
||||
return decrypted
|
||||
|
||||
def decrypt_discovery_key(self, encrypted_key: bytes) -> bytes:
|
||||
"""Decrypt an aes discovery key."""
|
||||
decrypted = self.private_key.decrypt(
|
||||
encrypted_key,
|
||||
asymmetric_padding.OAEP(
|
||||
mgf=asymmetric_padding.MGF1(algorithm=hashes.SHA1()), # noqa: S303
|
||||
algorithm=hashes.SHA1(), # noqa: S303
|
||||
label=None,
|
||||
),
|
||||
)
|
||||
return decrypted
|
55
kasa/transports/basetransport.py
Normal file
55
kasa/transports/basetransport.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""Base class for all transport implementations.
|
||||
|
||||
All transport classes must derive from this to implement the common interface.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kasa import DeviceConfig
|
||||
|
||||
|
||||
class BaseTransport(ABC):
|
||||
"""Base class for all TP-Link protocol transports."""
|
||||
|
||||
DEFAULT_TIMEOUT = 5
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
config: DeviceConfig,
|
||||
) -> None:
|
||||
"""Create a protocol object."""
|
||||
self._config = config
|
||||
self._host = config.host
|
||||
self._port = config.port_override or self.default_port
|
||||
self._credentials = config.credentials
|
||||
self._credentials_hash = config.credentials_hash
|
||||
if not config.timeout:
|
||||
config.timeout = self.DEFAULT_TIMEOUT
|
||||
self._timeout = config.timeout
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def default_port(self) -> int:
|
||||
"""The default port for the transport."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def credentials_hash(self) -> str | None:
|
||||
"""The hashed credentials used by the transport."""
|
||||
|
||||
@abstractmethod
|
||||
async def send(self, request: str) -> dict:
|
||||
"""Send a message to the device and return a response."""
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
"""Close the transport. Abstract method to be overriden."""
|
||||
|
||||
@abstractmethod
|
||||
async def reset(self) -> None:
|
||||
"""Reset internal state."""
|
512
kasa/transports/klaptransport.py
Normal file
512
kasa/transports/klaptransport.py
Normal file
@@ -0,0 +1,512 @@
|
||||
"""Implementation of the TP-Link Klap Home Protocol.
|
||||
|
||||
Encryption/Decryption methods based on the works of
|
||||
Simon Wilkinson and Chris Weeldon
|
||||
|
||||
Klap devices that have never been connected to the kasa
|
||||
cloud should work with blank credentials.
|
||||
Devices that have been connected to the kasa cloud will
|
||||
switch intermittently between the users cloud credentials
|
||||
and default kasa credentials that are hardcoded.
|
||||
This appears to be an issue with the devices.
|
||||
|
||||
The protocol works by doing a two stage handshake to obtain
|
||||
and encryption key and session id cookie.
|
||||
|
||||
Authentication uses an auth_hash which is
|
||||
md5(md5(username),md5(password))
|
||||
|
||||
handshake1: client sends a random 16 byte local_seed to the
|
||||
device and receives a random 16 bytes remote_seed, followed
|
||||
by sha256(local_seed + auth_hash). It also returns a
|
||||
TP_SESSIONID in the cookie header. This implementation
|
||||
then checks this value against the possible auth_hashes
|
||||
described above (user cloud, kasa hardcoded, blank). If it
|
||||
finds a match it moves onto handshake2
|
||||
|
||||
handshake2: client sends sha25(remote_seed + auth_hash) to
|
||||
the device along with the TP_SESSIONID. Device responds with
|
||||
200 if successful. It generally will be because this
|
||||
implementation checks the auth_hash it received during handshake1
|
||||
|
||||
encryption: local_seed, remote_seed and auth_hash are now used
|
||||
for encryption. The last 4 bytes of the initialization vector
|
||||
are used as a sequence number that increments every time the
|
||||
client calls encrypt and this sequence number is sent as a
|
||||
url parameter to the device along with the encrypted payload
|
||||
|
||||
https://gist.github.com/chriswheeldon/3b17d974db3817613c69191c0480fe55
|
||||
https://github.com/python-kasa/python-kasa/pull/117
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import datetime
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
import struct
|
||||
import time
|
||||
from asyncio import Future
|
||||
from typing import TYPE_CHECKING, Any, Generator, cast
|
||||
|
||||
from cryptography.hazmat.primitives import padding
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
from yarl import URL
|
||||
|
||||
from kasa.credentials import Credentials
|
||||
from kasa.deviceconfig import DeviceConfig
|
||||
from kasa.exceptions import AuthenticationError, KasaException, _RetryableError
|
||||
from kasa.httpclient import HttpClient
|
||||
from kasa.json import loads as json_loads
|
||||
from kasa.protocol import (
|
||||
DEFAULT_CREDENTIALS,
|
||||
get_default_credentials,
|
||||
md5,
|
||||
)
|
||||
|
||||
from .basetransport import BaseTransport
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
ONE_DAY_SECONDS = 86400
|
||||
SESSION_EXPIRE_BUFFER_SECONDS = 60 * 20
|
||||
|
||||
PACK_SIGNED_LONG = struct.Struct(">l").pack
|
||||
|
||||
|
||||
def _sha256(payload: bytes) -> bytes:
|
||||
return hashlib.sha256(payload).digest() # noqa: S324
|
||||
|
||||
|
||||
def _sha1(payload: bytes) -> bytes:
|
||||
return hashlib.sha1(payload).digest() # noqa: S324
|
||||
|
||||
|
||||
class KlapTransport(BaseTransport):
|
||||
"""Implementation of the KLAP encryption protocol.
|
||||
|
||||
KLAP is the name used in device discovery for TP-Link's new encryption
|
||||
protocol, used by newer firmware versions.
|
||||
"""
|
||||
|
||||
DEFAULT_PORT: int = 80
|
||||
SESSION_COOKIE_NAME = "TP_SESSIONID"
|
||||
TIMEOUT_COOKIE_NAME = "TIMEOUT"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
config: DeviceConfig,
|
||||
) -> None:
|
||||
super().__init__(config=config)
|
||||
|
||||
self._http_client = HttpClient(config)
|
||||
self._local_seed: bytes | None = None
|
||||
if (
|
||||
not self._credentials or self._credentials.username is None
|
||||
) and not self._credentials_hash:
|
||||
self._credentials = Credentials()
|
||||
if self._credentials:
|
||||
self._local_auth_hash = self.generate_auth_hash(self._credentials)
|
||||
self._local_auth_owner = self.generate_owner_hash(self._credentials).hex()
|
||||
else:
|
||||
self._local_auth_hash = base64.b64decode(self._credentials_hash.encode()) # type: ignore[union-attr]
|
||||
self._default_credentials_auth_hash: dict[str, bytes] = {}
|
||||
self._blank_auth_hash: bytes | None = None
|
||||
self._handshake_lock = asyncio.Lock()
|
||||
self._query_lock = asyncio.Lock()
|
||||
self._handshake_done: bool = False
|
||||
|
||||
self._encryption_session: KlapEncryptionSession | None = None
|
||||
self._session_expire_at: float | None = None
|
||||
|
||||
self._session_cookie: dict[str, Any] | None = None
|
||||
|
||||
_LOGGER.debug("Created KLAP transport for %s", self._host)
|
||||
self._app_url = URL(f"http://{self._host}:{self._port}/app")
|
||||
self._request_url = self._app_url / "request"
|
||||
|
||||
@property
|
||||
def default_port(self) -> int:
|
||||
"""Default port for the transport."""
|
||||
return self.DEFAULT_PORT
|
||||
|
||||
@property
|
||||
def credentials_hash(self) -> str | None:
|
||||
"""The hashed credentials used by the transport."""
|
||||
if self._credentials == Credentials():
|
||||
return None
|
||||
return base64.b64encode(self._local_auth_hash).decode()
|
||||
|
||||
async def perform_handshake1(self) -> tuple[bytes, bytes, bytes]:
|
||||
"""Perform handshake1."""
|
||||
local_seed: bytes = secrets.token_bytes(16)
|
||||
|
||||
# Handshake 1 has a payload of local_seed
|
||||
# and a response of 16 bytes, followed by
|
||||
# sha256(remote_seed | auth_hash)
|
||||
|
||||
payload = local_seed
|
||||
|
||||
url = self._app_url / "handshake1"
|
||||
|
||||
response_status, response_data = await self._http_client.post(url, data=payload)
|
||||
|
||||
if _LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_LOGGER.debug(
|
||||
"Handshake1 posted at %s. Host is %s, "
|
||||
"Response status is %s, Request was %s",
|
||||
datetime.datetime.now(),
|
||||
self._host,
|
||||
response_status,
|
||||
payload.hex(),
|
||||
)
|
||||
|
||||
if response_status != 200:
|
||||
raise KasaException(
|
||||
f"Device {self._host} responded with {response_status} to handshake1"
|
||||
)
|
||||
|
||||
response_data = cast(bytes, response_data)
|
||||
remote_seed: bytes = response_data[0:16]
|
||||
server_hash = response_data[16:]
|
||||
|
||||
if len(server_hash) != 32:
|
||||
raise KasaException(
|
||||
f"Device {self._host} responded with unexpected klap response "
|
||||
+ f"{response_data!r} to handshake1"
|
||||
)
|
||||
|
||||
if _LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_LOGGER.debug(
|
||||
"Handshake1 success at %s. Host is %s, "
|
||||
"Server remote_seed is: %s, server hash is: %s",
|
||||
datetime.datetime.now(),
|
||||
self._host,
|
||||
remote_seed.hex(),
|
||||
server_hash.hex(),
|
||||
)
|
||||
|
||||
local_seed_auth_hash = self.handshake1_seed_auth_hash(
|
||||
local_seed, remote_seed, self._local_auth_hash
|
||||
) # type: ignore
|
||||
|
||||
# Check the response from the device with local credentials
|
||||
if local_seed_auth_hash == server_hash:
|
||||
_LOGGER.debug("handshake1 hashes match with expected credentials")
|
||||
return local_seed, remote_seed, self._local_auth_hash # type: ignore
|
||||
|
||||
# Now check against the default setup credentials
|
||||
for key, value in DEFAULT_CREDENTIALS.items():
|
||||
if key not in self._default_credentials_auth_hash:
|
||||
default_credentials = get_default_credentials(value)
|
||||
self._default_credentials_auth_hash[key] = self.generate_auth_hash(
|
||||
default_credentials
|
||||
)
|
||||
|
||||
default_credentials_seed_auth_hash = self.handshake1_seed_auth_hash(
|
||||
local_seed,
|
||||
remote_seed,
|
||||
self._default_credentials_auth_hash[key], # type: ignore
|
||||
)
|
||||
|
||||
if default_credentials_seed_auth_hash == server_hash:
|
||||
_LOGGER.debug(
|
||||
"Server response doesn't match our expected hash on ip %s, "
|
||||
"but an authentication with %s default credentials matched",
|
||||
self._host,
|
||||
key,
|
||||
)
|
||||
return local_seed, remote_seed, self._default_credentials_auth_hash[key] # type: ignore
|
||||
|
||||
# Finally check against blank credentials if not already blank
|
||||
blank_creds = Credentials()
|
||||
if self._credentials != blank_creds:
|
||||
if not self._blank_auth_hash:
|
||||
self._blank_auth_hash = self.generate_auth_hash(blank_creds)
|
||||
|
||||
blank_seed_auth_hash = self.handshake1_seed_auth_hash(
|
||||
local_seed,
|
||||
remote_seed,
|
||||
self._blank_auth_hash, # type: ignore
|
||||
)
|
||||
|
||||
if blank_seed_auth_hash == server_hash:
|
||||
_LOGGER.debug(
|
||||
"Server response doesn't match our expected hash on ip %s, "
|
||||
"but an authentication with blank credentials matched",
|
||||
self._host,
|
||||
)
|
||||
return local_seed, remote_seed, self._blank_auth_hash # type: ignore
|
||||
|
||||
msg = f"Server response doesn't match our challenge on ip {self._host}"
|
||||
_LOGGER.debug(msg)
|
||||
raise AuthenticationError(msg)
|
||||
|
||||
async def perform_handshake2(
|
||||
self, local_seed: bytes, remote_seed: bytes, auth_hash: bytes
|
||||
) -> KlapEncryptionSession:
|
||||
"""Perform handshake2."""
|
||||
# Handshake 2 has the following payload:
|
||||
# sha256(serverBytes | authenticator)
|
||||
|
||||
url = self._app_url / "handshake2"
|
||||
|
||||
payload = self.handshake2_seed_auth_hash(local_seed, remote_seed, auth_hash)
|
||||
|
||||
response_status, _ = await self._http_client.post(
|
||||
url,
|
||||
data=payload,
|
||||
cookies_dict=self._session_cookie,
|
||||
)
|
||||
|
||||
if _LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_LOGGER.debug(
|
||||
"Handshake2 posted %s. Host is %s, "
|
||||
"Response status is %s, Request was %s",
|
||||
datetime.datetime.now(),
|
||||
self._host,
|
||||
response_status,
|
||||
payload.hex(),
|
||||
)
|
||||
|
||||
if response_status != 200:
|
||||
# This shouldn't be caused by incorrect
|
||||
# credentials so don't raise AuthenticationError
|
||||
raise KasaException(
|
||||
f"Device {self._host} responded with {response_status} to handshake2"
|
||||
)
|
||||
|
||||
return KlapEncryptionSession(local_seed, remote_seed, auth_hash)
|
||||
|
||||
async def perform_handshake(self) -> None:
|
||||
"""Perform handshake1 and handshake2.
|
||||
|
||||
Sets the encryption_session if successful.
|
||||
"""
|
||||
_LOGGER.debug("Starting handshake with %s", self._host)
|
||||
self._handshake_done = False
|
||||
self._session_expire_at = None
|
||||
self._session_cookie = None
|
||||
|
||||
local_seed, remote_seed, auth_hash = await self.perform_handshake1()
|
||||
http_client = self._http_client
|
||||
if cookie := http_client.get_cookie(self.SESSION_COOKIE_NAME): # type: ignore
|
||||
self._session_cookie = {self.SESSION_COOKIE_NAME: cookie}
|
||||
# The device returns a TIMEOUT cookie on handshake1 which
|
||||
# it doesn't like to get back so we store the one we want
|
||||
timeout = int(
|
||||
http_client.get_cookie(self.TIMEOUT_COOKIE_NAME) or ONE_DAY_SECONDS
|
||||
)
|
||||
# There is a 24 hour timeout on the session cookie
|
||||
# but the clock on the device is not always accurate
|
||||
# so we set the expiry to 24 hours from now minus a buffer
|
||||
self._session_expire_at = (
|
||||
time.monotonic() + timeout - SESSION_EXPIRE_BUFFER_SECONDS
|
||||
)
|
||||
self._encryption_session = await self.perform_handshake2(
|
||||
local_seed, remote_seed, auth_hash
|
||||
)
|
||||
self._handshake_done = True
|
||||
|
||||
_LOGGER.debug("Handshake with %s complete", self._host)
|
||||
|
||||
def _handshake_session_expired(self) -> bool:
|
||||
"""Return true if session has expired."""
|
||||
return (
|
||||
self._session_expire_at is None
|
||||
or self._session_expire_at - time.monotonic() <= 0
|
||||
)
|
||||
|
||||
async def send(self, request: str) -> Generator[Future, None, dict[str, str]]: # type: ignore[override]
|
||||
"""Send the request."""
|
||||
if not self._handshake_done or self._handshake_session_expired():
|
||||
await self.perform_handshake()
|
||||
|
||||
# Check for mypy
|
||||
if self._encryption_session is not None:
|
||||
payload, seq = self._encryption_session.encrypt(request.encode())
|
||||
|
||||
response_status, response_data = await self._http_client.post(
|
||||
self._request_url,
|
||||
params={"seq": seq},
|
||||
data=payload,
|
||||
cookies_dict=self._session_cookie,
|
||||
)
|
||||
|
||||
msg = (
|
||||
f"Host is {self._host}, "
|
||||
+ f"Sequence is {seq}, "
|
||||
+ f"Response status is {response_status}, Request was {request}"
|
||||
)
|
||||
if response_status != 200:
|
||||
_LOGGER.error("Query failed after successful authentication: %s", msg)
|
||||
# If we failed with a security error, force a new handshake next time.
|
||||
if response_status == 403:
|
||||
self._handshake_done = False
|
||||
raise _RetryableError(
|
||||
"Got a security error from %s after handshake completed", self._host
|
||||
)
|
||||
else:
|
||||
raise KasaException(
|
||||
f"Device {self._host} responded with {response_status} to "
|
||||
f"request with seq {seq}"
|
||||
)
|
||||
else:
|
||||
_LOGGER.debug("Device %s query posted %s", self._host, msg)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
assert self._encryption_session
|
||||
assert isinstance(response_data, bytes)
|
||||
try:
|
||||
decrypted_response = self._encryption_session.decrypt(response_data)
|
||||
except Exception as ex:
|
||||
raise KasaException(
|
||||
f"Error trying to decrypt device {self._host} response: {ex}"
|
||||
) from ex
|
||||
|
||||
json_payload = json_loads(decrypted_response)
|
||||
|
||||
_LOGGER.debug("Device %s query response received", self._host)
|
||||
|
||||
return json_payload
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the http client and reset internal state."""
|
||||
await self.reset()
|
||||
await self._http_client.close()
|
||||
|
||||
async def reset(self) -> None:
|
||||
"""Reset internal handshake state."""
|
||||
self._handshake_done = False
|
||||
|
||||
@staticmethod
|
||||
def generate_auth_hash(creds: Credentials) -> bytes:
|
||||
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
|
||||
un = creds.username
|
||||
pw = creds.password
|
||||
|
||||
return md5(md5(un.encode()) + md5(pw.encode()))
|
||||
|
||||
@staticmethod
|
||||
def handshake1_seed_auth_hash(
|
||||
local_seed: bytes, remote_seed: bytes, auth_hash: bytes
|
||||
) -> bytes:
|
||||
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
|
||||
return _sha256(local_seed + auth_hash)
|
||||
|
||||
@staticmethod
|
||||
def handshake2_seed_auth_hash(
|
||||
local_seed: bytes, remote_seed: bytes, auth_hash: bytes
|
||||
) -> bytes:
|
||||
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
|
||||
return _sha256(remote_seed + auth_hash)
|
||||
|
||||
@staticmethod
|
||||
def generate_owner_hash(creds: Credentials) -> bytes:
|
||||
"""Return the MD5 hash of the username in this object."""
|
||||
un = creds.username
|
||||
return md5(un.encode())
|
||||
|
||||
|
||||
class KlapTransportV2(KlapTransport):
|
||||
"""Implementation of the KLAP encryption protocol with v2 hanshake hashes."""
|
||||
|
||||
@staticmethod
|
||||
def generate_auth_hash(creds: Credentials) -> bytes:
|
||||
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
|
||||
un = creds.username
|
||||
pw = creds.password
|
||||
|
||||
return _sha256(_sha1(un.encode()) + _sha1(pw.encode()))
|
||||
|
||||
@staticmethod
|
||||
def handshake1_seed_auth_hash(
|
||||
local_seed: bytes, remote_seed: bytes, auth_hash: bytes
|
||||
) -> bytes:
|
||||
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
|
||||
return _sha256(local_seed + remote_seed + auth_hash)
|
||||
|
||||
@staticmethod
|
||||
def handshake2_seed_auth_hash(
|
||||
local_seed: bytes, remote_seed: bytes, auth_hash: bytes
|
||||
) -> bytes:
|
||||
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
|
||||
return _sha256(remote_seed + local_seed + auth_hash)
|
||||
|
||||
|
||||
class KlapEncryptionSession:
|
||||
"""Class to represent an encryption session and it's internal state.
|
||||
|
||||
i.e. sequence number which the device expects to increment.
|
||||
"""
|
||||
|
||||
_cipher: Cipher
|
||||
|
||||
def __init__(self, local_seed: bytes, remote_seed: bytes, user_hash: bytes) -> None:
|
||||
self.local_seed = local_seed
|
||||
self.remote_seed = remote_seed
|
||||
self.user_hash = user_hash
|
||||
self._key = self._key_derive(local_seed, remote_seed, user_hash)
|
||||
(self._iv, self._seq) = self._iv_derive(local_seed, remote_seed, user_hash)
|
||||
self._aes = algorithms.AES(self._key)
|
||||
self._sig = self._sig_derive(local_seed, remote_seed, user_hash)
|
||||
|
||||
def _key_derive(
|
||||
self, local_seed: bytes, remote_seed: bytes, user_hash: bytes
|
||||
) -> bytes:
|
||||
payload = b"lsk" + local_seed + remote_seed + user_hash
|
||||
return hashlib.sha256(payload).digest()[:16]
|
||||
|
||||
def _iv_derive(
|
||||
self, local_seed: bytes, remote_seed: bytes, user_hash: bytes
|
||||
) -> tuple[bytes, int]:
|
||||
# iv is first 16 bytes of sha256, where the last 4 bytes forms the
|
||||
# sequence number used in requests and is incremented on each request
|
||||
payload = b"iv" + local_seed + remote_seed + user_hash
|
||||
fulliv = hashlib.sha256(payload).digest()
|
||||
seq = int.from_bytes(fulliv[-4:], "big", signed=True)
|
||||
return (fulliv[:12], seq)
|
||||
|
||||
def _sig_derive(
|
||||
self, local_seed: bytes, remote_seed: bytes, user_hash: bytes
|
||||
) -> bytes:
|
||||
# used to create a hash with which to prefix each request
|
||||
payload = b"ldk" + local_seed + remote_seed + user_hash
|
||||
return hashlib.sha256(payload).digest()[:28]
|
||||
|
||||
def _generate_cipher(self) -> None:
|
||||
iv_seq = self._iv + PACK_SIGNED_LONG(self._seq)
|
||||
cbc = modes.CBC(iv_seq)
|
||||
self._cipher = Cipher(self._aes, cbc)
|
||||
|
||||
def encrypt(self, msg: bytes | str) -> tuple[bytes, int]:
|
||||
"""Encrypt the data and increment the sequence number."""
|
||||
self._seq += 1
|
||||
self._generate_cipher()
|
||||
|
||||
if isinstance(msg, str):
|
||||
msg = msg.encode("utf-8")
|
||||
|
||||
encryptor = self._cipher.encryptor()
|
||||
padder = padding.PKCS7(128).padder()
|
||||
padded_data = padder.update(msg) + padder.finalize()
|
||||
ciphertext = encryptor.update(padded_data) + encryptor.finalize()
|
||||
signature = hashlib.sha256(
|
||||
self._sig + PACK_SIGNED_LONG(self._seq) + ciphertext
|
||||
).digest()
|
||||
return (signature + ciphertext, self._seq)
|
||||
|
||||
def decrypt(self, msg: bytes) -> str:
|
||||
"""Decrypt the data."""
|
||||
decryptor = self._cipher.decryptor()
|
||||
dp = decryptor.update(msg[32:]) + decryptor.finalize()
|
||||
unpadder = padding.PKCS7(128).unpadder()
|
||||
plaintextbytes = unpadder.update(dp) + unpadder.finalize()
|
||||
|
||||
return plaintextbytes.decode()
|
234
kasa/transports/xortransport.py
Normal file
234
kasa/transports/xortransport.py
Normal file
@@ -0,0 +1,234 @@
|
||||
"""Implementation of the legacy TP-Link Smart Home Protocol.
|
||||
|
||||
Encryption/Decryption methods based on the works of
|
||||
Lubomir Stroetmann and Tobias Esser
|
||||
|
||||
https://www.softscheck.com/en/reverse-engineering-tp-link-hs110/
|
||||
https://github.com/softScheck/tplink-smartplug/
|
||||
|
||||
which are licensed under the Apache License, Version 2.0
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import errno
|
||||
import logging
|
||||
import socket
|
||||
import struct
|
||||
from collections.abc import Generator
|
||||
|
||||
# When support for cpython older than 3.11 is dropped
|
||||
# async_timeout can be replaced with asyncio.timeout
|
||||
from async_timeout import timeout as asyncio_timeout
|
||||
|
||||
from kasa.deviceconfig import DeviceConfig
|
||||
from kasa.exceptions import KasaException, _RetryableError
|
||||
from kasa.json import loads as json_loads
|
||||
|
||||
from .basetransport import BaseTransport
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
_NO_RETRY_ERRORS = {errno.EHOSTDOWN, errno.EHOSTUNREACH, errno.ECONNREFUSED}
|
||||
_UNSIGNED_INT_NETWORK_ORDER = struct.Struct(">I")
|
||||
|
||||
|
||||
class XorTransport(BaseTransport):
|
||||
"""XorTransport class."""
|
||||
|
||||
DEFAULT_PORT: int = 9999
|
||||
BLOCK_SIZE = 4
|
||||
|
||||
def __init__(self, *, config: DeviceConfig) -> None:
|
||||
super().__init__(config=config)
|
||||
self.reader: asyncio.StreamReader | None = None
|
||||
self.writer: asyncio.StreamWriter | None = None
|
||||
self.query_lock = asyncio.Lock()
|
||||
self.loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
@property
|
||||
def default_port(self) -> int:
|
||||
"""Default port for the transport."""
|
||||
return self.DEFAULT_PORT
|
||||
|
||||
@property
|
||||
def credentials_hash(self) -> str | None:
|
||||
"""The hashed credentials used by the transport."""
|
||||
return None
|
||||
|
||||
async def _connect(self, timeout: int) -> None:
|
||||
"""Try to connect or reconnect to the device."""
|
||||
if self.writer:
|
||||
return
|
||||
self.reader = self.writer = None
|
||||
|
||||
task = asyncio.open_connection(self._host, self._port)
|
||||
async with asyncio_timeout(timeout):
|
||||
self.reader, self.writer = await task
|
||||
sock: socket.socket = self.writer.get_extra_info("socket")
|
||||
# Ensure our packets get sent without delay as we do all
|
||||
# our writes in a single go and we do not want any buffering
|
||||
# which would needlessly delay the request or risk overloading
|
||||
# the buffer on the device
|
||||
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||||
|
||||
async def _execute_send(self, request: str) -> dict:
|
||||
"""Execute a query on the device and wait for the response."""
|
||||
assert self.writer is not None # noqa: S101
|
||||
assert self.reader is not None # noqa: S101
|
||||
_LOGGER.debug("Device %s sending query %s", self._host, request)
|
||||
|
||||
self.writer.write(XorEncryption.encrypt(request))
|
||||
await self.writer.drain()
|
||||
|
||||
packed_block_size = await self.reader.readexactly(self.BLOCK_SIZE)
|
||||
length = _UNSIGNED_INT_NETWORK_ORDER.unpack(packed_block_size)[0]
|
||||
|
||||
buffer = await self.reader.readexactly(length)
|
||||
response = XorEncryption.decrypt(buffer)
|
||||
json_payload = json_loads(response)
|
||||
|
||||
_LOGGER.debug("Device %s query response received", self._host)
|
||||
|
||||
return json_payload
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the connection."""
|
||||
writer = self.writer
|
||||
self.close_without_wait()
|
||||
if writer:
|
||||
with contextlib.suppress(Exception):
|
||||
await writer.wait_closed()
|
||||
|
||||
def close_without_wait(self) -> None:
|
||||
"""Close the connection without waiting for the connection to close."""
|
||||
writer = self.writer
|
||||
self.reader = self.writer = None
|
||||
if writer:
|
||||
writer.close()
|
||||
|
||||
async def reset(self) -> None:
|
||||
"""Reset the transport.
|
||||
|
||||
The transport cannot be reset so we must close instead.
|
||||
"""
|
||||
await self.close()
|
||||
|
||||
async def send(self, request: str) -> dict:
|
||||
"""Send a message to the device and return a response."""
|
||||
#
|
||||
# Most of the time we will already be connected if the device is online
|
||||
# and the connect call will do nothing and return right away
|
||||
#
|
||||
# However, if we get an unrecoverable error (_NO_RETRY_ERRORS and
|
||||
# ConnectionRefusedError) we do not want to keep trying since many
|
||||
# connection open/close operations in the same time frame can block
|
||||
# the event loop.
|
||||
# This is especially import when there are multiple tplink devices being polled.
|
||||
try:
|
||||
await self._connect(self._timeout)
|
||||
except ConnectionRefusedError as ex:
|
||||
await self.reset()
|
||||
raise KasaException(
|
||||
f"Unable to connect to the device: {self._host}:{self._port}: {ex}"
|
||||
) from ex
|
||||
except OSError as ex:
|
||||
await self.reset()
|
||||
if ex.errno in _NO_RETRY_ERRORS:
|
||||
raise KasaException(
|
||||
f"Unable to connect to the device:"
|
||||
f" {self._host}:{self._port}: {ex}"
|
||||
) from ex
|
||||
else:
|
||||
raise _RetryableError(
|
||||
f"Unable to connect to the device:"
|
||||
f" {self._host}:{self._port}: {ex}"
|
||||
) from ex
|
||||
except Exception as ex:
|
||||
await self.reset()
|
||||
raise _RetryableError(
|
||||
f"Unable to connect to the device:" f" {self._host}:{self._port}: {ex}"
|
||||
) from ex
|
||||
except BaseException:
|
||||
# Likely something cancelled the task so we need to close the connection
|
||||
# as we are not in an indeterminate state
|
||||
self.close_without_wait()
|
||||
raise
|
||||
|
||||
try:
|
||||
assert self.reader is not None # noqa: S101
|
||||
assert self.writer is not None # noqa: S101
|
||||
async with asyncio_timeout(self._timeout):
|
||||
return await self._execute_send(request)
|
||||
except Exception as ex:
|
||||
await self.reset()
|
||||
raise _RetryableError(
|
||||
f"Unable to query the device {self._host}:{self._port}: {ex}"
|
||||
) from ex
|
||||
except BaseException:
|
||||
# Likely something cancelled the task so we need to close the connection
|
||||
# as we are not in an indeterminate state
|
||||
self.close_without_wait()
|
||||
raise
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.writer and self.loop and self.loop.is_running():
|
||||
# Since __del__ will be called when python does
|
||||
# garbage collection is can happen in the event loop thread
|
||||
# or in another thread so we need to make sure the call to
|
||||
# close is called safely with call_soon_threadsafe
|
||||
self.loop.call_soon_threadsafe(self.writer.close)
|
||||
|
||||
|
||||
class XorEncryption:
|
||||
"""XorEncryption class."""
|
||||
|
||||
INITIALIZATION_VECTOR = 171
|
||||
|
||||
@staticmethod
|
||||
def _xor_payload(unencrypted: bytes) -> Generator[int, None, None]:
|
||||
key = XorEncryption.INITIALIZATION_VECTOR
|
||||
for unencryptedbyte in unencrypted:
|
||||
key = key ^ unencryptedbyte
|
||||
yield key
|
||||
|
||||
@staticmethod
|
||||
def encrypt(request: str) -> bytes:
|
||||
"""Encrypt a request for a TP-Link Smart Home Device.
|
||||
|
||||
:param request: plaintext request data
|
||||
:return: ciphertext to be send over wire, in bytes
|
||||
"""
|
||||
plainbytes = request.encode()
|
||||
return _UNSIGNED_INT_NETWORK_ORDER.pack(len(plainbytes)) + bytes(
|
||||
XorEncryption._xor_payload(plainbytes)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _xor_encrypted_payload(ciphertext: bytes) -> Generator[int, None, None]:
|
||||
key = XorEncryption.INITIALIZATION_VECTOR
|
||||
for cipherbyte in ciphertext:
|
||||
plainbyte = key ^ cipherbyte
|
||||
key = cipherbyte
|
||||
yield plainbyte
|
||||
|
||||
@staticmethod
|
||||
def decrypt(ciphertext: bytes) -> str:
|
||||
"""Decrypt a response of a TP-Link Smart Home Device.
|
||||
|
||||
:param ciphertext: encrypted response data
|
||||
:return: plaintext response
|
||||
"""
|
||||
return bytes(XorEncryption._xor_encrypted_payload(ciphertext)).decode()
|
||||
|
||||
|
||||
# Try to load the kasa_crypt module and if it is available
|
||||
try:
|
||||
from kasa_crypt import decrypt, encrypt
|
||||
|
||||
XorEncryption.decrypt = decrypt # type: ignore[assignment]
|
||||
XorEncryption.encrypt = encrypt # type: ignore[assignment]
|
||||
except ImportError:
|
||||
pass
|
Reference in New Issue
Block a user