mirror of
https://github.com/python-kasa/python-kasa.git
synced 2024-12-22 19:23:34 +00:00
Add klap support for TAPO protocol by splitting out Transports and Protocols (#557)
* Add support for TAPO/SMART KLAP and seperate transports from protocols * Add tests and some review changes * Update following review * Updates following review
This commit is contained in:
parent
347cbfe3bd
commit
4a00199506
@ -21,13 +21,14 @@ from kasa.exceptions import (
|
||||
SmartDeviceException,
|
||||
UnsupportedDeviceException,
|
||||
)
|
||||
from kasa.klapprotocol import TPLinkKlap
|
||||
from kasa.iotprotocol import IotProtocol
|
||||
from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol
|
||||
from kasa.smartbulb import SmartBulb, SmartBulbPreset, TurnOnBehavior, TurnOnBehaviors
|
||||
from kasa.smartdevice import DeviceType, SmartDevice
|
||||
from kasa.smartdimmer import SmartDimmer
|
||||
from kasa.smartlightstrip import SmartLightStrip
|
||||
from kasa.smartplug import SmartPlug
|
||||
from kasa.smartprotocol import SmartProtocol
|
||||
from kasa.smartstrip import SmartStrip
|
||||
|
||||
__version__ = version("python-kasa")
|
||||
@ -37,7 +38,8 @@ __all__ = [
|
||||
"Discover",
|
||||
"TPLinkSmartHomeProtocol",
|
||||
"TPLinkProtocol",
|
||||
"TPLinkKlap",
|
||||
"IotProtocol",
|
||||
"SmartProtocol",
|
||||
"SmartBulb",
|
||||
"SmartBulbPreset",
|
||||
"TurnOnBehaviors",
|
||||
|
@ -1,498 +0,0 @@
|
||||
"""Implementation of the TP-Link AES Protocol.
|
||||
|
||||
Based on the work of https://github.com/petretiandrea/plugp100
|
||||
under compatible GNU GPL3 license.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from pprint import pformat as pf
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import httpx
|
||||
from cryptography.hazmat.primitives import hashes, padding, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
|
||||
from .credentials import Credentials
|
||||
from .exceptions import AuthenticationException, SmartDeviceException
|
||||
from .json import dumps as json_dumps
|
||||
from .json import loads as json_loads
|
||||
from .protocol import TPLinkProtocol
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
logging.getLogger("httpx").propagate = False
|
||||
|
||||
|
||||
def _md5(payload: bytes) -> bytes:
|
||||
digest = hashes.Hash(hashes.MD5()) # noqa: S303
|
||||
digest.update(payload)
|
||||
hash = digest.finalize()
|
||||
return hash
|
||||
|
||||
|
||||
def _sha1(payload: bytes) -> str:
|
||||
sha1_algo = hashlib.sha1() # noqa: S324
|
||||
sha1_algo.update(payload)
|
||||
return sha1_algo.hexdigest()
|
||||
|
||||
|
||||
class TPLinkAes(TPLinkProtocol):
|
||||
"""Implementation of the AES encryption protocol.
|
||||
|
||||
AES is the name used in device discovery for TP-Link's TAPO encryption
|
||||
protocol, sometimes used by newer firmware versions on kasa devices.
|
||||
"""
|
||||
|
||||
DEFAULT_PORT = 80
|
||||
DEFAULT_TIMEOUT = 5
|
||||
SESSION_COOKIE_NAME = "TP_SESSIONID"
|
||||
COMMON_HEADERS = {
|
||||
"Content-Type": "application/json",
|
||||
"requestByApp": "true",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
*,
|
||||
credentials: Optional[Credentials] = None,
|
||||
timeout: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__(host=host, port=self.DEFAULT_PORT)
|
||||
|
||||
self.credentials = (
|
||||
credentials
|
||||
if credentials and credentials.username and credentials.password
|
||||
else Credentials(username="", password="")
|
||||
)
|
||||
|
||||
self._local_seed: Optional[bytes] = None
|
||||
self.local_auth_hash = self.generate_auth_hash(self.credentials)
|
||||
self.local_auth_owner = self.generate_owner_hash(self.credentials).hex()
|
||||
self.kasa_setup_auth_hash = None
|
||||
self.blank_auth_hash = None
|
||||
self.handshake_lock = asyncio.Lock()
|
||||
self.query_lock = asyncio.Lock()
|
||||
self.handshake_done = False
|
||||
|
||||
self.encryption_session: Optional[AesEncyptionSession] = None
|
||||
self.session_expire_at: Optional[float] = None
|
||||
|
||||
self.timeout = timeout if timeout else self.DEFAULT_TIMEOUT
|
||||
self.session_cookie = None
|
||||
self.terminal_uuid = None
|
||||
self.http_client: Optional[httpx.AsyncClient] = None
|
||||
self.request_id_generator = SnowflakeId(1, 1)
|
||||
self.login_token = None
|
||||
|
||||
_LOGGER.debug("Created AES object for %s", self.host)
|
||||
|
||||
def hash_credentials(self, credentials, try_login_version2):
|
||||
"""Hash the credentials."""
|
||||
if try_login_version2:
|
||||
un = base64.b64encode(
|
||||
_sha1(credentials.username.encode()).encode()
|
||||
).decode()
|
||||
pw = base64.b64encode(
|
||||
_sha1(credentials.password.encode()).encode()
|
||||
).decode()
|
||||
else:
|
||||
un = base64.b64encode(
|
||||
_sha1(credentials.username.encode()).encode()
|
||||
).decode()
|
||||
pw = base64.b64encode(credentials.password.encode()).decode()
|
||||
return un, pw
|
||||
|
||||
async def client_post(self, url, params=None, data=None, json=None, headers=None):
|
||||
"""Send an http post request to the device."""
|
||||
response_data = None
|
||||
cookies = None
|
||||
if self.session_cookie:
|
||||
cookies = httpx.Cookies()
|
||||
cookies.set(self.SESSION_COOKIE_NAME, self.session_cookie)
|
||||
self.http_client.cookies.clear()
|
||||
resp = await self.http_client.post(
|
||||
url,
|
||||
params=params,
|
||||
data=data,
|
||||
json=json,
|
||||
timeout=self.timeout,
|
||||
cookies=cookies,
|
||||
headers=self.COMMON_HEADERS,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
response_data = resp.json()
|
||||
|
||||
return resp.status_code, response_data
|
||||
|
||||
async def send_secure_passthrough(self, request):
|
||||
"""Send encrypted message as passthrough."""
|
||||
url = f"http://{self.host}/app"
|
||||
if self.login_token:
|
||||
url += f"?token={self.login_token}"
|
||||
raw_request = json_dumps(request)
|
||||
encrypted_payload = self.encryption_session.encrypt(raw_request.encode())
|
||||
passthrough_request = {
|
||||
"method": "securePassthrough",
|
||||
"params": {"request": encrypted_payload.decode()},
|
||||
}
|
||||
status_code, resp_dict = await self.client_post(url, json=passthrough_request)
|
||||
if status_code == 200 and resp_dict["error_code"] == 0:
|
||||
response = self.encryption_session.decrypt(
|
||||
resp_dict["result"]["response"].encode()
|
||||
)
|
||||
resp_dict = json_loads(response)
|
||||
if resp_dict["error_code"] != 0:
|
||||
raise SmartDeviceException(
|
||||
f"Could not complete send, response was {resp_dict}",
|
||||
)
|
||||
if "result" in resp_dict:
|
||||
return resp_dict["result"]
|
||||
else:
|
||||
raise AuthenticationException("Could not complete send")
|
||||
|
||||
def get_aes_request(self, method, params=None):
|
||||
"""Get a request message."""
|
||||
request = {
|
||||
"method": method,
|
||||
"params": params,
|
||||
"requestID": self.request_id_generator.generate_id(),
|
||||
"request_time_milis": round(time.time() * 1000),
|
||||
"terminal_uuid": self.terminal_uuid,
|
||||
}
|
||||
return request
|
||||
|
||||
async def perform_login(self, login_v2):
|
||||
"""Login to the device."""
|
||||
self.login_token = None
|
||||
|
||||
un, pw = self.hash_credentials(self.credentials, login_v2)
|
||||
params = {"password": pw, "username": un}
|
||||
request = self.get_aes_request("login_device", params)
|
||||
try:
|
||||
result = await self.send_secure_passthrough(request)
|
||||
except SmartDeviceException as ex:
|
||||
raise AuthenticationException(ex) from ex
|
||||
self.login_token = result["token"]
|
||||
|
||||
async def perform_handshake(self):
|
||||
"""Perform the handshake."""
|
||||
_LOGGER.debug("Will perform handshaking...")
|
||||
_LOGGER.debug("Generating keypair")
|
||||
|
||||
self.handshake_done = False
|
||||
self.session_expire_at = None
|
||||
self.session_cookie = None
|
||||
|
||||
url = f"http://{self.host}/app"
|
||||
key_pair = KeyPair.create_key_pair()
|
||||
|
||||
pub_key = (
|
||||
"-----BEGIN PUBLIC KEY-----\n"
|
||||
+ key_pair.get_public_key()
|
||||
+ "\n-----END PUBLIC KEY-----\n"
|
||||
)
|
||||
handshake_params = {"key": pub_key}
|
||||
_LOGGER.debug(f"Handshake params: {handshake_params}")
|
||||
|
||||
request_body = {"method": "handshake", "params": handshake_params}
|
||||
|
||||
_LOGGER.debug(f"Request {request_body}")
|
||||
|
||||
status_code, resp_dict = await self.client_post(url, json=request_body)
|
||||
|
||||
_LOGGER.debug(f"Device responded with: {resp_dict}")
|
||||
|
||||
if status_code == 200 and resp_dict["error_code"] == 0:
|
||||
_LOGGER.debug("Decoding handshake key...")
|
||||
handshake_key = resp_dict["result"]["key"]
|
||||
|
||||
self.session_cookie = self.http_client.cookies.get( # type: ignore
|
||||
self.SESSION_COOKIE_NAME
|
||||
)
|
||||
if not self.session_cookie:
|
||||
self.session_cookie = self.http_client.cookies.get( # type: ignore
|
||||
"SESSIONID"
|
||||
)
|
||||
|
||||
self.session_expire_at = time.time() + 86400
|
||||
self.encryption_session = AesEncyptionSession.create_from_keypair(
|
||||
handshake_key, key_pair
|
||||
)
|
||||
|
||||
self.terminal_uuid = base64.b64encode(_md5(uuid.uuid4().bytes)).decode(
|
||||
"UTF-8"
|
||||
)
|
||||
self.handshake_done = True
|
||||
|
||||
_LOGGER.debug("Handshake with %s complete", self.host)
|
||||
|
||||
else:
|
||||
raise AuthenticationException("Could not complete handshake")
|
||||
|
||||
def handshake_session_expired(self):
|
||||
"""Return true if session has expired."""
|
||||
return (
|
||||
self.session_expire_at is None or self.session_expire_at - time.time() <= 0
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def generate_auth_hash(creds: Credentials):
|
||||
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
|
||||
un = creds.username or ""
|
||||
pw = creds.password or ""
|
||||
return _md5(_md5(un.encode()) + _md5(pw.encode()))
|
||||
|
||||
@staticmethod
|
||||
def generate_owner_hash(creds: Credentials):
|
||||
"""Return the MD5 hash of the username in this object."""
|
||||
un = creds.username or ""
|
||||
return _md5(un.encode())
|
||||
|
||||
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
||||
"""Query the device retrying for retry_count on failure."""
|
||||
async with self.query_lock:
|
||||
return await self._query(request, retry_count)
|
||||
|
||||
async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
||||
for retry in range(retry_count + 1):
|
||||
try:
|
||||
return await self._execute_query(request, retry)
|
||||
except httpx.CloseError as sdex:
|
||||
await self.close()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device: {self.host}: {sdex}"
|
||||
) from sdex
|
||||
continue
|
||||
except httpx.ConnectError as cex:
|
||||
await self.close()
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device: {self.host}: {cex}"
|
||||
) from cex
|
||||
except TimeoutError as tex:
|
||||
await self.close()
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device, timed out: {self.host}: {tex}"
|
||||
) from tex
|
||||
except AuthenticationException as auex:
|
||||
_LOGGER.debug("Unable to authenticate with %s, not retrying", self.host)
|
||||
raise auex
|
||||
except Exception as ex:
|
||||
await self.close()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device: {self.host}: {ex}"
|
||||
) from ex
|
||||
continue
|
||||
|
||||
# make mypy happy, this should never be reached..
|
||||
raise SmartDeviceException("Query reached somehow to unreachable")
|
||||
|
||||
async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> Dict:
|
||||
_LOGGER.debug(
|
||||
"%s >> %s",
|
||||
self.host,
|
||||
_LOGGER.isEnabledFor(logging.DEBUG) and pf(request),
|
||||
)
|
||||
|
||||
if not self.http_client:
|
||||
self.http_client = httpx.AsyncClient()
|
||||
|
||||
if not self.handshake_done or self.handshake_session_expired():
|
||||
try:
|
||||
await self.perform_handshake()
|
||||
await self.perform_login(False)
|
||||
except AuthenticationException:
|
||||
await self.perform_handshake()
|
||||
await self.perform_login(True)
|
||||
|
||||
if isinstance(request, dict):
|
||||
aes_method = next(iter(request))
|
||||
aes_params = request[aes_method]
|
||||
else:
|
||||
aes_method = request
|
||||
aes_params = None
|
||||
|
||||
aes_request = self.get_aes_request(aes_method, aes_params)
|
||||
response_data = await self.send_secure_passthrough(aes_request)
|
||||
|
||||
_LOGGER.debug(
|
||||
"%s << %s",
|
||||
self.host,
|
||||
_LOGGER.isEnabledFor(logging.DEBUG) and pf(response_data),
|
||||
)
|
||||
|
||||
return response_data
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the protocol."""
|
||||
client = self.http_client
|
||||
self.http_client = None
|
||||
if client:
|
||||
await client.aclose()
|
||||
|
||||
|
||||
class AesEncyptionSession:
|
||||
"""Class for an AES encryption session."""
|
||||
|
||||
@staticmethod
|
||||
def create_from_keypair(handshake_key: str, keypair):
|
||||
"""Create the encryption session."""
|
||||
handshake_key_bytes: bytes = base64.b64decode(handshake_key.encode("UTF-8"))
|
||||
private_key_data = base64.b64decode(keypair.get_private_key().encode("UTF-8"))
|
||||
|
||||
private_key = serialization.load_der_private_key(private_key_data, None, None)
|
||||
key_and_iv = private_key.decrypt(
|
||||
handshake_key_bytes, asymmetric_padding.PKCS1v15()
|
||||
)
|
||||
if key_and_iv is None:
|
||||
raise ValueError("Decryption failed!")
|
||||
|
||||
return AesEncyptionSession(key_and_iv[:16], key_and_iv[16:])
|
||||
|
||||
def __init__(self, key, iv):
|
||||
self.cipher = Cipher(algorithms.AES(key), modes.CBC(iv))
|
||||
self.padding_strategy = padding.PKCS7(algorithms.AES.block_size)
|
||||
|
||||
def encrypt(self, data) -> bytes:
|
||||
"""Encrypt the message."""
|
||||
encryptor = self.cipher.encryptor()
|
||||
padder = self.padding_strategy.padder()
|
||||
padded_data = padder.update(data) + padder.finalize()
|
||||
encrypted = encryptor.update(padded_data) + encryptor.finalize()
|
||||
return base64.b64encode(encrypted)
|
||||
|
||||
def decrypt(self, data) -> str:
|
||||
"""Decrypt the message."""
|
||||
decryptor = self.cipher.decryptor()
|
||||
unpadder = self.padding_strategy.unpadder()
|
||||
decrypted = decryptor.update(base64.b64decode(data)) + decryptor.finalize()
|
||||
unpadded_data = unpadder.update(decrypted) + unpadder.finalize()
|
||||
return unpadded_data.decode()
|
||||
|
||||
|
||||
class KeyPair:
|
||||
"""Class for generating key pairs."""
|
||||
|
||||
@staticmethod
|
||||
def create_key_pair(key_size: int = 1024):
|
||||
"""Create a key pair."""
|
||||
private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size)
|
||||
public_key = private_key.public_key()
|
||||
|
||||
private_key_bytes = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
public_key_bytes = public_key.public_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
return KeyPair(
|
||||
private_key=base64.b64encode(private_key_bytes).decode("UTF-8"),
|
||||
public_key=base64.b64encode(public_key_bytes).decode("UTF-8"),
|
||||
)
|
||||
|
||||
def __init__(self, private_key: str, public_key: str):
|
||||
self.private_key = private_key
|
||||
self.public_key = public_key
|
||||
|
||||
def get_private_key(self) -> str:
|
||||
"""Get the private key."""
|
||||
return self.private_key
|
||||
|
||||
def get_public_key(self) -> str:
|
||||
"""Get the public key."""
|
||||
return self.public_key
|
||||
|
||||
|
||||
class SnowflakeId:
|
||||
"""Class for generating snowflake ids."""
|
||||
|
||||
EPOCH = 1420041600000 # Custom epoch (in milliseconds)
|
||||
WORKER_ID_BITS = 5
|
||||
DATA_CENTER_ID_BITS = 5
|
||||
SEQUENCE_BITS = 12
|
||||
|
||||
MAX_WORKER_ID = (1 << WORKER_ID_BITS) - 1
|
||||
MAX_DATA_CENTER_ID = (1 << DATA_CENTER_ID_BITS) - 1
|
||||
|
||||
SEQUENCE_MASK = (1 << SEQUENCE_BITS) - 1
|
||||
|
||||
def __init__(self, worker_id, data_center_id):
|
||||
if worker_id > SnowflakeId.MAX_WORKER_ID or worker_id < 0:
|
||||
raise ValueError(
|
||||
"Worker ID can't be greater than "
|
||||
+ str(SnowflakeId.MAX_WORKER_ID)
|
||||
+ " or less than 0"
|
||||
)
|
||||
if data_center_id > SnowflakeId.MAX_DATA_CENTER_ID or data_center_id < 0:
|
||||
raise ValueError(
|
||||
"Data center ID can't be greater than "
|
||||
+ str(SnowflakeId.MAX_DATA_CENTER_ID)
|
||||
+ " or less than 0"
|
||||
)
|
||||
|
||||
self.worker_id = worker_id
|
||||
self.data_center_id = data_center_id
|
||||
self.sequence = 0
|
||||
self.last_timestamp = -1
|
||||
|
||||
def generate_id(self):
|
||||
"""Generate a snowflake id."""
|
||||
timestamp = self._current_millis()
|
||||
|
||||
if timestamp < self.last_timestamp:
|
||||
raise ValueError("Clock moved backwards. Refusing to generate ID.")
|
||||
|
||||
if timestamp == self.last_timestamp:
|
||||
# Within the same millisecond, increment the sequence number
|
||||
self.sequence = (self.sequence + 1) & SnowflakeId.SEQUENCE_MASK
|
||||
if self.sequence == 0:
|
||||
# Sequence exceeds its bit range, wait until the next millisecond
|
||||
timestamp = self._wait_next_millis(self.last_timestamp)
|
||||
else:
|
||||
# New millisecond, reset the sequence number
|
||||
self.sequence = 0
|
||||
|
||||
# Update the last timestamp
|
||||
self.last_timestamp = timestamp
|
||||
|
||||
# Generate and return the final ID
|
||||
return (
|
||||
(
|
||||
(timestamp - SnowflakeId.EPOCH)
|
||||
<< (
|
||||
SnowflakeId.WORKER_ID_BITS
|
||||
+ SnowflakeId.SEQUENCE_BITS
|
||||
+ SnowflakeId.DATA_CENTER_ID_BITS
|
||||
)
|
||||
)
|
||||
| (
|
||||
self.data_center_id
|
||||
<< (SnowflakeId.SEQUENCE_BITS + SnowflakeId.WORKER_ID_BITS)
|
||||
)
|
||||
| (self.worker_id << SnowflakeId.SEQUENCE_BITS)
|
||||
| self.sequence
|
||||
)
|
||||
|
||||
def _current_millis(self):
|
||||
return round(time.time() * 1000)
|
||||
|
||||
def _wait_next_millis(self, last_timestamp):
|
||||
timestamp = self._current_millis()
|
||||
while timestamp <= last_timestamp:
|
||||
timestamp = self._current_millis()
|
||||
return timestamp
|
338
kasa/aestransport.py
Normal file
338
kasa/aestransport.py
Normal file
@ -0,0 +1,338 @@
|
||||
"""Implementation of the TP-Link AES transport.
|
||||
|
||||
Based on the work of https://github.com/petretiandrea/plugp100
|
||||
under compatible GNU GPL3 license.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional, Union
|
||||
|
||||
import httpx
|
||||
from cryptography.hazmat.primitives import padding, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
|
||||
from .credentials import Credentials
|
||||
from .exceptions import AuthenticationException, SmartDeviceException
|
||||
from .json import dumps as json_dumps
|
||||
from .json import loads as json_loads
|
||||
from .protocol import BaseTransport
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _sha1(payload: bytes) -> str:
|
||||
sha1_algo = hashlib.sha1() # noqa: S324
|
||||
sha1_algo.update(payload)
|
||||
return sha1_algo.hexdigest()
|
||||
|
||||
|
||||
class AesTransport(BaseTransport):
|
||||
"""Implementation of the AES encryption protocol.
|
||||
|
||||
AES is the name used in device discovery for TP-Link's TAPO encryption
|
||||
protocol, sometimes used by newer firmware versions on kasa devices.
|
||||
"""
|
||||
|
||||
DEFAULT_TIMEOUT = 5
|
||||
SESSION_COOKIE_NAME = "TP_SESSIONID"
|
||||
COMMON_HEADERS = {
|
||||
"Content-Type": "application/json",
|
||||
"requestByApp": "true",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
*,
|
||||
credentials: Optional[Credentials] = None,
|
||||
timeout: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__(host=host)
|
||||
|
||||
self._credentials = credentials or Credentials(username="", password="")
|
||||
|
||||
self._handshake_done = False
|
||||
|
||||
self._encryption_session: Optional[AesEncyptionSession] = None
|
||||
self._session_expire_at: Optional[float] = None
|
||||
|
||||
self._timeout = timeout if timeout else self.DEFAULT_TIMEOUT
|
||||
self._session_cookie = None
|
||||
|
||||
self._http_client: httpx.AsyncClient = httpx.AsyncClient()
|
||||
self._login_token = None
|
||||
|
||||
_LOGGER.debug("Created AES object for %s", self.host)
|
||||
|
||||
def hash_credentials(self, login_v2):
|
||||
"""Hash the credentials."""
|
||||
if login_v2:
|
||||
un = base64.b64encode(
|
||||
_sha1(self._credentials.username.encode()).encode()
|
||||
).decode()
|
||||
pw = base64.b64encode(
|
||||
_sha1(self._credentials.password.encode()).encode()
|
||||
).decode()
|
||||
else:
|
||||
un = base64.b64encode(
|
||||
_sha1(self._credentials.username.encode()).encode()
|
||||
).decode()
|
||||
pw = base64.b64encode(self._credentials.password.encode()).decode()
|
||||
return un, pw
|
||||
|
||||
async def client_post(self, url, params=None, data=None, json=None, headers=None):
|
||||
"""Send an http post request to the device."""
|
||||
response_data = None
|
||||
cookies = None
|
||||
if self._session_cookie:
|
||||
cookies = httpx.Cookies()
|
||||
cookies.set(self.SESSION_COOKIE_NAME, self._session_cookie)
|
||||
self._http_client.cookies.clear()
|
||||
resp = await self._http_client.post(
|
||||
url,
|
||||
params=params,
|
||||
data=data,
|
||||
json=json,
|
||||
timeout=self._timeout,
|
||||
cookies=cookies,
|
||||
headers=self.COMMON_HEADERS,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
response_data = resp.json()
|
||||
|
||||
return resp.status_code, response_data
|
||||
|
||||
async def send_secure_passthrough(self, request: str):
|
||||
"""Send encrypted message as passthrough."""
|
||||
url = f"http://{self.host}/app"
|
||||
if self._login_token:
|
||||
url += f"?token={self._login_token}"
|
||||
|
||||
encrypted_payload = self._encryption_session.encrypt(request.encode()) # type: ignore
|
||||
passthrough_request = {
|
||||
"method": "securePassthrough",
|
||||
"params": {"request": encrypted_payload.decode()},
|
||||
}
|
||||
status_code, resp_dict = await self.client_post(url, json=passthrough_request)
|
||||
_LOGGER.debug(f"secure_passthrough response is {status_code}: {resp_dict}")
|
||||
if status_code == 200 and resp_dict["error_code"] == 0:
|
||||
response = self._encryption_session.decrypt( # type: ignore
|
||||
resp_dict["result"]["response"].encode()
|
||||
)
|
||||
_LOGGER.debug(f"decrypted secure_passthrough response is {response}")
|
||||
resp_dict = json_loads(response)
|
||||
return resp_dict
|
||||
else:
|
||||
self._handshake_done = False
|
||||
self._login_token = None
|
||||
raise AuthenticationException("Could not complete send")
|
||||
|
||||
async def perform_login(self, login_request: Union[str, dict], *, login_v2: bool):
|
||||
"""Login to the device."""
|
||||
self._login_token = None
|
||||
|
||||
if isinstance(login_request, str):
|
||||
login_request_dict: dict = json_loads(login_request)
|
||||
else:
|
||||
login_request_dict = login_request
|
||||
|
||||
un, pw = self.hash_credentials(login_v2)
|
||||
login_request_dict["params"] = {"password": pw, "username": un}
|
||||
request = json_dumps(login_request_dict)
|
||||
try:
|
||||
resp_dict = await self.send_secure_passthrough(request)
|
||||
except SmartDeviceException as ex:
|
||||
raise AuthenticationException(ex) from ex
|
||||
self._login_token = resp_dict["result"]["token"]
|
||||
|
||||
@property
|
||||
def needs_login(self) -> bool:
|
||||
"""Return true if the transport needs to do a login."""
|
||||
return self._login_token is None
|
||||
|
||||
async def login(self, request: str) -> None:
|
||||
"""Login to the device."""
|
||||
try:
|
||||
if self.needs_handshake:
|
||||
raise SmartDeviceException(
|
||||
"Handshake must be complete before trying to login"
|
||||
)
|
||||
await self.perform_login(request, login_v2=False)
|
||||
except AuthenticationException:
|
||||
await self.perform_handshake()
|
||||
await self.perform_login(request, login_v2=True)
|
||||
|
||||
@property
|
||||
def needs_handshake(self) -> bool:
|
||||
"""Return true if the transport needs to do a handshake."""
|
||||
return not self._handshake_done or self._handshake_session_expired()
|
||||
|
||||
async def handshake(self) -> None:
|
||||
"""Perform the encryption handshake."""
|
||||
await self.perform_handshake()
|
||||
|
||||
async def perform_handshake(self):
|
||||
"""Perform the handshake."""
|
||||
_LOGGER.debug("Will perform handshaking...")
|
||||
_LOGGER.debug("Generating keypair")
|
||||
|
||||
self._handshake_done = False
|
||||
self._session_expire_at = None
|
||||
self._session_cookie = None
|
||||
|
||||
url = f"http://{self.host}/app"
|
||||
key_pair = KeyPair.create_key_pair()
|
||||
|
||||
pub_key = (
|
||||
"-----BEGIN PUBLIC KEY-----\n"
|
||||
+ key_pair.get_public_key()
|
||||
+ "\n-----END PUBLIC KEY-----\n"
|
||||
)
|
||||
handshake_params = {"key": pub_key}
|
||||
_LOGGER.debug(f"Handshake params: {handshake_params}")
|
||||
|
||||
request_body = {"method": "handshake", "params": handshake_params}
|
||||
|
||||
_LOGGER.debug(f"Request {request_body}")
|
||||
|
||||
status_code, resp_dict = await self.client_post(url, json=request_body)
|
||||
|
||||
_LOGGER.debug(f"Device responded with: {resp_dict}")
|
||||
|
||||
if status_code == 200 and resp_dict["error_code"] == 0:
|
||||
_LOGGER.debug("Decoding handshake key...")
|
||||
handshake_key = resp_dict["result"]["key"]
|
||||
|
||||
self._session_cookie = self._http_client.cookies.get( # type: ignore
|
||||
self.SESSION_COOKIE_NAME
|
||||
)
|
||||
if not self._session_cookie:
|
||||
self._session_cookie = self._http_client.cookies.get( # type: ignore
|
||||
"SESSIONID"
|
||||
)
|
||||
|
||||
self._session_expire_at = time.time() + 86400
|
||||
self._encryption_session = AesEncyptionSession.create_from_keypair(
|
||||
handshake_key, key_pair
|
||||
)
|
||||
|
||||
self._handshake_done = True
|
||||
|
||||
_LOGGER.debug("Handshake with %s complete", self.host)
|
||||
|
||||
else:
|
||||
raise AuthenticationException("Could not complete handshake")
|
||||
|
||||
def _handshake_session_expired(self):
|
||||
"""Return true if session has expired."""
|
||||
return (
|
||||
self._session_expire_at is None
|
||||
or self._session_expire_at - time.time() <= 0
|
||||
)
|
||||
|
||||
async def send(self, request: str):
|
||||
"""Send the request."""
|
||||
if self.needs_handshake:
|
||||
raise SmartDeviceException(
|
||||
"Handshake must be complete before trying to send"
|
||||
)
|
||||
if self.needs_login:
|
||||
raise SmartDeviceException("Login must be complete before trying to send")
|
||||
|
||||
resp_dict = await self.send_secure_passthrough(request)
|
||||
if resp_dict["error_code"] != 0:
|
||||
self._handshake_done = False
|
||||
self._login_token = None
|
||||
raise SmartDeviceException(
|
||||
f"Could not complete send, response was {resp_dict}",
|
||||
)
|
||||
return resp_dict
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the protocol."""
|
||||
client = self._http_client
|
||||
self._http_client = None
|
||||
if client:
|
||||
await client.aclose()
|
||||
|
||||
|
||||
class AesEncyptionSession:
|
||||
"""Class for an AES encryption session."""
|
||||
|
||||
@staticmethod
|
||||
def create_from_keypair(handshake_key: str, keypair):
|
||||
"""Create the encryption session."""
|
||||
handshake_key_bytes: bytes = base64.b64decode(handshake_key.encode("UTF-8"))
|
||||
private_key_data = base64.b64decode(keypair.get_private_key().encode("UTF-8"))
|
||||
|
||||
private_key = serialization.load_der_private_key(private_key_data, None, None)
|
||||
key_and_iv = private_key.decrypt(
|
||||
handshake_key_bytes, asymmetric_padding.PKCS1v15()
|
||||
)
|
||||
if key_and_iv is None:
|
||||
raise ValueError("Decryption failed!")
|
||||
|
||||
return AesEncyptionSession(key_and_iv[:16], key_and_iv[16:])
|
||||
|
||||
def __init__(self, key, iv):
|
||||
self.cipher = Cipher(algorithms.AES(key), modes.CBC(iv))
|
||||
self.padding_strategy = padding.PKCS7(algorithms.AES.block_size)
|
||||
|
||||
def encrypt(self, data) -> bytes:
|
||||
"""Encrypt the message."""
|
||||
encryptor = self.cipher.encryptor()
|
||||
padder = self.padding_strategy.padder()
|
||||
padded_data = padder.update(data) + padder.finalize()
|
||||
encrypted = encryptor.update(padded_data) + encryptor.finalize()
|
||||
return base64.b64encode(encrypted)
|
||||
|
||||
def decrypt(self, data) -> str:
|
||||
"""Decrypt the message."""
|
||||
decryptor = self.cipher.decryptor()
|
||||
unpadder = self.padding_strategy.unpadder()
|
||||
decrypted = decryptor.update(base64.b64decode(data)) + decryptor.finalize()
|
||||
unpadded_data = unpadder.update(decrypted) + unpadder.finalize()
|
||||
return unpadded_data.decode()
|
||||
|
||||
|
||||
class KeyPair:
|
||||
"""Class for generating key pairs."""
|
||||
|
||||
@staticmethod
|
||||
def create_key_pair(key_size: int = 1024):
|
||||
"""Create a key pair."""
|
||||
private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size)
|
||||
public_key = private_key.public_key()
|
||||
|
||||
private_key_bytes = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
public_key_bytes = public_key.public_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
return KeyPair(
|
||||
private_key=base64.b64encode(private_key_bytes).decode("UTF-8"),
|
||||
public_key=base64.b64encode(public_key_bytes).decode("UTF-8"),
|
||||
)
|
||||
|
||||
def __init__(self, private_key: str, public_key: str):
|
||||
self.private_key = private_key
|
||||
self.public_key = public_key
|
||||
|
||||
def get_private_key(self) -> str:
|
||||
"""Get the private key."""
|
||||
return self.private_key
|
||||
|
||||
def get_public_key(self) -> str:
|
||||
"""Get the public key."""
|
||||
return self.public_key
|
@ -2,17 +2,21 @@
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, Optional, Type
|
||||
from typing import Any, Dict, Optional, Tuple, Type
|
||||
|
||||
from .aestransport import AesTransport
|
||||
from .credentials import Credentials
|
||||
from .device_type import DeviceType
|
||||
from .exceptions import UnsupportedDeviceException
|
||||
from .protocol import TPLinkProtocol
|
||||
from .iotprotocol import IotProtocol
|
||||
from .klaptransport import KlapTransport, TPlinkKlapTransportV2
|
||||
from .protocol import BaseTransport, TPLinkProtocol
|
||||
from .smartbulb import SmartBulb
|
||||
from .smartdevice import SmartDevice, SmartDeviceException
|
||||
from .smartdimmer import SmartDimmer
|
||||
from .smartlightstrip import SmartLightStrip
|
||||
from .smartplug import SmartPlug
|
||||
from .smartprotocol import SmartProtocol
|
||||
from .smartstrip import SmartStrip
|
||||
from .tapo.tapoplug import TapoPlug
|
||||
|
||||
@ -87,7 +91,7 @@ async def connect(
|
||||
if protocol_class is not None:
|
||||
unknown_dev.protocol = protocol_class(host, credentials=credentials)
|
||||
await unknown_dev.update()
|
||||
device_class = get_device_class_from_info(unknown_dev.internal_state)
|
||||
device_class = get_device_class_from_sys_info(unknown_dev.internal_state)
|
||||
dev = device_class(host=host, port=port, credentials=credentials, timeout=timeout)
|
||||
# Reuse the connection from the unknown device
|
||||
# so we don't have to reconnect
|
||||
@ -104,7 +108,7 @@ async def connect(
|
||||
return dev
|
||||
|
||||
|
||||
def get_device_class_from_info(info: Dict[str, Any]) -> Type[SmartDevice]:
|
||||
def get_device_class_from_sys_info(info: Dict[str, Any]) -> Type[SmartDevice]:
|
||||
"""Find SmartDevice subclass for device described by passed data."""
|
||||
if "system" not in info or "get_sysinfo" not in info["system"]:
|
||||
raise SmartDeviceException("No 'system' or 'get_sysinfo' in response")
|
||||
@ -129,3 +133,35 @@ def get_device_class_from_info(info: Dict[str, Any]) -> Type[SmartDevice]:
|
||||
|
||||
return SmartBulb
|
||||
raise UnsupportedDeviceException("Unknown device type: %s" % type_)
|
||||
|
||||
|
||||
def get_device_class_from_type_name(device_type: str) -> Optional[Type[SmartDevice]]:
|
||||
"""Return the device class from the type name."""
|
||||
supported_device_types: dict[str, Type[SmartDevice]] = {
|
||||
"SMART.TAPOPLUG": TapoPlug,
|
||||
"SMART.KASAPLUG": TapoPlug,
|
||||
"IOT.SMARTPLUGSWITCH": SmartPlug,
|
||||
}
|
||||
return supported_device_types.get(device_type)
|
||||
|
||||
|
||||
def get_protocol_from_connection_name(
|
||||
connection_name: str, host: str, credentials: Optional[Credentials] = None
|
||||
) -> Optional[TPLinkProtocol]:
|
||||
"""Return the protocol from the connection name."""
|
||||
supported_device_protocols: dict[
|
||||
str, Tuple[Type[TPLinkProtocol], Type[BaseTransport]]
|
||||
] = {
|
||||
"IOT.KLAP": (IotProtocol, KlapTransport),
|
||||
"SMART.AES": (SmartProtocol, AesTransport),
|
||||
"SMART.KLAP": (SmartProtocol, TPlinkKlapTransportV2),
|
||||
}
|
||||
if connection_name not in supported_device_protocols:
|
||||
return None
|
||||
|
||||
protocol_class, transport_class = supported_device_protocols.get(connection_name) # type: ignore
|
||||
transport: BaseTransport = transport_class(host, credentials=credentials)
|
||||
protocol: TPLinkProtocol = protocol_class(
|
||||
host, credentials=credentials, transport=transport
|
||||
)
|
||||
return protocol
|
||||
|
@ -15,18 +15,18 @@ try:
|
||||
except ImportError:
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from kasa.aesprotocol import TPLinkAes
|
||||
from kasa.credentials import Credentials
|
||||
from kasa.exceptions import UnsupportedDeviceException
|
||||
from kasa.json import dumps as json_dumps
|
||||
from kasa.json import loads as json_loads
|
||||
from kasa.klapprotocol import TPLinkKlap
|
||||
from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol
|
||||
from kasa.protocol import TPLinkSmartHomeProtocol
|
||||
from kasa.smartdevice import SmartDevice, SmartDeviceException
|
||||
from kasa.smartplug import SmartPlug
|
||||
from kasa.tapo.tapoplug import TapoPlug
|
||||
|
||||
from .device_factory import get_device_class_from_info
|
||||
from .device_factory import (
|
||||
get_device_class_from_sys_info,
|
||||
get_device_class_from_type_name,
|
||||
get_protocol_from_connection_name,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@ -348,7 +348,16 @@ class Discover:
|
||||
@staticmethod
|
||||
def _get_device_class(info: dict) -> Type[SmartDevice]:
|
||||
"""Find SmartDevice subclass for device described by passed data."""
|
||||
return get_device_class_from_info(info)
|
||||
if "result" in info:
|
||||
discovery_result = DiscoveryResult(**info["result"])
|
||||
dev_class = get_device_class_from_type_name(discovery_result.device_type)
|
||||
if not dev_class:
|
||||
raise UnsupportedDeviceException(
|
||||
"Unknown device type: %s" % discovery_result.device_type
|
||||
)
|
||||
return dev_class
|
||||
else:
|
||||
return get_device_class_from_sys_info(info)
|
||||
|
||||
@staticmethod
|
||||
def _get_device_instance_legacy(data: bytes, ip: str, port: int) -> SmartDevice:
|
||||
@ -384,24 +393,17 @@ class Discover:
|
||||
encrypt_type_ = (
|
||||
f"{type_.split('.')[0]}.{discovery_result.mgt_encrypt_schm.encrypt_type}"
|
||||
)
|
||||
device_class = None
|
||||
|
||||
supported_device_types: dict[str, Type[SmartDevice]] = {
|
||||
"SMART.TAPOPLUG": TapoPlug,
|
||||
"SMART.KASAPLUG": TapoPlug,
|
||||
"IOT.SMARTPLUGSWITCH": SmartPlug,
|
||||
}
|
||||
supported_device_protocols: dict[str, Type[TPLinkProtocol]] = {
|
||||
"IOT.KLAP": TPLinkKlap,
|
||||
"SMART.AES": TPLinkAes,
|
||||
}
|
||||
|
||||
if (device_class := supported_device_types.get(type_)) is None:
|
||||
if (device_class := get_device_class_from_type_name(type_)) is None:
|
||||
_LOGGER.warning("Got unsupported device type: %s", type_)
|
||||
raise UnsupportedDeviceException(
|
||||
f"Unsupported device {ip} of type {type_}: {info}"
|
||||
)
|
||||
if (protocol_class := supported_device_protocols.get(encrypt_type_)) is None:
|
||||
if (
|
||||
protocol := get_protocol_from_connection_name(
|
||||
encrypt_type_, ip, credentials=credentials
|
||||
)
|
||||
) is None:
|
||||
_LOGGER.warning("Got unsupported device type: %s", encrypt_type_)
|
||||
raise UnsupportedDeviceException(
|
||||
f"Unsupported encryption scheme {ip} of type {encrypt_type_}: {info}"
|
||||
@ -409,7 +411,7 @@ class Discover:
|
||||
|
||||
_LOGGER.debug("[DISCOVERY] %s << %s", ip, info)
|
||||
device = device_class(ip, port=port, credentials=credentials)
|
||||
device.protocol = protocol_class(ip, credentials=credentials)
|
||||
device.protocol = protocol
|
||||
device.update_from_discover_info(discovery_result.get_dict())
|
||||
return device
|
||||
|
||||
|
100
kasa/iotprotocol.py
Executable file
100
kasa/iotprotocol.py
Executable file
@ -0,0 +1,100 @@
|
||||
"""Module for the IOT legacy IOT KASA protocol."""
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from .credentials import Credentials
|
||||
from .exceptions import AuthenticationException, SmartDeviceException
|
||||
from .json import dumps as json_dumps
|
||||
from .klaptransport import KlapTransport
|
||||
from .protocol import BaseTransport, TPLinkProtocol
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IotProtocol(TPLinkProtocol):
|
||||
"""Class for the legacy TPLink IOT KASA Protocol."""
|
||||
|
||||
DEFAULT_PORT = 80
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
*,
|
||||
transport: Optional[BaseTransport] = None,
|
||||
credentials: Optional[Credentials] = None,
|
||||
timeout: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__(host=host, port=self.DEFAULT_PORT)
|
||||
|
||||
self._credentials: Credentials = credentials or Credentials(
|
||||
username="", password=""
|
||||
)
|
||||
self._transport: BaseTransport = transport or KlapTransport(
|
||||
host, credentials=self._credentials, timeout=timeout
|
||||
)
|
||||
|
||||
self._query_lock = asyncio.Lock()
|
||||
|
||||
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
||||
"""Query the device retrying for retry_count on failure."""
|
||||
if isinstance(request, dict):
|
||||
request = json_dumps(request)
|
||||
assert isinstance(request, str) # noqa: S101
|
||||
|
||||
async with self._query_lock:
|
||||
return await self._query(request, retry_count)
|
||||
|
||||
async def _query(self, request: str, retry_count: int = 3) -> Dict:
|
||||
for retry in range(retry_count + 1):
|
||||
try:
|
||||
return await self._execute_query(request, retry)
|
||||
except httpx.CloseError as sdex:
|
||||
await self.close()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device: {self.host}: {sdex}"
|
||||
) from sdex
|
||||
continue
|
||||
except httpx.ConnectError as cex:
|
||||
await self.close()
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device: {self.host}: {cex}"
|
||||
) from cex
|
||||
except TimeoutError as tex:
|
||||
await self.close()
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device, timed out: {self.host}: {tex}"
|
||||
) from tex
|
||||
except AuthenticationException as auex:
|
||||
_LOGGER.debug("Unable to authenticate with %s, not retrying", self.host)
|
||||
raise auex
|
||||
except Exception as ex:
|
||||
await self.close()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device: {self.host}: {ex}"
|
||||
) from ex
|
||||
continue
|
||||
|
||||
# make mypy happy, this should never be reached..
|
||||
raise SmartDeviceException("Query reached somehow to unreachable")
|
||||
|
||||
async def _execute_query(self, request: str, retry_count: int) -> Dict:
|
||||
if self._transport.needs_handshake:
|
||||
await self._transport.handshake()
|
||||
|
||||
if self._transport.needs_login: # This shouln't happen
|
||||
raise SmartDeviceException(
|
||||
"IOT Protocol needs to login to transport but is not login aware"
|
||||
)
|
||||
|
||||
return await self._transport.send(request)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the protocol."""
|
||||
await self._transport.close()
|
295
kasa/klapprotocol.py → kasa/klaptransport.py
Executable file → Normal file
295
kasa/klapprotocol.py → kasa/klaptransport.py
Executable file → Normal file
@ -47,7 +47,7 @@ import logging
|
||||
import secrets
|
||||
import time
|
||||
from pprint import pformat as pf
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
from cryptography.hazmat.primitives import hashes, padding
|
||||
@ -55,33 +55,33 @@ from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
|
||||
from .credentials import Credentials
|
||||
from .exceptions import AuthenticationException, SmartDeviceException
|
||||
from .json import dumps as json_dumps
|
||||
from .json import loads as json_loads
|
||||
from .protocol import TPLinkProtocol
|
||||
from .protocol import BaseTransport, md5
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
logging.getLogger("httpx").propagate = False
|
||||
|
||||
|
||||
def _sha256(payload: bytes) -> bytes:
|
||||
return hashlib.sha256(payload).digest()
|
||||
|
||||
|
||||
def _md5(payload: bytes) -> bytes:
|
||||
digest = hashes.Hash(hashes.MD5()) # noqa: S303
|
||||
digest = hashes.Hash(hashes.SHA256()) # noqa: S303
|
||||
digest.update(payload)
|
||||
hash = digest.finalize()
|
||||
return hash
|
||||
|
||||
|
||||
class TPLinkKlap(TPLinkProtocol):
|
||||
def _sha1(payload: bytes) -> bytes:
|
||||
digest = hashes.Hash(hashes.SHA1()) # noqa: S303
|
||||
digest.update(payload)
|
||||
return digest.finalize()
|
||||
|
||||
|
||||
class KlapTransport(BaseTransport):
|
||||
"""Implementation of the KLAP encryption protocol.
|
||||
|
||||
KLAP is the name used in device discovery for TP-Link's new encryption
|
||||
protocol, used by newer firmware versions.
|
||||
"""
|
||||
|
||||
DEFAULT_PORT = 80
|
||||
DEFAULT_TIMEOUT = 5
|
||||
DISCOVERY_QUERY = {"system": {"get_sysinfo": None}}
|
||||
KASA_SETUP_EMAIL = "kasa@tp-link.net"
|
||||
@ -95,29 +95,24 @@ class TPLinkKlap(TPLinkProtocol):
|
||||
credentials: Optional[Credentials] = None,
|
||||
timeout: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__(host=host, port=self.DEFAULT_PORT)
|
||||
|
||||
self.credentials = (
|
||||
credentials
|
||||
if credentials and credentials.username and credentials.password
|
||||
else Credentials(username="", password="")
|
||||
)
|
||||
super().__init__(host=host)
|
||||
|
||||
self._credentials = credentials or Credentials(username="", password="")
|
||||
self._local_seed: Optional[bytes] = None
|
||||
self.local_auth_hash = self.generate_auth_hash(self.credentials)
|
||||
self.local_auth_owner = self.generate_owner_hash(self.credentials).hex()
|
||||
self.kasa_setup_auth_hash = None
|
||||
self.blank_auth_hash = None
|
||||
self.handshake_lock = asyncio.Lock()
|
||||
self.query_lock = asyncio.Lock()
|
||||
self.handshake_done = False
|
||||
self._local_auth_hash = self.generate_auth_hash(self._credentials)
|
||||
self._local_auth_owner = self.generate_owner_hash(self._credentials).hex()
|
||||
self._kasa_setup_auth_hash = None
|
||||
self._blank_auth_hash = None
|
||||
self._handshake_lock = asyncio.Lock()
|
||||
self._query_lock = asyncio.Lock()
|
||||
self._handshake_done = False
|
||||
|
||||
self.encryption_session: Optional[KlapEncryptionSession] = None
|
||||
self.session_expire_at: Optional[float] = None
|
||||
self._encryption_session: Optional[KlapEncryptionSession] = None
|
||||
self._session_expire_at: Optional[float] = None
|
||||
|
||||
self.timeout = timeout if timeout else self.DEFAULT_TIMEOUT
|
||||
self.session_cookie = None
|
||||
self.http_client: Optional[httpx.AsyncClient] = None
|
||||
self._timeout = timeout if timeout else self.DEFAULT_TIMEOUT
|
||||
self._session_cookie = None
|
||||
self._http_client: httpx.AsyncClient = httpx.AsyncClient()
|
||||
|
||||
_LOGGER.debug("Created KLAP object for %s", self.host)
|
||||
|
||||
@ -125,15 +120,15 @@ class TPLinkKlap(TPLinkProtocol):
|
||||
"""Send an http post request to the device."""
|
||||
response_data = None
|
||||
cookies = None
|
||||
if self.session_cookie:
|
||||
if self._session_cookie:
|
||||
cookies = httpx.Cookies()
|
||||
cookies.set(self.SESSION_COOKIE_NAME, self.session_cookie)
|
||||
self.http_client.cookies.clear()
|
||||
resp = await self.http_client.post(
|
||||
cookies.set(self.SESSION_COOKIE_NAME, self._session_cookie)
|
||||
self._http_client.cookies.clear()
|
||||
resp = await self._http_client.post(
|
||||
url,
|
||||
params=params,
|
||||
data=data,
|
||||
timeout=self.timeout,
|
||||
timeout=self._timeout,
|
||||
cookies=cookies,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
@ -183,44 +178,55 @@ class TPLinkKlap(TPLinkProtocol):
|
||||
server_hash.hex(),
|
||||
)
|
||||
|
||||
local_seed_auth_hash = _sha256(local_seed + self.local_auth_hash)
|
||||
local_seed_auth_hash = self.handshake1_seed_auth_hash(
|
||||
local_seed, remote_seed, self._local_auth_hash
|
||||
) # type: ignore
|
||||
|
||||
# Check the response from the device with local credentials
|
||||
if local_seed_auth_hash == server_hash:
|
||||
_LOGGER.debug("handshake1 hashes match with expected credentials")
|
||||
return local_seed, remote_seed, self.local_auth_hash # type: ignore
|
||||
return local_seed, remote_seed, self._local_auth_hash # type: ignore
|
||||
|
||||
# Now check against the default kasa setup credentials
|
||||
if not self.kasa_setup_auth_hash:
|
||||
if not self._kasa_setup_auth_hash:
|
||||
kasa_setup_creds = Credentials(
|
||||
username=TPLinkKlap.KASA_SETUP_EMAIL,
|
||||
password=TPLinkKlap.KASA_SETUP_PASSWORD,
|
||||
username=self.KASA_SETUP_EMAIL,
|
||||
password=self.KASA_SETUP_PASSWORD,
|
||||
)
|
||||
self.kasa_setup_auth_hash = TPLinkKlap.generate_auth_hash(kasa_setup_creds)
|
||||
self._kasa_setup_auth_hash = self.generate_auth_hash(kasa_setup_creds)
|
||||
|
||||
kasa_setup_seed_auth_hash = _sha256(
|
||||
local_seed + self.kasa_setup_auth_hash # type: ignore
|
||||
kasa_setup_seed_auth_hash = self.handshake1_seed_auth_hash(
|
||||
local_seed,
|
||||
remote_seed,
|
||||
self._kasa_setup_auth_hash, # type: ignore
|
||||
)
|
||||
|
||||
if kasa_setup_seed_auth_hash == server_hash:
|
||||
_LOGGER.debug(
|
||||
"Server response doesn't match our expected hash on ip %s"
|
||||
+ " but an authentication with kasa setup credentials matched",
|
||||
self.host,
|
||||
)
|
||||
return local_seed, remote_seed, self.kasa_setup_auth_hash # type: ignore
|
||||
return local_seed, remote_seed, self._kasa_setup_auth_hash # type: ignore
|
||||
|
||||
# Finally check against blank credentials if not already blank
|
||||
if self.credentials != (blank_creds := Credentials(username="", password="")):
|
||||
if not self.blank_auth_hash:
|
||||
self.blank_auth_hash = TPLinkKlap.generate_auth_hash(blank_creds)
|
||||
blank_seed_auth_hash = _sha256(local_seed + self.blank_auth_hash) # type: ignore
|
||||
if self._credentials != (blank_creds := Credentials(username="", password="")):
|
||||
if not self._blank_auth_hash:
|
||||
self._blank_auth_hash = self.generate_auth_hash(blank_creds)
|
||||
|
||||
blank_seed_auth_hash = self.handshake1_seed_auth_hash(
|
||||
local_seed,
|
||||
remote_seed,
|
||||
self._blank_auth_hash, # type: ignore
|
||||
)
|
||||
|
||||
if blank_seed_auth_hash == server_hash:
|
||||
_LOGGER.debug(
|
||||
"Server response doesn't match our expected hash on ip %s"
|
||||
+ " but an authentication with blank credentials matched",
|
||||
self.host,
|
||||
)
|
||||
return local_seed, remote_seed, self.blank_auth_hash # type: ignore
|
||||
return local_seed, remote_seed, self._blank_auth_hash # type: ignore
|
||||
|
||||
msg = f"Server response doesn't match our challenge on ip {self.host}"
|
||||
_LOGGER.debug(msg)
|
||||
@ -235,7 +241,7 @@ class TPLinkKlap(TPLinkProtocol):
|
||||
|
||||
url = f"http://{self.host}/app/handshake2"
|
||||
|
||||
payload = _sha256(remote_seed + auth_hash)
|
||||
payload = self.handshake2_seed_auth_hash(local_seed, remote_seed, auth_hash)
|
||||
|
||||
response_status, response_data = await self.client_post(url, data=payload)
|
||||
|
||||
@ -256,115 +262,70 @@ class TPLinkKlap(TPLinkProtocol):
|
||||
|
||||
return KlapEncryptionSession(local_seed, remote_seed, auth_hash)
|
||||
|
||||
@property
|
||||
def needs_login(self) -> bool:
|
||||
"""Will return false as KLAP does not do a login."""
|
||||
return False
|
||||
|
||||
async def login(self, request: str) -> None:
|
||||
"""Will raise and exception as KLAP does not do a login."""
|
||||
raise SmartDeviceException(
|
||||
"KLAP does not perform logins and return needs_login == False"
|
||||
)
|
||||
|
||||
@property
|
||||
def needs_handshake(self) -> bool:
|
||||
"""Return true if the transport needs to do a handshake."""
|
||||
return not self._handshake_done or self._handshake_session_expired()
|
||||
|
||||
async def handshake(self) -> None:
|
||||
"""Perform the encryption handshake."""
|
||||
await self.perform_handshake()
|
||||
|
||||
async def perform_handshake(self) -> Any:
|
||||
"""Perform handshake1 and handshake2.
|
||||
|
||||
Sets the encryption_session if successful.
|
||||
"""
|
||||
_LOGGER.debug("Starting handshake with %s", self.host)
|
||||
self.handshake_done = False
|
||||
self.session_expire_at = None
|
||||
self.session_cookie = None
|
||||
self._handshake_done = False
|
||||
self._session_expire_at = None
|
||||
self._session_cookie = None
|
||||
|
||||
local_seed, remote_seed, auth_hash = await self.perform_handshake1()
|
||||
self.session_cookie = self.http_client.cookies.get( # type: ignore
|
||||
TPLinkKlap.SESSION_COOKIE_NAME
|
||||
self._session_cookie = self._http_client.cookies.get( # type: ignore
|
||||
self.SESSION_COOKIE_NAME
|
||||
)
|
||||
# The device returns a TIMEOUT cookie on handshake1 which
|
||||
# it doesn't like to get back so we store the one we want
|
||||
|
||||
self.session_expire_at = time.time() + 86400
|
||||
self.encryption_session = await self.perform_handshake2(
|
||||
self._session_expire_at = time.time() + 86400
|
||||
self._encryption_session = await self.perform_handshake2(
|
||||
local_seed, remote_seed, auth_hash
|
||||
)
|
||||
self.handshake_done = True
|
||||
self._handshake_done = True
|
||||
|
||||
_LOGGER.debug("Handshake with %s complete", self.host)
|
||||
|
||||
def handshake_session_expired(self):
|
||||
def _handshake_session_expired(self):
|
||||
"""Return true if session has expired."""
|
||||
return (
|
||||
self.session_expire_at is None or self.session_expire_at - time.time() <= 0
|
||||
self._session_expire_at is None
|
||||
or self._session_expire_at - time.time() <= 0
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def generate_auth_hash(creds: Credentials):
|
||||
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
|
||||
un = creds.username or ""
|
||||
pw = creds.password or ""
|
||||
return _md5(_md5(un.encode()) + _md5(pw.encode()))
|
||||
|
||||
@staticmethod
|
||||
def generate_owner_hash(creds: Credentials):
|
||||
"""Return the MD5 hash of the username in this object."""
|
||||
un = creds.username or ""
|
||||
return _md5(un.encode())
|
||||
|
||||
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
||||
"""Query the device retrying for retry_count on failure."""
|
||||
if isinstance(request, dict):
|
||||
request = json_dumps(request)
|
||||
assert isinstance(request, str) # noqa: S101
|
||||
|
||||
async with self.query_lock:
|
||||
return await self._query(request, retry_count)
|
||||
|
||||
async def _query(self, request: str, retry_count: int = 3) -> Dict:
|
||||
for retry in range(retry_count + 1):
|
||||
try:
|
||||
return await self._execute_query(request, retry)
|
||||
except httpx.CloseError as sdex:
|
||||
await self.close()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
||||
async def send(self, request: str):
|
||||
"""Send the request."""
|
||||
if self.needs_handshake:
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device: {self.host}: {sdex}"
|
||||
) from sdex
|
||||
continue
|
||||
except httpx.ConnectError as cex:
|
||||
await self.close()
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device: {self.host}: {cex}"
|
||||
) from cex
|
||||
except TimeoutError as tex:
|
||||
await self.close()
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device, timed out: {self.host}: {tex}"
|
||||
) from tex
|
||||
except AuthenticationException as auex:
|
||||
_LOGGER.debug("Unable to authenticate with %s, not retrying", self.host)
|
||||
raise auex
|
||||
except Exception as ex:
|
||||
await self.close()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device: {self.host}: {ex}"
|
||||
) from ex
|
||||
continue
|
||||
|
||||
# make mypy happy, this should never be reached..
|
||||
raise SmartDeviceException("Query reached somehow to unreachable")
|
||||
|
||||
async def _execute_query(self, request: str, retry_count: int) -> Dict:
|
||||
if not self.http_client:
|
||||
self.http_client = httpx.AsyncClient()
|
||||
|
||||
if not self.handshake_done or self.handshake_session_expired():
|
||||
try:
|
||||
await self.perform_handshake()
|
||||
|
||||
except AuthenticationException as auex:
|
||||
_LOGGER.debug(
|
||||
"Unable to complete handshake for device %s, "
|
||||
+ "authentication failed",
|
||||
self.host,
|
||||
"Handshake must be complete before trying to send"
|
||||
)
|
||||
raise auex
|
||||
if self.needs_login:
|
||||
raise SmartDeviceException("Login must be complete before trying to send")
|
||||
|
||||
# Check for mypy
|
||||
if self.encryption_session is not None:
|
||||
payload, seq = self.encryption_session.encrypt(request.encode())
|
||||
if self._encryption_session is not None:
|
||||
payload, seq = self._encryption_session.encrypt(request.encode())
|
||||
|
||||
url = f"http://{self.host}/app/request"
|
||||
|
||||
@ -376,14 +337,14 @@ class TPLinkKlap(TPLinkProtocol):
|
||||
|
||||
msg = (
|
||||
f"at {datetime.datetime.now()}. Host is {self.host}, "
|
||||
+ f"Retry count is {retry_count}, Sequence is {seq}, "
|
||||
+ f"Sequence is {seq}, "
|
||||
+ f"Response status is {response_status}, Request was {request}"
|
||||
)
|
||||
if response_status != 200:
|
||||
_LOGGER.error("Query failed after succesful authentication " + msg)
|
||||
# If we failed with a security error, force a new handshake next time.
|
||||
if response_status == 403:
|
||||
self.handshake_done = False
|
||||
self._handshake_done = False
|
||||
raise AuthenticationException(
|
||||
f"Got a security error from {self.host} after handshake "
|
||||
+ "completed"
|
||||
@ -397,8 +358,8 @@ class TPLinkKlap(TPLinkProtocol):
|
||||
_LOGGER.debug("Query posted " + msg)
|
||||
|
||||
# Check for mypy
|
||||
if self.encryption_session is not None:
|
||||
decrypted_response = self.encryption_session.decrypt(response_data)
|
||||
if self._encryption_session is not None:
|
||||
decrypted_response = self._encryption_session.decrypt(response_data)
|
||||
|
||||
json_payload = json_loads(decrypted_response)
|
||||
|
||||
@ -411,12 +372,66 @@ class TPLinkKlap(TPLinkProtocol):
|
||||
return json_payload
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the protocol."""
|
||||
client = self.http_client
|
||||
self.http_client = None
|
||||
"""Close the transport."""
|
||||
client = self._http_client
|
||||
self._http_client = None
|
||||
if client:
|
||||
await client.aclose()
|
||||
|
||||
@staticmethod
|
||||
def generate_auth_hash(creds: Credentials):
|
||||
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
|
||||
un = creds.username or ""
|
||||
pw = creds.password or ""
|
||||
|
||||
return md5(md5(un.encode()) + md5(pw.encode()))
|
||||
|
||||
@staticmethod
|
||||
def handshake1_seed_auth_hash(
|
||||
local_seed: bytes, remote_seed: bytes, auth_hash: bytes
|
||||
):
|
||||
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
|
||||
return _sha256(local_seed + auth_hash)
|
||||
|
||||
@staticmethod
|
||||
def handshake2_seed_auth_hash(
|
||||
local_seed: bytes, remote_seed: bytes, auth_hash: bytes
|
||||
):
|
||||
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
|
||||
return _sha256(remote_seed + auth_hash)
|
||||
|
||||
@staticmethod
|
||||
def generate_owner_hash(creds: Credentials):
|
||||
"""Return the MD5 hash of the username in this object."""
|
||||
un = creds.username or ""
|
||||
return md5(un.encode())
|
||||
|
||||
|
||||
class TPlinkKlapTransportV2(KlapTransport):
|
||||
"""Implementation of the KLAP encryption protocol with v2 hanshake hashes."""
|
||||
|
||||
@staticmethod
|
||||
def generate_auth_hash(creds: Credentials):
|
||||
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
|
||||
un = creds.username or ""
|
||||
pw = creds.password or ""
|
||||
|
||||
return _sha256(_sha1(un.encode()) + _sha1(pw.encode()))
|
||||
|
||||
@staticmethod
|
||||
def handshake1_seed_auth_hash(
|
||||
local_seed: bytes, remote_seed: bytes, auth_hash: bytes
|
||||
):
|
||||
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
|
||||
return _sha256(local_seed + remote_seed + auth_hash)
|
||||
|
||||
@staticmethod
|
||||
def handshake2_seed_auth_hash(
|
||||
local_seed: bytes, remote_seed: bytes, auth_hash: bytes
|
||||
):
|
||||
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
|
||||
return _sha256(remote_seed + local_seed + auth_hash)
|
||||
|
||||
|
||||
class KlapEncryptionSession:
|
||||
"""Class to represent an encryption session and it's internal state.
|
@ -22,6 +22,7 @@ from typing import Dict, Generator, Optional, Union
|
||||
# When support for cpython older than 3.11 is dropped
|
||||
# async_timeout can be replaced with asyncio.timeout
|
||||
from async_timeout import timeout as asyncio_timeout
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
|
||||
from .credentials import Credentials
|
||||
from .exceptions import SmartDeviceException
|
||||
@ -32,6 +33,56 @@ _LOGGER = logging.getLogger(__name__)
|
||||
_NO_RETRY_ERRORS = {errno.EHOSTDOWN, errno.EHOSTUNREACH, errno.ECONNREFUSED}
|
||||
|
||||
|
||||
def md5(payload: bytes) -> bytes:
|
||||
"""Return an md5 hash of the payload."""
|
||||
digest = hashes.Hash(hashes.MD5()) # noqa: S303
|
||||
digest.update(payload)
|
||||
hash = digest.finalize()
|
||||
return hash
|
||||
|
||||
|
||||
class BaseTransport(ABC):
|
||||
"""Base class for all TP-Link protocol transports."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
*,
|
||||
port: Optional[int] = None,
|
||||
credentials: Optional[Credentials] = None,
|
||||
) -> None:
|
||||
"""Create a protocol object."""
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.credentials = credentials
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def needs_handshake(self) -> bool:
|
||||
"""Return true if the transport needs to do a handshake."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def needs_login(self) -> bool:
|
||||
"""Return true if the transport needs to do a login."""
|
||||
|
||||
@abstractmethod
|
||||
async def login(self, request: str) -> None:
|
||||
"""Login to the device."""
|
||||
|
||||
@abstractmethod
|
||||
async def handshake(self) -> None:
|
||||
"""Perform the encryption handshake."""
|
||||
|
||||
@abstractmethod
|
||||
async def send(self, request: str) -> Dict:
|
||||
"""Send a message to the device and return a response."""
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
"""Close the transport. Abstract method to be overriden."""
|
||||
|
||||
|
||||
class TPLinkProtocol(ABC):
|
||||
"""Base class for all TP-Link Smart Home communication."""
|
||||
|
||||
@ -41,6 +92,7 @@ class TPLinkProtocol(ABC):
|
||||
*,
|
||||
port: Optional[int] = None,
|
||||
credentials: Optional[Credentials] = None,
|
||||
transport: Optional[BaseTransport] = None,
|
||||
) -> None:
|
||||
"""Create a protocol object."""
|
||||
self.host = host
|
||||
|
@ -365,6 +365,7 @@ class SmartDevice:
|
||||
|
||||
def update_from_discover_info(self, info: Dict[str, Any]) -> None:
|
||||
"""Update state from info from the discover call."""
|
||||
self._discovery_info = info
|
||||
if "system" in info and (sys_info := info["system"].get("get_sysinfo")):
|
||||
self._last_update = info
|
||||
self._set_sys_info(sys_info)
|
||||
@ -372,7 +373,6 @@ class SmartDevice:
|
||||
# This allows setting of some info properties directly
|
||||
# from partial discovery info that will then be found
|
||||
# by the requires_update decorator
|
||||
self._discovery_info = info
|
||||
self._set_sys_info(info)
|
||||
|
||||
def _set_sys_info(self, sys_info: Dict[str, Any]) -> None:
|
||||
|
219
kasa/smartprotocol.py
Normal file
219
kasa/smartprotocol.py
Normal file
@ -0,0 +1,219 @@
|
||||
"""Implementation of the TP-Link AES Protocol.
|
||||
|
||||
Based on the work of https://github.com/petretiandrea/plugp100
|
||||
under compatible GNU GPL3 license.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from pprint import pformat as pf
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from .aestransport import AesTransport
|
||||
from .credentials import Credentials
|
||||
from .exceptions import AuthenticationException, SmartDeviceException
|
||||
from .json import dumps as json_dumps
|
||||
from .protocol import BaseTransport, TPLinkProtocol, md5
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
logging.getLogger("httpx").propagate = False
|
||||
|
||||
|
||||
class SmartProtocol(TPLinkProtocol):
|
||||
"""Class for the new TPLink SMART protocol."""
|
||||
|
||||
DEFAULT_PORT = 80
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
*,
|
||||
transport: Optional[BaseTransport] = None,
|
||||
credentials: Optional[Credentials] = None,
|
||||
timeout: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__(host=host, port=self.DEFAULT_PORT)
|
||||
|
||||
self._credentials: Credentials = credentials or Credentials(
|
||||
username="", password=""
|
||||
)
|
||||
self._transport: BaseTransport = transport or AesTransport(
|
||||
host, credentials=self._credentials, timeout=timeout
|
||||
)
|
||||
self._terminal_uuid: Optional[str] = None
|
||||
self._request_id_generator = SnowflakeId(1, 1)
|
||||
self._query_lock = asyncio.Lock()
|
||||
|
||||
def get_smart_request(self, method, params=None) -> str:
|
||||
"""Get a request message as a string."""
|
||||
request = {
|
||||
"method": method,
|
||||
"params": params,
|
||||
"requestID": self._request_id_generator.generate_id(),
|
||||
"request_time_milis": round(time.time() * 1000),
|
||||
"terminal_uuid": self._terminal_uuid,
|
||||
}
|
||||
return json_dumps(request)
|
||||
|
||||
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
||||
"""Query the device retrying for retry_count on failure."""
|
||||
async with self._query_lock:
|
||||
resp_dict = await self._query(request, retry_count)
|
||||
if "result" in resp_dict:
|
||||
return resp_dict["result"]
|
||||
return {}
|
||||
|
||||
async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
||||
for retry in range(retry_count + 1):
|
||||
try:
|
||||
return await self._execute_query(request, retry)
|
||||
except httpx.CloseError as sdex:
|
||||
await self.close()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device: {self.host}: {sdex}"
|
||||
) from sdex
|
||||
continue
|
||||
except httpx.ConnectError as cex:
|
||||
await self.close()
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device: {self.host}: {cex}"
|
||||
) from cex
|
||||
except TimeoutError as tex:
|
||||
await self.close()
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device, timed out: {self.host}: {tex}"
|
||||
) from tex
|
||||
except AuthenticationException as auex:
|
||||
_LOGGER.debug("Unable to authenticate with %s, not retrying", self.host)
|
||||
raise auex
|
||||
except Exception as ex:
|
||||
await self.close()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device: {self.host}: {ex}"
|
||||
) from ex
|
||||
continue
|
||||
|
||||
# make mypy happy, this should never be reached..
|
||||
raise SmartDeviceException("Query reached somehow to unreachable")
|
||||
|
||||
async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> Dict:
|
||||
if isinstance(request, dict):
|
||||
smart_method = next(iter(request))
|
||||
smart_params = request[smart_method]
|
||||
else:
|
||||
smart_method = request
|
||||
smart_params = None
|
||||
|
||||
if self._transport.needs_handshake:
|
||||
await self._transport.handshake()
|
||||
|
||||
if self._transport.needs_login:
|
||||
self._terminal_uuid = base64.b64encode(md5(uuid.uuid4().bytes)).decode(
|
||||
"UTF-8"
|
||||
)
|
||||
login_request = self.get_smart_request("login_device")
|
||||
await self._transport.login(login_request)
|
||||
|
||||
smart_request = self.get_smart_request(smart_method, smart_params)
|
||||
response_data = await self._transport.send(smart_request)
|
||||
|
||||
_LOGGER.debug(
|
||||
"%s << %s",
|
||||
self.host,
|
||||
_LOGGER.isEnabledFor(logging.DEBUG) and pf(response_data),
|
||||
)
|
||||
|
||||
return response_data
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the protocol."""
|
||||
await self._transport.close()
|
||||
|
||||
|
||||
class SnowflakeId:
|
||||
"""Class for generating snowflake ids."""
|
||||
|
||||
EPOCH = 1420041600000 # Custom epoch (in milliseconds)
|
||||
WORKER_ID_BITS = 5
|
||||
DATA_CENTER_ID_BITS = 5
|
||||
SEQUENCE_BITS = 12
|
||||
|
||||
MAX_WORKER_ID = (1 << WORKER_ID_BITS) - 1
|
||||
MAX_DATA_CENTER_ID = (1 << DATA_CENTER_ID_BITS) - 1
|
||||
|
||||
SEQUENCE_MASK = (1 << SEQUENCE_BITS) - 1
|
||||
|
||||
def __init__(self, worker_id, data_center_id):
|
||||
if worker_id > SnowflakeId.MAX_WORKER_ID or worker_id < 0:
|
||||
raise ValueError(
|
||||
"Worker ID can't be greater than "
|
||||
+ str(SnowflakeId.MAX_WORKER_ID)
|
||||
+ " or less than 0"
|
||||
)
|
||||
if data_center_id > SnowflakeId.MAX_DATA_CENTER_ID or data_center_id < 0:
|
||||
raise ValueError(
|
||||
"Data center ID can't be greater than "
|
||||
+ str(SnowflakeId.MAX_DATA_CENTER_ID)
|
||||
+ " or less than 0"
|
||||
)
|
||||
|
||||
self.worker_id = worker_id
|
||||
self.data_center_id = data_center_id
|
||||
self.sequence = 0
|
||||
self.last_timestamp = -1
|
||||
|
||||
def generate_id(self):
|
||||
"""Generate a snowflake id."""
|
||||
timestamp = self._current_millis()
|
||||
|
||||
if timestamp < self.last_timestamp:
|
||||
raise ValueError("Clock moved backwards. Refusing to generate ID.")
|
||||
|
||||
if timestamp == self.last_timestamp:
|
||||
# Within the same millisecond, increment the sequence number
|
||||
self.sequence = (self.sequence + 1) & SnowflakeId.SEQUENCE_MASK
|
||||
if self.sequence == 0:
|
||||
# Sequence exceeds its bit range, wait until the next millisecond
|
||||
timestamp = self._wait_next_millis(self.last_timestamp)
|
||||
else:
|
||||
# New millisecond, reset the sequence number
|
||||
self.sequence = 0
|
||||
|
||||
# Update the last timestamp
|
||||
self.last_timestamp = timestamp
|
||||
|
||||
# Generate and return the final ID
|
||||
return (
|
||||
(
|
||||
(timestamp - SnowflakeId.EPOCH)
|
||||
<< (
|
||||
SnowflakeId.WORKER_ID_BITS
|
||||
+ SnowflakeId.SEQUENCE_BITS
|
||||
+ SnowflakeId.DATA_CENTER_ID_BITS
|
||||
)
|
||||
)
|
||||
| (
|
||||
self.data_center_id
|
||||
<< (SnowflakeId.SEQUENCE_BITS + SnowflakeId.WORKER_ID_BITS)
|
||||
)
|
||||
| (self.worker_id << SnowflakeId.SEQUENCE_BITS)
|
||||
| self.sequence
|
||||
)
|
||||
|
||||
def _current_millis(self):
|
||||
return round(time.time() * 1000)
|
||||
|
||||
def _wait_next_millis(self, last_timestamp):
|
||||
timestamp = self._current_millis()
|
||||
while timestamp <= last_timestamp:
|
||||
timestamp = self._current_millis()
|
||||
return timestamp
|
@ -4,10 +4,10 @@ import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, Optional, Set, cast
|
||||
|
||||
from ..aesprotocol import TPLinkAes
|
||||
from ..credentials import Credentials
|
||||
from ..exceptions import AuthenticationException
|
||||
from ..smartdevice import SmartDevice
|
||||
from ..smartprotocol import SmartProtocol
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@ -26,7 +26,7 @@ class TapoDevice(SmartDevice):
|
||||
super().__init__(host, port=port, credentials=credentials, timeout=timeout)
|
||||
self._state_information: Dict[str, Any] = {}
|
||||
self._discovery_info: Optional[Dict[str, Any]] = None
|
||||
self.protocol = TPLinkAes(host, credentials=credentials, timeout=timeout)
|
||||
self.protocol = SmartProtocol(host, credentials=credentials, timeout=timeout)
|
||||
|
||||
async def update(self, update_children: bool = True):
|
||||
"""Update the device."""
|
||||
|
@ -2,27 +2,45 @@ import asyncio
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from json import dumps as json_dumps
|
||||
from os.path import basename
|
||||
from pathlib import Path, PurePath
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest # type: ignore # see https://github.com/pytest-dev/pytest/issues/3342
|
||||
|
||||
from kasa import (
|
||||
Credentials,
|
||||
Discover,
|
||||
SmartBulb,
|
||||
SmartDimmer,
|
||||
SmartLightStrip,
|
||||
SmartPlug,
|
||||
SmartStrip,
|
||||
TPLinkSmartHomeProtocol,
|
||||
)
|
||||
from kasa.tapo import TapoDevice, TapoPlug
|
||||
|
||||
from .newfakes import FakeTransportProtocol
|
||||
from .newfakes import FakeSmartProtocol, FakeTransportProtocol
|
||||
|
||||
SUPPORTED_DEVICES = glob.glob(
|
||||
SUPPORTED_IOT_DEVICES = [
|
||||
(device, "IOT")
|
||||
for device in glob.glob(
|
||||
os.path.dirname(os.path.abspath(__file__)) + "/fixtures/*.json"
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
SUPPORTED_SMART_DEVICES = [
|
||||
(device, "SMART")
|
||||
for device in glob.glob(
|
||||
os.path.dirname(os.path.abspath(__file__)) + "/fixtures/smart/*.json"
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
SUPPORTED_DEVICES = SUPPORTED_IOT_DEVICES + SUPPORTED_SMART_DEVICES
|
||||
|
||||
|
||||
LIGHT_STRIPS = {"KL400", "KL430", "KL420"}
|
||||
@ -55,43 +73,59 @@ PLUGS = {
|
||||
"KP401",
|
||||
"KS200M",
|
||||
}
|
||||
|
||||
STRIPS = {"HS107", "HS300", "KP303", "KP200", "KP400", "EP40"}
|
||||
DIMMERS = {"ES20M", "HS220", "KS220M", "KS230", "KP405"}
|
||||
|
||||
DIMMABLE = {*BULBS, *DIMMERS}
|
||||
WITH_EMETER = {"HS110", "HS300", "KP115", "KP125", *BULBS}
|
||||
|
||||
ALL_DEVICES = BULBS.union(PLUGS).union(STRIPS).union(DIMMERS)
|
||||
ALL_DEVICES_IOT = BULBS.union(PLUGS).union(STRIPS).union(DIMMERS)
|
||||
|
||||
PLUGS_SMART = {"P110"}
|
||||
ALL_DEVICES_SMART = PLUGS_SMART
|
||||
|
||||
ALL_DEVICES = ALL_DEVICES_IOT.union(ALL_DEVICES_SMART)
|
||||
|
||||
IP_MODEL_CACHE: Dict[str, str] = {}
|
||||
|
||||
|
||||
def filter_model(desc, filter):
|
||||
filtered = list()
|
||||
for dev in SUPPORTED_DEVICES:
|
||||
for filt in filter:
|
||||
if filt in basename(dev):
|
||||
filtered.append(dev)
|
||||
def idgenerator(paramtuple):
|
||||
return basename(paramtuple[0]) + (
|
||||
"" if paramtuple[1] == "IOT" else "-" + paramtuple[1]
|
||||
)
|
||||
|
||||
filtered_basenames = [basename(f) for f in filtered]
|
||||
|
||||
def filter_model(desc, model_filter, protocol_filter=None):
|
||||
if not protocol_filter:
|
||||
protocol_filter = {"IOT"}
|
||||
filtered = list()
|
||||
for file, protocol in SUPPORTED_DEVICES:
|
||||
if protocol in protocol_filter:
|
||||
file_model = basename(file).split("_")[0]
|
||||
for model in model_filter:
|
||||
if model in file_model:
|
||||
filtered.append((file, protocol))
|
||||
|
||||
filtered_basenames = [basename(f) + "-" + p for f, p in filtered]
|
||||
print(f"{desc}: {filtered_basenames}")
|
||||
return filtered
|
||||
|
||||
|
||||
def parametrize(desc, devices, ids=None):
|
||||
def parametrize(desc, devices, protocol_filter=None, ids=None):
|
||||
return pytest.mark.parametrize(
|
||||
"dev", filter_model(desc, devices), indirect=True, ids=ids
|
||||
"dev", filter_model(desc, devices, protocol_filter), indirect=True, ids=ids
|
||||
)
|
||||
|
||||
|
||||
has_emeter = parametrize("has emeter", WITH_EMETER)
|
||||
no_emeter = parametrize("no emeter", ALL_DEVICES - WITH_EMETER)
|
||||
no_emeter = parametrize("no emeter", ALL_DEVICES_IOT - WITH_EMETER)
|
||||
|
||||
bulb = parametrize("bulbs", BULBS, ids=basename)
|
||||
plug = parametrize("plugs", PLUGS, ids=basename)
|
||||
strip = parametrize("strips", STRIPS, ids=basename)
|
||||
dimmer = parametrize("dimmers", DIMMERS, ids=basename)
|
||||
lightstrip = parametrize("lightstrips", LIGHT_STRIPS, ids=basename)
|
||||
bulb = parametrize("bulbs", BULBS, ids=idgenerator)
|
||||
plug = parametrize("plugs", PLUGS, ids=idgenerator)
|
||||
strip = parametrize("strips", STRIPS, ids=idgenerator)
|
||||
dimmer = parametrize("dimmers", DIMMERS, ids=idgenerator)
|
||||
lightstrip = parametrize("lightstrips", LIGHT_STRIPS, ids=idgenerator)
|
||||
|
||||
# bulb types
|
||||
dimmable = parametrize("dimmable", DIMMABLE)
|
||||
@ -101,6 +135,58 @@ non_variable_temp = parametrize("non-variable color temp", BULBS - VARIABLE_TEMP
|
||||
color_bulb = parametrize("color bulbs", COLOR_BULBS)
|
||||
non_color_bulb = parametrize("non-color bulbs", BULBS - COLOR_BULBS)
|
||||
|
||||
plug_smart = parametrize(
|
||||
"plug devices smart", PLUGS_SMART, protocol_filter={"SMART"}, ids=idgenerator
|
||||
)
|
||||
device_smart = parametrize(
|
||||
"devices smart", ALL_DEVICES_SMART, protocol_filter={"SMART"}, ids=idgenerator
|
||||
)
|
||||
device_iot = parametrize(
|
||||
"devices iot", ALL_DEVICES_IOT, protocol_filter={"IOT"}, ids=idgenerator
|
||||
)
|
||||
|
||||
|
||||
def get_fixture_data():
|
||||
"""Return raw discovery file contents as JSON. Used for discovery tests."""
|
||||
fixture_data = {}
|
||||
for file, protocol in SUPPORTED_DEVICES:
|
||||
p = Path(file)
|
||||
if not p.is_absolute():
|
||||
folder = Path(__file__).parent / "fixtures"
|
||||
if protocol == "SMART":
|
||||
folder = folder / "smart"
|
||||
p = folder / file
|
||||
|
||||
with open(p) as f:
|
||||
fixture_data[basename(p)] = json.load(f)
|
||||
return fixture_data
|
||||
|
||||
|
||||
FIXTURE_DATA = get_fixture_data()
|
||||
|
||||
|
||||
def filter_fixtures(desc, root_filter):
|
||||
filtered = {}
|
||||
for key, val in FIXTURE_DATA.items():
|
||||
if root_filter in val:
|
||||
filtered[key] = val
|
||||
|
||||
print(f"{desc}: {filtered.keys()}")
|
||||
return filtered
|
||||
|
||||
|
||||
def parametrize_discovery(desc, root_key):
|
||||
filtered_fixtures = filter_fixtures(desc, root_key)
|
||||
return pytest.mark.parametrize(
|
||||
"discovery_data",
|
||||
filtered_fixtures.values(),
|
||||
indirect=True,
|
||||
ids=filtered_fixtures.keys(),
|
||||
)
|
||||
|
||||
|
||||
new_discovery = parametrize_discovery("new discovery", "discovery_result")
|
||||
|
||||
|
||||
def check_categories():
|
||||
"""Check that every fixture file is categorized."""
|
||||
@ -110,15 +196,15 @@ def check_categories():
|
||||
+ plug.args[1]
|
||||
+ bulb.args[1]
|
||||
+ lightstrip.args[1]
|
||||
+ plug_smart.args[1]
|
||||
)
|
||||
diff = set(SUPPORTED_DEVICES) - set(categorized_fixtures)
|
||||
if diff:
|
||||
for file in diff:
|
||||
for file, protocol in diff:
|
||||
print(
|
||||
"No category for file %s, add to the corresponding set (BULBS, PLUGS, ..)"
|
||||
% file
|
||||
f"No category for file {file} protocol {protocol}, add to the corresponding set (BULBS, PLUGS, ..)"
|
||||
)
|
||||
raise Exception("Missing category for %s" % diff)
|
||||
raise Exception(f"Missing category for {diff}")
|
||||
|
||||
|
||||
check_categories()
|
||||
@ -134,7 +220,12 @@ async def handle_turn_on(dev, turn_on):
|
||||
await dev.turn_off()
|
||||
|
||||
|
||||
def device_for_file(model):
|
||||
def device_for_file(model, protocol):
|
||||
if protocol == "SMART":
|
||||
for d in PLUGS_SMART:
|
||||
if d in model:
|
||||
return TapoPlug
|
||||
else:
|
||||
for d in STRIPS:
|
||||
if d in model:
|
||||
return SmartStrip
|
||||
@ -170,11 +261,14 @@ async def _discover_update_and_close(ip):
|
||||
return await _update_and_close(d)
|
||||
|
||||
|
||||
async def get_device_for_file(file):
|
||||
async def get_device_for_file(file, protocol):
|
||||
# if the wanted file is not an absolute path, prepend the fixtures directory
|
||||
p = Path(file)
|
||||
if not p.is_absolute():
|
||||
p = Path(__file__).parent / "fixtures" / file
|
||||
folder = Path(__file__).parent / "fixtures"
|
||||
if protocol == "SMART":
|
||||
folder = folder / "smart"
|
||||
p = folder / file
|
||||
|
||||
def load_file():
|
||||
with open(p) as f:
|
||||
@ -184,7 +278,11 @@ async def get_device_for_file(file):
|
||||
sysinfo = await loop.run_in_executor(None, load_file)
|
||||
|
||||
model = basename(file)
|
||||
d = device_for_file(model)(host="127.0.0.123")
|
||||
d = device_for_file(model, protocol)(host="127.0.0.123")
|
||||
if protocol == "SMART":
|
||||
d.protocol = FakeSmartProtocol(sysinfo)
|
||||
d.credentials = Credentials("", "")
|
||||
else:
|
||||
d.protocol = FakeTransportProtocol(sysinfo)
|
||||
await _update_and_close(d)
|
||||
return d
|
||||
@ -197,7 +295,7 @@ async def dev(request):
|
||||
Provides a device (given --ip) or parametrized fixture for the supported devices.
|
||||
The initial update is called automatically before returning the device.
|
||||
"""
|
||||
file = request.param
|
||||
file, protocol = request.param
|
||||
|
||||
ip = request.config.getoption("--ip")
|
||||
if ip:
|
||||
@ -210,19 +308,62 @@ async def dev(request):
|
||||
pytest.skip(f"skipping file {file}")
|
||||
return d if d else await _discover_update_and_close(ip)
|
||||
|
||||
return await get_device_for_file(file)
|
||||
return await get_device_for_file(file, protocol)
|
||||
|
||||
|
||||
@pytest.fixture(params=SUPPORTED_DEVICES, scope="session")
|
||||
@pytest.fixture
|
||||
def discovery_mock(discovery_data, mocker):
|
||||
@dataclass
|
||||
class _DiscoveryMock:
|
||||
ip: str
|
||||
default_port: int
|
||||
discovery_data: dict
|
||||
port_override: Optional[int] = None
|
||||
|
||||
if "result" in discovery_data:
|
||||
datagram = (
|
||||
b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8"
|
||||
+ json_dumps(discovery_data).encode()
|
||||
)
|
||||
dm = _DiscoveryMock("127.0.0.123", 20002, discovery_data)
|
||||
else:
|
||||
datagram = TPLinkSmartHomeProtocol.encrypt(json_dumps(discovery_data))[4:]
|
||||
dm = _DiscoveryMock("127.0.0.123", 9999, discovery_data)
|
||||
|
||||
def mock_discover(self):
|
||||
port = (
|
||||
dm.port_override
|
||||
if dm.port_override and dm.default_port != 20002
|
||||
else dm.default_port
|
||||
)
|
||||
self.datagram_received(
|
||||
datagram,
|
||||
(dm.ip, port),
|
||||
)
|
||||
|
||||
mocker.patch("kasa.discover._DiscoverProtocol.do_discover", mock_discover)
|
||||
mocker.patch(
|
||||
"socket.getaddrinfo",
|
||||
side_effect=lambda *_, **__: [(None, None, None, None, (dm.ip, 0))],
|
||||
)
|
||||
yield dm
|
||||
|
||||
|
||||
@pytest.fixture(params=FIXTURE_DATA.values(), ids=FIXTURE_DATA.keys(), scope="session")
|
||||
def discovery_data(request):
|
||||
"""Return raw discovery file contents as JSON. Used for discovery tests."""
|
||||
file = request.param
|
||||
p = Path(file)
|
||||
if not p.is_absolute():
|
||||
p = Path(__file__).parent / "fixtures" / file
|
||||
fixture_data = request.param
|
||||
if "discovery_result" in fixture_data:
|
||||
return {"result": fixture_data["discovery_result"]}
|
||||
else:
|
||||
return {"system": {"get_sysinfo": fixture_data["system"]["get_sysinfo"]}}
|
||||
|
||||
with open(p) as f:
|
||||
return json.load(f)
|
||||
|
||||
@pytest.fixture(params=FIXTURE_DATA.values(), ids=FIXTURE_DATA.keys(), scope="session")
|
||||
def all_fixture_data(request):
|
||||
"""Return raw fixture file contents as JSON. Used for discovery tests."""
|
||||
fixture_data = request.param
|
||||
return fixture_data
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
|
180
kasa/tests/fixtures/smart/P110_1.0_1.3.0.json
vendored
Normal file
180
kasa/tests/fixtures/smart/P110_1.0_1.3.0.json
vendored
Normal file
@ -0,0 +1,180 @@
|
||||
{
|
||||
"component_nego": {
|
||||
"component_list": [
|
||||
{
|
||||
"id": "device",
|
||||
"ver_code": 2
|
||||
},
|
||||
{
|
||||
"id": "firmware",
|
||||
"ver_code": 2
|
||||
},
|
||||
{
|
||||
"id": "quick_setup",
|
||||
"ver_code": 3
|
||||
},
|
||||
{
|
||||
"id": "time",
|
||||
"ver_code": 1
|
||||
},
|
||||
{
|
||||
"id": "wireless",
|
||||
"ver_code": 1
|
||||
},
|
||||
{
|
||||
"id": "schedule",
|
||||
"ver_code": 2
|
||||
},
|
||||
{
|
||||
"id": "countdown",
|
||||
"ver_code": 2
|
||||
},
|
||||
{
|
||||
"id": "antitheft",
|
||||
"ver_code": 1
|
||||
},
|
||||
{
|
||||
"id": "account",
|
||||
"ver_code": 1
|
||||
},
|
||||
{
|
||||
"id": "synchronize",
|
||||
"ver_code": 1
|
||||
},
|
||||
{
|
||||
"id": "sunrise_sunset",
|
||||
"ver_code": 1
|
||||
},
|
||||
{
|
||||
"id": "led",
|
||||
"ver_code": 1
|
||||
},
|
||||
{
|
||||
"id": "cloud_connect",
|
||||
"ver_code": 1
|
||||
},
|
||||
{
|
||||
"id": "iot_cloud",
|
||||
"ver_code": 1
|
||||
},
|
||||
{
|
||||
"id": "device_local_time",
|
||||
"ver_code": 1
|
||||
},
|
||||
{
|
||||
"id": "default_states",
|
||||
"ver_code": 1
|
||||
},
|
||||
{
|
||||
"id": "auto_off",
|
||||
"ver_code": 2
|
||||
},
|
||||
{
|
||||
"id": "localSmart",
|
||||
"ver_code": 1
|
||||
},
|
||||
{
|
||||
"id": "energy_monitoring",
|
||||
"ver_code": 2
|
||||
},
|
||||
{
|
||||
"id": "power_protection",
|
||||
"ver_code": 1
|
||||
},
|
||||
{
|
||||
"id": "current_protection",
|
||||
"ver_code": 1
|
||||
}
|
||||
]
|
||||
},
|
||||
"discovery_result": {
|
||||
"device_id": "00000000000000000000000000000000",
|
||||
"device_model": "P110(UK)",
|
||||
"device_type": "SMART.TAPOPLUG",
|
||||
"factory_default": false,
|
||||
"ip": "127.0.0.123",
|
||||
"is_support_iot_cloud": true,
|
||||
"mac": "00-00-00-00-00-00",
|
||||
"mgt_encrypt_schm": {
|
||||
"encrypt_type": "KLAP",
|
||||
"http_port": 80,
|
||||
"is_support_https": false,
|
||||
"lv": 2
|
||||
},
|
||||
"obd_src": "tplink",
|
||||
"owner": "00000000000000000000000000000000"
|
||||
},
|
||||
"get_current_power": {
|
||||
"current_power": 0
|
||||
},
|
||||
"get_device_info": {
|
||||
"auto_off_remain_time": 0,
|
||||
"auto_off_status": "off",
|
||||
"avatar": "plug",
|
||||
"default_states": {
|
||||
"state": {},
|
||||
"type": "last_states"
|
||||
},
|
||||
"device_id": "0000000000000000000000000000000000000000",
|
||||
"device_on": true,
|
||||
"fw_id": "00000000000000000000000000000000",
|
||||
"fw_ver": "1.3.0 Build 230905 Rel.152200",
|
||||
"has_set_location_info": true,
|
||||
"hw_id": "00000000000000000000000000000000",
|
||||
"hw_ver": "1.0",
|
||||
"ip": "127.0.0.123",
|
||||
"lang": "en_US",
|
||||
"latitude": 0,
|
||||
"longitude": 0,
|
||||
"mac": "00-00-00-00-00-00",
|
||||
"model": "P110",
|
||||
"nickname": "VGFwaSBTbWFydCBQbHVnIDE=",
|
||||
"oem_id": "00000000000000000000000000000000",
|
||||
"on_time": 119335,
|
||||
"overcurrent_status": "normal",
|
||||
"overheated": false,
|
||||
"power_protection_status": "normal",
|
||||
"region": "Europe/London",
|
||||
"rssi": -57,
|
||||
"signal_level": 2,
|
||||
"specs": "",
|
||||
"ssid": "IyNNQVNLRUROQU1FIyM=",
|
||||
"time_diff": 0,
|
||||
"type": "SMART.TAPOPLUG"
|
||||
},
|
||||
"get_device_time": {
|
||||
"region": "Europe/London",
|
||||
"time_diff": 0,
|
||||
"timestamp": 1701370224
|
||||
},
|
||||
"get_device_usage": {
|
||||
"power_usage": {
|
||||
"past30": 75,
|
||||
"past7": 69,
|
||||
"today": 0
|
||||
},
|
||||
"saved_power": {
|
||||
"past30": 2029,
|
||||
"past7": 1964,
|
||||
"today": 1130
|
||||
},
|
||||
"time_usage": {
|
||||
"past30": 2104,
|
||||
"past7": 2033,
|
||||
"today": 1130
|
||||
}
|
||||
},
|
||||
"get_energy_usage": {
|
||||
"current_power": 0,
|
||||
"electricity_charge": [
|
||||
0,
|
||||
0,
|
||||
0
|
||||
],
|
||||
"local_time": "2023-11-30 18:50:24",
|
||||
"month_energy": 75,
|
||||
"month_runtime": 2104,
|
||||
"today_energy": 0,
|
||||
"today_runtime": 1130
|
||||
}
|
||||
}
|
@ -1,6 +1,7 @@
|
||||
import copy
|
||||
import logging
|
||||
import re
|
||||
from json import loads as json_loads
|
||||
|
||||
from voluptuous import (
|
||||
REMOVE_EXTRA,
|
||||
@ -13,7 +14,8 @@ from voluptuous import (
|
||||
Schema,
|
||||
)
|
||||
|
||||
from ..protocol import TPLinkSmartHomeProtocol
|
||||
from ..protocol import BaseTransport, TPLinkSmartHomeProtocol
|
||||
from ..smartprotocol import SmartProtocol
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@ -285,6 +287,41 @@ TIME_MODULE = {
|
||||
}
|
||||
|
||||
|
||||
class FakeSmartProtocol(SmartProtocol):
|
||||
def __init__(self, info):
|
||||
super().__init__("127.0.0.123", transport=FakeSmartTransport(info))
|
||||
|
||||
|
||||
class FakeSmartTransport(BaseTransport):
|
||||
def __init__(self, info):
|
||||
self.info = info
|
||||
|
||||
@property
|
||||
def needs_handshake(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def needs_login(self) -> bool:
|
||||
return False
|
||||
|
||||
async def login(self, request: str) -> None:
|
||||
pass
|
||||
|
||||
async def handshake(self) -> None:
|
||||
pass
|
||||
|
||||
async def send(self, request: str):
|
||||
request_dict = json_loads(request)
|
||||
method = request_dict["method"]
|
||||
if method == "component_nego" or method[:4] == "get_":
|
||||
return self.info[method]
|
||||
elif method[:4] == "set_":
|
||||
_LOGGER.debug("Call %s not implemented, doing nothing", method)
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class FakeTransportProtocol(TPLinkSmartHomeProtocol):
|
||||
def __init__(self, info):
|
||||
self.discovery_data = info
|
||||
|
@ -6,12 +6,15 @@ from asyncclick.testing import CliRunner
|
||||
|
||||
from kasa import SmartDevice, TPLinkSmartHomeProtocol
|
||||
from kasa.cli import alias, brightness, cli, emeter, raw_command, state, sysinfo, toggle
|
||||
from kasa.device_factory import DEVICE_TYPE_TO_CLASS
|
||||
from kasa.discover import Discover
|
||||
from kasa.smartprotocol import SmartProtocol
|
||||
|
||||
from .conftest import handle_turn_on, turn_on
|
||||
from .newfakes import FakeTransportProtocol
|
||||
from .conftest import device_iot, handle_turn_on, new_discovery, turn_on
|
||||
from .newfakes import FakeSmartProtocol, FakeTransportProtocol
|
||||
|
||||
|
||||
@device_iot
|
||||
async def test_sysinfo(dev):
|
||||
runner = CliRunner()
|
||||
res = await runner.invoke(sysinfo, obj=dev)
|
||||
@ -19,6 +22,7 @@ async def test_sysinfo(dev):
|
||||
assert dev.alias in res.output
|
||||
|
||||
|
||||
@device_iot
|
||||
@turn_on
|
||||
async def test_state(dev, turn_on):
|
||||
await handle_turn_on(dev, turn_on)
|
||||
@ -32,6 +36,7 @@ async def test_state(dev, turn_on):
|
||||
assert "Device state: False" in res.output
|
||||
|
||||
|
||||
@device_iot
|
||||
@turn_on
|
||||
async def test_toggle(dev, turn_on, mocker):
|
||||
await handle_turn_on(dev, turn_on)
|
||||
@ -44,6 +49,7 @@ async def test_toggle(dev, turn_on, mocker):
|
||||
assert dev.is_on
|
||||
|
||||
|
||||
@device_iot
|
||||
async def test_alias(dev):
|
||||
runner = CliRunner()
|
||||
|
||||
@ -62,6 +68,7 @@ async def test_alias(dev):
|
||||
await dev.set_alias(old_alias)
|
||||
|
||||
|
||||
@device_iot
|
||||
async def test_raw_command(dev):
|
||||
runner = CliRunner()
|
||||
res = await runner.invoke(raw_command, ["system", "get_sysinfo"], obj=dev)
|
||||
@ -74,6 +81,7 @@ async def test_raw_command(dev):
|
||||
assert "Usage" in res.output
|
||||
|
||||
|
||||
@device_iot
|
||||
async def test_emeter(dev: SmartDevice, mocker):
|
||||
runner = CliRunner()
|
||||
|
||||
@ -99,6 +107,7 @@ async def test_emeter(dev: SmartDevice, mocker):
|
||||
daily.assert_called_with(year=1900, month=12)
|
||||
|
||||
|
||||
@device_iot
|
||||
async def test_brightness(dev):
|
||||
runner = CliRunner()
|
||||
res = await runner.invoke(brightness, obj=dev)
|
||||
@ -116,6 +125,7 @@ async def test_brightness(dev):
|
||||
assert "Brightness: 12" in res.output
|
||||
|
||||
|
||||
@device_iot
|
||||
async def test_json_output(dev: SmartDevice, mocker):
|
||||
"""Test that the json output produces correct output."""
|
||||
mocker.patch("kasa.Discover.discover", return_value=[dev])
|
||||
@ -125,13 +135,9 @@ async def test_json_output(dev: SmartDevice, mocker):
|
||||
assert json.loads(res.output) == dev.internal_state
|
||||
|
||||
|
||||
async def test_credentials(discovery_data: dict, mocker):
|
||||
@new_discovery
|
||||
async def test_credentials(discovery_mock, mocker):
|
||||
"""Test credentials are passed correctly from cli to device."""
|
||||
# As this is testing the device constructor need to explicitly wire in
|
||||
# the FakeTransportProtocol
|
||||
ftp = FakeTransportProtocol(discovery_data)
|
||||
mocker.patch.object(TPLinkSmartHomeProtocol, "query", ftp.query)
|
||||
|
||||
# Patch state to echo username and password
|
||||
pass_dev = click.make_pass_decorator(SmartDevice)
|
||||
|
||||
@ -143,18 +149,15 @@ async def test_credentials(discovery_data: dict, mocker):
|
||||
)
|
||||
|
||||
mocker.patch("kasa.cli.state", new=_state)
|
||||
cli_device_type = Discover._get_device_class(discovery_data)(
|
||||
"any"
|
||||
).device_type.value
|
||||
for subclass in DEVICE_TYPE_TO_CLASS.values():
|
||||
mocker.patch.object(subclass, "update")
|
||||
|
||||
runner = CliRunner()
|
||||
res = await runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"--host",
|
||||
"127.0.0.1",
|
||||
"--type",
|
||||
cli_device_type,
|
||||
"127.0.0.123",
|
||||
"--username",
|
||||
"foo",
|
||||
"--password",
|
||||
@ -162,9 +165,11 @@ async def test_credentials(discovery_data: dict, mocker):
|
||||
],
|
||||
)
|
||||
assert res.exit_code == 0
|
||||
assert res.output == "Username:foo Password:bar\n"
|
||||
|
||||
assert "Username:foo Password:bar\n" in res.output
|
||||
|
||||
|
||||
@device_iot
|
||||
async def test_without_device_type(discovery_data: dict, dev, mocker):
|
||||
"""Test connecting without the device type."""
|
||||
runner = CliRunner()
|
||||
|
@ -5,7 +5,9 @@ from typing import Type
|
||||
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
|
||||
|
||||
from kasa import (
|
||||
Credentials,
|
||||
DeviceType,
|
||||
Discover,
|
||||
SmartBulb,
|
||||
SmartDevice,
|
||||
SmartDeviceException,
|
||||
@ -13,8 +15,13 @@ from kasa import (
|
||||
SmartLightStrip,
|
||||
SmartPlug,
|
||||
)
|
||||
from kasa.device_factory import connect
|
||||
from kasa.klapprotocol import TPLinkKlap
|
||||
from kasa.device_factory import (
|
||||
DEVICE_TYPE_TO_CLASS,
|
||||
connect,
|
||||
get_protocol_from_connection_name,
|
||||
)
|
||||
from kasa.discover import DiscoveryResult
|
||||
from kasa.iotprotocol import IotProtocol
|
||||
from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol
|
||||
|
||||
|
||||
@ -22,8 +29,12 @@ from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol
|
||||
async def test_connect(discovery_data: dict, mocker, custom_port):
|
||||
"""Make sure that connect returns an initialized SmartDevice instance."""
|
||||
host = "127.0.0.1"
|
||||
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
|
||||
|
||||
if "result" in discovery_data:
|
||||
with pytest.raises(SmartDeviceException):
|
||||
dev = await connect(host, port=custom_port)
|
||||
else:
|
||||
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
|
||||
dev = await connect(host, port=custom_port)
|
||||
assert issubclass(dev.__class__, SmartDevice)
|
||||
assert dev.port == custom_port or dev.port == 9999
|
||||
@ -49,8 +60,12 @@ async def test_connect_passed_device_type(
|
||||
):
|
||||
"""Make sure that connect with a passed device type."""
|
||||
host = "127.0.0.1"
|
||||
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
|
||||
|
||||
if "result" in discovery_data:
|
||||
with pytest.raises(SmartDeviceException):
|
||||
dev = await connect(host, port=custom_port)
|
||||
else:
|
||||
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
|
||||
dev = await connect(host, port=custom_port, device_type=device_type)
|
||||
assert isinstance(dev, klass)
|
||||
assert dev.port == custom_port or dev.port == 9999
|
||||
@ -70,32 +85,52 @@ async def test_connect_logs_connect_time(
|
||||
):
|
||||
"""Test that the connect time is logged when debug logging is enabled."""
|
||||
host = "127.0.0.1"
|
||||
if "result" in discovery_data:
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await connect(host)
|
||||
else:
|
||||
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
|
||||
logging.getLogger("kasa").setLevel(logging.DEBUG)
|
||||
await connect(host)
|
||||
assert "seconds to connect" in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device_type", [DeviceType.Plug, None])
|
||||
@pytest.mark.parametrize(
|
||||
("protocol_in", "protocol_result"),
|
||||
(
|
||||
(None, TPLinkSmartHomeProtocol),
|
||||
(TPLinkKlap, TPLinkKlap),
|
||||
(TPLinkSmartHomeProtocol, TPLinkSmartHomeProtocol),
|
||||
),
|
||||
)
|
||||
async def test_connect_pass_protocol(
|
||||
discovery_data: dict,
|
||||
all_fixture_data: dict,
|
||||
mocker,
|
||||
device_type: DeviceType,
|
||||
protocol_in: Type[TPLinkProtocol],
|
||||
protocol_result: Type[TPLinkProtocol],
|
||||
):
|
||||
"""Test that if the protocol is passed in it's gets set correctly."""
|
||||
host = "127.0.0.1"
|
||||
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
|
||||
mocker.patch("kasa.TPLinkKlap.query", return_value=discovery_data)
|
||||
if "discovery_result" in all_fixture_data:
|
||||
discovery_info = {"result": all_fixture_data["discovery_result"]}
|
||||
device_class = Discover._get_device_class(discovery_info)
|
||||
else:
|
||||
device_class = Discover._get_device_class(all_fixture_data)
|
||||
|
||||
dev = await connect(host, device_type=device_type, protocol_class=protocol_in)
|
||||
assert isinstance(dev.protocol, protocol_result)
|
||||
device_type = list(DEVICE_TYPE_TO_CLASS.keys())[
|
||||
list(DEVICE_TYPE_TO_CLASS.values()).index(device_class)
|
||||
]
|
||||
host = "127.0.0.1"
|
||||
if "discovery_result" in all_fixture_data:
|
||||
mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data)
|
||||
mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data)
|
||||
|
||||
dr = DiscoveryResult(**discovery_info["result"])
|
||||
connection_name = (
|
||||
dr.device_type.split(".")[0] + "." + dr.mgt_encrypt_schm.encrypt_type
|
||||
)
|
||||
protocol_class = get_protocol_from_connection_name(
|
||||
connection_name, host
|
||||
).__class__
|
||||
else:
|
||||
mocker.patch(
|
||||
"kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data
|
||||
)
|
||||
protocol_class = TPLinkSmartHomeProtocol
|
||||
|
||||
dev = await connect(
|
||||
host,
|
||||
device_type=device_type,
|
||||
protocol_class=protocol_class,
|
||||
credentials=Credentials("", ""),
|
||||
)
|
||||
assert isinstance(dev.protocol, protocol_class)
|
||||
|
@ -17,6 +17,27 @@ from kasa.exceptions import AuthenticationException, UnsupportedDeviceException
|
||||
|
||||
from .conftest import bulb, dimmer, lightstrip, plug, strip
|
||||
|
||||
UNSUPPORTED = {
|
||||
"result": {
|
||||
"device_id": "xx",
|
||||
"owner": "xx",
|
||||
"device_type": "SMART.TAPOXMASTREE",
|
||||
"device_model": "P110(EU)",
|
||||
"ip": "127.0.0.1",
|
||||
"mac": "48-22xxx",
|
||||
"is_support_iot_cloud": True,
|
||||
"obd_src": "tplink",
|
||||
"factory_default": False,
|
||||
"mgt_encrypt_schm": {
|
||||
"is_support_https": False,
|
||||
"encrypt_type": "AES",
|
||||
"http_port": 80,
|
||||
"lv": 2,
|
||||
},
|
||||
},
|
||||
"error_code": 0,
|
||||
}
|
||||
|
||||
|
||||
@plug
|
||||
async def test_type_detection_plug(dev: SmartDevice):
|
||||
@ -62,76 +83,40 @@ async def test_type_unknown():
|
||||
|
||||
|
||||
@pytest.mark.parametrize("custom_port", [123, None])
|
||||
async def test_discover_single(discovery_data: dict, mocker, custom_port):
|
||||
# @pytest.mark.parametrize("discovery_mock", [("127.0.0.1",123), ("127.0.0.1",None)], indirect=True)
|
||||
async def test_discover_single(discovery_mock, custom_port, mocker):
|
||||
"""Make sure that discover_single returns an initialized SmartDevice instance."""
|
||||
host = "127.0.0.1"
|
||||
info = {"system": {"get_sysinfo": discovery_data["system"]["get_sysinfo"]}}
|
||||
query_mock = mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=info)
|
||||
|
||||
def mock_discover(self):
|
||||
self.datagram_received(
|
||||
protocol.TPLinkSmartHomeProtocol.encrypt(json_dumps(info))[4:],
|
||||
(host, custom_port or 9999),
|
||||
)
|
||||
|
||||
mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover)
|
||||
discovery_mock.ip = host
|
||||
discovery_mock.port_override = custom_port
|
||||
update_mock = mocker.patch.object(SmartStrip, "update")
|
||||
|
||||
x = await Discover.discover_single(host, port=custom_port)
|
||||
assert issubclass(x.__class__, SmartDevice)
|
||||
assert x._sys_info is not None
|
||||
assert x.port == custom_port or x.port == 9999
|
||||
assert (query_mock.call_count > 0) == isinstance(x, SmartStrip)
|
||||
assert x._discovery_info is not None
|
||||
assert x.port == custom_port or x.port == discovery_mock.default_port
|
||||
assert (update_mock.call_count > 0) == isinstance(x, SmartStrip)
|
||||
|
||||
|
||||
async def test_discover_single_hostname(discovery_data: dict, mocker):
|
||||
async def test_discover_single_hostname(discovery_mock, mocker):
|
||||
"""Make sure that discover_single returns an initialized SmartDevice instance."""
|
||||
host = "foobar"
|
||||
ip = "127.0.0.1"
|
||||
info = {"system": {"get_sysinfo": discovery_data["system"]["get_sysinfo"]}}
|
||||
query_mock = mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=info)
|
||||
|
||||
def mock_discover(self):
|
||||
self.datagram_received(
|
||||
protocol.TPLinkSmartHomeProtocol.encrypt(json_dumps(info))[4:],
|
||||
(ip, 9999),
|
||||
)
|
||||
|
||||
mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover)
|
||||
mocker.patch("socket.getaddrinfo", return_value=[(None, None, None, None, (ip, 0))])
|
||||
discovery_mock.ip = ip
|
||||
update_mock = mocker.patch.object(SmartStrip, "update")
|
||||
|
||||
x = await Discover.discover_single(host)
|
||||
assert issubclass(x.__class__, SmartDevice)
|
||||
assert x._sys_info is not None
|
||||
assert x._discovery_info is not None
|
||||
assert x.host == host
|
||||
assert (query_mock.call_count > 0) == isinstance(x, SmartStrip)
|
||||
assert (update_mock.call_count > 0) == isinstance(x, SmartStrip)
|
||||
|
||||
mocker.patch("socket.getaddrinfo", side_effect=socket.gaierror())
|
||||
with pytest.raises(SmartDeviceException):
|
||||
x = await Discover.discover_single(host)
|
||||
|
||||
|
||||
UNSUPPORTED = {
|
||||
"result": {
|
||||
"device_id": "xx",
|
||||
"owner": "xx",
|
||||
"device_type": "SMART.TAPOXMASTREE",
|
||||
"device_model": "P110(EU)",
|
||||
"ip": "127.0.0.1",
|
||||
"mac": "48-22xxx",
|
||||
"is_support_iot_cloud": True,
|
||||
"obd_src": "tplink",
|
||||
"factory_default": False,
|
||||
"mgt_encrypt_schm": {
|
||||
"is_support_https": False,
|
||||
"encrypt_type": "AES",
|
||||
"http_port": 80,
|
||||
"lv": 2,
|
||||
},
|
||||
},
|
||||
"error_code": 0,
|
||||
}
|
||||
|
||||
|
||||
async def test_discover_single_unsupported(mocker):
|
||||
"""Make sure that discover_single handles unsupported devices correctly."""
|
||||
host = "127.0.0.1"
|
||||
@ -201,14 +186,17 @@ async def test_discover_send(mocker):
|
||||
async def test_discover_datagram_received(mocker, discovery_data):
|
||||
"""Verify that datagram received fills discovered_devices."""
|
||||
proto = _DiscoverProtocol()
|
||||
info = {"system": {"get_sysinfo": discovery_data["system"]["get_sysinfo"]}}
|
||||
mocker.patch("kasa.discover.json_loads", return_value=info)
|
||||
mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "encrypt")
|
||||
|
||||
mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "decrypt")
|
||||
|
||||
addr = "127.0.0.1"
|
||||
proto.datagram_received("<placeholder data>", (addr, 9999))
|
||||
port = 20002 if "result" in discovery_data else 9999
|
||||
|
||||
mocker.patch("kasa.discover.json_loads", return_value=discovery_data)
|
||||
proto.datagram_received("<placeholder data>", (addr, port))
|
||||
|
||||
addr2 = "127.0.0.2"
|
||||
mocker.patch("kasa.discover.json_loads", return_value=UNSUPPORTED)
|
||||
proto.datagram_received("<placeholder data>", (addr2, 20002))
|
||||
|
||||
# Check that device in discovered_devices is initialized correctly
|
||||
|
@ -10,9 +10,14 @@ from contextlib import nullcontext as does_not_raise
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from ..aestransport import AesTransport
|
||||
from ..credentials import Credentials
|
||||
from ..exceptions import AuthenticationException, SmartDeviceException
|
||||
from ..klapprotocol import KlapEncryptionSession, TPLinkKlap, _sha256
|
||||
from ..iotprotocol import IotProtocol
|
||||
from ..klaptransport import KlapEncryptionSession, KlapTransport, _sha256
|
||||
from ..smartprotocol import SmartProtocol
|
||||
|
||||
DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}}
|
||||
|
||||
|
||||
class _mock_response:
|
||||
@ -21,67 +26,92 @@ class _mock_response:
|
||||
self.content = content
|
||||
|
||||
|
||||
@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
|
||||
@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol])
|
||||
@pytest.mark.parametrize("retry_count", [1, 3, 5])
|
||||
async def test_protocol_retries(mocker, retry_count):
|
||||
async def test_protocol_retries(mocker, retry_count, protocol_class, transport_class):
|
||||
host = "127.0.0.1"
|
||||
conn = mocker.patch.object(
|
||||
TPLinkKlap, "client_post", side_effect=Exception("dummy exception")
|
||||
transport_class, "client_post", side_effect=Exception("dummy exception")
|
||||
)
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await TPLinkKlap("127.0.0.1").query({}, retry_count=retry_count)
|
||||
await protocol_class(host, transport=transport_class(host)).query(
|
||||
DUMMY_QUERY, retry_count=retry_count
|
||||
)
|
||||
|
||||
assert conn.call_count == retry_count + 1
|
||||
|
||||
|
||||
async def test_protocol_no_retry_on_connection_error(mocker):
|
||||
@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
|
||||
@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol])
|
||||
async def test_protocol_no_retry_on_connection_error(
|
||||
mocker, protocol_class, transport_class
|
||||
):
|
||||
host = "127.0.0.1"
|
||||
conn = mocker.patch.object(
|
||||
TPLinkKlap,
|
||||
transport_class,
|
||||
"client_post",
|
||||
side_effect=httpx.ConnectError("foo"),
|
||||
)
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await TPLinkKlap("127.0.0.1").query({}, retry_count=5)
|
||||
await protocol_class(host, transport=transport_class(host)).query(
|
||||
DUMMY_QUERY, retry_count=5
|
||||
)
|
||||
|
||||
assert conn.call_count == 1
|
||||
|
||||
|
||||
async def test_protocol_retry_recoverable_error(mocker):
|
||||
@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
|
||||
@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol])
|
||||
async def test_protocol_retry_recoverable_error(
|
||||
mocker, protocol_class, transport_class
|
||||
):
|
||||
host = "127.0.0.1"
|
||||
conn = mocker.patch.object(
|
||||
TPLinkKlap,
|
||||
transport_class,
|
||||
"client_post",
|
||||
side_effect=httpx.CloseError("foo"),
|
||||
)
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await TPLinkKlap("127.0.0.1").query({}, retry_count=5)
|
||||
await protocol_class(host, transport=transport_class(host)).query(
|
||||
DUMMY_QUERY, retry_count=5
|
||||
)
|
||||
|
||||
assert conn.call_count == 6
|
||||
|
||||
|
||||
@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
|
||||
@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol])
|
||||
@pytest.mark.parametrize("retry_count", [1, 3, 5])
|
||||
async def test_protocol_reconnect(mocker, retry_count):
|
||||
async def test_protocol_reconnect(mocker, retry_count, protocol_class, transport_class):
|
||||
host = "127.0.0.1"
|
||||
remaining = retry_count
|
||||
mock_response = {"result": {"great": "success"}}
|
||||
|
||||
def _fail_one_less_than_retry_count(*_, **__):
|
||||
nonlocal remaining, encryption_session
|
||||
nonlocal remaining
|
||||
remaining -= 1
|
||||
if remaining:
|
||||
raise Exception("Simulated post failure")
|
||||
# Do the encrypt just before returning the value so the incrementing sequence number is correct
|
||||
encrypted, seq = encryption_session.encrypt('{"great":"success"}')
|
||||
return 200, encrypted
|
||||
|
||||
seed = secrets.token_bytes(16)
|
||||
auth_hash = TPLinkKlap.generate_auth_hash(Credentials("foo", "bar"))
|
||||
encryption_session = KlapEncryptionSession(seed, seed, auth_hash)
|
||||
protocol = TPLinkKlap("127.0.0.1")
|
||||
protocol.handshake_done = True
|
||||
protocol.session_expire_at = time.time() + 86400
|
||||
protocol.encryption_session = encryption_session
|
||||
return mock_response
|
||||
|
||||
mocker.patch.object(
|
||||
TPLinkKlap, "client_post", side_effect=_fail_one_less_than_retry_count
|
||||
transport_class, "needs_handshake", property(lambda self: False)
|
||||
)
|
||||
mocker.patch.object(transport_class, "needs_login", property(lambda self: False))
|
||||
|
||||
send_mock = mocker.patch.object(
|
||||
transport_class,
|
||||
"send",
|
||||
side_effect=_fail_one_less_than_retry_count,
|
||||
)
|
||||
|
||||
response = await protocol.query({}, retry_count=retry_count)
|
||||
assert response == {"great": "success"}
|
||||
response = await protocol_class(host, transport=transport_class(host)).query(
|
||||
DUMMY_QUERY, retry_count=retry_count
|
||||
)
|
||||
assert "result" in response or "great" in response
|
||||
assert send_mock.call_count == retry_count
|
||||
|
||||
|
||||
@pytest.mark.parametrize("log_level", [logging.WARNING, logging.DEBUG])
|
||||
@ -96,14 +126,14 @@ async def test_protocol_logging(mocker, caplog, log_level):
|
||||
return 200, encrypted
|
||||
|
||||
seed = secrets.token_bytes(16)
|
||||
auth_hash = TPLinkKlap.generate_auth_hash(Credentials("foo", "bar"))
|
||||
auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar"))
|
||||
encryption_session = KlapEncryptionSession(seed, seed, auth_hash)
|
||||
protocol = TPLinkKlap("127.0.0.1")
|
||||
protocol = IotProtocol("127.0.0.1")
|
||||
|
||||
protocol.handshake_done = True
|
||||
protocol.session_expire_at = time.time() + 86400
|
||||
protocol.encryption_session = encryption_session
|
||||
mocker.patch.object(TPLinkKlap, "client_post", side_effect=_return_encrypted)
|
||||
protocol._transport._handshake_done = True
|
||||
protocol._transport._session_expire_at = time.time() + 86400
|
||||
protocol._transport._encryption_session = encryption_session
|
||||
mocker.patch.object(KlapTransport, "client_post", side_effect=_return_encrypted)
|
||||
|
||||
response = await protocol.query({})
|
||||
assert response == {"great": "success"}
|
||||
@ -117,7 +147,7 @@ def test_encrypt():
|
||||
d = json.dumps({"foo": 1, "bar": 2})
|
||||
|
||||
seed = secrets.token_bytes(16)
|
||||
auth_hash = TPLinkKlap.generate_auth_hash(Credentials("foo", "bar"))
|
||||
auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar"))
|
||||
encryption_session = KlapEncryptionSession(seed, seed, auth_hash)
|
||||
|
||||
encrypted, seq = encryption_session.encrypt(d)
|
||||
@ -129,7 +159,7 @@ def test_encrypt_unicode():
|
||||
d = "{'snowman': '\u2603'}"
|
||||
|
||||
seed = secrets.token_bytes(16)
|
||||
auth_hash = TPLinkKlap.generate_auth_hash(Credentials("foo", "bar"))
|
||||
auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar"))
|
||||
encryption_session = KlapEncryptionSession(seed, seed, auth_hash)
|
||||
|
||||
encrypted, seq = encryption_session.encrypt(d)
|
||||
@ -145,7 +175,10 @@ def test_encrypt_unicode():
|
||||
(Credentials("foo", "bar"), does_not_raise()),
|
||||
(Credentials("", ""), does_not_raise()),
|
||||
(
|
||||
Credentials(TPLinkKlap.KASA_SETUP_EMAIL, TPLinkKlap.KASA_SETUP_PASSWORD),
|
||||
Credentials(
|
||||
KlapTransport.KASA_SETUP_EMAIL,
|
||||
KlapTransport.KASA_SETUP_PASSWORD,
|
||||
),
|
||||
does_not_raise(),
|
||||
),
|
||||
(
|
||||
@ -167,21 +200,21 @@ async def test_handshake1(mocker, device_credentials, expectation):
|
||||
client_seed = None
|
||||
server_seed = secrets.token_bytes(16)
|
||||
client_credentials = Credentials("foo", "bar")
|
||||
device_auth_hash = TPLinkKlap.generate_auth_hash(device_credentials)
|
||||
device_auth_hash = KlapTransport.generate_auth_hash(device_credentials)
|
||||
|
||||
mocker.patch.object(
|
||||
httpx.AsyncClient, "post", side_effect=_return_handshake1_response
|
||||
)
|
||||
|
||||
protocol = TPLinkKlap("127.0.0.1", credentials=client_credentials)
|
||||
protocol = IotProtocol("127.0.0.1", credentials=client_credentials)
|
||||
|
||||
protocol.http_client = httpx.AsyncClient()
|
||||
protocol._transport.http_client = httpx.AsyncClient()
|
||||
with expectation:
|
||||
(
|
||||
local_seed,
|
||||
device_remote_seed,
|
||||
auth_hash,
|
||||
) = await protocol.perform_handshake1()
|
||||
) = await protocol._transport.perform_handshake1()
|
||||
|
||||
assert local_seed == client_seed
|
||||
assert device_remote_seed == server_seed
|
||||
@ -204,23 +237,23 @@ async def test_handshake(mocker):
|
||||
client_seed = None
|
||||
server_seed = secrets.token_bytes(16)
|
||||
client_credentials = Credentials("foo", "bar")
|
||||
device_auth_hash = TPLinkKlap.generate_auth_hash(client_credentials)
|
||||
device_auth_hash = KlapTransport.generate_auth_hash(client_credentials)
|
||||
|
||||
mocker.patch.object(
|
||||
httpx.AsyncClient, "post", side_effect=_return_handshake_response
|
||||
)
|
||||
|
||||
protocol = TPLinkKlap("127.0.0.1", credentials=client_credentials)
|
||||
protocol.http_client = httpx.AsyncClient()
|
||||
protocol = IotProtocol("127.0.0.1", credentials=client_credentials)
|
||||
protocol._transport.http_client = httpx.AsyncClient()
|
||||
|
||||
response_status = 200
|
||||
await protocol.perform_handshake()
|
||||
assert protocol.handshake_done is True
|
||||
await protocol._transport.perform_handshake()
|
||||
assert protocol._transport._handshake_done is True
|
||||
|
||||
response_status = 403
|
||||
with pytest.raises(AuthenticationException):
|
||||
await protocol.perform_handshake()
|
||||
assert protocol.handshake_done is False
|
||||
await protocol._transport.perform_handshake()
|
||||
assert protocol._transport._handshake_done is False
|
||||
await protocol.close()
|
||||
|
||||
|
||||
@ -237,9 +270,9 @@ async def test_query(mocker):
|
||||
return _mock_response(200, b"")
|
||||
elif url == "http://127.0.0.1/app/request":
|
||||
encryption_session = KlapEncryptionSession(
|
||||
protocol.encryption_session.local_seed,
|
||||
protocol.encryption_session.remote_seed,
|
||||
protocol.encryption_session.user_hash,
|
||||
protocol._transport._encryption_session.local_seed,
|
||||
protocol._transport._encryption_session.remote_seed,
|
||||
protocol._transport._encryption_session.user_hash,
|
||||
)
|
||||
seq = params.get("seq")
|
||||
encryption_session._seq = seq - 1
|
||||
@ -252,11 +285,11 @@ async def test_query(mocker):
|
||||
seq = None
|
||||
server_seed = secrets.token_bytes(16)
|
||||
client_credentials = Credentials("foo", "bar")
|
||||
device_auth_hash = TPLinkKlap.generate_auth_hash(client_credentials)
|
||||
device_auth_hash = KlapTransport.generate_auth_hash(client_credentials)
|
||||
|
||||
mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response)
|
||||
|
||||
protocol = TPLinkKlap("127.0.0.1", credentials=client_credentials)
|
||||
protocol = IotProtocol("127.0.0.1", credentials=client_credentials)
|
||||
|
||||
for _ in range(10):
|
||||
resp = await protocol.query({})
|
||||
@ -296,11 +329,11 @@ async def test_authentication_failures(mocker, response_status, expectation):
|
||||
|
||||
server_seed = secrets.token_bytes(16)
|
||||
client_credentials = Credentials("foo", "bar")
|
||||
device_auth_hash = TPLinkKlap.generate_auth_hash(client_credentials)
|
||||
device_auth_hash = KlapTransport.generate_auth_hash(client_credentials)
|
||||
|
||||
mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response)
|
||||
|
||||
protocol = TPLinkKlap("127.0.0.1", credentials=client_credentials)
|
||||
protocol = IotProtocol("127.0.0.1", credentials=client_credentials)
|
||||
|
||||
with expectation:
|
||||
await protocol.query({})
|
||||
|
@ -1,6 +1,6 @@
|
||||
from kasa import DeviceType
|
||||
|
||||
from .conftest import plug
|
||||
from .conftest import plug, plug_smart
|
||||
from .newfakes import PLUG_SCHEMA
|
||||
|
||||
|
||||
@ -28,3 +28,14 @@ async def test_led(dev):
|
||||
assert dev.led
|
||||
|
||||
await dev.set_led(original)
|
||||
|
||||
|
||||
@plug_smart
|
||||
async def test_plug_device_info(dev):
|
||||
assert dev._info is not None
|
||||
# PLUG_SCHEMA(dev.sys_info)
|
||||
|
||||
assert dev.model is not None
|
||||
|
||||
assert dev.device_type == DeviceType.Plug or dev.device_type == DeviceType.Strip
|
||||
# assert dev.is_plug or dev.is_strip
|
||||
|
@ -9,7 +9,7 @@ from kasa.tests.conftest import get_device_for_file
|
||||
|
||||
def test_bulb_examples(mocker):
|
||||
"""Use KL130 (bulb with all features) to test the doctests."""
|
||||
p = asyncio.run(get_device_for_file("KL130(US)_1.0_1.8.11.json"))
|
||||
p = asyncio.run(get_device_for_file("KL130(US)_1.0_1.8.11.json", "IOT"))
|
||||
mocker.patch("kasa.smartbulb.SmartBulb", return_value=p)
|
||||
mocker.patch("kasa.smartbulb.SmartBulb.update")
|
||||
res = xdoctest.doctest_module("kasa.smartbulb", "all")
|
||||
@ -18,7 +18,7 @@ def test_bulb_examples(mocker):
|
||||
|
||||
def test_smartdevice_examples(mocker):
|
||||
"""Use HS110 for emeter examples."""
|
||||
p = asyncio.run(get_device_for_file("HS110(EU)_1.0_1.2.5.json"))
|
||||
p = asyncio.run(get_device_for_file("HS110(EU)_1.0_1.2.5.json", "IOT"))
|
||||
mocker.patch("kasa.smartdevice.SmartDevice", return_value=p)
|
||||
mocker.patch("kasa.smartdevice.SmartDevice.update")
|
||||
res = xdoctest.doctest_module("kasa.smartdevice", "all")
|
||||
@ -27,7 +27,7 @@ def test_smartdevice_examples(mocker):
|
||||
|
||||
def test_plug_examples(mocker):
|
||||
"""Test plug examples."""
|
||||
p = asyncio.run(get_device_for_file("HS110(EU)_1.0_1.2.5.json"))
|
||||
p = asyncio.run(get_device_for_file("HS110(EU)_1.0_1.2.5.json", "IOT"))
|
||||
mocker.patch("kasa.smartplug.SmartPlug", return_value=p)
|
||||
mocker.patch("kasa.smartplug.SmartPlug.update")
|
||||
res = xdoctest.doctest_module("kasa.smartplug", "all")
|
||||
@ -36,7 +36,7 @@ def test_plug_examples(mocker):
|
||||
|
||||
def test_strip_examples(mocker):
|
||||
"""Test strip examples."""
|
||||
p = asyncio.run(get_device_for_file("KP303(UK)_1.0_1.0.3.json"))
|
||||
p = asyncio.run(get_device_for_file("KP303(UK)_1.0_1.0.3.json", "IOT"))
|
||||
mocker.patch("kasa.smartstrip.SmartStrip", return_value=p)
|
||||
mocker.patch("kasa.smartstrip.SmartStrip.update")
|
||||
res = xdoctest.doctest_module("kasa.smartstrip", "all")
|
||||
@ -45,7 +45,7 @@ def test_strip_examples(mocker):
|
||||
|
||||
def test_dimmer_examples(mocker):
|
||||
"""Test dimmer examples."""
|
||||
p = asyncio.run(get_device_for_file("HS220(US)_1.0_1.5.7.json"))
|
||||
p = asyncio.run(get_device_for_file("HS220(US)_1.0_1.5.7.json", "IOT"))
|
||||
mocker.patch("kasa.smartdimmer.SmartDimmer", return_value=p)
|
||||
mocker.patch("kasa.smartdimmer.SmartDimmer.update")
|
||||
res = xdoctest.doctest_module("kasa.smartdimmer", "all")
|
||||
@ -54,7 +54,7 @@ def test_dimmer_examples(mocker):
|
||||
|
||||
def test_lightstrip_examples(mocker):
|
||||
"""Test lightstrip examples."""
|
||||
p = asyncio.run(get_device_for_file("KL430(US)_1.0_1.0.10.json"))
|
||||
p = asyncio.run(get_device_for_file("KL430(US)_1.0_1.0.10.json", "IOT"))
|
||||
mocker.patch("kasa.smartlightstrip.SmartLightStrip", return_value=p)
|
||||
mocker.patch("kasa.smartlightstrip.SmartLightStrip.update")
|
||||
res = xdoctest.doctest_module("kasa.smartlightstrip", "all")
|
||||
@ -63,7 +63,7 @@ def test_lightstrip_examples(mocker):
|
||||
|
||||
def test_discovery_examples(mocker):
|
||||
"""Test discovery examples."""
|
||||
p = asyncio.run(get_device_for_file("KP303(UK)_1.0_1.0.3.json"))
|
||||
p = asyncio.run(get_device_for_file("KP303(UK)_1.0_1.0.3.json", "IOT"))
|
||||
|
||||
mocker.patch("kasa.discover.Discover.discover", return_value=[p])
|
||||
res = xdoctest.doctest_module("kasa.discover", "all")
|
||||
|
@ -8,7 +8,7 @@ import kasa
|
||||
from kasa import Credentials, SmartDevice, SmartDeviceException
|
||||
from kasa.smartdevice import DeviceType
|
||||
|
||||
from .conftest import handle_turn_on, has_emeter, no_emeter, turn_on
|
||||
from .conftest import device_iot, handle_turn_on, has_emeter, no_emeter, turn_on
|
||||
from .newfakes import PLUG_SCHEMA, TZ_SCHEMA, FakeTransportProtocol
|
||||
|
||||
# List of all SmartXXX classes including the SmartDevice base class
|
||||
@ -22,11 +22,13 @@ smart_device_classes = [
|
||||
]
|
||||
|
||||
|
||||
@device_iot
|
||||
async def test_state_info(dev):
|
||||
assert isinstance(dev.state_information, dict)
|
||||
|
||||
|
||||
@pytest.mark.requires_dummy
|
||||
@device_iot
|
||||
async def test_invalid_connection(dev):
|
||||
with patch.object(
|
||||
FakeTransportProtocol, "query", side_effect=SmartDeviceException
|
||||
@ -58,12 +60,14 @@ async def test_initial_update_no_emeter(dev, mocker):
|
||||
assert spy.call_count == 2
|
||||
|
||||
|
||||
@device_iot
|
||||
async def test_query_helper(dev):
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await dev._query_helper("test", "testcmd", {})
|
||||
# TODO check for unwrapping?
|
||||
|
||||
|
||||
@device_iot
|
||||
@turn_on
|
||||
async def test_state(dev, turn_on):
|
||||
await handle_turn_on(dev, turn_on)
|
||||
@ -90,6 +94,7 @@ async def test_state(dev, turn_on):
|
||||
assert dev.is_off
|
||||
|
||||
|
||||
@device_iot
|
||||
async def test_alias(dev):
|
||||
test_alias = "TEST1234"
|
||||
original = dev.alias
|
||||
@ -104,6 +109,7 @@ async def test_alias(dev):
|
||||
assert dev.alias == original
|
||||
|
||||
|
||||
@device_iot
|
||||
@turn_on
|
||||
async def test_on_since(dev, turn_on):
|
||||
await handle_turn_on(dev, turn_on)
|
||||
@ -116,30 +122,37 @@ async def test_on_since(dev, turn_on):
|
||||
assert dev.on_since is None
|
||||
|
||||
|
||||
@device_iot
|
||||
async def test_time(dev):
|
||||
assert isinstance(await dev.get_time(), datetime)
|
||||
|
||||
|
||||
@device_iot
|
||||
async def test_timezone(dev):
|
||||
TZ_SCHEMA(await dev.get_timezone())
|
||||
|
||||
|
||||
@device_iot
|
||||
async def test_hw_info(dev):
|
||||
PLUG_SCHEMA(dev.hw_info)
|
||||
|
||||
|
||||
@device_iot
|
||||
async def test_location(dev):
|
||||
PLUG_SCHEMA(dev.location)
|
||||
|
||||
|
||||
@device_iot
|
||||
async def test_rssi(dev):
|
||||
PLUG_SCHEMA({"rssi": dev.rssi}) # wrapping for vol
|
||||
|
||||
|
||||
@device_iot
|
||||
async def test_mac(dev):
|
||||
PLUG_SCHEMA({"mac": dev.mac}) # wrapping for val
|
||||
|
||||
|
||||
@device_iot
|
||||
async def test_representation(dev):
|
||||
import re
|
||||
|
||||
@ -147,6 +160,7 @@ async def test_representation(dev):
|
||||
assert pattern.match(str(dev))
|
||||
|
||||
|
||||
@device_iot
|
||||
async def test_childrens(dev):
|
||||
"""Make sure that children property is exposed by every device."""
|
||||
if dev.is_strip:
|
||||
@ -155,6 +169,7 @@ async def test_childrens(dev):
|
||||
assert len(dev.children) == 0
|
||||
|
||||
|
||||
@device_iot
|
||||
async def test_children(dev):
|
||||
"""Make sure that children property is exposed by every device."""
|
||||
if dev.is_strip:
|
||||
@ -165,11 +180,13 @@ async def test_children(dev):
|
||||
assert dev.has_children is False
|
||||
|
||||
|
||||
@device_iot
|
||||
async def test_internal_state(dev):
|
||||
"""Make sure the internal state returns the last update results."""
|
||||
assert dev.internal_state == dev._last_update
|
||||
|
||||
|
||||
@device_iot
|
||||
async def test_features(dev):
|
||||
"""Make sure features is always accessible."""
|
||||
sysinfo = dev._last_update["system"]["get_sysinfo"]
|
||||
@ -179,11 +196,13 @@ async def test_features(dev):
|
||||
assert dev.features == set()
|
||||
|
||||
|
||||
@device_iot
|
||||
async def test_max_device_response_size(dev):
|
||||
"""Make sure every device return has a set max response size."""
|
||||
assert dev.max_device_response_size > 0
|
||||
|
||||
|
||||
@device_iot
|
||||
async def test_estimated_response_sizes(dev):
|
||||
"""Make sure every module has an estimated response size set."""
|
||||
for mod in dev.modules.values():
|
||||
@ -202,6 +221,7 @@ def test_device_class_ctors(device_class):
|
||||
assert dev.credentials == credentials
|
||||
|
||||
|
||||
@device_iot
|
||||
async def test_modules_preserved(dev: SmartDevice):
|
||||
"""Make modules that are not being updated are preserved between updates."""
|
||||
dev._last_update["some_module_not_being_updated"] = "should_be_kept"
|
||||
@ -237,6 +257,7 @@ async def test_create_thin_wrapper():
|
||||
)
|
||||
|
||||
|
||||
@device_iot
|
||||
async def test_modules_not_supported(dev: SmartDevice):
|
||||
"""Test that unsupported modules do not break the device."""
|
||||
for module in dev.modules.values():
|
||||
|
Loading…
Reference in New Issue
Block a user