Enable newer encrypted discovery protocol (#1168)

This commit is contained in:
Steven B. 2024-10-16 15:28:27 +01:00 committed by GitHub
parent 7fd8c14c1f
commit 380fbb93c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 257 additions and 69 deletions

View File

@ -14,7 +14,7 @@ from collections.abc import AsyncGenerator
from enum import Enum, auto from enum import Enum, auto
from typing import TYPE_CHECKING, Any, Dict, cast from typing import TYPE_CHECKING, Any, Dict, cast
from cryptography.hazmat.primitives import padding, serialization from cryptography.hazmat.primitives import hashes, padding, serialization
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
@ -108,7 +108,9 @@ class AesTransport(BaseTransport):
self._key_pair: KeyPair | None = None self._key_pair: KeyPair | None = None
if config.aes_keys: if config.aes_keys:
aes_keys = config.aes_keys aes_keys = config.aes_keys
self._key_pair = KeyPair(aes_keys["private"], aes_keys["public"]) self._key_pair = KeyPair.create_from_der_keys(
aes_keys["private"], aes_keys["public"]
)
self._app_url = URL(f"http://{self._host}:{self._port}/app") self._app_url = URL(f"http://{self._host}:{self._port}/app")
self._token_url: URL | None = None self._token_url: URL | None = None
@ -277,14 +279,14 @@ class AesTransport(BaseTransport):
if not self._key_pair: if not self._key_pair:
kp = KeyPair.create_key_pair() kp = KeyPair.create_key_pair()
self._config.aes_keys = { self._config.aes_keys = {
"private": kp.get_private_key(), "private": kp.private_key_der_b64,
"public": kp.get_public_key(), "public": kp.public_key_der_b64,
} }
self._key_pair = kp self._key_pair = kp
pub_key = ( pub_key = (
"-----BEGIN PUBLIC KEY-----\n" "-----BEGIN PUBLIC KEY-----\n"
+ self._key_pair.get_public_key() # type: ignore[union-attr] + self._key_pair.public_key_der_b64 # type: ignore[union-attr]
+ "\n-----END PUBLIC KEY-----\n" + "\n-----END PUBLIC KEY-----\n"
) )
handshake_params = {"key": pub_key} handshake_params = {"key": pub_key}
@ -392,18 +394,11 @@ class AesEncyptionSession:
"""Class for an AES encryption session.""" """Class for an AES encryption session."""
@staticmethod @staticmethod
def create_from_keypair(handshake_key: str, keypair): def create_from_keypair(handshake_key: str, keypair: KeyPair):
"""Create the encryption session.""" """Create the encryption session."""
handshake_key_bytes: bytes = base64.b64decode(handshake_key.encode("UTF-8")) handshake_key_bytes: bytes = base64.b64decode(handshake_key.encode())
private_key_data = base64.b64decode(keypair.get_private_key().encode("UTF-8"))
private_key = cast( key_and_iv = keypair.decrypt_handshake_key(handshake_key_bytes)
rsa.RSAPrivateKey,
serialization.load_der_private_key(private_key_data, None, None),
)
key_and_iv = private_key.decrypt(
handshake_key_bytes, asymmetric_padding.PKCS1v15()
)
if key_and_iv is None: if key_and_iv is None:
raise ValueError("Decryption failed!") raise ValueError("Decryption failed!")
@ -438,30 +433,59 @@ class KeyPair:
"""Create a key pair.""" """Create a key pair."""
private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size) private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size)
public_key = private_key.public_key() public_key = private_key.public_key()
return KeyPair(private_key, public_key)
private_key_bytes = private_key.private_bytes( @staticmethod
def create_from_der_keys(private_key_der_b64: str, public_key_der_b64: str):
"""Create a key pair."""
key_bytes = base64.b64decode(private_key_der_b64.encode())
private_key = cast(
rsa.RSAPrivateKey, serialization.load_der_private_key(key_bytes, None)
)
key_bytes = base64.b64decode(public_key_der_b64.encode())
public_key = cast(
rsa.RSAPublicKey, serialization.load_der_public_key(key_bytes, None)
)
return KeyPair(private_key, public_key)
def __init__(self, private_key: rsa.RSAPrivateKey, public_key: rsa.RSAPublicKey):
self.private_key = private_key
self.public_key = public_key
self.private_key_der_bytes = self.private_key.private_bytes(
encoding=serialization.Encoding.DER, encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8, format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(), encryption_algorithm=serialization.NoEncryption(),
) )
public_key_bytes = public_key.public_bytes( self.public_key_der_bytes = self.public_key.public_bytes(
encoding=serialization.Encoding.DER, encoding=serialization.Encoding.DER,
format=serialization.PublicFormat.SubjectPublicKeyInfo, format=serialization.PublicFormat.SubjectPublicKeyInfo,
) )
self.private_key_der_b64 = base64.b64encode(self.private_key_der_bytes).decode()
self.public_key_der_b64 = base64.b64encode(self.public_key_der_bytes).decode()
return KeyPair( def get_public_pem(self) -> bytes:
private_key=base64.b64encode(private_key_bytes).decode("UTF-8"), """Get public key in PEM encoding."""
public_key=base64.b64encode(public_key_bytes).decode("UTF-8"), return self.public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
) )
def __init__(self, private_key: str, public_key: str): def decrypt_handshake_key(self, encrypted_key: bytes) -> bytes:
self.private_key = private_key """Decrypt an aes handshake key."""
self.public_key = public_key decrypted = self.private_key.decrypt(
encrypted_key, asymmetric_padding.PKCS1v15()
)
return decrypted
def get_private_key(self) -> str: def decrypt_discovery_key(self, encrypted_key: bytes) -> bytes:
"""Get the private key.""" """Decrypt an aes discovery key."""
return self.private_key decrypted = self.private_key.decrypt(
encrypted_key,
def get_public_key(self) -> str: asymmetric_padding.OAEP(
"""Get the public key.""" mgf=asymmetric_padding.MGF1(algorithm=hashes.SHA1()), # noqa: S303
return self.public_key algorithm=hashes.SHA1(), # noqa: S303
label=None,
),
)
return decrypted

View File

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from pprint import pformat as pf
import asyncclick as click import asyncclick as click
from pydantic.v1 import ValidationError from pydantic.v1 import ValidationError
@ -28,6 +29,7 @@ async def discover(ctx):
password = ctx.parent.params["password"] password = ctx.parent.params["password"]
discovery_timeout = ctx.parent.params["discovery_timeout"] discovery_timeout = ctx.parent.params["discovery_timeout"]
timeout = ctx.parent.params["timeout"] timeout = ctx.parent.params["timeout"]
host = ctx.parent.params["host"]
port = ctx.parent.params["port"] port = ctx.parent.params["port"]
credentials = Credentials(username, password) if username and password else None credentials = Credentials(username, password) if username and password else None
@ -49,8 +51,6 @@ async def discover(ctx):
echo(f"\t{unsupported_exception}") echo(f"\t{unsupported_exception}")
echo() echo()
echo(f"Discovering devices on {target} for {discovery_timeout} seconds")
from .device import state from .device import state
async def print_discovered(dev: Device): async def print_discovered(dev: Device):
@ -68,6 +68,18 @@ async def discover(ctx):
discovered[dev.host] = dev.internal_state discovered[dev.host] = dev.internal_state
echo() echo()
if host:
echo(f"Discovering device {host} for {discovery_timeout} seconds")
return await Discover.discover_single(
host,
port=port,
credentials=credentials,
timeout=timeout,
discovery_timeout=discovery_timeout,
on_unsupported=print_unsupported,
)
echo(f"Discovering devices on {target} for {discovery_timeout} seconds")
discovered_devices = await Discover.discover( discovered_devices = await Discover.discover(
target=target, target=target,
discovery_timeout=discovery_timeout, discovery_timeout=discovery_timeout,
@ -113,21 +125,31 @@ def _echo_discovery_info(discovery_info):
_echo_dictionary(discovery_info) _echo_dictionary(discovery_info)
return return
def _conditional_echo(label, value):
if value:
ws = " " * (19 - len(label))
echo(f"\t{label}:{ws}{value}")
echo("\t[bold]== Discovery Result ==[/bold]") echo("\t[bold]== Discovery Result ==[/bold]")
echo(f"\tDevice Type: {dr.device_type}") _conditional_echo("Device Type", dr.device_type)
echo(f"\tDevice Model: {dr.device_model}") _conditional_echo("Device Model", dr.device_model)
echo(f"\tIP: {dr.ip}") _conditional_echo("Device Name", dr.device_name)
echo(f"\tMAC: {dr.mac}") _conditional_echo("IP", dr.ip)
echo(f"\tDevice Id (hash): {dr.device_id}") _conditional_echo("MAC", dr.mac)
echo(f"\tOwner (hash): {dr.owner}") _conditional_echo("Device Id (hash)", dr.device_id)
echo(f"\tHW Ver: {dr.hw_ver}") _conditional_echo("Owner (hash)", dr.owner)
echo(f"\tSupports IOT Cloud: {dr.is_support_iot_cloud}") _conditional_echo("FW Ver", dr.firmware_version)
echo(f"\tOBD Src: {dr.obd_src}") _conditional_echo("HW Ver", dr.hw_ver)
echo(f"\tFactory Default: {dr.factory_default}") _conditional_echo("HW Ver", dr.hardware_version)
echo(f"\tEncrypt Type: {dr.mgt_encrypt_schm.encrypt_type}") _conditional_echo("Supports IOT Cloud", dr.is_support_iot_cloud)
echo(f"\tSupports HTTPS: {dr.mgt_encrypt_schm.is_support_https}") _conditional_echo("OBD Src", dr.owner)
echo(f"\tHTTP Port: {dr.mgt_encrypt_schm.http_port}") _conditional_echo("Factory Default", dr.factory_default)
echo(f"\tLV (Login Level): {dr.mgt_encrypt_schm.lv}") _conditional_echo("Encrypt Type", dr.mgt_encrypt_schm.encrypt_type)
_conditional_echo("Encrypt Type", dr.encrypt_type)
_conditional_echo("Supports HTTPS", dr.mgt_encrypt_schm.is_support_https)
_conditional_echo("HTTP Port", dr.mgt_encrypt_schm.http_port)
_conditional_echo("Encrypt info", pf(dr.encrypt_info) if dr.encrypt_info else None)
_conditional_echo("Decrypted", pf(dr.decrypted_data) if dr.decrypted_data else None)
async def find_host_from_alias(alias, target="255.255.255.255", timeout=1, attempts=3): async def find_host_from_alias(alias, target="255.255.255.255", timeout=1, attempts=3):

View File

@ -158,6 +158,7 @@ def _legacy_type_to_class(_type):
type=click.Choice(ENCRYPT_TYPES, case_sensitive=False), type=click.Choice(ENCRYPT_TYPES, case_sensitive=False),
) )
@click.option( @click.option(
"-df",
"--device-family", "--device-family",
envvar="KASA_DEVICE_FAMILY", envvar="KASA_DEVICE_FAMILY",
default="SMART.TAPOPLUG", default="SMART.TAPOPLUG",
@ -182,7 +183,7 @@ def _legacy_type_to_class(_type):
@click.option( @click.option(
"--discovery-timeout", "--discovery-timeout",
envvar="KASA_DISCOVERY_TIMEOUT", envvar="KASA_DISCOVERY_TIMEOUT",
default=5, default=10,
required=False, required=False,
show_default=True, show_default=True,
help="Timeout for discovery.", help="Timeout for discovery.",
@ -326,15 +327,11 @@ async def cli(
dev = await Device.connect(config=config) dev = await Device.connect(config=config)
device_updated = True device_updated = True
else: else:
from kasa.discover import Discover from .discover import discover
dev = await Discover.discover_single( dev = await ctx.invoke(discover)
host, if not dev:
port=port, error(f"Unable to create device for {host}")
credentials=credentials,
timeout=timeout,
discovery_timeout=discovery_timeout,
)
# Skip update on specific commands, or if device factory, # Skip update on specific commands, or if device factory,
# that performs an update was used for the device. # that performs an update was used for the device.

View File

@ -82,13 +82,16 @@ Discovering a single device returns a kasa.Device object.
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import base64
import binascii import binascii
import ipaddress import ipaddress
import logging import logging
import secrets
import socket import socket
import struct
from collections.abc import Awaitable from collections.abc import Awaitable
from pprint import pformat as pf from pprint import pformat as pf
from typing import Any, Callable, Dict, Optional, Type, cast from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast
# When support for cpython older than 3.11 is dropped # When support for cpython older than 3.11 is dropped
# async_timeout can be replaced with asyncio.timeout # async_timeout can be replaced with asyncio.timeout
@ -96,6 +99,7 @@ from async_timeout import timeout as asyncio_timeout
from pydantic.v1 import BaseModel, ValidationError from pydantic.v1 import BaseModel, ValidationError
from kasa import Device from kasa import Device
from kasa.aestransport import AesEncyptionSession, KeyPair
from kasa.credentials import Credentials from kasa.credentials import Credentials
from kasa.device_factory import ( from kasa.device_factory import (
get_device_class_from_family, get_device_class_from_family,
@ -133,6 +137,46 @@ NEW_DISCOVERY_REDACTORS: dict[str, Callable[[Any], Any] | None] = {
} }
class _AesDiscoveryQuery:
keypair: KeyPair | None = None
@classmethod
def generate_query(cls):
if not cls.keypair:
cls.keypair = KeyPair.create_key_pair(key_size=2048)
secret = secrets.token_bytes(4)
key_payload = {"params": {"rsa_key": cls.keypair.get_public_pem().decode()}}
key_payload_bytes = json_dumps(key_payload).encode()
# https://labs.withsecure.com/advisories/tp-link-ac1750-pwn2own-2019
version = 2 # version of tdp
msg_type = 0
op_code = 1 # probe
msg_size = len(key_payload_bytes)
flags = 17
padding_byte = 0 # blank byte
device_serial = int.from_bytes(secret, "big")
initial_crc = 0x5A6B7C8D
disco_header = struct.pack(
">BBHHBBII",
version,
msg_type,
op_code,
msg_size,
flags,
padding_byte,
device_serial,
initial_crc,
)
query = bytearray(disco_header + key_payload_bytes)
crc = binascii.crc32(query).to_bytes(length=4, byteorder="big")
query[12:16] = crc
return query
class _DiscoverProtocol(asyncio.DatagramProtocol): class _DiscoverProtocol(asyncio.DatagramProtocol):
"""Implementation of the discovery protocol handler. """Implementation of the discovery protocol handler.
@ -224,15 +268,21 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
_LOGGER.debug("[DISCOVERY] %s >> %s", self.target, Discover.DISCOVERY_QUERY) _LOGGER.debug("[DISCOVERY] %s >> %s", self.target, Discover.DISCOVERY_QUERY)
encrypted_req = XorEncryption.encrypt(req) encrypted_req = XorEncryption.encrypt(req)
sleep_between_packets = self.discovery_timeout / self.discovery_packets sleep_between_packets = self.discovery_timeout / self.discovery_packets
aes_discovery_query = _AesDiscoveryQuery.generate_query()
for _ in range(self.discovery_packets): for _ in range(self.discovery_packets):
if self.target in self.seen_hosts: # Stop sending for discover_single if self.target in self.seen_hosts: # Stop sending for discover_single
break break
self.transport.sendto(encrypted_req[4:], self.target_1) # type: ignore self.transport.sendto(encrypted_req[4:], self.target_1) # type: ignore
self.transport.sendto(Discover.DISCOVERY_QUERY_2, self.target_2) # type: ignore self.transport.sendto(Discover.DISCOVERY_QUERY_2, self.target_2) # type: ignore
self.transport.sendto(aes_discovery_query, self.target_2) # type: ignore
await asyncio.sleep(sleep_between_packets) await asyncio.sleep(sleep_between_packets)
def datagram_received(self, data, addr) -> None: def datagram_received(self, data, addr) -> None:
"""Handle discovery responses.""" """Handle discovery responses."""
if TYPE_CHECKING:
assert _AesDiscoveryQuery.keypair
ip, port = addr ip, port = addr
# Prevent multiple entries due multiple broadcasts # Prevent multiple entries due multiple broadcasts
if ip in self.seen_hosts: if ip in self.seen_hosts:
@ -395,7 +445,8 @@ class Discover:
credentials: Credentials | None = None, credentials: Credentials | None = None,
username: str | None = None, username: str | None = None,
password: str | None = None, password: str | None = None,
) -> Device: on_unsupported: OnUnsupportedCallable | None = None,
) -> Device | None:
"""Discover a single device by the given IP address. """Discover a single device by the given IP address.
It is generally preferred to avoid :func:`discover_single()` and It is generally preferred to avoid :func:`discover_single()` and
@ -465,6 +516,10 @@ class Discover:
dev.host = host dev.host = host
return dev return dev
elif ip in protocol.unsupported_device_exceptions: elif ip in protocol.unsupported_device_exceptions:
if on_unsupported:
await on_unsupported(protocol.unsupported_device_exceptions[ip])
return None
else:
raise protocol.unsupported_device_exceptions[ip] raise protocol.unsupported_device_exceptions[ip]
elif ip in protocol.invalid_device_exceptions: elif ip in protocol.invalid_device_exceptions:
raise protocol.invalid_device_exceptions[ip] raise protocol.invalid_device_exceptions[ip]
@ -512,6 +567,25 @@ class Discover:
device.update_from_discover_info(info) device.update_from_discover_info(info)
return device return device
@staticmethod
def _decrypt_discovery_data(discovery_result: DiscoveryResult) -> None:
if TYPE_CHECKING:
assert discovery_result.encrypt_info
assert _AesDiscoveryQuery.keypair
encryped_key = discovery_result.encrypt_info.key
encrypted_data = discovery_result.encrypt_info.data
key_and_iv = _AesDiscoveryQuery.keypair.decrypt_discovery_key(
base64.b64decode(encryped_key.encode())
)
key, iv = key_and_iv[:16], key_and_iv[16:]
session = AesEncyptionSession(key, iv)
decrypted_data = session.decrypt(encrypted_data)
discovery_result.decrypted_data = json_loads(decrypted_data)
@staticmethod @staticmethod
def _get_device_instance( def _get_device_instance(
data: bytes, data: bytes,
@ -528,6 +602,8 @@ class Discover:
) from ex ) from ex
try: try:
discovery_result = DiscoveryResult(**info["result"]) discovery_result = DiscoveryResult(**info["result"])
if discovery_result.encrypt_info:
Discover._decrypt_discovery_data(discovery_result)
except ValidationError as ex: except ValidationError as ex:
if debug_enabled: if debug_enabled:
data = ( data = (
@ -547,9 +623,19 @@ class Discover:
type_ = discovery_result.device_type type_ = discovery_result.device_type
try: try:
if not (
encrypt_type := discovery_result.mgt_encrypt_schm.encrypt_type
) and (encrypt_info := discovery_result.encrypt_info):
encrypt_type = encrypt_info.sym_schm
if not encrypt_type:
raise UnsupportedDeviceError(
f"Unsupported device {config.host} of type {type_} "
+ "with no encryption type",
discovery_result=discovery_result.get_dict(),
)
config.connection_type = DeviceConnectionParameters.from_values( config.connection_type = DeviceConnectionParameters.from_values(
type_, type_,
discovery_result.mgt_encrypt_schm.encrypt_type, encrypt_type,
discovery_result.mgt_encrypt_schm.lv, discovery_result.mgt_encrypt_schm.lv,
) )
except KasaException as ex: except KasaException as ex:
@ -593,21 +679,35 @@ class EncryptionScheme(BaseModel):
"""Base model for encryption scheme of discovery result.""" """Base model for encryption scheme of discovery result."""
is_support_https: bool is_support_https: bool
encrypt_type: str encrypt_type: Optional[str] # noqa: UP007
http_port: int http_port: Optional[int] = None # noqa: UP007
lv: Optional[int] = None # noqa: UP007 lv: Optional[int] = None # noqa: UP007
class EncryptionInfo(BaseModel):
"""Base model for encryption info of discovery result."""
sym_schm: str
key: str
data: str
class DiscoveryResult(BaseModel): class DiscoveryResult(BaseModel):
"""Base model for discovery result.""" """Base model for discovery result."""
device_type: str device_type: str
device_model: str device_model: str
device_name: Optional[str] # noqa: UP007
ip: str ip: str
mac: str mac: str
mgt_encrypt_schm: EncryptionScheme mgt_encrypt_schm: EncryptionScheme
encrypt_info: Optional[EncryptionInfo] = None # noqa: UP007
encrypt_type: Optional[list[str]] = None # noqa: UP007
decrypted_data: Optional[dict] = None # noqa: UP007
device_id: str device_id: str
firmware_version: Optional[str] = None # noqa: UP007
hardware_version: Optional[str] = None # noqa: UP007
hw_ver: Optional[str] = None # noqa: UP007 hw_ver: Optional[str] = None # noqa: UP007
owner: Optional[str] = None # noqa: UP007 owner: Optional[str] = None # noqa: UP007
is_support_iot_cloud: Optional[bool] = None # noqa: UP007 is_support_iot_cloud: Optional[bool] = None # noqa: UP007

View File

@ -99,8 +99,8 @@ async def test_handshake_with_keys(mocker):
assert transport._state is TransportState.HANDSHAKE_REQUIRED assert transport._state is TransportState.HANDSHAKE_REQUIRED
await transport.perform_handshake() await transport.perform_handshake()
assert transport._key_pair.get_private_key() == test_keys["private"] assert transport._key_pair.private_key_der_b64 == test_keys["private"]
assert transport._key_pair.get_public_key() == test_keys["public"] assert transport._key_pair.public_key_der_b64 == test_keys["public"]
@status_parameters @status_parameters

View File

@ -2,6 +2,7 @@ import json
import os import os
import re import re
from datetime import datetime from datetime import datetime
from unittest.mock import ANY
import asyncclick as click import asyncclick as click
import pytest import pytest
@ -17,7 +18,6 @@ from kasa import (
EmeterStatus, EmeterStatus,
KasaException, KasaException,
Module, Module,
UnsupportedDeviceError,
) )
from kasa.cli.device import ( from kasa.cli.device import (
alias, alias,
@ -613,6 +613,7 @@ async def test_without_device_type(dev, mocker, runner):
credentials=Credentials("foo", "bar"), credentials=Credentials("foo", "bar"),
timeout=5, timeout=5,
discovery_timeout=7, discovery_timeout=7,
on_unsupported=ANY,
) )
@ -735,7 +736,7 @@ async def test_host_unsupported(unsupported_device_info, runner):
) )
assert res.exit_code != 0 assert res.exit_code != 0
assert isinstance(res.exception, UnsupportedDeviceError) assert "== Unsupported device ==" in res.output
@new_discovery @new_discovery

View File

@ -2,6 +2,8 @@
# ruff: noqa: S106 # ruff: noqa: S106
import asyncio import asyncio
import base64
import json
import logging import logging
import re import re
import socket import socket
@ -10,6 +12,8 @@ from unittest.mock import MagicMock
import aiohttp import aiohttp
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
from async_timeout import timeout as asyncio_timeout from async_timeout import timeout as asyncio_timeout
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
from kasa import ( from kasa import (
Credentials, Credentials,
@ -18,11 +22,17 @@ from kasa import (
Discover, Discover,
KasaException, KasaException,
) )
from kasa.aestransport import AesEncyptionSession
from kasa.deviceconfig import ( from kasa.deviceconfig import (
DeviceConfig, DeviceConfig,
DeviceConnectionParameters, DeviceConnectionParameters,
) )
from kasa.discover import DiscoveryResult, _DiscoverProtocol, json_dumps from kasa.discover import (
DiscoveryResult,
_AesDiscoveryQuery,
_DiscoverProtocol,
json_dumps,
)
from kasa.exceptions import AuthenticationError, UnsupportedDeviceError from kasa.exceptions import AuthenticationError, UnsupportedDeviceError
from kasa.iot import IotDevice from kasa.iot import IotDevice
from kasa.xortransport import XorEncryption from kasa.xortransport import XorEncryption
@ -278,7 +288,7 @@ async def test_discover_send(mocker):
assert proto.target_1 == ("255.255.255.255", 9999) assert proto.target_1 == ("255.255.255.255", 9999)
transport = mocker.patch.object(proto, "transport") transport = mocker.patch.object(proto, "transport")
await proto.do_discover() await proto.do_discover()
assert transport.sendto.call_count == proto.discovery_packets * 2 assert transport.sendto.call_count == proto.discovery_packets * 3
async def test_discover_datagram_received(mocker, discovery_data): async def test_discover_datagram_received(mocker, discovery_data):
@ -485,13 +495,14 @@ async def test_do_discover_drop_packets(mocker, port, do_not_reply_count):
discovery_timeout=discovery_timeout, discovery_timeout=discovery_timeout,
discovery_packets=5, discovery_packets=5,
) )
ft = FakeDatagramTransport(dp, port, do_not_reply_count) expected_send = 1 if port == 9999 else 2
ft = FakeDatagramTransport(dp, port, do_not_reply_count * expected_send)
dp.connection_made(ft) dp.connection_made(ft)
await dp.wait_for_discovery_to_complete() await dp.wait_for_discovery_to_complete()
await asyncio.sleep(0) await asyncio.sleep(0)
assert ft.send_count == do_not_reply_count + 1 assert ft.send_count == do_not_reply_count * expected_send + expected_send
assert dp.discover_task.done() assert dp.discover_task.done()
assert dp.discover_task.cancelled() assert dp.discover_task.cancelled()
@ -603,3 +614,36 @@ async def test_discovery_redaction(discovery_mock, caplog: pytest.LogCaptureFixt
await Discover.discover() await Discover.discover()
assert mac not in caplog.text assert mac not in caplog.text
assert "12:34:56:00:00:00" in caplog.text assert "12:34:56:00:00:00" in caplog.text
async def test_discovery_decryption():
"""Test discovery decryption."""
key = b"8\x89\x02\xfa\xf5Xs\x1c\xa1 H\x9a\x82\xc7\xd9\t"
iv = b"9=\xf8\x1bS\xcd0\xb5\x89i\xba\xfd^9\x9f\xfa"
key_iv = key + iv
_AesDiscoveryQuery.generate_query()
keypair = _AesDiscoveryQuery.keypair
padding = asymmetric_padding.OAEP(
mgf=asymmetric_padding.MGF1(algorithm=hashes.SHA1()), # noqa: S303
algorithm=hashes.SHA1(), # noqa: S303
label=None,
)
encrypted_key_iv = keypair.public_key.encrypt(key_iv, padding)
encrypted_key_iv_b4 = base64.b64encode(encrypted_key_iv)
encryption_session = AesEncyptionSession(key_iv[:16], key_iv[16:])
data_dict = {"foo": 1, "bar": 2}
data = json.dumps(data_dict)
encypted_data = encryption_session.encrypt(data.encode())
encrypt_info = {
"data": encypted_data.decode(),
"key": encrypted_key_iv_b4.decode(),
"sym_schm": "AES",
}
info = {**UNSUPPORTED["result"], "encrypt_info": encrypt_info}
dr = DiscoveryResult(**info)
Discover._decrypt_discovery_data(dr)
assert dr.decrypted_data == data_dict