mirror of
https://github.com/python-kasa/python-kasa.git
synced 2024-12-22 19:23:34 +00:00
Add support for the protocol used by TAPO devices and some newer KASA devices. (#552)
* Add Tapo protocol support * Update get_device_instance and test_unsupported following review
This commit is contained in:
parent
9de3f69033
commit
63d64ad920
498
kasa/aesprotocol.py
Normal file
498
kasa/aesprotocol.py
Normal file
@ -0,0 +1,498 @@
|
||||
"""Implementation of the TP-Link AES Protocol.
|
||||
|
||||
Based on the work of https://github.com/petretiandrea/plugp100
|
||||
under compatible GNU GPL3 license.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from pprint import pformat as pf
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import httpx
|
||||
from cryptography.hazmat.primitives import hashes, padding, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
|
||||
from .credentials import Credentials
|
||||
from .exceptions import AuthenticationException, SmartDeviceException
|
||||
from .json import dumps as json_dumps
|
||||
from .json import loads as json_loads
|
||||
from .protocol import TPLinkProtocol
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
logging.getLogger("httpx").propagate = False
|
||||
|
||||
|
||||
def _md5(payload: bytes) -> bytes:
|
||||
digest = hashes.Hash(hashes.MD5()) # noqa: S303
|
||||
digest.update(payload)
|
||||
hash = digest.finalize()
|
||||
return hash
|
||||
|
||||
|
||||
def _sha1(payload: bytes) -> str:
|
||||
sha1_algo = hashlib.sha1() # noqa: S324
|
||||
sha1_algo.update(payload)
|
||||
return sha1_algo.hexdigest()
|
||||
|
||||
|
||||
class TPLinkAes(TPLinkProtocol):
|
||||
"""Implementation of the AES encryption protocol.
|
||||
|
||||
AES is the name used in device discovery for TP-Link's TAPO encryption
|
||||
protocol, sometimes used by newer firmware versions on kasa devices.
|
||||
"""
|
||||
|
||||
DEFAULT_PORT = 80
|
||||
DEFAULT_TIMEOUT = 5
|
||||
SESSION_COOKIE_NAME = "TP_SESSIONID"
|
||||
COMMON_HEADERS = {
|
||||
"Content-Type": "application/json",
|
||||
"requestByApp": "true",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
*,
|
||||
credentials: Optional[Credentials] = None,
|
||||
timeout: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__(host=host, port=self.DEFAULT_PORT)
|
||||
|
||||
self.credentials = (
|
||||
credentials
|
||||
if credentials and credentials.username and credentials.password
|
||||
else Credentials(username="", password="")
|
||||
)
|
||||
|
||||
self._local_seed: Optional[bytes] = None
|
||||
self.local_auth_hash = self.generate_auth_hash(self.credentials)
|
||||
self.local_auth_owner = self.generate_owner_hash(self.credentials).hex()
|
||||
self.kasa_setup_auth_hash = None
|
||||
self.blank_auth_hash = None
|
||||
self.handshake_lock = asyncio.Lock()
|
||||
self.query_lock = asyncio.Lock()
|
||||
self.handshake_done = False
|
||||
|
||||
self.encryption_session: Optional[AesEncyptionSession] = None
|
||||
self.session_expire_at: Optional[float] = None
|
||||
|
||||
self.timeout = timeout if timeout else self.DEFAULT_TIMEOUT
|
||||
self.session_cookie = None
|
||||
self.terminal_uuid = None
|
||||
self.http_client: Optional[httpx.AsyncClient] = None
|
||||
self.request_id_generator = SnowflakeId(1, 1)
|
||||
self.login_token = None
|
||||
|
||||
_LOGGER.debug("Created AES object for %s", self.host)
|
||||
|
||||
def hash_credentials(self, credentials, try_login_version2):
|
||||
"""Hash the credentials."""
|
||||
if try_login_version2:
|
||||
un = base64.b64encode(
|
||||
_sha1(credentials.username.encode()).encode()
|
||||
).decode()
|
||||
pw = base64.b64encode(
|
||||
_sha1(credentials.password.encode()).encode()
|
||||
).decode()
|
||||
else:
|
||||
un = base64.b64encode(
|
||||
_sha1(credentials.username.encode()).encode()
|
||||
).decode()
|
||||
pw = base64.b64encode(credentials.password.encode()).decode()
|
||||
return un, pw
|
||||
|
||||
async def client_post(self, url, params=None, data=None, json=None, headers=None):
|
||||
"""Send an http post request to the device."""
|
||||
response_data = None
|
||||
cookies = None
|
||||
if self.session_cookie:
|
||||
cookies = httpx.Cookies()
|
||||
cookies.set(self.SESSION_COOKIE_NAME, self.session_cookie)
|
||||
self.http_client.cookies.clear()
|
||||
resp = await self.http_client.post(
|
||||
url,
|
||||
params=params,
|
||||
data=data,
|
||||
json=json,
|
||||
timeout=self.timeout,
|
||||
cookies=cookies,
|
||||
headers=self.COMMON_HEADERS,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
response_data = resp.json()
|
||||
|
||||
return resp.status_code, response_data
|
||||
|
||||
async def send_secure_passthrough(self, request):
|
||||
"""Send encrypted message as passthrough."""
|
||||
url = f"http://{self.host}/app"
|
||||
if self.login_token:
|
||||
url += f"?token={self.login_token}"
|
||||
raw_request = json_dumps(request)
|
||||
encrypted_payload = self.encryption_session.encrypt(raw_request.encode())
|
||||
passthrough_request = {
|
||||
"method": "securePassthrough",
|
||||
"params": {"request": encrypted_payload.decode()},
|
||||
}
|
||||
status_code, resp_dict = await self.client_post(url, json=passthrough_request)
|
||||
if status_code == 200 and resp_dict["error_code"] == 0:
|
||||
response = self.encryption_session.decrypt(
|
||||
resp_dict["result"]["response"].encode()
|
||||
)
|
||||
resp_dict = json_loads(response)
|
||||
if resp_dict["error_code"] != 0:
|
||||
raise SmartDeviceException(
|
||||
f"Could not complete send, response was {resp_dict}",
|
||||
)
|
||||
if "result" in resp_dict:
|
||||
return resp_dict["result"]
|
||||
else:
|
||||
raise AuthenticationException("Could not complete send")
|
||||
|
||||
def get_aes_request(self, method, params=None):
|
||||
"""Get a request message."""
|
||||
request = {
|
||||
"method": method,
|
||||
"params": params,
|
||||
"requestID": self.request_id_generator.generate_id(),
|
||||
"request_time_milis": round(time.time() * 1000),
|
||||
"terminal_uuid": self.terminal_uuid,
|
||||
}
|
||||
return request
|
||||
|
||||
async def perform_login(self, login_v2):
|
||||
"""Login to the device."""
|
||||
self.login_token = None
|
||||
|
||||
un, pw = self.hash_credentials(self.credentials, login_v2)
|
||||
params = {"password": pw, "username": un}
|
||||
request = self.get_aes_request("login_device", params)
|
||||
try:
|
||||
result = await self.send_secure_passthrough(request)
|
||||
except SmartDeviceException as ex:
|
||||
raise AuthenticationException(ex) from ex
|
||||
self.login_token = result["token"]
|
||||
|
||||
async def perform_handshake(self):
|
||||
"""Perform the handshake."""
|
||||
_LOGGER.debug("Will perform handshaking...")
|
||||
_LOGGER.debug("Generating keypair")
|
||||
|
||||
self.handshake_done = False
|
||||
self.session_expire_at = None
|
||||
self.session_cookie = None
|
||||
|
||||
url = f"http://{self.host}/app"
|
||||
key_pair = KeyPair.create_key_pair()
|
||||
|
||||
pub_key = (
|
||||
"-----BEGIN PUBLIC KEY-----\n"
|
||||
+ key_pair.get_public_key()
|
||||
+ "\n-----END PUBLIC KEY-----\n"
|
||||
)
|
||||
handshake_params = {"key": pub_key}
|
||||
_LOGGER.debug(f"Handshake params: {handshake_params}")
|
||||
|
||||
request_body = {"method": "handshake", "params": handshake_params}
|
||||
|
||||
_LOGGER.debug(f"Request {request_body}")
|
||||
|
||||
status_code, resp_dict = await self.client_post(url, json=request_body)
|
||||
|
||||
_LOGGER.debug(f"Device responded with: {resp_dict}")
|
||||
|
||||
if status_code == 200 and resp_dict["error_code"] == 0:
|
||||
_LOGGER.debug("Decoding handshake key...")
|
||||
handshake_key = resp_dict["result"]["key"]
|
||||
|
||||
self.session_cookie = self.http_client.cookies.get( # type: ignore
|
||||
self.SESSION_COOKIE_NAME
|
||||
)
|
||||
if not self.session_cookie:
|
||||
self.session_cookie = self.http_client.cookies.get( # type: ignore
|
||||
"SESSIONID"
|
||||
)
|
||||
|
||||
self.session_expire_at = time.time() + 86400
|
||||
self.encryption_session = AesEncyptionSession.create_from_keypair(
|
||||
handshake_key, key_pair
|
||||
)
|
||||
|
||||
self.terminal_uuid = base64.b64encode(_md5(uuid.uuid4().bytes)).decode(
|
||||
"UTF-8"
|
||||
)
|
||||
self.handshake_done = True
|
||||
|
||||
_LOGGER.debug("Handshake with %s complete", self.host)
|
||||
|
||||
else:
|
||||
raise AuthenticationException("Could not complete handshake")
|
||||
|
||||
def handshake_session_expired(self):
|
||||
"""Return true if session has expired."""
|
||||
return (
|
||||
self.session_expire_at is None or self.session_expire_at - time.time() <= 0
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def generate_auth_hash(creds: Credentials):
|
||||
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
|
||||
un = creds.username or ""
|
||||
pw = creds.password or ""
|
||||
return _md5(_md5(un.encode()) + _md5(pw.encode()))
|
||||
|
||||
@staticmethod
|
||||
def generate_owner_hash(creds: Credentials):
|
||||
"""Return the MD5 hash of the username in this object."""
|
||||
un = creds.username or ""
|
||||
return _md5(un.encode())
|
||||
|
||||
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
||||
"""Query the device retrying for retry_count on failure."""
|
||||
async with self.query_lock:
|
||||
return await self._query(request, retry_count)
|
||||
|
||||
async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
||||
for retry in range(retry_count + 1):
|
||||
try:
|
||||
return await self._execute_query(request, retry)
|
||||
except httpx.CloseError as sdex:
|
||||
await self.close()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device: {self.host}: {sdex}"
|
||||
) from sdex
|
||||
continue
|
||||
except httpx.ConnectError as cex:
|
||||
await self.close()
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device: {self.host}: {cex}"
|
||||
) from cex
|
||||
except TimeoutError as tex:
|
||||
await self.close()
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device, timed out: {self.host}: {tex}"
|
||||
) from tex
|
||||
except AuthenticationException as auex:
|
||||
_LOGGER.debug("Unable to authenticate with %s, not retrying", self.host)
|
||||
raise auex
|
||||
except Exception as ex:
|
||||
await self.close()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device: {self.host}: {ex}"
|
||||
) from ex
|
||||
continue
|
||||
|
||||
# make mypy happy, this should never be reached..
|
||||
raise SmartDeviceException("Query reached somehow to unreachable")
|
||||
|
||||
async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> Dict:
|
||||
_LOGGER.debug(
|
||||
"%s >> %s",
|
||||
self.host,
|
||||
_LOGGER.isEnabledFor(logging.DEBUG) and pf(request),
|
||||
)
|
||||
|
||||
if not self.http_client:
|
||||
self.http_client = httpx.AsyncClient()
|
||||
|
||||
if not self.handshake_done or self.handshake_session_expired():
|
||||
try:
|
||||
await self.perform_handshake()
|
||||
await self.perform_login(False)
|
||||
except AuthenticationException:
|
||||
await self.perform_handshake()
|
||||
await self.perform_login(True)
|
||||
|
||||
if isinstance(request, dict):
|
||||
aes_method = next(iter(request))
|
||||
aes_params = request[aes_method]
|
||||
else:
|
||||
aes_method = request
|
||||
aes_params = None
|
||||
|
||||
aes_request = self.get_aes_request(aes_method, aes_params)
|
||||
response_data = await self.send_secure_passthrough(aes_request)
|
||||
|
||||
_LOGGER.debug(
|
||||
"%s << %s",
|
||||
self.host,
|
||||
_LOGGER.isEnabledFor(logging.DEBUG) and pf(response_data),
|
||||
)
|
||||
|
||||
return response_data
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the protocol."""
|
||||
client = self.http_client
|
||||
self.http_client = None
|
||||
if client:
|
||||
await client.aclose()
|
||||
|
||||
|
||||
class AesEncyptionSession:
|
||||
"""Class for an AES encryption session."""
|
||||
|
||||
@staticmethod
|
||||
def create_from_keypair(handshake_key: str, keypair):
|
||||
"""Create the encryption session."""
|
||||
handshake_key_bytes: bytes = base64.b64decode(handshake_key.encode("UTF-8"))
|
||||
private_key_data = base64.b64decode(keypair.get_private_key().encode("UTF-8"))
|
||||
|
||||
private_key = serialization.load_der_private_key(private_key_data, None, None)
|
||||
key_and_iv = private_key.decrypt(
|
||||
handshake_key_bytes, asymmetric_padding.PKCS1v15()
|
||||
)
|
||||
if key_and_iv is None:
|
||||
raise ValueError("Decryption failed!")
|
||||
|
||||
return AesEncyptionSession(key_and_iv[:16], key_and_iv[16:])
|
||||
|
||||
def __init__(self, key, iv):
|
||||
self.cipher = Cipher(algorithms.AES(key), modes.CBC(iv))
|
||||
self.padding_strategy = padding.PKCS7(algorithms.AES.block_size)
|
||||
|
||||
def encrypt(self, data) -> bytes:
|
||||
"""Encrypt the message."""
|
||||
encryptor = self.cipher.encryptor()
|
||||
padder = self.padding_strategy.padder()
|
||||
padded_data = padder.update(data) + padder.finalize()
|
||||
encrypted = encryptor.update(padded_data) + encryptor.finalize()
|
||||
return base64.b64encode(encrypted)
|
||||
|
||||
def decrypt(self, data) -> str:
|
||||
"""Decrypt the message."""
|
||||
decryptor = self.cipher.decryptor()
|
||||
unpadder = self.padding_strategy.unpadder()
|
||||
decrypted = decryptor.update(base64.b64decode(data)) + decryptor.finalize()
|
||||
unpadded_data = unpadder.update(decrypted) + unpadder.finalize()
|
||||
return unpadded_data.decode()
|
||||
|
||||
|
||||
class KeyPair:
|
||||
"""Class for generating key pairs."""
|
||||
|
||||
@staticmethod
|
||||
def create_key_pair(key_size: int = 1024):
|
||||
"""Create a key pair."""
|
||||
private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size)
|
||||
public_key = private_key.public_key()
|
||||
|
||||
private_key_bytes = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
public_key_bytes = public_key.public_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
return KeyPair(
|
||||
private_key=base64.b64encode(private_key_bytes).decode("UTF-8"),
|
||||
public_key=base64.b64encode(public_key_bytes).decode("UTF-8"),
|
||||
)
|
||||
|
||||
def __init__(self, private_key: str, public_key: str):
|
||||
self.private_key = private_key
|
||||
self.public_key = public_key
|
||||
|
||||
def get_private_key(self) -> str:
|
||||
"""Get the private key."""
|
||||
return self.private_key
|
||||
|
||||
def get_public_key(self) -> str:
|
||||
"""Get the public key."""
|
||||
return self.public_key
|
||||
|
||||
|
||||
class SnowflakeId:
|
||||
"""Class for generating snowflake ids."""
|
||||
|
||||
EPOCH = 1420041600000 # Custom epoch (in milliseconds)
|
||||
WORKER_ID_BITS = 5
|
||||
DATA_CENTER_ID_BITS = 5
|
||||
SEQUENCE_BITS = 12
|
||||
|
||||
MAX_WORKER_ID = (1 << WORKER_ID_BITS) - 1
|
||||
MAX_DATA_CENTER_ID = (1 << DATA_CENTER_ID_BITS) - 1
|
||||
|
||||
SEQUENCE_MASK = (1 << SEQUENCE_BITS) - 1
|
||||
|
||||
def __init__(self, worker_id, data_center_id):
|
||||
if worker_id > SnowflakeId.MAX_WORKER_ID or worker_id < 0:
|
||||
raise ValueError(
|
||||
"Worker ID can't be greater than "
|
||||
+ str(SnowflakeId.MAX_WORKER_ID)
|
||||
+ " or less than 0"
|
||||
)
|
||||
if data_center_id > SnowflakeId.MAX_DATA_CENTER_ID or data_center_id < 0:
|
||||
raise ValueError(
|
||||
"Data center ID can't be greater than "
|
||||
+ str(SnowflakeId.MAX_DATA_CENTER_ID)
|
||||
+ " or less than 0"
|
||||
)
|
||||
|
||||
self.worker_id = worker_id
|
||||
self.data_center_id = data_center_id
|
||||
self.sequence = 0
|
||||
self.last_timestamp = -1
|
||||
|
||||
def generate_id(self):
|
||||
"""Generate a snowflake id."""
|
||||
timestamp = self._current_millis()
|
||||
|
||||
if timestamp < self.last_timestamp:
|
||||
raise ValueError("Clock moved backwards. Refusing to generate ID.")
|
||||
|
||||
if timestamp == self.last_timestamp:
|
||||
# Within the same millisecond, increment the sequence number
|
||||
self.sequence = (self.sequence + 1) & SnowflakeId.SEQUENCE_MASK
|
||||
if self.sequence == 0:
|
||||
# Sequence exceeds its bit range, wait until the next millisecond
|
||||
timestamp = self._wait_next_millis(self.last_timestamp)
|
||||
else:
|
||||
# New millisecond, reset the sequence number
|
||||
self.sequence = 0
|
||||
|
||||
# Update the last timestamp
|
||||
self.last_timestamp = timestamp
|
||||
|
||||
# Generate and return the final ID
|
||||
return (
|
||||
(
|
||||
(timestamp - SnowflakeId.EPOCH)
|
||||
<< (
|
||||
SnowflakeId.WORKER_ID_BITS
|
||||
+ SnowflakeId.SEQUENCE_BITS
|
||||
+ SnowflakeId.DATA_CENTER_ID_BITS
|
||||
)
|
||||
)
|
||||
| (
|
||||
self.data_center_id
|
||||
<< (SnowflakeId.SEQUENCE_BITS + SnowflakeId.WORKER_ID_BITS)
|
||||
)
|
||||
| (self.worker_id << SnowflakeId.SEQUENCE_BITS)
|
||||
| self.sequence
|
||||
)
|
||||
|
||||
def _current_millis(self):
|
||||
return round(time.time() * 1000)
|
||||
|
||||
def _wait_next_millis(self, last_timestamp):
|
||||
timestamp = self._current_millis()
|
||||
while timestamp <= last_timestamp:
|
||||
timestamp = self._current_millis()
|
||||
return timestamp
|
@ -14,6 +14,7 @@ from .smartdimmer import SmartDimmer
|
||||
from .smartlightstrip import SmartLightStrip
|
||||
from .smartplug import SmartPlug
|
||||
from .smartstrip import SmartStrip
|
||||
from .tapo.tapoplug import TapoPlug
|
||||
|
||||
DEVICE_TYPE_TO_CLASS = {
|
||||
DeviceType.Plug: SmartPlug,
|
||||
@ -21,6 +22,7 @@ DEVICE_TYPE_TO_CLASS = {
|
||||
DeviceType.Strip: SmartStrip,
|
||||
DeviceType.Dimmer: SmartDimmer,
|
||||
DeviceType.LightStrip: SmartLightStrip,
|
||||
DeviceType.TapoPlug: TapoPlug,
|
||||
}
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
@ -14,6 +14,7 @@ class DeviceType(Enum):
|
||||
StripSocket = "stripsocket"
|
||||
Dimmer = "dimmer"
|
||||
LightStrip = "lightstrip"
|
||||
TapoPlug = "tapoplug"
|
||||
Unknown = "unknown"
|
||||
|
||||
@staticmethod
|
||||
|
@ -15,14 +15,16 @@ try:
|
||||
except ImportError:
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from kasa.aesprotocol import TPLinkAes
|
||||
from kasa.credentials import Credentials
|
||||
from kasa.exceptions import UnsupportedDeviceException
|
||||
from kasa.json import dumps as json_dumps
|
||||
from kasa.json import loads as json_loads
|
||||
from kasa.klapprotocol import TPLinkKlap
|
||||
from kasa.protocol import TPLinkSmartHomeProtocol
|
||||
from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol
|
||||
from kasa.smartdevice import SmartDevice, SmartDeviceException
|
||||
from kasa.smartplug import SmartPlug
|
||||
from kasa.tapo.tapoplug import TapoPlug
|
||||
|
||||
from .device_factory import get_device_class_from_info
|
||||
|
||||
@ -378,27 +380,38 @@ class Discover:
|
||||
f"Unable to read response from device: {ip}: {ex}"
|
||||
) from ex
|
||||
|
||||
if (
|
||||
discovery_result.mgt_encrypt_schm.encrypt_type == "KLAP"
|
||||
and discovery_result.mgt_encrypt_schm.lv is None
|
||||
):
|
||||
type_ = discovery_result.device_type
|
||||
device_class = None
|
||||
if type_.upper() == "IOT.SMARTPLUGSWITCH":
|
||||
device_class = SmartPlug
|
||||
type_ = discovery_result.device_type
|
||||
encrypt_type_ = (
|
||||
f"{type_.split('.')[0]}.{discovery_result.mgt_encrypt_schm.encrypt_type}"
|
||||
)
|
||||
device_class = None
|
||||
|
||||
if device_class:
|
||||
_LOGGER.debug("[DISCOVERY] %s << %s", ip, info)
|
||||
device = device_class(ip, port=port, credentials=credentials)
|
||||
device.update_from_discover_info(discovery_result.get_dict())
|
||||
device.protocol = TPLinkKlap(ip, credentials=credentials)
|
||||
return device
|
||||
else:
|
||||
raise UnsupportedDeviceException(
|
||||
f"Unsupported device {ip} of type {type_}: {info}"
|
||||
)
|
||||
else:
|
||||
raise UnsupportedDeviceException(f"Unsupported device {ip}: {info}")
|
||||
supported_device_types: dict[str, Type[SmartDevice]] = {
|
||||
"SMART.TAPOPLUG": TapoPlug,
|
||||
"SMART.KASAPLUG": TapoPlug,
|
||||
"IOT.SMARTPLUGSWITCH": SmartPlug,
|
||||
}
|
||||
supported_device_protocols: dict[str, Type[TPLinkProtocol]] = {
|
||||
"IOT.KLAP": TPLinkKlap,
|
||||
"SMART.AES": TPLinkAes,
|
||||
}
|
||||
|
||||
if (device_class := supported_device_types.get(type_)) is None:
|
||||
_LOGGER.warning("Got unsupported device type: %s", type_)
|
||||
raise UnsupportedDeviceException(
|
||||
f"Unsupported device {ip} of type {type_}: {info}"
|
||||
)
|
||||
if (protocol_class := supported_device_protocols.get(encrypt_type_)) is None:
|
||||
_LOGGER.warning("Got unsupported device type: %s", encrypt_type_)
|
||||
raise UnsupportedDeviceException(
|
||||
f"Unsupported encryption scheme {ip} of type {encrypt_type_}: {info}"
|
||||
)
|
||||
|
||||
_LOGGER.debug("[DISCOVERY] %s << %s", ip, info)
|
||||
device = device_class(ip, port=port, credentials=credentials)
|
||||
device.protocol = protocol_class(ip, credentials=credentials)
|
||||
device.update_from_discover_info(discovery_result.get_dict())
|
||||
return device
|
||||
|
||||
|
||||
class DiscoveryResult(BaseModel):
|
||||
@ -415,7 +428,7 @@ class DiscoveryResult(BaseModel):
|
||||
is_support_https: Optional[bool] = None
|
||||
encrypt_type: Optional[str] = None
|
||||
http_port: Optional[int] = None
|
||||
lv: Optional[int] = None
|
||||
lv: Optional[int] = 1
|
||||
|
||||
device_type: str = Field(alias="device_type_text")
|
||||
device_model: str = Field(alias="model")
|
||||
|
164
kasa/tapo/tapodevice.py
Normal file
164
kasa/tapo/tapodevice.py
Normal file
@ -0,0 +1,164 @@
|
||||
"""Module for a TAPO device."""
|
||||
import base64
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, Optional, Set, cast
|
||||
|
||||
from ..aesprotocol import TPLinkAes
|
||||
from ..credentials import Credentials
|
||||
from ..exceptions import AuthenticationException
|
||||
from ..smartdevice import SmartDevice
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TapoDevice(SmartDevice):
|
||||
"""Base class to represent a TAPO device."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
*,
|
||||
port: Optional[int] = None,
|
||||
credentials: Optional[Credentials] = None,
|
||||
timeout: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__(host, port=port, credentials=credentials, timeout=timeout)
|
||||
self._state_information: Dict[str, Any] = {}
|
||||
self._discovery_info: Optional[Dict[str, Any]] = None
|
||||
self.protocol = TPLinkAes(host, credentials=credentials, timeout=timeout)
|
||||
|
||||
async def update(self, update_children: bool = True):
|
||||
"""Update the device."""
|
||||
if self.credentials is None or self.credentials.username is None:
|
||||
raise AuthenticationException("Tapo plug requires authentication.")
|
||||
|
||||
self._info = await self.protocol.query("get_device_info")
|
||||
self._usage = await self.protocol.query("get_device_usage")
|
||||
self._time = await self.protocol.query("get_device_time")
|
||||
|
||||
self._last_update = self._data = {
|
||||
"info": self._info,
|
||||
"usage": self._usage,
|
||||
"time": self._time,
|
||||
}
|
||||
|
||||
_LOGGER.debug("Got an update: %s", self._data)
|
||||
|
||||
@property
|
||||
def sys_info(self) -> Dict[str, Any]:
|
||||
"""Returns the device info."""
|
||||
return self._info
|
||||
|
||||
@property
|
||||
def model(self) -> str:
|
||||
"""Returns the device model."""
|
||||
return str(self._info.get("model"))
|
||||
|
||||
@property
|
||||
def alias(self) -> str:
|
||||
"""Returns the device alias or nickname."""
|
||||
return base64.b64decode(str(self._info.get("nickname"))).decode()
|
||||
|
||||
@property
|
||||
def time(self) -> datetime:
|
||||
"""Return the time."""
|
||||
td = timedelta(minutes=cast(float, self._time.get("time_diff")))
|
||||
if self._time.get("region"):
|
||||
tz = timezone(td, str(self._time.get("region")))
|
||||
else:
|
||||
# in case the device returns a blank region this will result in the
|
||||
# tzname being a UTC offset
|
||||
tz = timezone(td)
|
||||
return datetime.fromtimestamp(
|
||||
cast(float, self._time.get("timestamp")),
|
||||
tz=tz,
|
||||
)
|
||||
|
||||
@property
|
||||
def timezone(self) -> Dict:
|
||||
"""Return the timezone and time_difference."""
|
||||
ti = self.time
|
||||
return {"timezone": ti.tzname()}
|
||||
|
||||
@property
|
||||
def hw_info(self) -> Dict:
|
||||
"""Return hardware info for the device."""
|
||||
return {
|
||||
"sw_ver": self._info.get("fw_ver"),
|
||||
"hw_ver": self._info.get("hw_ver"),
|
||||
"mac": self._info.get("mac"),
|
||||
"type": self._info.get("type"),
|
||||
"hwId": self._info.get("device_id"),
|
||||
"dev_name": self.alias,
|
||||
"oemId": self._info.get("oem_id"),
|
||||
}
|
||||
|
||||
@property
|
||||
def location(self) -> Dict:
|
||||
"""Return the device location."""
|
||||
loc = {
|
||||
"latitude": cast(float, self._info.get("latitude")) / 10_000,
|
||||
"longitude": cast(float, self._info.get("longitude")) / 10_000,
|
||||
}
|
||||
return loc
|
||||
|
||||
@property
|
||||
def rssi(self) -> Optional[int]:
|
||||
"""Return the rssi."""
|
||||
rssi = self._info.get("rssi")
|
||||
return int(rssi) if rssi else None
|
||||
|
||||
@property
|
||||
def mac(self) -> str:
|
||||
"""Return the mac formatted with colons."""
|
||||
return str(self._info.get("mac")).replace("-", ":")
|
||||
|
||||
@property
|
||||
def device_id(self) -> str:
|
||||
"""Return the device id."""
|
||||
return str(self._info.get("device_id"))
|
||||
|
||||
@property
|
||||
def internal_state(self) -> Any:
|
||||
"""Return all the internal state data."""
|
||||
return self._data
|
||||
|
||||
async def _query_helper(
|
||||
self, target: str, cmd: str, arg: Optional[Dict] = None, child_ids=None
|
||||
) -> Any:
|
||||
res = await self.protocol.query({cmd: arg})
|
||||
|
||||
return res
|
||||
|
||||
@property
|
||||
def state_information(self) -> Dict[str, Any]:
|
||||
"""Return the key state information."""
|
||||
return {
|
||||
"overheated": self._info.get("overheated"),
|
||||
"signal_level": self._info.get("signal_level"),
|
||||
"SSID": base64.b64decode(str(self._info.get("ssid"))).decode(),
|
||||
}
|
||||
|
||||
@property
|
||||
def features(self) -> Set[str]:
|
||||
"""Return the list of supported features."""
|
||||
# TODO:
|
||||
return set()
|
||||
|
||||
@property
|
||||
def is_on(self) -> bool:
|
||||
"""Return true if the device is on."""
|
||||
return bool(self._info.get("device_on"))
|
||||
|
||||
async def turn_on(self, **kwargs):
|
||||
"""Turn on the device."""
|
||||
await self.protocol.query({"set_device_info": {"device_on": True}})
|
||||
|
||||
async def turn_off(self, **kwargs):
|
||||
"""Turn off the device."""
|
||||
await self.protocol.query({"set_device_info": {"device_on": False}})
|
||||
|
||||
def update_from_discover_info(self, info):
|
||||
"""Update state from info from the discover call."""
|
||||
self._discovery_info = info
|
73
kasa/tapo/tapoplug.py
Normal file
73
kasa/tapo/tapoplug.py
Normal file
@ -0,0 +1,73 @@
|
||||
"""Module for a TAPO Plug."""
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, Optional, cast
|
||||
|
||||
from ..credentials import Credentials
|
||||
from ..emeterstatus import EmeterStatus
|
||||
from ..smartdevice import DeviceType
|
||||
from .tapodevice import TapoDevice
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TapoPlug(TapoDevice):
|
||||
"""Class to represent a TAPO Plug."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
*,
|
||||
port: Optional[int] = None,
|
||||
credentials: Optional[Credentials] = None,
|
||||
timeout: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__(host, port=port, credentials=credentials, timeout=timeout)
|
||||
self._device_type = DeviceType.Plug
|
||||
|
||||
async def update(self, update_children: bool = True):
|
||||
"""Call the device endpoint and update the device data."""
|
||||
await super().update(update_children)
|
||||
|
||||
self._energy = await self.protocol.query("get_energy_usage")
|
||||
self._emeter = await self.protocol.query("get_current_power")
|
||||
|
||||
self._data["energy"] = self._energy
|
||||
self._data["emeter"] = self._emeter
|
||||
|
||||
_LOGGER.debug("Got an update: %s %s", self._energy, self._emeter)
|
||||
|
||||
@property
|
||||
def state_information(self) -> Dict[str, Any]:
|
||||
"""Return the key state information."""
|
||||
return {
|
||||
**super().state_information,
|
||||
**{
|
||||
"On since": self.on_since,
|
||||
"auto_off_status": self._info.get("auto_off_status"),
|
||||
"auto_off_remain_time": self._info.get("auto_off_remain_time"),
|
||||
},
|
||||
}
|
||||
|
||||
@property
|
||||
def emeter_realtime(self) -> EmeterStatus:
|
||||
"""Get the emeter status."""
|
||||
return EmeterStatus({"power_mw": self._energy.get("current_power")})
|
||||
|
||||
@property
|
||||
def emeter_today(self) -> Optional[float]:
|
||||
"""Get the emeter value for today."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def emeter_this_month(self) -> Optional[float]:
|
||||
"""Get the emeter value for this month."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def on_since(self) -> Optional[datetime]:
|
||||
"""Return the time that the device was turned on or None if turned off."""
|
||||
if not self._info.get("device_on"):
|
||||
return None
|
||||
on_time = cast(float, self._info.get("on_time"))
|
||||
return datetime.now().replace(microsecond=0) - timedelta(seconds=on_time)
|
@ -114,7 +114,7 @@ UNSUPPORTED = {
|
||||
"result": {
|
||||
"device_id": "xx",
|
||||
"owner": "xx",
|
||||
"device_type": "SMART.TAPOPLUG",
|
||||
"device_type": "SMART.TAPOXMASTREE",
|
||||
"device_model": "P110(EU)",
|
||||
"ip": "127.0.0.1",
|
||||
"mac": "48-22xxx",
|
||||
@ -150,7 +150,7 @@ async def test_discover_single_unsupported(mocker):
|
||||
discovery_data = UNSUPPORTED
|
||||
with pytest.raises(
|
||||
UnsupportedDeviceException,
|
||||
match=f"Unsupported device {host}: {re.escape(str(UNSUPPORTED))}",
|
||||
match=f"Unsupported device {host} of type SMART.TAPOXMASTREE: {re.escape(str(UNSUPPORTED))}",
|
||||
):
|
||||
await Discover.discover_single(host)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user