mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-01-24 05:37:59 +00:00
Add support for the protocol used by TAPO devices and some newer KASA devices. (#552)
* Add Tapo protocol support * Update get_device_instance and test_unsupported following review
This commit is contained in:
parent
9de3f69033
commit
63d64ad920
498
kasa/aesprotocol.py
Normal file
498
kasa/aesprotocol.py
Normal file
@ -0,0 +1,498 @@
|
|||||||
|
"""Implementation of the TP-Link AES Protocol.
|
||||||
|
|
||||||
|
Based on the work of https://github.com/petretiandrea/plugp100
|
||||||
|
under compatible GNU GPL3 license.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from pprint import pformat as pf
|
||||||
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from cryptography.hazmat.primitives import hashes, padding, serialization
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||||
|
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||||
|
|
||||||
|
from .credentials import Credentials
|
||||||
|
from .exceptions import AuthenticationException, SmartDeviceException
|
||||||
|
from .json import dumps as json_dumps
|
||||||
|
from .json import loads as json_loads
|
||||||
|
from .protocol import TPLinkProtocol
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
logging.getLogger("httpx").propagate = False
|
||||||
|
|
||||||
|
|
||||||
|
def _md5(payload: bytes) -> bytes:
|
||||||
|
digest = hashes.Hash(hashes.MD5()) # noqa: S303
|
||||||
|
digest.update(payload)
|
||||||
|
hash = digest.finalize()
|
||||||
|
return hash
|
||||||
|
|
||||||
|
|
||||||
|
def _sha1(payload: bytes) -> str:
|
||||||
|
sha1_algo = hashlib.sha1() # noqa: S324
|
||||||
|
sha1_algo.update(payload)
|
||||||
|
return sha1_algo.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
class TPLinkAes(TPLinkProtocol):
|
||||||
|
"""Implementation of the AES encryption protocol.
|
||||||
|
|
||||||
|
AES is the name used in device discovery for TP-Link's TAPO encryption
|
||||||
|
protocol, sometimes used by newer firmware versions on kasa devices.
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_PORT = 80
|
||||||
|
DEFAULT_TIMEOUT = 5
|
||||||
|
SESSION_COOKIE_NAME = "TP_SESSIONID"
|
||||||
|
COMMON_HEADERS = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"requestByApp": "true",
|
||||||
|
"Accept": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
host: str,
|
||||||
|
*,
|
||||||
|
credentials: Optional[Credentials] = None,
|
||||||
|
timeout: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(host=host, port=self.DEFAULT_PORT)
|
||||||
|
|
||||||
|
self.credentials = (
|
||||||
|
credentials
|
||||||
|
if credentials and credentials.username and credentials.password
|
||||||
|
else Credentials(username="", password="")
|
||||||
|
)
|
||||||
|
|
||||||
|
self._local_seed: Optional[bytes] = None
|
||||||
|
self.local_auth_hash = self.generate_auth_hash(self.credentials)
|
||||||
|
self.local_auth_owner = self.generate_owner_hash(self.credentials).hex()
|
||||||
|
self.kasa_setup_auth_hash = None
|
||||||
|
self.blank_auth_hash = None
|
||||||
|
self.handshake_lock = asyncio.Lock()
|
||||||
|
self.query_lock = asyncio.Lock()
|
||||||
|
self.handshake_done = False
|
||||||
|
|
||||||
|
self.encryption_session: Optional[AesEncyptionSession] = None
|
||||||
|
self.session_expire_at: Optional[float] = None
|
||||||
|
|
||||||
|
self.timeout = timeout if timeout else self.DEFAULT_TIMEOUT
|
||||||
|
self.session_cookie = None
|
||||||
|
self.terminal_uuid = None
|
||||||
|
self.http_client: Optional[httpx.AsyncClient] = None
|
||||||
|
self.request_id_generator = SnowflakeId(1, 1)
|
||||||
|
self.login_token = None
|
||||||
|
|
||||||
|
_LOGGER.debug("Created AES object for %s", self.host)
|
||||||
|
|
||||||
|
def hash_credentials(self, credentials, try_login_version2):
|
||||||
|
"""Hash the credentials."""
|
||||||
|
if try_login_version2:
|
||||||
|
un = base64.b64encode(
|
||||||
|
_sha1(credentials.username.encode()).encode()
|
||||||
|
).decode()
|
||||||
|
pw = base64.b64encode(
|
||||||
|
_sha1(credentials.password.encode()).encode()
|
||||||
|
).decode()
|
||||||
|
else:
|
||||||
|
un = base64.b64encode(
|
||||||
|
_sha1(credentials.username.encode()).encode()
|
||||||
|
).decode()
|
||||||
|
pw = base64.b64encode(credentials.password.encode()).decode()
|
||||||
|
return un, pw
|
||||||
|
|
||||||
|
async def client_post(self, url, params=None, data=None, json=None, headers=None):
|
||||||
|
"""Send an http post request to the device."""
|
||||||
|
response_data = None
|
||||||
|
cookies = None
|
||||||
|
if self.session_cookie:
|
||||||
|
cookies = httpx.Cookies()
|
||||||
|
cookies.set(self.SESSION_COOKIE_NAME, self.session_cookie)
|
||||||
|
self.http_client.cookies.clear()
|
||||||
|
resp = await self.http_client.post(
|
||||||
|
url,
|
||||||
|
params=params,
|
||||||
|
data=data,
|
||||||
|
json=json,
|
||||||
|
timeout=self.timeout,
|
||||||
|
cookies=cookies,
|
||||||
|
headers=self.COMMON_HEADERS,
|
||||||
|
)
|
||||||
|
if resp.status_code == 200:
|
||||||
|
response_data = resp.json()
|
||||||
|
|
||||||
|
return resp.status_code, response_data
|
||||||
|
|
||||||
|
async def send_secure_passthrough(self, request):
|
||||||
|
"""Send encrypted message as passthrough."""
|
||||||
|
url = f"http://{self.host}/app"
|
||||||
|
if self.login_token:
|
||||||
|
url += f"?token={self.login_token}"
|
||||||
|
raw_request = json_dumps(request)
|
||||||
|
encrypted_payload = self.encryption_session.encrypt(raw_request.encode())
|
||||||
|
passthrough_request = {
|
||||||
|
"method": "securePassthrough",
|
||||||
|
"params": {"request": encrypted_payload.decode()},
|
||||||
|
}
|
||||||
|
status_code, resp_dict = await self.client_post(url, json=passthrough_request)
|
||||||
|
if status_code == 200 and resp_dict["error_code"] == 0:
|
||||||
|
response = self.encryption_session.decrypt(
|
||||||
|
resp_dict["result"]["response"].encode()
|
||||||
|
)
|
||||||
|
resp_dict = json_loads(response)
|
||||||
|
if resp_dict["error_code"] != 0:
|
||||||
|
raise SmartDeviceException(
|
||||||
|
f"Could not complete send, response was {resp_dict}",
|
||||||
|
)
|
||||||
|
if "result" in resp_dict:
|
||||||
|
return resp_dict["result"]
|
||||||
|
else:
|
||||||
|
raise AuthenticationException("Could not complete send")
|
||||||
|
|
||||||
|
def get_aes_request(self, method, params=None):
|
||||||
|
"""Get a request message."""
|
||||||
|
request = {
|
||||||
|
"method": method,
|
||||||
|
"params": params,
|
||||||
|
"requestID": self.request_id_generator.generate_id(),
|
||||||
|
"request_time_milis": round(time.time() * 1000),
|
||||||
|
"terminal_uuid": self.terminal_uuid,
|
||||||
|
}
|
||||||
|
return request
|
||||||
|
|
||||||
|
async def perform_login(self, login_v2):
|
||||||
|
"""Login to the device."""
|
||||||
|
self.login_token = None
|
||||||
|
|
||||||
|
un, pw = self.hash_credentials(self.credentials, login_v2)
|
||||||
|
params = {"password": pw, "username": un}
|
||||||
|
request = self.get_aes_request("login_device", params)
|
||||||
|
try:
|
||||||
|
result = await self.send_secure_passthrough(request)
|
||||||
|
except SmartDeviceException as ex:
|
||||||
|
raise AuthenticationException(ex) from ex
|
||||||
|
self.login_token = result["token"]
|
||||||
|
|
||||||
|
async def perform_handshake(self):
|
||||||
|
"""Perform the handshake."""
|
||||||
|
_LOGGER.debug("Will perform handshaking...")
|
||||||
|
_LOGGER.debug("Generating keypair")
|
||||||
|
|
||||||
|
self.handshake_done = False
|
||||||
|
self.session_expire_at = None
|
||||||
|
self.session_cookie = None
|
||||||
|
|
||||||
|
url = f"http://{self.host}/app"
|
||||||
|
key_pair = KeyPair.create_key_pair()
|
||||||
|
|
||||||
|
pub_key = (
|
||||||
|
"-----BEGIN PUBLIC KEY-----\n"
|
||||||
|
+ key_pair.get_public_key()
|
||||||
|
+ "\n-----END PUBLIC KEY-----\n"
|
||||||
|
)
|
||||||
|
handshake_params = {"key": pub_key}
|
||||||
|
_LOGGER.debug(f"Handshake params: {handshake_params}")
|
||||||
|
|
||||||
|
request_body = {"method": "handshake", "params": handshake_params}
|
||||||
|
|
||||||
|
_LOGGER.debug(f"Request {request_body}")
|
||||||
|
|
||||||
|
status_code, resp_dict = await self.client_post(url, json=request_body)
|
||||||
|
|
||||||
|
_LOGGER.debug(f"Device responded with: {resp_dict}")
|
||||||
|
|
||||||
|
if status_code == 200 and resp_dict["error_code"] == 0:
|
||||||
|
_LOGGER.debug("Decoding handshake key...")
|
||||||
|
handshake_key = resp_dict["result"]["key"]
|
||||||
|
|
||||||
|
self.session_cookie = self.http_client.cookies.get( # type: ignore
|
||||||
|
self.SESSION_COOKIE_NAME
|
||||||
|
)
|
||||||
|
if not self.session_cookie:
|
||||||
|
self.session_cookie = self.http_client.cookies.get( # type: ignore
|
||||||
|
"SESSIONID"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.session_expire_at = time.time() + 86400
|
||||||
|
self.encryption_session = AesEncyptionSession.create_from_keypair(
|
||||||
|
handshake_key, key_pair
|
||||||
|
)
|
||||||
|
|
||||||
|
self.terminal_uuid = base64.b64encode(_md5(uuid.uuid4().bytes)).decode(
|
||||||
|
"UTF-8"
|
||||||
|
)
|
||||||
|
self.handshake_done = True
|
||||||
|
|
||||||
|
_LOGGER.debug("Handshake with %s complete", self.host)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise AuthenticationException("Could not complete handshake")
|
||||||
|
|
||||||
|
def handshake_session_expired(self):
|
||||||
|
"""Return true if session has expired."""
|
||||||
|
return (
|
||||||
|
self.session_expire_at is None or self.session_expire_at - time.time() <= 0
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_auth_hash(creds: Credentials):
|
||||||
|
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
|
||||||
|
un = creds.username or ""
|
||||||
|
pw = creds.password or ""
|
||||||
|
return _md5(_md5(un.encode()) + _md5(pw.encode()))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_owner_hash(creds: Credentials):
|
||||||
|
"""Return the MD5 hash of the username in this object."""
|
||||||
|
un = creds.username or ""
|
||||||
|
return _md5(un.encode())
|
||||||
|
|
||||||
|
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
||||||
|
"""Query the device retrying for retry_count on failure."""
|
||||||
|
async with self.query_lock:
|
||||||
|
return await self._query(request, retry_count)
|
||||||
|
|
||||||
|
async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
||||||
|
for retry in range(retry_count + 1):
|
||||||
|
try:
|
||||||
|
return await self._execute_query(request, retry)
|
||||||
|
except httpx.CloseError as sdex:
|
||||||
|
await self.close()
|
||||||
|
if retry >= retry_count:
|
||||||
|
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
||||||
|
raise SmartDeviceException(
|
||||||
|
f"Unable to connect to the device: {self.host}: {sdex}"
|
||||||
|
) from sdex
|
||||||
|
continue
|
||||||
|
except httpx.ConnectError as cex:
|
||||||
|
await self.close()
|
||||||
|
raise SmartDeviceException(
|
||||||
|
f"Unable to connect to the device: {self.host}: {cex}"
|
||||||
|
) from cex
|
||||||
|
except TimeoutError as tex:
|
||||||
|
await self.close()
|
||||||
|
raise SmartDeviceException(
|
||||||
|
f"Unable to connect to the device, timed out: {self.host}: {tex}"
|
||||||
|
) from tex
|
||||||
|
except AuthenticationException as auex:
|
||||||
|
_LOGGER.debug("Unable to authenticate with %s, not retrying", self.host)
|
||||||
|
raise auex
|
||||||
|
except Exception as ex:
|
||||||
|
await self.close()
|
||||||
|
if retry >= retry_count:
|
||||||
|
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
||||||
|
raise SmartDeviceException(
|
||||||
|
f"Unable to connect to the device: {self.host}: {ex}"
|
||||||
|
) from ex
|
||||||
|
continue
|
||||||
|
|
||||||
|
# make mypy happy, this should never be reached..
|
||||||
|
raise SmartDeviceException("Query reached somehow to unreachable")
|
||||||
|
|
||||||
|
async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> Dict:
|
||||||
|
_LOGGER.debug(
|
||||||
|
"%s >> %s",
|
||||||
|
self.host,
|
||||||
|
_LOGGER.isEnabledFor(logging.DEBUG) and pf(request),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.http_client:
|
||||||
|
self.http_client = httpx.AsyncClient()
|
||||||
|
|
||||||
|
if not self.handshake_done or self.handshake_session_expired():
|
||||||
|
try:
|
||||||
|
await self.perform_handshake()
|
||||||
|
await self.perform_login(False)
|
||||||
|
except AuthenticationException:
|
||||||
|
await self.perform_handshake()
|
||||||
|
await self.perform_login(True)
|
||||||
|
|
||||||
|
if isinstance(request, dict):
|
||||||
|
aes_method = next(iter(request))
|
||||||
|
aes_params = request[aes_method]
|
||||||
|
else:
|
||||||
|
aes_method = request
|
||||||
|
aes_params = None
|
||||||
|
|
||||||
|
aes_request = self.get_aes_request(aes_method, aes_params)
|
||||||
|
response_data = await self.send_secure_passthrough(aes_request)
|
||||||
|
|
||||||
|
_LOGGER.debug(
|
||||||
|
"%s << %s",
|
||||||
|
self.host,
|
||||||
|
_LOGGER.isEnabledFor(logging.DEBUG) and pf(response_data),
|
||||||
|
)
|
||||||
|
|
||||||
|
return response_data
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close the protocol."""
|
||||||
|
client = self.http_client
|
||||||
|
self.http_client = None
|
||||||
|
if client:
|
||||||
|
await client.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
class AesEncyptionSession:
|
||||||
|
"""Class for an AES encryption session."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_from_keypair(handshake_key: str, keypair):
|
||||||
|
"""Create the encryption session."""
|
||||||
|
handshake_key_bytes: bytes = base64.b64decode(handshake_key.encode("UTF-8"))
|
||||||
|
private_key_data = base64.b64decode(keypair.get_private_key().encode("UTF-8"))
|
||||||
|
|
||||||
|
private_key = serialization.load_der_private_key(private_key_data, None, None)
|
||||||
|
key_and_iv = private_key.decrypt(
|
||||||
|
handshake_key_bytes, asymmetric_padding.PKCS1v15()
|
||||||
|
)
|
||||||
|
if key_and_iv is None:
|
||||||
|
raise ValueError("Decryption failed!")
|
||||||
|
|
||||||
|
return AesEncyptionSession(key_and_iv[:16], key_and_iv[16:])
|
||||||
|
|
||||||
|
def __init__(self, key, iv):
|
||||||
|
self.cipher = Cipher(algorithms.AES(key), modes.CBC(iv))
|
||||||
|
self.padding_strategy = padding.PKCS7(algorithms.AES.block_size)
|
||||||
|
|
||||||
|
def encrypt(self, data) -> bytes:
|
||||||
|
"""Encrypt the message."""
|
||||||
|
encryptor = self.cipher.encryptor()
|
||||||
|
padder = self.padding_strategy.padder()
|
||||||
|
padded_data = padder.update(data) + padder.finalize()
|
||||||
|
encrypted = encryptor.update(padded_data) + encryptor.finalize()
|
||||||
|
return base64.b64encode(encrypted)
|
||||||
|
|
||||||
|
def decrypt(self, data) -> str:
|
||||||
|
"""Decrypt the message."""
|
||||||
|
decryptor = self.cipher.decryptor()
|
||||||
|
unpadder = self.padding_strategy.unpadder()
|
||||||
|
decrypted = decryptor.update(base64.b64decode(data)) + decryptor.finalize()
|
||||||
|
unpadded_data = unpadder.update(decrypted) + unpadder.finalize()
|
||||||
|
return unpadded_data.decode()
|
||||||
|
|
||||||
|
|
||||||
|
class KeyPair:
|
||||||
|
"""Class for generating key pairs."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_key_pair(key_size: int = 1024):
|
||||||
|
"""Create a key pair."""
|
||||||
|
private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size)
|
||||||
|
public_key = private_key.public_key()
|
||||||
|
|
||||||
|
private_key_bytes = private_key.private_bytes(
|
||||||
|
encoding=serialization.Encoding.DER,
|
||||||
|
format=serialization.PrivateFormat.PKCS8,
|
||||||
|
encryption_algorithm=serialization.NoEncryption(),
|
||||||
|
)
|
||||||
|
public_key_bytes = public_key.public_bytes(
|
||||||
|
encoding=serialization.Encoding.DER,
|
||||||
|
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||||
|
)
|
||||||
|
|
||||||
|
return KeyPair(
|
||||||
|
private_key=base64.b64encode(private_key_bytes).decode("UTF-8"),
|
||||||
|
public_key=base64.b64encode(public_key_bytes).decode("UTF-8"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, private_key: str, public_key: str):
|
||||||
|
self.private_key = private_key
|
||||||
|
self.public_key = public_key
|
||||||
|
|
||||||
|
def get_private_key(self) -> str:
|
||||||
|
"""Get the private key."""
|
||||||
|
return self.private_key
|
||||||
|
|
||||||
|
def get_public_key(self) -> str:
|
||||||
|
"""Get the public key."""
|
||||||
|
return self.public_key
|
||||||
|
|
||||||
|
|
||||||
|
class SnowflakeId:
|
||||||
|
"""Class for generating snowflake ids."""
|
||||||
|
|
||||||
|
EPOCH = 1420041600000 # Custom epoch (in milliseconds)
|
||||||
|
WORKER_ID_BITS = 5
|
||||||
|
DATA_CENTER_ID_BITS = 5
|
||||||
|
SEQUENCE_BITS = 12
|
||||||
|
|
||||||
|
MAX_WORKER_ID = (1 << WORKER_ID_BITS) - 1
|
||||||
|
MAX_DATA_CENTER_ID = (1 << DATA_CENTER_ID_BITS) - 1
|
||||||
|
|
||||||
|
SEQUENCE_MASK = (1 << SEQUENCE_BITS) - 1
|
||||||
|
|
||||||
|
def __init__(self, worker_id, data_center_id):
|
||||||
|
if worker_id > SnowflakeId.MAX_WORKER_ID or worker_id < 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Worker ID can't be greater than "
|
||||||
|
+ str(SnowflakeId.MAX_WORKER_ID)
|
||||||
|
+ " or less than 0"
|
||||||
|
)
|
||||||
|
if data_center_id > SnowflakeId.MAX_DATA_CENTER_ID or data_center_id < 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Data center ID can't be greater than "
|
||||||
|
+ str(SnowflakeId.MAX_DATA_CENTER_ID)
|
||||||
|
+ " or less than 0"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.worker_id = worker_id
|
||||||
|
self.data_center_id = data_center_id
|
||||||
|
self.sequence = 0
|
||||||
|
self.last_timestamp = -1
|
||||||
|
|
||||||
|
def generate_id(self):
|
||||||
|
"""Generate a snowflake id."""
|
||||||
|
timestamp = self._current_millis()
|
||||||
|
|
||||||
|
if timestamp < self.last_timestamp:
|
||||||
|
raise ValueError("Clock moved backwards. Refusing to generate ID.")
|
||||||
|
|
||||||
|
if timestamp == self.last_timestamp:
|
||||||
|
# Within the same millisecond, increment the sequence number
|
||||||
|
self.sequence = (self.sequence + 1) & SnowflakeId.SEQUENCE_MASK
|
||||||
|
if self.sequence == 0:
|
||||||
|
# Sequence exceeds its bit range, wait until the next millisecond
|
||||||
|
timestamp = self._wait_next_millis(self.last_timestamp)
|
||||||
|
else:
|
||||||
|
# New millisecond, reset the sequence number
|
||||||
|
self.sequence = 0
|
||||||
|
|
||||||
|
# Update the last timestamp
|
||||||
|
self.last_timestamp = timestamp
|
||||||
|
|
||||||
|
# Generate and return the final ID
|
||||||
|
return (
|
||||||
|
(
|
||||||
|
(timestamp - SnowflakeId.EPOCH)
|
||||||
|
<< (
|
||||||
|
SnowflakeId.WORKER_ID_BITS
|
||||||
|
+ SnowflakeId.SEQUENCE_BITS
|
||||||
|
+ SnowflakeId.DATA_CENTER_ID_BITS
|
||||||
|
)
|
||||||
|
)
|
||||||
|
| (
|
||||||
|
self.data_center_id
|
||||||
|
<< (SnowflakeId.SEQUENCE_BITS + SnowflakeId.WORKER_ID_BITS)
|
||||||
|
)
|
||||||
|
| (self.worker_id << SnowflakeId.SEQUENCE_BITS)
|
||||||
|
| self.sequence
|
||||||
|
)
|
||||||
|
|
||||||
|
def _current_millis(self):
|
||||||
|
return round(time.time() * 1000)
|
||||||
|
|
||||||
|
def _wait_next_millis(self, last_timestamp):
|
||||||
|
timestamp = self._current_millis()
|
||||||
|
while timestamp <= last_timestamp:
|
||||||
|
timestamp = self._current_millis()
|
||||||
|
return timestamp
|
@ -14,6 +14,7 @@ from .smartdimmer import SmartDimmer
|
|||||||
from .smartlightstrip import SmartLightStrip
|
from .smartlightstrip import SmartLightStrip
|
||||||
from .smartplug import SmartPlug
|
from .smartplug import SmartPlug
|
||||||
from .smartstrip import SmartStrip
|
from .smartstrip import SmartStrip
|
||||||
|
from .tapo.tapoplug import TapoPlug
|
||||||
|
|
||||||
DEVICE_TYPE_TO_CLASS = {
|
DEVICE_TYPE_TO_CLASS = {
|
||||||
DeviceType.Plug: SmartPlug,
|
DeviceType.Plug: SmartPlug,
|
||||||
@ -21,6 +22,7 @@ DEVICE_TYPE_TO_CLASS = {
|
|||||||
DeviceType.Strip: SmartStrip,
|
DeviceType.Strip: SmartStrip,
|
||||||
DeviceType.Dimmer: SmartDimmer,
|
DeviceType.Dimmer: SmartDimmer,
|
||||||
DeviceType.LightStrip: SmartLightStrip,
|
DeviceType.LightStrip: SmartLightStrip,
|
||||||
|
DeviceType.TapoPlug: TapoPlug,
|
||||||
}
|
}
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
@ -14,6 +14,7 @@ class DeviceType(Enum):
|
|||||||
StripSocket = "stripsocket"
|
StripSocket = "stripsocket"
|
||||||
Dimmer = "dimmer"
|
Dimmer = "dimmer"
|
||||||
LightStrip = "lightstrip"
|
LightStrip = "lightstrip"
|
||||||
|
TapoPlug = "tapoplug"
|
||||||
Unknown = "unknown"
|
Unknown = "unknown"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -15,14 +15,16 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from kasa.aesprotocol import TPLinkAes
|
||||||
from kasa.credentials import Credentials
|
from kasa.credentials import Credentials
|
||||||
from kasa.exceptions import UnsupportedDeviceException
|
from kasa.exceptions import UnsupportedDeviceException
|
||||||
from kasa.json import dumps as json_dumps
|
from kasa.json import dumps as json_dumps
|
||||||
from kasa.json import loads as json_loads
|
from kasa.json import loads as json_loads
|
||||||
from kasa.klapprotocol import TPLinkKlap
|
from kasa.klapprotocol import TPLinkKlap
|
||||||
from kasa.protocol import TPLinkSmartHomeProtocol
|
from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol
|
||||||
from kasa.smartdevice import SmartDevice, SmartDeviceException
|
from kasa.smartdevice import SmartDevice, SmartDeviceException
|
||||||
from kasa.smartplug import SmartPlug
|
from kasa.smartplug import SmartPlug
|
||||||
|
from kasa.tapo.tapoplug import TapoPlug
|
||||||
|
|
||||||
from .device_factory import get_device_class_from_info
|
from .device_factory import get_device_class_from_info
|
||||||
|
|
||||||
@ -378,27 +380,38 @@ class Discover:
|
|||||||
f"Unable to read response from device: {ip}: {ex}"
|
f"Unable to read response from device: {ip}: {ex}"
|
||||||
) from ex
|
) from ex
|
||||||
|
|
||||||
if (
|
type_ = discovery_result.device_type
|
||||||
discovery_result.mgt_encrypt_schm.encrypt_type == "KLAP"
|
encrypt_type_ = (
|
||||||
and discovery_result.mgt_encrypt_schm.lv is None
|
f"{type_.split('.')[0]}.{discovery_result.mgt_encrypt_schm.encrypt_type}"
|
||||||
):
|
)
|
||||||
type_ = discovery_result.device_type
|
device_class = None
|
||||||
device_class = None
|
|
||||||
if type_.upper() == "IOT.SMARTPLUGSWITCH":
|
|
||||||
device_class = SmartPlug
|
|
||||||
|
|
||||||
if device_class:
|
supported_device_types: dict[str, Type[SmartDevice]] = {
|
||||||
_LOGGER.debug("[DISCOVERY] %s << %s", ip, info)
|
"SMART.TAPOPLUG": TapoPlug,
|
||||||
device = device_class(ip, port=port, credentials=credentials)
|
"SMART.KASAPLUG": TapoPlug,
|
||||||
device.update_from_discover_info(discovery_result.get_dict())
|
"IOT.SMARTPLUGSWITCH": SmartPlug,
|
||||||
device.protocol = TPLinkKlap(ip, credentials=credentials)
|
}
|
||||||
return device
|
supported_device_protocols: dict[str, Type[TPLinkProtocol]] = {
|
||||||
else:
|
"IOT.KLAP": TPLinkKlap,
|
||||||
raise UnsupportedDeviceException(
|
"SMART.AES": TPLinkAes,
|
||||||
f"Unsupported device {ip} of type {type_}: {info}"
|
}
|
||||||
)
|
|
||||||
else:
|
if (device_class := supported_device_types.get(type_)) is None:
|
||||||
raise UnsupportedDeviceException(f"Unsupported device {ip}: {info}")
|
_LOGGER.warning("Got unsupported device type: %s", type_)
|
||||||
|
raise UnsupportedDeviceException(
|
||||||
|
f"Unsupported device {ip} of type {type_}: {info}"
|
||||||
|
)
|
||||||
|
if (protocol_class := supported_device_protocols.get(encrypt_type_)) is None:
|
||||||
|
_LOGGER.warning("Got unsupported device type: %s", encrypt_type_)
|
||||||
|
raise UnsupportedDeviceException(
|
||||||
|
f"Unsupported encryption scheme {ip} of type {encrypt_type_}: {info}"
|
||||||
|
)
|
||||||
|
|
||||||
|
_LOGGER.debug("[DISCOVERY] %s << %s", ip, info)
|
||||||
|
device = device_class(ip, port=port, credentials=credentials)
|
||||||
|
device.protocol = protocol_class(ip, credentials=credentials)
|
||||||
|
device.update_from_discover_info(discovery_result.get_dict())
|
||||||
|
return device
|
||||||
|
|
||||||
|
|
||||||
class DiscoveryResult(BaseModel):
|
class DiscoveryResult(BaseModel):
|
||||||
@ -415,7 +428,7 @@ class DiscoveryResult(BaseModel):
|
|||||||
is_support_https: Optional[bool] = None
|
is_support_https: Optional[bool] = None
|
||||||
encrypt_type: Optional[str] = None
|
encrypt_type: Optional[str] = None
|
||||||
http_port: Optional[int] = None
|
http_port: Optional[int] = None
|
||||||
lv: Optional[int] = None
|
lv: Optional[int] = 1
|
||||||
|
|
||||||
device_type: str = Field(alias="device_type_text")
|
device_type: str = Field(alias="device_type_text")
|
||||||
device_model: str = Field(alias="model")
|
device_model: str = Field(alias="model")
|
||||||
|
164
kasa/tapo/tapodevice.py
Normal file
164
kasa/tapo/tapodevice.py
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
"""Module for a TAPO device."""
|
||||||
|
import base64
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Any, Dict, Optional, Set, cast
|
||||||
|
|
||||||
|
from ..aesprotocol import TPLinkAes
|
||||||
|
from ..credentials import Credentials
|
||||||
|
from ..exceptions import AuthenticationException
|
||||||
|
from ..smartdevice import SmartDevice
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TapoDevice(SmartDevice):
|
||||||
|
"""Base class to represent a TAPO device."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
host: str,
|
||||||
|
*,
|
||||||
|
port: Optional[int] = None,
|
||||||
|
credentials: Optional[Credentials] = None,
|
||||||
|
timeout: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(host, port=port, credentials=credentials, timeout=timeout)
|
||||||
|
self._state_information: Dict[str, Any] = {}
|
||||||
|
self._discovery_info: Optional[Dict[str, Any]] = None
|
||||||
|
self.protocol = TPLinkAes(host, credentials=credentials, timeout=timeout)
|
||||||
|
|
||||||
|
async def update(self, update_children: bool = True):
|
||||||
|
"""Update the device."""
|
||||||
|
if self.credentials is None or self.credentials.username is None:
|
||||||
|
raise AuthenticationException("Tapo plug requires authentication.")
|
||||||
|
|
||||||
|
self._info = await self.protocol.query("get_device_info")
|
||||||
|
self._usage = await self.protocol.query("get_device_usage")
|
||||||
|
self._time = await self.protocol.query("get_device_time")
|
||||||
|
|
||||||
|
self._last_update = self._data = {
|
||||||
|
"info": self._info,
|
||||||
|
"usage": self._usage,
|
||||||
|
"time": self._time,
|
||||||
|
}
|
||||||
|
|
||||||
|
_LOGGER.debug("Got an update: %s", self._data)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sys_info(self) -> Dict[str, Any]:
|
||||||
|
"""Returns the device info."""
|
||||||
|
return self._info
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model(self) -> str:
|
||||||
|
"""Returns the device model."""
|
||||||
|
return str(self._info.get("model"))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def alias(self) -> str:
|
||||||
|
"""Returns the device alias or nickname."""
|
||||||
|
return base64.b64decode(str(self._info.get("nickname"))).decode()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def time(self) -> datetime:
|
||||||
|
"""Return the time."""
|
||||||
|
td = timedelta(minutes=cast(float, self._time.get("time_diff")))
|
||||||
|
if self._time.get("region"):
|
||||||
|
tz = timezone(td, str(self._time.get("region")))
|
||||||
|
else:
|
||||||
|
# in case the device returns a blank region this will result in the
|
||||||
|
# tzname being a UTC offset
|
||||||
|
tz = timezone(td)
|
||||||
|
return datetime.fromtimestamp(
|
||||||
|
cast(float, self._time.get("timestamp")),
|
||||||
|
tz=tz,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def timezone(self) -> Dict:
|
||||||
|
"""Return the timezone and time_difference."""
|
||||||
|
ti = self.time
|
||||||
|
return {"timezone": ti.tzname()}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hw_info(self) -> Dict:
|
||||||
|
"""Return hardware info for the device."""
|
||||||
|
return {
|
||||||
|
"sw_ver": self._info.get("fw_ver"),
|
||||||
|
"hw_ver": self._info.get("hw_ver"),
|
||||||
|
"mac": self._info.get("mac"),
|
||||||
|
"type": self._info.get("type"),
|
||||||
|
"hwId": self._info.get("device_id"),
|
||||||
|
"dev_name": self.alias,
|
||||||
|
"oemId": self._info.get("oem_id"),
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def location(self) -> Dict:
|
||||||
|
"""Return the device location."""
|
||||||
|
loc = {
|
||||||
|
"latitude": cast(float, self._info.get("latitude")) / 10_000,
|
||||||
|
"longitude": cast(float, self._info.get("longitude")) / 10_000,
|
||||||
|
}
|
||||||
|
return loc
|
||||||
|
|
||||||
|
@property
|
||||||
|
def rssi(self) -> Optional[int]:
|
||||||
|
"""Return the rssi."""
|
||||||
|
rssi = self._info.get("rssi")
|
||||||
|
return int(rssi) if rssi else None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mac(self) -> str:
|
||||||
|
"""Return the mac formatted with colons."""
|
||||||
|
return str(self._info.get("mac")).replace("-", ":")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device_id(self) -> str:
|
||||||
|
"""Return the device id."""
|
||||||
|
return str(self._info.get("device_id"))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def internal_state(self) -> Any:
|
||||||
|
"""Return all the internal state data."""
|
||||||
|
return self._data
|
||||||
|
|
||||||
|
async def _query_helper(
|
||||||
|
self, target: str, cmd: str, arg: Optional[Dict] = None, child_ids=None
|
||||||
|
) -> Any:
|
||||||
|
res = await self.protocol.query({cmd: arg})
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state_information(self) -> Dict[str, Any]:
|
||||||
|
"""Return the key state information."""
|
||||||
|
return {
|
||||||
|
"overheated": self._info.get("overheated"),
|
||||||
|
"signal_level": self._info.get("signal_level"),
|
||||||
|
"SSID": base64.b64decode(str(self._info.get("ssid"))).decode(),
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def features(self) -> Set[str]:
|
||||||
|
"""Return the list of supported features."""
|
||||||
|
# TODO:
|
||||||
|
return set()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_on(self) -> bool:
|
||||||
|
"""Return true if the device is on."""
|
||||||
|
return bool(self._info.get("device_on"))
|
||||||
|
|
||||||
|
async def turn_on(self, **kwargs):
|
||||||
|
"""Turn on the device."""
|
||||||
|
await self.protocol.query({"set_device_info": {"device_on": True}})
|
||||||
|
|
||||||
|
async def turn_off(self, **kwargs):
|
||||||
|
"""Turn off the device."""
|
||||||
|
await self.protocol.query({"set_device_info": {"device_on": False}})
|
||||||
|
|
||||||
|
def update_from_discover_info(self, info):
|
||||||
|
"""Update state from info from the discover call."""
|
||||||
|
self._discovery_info = info
|
73
kasa/tapo/tapoplug.py
Normal file
73
kasa/tapo/tapoplug.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
"""Module for a TAPO Plug."""
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Any, Dict, Optional, cast
|
||||||
|
|
||||||
|
from ..credentials import Credentials
|
||||||
|
from ..emeterstatus import EmeterStatus
|
||||||
|
from ..smartdevice import DeviceType
|
||||||
|
from .tapodevice import TapoDevice
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TapoPlug(TapoDevice):
|
||||||
|
"""Class to represent a TAPO Plug."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
host: str,
|
||||||
|
*,
|
||||||
|
port: Optional[int] = None,
|
||||||
|
credentials: Optional[Credentials] = None,
|
||||||
|
timeout: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(host, port=port, credentials=credentials, timeout=timeout)
|
||||||
|
self._device_type = DeviceType.Plug
|
||||||
|
|
||||||
|
async def update(self, update_children: bool = True):
|
||||||
|
"""Call the device endpoint and update the device data."""
|
||||||
|
await super().update(update_children)
|
||||||
|
|
||||||
|
self._energy = await self.protocol.query("get_energy_usage")
|
||||||
|
self._emeter = await self.protocol.query("get_current_power")
|
||||||
|
|
||||||
|
self._data["energy"] = self._energy
|
||||||
|
self._data["emeter"] = self._emeter
|
||||||
|
|
||||||
|
_LOGGER.debug("Got an update: %s %s", self._energy, self._emeter)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state_information(self) -> Dict[str, Any]:
|
||||||
|
"""Return the key state information."""
|
||||||
|
return {
|
||||||
|
**super().state_information,
|
||||||
|
**{
|
||||||
|
"On since": self.on_since,
|
||||||
|
"auto_off_status": self._info.get("auto_off_status"),
|
||||||
|
"auto_off_remain_time": self._info.get("auto_off_remain_time"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def emeter_realtime(self) -> EmeterStatus:
|
||||||
|
"""Get the emeter status."""
|
||||||
|
return EmeterStatus({"power_mw": self._energy.get("current_power")})
|
||||||
|
|
||||||
|
@property
|
||||||
|
def emeter_today(self) -> Optional[float]:
|
||||||
|
"""Get the emeter value for today."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def emeter_this_month(self) -> Optional[float]:
|
||||||
|
"""Get the emeter value for this month."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def on_since(self) -> Optional[datetime]:
|
||||||
|
"""Return the time that the device was turned on or None if turned off."""
|
||||||
|
if not self._info.get("device_on"):
|
||||||
|
return None
|
||||||
|
on_time = cast(float, self._info.get("on_time"))
|
||||||
|
return datetime.now().replace(microsecond=0) - timedelta(seconds=on_time)
|
@ -114,7 +114,7 @@ UNSUPPORTED = {
|
|||||||
"result": {
|
"result": {
|
||||||
"device_id": "xx",
|
"device_id": "xx",
|
||||||
"owner": "xx",
|
"owner": "xx",
|
||||||
"device_type": "SMART.TAPOPLUG",
|
"device_type": "SMART.TAPOXMASTREE",
|
||||||
"device_model": "P110(EU)",
|
"device_model": "P110(EU)",
|
||||||
"ip": "127.0.0.1",
|
"ip": "127.0.0.1",
|
||||||
"mac": "48-22xxx",
|
"mac": "48-22xxx",
|
||||||
@ -150,7 +150,7 @@ async def test_discover_single_unsupported(mocker):
|
|||||||
discovery_data = UNSUPPORTED
|
discovery_data = UNSUPPORTED
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
UnsupportedDeviceException,
|
UnsupportedDeviceException,
|
||||||
match=f"Unsupported device {host}: {re.escape(str(UNSUPPORTED))}",
|
match=f"Unsupported device {host} of type SMART.TAPOXMASTREE: {re.escape(str(UNSUPPORTED))}",
|
||||||
):
|
):
|
||||||
await Discover.discover_single(host)
|
await Discover.discover_single(host)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user