mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-08-06 10:44:04 +00:00
Add klap protocol (#509)
* Add support for the new encryption protocol This adds support for the new TP-Link discovery and encryption protocols. It is currently incomplete - only devices without username and password are current supported, and single device discovery is not implemented. Discovery should find both old and new devices. When accessing a device by IP the --klap option can be specified on the command line to active the new connection protocol. sdb9696 - This commit also contains 16 later commits from Simon Wilkinson squashed into the original * Update klap changes 2023 to fix encryption, deal with kasa credential switching and work with new discovery changes * Move from aiohttp to httpx * Changes following review comments --------- Co-authored-by: Simon Wilkinson <simon@sxw.org.uk>
This commit is contained in:
@@ -21,7 +21,8 @@ from kasa.exceptions import (
|
||||
SmartDeviceException,
|
||||
UnsupportedDeviceException,
|
||||
)
|
||||
from kasa.protocol import TPLinkSmartHomeProtocol
|
||||
from kasa.klapprotocol import TPLinkKlap
|
||||
from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol
|
||||
from kasa.smartbulb import SmartBulb, SmartBulbPreset, TurnOnBehavior, TurnOnBehaviors
|
||||
from kasa.smartdevice import DeviceType, SmartDevice
|
||||
from kasa.smartdimmer import SmartDimmer
|
||||
@@ -35,6 +36,8 @@ __version__ = version("python-kasa")
|
||||
__all__ = [
|
||||
"Discover",
|
||||
"TPLinkSmartHomeProtocol",
|
||||
"TPLinkProtocol",
|
||||
"TPLinkKlap",
|
||||
"SmartBulb",
|
||||
"SmartBulbPreset",
|
||||
"TurnOnBehaviors",
|
||||
|
23
kasa/cli.py
23
kasa/cli.py
@@ -11,6 +11,7 @@ from typing import Any, Dict, cast
|
||||
import asyncclick as click
|
||||
|
||||
from kasa import (
|
||||
AuthenticationException,
|
||||
Credentials,
|
||||
Discover,
|
||||
SmartBulb,
|
||||
@@ -308,8 +309,9 @@ async def discover(ctx, timeout, show_unsupported):
|
||||
sem = asyncio.Semaphore()
|
||||
discovered = dict()
|
||||
unsupported = []
|
||||
auth_failed = []
|
||||
|
||||
async def print_unsupported(data: Dict):
|
||||
async def print_unsupported(data: str):
|
||||
unsupported.append(data)
|
||||
if show_unsupported:
|
||||
echo(f"Found unsupported device (tapo/unknown encryption): {data}")
|
||||
@@ -318,12 +320,15 @@ async def discover(ctx, timeout, show_unsupported):
|
||||
echo(f"Discovering devices on {target} for {timeout} seconds")
|
||||
|
||||
async def print_discovered(dev: SmartDevice):
|
||||
await dev.update()
|
||||
async with sem:
|
||||
discovered[dev.host] = dev.internal_state
|
||||
ctx.obj = dev
|
||||
await ctx.invoke(state)
|
||||
echo()
|
||||
try:
|
||||
await dev.update()
|
||||
async with sem:
|
||||
discovered[dev.host] = dev.internal_state
|
||||
ctx.obj = dev
|
||||
await ctx.invoke(state)
|
||||
echo()
|
||||
except AuthenticationException as aex:
|
||||
auth_failed.append(str(aex))
|
||||
|
||||
await Discover.discover(
|
||||
target=target,
|
||||
@@ -343,6 +348,10 @@ async def discover(ctx, timeout, show_unsupported):
|
||||
else ", to show them use: kasa discover --show-unsupported"
|
||||
)
|
||||
)
|
||||
if auth_failed:
|
||||
echo(f"Found {len(auth_failed)} devices that failed to authenticate")
|
||||
for fail in auth_failed:
|
||||
echo(fail)
|
||||
|
||||
return discovered
|
||||
|
||||
|
161
kasa/discover.py
161
kasa/discover.py
@@ -4,17 +4,23 @@ import binascii
|
||||
import ipaddress
|
||||
import logging
|
||||
import socket
|
||||
from typing import Awaitable, Callable, Dict, Optional, Type, cast
|
||||
from typing import Awaitable, Callable, Dict, Optional, Set, Type, cast
|
||||
|
||||
# When support for cpython older than 3.11 is dropped
|
||||
# async_timeout can be replaced with asyncio.timeout
|
||||
from async_timeout import timeout as asyncio_timeout
|
||||
|
||||
try:
|
||||
from pydantic.v1 import BaseModel, Field
|
||||
except ImportError:
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from kasa.credentials import Credentials
|
||||
from kasa.exceptions import UnsupportedDeviceException
|
||||
from kasa.json import dumps as json_dumps
|
||||
from kasa.json import loads as json_loads
|
||||
from kasa.protocol import TPLinkSmartHomeProtocol
|
||||
from kasa.klapprotocol import TPLinkKlap
|
||||
from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol
|
||||
from kasa.smartbulb import SmartBulb
|
||||
from kasa.smartdevice import SmartDevice, SmartDeviceException
|
||||
from kasa.smartdimmer import SmartDimmer
|
||||
@@ -44,7 +50,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
|
||||
target: str = "255.255.255.255",
|
||||
discovery_packets: int = 3,
|
||||
interface: Optional[str] = None,
|
||||
on_unsupported: Optional[Callable[[Dict], Awaitable[None]]] = None,
|
||||
on_unsupported: Optional[Callable[[str], Awaitable[None]]] = None,
|
||||
port: Optional[int] = None,
|
||||
discovered_event: Optional[asyncio.Event] = None,
|
||||
credentials: Optional[Credentials] = None,
|
||||
@@ -64,6 +70,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
|
||||
self.discovered_event = discovered_event
|
||||
self.credentials = credentials
|
||||
self.timeout = timeout
|
||||
self.seen_hosts: Set[str] = set()
|
||||
|
||||
def connection_made(self, transport) -> None:
|
||||
"""Set socket options for broadcasting."""
|
||||
@@ -95,43 +102,36 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
|
||||
def datagram_received(self, data, addr) -> None:
|
||||
"""Handle discovery responses."""
|
||||
ip, port = addr
|
||||
if (
|
||||
ip in self.discovered_devices
|
||||
or ip in self.unsupported_devices
|
||||
or ip in self.invalid_device_exceptions
|
||||
):
|
||||
# Prevent multiple entries due multiple broadcasts
|
||||
if ip in self.seen_hosts:
|
||||
return
|
||||
self.seen_hosts.add(ip)
|
||||
|
||||
if port == self.discovery_port:
|
||||
info = json_loads(TPLinkSmartHomeProtocol.decrypt(data))
|
||||
_LOGGER.debug("[DISCOVERY] %s << %s", ip, info)
|
||||
|
||||
elif port == Discover.DISCOVERY_PORT_2:
|
||||
info = json_loads(data[16:])
|
||||
self.unsupported_devices[ip] = info
|
||||
device = None
|
||||
try:
|
||||
if port == self.discovery_port:
|
||||
device = Discover._get_device_instance_legacy(data, ip, port)
|
||||
elif port == Discover.DISCOVERY_PORT_2:
|
||||
device = Discover._get_device_instance(
|
||||
data, ip, port, self.credentials or Credentials()
|
||||
)
|
||||
else:
|
||||
return
|
||||
except UnsupportedDeviceException as udex:
|
||||
_LOGGER.debug("Unsupported device found at %s << %s", ip, udex)
|
||||
self.unsupported_devices[ip] = str(udex)
|
||||
if self.on_unsupported is not None:
|
||||
asyncio.ensure_future(self.on_unsupported(info))
|
||||
_LOGGER.debug("[DISCOVERY] Unsupported device found at %s << %s", ip, info)
|
||||
asyncio.ensure_future(self.on_unsupported(str(udex)))
|
||||
if self.discovered_event is not None:
|
||||
self.discovered_event.set()
|
||||
return
|
||||
|
||||
try:
|
||||
device_class = Discover._get_device_class(info)
|
||||
except SmartDeviceException as ex:
|
||||
_LOGGER.debug(
|
||||
"[DISCOVERY] Unable to find device type from %s: %s", info, ex
|
||||
)
|
||||
_LOGGER.debug(f"[DISCOVERY] Unable to find device type for {ip}: {ex}")
|
||||
self.invalid_device_exceptions[ip] = ex
|
||||
if self.discovered_event is not None:
|
||||
self.discovered_event.set()
|
||||
return
|
||||
|
||||
device = device_class(
|
||||
ip, port=port, credentials=self.credentials, timeout=self.timeout
|
||||
)
|
||||
device.update_from_discover_info(info)
|
||||
|
||||
self.discovered_devices[ip] = device
|
||||
|
||||
if self.on_discovered is not None:
|
||||
@@ -269,6 +269,10 @@ class Discover:
|
||||
to discovery requests.
|
||||
|
||||
:param host: Hostname of device to query
|
||||
:param port: Optionally set a different port for the device
|
||||
:param timeout: Timeout for discovery
|
||||
:param credentials: Optionally provide credentials for
|
||||
devices requiring them
|
||||
:rtype: SmartDevice
|
||||
:return: Object for querying/controlling found device.
|
||||
"""
|
||||
@@ -344,6 +348,7 @@ class Discover:
|
||||
port: Optional[int] = None,
|
||||
timeout=5,
|
||||
credentials: Optional[Credentials] = None,
|
||||
protocol_class: Optional[Type[TPLinkProtocol]] = None,
|
||||
) -> SmartDevice:
|
||||
"""Connect to a single device by the given IP address.
|
||||
|
||||
@@ -358,12 +363,20 @@ class Discover:
|
||||
The device type is discovered by querying the device.
|
||||
|
||||
:param host: Hostname of device to query
|
||||
:param port: Optionally set a different port for the device
|
||||
:param timeout: Timeout for discovery
|
||||
:param credentials: Optionally provide credentials for
|
||||
devices requiring them
|
||||
:param protocol_class: Optionally provide the protocol class
|
||||
to use.
|
||||
:rtype: SmartDevice
|
||||
:return: Object for querying/controlling found device.
|
||||
"""
|
||||
unknown_dev = SmartDevice(
|
||||
host=host, port=port, credentials=credentials, timeout=timeout
|
||||
)
|
||||
if protocol_class is not None:
|
||||
unknown_dev.protocol = protocol_class(host, credentials=credentials)
|
||||
await unknown_dev.update()
|
||||
device_class = Discover._get_device_class(unknown_dev.internal_state)
|
||||
dev = device_class(
|
||||
@@ -399,5 +412,95 @@ class Discover:
|
||||
return SmartLightStrip
|
||||
|
||||
return SmartBulb
|
||||
raise UnsupportedDeviceException("Unknown device type: %s" % type_)
|
||||
|
||||
raise SmartDeviceException("Unknown device type: %s" % type_)
|
||||
@staticmethod
|
||||
def _get_device_instance_legacy(data: bytes, ip: str, port: int) -> SmartDevice:
|
||||
"""Get SmartDevice from legacy 9999 response."""
|
||||
try:
|
||||
info = json_loads(TPLinkSmartHomeProtocol.decrypt(data))
|
||||
except Exception as ex:
|
||||
raise SmartDeviceException(
|
||||
f"Unable to read response from device: {ip}: {ex}"
|
||||
) from ex
|
||||
|
||||
_LOGGER.debug("[DISCOVERY] %s << %s", ip, info)
|
||||
|
||||
device_class = Discover._get_device_class(info)
|
||||
device = device_class(ip, port=port)
|
||||
device.update_from_discover_info(info)
|
||||
return device
|
||||
|
||||
@staticmethod
|
||||
def _get_device_instance(
|
||||
data: bytes, ip: str, port: int, credentials: Credentials
|
||||
) -> SmartDevice:
|
||||
"""Get SmartDevice from the new 20002 response."""
|
||||
try:
|
||||
info = json_loads(data[16:])
|
||||
discovery_result = DiscoveryResult(**info["result"])
|
||||
except Exception as ex:
|
||||
raise UnsupportedDeviceException(
|
||||
f"Unable to read response from device: {ip}: {ex}"
|
||||
) from ex
|
||||
|
||||
if (
|
||||
discovery_result.mgt_encrypt_schm.encrypt_type == "KLAP"
|
||||
and discovery_result.mgt_encrypt_schm.lv is None
|
||||
):
|
||||
type_ = discovery_result.device_type
|
||||
device_class = None
|
||||
if type_.upper() == "IOT.SMARTPLUGSWITCH":
|
||||
device_class = SmartPlug
|
||||
|
||||
if device_class:
|
||||
_LOGGER.debug("[DISCOVERY] %s << %s", ip, info)
|
||||
device = device_class(ip, port=port, credentials=credentials)
|
||||
device.update_from_discover_info(discovery_result.get_dict())
|
||||
device.protocol = TPLinkKlap(ip, credentials=credentials)
|
||||
return device
|
||||
else:
|
||||
raise UnsupportedDeviceException(
|
||||
f"Unsupported device {ip} of type {type_}: {info}"
|
||||
)
|
||||
else:
|
||||
raise UnsupportedDeviceException(f"Unsupported device {ip}: {info}")
|
||||
|
||||
|
||||
class DiscoveryResult(BaseModel):
|
||||
"""Base model for discovery result."""
|
||||
|
||||
class Config:
|
||||
"""Class for configuring model behaviour."""
|
||||
|
||||
allow_population_by_field_name = True
|
||||
|
||||
class EncryptionScheme(BaseModel):
|
||||
"""Base model for encryption scheme of discovery result."""
|
||||
|
||||
is_support_https: Optional[bool] = None
|
||||
encrypt_type: Optional[str] = None
|
||||
http_port: Optional[int] = None
|
||||
lv: Optional[int] = None
|
||||
|
||||
device_type: str = Field(alias="device_type_text")
|
||||
device_model: str = Field(alias="model")
|
||||
ip: str = Field(alias="alias")
|
||||
mac: str
|
||||
mgt_encrypt_schm: EncryptionScheme
|
||||
|
||||
device_id: Optional[str] = Field(default=None, alias="device_type_hash")
|
||||
owner: Optional[str] = Field(default=None, alias="device_owner_hash")
|
||||
hw_ver: Optional[str] = None
|
||||
is_support_iot_cloud: Optional[bool] = None
|
||||
obd_src: Optional[str] = None
|
||||
factory_default: Optional[bool] = None
|
||||
|
||||
def get_dict(self) -> dict:
|
||||
"""Return a dict for this discovery result.
|
||||
|
||||
containing only the values actually set and with aliases as field names.
|
||||
"""
|
||||
return self.dict(
|
||||
by_alias=True, exclude_unset=True, exclude_none=True, exclude_defaults=True
|
||||
)
|
||||
|
485
kasa/klapprotocol.py
Executable file
485
kasa/klapprotocol.py
Executable file
@@ -0,0 +1,485 @@
|
||||
"""Implementation of the TP-Link Klap Home Protocol.
|
||||
|
||||
Encryption/Decryption methods based on the works of
|
||||
Simon Wilkinson and Chris Weeldon
|
||||
|
||||
Klap devices that have never been connected to the kasa
|
||||
cloud should work with blank credentials.
|
||||
Devices that have been connected to the kasa cloud will
|
||||
switch intermittently between the users cloud credentials
|
||||
and default kasa credentials that are hardcoded.
|
||||
This appears to be an issue with the devices.
|
||||
|
||||
The protocol works by doing a two stage handshake to obtain
|
||||
and encryption key and session id cookie.
|
||||
|
||||
Authentication uses an auth_hash which is
|
||||
md5(md5(username),md5(password))
|
||||
|
||||
handshake1: client sends a random 16 byte local_seed to the
|
||||
device and receives a random 16 bytes remote_seed, followed
|
||||
by sha256(local_seed + auth_hash). It also returns a
|
||||
TP_SESSIONID in the cookie header. This implementation
|
||||
then checks this value against the possible auth_hashes
|
||||
described above (user cloud, kasa hardcoded, blank). If it
|
||||
finds a match it moves onto handshake2
|
||||
|
||||
handshake2: client sends sha25(remote_seed + auth_hash) to
|
||||
the device along with the TP_SESSIONID. Device responds with
|
||||
200 if succesful. It generally will be because this
|
||||
implemenation checks the auth_hash it recevied during handshake1
|
||||
|
||||
encryption: local_seed, remote_seed and auth_hash are now used
|
||||
for encryption. The last 4 bytes of the initialisation vector
|
||||
are used as a sequence number that increments every time the
|
||||
client calls encrypt and this sequence number is sent as a
|
||||
url parameter to the device along with the encrypted payload
|
||||
|
||||
https://gist.github.com/chriswheeldon/3b17d974db3817613c69191c0480fe55
|
||||
https://github.com/python-kasa/python-kasa/pull/117
|
||||
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
import time
|
||||
from pprint import pformat as pf
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
from cryptography.hazmat.primitives import hashes, padding
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
|
||||
from .credentials import Credentials
|
||||
from .exceptions import AuthenticationException, SmartDeviceException
|
||||
from .json import dumps as json_dumps
|
||||
from .json import loads as json_loads
|
||||
from .protocol import TPLinkProtocol
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
logging.getLogger("httpx").propagate = False
|
||||
|
||||
|
||||
def _sha256(payload: bytes) -> bytes:
|
||||
return hashlib.sha256(payload).digest()
|
||||
|
||||
|
||||
def _md5(payload: bytes) -> bytes:
|
||||
digest = hashes.Hash(hashes.MD5()) # noqa: S303
|
||||
digest.update(payload)
|
||||
hash = digest.finalize()
|
||||
return hash
|
||||
|
||||
|
||||
class TPLinkKlap(TPLinkProtocol):
|
||||
"""Implementation of the KLAP encryption protocol.
|
||||
|
||||
KLAP is the name used in device discovery for TP-Link's new encryption
|
||||
protocol, used by newer firmware versions.
|
||||
"""
|
||||
|
||||
DEFAULT_PORT = 80
|
||||
DEFAULT_TIMEOUT = 5
|
||||
DISCOVERY_QUERY = {"system": {"get_sysinfo": None}}
|
||||
KASA_SETUP_EMAIL = "kasa@tp-link.net"
|
||||
KASA_SETUP_PASSWORD = "kasaSetup" # noqa: S105
|
||||
SESSION_COOKIE_NAME = "TP_SESSIONID"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
*,
|
||||
credentials: Optional[Credentials] = None,
|
||||
timeout: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__(host=host, port=self.DEFAULT_PORT)
|
||||
|
||||
self.credentials = (
|
||||
credentials
|
||||
if credentials and credentials.username and credentials.password
|
||||
else Credentials(username="", password="")
|
||||
)
|
||||
|
||||
self._local_seed: Optional[bytes] = None
|
||||
self.local_auth_hash = self.generate_auth_hash(self.credentials)
|
||||
self.local_auth_owner = self.generate_owner_hash(self.credentials).hex()
|
||||
self.kasa_setup_auth_hash = None
|
||||
self.blank_auth_hash = None
|
||||
self.handshake_lock = asyncio.Lock()
|
||||
self.query_lock = asyncio.Lock()
|
||||
self.handshake_done = False
|
||||
|
||||
self.encryption_session: Optional[KlapEncryptionSession] = None
|
||||
self.session_expire_at: Optional[float] = None
|
||||
|
||||
self.timeout = timeout if timeout else self.DEFAULT_TIMEOUT
|
||||
self.session_cookie = None
|
||||
self.http_client: Optional[httpx.AsyncClient] = None
|
||||
|
||||
_LOGGER.debug("Created KLAP object for %s", self.host)
|
||||
|
||||
async def client_post(self, url, params=None, data=None):
|
||||
"""Send an http post request to the device."""
|
||||
response_data = None
|
||||
cookies = None
|
||||
if self.session_cookie:
|
||||
cookies = httpx.Cookies()
|
||||
cookies.set(self.SESSION_COOKIE_NAME, self.session_cookie)
|
||||
self.http_client.cookies.clear()
|
||||
resp = await self.http_client.post(
|
||||
url,
|
||||
params=params,
|
||||
data=data,
|
||||
timeout=self.timeout,
|
||||
cookies=cookies,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
response_data = resp.content
|
||||
|
||||
return resp.status_code, response_data
|
||||
|
||||
async def perform_handshake1(self) -> Tuple[bytes, bytes, bytes]:
|
||||
"""Perform handshake1."""
|
||||
local_seed: bytes = secrets.token_bytes(16)
|
||||
|
||||
# Handshake 1 has a payload of local_seed
|
||||
# and a response of 16 bytes, followed by
|
||||
# sha256(remote_seed | auth_hash)
|
||||
|
||||
payload = local_seed
|
||||
|
||||
url = f"http://{self.host}/app/handshake1"
|
||||
|
||||
response_status, response_data = await self.client_post(url, data=payload)
|
||||
|
||||
if _LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_LOGGER.debug(
|
||||
"Handshake1 posted at %s. Host is %s, Response"
|
||||
+ "status is %s, Request was %s",
|
||||
datetime.datetime.now(),
|
||||
self.host,
|
||||
response_status,
|
||||
payload.hex(),
|
||||
)
|
||||
|
||||
if response_status != 200:
|
||||
raise AuthenticationException(
|
||||
f"Device {self.host} responded with {response_status} to handshake1"
|
||||
)
|
||||
|
||||
remote_seed: bytes = response_data[0:16]
|
||||
server_hash = response_data[16:]
|
||||
|
||||
if _LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_LOGGER.debug(
|
||||
"Handshake1 success at %s. Host is %s, "
|
||||
+ "Server remote_seed is: %s, server hash is: %s",
|
||||
datetime.datetime.now(),
|
||||
self.host,
|
||||
remote_seed.hex(),
|
||||
server_hash.hex(),
|
||||
)
|
||||
|
||||
local_seed_auth_hash = _sha256(local_seed + self.local_auth_hash)
|
||||
|
||||
# Check the response from the device with local credentials
|
||||
if local_seed_auth_hash == server_hash:
|
||||
_LOGGER.debug("handshake1 hashes match with expected credentials")
|
||||
return local_seed, remote_seed, self.local_auth_hash # type: ignore
|
||||
|
||||
# Now check against the default kasa setup credentials
|
||||
if not self.kasa_setup_auth_hash:
|
||||
kasa_setup_creds = Credentials(
|
||||
username=TPLinkKlap.KASA_SETUP_EMAIL,
|
||||
password=TPLinkKlap.KASA_SETUP_PASSWORD,
|
||||
)
|
||||
self.kasa_setup_auth_hash = TPLinkKlap.generate_auth_hash(kasa_setup_creds)
|
||||
|
||||
kasa_setup_seed_auth_hash = _sha256(
|
||||
local_seed + self.kasa_setup_auth_hash # type: ignore
|
||||
)
|
||||
if kasa_setup_seed_auth_hash == server_hash:
|
||||
_LOGGER.debug(
|
||||
"Server response doesn't match our expected hash on ip %s"
|
||||
+ " but an authentication with kasa setup credentials matched",
|
||||
self.host,
|
||||
)
|
||||
return local_seed, remote_seed, self.kasa_setup_auth_hash # type: ignore
|
||||
|
||||
# Finally check against blank credentials if not already blank
|
||||
if self.credentials != (blank_creds := Credentials(username="", password="")):
|
||||
if not self.blank_auth_hash:
|
||||
self.blank_auth_hash = TPLinkKlap.generate_auth_hash(blank_creds)
|
||||
blank_seed_auth_hash = _sha256(local_seed + self.blank_auth_hash) # type: ignore
|
||||
if blank_seed_auth_hash == server_hash:
|
||||
_LOGGER.debug(
|
||||
"Server response doesn't match our expected hash on ip %s"
|
||||
+ " but an authentication with blank credentials matched",
|
||||
self.host,
|
||||
)
|
||||
return local_seed, remote_seed, self.blank_auth_hash # type: ignore
|
||||
|
||||
msg = f"Server response doesn't match our challenge on ip {self.host}"
|
||||
_LOGGER.debug(msg)
|
||||
raise AuthenticationException(msg)
|
||||
|
||||
async def perform_handshake2(
|
||||
self, local_seed, remote_seed, auth_hash
|
||||
) -> "KlapEncryptionSession":
|
||||
"""Perform handshake2."""
|
||||
# Handshake 2 has the following payload:
|
||||
# sha256(serverBytes | authenticator)
|
||||
|
||||
url = f"http://{self.host}/app/handshake2"
|
||||
|
||||
payload = _sha256(remote_seed + auth_hash)
|
||||
|
||||
response_status, response_data = await self.client_post(url, data=payload)
|
||||
|
||||
if _LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_LOGGER.debug(
|
||||
"Handshake2 posted %s. Host is %s, Response status is %s, "
|
||||
+ "Request was %s",
|
||||
datetime.datetime.now(),
|
||||
self.host,
|
||||
response_status,
|
||||
payload.hex(),
|
||||
)
|
||||
|
||||
if response_status != 200:
|
||||
raise AuthenticationException(
|
||||
f"Device {self.host} responded with {response_status} to handshake2"
|
||||
)
|
||||
|
||||
return KlapEncryptionSession(local_seed, remote_seed, auth_hash)
|
||||
|
||||
async def perform_handshake(self) -> Any:
|
||||
"""Perform handshake1 and handshake2.
|
||||
|
||||
Sets the encryption_session if successful.
|
||||
"""
|
||||
_LOGGER.debug("Starting handshake with %s", self.host)
|
||||
self.handshake_done = False
|
||||
self.session_expire_at = None
|
||||
self.session_cookie = None
|
||||
|
||||
local_seed, remote_seed, auth_hash = await self.perform_handshake1()
|
||||
self.session_cookie = self.http_client.cookies.get( # type: ignore
|
||||
TPLinkKlap.SESSION_COOKIE_NAME
|
||||
)
|
||||
# The device returns a TIMEOUT cookie on handshake1 which
|
||||
# it doesn't like to get back so we store the one we want
|
||||
|
||||
self.session_expire_at = time.time() + 86400
|
||||
self.encryption_session = await self.perform_handshake2(
|
||||
local_seed, remote_seed, auth_hash
|
||||
)
|
||||
self.handshake_done = True
|
||||
|
||||
_LOGGER.debug("Handshake with %s complete", self.host)
|
||||
|
||||
def handshake_session_expired(self):
|
||||
"""Return true if session has expired."""
|
||||
return (
|
||||
self.session_expire_at is None or self.session_expire_at - time.time() <= 0
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def generate_auth_hash(creds: Credentials):
|
||||
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
|
||||
un = creds.username or ""
|
||||
pw = creds.password or ""
|
||||
return _md5(_md5(un.encode()) + _md5(pw.encode()))
|
||||
|
||||
@staticmethod
|
||||
def generate_owner_hash(creds: Credentials):
|
||||
"""Return the MD5 hash of the username in this object."""
|
||||
un = creds.username or ""
|
||||
return _md5(un.encode())
|
||||
|
||||
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
||||
"""Query the device retrying for retry_count on failure."""
|
||||
if isinstance(request, dict):
|
||||
request = json_dumps(request)
|
||||
assert isinstance(request, str) # noqa: S101
|
||||
|
||||
async with self.query_lock:
|
||||
return await self._query(request, retry_count)
|
||||
|
||||
async def _query(self, request: str, retry_count: int = 3) -> Dict:
|
||||
for retry in range(retry_count + 1):
|
||||
try:
|
||||
return await self._execute_query(request, retry)
|
||||
except httpx.CloseError as sdex:
|
||||
await self.close()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device: {self.host}: {sdex}"
|
||||
) from sdex
|
||||
continue
|
||||
except httpx.ConnectError as cex:
|
||||
await self.close()
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device: {self.host}: {cex}"
|
||||
) from cex
|
||||
except TimeoutError as tex:
|
||||
await self.close()
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device, timed out: {self.host}: {tex}"
|
||||
) from tex
|
||||
except AuthenticationException as auex:
|
||||
_LOGGER.debug("Unable to authenticate with %s, not retrying", self.host)
|
||||
raise auex
|
||||
except Exception as ex:
|
||||
await self.close()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device: {self.host}: {ex}"
|
||||
) from ex
|
||||
continue
|
||||
|
||||
# make mypy happy, this should never be reached..
|
||||
raise SmartDeviceException("Query reached somehow to unreachable")
|
||||
|
||||
async def _execute_query(self, request: str, retry_count: int) -> Dict:
|
||||
if not self.http_client:
|
||||
self.http_client = httpx.AsyncClient()
|
||||
|
||||
if not self.handshake_done or self.handshake_session_expired():
|
||||
try:
|
||||
await self.perform_handshake()
|
||||
|
||||
except AuthenticationException as auex:
|
||||
_LOGGER.debug(
|
||||
"Unable to complete handshake for device %s, "
|
||||
+ "authentication failed",
|
||||
self.host,
|
||||
)
|
||||
raise auex
|
||||
|
||||
# Check for mypy
|
||||
if self.encryption_session is not None:
|
||||
payload, seq = self.encryption_session.encrypt(request.encode())
|
||||
|
||||
url = f"http://{self.host}/app/request"
|
||||
|
||||
response_status, response_data = await self.client_post(
|
||||
url,
|
||||
params={"seq": seq},
|
||||
data=payload,
|
||||
)
|
||||
|
||||
msg = (
|
||||
f"at {datetime.datetime.now()}. Host is {self.host}, "
|
||||
+ f"Retry count is {retry_count}, Sequence is {seq}, "
|
||||
+ f"Response status is {response_status}, Request was {request}"
|
||||
)
|
||||
if response_status != 200:
|
||||
_LOGGER.error("Query failed after succesful authentication " + msg)
|
||||
# If we failed with a security error, force a new handshake next time.
|
||||
if response_status == 403:
|
||||
self.handshake_done = False
|
||||
raise AuthenticationException(
|
||||
f"Got a security error from {self.host} after handshake "
|
||||
+ "completed"
|
||||
)
|
||||
else:
|
||||
raise SmartDeviceException(
|
||||
f"Device {self.host} responded with {response_status} to"
|
||||
+ f"request with seq {seq}"
|
||||
)
|
||||
else:
|
||||
_LOGGER.debug("Query posted " + msg)
|
||||
|
||||
# Check for mypy
|
||||
if self.encryption_session is not None:
|
||||
decrypted_response = self.encryption_session.decrypt(response_data)
|
||||
|
||||
json_payload = json_loads(decrypted_response)
|
||||
|
||||
_LOGGER.debug(
|
||||
"%s << %s",
|
||||
self.host,
|
||||
_LOGGER.isEnabledFor(logging.DEBUG) and pf(json_payload),
|
||||
)
|
||||
|
||||
return json_payload
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the protocol."""
|
||||
client = self.http_client
|
||||
self.http_client = None
|
||||
if client:
|
||||
await client.aclose()
|
||||
|
||||
|
||||
class KlapEncryptionSession:
|
||||
"""Class to represent an encryption session and it's internal state.
|
||||
|
||||
i.e. sequence number which the device expects to increment.
|
||||
"""
|
||||
|
||||
def __init__(self, local_seed, remote_seed, user_hash):
|
||||
self.local_seed = local_seed
|
||||
self.remote_seed = remote_seed
|
||||
self.user_hash = user_hash
|
||||
self._key = self._key_derive(local_seed, remote_seed, user_hash)
|
||||
(self._iv, self._seq) = self._iv_derive(local_seed, remote_seed, user_hash)
|
||||
self._sig = self._sig_derive(local_seed, remote_seed, user_hash)
|
||||
|
||||
def _key_derive(self, local_seed, remote_seed, user_hash):
|
||||
payload = b"lsk" + local_seed + remote_seed + user_hash
|
||||
return hashlib.sha256(payload).digest()[:16]
|
||||
|
||||
def _iv_derive(self, local_seed, remote_seed, user_hash):
|
||||
# iv is first 16 bytes of sha256, where the last 4 bytes forms the
|
||||
# sequence number used in requests and is incremented on each request
|
||||
payload = b"iv" + local_seed + remote_seed + user_hash
|
||||
fulliv = hashlib.sha256(payload).digest()
|
||||
seq = int.from_bytes(fulliv[-4:], "big", signed=True)
|
||||
return (fulliv[:12], seq)
|
||||
|
||||
def _sig_derive(self, local_seed, remote_seed, user_hash):
|
||||
# used to create a hash with which to prefix each request
|
||||
payload = b"ldk" + local_seed + remote_seed + user_hash
|
||||
return hashlib.sha256(payload).digest()[:28]
|
||||
|
||||
def _iv_seq(self):
|
||||
seq = self._seq.to_bytes(4, "big", signed=True)
|
||||
iv = self._iv + seq
|
||||
return iv
|
||||
|
||||
def encrypt(self, msg):
|
||||
"""Encrypt the data and increment the sequence number."""
|
||||
self._seq = self._seq + 1
|
||||
if isinstance(msg, str):
|
||||
msg = msg.encode("utf-8")
|
||||
|
||||
cipher = Cipher(algorithms.AES(self._key), modes.CBC(self._iv_seq()))
|
||||
encryptor = cipher.encryptor()
|
||||
padder = padding.PKCS7(128).padder()
|
||||
padded_data = padder.update(msg) + padder.finalize()
|
||||
ciphertext = encryptor.update(padded_data) + encryptor.finalize()
|
||||
|
||||
digest = hashes.Hash(hashes.SHA256())
|
||||
digest.update(
|
||||
self._sig + self._seq.to_bytes(4, "big", signed=True) + ciphertext
|
||||
)
|
||||
signature = digest.finalize()
|
||||
|
||||
return (signature + ciphertext, self._seq)
|
||||
|
||||
def decrypt(self, msg):
|
||||
"""Decrypt the data."""
|
||||
cipher = Cipher(algorithms.AES(self._key), modes.CBC(self._iv_seq()))
|
||||
decryptor = cipher.decryptor()
|
||||
dp = decryptor.update(msg[32:]) + decryptor.finalize()
|
||||
unpadder = padding.PKCS7(128).unpadder()
|
||||
plaintextbytes = unpadder.update(dp) + unpadder.finalize()
|
||||
|
||||
return plaintextbytes.decode()
|
@@ -14,6 +14,7 @@ import contextlib
|
||||
import errno
|
||||
import logging
|
||||
import struct
|
||||
from abc import ABC, abstractmethod
|
||||
from pprint import pformat as pf
|
||||
from typing import Dict, Generator, Optional, Union
|
||||
|
||||
@@ -21,6 +22,7 @@ from typing import Dict, Generator, Optional, Union
|
||||
# async_timeout can be replaced with asyncio.timeout
|
||||
from async_timeout import timeout as asyncio_timeout
|
||||
|
||||
from .credentials import Credentials
|
||||
from .exceptions import SmartDeviceException
|
||||
from .json import dumps as json_dumps
|
||||
from .json import loads as json_loads
|
||||
@@ -29,7 +31,31 @@ _LOGGER = logging.getLogger(__name__)
|
||||
_NO_RETRY_ERRORS = {errno.EHOSTDOWN, errno.EHOSTUNREACH, errno.ECONNREFUSED}
|
||||
|
||||
|
||||
class TPLinkSmartHomeProtocol:
|
||||
class TPLinkProtocol(ABC):
|
||||
"""Base class for all TP-Link Smart Home communication."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
*,
|
||||
port: Optional[int] = None,
|
||||
credentials: Optional[Credentials] = None,
|
||||
) -> None:
|
||||
"""Create a protocol object."""
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.credentials = credentials
|
||||
|
||||
@abstractmethod
|
||||
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
||||
"""Query the device for the protocol. Abstract method to be overriden."""
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
"""Close the protocol. Abstract method to be overriden."""
|
||||
|
||||
|
||||
class TPLinkSmartHomeProtocol(TPLinkProtocol):
|
||||
"""Implementation of the TP-Link Smart Home protocol."""
|
||||
|
||||
INITIALIZATION_VECTOR = 171
|
||||
@@ -38,11 +64,18 @@ class TPLinkSmartHomeProtocol:
|
||||
BLOCK_SIZE = 4
|
||||
|
||||
def __init__(
|
||||
self, host: str, *, port: Optional[int] = None, timeout: Optional[int] = None
|
||||
self,
|
||||
host: str,
|
||||
*,
|
||||
port: Optional[int] = None,
|
||||
timeout: Optional[int] = None,
|
||||
credentials: Optional[Credentials] = None,
|
||||
) -> None:
|
||||
"""Create a protocol object."""
|
||||
self.host = host
|
||||
self.port = port or TPLinkSmartHomeProtocol.DEFAULT_PORT
|
||||
super().__init__(
|
||||
host=host, port=port or self.DEFAULT_PORT, credentials=credentials
|
||||
)
|
||||
|
||||
self.reader: Optional[asyncio.StreamReader] = None
|
||||
self.writer: Optional[asyncio.StreamWriter] = None
|
||||
self.query_lock = asyncio.Lock()
|
||||
|
@@ -24,7 +24,7 @@ from .credentials import Credentials
|
||||
from .emeterstatus import EmeterStatus
|
||||
from .exceptions import SmartDeviceException
|
||||
from .modules import Emeter, Module
|
||||
from .protocol import TPLinkSmartHomeProtocol
|
||||
from .protocol import TPLinkProtocol, TPLinkSmartHomeProtocol
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -71,7 +71,7 @@ def requires_update(f):
|
||||
@functools.wraps(f)
|
||||
async def wrapped(*args, **kwargs):
|
||||
self = args[0]
|
||||
if self._last_update is None:
|
||||
if self._last_update is None and f.__name__ not in self._sys_info:
|
||||
raise SmartDeviceException(
|
||||
"You need to await update() to access the data"
|
||||
)
|
||||
@@ -82,7 +82,7 @@ def requires_update(f):
|
||||
@functools.wraps(f)
|
||||
def wrapped(*args, **kwargs):
|
||||
self = args[0]
|
||||
if self._last_update is None:
|
||||
if self._last_update is None and f.__name__ not in self._sys_info:
|
||||
raise SmartDeviceException(
|
||||
"You need to await update() to access the data"
|
||||
)
|
||||
@@ -213,8 +213,9 @@ class SmartDevice:
|
||||
"""
|
||||
self.host = host
|
||||
self.port = port
|
||||
|
||||
self.protocol = TPLinkSmartHomeProtocol(host, port=port, timeout=timeout)
|
||||
self.protocol: TPLinkProtocol = TPLinkSmartHomeProtocol(
|
||||
host, port=port, timeout=timeout
|
||||
)
|
||||
self.credentials = credentials
|
||||
_LOGGER.debug("Initializing %s of type %s", self.host, type(self))
|
||||
self._device_type = DeviceType.Unknown
|
||||
@@ -222,6 +223,7 @@ class SmartDevice:
|
||||
# checks in accessors. the @updated_required decorator does not ensure
|
||||
# mypy that these are not accessed incorrectly.
|
||||
self._last_update: Any = None
|
||||
|
||||
self._sys_info: Any = None # TODO: this is here to avoid changing tests
|
||||
self._features: Set[str] = set()
|
||||
self.modules: Dict[str, Any] = {}
|
||||
@@ -374,8 +376,14 @@ class SmartDevice:
|
||||
|
||||
def update_from_discover_info(self, info: Dict[str, Any]) -> None:
|
||||
"""Update state from info from the discover call."""
|
||||
self._last_update = info
|
||||
self._set_sys_info(info["system"]["get_sysinfo"])
|
||||
if "system" in info and (sys_info := info["system"].get("get_sysinfo")):
|
||||
self._last_update = info
|
||||
self._set_sys_info(sys_info)
|
||||
else:
|
||||
# This allows setting of some info properties directly
|
||||
# from partial discovery info that will then be found
|
||||
# by the requires_update decorator
|
||||
self._set_sys_info(info)
|
||||
|
||||
def _set_sys_info(self, sys_info: Dict[str, Any]) -> None:
|
||||
"""Set sys_info."""
|
||||
@@ -388,21 +396,26 @@ class SmartDevice:
|
||||
@property # type: ignore
|
||||
@requires_update
|
||||
def sys_info(self) -> Dict[str, Any]:
|
||||
"""Return system information."""
|
||||
"""
|
||||
Return system information.
|
||||
|
||||
Do not call this function from within the SmartDevice
|
||||
class itself as @requires_update will be affected for other properties.
|
||||
"""
|
||||
return self._sys_info # type: ignore
|
||||
|
||||
@property # type: ignore
|
||||
@requires_update
|
||||
def model(self) -> str:
|
||||
"""Return device model."""
|
||||
sys_info = self.sys_info
|
||||
sys_info = self._sys_info
|
||||
return str(sys_info["model"])
|
||||
|
||||
@property # type: ignore
|
||||
@requires_update
|
||||
def alias(self) -> str:
|
||||
"""Return device name (alias)."""
|
||||
sys_info = self.sys_info
|
||||
sys_info = self._sys_info
|
||||
return str(sys_info["alias"])
|
||||
|
||||
async def set_alias(self, alias: str) -> None:
|
||||
@@ -454,14 +467,14 @@ class SmartDevice:
|
||||
"oemId",
|
||||
"dev_name",
|
||||
]
|
||||
sys_info = self.sys_info
|
||||
sys_info = self._sys_info
|
||||
return {key: sys_info[key] for key in keys if key in sys_info}
|
||||
|
||||
@property # type: ignore
|
||||
@requires_update
|
||||
def location(self) -> Dict:
|
||||
"""Return geographical location."""
|
||||
sys_info = self.sys_info
|
||||
sys_info = self._sys_info
|
||||
loc = {"latitude": None, "longitude": None}
|
||||
|
||||
if "latitude" in sys_info and "longitude" in sys_info:
|
||||
@@ -479,7 +492,7 @@ class SmartDevice:
|
||||
@requires_update
|
||||
def rssi(self) -> Optional[int]:
|
||||
"""Return WiFi signal strength (rssi)."""
|
||||
rssi = self.sys_info.get("rssi")
|
||||
rssi = self._sys_info.get("rssi")
|
||||
return None if rssi is None else int(rssi)
|
||||
|
||||
@property # type: ignore
|
||||
@@ -489,14 +502,14 @@ class SmartDevice:
|
||||
|
||||
:return: mac address in hexadecimal with colons, e.g. 01:23:45:67:89:ab
|
||||
"""
|
||||
sys_info = self.sys_info
|
||||
|
||||
sys_info = self._sys_info
|
||||
mac = sys_info.get("mac", sys_info.get("mic_mac"))
|
||||
if not mac:
|
||||
raise SmartDeviceException(
|
||||
"Unknown mac, please submit a bug report with sys_info output."
|
||||
)
|
||||
|
||||
mac = mac.replace("-", ":")
|
||||
# Format a mac that has no colons (usually from mic_mac field)
|
||||
if ":" not in mac:
|
||||
mac = ":".join(format(s, "02x") for s in bytes.fromhex(mac))
|
||||
|
||||
@@ -607,13 +620,13 @@ class SmartDevice:
|
||||
@requires_update
|
||||
def on_since(self) -> Optional[datetime]:
|
||||
"""Return pretty-printed on-time, or None if not available."""
|
||||
if "on_time" not in self.sys_info:
|
||||
if "on_time" not in self._sys_info:
|
||||
return None
|
||||
|
||||
if self.is_off:
|
||||
return None
|
||||
|
||||
on_time = self.sys_info["on_time"]
|
||||
on_time = self._sys_info["on_time"]
|
||||
|
||||
return datetime.now().replace(microsecond=0) - timedelta(seconds=on_time)
|
||||
|
||||
|
@@ -6,8 +6,8 @@ import sys
|
||||
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
|
||||
|
||||
from kasa import DeviceType, Discover, SmartDevice, SmartDeviceException, protocol
|
||||
from kasa.discover import _DiscoverProtocol, json_dumps
|
||||
from kasa.exceptions import UnsupportedDeviceException
|
||||
from kasa.discover import DiscoveryResult, _DiscoverProtocol, json_dumps
|
||||
from kasa.exceptions import AuthenticationException, UnsupportedDeviceException
|
||||
|
||||
from .conftest import bulb, dimmer, lightstrip, plug, strip
|
||||
|
||||
@@ -51,7 +51,7 @@ async def test_type_detection_lightstrip(dev: SmartDevice):
|
||||
|
||||
async def test_type_unknown():
|
||||
invalid_info = {"system": {"get_sysinfo": {"type": "nosuchtype"}}}
|
||||
with pytest.raises(SmartDeviceException):
|
||||
with pytest.raises(UnsupportedDeviceException):
|
||||
Discover._get_device_class(invalid_info)
|
||||
|
||||
|
||||
@@ -239,3 +239,73 @@ async def test_discover_invalid_responses(msg, data, mocker):
|
||||
|
||||
proto.datagram_received(data, ("127.0.0.1", 9999))
|
||||
assert len(proto.discovered_devices) == 0
|
||||
|
||||
|
||||
AUTHENTICATION_DATA_KLAP = {
|
||||
"result": {
|
||||
"device_id": "xx",
|
||||
"owner": "xx",
|
||||
"device_type": "IOT.SMARTPLUGSWITCH",
|
||||
"device_model": "HS100(UK)",
|
||||
"ip": "127.0.0.1",
|
||||
"mac": "12-34-56-78-90-AB",
|
||||
"is_support_iot_cloud": True,
|
||||
"obd_src": "tplink",
|
||||
"factory_default": False,
|
||||
"mgt_encrypt_schm": {
|
||||
"is_support_https": False,
|
||||
"encrypt_type": "KLAP",
|
||||
"http_port": 80,
|
||||
},
|
||||
},
|
||||
"error_code": 0,
|
||||
}
|
||||
|
||||
|
||||
async def test_discover_single_authentication(mocker):
|
||||
"""Make sure that discover_single handles authenticating devices correctly."""
|
||||
host = "127.0.0.1"
|
||||
|
||||
def mock_discover(self):
|
||||
if discovery_data:
|
||||
data = (
|
||||
b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8"
|
||||
+ json_dumps(discovery_data).encode()
|
||||
)
|
||||
self.datagram_received(data, (host, 20002))
|
||||
|
||||
mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover)
|
||||
mocker.patch.object(
|
||||
SmartDevice,
|
||||
"update",
|
||||
side_effect=AuthenticationException("Failed to authenticate"),
|
||||
)
|
||||
|
||||
# Test with a valid unsupported response
|
||||
discovery_data = AUTHENTICATION_DATA_KLAP
|
||||
with pytest.raises(
|
||||
AuthenticationException,
|
||||
match="Failed to authenticate",
|
||||
):
|
||||
await Discover.discover_single(host)
|
||||
|
||||
mocker.patch.object(SmartDevice, "update")
|
||||
device = await Discover.discover_single(host)
|
||||
assert device.device_type == DeviceType.Plug
|
||||
|
||||
|
||||
async def test_device_update_from_new_discovery_info():
|
||||
device = SmartDevice("127.0.0.7")
|
||||
discover_info = DiscoveryResult(**AUTHENTICATION_DATA_KLAP["result"])
|
||||
discover_dump = discover_info.get_dict()
|
||||
device.update_from_discover_info(discover_dump)
|
||||
|
||||
assert device.alias == discover_dump["alias"]
|
||||
assert device.mac == discover_dump["mac"].replace("-", ":")
|
||||
assert device.model == discover_dump["model"]
|
||||
|
||||
with pytest.raises(
|
||||
SmartDeviceException,
|
||||
match=re.escape("You need to await update() to access the data"),
|
||||
):
|
||||
assert device.supported_modules
|
||||
|
306
kasa/tests/test_klapprotocol.py
Normal file
306
kasa/tests/test_klapprotocol.py
Normal file
@@ -0,0 +1,306 @@
|
||||
import errno
|
||||
import json
|
||||
import logging
|
||||
import secrets
|
||||
import struct
|
||||
import sys
|
||||
import time
|
||||
from contextlib import nullcontext as does_not_raise
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from ..credentials import Credentials
|
||||
from ..exceptions import AuthenticationException, SmartDeviceException
|
||||
from ..klapprotocol import KlapEncryptionSession, TPLinkKlap, _sha256
|
||||
|
||||
|
||||
class _mock_response:
|
||||
def __init__(self, status_code, content: bytes):
|
||||
self.status_code = status_code
|
||||
self.content = content
|
||||
|
||||
|
||||
@pytest.mark.parametrize("retry_count", [1, 3, 5])
|
||||
async def test_protocol_retries(mocker, retry_count):
|
||||
conn = mocker.patch.object(
|
||||
TPLinkKlap, "client_post", side_effect=Exception("dummy exception")
|
||||
)
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await TPLinkKlap("127.0.0.1").query({}, retry_count=retry_count)
|
||||
|
||||
assert conn.call_count == retry_count + 1
|
||||
|
||||
|
||||
async def test_protocol_no_retry_on_connection_error(mocker):
|
||||
conn = mocker.patch.object(
|
||||
TPLinkKlap,
|
||||
"client_post",
|
||||
side_effect=httpx.ConnectError("foo"),
|
||||
)
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await TPLinkKlap("127.0.0.1").query({}, retry_count=5)
|
||||
|
||||
assert conn.call_count == 1
|
||||
|
||||
|
||||
async def test_protocol_retry_recoverable_error(mocker):
|
||||
conn = mocker.patch.object(
|
||||
TPLinkKlap,
|
||||
"client_post",
|
||||
side_effect=httpx.CloseError("foo"),
|
||||
)
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await TPLinkKlap("127.0.0.1").query({}, retry_count=5)
|
||||
|
||||
assert conn.call_count == 6
|
||||
|
||||
|
||||
@pytest.mark.parametrize("retry_count", [1, 3, 5])
|
||||
async def test_protocol_reconnect(mocker, retry_count):
|
||||
remaining = retry_count
|
||||
|
||||
def _fail_one_less_than_retry_count(*_, **__):
|
||||
nonlocal remaining, encryption_session
|
||||
remaining -= 1
|
||||
if remaining:
|
||||
raise Exception("Simulated post failure")
|
||||
# Do the encrypt just before returning the value so the incrementing sequence number is correct
|
||||
encrypted, seq = encryption_session.encrypt('{"great":"success"}')
|
||||
return 200, encrypted
|
||||
|
||||
seed = secrets.token_bytes(16)
|
||||
auth_hash = TPLinkKlap.generate_auth_hash(Credentials("foo", "bar"))
|
||||
encryption_session = KlapEncryptionSession(seed, seed, auth_hash)
|
||||
protocol = TPLinkKlap("127.0.0.1")
|
||||
protocol.handshake_done = True
|
||||
protocol.session_expire_at = time.time() + 86400
|
||||
protocol.encryption_session = encryption_session
|
||||
mocker.patch.object(
|
||||
TPLinkKlap, "client_post", side_effect=_fail_one_less_than_retry_count
|
||||
)
|
||||
|
||||
response = await protocol.query({}, retry_count=retry_count)
|
||||
assert response == {"great": "success"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("log_level", [logging.WARNING, logging.DEBUG])
|
||||
async def test_protocol_logging(mocker, caplog, log_level):
|
||||
caplog.set_level(log_level)
|
||||
logging.getLogger("kasa").setLevel(log_level)
|
||||
|
||||
def _return_encrypted(*_, **__):
|
||||
nonlocal encryption_session
|
||||
# Do the encrypt just before returning the value so the incrementing sequence number is correct
|
||||
encrypted, seq = encryption_session.encrypt('{"great":"success"}')
|
||||
return 200, encrypted
|
||||
|
||||
seed = secrets.token_bytes(16)
|
||||
auth_hash = TPLinkKlap.generate_auth_hash(Credentials("foo", "bar"))
|
||||
encryption_session = KlapEncryptionSession(seed, seed, auth_hash)
|
||||
protocol = TPLinkKlap("127.0.0.1")
|
||||
|
||||
protocol.handshake_done = True
|
||||
protocol.session_expire_at = time.time() + 86400
|
||||
protocol.encryption_session = encryption_session
|
||||
mocker.patch.object(TPLinkKlap, "client_post", side_effect=_return_encrypted)
|
||||
|
||||
response = await protocol.query({})
|
||||
assert response == {"great": "success"}
|
||||
if log_level == logging.DEBUG:
|
||||
assert "success" in caplog.text
|
||||
else:
|
||||
assert "success" not in caplog.text
|
||||
|
||||
|
||||
def test_encrypt():
|
||||
d = json.dumps({"foo": 1, "bar": 2})
|
||||
|
||||
seed = secrets.token_bytes(16)
|
||||
auth_hash = TPLinkKlap.generate_auth_hash(Credentials("foo", "bar"))
|
||||
encryption_session = KlapEncryptionSession(seed, seed, auth_hash)
|
||||
|
||||
encrypted, seq = encryption_session.encrypt(d)
|
||||
|
||||
assert d == encryption_session.decrypt(encrypted)
|
||||
|
||||
|
||||
def test_encrypt_unicode():
|
||||
d = "{'snowman': '\u2603'}"
|
||||
|
||||
seed = secrets.token_bytes(16)
|
||||
auth_hash = TPLinkKlap.generate_auth_hash(Credentials("foo", "bar"))
|
||||
encryption_session = KlapEncryptionSession(seed, seed, auth_hash)
|
||||
|
||||
encrypted, seq = encryption_session.encrypt(d)
|
||||
|
||||
decrypted = encryption_session.decrypt(encrypted)
|
||||
|
||||
assert d == decrypted
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"device_credentials, expectation",
|
||||
[
|
||||
(Credentials("foo", "bar"), does_not_raise()),
|
||||
(Credentials("", ""), does_not_raise()),
|
||||
(
|
||||
Credentials(TPLinkKlap.KASA_SETUP_EMAIL, TPLinkKlap.KASA_SETUP_PASSWORD),
|
||||
does_not_raise(),
|
||||
),
|
||||
(
|
||||
Credentials("shouldfail", "shouldfail"),
|
||||
pytest.raises(AuthenticationException),
|
||||
),
|
||||
],
|
||||
ids=("client", "blank", "kasa_setup", "shouldfail"),
|
||||
)
|
||||
async def test_handshake1(mocker, device_credentials, expectation):
|
||||
async def _return_handshake1_response(url, params=None, data=None, *_, **__):
|
||||
nonlocal client_seed, server_seed, device_auth_hash
|
||||
|
||||
client_seed = data
|
||||
client_seed_auth_hash = _sha256(data + device_auth_hash)
|
||||
|
||||
return _mock_response(200, server_seed + client_seed_auth_hash)
|
||||
|
||||
client_seed = None
|
||||
server_seed = secrets.token_bytes(16)
|
||||
client_credentials = Credentials("foo", "bar")
|
||||
device_auth_hash = TPLinkKlap.generate_auth_hash(device_credentials)
|
||||
|
||||
mocker.patch.object(
|
||||
httpx.AsyncClient, "post", side_effect=_return_handshake1_response
|
||||
)
|
||||
|
||||
protocol = TPLinkKlap("127.0.0.1", credentials=client_credentials)
|
||||
|
||||
protocol.http_client = httpx.AsyncClient()
|
||||
with expectation:
|
||||
(
|
||||
local_seed,
|
||||
device_remote_seed,
|
||||
auth_hash,
|
||||
) = await protocol.perform_handshake1()
|
||||
|
||||
assert local_seed == client_seed
|
||||
assert device_remote_seed == server_seed
|
||||
assert device_auth_hash == auth_hash
|
||||
await protocol.close()
|
||||
|
||||
|
||||
async def test_handshake(mocker):
|
||||
async def _return_handshake_response(url, params=None, data=None, *_, **__):
|
||||
nonlocal response_status, client_seed, server_seed, device_auth_hash
|
||||
|
||||
if url == "http://127.0.0.1/app/handshake1":
|
||||
client_seed = data
|
||||
client_seed_auth_hash = _sha256(data + device_auth_hash)
|
||||
|
||||
return _mock_response(200, server_seed + client_seed_auth_hash)
|
||||
elif url == "http://127.0.0.1/app/handshake2":
|
||||
return _mock_response(response_status, b"")
|
||||
|
||||
client_seed = None
|
||||
server_seed = secrets.token_bytes(16)
|
||||
client_credentials = Credentials("foo", "bar")
|
||||
device_auth_hash = TPLinkKlap.generate_auth_hash(client_credentials)
|
||||
|
||||
mocker.patch.object(
|
||||
httpx.AsyncClient, "post", side_effect=_return_handshake_response
|
||||
)
|
||||
|
||||
protocol = TPLinkKlap("127.0.0.1", credentials=client_credentials)
|
||||
protocol.http_client = httpx.AsyncClient()
|
||||
|
||||
response_status = 200
|
||||
await protocol.perform_handshake()
|
||||
assert protocol.handshake_done is True
|
||||
|
||||
response_status = 403
|
||||
with pytest.raises(AuthenticationException):
|
||||
await protocol.perform_handshake()
|
||||
assert protocol.handshake_done is False
|
||||
await protocol.close()
|
||||
|
||||
|
||||
async def test_query(mocker):
|
||||
async def _return_response(url, params=None, data=None, *_, **__):
|
||||
nonlocal client_seed, server_seed, device_auth_hash, protocol, seq
|
||||
|
||||
if url == "http://127.0.0.1/app/handshake1":
|
||||
client_seed = data
|
||||
client_seed_auth_hash = _sha256(data + device_auth_hash)
|
||||
|
||||
return _mock_response(200, server_seed + client_seed_auth_hash)
|
||||
elif url == "http://127.0.0.1/app/handshake2":
|
||||
return _mock_response(200, b"")
|
||||
elif url == "http://127.0.0.1/app/request":
|
||||
encryption_session = KlapEncryptionSession(
|
||||
protocol.encryption_session.local_seed,
|
||||
protocol.encryption_session.remote_seed,
|
||||
protocol.encryption_session.user_hash,
|
||||
)
|
||||
seq = params.get("seq")
|
||||
encryption_session._seq = seq - 1
|
||||
encrypted, seq = encryption_session.encrypt('{"great": "success"}')
|
||||
seq = seq
|
||||
return _mock_response(200, encrypted)
|
||||
|
||||
client_seed = None
|
||||
last_seq = None
|
||||
seq = None
|
||||
server_seed = secrets.token_bytes(16)
|
||||
client_credentials = Credentials("foo", "bar")
|
||||
device_auth_hash = TPLinkKlap.generate_auth_hash(client_credentials)
|
||||
|
||||
mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response)
|
||||
|
||||
protocol = TPLinkKlap("127.0.0.1", credentials=client_credentials)
|
||||
|
||||
for _ in range(10):
|
||||
resp = await protocol.query({})
|
||||
assert resp == {"great": "success"}
|
||||
# Check the protocol is incrementing the sequence number
|
||||
assert last_seq is None or last_seq + 1 == seq
|
||||
last_seq = seq
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"response_status, expectation",
|
||||
[
|
||||
((403, 403, 403), pytest.raises(AuthenticationException)),
|
||||
((200, 403, 403), pytest.raises(AuthenticationException)),
|
||||
((200, 200, 403), pytest.raises(AuthenticationException)),
|
||||
((200, 200, 400), pytest.raises(SmartDeviceException)),
|
||||
],
|
||||
ids=("handshake1", "handshake2", "request", "non_auth_error"),
|
||||
)
|
||||
async def test_authentication_failures(mocker, response_status, expectation):
|
||||
async def _return_response(url, params=None, data=None, *_, **__):
|
||||
nonlocal client_seed, server_seed, device_auth_hash, response_status
|
||||
|
||||
if url == "http://127.0.0.1/app/handshake1":
|
||||
client_seed = data
|
||||
client_seed_auth_hash = _sha256(data + device_auth_hash)
|
||||
|
||||
return _mock_response(
|
||||
response_status[0], server_seed + client_seed_auth_hash
|
||||
)
|
||||
elif url == "http://127.0.0.1/app/handshake2":
|
||||
return _mock_response(response_status[1], b"")
|
||||
elif url == "http://127.0.0.1/app/request":
|
||||
return _mock_response(response_status[2], None)
|
||||
|
||||
client_seed = None
|
||||
|
||||
server_seed = secrets.token_bytes(16)
|
||||
client_credentials = Credentials("foo", "bar")
|
||||
device_auth_hash = TPLinkKlap.generate_auth_hash(client_credentials)
|
||||
|
||||
mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response)
|
||||
|
||||
protocol = TPLinkKlap("127.0.0.1", credentials=client_credentials)
|
||||
|
||||
with expectation:
|
||||
await protocol.query({})
|
Reference in New Issue
Block a user