mirror of
https://github.com/python-kasa/python-kasa.git
synced 2024-12-23 03:33:35 +00:00
63d64ad920
* Add Tapo protocol support * Update get_device_instance and test_unsupported following review
499 lines
18 KiB
Python
499 lines
18 KiB
Python
"""Implementation of the TP-Link AES Protocol.
|
|
|
|
Based on the work of https://github.com/petretiandrea/plugp100
|
|
under compatible GNU GPL3 license.
|
|
"""
|
|
|
|
import asyncio
|
|
import base64
|
|
import hashlib
|
|
import logging
|
|
import time
|
|
import uuid
|
|
from pprint import pformat as pf
|
|
from typing import Dict, Optional, Union
|
|
|
|
import httpx
|
|
from cryptography.hazmat.primitives import hashes, padding, serialization
|
|
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
|
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
|
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
|
|
|
from .credentials import Credentials
|
|
from .exceptions import AuthenticationException, SmartDeviceException
|
|
from .json import dumps as json_dumps
|
|
from .json import loads as json_loads
|
|
from .protocol import TPLinkProtocol
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
logging.getLogger("httpx").propagate = False
|
|
|
|
|
|
def _md5(payload: bytes) -> bytes:
|
|
digest = hashes.Hash(hashes.MD5()) # noqa: S303
|
|
digest.update(payload)
|
|
hash = digest.finalize()
|
|
return hash
|
|
|
|
|
|
def _sha1(payload: bytes) -> str:
|
|
sha1_algo = hashlib.sha1() # noqa: S324
|
|
sha1_algo.update(payload)
|
|
return sha1_algo.hexdigest()
|
|
|
|
|
|
class TPLinkAes(TPLinkProtocol):
|
|
"""Implementation of the AES encryption protocol.
|
|
|
|
AES is the name used in device discovery for TP-Link's TAPO encryption
|
|
protocol, sometimes used by newer firmware versions on kasa devices.
|
|
"""
|
|
|
|
DEFAULT_PORT = 80
|
|
DEFAULT_TIMEOUT = 5
|
|
SESSION_COOKIE_NAME = "TP_SESSIONID"
|
|
COMMON_HEADERS = {
|
|
"Content-Type": "application/json",
|
|
"requestByApp": "true",
|
|
"Accept": "application/json",
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
host: str,
|
|
*,
|
|
credentials: Optional[Credentials] = None,
|
|
timeout: Optional[int] = None,
|
|
) -> None:
|
|
super().__init__(host=host, port=self.DEFAULT_PORT)
|
|
|
|
self.credentials = (
|
|
credentials
|
|
if credentials and credentials.username and credentials.password
|
|
else Credentials(username="", password="")
|
|
)
|
|
|
|
self._local_seed: Optional[bytes] = None
|
|
self.local_auth_hash = self.generate_auth_hash(self.credentials)
|
|
self.local_auth_owner = self.generate_owner_hash(self.credentials).hex()
|
|
self.kasa_setup_auth_hash = None
|
|
self.blank_auth_hash = None
|
|
self.handshake_lock = asyncio.Lock()
|
|
self.query_lock = asyncio.Lock()
|
|
self.handshake_done = False
|
|
|
|
self.encryption_session: Optional[AesEncyptionSession] = None
|
|
self.session_expire_at: Optional[float] = None
|
|
|
|
self.timeout = timeout if timeout else self.DEFAULT_TIMEOUT
|
|
self.session_cookie = None
|
|
self.terminal_uuid = None
|
|
self.http_client: Optional[httpx.AsyncClient] = None
|
|
self.request_id_generator = SnowflakeId(1, 1)
|
|
self.login_token = None
|
|
|
|
_LOGGER.debug("Created AES object for %s", self.host)
|
|
|
|
def hash_credentials(self, credentials, try_login_version2):
|
|
"""Hash the credentials."""
|
|
if try_login_version2:
|
|
un = base64.b64encode(
|
|
_sha1(credentials.username.encode()).encode()
|
|
).decode()
|
|
pw = base64.b64encode(
|
|
_sha1(credentials.password.encode()).encode()
|
|
).decode()
|
|
else:
|
|
un = base64.b64encode(
|
|
_sha1(credentials.username.encode()).encode()
|
|
).decode()
|
|
pw = base64.b64encode(credentials.password.encode()).decode()
|
|
return un, pw
|
|
|
|
async def client_post(self, url, params=None, data=None, json=None, headers=None):
|
|
"""Send an http post request to the device."""
|
|
response_data = None
|
|
cookies = None
|
|
if self.session_cookie:
|
|
cookies = httpx.Cookies()
|
|
cookies.set(self.SESSION_COOKIE_NAME, self.session_cookie)
|
|
self.http_client.cookies.clear()
|
|
resp = await self.http_client.post(
|
|
url,
|
|
params=params,
|
|
data=data,
|
|
json=json,
|
|
timeout=self.timeout,
|
|
cookies=cookies,
|
|
headers=self.COMMON_HEADERS,
|
|
)
|
|
if resp.status_code == 200:
|
|
response_data = resp.json()
|
|
|
|
return resp.status_code, response_data
|
|
|
|
async def send_secure_passthrough(self, request):
|
|
"""Send encrypted message as passthrough."""
|
|
url = f"http://{self.host}/app"
|
|
if self.login_token:
|
|
url += f"?token={self.login_token}"
|
|
raw_request = json_dumps(request)
|
|
encrypted_payload = self.encryption_session.encrypt(raw_request.encode())
|
|
passthrough_request = {
|
|
"method": "securePassthrough",
|
|
"params": {"request": encrypted_payload.decode()},
|
|
}
|
|
status_code, resp_dict = await self.client_post(url, json=passthrough_request)
|
|
if status_code == 200 and resp_dict["error_code"] == 0:
|
|
response = self.encryption_session.decrypt(
|
|
resp_dict["result"]["response"].encode()
|
|
)
|
|
resp_dict = json_loads(response)
|
|
if resp_dict["error_code"] != 0:
|
|
raise SmartDeviceException(
|
|
f"Could not complete send, response was {resp_dict}",
|
|
)
|
|
if "result" in resp_dict:
|
|
return resp_dict["result"]
|
|
else:
|
|
raise AuthenticationException("Could not complete send")
|
|
|
|
def get_aes_request(self, method, params=None):
|
|
"""Get a request message."""
|
|
request = {
|
|
"method": method,
|
|
"params": params,
|
|
"requestID": self.request_id_generator.generate_id(),
|
|
"request_time_milis": round(time.time() * 1000),
|
|
"terminal_uuid": self.terminal_uuid,
|
|
}
|
|
return request
|
|
|
|
async def perform_login(self, login_v2):
|
|
"""Login to the device."""
|
|
self.login_token = None
|
|
|
|
un, pw = self.hash_credentials(self.credentials, login_v2)
|
|
params = {"password": pw, "username": un}
|
|
request = self.get_aes_request("login_device", params)
|
|
try:
|
|
result = await self.send_secure_passthrough(request)
|
|
except SmartDeviceException as ex:
|
|
raise AuthenticationException(ex) from ex
|
|
self.login_token = result["token"]
|
|
|
|
async def perform_handshake(self):
|
|
"""Perform the handshake."""
|
|
_LOGGER.debug("Will perform handshaking...")
|
|
_LOGGER.debug("Generating keypair")
|
|
|
|
self.handshake_done = False
|
|
self.session_expire_at = None
|
|
self.session_cookie = None
|
|
|
|
url = f"http://{self.host}/app"
|
|
key_pair = KeyPair.create_key_pair()
|
|
|
|
pub_key = (
|
|
"-----BEGIN PUBLIC KEY-----\n"
|
|
+ key_pair.get_public_key()
|
|
+ "\n-----END PUBLIC KEY-----\n"
|
|
)
|
|
handshake_params = {"key": pub_key}
|
|
_LOGGER.debug(f"Handshake params: {handshake_params}")
|
|
|
|
request_body = {"method": "handshake", "params": handshake_params}
|
|
|
|
_LOGGER.debug(f"Request {request_body}")
|
|
|
|
status_code, resp_dict = await self.client_post(url, json=request_body)
|
|
|
|
_LOGGER.debug(f"Device responded with: {resp_dict}")
|
|
|
|
if status_code == 200 and resp_dict["error_code"] == 0:
|
|
_LOGGER.debug("Decoding handshake key...")
|
|
handshake_key = resp_dict["result"]["key"]
|
|
|
|
self.session_cookie = self.http_client.cookies.get( # type: ignore
|
|
self.SESSION_COOKIE_NAME
|
|
)
|
|
if not self.session_cookie:
|
|
self.session_cookie = self.http_client.cookies.get( # type: ignore
|
|
"SESSIONID"
|
|
)
|
|
|
|
self.session_expire_at = time.time() + 86400
|
|
self.encryption_session = AesEncyptionSession.create_from_keypair(
|
|
handshake_key, key_pair
|
|
)
|
|
|
|
self.terminal_uuid = base64.b64encode(_md5(uuid.uuid4().bytes)).decode(
|
|
"UTF-8"
|
|
)
|
|
self.handshake_done = True
|
|
|
|
_LOGGER.debug("Handshake with %s complete", self.host)
|
|
|
|
else:
|
|
raise AuthenticationException("Could not complete handshake")
|
|
|
|
def handshake_session_expired(self):
|
|
"""Return true if session has expired."""
|
|
return (
|
|
self.session_expire_at is None or self.session_expire_at - time.time() <= 0
|
|
)
|
|
|
|
@staticmethod
|
|
def generate_auth_hash(creds: Credentials):
|
|
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
|
|
un = creds.username or ""
|
|
pw = creds.password or ""
|
|
return _md5(_md5(un.encode()) + _md5(pw.encode()))
|
|
|
|
@staticmethod
|
|
def generate_owner_hash(creds: Credentials):
|
|
"""Return the MD5 hash of the username in this object."""
|
|
un = creds.username or ""
|
|
return _md5(un.encode())
|
|
|
|
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
|
"""Query the device retrying for retry_count on failure."""
|
|
async with self.query_lock:
|
|
return await self._query(request, retry_count)
|
|
|
|
async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
|
for retry in range(retry_count + 1):
|
|
try:
|
|
return await self._execute_query(request, retry)
|
|
except httpx.CloseError as sdex:
|
|
await self.close()
|
|
if retry >= retry_count:
|
|
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
|
raise SmartDeviceException(
|
|
f"Unable to connect to the device: {self.host}: {sdex}"
|
|
) from sdex
|
|
continue
|
|
except httpx.ConnectError as cex:
|
|
await self.close()
|
|
raise SmartDeviceException(
|
|
f"Unable to connect to the device: {self.host}: {cex}"
|
|
) from cex
|
|
except TimeoutError as tex:
|
|
await self.close()
|
|
raise SmartDeviceException(
|
|
f"Unable to connect to the device, timed out: {self.host}: {tex}"
|
|
) from tex
|
|
except AuthenticationException as auex:
|
|
_LOGGER.debug("Unable to authenticate with %s, not retrying", self.host)
|
|
raise auex
|
|
except Exception as ex:
|
|
await self.close()
|
|
if retry >= retry_count:
|
|
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
|
raise SmartDeviceException(
|
|
f"Unable to connect to the device: {self.host}: {ex}"
|
|
) from ex
|
|
continue
|
|
|
|
# make mypy happy, this should never be reached..
|
|
raise SmartDeviceException("Query reached somehow to unreachable")
|
|
|
|
async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> Dict:
|
|
_LOGGER.debug(
|
|
"%s >> %s",
|
|
self.host,
|
|
_LOGGER.isEnabledFor(logging.DEBUG) and pf(request),
|
|
)
|
|
|
|
if not self.http_client:
|
|
self.http_client = httpx.AsyncClient()
|
|
|
|
if not self.handshake_done or self.handshake_session_expired():
|
|
try:
|
|
await self.perform_handshake()
|
|
await self.perform_login(False)
|
|
except AuthenticationException:
|
|
await self.perform_handshake()
|
|
await self.perform_login(True)
|
|
|
|
if isinstance(request, dict):
|
|
aes_method = next(iter(request))
|
|
aes_params = request[aes_method]
|
|
else:
|
|
aes_method = request
|
|
aes_params = None
|
|
|
|
aes_request = self.get_aes_request(aes_method, aes_params)
|
|
response_data = await self.send_secure_passthrough(aes_request)
|
|
|
|
_LOGGER.debug(
|
|
"%s << %s",
|
|
self.host,
|
|
_LOGGER.isEnabledFor(logging.DEBUG) and pf(response_data),
|
|
)
|
|
|
|
return response_data
|
|
|
|
async def close(self) -> None:
|
|
"""Close the protocol."""
|
|
client = self.http_client
|
|
self.http_client = None
|
|
if client:
|
|
await client.aclose()
|
|
|
|
|
|
class AesEncyptionSession:
|
|
"""Class for an AES encryption session."""
|
|
|
|
@staticmethod
|
|
def create_from_keypair(handshake_key: str, keypair):
|
|
"""Create the encryption session."""
|
|
handshake_key_bytes: bytes = base64.b64decode(handshake_key.encode("UTF-8"))
|
|
private_key_data = base64.b64decode(keypair.get_private_key().encode("UTF-8"))
|
|
|
|
private_key = serialization.load_der_private_key(private_key_data, None, None)
|
|
key_and_iv = private_key.decrypt(
|
|
handshake_key_bytes, asymmetric_padding.PKCS1v15()
|
|
)
|
|
if key_and_iv is None:
|
|
raise ValueError("Decryption failed!")
|
|
|
|
return AesEncyptionSession(key_and_iv[:16], key_and_iv[16:])
|
|
|
|
def __init__(self, key, iv):
|
|
self.cipher = Cipher(algorithms.AES(key), modes.CBC(iv))
|
|
self.padding_strategy = padding.PKCS7(algorithms.AES.block_size)
|
|
|
|
def encrypt(self, data) -> bytes:
|
|
"""Encrypt the message."""
|
|
encryptor = self.cipher.encryptor()
|
|
padder = self.padding_strategy.padder()
|
|
padded_data = padder.update(data) + padder.finalize()
|
|
encrypted = encryptor.update(padded_data) + encryptor.finalize()
|
|
return base64.b64encode(encrypted)
|
|
|
|
def decrypt(self, data) -> str:
|
|
"""Decrypt the message."""
|
|
decryptor = self.cipher.decryptor()
|
|
unpadder = self.padding_strategy.unpadder()
|
|
decrypted = decryptor.update(base64.b64decode(data)) + decryptor.finalize()
|
|
unpadded_data = unpadder.update(decrypted) + unpadder.finalize()
|
|
return unpadded_data.decode()
|
|
|
|
|
|
class KeyPair:
|
|
"""Class for generating key pairs."""
|
|
|
|
@staticmethod
|
|
def create_key_pair(key_size: int = 1024):
|
|
"""Create a key pair."""
|
|
private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size)
|
|
public_key = private_key.public_key()
|
|
|
|
private_key_bytes = private_key.private_bytes(
|
|
encoding=serialization.Encoding.DER,
|
|
format=serialization.PrivateFormat.PKCS8,
|
|
encryption_algorithm=serialization.NoEncryption(),
|
|
)
|
|
public_key_bytes = public_key.public_bytes(
|
|
encoding=serialization.Encoding.DER,
|
|
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
|
)
|
|
|
|
return KeyPair(
|
|
private_key=base64.b64encode(private_key_bytes).decode("UTF-8"),
|
|
public_key=base64.b64encode(public_key_bytes).decode("UTF-8"),
|
|
)
|
|
|
|
def __init__(self, private_key: str, public_key: str):
|
|
self.private_key = private_key
|
|
self.public_key = public_key
|
|
|
|
def get_private_key(self) -> str:
|
|
"""Get the private key."""
|
|
return self.private_key
|
|
|
|
def get_public_key(self) -> str:
|
|
"""Get the public key."""
|
|
return self.public_key
|
|
|
|
|
|
class SnowflakeId:
|
|
"""Class for generating snowflake ids."""
|
|
|
|
EPOCH = 1420041600000 # Custom epoch (in milliseconds)
|
|
WORKER_ID_BITS = 5
|
|
DATA_CENTER_ID_BITS = 5
|
|
SEQUENCE_BITS = 12
|
|
|
|
MAX_WORKER_ID = (1 << WORKER_ID_BITS) - 1
|
|
MAX_DATA_CENTER_ID = (1 << DATA_CENTER_ID_BITS) - 1
|
|
|
|
SEQUENCE_MASK = (1 << SEQUENCE_BITS) - 1
|
|
|
|
def __init__(self, worker_id, data_center_id):
|
|
if worker_id > SnowflakeId.MAX_WORKER_ID or worker_id < 0:
|
|
raise ValueError(
|
|
"Worker ID can't be greater than "
|
|
+ str(SnowflakeId.MAX_WORKER_ID)
|
|
+ " or less than 0"
|
|
)
|
|
if data_center_id > SnowflakeId.MAX_DATA_CENTER_ID or data_center_id < 0:
|
|
raise ValueError(
|
|
"Data center ID can't be greater than "
|
|
+ str(SnowflakeId.MAX_DATA_CENTER_ID)
|
|
+ " or less than 0"
|
|
)
|
|
|
|
self.worker_id = worker_id
|
|
self.data_center_id = data_center_id
|
|
self.sequence = 0
|
|
self.last_timestamp = -1
|
|
|
|
def generate_id(self):
|
|
"""Generate a snowflake id."""
|
|
timestamp = self._current_millis()
|
|
|
|
if timestamp < self.last_timestamp:
|
|
raise ValueError("Clock moved backwards. Refusing to generate ID.")
|
|
|
|
if timestamp == self.last_timestamp:
|
|
# Within the same millisecond, increment the sequence number
|
|
self.sequence = (self.sequence + 1) & SnowflakeId.SEQUENCE_MASK
|
|
if self.sequence == 0:
|
|
# Sequence exceeds its bit range, wait until the next millisecond
|
|
timestamp = self._wait_next_millis(self.last_timestamp)
|
|
else:
|
|
# New millisecond, reset the sequence number
|
|
self.sequence = 0
|
|
|
|
# Update the last timestamp
|
|
self.last_timestamp = timestamp
|
|
|
|
# Generate and return the final ID
|
|
return (
|
|
(
|
|
(timestamp - SnowflakeId.EPOCH)
|
|
<< (
|
|
SnowflakeId.WORKER_ID_BITS
|
|
+ SnowflakeId.SEQUENCE_BITS
|
|
+ SnowflakeId.DATA_CENTER_ID_BITS
|
|
)
|
|
)
|
|
| (
|
|
self.data_center_id
|
|
<< (SnowflakeId.SEQUENCE_BITS + SnowflakeId.WORKER_ID_BITS)
|
|
)
|
|
| (self.worker_id << SnowflakeId.SEQUENCE_BITS)
|
|
| self.sequence
|
|
)
|
|
|
|
def _current_millis(self):
|
|
return round(time.time() * 1000)
|
|
|
|
def _wait_next_millis(self, last_timestamp):
|
|
timestamp = self._current_millis()
|
|
while timestamp <= last_timestamp:
|
|
timestamp = self._current_millis()
|
|
return timestamp
|