mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-01-22 12:47:05 +00:00
Enable newer encrypted discovery protocol (#1168)
This commit is contained in:
parent
7fd8c14c1f
commit
380fbb93c3
@ -14,7 +14,7 @@ from collections.abc import AsyncGenerator
|
||||
from enum import Enum, auto
|
||||
from typing import TYPE_CHECKING, Any, Dict, cast
|
||||
|
||||
from cryptography.hazmat.primitives import padding, serialization
|
||||
from cryptography.hazmat.primitives import hashes, padding, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
@ -108,7 +108,9 @@ class AesTransport(BaseTransport):
|
||||
self._key_pair: KeyPair | None = None
|
||||
if config.aes_keys:
|
||||
aes_keys = config.aes_keys
|
||||
self._key_pair = KeyPair(aes_keys["private"], aes_keys["public"])
|
||||
self._key_pair = KeyPair.create_from_der_keys(
|
||||
aes_keys["private"], aes_keys["public"]
|
||||
)
|
||||
self._app_url = URL(f"http://{self._host}:{self._port}/app")
|
||||
self._token_url: URL | None = None
|
||||
|
||||
@ -277,14 +279,14 @@ class AesTransport(BaseTransport):
|
||||
if not self._key_pair:
|
||||
kp = KeyPair.create_key_pair()
|
||||
self._config.aes_keys = {
|
||||
"private": kp.get_private_key(),
|
||||
"public": kp.get_public_key(),
|
||||
"private": kp.private_key_der_b64,
|
||||
"public": kp.public_key_der_b64,
|
||||
}
|
||||
self._key_pair = kp
|
||||
|
||||
pub_key = (
|
||||
"-----BEGIN PUBLIC KEY-----\n"
|
||||
+ self._key_pair.get_public_key() # type: ignore[union-attr]
|
||||
+ self._key_pair.public_key_der_b64 # type: ignore[union-attr]
|
||||
+ "\n-----END PUBLIC KEY-----\n"
|
||||
)
|
||||
handshake_params = {"key": pub_key}
|
||||
@ -392,18 +394,11 @@ class AesEncyptionSession:
|
||||
"""Class for an AES encryption session."""
|
||||
|
||||
@staticmethod
|
||||
def create_from_keypair(handshake_key: str, keypair):
|
||||
def create_from_keypair(handshake_key: str, keypair: KeyPair):
|
||||
"""Create the encryption session."""
|
||||
handshake_key_bytes: bytes = base64.b64decode(handshake_key.encode("UTF-8"))
|
||||
private_key_data = base64.b64decode(keypair.get_private_key().encode("UTF-8"))
|
||||
handshake_key_bytes: bytes = base64.b64decode(handshake_key.encode())
|
||||
|
||||
private_key = cast(
|
||||
rsa.RSAPrivateKey,
|
||||
serialization.load_der_private_key(private_key_data, None, None),
|
||||
)
|
||||
key_and_iv = private_key.decrypt(
|
||||
handshake_key_bytes, asymmetric_padding.PKCS1v15()
|
||||
)
|
||||
key_and_iv = keypair.decrypt_handshake_key(handshake_key_bytes)
|
||||
if key_and_iv is None:
|
||||
raise ValueError("Decryption failed!")
|
||||
|
||||
@ -438,30 +433,59 @@ class KeyPair:
|
||||
"""Create a key pair."""
|
||||
private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size)
|
||||
public_key = private_key.public_key()
|
||||
return KeyPair(private_key, public_key)
|
||||
|
||||
private_key_bytes = private_key.private_bytes(
|
||||
@staticmethod
|
||||
def create_from_der_keys(private_key_der_b64: str, public_key_der_b64: str):
|
||||
"""Create a key pair."""
|
||||
key_bytes = base64.b64decode(private_key_der_b64.encode())
|
||||
private_key = cast(
|
||||
rsa.RSAPrivateKey, serialization.load_der_private_key(key_bytes, None)
|
||||
)
|
||||
key_bytes = base64.b64decode(public_key_der_b64.encode())
|
||||
public_key = cast(
|
||||
rsa.RSAPublicKey, serialization.load_der_public_key(key_bytes, None)
|
||||
)
|
||||
|
||||
return KeyPair(private_key, public_key)
|
||||
|
||||
def __init__(self, private_key: rsa.RSAPrivateKey, public_key: rsa.RSAPublicKey):
|
||||
self.private_key = private_key
|
||||
self.public_key = public_key
|
||||
self.private_key_der_bytes = self.private_key.private_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
public_key_bytes = public_key.public_bytes(
|
||||
self.public_key_der_bytes = self.public_key.public_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
self.private_key_der_b64 = base64.b64encode(self.private_key_der_bytes).decode()
|
||||
self.public_key_der_b64 = base64.b64encode(self.public_key_der_bytes).decode()
|
||||
|
||||
return KeyPair(
|
||||
private_key=base64.b64encode(private_key_bytes).decode("UTF-8"),
|
||||
public_key=base64.b64encode(public_key_bytes).decode("UTF-8"),
|
||||
def get_public_pem(self) -> bytes:
|
||||
"""Get public key in PEM encoding."""
|
||||
return self.public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
def __init__(self, private_key: str, public_key: str):
|
||||
self.private_key = private_key
|
||||
self.public_key = public_key
|
||||
def decrypt_handshake_key(self, encrypted_key: bytes) -> bytes:
|
||||
"""Decrypt an aes handshake key."""
|
||||
decrypted = self.private_key.decrypt(
|
||||
encrypted_key, asymmetric_padding.PKCS1v15()
|
||||
)
|
||||
return decrypted
|
||||
|
||||
def get_private_key(self) -> str:
|
||||
"""Get the private key."""
|
||||
return self.private_key
|
||||
|
||||
def get_public_key(self) -> str:
|
||||
"""Get the public key."""
|
||||
return self.public_key
|
||||
def decrypt_discovery_key(self, encrypted_key: bytes) -> bytes:
|
||||
"""Decrypt an aes discovery key."""
|
||||
decrypted = self.private_key.decrypt(
|
||||
encrypted_key,
|
||||
asymmetric_padding.OAEP(
|
||||
mgf=asymmetric_padding.MGF1(algorithm=hashes.SHA1()), # noqa: S303
|
||||
algorithm=hashes.SHA1(), # noqa: S303
|
||||
label=None,
|
||||
),
|
||||
)
|
||||
return decrypted
|
||||
|
@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pprint import pformat as pf
|
||||
|
||||
import asyncclick as click
|
||||
from pydantic.v1 import ValidationError
|
||||
@ -28,6 +29,7 @@ async def discover(ctx):
|
||||
password = ctx.parent.params["password"]
|
||||
discovery_timeout = ctx.parent.params["discovery_timeout"]
|
||||
timeout = ctx.parent.params["timeout"]
|
||||
host = ctx.parent.params["host"]
|
||||
port = ctx.parent.params["port"]
|
||||
|
||||
credentials = Credentials(username, password) if username and password else None
|
||||
@ -49,8 +51,6 @@ async def discover(ctx):
|
||||
echo(f"\t{unsupported_exception}")
|
||||
echo()
|
||||
|
||||
echo(f"Discovering devices on {target} for {discovery_timeout} seconds")
|
||||
|
||||
from .device import state
|
||||
|
||||
async def print_discovered(dev: Device):
|
||||
@ -68,6 +68,18 @@ async def discover(ctx):
|
||||
discovered[dev.host] = dev.internal_state
|
||||
echo()
|
||||
|
||||
if host:
|
||||
echo(f"Discovering device {host} for {discovery_timeout} seconds")
|
||||
return await Discover.discover_single(
|
||||
host,
|
||||
port=port,
|
||||
credentials=credentials,
|
||||
timeout=timeout,
|
||||
discovery_timeout=discovery_timeout,
|
||||
on_unsupported=print_unsupported,
|
||||
)
|
||||
|
||||
echo(f"Discovering devices on {target} for {discovery_timeout} seconds")
|
||||
discovered_devices = await Discover.discover(
|
||||
target=target,
|
||||
discovery_timeout=discovery_timeout,
|
||||
@ -113,21 +125,31 @@ def _echo_discovery_info(discovery_info):
|
||||
_echo_dictionary(discovery_info)
|
||||
return
|
||||
|
||||
def _conditional_echo(label, value):
|
||||
if value:
|
||||
ws = " " * (19 - len(label))
|
||||
echo(f"\t{label}:{ws}{value}")
|
||||
|
||||
echo("\t[bold]== Discovery Result ==[/bold]")
|
||||
echo(f"\tDevice Type: {dr.device_type}")
|
||||
echo(f"\tDevice Model: {dr.device_model}")
|
||||
echo(f"\tIP: {dr.ip}")
|
||||
echo(f"\tMAC: {dr.mac}")
|
||||
echo(f"\tDevice Id (hash): {dr.device_id}")
|
||||
echo(f"\tOwner (hash): {dr.owner}")
|
||||
echo(f"\tHW Ver: {dr.hw_ver}")
|
||||
echo(f"\tSupports IOT Cloud: {dr.is_support_iot_cloud}")
|
||||
echo(f"\tOBD Src: {dr.obd_src}")
|
||||
echo(f"\tFactory Default: {dr.factory_default}")
|
||||
echo(f"\tEncrypt Type: {dr.mgt_encrypt_schm.encrypt_type}")
|
||||
echo(f"\tSupports HTTPS: {dr.mgt_encrypt_schm.is_support_https}")
|
||||
echo(f"\tHTTP Port: {dr.mgt_encrypt_schm.http_port}")
|
||||
echo(f"\tLV (Login Level): {dr.mgt_encrypt_schm.lv}")
|
||||
_conditional_echo("Device Type", dr.device_type)
|
||||
_conditional_echo("Device Model", dr.device_model)
|
||||
_conditional_echo("Device Name", dr.device_name)
|
||||
_conditional_echo("IP", dr.ip)
|
||||
_conditional_echo("MAC", dr.mac)
|
||||
_conditional_echo("Device Id (hash)", dr.device_id)
|
||||
_conditional_echo("Owner (hash)", dr.owner)
|
||||
_conditional_echo("FW Ver", dr.firmware_version)
|
||||
_conditional_echo("HW Ver", dr.hw_ver)
|
||||
_conditional_echo("HW Ver", dr.hardware_version)
|
||||
_conditional_echo("Supports IOT Cloud", dr.is_support_iot_cloud)
|
||||
_conditional_echo("OBD Src", dr.owner)
|
||||
_conditional_echo("Factory Default", dr.factory_default)
|
||||
_conditional_echo("Encrypt Type", dr.mgt_encrypt_schm.encrypt_type)
|
||||
_conditional_echo("Encrypt Type", dr.encrypt_type)
|
||||
_conditional_echo("Supports HTTPS", dr.mgt_encrypt_schm.is_support_https)
|
||||
_conditional_echo("HTTP Port", dr.mgt_encrypt_schm.http_port)
|
||||
_conditional_echo("Encrypt info", pf(dr.encrypt_info) if dr.encrypt_info else None)
|
||||
_conditional_echo("Decrypted", pf(dr.decrypted_data) if dr.decrypted_data else None)
|
||||
|
||||
|
||||
async def find_host_from_alias(alias, target="255.255.255.255", timeout=1, attempts=3):
|
||||
|
@ -158,6 +158,7 @@ def _legacy_type_to_class(_type):
|
||||
type=click.Choice(ENCRYPT_TYPES, case_sensitive=False),
|
||||
)
|
||||
@click.option(
|
||||
"-df",
|
||||
"--device-family",
|
||||
envvar="KASA_DEVICE_FAMILY",
|
||||
default="SMART.TAPOPLUG",
|
||||
@ -182,7 +183,7 @@ def _legacy_type_to_class(_type):
|
||||
@click.option(
|
||||
"--discovery-timeout",
|
||||
envvar="KASA_DISCOVERY_TIMEOUT",
|
||||
default=5,
|
||||
default=10,
|
||||
required=False,
|
||||
show_default=True,
|
||||
help="Timeout for discovery.",
|
||||
@ -326,15 +327,11 @@ async def cli(
|
||||
dev = await Device.connect(config=config)
|
||||
device_updated = True
|
||||
else:
|
||||
from kasa.discover import Discover
|
||||
from .discover import discover
|
||||
|
||||
dev = await Discover.discover_single(
|
||||
host,
|
||||
port=port,
|
||||
credentials=credentials,
|
||||
timeout=timeout,
|
||||
discovery_timeout=discovery_timeout,
|
||||
)
|
||||
dev = await ctx.invoke(discover)
|
||||
if not dev:
|
||||
error(f"Unable to create device for {host}")
|
||||
|
||||
# Skip update on specific commands, or if device factory,
|
||||
# that performs an update was used for the device.
|
||||
|
112
kasa/discover.py
112
kasa/discover.py
@ -82,13 +82,16 @@ Discovering a single device returns a kasa.Device object.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import binascii
|
||||
import ipaddress
|
||||
import logging
|
||||
import secrets
|
||||
import socket
|
||||
import struct
|
||||
from collections.abc import Awaitable
|
||||
from pprint import pformat as pf
|
||||
from typing import Any, Callable, Dict, Optional, Type, cast
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast
|
||||
|
||||
# When support for cpython older than 3.11 is dropped
|
||||
# async_timeout can be replaced with asyncio.timeout
|
||||
@ -96,6 +99,7 @@ from async_timeout import timeout as asyncio_timeout
|
||||
from pydantic.v1 import BaseModel, ValidationError
|
||||
|
||||
from kasa import Device
|
||||
from kasa.aestransport import AesEncyptionSession, KeyPair
|
||||
from kasa.credentials import Credentials
|
||||
from kasa.device_factory import (
|
||||
get_device_class_from_family,
|
||||
@ -133,6 +137,46 @@ NEW_DISCOVERY_REDACTORS: dict[str, Callable[[Any], Any] | None] = {
|
||||
}
|
||||
|
||||
|
||||
class _AesDiscoveryQuery:
|
||||
keypair: KeyPair | None = None
|
||||
|
||||
@classmethod
|
||||
def generate_query(cls):
|
||||
if not cls.keypair:
|
||||
cls.keypair = KeyPair.create_key_pair(key_size=2048)
|
||||
secret = secrets.token_bytes(4)
|
||||
|
||||
key_payload = {"params": {"rsa_key": cls.keypair.get_public_pem().decode()}}
|
||||
|
||||
key_payload_bytes = json_dumps(key_payload).encode()
|
||||
# https://labs.withsecure.com/advisories/tp-link-ac1750-pwn2own-2019
|
||||
version = 2 # version of tdp
|
||||
msg_type = 0
|
||||
op_code = 1 # probe
|
||||
msg_size = len(key_payload_bytes)
|
||||
flags = 17
|
||||
padding_byte = 0 # blank byte
|
||||
device_serial = int.from_bytes(secret, "big")
|
||||
initial_crc = 0x5A6B7C8D
|
||||
|
||||
disco_header = struct.pack(
|
||||
">BBHHBBII",
|
||||
version,
|
||||
msg_type,
|
||||
op_code,
|
||||
msg_size,
|
||||
flags,
|
||||
padding_byte,
|
||||
device_serial,
|
||||
initial_crc,
|
||||
)
|
||||
|
||||
query = bytearray(disco_header + key_payload_bytes)
|
||||
crc = binascii.crc32(query).to_bytes(length=4, byteorder="big")
|
||||
query[12:16] = crc
|
||||
return query
|
||||
|
||||
|
||||
class _DiscoverProtocol(asyncio.DatagramProtocol):
|
||||
"""Implementation of the discovery protocol handler.
|
||||
|
||||
@ -224,15 +268,21 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
|
||||
_LOGGER.debug("[DISCOVERY] %s >> %s", self.target, Discover.DISCOVERY_QUERY)
|
||||
encrypted_req = XorEncryption.encrypt(req)
|
||||
sleep_between_packets = self.discovery_timeout / self.discovery_packets
|
||||
|
||||
aes_discovery_query = _AesDiscoveryQuery.generate_query()
|
||||
for _ in range(self.discovery_packets):
|
||||
if self.target in self.seen_hosts: # Stop sending for discover_single
|
||||
break
|
||||
self.transport.sendto(encrypted_req[4:], self.target_1) # type: ignore
|
||||
self.transport.sendto(Discover.DISCOVERY_QUERY_2, self.target_2) # type: ignore
|
||||
self.transport.sendto(aes_discovery_query, self.target_2) # type: ignore
|
||||
await asyncio.sleep(sleep_between_packets)
|
||||
|
||||
def datagram_received(self, data, addr) -> None:
|
||||
"""Handle discovery responses."""
|
||||
if TYPE_CHECKING:
|
||||
assert _AesDiscoveryQuery.keypair
|
||||
|
||||
ip, port = addr
|
||||
# Prevent multiple entries due multiple broadcasts
|
||||
if ip in self.seen_hosts:
|
||||
@ -395,7 +445,8 @@ class Discover:
|
||||
credentials: Credentials | None = None,
|
||||
username: str | None = None,
|
||||
password: str | None = None,
|
||||
) -> Device:
|
||||
on_unsupported: OnUnsupportedCallable | None = None,
|
||||
) -> Device | None:
|
||||
"""Discover a single device by the given IP address.
|
||||
|
||||
It is generally preferred to avoid :func:`discover_single()` and
|
||||
@ -465,7 +516,11 @@ class Discover:
|
||||
dev.host = host
|
||||
return dev
|
||||
elif ip in protocol.unsupported_device_exceptions:
|
||||
raise protocol.unsupported_device_exceptions[ip]
|
||||
if on_unsupported:
|
||||
await on_unsupported(protocol.unsupported_device_exceptions[ip])
|
||||
return None
|
||||
else:
|
||||
raise protocol.unsupported_device_exceptions[ip]
|
||||
elif ip in protocol.invalid_device_exceptions:
|
||||
raise protocol.invalid_device_exceptions[ip]
|
||||
else:
|
||||
@ -512,6 +567,25 @@ class Discover:
|
||||
device.update_from_discover_info(info)
|
||||
return device
|
||||
|
||||
@staticmethod
|
||||
def _decrypt_discovery_data(discovery_result: DiscoveryResult) -> None:
|
||||
if TYPE_CHECKING:
|
||||
assert discovery_result.encrypt_info
|
||||
assert _AesDiscoveryQuery.keypair
|
||||
encryped_key = discovery_result.encrypt_info.key
|
||||
encrypted_data = discovery_result.encrypt_info.data
|
||||
|
||||
key_and_iv = _AesDiscoveryQuery.keypair.decrypt_discovery_key(
|
||||
base64.b64decode(encryped_key.encode())
|
||||
)
|
||||
|
||||
key, iv = key_and_iv[:16], key_and_iv[16:]
|
||||
|
||||
session = AesEncyptionSession(key, iv)
|
||||
decrypted_data = session.decrypt(encrypted_data)
|
||||
|
||||
discovery_result.decrypted_data = json_loads(decrypted_data)
|
||||
|
||||
@staticmethod
|
||||
def _get_device_instance(
|
||||
data: bytes,
|
||||
@ -528,6 +602,8 @@ class Discover:
|
||||
) from ex
|
||||
try:
|
||||
discovery_result = DiscoveryResult(**info["result"])
|
||||
if discovery_result.encrypt_info:
|
||||
Discover._decrypt_discovery_data(discovery_result)
|
||||
except ValidationError as ex:
|
||||
if debug_enabled:
|
||||
data = (
|
||||
@ -547,9 +623,19 @@ class Discover:
|
||||
type_ = discovery_result.device_type
|
||||
|
||||
try:
|
||||
if not (
|
||||
encrypt_type := discovery_result.mgt_encrypt_schm.encrypt_type
|
||||
) and (encrypt_info := discovery_result.encrypt_info):
|
||||
encrypt_type = encrypt_info.sym_schm
|
||||
if not encrypt_type:
|
||||
raise UnsupportedDeviceError(
|
||||
f"Unsupported device {config.host} of type {type_} "
|
||||
+ "with no encryption type",
|
||||
discovery_result=discovery_result.get_dict(),
|
||||
)
|
||||
config.connection_type = DeviceConnectionParameters.from_values(
|
||||
type_,
|
||||
discovery_result.mgt_encrypt_schm.encrypt_type,
|
||||
encrypt_type,
|
||||
discovery_result.mgt_encrypt_schm.lv,
|
||||
)
|
||||
except KasaException as ex:
|
||||
@ -593,21 +679,35 @@ class EncryptionScheme(BaseModel):
|
||||
"""Base model for encryption scheme of discovery result."""
|
||||
|
||||
is_support_https: bool
|
||||
encrypt_type: str
|
||||
http_port: int
|
||||
encrypt_type: Optional[str] # noqa: UP007
|
||||
http_port: Optional[int] = None # noqa: UP007
|
||||
lv: Optional[int] = None # noqa: UP007
|
||||
|
||||
|
||||
class EncryptionInfo(BaseModel):
|
||||
"""Base model for encryption info of discovery result."""
|
||||
|
||||
sym_schm: str
|
||||
key: str
|
||||
data: str
|
||||
|
||||
|
||||
class DiscoveryResult(BaseModel):
|
||||
"""Base model for discovery result."""
|
||||
|
||||
device_type: str
|
||||
device_model: str
|
||||
device_name: Optional[str] # noqa: UP007
|
||||
ip: str
|
||||
mac: str
|
||||
mgt_encrypt_schm: EncryptionScheme
|
||||
encrypt_info: Optional[EncryptionInfo] = None # noqa: UP007
|
||||
encrypt_type: Optional[list[str]] = None # noqa: UP007
|
||||
decrypted_data: Optional[dict] = None # noqa: UP007
|
||||
device_id: str
|
||||
|
||||
firmware_version: Optional[str] = None # noqa: UP007
|
||||
hardware_version: Optional[str] = None # noqa: UP007
|
||||
hw_ver: Optional[str] = None # noqa: UP007
|
||||
owner: Optional[str] = None # noqa: UP007
|
||||
is_support_iot_cloud: Optional[bool] = None # noqa: UP007
|
||||
|
@ -99,8 +99,8 @@ async def test_handshake_with_keys(mocker):
|
||||
assert transport._state is TransportState.HANDSHAKE_REQUIRED
|
||||
|
||||
await transport.perform_handshake()
|
||||
assert transport._key_pair.get_private_key() == test_keys["private"]
|
||||
assert transport._key_pair.get_public_key() == test_keys["public"]
|
||||
assert transport._key_pair.private_key_der_b64 == test_keys["private"]
|
||||
assert transport._key_pair.public_key_der_b64 == test_keys["public"]
|
||||
|
||||
|
||||
@status_parameters
|
||||
|
@ -2,6 +2,7 @@ import json
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
from unittest.mock import ANY
|
||||
|
||||
import asyncclick as click
|
||||
import pytest
|
||||
@ -17,7 +18,6 @@ from kasa import (
|
||||
EmeterStatus,
|
||||
KasaException,
|
||||
Module,
|
||||
UnsupportedDeviceError,
|
||||
)
|
||||
from kasa.cli.device import (
|
||||
alias,
|
||||
@ -613,6 +613,7 @@ async def test_without_device_type(dev, mocker, runner):
|
||||
credentials=Credentials("foo", "bar"),
|
||||
timeout=5,
|
||||
discovery_timeout=7,
|
||||
on_unsupported=ANY,
|
||||
)
|
||||
|
||||
|
||||
@ -735,7 +736,7 @@ async def test_host_unsupported(unsupported_device_info, runner):
|
||||
)
|
||||
|
||||
assert res.exit_code != 0
|
||||
assert isinstance(res.exception, UnsupportedDeviceError)
|
||||
assert "== Unsupported device ==" in res.output
|
||||
|
||||
|
||||
@new_discovery
|
||||
|
@ -2,6 +2,8 @@
|
||||
# ruff: noqa: S106
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import socket
|
||||
@ -10,6 +12,8 @@ from unittest.mock import MagicMock
|
||||
import aiohttp
|
||||
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
|
||||
from async_timeout import timeout as asyncio_timeout
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
|
||||
|
||||
from kasa import (
|
||||
Credentials,
|
||||
@ -18,11 +22,17 @@ from kasa import (
|
||||
Discover,
|
||||
KasaException,
|
||||
)
|
||||
from kasa.aestransport import AesEncyptionSession
|
||||
from kasa.deviceconfig import (
|
||||
DeviceConfig,
|
||||
DeviceConnectionParameters,
|
||||
)
|
||||
from kasa.discover import DiscoveryResult, _DiscoverProtocol, json_dumps
|
||||
from kasa.discover import (
|
||||
DiscoveryResult,
|
||||
_AesDiscoveryQuery,
|
||||
_DiscoverProtocol,
|
||||
json_dumps,
|
||||
)
|
||||
from kasa.exceptions import AuthenticationError, UnsupportedDeviceError
|
||||
from kasa.iot import IotDevice
|
||||
from kasa.xortransport import XorEncryption
|
||||
@ -278,7 +288,7 @@ async def test_discover_send(mocker):
|
||||
assert proto.target_1 == ("255.255.255.255", 9999)
|
||||
transport = mocker.patch.object(proto, "transport")
|
||||
await proto.do_discover()
|
||||
assert transport.sendto.call_count == proto.discovery_packets * 2
|
||||
assert transport.sendto.call_count == proto.discovery_packets * 3
|
||||
|
||||
|
||||
async def test_discover_datagram_received(mocker, discovery_data):
|
||||
@ -485,13 +495,14 @@ async def test_do_discover_drop_packets(mocker, port, do_not_reply_count):
|
||||
discovery_timeout=discovery_timeout,
|
||||
discovery_packets=5,
|
||||
)
|
||||
ft = FakeDatagramTransport(dp, port, do_not_reply_count)
|
||||
expected_send = 1 if port == 9999 else 2
|
||||
ft = FakeDatagramTransport(dp, port, do_not_reply_count * expected_send)
|
||||
dp.connection_made(ft)
|
||||
|
||||
await dp.wait_for_discovery_to_complete()
|
||||
|
||||
await asyncio.sleep(0)
|
||||
assert ft.send_count == do_not_reply_count + 1
|
||||
assert ft.send_count == do_not_reply_count * expected_send + expected_send
|
||||
assert dp.discover_task.done()
|
||||
assert dp.discover_task.cancelled()
|
||||
|
||||
@ -603,3 +614,36 @@ async def test_discovery_redaction(discovery_mock, caplog: pytest.LogCaptureFixt
|
||||
await Discover.discover()
|
||||
assert mac not in caplog.text
|
||||
assert "12:34:56:00:00:00" in caplog.text
|
||||
|
||||
|
||||
async def test_discovery_decryption():
|
||||
"""Test discovery decryption."""
|
||||
key = b"8\x89\x02\xfa\xf5Xs\x1c\xa1 H\x9a\x82\xc7\xd9\t"
|
||||
iv = b"9=\xf8\x1bS\xcd0\xb5\x89i\xba\xfd^9\x9f\xfa"
|
||||
key_iv = key + iv
|
||||
|
||||
_AesDiscoveryQuery.generate_query()
|
||||
keypair = _AesDiscoveryQuery.keypair
|
||||
|
||||
padding = asymmetric_padding.OAEP(
|
||||
mgf=asymmetric_padding.MGF1(algorithm=hashes.SHA1()), # noqa: S303
|
||||
algorithm=hashes.SHA1(), # noqa: S303
|
||||
label=None,
|
||||
)
|
||||
encrypted_key_iv = keypair.public_key.encrypt(key_iv, padding)
|
||||
encrypted_key_iv_b4 = base64.b64encode(encrypted_key_iv)
|
||||
encryption_session = AesEncyptionSession(key_iv[:16], key_iv[16:])
|
||||
|
||||
data_dict = {"foo": 1, "bar": 2}
|
||||
data = json.dumps(data_dict)
|
||||
encypted_data = encryption_session.encrypt(data.encode())
|
||||
|
||||
encrypt_info = {
|
||||
"data": encypted_data.decode(),
|
||||
"key": encrypted_key_iv_b4.decode(),
|
||||
"sym_schm": "AES",
|
||||
}
|
||||
info = {**UNSUPPORTED["result"], "encrypt_info": encrypt_info}
|
||||
dr = DiscoveryResult(**info)
|
||||
Discover._decrypt_discovery_data(dr)
|
||||
assert dr.decrypted_data == data_dict
|
||||
|
Loading…
Reference in New Issue
Block a user