mirror of
https://github.com/python-kasa/python-kasa.git
synced 2024-12-22 19:23:34 +00:00
Enable newer encrypted discovery protocol (#1168)
This commit is contained in:
parent
7fd8c14c1f
commit
380fbb93c3
@ -14,7 +14,7 @@ from collections.abc import AsyncGenerator
|
|||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from typing import TYPE_CHECKING, Any, Dict, cast
|
from typing import TYPE_CHECKING, Any, Dict, cast
|
||||||
|
|
||||||
from cryptography.hazmat.primitives import padding, serialization
|
from cryptography.hazmat.primitives import hashes, padding, serialization
|
||||||
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
|
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
|
||||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||||
@ -108,7 +108,9 @@ class AesTransport(BaseTransport):
|
|||||||
self._key_pair: KeyPair | None = None
|
self._key_pair: KeyPair | None = None
|
||||||
if config.aes_keys:
|
if config.aes_keys:
|
||||||
aes_keys = config.aes_keys
|
aes_keys = config.aes_keys
|
||||||
self._key_pair = KeyPair(aes_keys["private"], aes_keys["public"])
|
self._key_pair = KeyPair.create_from_der_keys(
|
||||||
|
aes_keys["private"], aes_keys["public"]
|
||||||
|
)
|
||||||
self._app_url = URL(f"http://{self._host}:{self._port}/app")
|
self._app_url = URL(f"http://{self._host}:{self._port}/app")
|
||||||
self._token_url: URL | None = None
|
self._token_url: URL | None = None
|
||||||
|
|
||||||
@ -277,14 +279,14 @@ class AesTransport(BaseTransport):
|
|||||||
if not self._key_pair:
|
if not self._key_pair:
|
||||||
kp = KeyPair.create_key_pair()
|
kp = KeyPair.create_key_pair()
|
||||||
self._config.aes_keys = {
|
self._config.aes_keys = {
|
||||||
"private": kp.get_private_key(),
|
"private": kp.private_key_der_b64,
|
||||||
"public": kp.get_public_key(),
|
"public": kp.public_key_der_b64,
|
||||||
}
|
}
|
||||||
self._key_pair = kp
|
self._key_pair = kp
|
||||||
|
|
||||||
pub_key = (
|
pub_key = (
|
||||||
"-----BEGIN PUBLIC KEY-----\n"
|
"-----BEGIN PUBLIC KEY-----\n"
|
||||||
+ self._key_pair.get_public_key() # type: ignore[union-attr]
|
+ self._key_pair.public_key_der_b64 # type: ignore[union-attr]
|
||||||
+ "\n-----END PUBLIC KEY-----\n"
|
+ "\n-----END PUBLIC KEY-----\n"
|
||||||
)
|
)
|
||||||
handshake_params = {"key": pub_key}
|
handshake_params = {"key": pub_key}
|
||||||
@ -392,18 +394,11 @@ class AesEncyptionSession:
|
|||||||
"""Class for an AES encryption session."""
|
"""Class for an AES encryption session."""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_from_keypair(handshake_key: str, keypair):
|
def create_from_keypair(handshake_key: str, keypair: KeyPair):
|
||||||
"""Create the encryption session."""
|
"""Create the encryption session."""
|
||||||
handshake_key_bytes: bytes = base64.b64decode(handshake_key.encode("UTF-8"))
|
handshake_key_bytes: bytes = base64.b64decode(handshake_key.encode())
|
||||||
private_key_data = base64.b64decode(keypair.get_private_key().encode("UTF-8"))
|
|
||||||
|
|
||||||
private_key = cast(
|
key_and_iv = keypair.decrypt_handshake_key(handshake_key_bytes)
|
||||||
rsa.RSAPrivateKey,
|
|
||||||
serialization.load_der_private_key(private_key_data, None, None),
|
|
||||||
)
|
|
||||||
key_and_iv = private_key.decrypt(
|
|
||||||
handshake_key_bytes, asymmetric_padding.PKCS1v15()
|
|
||||||
)
|
|
||||||
if key_and_iv is None:
|
if key_and_iv is None:
|
||||||
raise ValueError("Decryption failed!")
|
raise ValueError("Decryption failed!")
|
||||||
|
|
||||||
@ -438,30 +433,59 @@ class KeyPair:
|
|||||||
"""Create a key pair."""
|
"""Create a key pair."""
|
||||||
private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size)
|
private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size)
|
||||||
public_key = private_key.public_key()
|
public_key = private_key.public_key()
|
||||||
|
return KeyPair(private_key, public_key)
|
||||||
|
|
||||||
private_key_bytes = private_key.private_bytes(
|
@staticmethod
|
||||||
|
def create_from_der_keys(private_key_der_b64: str, public_key_der_b64: str):
|
||||||
|
"""Create a key pair."""
|
||||||
|
key_bytes = base64.b64decode(private_key_der_b64.encode())
|
||||||
|
private_key = cast(
|
||||||
|
rsa.RSAPrivateKey, serialization.load_der_private_key(key_bytes, None)
|
||||||
|
)
|
||||||
|
key_bytes = base64.b64decode(public_key_der_b64.encode())
|
||||||
|
public_key = cast(
|
||||||
|
rsa.RSAPublicKey, serialization.load_der_public_key(key_bytes, None)
|
||||||
|
)
|
||||||
|
|
||||||
|
return KeyPair(private_key, public_key)
|
||||||
|
|
||||||
|
def __init__(self, private_key: rsa.RSAPrivateKey, public_key: rsa.RSAPublicKey):
|
||||||
|
self.private_key = private_key
|
||||||
|
self.public_key = public_key
|
||||||
|
self.private_key_der_bytes = self.private_key.private_bytes(
|
||||||
encoding=serialization.Encoding.DER,
|
encoding=serialization.Encoding.DER,
|
||||||
format=serialization.PrivateFormat.PKCS8,
|
format=serialization.PrivateFormat.PKCS8,
|
||||||
encryption_algorithm=serialization.NoEncryption(),
|
encryption_algorithm=serialization.NoEncryption(),
|
||||||
)
|
)
|
||||||
public_key_bytes = public_key.public_bytes(
|
self.public_key_der_bytes = self.public_key.public_bytes(
|
||||||
encoding=serialization.Encoding.DER,
|
encoding=serialization.Encoding.DER,
|
||||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||||
)
|
)
|
||||||
|
self.private_key_der_b64 = base64.b64encode(self.private_key_der_bytes).decode()
|
||||||
|
self.public_key_der_b64 = base64.b64encode(self.public_key_der_bytes).decode()
|
||||||
|
|
||||||
return KeyPair(
|
def get_public_pem(self) -> bytes:
|
||||||
private_key=base64.b64encode(private_key_bytes).decode("UTF-8"),
|
"""Get public key in PEM encoding."""
|
||||||
public_key=base64.b64encode(public_key_bytes).decode("UTF-8"),
|
return self.public_key.public_bytes(
|
||||||
|
encoding=serialization.Encoding.PEM,
|
||||||
|
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, private_key: str, public_key: str):
|
def decrypt_handshake_key(self, encrypted_key: bytes) -> bytes:
|
||||||
self.private_key = private_key
|
"""Decrypt an aes handshake key."""
|
||||||
self.public_key = public_key
|
decrypted = self.private_key.decrypt(
|
||||||
|
encrypted_key, asymmetric_padding.PKCS1v15()
|
||||||
|
)
|
||||||
|
return decrypted
|
||||||
|
|
||||||
def get_private_key(self) -> str:
|
def decrypt_discovery_key(self, encrypted_key: bytes) -> bytes:
|
||||||
"""Get the private key."""
|
"""Decrypt an aes discovery key."""
|
||||||
return self.private_key
|
decrypted = self.private_key.decrypt(
|
||||||
|
encrypted_key,
|
||||||
def get_public_key(self) -> str:
|
asymmetric_padding.OAEP(
|
||||||
"""Get the public key."""
|
mgf=asymmetric_padding.MGF1(algorithm=hashes.SHA1()), # noqa: S303
|
||||||
return self.public_key
|
algorithm=hashes.SHA1(), # noqa: S303
|
||||||
|
label=None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return decrypted
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from pprint import pformat as pf
|
||||||
|
|
||||||
import asyncclick as click
|
import asyncclick as click
|
||||||
from pydantic.v1 import ValidationError
|
from pydantic.v1 import ValidationError
|
||||||
@ -28,6 +29,7 @@ async def discover(ctx):
|
|||||||
password = ctx.parent.params["password"]
|
password = ctx.parent.params["password"]
|
||||||
discovery_timeout = ctx.parent.params["discovery_timeout"]
|
discovery_timeout = ctx.parent.params["discovery_timeout"]
|
||||||
timeout = ctx.parent.params["timeout"]
|
timeout = ctx.parent.params["timeout"]
|
||||||
|
host = ctx.parent.params["host"]
|
||||||
port = ctx.parent.params["port"]
|
port = ctx.parent.params["port"]
|
||||||
|
|
||||||
credentials = Credentials(username, password) if username and password else None
|
credentials = Credentials(username, password) if username and password else None
|
||||||
@ -49,8 +51,6 @@ async def discover(ctx):
|
|||||||
echo(f"\t{unsupported_exception}")
|
echo(f"\t{unsupported_exception}")
|
||||||
echo()
|
echo()
|
||||||
|
|
||||||
echo(f"Discovering devices on {target} for {discovery_timeout} seconds")
|
|
||||||
|
|
||||||
from .device import state
|
from .device import state
|
||||||
|
|
||||||
async def print_discovered(dev: Device):
|
async def print_discovered(dev: Device):
|
||||||
@ -68,6 +68,18 @@ async def discover(ctx):
|
|||||||
discovered[dev.host] = dev.internal_state
|
discovered[dev.host] = dev.internal_state
|
||||||
echo()
|
echo()
|
||||||
|
|
||||||
|
if host:
|
||||||
|
echo(f"Discovering device {host} for {discovery_timeout} seconds")
|
||||||
|
return await Discover.discover_single(
|
||||||
|
host,
|
||||||
|
port=port,
|
||||||
|
credentials=credentials,
|
||||||
|
timeout=timeout,
|
||||||
|
discovery_timeout=discovery_timeout,
|
||||||
|
on_unsupported=print_unsupported,
|
||||||
|
)
|
||||||
|
|
||||||
|
echo(f"Discovering devices on {target} for {discovery_timeout} seconds")
|
||||||
discovered_devices = await Discover.discover(
|
discovered_devices = await Discover.discover(
|
||||||
target=target,
|
target=target,
|
||||||
discovery_timeout=discovery_timeout,
|
discovery_timeout=discovery_timeout,
|
||||||
@ -113,21 +125,31 @@ def _echo_discovery_info(discovery_info):
|
|||||||
_echo_dictionary(discovery_info)
|
_echo_dictionary(discovery_info)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def _conditional_echo(label, value):
|
||||||
|
if value:
|
||||||
|
ws = " " * (19 - len(label))
|
||||||
|
echo(f"\t{label}:{ws}{value}")
|
||||||
|
|
||||||
echo("\t[bold]== Discovery Result ==[/bold]")
|
echo("\t[bold]== Discovery Result ==[/bold]")
|
||||||
echo(f"\tDevice Type: {dr.device_type}")
|
_conditional_echo("Device Type", dr.device_type)
|
||||||
echo(f"\tDevice Model: {dr.device_model}")
|
_conditional_echo("Device Model", dr.device_model)
|
||||||
echo(f"\tIP: {dr.ip}")
|
_conditional_echo("Device Name", dr.device_name)
|
||||||
echo(f"\tMAC: {dr.mac}")
|
_conditional_echo("IP", dr.ip)
|
||||||
echo(f"\tDevice Id (hash): {dr.device_id}")
|
_conditional_echo("MAC", dr.mac)
|
||||||
echo(f"\tOwner (hash): {dr.owner}")
|
_conditional_echo("Device Id (hash)", dr.device_id)
|
||||||
echo(f"\tHW Ver: {dr.hw_ver}")
|
_conditional_echo("Owner (hash)", dr.owner)
|
||||||
echo(f"\tSupports IOT Cloud: {dr.is_support_iot_cloud}")
|
_conditional_echo("FW Ver", dr.firmware_version)
|
||||||
echo(f"\tOBD Src: {dr.obd_src}")
|
_conditional_echo("HW Ver", dr.hw_ver)
|
||||||
echo(f"\tFactory Default: {dr.factory_default}")
|
_conditional_echo("HW Ver", dr.hardware_version)
|
||||||
echo(f"\tEncrypt Type: {dr.mgt_encrypt_schm.encrypt_type}")
|
_conditional_echo("Supports IOT Cloud", dr.is_support_iot_cloud)
|
||||||
echo(f"\tSupports HTTPS: {dr.mgt_encrypt_schm.is_support_https}")
|
_conditional_echo("OBD Src", dr.owner)
|
||||||
echo(f"\tHTTP Port: {dr.mgt_encrypt_schm.http_port}")
|
_conditional_echo("Factory Default", dr.factory_default)
|
||||||
echo(f"\tLV (Login Level): {dr.mgt_encrypt_schm.lv}")
|
_conditional_echo("Encrypt Type", dr.mgt_encrypt_schm.encrypt_type)
|
||||||
|
_conditional_echo("Encrypt Type", dr.encrypt_type)
|
||||||
|
_conditional_echo("Supports HTTPS", dr.mgt_encrypt_schm.is_support_https)
|
||||||
|
_conditional_echo("HTTP Port", dr.mgt_encrypt_schm.http_port)
|
||||||
|
_conditional_echo("Encrypt info", pf(dr.encrypt_info) if dr.encrypt_info else None)
|
||||||
|
_conditional_echo("Decrypted", pf(dr.decrypted_data) if dr.decrypted_data else None)
|
||||||
|
|
||||||
|
|
||||||
async def find_host_from_alias(alias, target="255.255.255.255", timeout=1, attempts=3):
|
async def find_host_from_alias(alias, target="255.255.255.255", timeout=1, attempts=3):
|
||||||
|
@ -158,6 +158,7 @@ def _legacy_type_to_class(_type):
|
|||||||
type=click.Choice(ENCRYPT_TYPES, case_sensitive=False),
|
type=click.Choice(ENCRYPT_TYPES, case_sensitive=False),
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
|
"-df",
|
||||||
"--device-family",
|
"--device-family",
|
||||||
envvar="KASA_DEVICE_FAMILY",
|
envvar="KASA_DEVICE_FAMILY",
|
||||||
default="SMART.TAPOPLUG",
|
default="SMART.TAPOPLUG",
|
||||||
@ -182,7 +183,7 @@ def _legacy_type_to_class(_type):
|
|||||||
@click.option(
|
@click.option(
|
||||||
"--discovery-timeout",
|
"--discovery-timeout",
|
||||||
envvar="KASA_DISCOVERY_TIMEOUT",
|
envvar="KASA_DISCOVERY_TIMEOUT",
|
||||||
default=5,
|
default=10,
|
||||||
required=False,
|
required=False,
|
||||||
show_default=True,
|
show_default=True,
|
||||||
help="Timeout for discovery.",
|
help="Timeout for discovery.",
|
||||||
@ -326,15 +327,11 @@ async def cli(
|
|||||||
dev = await Device.connect(config=config)
|
dev = await Device.connect(config=config)
|
||||||
device_updated = True
|
device_updated = True
|
||||||
else:
|
else:
|
||||||
from kasa.discover import Discover
|
from .discover import discover
|
||||||
|
|
||||||
dev = await Discover.discover_single(
|
dev = await ctx.invoke(discover)
|
||||||
host,
|
if not dev:
|
||||||
port=port,
|
error(f"Unable to create device for {host}")
|
||||||
credentials=credentials,
|
|
||||||
timeout=timeout,
|
|
||||||
discovery_timeout=discovery_timeout,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Skip update on specific commands, or if device factory,
|
# Skip update on specific commands, or if device factory,
|
||||||
# that performs an update was used for the device.
|
# that performs an update was used for the device.
|
||||||
|
110
kasa/discover.py
110
kasa/discover.py
@ -82,13 +82,16 @@ Discovering a single device returns a kasa.Device object.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import base64
|
||||||
import binascii
|
import binascii
|
||||||
import ipaddress
|
import ipaddress
|
||||||
import logging
|
import logging
|
||||||
|
import secrets
|
||||||
import socket
|
import socket
|
||||||
|
import struct
|
||||||
from collections.abc import Awaitable
|
from collections.abc import Awaitable
|
||||||
from pprint import pformat as pf
|
from pprint import pformat as pf
|
||||||
from typing import Any, Callable, Dict, Optional, Type, cast
|
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast
|
||||||
|
|
||||||
# When support for cpython older than 3.11 is dropped
|
# When support for cpython older than 3.11 is dropped
|
||||||
# async_timeout can be replaced with asyncio.timeout
|
# async_timeout can be replaced with asyncio.timeout
|
||||||
@ -96,6 +99,7 @@ from async_timeout import timeout as asyncio_timeout
|
|||||||
from pydantic.v1 import BaseModel, ValidationError
|
from pydantic.v1 import BaseModel, ValidationError
|
||||||
|
|
||||||
from kasa import Device
|
from kasa import Device
|
||||||
|
from kasa.aestransport import AesEncyptionSession, KeyPair
|
||||||
from kasa.credentials import Credentials
|
from kasa.credentials import Credentials
|
||||||
from kasa.device_factory import (
|
from kasa.device_factory import (
|
||||||
get_device_class_from_family,
|
get_device_class_from_family,
|
||||||
@ -133,6 +137,46 @@ NEW_DISCOVERY_REDACTORS: dict[str, Callable[[Any], Any] | None] = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class _AesDiscoveryQuery:
|
||||||
|
keypair: KeyPair | None = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def generate_query(cls):
|
||||||
|
if not cls.keypair:
|
||||||
|
cls.keypair = KeyPair.create_key_pair(key_size=2048)
|
||||||
|
secret = secrets.token_bytes(4)
|
||||||
|
|
||||||
|
key_payload = {"params": {"rsa_key": cls.keypair.get_public_pem().decode()}}
|
||||||
|
|
||||||
|
key_payload_bytes = json_dumps(key_payload).encode()
|
||||||
|
# https://labs.withsecure.com/advisories/tp-link-ac1750-pwn2own-2019
|
||||||
|
version = 2 # version of tdp
|
||||||
|
msg_type = 0
|
||||||
|
op_code = 1 # probe
|
||||||
|
msg_size = len(key_payload_bytes)
|
||||||
|
flags = 17
|
||||||
|
padding_byte = 0 # blank byte
|
||||||
|
device_serial = int.from_bytes(secret, "big")
|
||||||
|
initial_crc = 0x5A6B7C8D
|
||||||
|
|
||||||
|
disco_header = struct.pack(
|
||||||
|
">BBHHBBII",
|
||||||
|
version,
|
||||||
|
msg_type,
|
||||||
|
op_code,
|
||||||
|
msg_size,
|
||||||
|
flags,
|
||||||
|
padding_byte,
|
||||||
|
device_serial,
|
||||||
|
initial_crc,
|
||||||
|
)
|
||||||
|
|
||||||
|
query = bytearray(disco_header + key_payload_bytes)
|
||||||
|
crc = binascii.crc32(query).to_bytes(length=4, byteorder="big")
|
||||||
|
query[12:16] = crc
|
||||||
|
return query
|
||||||
|
|
||||||
|
|
||||||
class _DiscoverProtocol(asyncio.DatagramProtocol):
|
class _DiscoverProtocol(asyncio.DatagramProtocol):
|
||||||
"""Implementation of the discovery protocol handler.
|
"""Implementation of the discovery protocol handler.
|
||||||
|
|
||||||
@ -224,15 +268,21 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
|
|||||||
_LOGGER.debug("[DISCOVERY] %s >> %s", self.target, Discover.DISCOVERY_QUERY)
|
_LOGGER.debug("[DISCOVERY] %s >> %s", self.target, Discover.DISCOVERY_QUERY)
|
||||||
encrypted_req = XorEncryption.encrypt(req)
|
encrypted_req = XorEncryption.encrypt(req)
|
||||||
sleep_between_packets = self.discovery_timeout / self.discovery_packets
|
sleep_between_packets = self.discovery_timeout / self.discovery_packets
|
||||||
|
|
||||||
|
aes_discovery_query = _AesDiscoveryQuery.generate_query()
|
||||||
for _ in range(self.discovery_packets):
|
for _ in range(self.discovery_packets):
|
||||||
if self.target in self.seen_hosts: # Stop sending for discover_single
|
if self.target in self.seen_hosts: # Stop sending for discover_single
|
||||||
break
|
break
|
||||||
self.transport.sendto(encrypted_req[4:], self.target_1) # type: ignore
|
self.transport.sendto(encrypted_req[4:], self.target_1) # type: ignore
|
||||||
self.transport.sendto(Discover.DISCOVERY_QUERY_2, self.target_2) # type: ignore
|
self.transport.sendto(Discover.DISCOVERY_QUERY_2, self.target_2) # type: ignore
|
||||||
|
self.transport.sendto(aes_discovery_query, self.target_2) # type: ignore
|
||||||
await asyncio.sleep(sleep_between_packets)
|
await asyncio.sleep(sleep_between_packets)
|
||||||
|
|
||||||
def datagram_received(self, data, addr) -> None:
|
def datagram_received(self, data, addr) -> None:
|
||||||
"""Handle discovery responses."""
|
"""Handle discovery responses."""
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
assert _AesDiscoveryQuery.keypair
|
||||||
|
|
||||||
ip, port = addr
|
ip, port = addr
|
||||||
# Prevent multiple entries due multiple broadcasts
|
# Prevent multiple entries due multiple broadcasts
|
||||||
if ip in self.seen_hosts:
|
if ip in self.seen_hosts:
|
||||||
@ -395,7 +445,8 @@ class Discover:
|
|||||||
credentials: Credentials | None = None,
|
credentials: Credentials | None = None,
|
||||||
username: str | None = None,
|
username: str | None = None,
|
||||||
password: str | None = None,
|
password: str | None = None,
|
||||||
) -> Device:
|
on_unsupported: OnUnsupportedCallable | None = None,
|
||||||
|
) -> Device | None:
|
||||||
"""Discover a single device by the given IP address.
|
"""Discover a single device by the given IP address.
|
||||||
|
|
||||||
It is generally preferred to avoid :func:`discover_single()` and
|
It is generally preferred to avoid :func:`discover_single()` and
|
||||||
@ -465,6 +516,10 @@ class Discover:
|
|||||||
dev.host = host
|
dev.host = host
|
||||||
return dev
|
return dev
|
||||||
elif ip in protocol.unsupported_device_exceptions:
|
elif ip in protocol.unsupported_device_exceptions:
|
||||||
|
if on_unsupported:
|
||||||
|
await on_unsupported(protocol.unsupported_device_exceptions[ip])
|
||||||
|
return None
|
||||||
|
else:
|
||||||
raise protocol.unsupported_device_exceptions[ip]
|
raise protocol.unsupported_device_exceptions[ip]
|
||||||
elif ip in protocol.invalid_device_exceptions:
|
elif ip in protocol.invalid_device_exceptions:
|
||||||
raise protocol.invalid_device_exceptions[ip]
|
raise protocol.invalid_device_exceptions[ip]
|
||||||
@ -512,6 +567,25 @@ class Discover:
|
|||||||
device.update_from_discover_info(info)
|
device.update_from_discover_info(info)
|
||||||
return device
|
return device
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _decrypt_discovery_data(discovery_result: DiscoveryResult) -> None:
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
assert discovery_result.encrypt_info
|
||||||
|
assert _AesDiscoveryQuery.keypair
|
||||||
|
encryped_key = discovery_result.encrypt_info.key
|
||||||
|
encrypted_data = discovery_result.encrypt_info.data
|
||||||
|
|
||||||
|
key_and_iv = _AesDiscoveryQuery.keypair.decrypt_discovery_key(
|
||||||
|
base64.b64decode(encryped_key.encode())
|
||||||
|
)
|
||||||
|
|
||||||
|
key, iv = key_and_iv[:16], key_and_iv[16:]
|
||||||
|
|
||||||
|
session = AesEncyptionSession(key, iv)
|
||||||
|
decrypted_data = session.decrypt(encrypted_data)
|
||||||
|
|
||||||
|
discovery_result.decrypted_data = json_loads(decrypted_data)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_device_instance(
|
def _get_device_instance(
|
||||||
data: bytes,
|
data: bytes,
|
||||||
@ -528,6 +602,8 @@ class Discover:
|
|||||||
) from ex
|
) from ex
|
||||||
try:
|
try:
|
||||||
discovery_result = DiscoveryResult(**info["result"])
|
discovery_result = DiscoveryResult(**info["result"])
|
||||||
|
if discovery_result.encrypt_info:
|
||||||
|
Discover._decrypt_discovery_data(discovery_result)
|
||||||
except ValidationError as ex:
|
except ValidationError as ex:
|
||||||
if debug_enabled:
|
if debug_enabled:
|
||||||
data = (
|
data = (
|
||||||
@ -547,9 +623,19 @@ class Discover:
|
|||||||
type_ = discovery_result.device_type
|
type_ = discovery_result.device_type
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if not (
|
||||||
|
encrypt_type := discovery_result.mgt_encrypt_schm.encrypt_type
|
||||||
|
) and (encrypt_info := discovery_result.encrypt_info):
|
||||||
|
encrypt_type = encrypt_info.sym_schm
|
||||||
|
if not encrypt_type:
|
||||||
|
raise UnsupportedDeviceError(
|
||||||
|
f"Unsupported device {config.host} of type {type_} "
|
||||||
|
+ "with no encryption type",
|
||||||
|
discovery_result=discovery_result.get_dict(),
|
||||||
|
)
|
||||||
config.connection_type = DeviceConnectionParameters.from_values(
|
config.connection_type = DeviceConnectionParameters.from_values(
|
||||||
type_,
|
type_,
|
||||||
discovery_result.mgt_encrypt_schm.encrypt_type,
|
encrypt_type,
|
||||||
discovery_result.mgt_encrypt_schm.lv,
|
discovery_result.mgt_encrypt_schm.lv,
|
||||||
)
|
)
|
||||||
except KasaException as ex:
|
except KasaException as ex:
|
||||||
@ -593,21 +679,35 @@ class EncryptionScheme(BaseModel):
|
|||||||
"""Base model for encryption scheme of discovery result."""
|
"""Base model for encryption scheme of discovery result."""
|
||||||
|
|
||||||
is_support_https: bool
|
is_support_https: bool
|
||||||
encrypt_type: str
|
encrypt_type: Optional[str] # noqa: UP007
|
||||||
http_port: int
|
http_port: Optional[int] = None # noqa: UP007
|
||||||
lv: Optional[int] = None # noqa: UP007
|
lv: Optional[int] = None # noqa: UP007
|
||||||
|
|
||||||
|
|
||||||
|
class EncryptionInfo(BaseModel):
|
||||||
|
"""Base model for encryption info of discovery result."""
|
||||||
|
|
||||||
|
sym_schm: str
|
||||||
|
key: str
|
||||||
|
data: str
|
||||||
|
|
||||||
|
|
||||||
class DiscoveryResult(BaseModel):
|
class DiscoveryResult(BaseModel):
|
||||||
"""Base model for discovery result."""
|
"""Base model for discovery result."""
|
||||||
|
|
||||||
device_type: str
|
device_type: str
|
||||||
device_model: str
|
device_model: str
|
||||||
|
device_name: Optional[str] # noqa: UP007
|
||||||
ip: str
|
ip: str
|
||||||
mac: str
|
mac: str
|
||||||
mgt_encrypt_schm: EncryptionScheme
|
mgt_encrypt_schm: EncryptionScheme
|
||||||
|
encrypt_info: Optional[EncryptionInfo] = None # noqa: UP007
|
||||||
|
encrypt_type: Optional[list[str]] = None # noqa: UP007
|
||||||
|
decrypted_data: Optional[dict] = None # noqa: UP007
|
||||||
device_id: str
|
device_id: str
|
||||||
|
|
||||||
|
firmware_version: Optional[str] = None # noqa: UP007
|
||||||
|
hardware_version: Optional[str] = None # noqa: UP007
|
||||||
hw_ver: Optional[str] = None # noqa: UP007
|
hw_ver: Optional[str] = None # noqa: UP007
|
||||||
owner: Optional[str] = None # noqa: UP007
|
owner: Optional[str] = None # noqa: UP007
|
||||||
is_support_iot_cloud: Optional[bool] = None # noqa: UP007
|
is_support_iot_cloud: Optional[bool] = None # noqa: UP007
|
||||||
|
@ -99,8 +99,8 @@ async def test_handshake_with_keys(mocker):
|
|||||||
assert transport._state is TransportState.HANDSHAKE_REQUIRED
|
assert transport._state is TransportState.HANDSHAKE_REQUIRED
|
||||||
|
|
||||||
await transport.perform_handshake()
|
await transport.perform_handshake()
|
||||||
assert transport._key_pair.get_private_key() == test_keys["private"]
|
assert transport._key_pair.private_key_der_b64 == test_keys["private"]
|
||||||
assert transport._key_pair.get_public_key() == test_keys["public"]
|
assert transport._key_pair.public_key_der_b64 == test_keys["public"]
|
||||||
|
|
||||||
|
|
||||||
@status_parameters
|
@status_parameters
|
||||||
|
@ -2,6 +2,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from unittest.mock import ANY
|
||||||
|
|
||||||
import asyncclick as click
|
import asyncclick as click
|
||||||
import pytest
|
import pytest
|
||||||
@ -17,7 +18,6 @@ from kasa import (
|
|||||||
EmeterStatus,
|
EmeterStatus,
|
||||||
KasaException,
|
KasaException,
|
||||||
Module,
|
Module,
|
||||||
UnsupportedDeviceError,
|
|
||||||
)
|
)
|
||||||
from kasa.cli.device import (
|
from kasa.cli.device import (
|
||||||
alias,
|
alias,
|
||||||
@ -613,6 +613,7 @@ async def test_without_device_type(dev, mocker, runner):
|
|||||||
credentials=Credentials("foo", "bar"),
|
credentials=Credentials("foo", "bar"),
|
||||||
timeout=5,
|
timeout=5,
|
||||||
discovery_timeout=7,
|
discovery_timeout=7,
|
||||||
|
on_unsupported=ANY,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -735,7 +736,7 @@ async def test_host_unsupported(unsupported_device_info, runner):
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert res.exit_code != 0
|
assert res.exit_code != 0
|
||||||
assert isinstance(res.exception, UnsupportedDeviceError)
|
assert "== Unsupported device ==" in res.output
|
||||||
|
|
||||||
|
|
||||||
@new_discovery
|
@new_discovery
|
||||||
|
@ -2,6 +2,8 @@
|
|||||||
# ruff: noqa: S106
|
# ruff: noqa: S106
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import socket
|
import socket
|
||||||
@ -10,6 +12,8 @@ from unittest.mock import MagicMock
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
|
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
|
||||||
from async_timeout import timeout as asyncio_timeout
|
from async_timeout import timeout as asyncio_timeout
|
||||||
|
from cryptography.hazmat.primitives import hashes
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
|
||||||
|
|
||||||
from kasa import (
|
from kasa import (
|
||||||
Credentials,
|
Credentials,
|
||||||
@ -18,11 +22,17 @@ from kasa import (
|
|||||||
Discover,
|
Discover,
|
||||||
KasaException,
|
KasaException,
|
||||||
)
|
)
|
||||||
|
from kasa.aestransport import AesEncyptionSession
|
||||||
from kasa.deviceconfig import (
|
from kasa.deviceconfig import (
|
||||||
DeviceConfig,
|
DeviceConfig,
|
||||||
DeviceConnectionParameters,
|
DeviceConnectionParameters,
|
||||||
)
|
)
|
||||||
from kasa.discover import DiscoveryResult, _DiscoverProtocol, json_dumps
|
from kasa.discover import (
|
||||||
|
DiscoveryResult,
|
||||||
|
_AesDiscoveryQuery,
|
||||||
|
_DiscoverProtocol,
|
||||||
|
json_dumps,
|
||||||
|
)
|
||||||
from kasa.exceptions import AuthenticationError, UnsupportedDeviceError
|
from kasa.exceptions import AuthenticationError, UnsupportedDeviceError
|
||||||
from kasa.iot import IotDevice
|
from kasa.iot import IotDevice
|
||||||
from kasa.xortransport import XorEncryption
|
from kasa.xortransport import XorEncryption
|
||||||
@ -278,7 +288,7 @@ async def test_discover_send(mocker):
|
|||||||
assert proto.target_1 == ("255.255.255.255", 9999)
|
assert proto.target_1 == ("255.255.255.255", 9999)
|
||||||
transport = mocker.patch.object(proto, "transport")
|
transport = mocker.patch.object(proto, "transport")
|
||||||
await proto.do_discover()
|
await proto.do_discover()
|
||||||
assert transport.sendto.call_count == proto.discovery_packets * 2
|
assert transport.sendto.call_count == proto.discovery_packets * 3
|
||||||
|
|
||||||
|
|
||||||
async def test_discover_datagram_received(mocker, discovery_data):
|
async def test_discover_datagram_received(mocker, discovery_data):
|
||||||
@ -485,13 +495,14 @@ async def test_do_discover_drop_packets(mocker, port, do_not_reply_count):
|
|||||||
discovery_timeout=discovery_timeout,
|
discovery_timeout=discovery_timeout,
|
||||||
discovery_packets=5,
|
discovery_packets=5,
|
||||||
)
|
)
|
||||||
ft = FakeDatagramTransport(dp, port, do_not_reply_count)
|
expected_send = 1 if port == 9999 else 2
|
||||||
|
ft = FakeDatagramTransport(dp, port, do_not_reply_count * expected_send)
|
||||||
dp.connection_made(ft)
|
dp.connection_made(ft)
|
||||||
|
|
||||||
await dp.wait_for_discovery_to_complete()
|
await dp.wait_for_discovery_to_complete()
|
||||||
|
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
assert ft.send_count == do_not_reply_count + 1
|
assert ft.send_count == do_not_reply_count * expected_send + expected_send
|
||||||
assert dp.discover_task.done()
|
assert dp.discover_task.done()
|
||||||
assert dp.discover_task.cancelled()
|
assert dp.discover_task.cancelled()
|
||||||
|
|
||||||
@ -603,3 +614,36 @@ async def test_discovery_redaction(discovery_mock, caplog: pytest.LogCaptureFixt
|
|||||||
await Discover.discover()
|
await Discover.discover()
|
||||||
assert mac not in caplog.text
|
assert mac not in caplog.text
|
||||||
assert "12:34:56:00:00:00" in caplog.text
|
assert "12:34:56:00:00:00" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
async def test_discovery_decryption():
|
||||||
|
"""Test discovery decryption."""
|
||||||
|
key = b"8\x89\x02\xfa\xf5Xs\x1c\xa1 H\x9a\x82\xc7\xd9\t"
|
||||||
|
iv = b"9=\xf8\x1bS\xcd0\xb5\x89i\xba\xfd^9\x9f\xfa"
|
||||||
|
key_iv = key + iv
|
||||||
|
|
||||||
|
_AesDiscoveryQuery.generate_query()
|
||||||
|
keypair = _AesDiscoveryQuery.keypair
|
||||||
|
|
||||||
|
padding = asymmetric_padding.OAEP(
|
||||||
|
mgf=asymmetric_padding.MGF1(algorithm=hashes.SHA1()), # noqa: S303
|
||||||
|
algorithm=hashes.SHA1(), # noqa: S303
|
||||||
|
label=None,
|
||||||
|
)
|
||||||
|
encrypted_key_iv = keypair.public_key.encrypt(key_iv, padding)
|
||||||
|
encrypted_key_iv_b4 = base64.b64encode(encrypted_key_iv)
|
||||||
|
encryption_session = AesEncyptionSession(key_iv[:16], key_iv[16:])
|
||||||
|
|
||||||
|
data_dict = {"foo": 1, "bar": 2}
|
||||||
|
data = json.dumps(data_dict)
|
||||||
|
encypted_data = encryption_session.encrypt(data.encode())
|
||||||
|
|
||||||
|
encrypt_info = {
|
||||||
|
"data": encypted_data.decode(),
|
||||||
|
"key": encrypted_key_iv_b4.decode(),
|
||||||
|
"sym_schm": "AES",
|
||||||
|
}
|
||||||
|
info = {**UNSUPPORTED["result"], "encrypt_info": encrypt_info}
|
||||||
|
dr = DiscoveryResult(**info)
|
||||||
|
Discover._decrypt_discovery_data(dr)
|
||||||
|
assert dr.decrypted_data == data_dict
|
||||||
|
Loading…
Reference in New Issue
Block a user