Add klap support for TAPO protocol by splitting out Transports and Protocols (#557)

* Add support for TAPO/SMART KLAP and seperate transports from protocols

* Add tests and some review changes

* Update following review

* Updates following review
This commit is contained in:
sdb9696 2023-12-04 18:50:05 +00:00 committed by GitHub
parent 347cbfe3bd
commit 4a00199506
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 1604 additions and 887 deletions

View File

@ -21,13 +21,14 @@ from kasa.exceptions import (
SmartDeviceException, SmartDeviceException,
UnsupportedDeviceException, UnsupportedDeviceException,
) )
from kasa.klapprotocol import TPLinkKlap from kasa.iotprotocol import IotProtocol
from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol
from kasa.smartbulb import SmartBulb, SmartBulbPreset, TurnOnBehavior, TurnOnBehaviors from kasa.smartbulb import SmartBulb, SmartBulbPreset, TurnOnBehavior, TurnOnBehaviors
from kasa.smartdevice import DeviceType, SmartDevice from kasa.smartdevice import DeviceType, SmartDevice
from kasa.smartdimmer import SmartDimmer from kasa.smartdimmer import SmartDimmer
from kasa.smartlightstrip import SmartLightStrip from kasa.smartlightstrip import SmartLightStrip
from kasa.smartplug import SmartPlug from kasa.smartplug import SmartPlug
from kasa.smartprotocol import SmartProtocol
from kasa.smartstrip import SmartStrip from kasa.smartstrip import SmartStrip
__version__ = version("python-kasa") __version__ = version("python-kasa")
@ -37,7 +38,8 @@ __all__ = [
"Discover", "Discover",
"TPLinkSmartHomeProtocol", "TPLinkSmartHomeProtocol",
"TPLinkProtocol", "TPLinkProtocol",
"TPLinkKlap", "IotProtocol",
"SmartProtocol",
"SmartBulb", "SmartBulb",
"SmartBulbPreset", "SmartBulbPreset",
"TurnOnBehaviors", "TurnOnBehaviors",

View File

@ -1,498 +0,0 @@
"""Implementation of the TP-Link AES Protocol.
Based on the work of https://github.com/petretiandrea/plugp100
under compatible GNU GPL3 license.
"""
import asyncio
import base64
import hashlib
import logging
import time
import uuid
from pprint import pformat as pf
from typing import Dict, Optional, Union
import httpx
from cryptography.hazmat.primitives import hashes, padding, serialization
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from .credentials import Credentials
from .exceptions import AuthenticationException, SmartDeviceException
from .json import dumps as json_dumps
from .json import loads as json_loads
from .protocol import TPLinkProtocol
_LOGGER = logging.getLogger(__name__)
logging.getLogger("httpx").propagate = False
def _md5(payload: bytes) -> bytes:
digest = hashes.Hash(hashes.MD5()) # noqa: S303
digest.update(payload)
hash = digest.finalize()
return hash
def _sha1(payload: bytes) -> str:
sha1_algo = hashlib.sha1() # noqa: S324
sha1_algo.update(payload)
return sha1_algo.hexdigest()
class TPLinkAes(TPLinkProtocol):
"""Implementation of the AES encryption protocol.
AES is the name used in device discovery for TP-Link's TAPO encryption
protocol, sometimes used by newer firmware versions on kasa devices.
"""
DEFAULT_PORT = 80
DEFAULT_TIMEOUT = 5
SESSION_COOKIE_NAME = "TP_SESSIONID"
COMMON_HEADERS = {
"Content-Type": "application/json",
"requestByApp": "true",
"Accept": "application/json",
}
def __init__(
self,
host: str,
*,
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
) -> None:
super().__init__(host=host, port=self.DEFAULT_PORT)
self.credentials = (
credentials
if credentials and credentials.username and credentials.password
else Credentials(username="", password="")
)
self._local_seed: Optional[bytes] = None
self.local_auth_hash = self.generate_auth_hash(self.credentials)
self.local_auth_owner = self.generate_owner_hash(self.credentials).hex()
self.kasa_setup_auth_hash = None
self.blank_auth_hash = None
self.handshake_lock = asyncio.Lock()
self.query_lock = asyncio.Lock()
self.handshake_done = False
self.encryption_session: Optional[AesEncyptionSession] = None
self.session_expire_at: Optional[float] = None
self.timeout = timeout if timeout else self.DEFAULT_TIMEOUT
self.session_cookie = None
self.terminal_uuid = None
self.http_client: Optional[httpx.AsyncClient] = None
self.request_id_generator = SnowflakeId(1, 1)
self.login_token = None
_LOGGER.debug("Created AES object for %s", self.host)
def hash_credentials(self, credentials, try_login_version2):
"""Hash the credentials."""
if try_login_version2:
un = base64.b64encode(
_sha1(credentials.username.encode()).encode()
).decode()
pw = base64.b64encode(
_sha1(credentials.password.encode()).encode()
).decode()
else:
un = base64.b64encode(
_sha1(credentials.username.encode()).encode()
).decode()
pw = base64.b64encode(credentials.password.encode()).decode()
return un, pw
async def client_post(self, url, params=None, data=None, json=None, headers=None):
"""Send an http post request to the device."""
response_data = None
cookies = None
if self.session_cookie:
cookies = httpx.Cookies()
cookies.set(self.SESSION_COOKIE_NAME, self.session_cookie)
self.http_client.cookies.clear()
resp = await self.http_client.post(
url,
params=params,
data=data,
json=json,
timeout=self.timeout,
cookies=cookies,
headers=self.COMMON_HEADERS,
)
if resp.status_code == 200:
response_data = resp.json()
return resp.status_code, response_data
async def send_secure_passthrough(self, request):
"""Send encrypted message as passthrough."""
url = f"http://{self.host}/app"
if self.login_token:
url += f"?token={self.login_token}"
raw_request = json_dumps(request)
encrypted_payload = self.encryption_session.encrypt(raw_request.encode())
passthrough_request = {
"method": "securePassthrough",
"params": {"request": encrypted_payload.decode()},
}
status_code, resp_dict = await self.client_post(url, json=passthrough_request)
if status_code == 200 and resp_dict["error_code"] == 0:
response = self.encryption_session.decrypt(
resp_dict["result"]["response"].encode()
)
resp_dict = json_loads(response)
if resp_dict["error_code"] != 0:
raise SmartDeviceException(
f"Could not complete send, response was {resp_dict}",
)
if "result" in resp_dict:
return resp_dict["result"]
else:
raise AuthenticationException("Could not complete send")
def get_aes_request(self, method, params=None):
"""Get a request message."""
request = {
"method": method,
"params": params,
"requestID": self.request_id_generator.generate_id(),
"request_time_milis": round(time.time() * 1000),
"terminal_uuid": self.terminal_uuid,
}
return request
async def perform_login(self, login_v2):
"""Login to the device."""
self.login_token = None
un, pw = self.hash_credentials(self.credentials, login_v2)
params = {"password": pw, "username": un}
request = self.get_aes_request("login_device", params)
try:
result = await self.send_secure_passthrough(request)
except SmartDeviceException as ex:
raise AuthenticationException(ex) from ex
self.login_token = result["token"]
async def perform_handshake(self):
"""Perform the handshake."""
_LOGGER.debug("Will perform handshaking...")
_LOGGER.debug("Generating keypair")
self.handshake_done = False
self.session_expire_at = None
self.session_cookie = None
url = f"http://{self.host}/app"
key_pair = KeyPair.create_key_pair()
pub_key = (
"-----BEGIN PUBLIC KEY-----\n"
+ key_pair.get_public_key()
+ "\n-----END PUBLIC KEY-----\n"
)
handshake_params = {"key": pub_key}
_LOGGER.debug(f"Handshake params: {handshake_params}")
request_body = {"method": "handshake", "params": handshake_params}
_LOGGER.debug(f"Request {request_body}")
status_code, resp_dict = await self.client_post(url, json=request_body)
_LOGGER.debug(f"Device responded with: {resp_dict}")
if status_code == 200 and resp_dict["error_code"] == 0:
_LOGGER.debug("Decoding handshake key...")
handshake_key = resp_dict["result"]["key"]
self.session_cookie = self.http_client.cookies.get( # type: ignore
self.SESSION_COOKIE_NAME
)
if not self.session_cookie:
self.session_cookie = self.http_client.cookies.get( # type: ignore
"SESSIONID"
)
self.session_expire_at = time.time() + 86400
self.encryption_session = AesEncyptionSession.create_from_keypair(
handshake_key, key_pair
)
self.terminal_uuid = base64.b64encode(_md5(uuid.uuid4().bytes)).decode(
"UTF-8"
)
self.handshake_done = True
_LOGGER.debug("Handshake with %s complete", self.host)
else:
raise AuthenticationException("Could not complete handshake")
def handshake_session_expired(self):
"""Return true if session has expired."""
return (
self.session_expire_at is None or self.session_expire_at - time.time() <= 0
)
@staticmethod
def generate_auth_hash(creds: Credentials):
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
un = creds.username or ""
pw = creds.password or ""
return _md5(_md5(un.encode()) + _md5(pw.encode()))
@staticmethod
def generate_owner_hash(creds: Credentials):
"""Return the MD5 hash of the username in this object."""
un = creds.username or ""
return _md5(un.encode())
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
"""Query the device retrying for retry_count on failure."""
async with self.query_lock:
return await self._query(request, retry_count)
async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
for retry in range(retry_count + 1):
try:
return await self._execute_query(request, retry)
except httpx.CloseError as sdex:
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {sdex}"
) from sdex
continue
except httpx.ConnectError as cex:
await self.close()
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {cex}"
) from cex
except TimeoutError as tex:
await self.close()
raise SmartDeviceException(
f"Unable to connect to the device, timed out: {self.host}: {tex}"
) from tex
except AuthenticationException as auex:
_LOGGER.debug("Unable to authenticate with %s, not retrying", self.host)
raise auex
except Exception as ex:
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {ex}"
) from ex
continue
# make mypy happy, this should never be reached..
raise SmartDeviceException("Query reached somehow to unreachable")
async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> Dict:
_LOGGER.debug(
"%s >> %s",
self.host,
_LOGGER.isEnabledFor(logging.DEBUG) and pf(request),
)
if not self.http_client:
self.http_client = httpx.AsyncClient()
if not self.handshake_done or self.handshake_session_expired():
try:
await self.perform_handshake()
await self.perform_login(False)
except AuthenticationException:
await self.perform_handshake()
await self.perform_login(True)
if isinstance(request, dict):
aes_method = next(iter(request))
aes_params = request[aes_method]
else:
aes_method = request
aes_params = None
aes_request = self.get_aes_request(aes_method, aes_params)
response_data = await self.send_secure_passthrough(aes_request)
_LOGGER.debug(
"%s << %s",
self.host,
_LOGGER.isEnabledFor(logging.DEBUG) and pf(response_data),
)
return response_data
async def close(self) -> None:
"""Close the protocol."""
client = self.http_client
self.http_client = None
if client:
await client.aclose()
class AesEncyptionSession:
"""Class for an AES encryption session."""
@staticmethod
def create_from_keypair(handshake_key: str, keypair):
"""Create the encryption session."""
handshake_key_bytes: bytes = base64.b64decode(handshake_key.encode("UTF-8"))
private_key_data = base64.b64decode(keypair.get_private_key().encode("UTF-8"))
private_key = serialization.load_der_private_key(private_key_data, None, None)
key_and_iv = private_key.decrypt(
handshake_key_bytes, asymmetric_padding.PKCS1v15()
)
if key_and_iv is None:
raise ValueError("Decryption failed!")
return AesEncyptionSession(key_and_iv[:16], key_and_iv[16:])
def __init__(self, key, iv):
self.cipher = Cipher(algorithms.AES(key), modes.CBC(iv))
self.padding_strategy = padding.PKCS7(algorithms.AES.block_size)
def encrypt(self, data) -> bytes:
"""Encrypt the message."""
encryptor = self.cipher.encryptor()
padder = self.padding_strategy.padder()
padded_data = padder.update(data) + padder.finalize()
encrypted = encryptor.update(padded_data) + encryptor.finalize()
return base64.b64encode(encrypted)
def decrypt(self, data) -> str:
"""Decrypt the message."""
decryptor = self.cipher.decryptor()
unpadder = self.padding_strategy.unpadder()
decrypted = decryptor.update(base64.b64decode(data)) + decryptor.finalize()
unpadded_data = unpadder.update(decrypted) + unpadder.finalize()
return unpadded_data.decode()
class KeyPair:
"""Class for generating key pairs."""
@staticmethod
def create_key_pair(key_size: int = 1024):
"""Create a key pair."""
private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size)
public_key = private_key.public_key()
private_key_bytes = private_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
public_key_bytes = public_key.public_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
return KeyPair(
private_key=base64.b64encode(private_key_bytes).decode("UTF-8"),
public_key=base64.b64encode(public_key_bytes).decode("UTF-8"),
)
def __init__(self, private_key: str, public_key: str):
self.private_key = private_key
self.public_key = public_key
def get_private_key(self) -> str:
"""Get the private key."""
return self.private_key
def get_public_key(self) -> str:
"""Get the public key."""
return self.public_key
class SnowflakeId:
"""Class for generating snowflake ids."""
EPOCH = 1420041600000 # Custom epoch (in milliseconds)
WORKER_ID_BITS = 5
DATA_CENTER_ID_BITS = 5
SEQUENCE_BITS = 12
MAX_WORKER_ID = (1 << WORKER_ID_BITS) - 1
MAX_DATA_CENTER_ID = (1 << DATA_CENTER_ID_BITS) - 1
SEQUENCE_MASK = (1 << SEQUENCE_BITS) - 1
def __init__(self, worker_id, data_center_id):
if worker_id > SnowflakeId.MAX_WORKER_ID or worker_id < 0:
raise ValueError(
"Worker ID can't be greater than "
+ str(SnowflakeId.MAX_WORKER_ID)
+ " or less than 0"
)
if data_center_id > SnowflakeId.MAX_DATA_CENTER_ID or data_center_id < 0:
raise ValueError(
"Data center ID can't be greater than "
+ str(SnowflakeId.MAX_DATA_CENTER_ID)
+ " or less than 0"
)
self.worker_id = worker_id
self.data_center_id = data_center_id
self.sequence = 0
self.last_timestamp = -1
def generate_id(self):
"""Generate a snowflake id."""
timestamp = self._current_millis()
if timestamp < self.last_timestamp:
raise ValueError("Clock moved backwards. Refusing to generate ID.")
if timestamp == self.last_timestamp:
# Within the same millisecond, increment the sequence number
self.sequence = (self.sequence + 1) & SnowflakeId.SEQUENCE_MASK
if self.sequence == 0:
# Sequence exceeds its bit range, wait until the next millisecond
timestamp = self._wait_next_millis(self.last_timestamp)
else:
# New millisecond, reset the sequence number
self.sequence = 0
# Update the last timestamp
self.last_timestamp = timestamp
# Generate and return the final ID
return (
(
(timestamp - SnowflakeId.EPOCH)
<< (
SnowflakeId.WORKER_ID_BITS
+ SnowflakeId.SEQUENCE_BITS
+ SnowflakeId.DATA_CENTER_ID_BITS
)
)
| (
self.data_center_id
<< (SnowflakeId.SEQUENCE_BITS + SnowflakeId.WORKER_ID_BITS)
)
| (self.worker_id << SnowflakeId.SEQUENCE_BITS)
| self.sequence
)
def _current_millis(self):
return round(time.time() * 1000)
def _wait_next_millis(self, last_timestamp):
timestamp = self._current_millis()
while timestamp <= last_timestamp:
timestamp = self._current_millis()
return timestamp

338
kasa/aestransport.py Normal file
View File

@ -0,0 +1,338 @@
"""Implementation of the TP-Link AES transport.
Based on the work of https://github.com/petretiandrea/plugp100
under compatible GNU GPL3 license.
"""
import base64
import hashlib
import logging
import time
from typing import Optional, Union
import httpx
from cryptography.hazmat.primitives import padding, serialization
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from .credentials import Credentials
from .exceptions import AuthenticationException, SmartDeviceException
from .json import dumps as json_dumps
from .json import loads as json_loads
from .protocol import BaseTransport
_LOGGER = logging.getLogger(__name__)
def _sha1(payload: bytes) -> str:
sha1_algo = hashlib.sha1() # noqa: S324
sha1_algo.update(payload)
return sha1_algo.hexdigest()
class AesTransport(BaseTransport):
"""Implementation of the AES encryption protocol.
AES is the name used in device discovery for TP-Link's TAPO encryption
protocol, sometimes used by newer firmware versions on kasa devices.
"""
DEFAULT_TIMEOUT = 5
SESSION_COOKIE_NAME = "TP_SESSIONID"
COMMON_HEADERS = {
"Content-Type": "application/json",
"requestByApp": "true",
"Accept": "application/json",
}
def __init__(
self,
host: str,
*,
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
) -> None:
super().__init__(host=host)
self._credentials = credentials or Credentials(username="", password="")
self._handshake_done = False
self._encryption_session: Optional[AesEncyptionSession] = None
self._session_expire_at: Optional[float] = None
self._timeout = timeout if timeout else self.DEFAULT_TIMEOUT
self._session_cookie = None
self._http_client: httpx.AsyncClient = httpx.AsyncClient()
self._login_token = None
_LOGGER.debug("Created AES object for %s", self.host)
def hash_credentials(self, login_v2):
"""Hash the credentials."""
if login_v2:
un = base64.b64encode(
_sha1(self._credentials.username.encode()).encode()
).decode()
pw = base64.b64encode(
_sha1(self._credentials.password.encode()).encode()
).decode()
else:
un = base64.b64encode(
_sha1(self._credentials.username.encode()).encode()
).decode()
pw = base64.b64encode(self._credentials.password.encode()).decode()
return un, pw
async def client_post(self, url, params=None, data=None, json=None, headers=None):
"""Send an http post request to the device."""
response_data = None
cookies = None
if self._session_cookie:
cookies = httpx.Cookies()
cookies.set(self.SESSION_COOKIE_NAME, self._session_cookie)
self._http_client.cookies.clear()
resp = await self._http_client.post(
url,
params=params,
data=data,
json=json,
timeout=self._timeout,
cookies=cookies,
headers=self.COMMON_HEADERS,
)
if resp.status_code == 200:
response_data = resp.json()
return resp.status_code, response_data
async def send_secure_passthrough(self, request: str):
"""Send encrypted message as passthrough."""
url = f"http://{self.host}/app"
if self._login_token:
url += f"?token={self._login_token}"
encrypted_payload = self._encryption_session.encrypt(request.encode()) # type: ignore
passthrough_request = {
"method": "securePassthrough",
"params": {"request": encrypted_payload.decode()},
}
status_code, resp_dict = await self.client_post(url, json=passthrough_request)
_LOGGER.debug(f"secure_passthrough response is {status_code}: {resp_dict}")
if status_code == 200 and resp_dict["error_code"] == 0:
response = self._encryption_session.decrypt( # type: ignore
resp_dict["result"]["response"].encode()
)
_LOGGER.debug(f"decrypted secure_passthrough response is {response}")
resp_dict = json_loads(response)
return resp_dict
else:
self._handshake_done = False
self._login_token = None
raise AuthenticationException("Could not complete send")
async def perform_login(self, login_request: Union[str, dict], *, login_v2: bool):
"""Login to the device."""
self._login_token = None
if isinstance(login_request, str):
login_request_dict: dict = json_loads(login_request)
else:
login_request_dict = login_request
un, pw = self.hash_credentials(login_v2)
login_request_dict["params"] = {"password": pw, "username": un}
request = json_dumps(login_request_dict)
try:
resp_dict = await self.send_secure_passthrough(request)
except SmartDeviceException as ex:
raise AuthenticationException(ex) from ex
self._login_token = resp_dict["result"]["token"]
@property
def needs_login(self) -> bool:
"""Return true if the transport needs to do a login."""
return self._login_token is None
async def login(self, request: str) -> None:
"""Login to the device."""
try:
if self.needs_handshake:
raise SmartDeviceException(
"Handshake must be complete before trying to login"
)
await self.perform_login(request, login_v2=False)
except AuthenticationException:
await self.perform_handshake()
await self.perform_login(request, login_v2=True)
@property
def needs_handshake(self) -> bool:
"""Return true if the transport needs to do a handshake."""
return not self._handshake_done or self._handshake_session_expired()
async def handshake(self) -> None:
"""Perform the encryption handshake."""
await self.perform_handshake()
async def perform_handshake(self):
"""Perform the handshake."""
_LOGGER.debug("Will perform handshaking...")
_LOGGER.debug("Generating keypair")
self._handshake_done = False
self._session_expire_at = None
self._session_cookie = None
url = f"http://{self.host}/app"
key_pair = KeyPair.create_key_pair()
pub_key = (
"-----BEGIN PUBLIC KEY-----\n"
+ key_pair.get_public_key()
+ "\n-----END PUBLIC KEY-----\n"
)
handshake_params = {"key": pub_key}
_LOGGER.debug(f"Handshake params: {handshake_params}")
request_body = {"method": "handshake", "params": handshake_params}
_LOGGER.debug(f"Request {request_body}")
status_code, resp_dict = await self.client_post(url, json=request_body)
_LOGGER.debug(f"Device responded with: {resp_dict}")
if status_code == 200 and resp_dict["error_code"] == 0:
_LOGGER.debug("Decoding handshake key...")
handshake_key = resp_dict["result"]["key"]
self._session_cookie = self._http_client.cookies.get( # type: ignore
self.SESSION_COOKIE_NAME
)
if not self._session_cookie:
self._session_cookie = self._http_client.cookies.get( # type: ignore
"SESSIONID"
)
self._session_expire_at = time.time() + 86400
self._encryption_session = AesEncyptionSession.create_from_keypair(
handshake_key, key_pair
)
self._handshake_done = True
_LOGGER.debug("Handshake with %s complete", self.host)
else:
raise AuthenticationException("Could not complete handshake")
def _handshake_session_expired(self):
"""Return true if session has expired."""
return (
self._session_expire_at is None
or self._session_expire_at - time.time() <= 0
)
async def send(self, request: str):
"""Send the request."""
if self.needs_handshake:
raise SmartDeviceException(
"Handshake must be complete before trying to send"
)
if self.needs_login:
raise SmartDeviceException("Login must be complete before trying to send")
resp_dict = await self.send_secure_passthrough(request)
if resp_dict["error_code"] != 0:
self._handshake_done = False
self._login_token = None
raise SmartDeviceException(
f"Could not complete send, response was {resp_dict}",
)
return resp_dict
async def close(self) -> None:
"""Close the protocol."""
client = self._http_client
self._http_client = None
if client:
await client.aclose()
class AesEncyptionSession:
"""Class for an AES encryption session."""
@staticmethod
def create_from_keypair(handshake_key: str, keypair):
"""Create the encryption session."""
handshake_key_bytes: bytes = base64.b64decode(handshake_key.encode("UTF-8"))
private_key_data = base64.b64decode(keypair.get_private_key().encode("UTF-8"))
private_key = serialization.load_der_private_key(private_key_data, None, None)
key_and_iv = private_key.decrypt(
handshake_key_bytes, asymmetric_padding.PKCS1v15()
)
if key_and_iv is None:
raise ValueError("Decryption failed!")
return AesEncyptionSession(key_and_iv[:16], key_and_iv[16:])
def __init__(self, key, iv):
self.cipher = Cipher(algorithms.AES(key), modes.CBC(iv))
self.padding_strategy = padding.PKCS7(algorithms.AES.block_size)
def encrypt(self, data) -> bytes:
"""Encrypt the message."""
encryptor = self.cipher.encryptor()
padder = self.padding_strategy.padder()
padded_data = padder.update(data) + padder.finalize()
encrypted = encryptor.update(padded_data) + encryptor.finalize()
return base64.b64encode(encrypted)
def decrypt(self, data) -> str:
"""Decrypt the message."""
decryptor = self.cipher.decryptor()
unpadder = self.padding_strategy.unpadder()
decrypted = decryptor.update(base64.b64decode(data)) + decryptor.finalize()
unpadded_data = unpadder.update(decrypted) + unpadder.finalize()
return unpadded_data.decode()
class KeyPair:
"""Class for generating key pairs."""
@staticmethod
def create_key_pair(key_size: int = 1024):
"""Create a key pair."""
private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size)
public_key = private_key.public_key()
private_key_bytes = private_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
public_key_bytes = public_key.public_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
return KeyPair(
private_key=base64.b64encode(private_key_bytes).decode("UTF-8"),
public_key=base64.b64encode(public_key_bytes).decode("UTF-8"),
)
def __init__(self, private_key: str, public_key: str):
self.private_key = private_key
self.public_key = public_key
def get_private_key(self) -> str:
"""Get the private key."""
return self.private_key
def get_public_key(self) -> str:
"""Get the public key."""
return self.public_key

View File

@ -2,17 +2,21 @@
import logging import logging
import time import time
from typing import Any, Dict, Optional, Type from typing import Any, Dict, Optional, Tuple, Type
from .aestransport import AesTransport
from .credentials import Credentials from .credentials import Credentials
from .device_type import DeviceType from .device_type import DeviceType
from .exceptions import UnsupportedDeviceException from .exceptions import UnsupportedDeviceException
from .protocol import TPLinkProtocol from .iotprotocol import IotProtocol
from .klaptransport import KlapTransport, TPlinkKlapTransportV2
from .protocol import BaseTransport, TPLinkProtocol
from .smartbulb import SmartBulb from .smartbulb import SmartBulb
from .smartdevice import SmartDevice, SmartDeviceException from .smartdevice import SmartDevice, SmartDeviceException
from .smartdimmer import SmartDimmer from .smartdimmer import SmartDimmer
from .smartlightstrip import SmartLightStrip from .smartlightstrip import SmartLightStrip
from .smartplug import SmartPlug from .smartplug import SmartPlug
from .smartprotocol import SmartProtocol
from .smartstrip import SmartStrip from .smartstrip import SmartStrip
from .tapo.tapoplug import TapoPlug from .tapo.tapoplug import TapoPlug
@ -87,7 +91,7 @@ async def connect(
if protocol_class is not None: if protocol_class is not None:
unknown_dev.protocol = protocol_class(host, credentials=credentials) unknown_dev.protocol = protocol_class(host, credentials=credentials)
await unknown_dev.update() await unknown_dev.update()
device_class = get_device_class_from_info(unknown_dev.internal_state) device_class = get_device_class_from_sys_info(unknown_dev.internal_state)
dev = device_class(host=host, port=port, credentials=credentials, timeout=timeout) dev = device_class(host=host, port=port, credentials=credentials, timeout=timeout)
# Reuse the connection from the unknown device # Reuse the connection from the unknown device
# so we don't have to reconnect # so we don't have to reconnect
@ -104,7 +108,7 @@ async def connect(
return dev return dev
def get_device_class_from_info(info: Dict[str, Any]) -> Type[SmartDevice]: def get_device_class_from_sys_info(info: Dict[str, Any]) -> Type[SmartDevice]:
"""Find SmartDevice subclass for device described by passed data.""" """Find SmartDevice subclass for device described by passed data."""
if "system" not in info or "get_sysinfo" not in info["system"]: if "system" not in info or "get_sysinfo" not in info["system"]:
raise SmartDeviceException("No 'system' or 'get_sysinfo' in response") raise SmartDeviceException("No 'system' or 'get_sysinfo' in response")
@ -129,3 +133,35 @@ def get_device_class_from_info(info: Dict[str, Any]) -> Type[SmartDevice]:
return SmartBulb return SmartBulb
raise UnsupportedDeviceException("Unknown device type: %s" % type_) raise UnsupportedDeviceException("Unknown device type: %s" % type_)
def get_device_class_from_type_name(device_type: str) -> Optional[Type[SmartDevice]]:
"""Return the device class from the type name."""
supported_device_types: dict[str, Type[SmartDevice]] = {
"SMART.TAPOPLUG": TapoPlug,
"SMART.KASAPLUG": TapoPlug,
"IOT.SMARTPLUGSWITCH": SmartPlug,
}
return supported_device_types.get(device_type)
def get_protocol_from_connection_name(
connection_name: str, host: str, credentials: Optional[Credentials] = None
) -> Optional[TPLinkProtocol]:
"""Return the protocol from the connection name."""
supported_device_protocols: dict[
str, Tuple[Type[TPLinkProtocol], Type[BaseTransport]]
] = {
"IOT.KLAP": (IotProtocol, KlapTransport),
"SMART.AES": (SmartProtocol, AesTransport),
"SMART.KLAP": (SmartProtocol, TPlinkKlapTransportV2),
}
if connection_name not in supported_device_protocols:
return None
protocol_class, transport_class = supported_device_protocols.get(connection_name) # type: ignore
transport: BaseTransport = transport_class(host, credentials=credentials)
protocol: TPLinkProtocol = protocol_class(
host, credentials=credentials, transport=transport
)
return protocol

View File

@ -15,18 +15,18 @@ try:
except ImportError: except ImportError:
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from kasa.aesprotocol import TPLinkAes
from kasa.credentials import Credentials from kasa.credentials import Credentials
from kasa.exceptions import UnsupportedDeviceException from kasa.exceptions import UnsupportedDeviceException
from kasa.json import dumps as json_dumps from kasa.json import dumps as json_dumps
from kasa.json import loads as json_loads from kasa.json import loads as json_loads
from kasa.klapprotocol import TPLinkKlap from kasa.protocol import TPLinkSmartHomeProtocol
from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol
from kasa.smartdevice import SmartDevice, SmartDeviceException from kasa.smartdevice import SmartDevice, SmartDeviceException
from kasa.smartplug import SmartPlug
from kasa.tapo.tapoplug import TapoPlug
from .device_factory import get_device_class_from_info from .device_factory import (
get_device_class_from_sys_info,
get_device_class_from_type_name,
get_protocol_from_connection_name,
)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -348,7 +348,16 @@ class Discover:
@staticmethod @staticmethod
def _get_device_class(info: dict) -> Type[SmartDevice]: def _get_device_class(info: dict) -> Type[SmartDevice]:
"""Find SmartDevice subclass for device described by passed data.""" """Find SmartDevice subclass for device described by passed data."""
return get_device_class_from_info(info) if "result" in info:
discovery_result = DiscoveryResult(**info["result"])
dev_class = get_device_class_from_type_name(discovery_result.device_type)
if not dev_class:
raise UnsupportedDeviceException(
"Unknown device type: %s" % discovery_result.device_type
)
return dev_class
else:
return get_device_class_from_sys_info(info)
@staticmethod @staticmethod
def _get_device_instance_legacy(data: bytes, ip: str, port: int) -> SmartDevice: def _get_device_instance_legacy(data: bytes, ip: str, port: int) -> SmartDevice:
@ -384,24 +393,17 @@ class Discover:
encrypt_type_ = ( encrypt_type_ = (
f"{type_.split('.')[0]}.{discovery_result.mgt_encrypt_schm.encrypt_type}" f"{type_.split('.')[0]}.{discovery_result.mgt_encrypt_schm.encrypt_type}"
) )
device_class = None
supported_device_types: dict[str, Type[SmartDevice]] = { if (device_class := get_device_class_from_type_name(type_)) is None:
"SMART.TAPOPLUG": TapoPlug,
"SMART.KASAPLUG": TapoPlug,
"IOT.SMARTPLUGSWITCH": SmartPlug,
}
supported_device_protocols: dict[str, Type[TPLinkProtocol]] = {
"IOT.KLAP": TPLinkKlap,
"SMART.AES": TPLinkAes,
}
if (device_class := supported_device_types.get(type_)) is None:
_LOGGER.warning("Got unsupported device type: %s", type_) _LOGGER.warning("Got unsupported device type: %s", type_)
raise UnsupportedDeviceException( raise UnsupportedDeviceException(
f"Unsupported device {ip} of type {type_}: {info}" f"Unsupported device {ip} of type {type_}: {info}"
) )
if (protocol_class := supported_device_protocols.get(encrypt_type_)) is None: if (
protocol := get_protocol_from_connection_name(
encrypt_type_, ip, credentials=credentials
)
) is None:
_LOGGER.warning("Got unsupported device type: %s", encrypt_type_) _LOGGER.warning("Got unsupported device type: %s", encrypt_type_)
raise UnsupportedDeviceException( raise UnsupportedDeviceException(
f"Unsupported encryption scheme {ip} of type {encrypt_type_}: {info}" f"Unsupported encryption scheme {ip} of type {encrypt_type_}: {info}"
@ -409,7 +411,7 @@ class Discover:
_LOGGER.debug("[DISCOVERY] %s << %s", ip, info) _LOGGER.debug("[DISCOVERY] %s << %s", ip, info)
device = device_class(ip, port=port, credentials=credentials) device = device_class(ip, port=port, credentials=credentials)
device.protocol = protocol_class(ip, credentials=credentials) device.protocol = protocol
device.update_from_discover_info(discovery_result.get_dict()) device.update_from_discover_info(discovery_result.get_dict())
return device return device

100
kasa/iotprotocol.py Executable file
View File

@ -0,0 +1,100 @@
"""Module for the IOT legacy IOT KASA protocol."""
import asyncio
import logging
from typing import Dict, Optional, Union
import httpx
from .credentials import Credentials
from .exceptions import AuthenticationException, SmartDeviceException
from .json import dumps as json_dumps
from .klaptransport import KlapTransport
from .protocol import BaseTransport, TPLinkProtocol
_LOGGER = logging.getLogger(__name__)
class IotProtocol(TPLinkProtocol):
"""Class for the legacy TPLink IOT KASA Protocol."""
DEFAULT_PORT = 80
def __init__(
self,
host: str,
*,
transport: Optional[BaseTransport] = None,
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
) -> None:
super().__init__(host=host, port=self.DEFAULT_PORT)
self._credentials: Credentials = credentials or Credentials(
username="", password=""
)
self._transport: BaseTransport = transport or KlapTransport(
host, credentials=self._credentials, timeout=timeout
)
self._query_lock = asyncio.Lock()
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
"""Query the device retrying for retry_count on failure."""
if isinstance(request, dict):
request = json_dumps(request)
assert isinstance(request, str) # noqa: S101
async with self._query_lock:
return await self._query(request, retry_count)
async def _query(self, request: str, retry_count: int = 3) -> Dict:
for retry in range(retry_count + 1):
try:
return await self._execute_query(request, retry)
except httpx.CloseError as sdex:
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {sdex}"
) from sdex
continue
except httpx.ConnectError as cex:
await self.close()
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {cex}"
) from cex
except TimeoutError as tex:
await self.close()
raise SmartDeviceException(
f"Unable to connect to the device, timed out: {self.host}: {tex}"
) from tex
except AuthenticationException as auex:
_LOGGER.debug("Unable to authenticate with %s, not retrying", self.host)
raise auex
except Exception as ex:
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {ex}"
) from ex
continue
# make mypy happy, this should never be reached..
raise SmartDeviceException("Query reached somehow to unreachable")
async def _execute_query(self, request: str, retry_count: int) -> Dict:
if self._transport.needs_handshake:
await self._transport.handshake()
if self._transport.needs_login: # This shouln't happen
raise SmartDeviceException(
"IOT Protocol needs to login to transport but is not login aware"
)
return await self._transport.send(request)
async def close(self) -> None:
"""Close the protocol."""
await self._transport.close()

295
kasa/klapprotocol.py → kasa/klaptransport.py Executable file → Normal file
View File

@ -47,7 +47,7 @@ import logging
import secrets import secrets
import time import time
from pprint import pformat as pf from pprint import pformat as pf
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Optional, Tuple
import httpx import httpx
from cryptography.hazmat.primitives import hashes, padding from cryptography.hazmat.primitives import hashes, padding
@ -55,33 +55,33 @@ from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from .credentials import Credentials from .credentials import Credentials
from .exceptions import AuthenticationException, SmartDeviceException from .exceptions import AuthenticationException, SmartDeviceException
from .json import dumps as json_dumps
from .json import loads as json_loads from .json import loads as json_loads
from .protocol import TPLinkProtocol from .protocol import BaseTransport, md5
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
logging.getLogger("httpx").propagate = False logging.getLogger("httpx").propagate = False
def _sha256(payload: bytes) -> bytes: def _sha256(payload: bytes) -> bytes:
return hashlib.sha256(payload).digest() digest = hashes.Hash(hashes.SHA256()) # noqa: S303
def _md5(payload: bytes) -> bytes:
digest = hashes.Hash(hashes.MD5()) # noqa: S303
digest.update(payload) digest.update(payload)
hash = digest.finalize() hash = digest.finalize()
return hash return hash
class TPLinkKlap(TPLinkProtocol): def _sha1(payload: bytes) -> bytes:
digest = hashes.Hash(hashes.SHA1()) # noqa: S303
digest.update(payload)
return digest.finalize()
class KlapTransport(BaseTransport):
"""Implementation of the KLAP encryption protocol. """Implementation of the KLAP encryption protocol.
KLAP is the name used in device discovery for TP-Link's new encryption KLAP is the name used in device discovery for TP-Link's new encryption
protocol, used by newer firmware versions. protocol, used by newer firmware versions.
""" """
DEFAULT_PORT = 80
DEFAULT_TIMEOUT = 5 DEFAULT_TIMEOUT = 5
DISCOVERY_QUERY = {"system": {"get_sysinfo": None}} DISCOVERY_QUERY = {"system": {"get_sysinfo": None}}
KASA_SETUP_EMAIL = "kasa@tp-link.net" KASA_SETUP_EMAIL = "kasa@tp-link.net"
@ -95,29 +95,24 @@ class TPLinkKlap(TPLinkProtocol):
credentials: Optional[Credentials] = None, credentials: Optional[Credentials] = None,
timeout: Optional[int] = None, timeout: Optional[int] = None,
) -> None: ) -> None:
super().__init__(host=host, port=self.DEFAULT_PORT) super().__init__(host=host)
self.credentials = (
credentials
if credentials and credentials.username and credentials.password
else Credentials(username="", password="")
)
self._credentials = credentials or Credentials(username="", password="")
self._local_seed: Optional[bytes] = None self._local_seed: Optional[bytes] = None
self.local_auth_hash = self.generate_auth_hash(self.credentials) self._local_auth_hash = self.generate_auth_hash(self._credentials)
self.local_auth_owner = self.generate_owner_hash(self.credentials).hex() self._local_auth_owner = self.generate_owner_hash(self._credentials).hex()
self.kasa_setup_auth_hash = None self._kasa_setup_auth_hash = None
self.blank_auth_hash = None self._blank_auth_hash = None
self.handshake_lock = asyncio.Lock() self._handshake_lock = asyncio.Lock()
self.query_lock = asyncio.Lock() self._query_lock = asyncio.Lock()
self.handshake_done = False self._handshake_done = False
self.encryption_session: Optional[KlapEncryptionSession] = None self._encryption_session: Optional[KlapEncryptionSession] = None
self.session_expire_at: Optional[float] = None self._session_expire_at: Optional[float] = None
self.timeout = timeout if timeout else self.DEFAULT_TIMEOUT self._timeout = timeout if timeout else self.DEFAULT_TIMEOUT
self.session_cookie = None self._session_cookie = None
self.http_client: Optional[httpx.AsyncClient] = None self._http_client: httpx.AsyncClient = httpx.AsyncClient()
_LOGGER.debug("Created KLAP object for %s", self.host) _LOGGER.debug("Created KLAP object for %s", self.host)
@ -125,15 +120,15 @@ class TPLinkKlap(TPLinkProtocol):
"""Send an http post request to the device.""" """Send an http post request to the device."""
response_data = None response_data = None
cookies = None cookies = None
if self.session_cookie: if self._session_cookie:
cookies = httpx.Cookies() cookies = httpx.Cookies()
cookies.set(self.SESSION_COOKIE_NAME, self.session_cookie) cookies.set(self.SESSION_COOKIE_NAME, self._session_cookie)
self.http_client.cookies.clear() self._http_client.cookies.clear()
resp = await self.http_client.post( resp = await self._http_client.post(
url, url,
params=params, params=params,
data=data, data=data,
timeout=self.timeout, timeout=self._timeout,
cookies=cookies, cookies=cookies,
) )
if resp.status_code == 200: if resp.status_code == 200:
@ -183,44 +178,55 @@ class TPLinkKlap(TPLinkProtocol):
server_hash.hex(), server_hash.hex(),
) )
local_seed_auth_hash = _sha256(local_seed + self.local_auth_hash) local_seed_auth_hash = self.handshake1_seed_auth_hash(
local_seed, remote_seed, self._local_auth_hash
) # type: ignore
# Check the response from the device with local credentials # Check the response from the device with local credentials
if local_seed_auth_hash == server_hash: if local_seed_auth_hash == server_hash:
_LOGGER.debug("handshake1 hashes match with expected credentials") _LOGGER.debug("handshake1 hashes match with expected credentials")
return local_seed, remote_seed, self.local_auth_hash # type: ignore return local_seed, remote_seed, self._local_auth_hash # type: ignore
# Now check against the default kasa setup credentials # Now check against the default kasa setup credentials
if not self.kasa_setup_auth_hash: if not self._kasa_setup_auth_hash:
kasa_setup_creds = Credentials( kasa_setup_creds = Credentials(
username=TPLinkKlap.KASA_SETUP_EMAIL, username=self.KASA_SETUP_EMAIL,
password=TPLinkKlap.KASA_SETUP_PASSWORD, password=self.KASA_SETUP_PASSWORD,
) )
self.kasa_setup_auth_hash = TPLinkKlap.generate_auth_hash(kasa_setup_creds) self._kasa_setup_auth_hash = self.generate_auth_hash(kasa_setup_creds)
kasa_setup_seed_auth_hash = _sha256( kasa_setup_seed_auth_hash = self.handshake1_seed_auth_hash(
local_seed + self.kasa_setup_auth_hash # type: ignore local_seed,
remote_seed,
self._kasa_setup_auth_hash, # type: ignore
) )
if kasa_setup_seed_auth_hash == server_hash: if kasa_setup_seed_auth_hash == server_hash:
_LOGGER.debug( _LOGGER.debug(
"Server response doesn't match our expected hash on ip %s" "Server response doesn't match our expected hash on ip %s"
+ " but an authentication with kasa setup credentials matched", + " but an authentication with kasa setup credentials matched",
self.host, self.host,
) )
return local_seed, remote_seed, self.kasa_setup_auth_hash # type: ignore return local_seed, remote_seed, self._kasa_setup_auth_hash # type: ignore
# Finally check against blank credentials if not already blank # Finally check against blank credentials if not already blank
if self.credentials != (blank_creds := Credentials(username="", password="")): if self._credentials != (blank_creds := Credentials(username="", password="")):
if not self.blank_auth_hash: if not self._blank_auth_hash:
self.blank_auth_hash = TPLinkKlap.generate_auth_hash(blank_creds) self._blank_auth_hash = self.generate_auth_hash(blank_creds)
blank_seed_auth_hash = _sha256(local_seed + self.blank_auth_hash) # type: ignore
blank_seed_auth_hash = self.handshake1_seed_auth_hash(
local_seed,
remote_seed,
self._blank_auth_hash, # type: ignore
)
if blank_seed_auth_hash == server_hash: if blank_seed_auth_hash == server_hash:
_LOGGER.debug( _LOGGER.debug(
"Server response doesn't match our expected hash on ip %s" "Server response doesn't match our expected hash on ip %s"
+ " but an authentication with blank credentials matched", + " but an authentication with blank credentials matched",
self.host, self.host,
) )
return local_seed, remote_seed, self.blank_auth_hash # type: ignore return local_seed, remote_seed, self._blank_auth_hash # type: ignore
msg = f"Server response doesn't match our challenge on ip {self.host}" msg = f"Server response doesn't match our challenge on ip {self.host}"
_LOGGER.debug(msg) _LOGGER.debug(msg)
@ -235,7 +241,7 @@ class TPLinkKlap(TPLinkProtocol):
url = f"http://{self.host}/app/handshake2" url = f"http://{self.host}/app/handshake2"
payload = _sha256(remote_seed + auth_hash) payload = self.handshake2_seed_auth_hash(local_seed, remote_seed, auth_hash)
response_status, response_data = await self.client_post(url, data=payload) response_status, response_data = await self.client_post(url, data=payload)
@ -256,115 +262,70 @@ class TPLinkKlap(TPLinkProtocol):
return KlapEncryptionSession(local_seed, remote_seed, auth_hash) return KlapEncryptionSession(local_seed, remote_seed, auth_hash)
@property
def needs_login(self) -> bool:
"""Will return false as KLAP does not do a login."""
return False
async def login(self, request: str) -> None:
"""Will raise and exception as KLAP does not do a login."""
raise SmartDeviceException(
"KLAP does not perform logins and return needs_login == False"
)
@property
def needs_handshake(self) -> bool:
"""Return true if the transport needs to do a handshake."""
return not self._handshake_done or self._handshake_session_expired()
async def handshake(self) -> None:
"""Perform the encryption handshake."""
await self.perform_handshake()
async def perform_handshake(self) -> Any: async def perform_handshake(self) -> Any:
"""Perform handshake1 and handshake2. """Perform handshake1 and handshake2.
Sets the encryption_session if successful. Sets the encryption_session if successful.
""" """
_LOGGER.debug("Starting handshake with %s", self.host) _LOGGER.debug("Starting handshake with %s", self.host)
self.handshake_done = False self._handshake_done = False
self.session_expire_at = None self._session_expire_at = None
self.session_cookie = None self._session_cookie = None
local_seed, remote_seed, auth_hash = await self.perform_handshake1() local_seed, remote_seed, auth_hash = await self.perform_handshake1()
self.session_cookie = self.http_client.cookies.get( # type: ignore self._session_cookie = self._http_client.cookies.get( # type: ignore
TPLinkKlap.SESSION_COOKIE_NAME self.SESSION_COOKIE_NAME
) )
# The device returns a TIMEOUT cookie on handshake1 which # The device returns a TIMEOUT cookie on handshake1 which
# it doesn't like to get back so we store the one we want # it doesn't like to get back so we store the one we want
self.session_expire_at = time.time() + 86400 self._session_expire_at = time.time() + 86400
self.encryption_session = await self.perform_handshake2( self._encryption_session = await self.perform_handshake2(
local_seed, remote_seed, auth_hash local_seed, remote_seed, auth_hash
) )
self.handshake_done = True self._handshake_done = True
_LOGGER.debug("Handshake with %s complete", self.host) _LOGGER.debug("Handshake with %s complete", self.host)
def handshake_session_expired(self): def _handshake_session_expired(self):
"""Return true if session has expired.""" """Return true if session has expired."""
return ( return (
self.session_expire_at is None or self.session_expire_at - time.time() <= 0 self._session_expire_at is None
or self._session_expire_at - time.time() <= 0
) )
@staticmethod async def send(self, request: str):
def generate_auth_hash(creds: Credentials): """Send the request."""
"""Generate an md5 auth hash for the protocol on the supplied credentials.""" if self.needs_handshake:
un = creds.username or ""
pw = creds.password or ""
return _md5(_md5(un.encode()) + _md5(pw.encode()))
@staticmethod
def generate_owner_hash(creds: Credentials):
"""Return the MD5 hash of the username in this object."""
un = creds.username or ""
return _md5(un.encode())
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
"""Query the device retrying for retry_count on failure."""
if isinstance(request, dict):
request = json_dumps(request)
assert isinstance(request, str) # noqa: S101
async with self.query_lock:
return await self._query(request, retry_count)
async def _query(self, request: str, retry_count: int = 3) -> Dict:
for retry in range(retry_count + 1):
try:
return await self._execute_query(request, retry)
except httpx.CloseError as sdex:
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
raise SmartDeviceException( raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {sdex}" "Handshake must be complete before trying to send"
) from sdex
continue
except httpx.ConnectError as cex:
await self.close()
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {cex}"
) from cex
except TimeoutError as tex:
await self.close()
raise SmartDeviceException(
f"Unable to connect to the device, timed out: {self.host}: {tex}"
) from tex
except AuthenticationException as auex:
_LOGGER.debug("Unable to authenticate with %s, not retrying", self.host)
raise auex
except Exception as ex:
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {ex}"
) from ex
continue
# make mypy happy, this should never be reached..
raise SmartDeviceException("Query reached somehow to unreachable")
async def _execute_query(self, request: str, retry_count: int) -> Dict:
if not self.http_client:
self.http_client = httpx.AsyncClient()
if not self.handshake_done or self.handshake_session_expired():
try:
await self.perform_handshake()
except AuthenticationException as auex:
_LOGGER.debug(
"Unable to complete handshake for device %s, "
+ "authentication failed",
self.host,
) )
raise auex if self.needs_login:
raise SmartDeviceException("Login must be complete before trying to send")
# Check for mypy # Check for mypy
if self.encryption_session is not None: if self._encryption_session is not None:
payload, seq = self.encryption_session.encrypt(request.encode()) payload, seq = self._encryption_session.encrypt(request.encode())
url = f"http://{self.host}/app/request" url = f"http://{self.host}/app/request"
@ -376,14 +337,14 @@ class TPLinkKlap(TPLinkProtocol):
msg = ( msg = (
f"at {datetime.datetime.now()}. Host is {self.host}, " f"at {datetime.datetime.now()}. Host is {self.host}, "
+ f"Retry count is {retry_count}, Sequence is {seq}, " + f"Sequence is {seq}, "
+ f"Response status is {response_status}, Request was {request}" + f"Response status is {response_status}, Request was {request}"
) )
if response_status != 200: if response_status != 200:
_LOGGER.error("Query failed after succesful authentication " + msg) _LOGGER.error("Query failed after succesful authentication " + msg)
# If we failed with a security error, force a new handshake next time. # If we failed with a security error, force a new handshake next time.
if response_status == 403: if response_status == 403:
self.handshake_done = False self._handshake_done = False
raise AuthenticationException( raise AuthenticationException(
f"Got a security error from {self.host} after handshake " f"Got a security error from {self.host} after handshake "
+ "completed" + "completed"
@ -397,8 +358,8 @@ class TPLinkKlap(TPLinkProtocol):
_LOGGER.debug("Query posted " + msg) _LOGGER.debug("Query posted " + msg)
# Check for mypy # Check for mypy
if self.encryption_session is not None: if self._encryption_session is not None:
decrypted_response = self.encryption_session.decrypt(response_data) decrypted_response = self._encryption_session.decrypt(response_data)
json_payload = json_loads(decrypted_response) json_payload = json_loads(decrypted_response)
@ -411,12 +372,66 @@ class TPLinkKlap(TPLinkProtocol):
return json_payload return json_payload
async def close(self) -> None: async def close(self) -> None:
"""Close the protocol.""" """Close the transport."""
client = self.http_client client = self._http_client
self.http_client = None self._http_client = None
if client: if client:
await client.aclose() await client.aclose()
@staticmethod
def generate_auth_hash(creds: Credentials):
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
un = creds.username or ""
pw = creds.password or ""
return md5(md5(un.encode()) + md5(pw.encode()))
@staticmethod
def handshake1_seed_auth_hash(
local_seed: bytes, remote_seed: bytes, auth_hash: bytes
):
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
return _sha256(local_seed + auth_hash)
@staticmethod
def handshake2_seed_auth_hash(
local_seed: bytes, remote_seed: bytes, auth_hash: bytes
):
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
return _sha256(remote_seed + auth_hash)
@staticmethod
def generate_owner_hash(creds: Credentials):
"""Return the MD5 hash of the username in this object."""
un = creds.username or ""
return md5(un.encode())
class TPlinkKlapTransportV2(KlapTransport):
"""Implementation of the KLAP encryption protocol with v2 hanshake hashes."""
@staticmethod
def generate_auth_hash(creds: Credentials):
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
un = creds.username or ""
pw = creds.password or ""
return _sha256(_sha1(un.encode()) + _sha1(pw.encode()))
@staticmethod
def handshake1_seed_auth_hash(
local_seed: bytes, remote_seed: bytes, auth_hash: bytes
):
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
return _sha256(local_seed + remote_seed + auth_hash)
@staticmethod
def handshake2_seed_auth_hash(
local_seed: bytes, remote_seed: bytes, auth_hash: bytes
):
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
return _sha256(remote_seed + local_seed + auth_hash)
class KlapEncryptionSession: class KlapEncryptionSession:
"""Class to represent an encryption session and it's internal state. """Class to represent an encryption session and it's internal state.

View File

@ -22,6 +22,7 @@ from typing import Dict, Generator, Optional, Union
# When support for cpython older than 3.11 is dropped # When support for cpython older than 3.11 is dropped
# async_timeout can be replaced with asyncio.timeout # async_timeout can be replaced with asyncio.timeout
from async_timeout import timeout as asyncio_timeout from async_timeout import timeout as asyncio_timeout
from cryptography.hazmat.primitives import hashes
from .credentials import Credentials from .credentials import Credentials
from .exceptions import SmartDeviceException from .exceptions import SmartDeviceException
@ -32,6 +33,56 @@ _LOGGER = logging.getLogger(__name__)
_NO_RETRY_ERRORS = {errno.EHOSTDOWN, errno.EHOSTUNREACH, errno.ECONNREFUSED} _NO_RETRY_ERRORS = {errno.EHOSTDOWN, errno.EHOSTUNREACH, errno.ECONNREFUSED}
def md5(payload: bytes) -> bytes:
"""Return an md5 hash of the payload."""
digest = hashes.Hash(hashes.MD5()) # noqa: S303
digest.update(payload)
hash = digest.finalize()
return hash
class BaseTransport(ABC):
"""Base class for all TP-Link protocol transports."""
def __init__(
self,
host: str,
*,
port: Optional[int] = None,
credentials: Optional[Credentials] = None,
) -> None:
"""Create a protocol object."""
self.host = host
self.port = port
self.credentials = credentials
@property
@abstractmethod
def needs_handshake(self) -> bool:
"""Return true if the transport needs to do a handshake."""
@property
@abstractmethod
def needs_login(self) -> bool:
"""Return true if the transport needs to do a login."""
@abstractmethod
async def login(self, request: str) -> None:
"""Login to the device."""
@abstractmethod
async def handshake(self) -> None:
"""Perform the encryption handshake."""
@abstractmethod
async def send(self, request: str) -> Dict:
"""Send a message to the device and return a response."""
@abstractmethod
async def close(self) -> None:
"""Close the transport. Abstract method to be overriden."""
class TPLinkProtocol(ABC): class TPLinkProtocol(ABC):
"""Base class for all TP-Link Smart Home communication.""" """Base class for all TP-Link Smart Home communication."""
@ -41,6 +92,7 @@ class TPLinkProtocol(ABC):
*, *,
port: Optional[int] = None, port: Optional[int] = None,
credentials: Optional[Credentials] = None, credentials: Optional[Credentials] = None,
transport: Optional[BaseTransport] = None,
) -> None: ) -> None:
"""Create a protocol object.""" """Create a protocol object."""
self.host = host self.host = host

View File

@ -365,6 +365,7 @@ class SmartDevice:
def update_from_discover_info(self, info: Dict[str, Any]) -> None: def update_from_discover_info(self, info: Dict[str, Any]) -> None:
"""Update state from info from the discover call.""" """Update state from info from the discover call."""
self._discovery_info = info
if "system" in info and (sys_info := info["system"].get("get_sysinfo")): if "system" in info and (sys_info := info["system"].get("get_sysinfo")):
self._last_update = info self._last_update = info
self._set_sys_info(sys_info) self._set_sys_info(sys_info)
@ -372,7 +373,6 @@ class SmartDevice:
# This allows setting of some info properties directly # This allows setting of some info properties directly
# from partial discovery info that will then be found # from partial discovery info that will then be found
# by the requires_update decorator # by the requires_update decorator
self._discovery_info = info
self._set_sys_info(info) self._set_sys_info(info)
def _set_sys_info(self, sys_info: Dict[str, Any]) -> None: def _set_sys_info(self, sys_info: Dict[str, Any]) -> None:

219
kasa/smartprotocol.py Normal file
View File

@ -0,0 +1,219 @@
"""Implementation of the TP-Link AES Protocol.
Based on the work of https://github.com/petretiandrea/plugp100
under compatible GNU GPL3 license.
"""
import asyncio
import base64
import logging
import time
import uuid
from pprint import pformat as pf
from typing import Dict, Optional, Union
import httpx
from .aestransport import AesTransport
from .credentials import Credentials
from .exceptions import AuthenticationException, SmartDeviceException
from .json import dumps as json_dumps
from .protocol import BaseTransport, TPLinkProtocol, md5
_LOGGER = logging.getLogger(__name__)
logging.getLogger("httpx").propagate = False
class SmartProtocol(TPLinkProtocol):
"""Class for the new TPLink SMART protocol."""
DEFAULT_PORT = 80
def __init__(
self,
host: str,
*,
transport: Optional[BaseTransport] = None,
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
) -> None:
super().__init__(host=host, port=self.DEFAULT_PORT)
self._credentials: Credentials = credentials or Credentials(
username="", password=""
)
self._transport: BaseTransport = transport or AesTransport(
host, credentials=self._credentials, timeout=timeout
)
self._terminal_uuid: Optional[str] = None
self._request_id_generator = SnowflakeId(1, 1)
self._query_lock = asyncio.Lock()
def get_smart_request(self, method, params=None) -> str:
"""Get a request message as a string."""
request = {
"method": method,
"params": params,
"requestID": self._request_id_generator.generate_id(),
"request_time_milis": round(time.time() * 1000),
"terminal_uuid": self._terminal_uuid,
}
return json_dumps(request)
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
"""Query the device retrying for retry_count on failure."""
async with self._query_lock:
resp_dict = await self._query(request, retry_count)
if "result" in resp_dict:
return resp_dict["result"]
return {}
async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
for retry in range(retry_count + 1):
try:
return await self._execute_query(request, retry)
except httpx.CloseError as sdex:
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {sdex}"
) from sdex
continue
except httpx.ConnectError as cex:
await self.close()
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {cex}"
) from cex
except TimeoutError as tex:
await self.close()
raise SmartDeviceException(
f"Unable to connect to the device, timed out: {self.host}: {tex}"
) from tex
except AuthenticationException as auex:
_LOGGER.debug("Unable to authenticate with %s, not retrying", self.host)
raise auex
except Exception as ex:
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {ex}"
) from ex
continue
# make mypy happy, this should never be reached..
raise SmartDeviceException("Query reached somehow to unreachable")
async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> Dict:
if isinstance(request, dict):
smart_method = next(iter(request))
smart_params = request[smart_method]
else:
smart_method = request
smart_params = None
if self._transport.needs_handshake:
await self._transport.handshake()
if self._transport.needs_login:
self._terminal_uuid = base64.b64encode(md5(uuid.uuid4().bytes)).decode(
"UTF-8"
)
login_request = self.get_smart_request("login_device")
await self._transport.login(login_request)
smart_request = self.get_smart_request(smart_method, smart_params)
response_data = await self._transport.send(smart_request)
_LOGGER.debug(
"%s << %s",
self.host,
_LOGGER.isEnabledFor(logging.DEBUG) and pf(response_data),
)
return response_data
async def close(self) -> None:
"""Close the protocol."""
await self._transport.close()
class SnowflakeId:
"""Class for generating snowflake ids."""
EPOCH = 1420041600000 # Custom epoch (in milliseconds)
WORKER_ID_BITS = 5
DATA_CENTER_ID_BITS = 5
SEQUENCE_BITS = 12
MAX_WORKER_ID = (1 << WORKER_ID_BITS) - 1
MAX_DATA_CENTER_ID = (1 << DATA_CENTER_ID_BITS) - 1
SEQUENCE_MASK = (1 << SEQUENCE_BITS) - 1
def __init__(self, worker_id, data_center_id):
if worker_id > SnowflakeId.MAX_WORKER_ID or worker_id < 0:
raise ValueError(
"Worker ID can't be greater than "
+ str(SnowflakeId.MAX_WORKER_ID)
+ " or less than 0"
)
if data_center_id > SnowflakeId.MAX_DATA_CENTER_ID or data_center_id < 0:
raise ValueError(
"Data center ID can't be greater than "
+ str(SnowflakeId.MAX_DATA_CENTER_ID)
+ " or less than 0"
)
self.worker_id = worker_id
self.data_center_id = data_center_id
self.sequence = 0
self.last_timestamp = -1
def generate_id(self):
"""Generate a snowflake id."""
timestamp = self._current_millis()
if timestamp < self.last_timestamp:
raise ValueError("Clock moved backwards. Refusing to generate ID.")
if timestamp == self.last_timestamp:
# Within the same millisecond, increment the sequence number
self.sequence = (self.sequence + 1) & SnowflakeId.SEQUENCE_MASK
if self.sequence == 0:
# Sequence exceeds its bit range, wait until the next millisecond
timestamp = self._wait_next_millis(self.last_timestamp)
else:
# New millisecond, reset the sequence number
self.sequence = 0
# Update the last timestamp
self.last_timestamp = timestamp
# Generate and return the final ID
return (
(
(timestamp - SnowflakeId.EPOCH)
<< (
SnowflakeId.WORKER_ID_BITS
+ SnowflakeId.SEQUENCE_BITS
+ SnowflakeId.DATA_CENTER_ID_BITS
)
)
| (
self.data_center_id
<< (SnowflakeId.SEQUENCE_BITS + SnowflakeId.WORKER_ID_BITS)
)
| (self.worker_id << SnowflakeId.SEQUENCE_BITS)
| self.sequence
)
def _current_millis(self):
return round(time.time() * 1000)
def _wait_next_millis(self, last_timestamp):
timestamp = self._current_millis()
while timestamp <= last_timestamp:
timestamp = self._current_millis()
return timestamp

View File

@ -4,10 +4,10 @@ import logging
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Any, Dict, Optional, Set, cast from typing import Any, Dict, Optional, Set, cast
from ..aesprotocol import TPLinkAes
from ..credentials import Credentials from ..credentials import Credentials
from ..exceptions import AuthenticationException from ..exceptions import AuthenticationException
from ..smartdevice import SmartDevice from ..smartdevice import SmartDevice
from ..smartprotocol import SmartProtocol
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -26,7 +26,7 @@ class TapoDevice(SmartDevice):
super().__init__(host, port=port, credentials=credentials, timeout=timeout) super().__init__(host, port=port, credentials=credentials, timeout=timeout)
self._state_information: Dict[str, Any] = {} self._state_information: Dict[str, Any] = {}
self._discovery_info: Optional[Dict[str, Any]] = None self._discovery_info: Optional[Dict[str, Any]] = None
self.protocol = TPLinkAes(host, credentials=credentials, timeout=timeout) self.protocol = SmartProtocol(host, credentials=credentials, timeout=timeout)
async def update(self, update_children: bool = True): async def update(self, update_children: bool = True):
"""Update the device.""" """Update the device."""

View File

@ -2,27 +2,45 @@ import asyncio
import glob import glob
import json import json
import os import os
from dataclasses import dataclass
from json import dumps as json_dumps
from os.path import basename from os.path import basename
from pathlib import Path, PurePath from pathlib import Path, PurePath
from typing import Dict from typing import Dict, Optional
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest # type: ignore # see https://github.com/pytest-dev/pytest/issues/3342 import pytest # type: ignore # see https://github.com/pytest-dev/pytest/issues/3342
from kasa import ( from kasa import (
Credentials,
Discover, Discover,
SmartBulb, SmartBulb,
SmartDimmer, SmartDimmer,
SmartLightStrip, SmartLightStrip,
SmartPlug, SmartPlug,
SmartStrip, SmartStrip,
TPLinkSmartHomeProtocol,
) )
from kasa.tapo import TapoDevice, TapoPlug
from .newfakes import FakeTransportProtocol from .newfakes import FakeSmartProtocol, FakeTransportProtocol
SUPPORTED_DEVICES = glob.glob( SUPPORTED_IOT_DEVICES = [
(device, "IOT")
for device in glob.glob(
os.path.dirname(os.path.abspath(__file__)) + "/fixtures/*.json" os.path.dirname(os.path.abspath(__file__)) + "/fixtures/*.json"
) )
]
SUPPORTED_SMART_DEVICES = [
(device, "SMART")
for device in glob.glob(
os.path.dirname(os.path.abspath(__file__)) + "/fixtures/smart/*.json"
)
]
SUPPORTED_DEVICES = SUPPORTED_IOT_DEVICES + SUPPORTED_SMART_DEVICES
LIGHT_STRIPS = {"KL400", "KL430", "KL420"} LIGHT_STRIPS = {"KL400", "KL430", "KL420"}
@ -55,43 +73,59 @@ PLUGS = {
"KP401", "KP401",
"KS200M", "KS200M",
} }
STRIPS = {"HS107", "HS300", "KP303", "KP200", "KP400", "EP40"} STRIPS = {"HS107", "HS300", "KP303", "KP200", "KP400", "EP40"}
DIMMERS = {"ES20M", "HS220", "KS220M", "KS230", "KP405"} DIMMERS = {"ES20M", "HS220", "KS220M", "KS230", "KP405"}
DIMMABLE = {*BULBS, *DIMMERS} DIMMABLE = {*BULBS, *DIMMERS}
WITH_EMETER = {"HS110", "HS300", "KP115", "KP125", *BULBS} WITH_EMETER = {"HS110", "HS300", "KP115", "KP125", *BULBS}
ALL_DEVICES = BULBS.union(PLUGS).union(STRIPS).union(DIMMERS) ALL_DEVICES_IOT = BULBS.union(PLUGS).union(STRIPS).union(DIMMERS)
PLUGS_SMART = {"P110"}
ALL_DEVICES_SMART = PLUGS_SMART
ALL_DEVICES = ALL_DEVICES_IOT.union(ALL_DEVICES_SMART)
IP_MODEL_CACHE: Dict[str, str] = {} IP_MODEL_CACHE: Dict[str, str] = {}
def filter_model(desc, filter): def idgenerator(paramtuple):
filtered = list() return basename(paramtuple[0]) + (
for dev in SUPPORTED_DEVICES: "" if paramtuple[1] == "IOT" else "-" + paramtuple[1]
for filt in filter: )
if filt in basename(dev):
filtered.append(dev)
filtered_basenames = [basename(f) for f in filtered]
def filter_model(desc, model_filter, protocol_filter=None):
if not protocol_filter:
protocol_filter = {"IOT"}
filtered = list()
for file, protocol in SUPPORTED_DEVICES:
if protocol in protocol_filter:
file_model = basename(file).split("_")[0]
for model in model_filter:
if model in file_model:
filtered.append((file, protocol))
filtered_basenames = [basename(f) + "-" + p for f, p in filtered]
print(f"{desc}: {filtered_basenames}") print(f"{desc}: {filtered_basenames}")
return filtered return filtered
def parametrize(desc, devices, ids=None): def parametrize(desc, devices, protocol_filter=None, ids=None):
return pytest.mark.parametrize( return pytest.mark.parametrize(
"dev", filter_model(desc, devices), indirect=True, ids=ids "dev", filter_model(desc, devices, protocol_filter), indirect=True, ids=ids
) )
has_emeter = parametrize("has emeter", WITH_EMETER) has_emeter = parametrize("has emeter", WITH_EMETER)
no_emeter = parametrize("no emeter", ALL_DEVICES - WITH_EMETER) no_emeter = parametrize("no emeter", ALL_DEVICES_IOT - WITH_EMETER)
bulb = parametrize("bulbs", BULBS, ids=basename) bulb = parametrize("bulbs", BULBS, ids=idgenerator)
plug = parametrize("plugs", PLUGS, ids=basename) plug = parametrize("plugs", PLUGS, ids=idgenerator)
strip = parametrize("strips", STRIPS, ids=basename) strip = parametrize("strips", STRIPS, ids=idgenerator)
dimmer = parametrize("dimmers", DIMMERS, ids=basename) dimmer = parametrize("dimmers", DIMMERS, ids=idgenerator)
lightstrip = parametrize("lightstrips", LIGHT_STRIPS, ids=basename) lightstrip = parametrize("lightstrips", LIGHT_STRIPS, ids=idgenerator)
# bulb types # bulb types
dimmable = parametrize("dimmable", DIMMABLE) dimmable = parametrize("dimmable", DIMMABLE)
@ -101,6 +135,58 @@ non_variable_temp = parametrize("non-variable color temp", BULBS - VARIABLE_TEMP
color_bulb = parametrize("color bulbs", COLOR_BULBS) color_bulb = parametrize("color bulbs", COLOR_BULBS)
non_color_bulb = parametrize("non-color bulbs", BULBS - COLOR_BULBS) non_color_bulb = parametrize("non-color bulbs", BULBS - COLOR_BULBS)
plug_smart = parametrize(
"plug devices smart", PLUGS_SMART, protocol_filter={"SMART"}, ids=idgenerator
)
device_smart = parametrize(
"devices smart", ALL_DEVICES_SMART, protocol_filter={"SMART"}, ids=idgenerator
)
device_iot = parametrize(
"devices iot", ALL_DEVICES_IOT, protocol_filter={"IOT"}, ids=idgenerator
)
def get_fixture_data():
"""Return raw discovery file contents as JSON. Used for discovery tests."""
fixture_data = {}
for file, protocol in SUPPORTED_DEVICES:
p = Path(file)
if not p.is_absolute():
folder = Path(__file__).parent / "fixtures"
if protocol == "SMART":
folder = folder / "smart"
p = folder / file
with open(p) as f:
fixture_data[basename(p)] = json.load(f)
return fixture_data
FIXTURE_DATA = get_fixture_data()
def filter_fixtures(desc, root_filter):
filtered = {}
for key, val in FIXTURE_DATA.items():
if root_filter in val:
filtered[key] = val
print(f"{desc}: {filtered.keys()}")
return filtered
def parametrize_discovery(desc, root_key):
filtered_fixtures = filter_fixtures(desc, root_key)
return pytest.mark.parametrize(
"discovery_data",
filtered_fixtures.values(),
indirect=True,
ids=filtered_fixtures.keys(),
)
new_discovery = parametrize_discovery("new discovery", "discovery_result")
def check_categories(): def check_categories():
"""Check that every fixture file is categorized.""" """Check that every fixture file is categorized."""
@ -110,15 +196,15 @@ def check_categories():
+ plug.args[1] + plug.args[1]
+ bulb.args[1] + bulb.args[1]
+ lightstrip.args[1] + lightstrip.args[1]
+ plug_smart.args[1]
) )
diff = set(SUPPORTED_DEVICES) - set(categorized_fixtures) diff = set(SUPPORTED_DEVICES) - set(categorized_fixtures)
if diff: if diff:
for file in diff: for file, protocol in diff:
print( print(
"No category for file %s, add to the corresponding set (BULBS, PLUGS, ..)" f"No category for file {file} protocol {protocol}, add to the corresponding set (BULBS, PLUGS, ..)"
% file
) )
raise Exception("Missing category for %s" % diff) raise Exception(f"Missing category for {diff}")
check_categories() check_categories()
@ -134,7 +220,12 @@ async def handle_turn_on(dev, turn_on):
await dev.turn_off() await dev.turn_off()
def device_for_file(model): def device_for_file(model, protocol):
if protocol == "SMART":
for d in PLUGS_SMART:
if d in model:
return TapoPlug
else:
for d in STRIPS: for d in STRIPS:
if d in model: if d in model:
return SmartStrip return SmartStrip
@ -170,11 +261,14 @@ async def _discover_update_and_close(ip):
return await _update_and_close(d) return await _update_and_close(d)
async def get_device_for_file(file): async def get_device_for_file(file, protocol):
# if the wanted file is not an absolute path, prepend the fixtures directory # if the wanted file is not an absolute path, prepend the fixtures directory
p = Path(file) p = Path(file)
if not p.is_absolute(): if not p.is_absolute():
p = Path(__file__).parent / "fixtures" / file folder = Path(__file__).parent / "fixtures"
if protocol == "SMART":
folder = folder / "smart"
p = folder / file
def load_file(): def load_file():
with open(p) as f: with open(p) as f:
@ -184,7 +278,11 @@ async def get_device_for_file(file):
sysinfo = await loop.run_in_executor(None, load_file) sysinfo = await loop.run_in_executor(None, load_file)
model = basename(file) model = basename(file)
d = device_for_file(model)(host="127.0.0.123") d = device_for_file(model, protocol)(host="127.0.0.123")
if protocol == "SMART":
d.protocol = FakeSmartProtocol(sysinfo)
d.credentials = Credentials("", "")
else:
d.protocol = FakeTransportProtocol(sysinfo) d.protocol = FakeTransportProtocol(sysinfo)
await _update_and_close(d) await _update_and_close(d)
return d return d
@ -197,7 +295,7 @@ async def dev(request):
Provides a device (given --ip) or parametrized fixture for the supported devices. Provides a device (given --ip) or parametrized fixture for the supported devices.
The initial update is called automatically before returning the device. The initial update is called automatically before returning the device.
""" """
file = request.param file, protocol = request.param
ip = request.config.getoption("--ip") ip = request.config.getoption("--ip")
if ip: if ip:
@ -210,19 +308,62 @@ async def dev(request):
pytest.skip(f"skipping file {file}") pytest.skip(f"skipping file {file}")
return d if d else await _discover_update_and_close(ip) return d if d else await _discover_update_and_close(ip)
return await get_device_for_file(file) return await get_device_for_file(file, protocol)
@pytest.fixture(params=SUPPORTED_DEVICES, scope="session") @pytest.fixture
def discovery_mock(discovery_data, mocker):
@dataclass
class _DiscoveryMock:
ip: str
default_port: int
discovery_data: dict
port_override: Optional[int] = None
if "result" in discovery_data:
datagram = (
b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8"
+ json_dumps(discovery_data).encode()
)
dm = _DiscoveryMock("127.0.0.123", 20002, discovery_data)
else:
datagram = TPLinkSmartHomeProtocol.encrypt(json_dumps(discovery_data))[4:]
dm = _DiscoveryMock("127.0.0.123", 9999, discovery_data)
def mock_discover(self):
port = (
dm.port_override
if dm.port_override and dm.default_port != 20002
else dm.default_port
)
self.datagram_received(
datagram,
(dm.ip, port),
)
mocker.patch("kasa.discover._DiscoverProtocol.do_discover", mock_discover)
mocker.patch(
"socket.getaddrinfo",
side_effect=lambda *_, **__: [(None, None, None, None, (dm.ip, 0))],
)
yield dm
@pytest.fixture(params=FIXTURE_DATA.values(), ids=FIXTURE_DATA.keys(), scope="session")
def discovery_data(request): def discovery_data(request):
"""Return raw discovery file contents as JSON. Used for discovery tests.""" """Return raw discovery file contents as JSON. Used for discovery tests."""
file = request.param fixture_data = request.param
p = Path(file) if "discovery_result" in fixture_data:
if not p.is_absolute(): return {"result": fixture_data["discovery_result"]}
p = Path(__file__).parent / "fixtures" / file else:
return {"system": {"get_sysinfo": fixture_data["system"]["get_sysinfo"]}}
with open(p) as f:
return json.load(f) @pytest.fixture(params=FIXTURE_DATA.values(), ids=FIXTURE_DATA.keys(), scope="session")
def all_fixture_data(request):
"""Return raw fixture file contents as JSON. Used for discovery tests."""
fixture_data = request.param
return fixture_data
def pytest_addoption(parser): def pytest_addoption(parser):

View File

@ -0,0 +1,180 @@
{
"component_nego": {
"component_list": [
{
"id": "device",
"ver_code": 2
},
{
"id": "firmware",
"ver_code": 2
},
{
"id": "quick_setup",
"ver_code": 3
},
{
"id": "time",
"ver_code": 1
},
{
"id": "wireless",
"ver_code": 1
},
{
"id": "schedule",
"ver_code": 2
},
{
"id": "countdown",
"ver_code": 2
},
{
"id": "antitheft",
"ver_code": 1
},
{
"id": "account",
"ver_code": 1
},
{
"id": "synchronize",
"ver_code": 1
},
{
"id": "sunrise_sunset",
"ver_code": 1
},
{
"id": "led",
"ver_code": 1
},
{
"id": "cloud_connect",
"ver_code": 1
},
{
"id": "iot_cloud",
"ver_code": 1
},
{
"id": "device_local_time",
"ver_code": 1
},
{
"id": "default_states",
"ver_code": 1
},
{
"id": "auto_off",
"ver_code": 2
},
{
"id": "localSmart",
"ver_code": 1
},
{
"id": "energy_monitoring",
"ver_code": 2
},
{
"id": "power_protection",
"ver_code": 1
},
{
"id": "current_protection",
"ver_code": 1
}
]
},
"discovery_result": {
"device_id": "00000000000000000000000000000000",
"device_model": "P110(UK)",
"device_type": "SMART.TAPOPLUG",
"factory_default": false,
"ip": "127.0.0.123",
"is_support_iot_cloud": true,
"mac": "00-00-00-00-00-00",
"mgt_encrypt_schm": {
"encrypt_type": "KLAP",
"http_port": 80,
"is_support_https": false,
"lv": 2
},
"obd_src": "tplink",
"owner": "00000000000000000000000000000000"
},
"get_current_power": {
"current_power": 0
},
"get_device_info": {
"auto_off_remain_time": 0,
"auto_off_status": "off",
"avatar": "plug",
"default_states": {
"state": {},
"type": "last_states"
},
"device_id": "0000000000000000000000000000000000000000",
"device_on": true,
"fw_id": "00000000000000000000000000000000",
"fw_ver": "1.3.0 Build 230905 Rel.152200",
"has_set_location_info": true,
"hw_id": "00000000000000000000000000000000",
"hw_ver": "1.0",
"ip": "127.0.0.123",
"lang": "en_US",
"latitude": 0,
"longitude": 0,
"mac": "00-00-00-00-00-00",
"model": "P110",
"nickname": "VGFwaSBTbWFydCBQbHVnIDE=",
"oem_id": "00000000000000000000000000000000",
"on_time": 119335,
"overcurrent_status": "normal",
"overheated": false,
"power_protection_status": "normal",
"region": "Europe/London",
"rssi": -57,
"signal_level": 2,
"specs": "",
"ssid": "IyNNQVNLRUROQU1FIyM=",
"time_diff": 0,
"type": "SMART.TAPOPLUG"
},
"get_device_time": {
"region": "Europe/London",
"time_diff": 0,
"timestamp": 1701370224
},
"get_device_usage": {
"power_usage": {
"past30": 75,
"past7": 69,
"today": 0
},
"saved_power": {
"past30": 2029,
"past7": 1964,
"today": 1130
},
"time_usage": {
"past30": 2104,
"past7": 2033,
"today": 1130
}
},
"get_energy_usage": {
"current_power": 0,
"electricity_charge": [
0,
0,
0
],
"local_time": "2023-11-30 18:50:24",
"month_energy": 75,
"month_runtime": 2104,
"today_energy": 0,
"today_runtime": 1130
}
}

View File

@ -1,6 +1,7 @@
import copy import copy
import logging import logging
import re import re
from json import loads as json_loads
from voluptuous import ( from voluptuous import (
REMOVE_EXTRA, REMOVE_EXTRA,
@ -13,7 +14,8 @@ from voluptuous import (
Schema, Schema,
) )
from ..protocol import TPLinkSmartHomeProtocol from ..protocol import BaseTransport, TPLinkSmartHomeProtocol
from ..smartprotocol import SmartProtocol
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -285,6 +287,41 @@ TIME_MODULE = {
} }
class FakeSmartProtocol(SmartProtocol):
def __init__(self, info):
super().__init__("127.0.0.123", transport=FakeSmartTransport(info))
class FakeSmartTransport(BaseTransport):
def __init__(self, info):
self.info = info
@property
def needs_handshake(self) -> bool:
return False
@property
def needs_login(self) -> bool:
return False
async def login(self, request: str) -> None:
pass
async def handshake(self) -> None:
pass
async def send(self, request: str):
request_dict = json_loads(request)
method = request_dict["method"]
if method == "component_nego" or method[:4] == "get_":
return self.info[method]
elif method[:4] == "set_":
_LOGGER.debug("Call %s not implemented, doing nothing", method)
async def close(self) -> None:
pass
class FakeTransportProtocol(TPLinkSmartHomeProtocol): class FakeTransportProtocol(TPLinkSmartHomeProtocol):
def __init__(self, info): def __init__(self, info):
self.discovery_data = info self.discovery_data = info

View File

@ -6,12 +6,15 @@ from asyncclick.testing import CliRunner
from kasa import SmartDevice, TPLinkSmartHomeProtocol from kasa import SmartDevice, TPLinkSmartHomeProtocol
from kasa.cli import alias, brightness, cli, emeter, raw_command, state, sysinfo, toggle from kasa.cli import alias, brightness, cli, emeter, raw_command, state, sysinfo, toggle
from kasa.device_factory import DEVICE_TYPE_TO_CLASS
from kasa.discover import Discover from kasa.discover import Discover
from kasa.smartprotocol import SmartProtocol
from .conftest import handle_turn_on, turn_on from .conftest import device_iot, handle_turn_on, new_discovery, turn_on
from .newfakes import FakeTransportProtocol from .newfakes import FakeSmartProtocol, FakeTransportProtocol
@device_iot
async def test_sysinfo(dev): async def test_sysinfo(dev):
runner = CliRunner() runner = CliRunner()
res = await runner.invoke(sysinfo, obj=dev) res = await runner.invoke(sysinfo, obj=dev)
@ -19,6 +22,7 @@ async def test_sysinfo(dev):
assert dev.alias in res.output assert dev.alias in res.output
@device_iot
@turn_on @turn_on
async def test_state(dev, turn_on): async def test_state(dev, turn_on):
await handle_turn_on(dev, turn_on) await handle_turn_on(dev, turn_on)
@ -32,6 +36,7 @@ async def test_state(dev, turn_on):
assert "Device state: False" in res.output assert "Device state: False" in res.output
@device_iot
@turn_on @turn_on
async def test_toggle(dev, turn_on, mocker): async def test_toggle(dev, turn_on, mocker):
await handle_turn_on(dev, turn_on) await handle_turn_on(dev, turn_on)
@ -44,6 +49,7 @@ async def test_toggle(dev, turn_on, mocker):
assert dev.is_on assert dev.is_on
@device_iot
async def test_alias(dev): async def test_alias(dev):
runner = CliRunner() runner = CliRunner()
@ -62,6 +68,7 @@ async def test_alias(dev):
await dev.set_alias(old_alias) await dev.set_alias(old_alias)
@device_iot
async def test_raw_command(dev): async def test_raw_command(dev):
runner = CliRunner() runner = CliRunner()
res = await runner.invoke(raw_command, ["system", "get_sysinfo"], obj=dev) res = await runner.invoke(raw_command, ["system", "get_sysinfo"], obj=dev)
@ -74,6 +81,7 @@ async def test_raw_command(dev):
assert "Usage" in res.output assert "Usage" in res.output
@device_iot
async def test_emeter(dev: SmartDevice, mocker): async def test_emeter(dev: SmartDevice, mocker):
runner = CliRunner() runner = CliRunner()
@ -99,6 +107,7 @@ async def test_emeter(dev: SmartDevice, mocker):
daily.assert_called_with(year=1900, month=12) daily.assert_called_with(year=1900, month=12)
@device_iot
async def test_brightness(dev): async def test_brightness(dev):
runner = CliRunner() runner = CliRunner()
res = await runner.invoke(brightness, obj=dev) res = await runner.invoke(brightness, obj=dev)
@ -116,6 +125,7 @@ async def test_brightness(dev):
assert "Brightness: 12" in res.output assert "Brightness: 12" in res.output
@device_iot
async def test_json_output(dev: SmartDevice, mocker): async def test_json_output(dev: SmartDevice, mocker):
"""Test that the json output produces correct output.""" """Test that the json output produces correct output."""
mocker.patch("kasa.Discover.discover", return_value=[dev]) mocker.patch("kasa.Discover.discover", return_value=[dev])
@ -125,13 +135,9 @@ async def test_json_output(dev: SmartDevice, mocker):
assert json.loads(res.output) == dev.internal_state assert json.loads(res.output) == dev.internal_state
async def test_credentials(discovery_data: dict, mocker): @new_discovery
async def test_credentials(discovery_mock, mocker):
"""Test credentials are passed correctly from cli to device.""" """Test credentials are passed correctly from cli to device."""
# As this is testing the device constructor need to explicitly wire in
# the FakeTransportProtocol
ftp = FakeTransportProtocol(discovery_data)
mocker.patch.object(TPLinkSmartHomeProtocol, "query", ftp.query)
# Patch state to echo username and password # Patch state to echo username and password
pass_dev = click.make_pass_decorator(SmartDevice) pass_dev = click.make_pass_decorator(SmartDevice)
@ -143,18 +149,15 @@ async def test_credentials(discovery_data: dict, mocker):
) )
mocker.patch("kasa.cli.state", new=_state) mocker.patch("kasa.cli.state", new=_state)
cli_device_type = Discover._get_device_class(discovery_data)( for subclass in DEVICE_TYPE_TO_CLASS.values():
"any" mocker.patch.object(subclass, "update")
).device_type.value
runner = CliRunner() runner = CliRunner()
res = await runner.invoke( res = await runner.invoke(
cli, cli,
[ [
"--host", "--host",
"127.0.0.1", "127.0.0.123",
"--type",
cli_device_type,
"--username", "--username",
"foo", "foo",
"--password", "--password",
@ -162,9 +165,11 @@ async def test_credentials(discovery_data: dict, mocker):
], ],
) )
assert res.exit_code == 0 assert res.exit_code == 0
assert res.output == "Username:foo Password:bar\n"
assert "Username:foo Password:bar\n" in res.output
@device_iot
async def test_without_device_type(discovery_data: dict, dev, mocker): async def test_without_device_type(discovery_data: dict, dev, mocker):
"""Test connecting without the device type.""" """Test connecting without the device type."""
runner = CliRunner() runner = CliRunner()

View File

@ -5,7 +5,9 @@ from typing import Type
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
from kasa import ( from kasa import (
Credentials,
DeviceType, DeviceType,
Discover,
SmartBulb, SmartBulb,
SmartDevice, SmartDevice,
SmartDeviceException, SmartDeviceException,
@ -13,8 +15,13 @@ from kasa import (
SmartLightStrip, SmartLightStrip,
SmartPlug, SmartPlug,
) )
from kasa.device_factory import connect from kasa.device_factory import (
from kasa.klapprotocol import TPLinkKlap DEVICE_TYPE_TO_CLASS,
connect,
get_protocol_from_connection_name,
)
from kasa.discover import DiscoveryResult
from kasa.iotprotocol import IotProtocol
from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol
@ -22,8 +29,12 @@ from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol
async def test_connect(discovery_data: dict, mocker, custom_port): async def test_connect(discovery_data: dict, mocker, custom_port):
"""Make sure that connect returns an initialized SmartDevice instance.""" """Make sure that connect returns an initialized SmartDevice instance."""
host = "127.0.0.1" host = "127.0.0.1"
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
if "result" in discovery_data:
with pytest.raises(SmartDeviceException):
dev = await connect(host, port=custom_port)
else:
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
dev = await connect(host, port=custom_port) dev = await connect(host, port=custom_port)
assert issubclass(dev.__class__, SmartDevice) assert issubclass(dev.__class__, SmartDevice)
assert dev.port == custom_port or dev.port == 9999 assert dev.port == custom_port or dev.port == 9999
@ -49,8 +60,12 @@ async def test_connect_passed_device_type(
): ):
"""Make sure that connect with a passed device type.""" """Make sure that connect with a passed device type."""
host = "127.0.0.1" host = "127.0.0.1"
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
if "result" in discovery_data:
with pytest.raises(SmartDeviceException):
dev = await connect(host, port=custom_port)
else:
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
dev = await connect(host, port=custom_port, device_type=device_type) dev = await connect(host, port=custom_port, device_type=device_type)
assert isinstance(dev, klass) assert isinstance(dev, klass)
assert dev.port == custom_port or dev.port == 9999 assert dev.port == custom_port or dev.port == 9999
@ -70,32 +85,52 @@ async def test_connect_logs_connect_time(
): ):
"""Test that the connect time is logged when debug logging is enabled.""" """Test that the connect time is logged when debug logging is enabled."""
host = "127.0.0.1" host = "127.0.0.1"
if "result" in discovery_data:
with pytest.raises(SmartDeviceException):
await connect(host)
else:
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data) mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
logging.getLogger("kasa").setLevel(logging.DEBUG) logging.getLogger("kasa").setLevel(logging.DEBUG)
await connect(host) await connect(host)
assert "seconds to connect" in caplog.text assert "seconds to connect" in caplog.text
@pytest.mark.parametrize("device_type", [DeviceType.Plug, None])
@pytest.mark.parametrize(
("protocol_in", "protocol_result"),
(
(None, TPLinkSmartHomeProtocol),
(TPLinkKlap, TPLinkKlap),
(TPLinkSmartHomeProtocol, TPLinkSmartHomeProtocol),
),
)
async def test_connect_pass_protocol( async def test_connect_pass_protocol(
discovery_data: dict, all_fixture_data: dict,
mocker, mocker,
device_type: DeviceType,
protocol_in: Type[TPLinkProtocol],
protocol_result: Type[TPLinkProtocol],
): ):
"""Test that if the protocol is passed in it's gets set correctly.""" """Test that if the protocol is passed in it's gets set correctly."""
host = "127.0.0.1" if "discovery_result" in all_fixture_data:
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data) discovery_info = {"result": all_fixture_data["discovery_result"]}
mocker.patch("kasa.TPLinkKlap.query", return_value=discovery_data) device_class = Discover._get_device_class(discovery_info)
else:
device_class = Discover._get_device_class(all_fixture_data)
dev = await connect(host, device_type=device_type, protocol_class=protocol_in) device_type = list(DEVICE_TYPE_TO_CLASS.keys())[
assert isinstance(dev.protocol, protocol_result) list(DEVICE_TYPE_TO_CLASS.values()).index(device_class)
]
host = "127.0.0.1"
if "discovery_result" in all_fixture_data:
mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data)
mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data)
dr = DiscoveryResult(**discovery_info["result"])
connection_name = (
dr.device_type.split(".")[0] + "." + dr.mgt_encrypt_schm.encrypt_type
)
protocol_class = get_protocol_from_connection_name(
connection_name, host
).__class__
else:
mocker.patch(
"kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data
)
protocol_class = TPLinkSmartHomeProtocol
dev = await connect(
host,
device_type=device_type,
protocol_class=protocol_class,
credentials=Credentials("", ""),
)
assert isinstance(dev.protocol, protocol_class)

View File

@ -17,6 +17,27 @@ from kasa.exceptions import AuthenticationException, UnsupportedDeviceException
from .conftest import bulb, dimmer, lightstrip, plug, strip from .conftest import bulb, dimmer, lightstrip, plug, strip
UNSUPPORTED = {
"result": {
"device_id": "xx",
"owner": "xx",
"device_type": "SMART.TAPOXMASTREE",
"device_model": "P110(EU)",
"ip": "127.0.0.1",
"mac": "48-22xxx",
"is_support_iot_cloud": True,
"obd_src": "tplink",
"factory_default": False,
"mgt_encrypt_schm": {
"is_support_https": False,
"encrypt_type": "AES",
"http_port": 80,
"lv": 2,
},
},
"error_code": 0,
}
@plug @plug
async def test_type_detection_plug(dev: SmartDevice): async def test_type_detection_plug(dev: SmartDevice):
@ -62,76 +83,40 @@ async def test_type_unknown():
@pytest.mark.parametrize("custom_port", [123, None]) @pytest.mark.parametrize("custom_port", [123, None])
async def test_discover_single(discovery_data: dict, mocker, custom_port): # @pytest.mark.parametrize("discovery_mock", [("127.0.0.1",123), ("127.0.0.1",None)], indirect=True)
async def test_discover_single(discovery_mock, custom_port, mocker):
"""Make sure that discover_single returns an initialized SmartDevice instance.""" """Make sure that discover_single returns an initialized SmartDevice instance."""
host = "127.0.0.1" host = "127.0.0.1"
info = {"system": {"get_sysinfo": discovery_data["system"]["get_sysinfo"]}} discovery_mock.ip = host
query_mock = mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=info) discovery_mock.port_override = custom_port
update_mock = mocker.patch.object(SmartStrip, "update")
def mock_discover(self):
self.datagram_received(
protocol.TPLinkSmartHomeProtocol.encrypt(json_dumps(info))[4:],
(host, custom_port or 9999),
)
mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover)
x = await Discover.discover_single(host, port=custom_port) x = await Discover.discover_single(host, port=custom_port)
assert issubclass(x.__class__, SmartDevice) assert issubclass(x.__class__, SmartDevice)
assert x._sys_info is not None assert x._discovery_info is not None
assert x.port == custom_port or x.port == 9999 assert x.port == custom_port or x.port == discovery_mock.default_port
assert (query_mock.call_count > 0) == isinstance(x, SmartStrip) assert (update_mock.call_count > 0) == isinstance(x, SmartStrip)
async def test_discover_single_hostname(discovery_data: dict, mocker): async def test_discover_single_hostname(discovery_mock, mocker):
"""Make sure that discover_single returns an initialized SmartDevice instance.""" """Make sure that discover_single returns an initialized SmartDevice instance."""
host = "foobar" host = "foobar"
ip = "127.0.0.1" ip = "127.0.0.1"
info = {"system": {"get_sysinfo": discovery_data["system"]["get_sysinfo"]}}
query_mock = mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=info)
def mock_discover(self): discovery_mock.ip = ip
self.datagram_received( update_mock = mocker.patch.object(SmartStrip, "update")
protocol.TPLinkSmartHomeProtocol.encrypt(json_dumps(info))[4:],
(ip, 9999),
)
mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover)
mocker.patch("socket.getaddrinfo", return_value=[(None, None, None, None, (ip, 0))])
x = await Discover.discover_single(host) x = await Discover.discover_single(host)
assert issubclass(x.__class__, SmartDevice) assert issubclass(x.__class__, SmartDevice)
assert x._sys_info is not None assert x._discovery_info is not None
assert x.host == host assert x.host == host
assert (query_mock.call_count > 0) == isinstance(x, SmartStrip) assert (update_mock.call_count > 0) == isinstance(x, SmartStrip)
mocker.patch("socket.getaddrinfo", side_effect=socket.gaierror()) mocker.patch("socket.getaddrinfo", side_effect=socket.gaierror())
with pytest.raises(SmartDeviceException): with pytest.raises(SmartDeviceException):
x = await Discover.discover_single(host) x = await Discover.discover_single(host)
UNSUPPORTED = {
"result": {
"device_id": "xx",
"owner": "xx",
"device_type": "SMART.TAPOXMASTREE",
"device_model": "P110(EU)",
"ip": "127.0.0.1",
"mac": "48-22xxx",
"is_support_iot_cloud": True,
"obd_src": "tplink",
"factory_default": False,
"mgt_encrypt_schm": {
"is_support_https": False,
"encrypt_type": "AES",
"http_port": 80,
"lv": 2,
},
},
"error_code": 0,
}
async def test_discover_single_unsupported(mocker): async def test_discover_single_unsupported(mocker):
"""Make sure that discover_single handles unsupported devices correctly.""" """Make sure that discover_single handles unsupported devices correctly."""
host = "127.0.0.1" host = "127.0.0.1"
@ -201,14 +186,17 @@ async def test_discover_send(mocker):
async def test_discover_datagram_received(mocker, discovery_data): async def test_discover_datagram_received(mocker, discovery_data):
"""Verify that datagram received fills discovered_devices.""" """Verify that datagram received fills discovered_devices."""
proto = _DiscoverProtocol() proto = _DiscoverProtocol()
info = {"system": {"get_sysinfo": discovery_data["system"]["get_sysinfo"]}}
mocker.patch("kasa.discover.json_loads", return_value=info)
mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "encrypt")
mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "decrypt") mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "decrypt")
addr = "127.0.0.1" addr = "127.0.0.1"
proto.datagram_received("<placeholder data>", (addr, 9999)) port = 20002 if "result" in discovery_data else 9999
mocker.patch("kasa.discover.json_loads", return_value=discovery_data)
proto.datagram_received("<placeholder data>", (addr, port))
addr2 = "127.0.0.2" addr2 = "127.0.0.2"
mocker.patch("kasa.discover.json_loads", return_value=UNSUPPORTED)
proto.datagram_received("<placeholder data>", (addr2, 20002)) proto.datagram_received("<placeholder data>", (addr2, 20002))
# Check that device in discovered_devices is initialized correctly # Check that device in discovered_devices is initialized correctly

View File

@ -10,9 +10,14 @@ from contextlib import nullcontext as does_not_raise
import httpx import httpx
import pytest import pytest
from ..aestransport import AesTransport
from ..credentials import Credentials from ..credentials import Credentials
from ..exceptions import AuthenticationException, SmartDeviceException from ..exceptions import AuthenticationException, SmartDeviceException
from ..klapprotocol import KlapEncryptionSession, TPLinkKlap, _sha256 from ..iotprotocol import IotProtocol
from ..klaptransport import KlapEncryptionSession, KlapTransport, _sha256
from ..smartprotocol import SmartProtocol
DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}}
class _mock_response: class _mock_response:
@ -21,67 +26,92 @@ class _mock_response:
self.content = content self.content = content
@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol])
@pytest.mark.parametrize("retry_count", [1, 3, 5]) @pytest.mark.parametrize("retry_count", [1, 3, 5])
async def test_protocol_retries(mocker, retry_count): async def test_protocol_retries(mocker, retry_count, protocol_class, transport_class):
host = "127.0.0.1"
conn = mocker.patch.object( conn = mocker.patch.object(
TPLinkKlap, "client_post", side_effect=Exception("dummy exception") transport_class, "client_post", side_effect=Exception("dummy exception")
) )
with pytest.raises(SmartDeviceException): with pytest.raises(SmartDeviceException):
await TPLinkKlap("127.0.0.1").query({}, retry_count=retry_count) await protocol_class(host, transport=transport_class(host)).query(
DUMMY_QUERY, retry_count=retry_count
)
assert conn.call_count == retry_count + 1 assert conn.call_count == retry_count + 1
async def test_protocol_no_retry_on_connection_error(mocker): @pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol])
async def test_protocol_no_retry_on_connection_error(
mocker, protocol_class, transport_class
):
host = "127.0.0.1"
conn = mocker.patch.object( conn = mocker.patch.object(
TPLinkKlap, transport_class,
"client_post", "client_post",
side_effect=httpx.ConnectError("foo"), side_effect=httpx.ConnectError("foo"),
) )
with pytest.raises(SmartDeviceException): with pytest.raises(SmartDeviceException):
await TPLinkKlap("127.0.0.1").query({}, retry_count=5) await protocol_class(host, transport=transport_class(host)).query(
DUMMY_QUERY, retry_count=5
)
assert conn.call_count == 1 assert conn.call_count == 1
async def test_protocol_retry_recoverable_error(mocker): @pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol])
async def test_protocol_retry_recoverable_error(
mocker, protocol_class, transport_class
):
host = "127.0.0.1"
conn = mocker.patch.object( conn = mocker.patch.object(
TPLinkKlap, transport_class,
"client_post", "client_post",
side_effect=httpx.CloseError("foo"), side_effect=httpx.CloseError("foo"),
) )
with pytest.raises(SmartDeviceException): with pytest.raises(SmartDeviceException):
await TPLinkKlap("127.0.0.1").query({}, retry_count=5) await protocol_class(host, transport=transport_class(host)).query(
DUMMY_QUERY, retry_count=5
)
assert conn.call_count == 6 assert conn.call_count == 6
@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol])
@pytest.mark.parametrize("retry_count", [1, 3, 5]) @pytest.mark.parametrize("retry_count", [1, 3, 5])
async def test_protocol_reconnect(mocker, retry_count): async def test_protocol_reconnect(mocker, retry_count, protocol_class, transport_class):
host = "127.0.0.1"
remaining = retry_count remaining = retry_count
mock_response = {"result": {"great": "success"}}
def _fail_one_less_than_retry_count(*_, **__): def _fail_one_less_than_retry_count(*_, **__):
nonlocal remaining, encryption_session nonlocal remaining
remaining -= 1 remaining -= 1
if remaining: if remaining:
raise Exception("Simulated post failure") raise Exception("Simulated post failure")
# Do the encrypt just before returning the value so the incrementing sequence number is correct
encrypted, seq = encryption_session.encrypt('{"great":"success"}')
return 200, encrypted
seed = secrets.token_bytes(16) return mock_response
auth_hash = TPLinkKlap.generate_auth_hash(Credentials("foo", "bar"))
encryption_session = KlapEncryptionSession(seed, seed, auth_hash)
protocol = TPLinkKlap("127.0.0.1")
protocol.handshake_done = True
protocol.session_expire_at = time.time() + 86400
protocol.encryption_session = encryption_session
mocker.patch.object( mocker.patch.object(
TPLinkKlap, "client_post", side_effect=_fail_one_less_than_retry_count transport_class, "needs_handshake", property(lambda self: False)
)
mocker.patch.object(transport_class, "needs_login", property(lambda self: False))
send_mock = mocker.patch.object(
transport_class,
"send",
side_effect=_fail_one_less_than_retry_count,
) )
response = await protocol.query({}, retry_count=retry_count) response = await protocol_class(host, transport=transport_class(host)).query(
assert response == {"great": "success"} DUMMY_QUERY, retry_count=retry_count
)
assert "result" in response or "great" in response
assert send_mock.call_count == retry_count
@pytest.mark.parametrize("log_level", [logging.WARNING, logging.DEBUG]) @pytest.mark.parametrize("log_level", [logging.WARNING, logging.DEBUG])
@ -96,14 +126,14 @@ async def test_protocol_logging(mocker, caplog, log_level):
return 200, encrypted return 200, encrypted
seed = secrets.token_bytes(16) seed = secrets.token_bytes(16)
auth_hash = TPLinkKlap.generate_auth_hash(Credentials("foo", "bar")) auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar"))
encryption_session = KlapEncryptionSession(seed, seed, auth_hash) encryption_session = KlapEncryptionSession(seed, seed, auth_hash)
protocol = TPLinkKlap("127.0.0.1") protocol = IotProtocol("127.0.0.1")
protocol.handshake_done = True protocol._transport._handshake_done = True
protocol.session_expire_at = time.time() + 86400 protocol._transport._session_expire_at = time.time() + 86400
protocol.encryption_session = encryption_session protocol._transport._encryption_session = encryption_session
mocker.patch.object(TPLinkKlap, "client_post", side_effect=_return_encrypted) mocker.patch.object(KlapTransport, "client_post", side_effect=_return_encrypted)
response = await protocol.query({}) response = await protocol.query({})
assert response == {"great": "success"} assert response == {"great": "success"}
@ -117,7 +147,7 @@ def test_encrypt():
d = json.dumps({"foo": 1, "bar": 2}) d = json.dumps({"foo": 1, "bar": 2})
seed = secrets.token_bytes(16) seed = secrets.token_bytes(16)
auth_hash = TPLinkKlap.generate_auth_hash(Credentials("foo", "bar")) auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar"))
encryption_session = KlapEncryptionSession(seed, seed, auth_hash) encryption_session = KlapEncryptionSession(seed, seed, auth_hash)
encrypted, seq = encryption_session.encrypt(d) encrypted, seq = encryption_session.encrypt(d)
@ -129,7 +159,7 @@ def test_encrypt_unicode():
d = "{'snowman': '\u2603'}" d = "{'snowman': '\u2603'}"
seed = secrets.token_bytes(16) seed = secrets.token_bytes(16)
auth_hash = TPLinkKlap.generate_auth_hash(Credentials("foo", "bar")) auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar"))
encryption_session = KlapEncryptionSession(seed, seed, auth_hash) encryption_session = KlapEncryptionSession(seed, seed, auth_hash)
encrypted, seq = encryption_session.encrypt(d) encrypted, seq = encryption_session.encrypt(d)
@ -145,7 +175,10 @@ def test_encrypt_unicode():
(Credentials("foo", "bar"), does_not_raise()), (Credentials("foo", "bar"), does_not_raise()),
(Credentials("", ""), does_not_raise()), (Credentials("", ""), does_not_raise()),
( (
Credentials(TPLinkKlap.KASA_SETUP_EMAIL, TPLinkKlap.KASA_SETUP_PASSWORD), Credentials(
KlapTransport.KASA_SETUP_EMAIL,
KlapTransport.KASA_SETUP_PASSWORD,
),
does_not_raise(), does_not_raise(),
), ),
( (
@ -167,21 +200,21 @@ async def test_handshake1(mocker, device_credentials, expectation):
client_seed = None client_seed = None
server_seed = secrets.token_bytes(16) server_seed = secrets.token_bytes(16)
client_credentials = Credentials("foo", "bar") client_credentials = Credentials("foo", "bar")
device_auth_hash = TPLinkKlap.generate_auth_hash(device_credentials) device_auth_hash = KlapTransport.generate_auth_hash(device_credentials)
mocker.patch.object( mocker.patch.object(
httpx.AsyncClient, "post", side_effect=_return_handshake1_response httpx.AsyncClient, "post", side_effect=_return_handshake1_response
) )
protocol = TPLinkKlap("127.0.0.1", credentials=client_credentials) protocol = IotProtocol("127.0.0.1", credentials=client_credentials)
protocol.http_client = httpx.AsyncClient() protocol._transport.http_client = httpx.AsyncClient()
with expectation: with expectation:
( (
local_seed, local_seed,
device_remote_seed, device_remote_seed,
auth_hash, auth_hash,
) = await protocol.perform_handshake1() ) = await protocol._transport.perform_handshake1()
assert local_seed == client_seed assert local_seed == client_seed
assert device_remote_seed == server_seed assert device_remote_seed == server_seed
@ -204,23 +237,23 @@ async def test_handshake(mocker):
client_seed = None client_seed = None
server_seed = secrets.token_bytes(16) server_seed = secrets.token_bytes(16)
client_credentials = Credentials("foo", "bar") client_credentials = Credentials("foo", "bar")
device_auth_hash = TPLinkKlap.generate_auth_hash(client_credentials) device_auth_hash = KlapTransport.generate_auth_hash(client_credentials)
mocker.patch.object( mocker.patch.object(
httpx.AsyncClient, "post", side_effect=_return_handshake_response httpx.AsyncClient, "post", side_effect=_return_handshake_response
) )
protocol = TPLinkKlap("127.0.0.1", credentials=client_credentials) protocol = IotProtocol("127.0.0.1", credentials=client_credentials)
protocol.http_client = httpx.AsyncClient() protocol._transport.http_client = httpx.AsyncClient()
response_status = 200 response_status = 200
await protocol.perform_handshake() await protocol._transport.perform_handshake()
assert protocol.handshake_done is True assert protocol._transport._handshake_done is True
response_status = 403 response_status = 403
with pytest.raises(AuthenticationException): with pytest.raises(AuthenticationException):
await protocol.perform_handshake() await protocol._transport.perform_handshake()
assert protocol.handshake_done is False assert protocol._transport._handshake_done is False
await protocol.close() await protocol.close()
@ -237,9 +270,9 @@ async def test_query(mocker):
return _mock_response(200, b"") return _mock_response(200, b"")
elif url == "http://127.0.0.1/app/request": elif url == "http://127.0.0.1/app/request":
encryption_session = KlapEncryptionSession( encryption_session = KlapEncryptionSession(
protocol.encryption_session.local_seed, protocol._transport._encryption_session.local_seed,
protocol.encryption_session.remote_seed, protocol._transport._encryption_session.remote_seed,
protocol.encryption_session.user_hash, protocol._transport._encryption_session.user_hash,
) )
seq = params.get("seq") seq = params.get("seq")
encryption_session._seq = seq - 1 encryption_session._seq = seq - 1
@ -252,11 +285,11 @@ async def test_query(mocker):
seq = None seq = None
server_seed = secrets.token_bytes(16) server_seed = secrets.token_bytes(16)
client_credentials = Credentials("foo", "bar") client_credentials = Credentials("foo", "bar")
device_auth_hash = TPLinkKlap.generate_auth_hash(client_credentials) device_auth_hash = KlapTransport.generate_auth_hash(client_credentials)
mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response) mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response)
protocol = TPLinkKlap("127.0.0.1", credentials=client_credentials) protocol = IotProtocol("127.0.0.1", credentials=client_credentials)
for _ in range(10): for _ in range(10):
resp = await protocol.query({}) resp = await protocol.query({})
@ -296,11 +329,11 @@ async def test_authentication_failures(mocker, response_status, expectation):
server_seed = secrets.token_bytes(16) server_seed = secrets.token_bytes(16)
client_credentials = Credentials("foo", "bar") client_credentials = Credentials("foo", "bar")
device_auth_hash = TPLinkKlap.generate_auth_hash(client_credentials) device_auth_hash = KlapTransport.generate_auth_hash(client_credentials)
mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response) mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response)
protocol = TPLinkKlap("127.0.0.1", credentials=client_credentials) protocol = IotProtocol("127.0.0.1", credentials=client_credentials)
with expectation: with expectation:
await protocol.query({}) await protocol.query({})

View File

@ -1,6 +1,6 @@
from kasa import DeviceType from kasa import DeviceType
from .conftest import plug from .conftest import plug, plug_smart
from .newfakes import PLUG_SCHEMA from .newfakes import PLUG_SCHEMA
@ -28,3 +28,14 @@ async def test_led(dev):
assert dev.led assert dev.led
await dev.set_led(original) await dev.set_led(original)
@plug_smart
async def test_plug_device_info(dev):
assert dev._info is not None
# PLUG_SCHEMA(dev.sys_info)
assert dev.model is not None
assert dev.device_type == DeviceType.Plug or dev.device_type == DeviceType.Strip
# assert dev.is_plug or dev.is_strip

View File

@ -9,7 +9,7 @@ from kasa.tests.conftest import get_device_for_file
def test_bulb_examples(mocker): def test_bulb_examples(mocker):
"""Use KL130 (bulb with all features) to test the doctests.""" """Use KL130 (bulb with all features) to test the doctests."""
p = asyncio.run(get_device_for_file("KL130(US)_1.0_1.8.11.json")) p = asyncio.run(get_device_for_file("KL130(US)_1.0_1.8.11.json", "IOT"))
mocker.patch("kasa.smartbulb.SmartBulb", return_value=p) mocker.patch("kasa.smartbulb.SmartBulb", return_value=p)
mocker.patch("kasa.smartbulb.SmartBulb.update") mocker.patch("kasa.smartbulb.SmartBulb.update")
res = xdoctest.doctest_module("kasa.smartbulb", "all") res = xdoctest.doctest_module("kasa.smartbulb", "all")
@ -18,7 +18,7 @@ def test_bulb_examples(mocker):
def test_smartdevice_examples(mocker): def test_smartdevice_examples(mocker):
"""Use HS110 for emeter examples.""" """Use HS110 for emeter examples."""
p = asyncio.run(get_device_for_file("HS110(EU)_1.0_1.2.5.json")) p = asyncio.run(get_device_for_file("HS110(EU)_1.0_1.2.5.json", "IOT"))
mocker.patch("kasa.smartdevice.SmartDevice", return_value=p) mocker.patch("kasa.smartdevice.SmartDevice", return_value=p)
mocker.patch("kasa.smartdevice.SmartDevice.update") mocker.patch("kasa.smartdevice.SmartDevice.update")
res = xdoctest.doctest_module("kasa.smartdevice", "all") res = xdoctest.doctest_module("kasa.smartdevice", "all")
@ -27,7 +27,7 @@ def test_smartdevice_examples(mocker):
def test_plug_examples(mocker): def test_plug_examples(mocker):
"""Test plug examples.""" """Test plug examples."""
p = asyncio.run(get_device_for_file("HS110(EU)_1.0_1.2.5.json")) p = asyncio.run(get_device_for_file("HS110(EU)_1.0_1.2.5.json", "IOT"))
mocker.patch("kasa.smartplug.SmartPlug", return_value=p) mocker.patch("kasa.smartplug.SmartPlug", return_value=p)
mocker.patch("kasa.smartplug.SmartPlug.update") mocker.patch("kasa.smartplug.SmartPlug.update")
res = xdoctest.doctest_module("kasa.smartplug", "all") res = xdoctest.doctest_module("kasa.smartplug", "all")
@ -36,7 +36,7 @@ def test_plug_examples(mocker):
def test_strip_examples(mocker): def test_strip_examples(mocker):
"""Test strip examples.""" """Test strip examples."""
p = asyncio.run(get_device_for_file("KP303(UK)_1.0_1.0.3.json")) p = asyncio.run(get_device_for_file("KP303(UK)_1.0_1.0.3.json", "IOT"))
mocker.patch("kasa.smartstrip.SmartStrip", return_value=p) mocker.patch("kasa.smartstrip.SmartStrip", return_value=p)
mocker.patch("kasa.smartstrip.SmartStrip.update") mocker.patch("kasa.smartstrip.SmartStrip.update")
res = xdoctest.doctest_module("kasa.smartstrip", "all") res = xdoctest.doctest_module("kasa.smartstrip", "all")
@ -45,7 +45,7 @@ def test_strip_examples(mocker):
def test_dimmer_examples(mocker): def test_dimmer_examples(mocker):
"""Test dimmer examples.""" """Test dimmer examples."""
p = asyncio.run(get_device_for_file("HS220(US)_1.0_1.5.7.json")) p = asyncio.run(get_device_for_file("HS220(US)_1.0_1.5.7.json", "IOT"))
mocker.patch("kasa.smartdimmer.SmartDimmer", return_value=p) mocker.patch("kasa.smartdimmer.SmartDimmer", return_value=p)
mocker.patch("kasa.smartdimmer.SmartDimmer.update") mocker.patch("kasa.smartdimmer.SmartDimmer.update")
res = xdoctest.doctest_module("kasa.smartdimmer", "all") res = xdoctest.doctest_module("kasa.smartdimmer", "all")
@ -54,7 +54,7 @@ def test_dimmer_examples(mocker):
def test_lightstrip_examples(mocker): def test_lightstrip_examples(mocker):
"""Test lightstrip examples.""" """Test lightstrip examples."""
p = asyncio.run(get_device_for_file("KL430(US)_1.0_1.0.10.json")) p = asyncio.run(get_device_for_file("KL430(US)_1.0_1.0.10.json", "IOT"))
mocker.patch("kasa.smartlightstrip.SmartLightStrip", return_value=p) mocker.patch("kasa.smartlightstrip.SmartLightStrip", return_value=p)
mocker.patch("kasa.smartlightstrip.SmartLightStrip.update") mocker.patch("kasa.smartlightstrip.SmartLightStrip.update")
res = xdoctest.doctest_module("kasa.smartlightstrip", "all") res = xdoctest.doctest_module("kasa.smartlightstrip", "all")
@ -63,7 +63,7 @@ def test_lightstrip_examples(mocker):
def test_discovery_examples(mocker): def test_discovery_examples(mocker):
"""Test discovery examples.""" """Test discovery examples."""
p = asyncio.run(get_device_for_file("KP303(UK)_1.0_1.0.3.json")) p = asyncio.run(get_device_for_file("KP303(UK)_1.0_1.0.3.json", "IOT"))
mocker.patch("kasa.discover.Discover.discover", return_value=[p]) mocker.patch("kasa.discover.Discover.discover", return_value=[p])
res = xdoctest.doctest_module("kasa.discover", "all") res = xdoctest.doctest_module("kasa.discover", "all")

View File

@ -8,7 +8,7 @@ import kasa
from kasa import Credentials, SmartDevice, SmartDeviceException from kasa import Credentials, SmartDevice, SmartDeviceException
from kasa.smartdevice import DeviceType from kasa.smartdevice import DeviceType
from .conftest import handle_turn_on, has_emeter, no_emeter, turn_on from .conftest import device_iot, handle_turn_on, has_emeter, no_emeter, turn_on
from .newfakes import PLUG_SCHEMA, TZ_SCHEMA, FakeTransportProtocol from .newfakes import PLUG_SCHEMA, TZ_SCHEMA, FakeTransportProtocol
# List of all SmartXXX classes including the SmartDevice base class # List of all SmartXXX classes including the SmartDevice base class
@ -22,11 +22,13 @@ smart_device_classes = [
] ]
@device_iot
async def test_state_info(dev): async def test_state_info(dev):
assert isinstance(dev.state_information, dict) assert isinstance(dev.state_information, dict)
@pytest.mark.requires_dummy @pytest.mark.requires_dummy
@device_iot
async def test_invalid_connection(dev): async def test_invalid_connection(dev):
with patch.object( with patch.object(
FakeTransportProtocol, "query", side_effect=SmartDeviceException FakeTransportProtocol, "query", side_effect=SmartDeviceException
@ -58,12 +60,14 @@ async def test_initial_update_no_emeter(dev, mocker):
assert spy.call_count == 2 assert spy.call_count == 2
@device_iot
async def test_query_helper(dev): async def test_query_helper(dev):
with pytest.raises(SmartDeviceException): with pytest.raises(SmartDeviceException):
await dev._query_helper("test", "testcmd", {}) await dev._query_helper("test", "testcmd", {})
# TODO check for unwrapping? # TODO check for unwrapping?
@device_iot
@turn_on @turn_on
async def test_state(dev, turn_on): async def test_state(dev, turn_on):
await handle_turn_on(dev, turn_on) await handle_turn_on(dev, turn_on)
@ -90,6 +94,7 @@ async def test_state(dev, turn_on):
assert dev.is_off assert dev.is_off
@device_iot
async def test_alias(dev): async def test_alias(dev):
test_alias = "TEST1234" test_alias = "TEST1234"
original = dev.alias original = dev.alias
@ -104,6 +109,7 @@ async def test_alias(dev):
assert dev.alias == original assert dev.alias == original
@device_iot
@turn_on @turn_on
async def test_on_since(dev, turn_on): async def test_on_since(dev, turn_on):
await handle_turn_on(dev, turn_on) await handle_turn_on(dev, turn_on)
@ -116,30 +122,37 @@ async def test_on_since(dev, turn_on):
assert dev.on_since is None assert dev.on_since is None
@device_iot
async def test_time(dev): async def test_time(dev):
assert isinstance(await dev.get_time(), datetime) assert isinstance(await dev.get_time(), datetime)
@device_iot
async def test_timezone(dev): async def test_timezone(dev):
TZ_SCHEMA(await dev.get_timezone()) TZ_SCHEMA(await dev.get_timezone())
@device_iot
async def test_hw_info(dev): async def test_hw_info(dev):
PLUG_SCHEMA(dev.hw_info) PLUG_SCHEMA(dev.hw_info)
@device_iot
async def test_location(dev): async def test_location(dev):
PLUG_SCHEMA(dev.location) PLUG_SCHEMA(dev.location)
@device_iot
async def test_rssi(dev): async def test_rssi(dev):
PLUG_SCHEMA({"rssi": dev.rssi}) # wrapping for vol PLUG_SCHEMA({"rssi": dev.rssi}) # wrapping for vol
@device_iot
async def test_mac(dev): async def test_mac(dev):
PLUG_SCHEMA({"mac": dev.mac}) # wrapping for val PLUG_SCHEMA({"mac": dev.mac}) # wrapping for val
@device_iot
async def test_representation(dev): async def test_representation(dev):
import re import re
@ -147,6 +160,7 @@ async def test_representation(dev):
assert pattern.match(str(dev)) assert pattern.match(str(dev))
@device_iot
async def test_childrens(dev): async def test_childrens(dev):
"""Make sure that children property is exposed by every device.""" """Make sure that children property is exposed by every device."""
if dev.is_strip: if dev.is_strip:
@ -155,6 +169,7 @@ async def test_childrens(dev):
assert len(dev.children) == 0 assert len(dev.children) == 0
@device_iot
async def test_children(dev): async def test_children(dev):
"""Make sure that children property is exposed by every device.""" """Make sure that children property is exposed by every device."""
if dev.is_strip: if dev.is_strip:
@ -165,11 +180,13 @@ async def test_children(dev):
assert dev.has_children is False assert dev.has_children is False
@device_iot
async def test_internal_state(dev): async def test_internal_state(dev):
"""Make sure the internal state returns the last update results.""" """Make sure the internal state returns the last update results."""
assert dev.internal_state == dev._last_update assert dev.internal_state == dev._last_update
@device_iot
async def test_features(dev): async def test_features(dev):
"""Make sure features is always accessible.""" """Make sure features is always accessible."""
sysinfo = dev._last_update["system"]["get_sysinfo"] sysinfo = dev._last_update["system"]["get_sysinfo"]
@ -179,11 +196,13 @@ async def test_features(dev):
assert dev.features == set() assert dev.features == set()
@device_iot
async def test_max_device_response_size(dev): async def test_max_device_response_size(dev):
"""Make sure every device return has a set max response size.""" """Make sure every device return has a set max response size."""
assert dev.max_device_response_size > 0 assert dev.max_device_response_size > 0
@device_iot
async def test_estimated_response_sizes(dev): async def test_estimated_response_sizes(dev):
"""Make sure every module has an estimated response size set.""" """Make sure every module has an estimated response size set."""
for mod in dev.modules.values(): for mod in dev.modules.values():
@ -202,6 +221,7 @@ def test_device_class_ctors(device_class):
assert dev.credentials == credentials assert dev.credentials == credentials
@device_iot
async def test_modules_preserved(dev: SmartDevice): async def test_modules_preserved(dev: SmartDevice):
"""Make modules that are not being updated are preserved between updates.""" """Make modules that are not being updated are preserved between updates."""
dev._last_update["some_module_not_being_updated"] = "should_be_kept" dev._last_update["some_module_not_being_updated"] = "should_be_kept"
@ -237,6 +257,7 @@ async def test_create_thin_wrapper():
) )
@device_iot
async def test_modules_not_supported(dev: SmartDevice): async def test_modules_not_supported(dev: SmartDevice):
"""Test that unsupported modules do not break the device.""" """Test that unsupported modules do not break the device."""
for module in dev.modules.values(): for module in dev.modules.values():