mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-01-08 22:07:06 +00:00
Add klap support for TAPO protocol by splitting out Transports and Protocols (#557)
* Add support for TAPO/SMART KLAP and seperate transports from protocols * Add tests and some review changes * Update following review * Updates following review
This commit is contained in:
parent
347cbfe3bd
commit
4a00199506
@ -21,13 +21,14 @@ from kasa.exceptions import (
|
|||||||
SmartDeviceException,
|
SmartDeviceException,
|
||||||
UnsupportedDeviceException,
|
UnsupportedDeviceException,
|
||||||
)
|
)
|
||||||
from kasa.klapprotocol import TPLinkKlap
|
from kasa.iotprotocol import IotProtocol
|
||||||
from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol
|
from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol
|
||||||
from kasa.smartbulb import SmartBulb, SmartBulbPreset, TurnOnBehavior, TurnOnBehaviors
|
from kasa.smartbulb import SmartBulb, SmartBulbPreset, TurnOnBehavior, TurnOnBehaviors
|
||||||
from kasa.smartdevice import DeviceType, SmartDevice
|
from kasa.smartdevice import DeviceType, SmartDevice
|
||||||
from kasa.smartdimmer import SmartDimmer
|
from kasa.smartdimmer import SmartDimmer
|
||||||
from kasa.smartlightstrip import SmartLightStrip
|
from kasa.smartlightstrip import SmartLightStrip
|
||||||
from kasa.smartplug import SmartPlug
|
from kasa.smartplug import SmartPlug
|
||||||
|
from kasa.smartprotocol import SmartProtocol
|
||||||
from kasa.smartstrip import SmartStrip
|
from kasa.smartstrip import SmartStrip
|
||||||
|
|
||||||
__version__ = version("python-kasa")
|
__version__ = version("python-kasa")
|
||||||
@ -37,7 +38,8 @@ __all__ = [
|
|||||||
"Discover",
|
"Discover",
|
||||||
"TPLinkSmartHomeProtocol",
|
"TPLinkSmartHomeProtocol",
|
||||||
"TPLinkProtocol",
|
"TPLinkProtocol",
|
||||||
"TPLinkKlap",
|
"IotProtocol",
|
||||||
|
"SmartProtocol",
|
||||||
"SmartBulb",
|
"SmartBulb",
|
||||||
"SmartBulbPreset",
|
"SmartBulbPreset",
|
||||||
"TurnOnBehaviors",
|
"TurnOnBehaviors",
|
||||||
|
@ -1,498 +0,0 @@
|
|||||||
"""Implementation of the TP-Link AES Protocol.
|
|
||||||
|
|
||||||
Based on the work of https://github.com/petretiandrea/plugp100
|
|
||||||
under compatible GNU GPL3 license.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import base64
|
|
||||||
import hashlib
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from pprint import pformat as pf
|
|
||||||
from typing import Dict, Optional, Union
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
from cryptography.hazmat.primitives import hashes, padding, serialization
|
|
||||||
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
|
|
||||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
|
||||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
|
||||||
|
|
||||||
from .credentials import Credentials
|
|
||||||
from .exceptions import AuthenticationException, SmartDeviceException
|
|
||||||
from .json import dumps as json_dumps
|
|
||||||
from .json import loads as json_loads
|
|
||||||
from .protocol import TPLinkProtocol
|
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
|
||||||
logging.getLogger("httpx").propagate = False
|
|
||||||
|
|
||||||
|
|
||||||
def _md5(payload: bytes) -> bytes:
|
|
||||||
digest = hashes.Hash(hashes.MD5()) # noqa: S303
|
|
||||||
digest.update(payload)
|
|
||||||
hash = digest.finalize()
|
|
||||||
return hash
|
|
||||||
|
|
||||||
|
|
||||||
def _sha1(payload: bytes) -> str:
|
|
||||||
sha1_algo = hashlib.sha1() # noqa: S324
|
|
||||||
sha1_algo.update(payload)
|
|
||||||
return sha1_algo.hexdigest()
|
|
||||||
|
|
||||||
|
|
||||||
class TPLinkAes(TPLinkProtocol):
|
|
||||||
"""Implementation of the AES encryption protocol.
|
|
||||||
|
|
||||||
AES is the name used in device discovery for TP-Link's TAPO encryption
|
|
||||||
protocol, sometimes used by newer firmware versions on kasa devices.
|
|
||||||
"""
|
|
||||||
|
|
||||||
DEFAULT_PORT = 80
|
|
||||||
DEFAULT_TIMEOUT = 5
|
|
||||||
SESSION_COOKIE_NAME = "TP_SESSIONID"
|
|
||||||
COMMON_HEADERS = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"requestByApp": "true",
|
|
||||||
"Accept": "application/json",
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
host: str,
|
|
||||||
*,
|
|
||||||
credentials: Optional[Credentials] = None,
|
|
||||||
timeout: Optional[int] = None,
|
|
||||||
) -> None:
|
|
||||||
super().__init__(host=host, port=self.DEFAULT_PORT)
|
|
||||||
|
|
||||||
self.credentials = (
|
|
||||||
credentials
|
|
||||||
if credentials and credentials.username and credentials.password
|
|
||||||
else Credentials(username="", password="")
|
|
||||||
)
|
|
||||||
|
|
||||||
self._local_seed: Optional[bytes] = None
|
|
||||||
self.local_auth_hash = self.generate_auth_hash(self.credentials)
|
|
||||||
self.local_auth_owner = self.generate_owner_hash(self.credentials).hex()
|
|
||||||
self.kasa_setup_auth_hash = None
|
|
||||||
self.blank_auth_hash = None
|
|
||||||
self.handshake_lock = asyncio.Lock()
|
|
||||||
self.query_lock = asyncio.Lock()
|
|
||||||
self.handshake_done = False
|
|
||||||
|
|
||||||
self.encryption_session: Optional[AesEncyptionSession] = None
|
|
||||||
self.session_expire_at: Optional[float] = None
|
|
||||||
|
|
||||||
self.timeout = timeout if timeout else self.DEFAULT_TIMEOUT
|
|
||||||
self.session_cookie = None
|
|
||||||
self.terminal_uuid = None
|
|
||||||
self.http_client: Optional[httpx.AsyncClient] = None
|
|
||||||
self.request_id_generator = SnowflakeId(1, 1)
|
|
||||||
self.login_token = None
|
|
||||||
|
|
||||||
_LOGGER.debug("Created AES object for %s", self.host)
|
|
||||||
|
|
||||||
def hash_credentials(self, credentials, try_login_version2):
|
|
||||||
"""Hash the credentials."""
|
|
||||||
if try_login_version2:
|
|
||||||
un = base64.b64encode(
|
|
||||||
_sha1(credentials.username.encode()).encode()
|
|
||||||
).decode()
|
|
||||||
pw = base64.b64encode(
|
|
||||||
_sha1(credentials.password.encode()).encode()
|
|
||||||
).decode()
|
|
||||||
else:
|
|
||||||
un = base64.b64encode(
|
|
||||||
_sha1(credentials.username.encode()).encode()
|
|
||||||
).decode()
|
|
||||||
pw = base64.b64encode(credentials.password.encode()).decode()
|
|
||||||
return un, pw
|
|
||||||
|
|
||||||
async def client_post(self, url, params=None, data=None, json=None, headers=None):
|
|
||||||
"""Send an http post request to the device."""
|
|
||||||
response_data = None
|
|
||||||
cookies = None
|
|
||||||
if self.session_cookie:
|
|
||||||
cookies = httpx.Cookies()
|
|
||||||
cookies.set(self.SESSION_COOKIE_NAME, self.session_cookie)
|
|
||||||
self.http_client.cookies.clear()
|
|
||||||
resp = await self.http_client.post(
|
|
||||||
url,
|
|
||||||
params=params,
|
|
||||||
data=data,
|
|
||||||
json=json,
|
|
||||||
timeout=self.timeout,
|
|
||||||
cookies=cookies,
|
|
||||||
headers=self.COMMON_HEADERS,
|
|
||||||
)
|
|
||||||
if resp.status_code == 200:
|
|
||||||
response_data = resp.json()
|
|
||||||
|
|
||||||
return resp.status_code, response_data
|
|
||||||
|
|
||||||
async def send_secure_passthrough(self, request):
|
|
||||||
"""Send encrypted message as passthrough."""
|
|
||||||
url = f"http://{self.host}/app"
|
|
||||||
if self.login_token:
|
|
||||||
url += f"?token={self.login_token}"
|
|
||||||
raw_request = json_dumps(request)
|
|
||||||
encrypted_payload = self.encryption_session.encrypt(raw_request.encode())
|
|
||||||
passthrough_request = {
|
|
||||||
"method": "securePassthrough",
|
|
||||||
"params": {"request": encrypted_payload.decode()},
|
|
||||||
}
|
|
||||||
status_code, resp_dict = await self.client_post(url, json=passthrough_request)
|
|
||||||
if status_code == 200 and resp_dict["error_code"] == 0:
|
|
||||||
response = self.encryption_session.decrypt(
|
|
||||||
resp_dict["result"]["response"].encode()
|
|
||||||
)
|
|
||||||
resp_dict = json_loads(response)
|
|
||||||
if resp_dict["error_code"] != 0:
|
|
||||||
raise SmartDeviceException(
|
|
||||||
f"Could not complete send, response was {resp_dict}",
|
|
||||||
)
|
|
||||||
if "result" in resp_dict:
|
|
||||||
return resp_dict["result"]
|
|
||||||
else:
|
|
||||||
raise AuthenticationException("Could not complete send")
|
|
||||||
|
|
||||||
def get_aes_request(self, method, params=None):
|
|
||||||
"""Get a request message."""
|
|
||||||
request = {
|
|
||||||
"method": method,
|
|
||||||
"params": params,
|
|
||||||
"requestID": self.request_id_generator.generate_id(),
|
|
||||||
"request_time_milis": round(time.time() * 1000),
|
|
||||||
"terminal_uuid": self.terminal_uuid,
|
|
||||||
}
|
|
||||||
return request
|
|
||||||
|
|
||||||
async def perform_login(self, login_v2):
|
|
||||||
"""Login to the device."""
|
|
||||||
self.login_token = None
|
|
||||||
|
|
||||||
un, pw = self.hash_credentials(self.credentials, login_v2)
|
|
||||||
params = {"password": pw, "username": un}
|
|
||||||
request = self.get_aes_request("login_device", params)
|
|
||||||
try:
|
|
||||||
result = await self.send_secure_passthrough(request)
|
|
||||||
except SmartDeviceException as ex:
|
|
||||||
raise AuthenticationException(ex) from ex
|
|
||||||
self.login_token = result["token"]
|
|
||||||
|
|
||||||
async def perform_handshake(self):
|
|
||||||
"""Perform the handshake."""
|
|
||||||
_LOGGER.debug("Will perform handshaking...")
|
|
||||||
_LOGGER.debug("Generating keypair")
|
|
||||||
|
|
||||||
self.handshake_done = False
|
|
||||||
self.session_expire_at = None
|
|
||||||
self.session_cookie = None
|
|
||||||
|
|
||||||
url = f"http://{self.host}/app"
|
|
||||||
key_pair = KeyPair.create_key_pair()
|
|
||||||
|
|
||||||
pub_key = (
|
|
||||||
"-----BEGIN PUBLIC KEY-----\n"
|
|
||||||
+ key_pair.get_public_key()
|
|
||||||
+ "\n-----END PUBLIC KEY-----\n"
|
|
||||||
)
|
|
||||||
handshake_params = {"key": pub_key}
|
|
||||||
_LOGGER.debug(f"Handshake params: {handshake_params}")
|
|
||||||
|
|
||||||
request_body = {"method": "handshake", "params": handshake_params}
|
|
||||||
|
|
||||||
_LOGGER.debug(f"Request {request_body}")
|
|
||||||
|
|
||||||
status_code, resp_dict = await self.client_post(url, json=request_body)
|
|
||||||
|
|
||||||
_LOGGER.debug(f"Device responded with: {resp_dict}")
|
|
||||||
|
|
||||||
if status_code == 200 and resp_dict["error_code"] == 0:
|
|
||||||
_LOGGER.debug("Decoding handshake key...")
|
|
||||||
handshake_key = resp_dict["result"]["key"]
|
|
||||||
|
|
||||||
self.session_cookie = self.http_client.cookies.get( # type: ignore
|
|
||||||
self.SESSION_COOKIE_NAME
|
|
||||||
)
|
|
||||||
if not self.session_cookie:
|
|
||||||
self.session_cookie = self.http_client.cookies.get( # type: ignore
|
|
||||||
"SESSIONID"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.session_expire_at = time.time() + 86400
|
|
||||||
self.encryption_session = AesEncyptionSession.create_from_keypair(
|
|
||||||
handshake_key, key_pair
|
|
||||||
)
|
|
||||||
|
|
||||||
self.terminal_uuid = base64.b64encode(_md5(uuid.uuid4().bytes)).decode(
|
|
||||||
"UTF-8"
|
|
||||||
)
|
|
||||||
self.handshake_done = True
|
|
||||||
|
|
||||||
_LOGGER.debug("Handshake with %s complete", self.host)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise AuthenticationException("Could not complete handshake")
|
|
||||||
|
|
||||||
def handshake_session_expired(self):
|
|
||||||
"""Return true if session has expired."""
|
|
||||||
return (
|
|
||||||
self.session_expire_at is None or self.session_expire_at - time.time() <= 0
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def generate_auth_hash(creds: Credentials):
|
|
||||||
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
|
|
||||||
un = creds.username or ""
|
|
||||||
pw = creds.password or ""
|
|
||||||
return _md5(_md5(un.encode()) + _md5(pw.encode()))
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def generate_owner_hash(creds: Credentials):
|
|
||||||
"""Return the MD5 hash of the username in this object."""
|
|
||||||
un = creds.username or ""
|
|
||||||
return _md5(un.encode())
|
|
||||||
|
|
||||||
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
|
||||||
"""Query the device retrying for retry_count on failure."""
|
|
||||||
async with self.query_lock:
|
|
||||||
return await self._query(request, retry_count)
|
|
||||||
|
|
||||||
async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
|
||||||
for retry in range(retry_count + 1):
|
|
||||||
try:
|
|
||||||
return await self._execute_query(request, retry)
|
|
||||||
except httpx.CloseError as sdex:
|
|
||||||
await self.close()
|
|
||||||
if retry >= retry_count:
|
|
||||||
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
|
||||||
raise SmartDeviceException(
|
|
||||||
f"Unable to connect to the device: {self.host}: {sdex}"
|
|
||||||
) from sdex
|
|
||||||
continue
|
|
||||||
except httpx.ConnectError as cex:
|
|
||||||
await self.close()
|
|
||||||
raise SmartDeviceException(
|
|
||||||
f"Unable to connect to the device: {self.host}: {cex}"
|
|
||||||
) from cex
|
|
||||||
except TimeoutError as tex:
|
|
||||||
await self.close()
|
|
||||||
raise SmartDeviceException(
|
|
||||||
f"Unable to connect to the device, timed out: {self.host}: {tex}"
|
|
||||||
) from tex
|
|
||||||
except AuthenticationException as auex:
|
|
||||||
_LOGGER.debug("Unable to authenticate with %s, not retrying", self.host)
|
|
||||||
raise auex
|
|
||||||
except Exception as ex:
|
|
||||||
await self.close()
|
|
||||||
if retry >= retry_count:
|
|
||||||
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
|
||||||
raise SmartDeviceException(
|
|
||||||
f"Unable to connect to the device: {self.host}: {ex}"
|
|
||||||
) from ex
|
|
||||||
continue
|
|
||||||
|
|
||||||
# make mypy happy, this should never be reached..
|
|
||||||
raise SmartDeviceException("Query reached somehow to unreachable")
|
|
||||||
|
|
||||||
async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> Dict:
|
|
||||||
_LOGGER.debug(
|
|
||||||
"%s >> %s",
|
|
||||||
self.host,
|
|
||||||
_LOGGER.isEnabledFor(logging.DEBUG) and pf(request),
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self.http_client:
|
|
||||||
self.http_client = httpx.AsyncClient()
|
|
||||||
|
|
||||||
if not self.handshake_done or self.handshake_session_expired():
|
|
||||||
try:
|
|
||||||
await self.perform_handshake()
|
|
||||||
await self.perform_login(False)
|
|
||||||
except AuthenticationException:
|
|
||||||
await self.perform_handshake()
|
|
||||||
await self.perform_login(True)
|
|
||||||
|
|
||||||
if isinstance(request, dict):
|
|
||||||
aes_method = next(iter(request))
|
|
||||||
aes_params = request[aes_method]
|
|
||||||
else:
|
|
||||||
aes_method = request
|
|
||||||
aes_params = None
|
|
||||||
|
|
||||||
aes_request = self.get_aes_request(aes_method, aes_params)
|
|
||||||
response_data = await self.send_secure_passthrough(aes_request)
|
|
||||||
|
|
||||||
_LOGGER.debug(
|
|
||||||
"%s << %s",
|
|
||||||
self.host,
|
|
||||||
_LOGGER.isEnabledFor(logging.DEBUG) and pf(response_data),
|
|
||||||
)
|
|
||||||
|
|
||||||
return response_data
|
|
||||||
|
|
||||||
async def close(self) -> None:
|
|
||||||
"""Close the protocol."""
|
|
||||||
client = self.http_client
|
|
||||||
self.http_client = None
|
|
||||||
if client:
|
|
||||||
await client.aclose()
|
|
||||||
|
|
||||||
|
|
||||||
class AesEncyptionSession:
|
|
||||||
"""Class for an AES encryption session."""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_from_keypair(handshake_key: str, keypair):
|
|
||||||
"""Create the encryption session."""
|
|
||||||
handshake_key_bytes: bytes = base64.b64decode(handshake_key.encode("UTF-8"))
|
|
||||||
private_key_data = base64.b64decode(keypair.get_private_key().encode("UTF-8"))
|
|
||||||
|
|
||||||
private_key = serialization.load_der_private_key(private_key_data, None, None)
|
|
||||||
key_and_iv = private_key.decrypt(
|
|
||||||
handshake_key_bytes, asymmetric_padding.PKCS1v15()
|
|
||||||
)
|
|
||||||
if key_and_iv is None:
|
|
||||||
raise ValueError("Decryption failed!")
|
|
||||||
|
|
||||||
return AesEncyptionSession(key_and_iv[:16], key_and_iv[16:])
|
|
||||||
|
|
||||||
def __init__(self, key, iv):
|
|
||||||
self.cipher = Cipher(algorithms.AES(key), modes.CBC(iv))
|
|
||||||
self.padding_strategy = padding.PKCS7(algorithms.AES.block_size)
|
|
||||||
|
|
||||||
def encrypt(self, data) -> bytes:
|
|
||||||
"""Encrypt the message."""
|
|
||||||
encryptor = self.cipher.encryptor()
|
|
||||||
padder = self.padding_strategy.padder()
|
|
||||||
padded_data = padder.update(data) + padder.finalize()
|
|
||||||
encrypted = encryptor.update(padded_data) + encryptor.finalize()
|
|
||||||
return base64.b64encode(encrypted)
|
|
||||||
|
|
||||||
def decrypt(self, data) -> str:
|
|
||||||
"""Decrypt the message."""
|
|
||||||
decryptor = self.cipher.decryptor()
|
|
||||||
unpadder = self.padding_strategy.unpadder()
|
|
||||||
decrypted = decryptor.update(base64.b64decode(data)) + decryptor.finalize()
|
|
||||||
unpadded_data = unpadder.update(decrypted) + unpadder.finalize()
|
|
||||||
return unpadded_data.decode()
|
|
||||||
|
|
||||||
|
|
||||||
class KeyPair:
|
|
||||||
"""Class for generating key pairs."""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_key_pair(key_size: int = 1024):
|
|
||||||
"""Create a key pair."""
|
|
||||||
private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size)
|
|
||||||
public_key = private_key.public_key()
|
|
||||||
|
|
||||||
private_key_bytes = private_key.private_bytes(
|
|
||||||
encoding=serialization.Encoding.DER,
|
|
||||||
format=serialization.PrivateFormat.PKCS8,
|
|
||||||
encryption_algorithm=serialization.NoEncryption(),
|
|
||||||
)
|
|
||||||
public_key_bytes = public_key.public_bytes(
|
|
||||||
encoding=serialization.Encoding.DER,
|
|
||||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
|
||||||
)
|
|
||||||
|
|
||||||
return KeyPair(
|
|
||||||
private_key=base64.b64encode(private_key_bytes).decode("UTF-8"),
|
|
||||||
public_key=base64.b64encode(public_key_bytes).decode("UTF-8"),
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self, private_key: str, public_key: str):
|
|
||||||
self.private_key = private_key
|
|
||||||
self.public_key = public_key
|
|
||||||
|
|
||||||
def get_private_key(self) -> str:
|
|
||||||
"""Get the private key."""
|
|
||||||
return self.private_key
|
|
||||||
|
|
||||||
def get_public_key(self) -> str:
|
|
||||||
"""Get the public key."""
|
|
||||||
return self.public_key
|
|
||||||
|
|
||||||
|
|
||||||
class SnowflakeId:
|
|
||||||
"""Class for generating snowflake ids."""
|
|
||||||
|
|
||||||
EPOCH = 1420041600000 # Custom epoch (in milliseconds)
|
|
||||||
WORKER_ID_BITS = 5
|
|
||||||
DATA_CENTER_ID_BITS = 5
|
|
||||||
SEQUENCE_BITS = 12
|
|
||||||
|
|
||||||
MAX_WORKER_ID = (1 << WORKER_ID_BITS) - 1
|
|
||||||
MAX_DATA_CENTER_ID = (1 << DATA_CENTER_ID_BITS) - 1
|
|
||||||
|
|
||||||
SEQUENCE_MASK = (1 << SEQUENCE_BITS) - 1
|
|
||||||
|
|
||||||
def __init__(self, worker_id, data_center_id):
|
|
||||||
if worker_id > SnowflakeId.MAX_WORKER_ID or worker_id < 0:
|
|
||||||
raise ValueError(
|
|
||||||
"Worker ID can't be greater than "
|
|
||||||
+ str(SnowflakeId.MAX_WORKER_ID)
|
|
||||||
+ " or less than 0"
|
|
||||||
)
|
|
||||||
if data_center_id > SnowflakeId.MAX_DATA_CENTER_ID or data_center_id < 0:
|
|
||||||
raise ValueError(
|
|
||||||
"Data center ID can't be greater than "
|
|
||||||
+ str(SnowflakeId.MAX_DATA_CENTER_ID)
|
|
||||||
+ " or less than 0"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.worker_id = worker_id
|
|
||||||
self.data_center_id = data_center_id
|
|
||||||
self.sequence = 0
|
|
||||||
self.last_timestamp = -1
|
|
||||||
|
|
||||||
def generate_id(self):
|
|
||||||
"""Generate a snowflake id."""
|
|
||||||
timestamp = self._current_millis()
|
|
||||||
|
|
||||||
if timestamp < self.last_timestamp:
|
|
||||||
raise ValueError("Clock moved backwards. Refusing to generate ID.")
|
|
||||||
|
|
||||||
if timestamp == self.last_timestamp:
|
|
||||||
# Within the same millisecond, increment the sequence number
|
|
||||||
self.sequence = (self.sequence + 1) & SnowflakeId.SEQUENCE_MASK
|
|
||||||
if self.sequence == 0:
|
|
||||||
# Sequence exceeds its bit range, wait until the next millisecond
|
|
||||||
timestamp = self._wait_next_millis(self.last_timestamp)
|
|
||||||
else:
|
|
||||||
# New millisecond, reset the sequence number
|
|
||||||
self.sequence = 0
|
|
||||||
|
|
||||||
# Update the last timestamp
|
|
||||||
self.last_timestamp = timestamp
|
|
||||||
|
|
||||||
# Generate and return the final ID
|
|
||||||
return (
|
|
||||||
(
|
|
||||||
(timestamp - SnowflakeId.EPOCH)
|
|
||||||
<< (
|
|
||||||
SnowflakeId.WORKER_ID_BITS
|
|
||||||
+ SnowflakeId.SEQUENCE_BITS
|
|
||||||
+ SnowflakeId.DATA_CENTER_ID_BITS
|
|
||||||
)
|
|
||||||
)
|
|
||||||
| (
|
|
||||||
self.data_center_id
|
|
||||||
<< (SnowflakeId.SEQUENCE_BITS + SnowflakeId.WORKER_ID_BITS)
|
|
||||||
)
|
|
||||||
| (self.worker_id << SnowflakeId.SEQUENCE_BITS)
|
|
||||||
| self.sequence
|
|
||||||
)
|
|
||||||
|
|
||||||
def _current_millis(self):
|
|
||||||
return round(time.time() * 1000)
|
|
||||||
|
|
||||||
def _wait_next_millis(self, last_timestamp):
|
|
||||||
timestamp = self._current_millis()
|
|
||||||
while timestamp <= last_timestamp:
|
|
||||||
timestamp = self._current_millis()
|
|
||||||
return timestamp
|
|
338
kasa/aestransport.py
Normal file
338
kasa/aestransport.py
Normal file
@ -0,0 +1,338 @@
|
|||||||
|
"""Implementation of the TP-Link AES transport.
|
||||||
|
|
||||||
|
Based on the work of https://github.com/petretiandrea/plugp100
|
||||||
|
under compatible GNU GPL3 license.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from cryptography.hazmat.primitives import padding, serialization
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||||
|
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||||
|
|
||||||
|
from .credentials import Credentials
|
||||||
|
from .exceptions import AuthenticationException, SmartDeviceException
|
||||||
|
from .json import dumps as json_dumps
|
||||||
|
from .json import loads as json_loads
|
||||||
|
from .protocol import BaseTransport
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _sha1(payload: bytes) -> str:
|
||||||
|
sha1_algo = hashlib.sha1() # noqa: S324
|
||||||
|
sha1_algo.update(payload)
|
||||||
|
return sha1_algo.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
class AesTransport(BaseTransport):
|
||||||
|
"""Implementation of the AES encryption protocol.
|
||||||
|
|
||||||
|
AES is the name used in device discovery for TP-Link's TAPO encryption
|
||||||
|
protocol, sometimes used by newer firmware versions on kasa devices.
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_TIMEOUT = 5
|
||||||
|
SESSION_COOKIE_NAME = "TP_SESSIONID"
|
||||||
|
COMMON_HEADERS = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"requestByApp": "true",
|
||||||
|
"Accept": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
host: str,
|
||||||
|
*,
|
||||||
|
credentials: Optional[Credentials] = None,
|
||||||
|
timeout: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(host=host)
|
||||||
|
|
||||||
|
self._credentials = credentials or Credentials(username="", password="")
|
||||||
|
|
||||||
|
self._handshake_done = False
|
||||||
|
|
||||||
|
self._encryption_session: Optional[AesEncyptionSession] = None
|
||||||
|
self._session_expire_at: Optional[float] = None
|
||||||
|
|
||||||
|
self._timeout = timeout if timeout else self.DEFAULT_TIMEOUT
|
||||||
|
self._session_cookie = None
|
||||||
|
|
||||||
|
self._http_client: httpx.AsyncClient = httpx.AsyncClient()
|
||||||
|
self._login_token = None
|
||||||
|
|
||||||
|
_LOGGER.debug("Created AES object for %s", self.host)
|
||||||
|
|
||||||
|
def hash_credentials(self, login_v2):
|
||||||
|
"""Hash the credentials."""
|
||||||
|
if login_v2:
|
||||||
|
un = base64.b64encode(
|
||||||
|
_sha1(self._credentials.username.encode()).encode()
|
||||||
|
).decode()
|
||||||
|
pw = base64.b64encode(
|
||||||
|
_sha1(self._credentials.password.encode()).encode()
|
||||||
|
).decode()
|
||||||
|
else:
|
||||||
|
un = base64.b64encode(
|
||||||
|
_sha1(self._credentials.username.encode()).encode()
|
||||||
|
).decode()
|
||||||
|
pw = base64.b64encode(self._credentials.password.encode()).decode()
|
||||||
|
return un, pw
|
||||||
|
|
||||||
|
async def client_post(self, url, params=None, data=None, json=None, headers=None):
|
||||||
|
"""Send an http post request to the device."""
|
||||||
|
response_data = None
|
||||||
|
cookies = None
|
||||||
|
if self._session_cookie:
|
||||||
|
cookies = httpx.Cookies()
|
||||||
|
cookies.set(self.SESSION_COOKIE_NAME, self._session_cookie)
|
||||||
|
self._http_client.cookies.clear()
|
||||||
|
resp = await self._http_client.post(
|
||||||
|
url,
|
||||||
|
params=params,
|
||||||
|
data=data,
|
||||||
|
json=json,
|
||||||
|
timeout=self._timeout,
|
||||||
|
cookies=cookies,
|
||||||
|
headers=self.COMMON_HEADERS,
|
||||||
|
)
|
||||||
|
if resp.status_code == 200:
|
||||||
|
response_data = resp.json()
|
||||||
|
|
||||||
|
return resp.status_code, response_data
|
||||||
|
|
||||||
|
async def send_secure_passthrough(self, request: str):
|
||||||
|
"""Send encrypted message as passthrough."""
|
||||||
|
url = f"http://{self.host}/app"
|
||||||
|
if self._login_token:
|
||||||
|
url += f"?token={self._login_token}"
|
||||||
|
|
||||||
|
encrypted_payload = self._encryption_session.encrypt(request.encode()) # type: ignore
|
||||||
|
passthrough_request = {
|
||||||
|
"method": "securePassthrough",
|
||||||
|
"params": {"request": encrypted_payload.decode()},
|
||||||
|
}
|
||||||
|
status_code, resp_dict = await self.client_post(url, json=passthrough_request)
|
||||||
|
_LOGGER.debug(f"secure_passthrough response is {status_code}: {resp_dict}")
|
||||||
|
if status_code == 200 and resp_dict["error_code"] == 0:
|
||||||
|
response = self._encryption_session.decrypt( # type: ignore
|
||||||
|
resp_dict["result"]["response"].encode()
|
||||||
|
)
|
||||||
|
_LOGGER.debug(f"decrypted secure_passthrough response is {response}")
|
||||||
|
resp_dict = json_loads(response)
|
||||||
|
return resp_dict
|
||||||
|
else:
|
||||||
|
self._handshake_done = False
|
||||||
|
self._login_token = None
|
||||||
|
raise AuthenticationException("Could not complete send")
|
||||||
|
|
||||||
|
async def perform_login(self, login_request: Union[str, dict], *, login_v2: bool):
|
||||||
|
"""Login to the device."""
|
||||||
|
self._login_token = None
|
||||||
|
|
||||||
|
if isinstance(login_request, str):
|
||||||
|
login_request_dict: dict = json_loads(login_request)
|
||||||
|
else:
|
||||||
|
login_request_dict = login_request
|
||||||
|
|
||||||
|
un, pw = self.hash_credentials(login_v2)
|
||||||
|
login_request_dict["params"] = {"password": pw, "username": un}
|
||||||
|
request = json_dumps(login_request_dict)
|
||||||
|
try:
|
||||||
|
resp_dict = await self.send_secure_passthrough(request)
|
||||||
|
except SmartDeviceException as ex:
|
||||||
|
raise AuthenticationException(ex) from ex
|
||||||
|
self._login_token = resp_dict["result"]["token"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def needs_login(self) -> bool:
|
||||||
|
"""Return true if the transport needs to do a login."""
|
||||||
|
return self._login_token is None
|
||||||
|
|
||||||
|
async def login(self, request: str) -> None:
|
||||||
|
"""Login to the device."""
|
||||||
|
try:
|
||||||
|
if self.needs_handshake:
|
||||||
|
raise SmartDeviceException(
|
||||||
|
"Handshake must be complete before trying to login"
|
||||||
|
)
|
||||||
|
await self.perform_login(request, login_v2=False)
|
||||||
|
except AuthenticationException:
|
||||||
|
await self.perform_handshake()
|
||||||
|
await self.perform_login(request, login_v2=True)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def needs_handshake(self) -> bool:
|
||||||
|
"""Return true if the transport needs to do a handshake."""
|
||||||
|
return not self._handshake_done or self._handshake_session_expired()
|
||||||
|
|
||||||
|
async def handshake(self) -> None:
|
||||||
|
"""Perform the encryption handshake."""
|
||||||
|
await self.perform_handshake()
|
||||||
|
|
||||||
|
async def perform_handshake(self):
|
||||||
|
"""Perform the handshake."""
|
||||||
|
_LOGGER.debug("Will perform handshaking...")
|
||||||
|
_LOGGER.debug("Generating keypair")
|
||||||
|
|
||||||
|
self._handshake_done = False
|
||||||
|
self._session_expire_at = None
|
||||||
|
self._session_cookie = None
|
||||||
|
|
||||||
|
url = f"http://{self.host}/app"
|
||||||
|
key_pair = KeyPair.create_key_pair()
|
||||||
|
|
||||||
|
pub_key = (
|
||||||
|
"-----BEGIN PUBLIC KEY-----\n"
|
||||||
|
+ key_pair.get_public_key()
|
||||||
|
+ "\n-----END PUBLIC KEY-----\n"
|
||||||
|
)
|
||||||
|
handshake_params = {"key": pub_key}
|
||||||
|
_LOGGER.debug(f"Handshake params: {handshake_params}")
|
||||||
|
|
||||||
|
request_body = {"method": "handshake", "params": handshake_params}
|
||||||
|
|
||||||
|
_LOGGER.debug(f"Request {request_body}")
|
||||||
|
|
||||||
|
status_code, resp_dict = await self.client_post(url, json=request_body)
|
||||||
|
|
||||||
|
_LOGGER.debug(f"Device responded with: {resp_dict}")
|
||||||
|
|
||||||
|
if status_code == 200 and resp_dict["error_code"] == 0:
|
||||||
|
_LOGGER.debug("Decoding handshake key...")
|
||||||
|
handshake_key = resp_dict["result"]["key"]
|
||||||
|
|
||||||
|
self._session_cookie = self._http_client.cookies.get( # type: ignore
|
||||||
|
self.SESSION_COOKIE_NAME
|
||||||
|
)
|
||||||
|
if not self._session_cookie:
|
||||||
|
self._session_cookie = self._http_client.cookies.get( # type: ignore
|
||||||
|
"SESSIONID"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._session_expire_at = time.time() + 86400
|
||||||
|
self._encryption_session = AesEncyptionSession.create_from_keypair(
|
||||||
|
handshake_key, key_pair
|
||||||
|
)
|
||||||
|
|
||||||
|
self._handshake_done = True
|
||||||
|
|
||||||
|
_LOGGER.debug("Handshake with %s complete", self.host)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise AuthenticationException("Could not complete handshake")
|
||||||
|
|
||||||
|
def _handshake_session_expired(self):
|
||||||
|
"""Return true if session has expired."""
|
||||||
|
return (
|
||||||
|
self._session_expire_at is None
|
||||||
|
or self._session_expire_at - time.time() <= 0
|
||||||
|
)
|
||||||
|
|
||||||
|
async def send(self, request: str):
|
||||||
|
"""Send the request."""
|
||||||
|
if self.needs_handshake:
|
||||||
|
raise SmartDeviceException(
|
||||||
|
"Handshake must be complete before trying to send"
|
||||||
|
)
|
||||||
|
if self.needs_login:
|
||||||
|
raise SmartDeviceException("Login must be complete before trying to send")
|
||||||
|
|
||||||
|
resp_dict = await self.send_secure_passthrough(request)
|
||||||
|
if resp_dict["error_code"] != 0:
|
||||||
|
self._handshake_done = False
|
||||||
|
self._login_token = None
|
||||||
|
raise SmartDeviceException(
|
||||||
|
f"Could not complete send, response was {resp_dict}",
|
||||||
|
)
|
||||||
|
return resp_dict
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close the protocol."""
|
||||||
|
client = self._http_client
|
||||||
|
self._http_client = None
|
||||||
|
if client:
|
||||||
|
await client.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
class AesEncyptionSession:
|
||||||
|
"""Class for an AES encryption session."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_from_keypair(handshake_key: str, keypair):
|
||||||
|
"""Create the encryption session."""
|
||||||
|
handshake_key_bytes: bytes = base64.b64decode(handshake_key.encode("UTF-8"))
|
||||||
|
private_key_data = base64.b64decode(keypair.get_private_key().encode("UTF-8"))
|
||||||
|
|
||||||
|
private_key = serialization.load_der_private_key(private_key_data, None, None)
|
||||||
|
key_and_iv = private_key.decrypt(
|
||||||
|
handshake_key_bytes, asymmetric_padding.PKCS1v15()
|
||||||
|
)
|
||||||
|
if key_and_iv is None:
|
||||||
|
raise ValueError("Decryption failed!")
|
||||||
|
|
||||||
|
return AesEncyptionSession(key_and_iv[:16], key_and_iv[16:])
|
||||||
|
|
||||||
|
def __init__(self, key, iv):
|
||||||
|
self.cipher = Cipher(algorithms.AES(key), modes.CBC(iv))
|
||||||
|
self.padding_strategy = padding.PKCS7(algorithms.AES.block_size)
|
||||||
|
|
||||||
|
def encrypt(self, data) -> bytes:
|
||||||
|
"""Encrypt the message."""
|
||||||
|
encryptor = self.cipher.encryptor()
|
||||||
|
padder = self.padding_strategy.padder()
|
||||||
|
padded_data = padder.update(data) + padder.finalize()
|
||||||
|
encrypted = encryptor.update(padded_data) + encryptor.finalize()
|
||||||
|
return base64.b64encode(encrypted)
|
||||||
|
|
||||||
|
def decrypt(self, data) -> str:
|
||||||
|
"""Decrypt the message."""
|
||||||
|
decryptor = self.cipher.decryptor()
|
||||||
|
unpadder = self.padding_strategy.unpadder()
|
||||||
|
decrypted = decryptor.update(base64.b64decode(data)) + decryptor.finalize()
|
||||||
|
unpadded_data = unpadder.update(decrypted) + unpadder.finalize()
|
||||||
|
return unpadded_data.decode()
|
||||||
|
|
||||||
|
|
||||||
|
class KeyPair:
|
||||||
|
"""Class for generating key pairs."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_key_pair(key_size: int = 1024):
|
||||||
|
"""Create a key pair."""
|
||||||
|
private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size)
|
||||||
|
public_key = private_key.public_key()
|
||||||
|
|
||||||
|
private_key_bytes = private_key.private_bytes(
|
||||||
|
encoding=serialization.Encoding.DER,
|
||||||
|
format=serialization.PrivateFormat.PKCS8,
|
||||||
|
encryption_algorithm=serialization.NoEncryption(),
|
||||||
|
)
|
||||||
|
public_key_bytes = public_key.public_bytes(
|
||||||
|
encoding=serialization.Encoding.DER,
|
||||||
|
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||||
|
)
|
||||||
|
|
||||||
|
return KeyPair(
|
||||||
|
private_key=base64.b64encode(private_key_bytes).decode("UTF-8"),
|
||||||
|
public_key=base64.b64encode(public_key_bytes).decode("UTF-8"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, private_key: str, public_key: str):
|
||||||
|
self.private_key = private_key
|
||||||
|
self.public_key = public_key
|
||||||
|
|
||||||
|
def get_private_key(self) -> str:
|
||||||
|
"""Get the private key."""
|
||||||
|
return self.private_key
|
||||||
|
|
||||||
|
def get_public_key(self) -> str:
|
||||||
|
"""Get the public key."""
|
||||||
|
return self.public_key
|
@ -2,17 +2,21 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Any, Dict, Optional, Type
|
from typing import Any, Dict, Optional, Tuple, Type
|
||||||
|
|
||||||
|
from .aestransport import AesTransport
|
||||||
from .credentials import Credentials
|
from .credentials import Credentials
|
||||||
from .device_type import DeviceType
|
from .device_type import DeviceType
|
||||||
from .exceptions import UnsupportedDeviceException
|
from .exceptions import UnsupportedDeviceException
|
||||||
from .protocol import TPLinkProtocol
|
from .iotprotocol import IotProtocol
|
||||||
|
from .klaptransport import KlapTransport, TPlinkKlapTransportV2
|
||||||
|
from .protocol import BaseTransport, TPLinkProtocol
|
||||||
from .smartbulb import SmartBulb
|
from .smartbulb import SmartBulb
|
||||||
from .smartdevice import SmartDevice, SmartDeviceException
|
from .smartdevice import SmartDevice, SmartDeviceException
|
||||||
from .smartdimmer import SmartDimmer
|
from .smartdimmer import SmartDimmer
|
||||||
from .smartlightstrip import SmartLightStrip
|
from .smartlightstrip import SmartLightStrip
|
||||||
from .smartplug import SmartPlug
|
from .smartplug import SmartPlug
|
||||||
|
from .smartprotocol import SmartProtocol
|
||||||
from .smartstrip import SmartStrip
|
from .smartstrip import SmartStrip
|
||||||
from .tapo.tapoplug import TapoPlug
|
from .tapo.tapoplug import TapoPlug
|
||||||
|
|
||||||
@ -87,7 +91,7 @@ async def connect(
|
|||||||
if protocol_class is not None:
|
if protocol_class is not None:
|
||||||
unknown_dev.protocol = protocol_class(host, credentials=credentials)
|
unknown_dev.protocol = protocol_class(host, credentials=credentials)
|
||||||
await unknown_dev.update()
|
await unknown_dev.update()
|
||||||
device_class = get_device_class_from_info(unknown_dev.internal_state)
|
device_class = get_device_class_from_sys_info(unknown_dev.internal_state)
|
||||||
dev = device_class(host=host, port=port, credentials=credentials, timeout=timeout)
|
dev = device_class(host=host, port=port, credentials=credentials, timeout=timeout)
|
||||||
# Reuse the connection from the unknown device
|
# Reuse the connection from the unknown device
|
||||||
# so we don't have to reconnect
|
# so we don't have to reconnect
|
||||||
@ -104,7 +108,7 @@ async def connect(
|
|||||||
return dev
|
return dev
|
||||||
|
|
||||||
|
|
||||||
def get_device_class_from_info(info: Dict[str, Any]) -> Type[SmartDevice]:
|
def get_device_class_from_sys_info(info: Dict[str, Any]) -> Type[SmartDevice]:
|
||||||
"""Find SmartDevice subclass for device described by passed data."""
|
"""Find SmartDevice subclass for device described by passed data."""
|
||||||
if "system" not in info or "get_sysinfo" not in info["system"]:
|
if "system" not in info or "get_sysinfo" not in info["system"]:
|
||||||
raise SmartDeviceException("No 'system' or 'get_sysinfo' in response")
|
raise SmartDeviceException("No 'system' or 'get_sysinfo' in response")
|
||||||
@ -129,3 +133,35 @@ def get_device_class_from_info(info: Dict[str, Any]) -> Type[SmartDevice]:
|
|||||||
|
|
||||||
return SmartBulb
|
return SmartBulb
|
||||||
raise UnsupportedDeviceException("Unknown device type: %s" % type_)
|
raise UnsupportedDeviceException("Unknown device type: %s" % type_)
|
||||||
|
|
||||||
|
|
||||||
|
def get_device_class_from_type_name(device_type: str) -> Optional[Type[SmartDevice]]:
|
||||||
|
"""Return the device class from the type name."""
|
||||||
|
supported_device_types: dict[str, Type[SmartDevice]] = {
|
||||||
|
"SMART.TAPOPLUG": TapoPlug,
|
||||||
|
"SMART.KASAPLUG": TapoPlug,
|
||||||
|
"IOT.SMARTPLUGSWITCH": SmartPlug,
|
||||||
|
}
|
||||||
|
return supported_device_types.get(device_type)
|
||||||
|
|
||||||
|
|
||||||
|
def get_protocol_from_connection_name(
|
||||||
|
connection_name: str, host: str, credentials: Optional[Credentials] = None
|
||||||
|
) -> Optional[TPLinkProtocol]:
|
||||||
|
"""Return the protocol from the connection name."""
|
||||||
|
supported_device_protocols: dict[
|
||||||
|
str, Tuple[Type[TPLinkProtocol], Type[BaseTransport]]
|
||||||
|
] = {
|
||||||
|
"IOT.KLAP": (IotProtocol, KlapTransport),
|
||||||
|
"SMART.AES": (SmartProtocol, AesTransport),
|
||||||
|
"SMART.KLAP": (SmartProtocol, TPlinkKlapTransportV2),
|
||||||
|
}
|
||||||
|
if connection_name not in supported_device_protocols:
|
||||||
|
return None
|
||||||
|
|
||||||
|
protocol_class, transport_class = supported_device_protocols.get(connection_name) # type: ignore
|
||||||
|
transport: BaseTransport = transport_class(host, credentials=credentials)
|
||||||
|
protocol: TPLinkProtocol = protocol_class(
|
||||||
|
host, credentials=credentials, transport=transport
|
||||||
|
)
|
||||||
|
return protocol
|
||||||
|
@ -15,18 +15,18 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from kasa.aesprotocol import TPLinkAes
|
|
||||||
from kasa.credentials import Credentials
|
from kasa.credentials import Credentials
|
||||||
from kasa.exceptions import UnsupportedDeviceException
|
from kasa.exceptions import UnsupportedDeviceException
|
||||||
from kasa.json import dumps as json_dumps
|
from kasa.json import dumps as json_dumps
|
||||||
from kasa.json import loads as json_loads
|
from kasa.json import loads as json_loads
|
||||||
from kasa.klapprotocol import TPLinkKlap
|
from kasa.protocol import TPLinkSmartHomeProtocol
|
||||||
from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol
|
|
||||||
from kasa.smartdevice import SmartDevice, SmartDeviceException
|
from kasa.smartdevice import SmartDevice, SmartDeviceException
|
||||||
from kasa.smartplug import SmartPlug
|
|
||||||
from kasa.tapo.tapoplug import TapoPlug
|
|
||||||
|
|
||||||
from .device_factory import get_device_class_from_info
|
from .device_factory import (
|
||||||
|
get_device_class_from_sys_info,
|
||||||
|
get_device_class_from_type_name,
|
||||||
|
get_protocol_from_connection_name,
|
||||||
|
)
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -348,7 +348,16 @@ class Discover:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_device_class(info: dict) -> Type[SmartDevice]:
|
def _get_device_class(info: dict) -> Type[SmartDevice]:
|
||||||
"""Find SmartDevice subclass for device described by passed data."""
|
"""Find SmartDevice subclass for device described by passed data."""
|
||||||
return get_device_class_from_info(info)
|
if "result" in info:
|
||||||
|
discovery_result = DiscoveryResult(**info["result"])
|
||||||
|
dev_class = get_device_class_from_type_name(discovery_result.device_type)
|
||||||
|
if not dev_class:
|
||||||
|
raise UnsupportedDeviceException(
|
||||||
|
"Unknown device type: %s" % discovery_result.device_type
|
||||||
|
)
|
||||||
|
return dev_class
|
||||||
|
else:
|
||||||
|
return get_device_class_from_sys_info(info)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_device_instance_legacy(data: bytes, ip: str, port: int) -> SmartDevice:
|
def _get_device_instance_legacy(data: bytes, ip: str, port: int) -> SmartDevice:
|
||||||
@ -384,24 +393,17 @@ class Discover:
|
|||||||
encrypt_type_ = (
|
encrypt_type_ = (
|
||||||
f"{type_.split('.')[0]}.{discovery_result.mgt_encrypt_schm.encrypt_type}"
|
f"{type_.split('.')[0]}.{discovery_result.mgt_encrypt_schm.encrypt_type}"
|
||||||
)
|
)
|
||||||
device_class = None
|
|
||||||
|
|
||||||
supported_device_types: dict[str, Type[SmartDevice]] = {
|
if (device_class := get_device_class_from_type_name(type_)) is None:
|
||||||
"SMART.TAPOPLUG": TapoPlug,
|
|
||||||
"SMART.KASAPLUG": TapoPlug,
|
|
||||||
"IOT.SMARTPLUGSWITCH": SmartPlug,
|
|
||||||
}
|
|
||||||
supported_device_protocols: dict[str, Type[TPLinkProtocol]] = {
|
|
||||||
"IOT.KLAP": TPLinkKlap,
|
|
||||||
"SMART.AES": TPLinkAes,
|
|
||||||
}
|
|
||||||
|
|
||||||
if (device_class := supported_device_types.get(type_)) is None:
|
|
||||||
_LOGGER.warning("Got unsupported device type: %s", type_)
|
_LOGGER.warning("Got unsupported device type: %s", type_)
|
||||||
raise UnsupportedDeviceException(
|
raise UnsupportedDeviceException(
|
||||||
f"Unsupported device {ip} of type {type_}: {info}"
|
f"Unsupported device {ip} of type {type_}: {info}"
|
||||||
)
|
)
|
||||||
if (protocol_class := supported_device_protocols.get(encrypt_type_)) is None:
|
if (
|
||||||
|
protocol := get_protocol_from_connection_name(
|
||||||
|
encrypt_type_, ip, credentials=credentials
|
||||||
|
)
|
||||||
|
) is None:
|
||||||
_LOGGER.warning("Got unsupported device type: %s", encrypt_type_)
|
_LOGGER.warning("Got unsupported device type: %s", encrypt_type_)
|
||||||
raise UnsupportedDeviceException(
|
raise UnsupportedDeviceException(
|
||||||
f"Unsupported encryption scheme {ip} of type {encrypt_type_}: {info}"
|
f"Unsupported encryption scheme {ip} of type {encrypt_type_}: {info}"
|
||||||
@ -409,7 +411,7 @@ class Discover:
|
|||||||
|
|
||||||
_LOGGER.debug("[DISCOVERY] %s << %s", ip, info)
|
_LOGGER.debug("[DISCOVERY] %s << %s", ip, info)
|
||||||
device = device_class(ip, port=port, credentials=credentials)
|
device = device_class(ip, port=port, credentials=credentials)
|
||||||
device.protocol = protocol_class(ip, credentials=credentials)
|
device.protocol = protocol
|
||||||
device.update_from_discover_info(discovery_result.get_dict())
|
device.update_from_discover_info(discovery_result.get_dict())
|
||||||
return device
|
return device
|
||||||
|
|
||||||
|
100
kasa/iotprotocol.py
Executable file
100
kasa/iotprotocol.py
Executable file
@ -0,0 +1,100 @@
|
|||||||
|
"""Module for the IOT legacy IOT KASA protocol."""
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from .credentials import Credentials
|
||||||
|
from .exceptions import AuthenticationException, SmartDeviceException
|
||||||
|
from .json import dumps as json_dumps
|
||||||
|
from .klaptransport import KlapTransport
|
||||||
|
from .protocol import BaseTransport, TPLinkProtocol
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class IotProtocol(TPLinkProtocol):
|
||||||
|
"""Class for the legacy TPLink IOT KASA Protocol."""
|
||||||
|
|
||||||
|
DEFAULT_PORT = 80
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
host: str,
|
||||||
|
*,
|
||||||
|
transport: Optional[BaseTransport] = None,
|
||||||
|
credentials: Optional[Credentials] = None,
|
||||||
|
timeout: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(host=host, port=self.DEFAULT_PORT)
|
||||||
|
|
||||||
|
self._credentials: Credentials = credentials or Credentials(
|
||||||
|
username="", password=""
|
||||||
|
)
|
||||||
|
self._transport: BaseTransport = transport or KlapTransport(
|
||||||
|
host, credentials=self._credentials, timeout=timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
self._query_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
||||||
|
"""Query the device retrying for retry_count on failure."""
|
||||||
|
if isinstance(request, dict):
|
||||||
|
request = json_dumps(request)
|
||||||
|
assert isinstance(request, str) # noqa: S101
|
||||||
|
|
||||||
|
async with self._query_lock:
|
||||||
|
return await self._query(request, retry_count)
|
||||||
|
|
||||||
|
async def _query(self, request: str, retry_count: int = 3) -> Dict:
|
||||||
|
for retry in range(retry_count + 1):
|
||||||
|
try:
|
||||||
|
return await self._execute_query(request, retry)
|
||||||
|
except httpx.CloseError as sdex:
|
||||||
|
await self.close()
|
||||||
|
if retry >= retry_count:
|
||||||
|
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
||||||
|
raise SmartDeviceException(
|
||||||
|
f"Unable to connect to the device: {self.host}: {sdex}"
|
||||||
|
) from sdex
|
||||||
|
continue
|
||||||
|
except httpx.ConnectError as cex:
|
||||||
|
await self.close()
|
||||||
|
raise SmartDeviceException(
|
||||||
|
f"Unable to connect to the device: {self.host}: {cex}"
|
||||||
|
) from cex
|
||||||
|
except TimeoutError as tex:
|
||||||
|
await self.close()
|
||||||
|
raise SmartDeviceException(
|
||||||
|
f"Unable to connect to the device, timed out: {self.host}: {tex}"
|
||||||
|
) from tex
|
||||||
|
except AuthenticationException as auex:
|
||||||
|
_LOGGER.debug("Unable to authenticate with %s, not retrying", self.host)
|
||||||
|
raise auex
|
||||||
|
except Exception as ex:
|
||||||
|
await self.close()
|
||||||
|
if retry >= retry_count:
|
||||||
|
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
||||||
|
raise SmartDeviceException(
|
||||||
|
f"Unable to connect to the device: {self.host}: {ex}"
|
||||||
|
) from ex
|
||||||
|
continue
|
||||||
|
|
||||||
|
# make mypy happy, this should never be reached..
|
||||||
|
raise SmartDeviceException("Query reached somehow to unreachable")
|
||||||
|
|
||||||
|
async def _execute_query(self, request: str, retry_count: int) -> Dict:
|
||||||
|
if self._transport.needs_handshake:
|
||||||
|
await self._transport.handshake()
|
||||||
|
|
||||||
|
if self._transport.needs_login: # This shouln't happen
|
||||||
|
raise SmartDeviceException(
|
||||||
|
"IOT Protocol needs to login to transport but is not login aware"
|
||||||
|
)
|
||||||
|
|
||||||
|
return await self._transport.send(request)
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close the protocol."""
|
||||||
|
await self._transport.close()
|
295
kasa/klapprotocol.py → kasa/klaptransport.py
Executable file → Normal file
295
kasa/klapprotocol.py → kasa/klaptransport.py
Executable file → Normal file
@ -47,7 +47,7 @@ import logging
|
|||||||
import secrets
|
import secrets
|
||||||
import time
|
import time
|
||||||
from pprint import pformat as pf
|
from pprint import pformat as pf
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Optional, Tuple
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from cryptography.hazmat.primitives import hashes, padding
|
from cryptography.hazmat.primitives import hashes, padding
|
||||||
@ -55,33 +55,33 @@ from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
|||||||
|
|
||||||
from .credentials import Credentials
|
from .credentials import Credentials
|
||||||
from .exceptions import AuthenticationException, SmartDeviceException
|
from .exceptions import AuthenticationException, SmartDeviceException
|
||||||
from .json import dumps as json_dumps
|
|
||||||
from .json import loads as json_loads
|
from .json import loads as json_loads
|
||||||
from .protocol import TPLinkProtocol
|
from .protocol import BaseTransport, md5
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
logging.getLogger("httpx").propagate = False
|
logging.getLogger("httpx").propagate = False
|
||||||
|
|
||||||
|
|
||||||
def _sha256(payload: bytes) -> bytes:
|
def _sha256(payload: bytes) -> bytes:
|
||||||
return hashlib.sha256(payload).digest()
|
digest = hashes.Hash(hashes.SHA256()) # noqa: S303
|
||||||
|
|
||||||
|
|
||||||
def _md5(payload: bytes) -> bytes:
|
|
||||||
digest = hashes.Hash(hashes.MD5()) # noqa: S303
|
|
||||||
digest.update(payload)
|
digest.update(payload)
|
||||||
hash = digest.finalize()
|
hash = digest.finalize()
|
||||||
return hash
|
return hash
|
||||||
|
|
||||||
|
|
||||||
class TPLinkKlap(TPLinkProtocol):
|
def _sha1(payload: bytes) -> bytes:
|
||||||
|
digest = hashes.Hash(hashes.SHA1()) # noqa: S303
|
||||||
|
digest.update(payload)
|
||||||
|
return digest.finalize()
|
||||||
|
|
||||||
|
|
||||||
|
class KlapTransport(BaseTransport):
|
||||||
"""Implementation of the KLAP encryption protocol.
|
"""Implementation of the KLAP encryption protocol.
|
||||||
|
|
||||||
KLAP is the name used in device discovery for TP-Link's new encryption
|
KLAP is the name used in device discovery for TP-Link's new encryption
|
||||||
protocol, used by newer firmware versions.
|
protocol, used by newer firmware versions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
DEFAULT_PORT = 80
|
|
||||||
DEFAULT_TIMEOUT = 5
|
DEFAULT_TIMEOUT = 5
|
||||||
DISCOVERY_QUERY = {"system": {"get_sysinfo": None}}
|
DISCOVERY_QUERY = {"system": {"get_sysinfo": None}}
|
||||||
KASA_SETUP_EMAIL = "kasa@tp-link.net"
|
KASA_SETUP_EMAIL = "kasa@tp-link.net"
|
||||||
@ -95,29 +95,24 @@ class TPLinkKlap(TPLinkProtocol):
|
|||||||
credentials: Optional[Credentials] = None,
|
credentials: Optional[Credentials] = None,
|
||||||
timeout: Optional[int] = None,
|
timeout: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(host=host, port=self.DEFAULT_PORT)
|
super().__init__(host=host)
|
||||||
|
|
||||||
self.credentials = (
|
|
||||||
credentials
|
|
||||||
if credentials and credentials.username and credentials.password
|
|
||||||
else Credentials(username="", password="")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
self._credentials = credentials or Credentials(username="", password="")
|
||||||
self._local_seed: Optional[bytes] = None
|
self._local_seed: Optional[bytes] = None
|
||||||
self.local_auth_hash = self.generate_auth_hash(self.credentials)
|
self._local_auth_hash = self.generate_auth_hash(self._credentials)
|
||||||
self.local_auth_owner = self.generate_owner_hash(self.credentials).hex()
|
self._local_auth_owner = self.generate_owner_hash(self._credentials).hex()
|
||||||
self.kasa_setup_auth_hash = None
|
self._kasa_setup_auth_hash = None
|
||||||
self.blank_auth_hash = None
|
self._blank_auth_hash = None
|
||||||
self.handshake_lock = asyncio.Lock()
|
self._handshake_lock = asyncio.Lock()
|
||||||
self.query_lock = asyncio.Lock()
|
self._query_lock = asyncio.Lock()
|
||||||
self.handshake_done = False
|
self._handshake_done = False
|
||||||
|
|
||||||
self.encryption_session: Optional[KlapEncryptionSession] = None
|
self._encryption_session: Optional[KlapEncryptionSession] = None
|
||||||
self.session_expire_at: Optional[float] = None
|
self._session_expire_at: Optional[float] = None
|
||||||
|
|
||||||
self.timeout = timeout if timeout else self.DEFAULT_TIMEOUT
|
self._timeout = timeout if timeout else self.DEFAULT_TIMEOUT
|
||||||
self.session_cookie = None
|
self._session_cookie = None
|
||||||
self.http_client: Optional[httpx.AsyncClient] = None
|
self._http_client: httpx.AsyncClient = httpx.AsyncClient()
|
||||||
|
|
||||||
_LOGGER.debug("Created KLAP object for %s", self.host)
|
_LOGGER.debug("Created KLAP object for %s", self.host)
|
||||||
|
|
||||||
@ -125,15 +120,15 @@ class TPLinkKlap(TPLinkProtocol):
|
|||||||
"""Send an http post request to the device."""
|
"""Send an http post request to the device."""
|
||||||
response_data = None
|
response_data = None
|
||||||
cookies = None
|
cookies = None
|
||||||
if self.session_cookie:
|
if self._session_cookie:
|
||||||
cookies = httpx.Cookies()
|
cookies = httpx.Cookies()
|
||||||
cookies.set(self.SESSION_COOKIE_NAME, self.session_cookie)
|
cookies.set(self.SESSION_COOKIE_NAME, self._session_cookie)
|
||||||
self.http_client.cookies.clear()
|
self._http_client.cookies.clear()
|
||||||
resp = await self.http_client.post(
|
resp = await self._http_client.post(
|
||||||
url,
|
url,
|
||||||
params=params,
|
params=params,
|
||||||
data=data,
|
data=data,
|
||||||
timeout=self.timeout,
|
timeout=self._timeout,
|
||||||
cookies=cookies,
|
cookies=cookies,
|
||||||
)
|
)
|
||||||
if resp.status_code == 200:
|
if resp.status_code == 200:
|
||||||
@ -183,44 +178,55 @@ class TPLinkKlap(TPLinkProtocol):
|
|||||||
server_hash.hex(),
|
server_hash.hex(),
|
||||||
)
|
)
|
||||||
|
|
||||||
local_seed_auth_hash = _sha256(local_seed + self.local_auth_hash)
|
local_seed_auth_hash = self.handshake1_seed_auth_hash(
|
||||||
|
local_seed, remote_seed, self._local_auth_hash
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
# Check the response from the device with local credentials
|
# Check the response from the device with local credentials
|
||||||
if local_seed_auth_hash == server_hash:
|
if local_seed_auth_hash == server_hash:
|
||||||
_LOGGER.debug("handshake1 hashes match with expected credentials")
|
_LOGGER.debug("handshake1 hashes match with expected credentials")
|
||||||
return local_seed, remote_seed, self.local_auth_hash # type: ignore
|
return local_seed, remote_seed, self._local_auth_hash # type: ignore
|
||||||
|
|
||||||
# Now check against the default kasa setup credentials
|
# Now check against the default kasa setup credentials
|
||||||
if not self.kasa_setup_auth_hash:
|
if not self._kasa_setup_auth_hash:
|
||||||
kasa_setup_creds = Credentials(
|
kasa_setup_creds = Credentials(
|
||||||
username=TPLinkKlap.KASA_SETUP_EMAIL,
|
username=self.KASA_SETUP_EMAIL,
|
||||||
password=TPLinkKlap.KASA_SETUP_PASSWORD,
|
password=self.KASA_SETUP_PASSWORD,
|
||||||
)
|
)
|
||||||
self.kasa_setup_auth_hash = TPLinkKlap.generate_auth_hash(kasa_setup_creds)
|
self._kasa_setup_auth_hash = self.generate_auth_hash(kasa_setup_creds)
|
||||||
|
|
||||||
kasa_setup_seed_auth_hash = _sha256(
|
kasa_setup_seed_auth_hash = self.handshake1_seed_auth_hash(
|
||||||
local_seed + self.kasa_setup_auth_hash # type: ignore
|
local_seed,
|
||||||
|
remote_seed,
|
||||||
|
self._kasa_setup_auth_hash, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
if kasa_setup_seed_auth_hash == server_hash:
|
if kasa_setup_seed_auth_hash == server_hash:
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
"Server response doesn't match our expected hash on ip %s"
|
"Server response doesn't match our expected hash on ip %s"
|
||||||
+ " but an authentication with kasa setup credentials matched",
|
+ " but an authentication with kasa setup credentials matched",
|
||||||
self.host,
|
self.host,
|
||||||
)
|
)
|
||||||
return local_seed, remote_seed, self.kasa_setup_auth_hash # type: ignore
|
return local_seed, remote_seed, self._kasa_setup_auth_hash # type: ignore
|
||||||
|
|
||||||
# Finally check against blank credentials if not already blank
|
# Finally check against blank credentials if not already blank
|
||||||
if self.credentials != (blank_creds := Credentials(username="", password="")):
|
if self._credentials != (blank_creds := Credentials(username="", password="")):
|
||||||
if not self.blank_auth_hash:
|
if not self._blank_auth_hash:
|
||||||
self.blank_auth_hash = TPLinkKlap.generate_auth_hash(blank_creds)
|
self._blank_auth_hash = self.generate_auth_hash(blank_creds)
|
||||||
blank_seed_auth_hash = _sha256(local_seed + self.blank_auth_hash) # type: ignore
|
|
||||||
|
blank_seed_auth_hash = self.handshake1_seed_auth_hash(
|
||||||
|
local_seed,
|
||||||
|
remote_seed,
|
||||||
|
self._blank_auth_hash, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
if blank_seed_auth_hash == server_hash:
|
if blank_seed_auth_hash == server_hash:
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
"Server response doesn't match our expected hash on ip %s"
|
"Server response doesn't match our expected hash on ip %s"
|
||||||
+ " but an authentication with blank credentials matched",
|
+ " but an authentication with blank credentials matched",
|
||||||
self.host,
|
self.host,
|
||||||
)
|
)
|
||||||
return local_seed, remote_seed, self.blank_auth_hash # type: ignore
|
return local_seed, remote_seed, self._blank_auth_hash # type: ignore
|
||||||
|
|
||||||
msg = f"Server response doesn't match our challenge on ip {self.host}"
|
msg = f"Server response doesn't match our challenge on ip {self.host}"
|
||||||
_LOGGER.debug(msg)
|
_LOGGER.debug(msg)
|
||||||
@ -235,7 +241,7 @@ class TPLinkKlap(TPLinkProtocol):
|
|||||||
|
|
||||||
url = f"http://{self.host}/app/handshake2"
|
url = f"http://{self.host}/app/handshake2"
|
||||||
|
|
||||||
payload = _sha256(remote_seed + auth_hash)
|
payload = self.handshake2_seed_auth_hash(local_seed, remote_seed, auth_hash)
|
||||||
|
|
||||||
response_status, response_data = await self.client_post(url, data=payload)
|
response_status, response_data = await self.client_post(url, data=payload)
|
||||||
|
|
||||||
@ -256,115 +262,70 @@ class TPLinkKlap(TPLinkProtocol):
|
|||||||
|
|
||||||
return KlapEncryptionSession(local_seed, remote_seed, auth_hash)
|
return KlapEncryptionSession(local_seed, remote_seed, auth_hash)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def needs_login(self) -> bool:
|
||||||
|
"""Will return false as KLAP does not do a login."""
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def login(self, request: str) -> None:
|
||||||
|
"""Will raise and exception as KLAP does not do a login."""
|
||||||
|
raise SmartDeviceException(
|
||||||
|
"KLAP does not perform logins and return needs_login == False"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def needs_handshake(self) -> bool:
|
||||||
|
"""Return true if the transport needs to do a handshake."""
|
||||||
|
return not self._handshake_done or self._handshake_session_expired()
|
||||||
|
|
||||||
|
async def handshake(self) -> None:
|
||||||
|
"""Perform the encryption handshake."""
|
||||||
|
await self.perform_handshake()
|
||||||
|
|
||||||
async def perform_handshake(self) -> Any:
|
async def perform_handshake(self) -> Any:
|
||||||
"""Perform handshake1 and handshake2.
|
"""Perform handshake1 and handshake2.
|
||||||
|
|
||||||
Sets the encryption_session if successful.
|
Sets the encryption_session if successful.
|
||||||
"""
|
"""
|
||||||
_LOGGER.debug("Starting handshake with %s", self.host)
|
_LOGGER.debug("Starting handshake with %s", self.host)
|
||||||
self.handshake_done = False
|
self._handshake_done = False
|
||||||
self.session_expire_at = None
|
self._session_expire_at = None
|
||||||
self.session_cookie = None
|
self._session_cookie = None
|
||||||
|
|
||||||
local_seed, remote_seed, auth_hash = await self.perform_handshake1()
|
local_seed, remote_seed, auth_hash = await self.perform_handshake1()
|
||||||
self.session_cookie = self.http_client.cookies.get( # type: ignore
|
self._session_cookie = self._http_client.cookies.get( # type: ignore
|
||||||
TPLinkKlap.SESSION_COOKIE_NAME
|
self.SESSION_COOKIE_NAME
|
||||||
)
|
)
|
||||||
# The device returns a TIMEOUT cookie on handshake1 which
|
# The device returns a TIMEOUT cookie on handshake1 which
|
||||||
# it doesn't like to get back so we store the one we want
|
# it doesn't like to get back so we store the one we want
|
||||||
|
|
||||||
self.session_expire_at = time.time() + 86400
|
self._session_expire_at = time.time() + 86400
|
||||||
self.encryption_session = await self.perform_handshake2(
|
self._encryption_session = await self.perform_handshake2(
|
||||||
local_seed, remote_seed, auth_hash
|
local_seed, remote_seed, auth_hash
|
||||||
)
|
)
|
||||||
self.handshake_done = True
|
self._handshake_done = True
|
||||||
|
|
||||||
_LOGGER.debug("Handshake with %s complete", self.host)
|
_LOGGER.debug("Handshake with %s complete", self.host)
|
||||||
|
|
||||||
def handshake_session_expired(self):
|
def _handshake_session_expired(self):
|
||||||
"""Return true if session has expired."""
|
"""Return true if session has expired."""
|
||||||
return (
|
return (
|
||||||
self.session_expire_at is None or self.session_expire_at - time.time() <= 0
|
self._session_expire_at is None
|
||||||
|
or self._session_expire_at - time.time() <= 0
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
async def send(self, request: str):
|
||||||
def generate_auth_hash(creds: Credentials):
|
"""Send the request."""
|
||||||
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
|
if self.needs_handshake:
|
||||||
un = creds.username or ""
|
|
||||||
pw = creds.password or ""
|
|
||||||
return _md5(_md5(un.encode()) + _md5(pw.encode()))
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def generate_owner_hash(creds: Credentials):
|
|
||||||
"""Return the MD5 hash of the username in this object."""
|
|
||||||
un = creds.username or ""
|
|
||||||
return _md5(un.encode())
|
|
||||||
|
|
||||||
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
|
||||||
"""Query the device retrying for retry_count on failure."""
|
|
||||||
if isinstance(request, dict):
|
|
||||||
request = json_dumps(request)
|
|
||||||
assert isinstance(request, str) # noqa: S101
|
|
||||||
|
|
||||||
async with self.query_lock:
|
|
||||||
return await self._query(request, retry_count)
|
|
||||||
|
|
||||||
async def _query(self, request: str, retry_count: int = 3) -> Dict:
|
|
||||||
for retry in range(retry_count + 1):
|
|
||||||
try:
|
|
||||||
return await self._execute_query(request, retry)
|
|
||||||
except httpx.CloseError as sdex:
|
|
||||||
await self.close()
|
|
||||||
if retry >= retry_count:
|
|
||||||
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
|
||||||
raise SmartDeviceException(
|
raise SmartDeviceException(
|
||||||
f"Unable to connect to the device: {self.host}: {sdex}"
|
"Handshake must be complete before trying to send"
|
||||||
) from sdex
|
|
||||||
continue
|
|
||||||
except httpx.ConnectError as cex:
|
|
||||||
await self.close()
|
|
||||||
raise SmartDeviceException(
|
|
||||||
f"Unable to connect to the device: {self.host}: {cex}"
|
|
||||||
) from cex
|
|
||||||
except TimeoutError as tex:
|
|
||||||
await self.close()
|
|
||||||
raise SmartDeviceException(
|
|
||||||
f"Unable to connect to the device, timed out: {self.host}: {tex}"
|
|
||||||
) from tex
|
|
||||||
except AuthenticationException as auex:
|
|
||||||
_LOGGER.debug("Unable to authenticate with %s, not retrying", self.host)
|
|
||||||
raise auex
|
|
||||||
except Exception as ex:
|
|
||||||
await self.close()
|
|
||||||
if retry >= retry_count:
|
|
||||||
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
|
||||||
raise SmartDeviceException(
|
|
||||||
f"Unable to connect to the device: {self.host}: {ex}"
|
|
||||||
) from ex
|
|
||||||
continue
|
|
||||||
|
|
||||||
# make mypy happy, this should never be reached..
|
|
||||||
raise SmartDeviceException("Query reached somehow to unreachable")
|
|
||||||
|
|
||||||
async def _execute_query(self, request: str, retry_count: int) -> Dict:
|
|
||||||
if not self.http_client:
|
|
||||||
self.http_client = httpx.AsyncClient()
|
|
||||||
|
|
||||||
if not self.handshake_done or self.handshake_session_expired():
|
|
||||||
try:
|
|
||||||
await self.perform_handshake()
|
|
||||||
|
|
||||||
except AuthenticationException as auex:
|
|
||||||
_LOGGER.debug(
|
|
||||||
"Unable to complete handshake for device %s, "
|
|
||||||
+ "authentication failed",
|
|
||||||
self.host,
|
|
||||||
)
|
)
|
||||||
raise auex
|
if self.needs_login:
|
||||||
|
raise SmartDeviceException("Login must be complete before trying to send")
|
||||||
|
|
||||||
# Check for mypy
|
# Check for mypy
|
||||||
if self.encryption_session is not None:
|
if self._encryption_session is not None:
|
||||||
payload, seq = self.encryption_session.encrypt(request.encode())
|
payload, seq = self._encryption_session.encrypt(request.encode())
|
||||||
|
|
||||||
url = f"http://{self.host}/app/request"
|
url = f"http://{self.host}/app/request"
|
||||||
|
|
||||||
@ -376,14 +337,14 @@ class TPLinkKlap(TPLinkProtocol):
|
|||||||
|
|
||||||
msg = (
|
msg = (
|
||||||
f"at {datetime.datetime.now()}. Host is {self.host}, "
|
f"at {datetime.datetime.now()}. Host is {self.host}, "
|
||||||
+ f"Retry count is {retry_count}, Sequence is {seq}, "
|
+ f"Sequence is {seq}, "
|
||||||
+ f"Response status is {response_status}, Request was {request}"
|
+ f"Response status is {response_status}, Request was {request}"
|
||||||
)
|
)
|
||||||
if response_status != 200:
|
if response_status != 200:
|
||||||
_LOGGER.error("Query failed after succesful authentication " + msg)
|
_LOGGER.error("Query failed after succesful authentication " + msg)
|
||||||
# If we failed with a security error, force a new handshake next time.
|
# If we failed with a security error, force a new handshake next time.
|
||||||
if response_status == 403:
|
if response_status == 403:
|
||||||
self.handshake_done = False
|
self._handshake_done = False
|
||||||
raise AuthenticationException(
|
raise AuthenticationException(
|
||||||
f"Got a security error from {self.host} after handshake "
|
f"Got a security error from {self.host} after handshake "
|
||||||
+ "completed"
|
+ "completed"
|
||||||
@ -397,8 +358,8 @@ class TPLinkKlap(TPLinkProtocol):
|
|||||||
_LOGGER.debug("Query posted " + msg)
|
_LOGGER.debug("Query posted " + msg)
|
||||||
|
|
||||||
# Check for mypy
|
# Check for mypy
|
||||||
if self.encryption_session is not None:
|
if self._encryption_session is not None:
|
||||||
decrypted_response = self.encryption_session.decrypt(response_data)
|
decrypted_response = self._encryption_session.decrypt(response_data)
|
||||||
|
|
||||||
json_payload = json_loads(decrypted_response)
|
json_payload = json_loads(decrypted_response)
|
||||||
|
|
||||||
@ -411,12 +372,66 @@ class TPLinkKlap(TPLinkProtocol):
|
|||||||
return json_payload
|
return json_payload
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
"""Close the protocol."""
|
"""Close the transport."""
|
||||||
client = self.http_client
|
client = self._http_client
|
||||||
self.http_client = None
|
self._http_client = None
|
||||||
if client:
|
if client:
|
||||||
await client.aclose()
|
await client.aclose()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_auth_hash(creds: Credentials):
|
||||||
|
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
|
||||||
|
un = creds.username or ""
|
||||||
|
pw = creds.password or ""
|
||||||
|
|
||||||
|
return md5(md5(un.encode()) + md5(pw.encode()))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def handshake1_seed_auth_hash(
|
||||||
|
local_seed: bytes, remote_seed: bytes, auth_hash: bytes
|
||||||
|
):
|
||||||
|
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
|
||||||
|
return _sha256(local_seed + auth_hash)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def handshake2_seed_auth_hash(
|
||||||
|
local_seed: bytes, remote_seed: bytes, auth_hash: bytes
|
||||||
|
):
|
||||||
|
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
|
||||||
|
return _sha256(remote_seed + auth_hash)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_owner_hash(creds: Credentials):
|
||||||
|
"""Return the MD5 hash of the username in this object."""
|
||||||
|
un = creds.username or ""
|
||||||
|
return md5(un.encode())
|
||||||
|
|
||||||
|
|
||||||
|
class TPlinkKlapTransportV2(KlapTransport):
|
||||||
|
"""Implementation of the KLAP encryption protocol with v2 hanshake hashes."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_auth_hash(creds: Credentials):
|
||||||
|
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
|
||||||
|
un = creds.username or ""
|
||||||
|
pw = creds.password or ""
|
||||||
|
|
||||||
|
return _sha256(_sha1(un.encode()) + _sha1(pw.encode()))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def handshake1_seed_auth_hash(
|
||||||
|
local_seed: bytes, remote_seed: bytes, auth_hash: bytes
|
||||||
|
):
|
||||||
|
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
|
||||||
|
return _sha256(local_seed + remote_seed + auth_hash)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def handshake2_seed_auth_hash(
|
||||||
|
local_seed: bytes, remote_seed: bytes, auth_hash: bytes
|
||||||
|
):
|
||||||
|
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
|
||||||
|
return _sha256(remote_seed + local_seed + auth_hash)
|
||||||
|
|
||||||
|
|
||||||
class KlapEncryptionSession:
|
class KlapEncryptionSession:
|
||||||
"""Class to represent an encryption session and it's internal state.
|
"""Class to represent an encryption session and it's internal state.
|
@ -22,6 +22,7 @@ from typing import Dict, Generator, Optional, Union
|
|||||||
# When support for cpython older than 3.11 is dropped
|
# When support for cpython older than 3.11 is dropped
|
||||||
# async_timeout can be replaced with asyncio.timeout
|
# async_timeout can be replaced with asyncio.timeout
|
||||||
from async_timeout import timeout as asyncio_timeout
|
from async_timeout import timeout as asyncio_timeout
|
||||||
|
from cryptography.hazmat.primitives import hashes
|
||||||
|
|
||||||
from .credentials import Credentials
|
from .credentials import Credentials
|
||||||
from .exceptions import SmartDeviceException
|
from .exceptions import SmartDeviceException
|
||||||
@ -32,6 +33,56 @@ _LOGGER = logging.getLogger(__name__)
|
|||||||
_NO_RETRY_ERRORS = {errno.EHOSTDOWN, errno.EHOSTUNREACH, errno.ECONNREFUSED}
|
_NO_RETRY_ERRORS = {errno.EHOSTDOWN, errno.EHOSTUNREACH, errno.ECONNREFUSED}
|
||||||
|
|
||||||
|
|
||||||
|
def md5(payload: bytes) -> bytes:
|
||||||
|
"""Return an md5 hash of the payload."""
|
||||||
|
digest = hashes.Hash(hashes.MD5()) # noqa: S303
|
||||||
|
digest.update(payload)
|
||||||
|
hash = digest.finalize()
|
||||||
|
return hash
|
||||||
|
|
||||||
|
|
||||||
|
class BaseTransport(ABC):
|
||||||
|
"""Base class for all TP-Link protocol transports."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
host: str,
|
||||||
|
*,
|
||||||
|
port: Optional[int] = None,
|
||||||
|
credentials: Optional[Credentials] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Create a protocol object."""
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.credentials = credentials
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def needs_handshake(self) -> bool:
|
||||||
|
"""Return true if the transport needs to do a handshake."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def needs_login(self) -> bool:
|
||||||
|
"""Return true if the transport needs to do a login."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def login(self, request: str) -> None:
|
||||||
|
"""Login to the device."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def handshake(self) -> None:
|
||||||
|
"""Perform the encryption handshake."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def send(self, request: str) -> Dict:
|
||||||
|
"""Send a message to the device and return a response."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close the transport. Abstract method to be overriden."""
|
||||||
|
|
||||||
|
|
||||||
class TPLinkProtocol(ABC):
|
class TPLinkProtocol(ABC):
|
||||||
"""Base class for all TP-Link Smart Home communication."""
|
"""Base class for all TP-Link Smart Home communication."""
|
||||||
|
|
||||||
@ -41,6 +92,7 @@ class TPLinkProtocol(ABC):
|
|||||||
*,
|
*,
|
||||||
port: Optional[int] = None,
|
port: Optional[int] = None,
|
||||||
credentials: Optional[Credentials] = None,
|
credentials: Optional[Credentials] = None,
|
||||||
|
transport: Optional[BaseTransport] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Create a protocol object."""
|
"""Create a protocol object."""
|
||||||
self.host = host
|
self.host = host
|
||||||
|
@ -365,6 +365,7 @@ class SmartDevice:
|
|||||||
|
|
||||||
def update_from_discover_info(self, info: Dict[str, Any]) -> None:
|
def update_from_discover_info(self, info: Dict[str, Any]) -> None:
|
||||||
"""Update state from info from the discover call."""
|
"""Update state from info from the discover call."""
|
||||||
|
self._discovery_info = info
|
||||||
if "system" in info and (sys_info := info["system"].get("get_sysinfo")):
|
if "system" in info and (sys_info := info["system"].get("get_sysinfo")):
|
||||||
self._last_update = info
|
self._last_update = info
|
||||||
self._set_sys_info(sys_info)
|
self._set_sys_info(sys_info)
|
||||||
@ -372,7 +373,6 @@ class SmartDevice:
|
|||||||
# This allows setting of some info properties directly
|
# This allows setting of some info properties directly
|
||||||
# from partial discovery info that will then be found
|
# from partial discovery info that will then be found
|
||||||
# by the requires_update decorator
|
# by the requires_update decorator
|
||||||
self._discovery_info = info
|
|
||||||
self._set_sys_info(info)
|
self._set_sys_info(info)
|
||||||
|
|
||||||
def _set_sys_info(self, sys_info: Dict[str, Any]) -> None:
|
def _set_sys_info(self, sys_info: Dict[str, Any]) -> None:
|
||||||
|
219
kasa/smartprotocol.py
Normal file
219
kasa/smartprotocol.py
Normal file
@ -0,0 +1,219 @@
|
|||||||
|
"""Implementation of the TP-Link AES Protocol.
|
||||||
|
|
||||||
|
Based on the work of https://github.com/petretiandrea/plugp100
|
||||||
|
under compatible GNU GPL3 license.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from pprint import pformat as pf
|
||||||
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from .aestransport import AesTransport
|
||||||
|
from .credentials import Credentials
|
||||||
|
from .exceptions import AuthenticationException, SmartDeviceException
|
||||||
|
from .json import dumps as json_dumps
|
||||||
|
from .protocol import BaseTransport, TPLinkProtocol, md5
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
logging.getLogger("httpx").propagate = False
|
||||||
|
|
||||||
|
|
||||||
|
class SmartProtocol(TPLinkProtocol):
|
||||||
|
"""Class for the new TPLink SMART protocol."""
|
||||||
|
|
||||||
|
DEFAULT_PORT = 80
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
host: str,
|
||||||
|
*,
|
||||||
|
transport: Optional[BaseTransport] = None,
|
||||||
|
credentials: Optional[Credentials] = None,
|
||||||
|
timeout: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(host=host, port=self.DEFAULT_PORT)
|
||||||
|
|
||||||
|
self._credentials: Credentials = credentials or Credentials(
|
||||||
|
username="", password=""
|
||||||
|
)
|
||||||
|
self._transport: BaseTransport = transport or AesTransport(
|
||||||
|
host, credentials=self._credentials, timeout=timeout
|
||||||
|
)
|
||||||
|
self._terminal_uuid: Optional[str] = None
|
||||||
|
self._request_id_generator = SnowflakeId(1, 1)
|
||||||
|
self._query_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
def get_smart_request(self, method, params=None) -> str:
|
||||||
|
"""Get a request message as a string."""
|
||||||
|
request = {
|
||||||
|
"method": method,
|
||||||
|
"params": params,
|
||||||
|
"requestID": self._request_id_generator.generate_id(),
|
||||||
|
"request_time_milis": round(time.time() * 1000),
|
||||||
|
"terminal_uuid": self._terminal_uuid,
|
||||||
|
}
|
||||||
|
return json_dumps(request)
|
||||||
|
|
||||||
|
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
||||||
|
"""Query the device retrying for retry_count on failure."""
|
||||||
|
async with self._query_lock:
|
||||||
|
resp_dict = await self._query(request, retry_count)
|
||||||
|
if "result" in resp_dict:
|
||||||
|
return resp_dict["result"]
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
||||||
|
for retry in range(retry_count + 1):
|
||||||
|
try:
|
||||||
|
return await self._execute_query(request, retry)
|
||||||
|
except httpx.CloseError as sdex:
|
||||||
|
await self.close()
|
||||||
|
if retry >= retry_count:
|
||||||
|
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
||||||
|
raise SmartDeviceException(
|
||||||
|
f"Unable to connect to the device: {self.host}: {sdex}"
|
||||||
|
) from sdex
|
||||||
|
continue
|
||||||
|
except httpx.ConnectError as cex:
|
||||||
|
await self.close()
|
||||||
|
raise SmartDeviceException(
|
||||||
|
f"Unable to connect to the device: {self.host}: {cex}"
|
||||||
|
) from cex
|
||||||
|
except TimeoutError as tex:
|
||||||
|
await self.close()
|
||||||
|
raise SmartDeviceException(
|
||||||
|
f"Unable to connect to the device, timed out: {self.host}: {tex}"
|
||||||
|
) from tex
|
||||||
|
except AuthenticationException as auex:
|
||||||
|
_LOGGER.debug("Unable to authenticate with %s, not retrying", self.host)
|
||||||
|
raise auex
|
||||||
|
except Exception as ex:
|
||||||
|
await self.close()
|
||||||
|
if retry >= retry_count:
|
||||||
|
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
||||||
|
raise SmartDeviceException(
|
||||||
|
f"Unable to connect to the device: {self.host}: {ex}"
|
||||||
|
) from ex
|
||||||
|
continue
|
||||||
|
|
||||||
|
# make mypy happy, this should never be reached..
|
||||||
|
raise SmartDeviceException("Query reached somehow to unreachable")
|
||||||
|
|
||||||
|
async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> Dict:
|
||||||
|
if isinstance(request, dict):
|
||||||
|
smart_method = next(iter(request))
|
||||||
|
smart_params = request[smart_method]
|
||||||
|
else:
|
||||||
|
smart_method = request
|
||||||
|
smart_params = None
|
||||||
|
|
||||||
|
if self._transport.needs_handshake:
|
||||||
|
await self._transport.handshake()
|
||||||
|
|
||||||
|
if self._transport.needs_login:
|
||||||
|
self._terminal_uuid = base64.b64encode(md5(uuid.uuid4().bytes)).decode(
|
||||||
|
"UTF-8"
|
||||||
|
)
|
||||||
|
login_request = self.get_smart_request("login_device")
|
||||||
|
await self._transport.login(login_request)
|
||||||
|
|
||||||
|
smart_request = self.get_smart_request(smart_method, smart_params)
|
||||||
|
response_data = await self._transport.send(smart_request)
|
||||||
|
|
||||||
|
_LOGGER.debug(
|
||||||
|
"%s << %s",
|
||||||
|
self.host,
|
||||||
|
_LOGGER.isEnabledFor(logging.DEBUG) and pf(response_data),
|
||||||
|
)
|
||||||
|
|
||||||
|
return response_data
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close the protocol."""
|
||||||
|
await self._transport.close()
|
||||||
|
|
||||||
|
|
||||||
|
class SnowflakeId:
|
||||||
|
"""Class for generating snowflake ids."""
|
||||||
|
|
||||||
|
EPOCH = 1420041600000 # Custom epoch (in milliseconds)
|
||||||
|
WORKER_ID_BITS = 5
|
||||||
|
DATA_CENTER_ID_BITS = 5
|
||||||
|
SEQUENCE_BITS = 12
|
||||||
|
|
||||||
|
MAX_WORKER_ID = (1 << WORKER_ID_BITS) - 1
|
||||||
|
MAX_DATA_CENTER_ID = (1 << DATA_CENTER_ID_BITS) - 1
|
||||||
|
|
||||||
|
SEQUENCE_MASK = (1 << SEQUENCE_BITS) - 1
|
||||||
|
|
||||||
|
def __init__(self, worker_id, data_center_id):
|
||||||
|
if worker_id > SnowflakeId.MAX_WORKER_ID or worker_id < 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Worker ID can't be greater than "
|
||||||
|
+ str(SnowflakeId.MAX_WORKER_ID)
|
||||||
|
+ " or less than 0"
|
||||||
|
)
|
||||||
|
if data_center_id > SnowflakeId.MAX_DATA_CENTER_ID or data_center_id < 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Data center ID can't be greater than "
|
||||||
|
+ str(SnowflakeId.MAX_DATA_CENTER_ID)
|
||||||
|
+ " or less than 0"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.worker_id = worker_id
|
||||||
|
self.data_center_id = data_center_id
|
||||||
|
self.sequence = 0
|
||||||
|
self.last_timestamp = -1
|
||||||
|
|
||||||
|
def generate_id(self):
|
||||||
|
"""Generate a snowflake id."""
|
||||||
|
timestamp = self._current_millis()
|
||||||
|
|
||||||
|
if timestamp < self.last_timestamp:
|
||||||
|
raise ValueError("Clock moved backwards. Refusing to generate ID.")
|
||||||
|
|
||||||
|
if timestamp == self.last_timestamp:
|
||||||
|
# Within the same millisecond, increment the sequence number
|
||||||
|
self.sequence = (self.sequence + 1) & SnowflakeId.SEQUENCE_MASK
|
||||||
|
if self.sequence == 0:
|
||||||
|
# Sequence exceeds its bit range, wait until the next millisecond
|
||||||
|
timestamp = self._wait_next_millis(self.last_timestamp)
|
||||||
|
else:
|
||||||
|
# New millisecond, reset the sequence number
|
||||||
|
self.sequence = 0
|
||||||
|
|
||||||
|
# Update the last timestamp
|
||||||
|
self.last_timestamp = timestamp
|
||||||
|
|
||||||
|
# Generate and return the final ID
|
||||||
|
return (
|
||||||
|
(
|
||||||
|
(timestamp - SnowflakeId.EPOCH)
|
||||||
|
<< (
|
||||||
|
SnowflakeId.WORKER_ID_BITS
|
||||||
|
+ SnowflakeId.SEQUENCE_BITS
|
||||||
|
+ SnowflakeId.DATA_CENTER_ID_BITS
|
||||||
|
)
|
||||||
|
)
|
||||||
|
| (
|
||||||
|
self.data_center_id
|
||||||
|
<< (SnowflakeId.SEQUENCE_BITS + SnowflakeId.WORKER_ID_BITS)
|
||||||
|
)
|
||||||
|
| (self.worker_id << SnowflakeId.SEQUENCE_BITS)
|
||||||
|
| self.sequence
|
||||||
|
)
|
||||||
|
|
||||||
|
def _current_millis(self):
|
||||||
|
return round(time.time() * 1000)
|
||||||
|
|
||||||
|
def _wait_next_millis(self, last_timestamp):
|
||||||
|
timestamp = self._current_millis()
|
||||||
|
while timestamp <= last_timestamp:
|
||||||
|
timestamp = self._current_millis()
|
||||||
|
return timestamp
|
@ -4,10 +4,10 @@ import logging
|
|||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Any, Dict, Optional, Set, cast
|
from typing import Any, Dict, Optional, Set, cast
|
||||||
|
|
||||||
from ..aesprotocol import TPLinkAes
|
|
||||||
from ..credentials import Credentials
|
from ..credentials import Credentials
|
||||||
from ..exceptions import AuthenticationException
|
from ..exceptions import AuthenticationException
|
||||||
from ..smartdevice import SmartDevice
|
from ..smartdevice import SmartDevice
|
||||||
|
from ..smartprotocol import SmartProtocol
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -26,7 +26,7 @@ class TapoDevice(SmartDevice):
|
|||||||
super().__init__(host, port=port, credentials=credentials, timeout=timeout)
|
super().__init__(host, port=port, credentials=credentials, timeout=timeout)
|
||||||
self._state_information: Dict[str, Any] = {}
|
self._state_information: Dict[str, Any] = {}
|
||||||
self._discovery_info: Optional[Dict[str, Any]] = None
|
self._discovery_info: Optional[Dict[str, Any]] = None
|
||||||
self.protocol = TPLinkAes(host, credentials=credentials, timeout=timeout)
|
self.protocol = SmartProtocol(host, credentials=credentials, timeout=timeout)
|
||||||
|
|
||||||
async def update(self, update_children: bool = True):
|
async def update(self, update_children: bool = True):
|
||||||
"""Update the device."""
|
"""Update the device."""
|
||||||
|
@ -2,27 +2,45 @@ import asyncio
|
|||||||
import glob
|
import glob
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from json import dumps as json_dumps
|
||||||
from os.path import basename
|
from os.path import basename
|
||||||
from pathlib import Path, PurePath
|
from pathlib import Path, PurePath
|
||||||
from typing import Dict
|
from typing import Dict, Optional
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest # type: ignore # see https://github.com/pytest-dev/pytest/issues/3342
|
import pytest # type: ignore # see https://github.com/pytest-dev/pytest/issues/3342
|
||||||
|
|
||||||
from kasa import (
|
from kasa import (
|
||||||
|
Credentials,
|
||||||
Discover,
|
Discover,
|
||||||
SmartBulb,
|
SmartBulb,
|
||||||
SmartDimmer,
|
SmartDimmer,
|
||||||
SmartLightStrip,
|
SmartLightStrip,
|
||||||
SmartPlug,
|
SmartPlug,
|
||||||
SmartStrip,
|
SmartStrip,
|
||||||
|
TPLinkSmartHomeProtocol,
|
||||||
)
|
)
|
||||||
|
from kasa.tapo import TapoDevice, TapoPlug
|
||||||
|
|
||||||
from .newfakes import FakeTransportProtocol
|
from .newfakes import FakeSmartProtocol, FakeTransportProtocol
|
||||||
|
|
||||||
SUPPORTED_DEVICES = glob.glob(
|
SUPPORTED_IOT_DEVICES = [
|
||||||
|
(device, "IOT")
|
||||||
|
for device in glob.glob(
|
||||||
os.path.dirname(os.path.abspath(__file__)) + "/fixtures/*.json"
|
os.path.dirname(os.path.abspath(__file__)) + "/fixtures/*.json"
|
||||||
)
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
SUPPORTED_SMART_DEVICES = [
|
||||||
|
(device, "SMART")
|
||||||
|
for device in glob.glob(
|
||||||
|
os.path.dirname(os.path.abspath(__file__)) + "/fixtures/smart/*.json"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
SUPPORTED_DEVICES = SUPPORTED_IOT_DEVICES + SUPPORTED_SMART_DEVICES
|
||||||
|
|
||||||
|
|
||||||
LIGHT_STRIPS = {"KL400", "KL430", "KL420"}
|
LIGHT_STRIPS = {"KL400", "KL430", "KL420"}
|
||||||
@ -55,43 +73,59 @@ PLUGS = {
|
|||||||
"KP401",
|
"KP401",
|
||||||
"KS200M",
|
"KS200M",
|
||||||
}
|
}
|
||||||
|
|
||||||
STRIPS = {"HS107", "HS300", "KP303", "KP200", "KP400", "EP40"}
|
STRIPS = {"HS107", "HS300", "KP303", "KP200", "KP400", "EP40"}
|
||||||
DIMMERS = {"ES20M", "HS220", "KS220M", "KS230", "KP405"}
|
DIMMERS = {"ES20M", "HS220", "KS220M", "KS230", "KP405"}
|
||||||
|
|
||||||
DIMMABLE = {*BULBS, *DIMMERS}
|
DIMMABLE = {*BULBS, *DIMMERS}
|
||||||
WITH_EMETER = {"HS110", "HS300", "KP115", "KP125", *BULBS}
|
WITH_EMETER = {"HS110", "HS300", "KP115", "KP125", *BULBS}
|
||||||
|
|
||||||
ALL_DEVICES = BULBS.union(PLUGS).union(STRIPS).union(DIMMERS)
|
ALL_DEVICES_IOT = BULBS.union(PLUGS).union(STRIPS).union(DIMMERS)
|
||||||
|
|
||||||
|
PLUGS_SMART = {"P110"}
|
||||||
|
ALL_DEVICES_SMART = PLUGS_SMART
|
||||||
|
|
||||||
|
ALL_DEVICES = ALL_DEVICES_IOT.union(ALL_DEVICES_SMART)
|
||||||
|
|
||||||
IP_MODEL_CACHE: Dict[str, str] = {}
|
IP_MODEL_CACHE: Dict[str, str] = {}
|
||||||
|
|
||||||
|
|
||||||
def filter_model(desc, filter):
|
def idgenerator(paramtuple):
|
||||||
filtered = list()
|
return basename(paramtuple[0]) + (
|
||||||
for dev in SUPPORTED_DEVICES:
|
"" if paramtuple[1] == "IOT" else "-" + paramtuple[1]
|
||||||
for filt in filter:
|
)
|
||||||
if filt in basename(dev):
|
|
||||||
filtered.append(dev)
|
|
||||||
|
|
||||||
filtered_basenames = [basename(f) for f in filtered]
|
|
||||||
|
def filter_model(desc, model_filter, protocol_filter=None):
|
||||||
|
if not protocol_filter:
|
||||||
|
protocol_filter = {"IOT"}
|
||||||
|
filtered = list()
|
||||||
|
for file, protocol in SUPPORTED_DEVICES:
|
||||||
|
if protocol in protocol_filter:
|
||||||
|
file_model = basename(file).split("_")[0]
|
||||||
|
for model in model_filter:
|
||||||
|
if model in file_model:
|
||||||
|
filtered.append((file, protocol))
|
||||||
|
|
||||||
|
filtered_basenames = [basename(f) + "-" + p for f, p in filtered]
|
||||||
print(f"{desc}: {filtered_basenames}")
|
print(f"{desc}: {filtered_basenames}")
|
||||||
return filtered
|
return filtered
|
||||||
|
|
||||||
|
|
||||||
def parametrize(desc, devices, ids=None):
|
def parametrize(desc, devices, protocol_filter=None, ids=None):
|
||||||
return pytest.mark.parametrize(
|
return pytest.mark.parametrize(
|
||||||
"dev", filter_model(desc, devices), indirect=True, ids=ids
|
"dev", filter_model(desc, devices, protocol_filter), indirect=True, ids=ids
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
has_emeter = parametrize("has emeter", WITH_EMETER)
|
has_emeter = parametrize("has emeter", WITH_EMETER)
|
||||||
no_emeter = parametrize("no emeter", ALL_DEVICES - WITH_EMETER)
|
no_emeter = parametrize("no emeter", ALL_DEVICES_IOT - WITH_EMETER)
|
||||||
|
|
||||||
bulb = parametrize("bulbs", BULBS, ids=basename)
|
bulb = parametrize("bulbs", BULBS, ids=idgenerator)
|
||||||
plug = parametrize("plugs", PLUGS, ids=basename)
|
plug = parametrize("plugs", PLUGS, ids=idgenerator)
|
||||||
strip = parametrize("strips", STRIPS, ids=basename)
|
strip = parametrize("strips", STRIPS, ids=idgenerator)
|
||||||
dimmer = parametrize("dimmers", DIMMERS, ids=basename)
|
dimmer = parametrize("dimmers", DIMMERS, ids=idgenerator)
|
||||||
lightstrip = parametrize("lightstrips", LIGHT_STRIPS, ids=basename)
|
lightstrip = parametrize("lightstrips", LIGHT_STRIPS, ids=idgenerator)
|
||||||
|
|
||||||
# bulb types
|
# bulb types
|
||||||
dimmable = parametrize("dimmable", DIMMABLE)
|
dimmable = parametrize("dimmable", DIMMABLE)
|
||||||
@ -101,6 +135,58 @@ non_variable_temp = parametrize("non-variable color temp", BULBS - VARIABLE_TEMP
|
|||||||
color_bulb = parametrize("color bulbs", COLOR_BULBS)
|
color_bulb = parametrize("color bulbs", COLOR_BULBS)
|
||||||
non_color_bulb = parametrize("non-color bulbs", BULBS - COLOR_BULBS)
|
non_color_bulb = parametrize("non-color bulbs", BULBS - COLOR_BULBS)
|
||||||
|
|
||||||
|
plug_smart = parametrize(
|
||||||
|
"plug devices smart", PLUGS_SMART, protocol_filter={"SMART"}, ids=idgenerator
|
||||||
|
)
|
||||||
|
device_smart = parametrize(
|
||||||
|
"devices smart", ALL_DEVICES_SMART, protocol_filter={"SMART"}, ids=idgenerator
|
||||||
|
)
|
||||||
|
device_iot = parametrize(
|
||||||
|
"devices iot", ALL_DEVICES_IOT, protocol_filter={"IOT"}, ids=idgenerator
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_fixture_data():
|
||||||
|
"""Return raw discovery file contents as JSON. Used for discovery tests."""
|
||||||
|
fixture_data = {}
|
||||||
|
for file, protocol in SUPPORTED_DEVICES:
|
||||||
|
p = Path(file)
|
||||||
|
if not p.is_absolute():
|
||||||
|
folder = Path(__file__).parent / "fixtures"
|
||||||
|
if protocol == "SMART":
|
||||||
|
folder = folder / "smart"
|
||||||
|
p = folder / file
|
||||||
|
|
||||||
|
with open(p) as f:
|
||||||
|
fixture_data[basename(p)] = json.load(f)
|
||||||
|
return fixture_data
|
||||||
|
|
||||||
|
|
||||||
|
FIXTURE_DATA = get_fixture_data()
|
||||||
|
|
||||||
|
|
||||||
|
def filter_fixtures(desc, root_filter):
|
||||||
|
filtered = {}
|
||||||
|
for key, val in FIXTURE_DATA.items():
|
||||||
|
if root_filter in val:
|
||||||
|
filtered[key] = val
|
||||||
|
|
||||||
|
print(f"{desc}: {filtered.keys()}")
|
||||||
|
return filtered
|
||||||
|
|
||||||
|
|
||||||
|
def parametrize_discovery(desc, root_key):
|
||||||
|
filtered_fixtures = filter_fixtures(desc, root_key)
|
||||||
|
return pytest.mark.parametrize(
|
||||||
|
"discovery_data",
|
||||||
|
filtered_fixtures.values(),
|
||||||
|
indirect=True,
|
||||||
|
ids=filtered_fixtures.keys(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
new_discovery = parametrize_discovery("new discovery", "discovery_result")
|
||||||
|
|
||||||
|
|
||||||
def check_categories():
|
def check_categories():
|
||||||
"""Check that every fixture file is categorized."""
|
"""Check that every fixture file is categorized."""
|
||||||
@ -110,15 +196,15 @@ def check_categories():
|
|||||||
+ plug.args[1]
|
+ plug.args[1]
|
||||||
+ bulb.args[1]
|
+ bulb.args[1]
|
||||||
+ lightstrip.args[1]
|
+ lightstrip.args[1]
|
||||||
|
+ plug_smart.args[1]
|
||||||
)
|
)
|
||||||
diff = set(SUPPORTED_DEVICES) - set(categorized_fixtures)
|
diff = set(SUPPORTED_DEVICES) - set(categorized_fixtures)
|
||||||
if diff:
|
if diff:
|
||||||
for file in diff:
|
for file, protocol in diff:
|
||||||
print(
|
print(
|
||||||
"No category for file %s, add to the corresponding set (BULBS, PLUGS, ..)"
|
f"No category for file {file} protocol {protocol}, add to the corresponding set (BULBS, PLUGS, ..)"
|
||||||
% file
|
|
||||||
)
|
)
|
||||||
raise Exception("Missing category for %s" % diff)
|
raise Exception(f"Missing category for {diff}")
|
||||||
|
|
||||||
|
|
||||||
check_categories()
|
check_categories()
|
||||||
@ -134,7 +220,12 @@ async def handle_turn_on(dev, turn_on):
|
|||||||
await dev.turn_off()
|
await dev.turn_off()
|
||||||
|
|
||||||
|
|
||||||
def device_for_file(model):
|
def device_for_file(model, protocol):
|
||||||
|
if protocol == "SMART":
|
||||||
|
for d in PLUGS_SMART:
|
||||||
|
if d in model:
|
||||||
|
return TapoPlug
|
||||||
|
else:
|
||||||
for d in STRIPS:
|
for d in STRIPS:
|
||||||
if d in model:
|
if d in model:
|
||||||
return SmartStrip
|
return SmartStrip
|
||||||
@ -170,11 +261,14 @@ async def _discover_update_and_close(ip):
|
|||||||
return await _update_and_close(d)
|
return await _update_and_close(d)
|
||||||
|
|
||||||
|
|
||||||
async def get_device_for_file(file):
|
async def get_device_for_file(file, protocol):
|
||||||
# if the wanted file is not an absolute path, prepend the fixtures directory
|
# if the wanted file is not an absolute path, prepend the fixtures directory
|
||||||
p = Path(file)
|
p = Path(file)
|
||||||
if not p.is_absolute():
|
if not p.is_absolute():
|
||||||
p = Path(__file__).parent / "fixtures" / file
|
folder = Path(__file__).parent / "fixtures"
|
||||||
|
if protocol == "SMART":
|
||||||
|
folder = folder / "smart"
|
||||||
|
p = folder / file
|
||||||
|
|
||||||
def load_file():
|
def load_file():
|
||||||
with open(p) as f:
|
with open(p) as f:
|
||||||
@ -184,7 +278,11 @@ async def get_device_for_file(file):
|
|||||||
sysinfo = await loop.run_in_executor(None, load_file)
|
sysinfo = await loop.run_in_executor(None, load_file)
|
||||||
|
|
||||||
model = basename(file)
|
model = basename(file)
|
||||||
d = device_for_file(model)(host="127.0.0.123")
|
d = device_for_file(model, protocol)(host="127.0.0.123")
|
||||||
|
if protocol == "SMART":
|
||||||
|
d.protocol = FakeSmartProtocol(sysinfo)
|
||||||
|
d.credentials = Credentials("", "")
|
||||||
|
else:
|
||||||
d.protocol = FakeTransportProtocol(sysinfo)
|
d.protocol = FakeTransportProtocol(sysinfo)
|
||||||
await _update_and_close(d)
|
await _update_and_close(d)
|
||||||
return d
|
return d
|
||||||
@ -197,7 +295,7 @@ async def dev(request):
|
|||||||
Provides a device (given --ip) or parametrized fixture for the supported devices.
|
Provides a device (given --ip) or parametrized fixture for the supported devices.
|
||||||
The initial update is called automatically before returning the device.
|
The initial update is called automatically before returning the device.
|
||||||
"""
|
"""
|
||||||
file = request.param
|
file, protocol = request.param
|
||||||
|
|
||||||
ip = request.config.getoption("--ip")
|
ip = request.config.getoption("--ip")
|
||||||
if ip:
|
if ip:
|
||||||
@ -210,19 +308,62 @@ async def dev(request):
|
|||||||
pytest.skip(f"skipping file {file}")
|
pytest.skip(f"skipping file {file}")
|
||||||
return d if d else await _discover_update_and_close(ip)
|
return d if d else await _discover_update_and_close(ip)
|
||||||
|
|
||||||
return await get_device_for_file(file)
|
return await get_device_for_file(file, protocol)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(params=SUPPORTED_DEVICES, scope="session")
|
@pytest.fixture
|
||||||
|
def discovery_mock(discovery_data, mocker):
|
||||||
|
@dataclass
|
||||||
|
class _DiscoveryMock:
|
||||||
|
ip: str
|
||||||
|
default_port: int
|
||||||
|
discovery_data: dict
|
||||||
|
port_override: Optional[int] = None
|
||||||
|
|
||||||
|
if "result" in discovery_data:
|
||||||
|
datagram = (
|
||||||
|
b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8"
|
||||||
|
+ json_dumps(discovery_data).encode()
|
||||||
|
)
|
||||||
|
dm = _DiscoveryMock("127.0.0.123", 20002, discovery_data)
|
||||||
|
else:
|
||||||
|
datagram = TPLinkSmartHomeProtocol.encrypt(json_dumps(discovery_data))[4:]
|
||||||
|
dm = _DiscoveryMock("127.0.0.123", 9999, discovery_data)
|
||||||
|
|
||||||
|
def mock_discover(self):
|
||||||
|
port = (
|
||||||
|
dm.port_override
|
||||||
|
if dm.port_override and dm.default_port != 20002
|
||||||
|
else dm.default_port
|
||||||
|
)
|
||||||
|
self.datagram_received(
|
||||||
|
datagram,
|
||||||
|
(dm.ip, port),
|
||||||
|
)
|
||||||
|
|
||||||
|
mocker.patch("kasa.discover._DiscoverProtocol.do_discover", mock_discover)
|
||||||
|
mocker.patch(
|
||||||
|
"socket.getaddrinfo",
|
||||||
|
side_effect=lambda *_, **__: [(None, None, None, None, (dm.ip, 0))],
|
||||||
|
)
|
||||||
|
yield dm
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(params=FIXTURE_DATA.values(), ids=FIXTURE_DATA.keys(), scope="session")
|
||||||
def discovery_data(request):
|
def discovery_data(request):
|
||||||
"""Return raw discovery file contents as JSON. Used for discovery tests."""
|
"""Return raw discovery file contents as JSON. Used for discovery tests."""
|
||||||
file = request.param
|
fixture_data = request.param
|
||||||
p = Path(file)
|
if "discovery_result" in fixture_data:
|
||||||
if not p.is_absolute():
|
return {"result": fixture_data["discovery_result"]}
|
||||||
p = Path(__file__).parent / "fixtures" / file
|
else:
|
||||||
|
return {"system": {"get_sysinfo": fixture_data["system"]["get_sysinfo"]}}
|
||||||
|
|
||||||
with open(p) as f:
|
|
||||||
return json.load(f)
|
@pytest.fixture(params=FIXTURE_DATA.values(), ids=FIXTURE_DATA.keys(), scope="session")
|
||||||
|
def all_fixture_data(request):
|
||||||
|
"""Return raw fixture file contents as JSON. Used for discovery tests."""
|
||||||
|
fixture_data = request.param
|
||||||
|
return fixture_data
|
||||||
|
|
||||||
|
|
||||||
def pytest_addoption(parser):
|
def pytest_addoption(parser):
|
||||||
|
180
kasa/tests/fixtures/smart/P110_1.0_1.3.0.json
vendored
Normal file
180
kasa/tests/fixtures/smart/P110_1.0_1.3.0.json
vendored
Normal file
@ -0,0 +1,180 @@
|
|||||||
|
{
|
||||||
|
"component_nego": {
|
||||||
|
"component_list": [
|
||||||
|
{
|
||||||
|
"id": "device",
|
||||||
|
"ver_code": 2
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "firmware",
|
||||||
|
"ver_code": 2
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "quick_setup",
|
||||||
|
"ver_code": 3
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "time",
|
||||||
|
"ver_code": 1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "wireless",
|
||||||
|
"ver_code": 1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "schedule",
|
||||||
|
"ver_code": 2
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "countdown",
|
||||||
|
"ver_code": 2
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "antitheft",
|
||||||
|
"ver_code": 1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "account",
|
||||||
|
"ver_code": 1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "synchronize",
|
||||||
|
"ver_code": 1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "sunrise_sunset",
|
||||||
|
"ver_code": 1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "led",
|
||||||
|
"ver_code": 1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "cloud_connect",
|
||||||
|
"ver_code": 1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "iot_cloud",
|
||||||
|
"ver_code": 1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "device_local_time",
|
||||||
|
"ver_code": 1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "default_states",
|
||||||
|
"ver_code": 1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "auto_off",
|
||||||
|
"ver_code": 2
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "localSmart",
|
||||||
|
"ver_code": 1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "energy_monitoring",
|
||||||
|
"ver_code": 2
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "power_protection",
|
||||||
|
"ver_code": 1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "current_protection",
|
||||||
|
"ver_code": 1
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"discovery_result": {
|
||||||
|
"device_id": "00000000000000000000000000000000",
|
||||||
|
"device_model": "P110(UK)",
|
||||||
|
"device_type": "SMART.TAPOPLUG",
|
||||||
|
"factory_default": false,
|
||||||
|
"ip": "127.0.0.123",
|
||||||
|
"is_support_iot_cloud": true,
|
||||||
|
"mac": "00-00-00-00-00-00",
|
||||||
|
"mgt_encrypt_schm": {
|
||||||
|
"encrypt_type": "KLAP",
|
||||||
|
"http_port": 80,
|
||||||
|
"is_support_https": false,
|
||||||
|
"lv": 2
|
||||||
|
},
|
||||||
|
"obd_src": "tplink",
|
||||||
|
"owner": "00000000000000000000000000000000"
|
||||||
|
},
|
||||||
|
"get_current_power": {
|
||||||
|
"current_power": 0
|
||||||
|
},
|
||||||
|
"get_device_info": {
|
||||||
|
"auto_off_remain_time": 0,
|
||||||
|
"auto_off_status": "off",
|
||||||
|
"avatar": "plug",
|
||||||
|
"default_states": {
|
||||||
|
"state": {},
|
||||||
|
"type": "last_states"
|
||||||
|
},
|
||||||
|
"device_id": "0000000000000000000000000000000000000000",
|
||||||
|
"device_on": true,
|
||||||
|
"fw_id": "00000000000000000000000000000000",
|
||||||
|
"fw_ver": "1.3.0 Build 230905 Rel.152200",
|
||||||
|
"has_set_location_info": true,
|
||||||
|
"hw_id": "00000000000000000000000000000000",
|
||||||
|
"hw_ver": "1.0",
|
||||||
|
"ip": "127.0.0.123",
|
||||||
|
"lang": "en_US",
|
||||||
|
"latitude": 0,
|
||||||
|
"longitude": 0,
|
||||||
|
"mac": "00-00-00-00-00-00",
|
||||||
|
"model": "P110",
|
||||||
|
"nickname": "VGFwaSBTbWFydCBQbHVnIDE=",
|
||||||
|
"oem_id": "00000000000000000000000000000000",
|
||||||
|
"on_time": 119335,
|
||||||
|
"overcurrent_status": "normal",
|
||||||
|
"overheated": false,
|
||||||
|
"power_protection_status": "normal",
|
||||||
|
"region": "Europe/London",
|
||||||
|
"rssi": -57,
|
||||||
|
"signal_level": 2,
|
||||||
|
"specs": "",
|
||||||
|
"ssid": "IyNNQVNLRUROQU1FIyM=",
|
||||||
|
"time_diff": 0,
|
||||||
|
"type": "SMART.TAPOPLUG"
|
||||||
|
},
|
||||||
|
"get_device_time": {
|
||||||
|
"region": "Europe/London",
|
||||||
|
"time_diff": 0,
|
||||||
|
"timestamp": 1701370224
|
||||||
|
},
|
||||||
|
"get_device_usage": {
|
||||||
|
"power_usage": {
|
||||||
|
"past30": 75,
|
||||||
|
"past7": 69,
|
||||||
|
"today": 0
|
||||||
|
},
|
||||||
|
"saved_power": {
|
||||||
|
"past30": 2029,
|
||||||
|
"past7": 1964,
|
||||||
|
"today": 1130
|
||||||
|
},
|
||||||
|
"time_usage": {
|
||||||
|
"past30": 2104,
|
||||||
|
"past7": 2033,
|
||||||
|
"today": 1130
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"get_energy_usage": {
|
||||||
|
"current_power": 0,
|
||||||
|
"electricity_charge": [
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"local_time": "2023-11-30 18:50:24",
|
||||||
|
"month_energy": 75,
|
||||||
|
"month_runtime": 2104,
|
||||||
|
"today_energy": 0,
|
||||||
|
"today_runtime": 1130
|
||||||
|
}
|
||||||
|
}
|
@ -1,6 +1,7 @@
|
|||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
from json import loads as json_loads
|
||||||
|
|
||||||
from voluptuous import (
|
from voluptuous import (
|
||||||
REMOVE_EXTRA,
|
REMOVE_EXTRA,
|
||||||
@ -13,7 +14,8 @@ from voluptuous import (
|
|||||||
Schema,
|
Schema,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..protocol import TPLinkSmartHomeProtocol
|
from ..protocol import BaseTransport, TPLinkSmartHomeProtocol
|
||||||
|
from ..smartprotocol import SmartProtocol
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -285,6 +287,41 @@ TIME_MODULE = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class FakeSmartProtocol(SmartProtocol):
|
||||||
|
def __init__(self, info):
|
||||||
|
super().__init__("127.0.0.123", transport=FakeSmartTransport(info))
|
||||||
|
|
||||||
|
|
||||||
|
class FakeSmartTransport(BaseTransport):
|
||||||
|
def __init__(self, info):
|
||||||
|
self.info = info
|
||||||
|
|
||||||
|
@property
|
||||||
|
def needs_handshake(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def needs_login(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def login(self, request: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def handshake(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def send(self, request: str):
|
||||||
|
request_dict = json_loads(request)
|
||||||
|
method = request_dict["method"]
|
||||||
|
if method == "component_nego" or method[:4] == "get_":
|
||||||
|
return self.info[method]
|
||||||
|
elif method[:4] == "set_":
|
||||||
|
_LOGGER.debug("Call %s not implemented, doing nothing", method)
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class FakeTransportProtocol(TPLinkSmartHomeProtocol):
|
class FakeTransportProtocol(TPLinkSmartHomeProtocol):
|
||||||
def __init__(self, info):
|
def __init__(self, info):
|
||||||
self.discovery_data = info
|
self.discovery_data = info
|
||||||
|
@ -6,12 +6,15 @@ from asyncclick.testing import CliRunner
|
|||||||
|
|
||||||
from kasa import SmartDevice, TPLinkSmartHomeProtocol
|
from kasa import SmartDevice, TPLinkSmartHomeProtocol
|
||||||
from kasa.cli import alias, brightness, cli, emeter, raw_command, state, sysinfo, toggle
|
from kasa.cli import alias, brightness, cli, emeter, raw_command, state, sysinfo, toggle
|
||||||
|
from kasa.device_factory import DEVICE_TYPE_TO_CLASS
|
||||||
from kasa.discover import Discover
|
from kasa.discover import Discover
|
||||||
|
from kasa.smartprotocol import SmartProtocol
|
||||||
|
|
||||||
from .conftest import handle_turn_on, turn_on
|
from .conftest import device_iot, handle_turn_on, new_discovery, turn_on
|
||||||
from .newfakes import FakeTransportProtocol
|
from .newfakes import FakeSmartProtocol, FakeTransportProtocol
|
||||||
|
|
||||||
|
|
||||||
|
@device_iot
|
||||||
async def test_sysinfo(dev):
|
async def test_sysinfo(dev):
|
||||||
runner = CliRunner()
|
runner = CliRunner()
|
||||||
res = await runner.invoke(sysinfo, obj=dev)
|
res = await runner.invoke(sysinfo, obj=dev)
|
||||||
@ -19,6 +22,7 @@ async def test_sysinfo(dev):
|
|||||||
assert dev.alias in res.output
|
assert dev.alias in res.output
|
||||||
|
|
||||||
|
|
||||||
|
@device_iot
|
||||||
@turn_on
|
@turn_on
|
||||||
async def test_state(dev, turn_on):
|
async def test_state(dev, turn_on):
|
||||||
await handle_turn_on(dev, turn_on)
|
await handle_turn_on(dev, turn_on)
|
||||||
@ -32,6 +36,7 @@ async def test_state(dev, turn_on):
|
|||||||
assert "Device state: False" in res.output
|
assert "Device state: False" in res.output
|
||||||
|
|
||||||
|
|
||||||
|
@device_iot
|
||||||
@turn_on
|
@turn_on
|
||||||
async def test_toggle(dev, turn_on, mocker):
|
async def test_toggle(dev, turn_on, mocker):
|
||||||
await handle_turn_on(dev, turn_on)
|
await handle_turn_on(dev, turn_on)
|
||||||
@ -44,6 +49,7 @@ async def test_toggle(dev, turn_on, mocker):
|
|||||||
assert dev.is_on
|
assert dev.is_on
|
||||||
|
|
||||||
|
|
||||||
|
@device_iot
|
||||||
async def test_alias(dev):
|
async def test_alias(dev):
|
||||||
runner = CliRunner()
|
runner = CliRunner()
|
||||||
|
|
||||||
@ -62,6 +68,7 @@ async def test_alias(dev):
|
|||||||
await dev.set_alias(old_alias)
|
await dev.set_alias(old_alias)
|
||||||
|
|
||||||
|
|
||||||
|
@device_iot
|
||||||
async def test_raw_command(dev):
|
async def test_raw_command(dev):
|
||||||
runner = CliRunner()
|
runner = CliRunner()
|
||||||
res = await runner.invoke(raw_command, ["system", "get_sysinfo"], obj=dev)
|
res = await runner.invoke(raw_command, ["system", "get_sysinfo"], obj=dev)
|
||||||
@ -74,6 +81,7 @@ async def test_raw_command(dev):
|
|||||||
assert "Usage" in res.output
|
assert "Usage" in res.output
|
||||||
|
|
||||||
|
|
||||||
|
@device_iot
|
||||||
async def test_emeter(dev: SmartDevice, mocker):
|
async def test_emeter(dev: SmartDevice, mocker):
|
||||||
runner = CliRunner()
|
runner = CliRunner()
|
||||||
|
|
||||||
@ -99,6 +107,7 @@ async def test_emeter(dev: SmartDevice, mocker):
|
|||||||
daily.assert_called_with(year=1900, month=12)
|
daily.assert_called_with(year=1900, month=12)
|
||||||
|
|
||||||
|
|
||||||
|
@device_iot
|
||||||
async def test_brightness(dev):
|
async def test_brightness(dev):
|
||||||
runner = CliRunner()
|
runner = CliRunner()
|
||||||
res = await runner.invoke(brightness, obj=dev)
|
res = await runner.invoke(brightness, obj=dev)
|
||||||
@ -116,6 +125,7 @@ async def test_brightness(dev):
|
|||||||
assert "Brightness: 12" in res.output
|
assert "Brightness: 12" in res.output
|
||||||
|
|
||||||
|
|
||||||
|
@device_iot
|
||||||
async def test_json_output(dev: SmartDevice, mocker):
|
async def test_json_output(dev: SmartDevice, mocker):
|
||||||
"""Test that the json output produces correct output."""
|
"""Test that the json output produces correct output."""
|
||||||
mocker.patch("kasa.Discover.discover", return_value=[dev])
|
mocker.patch("kasa.Discover.discover", return_value=[dev])
|
||||||
@ -125,13 +135,9 @@ async def test_json_output(dev: SmartDevice, mocker):
|
|||||||
assert json.loads(res.output) == dev.internal_state
|
assert json.loads(res.output) == dev.internal_state
|
||||||
|
|
||||||
|
|
||||||
async def test_credentials(discovery_data: dict, mocker):
|
@new_discovery
|
||||||
|
async def test_credentials(discovery_mock, mocker):
|
||||||
"""Test credentials are passed correctly from cli to device."""
|
"""Test credentials are passed correctly from cli to device."""
|
||||||
# As this is testing the device constructor need to explicitly wire in
|
|
||||||
# the FakeTransportProtocol
|
|
||||||
ftp = FakeTransportProtocol(discovery_data)
|
|
||||||
mocker.patch.object(TPLinkSmartHomeProtocol, "query", ftp.query)
|
|
||||||
|
|
||||||
# Patch state to echo username and password
|
# Patch state to echo username and password
|
||||||
pass_dev = click.make_pass_decorator(SmartDevice)
|
pass_dev = click.make_pass_decorator(SmartDevice)
|
||||||
|
|
||||||
@ -143,18 +149,15 @@ async def test_credentials(discovery_data: dict, mocker):
|
|||||||
)
|
)
|
||||||
|
|
||||||
mocker.patch("kasa.cli.state", new=_state)
|
mocker.patch("kasa.cli.state", new=_state)
|
||||||
cli_device_type = Discover._get_device_class(discovery_data)(
|
for subclass in DEVICE_TYPE_TO_CLASS.values():
|
||||||
"any"
|
mocker.patch.object(subclass, "update")
|
||||||
).device_type.value
|
|
||||||
|
|
||||||
runner = CliRunner()
|
runner = CliRunner()
|
||||||
res = await runner.invoke(
|
res = await runner.invoke(
|
||||||
cli,
|
cli,
|
||||||
[
|
[
|
||||||
"--host",
|
"--host",
|
||||||
"127.0.0.1",
|
"127.0.0.123",
|
||||||
"--type",
|
|
||||||
cli_device_type,
|
|
||||||
"--username",
|
"--username",
|
||||||
"foo",
|
"foo",
|
||||||
"--password",
|
"--password",
|
||||||
@ -162,9 +165,11 @@ async def test_credentials(discovery_data: dict, mocker):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
assert res.exit_code == 0
|
assert res.exit_code == 0
|
||||||
assert res.output == "Username:foo Password:bar\n"
|
|
||||||
|
assert "Username:foo Password:bar\n" in res.output
|
||||||
|
|
||||||
|
|
||||||
|
@device_iot
|
||||||
async def test_without_device_type(discovery_data: dict, dev, mocker):
|
async def test_without_device_type(discovery_data: dict, dev, mocker):
|
||||||
"""Test connecting without the device type."""
|
"""Test connecting without the device type."""
|
||||||
runner = CliRunner()
|
runner = CliRunner()
|
||||||
|
@ -5,7 +5,9 @@ from typing import Type
|
|||||||
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
|
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
|
||||||
|
|
||||||
from kasa import (
|
from kasa import (
|
||||||
|
Credentials,
|
||||||
DeviceType,
|
DeviceType,
|
||||||
|
Discover,
|
||||||
SmartBulb,
|
SmartBulb,
|
||||||
SmartDevice,
|
SmartDevice,
|
||||||
SmartDeviceException,
|
SmartDeviceException,
|
||||||
@ -13,8 +15,13 @@ from kasa import (
|
|||||||
SmartLightStrip,
|
SmartLightStrip,
|
||||||
SmartPlug,
|
SmartPlug,
|
||||||
)
|
)
|
||||||
from kasa.device_factory import connect
|
from kasa.device_factory import (
|
||||||
from kasa.klapprotocol import TPLinkKlap
|
DEVICE_TYPE_TO_CLASS,
|
||||||
|
connect,
|
||||||
|
get_protocol_from_connection_name,
|
||||||
|
)
|
||||||
|
from kasa.discover import DiscoveryResult
|
||||||
|
from kasa.iotprotocol import IotProtocol
|
||||||
from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol
|
from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol
|
||||||
|
|
||||||
|
|
||||||
@ -22,8 +29,12 @@ from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol
|
|||||||
async def test_connect(discovery_data: dict, mocker, custom_port):
|
async def test_connect(discovery_data: dict, mocker, custom_port):
|
||||||
"""Make sure that connect returns an initialized SmartDevice instance."""
|
"""Make sure that connect returns an initialized SmartDevice instance."""
|
||||||
host = "127.0.0.1"
|
host = "127.0.0.1"
|
||||||
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
|
|
||||||
|
|
||||||
|
if "result" in discovery_data:
|
||||||
|
with pytest.raises(SmartDeviceException):
|
||||||
|
dev = await connect(host, port=custom_port)
|
||||||
|
else:
|
||||||
|
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
|
||||||
dev = await connect(host, port=custom_port)
|
dev = await connect(host, port=custom_port)
|
||||||
assert issubclass(dev.__class__, SmartDevice)
|
assert issubclass(dev.__class__, SmartDevice)
|
||||||
assert dev.port == custom_port or dev.port == 9999
|
assert dev.port == custom_port or dev.port == 9999
|
||||||
@ -49,8 +60,12 @@ async def test_connect_passed_device_type(
|
|||||||
):
|
):
|
||||||
"""Make sure that connect with a passed device type."""
|
"""Make sure that connect with a passed device type."""
|
||||||
host = "127.0.0.1"
|
host = "127.0.0.1"
|
||||||
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
|
|
||||||
|
|
||||||
|
if "result" in discovery_data:
|
||||||
|
with pytest.raises(SmartDeviceException):
|
||||||
|
dev = await connect(host, port=custom_port)
|
||||||
|
else:
|
||||||
|
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
|
||||||
dev = await connect(host, port=custom_port, device_type=device_type)
|
dev = await connect(host, port=custom_port, device_type=device_type)
|
||||||
assert isinstance(dev, klass)
|
assert isinstance(dev, klass)
|
||||||
assert dev.port == custom_port or dev.port == 9999
|
assert dev.port == custom_port or dev.port == 9999
|
||||||
@ -70,32 +85,52 @@ async def test_connect_logs_connect_time(
|
|||||||
):
|
):
|
||||||
"""Test that the connect time is logged when debug logging is enabled."""
|
"""Test that the connect time is logged when debug logging is enabled."""
|
||||||
host = "127.0.0.1"
|
host = "127.0.0.1"
|
||||||
|
if "result" in discovery_data:
|
||||||
|
with pytest.raises(SmartDeviceException):
|
||||||
|
await connect(host)
|
||||||
|
else:
|
||||||
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
|
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
|
||||||
logging.getLogger("kasa").setLevel(logging.DEBUG)
|
logging.getLogger("kasa").setLevel(logging.DEBUG)
|
||||||
await connect(host)
|
await connect(host)
|
||||||
assert "seconds to connect" in caplog.text
|
assert "seconds to connect" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("device_type", [DeviceType.Plug, None])
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
("protocol_in", "protocol_result"),
|
|
||||||
(
|
|
||||||
(None, TPLinkSmartHomeProtocol),
|
|
||||||
(TPLinkKlap, TPLinkKlap),
|
|
||||||
(TPLinkSmartHomeProtocol, TPLinkSmartHomeProtocol),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
async def test_connect_pass_protocol(
|
async def test_connect_pass_protocol(
|
||||||
discovery_data: dict,
|
all_fixture_data: dict,
|
||||||
mocker,
|
mocker,
|
||||||
device_type: DeviceType,
|
|
||||||
protocol_in: Type[TPLinkProtocol],
|
|
||||||
protocol_result: Type[TPLinkProtocol],
|
|
||||||
):
|
):
|
||||||
"""Test that if the protocol is passed in it's gets set correctly."""
|
"""Test that if the protocol is passed in it's gets set correctly."""
|
||||||
host = "127.0.0.1"
|
if "discovery_result" in all_fixture_data:
|
||||||
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
|
discovery_info = {"result": all_fixture_data["discovery_result"]}
|
||||||
mocker.patch("kasa.TPLinkKlap.query", return_value=discovery_data)
|
device_class = Discover._get_device_class(discovery_info)
|
||||||
|
else:
|
||||||
|
device_class = Discover._get_device_class(all_fixture_data)
|
||||||
|
|
||||||
dev = await connect(host, device_type=device_type, protocol_class=protocol_in)
|
device_type = list(DEVICE_TYPE_TO_CLASS.keys())[
|
||||||
assert isinstance(dev.protocol, protocol_result)
|
list(DEVICE_TYPE_TO_CLASS.values()).index(device_class)
|
||||||
|
]
|
||||||
|
host = "127.0.0.1"
|
||||||
|
if "discovery_result" in all_fixture_data:
|
||||||
|
mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data)
|
||||||
|
mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data)
|
||||||
|
|
||||||
|
dr = DiscoveryResult(**discovery_info["result"])
|
||||||
|
connection_name = (
|
||||||
|
dr.device_type.split(".")[0] + "." + dr.mgt_encrypt_schm.encrypt_type
|
||||||
|
)
|
||||||
|
protocol_class = get_protocol_from_connection_name(
|
||||||
|
connection_name, host
|
||||||
|
).__class__
|
||||||
|
else:
|
||||||
|
mocker.patch(
|
||||||
|
"kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data
|
||||||
|
)
|
||||||
|
protocol_class = TPLinkSmartHomeProtocol
|
||||||
|
|
||||||
|
dev = await connect(
|
||||||
|
host,
|
||||||
|
device_type=device_type,
|
||||||
|
protocol_class=protocol_class,
|
||||||
|
credentials=Credentials("", ""),
|
||||||
|
)
|
||||||
|
assert isinstance(dev.protocol, protocol_class)
|
||||||
|
@ -17,6 +17,27 @@ from kasa.exceptions import AuthenticationException, UnsupportedDeviceException
|
|||||||
|
|
||||||
from .conftest import bulb, dimmer, lightstrip, plug, strip
|
from .conftest import bulb, dimmer, lightstrip, plug, strip
|
||||||
|
|
||||||
|
UNSUPPORTED = {
|
||||||
|
"result": {
|
||||||
|
"device_id": "xx",
|
||||||
|
"owner": "xx",
|
||||||
|
"device_type": "SMART.TAPOXMASTREE",
|
||||||
|
"device_model": "P110(EU)",
|
||||||
|
"ip": "127.0.0.1",
|
||||||
|
"mac": "48-22xxx",
|
||||||
|
"is_support_iot_cloud": True,
|
||||||
|
"obd_src": "tplink",
|
||||||
|
"factory_default": False,
|
||||||
|
"mgt_encrypt_schm": {
|
||||||
|
"is_support_https": False,
|
||||||
|
"encrypt_type": "AES",
|
||||||
|
"http_port": 80,
|
||||||
|
"lv": 2,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"error_code": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@plug
|
@plug
|
||||||
async def test_type_detection_plug(dev: SmartDevice):
|
async def test_type_detection_plug(dev: SmartDevice):
|
||||||
@ -62,76 +83,40 @@ async def test_type_unknown():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("custom_port", [123, None])
|
@pytest.mark.parametrize("custom_port", [123, None])
|
||||||
async def test_discover_single(discovery_data: dict, mocker, custom_port):
|
# @pytest.mark.parametrize("discovery_mock", [("127.0.0.1",123), ("127.0.0.1",None)], indirect=True)
|
||||||
|
async def test_discover_single(discovery_mock, custom_port, mocker):
|
||||||
"""Make sure that discover_single returns an initialized SmartDevice instance."""
|
"""Make sure that discover_single returns an initialized SmartDevice instance."""
|
||||||
host = "127.0.0.1"
|
host = "127.0.0.1"
|
||||||
info = {"system": {"get_sysinfo": discovery_data["system"]["get_sysinfo"]}}
|
discovery_mock.ip = host
|
||||||
query_mock = mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=info)
|
discovery_mock.port_override = custom_port
|
||||||
|
update_mock = mocker.patch.object(SmartStrip, "update")
|
||||||
def mock_discover(self):
|
|
||||||
self.datagram_received(
|
|
||||||
protocol.TPLinkSmartHomeProtocol.encrypt(json_dumps(info))[4:],
|
|
||||||
(host, custom_port or 9999),
|
|
||||||
)
|
|
||||||
|
|
||||||
mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover)
|
|
||||||
|
|
||||||
x = await Discover.discover_single(host, port=custom_port)
|
x = await Discover.discover_single(host, port=custom_port)
|
||||||
assert issubclass(x.__class__, SmartDevice)
|
assert issubclass(x.__class__, SmartDevice)
|
||||||
assert x._sys_info is not None
|
assert x._discovery_info is not None
|
||||||
assert x.port == custom_port or x.port == 9999
|
assert x.port == custom_port or x.port == discovery_mock.default_port
|
||||||
assert (query_mock.call_count > 0) == isinstance(x, SmartStrip)
|
assert (update_mock.call_count > 0) == isinstance(x, SmartStrip)
|
||||||
|
|
||||||
|
|
||||||
async def test_discover_single_hostname(discovery_data: dict, mocker):
|
async def test_discover_single_hostname(discovery_mock, mocker):
|
||||||
"""Make sure that discover_single returns an initialized SmartDevice instance."""
|
"""Make sure that discover_single returns an initialized SmartDevice instance."""
|
||||||
host = "foobar"
|
host = "foobar"
|
||||||
ip = "127.0.0.1"
|
ip = "127.0.0.1"
|
||||||
info = {"system": {"get_sysinfo": discovery_data["system"]["get_sysinfo"]}}
|
|
||||||
query_mock = mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=info)
|
|
||||||
|
|
||||||
def mock_discover(self):
|
discovery_mock.ip = ip
|
||||||
self.datagram_received(
|
update_mock = mocker.patch.object(SmartStrip, "update")
|
||||||
protocol.TPLinkSmartHomeProtocol.encrypt(json_dumps(info))[4:],
|
|
||||||
(ip, 9999),
|
|
||||||
)
|
|
||||||
|
|
||||||
mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover)
|
|
||||||
mocker.patch("socket.getaddrinfo", return_value=[(None, None, None, None, (ip, 0))])
|
|
||||||
|
|
||||||
x = await Discover.discover_single(host)
|
x = await Discover.discover_single(host)
|
||||||
assert issubclass(x.__class__, SmartDevice)
|
assert issubclass(x.__class__, SmartDevice)
|
||||||
assert x._sys_info is not None
|
assert x._discovery_info is not None
|
||||||
assert x.host == host
|
assert x.host == host
|
||||||
assert (query_mock.call_count > 0) == isinstance(x, SmartStrip)
|
assert (update_mock.call_count > 0) == isinstance(x, SmartStrip)
|
||||||
|
|
||||||
mocker.patch("socket.getaddrinfo", side_effect=socket.gaierror())
|
mocker.patch("socket.getaddrinfo", side_effect=socket.gaierror())
|
||||||
with pytest.raises(SmartDeviceException):
|
with pytest.raises(SmartDeviceException):
|
||||||
x = await Discover.discover_single(host)
|
x = await Discover.discover_single(host)
|
||||||
|
|
||||||
|
|
||||||
UNSUPPORTED = {
|
|
||||||
"result": {
|
|
||||||
"device_id": "xx",
|
|
||||||
"owner": "xx",
|
|
||||||
"device_type": "SMART.TAPOXMASTREE",
|
|
||||||
"device_model": "P110(EU)",
|
|
||||||
"ip": "127.0.0.1",
|
|
||||||
"mac": "48-22xxx",
|
|
||||||
"is_support_iot_cloud": True,
|
|
||||||
"obd_src": "tplink",
|
|
||||||
"factory_default": False,
|
|
||||||
"mgt_encrypt_schm": {
|
|
||||||
"is_support_https": False,
|
|
||||||
"encrypt_type": "AES",
|
|
||||||
"http_port": 80,
|
|
||||||
"lv": 2,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"error_code": 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def test_discover_single_unsupported(mocker):
|
async def test_discover_single_unsupported(mocker):
|
||||||
"""Make sure that discover_single handles unsupported devices correctly."""
|
"""Make sure that discover_single handles unsupported devices correctly."""
|
||||||
host = "127.0.0.1"
|
host = "127.0.0.1"
|
||||||
@ -201,14 +186,17 @@ async def test_discover_send(mocker):
|
|||||||
async def test_discover_datagram_received(mocker, discovery_data):
|
async def test_discover_datagram_received(mocker, discovery_data):
|
||||||
"""Verify that datagram received fills discovered_devices."""
|
"""Verify that datagram received fills discovered_devices."""
|
||||||
proto = _DiscoverProtocol()
|
proto = _DiscoverProtocol()
|
||||||
info = {"system": {"get_sysinfo": discovery_data["system"]["get_sysinfo"]}}
|
|
||||||
mocker.patch("kasa.discover.json_loads", return_value=info)
|
|
||||||
mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "encrypt")
|
|
||||||
mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "decrypt")
|
mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "decrypt")
|
||||||
|
|
||||||
addr = "127.0.0.1"
|
addr = "127.0.0.1"
|
||||||
proto.datagram_received("<placeholder data>", (addr, 9999))
|
port = 20002 if "result" in discovery_data else 9999
|
||||||
|
|
||||||
|
mocker.patch("kasa.discover.json_loads", return_value=discovery_data)
|
||||||
|
proto.datagram_received("<placeholder data>", (addr, port))
|
||||||
|
|
||||||
addr2 = "127.0.0.2"
|
addr2 = "127.0.0.2"
|
||||||
|
mocker.patch("kasa.discover.json_loads", return_value=UNSUPPORTED)
|
||||||
proto.datagram_received("<placeholder data>", (addr2, 20002))
|
proto.datagram_received("<placeholder data>", (addr2, 20002))
|
||||||
|
|
||||||
# Check that device in discovered_devices is initialized correctly
|
# Check that device in discovered_devices is initialized correctly
|
||||||
|
@ -10,9 +10,14 @@ from contextlib import nullcontext as does_not_raise
|
|||||||
import httpx
|
import httpx
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from ..aestransport import AesTransport
|
||||||
from ..credentials import Credentials
|
from ..credentials import Credentials
|
||||||
from ..exceptions import AuthenticationException, SmartDeviceException
|
from ..exceptions import AuthenticationException, SmartDeviceException
|
||||||
from ..klapprotocol import KlapEncryptionSession, TPLinkKlap, _sha256
|
from ..iotprotocol import IotProtocol
|
||||||
|
from ..klaptransport import KlapEncryptionSession, KlapTransport, _sha256
|
||||||
|
from ..smartprotocol import SmartProtocol
|
||||||
|
|
||||||
|
DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}}
|
||||||
|
|
||||||
|
|
||||||
class _mock_response:
|
class _mock_response:
|
||||||
@ -21,67 +26,92 @@ class _mock_response:
|
|||||||
self.content = content
|
self.content = content
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
|
||||||
|
@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol])
|
||||||
@pytest.mark.parametrize("retry_count", [1, 3, 5])
|
@pytest.mark.parametrize("retry_count", [1, 3, 5])
|
||||||
async def test_protocol_retries(mocker, retry_count):
|
async def test_protocol_retries(mocker, retry_count, protocol_class, transport_class):
|
||||||
|
host = "127.0.0.1"
|
||||||
conn = mocker.patch.object(
|
conn = mocker.patch.object(
|
||||||
TPLinkKlap, "client_post", side_effect=Exception("dummy exception")
|
transport_class, "client_post", side_effect=Exception("dummy exception")
|
||||||
)
|
)
|
||||||
with pytest.raises(SmartDeviceException):
|
with pytest.raises(SmartDeviceException):
|
||||||
await TPLinkKlap("127.0.0.1").query({}, retry_count=retry_count)
|
await protocol_class(host, transport=transport_class(host)).query(
|
||||||
|
DUMMY_QUERY, retry_count=retry_count
|
||||||
|
)
|
||||||
|
|
||||||
assert conn.call_count == retry_count + 1
|
assert conn.call_count == retry_count + 1
|
||||||
|
|
||||||
|
|
||||||
async def test_protocol_no_retry_on_connection_error(mocker):
|
@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
|
||||||
|
@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol])
|
||||||
|
async def test_protocol_no_retry_on_connection_error(
|
||||||
|
mocker, protocol_class, transport_class
|
||||||
|
):
|
||||||
|
host = "127.0.0.1"
|
||||||
conn = mocker.patch.object(
|
conn = mocker.patch.object(
|
||||||
TPLinkKlap,
|
transport_class,
|
||||||
"client_post",
|
"client_post",
|
||||||
side_effect=httpx.ConnectError("foo"),
|
side_effect=httpx.ConnectError("foo"),
|
||||||
)
|
)
|
||||||
with pytest.raises(SmartDeviceException):
|
with pytest.raises(SmartDeviceException):
|
||||||
await TPLinkKlap("127.0.0.1").query({}, retry_count=5)
|
await protocol_class(host, transport=transport_class(host)).query(
|
||||||
|
DUMMY_QUERY, retry_count=5
|
||||||
|
)
|
||||||
|
|
||||||
assert conn.call_count == 1
|
assert conn.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
async def test_protocol_retry_recoverable_error(mocker):
|
@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
|
||||||
|
@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol])
|
||||||
|
async def test_protocol_retry_recoverable_error(
|
||||||
|
mocker, protocol_class, transport_class
|
||||||
|
):
|
||||||
|
host = "127.0.0.1"
|
||||||
conn = mocker.patch.object(
|
conn = mocker.patch.object(
|
||||||
TPLinkKlap,
|
transport_class,
|
||||||
"client_post",
|
"client_post",
|
||||||
side_effect=httpx.CloseError("foo"),
|
side_effect=httpx.CloseError("foo"),
|
||||||
)
|
)
|
||||||
with pytest.raises(SmartDeviceException):
|
with pytest.raises(SmartDeviceException):
|
||||||
await TPLinkKlap("127.0.0.1").query({}, retry_count=5)
|
await protocol_class(host, transport=transport_class(host)).query(
|
||||||
|
DUMMY_QUERY, retry_count=5
|
||||||
|
)
|
||||||
|
|
||||||
assert conn.call_count == 6
|
assert conn.call_count == 6
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
|
||||||
|
@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol])
|
||||||
@pytest.mark.parametrize("retry_count", [1, 3, 5])
|
@pytest.mark.parametrize("retry_count", [1, 3, 5])
|
||||||
async def test_protocol_reconnect(mocker, retry_count):
|
async def test_protocol_reconnect(mocker, retry_count, protocol_class, transport_class):
|
||||||
|
host = "127.0.0.1"
|
||||||
remaining = retry_count
|
remaining = retry_count
|
||||||
|
mock_response = {"result": {"great": "success"}}
|
||||||
|
|
||||||
def _fail_one_less_than_retry_count(*_, **__):
|
def _fail_one_less_than_retry_count(*_, **__):
|
||||||
nonlocal remaining, encryption_session
|
nonlocal remaining
|
||||||
remaining -= 1
|
remaining -= 1
|
||||||
if remaining:
|
if remaining:
|
||||||
raise Exception("Simulated post failure")
|
raise Exception("Simulated post failure")
|
||||||
# Do the encrypt just before returning the value so the incrementing sequence number is correct
|
|
||||||
encrypted, seq = encryption_session.encrypt('{"great":"success"}')
|
|
||||||
return 200, encrypted
|
|
||||||
|
|
||||||
seed = secrets.token_bytes(16)
|
return mock_response
|
||||||
auth_hash = TPLinkKlap.generate_auth_hash(Credentials("foo", "bar"))
|
|
||||||
encryption_session = KlapEncryptionSession(seed, seed, auth_hash)
|
|
||||||
protocol = TPLinkKlap("127.0.0.1")
|
|
||||||
protocol.handshake_done = True
|
|
||||||
protocol.session_expire_at = time.time() + 86400
|
|
||||||
protocol.encryption_session = encryption_session
|
|
||||||
mocker.patch.object(
|
mocker.patch.object(
|
||||||
TPLinkKlap, "client_post", side_effect=_fail_one_less_than_retry_count
|
transport_class, "needs_handshake", property(lambda self: False)
|
||||||
|
)
|
||||||
|
mocker.patch.object(transport_class, "needs_login", property(lambda self: False))
|
||||||
|
|
||||||
|
send_mock = mocker.patch.object(
|
||||||
|
transport_class,
|
||||||
|
"send",
|
||||||
|
side_effect=_fail_one_less_than_retry_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await protocol.query({}, retry_count=retry_count)
|
response = await protocol_class(host, transport=transport_class(host)).query(
|
||||||
assert response == {"great": "success"}
|
DUMMY_QUERY, retry_count=retry_count
|
||||||
|
)
|
||||||
|
assert "result" in response or "great" in response
|
||||||
|
assert send_mock.call_count == retry_count
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("log_level", [logging.WARNING, logging.DEBUG])
|
@pytest.mark.parametrize("log_level", [logging.WARNING, logging.DEBUG])
|
||||||
@ -96,14 +126,14 @@ async def test_protocol_logging(mocker, caplog, log_level):
|
|||||||
return 200, encrypted
|
return 200, encrypted
|
||||||
|
|
||||||
seed = secrets.token_bytes(16)
|
seed = secrets.token_bytes(16)
|
||||||
auth_hash = TPLinkKlap.generate_auth_hash(Credentials("foo", "bar"))
|
auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar"))
|
||||||
encryption_session = KlapEncryptionSession(seed, seed, auth_hash)
|
encryption_session = KlapEncryptionSession(seed, seed, auth_hash)
|
||||||
protocol = TPLinkKlap("127.0.0.1")
|
protocol = IotProtocol("127.0.0.1")
|
||||||
|
|
||||||
protocol.handshake_done = True
|
protocol._transport._handshake_done = True
|
||||||
protocol.session_expire_at = time.time() + 86400
|
protocol._transport._session_expire_at = time.time() + 86400
|
||||||
protocol.encryption_session = encryption_session
|
protocol._transport._encryption_session = encryption_session
|
||||||
mocker.patch.object(TPLinkKlap, "client_post", side_effect=_return_encrypted)
|
mocker.patch.object(KlapTransport, "client_post", side_effect=_return_encrypted)
|
||||||
|
|
||||||
response = await protocol.query({})
|
response = await protocol.query({})
|
||||||
assert response == {"great": "success"}
|
assert response == {"great": "success"}
|
||||||
@ -117,7 +147,7 @@ def test_encrypt():
|
|||||||
d = json.dumps({"foo": 1, "bar": 2})
|
d = json.dumps({"foo": 1, "bar": 2})
|
||||||
|
|
||||||
seed = secrets.token_bytes(16)
|
seed = secrets.token_bytes(16)
|
||||||
auth_hash = TPLinkKlap.generate_auth_hash(Credentials("foo", "bar"))
|
auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar"))
|
||||||
encryption_session = KlapEncryptionSession(seed, seed, auth_hash)
|
encryption_session = KlapEncryptionSession(seed, seed, auth_hash)
|
||||||
|
|
||||||
encrypted, seq = encryption_session.encrypt(d)
|
encrypted, seq = encryption_session.encrypt(d)
|
||||||
@ -129,7 +159,7 @@ def test_encrypt_unicode():
|
|||||||
d = "{'snowman': '\u2603'}"
|
d = "{'snowman': '\u2603'}"
|
||||||
|
|
||||||
seed = secrets.token_bytes(16)
|
seed = secrets.token_bytes(16)
|
||||||
auth_hash = TPLinkKlap.generate_auth_hash(Credentials("foo", "bar"))
|
auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar"))
|
||||||
encryption_session = KlapEncryptionSession(seed, seed, auth_hash)
|
encryption_session = KlapEncryptionSession(seed, seed, auth_hash)
|
||||||
|
|
||||||
encrypted, seq = encryption_session.encrypt(d)
|
encrypted, seq = encryption_session.encrypt(d)
|
||||||
@ -145,7 +175,10 @@ def test_encrypt_unicode():
|
|||||||
(Credentials("foo", "bar"), does_not_raise()),
|
(Credentials("foo", "bar"), does_not_raise()),
|
||||||
(Credentials("", ""), does_not_raise()),
|
(Credentials("", ""), does_not_raise()),
|
||||||
(
|
(
|
||||||
Credentials(TPLinkKlap.KASA_SETUP_EMAIL, TPLinkKlap.KASA_SETUP_PASSWORD),
|
Credentials(
|
||||||
|
KlapTransport.KASA_SETUP_EMAIL,
|
||||||
|
KlapTransport.KASA_SETUP_PASSWORD,
|
||||||
|
),
|
||||||
does_not_raise(),
|
does_not_raise(),
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
@ -167,21 +200,21 @@ async def test_handshake1(mocker, device_credentials, expectation):
|
|||||||
client_seed = None
|
client_seed = None
|
||||||
server_seed = secrets.token_bytes(16)
|
server_seed = secrets.token_bytes(16)
|
||||||
client_credentials = Credentials("foo", "bar")
|
client_credentials = Credentials("foo", "bar")
|
||||||
device_auth_hash = TPLinkKlap.generate_auth_hash(device_credentials)
|
device_auth_hash = KlapTransport.generate_auth_hash(device_credentials)
|
||||||
|
|
||||||
mocker.patch.object(
|
mocker.patch.object(
|
||||||
httpx.AsyncClient, "post", side_effect=_return_handshake1_response
|
httpx.AsyncClient, "post", side_effect=_return_handshake1_response
|
||||||
)
|
)
|
||||||
|
|
||||||
protocol = TPLinkKlap("127.0.0.1", credentials=client_credentials)
|
protocol = IotProtocol("127.0.0.1", credentials=client_credentials)
|
||||||
|
|
||||||
protocol.http_client = httpx.AsyncClient()
|
protocol._transport.http_client = httpx.AsyncClient()
|
||||||
with expectation:
|
with expectation:
|
||||||
(
|
(
|
||||||
local_seed,
|
local_seed,
|
||||||
device_remote_seed,
|
device_remote_seed,
|
||||||
auth_hash,
|
auth_hash,
|
||||||
) = await protocol.perform_handshake1()
|
) = await protocol._transport.perform_handshake1()
|
||||||
|
|
||||||
assert local_seed == client_seed
|
assert local_seed == client_seed
|
||||||
assert device_remote_seed == server_seed
|
assert device_remote_seed == server_seed
|
||||||
@ -204,23 +237,23 @@ async def test_handshake(mocker):
|
|||||||
client_seed = None
|
client_seed = None
|
||||||
server_seed = secrets.token_bytes(16)
|
server_seed = secrets.token_bytes(16)
|
||||||
client_credentials = Credentials("foo", "bar")
|
client_credentials = Credentials("foo", "bar")
|
||||||
device_auth_hash = TPLinkKlap.generate_auth_hash(client_credentials)
|
device_auth_hash = KlapTransport.generate_auth_hash(client_credentials)
|
||||||
|
|
||||||
mocker.patch.object(
|
mocker.patch.object(
|
||||||
httpx.AsyncClient, "post", side_effect=_return_handshake_response
|
httpx.AsyncClient, "post", side_effect=_return_handshake_response
|
||||||
)
|
)
|
||||||
|
|
||||||
protocol = TPLinkKlap("127.0.0.1", credentials=client_credentials)
|
protocol = IotProtocol("127.0.0.1", credentials=client_credentials)
|
||||||
protocol.http_client = httpx.AsyncClient()
|
protocol._transport.http_client = httpx.AsyncClient()
|
||||||
|
|
||||||
response_status = 200
|
response_status = 200
|
||||||
await protocol.perform_handshake()
|
await protocol._transport.perform_handshake()
|
||||||
assert protocol.handshake_done is True
|
assert protocol._transport._handshake_done is True
|
||||||
|
|
||||||
response_status = 403
|
response_status = 403
|
||||||
with pytest.raises(AuthenticationException):
|
with pytest.raises(AuthenticationException):
|
||||||
await protocol.perform_handshake()
|
await protocol._transport.perform_handshake()
|
||||||
assert protocol.handshake_done is False
|
assert protocol._transport._handshake_done is False
|
||||||
await protocol.close()
|
await protocol.close()
|
||||||
|
|
||||||
|
|
||||||
@ -237,9 +270,9 @@ async def test_query(mocker):
|
|||||||
return _mock_response(200, b"")
|
return _mock_response(200, b"")
|
||||||
elif url == "http://127.0.0.1/app/request":
|
elif url == "http://127.0.0.1/app/request":
|
||||||
encryption_session = KlapEncryptionSession(
|
encryption_session = KlapEncryptionSession(
|
||||||
protocol.encryption_session.local_seed,
|
protocol._transport._encryption_session.local_seed,
|
||||||
protocol.encryption_session.remote_seed,
|
protocol._transport._encryption_session.remote_seed,
|
||||||
protocol.encryption_session.user_hash,
|
protocol._transport._encryption_session.user_hash,
|
||||||
)
|
)
|
||||||
seq = params.get("seq")
|
seq = params.get("seq")
|
||||||
encryption_session._seq = seq - 1
|
encryption_session._seq = seq - 1
|
||||||
@ -252,11 +285,11 @@ async def test_query(mocker):
|
|||||||
seq = None
|
seq = None
|
||||||
server_seed = secrets.token_bytes(16)
|
server_seed = secrets.token_bytes(16)
|
||||||
client_credentials = Credentials("foo", "bar")
|
client_credentials = Credentials("foo", "bar")
|
||||||
device_auth_hash = TPLinkKlap.generate_auth_hash(client_credentials)
|
device_auth_hash = KlapTransport.generate_auth_hash(client_credentials)
|
||||||
|
|
||||||
mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response)
|
mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response)
|
||||||
|
|
||||||
protocol = TPLinkKlap("127.0.0.1", credentials=client_credentials)
|
protocol = IotProtocol("127.0.0.1", credentials=client_credentials)
|
||||||
|
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
resp = await protocol.query({})
|
resp = await protocol.query({})
|
||||||
@ -296,11 +329,11 @@ async def test_authentication_failures(mocker, response_status, expectation):
|
|||||||
|
|
||||||
server_seed = secrets.token_bytes(16)
|
server_seed = secrets.token_bytes(16)
|
||||||
client_credentials = Credentials("foo", "bar")
|
client_credentials = Credentials("foo", "bar")
|
||||||
device_auth_hash = TPLinkKlap.generate_auth_hash(client_credentials)
|
device_auth_hash = KlapTransport.generate_auth_hash(client_credentials)
|
||||||
|
|
||||||
mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response)
|
mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response)
|
||||||
|
|
||||||
protocol = TPLinkKlap("127.0.0.1", credentials=client_credentials)
|
protocol = IotProtocol("127.0.0.1", credentials=client_credentials)
|
||||||
|
|
||||||
with expectation:
|
with expectation:
|
||||||
await protocol.query({})
|
await protocol.query({})
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from kasa import DeviceType
|
from kasa import DeviceType
|
||||||
|
|
||||||
from .conftest import plug
|
from .conftest import plug, plug_smart
|
||||||
from .newfakes import PLUG_SCHEMA
|
from .newfakes import PLUG_SCHEMA
|
||||||
|
|
||||||
|
|
||||||
@ -28,3 +28,14 @@ async def test_led(dev):
|
|||||||
assert dev.led
|
assert dev.led
|
||||||
|
|
||||||
await dev.set_led(original)
|
await dev.set_led(original)
|
||||||
|
|
||||||
|
|
||||||
|
@plug_smart
|
||||||
|
async def test_plug_device_info(dev):
|
||||||
|
assert dev._info is not None
|
||||||
|
# PLUG_SCHEMA(dev.sys_info)
|
||||||
|
|
||||||
|
assert dev.model is not None
|
||||||
|
|
||||||
|
assert dev.device_type == DeviceType.Plug or dev.device_type == DeviceType.Strip
|
||||||
|
# assert dev.is_plug or dev.is_strip
|
||||||
|
@ -9,7 +9,7 @@ from kasa.tests.conftest import get_device_for_file
|
|||||||
|
|
||||||
def test_bulb_examples(mocker):
|
def test_bulb_examples(mocker):
|
||||||
"""Use KL130 (bulb with all features) to test the doctests."""
|
"""Use KL130 (bulb with all features) to test the doctests."""
|
||||||
p = asyncio.run(get_device_for_file("KL130(US)_1.0_1.8.11.json"))
|
p = asyncio.run(get_device_for_file("KL130(US)_1.0_1.8.11.json", "IOT"))
|
||||||
mocker.patch("kasa.smartbulb.SmartBulb", return_value=p)
|
mocker.patch("kasa.smartbulb.SmartBulb", return_value=p)
|
||||||
mocker.patch("kasa.smartbulb.SmartBulb.update")
|
mocker.patch("kasa.smartbulb.SmartBulb.update")
|
||||||
res = xdoctest.doctest_module("kasa.smartbulb", "all")
|
res = xdoctest.doctest_module("kasa.smartbulb", "all")
|
||||||
@ -18,7 +18,7 @@ def test_bulb_examples(mocker):
|
|||||||
|
|
||||||
def test_smartdevice_examples(mocker):
|
def test_smartdevice_examples(mocker):
|
||||||
"""Use HS110 for emeter examples."""
|
"""Use HS110 for emeter examples."""
|
||||||
p = asyncio.run(get_device_for_file("HS110(EU)_1.0_1.2.5.json"))
|
p = asyncio.run(get_device_for_file("HS110(EU)_1.0_1.2.5.json", "IOT"))
|
||||||
mocker.patch("kasa.smartdevice.SmartDevice", return_value=p)
|
mocker.patch("kasa.smartdevice.SmartDevice", return_value=p)
|
||||||
mocker.patch("kasa.smartdevice.SmartDevice.update")
|
mocker.patch("kasa.smartdevice.SmartDevice.update")
|
||||||
res = xdoctest.doctest_module("kasa.smartdevice", "all")
|
res = xdoctest.doctest_module("kasa.smartdevice", "all")
|
||||||
@ -27,7 +27,7 @@ def test_smartdevice_examples(mocker):
|
|||||||
|
|
||||||
def test_plug_examples(mocker):
|
def test_plug_examples(mocker):
|
||||||
"""Test plug examples."""
|
"""Test plug examples."""
|
||||||
p = asyncio.run(get_device_for_file("HS110(EU)_1.0_1.2.5.json"))
|
p = asyncio.run(get_device_for_file("HS110(EU)_1.0_1.2.5.json", "IOT"))
|
||||||
mocker.patch("kasa.smartplug.SmartPlug", return_value=p)
|
mocker.patch("kasa.smartplug.SmartPlug", return_value=p)
|
||||||
mocker.patch("kasa.smartplug.SmartPlug.update")
|
mocker.patch("kasa.smartplug.SmartPlug.update")
|
||||||
res = xdoctest.doctest_module("kasa.smartplug", "all")
|
res = xdoctest.doctest_module("kasa.smartplug", "all")
|
||||||
@ -36,7 +36,7 @@ def test_plug_examples(mocker):
|
|||||||
|
|
||||||
def test_strip_examples(mocker):
|
def test_strip_examples(mocker):
|
||||||
"""Test strip examples."""
|
"""Test strip examples."""
|
||||||
p = asyncio.run(get_device_for_file("KP303(UK)_1.0_1.0.3.json"))
|
p = asyncio.run(get_device_for_file("KP303(UK)_1.0_1.0.3.json", "IOT"))
|
||||||
mocker.patch("kasa.smartstrip.SmartStrip", return_value=p)
|
mocker.patch("kasa.smartstrip.SmartStrip", return_value=p)
|
||||||
mocker.patch("kasa.smartstrip.SmartStrip.update")
|
mocker.patch("kasa.smartstrip.SmartStrip.update")
|
||||||
res = xdoctest.doctest_module("kasa.smartstrip", "all")
|
res = xdoctest.doctest_module("kasa.smartstrip", "all")
|
||||||
@ -45,7 +45,7 @@ def test_strip_examples(mocker):
|
|||||||
|
|
||||||
def test_dimmer_examples(mocker):
|
def test_dimmer_examples(mocker):
|
||||||
"""Test dimmer examples."""
|
"""Test dimmer examples."""
|
||||||
p = asyncio.run(get_device_for_file("HS220(US)_1.0_1.5.7.json"))
|
p = asyncio.run(get_device_for_file("HS220(US)_1.0_1.5.7.json", "IOT"))
|
||||||
mocker.patch("kasa.smartdimmer.SmartDimmer", return_value=p)
|
mocker.patch("kasa.smartdimmer.SmartDimmer", return_value=p)
|
||||||
mocker.patch("kasa.smartdimmer.SmartDimmer.update")
|
mocker.patch("kasa.smartdimmer.SmartDimmer.update")
|
||||||
res = xdoctest.doctest_module("kasa.smartdimmer", "all")
|
res = xdoctest.doctest_module("kasa.smartdimmer", "all")
|
||||||
@ -54,7 +54,7 @@ def test_dimmer_examples(mocker):
|
|||||||
|
|
||||||
def test_lightstrip_examples(mocker):
|
def test_lightstrip_examples(mocker):
|
||||||
"""Test lightstrip examples."""
|
"""Test lightstrip examples."""
|
||||||
p = asyncio.run(get_device_for_file("KL430(US)_1.0_1.0.10.json"))
|
p = asyncio.run(get_device_for_file("KL430(US)_1.0_1.0.10.json", "IOT"))
|
||||||
mocker.patch("kasa.smartlightstrip.SmartLightStrip", return_value=p)
|
mocker.patch("kasa.smartlightstrip.SmartLightStrip", return_value=p)
|
||||||
mocker.patch("kasa.smartlightstrip.SmartLightStrip.update")
|
mocker.patch("kasa.smartlightstrip.SmartLightStrip.update")
|
||||||
res = xdoctest.doctest_module("kasa.smartlightstrip", "all")
|
res = xdoctest.doctest_module("kasa.smartlightstrip", "all")
|
||||||
@ -63,7 +63,7 @@ def test_lightstrip_examples(mocker):
|
|||||||
|
|
||||||
def test_discovery_examples(mocker):
|
def test_discovery_examples(mocker):
|
||||||
"""Test discovery examples."""
|
"""Test discovery examples."""
|
||||||
p = asyncio.run(get_device_for_file("KP303(UK)_1.0_1.0.3.json"))
|
p = asyncio.run(get_device_for_file("KP303(UK)_1.0_1.0.3.json", "IOT"))
|
||||||
|
|
||||||
mocker.patch("kasa.discover.Discover.discover", return_value=[p])
|
mocker.patch("kasa.discover.Discover.discover", return_value=[p])
|
||||||
res = xdoctest.doctest_module("kasa.discover", "all")
|
res = xdoctest.doctest_module("kasa.discover", "all")
|
||||||
|
@ -8,7 +8,7 @@ import kasa
|
|||||||
from kasa import Credentials, SmartDevice, SmartDeviceException
|
from kasa import Credentials, SmartDevice, SmartDeviceException
|
||||||
from kasa.smartdevice import DeviceType
|
from kasa.smartdevice import DeviceType
|
||||||
|
|
||||||
from .conftest import handle_turn_on, has_emeter, no_emeter, turn_on
|
from .conftest import device_iot, handle_turn_on, has_emeter, no_emeter, turn_on
|
||||||
from .newfakes import PLUG_SCHEMA, TZ_SCHEMA, FakeTransportProtocol
|
from .newfakes import PLUG_SCHEMA, TZ_SCHEMA, FakeTransportProtocol
|
||||||
|
|
||||||
# List of all SmartXXX classes including the SmartDevice base class
|
# List of all SmartXXX classes including the SmartDevice base class
|
||||||
@ -22,11 +22,13 @@ smart_device_classes = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@device_iot
|
||||||
async def test_state_info(dev):
|
async def test_state_info(dev):
|
||||||
assert isinstance(dev.state_information, dict)
|
assert isinstance(dev.state_information, dict)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires_dummy
|
@pytest.mark.requires_dummy
|
||||||
|
@device_iot
|
||||||
async def test_invalid_connection(dev):
|
async def test_invalid_connection(dev):
|
||||||
with patch.object(
|
with patch.object(
|
||||||
FakeTransportProtocol, "query", side_effect=SmartDeviceException
|
FakeTransportProtocol, "query", side_effect=SmartDeviceException
|
||||||
@ -58,12 +60,14 @@ async def test_initial_update_no_emeter(dev, mocker):
|
|||||||
assert spy.call_count == 2
|
assert spy.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
@device_iot
|
||||||
async def test_query_helper(dev):
|
async def test_query_helper(dev):
|
||||||
with pytest.raises(SmartDeviceException):
|
with pytest.raises(SmartDeviceException):
|
||||||
await dev._query_helper("test", "testcmd", {})
|
await dev._query_helper("test", "testcmd", {})
|
||||||
# TODO check for unwrapping?
|
# TODO check for unwrapping?
|
||||||
|
|
||||||
|
|
||||||
|
@device_iot
|
||||||
@turn_on
|
@turn_on
|
||||||
async def test_state(dev, turn_on):
|
async def test_state(dev, turn_on):
|
||||||
await handle_turn_on(dev, turn_on)
|
await handle_turn_on(dev, turn_on)
|
||||||
@ -90,6 +94,7 @@ async def test_state(dev, turn_on):
|
|||||||
assert dev.is_off
|
assert dev.is_off
|
||||||
|
|
||||||
|
|
||||||
|
@device_iot
|
||||||
async def test_alias(dev):
|
async def test_alias(dev):
|
||||||
test_alias = "TEST1234"
|
test_alias = "TEST1234"
|
||||||
original = dev.alias
|
original = dev.alias
|
||||||
@ -104,6 +109,7 @@ async def test_alias(dev):
|
|||||||
assert dev.alias == original
|
assert dev.alias == original
|
||||||
|
|
||||||
|
|
||||||
|
@device_iot
|
||||||
@turn_on
|
@turn_on
|
||||||
async def test_on_since(dev, turn_on):
|
async def test_on_since(dev, turn_on):
|
||||||
await handle_turn_on(dev, turn_on)
|
await handle_turn_on(dev, turn_on)
|
||||||
@ -116,30 +122,37 @@ async def test_on_since(dev, turn_on):
|
|||||||
assert dev.on_since is None
|
assert dev.on_since is None
|
||||||
|
|
||||||
|
|
||||||
|
@device_iot
|
||||||
async def test_time(dev):
|
async def test_time(dev):
|
||||||
assert isinstance(await dev.get_time(), datetime)
|
assert isinstance(await dev.get_time(), datetime)
|
||||||
|
|
||||||
|
|
||||||
|
@device_iot
|
||||||
async def test_timezone(dev):
|
async def test_timezone(dev):
|
||||||
TZ_SCHEMA(await dev.get_timezone())
|
TZ_SCHEMA(await dev.get_timezone())
|
||||||
|
|
||||||
|
|
||||||
|
@device_iot
|
||||||
async def test_hw_info(dev):
|
async def test_hw_info(dev):
|
||||||
PLUG_SCHEMA(dev.hw_info)
|
PLUG_SCHEMA(dev.hw_info)
|
||||||
|
|
||||||
|
|
||||||
|
@device_iot
|
||||||
async def test_location(dev):
|
async def test_location(dev):
|
||||||
PLUG_SCHEMA(dev.location)
|
PLUG_SCHEMA(dev.location)
|
||||||
|
|
||||||
|
|
||||||
|
@device_iot
|
||||||
async def test_rssi(dev):
|
async def test_rssi(dev):
|
||||||
PLUG_SCHEMA({"rssi": dev.rssi}) # wrapping for vol
|
PLUG_SCHEMA({"rssi": dev.rssi}) # wrapping for vol
|
||||||
|
|
||||||
|
|
||||||
|
@device_iot
|
||||||
async def test_mac(dev):
|
async def test_mac(dev):
|
||||||
PLUG_SCHEMA({"mac": dev.mac}) # wrapping for val
|
PLUG_SCHEMA({"mac": dev.mac}) # wrapping for val
|
||||||
|
|
||||||
|
|
||||||
|
@device_iot
|
||||||
async def test_representation(dev):
|
async def test_representation(dev):
|
||||||
import re
|
import re
|
||||||
|
|
||||||
@ -147,6 +160,7 @@ async def test_representation(dev):
|
|||||||
assert pattern.match(str(dev))
|
assert pattern.match(str(dev))
|
||||||
|
|
||||||
|
|
||||||
|
@device_iot
|
||||||
async def test_childrens(dev):
|
async def test_childrens(dev):
|
||||||
"""Make sure that children property is exposed by every device."""
|
"""Make sure that children property is exposed by every device."""
|
||||||
if dev.is_strip:
|
if dev.is_strip:
|
||||||
@ -155,6 +169,7 @@ async def test_childrens(dev):
|
|||||||
assert len(dev.children) == 0
|
assert len(dev.children) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@device_iot
|
||||||
async def test_children(dev):
|
async def test_children(dev):
|
||||||
"""Make sure that children property is exposed by every device."""
|
"""Make sure that children property is exposed by every device."""
|
||||||
if dev.is_strip:
|
if dev.is_strip:
|
||||||
@ -165,11 +180,13 @@ async def test_children(dev):
|
|||||||
assert dev.has_children is False
|
assert dev.has_children is False
|
||||||
|
|
||||||
|
|
||||||
|
@device_iot
|
||||||
async def test_internal_state(dev):
|
async def test_internal_state(dev):
|
||||||
"""Make sure the internal state returns the last update results."""
|
"""Make sure the internal state returns the last update results."""
|
||||||
assert dev.internal_state == dev._last_update
|
assert dev.internal_state == dev._last_update
|
||||||
|
|
||||||
|
|
||||||
|
@device_iot
|
||||||
async def test_features(dev):
|
async def test_features(dev):
|
||||||
"""Make sure features is always accessible."""
|
"""Make sure features is always accessible."""
|
||||||
sysinfo = dev._last_update["system"]["get_sysinfo"]
|
sysinfo = dev._last_update["system"]["get_sysinfo"]
|
||||||
@ -179,11 +196,13 @@ async def test_features(dev):
|
|||||||
assert dev.features == set()
|
assert dev.features == set()
|
||||||
|
|
||||||
|
|
||||||
|
@device_iot
|
||||||
async def test_max_device_response_size(dev):
|
async def test_max_device_response_size(dev):
|
||||||
"""Make sure every device return has a set max response size."""
|
"""Make sure every device return has a set max response size."""
|
||||||
assert dev.max_device_response_size > 0
|
assert dev.max_device_response_size > 0
|
||||||
|
|
||||||
|
|
||||||
|
@device_iot
|
||||||
async def test_estimated_response_sizes(dev):
|
async def test_estimated_response_sizes(dev):
|
||||||
"""Make sure every module has an estimated response size set."""
|
"""Make sure every module has an estimated response size set."""
|
||||||
for mod in dev.modules.values():
|
for mod in dev.modules.values():
|
||||||
@ -202,6 +221,7 @@ def test_device_class_ctors(device_class):
|
|||||||
assert dev.credentials == credentials
|
assert dev.credentials == credentials
|
||||||
|
|
||||||
|
|
||||||
|
@device_iot
|
||||||
async def test_modules_preserved(dev: SmartDevice):
|
async def test_modules_preserved(dev: SmartDevice):
|
||||||
"""Make modules that are not being updated are preserved between updates."""
|
"""Make modules that are not being updated are preserved between updates."""
|
||||||
dev._last_update["some_module_not_being_updated"] = "should_be_kept"
|
dev._last_update["some_module_not_being_updated"] = "should_be_kept"
|
||||||
@ -237,6 +257,7 @@ async def test_create_thin_wrapper():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@device_iot
|
||||||
async def test_modules_not_supported(dev: SmartDevice):
|
async def test_modules_not_supported(dev: SmartDevice):
|
||||||
"""Test that unsupported modules do not break the device."""
|
"""Test that unsupported modules do not break the device."""
|
||||||
for module in dev.modules.values():
|
for module in dev.modules.values():
|
||||||
|
Loading…
Reference in New Issue
Block a user