Add klap support for TAPO protocol by splitting out Transports and Protocols (#557)

* Add support for TAPO/SMART KLAP and seperate transports from protocols

* Add tests and some review changes

* Update following review

* Updates following review
This commit is contained in:
sdb9696 2023-12-04 18:50:05 +00:00 committed by GitHub
parent 347cbfe3bd
commit 4a00199506
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 1604 additions and 887 deletions

View File

@ -21,13 +21,14 @@ from kasa.exceptions import (
SmartDeviceException,
UnsupportedDeviceException,
)
from kasa.klapprotocol import TPLinkKlap
from kasa.iotprotocol import IotProtocol
from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol
from kasa.smartbulb import SmartBulb, SmartBulbPreset, TurnOnBehavior, TurnOnBehaviors
from kasa.smartdevice import DeviceType, SmartDevice
from kasa.smartdimmer import SmartDimmer
from kasa.smartlightstrip import SmartLightStrip
from kasa.smartplug import SmartPlug
from kasa.smartprotocol import SmartProtocol
from kasa.smartstrip import SmartStrip
__version__ = version("python-kasa")
@ -37,7 +38,8 @@ __all__ = [
"Discover",
"TPLinkSmartHomeProtocol",
"TPLinkProtocol",
"TPLinkKlap",
"IotProtocol",
"SmartProtocol",
"SmartBulb",
"SmartBulbPreset",
"TurnOnBehaviors",

View File

@ -1,498 +0,0 @@
"""Implementation of the TP-Link AES Protocol.
Based on the work of https://github.com/petretiandrea/plugp100
under compatible GNU GPL3 license.
"""
import asyncio
import base64
import hashlib
import logging
import time
import uuid
from pprint import pformat as pf
from typing import Dict, Optional, Union
import httpx
from cryptography.hazmat.primitives import hashes, padding, serialization
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from .credentials import Credentials
from .exceptions import AuthenticationException, SmartDeviceException
from .json import dumps as json_dumps
from .json import loads as json_loads
from .protocol import TPLinkProtocol
_LOGGER = logging.getLogger(__name__)
logging.getLogger("httpx").propagate = False
def _md5(payload: bytes) -> bytes:
digest = hashes.Hash(hashes.MD5()) # noqa: S303
digest.update(payload)
hash = digest.finalize()
return hash
def _sha1(payload: bytes) -> str:
sha1_algo = hashlib.sha1() # noqa: S324
sha1_algo.update(payload)
return sha1_algo.hexdigest()
class TPLinkAes(TPLinkProtocol):
"""Implementation of the AES encryption protocol.
AES is the name used in device discovery for TP-Link's TAPO encryption
protocol, sometimes used by newer firmware versions on kasa devices.
"""
DEFAULT_PORT = 80
DEFAULT_TIMEOUT = 5
SESSION_COOKIE_NAME = "TP_SESSIONID"
COMMON_HEADERS = {
"Content-Type": "application/json",
"requestByApp": "true",
"Accept": "application/json",
}
def __init__(
self,
host: str,
*,
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
) -> None:
super().__init__(host=host, port=self.DEFAULT_PORT)
self.credentials = (
credentials
if credentials and credentials.username and credentials.password
else Credentials(username="", password="")
)
self._local_seed: Optional[bytes] = None
self.local_auth_hash = self.generate_auth_hash(self.credentials)
self.local_auth_owner = self.generate_owner_hash(self.credentials).hex()
self.kasa_setup_auth_hash = None
self.blank_auth_hash = None
self.handshake_lock = asyncio.Lock()
self.query_lock = asyncio.Lock()
self.handshake_done = False
self.encryption_session: Optional[AesEncyptionSession] = None
self.session_expire_at: Optional[float] = None
self.timeout = timeout if timeout else self.DEFAULT_TIMEOUT
self.session_cookie = None
self.terminal_uuid = None
self.http_client: Optional[httpx.AsyncClient] = None
self.request_id_generator = SnowflakeId(1, 1)
self.login_token = None
_LOGGER.debug("Created AES object for %s", self.host)
def hash_credentials(self, credentials, try_login_version2):
"""Hash the credentials."""
if try_login_version2:
un = base64.b64encode(
_sha1(credentials.username.encode()).encode()
).decode()
pw = base64.b64encode(
_sha1(credentials.password.encode()).encode()
).decode()
else:
un = base64.b64encode(
_sha1(credentials.username.encode()).encode()
).decode()
pw = base64.b64encode(credentials.password.encode()).decode()
return un, pw
async def client_post(self, url, params=None, data=None, json=None, headers=None):
"""Send an http post request to the device."""
response_data = None
cookies = None
if self.session_cookie:
cookies = httpx.Cookies()
cookies.set(self.SESSION_COOKIE_NAME, self.session_cookie)
self.http_client.cookies.clear()
resp = await self.http_client.post(
url,
params=params,
data=data,
json=json,
timeout=self.timeout,
cookies=cookies,
headers=self.COMMON_HEADERS,
)
if resp.status_code == 200:
response_data = resp.json()
return resp.status_code, response_data
async def send_secure_passthrough(self, request):
"""Send encrypted message as passthrough."""
url = f"http://{self.host}/app"
if self.login_token:
url += f"?token={self.login_token}"
raw_request = json_dumps(request)
encrypted_payload = self.encryption_session.encrypt(raw_request.encode())
passthrough_request = {
"method": "securePassthrough",
"params": {"request": encrypted_payload.decode()},
}
status_code, resp_dict = await self.client_post(url, json=passthrough_request)
if status_code == 200 and resp_dict["error_code"] == 0:
response = self.encryption_session.decrypt(
resp_dict["result"]["response"].encode()
)
resp_dict = json_loads(response)
if resp_dict["error_code"] != 0:
raise SmartDeviceException(
f"Could not complete send, response was {resp_dict}",
)
if "result" in resp_dict:
return resp_dict["result"]
else:
raise AuthenticationException("Could not complete send")
def get_aes_request(self, method, params=None):
"""Get a request message."""
request = {
"method": method,
"params": params,
"requestID": self.request_id_generator.generate_id(),
"request_time_milis": round(time.time() * 1000),
"terminal_uuid": self.terminal_uuid,
}
return request
async def perform_login(self, login_v2):
"""Login to the device."""
self.login_token = None
un, pw = self.hash_credentials(self.credentials, login_v2)
params = {"password": pw, "username": un}
request = self.get_aes_request("login_device", params)
try:
result = await self.send_secure_passthrough(request)
except SmartDeviceException as ex:
raise AuthenticationException(ex) from ex
self.login_token = result["token"]
async def perform_handshake(self):
"""Perform the handshake."""
_LOGGER.debug("Will perform handshaking...")
_LOGGER.debug("Generating keypair")
self.handshake_done = False
self.session_expire_at = None
self.session_cookie = None
url = f"http://{self.host}/app"
key_pair = KeyPair.create_key_pair()
pub_key = (
"-----BEGIN PUBLIC KEY-----\n"
+ key_pair.get_public_key()
+ "\n-----END PUBLIC KEY-----\n"
)
handshake_params = {"key": pub_key}
_LOGGER.debug(f"Handshake params: {handshake_params}")
request_body = {"method": "handshake", "params": handshake_params}
_LOGGER.debug(f"Request {request_body}")
status_code, resp_dict = await self.client_post(url, json=request_body)
_LOGGER.debug(f"Device responded with: {resp_dict}")
if status_code == 200 and resp_dict["error_code"] == 0:
_LOGGER.debug("Decoding handshake key...")
handshake_key = resp_dict["result"]["key"]
self.session_cookie = self.http_client.cookies.get( # type: ignore
self.SESSION_COOKIE_NAME
)
if not self.session_cookie:
self.session_cookie = self.http_client.cookies.get( # type: ignore
"SESSIONID"
)
self.session_expire_at = time.time() + 86400
self.encryption_session = AesEncyptionSession.create_from_keypair(
handshake_key, key_pair
)
self.terminal_uuid = base64.b64encode(_md5(uuid.uuid4().bytes)).decode(
"UTF-8"
)
self.handshake_done = True
_LOGGER.debug("Handshake with %s complete", self.host)
else:
raise AuthenticationException("Could not complete handshake")
def handshake_session_expired(self):
"""Return true if session has expired."""
return (
self.session_expire_at is None or self.session_expire_at - time.time() <= 0
)
@staticmethod
def generate_auth_hash(creds: Credentials):
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
un = creds.username or ""
pw = creds.password or ""
return _md5(_md5(un.encode()) + _md5(pw.encode()))
@staticmethod
def generate_owner_hash(creds: Credentials):
"""Return the MD5 hash of the username in this object."""
un = creds.username or ""
return _md5(un.encode())
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
"""Query the device retrying for retry_count on failure."""
async with self.query_lock:
return await self._query(request, retry_count)
async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
for retry in range(retry_count + 1):
try:
return await self._execute_query(request, retry)
except httpx.CloseError as sdex:
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {sdex}"
) from sdex
continue
except httpx.ConnectError as cex:
await self.close()
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {cex}"
) from cex
except TimeoutError as tex:
await self.close()
raise SmartDeviceException(
f"Unable to connect to the device, timed out: {self.host}: {tex}"
) from tex
except AuthenticationException as auex:
_LOGGER.debug("Unable to authenticate with %s, not retrying", self.host)
raise auex
except Exception as ex:
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {ex}"
) from ex
continue
# make mypy happy, this should never be reached..
raise SmartDeviceException("Query reached somehow to unreachable")
async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> Dict:
_LOGGER.debug(
"%s >> %s",
self.host,
_LOGGER.isEnabledFor(logging.DEBUG) and pf(request),
)
if not self.http_client:
self.http_client = httpx.AsyncClient()
if not self.handshake_done or self.handshake_session_expired():
try:
await self.perform_handshake()
await self.perform_login(False)
except AuthenticationException:
await self.perform_handshake()
await self.perform_login(True)
if isinstance(request, dict):
aes_method = next(iter(request))
aes_params = request[aes_method]
else:
aes_method = request
aes_params = None
aes_request = self.get_aes_request(aes_method, aes_params)
response_data = await self.send_secure_passthrough(aes_request)
_LOGGER.debug(
"%s << %s",
self.host,
_LOGGER.isEnabledFor(logging.DEBUG) and pf(response_data),
)
return response_data
async def close(self) -> None:
"""Close the protocol."""
client = self.http_client
self.http_client = None
if client:
await client.aclose()
class AesEncyptionSession:
"""Class for an AES encryption session."""
@staticmethod
def create_from_keypair(handshake_key: str, keypair):
"""Create the encryption session."""
handshake_key_bytes: bytes = base64.b64decode(handshake_key.encode("UTF-8"))
private_key_data = base64.b64decode(keypair.get_private_key().encode("UTF-8"))
private_key = serialization.load_der_private_key(private_key_data, None, None)
key_and_iv = private_key.decrypt(
handshake_key_bytes, asymmetric_padding.PKCS1v15()
)
if key_and_iv is None:
raise ValueError("Decryption failed!")
return AesEncyptionSession(key_and_iv[:16], key_and_iv[16:])
def __init__(self, key, iv):
self.cipher = Cipher(algorithms.AES(key), modes.CBC(iv))
self.padding_strategy = padding.PKCS7(algorithms.AES.block_size)
def encrypt(self, data) -> bytes:
"""Encrypt the message."""
encryptor = self.cipher.encryptor()
padder = self.padding_strategy.padder()
padded_data = padder.update(data) + padder.finalize()
encrypted = encryptor.update(padded_data) + encryptor.finalize()
return base64.b64encode(encrypted)
def decrypt(self, data) -> str:
"""Decrypt the message."""
decryptor = self.cipher.decryptor()
unpadder = self.padding_strategy.unpadder()
decrypted = decryptor.update(base64.b64decode(data)) + decryptor.finalize()
unpadded_data = unpadder.update(decrypted) + unpadder.finalize()
return unpadded_data.decode()
class KeyPair:
"""Class for generating key pairs."""
@staticmethod
def create_key_pair(key_size: int = 1024):
"""Create a key pair."""
private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size)
public_key = private_key.public_key()
private_key_bytes = private_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
public_key_bytes = public_key.public_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
return KeyPair(
private_key=base64.b64encode(private_key_bytes).decode("UTF-8"),
public_key=base64.b64encode(public_key_bytes).decode("UTF-8"),
)
def __init__(self, private_key: str, public_key: str):
self.private_key = private_key
self.public_key = public_key
def get_private_key(self) -> str:
"""Get the private key."""
return self.private_key
def get_public_key(self) -> str:
"""Get the public key."""
return self.public_key
class SnowflakeId:
"""Class for generating snowflake ids."""
EPOCH = 1420041600000 # Custom epoch (in milliseconds)
WORKER_ID_BITS = 5
DATA_CENTER_ID_BITS = 5
SEQUENCE_BITS = 12
MAX_WORKER_ID = (1 << WORKER_ID_BITS) - 1
MAX_DATA_CENTER_ID = (1 << DATA_CENTER_ID_BITS) - 1
SEQUENCE_MASK = (1 << SEQUENCE_BITS) - 1
def __init__(self, worker_id, data_center_id):
if worker_id > SnowflakeId.MAX_WORKER_ID or worker_id < 0:
raise ValueError(
"Worker ID can't be greater than "
+ str(SnowflakeId.MAX_WORKER_ID)
+ " or less than 0"
)
if data_center_id > SnowflakeId.MAX_DATA_CENTER_ID or data_center_id < 0:
raise ValueError(
"Data center ID can't be greater than "
+ str(SnowflakeId.MAX_DATA_CENTER_ID)
+ " or less than 0"
)
self.worker_id = worker_id
self.data_center_id = data_center_id
self.sequence = 0
self.last_timestamp = -1
def generate_id(self):
"""Generate a snowflake id."""
timestamp = self._current_millis()
if timestamp < self.last_timestamp:
raise ValueError("Clock moved backwards. Refusing to generate ID.")
if timestamp == self.last_timestamp:
# Within the same millisecond, increment the sequence number
self.sequence = (self.sequence + 1) & SnowflakeId.SEQUENCE_MASK
if self.sequence == 0:
# Sequence exceeds its bit range, wait until the next millisecond
timestamp = self._wait_next_millis(self.last_timestamp)
else:
# New millisecond, reset the sequence number
self.sequence = 0
# Update the last timestamp
self.last_timestamp = timestamp
# Generate and return the final ID
return (
(
(timestamp - SnowflakeId.EPOCH)
<< (
SnowflakeId.WORKER_ID_BITS
+ SnowflakeId.SEQUENCE_BITS
+ SnowflakeId.DATA_CENTER_ID_BITS
)
)
| (
self.data_center_id
<< (SnowflakeId.SEQUENCE_BITS + SnowflakeId.WORKER_ID_BITS)
)
| (self.worker_id << SnowflakeId.SEQUENCE_BITS)
| self.sequence
)
def _current_millis(self):
return round(time.time() * 1000)
def _wait_next_millis(self, last_timestamp):
timestamp = self._current_millis()
while timestamp <= last_timestamp:
timestamp = self._current_millis()
return timestamp

338
kasa/aestransport.py Normal file
View File

@ -0,0 +1,338 @@
"""Implementation of the TP-Link AES transport.
Based on the work of https://github.com/petretiandrea/plugp100
under compatible GNU GPL3 license.
"""
import base64
import hashlib
import logging
import time
from typing import Optional, Union
import httpx
from cryptography.hazmat.primitives import padding, serialization
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from .credentials import Credentials
from .exceptions import AuthenticationException, SmartDeviceException
from .json import dumps as json_dumps
from .json import loads as json_loads
from .protocol import BaseTransport
_LOGGER = logging.getLogger(__name__)
def _sha1(payload: bytes) -> str:
sha1_algo = hashlib.sha1() # noqa: S324
sha1_algo.update(payload)
return sha1_algo.hexdigest()
class AesTransport(BaseTransport):
"""Implementation of the AES encryption protocol.
AES is the name used in device discovery for TP-Link's TAPO encryption
protocol, sometimes used by newer firmware versions on kasa devices.
"""
DEFAULT_TIMEOUT = 5
SESSION_COOKIE_NAME = "TP_SESSIONID"
COMMON_HEADERS = {
"Content-Type": "application/json",
"requestByApp": "true",
"Accept": "application/json",
}
def __init__(
self,
host: str,
*,
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
) -> None:
super().__init__(host=host)
self._credentials = credentials or Credentials(username="", password="")
self._handshake_done = False
self._encryption_session: Optional[AesEncyptionSession] = None
self._session_expire_at: Optional[float] = None
self._timeout = timeout if timeout else self.DEFAULT_TIMEOUT
self._session_cookie = None
self._http_client: httpx.AsyncClient = httpx.AsyncClient()
self._login_token = None
_LOGGER.debug("Created AES object for %s", self.host)
def hash_credentials(self, login_v2):
"""Hash the credentials."""
if login_v2:
un = base64.b64encode(
_sha1(self._credentials.username.encode()).encode()
).decode()
pw = base64.b64encode(
_sha1(self._credentials.password.encode()).encode()
).decode()
else:
un = base64.b64encode(
_sha1(self._credentials.username.encode()).encode()
).decode()
pw = base64.b64encode(self._credentials.password.encode()).decode()
return un, pw
async def client_post(self, url, params=None, data=None, json=None, headers=None):
"""Send an http post request to the device."""
response_data = None
cookies = None
if self._session_cookie:
cookies = httpx.Cookies()
cookies.set(self.SESSION_COOKIE_NAME, self._session_cookie)
self._http_client.cookies.clear()
resp = await self._http_client.post(
url,
params=params,
data=data,
json=json,
timeout=self._timeout,
cookies=cookies,
headers=self.COMMON_HEADERS,
)
if resp.status_code == 200:
response_data = resp.json()
return resp.status_code, response_data
async def send_secure_passthrough(self, request: str):
"""Send encrypted message as passthrough."""
url = f"http://{self.host}/app"
if self._login_token:
url += f"?token={self._login_token}"
encrypted_payload = self._encryption_session.encrypt(request.encode()) # type: ignore
passthrough_request = {
"method": "securePassthrough",
"params": {"request": encrypted_payload.decode()},
}
status_code, resp_dict = await self.client_post(url, json=passthrough_request)
_LOGGER.debug(f"secure_passthrough response is {status_code}: {resp_dict}")
if status_code == 200 and resp_dict["error_code"] == 0:
response = self._encryption_session.decrypt( # type: ignore
resp_dict["result"]["response"].encode()
)
_LOGGER.debug(f"decrypted secure_passthrough response is {response}")
resp_dict = json_loads(response)
return resp_dict
else:
self._handshake_done = False
self._login_token = None
raise AuthenticationException("Could not complete send")
async def perform_login(self, login_request: Union[str, dict], *, login_v2: bool):
"""Login to the device."""
self._login_token = None
if isinstance(login_request, str):
login_request_dict: dict = json_loads(login_request)
else:
login_request_dict = login_request
un, pw = self.hash_credentials(login_v2)
login_request_dict["params"] = {"password": pw, "username": un}
request = json_dumps(login_request_dict)
try:
resp_dict = await self.send_secure_passthrough(request)
except SmartDeviceException as ex:
raise AuthenticationException(ex) from ex
self._login_token = resp_dict["result"]["token"]
@property
def needs_login(self) -> bool:
"""Return true if the transport needs to do a login."""
return self._login_token is None
async def login(self, request: str) -> None:
"""Login to the device."""
try:
if self.needs_handshake:
raise SmartDeviceException(
"Handshake must be complete before trying to login"
)
await self.perform_login(request, login_v2=False)
except AuthenticationException:
await self.perform_handshake()
await self.perform_login(request, login_v2=True)
@property
def needs_handshake(self) -> bool:
"""Return true if the transport needs to do a handshake."""
return not self._handshake_done or self._handshake_session_expired()
async def handshake(self) -> None:
"""Perform the encryption handshake."""
await self.perform_handshake()
async def perform_handshake(self):
"""Perform the handshake."""
_LOGGER.debug("Will perform handshaking...")
_LOGGER.debug("Generating keypair")
self._handshake_done = False
self._session_expire_at = None
self._session_cookie = None
url = f"http://{self.host}/app"
key_pair = KeyPair.create_key_pair()
pub_key = (
"-----BEGIN PUBLIC KEY-----\n"
+ key_pair.get_public_key()
+ "\n-----END PUBLIC KEY-----\n"
)
handshake_params = {"key": pub_key}
_LOGGER.debug(f"Handshake params: {handshake_params}")
request_body = {"method": "handshake", "params": handshake_params}
_LOGGER.debug(f"Request {request_body}")
status_code, resp_dict = await self.client_post(url, json=request_body)
_LOGGER.debug(f"Device responded with: {resp_dict}")
if status_code == 200 and resp_dict["error_code"] == 0:
_LOGGER.debug("Decoding handshake key...")
handshake_key = resp_dict["result"]["key"]
self._session_cookie = self._http_client.cookies.get( # type: ignore
self.SESSION_COOKIE_NAME
)
if not self._session_cookie:
self._session_cookie = self._http_client.cookies.get( # type: ignore
"SESSIONID"
)
self._session_expire_at = time.time() + 86400
self._encryption_session = AesEncyptionSession.create_from_keypair(
handshake_key, key_pair
)
self._handshake_done = True
_LOGGER.debug("Handshake with %s complete", self.host)
else:
raise AuthenticationException("Could not complete handshake")
def _handshake_session_expired(self):
"""Return true if session has expired."""
return (
self._session_expire_at is None
or self._session_expire_at - time.time() <= 0
)
async def send(self, request: str):
"""Send the request."""
if self.needs_handshake:
raise SmartDeviceException(
"Handshake must be complete before trying to send"
)
if self.needs_login:
raise SmartDeviceException("Login must be complete before trying to send")
resp_dict = await self.send_secure_passthrough(request)
if resp_dict["error_code"] != 0:
self._handshake_done = False
self._login_token = None
raise SmartDeviceException(
f"Could not complete send, response was {resp_dict}",
)
return resp_dict
async def close(self) -> None:
"""Close the protocol."""
client = self._http_client
self._http_client = None
if client:
await client.aclose()
class AesEncyptionSession:
"""Class for an AES encryption session."""
@staticmethod
def create_from_keypair(handshake_key: str, keypair):
"""Create the encryption session."""
handshake_key_bytes: bytes = base64.b64decode(handshake_key.encode("UTF-8"))
private_key_data = base64.b64decode(keypair.get_private_key().encode("UTF-8"))
private_key = serialization.load_der_private_key(private_key_data, None, None)
key_and_iv = private_key.decrypt(
handshake_key_bytes, asymmetric_padding.PKCS1v15()
)
if key_and_iv is None:
raise ValueError("Decryption failed!")
return AesEncyptionSession(key_and_iv[:16], key_and_iv[16:])
def __init__(self, key, iv):
self.cipher = Cipher(algorithms.AES(key), modes.CBC(iv))
self.padding_strategy = padding.PKCS7(algorithms.AES.block_size)
def encrypt(self, data) -> bytes:
"""Encrypt the message."""
encryptor = self.cipher.encryptor()
padder = self.padding_strategy.padder()
padded_data = padder.update(data) + padder.finalize()
encrypted = encryptor.update(padded_data) + encryptor.finalize()
return base64.b64encode(encrypted)
def decrypt(self, data) -> str:
"""Decrypt the message."""
decryptor = self.cipher.decryptor()
unpadder = self.padding_strategy.unpadder()
decrypted = decryptor.update(base64.b64decode(data)) + decryptor.finalize()
unpadded_data = unpadder.update(decrypted) + unpadder.finalize()
return unpadded_data.decode()
class KeyPair:
"""Class for generating key pairs."""
@staticmethod
def create_key_pair(key_size: int = 1024):
"""Create a key pair."""
private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size)
public_key = private_key.public_key()
private_key_bytes = private_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
public_key_bytes = public_key.public_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
return KeyPair(
private_key=base64.b64encode(private_key_bytes).decode("UTF-8"),
public_key=base64.b64encode(public_key_bytes).decode("UTF-8"),
)
def __init__(self, private_key: str, public_key: str):
self.private_key = private_key
self.public_key = public_key
def get_private_key(self) -> str:
"""Get the private key."""
return self.private_key
def get_public_key(self) -> str:
"""Get the public key."""
return self.public_key

View File

@ -2,17 +2,21 @@
import logging
import time
from typing import Any, Dict, Optional, Type
from typing import Any, Dict, Optional, Tuple, Type
from .aestransport import AesTransport
from .credentials import Credentials
from .device_type import DeviceType
from .exceptions import UnsupportedDeviceException
from .protocol import TPLinkProtocol
from .iotprotocol import IotProtocol
from .klaptransport import KlapTransport, TPlinkKlapTransportV2
from .protocol import BaseTransport, TPLinkProtocol
from .smartbulb import SmartBulb
from .smartdevice import SmartDevice, SmartDeviceException
from .smartdimmer import SmartDimmer
from .smartlightstrip import SmartLightStrip
from .smartplug import SmartPlug
from .smartprotocol import SmartProtocol
from .smartstrip import SmartStrip
from .tapo.tapoplug import TapoPlug
@ -87,7 +91,7 @@ async def connect(
if protocol_class is not None:
unknown_dev.protocol = protocol_class(host, credentials=credentials)
await unknown_dev.update()
device_class = get_device_class_from_info(unknown_dev.internal_state)
device_class = get_device_class_from_sys_info(unknown_dev.internal_state)
dev = device_class(host=host, port=port, credentials=credentials, timeout=timeout)
# Reuse the connection from the unknown device
# so we don't have to reconnect
@ -104,7 +108,7 @@ async def connect(
return dev
def get_device_class_from_info(info: Dict[str, Any]) -> Type[SmartDevice]:
def get_device_class_from_sys_info(info: Dict[str, Any]) -> Type[SmartDevice]:
"""Find SmartDevice subclass for device described by passed data."""
if "system" not in info or "get_sysinfo" not in info["system"]:
raise SmartDeviceException("No 'system' or 'get_sysinfo' in response")
@ -129,3 +133,35 @@ def get_device_class_from_info(info: Dict[str, Any]) -> Type[SmartDevice]:
return SmartBulb
raise UnsupportedDeviceException("Unknown device type: %s" % type_)
def get_device_class_from_type_name(device_type: str) -> Optional[Type[SmartDevice]]:
"""Return the device class from the type name."""
supported_device_types: dict[str, Type[SmartDevice]] = {
"SMART.TAPOPLUG": TapoPlug,
"SMART.KASAPLUG": TapoPlug,
"IOT.SMARTPLUGSWITCH": SmartPlug,
}
return supported_device_types.get(device_type)
def get_protocol_from_connection_name(
connection_name: str, host: str, credentials: Optional[Credentials] = None
) -> Optional[TPLinkProtocol]:
"""Return the protocol from the connection name."""
supported_device_protocols: dict[
str, Tuple[Type[TPLinkProtocol], Type[BaseTransport]]
] = {
"IOT.KLAP": (IotProtocol, KlapTransport),
"SMART.AES": (SmartProtocol, AesTransport),
"SMART.KLAP": (SmartProtocol, TPlinkKlapTransportV2),
}
if connection_name not in supported_device_protocols:
return None
protocol_class, transport_class = supported_device_protocols.get(connection_name) # type: ignore
transport: BaseTransport = transport_class(host, credentials=credentials)
protocol: TPLinkProtocol = protocol_class(
host, credentials=credentials, transport=transport
)
return protocol

View File

@ -15,18 +15,18 @@ try:
except ImportError:
from pydantic import BaseModel, Field
from kasa.aesprotocol import TPLinkAes
from kasa.credentials import Credentials
from kasa.exceptions import UnsupportedDeviceException
from kasa.json import dumps as json_dumps
from kasa.json import loads as json_loads
from kasa.klapprotocol import TPLinkKlap
from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol
from kasa.protocol import TPLinkSmartHomeProtocol
from kasa.smartdevice import SmartDevice, SmartDeviceException
from kasa.smartplug import SmartPlug
from kasa.tapo.tapoplug import TapoPlug
from .device_factory import get_device_class_from_info
from .device_factory import (
get_device_class_from_sys_info,
get_device_class_from_type_name,
get_protocol_from_connection_name,
)
_LOGGER = logging.getLogger(__name__)
@ -348,7 +348,16 @@ class Discover:
@staticmethod
def _get_device_class(info: dict) -> Type[SmartDevice]:
"""Find SmartDevice subclass for device described by passed data."""
return get_device_class_from_info(info)
if "result" in info:
discovery_result = DiscoveryResult(**info["result"])
dev_class = get_device_class_from_type_name(discovery_result.device_type)
if not dev_class:
raise UnsupportedDeviceException(
"Unknown device type: %s" % discovery_result.device_type
)
return dev_class
else:
return get_device_class_from_sys_info(info)
@staticmethod
def _get_device_instance_legacy(data: bytes, ip: str, port: int) -> SmartDevice:
@ -384,24 +393,17 @@ class Discover:
encrypt_type_ = (
f"{type_.split('.')[0]}.{discovery_result.mgt_encrypt_schm.encrypt_type}"
)
device_class = None
supported_device_types: dict[str, Type[SmartDevice]] = {
"SMART.TAPOPLUG": TapoPlug,
"SMART.KASAPLUG": TapoPlug,
"IOT.SMARTPLUGSWITCH": SmartPlug,
}
supported_device_protocols: dict[str, Type[TPLinkProtocol]] = {
"IOT.KLAP": TPLinkKlap,
"SMART.AES": TPLinkAes,
}
if (device_class := supported_device_types.get(type_)) is None:
if (device_class := get_device_class_from_type_name(type_)) is None:
_LOGGER.warning("Got unsupported device type: %s", type_)
raise UnsupportedDeviceException(
f"Unsupported device {ip} of type {type_}: {info}"
)
if (protocol_class := supported_device_protocols.get(encrypt_type_)) is None:
if (
protocol := get_protocol_from_connection_name(
encrypt_type_, ip, credentials=credentials
)
) is None:
_LOGGER.warning("Got unsupported device type: %s", encrypt_type_)
raise UnsupportedDeviceException(
f"Unsupported encryption scheme {ip} of type {encrypt_type_}: {info}"
@ -409,7 +411,7 @@ class Discover:
_LOGGER.debug("[DISCOVERY] %s << %s", ip, info)
device = device_class(ip, port=port, credentials=credentials)
device.protocol = protocol_class(ip, credentials=credentials)
device.protocol = protocol
device.update_from_discover_info(discovery_result.get_dict())
return device

100
kasa/iotprotocol.py Executable file
View File

@ -0,0 +1,100 @@
"""Module for the IOT legacy IOT KASA protocol."""
import asyncio
import logging
from typing import Dict, Optional, Union
import httpx
from .credentials import Credentials
from .exceptions import AuthenticationException, SmartDeviceException
from .json import dumps as json_dumps
from .klaptransport import KlapTransport
from .protocol import BaseTransport, TPLinkProtocol
_LOGGER = logging.getLogger(__name__)
class IotProtocol(TPLinkProtocol):
"""Class for the legacy TPLink IOT KASA Protocol."""
DEFAULT_PORT = 80
def __init__(
self,
host: str,
*,
transport: Optional[BaseTransport] = None,
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
) -> None:
super().__init__(host=host, port=self.DEFAULT_PORT)
self._credentials: Credentials = credentials or Credentials(
username="", password=""
)
self._transport: BaseTransport = transport or KlapTransport(
host, credentials=self._credentials, timeout=timeout
)
self._query_lock = asyncio.Lock()
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
"""Query the device retrying for retry_count on failure."""
if isinstance(request, dict):
request = json_dumps(request)
assert isinstance(request, str) # noqa: S101
async with self._query_lock:
return await self._query(request, retry_count)
async def _query(self, request: str, retry_count: int = 3) -> Dict:
for retry in range(retry_count + 1):
try:
return await self._execute_query(request, retry)
except httpx.CloseError as sdex:
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {sdex}"
) from sdex
continue
except httpx.ConnectError as cex:
await self.close()
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {cex}"
) from cex
except TimeoutError as tex:
await self.close()
raise SmartDeviceException(
f"Unable to connect to the device, timed out: {self.host}: {tex}"
) from tex
except AuthenticationException as auex:
_LOGGER.debug("Unable to authenticate with %s, not retrying", self.host)
raise auex
except Exception as ex:
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {ex}"
) from ex
continue
# make mypy happy, this should never be reached..
raise SmartDeviceException("Query reached somehow to unreachable")
async def _execute_query(self, request: str, retry_count: int) -> Dict:
if self._transport.needs_handshake:
await self._transport.handshake()
if self._transport.needs_login: # This shouln't happen
raise SmartDeviceException(
"IOT Protocol needs to login to transport but is not login aware"
)
return await self._transport.send(request)
async def close(self) -> None:
"""Close the protocol."""
await self._transport.close()

295
kasa/klapprotocol.py → kasa/klaptransport.py Executable file → Normal file
View File

@ -47,7 +47,7 @@ import logging
import secrets
import time
from pprint import pformat as pf
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Optional, Tuple
import httpx
from cryptography.hazmat.primitives import hashes, padding
@ -55,33 +55,33 @@ from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from .credentials import Credentials
from .exceptions import AuthenticationException, SmartDeviceException
from .json import dumps as json_dumps
from .json import loads as json_loads
from .protocol import TPLinkProtocol
from .protocol import BaseTransport, md5
_LOGGER = logging.getLogger(__name__)
logging.getLogger("httpx").propagate = False
def _sha256(payload: bytes) -> bytes:
return hashlib.sha256(payload).digest()
def _md5(payload: bytes) -> bytes:
digest = hashes.Hash(hashes.MD5()) # noqa: S303
digest = hashes.Hash(hashes.SHA256()) # noqa: S303
digest.update(payload)
hash = digest.finalize()
return hash
class TPLinkKlap(TPLinkProtocol):
def _sha1(payload: bytes) -> bytes:
digest = hashes.Hash(hashes.SHA1()) # noqa: S303
digest.update(payload)
return digest.finalize()
class KlapTransport(BaseTransport):
"""Implementation of the KLAP encryption protocol.
KLAP is the name used in device discovery for TP-Link's new encryption
protocol, used by newer firmware versions.
"""
DEFAULT_PORT = 80
DEFAULT_TIMEOUT = 5
DISCOVERY_QUERY = {"system": {"get_sysinfo": None}}
KASA_SETUP_EMAIL = "kasa@tp-link.net"
@ -95,29 +95,24 @@ class TPLinkKlap(TPLinkProtocol):
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
) -> None:
super().__init__(host=host, port=self.DEFAULT_PORT)
self.credentials = (
credentials
if credentials and credentials.username and credentials.password
else Credentials(username="", password="")
)
super().__init__(host=host)
self._credentials = credentials or Credentials(username="", password="")
self._local_seed: Optional[bytes] = None
self.local_auth_hash = self.generate_auth_hash(self.credentials)
self.local_auth_owner = self.generate_owner_hash(self.credentials).hex()
self.kasa_setup_auth_hash = None
self.blank_auth_hash = None
self.handshake_lock = asyncio.Lock()
self.query_lock = asyncio.Lock()
self.handshake_done = False
self._local_auth_hash = self.generate_auth_hash(self._credentials)
self._local_auth_owner = self.generate_owner_hash(self._credentials).hex()
self._kasa_setup_auth_hash = None
self._blank_auth_hash = None
self._handshake_lock = asyncio.Lock()
self._query_lock = asyncio.Lock()
self._handshake_done = False
self.encryption_session: Optional[KlapEncryptionSession] = None
self.session_expire_at: Optional[float] = None
self._encryption_session: Optional[KlapEncryptionSession] = None
self._session_expire_at: Optional[float] = None
self.timeout = timeout if timeout else self.DEFAULT_TIMEOUT
self.session_cookie = None
self.http_client: Optional[httpx.AsyncClient] = None
self._timeout = timeout if timeout else self.DEFAULT_TIMEOUT
self._session_cookie = None
self._http_client: httpx.AsyncClient = httpx.AsyncClient()
_LOGGER.debug("Created KLAP object for %s", self.host)
@ -125,15 +120,15 @@ class TPLinkKlap(TPLinkProtocol):
"""Send an http post request to the device."""
response_data = None
cookies = None
if self.session_cookie:
if self._session_cookie:
cookies = httpx.Cookies()
cookies.set(self.SESSION_COOKIE_NAME, self.session_cookie)
self.http_client.cookies.clear()
resp = await self.http_client.post(
cookies.set(self.SESSION_COOKIE_NAME, self._session_cookie)
self._http_client.cookies.clear()
resp = await self._http_client.post(
url,
params=params,
data=data,
timeout=self.timeout,
timeout=self._timeout,
cookies=cookies,
)
if resp.status_code == 200:
@ -183,44 +178,55 @@ class TPLinkKlap(TPLinkProtocol):
server_hash.hex(),
)
local_seed_auth_hash = _sha256(local_seed + self.local_auth_hash)
local_seed_auth_hash = self.handshake1_seed_auth_hash(
local_seed, remote_seed, self._local_auth_hash
) # type: ignore
# Check the response from the device with local credentials
if local_seed_auth_hash == server_hash:
_LOGGER.debug("handshake1 hashes match with expected credentials")
return local_seed, remote_seed, self.local_auth_hash # type: ignore
return local_seed, remote_seed, self._local_auth_hash # type: ignore
# Now check against the default kasa setup credentials
if not self.kasa_setup_auth_hash:
if not self._kasa_setup_auth_hash:
kasa_setup_creds = Credentials(
username=TPLinkKlap.KASA_SETUP_EMAIL,
password=TPLinkKlap.KASA_SETUP_PASSWORD,
username=self.KASA_SETUP_EMAIL,
password=self.KASA_SETUP_PASSWORD,
)
self.kasa_setup_auth_hash = TPLinkKlap.generate_auth_hash(kasa_setup_creds)
self._kasa_setup_auth_hash = self.generate_auth_hash(kasa_setup_creds)
kasa_setup_seed_auth_hash = _sha256(
local_seed + self.kasa_setup_auth_hash # type: ignore
kasa_setup_seed_auth_hash = self.handshake1_seed_auth_hash(
local_seed,
remote_seed,
self._kasa_setup_auth_hash, # type: ignore
)
if kasa_setup_seed_auth_hash == server_hash:
_LOGGER.debug(
"Server response doesn't match our expected hash on ip %s"
+ " but an authentication with kasa setup credentials matched",
self.host,
)
return local_seed, remote_seed, self.kasa_setup_auth_hash # type: ignore
return local_seed, remote_seed, self._kasa_setup_auth_hash # type: ignore
# Finally check against blank credentials if not already blank
if self.credentials != (blank_creds := Credentials(username="", password="")):
if not self.blank_auth_hash:
self.blank_auth_hash = TPLinkKlap.generate_auth_hash(blank_creds)
blank_seed_auth_hash = _sha256(local_seed + self.blank_auth_hash) # type: ignore
if self._credentials != (blank_creds := Credentials(username="", password="")):
if not self._blank_auth_hash:
self._blank_auth_hash = self.generate_auth_hash(blank_creds)
blank_seed_auth_hash = self.handshake1_seed_auth_hash(
local_seed,
remote_seed,
self._blank_auth_hash, # type: ignore
)
if blank_seed_auth_hash == server_hash:
_LOGGER.debug(
"Server response doesn't match our expected hash on ip %s"
+ " but an authentication with blank credentials matched",
self.host,
)
return local_seed, remote_seed, self.blank_auth_hash # type: ignore
return local_seed, remote_seed, self._blank_auth_hash # type: ignore
msg = f"Server response doesn't match our challenge on ip {self.host}"
_LOGGER.debug(msg)
@ -235,7 +241,7 @@ class TPLinkKlap(TPLinkProtocol):
url = f"http://{self.host}/app/handshake2"
payload = _sha256(remote_seed + auth_hash)
payload = self.handshake2_seed_auth_hash(local_seed, remote_seed, auth_hash)
response_status, response_data = await self.client_post(url, data=payload)
@ -256,115 +262,70 @@ class TPLinkKlap(TPLinkProtocol):
return KlapEncryptionSession(local_seed, remote_seed, auth_hash)
@property
def needs_login(self) -> bool:
"""Will return false as KLAP does not do a login."""
return False
async def login(self, request: str) -> None:
"""Will raise and exception as KLAP does not do a login."""
raise SmartDeviceException(
"KLAP does not perform logins and return needs_login == False"
)
@property
def needs_handshake(self) -> bool:
"""Return true if the transport needs to do a handshake."""
return not self._handshake_done or self._handshake_session_expired()
async def handshake(self) -> None:
"""Perform the encryption handshake."""
await self.perform_handshake()
async def perform_handshake(self) -> Any:
"""Perform handshake1 and handshake2.
Sets the encryption_session if successful.
"""
_LOGGER.debug("Starting handshake with %s", self.host)
self.handshake_done = False
self.session_expire_at = None
self.session_cookie = None
self._handshake_done = False
self._session_expire_at = None
self._session_cookie = None
local_seed, remote_seed, auth_hash = await self.perform_handshake1()
self.session_cookie = self.http_client.cookies.get( # type: ignore
TPLinkKlap.SESSION_COOKIE_NAME
self._session_cookie = self._http_client.cookies.get( # type: ignore
self.SESSION_COOKIE_NAME
)
# The device returns a TIMEOUT cookie on handshake1 which
# it doesn't like to get back so we store the one we want
self.session_expire_at = time.time() + 86400
self.encryption_session = await self.perform_handshake2(
self._session_expire_at = time.time() + 86400
self._encryption_session = await self.perform_handshake2(
local_seed, remote_seed, auth_hash
)
self.handshake_done = True
self._handshake_done = True
_LOGGER.debug("Handshake with %s complete", self.host)
def handshake_session_expired(self):
def _handshake_session_expired(self):
"""Return true if session has expired."""
return (
self.session_expire_at is None or self.session_expire_at - time.time() <= 0
self._session_expire_at is None
or self._session_expire_at - time.time() <= 0
)
@staticmethod
def generate_auth_hash(creds: Credentials):
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
un = creds.username or ""
pw = creds.password or ""
return _md5(_md5(un.encode()) + _md5(pw.encode()))
@staticmethod
def generate_owner_hash(creds: Credentials):
"""Return the MD5 hash of the username in this object."""
un = creds.username or ""
return _md5(un.encode())
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
"""Query the device retrying for retry_count on failure."""
if isinstance(request, dict):
request = json_dumps(request)
assert isinstance(request, str) # noqa: S101
async with self.query_lock:
return await self._query(request, retry_count)
async def _query(self, request: str, retry_count: int = 3) -> Dict:
for retry in range(retry_count + 1):
try:
return await self._execute_query(request, retry)
except httpx.CloseError as sdex:
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
async def send(self, request: str):
"""Send the request."""
if self.needs_handshake:
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {sdex}"
) from sdex
continue
except httpx.ConnectError as cex:
await self.close()
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {cex}"
) from cex
except TimeoutError as tex:
await self.close()
raise SmartDeviceException(
f"Unable to connect to the device, timed out: {self.host}: {tex}"
) from tex
except AuthenticationException as auex:
_LOGGER.debug("Unable to authenticate with %s, not retrying", self.host)
raise auex
except Exception as ex:
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {ex}"
) from ex
continue
# make mypy happy, this should never be reached..
raise SmartDeviceException("Query reached somehow to unreachable")
async def _execute_query(self, request: str, retry_count: int) -> Dict:
if not self.http_client:
self.http_client = httpx.AsyncClient()
if not self.handshake_done or self.handshake_session_expired():
try:
await self.perform_handshake()
except AuthenticationException as auex:
_LOGGER.debug(
"Unable to complete handshake for device %s, "
+ "authentication failed",
self.host,
"Handshake must be complete before trying to send"
)
raise auex
if self.needs_login:
raise SmartDeviceException("Login must be complete before trying to send")
# Check for mypy
if self.encryption_session is not None:
payload, seq = self.encryption_session.encrypt(request.encode())
if self._encryption_session is not None:
payload, seq = self._encryption_session.encrypt(request.encode())
url = f"http://{self.host}/app/request"
@ -376,14 +337,14 @@ class TPLinkKlap(TPLinkProtocol):
msg = (
f"at {datetime.datetime.now()}. Host is {self.host}, "
+ f"Retry count is {retry_count}, Sequence is {seq}, "
+ f"Sequence is {seq}, "
+ f"Response status is {response_status}, Request was {request}"
)
if response_status != 200:
_LOGGER.error("Query failed after succesful authentication " + msg)
# If we failed with a security error, force a new handshake next time.
if response_status == 403:
self.handshake_done = False
self._handshake_done = False
raise AuthenticationException(
f"Got a security error from {self.host} after handshake "
+ "completed"
@ -397,8 +358,8 @@ class TPLinkKlap(TPLinkProtocol):
_LOGGER.debug("Query posted " + msg)
# Check for mypy
if self.encryption_session is not None:
decrypted_response = self.encryption_session.decrypt(response_data)
if self._encryption_session is not None:
decrypted_response = self._encryption_session.decrypt(response_data)
json_payload = json_loads(decrypted_response)
@ -411,12 +372,66 @@ class TPLinkKlap(TPLinkProtocol):
return json_payload
async def close(self) -> None:
"""Close the protocol."""
client = self.http_client
self.http_client = None
"""Close the transport."""
client = self._http_client
self._http_client = None
if client:
await client.aclose()
@staticmethod
def generate_auth_hash(creds: Credentials):
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
un = creds.username or ""
pw = creds.password or ""
return md5(md5(un.encode()) + md5(pw.encode()))
@staticmethod
def handshake1_seed_auth_hash(
local_seed: bytes, remote_seed: bytes, auth_hash: bytes
):
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
return _sha256(local_seed + auth_hash)
@staticmethod
def handshake2_seed_auth_hash(
local_seed: bytes, remote_seed: bytes, auth_hash: bytes
):
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
return _sha256(remote_seed + auth_hash)
@staticmethod
def generate_owner_hash(creds: Credentials):
"""Return the MD5 hash of the username in this object."""
un = creds.username or ""
return md5(un.encode())
class TPlinkKlapTransportV2(KlapTransport):
"""Implementation of the KLAP encryption protocol with v2 hanshake hashes."""
@staticmethod
def generate_auth_hash(creds: Credentials):
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
un = creds.username or ""
pw = creds.password or ""
return _sha256(_sha1(un.encode()) + _sha1(pw.encode()))
@staticmethod
def handshake1_seed_auth_hash(
local_seed: bytes, remote_seed: bytes, auth_hash: bytes
):
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
return _sha256(local_seed + remote_seed + auth_hash)
@staticmethod
def handshake2_seed_auth_hash(
local_seed: bytes, remote_seed: bytes, auth_hash: bytes
):
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
return _sha256(remote_seed + local_seed + auth_hash)
class KlapEncryptionSession:
"""Class to represent an encryption session and it's internal state.

View File

@ -22,6 +22,7 @@ from typing import Dict, Generator, Optional, Union
# When support for cpython older than 3.11 is dropped
# async_timeout can be replaced with asyncio.timeout
from async_timeout import timeout as asyncio_timeout
from cryptography.hazmat.primitives import hashes
from .credentials import Credentials
from .exceptions import SmartDeviceException
@ -32,6 +33,56 @@ _LOGGER = logging.getLogger(__name__)
_NO_RETRY_ERRORS = {errno.EHOSTDOWN, errno.EHOSTUNREACH, errno.ECONNREFUSED}
def md5(payload: bytes) -> bytes:
"""Return an md5 hash of the payload."""
digest = hashes.Hash(hashes.MD5()) # noqa: S303
digest.update(payload)
hash = digest.finalize()
return hash
class BaseTransport(ABC):
"""Base class for all TP-Link protocol transports."""
def __init__(
self,
host: str,
*,
port: Optional[int] = None,
credentials: Optional[Credentials] = None,
) -> None:
"""Create a protocol object."""
self.host = host
self.port = port
self.credentials = credentials
@property
@abstractmethod
def needs_handshake(self) -> bool:
"""Return true if the transport needs to do a handshake."""
@property
@abstractmethod
def needs_login(self) -> bool:
"""Return true if the transport needs to do a login."""
@abstractmethod
async def login(self, request: str) -> None:
"""Login to the device."""
@abstractmethod
async def handshake(self) -> None:
"""Perform the encryption handshake."""
@abstractmethod
async def send(self, request: str) -> Dict:
"""Send a message to the device and return a response."""
@abstractmethod
async def close(self) -> None:
"""Close the transport. Abstract method to be overriden."""
class TPLinkProtocol(ABC):
"""Base class for all TP-Link Smart Home communication."""
@ -41,6 +92,7 @@ class TPLinkProtocol(ABC):
*,
port: Optional[int] = None,
credentials: Optional[Credentials] = None,
transport: Optional[BaseTransport] = None,
) -> None:
"""Create a protocol object."""
self.host = host

View File

@ -365,6 +365,7 @@ class SmartDevice:
def update_from_discover_info(self, info: Dict[str, Any]) -> None:
"""Update state from info from the discover call."""
self._discovery_info = info
if "system" in info and (sys_info := info["system"].get("get_sysinfo")):
self._last_update = info
self._set_sys_info(sys_info)
@ -372,7 +373,6 @@ class SmartDevice:
# This allows setting of some info properties directly
# from partial discovery info that will then be found
# by the requires_update decorator
self._discovery_info = info
self._set_sys_info(info)
def _set_sys_info(self, sys_info: Dict[str, Any]) -> None:

219
kasa/smartprotocol.py Normal file
View File

@ -0,0 +1,219 @@
"""Implementation of the TP-Link AES Protocol.
Based on the work of https://github.com/petretiandrea/plugp100
under compatible GNU GPL3 license.
"""
import asyncio
import base64
import logging
import time
import uuid
from pprint import pformat as pf
from typing import Dict, Optional, Union
import httpx
from .aestransport import AesTransport
from .credentials import Credentials
from .exceptions import AuthenticationException, SmartDeviceException
from .json import dumps as json_dumps
from .protocol import BaseTransport, TPLinkProtocol, md5
_LOGGER = logging.getLogger(__name__)
logging.getLogger("httpx").propagate = False
class SmartProtocol(TPLinkProtocol):
"""Class for the new TPLink SMART protocol."""
DEFAULT_PORT = 80
def __init__(
self,
host: str,
*,
transport: Optional[BaseTransport] = None,
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
) -> None:
super().__init__(host=host, port=self.DEFAULT_PORT)
self._credentials: Credentials = credentials or Credentials(
username="", password=""
)
self._transport: BaseTransport = transport or AesTransport(
host, credentials=self._credentials, timeout=timeout
)
self._terminal_uuid: Optional[str] = None
self._request_id_generator = SnowflakeId(1, 1)
self._query_lock = asyncio.Lock()
def get_smart_request(self, method, params=None) -> str:
"""Get a request message as a string."""
request = {
"method": method,
"params": params,
"requestID": self._request_id_generator.generate_id(),
"request_time_milis": round(time.time() * 1000),
"terminal_uuid": self._terminal_uuid,
}
return json_dumps(request)
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
"""Query the device retrying for retry_count on failure."""
async with self._query_lock:
resp_dict = await self._query(request, retry_count)
if "result" in resp_dict:
return resp_dict["result"]
return {}
async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
for retry in range(retry_count + 1):
try:
return await self._execute_query(request, retry)
except httpx.CloseError as sdex:
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {sdex}"
) from sdex
continue
except httpx.ConnectError as cex:
await self.close()
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {cex}"
) from cex
except TimeoutError as tex:
await self.close()
raise SmartDeviceException(
f"Unable to connect to the device, timed out: {self.host}: {tex}"
) from tex
except AuthenticationException as auex:
_LOGGER.debug("Unable to authenticate with %s, not retrying", self.host)
raise auex
except Exception as ex:
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {ex}"
) from ex
continue
# make mypy happy, this should never be reached..
raise SmartDeviceException("Query reached somehow to unreachable")
async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> Dict:
if isinstance(request, dict):
smart_method = next(iter(request))
smart_params = request[smart_method]
else:
smart_method = request
smart_params = None
if self._transport.needs_handshake:
await self._transport.handshake()
if self._transport.needs_login:
self._terminal_uuid = base64.b64encode(md5(uuid.uuid4().bytes)).decode(
"UTF-8"
)
login_request = self.get_smart_request("login_device")
await self._transport.login(login_request)
smart_request = self.get_smart_request(smart_method, smart_params)
response_data = await self._transport.send(smart_request)
_LOGGER.debug(
"%s << %s",
self.host,
_LOGGER.isEnabledFor(logging.DEBUG) and pf(response_data),
)
return response_data
async def close(self) -> None:
"""Close the protocol."""
await self._transport.close()
class SnowflakeId:
"""Class for generating snowflake ids."""
EPOCH = 1420041600000 # Custom epoch (in milliseconds)
WORKER_ID_BITS = 5
DATA_CENTER_ID_BITS = 5
SEQUENCE_BITS = 12
MAX_WORKER_ID = (1 << WORKER_ID_BITS) - 1
MAX_DATA_CENTER_ID = (1 << DATA_CENTER_ID_BITS) - 1
SEQUENCE_MASK = (1 << SEQUENCE_BITS) - 1
def __init__(self, worker_id, data_center_id):
if worker_id > SnowflakeId.MAX_WORKER_ID or worker_id < 0:
raise ValueError(
"Worker ID can't be greater than "
+ str(SnowflakeId.MAX_WORKER_ID)
+ " or less than 0"
)
if data_center_id > SnowflakeId.MAX_DATA_CENTER_ID or data_center_id < 0:
raise ValueError(
"Data center ID can't be greater than "
+ str(SnowflakeId.MAX_DATA_CENTER_ID)
+ " or less than 0"
)
self.worker_id = worker_id
self.data_center_id = data_center_id
self.sequence = 0
self.last_timestamp = -1
def generate_id(self):
"""Generate a snowflake id."""
timestamp = self._current_millis()
if timestamp < self.last_timestamp:
raise ValueError("Clock moved backwards. Refusing to generate ID.")
if timestamp == self.last_timestamp:
# Within the same millisecond, increment the sequence number
self.sequence = (self.sequence + 1) & SnowflakeId.SEQUENCE_MASK
if self.sequence == 0:
# Sequence exceeds its bit range, wait until the next millisecond
timestamp = self._wait_next_millis(self.last_timestamp)
else:
# New millisecond, reset the sequence number
self.sequence = 0
# Update the last timestamp
self.last_timestamp = timestamp
# Generate and return the final ID
return (
(
(timestamp - SnowflakeId.EPOCH)
<< (
SnowflakeId.WORKER_ID_BITS
+ SnowflakeId.SEQUENCE_BITS
+ SnowflakeId.DATA_CENTER_ID_BITS
)
)
| (
self.data_center_id
<< (SnowflakeId.SEQUENCE_BITS + SnowflakeId.WORKER_ID_BITS)
)
| (self.worker_id << SnowflakeId.SEQUENCE_BITS)
| self.sequence
)
def _current_millis(self):
return round(time.time() * 1000)
def _wait_next_millis(self, last_timestamp):
timestamp = self._current_millis()
while timestamp <= last_timestamp:
timestamp = self._current_millis()
return timestamp

View File

@ -4,10 +4,10 @@ import logging
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, Optional, Set, cast
from ..aesprotocol import TPLinkAes
from ..credentials import Credentials
from ..exceptions import AuthenticationException
from ..smartdevice import SmartDevice
from ..smartprotocol import SmartProtocol
_LOGGER = logging.getLogger(__name__)
@ -26,7 +26,7 @@ class TapoDevice(SmartDevice):
super().__init__(host, port=port, credentials=credentials, timeout=timeout)
self._state_information: Dict[str, Any] = {}
self._discovery_info: Optional[Dict[str, Any]] = None
self.protocol = TPLinkAes(host, credentials=credentials, timeout=timeout)
self.protocol = SmartProtocol(host, credentials=credentials, timeout=timeout)
async def update(self, update_children: bool = True):
"""Update the device."""

View File

@ -2,27 +2,45 @@ import asyncio
import glob
import json
import os
from dataclasses import dataclass
from json import dumps as json_dumps
from os.path import basename
from pathlib import Path, PurePath
from typing import Dict
from typing import Dict, Optional
from unittest.mock import MagicMock
import pytest # type: ignore # see https://github.com/pytest-dev/pytest/issues/3342
from kasa import (
Credentials,
Discover,
SmartBulb,
SmartDimmer,
SmartLightStrip,
SmartPlug,
SmartStrip,
TPLinkSmartHomeProtocol,
)
from kasa.tapo import TapoDevice, TapoPlug
from .newfakes import FakeTransportProtocol
from .newfakes import FakeSmartProtocol, FakeTransportProtocol
SUPPORTED_DEVICES = glob.glob(
SUPPORTED_IOT_DEVICES = [
(device, "IOT")
for device in glob.glob(
os.path.dirname(os.path.abspath(__file__)) + "/fixtures/*.json"
)
)
]
SUPPORTED_SMART_DEVICES = [
(device, "SMART")
for device in glob.glob(
os.path.dirname(os.path.abspath(__file__)) + "/fixtures/smart/*.json"
)
]
SUPPORTED_DEVICES = SUPPORTED_IOT_DEVICES + SUPPORTED_SMART_DEVICES
LIGHT_STRIPS = {"KL400", "KL430", "KL420"}
@ -55,43 +73,59 @@ PLUGS = {
"KP401",
"KS200M",
}
STRIPS = {"HS107", "HS300", "KP303", "KP200", "KP400", "EP40"}
DIMMERS = {"ES20M", "HS220", "KS220M", "KS230", "KP405"}
DIMMABLE = {*BULBS, *DIMMERS}
WITH_EMETER = {"HS110", "HS300", "KP115", "KP125", *BULBS}
ALL_DEVICES = BULBS.union(PLUGS).union(STRIPS).union(DIMMERS)
ALL_DEVICES_IOT = BULBS.union(PLUGS).union(STRIPS).union(DIMMERS)
PLUGS_SMART = {"P110"}
ALL_DEVICES_SMART = PLUGS_SMART
ALL_DEVICES = ALL_DEVICES_IOT.union(ALL_DEVICES_SMART)
IP_MODEL_CACHE: Dict[str, str] = {}
def filter_model(desc, filter):
filtered = list()
for dev in SUPPORTED_DEVICES:
for filt in filter:
if filt in basename(dev):
filtered.append(dev)
def idgenerator(paramtuple):
return basename(paramtuple[0]) + (
"" if paramtuple[1] == "IOT" else "-" + paramtuple[1]
)
filtered_basenames = [basename(f) for f in filtered]
def filter_model(desc, model_filter, protocol_filter=None):
if not protocol_filter:
protocol_filter = {"IOT"}
filtered = list()
for file, protocol in SUPPORTED_DEVICES:
if protocol in protocol_filter:
file_model = basename(file).split("_")[0]
for model in model_filter:
if model in file_model:
filtered.append((file, protocol))
filtered_basenames = [basename(f) + "-" + p for f, p in filtered]
print(f"{desc}: {filtered_basenames}")
return filtered
def parametrize(desc, devices, ids=None):
def parametrize(desc, devices, protocol_filter=None, ids=None):
return pytest.mark.parametrize(
"dev", filter_model(desc, devices), indirect=True, ids=ids
"dev", filter_model(desc, devices, protocol_filter), indirect=True, ids=ids
)
has_emeter = parametrize("has emeter", WITH_EMETER)
no_emeter = parametrize("no emeter", ALL_DEVICES - WITH_EMETER)
no_emeter = parametrize("no emeter", ALL_DEVICES_IOT - WITH_EMETER)
bulb = parametrize("bulbs", BULBS, ids=basename)
plug = parametrize("plugs", PLUGS, ids=basename)
strip = parametrize("strips", STRIPS, ids=basename)
dimmer = parametrize("dimmers", DIMMERS, ids=basename)
lightstrip = parametrize("lightstrips", LIGHT_STRIPS, ids=basename)
bulb = parametrize("bulbs", BULBS, ids=idgenerator)
plug = parametrize("plugs", PLUGS, ids=idgenerator)
strip = parametrize("strips", STRIPS, ids=idgenerator)
dimmer = parametrize("dimmers", DIMMERS, ids=idgenerator)
lightstrip = parametrize("lightstrips", LIGHT_STRIPS, ids=idgenerator)
# bulb types
dimmable = parametrize("dimmable", DIMMABLE)
@ -101,6 +135,58 @@ non_variable_temp = parametrize("non-variable color temp", BULBS - VARIABLE_TEMP
color_bulb = parametrize("color bulbs", COLOR_BULBS)
non_color_bulb = parametrize("non-color bulbs", BULBS - COLOR_BULBS)
plug_smart = parametrize(
"plug devices smart", PLUGS_SMART, protocol_filter={"SMART"}, ids=idgenerator
)
device_smart = parametrize(
"devices smart", ALL_DEVICES_SMART, protocol_filter={"SMART"}, ids=idgenerator
)
device_iot = parametrize(
"devices iot", ALL_DEVICES_IOT, protocol_filter={"IOT"}, ids=idgenerator
)
def get_fixture_data():
"""Return raw discovery file contents as JSON. Used for discovery tests."""
fixture_data = {}
for file, protocol in SUPPORTED_DEVICES:
p = Path(file)
if not p.is_absolute():
folder = Path(__file__).parent / "fixtures"
if protocol == "SMART":
folder = folder / "smart"
p = folder / file
with open(p) as f:
fixture_data[basename(p)] = json.load(f)
return fixture_data
FIXTURE_DATA = get_fixture_data()
def filter_fixtures(desc, root_filter):
filtered = {}
for key, val in FIXTURE_DATA.items():
if root_filter in val:
filtered[key] = val
print(f"{desc}: {filtered.keys()}")
return filtered
def parametrize_discovery(desc, root_key):
filtered_fixtures = filter_fixtures(desc, root_key)
return pytest.mark.parametrize(
"discovery_data",
filtered_fixtures.values(),
indirect=True,
ids=filtered_fixtures.keys(),
)
new_discovery = parametrize_discovery("new discovery", "discovery_result")
def check_categories():
"""Check that every fixture file is categorized."""
@ -110,15 +196,15 @@ def check_categories():
+ plug.args[1]
+ bulb.args[1]
+ lightstrip.args[1]
+ plug_smart.args[1]
)
diff = set(SUPPORTED_DEVICES) - set(categorized_fixtures)
if diff:
for file in diff:
for file, protocol in diff:
print(
"No category for file %s, add to the corresponding set (BULBS, PLUGS, ..)"
% file
f"No category for file {file} protocol {protocol}, add to the corresponding set (BULBS, PLUGS, ..)"
)
raise Exception("Missing category for %s" % diff)
raise Exception(f"Missing category for {diff}")
check_categories()
@ -134,7 +220,12 @@ async def handle_turn_on(dev, turn_on):
await dev.turn_off()
def device_for_file(model):
def device_for_file(model, protocol):
if protocol == "SMART":
for d in PLUGS_SMART:
if d in model:
return TapoPlug
else:
for d in STRIPS:
if d in model:
return SmartStrip
@ -170,11 +261,14 @@ async def _discover_update_and_close(ip):
return await _update_and_close(d)
async def get_device_for_file(file):
async def get_device_for_file(file, protocol):
# if the wanted file is not an absolute path, prepend the fixtures directory
p = Path(file)
if not p.is_absolute():
p = Path(__file__).parent / "fixtures" / file
folder = Path(__file__).parent / "fixtures"
if protocol == "SMART":
folder = folder / "smart"
p = folder / file
def load_file():
with open(p) as f:
@ -184,7 +278,11 @@ async def get_device_for_file(file):
sysinfo = await loop.run_in_executor(None, load_file)
model = basename(file)
d = device_for_file(model)(host="127.0.0.123")
d = device_for_file(model, protocol)(host="127.0.0.123")
if protocol == "SMART":
d.protocol = FakeSmartProtocol(sysinfo)
d.credentials = Credentials("", "")
else:
d.protocol = FakeTransportProtocol(sysinfo)
await _update_and_close(d)
return d
@ -197,7 +295,7 @@ async def dev(request):
Provides a device (given --ip) or parametrized fixture for the supported devices.
The initial update is called automatically before returning the device.
"""
file = request.param
file, protocol = request.param
ip = request.config.getoption("--ip")
if ip:
@ -210,19 +308,62 @@ async def dev(request):
pytest.skip(f"skipping file {file}")
return d if d else await _discover_update_and_close(ip)
return await get_device_for_file(file)
return await get_device_for_file(file, protocol)
@pytest.fixture(params=SUPPORTED_DEVICES, scope="session")
@pytest.fixture
def discovery_mock(discovery_data, mocker):
@dataclass
class _DiscoveryMock:
ip: str
default_port: int
discovery_data: dict
port_override: Optional[int] = None
if "result" in discovery_data:
datagram = (
b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8"
+ json_dumps(discovery_data).encode()
)
dm = _DiscoveryMock("127.0.0.123", 20002, discovery_data)
else:
datagram = TPLinkSmartHomeProtocol.encrypt(json_dumps(discovery_data))[4:]
dm = _DiscoveryMock("127.0.0.123", 9999, discovery_data)
def mock_discover(self):
port = (
dm.port_override
if dm.port_override and dm.default_port != 20002
else dm.default_port
)
self.datagram_received(
datagram,
(dm.ip, port),
)
mocker.patch("kasa.discover._DiscoverProtocol.do_discover", mock_discover)
mocker.patch(
"socket.getaddrinfo",
side_effect=lambda *_, **__: [(None, None, None, None, (dm.ip, 0))],
)
yield dm
@pytest.fixture(params=FIXTURE_DATA.values(), ids=FIXTURE_DATA.keys(), scope="session")
def discovery_data(request):
"""Return raw discovery file contents as JSON. Used for discovery tests."""
file = request.param
p = Path(file)
if not p.is_absolute():
p = Path(__file__).parent / "fixtures" / file
fixture_data = request.param
if "discovery_result" in fixture_data:
return {"result": fixture_data["discovery_result"]}
else:
return {"system": {"get_sysinfo": fixture_data["system"]["get_sysinfo"]}}
with open(p) as f:
return json.load(f)
@pytest.fixture(params=FIXTURE_DATA.values(), ids=FIXTURE_DATA.keys(), scope="session")
def all_fixture_data(request):
"""Return raw fixture file contents as JSON. Used for discovery tests."""
fixture_data = request.param
return fixture_data
def pytest_addoption(parser):

View File

@ -0,0 +1,180 @@
{
"component_nego": {
"component_list": [
{
"id": "device",
"ver_code": 2
},
{
"id": "firmware",
"ver_code": 2
},
{
"id": "quick_setup",
"ver_code": 3
},
{
"id": "time",
"ver_code": 1
},
{
"id": "wireless",
"ver_code": 1
},
{
"id": "schedule",
"ver_code": 2
},
{
"id": "countdown",
"ver_code": 2
},
{
"id": "antitheft",
"ver_code": 1
},
{
"id": "account",
"ver_code": 1
},
{
"id": "synchronize",
"ver_code": 1
},
{
"id": "sunrise_sunset",
"ver_code": 1
},
{
"id": "led",
"ver_code": 1
},
{
"id": "cloud_connect",
"ver_code": 1
},
{
"id": "iot_cloud",
"ver_code": 1
},
{
"id": "device_local_time",
"ver_code": 1
},
{
"id": "default_states",
"ver_code": 1
},
{
"id": "auto_off",
"ver_code": 2
},
{
"id": "localSmart",
"ver_code": 1
},
{
"id": "energy_monitoring",
"ver_code": 2
},
{
"id": "power_protection",
"ver_code": 1
},
{
"id": "current_protection",
"ver_code": 1
}
]
},
"discovery_result": {
"device_id": "00000000000000000000000000000000",
"device_model": "P110(UK)",
"device_type": "SMART.TAPOPLUG",
"factory_default": false,
"ip": "127.0.0.123",
"is_support_iot_cloud": true,
"mac": "00-00-00-00-00-00",
"mgt_encrypt_schm": {
"encrypt_type": "KLAP",
"http_port": 80,
"is_support_https": false,
"lv": 2
},
"obd_src": "tplink",
"owner": "00000000000000000000000000000000"
},
"get_current_power": {
"current_power": 0
},
"get_device_info": {
"auto_off_remain_time": 0,
"auto_off_status": "off",
"avatar": "plug",
"default_states": {
"state": {},
"type": "last_states"
},
"device_id": "0000000000000000000000000000000000000000",
"device_on": true,
"fw_id": "00000000000000000000000000000000",
"fw_ver": "1.3.0 Build 230905 Rel.152200",
"has_set_location_info": true,
"hw_id": "00000000000000000000000000000000",
"hw_ver": "1.0",
"ip": "127.0.0.123",
"lang": "en_US",
"latitude": 0,
"longitude": 0,
"mac": "00-00-00-00-00-00",
"model": "P110",
"nickname": "VGFwaSBTbWFydCBQbHVnIDE=",
"oem_id": "00000000000000000000000000000000",
"on_time": 119335,
"overcurrent_status": "normal",
"overheated": false,
"power_protection_status": "normal",
"region": "Europe/London",
"rssi": -57,
"signal_level": 2,
"specs": "",
"ssid": "IyNNQVNLRUROQU1FIyM=",
"time_diff": 0,
"type": "SMART.TAPOPLUG"
},
"get_device_time": {
"region": "Europe/London",
"time_diff": 0,
"timestamp": 1701370224
},
"get_device_usage": {
"power_usage": {
"past30": 75,
"past7": 69,
"today": 0
},
"saved_power": {
"past30": 2029,
"past7": 1964,
"today": 1130
},
"time_usage": {
"past30": 2104,
"past7": 2033,
"today": 1130
}
},
"get_energy_usage": {
"current_power": 0,
"electricity_charge": [
0,
0,
0
],
"local_time": "2023-11-30 18:50:24",
"month_energy": 75,
"month_runtime": 2104,
"today_energy": 0,
"today_runtime": 1130
}
}

View File

@ -1,6 +1,7 @@
import copy
import logging
import re
from json import loads as json_loads
from voluptuous import (
REMOVE_EXTRA,
@ -13,7 +14,8 @@ from voluptuous import (
Schema,
)
from ..protocol import TPLinkSmartHomeProtocol
from ..protocol import BaseTransport, TPLinkSmartHomeProtocol
from ..smartprotocol import SmartProtocol
_LOGGER = logging.getLogger(__name__)
@ -285,6 +287,41 @@ TIME_MODULE = {
}
class FakeSmartProtocol(SmartProtocol):
def __init__(self, info):
super().__init__("127.0.0.123", transport=FakeSmartTransport(info))
class FakeSmartTransport(BaseTransport):
def __init__(self, info):
self.info = info
@property
def needs_handshake(self) -> bool:
return False
@property
def needs_login(self) -> bool:
return False
async def login(self, request: str) -> None:
pass
async def handshake(self) -> None:
pass
async def send(self, request: str):
request_dict = json_loads(request)
method = request_dict["method"]
if method == "component_nego" or method[:4] == "get_":
return self.info[method]
elif method[:4] == "set_":
_LOGGER.debug("Call %s not implemented, doing nothing", method)
async def close(self) -> None:
pass
class FakeTransportProtocol(TPLinkSmartHomeProtocol):
def __init__(self, info):
self.discovery_data = info

View File

@ -6,12 +6,15 @@ from asyncclick.testing import CliRunner
from kasa import SmartDevice, TPLinkSmartHomeProtocol
from kasa.cli import alias, brightness, cli, emeter, raw_command, state, sysinfo, toggle
from kasa.device_factory import DEVICE_TYPE_TO_CLASS
from kasa.discover import Discover
from kasa.smartprotocol import SmartProtocol
from .conftest import handle_turn_on, turn_on
from .newfakes import FakeTransportProtocol
from .conftest import device_iot, handle_turn_on, new_discovery, turn_on
from .newfakes import FakeSmartProtocol, FakeTransportProtocol
@device_iot
async def test_sysinfo(dev):
runner = CliRunner()
res = await runner.invoke(sysinfo, obj=dev)
@ -19,6 +22,7 @@ async def test_sysinfo(dev):
assert dev.alias in res.output
@device_iot
@turn_on
async def test_state(dev, turn_on):
await handle_turn_on(dev, turn_on)
@ -32,6 +36,7 @@ async def test_state(dev, turn_on):
assert "Device state: False" in res.output
@device_iot
@turn_on
async def test_toggle(dev, turn_on, mocker):
await handle_turn_on(dev, turn_on)
@ -44,6 +49,7 @@ async def test_toggle(dev, turn_on, mocker):
assert dev.is_on
@device_iot
async def test_alias(dev):
runner = CliRunner()
@ -62,6 +68,7 @@ async def test_alias(dev):
await dev.set_alias(old_alias)
@device_iot
async def test_raw_command(dev):
runner = CliRunner()
res = await runner.invoke(raw_command, ["system", "get_sysinfo"], obj=dev)
@ -74,6 +81,7 @@ async def test_raw_command(dev):
assert "Usage" in res.output
@device_iot
async def test_emeter(dev: SmartDevice, mocker):
runner = CliRunner()
@ -99,6 +107,7 @@ async def test_emeter(dev: SmartDevice, mocker):
daily.assert_called_with(year=1900, month=12)
@device_iot
async def test_brightness(dev):
runner = CliRunner()
res = await runner.invoke(brightness, obj=dev)
@ -116,6 +125,7 @@ async def test_brightness(dev):
assert "Brightness: 12" in res.output
@device_iot
async def test_json_output(dev: SmartDevice, mocker):
"""Test that the json output produces correct output."""
mocker.patch("kasa.Discover.discover", return_value=[dev])
@ -125,13 +135,9 @@ async def test_json_output(dev: SmartDevice, mocker):
assert json.loads(res.output) == dev.internal_state
async def test_credentials(discovery_data: dict, mocker):
@new_discovery
async def test_credentials(discovery_mock, mocker):
"""Test credentials are passed correctly from cli to device."""
# As this is testing the device constructor need to explicitly wire in
# the FakeTransportProtocol
ftp = FakeTransportProtocol(discovery_data)
mocker.patch.object(TPLinkSmartHomeProtocol, "query", ftp.query)
# Patch state to echo username and password
pass_dev = click.make_pass_decorator(SmartDevice)
@ -143,18 +149,15 @@ async def test_credentials(discovery_data: dict, mocker):
)
mocker.patch("kasa.cli.state", new=_state)
cli_device_type = Discover._get_device_class(discovery_data)(
"any"
).device_type.value
for subclass in DEVICE_TYPE_TO_CLASS.values():
mocker.patch.object(subclass, "update")
runner = CliRunner()
res = await runner.invoke(
cli,
[
"--host",
"127.0.0.1",
"--type",
cli_device_type,
"127.0.0.123",
"--username",
"foo",
"--password",
@ -162,9 +165,11 @@ async def test_credentials(discovery_data: dict, mocker):
],
)
assert res.exit_code == 0
assert res.output == "Username:foo Password:bar\n"
assert "Username:foo Password:bar\n" in res.output
@device_iot
async def test_without_device_type(discovery_data: dict, dev, mocker):
"""Test connecting without the device type."""
runner = CliRunner()

View File

@ -5,7 +5,9 @@ from typing import Type
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
from kasa import (
Credentials,
DeviceType,
Discover,
SmartBulb,
SmartDevice,
SmartDeviceException,
@ -13,8 +15,13 @@ from kasa import (
SmartLightStrip,
SmartPlug,
)
from kasa.device_factory import connect
from kasa.klapprotocol import TPLinkKlap
from kasa.device_factory import (
DEVICE_TYPE_TO_CLASS,
connect,
get_protocol_from_connection_name,
)
from kasa.discover import DiscoveryResult
from kasa.iotprotocol import IotProtocol
from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol
@ -22,8 +29,12 @@ from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol
async def test_connect(discovery_data: dict, mocker, custom_port):
"""Make sure that connect returns an initialized SmartDevice instance."""
host = "127.0.0.1"
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
if "result" in discovery_data:
with pytest.raises(SmartDeviceException):
dev = await connect(host, port=custom_port)
else:
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
dev = await connect(host, port=custom_port)
assert issubclass(dev.__class__, SmartDevice)
assert dev.port == custom_port or dev.port == 9999
@ -49,8 +60,12 @@ async def test_connect_passed_device_type(
):
"""Make sure that connect with a passed device type."""
host = "127.0.0.1"
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
if "result" in discovery_data:
with pytest.raises(SmartDeviceException):
dev = await connect(host, port=custom_port)
else:
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
dev = await connect(host, port=custom_port, device_type=device_type)
assert isinstance(dev, klass)
assert dev.port == custom_port or dev.port == 9999
@ -70,32 +85,52 @@ async def test_connect_logs_connect_time(
):
"""Test that the connect time is logged when debug logging is enabled."""
host = "127.0.0.1"
if "result" in discovery_data:
with pytest.raises(SmartDeviceException):
await connect(host)
else:
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
logging.getLogger("kasa").setLevel(logging.DEBUG)
await connect(host)
assert "seconds to connect" in caplog.text
@pytest.mark.parametrize("device_type", [DeviceType.Plug, None])
@pytest.mark.parametrize(
("protocol_in", "protocol_result"),
(
(None, TPLinkSmartHomeProtocol),
(TPLinkKlap, TPLinkKlap),
(TPLinkSmartHomeProtocol, TPLinkSmartHomeProtocol),
),
)
async def test_connect_pass_protocol(
discovery_data: dict,
all_fixture_data: dict,
mocker,
device_type: DeviceType,
protocol_in: Type[TPLinkProtocol],
protocol_result: Type[TPLinkProtocol],
):
"""Test that if the protocol is passed in it's gets set correctly."""
host = "127.0.0.1"
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
mocker.patch("kasa.TPLinkKlap.query", return_value=discovery_data)
if "discovery_result" in all_fixture_data:
discovery_info = {"result": all_fixture_data["discovery_result"]}
device_class = Discover._get_device_class(discovery_info)
else:
device_class = Discover._get_device_class(all_fixture_data)
dev = await connect(host, device_type=device_type, protocol_class=protocol_in)
assert isinstance(dev.protocol, protocol_result)
device_type = list(DEVICE_TYPE_TO_CLASS.keys())[
list(DEVICE_TYPE_TO_CLASS.values()).index(device_class)
]
host = "127.0.0.1"
if "discovery_result" in all_fixture_data:
mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data)
mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data)
dr = DiscoveryResult(**discovery_info["result"])
connection_name = (
dr.device_type.split(".")[0] + "." + dr.mgt_encrypt_schm.encrypt_type
)
protocol_class = get_protocol_from_connection_name(
connection_name, host
).__class__
else:
mocker.patch(
"kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data
)
protocol_class = TPLinkSmartHomeProtocol
dev = await connect(
host,
device_type=device_type,
protocol_class=protocol_class,
credentials=Credentials("", ""),
)
assert isinstance(dev.protocol, protocol_class)

View File

@ -17,6 +17,27 @@ from kasa.exceptions import AuthenticationException, UnsupportedDeviceException
from .conftest import bulb, dimmer, lightstrip, plug, strip
UNSUPPORTED = {
"result": {
"device_id": "xx",
"owner": "xx",
"device_type": "SMART.TAPOXMASTREE",
"device_model": "P110(EU)",
"ip": "127.0.0.1",
"mac": "48-22xxx",
"is_support_iot_cloud": True,
"obd_src": "tplink",
"factory_default": False,
"mgt_encrypt_schm": {
"is_support_https": False,
"encrypt_type": "AES",
"http_port": 80,
"lv": 2,
},
},
"error_code": 0,
}
@plug
async def test_type_detection_plug(dev: SmartDevice):
@ -62,76 +83,40 @@ async def test_type_unknown():
@pytest.mark.parametrize("custom_port", [123, None])
async def test_discover_single(discovery_data: dict, mocker, custom_port):
# @pytest.mark.parametrize("discovery_mock", [("127.0.0.1",123), ("127.0.0.1",None)], indirect=True)
async def test_discover_single(discovery_mock, custom_port, mocker):
"""Make sure that discover_single returns an initialized SmartDevice instance."""
host = "127.0.0.1"
info = {"system": {"get_sysinfo": discovery_data["system"]["get_sysinfo"]}}
query_mock = mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=info)
def mock_discover(self):
self.datagram_received(
protocol.TPLinkSmartHomeProtocol.encrypt(json_dumps(info))[4:],
(host, custom_port or 9999),
)
mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover)
discovery_mock.ip = host
discovery_mock.port_override = custom_port
update_mock = mocker.patch.object(SmartStrip, "update")
x = await Discover.discover_single(host, port=custom_port)
assert issubclass(x.__class__, SmartDevice)
assert x._sys_info is not None
assert x.port == custom_port or x.port == 9999
assert (query_mock.call_count > 0) == isinstance(x, SmartStrip)
assert x._discovery_info is not None
assert x.port == custom_port or x.port == discovery_mock.default_port
assert (update_mock.call_count > 0) == isinstance(x, SmartStrip)
async def test_discover_single_hostname(discovery_data: dict, mocker):
async def test_discover_single_hostname(discovery_mock, mocker):
"""Make sure that discover_single returns an initialized SmartDevice instance."""
host = "foobar"
ip = "127.0.0.1"
info = {"system": {"get_sysinfo": discovery_data["system"]["get_sysinfo"]}}
query_mock = mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=info)
def mock_discover(self):
self.datagram_received(
protocol.TPLinkSmartHomeProtocol.encrypt(json_dumps(info))[4:],
(ip, 9999),
)
mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover)
mocker.patch("socket.getaddrinfo", return_value=[(None, None, None, None, (ip, 0))])
discovery_mock.ip = ip
update_mock = mocker.patch.object(SmartStrip, "update")
x = await Discover.discover_single(host)
assert issubclass(x.__class__, SmartDevice)
assert x._sys_info is not None
assert x._discovery_info is not None
assert x.host == host
assert (query_mock.call_count > 0) == isinstance(x, SmartStrip)
assert (update_mock.call_count > 0) == isinstance(x, SmartStrip)
mocker.patch("socket.getaddrinfo", side_effect=socket.gaierror())
with pytest.raises(SmartDeviceException):
x = await Discover.discover_single(host)
UNSUPPORTED = {
"result": {
"device_id": "xx",
"owner": "xx",
"device_type": "SMART.TAPOXMASTREE",
"device_model": "P110(EU)",
"ip": "127.0.0.1",
"mac": "48-22xxx",
"is_support_iot_cloud": True,
"obd_src": "tplink",
"factory_default": False,
"mgt_encrypt_schm": {
"is_support_https": False,
"encrypt_type": "AES",
"http_port": 80,
"lv": 2,
},
},
"error_code": 0,
}
async def test_discover_single_unsupported(mocker):
"""Make sure that discover_single handles unsupported devices correctly."""
host = "127.0.0.1"
@ -201,14 +186,17 @@ async def test_discover_send(mocker):
async def test_discover_datagram_received(mocker, discovery_data):
"""Verify that datagram received fills discovered_devices."""
proto = _DiscoverProtocol()
info = {"system": {"get_sysinfo": discovery_data["system"]["get_sysinfo"]}}
mocker.patch("kasa.discover.json_loads", return_value=info)
mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "encrypt")
mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "decrypt")
addr = "127.0.0.1"
proto.datagram_received("<placeholder data>", (addr, 9999))
port = 20002 if "result" in discovery_data else 9999
mocker.patch("kasa.discover.json_loads", return_value=discovery_data)
proto.datagram_received("<placeholder data>", (addr, port))
addr2 = "127.0.0.2"
mocker.patch("kasa.discover.json_loads", return_value=UNSUPPORTED)
proto.datagram_received("<placeholder data>", (addr2, 20002))
# Check that device in discovered_devices is initialized correctly

View File

@ -10,9 +10,14 @@ from contextlib import nullcontext as does_not_raise
import httpx
import pytest
from ..aestransport import AesTransport
from ..credentials import Credentials
from ..exceptions import AuthenticationException, SmartDeviceException
from ..klapprotocol import KlapEncryptionSession, TPLinkKlap, _sha256
from ..iotprotocol import IotProtocol
from ..klaptransport import KlapEncryptionSession, KlapTransport, _sha256
from ..smartprotocol import SmartProtocol
DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}}
class _mock_response:
@ -21,67 +26,92 @@ class _mock_response:
self.content = content
@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol])
@pytest.mark.parametrize("retry_count", [1, 3, 5])
async def test_protocol_retries(mocker, retry_count):
async def test_protocol_retries(mocker, retry_count, protocol_class, transport_class):
host = "127.0.0.1"
conn = mocker.patch.object(
TPLinkKlap, "client_post", side_effect=Exception("dummy exception")
transport_class, "client_post", side_effect=Exception("dummy exception")
)
with pytest.raises(SmartDeviceException):
await TPLinkKlap("127.0.0.1").query({}, retry_count=retry_count)
await protocol_class(host, transport=transport_class(host)).query(
DUMMY_QUERY, retry_count=retry_count
)
assert conn.call_count == retry_count + 1
async def test_protocol_no_retry_on_connection_error(mocker):
@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol])
async def test_protocol_no_retry_on_connection_error(
mocker, protocol_class, transport_class
):
host = "127.0.0.1"
conn = mocker.patch.object(
TPLinkKlap,
transport_class,
"client_post",
side_effect=httpx.ConnectError("foo"),
)
with pytest.raises(SmartDeviceException):
await TPLinkKlap("127.0.0.1").query({}, retry_count=5)
await protocol_class(host, transport=transport_class(host)).query(
DUMMY_QUERY, retry_count=5
)
assert conn.call_count == 1
async def test_protocol_retry_recoverable_error(mocker):
@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol])
async def test_protocol_retry_recoverable_error(
mocker, protocol_class, transport_class
):
host = "127.0.0.1"
conn = mocker.patch.object(
TPLinkKlap,
transport_class,
"client_post",
side_effect=httpx.CloseError("foo"),
)
with pytest.raises(SmartDeviceException):
await TPLinkKlap("127.0.0.1").query({}, retry_count=5)
await protocol_class(host, transport=transport_class(host)).query(
DUMMY_QUERY, retry_count=5
)
assert conn.call_count == 6
@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol])
@pytest.mark.parametrize("retry_count", [1, 3, 5])
async def test_protocol_reconnect(mocker, retry_count):
async def test_protocol_reconnect(mocker, retry_count, protocol_class, transport_class):
host = "127.0.0.1"
remaining = retry_count
mock_response = {"result": {"great": "success"}}
def _fail_one_less_than_retry_count(*_, **__):
nonlocal remaining, encryption_session
nonlocal remaining
remaining -= 1
if remaining:
raise Exception("Simulated post failure")
# Do the encrypt just before returning the value so the incrementing sequence number is correct
encrypted, seq = encryption_session.encrypt('{"great":"success"}')
return 200, encrypted
seed = secrets.token_bytes(16)
auth_hash = TPLinkKlap.generate_auth_hash(Credentials("foo", "bar"))
encryption_session = KlapEncryptionSession(seed, seed, auth_hash)
protocol = TPLinkKlap("127.0.0.1")
protocol.handshake_done = True
protocol.session_expire_at = time.time() + 86400
protocol.encryption_session = encryption_session
return mock_response
mocker.patch.object(
TPLinkKlap, "client_post", side_effect=_fail_one_less_than_retry_count
transport_class, "needs_handshake", property(lambda self: False)
)
mocker.patch.object(transport_class, "needs_login", property(lambda self: False))
send_mock = mocker.patch.object(
transport_class,
"send",
side_effect=_fail_one_less_than_retry_count,
)
response = await protocol.query({}, retry_count=retry_count)
assert response == {"great": "success"}
response = await protocol_class(host, transport=transport_class(host)).query(
DUMMY_QUERY, retry_count=retry_count
)
assert "result" in response or "great" in response
assert send_mock.call_count == retry_count
@pytest.mark.parametrize("log_level", [logging.WARNING, logging.DEBUG])
@ -96,14 +126,14 @@ async def test_protocol_logging(mocker, caplog, log_level):
return 200, encrypted
seed = secrets.token_bytes(16)
auth_hash = TPLinkKlap.generate_auth_hash(Credentials("foo", "bar"))
auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar"))
encryption_session = KlapEncryptionSession(seed, seed, auth_hash)
protocol = TPLinkKlap("127.0.0.1")
protocol = IotProtocol("127.0.0.1")
protocol.handshake_done = True
protocol.session_expire_at = time.time() + 86400
protocol.encryption_session = encryption_session
mocker.patch.object(TPLinkKlap, "client_post", side_effect=_return_encrypted)
protocol._transport._handshake_done = True
protocol._transport._session_expire_at = time.time() + 86400
protocol._transport._encryption_session = encryption_session
mocker.patch.object(KlapTransport, "client_post", side_effect=_return_encrypted)
response = await protocol.query({})
assert response == {"great": "success"}
@ -117,7 +147,7 @@ def test_encrypt():
d = json.dumps({"foo": 1, "bar": 2})
seed = secrets.token_bytes(16)
auth_hash = TPLinkKlap.generate_auth_hash(Credentials("foo", "bar"))
auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar"))
encryption_session = KlapEncryptionSession(seed, seed, auth_hash)
encrypted, seq = encryption_session.encrypt(d)
@ -129,7 +159,7 @@ def test_encrypt_unicode():
d = "{'snowman': '\u2603'}"
seed = secrets.token_bytes(16)
auth_hash = TPLinkKlap.generate_auth_hash(Credentials("foo", "bar"))
auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar"))
encryption_session = KlapEncryptionSession(seed, seed, auth_hash)
encrypted, seq = encryption_session.encrypt(d)
@ -145,7 +175,10 @@ def test_encrypt_unicode():
(Credentials("foo", "bar"), does_not_raise()),
(Credentials("", ""), does_not_raise()),
(
Credentials(TPLinkKlap.KASA_SETUP_EMAIL, TPLinkKlap.KASA_SETUP_PASSWORD),
Credentials(
KlapTransport.KASA_SETUP_EMAIL,
KlapTransport.KASA_SETUP_PASSWORD,
),
does_not_raise(),
),
(
@ -167,21 +200,21 @@ async def test_handshake1(mocker, device_credentials, expectation):
client_seed = None
server_seed = secrets.token_bytes(16)
client_credentials = Credentials("foo", "bar")
device_auth_hash = TPLinkKlap.generate_auth_hash(device_credentials)
device_auth_hash = KlapTransport.generate_auth_hash(device_credentials)
mocker.patch.object(
httpx.AsyncClient, "post", side_effect=_return_handshake1_response
)
protocol = TPLinkKlap("127.0.0.1", credentials=client_credentials)
protocol = IotProtocol("127.0.0.1", credentials=client_credentials)
protocol.http_client = httpx.AsyncClient()
protocol._transport.http_client = httpx.AsyncClient()
with expectation:
(
local_seed,
device_remote_seed,
auth_hash,
) = await protocol.perform_handshake1()
) = await protocol._transport.perform_handshake1()
assert local_seed == client_seed
assert device_remote_seed == server_seed
@ -204,23 +237,23 @@ async def test_handshake(mocker):
client_seed = None
server_seed = secrets.token_bytes(16)
client_credentials = Credentials("foo", "bar")
device_auth_hash = TPLinkKlap.generate_auth_hash(client_credentials)
device_auth_hash = KlapTransport.generate_auth_hash(client_credentials)
mocker.patch.object(
httpx.AsyncClient, "post", side_effect=_return_handshake_response
)
protocol = TPLinkKlap("127.0.0.1", credentials=client_credentials)
protocol.http_client = httpx.AsyncClient()
protocol = IotProtocol("127.0.0.1", credentials=client_credentials)
protocol._transport.http_client = httpx.AsyncClient()
response_status = 200
await protocol.perform_handshake()
assert protocol.handshake_done is True
await protocol._transport.perform_handshake()
assert protocol._transport._handshake_done is True
response_status = 403
with pytest.raises(AuthenticationException):
await protocol.perform_handshake()
assert protocol.handshake_done is False
await protocol._transport.perform_handshake()
assert protocol._transport._handshake_done is False
await protocol.close()
@ -237,9 +270,9 @@ async def test_query(mocker):
return _mock_response(200, b"")
elif url == "http://127.0.0.1/app/request":
encryption_session = KlapEncryptionSession(
protocol.encryption_session.local_seed,
protocol.encryption_session.remote_seed,
protocol.encryption_session.user_hash,
protocol._transport._encryption_session.local_seed,
protocol._transport._encryption_session.remote_seed,
protocol._transport._encryption_session.user_hash,
)
seq = params.get("seq")
encryption_session._seq = seq - 1
@ -252,11 +285,11 @@ async def test_query(mocker):
seq = None
server_seed = secrets.token_bytes(16)
client_credentials = Credentials("foo", "bar")
device_auth_hash = TPLinkKlap.generate_auth_hash(client_credentials)
device_auth_hash = KlapTransport.generate_auth_hash(client_credentials)
mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response)
protocol = TPLinkKlap("127.0.0.1", credentials=client_credentials)
protocol = IotProtocol("127.0.0.1", credentials=client_credentials)
for _ in range(10):
resp = await protocol.query({})
@ -296,11 +329,11 @@ async def test_authentication_failures(mocker, response_status, expectation):
server_seed = secrets.token_bytes(16)
client_credentials = Credentials("foo", "bar")
device_auth_hash = TPLinkKlap.generate_auth_hash(client_credentials)
device_auth_hash = KlapTransport.generate_auth_hash(client_credentials)
mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response)
protocol = TPLinkKlap("127.0.0.1", credentials=client_credentials)
protocol = IotProtocol("127.0.0.1", credentials=client_credentials)
with expectation:
await protocol.query({})

View File

@ -1,6 +1,6 @@
from kasa import DeviceType
from .conftest import plug
from .conftest import plug, plug_smart
from .newfakes import PLUG_SCHEMA
@ -28,3 +28,14 @@ async def test_led(dev):
assert dev.led
await dev.set_led(original)
@plug_smart
async def test_plug_device_info(dev):
assert dev._info is not None
# PLUG_SCHEMA(dev.sys_info)
assert dev.model is not None
assert dev.device_type == DeviceType.Plug or dev.device_type == DeviceType.Strip
# assert dev.is_plug or dev.is_strip

View File

@ -9,7 +9,7 @@ from kasa.tests.conftest import get_device_for_file
def test_bulb_examples(mocker):
"""Use KL130 (bulb with all features) to test the doctests."""
p = asyncio.run(get_device_for_file("KL130(US)_1.0_1.8.11.json"))
p = asyncio.run(get_device_for_file("KL130(US)_1.0_1.8.11.json", "IOT"))
mocker.patch("kasa.smartbulb.SmartBulb", return_value=p)
mocker.patch("kasa.smartbulb.SmartBulb.update")
res = xdoctest.doctest_module("kasa.smartbulb", "all")
@ -18,7 +18,7 @@ def test_bulb_examples(mocker):
def test_smartdevice_examples(mocker):
"""Use HS110 for emeter examples."""
p = asyncio.run(get_device_for_file("HS110(EU)_1.0_1.2.5.json"))
p = asyncio.run(get_device_for_file("HS110(EU)_1.0_1.2.5.json", "IOT"))
mocker.patch("kasa.smartdevice.SmartDevice", return_value=p)
mocker.patch("kasa.smartdevice.SmartDevice.update")
res = xdoctest.doctest_module("kasa.smartdevice", "all")
@ -27,7 +27,7 @@ def test_smartdevice_examples(mocker):
def test_plug_examples(mocker):
"""Test plug examples."""
p = asyncio.run(get_device_for_file("HS110(EU)_1.0_1.2.5.json"))
p = asyncio.run(get_device_for_file("HS110(EU)_1.0_1.2.5.json", "IOT"))
mocker.patch("kasa.smartplug.SmartPlug", return_value=p)
mocker.patch("kasa.smartplug.SmartPlug.update")
res = xdoctest.doctest_module("kasa.smartplug", "all")
@ -36,7 +36,7 @@ def test_plug_examples(mocker):
def test_strip_examples(mocker):
"""Test strip examples."""
p = asyncio.run(get_device_for_file("KP303(UK)_1.0_1.0.3.json"))
p = asyncio.run(get_device_for_file("KP303(UK)_1.0_1.0.3.json", "IOT"))
mocker.patch("kasa.smartstrip.SmartStrip", return_value=p)
mocker.patch("kasa.smartstrip.SmartStrip.update")
res = xdoctest.doctest_module("kasa.smartstrip", "all")
@ -45,7 +45,7 @@ def test_strip_examples(mocker):
def test_dimmer_examples(mocker):
"""Test dimmer examples."""
p = asyncio.run(get_device_for_file("HS220(US)_1.0_1.5.7.json"))
p = asyncio.run(get_device_for_file("HS220(US)_1.0_1.5.7.json", "IOT"))
mocker.patch("kasa.smartdimmer.SmartDimmer", return_value=p)
mocker.patch("kasa.smartdimmer.SmartDimmer.update")
res = xdoctest.doctest_module("kasa.smartdimmer", "all")
@ -54,7 +54,7 @@ def test_dimmer_examples(mocker):
def test_lightstrip_examples(mocker):
"""Test lightstrip examples."""
p = asyncio.run(get_device_for_file("KL430(US)_1.0_1.0.10.json"))
p = asyncio.run(get_device_for_file("KL430(US)_1.0_1.0.10.json", "IOT"))
mocker.patch("kasa.smartlightstrip.SmartLightStrip", return_value=p)
mocker.patch("kasa.smartlightstrip.SmartLightStrip.update")
res = xdoctest.doctest_module("kasa.smartlightstrip", "all")
@ -63,7 +63,7 @@ def test_lightstrip_examples(mocker):
def test_discovery_examples(mocker):
"""Test discovery examples."""
p = asyncio.run(get_device_for_file("KP303(UK)_1.0_1.0.3.json"))
p = asyncio.run(get_device_for_file("KP303(UK)_1.0_1.0.3.json", "IOT"))
mocker.patch("kasa.discover.Discover.discover", return_value=[p])
res = xdoctest.doctest_module("kasa.discover", "all")

View File

@ -8,7 +8,7 @@ import kasa
from kasa import Credentials, SmartDevice, SmartDeviceException
from kasa.smartdevice import DeviceType
from .conftest import handle_turn_on, has_emeter, no_emeter, turn_on
from .conftest import device_iot, handle_turn_on, has_emeter, no_emeter, turn_on
from .newfakes import PLUG_SCHEMA, TZ_SCHEMA, FakeTransportProtocol
# List of all SmartXXX classes including the SmartDevice base class
@ -22,11 +22,13 @@ smart_device_classes = [
]
@device_iot
async def test_state_info(dev):
assert isinstance(dev.state_information, dict)
@pytest.mark.requires_dummy
@device_iot
async def test_invalid_connection(dev):
with patch.object(
FakeTransportProtocol, "query", side_effect=SmartDeviceException
@ -58,12 +60,14 @@ async def test_initial_update_no_emeter(dev, mocker):
assert spy.call_count == 2
@device_iot
async def test_query_helper(dev):
with pytest.raises(SmartDeviceException):
await dev._query_helper("test", "testcmd", {})
# TODO check for unwrapping?
@device_iot
@turn_on
async def test_state(dev, turn_on):
await handle_turn_on(dev, turn_on)
@ -90,6 +94,7 @@ async def test_state(dev, turn_on):
assert dev.is_off
@device_iot
async def test_alias(dev):
test_alias = "TEST1234"
original = dev.alias
@ -104,6 +109,7 @@ async def test_alias(dev):
assert dev.alias == original
@device_iot
@turn_on
async def test_on_since(dev, turn_on):
await handle_turn_on(dev, turn_on)
@ -116,30 +122,37 @@ async def test_on_since(dev, turn_on):
assert dev.on_since is None
@device_iot
async def test_time(dev):
assert isinstance(await dev.get_time(), datetime)
@device_iot
async def test_timezone(dev):
TZ_SCHEMA(await dev.get_timezone())
@device_iot
async def test_hw_info(dev):
PLUG_SCHEMA(dev.hw_info)
@device_iot
async def test_location(dev):
PLUG_SCHEMA(dev.location)
@device_iot
async def test_rssi(dev):
PLUG_SCHEMA({"rssi": dev.rssi}) # wrapping for vol
@device_iot
async def test_mac(dev):
PLUG_SCHEMA({"mac": dev.mac}) # wrapping for val
@device_iot
async def test_representation(dev):
import re
@ -147,6 +160,7 @@ async def test_representation(dev):
assert pattern.match(str(dev))
@device_iot
async def test_childrens(dev):
"""Make sure that children property is exposed by every device."""
if dev.is_strip:
@ -155,6 +169,7 @@ async def test_childrens(dev):
assert len(dev.children) == 0
@device_iot
async def test_children(dev):
"""Make sure that children property is exposed by every device."""
if dev.is_strip:
@ -165,11 +180,13 @@ async def test_children(dev):
assert dev.has_children is False
@device_iot
async def test_internal_state(dev):
"""Make sure the internal state returns the last update results."""
assert dev.internal_state == dev._last_update
@device_iot
async def test_features(dev):
"""Make sure features is always accessible."""
sysinfo = dev._last_update["system"]["get_sysinfo"]
@ -179,11 +196,13 @@ async def test_features(dev):
assert dev.features == set()
@device_iot
async def test_max_device_response_size(dev):
"""Make sure every device return has a set max response size."""
assert dev.max_device_response_size > 0
@device_iot
async def test_estimated_response_sizes(dev):
"""Make sure every module has an estimated response size set."""
for mod in dev.modules.values():
@ -202,6 +221,7 @@ def test_device_class_ctors(device_class):
assert dev.credentials == credentials
@device_iot
async def test_modules_preserved(dev: SmartDevice):
"""Make modules that are not being updated are preserved between updates."""
dev._last_update["some_module_not_being_updated"] = "should_be_kept"
@ -237,6 +257,7 @@ async def test_create_thin_wrapper():
)
@device_iot
async def test_modules_not_supported(dev: SmartDevice):
"""Test that unsupported modules do not break the device."""
for module in dev.modules.values():