Enable newer encrypted discovery protocol (#1168)

This commit is contained in:
Steven B. 2024-10-16 15:28:27 +01:00 committed by GitHub
parent 7fd8c14c1f
commit 380fbb93c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 257 additions and 69 deletions

View File

@ -14,7 +14,7 @@ from collections.abc import AsyncGenerator
from enum import Enum, auto
from typing import TYPE_CHECKING, Any, Dict, cast
from cryptography.hazmat.primitives import padding, serialization
from cryptography.hazmat.primitives import hashes, padding, serialization
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
@ -108,7 +108,9 @@ class AesTransport(BaseTransport):
self._key_pair: KeyPair | None = None
if config.aes_keys:
aes_keys = config.aes_keys
self._key_pair = KeyPair(aes_keys["private"], aes_keys["public"])
self._key_pair = KeyPair.create_from_der_keys(
aes_keys["private"], aes_keys["public"]
)
self._app_url = URL(f"http://{self._host}:{self._port}/app")
self._token_url: URL | None = None
@ -277,14 +279,14 @@ class AesTransport(BaseTransport):
if not self._key_pair:
kp = KeyPair.create_key_pair()
self._config.aes_keys = {
"private": kp.get_private_key(),
"public": kp.get_public_key(),
"private": kp.private_key_der_b64,
"public": kp.public_key_der_b64,
}
self._key_pair = kp
pub_key = (
"-----BEGIN PUBLIC KEY-----\n"
+ self._key_pair.get_public_key() # type: ignore[union-attr]
+ self._key_pair.public_key_der_b64 # type: ignore[union-attr]
+ "\n-----END PUBLIC KEY-----\n"
)
handshake_params = {"key": pub_key}
@ -392,18 +394,11 @@ class AesEncyptionSession:
"""Class for an AES encryption session."""
@staticmethod
def create_from_keypair(handshake_key: str, keypair):
def create_from_keypair(handshake_key: str, keypair: KeyPair):
"""Create the encryption session."""
handshake_key_bytes: bytes = base64.b64decode(handshake_key.encode("UTF-8"))
private_key_data = base64.b64decode(keypair.get_private_key().encode("UTF-8"))
handshake_key_bytes: bytes = base64.b64decode(handshake_key.encode())
private_key = cast(
rsa.RSAPrivateKey,
serialization.load_der_private_key(private_key_data, None, None),
)
key_and_iv = private_key.decrypt(
handshake_key_bytes, asymmetric_padding.PKCS1v15()
)
key_and_iv = keypair.decrypt_handshake_key(handshake_key_bytes)
if key_and_iv is None:
raise ValueError("Decryption failed!")
@ -438,30 +433,59 @@ class KeyPair:
"""Create a key pair."""
private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size)
public_key = private_key.public_key()
return KeyPair(private_key, public_key)
private_key_bytes = private_key.private_bytes(
@staticmethod
def create_from_der_keys(private_key_der_b64: str, public_key_der_b64: str):
"""Create a key pair."""
key_bytes = base64.b64decode(private_key_der_b64.encode())
private_key = cast(
rsa.RSAPrivateKey, serialization.load_der_private_key(key_bytes, None)
)
key_bytes = base64.b64decode(public_key_der_b64.encode())
public_key = cast(
rsa.RSAPublicKey, serialization.load_der_public_key(key_bytes, None)
)
return KeyPair(private_key, public_key)
def __init__(self, private_key: rsa.RSAPrivateKey, public_key: rsa.RSAPublicKey):
self.private_key = private_key
self.public_key = public_key
self.private_key_der_bytes = self.private_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
public_key_bytes = public_key.public_bytes(
self.public_key_der_bytes = self.public_key.public_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
self.private_key_der_b64 = base64.b64encode(self.private_key_der_bytes).decode()
self.public_key_der_b64 = base64.b64encode(self.public_key_der_bytes).decode()
return KeyPair(
private_key=base64.b64encode(private_key_bytes).decode("UTF-8"),
public_key=base64.b64encode(public_key_bytes).decode("UTF-8"),
def get_public_pem(self) -> bytes:
"""Get public key in PEM encoding."""
return self.public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
def __init__(self, private_key: str, public_key: str):
self.private_key = private_key
self.public_key = public_key
def decrypt_handshake_key(self, encrypted_key: bytes) -> bytes:
"""Decrypt an aes handshake key."""
decrypted = self.private_key.decrypt(
encrypted_key, asymmetric_padding.PKCS1v15()
)
return decrypted
def get_private_key(self) -> str:
"""Get the private key."""
return self.private_key
def get_public_key(self) -> str:
"""Get the public key."""
return self.public_key
def decrypt_discovery_key(self, encrypted_key: bytes) -> bytes:
"""Decrypt an aes discovery key."""
decrypted = self.private_key.decrypt(
encrypted_key,
asymmetric_padding.OAEP(
mgf=asymmetric_padding.MGF1(algorithm=hashes.SHA1()), # noqa: S303
algorithm=hashes.SHA1(), # noqa: S303
label=None,
),
)
return decrypted

View File

@ -3,6 +3,7 @@
from __future__ import annotations
import asyncio
from pprint import pformat as pf
import asyncclick as click
from pydantic.v1 import ValidationError
@ -28,6 +29,7 @@ async def discover(ctx):
password = ctx.parent.params["password"]
discovery_timeout = ctx.parent.params["discovery_timeout"]
timeout = ctx.parent.params["timeout"]
host = ctx.parent.params["host"]
port = ctx.parent.params["port"]
credentials = Credentials(username, password) if username and password else None
@ -49,8 +51,6 @@ async def discover(ctx):
echo(f"\t{unsupported_exception}")
echo()
echo(f"Discovering devices on {target} for {discovery_timeout} seconds")
from .device import state
async def print_discovered(dev: Device):
@ -68,6 +68,18 @@ async def discover(ctx):
discovered[dev.host] = dev.internal_state
echo()
if host:
echo(f"Discovering device {host} for {discovery_timeout} seconds")
return await Discover.discover_single(
host,
port=port,
credentials=credentials,
timeout=timeout,
discovery_timeout=discovery_timeout,
on_unsupported=print_unsupported,
)
echo(f"Discovering devices on {target} for {discovery_timeout} seconds")
discovered_devices = await Discover.discover(
target=target,
discovery_timeout=discovery_timeout,
@ -113,21 +125,31 @@ def _echo_discovery_info(discovery_info):
_echo_dictionary(discovery_info)
return
def _conditional_echo(label, value):
if value:
ws = " " * (19 - len(label))
echo(f"\t{label}:{ws}{value}")
echo("\t[bold]== Discovery Result ==[/bold]")
echo(f"\tDevice Type: {dr.device_type}")
echo(f"\tDevice Model: {dr.device_model}")
echo(f"\tIP: {dr.ip}")
echo(f"\tMAC: {dr.mac}")
echo(f"\tDevice Id (hash): {dr.device_id}")
echo(f"\tOwner (hash): {dr.owner}")
echo(f"\tHW Ver: {dr.hw_ver}")
echo(f"\tSupports IOT Cloud: {dr.is_support_iot_cloud}")
echo(f"\tOBD Src: {dr.obd_src}")
echo(f"\tFactory Default: {dr.factory_default}")
echo(f"\tEncrypt Type: {dr.mgt_encrypt_schm.encrypt_type}")
echo(f"\tSupports HTTPS: {dr.mgt_encrypt_schm.is_support_https}")
echo(f"\tHTTP Port: {dr.mgt_encrypt_schm.http_port}")
echo(f"\tLV (Login Level): {dr.mgt_encrypt_schm.lv}")
_conditional_echo("Device Type", dr.device_type)
_conditional_echo("Device Model", dr.device_model)
_conditional_echo("Device Name", dr.device_name)
_conditional_echo("IP", dr.ip)
_conditional_echo("MAC", dr.mac)
_conditional_echo("Device Id (hash)", dr.device_id)
_conditional_echo("Owner (hash)", dr.owner)
_conditional_echo("FW Ver", dr.firmware_version)
_conditional_echo("HW Ver", dr.hw_ver)
_conditional_echo("HW Ver", dr.hardware_version)
_conditional_echo("Supports IOT Cloud", dr.is_support_iot_cloud)
_conditional_echo("OBD Src", dr.owner)
_conditional_echo("Factory Default", dr.factory_default)
_conditional_echo("Encrypt Type", dr.mgt_encrypt_schm.encrypt_type)
_conditional_echo("Encrypt Type", dr.encrypt_type)
_conditional_echo("Supports HTTPS", dr.mgt_encrypt_schm.is_support_https)
_conditional_echo("HTTP Port", dr.mgt_encrypt_schm.http_port)
_conditional_echo("Encrypt info", pf(dr.encrypt_info) if dr.encrypt_info else None)
_conditional_echo("Decrypted", pf(dr.decrypted_data) if dr.decrypted_data else None)
async def find_host_from_alias(alias, target="255.255.255.255", timeout=1, attempts=3):

View File

@ -158,6 +158,7 @@ def _legacy_type_to_class(_type):
type=click.Choice(ENCRYPT_TYPES, case_sensitive=False),
)
@click.option(
"-df",
"--device-family",
envvar="KASA_DEVICE_FAMILY",
default="SMART.TAPOPLUG",
@ -182,7 +183,7 @@ def _legacy_type_to_class(_type):
@click.option(
"--discovery-timeout",
envvar="KASA_DISCOVERY_TIMEOUT",
default=5,
default=10,
required=False,
show_default=True,
help="Timeout for discovery.",
@ -326,15 +327,11 @@ async def cli(
dev = await Device.connect(config=config)
device_updated = True
else:
from kasa.discover import Discover
from .discover import discover
dev = await Discover.discover_single(
host,
port=port,
credentials=credentials,
timeout=timeout,
discovery_timeout=discovery_timeout,
)
dev = await ctx.invoke(discover)
if not dev:
error(f"Unable to create device for {host}")
# Skip update on specific commands, or if device factory,
# that performs an update was used for the device.

View File

@ -82,13 +82,16 @@ Discovering a single device returns a kasa.Device object.
from __future__ import annotations
import asyncio
import base64
import binascii
import ipaddress
import logging
import secrets
import socket
import struct
from collections.abc import Awaitable
from pprint import pformat as pf
from typing import Any, Callable, Dict, Optional, Type, cast
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast
# When support for cpython older than 3.11 is dropped
# async_timeout can be replaced with asyncio.timeout
@ -96,6 +99,7 @@ from async_timeout import timeout as asyncio_timeout
from pydantic.v1 import BaseModel, ValidationError
from kasa import Device
from kasa.aestransport import AesEncyptionSession, KeyPair
from kasa.credentials import Credentials
from kasa.device_factory import (
get_device_class_from_family,
@ -133,6 +137,46 @@ NEW_DISCOVERY_REDACTORS: dict[str, Callable[[Any], Any] | None] = {
}
class _AesDiscoveryQuery:
keypair: KeyPair | None = None
@classmethod
def generate_query(cls):
if not cls.keypair:
cls.keypair = KeyPair.create_key_pair(key_size=2048)
secret = secrets.token_bytes(4)
key_payload = {"params": {"rsa_key": cls.keypair.get_public_pem().decode()}}
key_payload_bytes = json_dumps(key_payload).encode()
# https://labs.withsecure.com/advisories/tp-link-ac1750-pwn2own-2019
version = 2 # version of tdp
msg_type = 0
op_code = 1 # probe
msg_size = len(key_payload_bytes)
flags = 17
padding_byte = 0 # blank byte
device_serial = int.from_bytes(secret, "big")
initial_crc = 0x5A6B7C8D
disco_header = struct.pack(
">BBHHBBII",
version,
msg_type,
op_code,
msg_size,
flags,
padding_byte,
device_serial,
initial_crc,
)
query = bytearray(disco_header + key_payload_bytes)
crc = binascii.crc32(query).to_bytes(length=4, byteorder="big")
query[12:16] = crc
return query
class _DiscoverProtocol(asyncio.DatagramProtocol):
"""Implementation of the discovery protocol handler.
@ -224,15 +268,21 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
_LOGGER.debug("[DISCOVERY] %s >> %s", self.target, Discover.DISCOVERY_QUERY)
encrypted_req = XorEncryption.encrypt(req)
sleep_between_packets = self.discovery_timeout / self.discovery_packets
aes_discovery_query = _AesDiscoveryQuery.generate_query()
for _ in range(self.discovery_packets):
if self.target in self.seen_hosts: # Stop sending for discover_single
break
self.transport.sendto(encrypted_req[4:], self.target_1) # type: ignore
self.transport.sendto(Discover.DISCOVERY_QUERY_2, self.target_2) # type: ignore
self.transport.sendto(aes_discovery_query, self.target_2) # type: ignore
await asyncio.sleep(sleep_between_packets)
def datagram_received(self, data, addr) -> None:
"""Handle discovery responses."""
if TYPE_CHECKING:
assert _AesDiscoveryQuery.keypair
ip, port = addr
# Prevent multiple entries due multiple broadcasts
if ip in self.seen_hosts:
@ -395,7 +445,8 @@ class Discover:
credentials: Credentials | None = None,
username: str | None = None,
password: str | None = None,
) -> Device:
on_unsupported: OnUnsupportedCallable | None = None,
) -> Device | None:
"""Discover a single device by the given IP address.
It is generally preferred to avoid :func:`discover_single()` and
@ -465,6 +516,10 @@ class Discover:
dev.host = host
return dev
elif ip in protocol.unsupported_device_exceptions:
if on_unsupported:
await on_unsupported(protocol.unsupported_device_exceptions[ip])
return None
else:
raise protocol.unsupported_device_exceptions[ip]
elif ip in protocol.invalid_device_exceptions:
raise protocol.invalid_device_exceptions[ip]
@ -512,6 +567,25 @@ class Discover:
device.update_from_discover_info(info)
return device
@staticmethod
def _decrypt_discovery_data(discovery_result: DiscoveryResult) -> None:
if TYPE_CHECKING:
assert discovery_result.encrypt_info
assert _AesDiscoveryQuery.keypair
encryped_key = discovery_result.encrypt_info.key
encrypted_data = discovery_result.encrypt_info.data
key_and_iv = _AesDiscoveryQuery.keypair.decrypt_discovery_key(
base64.b64decode(encryped_key.encode())
)
key, iv = key_and_iv[:16], key_and_iv[16:]
session = AesEncyptionSession(key, iv)
decrypted_data = session.decrypt(encrypted_data)
discovery_result.decrypted_data = json_loads(decrypted_data)
@staticmethod
def _get_device_instance(
data: bytes,
@ -528,6 +602,8 @@ class Discover:
) from ex
try:
discovery_result = DiscoveryResult(**info["result"])
if discovery_result.encrypt_info:
Discover._decrypt_discovery_data(discovery_result)
except ValidationError as ex:
if debug_enabled:
data = (
@ -547,9 +623,19 @@ class Discover:
type_ = discovery_result.device_type
try:
if not (
encrypt_type := discovery_result.mgt_encrypt_schm.encrypt_type
) and (encrypt_info := discovery_result.encrypt_info):
encrypt_type = encrypt_info.sym_schm
if not encrypt_type:
raise UnsupportedDeviceError(
f"Unsupported device {config.host} of type {type_} "
+ "with no encryption type",
discovery_result=discovery_result.get_dict(),
)
config.connection_type = DeviceConnectionParameters.from_values(
type_,
discovery_result.mgt_encrypt_schm.encrypt_type,
encrypt_type,
discovery_result.mgt_encrypt_schm.lv,
)
except KasaException as ex:
@ -593,21 +679,35 @@ class EncryptionScheme(BaseModel):
"""Base model for encryption scheme of discovery result."""
is_support_https: bool
encrypt_type: str
http_port: int
encrypt_type: Optional[str] # noqa: UP007
http_port: Optional[int] = None # noqa: UP007
lv: Optional[int] = None # noqa: UP007
class EncryptionInfo(BaseModel):
"""Base model for encryption info of discovery result."""
sym_schm: str
key: str
data: str
class DiscoveryResult(BaseModel):
"""Base model for discovery result."""
device_type: str
device_model: str
device_name: Optional[str] # noqa: UP007
ip: str
mac: str
mgt_encrypt_schm: EncryptionScheme
encrypt_info: Optional[EncryptionInfo] = None # noqa: UP007
encrypt_type: Optional[list[str]] = None # noqa: UP007
decrypted_data: Optional[dict] = None # noqa: UP007
device_id: str
firmware_version: Optional[str] = None # noqa: UP007
hardware_version: Optional[str] = None # noqa: UP007
hw_ver: Optional[str] = None # noqa: UP007
owner: Optional[str] = None # noqa: UP007
is_support_iot_cloud: Optional[bool] = None # noqa: UP007

View File

@ -99,8 +99,8 @@ async def test_handshake_with_keys(mocker):
assert transport._state is TransportState.HANDSHAKE_REQUIRED
await transport.perform_handshake()
assert transport._key_pair.get_private_key() == test_keys["private"]
assert transport._key_pair.get_public_key() == test_keys["public"]
assert transport._key_pair.private_key_der_b64 == test_keys["private"]
assert transport._key_pair.public_key_der_b64 == test_keys["public"]
@status_parameters

View File

@ -2,6 +2,7 @@ import json
import os
import re
from datetime import datetime
from unittest.mock import ANY
import asyncclick as click
import pytest
@ -17,7 +18,6 @@ from kasa import (
EmeterStatus,
KasaException,
Module,
UnsupportedDeviceError,
)
from kasa.cli.device import (
alias,
@ -613,6 +613,7 @@ async def test_without_device_type(dev, mocker, runner):
credentials=Credentials("foo", "bar"),
timeout=5,
discovery_timeout=7,
on_unsupported=ANY,
)
@ -735,7 +736,7 @@ async def test_host_unsupported(unsupported_device_info, runner):
)
assert res.exit_code != 0
assert isinstance(res.exception, UnsupportedDeviceError)
assert "== Unsupported device ==" in res.output
@new_discovery

View File

@ -2,6 +2,8 @@
# ruff: noqa: S106
import asyncio
import base64
import json
import logging
import re
import socket
@ -10,6 +12,8 @@ from unittest.mock import MagicMock
import aiohttp
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
from async_timeout import timeout as asyncio_timeout
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
from kasa import (
Credentials,
@ -18,11 +22,17 @@ from kasa import (
Discover,
KasaException,
)
from kasa.aestransport import AesEncyptionSession
from kasa.deviceconfig import (
DeviceConfig,
DeviceConnectionParameters,
)
from kasa.discover import DiscoveryResult, _DiscoverProtocol, json_dumps
from kasa.discover import (
DiscoveryResult,
_AesDiscoveryQuery,
_DiscoverProtocol,
json_dumps,
)
from kasa.exceptions import AuthenticationError, UnsupportedDeviceError
from kasa.iot import IotDevice
from kasa.xortransport import XorEncryption
@ -278,7 +288,7 @@ async def test_discover_send(mocker):
assert proto.target_1 == ("255.255.255.255", 9999)
transport = mocker.patch.object(proto, "transport")
await proto.do_discover()
assert transport.sendto.call_count == proto.discovery_packets * 2
assert transport.sendto.call_count == proto.discovery_packets * 3
async def test_discover_datagram_received(mocker, discovery_data):
@ -485,13 +495,14 @@ async def test_do_discover_drop_packets(mocker, port, do_not_reply_count):
discovery_timeout=discovery_timeout,
discovery_packets=5,
)
ft = FakeDatagramTransport(dp, port, do_not_reply_count)
expected_send = 1 if port == 9999 else 2
ft = FakeDatagramTransport(dp, port, do_not_reply_count * expected_send)
dp.connection_made(ft)
await dp.wait_for_discovery_to_complete()
await asyncio.sleep(0)
assert ft.send_count == do_not_reply_count + 1
assert ft.send_count == do_not_reply_count * expected_send + expected_send
assert dp.discover_task.done()
assert dp.discover_task.cancelled()
@ -603,3 +614,36 @@ async def test_discovery_redaction(discovery_mock, caplog: pytest.LogCaptureFixt
await Discover.discover()
assert mac not in caplog.text
assert "12:34:56:00:00:00" in caplog.text
async def test_discovery_decryption():
"""Test discovery decryption."""
key = b"8\x89\x02\xfa\xf5Xs\x1c\xa1 H\x9a\x82\xc7\xd9\t"
iv = b"9=\xf8\x1bS\xcd0\xb5\x89i\xba\xfd^9\x9f\xfa"
key_iv = key + iv
_AesDiscoveryQuery.generate_query()
keypair = _AesDiscoveryQuery.keypair
padding = asymmetric_padding.OAEP(
mgf=asymmetric_padding.MGF1(algorithm=hashes.SHA1()), # noqa: S303
algorithm=hashes.SHA1(), # noqa: S303
label=None,
)
encrypted_key_iv = keypair.public_key.encrypt(key_iv, padding)
encrypted_key_iv_b4 = base64.b64encode(encrypted_key_iv)
encryption_session = AesEncyptionSession(key_iv[:16], key_iv[16:])
data_dict = {"foo": 1, "bar": 2}
data = json.dumps(data_dict)
encypted_data = encryption_session.encrypt(data.encode())
encrypt_info = {
"data": encypted_data.decode(),
"key": encrypted_key_iv_b4.decode(),
"sym_schm": "AES",
}
info = {**UNSUPPORTED["result"], "encrypt_info": encrypt_info}
dr = DiscoveryResult(**info)
Discover._decrypt_discovery_data(dr)
assert dr.decrypted_data == data_dict