diff --git a/kasa/aestransport.py b/kasa/aestransport.py index 0048bd12..ae75117c 100644 --- a/kasa/aestransport.py +++ b/kasa/aestransport.py @@ -14,7 +14,7 @@ from collections.abc import AsyncGenerator from enum import Enum, auto from typing import TYPE_CHECKING, Any, Dict, cast -from cryptography.hazmat.primitives import padding, serialization +from cryptography.hazmat.primitives import hashes, padding, serialization from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes @@ -108,7 +108,9 @@ class AesTransport(BaseTransport): self._key_pair: KeyPair | None = None if config.aes_keys: aes_keys = config.aes_keys - self._key_pair = KeyPair(aes_keys["private"], aes_keys["public"]) + self._key_pair = KeyPair.create_from_der_keys( + aes_keys["private"], aes_keys["public"] + ) self._app_url = URL(f"http://{self._host}:{self._port}/app") self._token_url: URL | None = None @@ -277,14 +279,14 @@ class AesTransport(BaseTransport): if not self._key_pair: kp = KeyPair.create_key_pair() self._config.aes_keys = { - "private": kp.get_private_key(), - "public": kp.get_public_key(), + "private": kp.private_key_der_b64, + "public": kp.public_key_der_b64, } self._key_pair = kp pub_key = ( "-----BEGIN PUBLIC KEY-----\n" - + self._key_pair.get_public_key() # type: ignore[union-attr] + + self._key_pair.public_key_der_b64 # type: ignore[union-attr] + "\n-----END PUBLIC KEY-----\n" ) handshake_params = {"key": pub_key} @@ -392,18 +394,11 @@ class AesEncyptionSession: """Class for an AES encryption session.""" @staticmethod - def create_from_keypair(handshake_key: str, keypair): + def create_from_keypair(handshake_key: str, keypair: KeyPair): """Create the encryption session.""" - handshake_key_bytes: bytes = base64.b64decode(handshake_key.encode("UTF-8")) - private_key_data = base64.b64decode(keypair.get_private_key().encode("UTF-8")) + handshake_key_bytes: bytes = base64.b64decode(handshake_key.encode()) - private_key = cast( - rsa.RSAPrivateKey, - serialization.load_der_private_key(private_key_data, None, None), - ) - key_and_iv = private_key.decrypt( - handshake_key_bytes, asymmetric_padding.PKCS1v15() - ) + key_and_iv = keypair.decrypt_handshake_key(handshake_key_bytes) if key_and_iv is None: raise ValueError("Decryption failed!") @@ -438,30 +433,59 @@ class KeyPair: """Create a key pair.""" private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size) public_key = private_key.public_key() + return KeyPair(private_key, public_key) - private_key_bytes = private_key.private_bytes( + @staticmethod + def create_from_der_keys(private_key_der_b64: str, public_key_der_b64: str): + """Create a key pair.""" + key_bytes = base64.b64decode(private_key_der_b64.encode()) + private_key = cast( + rsa.RSAPrivateKey, serialization.load_der_private_key(key_bytes, None) + ) + key_bytes = base64.b64decode(public_key_der_b64.encode()) + public_key = cast( + rsa.RSAPublicKey, serialization.load_der_public_key(key_bytes, None) + ) + + return KeyPair(private_key, public_key) + + def __init__(self, private_key: rsa.RSAPrivateKey, public_key: rsa.RSAPublicKey): + self.private_key = private_key + self.public_key = public_key + self.private_key_der_bytes = self.private_key.private_bytes( encoding=serialization.Encoding.DER, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption(), ) - public_key_bytes = public_key.public_bytes( + self.public_key_der_bytes = self.public_key.public_bytes( encoding=serialization.Encoding.DER, format=serialization.PublicFormat.SubjectPublicKeyInfo, ) + self.private_key_der_b64 = base64.b64encode(self.private_key_der_bytes).decode() + self.public_key_der_b64 = base64.b64encode(self.public_key_der_bytes).decode() - return KeyPair( - private_key=base64.b64encode(private_key_bytes).decode("UTF-8"), - public_key=base64.b64encode(public_key_bytes).decode("UTF-8"), + def get_public_pem(self) -> bytes: + """Get public key in PEM encoding.""" + return self.public_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, ) - def __init__(self, private_key: str, public_key: str): - self.private_key = private_key - self.public_key = public_key + def decrypt_handshake_key(self, encrypted_key: bytes) -> bytes: + """Decrypt an aes handshake key.""" + decrypted = self.private_key.decrypt( + encrypted_key, asymmetric_padding.PKCS1v15() + ) + return decrypted - def get_private_key(self) -> str: - """Get the private key.""" - return self.private_key - - def get_public_key(self) -> str: - """Get the public key.""" - return self.public_key + def decrypt_discovery_key(self, encrypted_key: bytes) -> bytes: + """Decrypt an aes discovery key.""" + decrypted = self.private_key.decrypt( + encrypted_key, + asymmetric_padding.OAEP( + mgf=asymmetric_padding.MGF1(algorithm=hashes.SHA1()), # noqa: S303 + algorithm=hashes.SHA1(), # noqa: S303 + label=None, + ), + ) + return decrypted diff --git a/kasa/cli/discover.py b/kasa/cli/discover.py index 6bf58e72..78f426f5 100644 --- a/kasa/cli/discover.py +++ b/kasa/cli/discover.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +from pprint import pformat as pf import asyncclick as click from pydantic.v1 import ValidationError @@ -28,6 +29,7 @@ async def discover(ctx): password = ctx.parent.params["password"] discovery_timeout = ctx.parent.params["discovery_timeout"] timeout = ctx.parent.params["timeout"] + host = ctx.parent.params["host"] port = ctx.parent.params["port"] credentials = Credentials(username, password) if username and password else None @@ -49,8 +51,6 @@ async def discover(ctx): echo(f"\t{unsupported_exception}") echo() - echo(f"Discovering devices on {target} for {discovery_timeout} seconds") - from .device import state async def print_discovered(dev: Device): @@ -68,6 +68,18 @@ async def discover(ctx): discovered[dev.host] = dev.internal_state echo() + if host: + echo(f"Discovering device {host} for {discovery_timeout} seconds") + return await Discover.discover_single( + host, + port=port, + credentials=credentials, + timeout=timeout, + discovery_timeout=discovery_timeout, + on_unsupported=print_unsupported, + ) + + echo(f"Discovering devices on {target} for {discovery_timeout} seconds") discovered_devices = await Discover.discover( target=target, discovery_timeout=discovery_timeout, @@ -113,21 +125,31 @@ def _echo_discovery_info(discovery_info): _echo_dictionary(discovery_info) return + def _conditional_echo(label, value): + if value: + ws = " " * (19 - len(label)) + echo(f"\t{label}:{ws}{value}") + echo("\t[bold]== Discovery Result ==[/bold]") - echo(f"\tDevice Type: {dr.device_type}") - echo(f"\tDevice Model: {dr.device_model}") - echo(f"\tIP: {dr.ip}") - echo(f"\tMAC: {dr.mac}") - echo(f"\tDevice Id (hash): {dr.device_id}") - echo(f"\tOwner (hash): {dr.owner}") - echo(f"\tHW Ver: {dr.hw_ver}") - echo(f"\tSupports IOT Cloud: {dr.is_support_iot_cloud}") - echo(f"\tOBD Src: {dr.obd_src}") - echo(f"\tFactory Default: {dr.factory_default}") - echo(f"\tEncrypt Type: {dr.mgt_encrypt_schm.encrypt_type}") - echo(f"\tSupports HTTPS: {dr.mgt_encrypt_schm.is_support_https}") - echo(f"\tHTTP Port: {dr.mgt_encrypt_schm.http_port}") - echo(f"\tLV (Login Level): {dr.mgt_encrypt_schm.lv}") + _conditional_echo("Device Type", dr.device_type) + _conditional_echo("Device Model", dr.device_model) + _conditional_echo("Device Name", dr.device_name) + _conditional_echo("IP", dr.ip) + _conditional_echo("MAC", dr.mac) + _conditional_echo("Device Id (hash)", dr.device_id) + _conditional_echo("Owner (hash)", dr.owner) + _conditional_echo("FW Ver", dr.firmware_version) + _conditional_echo("HW Ver", dr.hw_ver) + _conditional_echo("HW Ver", dr.hardware_version) + _conditional_echo("Supports IOT Cloud", dr.is_support_iot_cloud) + _conditional_echo("OBD Src", dr.owner) + _conditional_echo("Factory Default", dr.factory_default) + _conditional_echo("Encrypt Type", dr.mgt_encrypt_schm.encrypt_type) + _conditional_echo("Encrypt Type", dr.encrypt_type) + _conditional_echo("Supports HTTPS", dr.mgt_encrypt_schm.is_support_https) + _conditional_echo("HTTP Port", dr.mgt_encrypt_schm.http_port) + _conditional_echo("Encrypt info", pf(dr.encrypt_info) if dr.encrypt_info else None) + _conditional_echo("Decrypted", pf(dr.decrypted_data) if dr.decrypted_data else None) async def find_host_from_alias(alias, target="255.255.255.255", timeout=1, attempts=3): diff --git a/kasa/cli/main.py b/kasa/cli/main.py index 88b768c4..1550b7af 100755 --- a/kasa/cli/main.py +++ b/kasa/cli/main.py @@ -158,6 +158,7 @@ def _legacy_type_to_class(_type): type=click.Choice(ENCRYPT_TYPES, case_sensitive=False), ) @click.option( + "-df", "--device-family", envvar="KASA_DEVICE_FAMILY", default="SMART.TAPOPLUG", @@ -182,7 +183,7 @@ def _legacy_type_to_class(_type): @click.option( "--discovery-timeout", envvar="KASA_DISCOVERY_TIMEOUT", - default=5, + default=10, required=False, show_default=True, help="Timeout for discovery.", @@ -326,15 +327,11 @@ async def cli( dev = await Device.connect(config=config) device_updated = True else: - from kasa.discover import Discover + from .discover import discover - dev = await Discover.discover_single( - host, - port=port, - credentials=credentials, - timeout=timeout, - discovery_timeout=discovery_timeout, - ) + dev = await ctx.invoke(discover) + if not dev: + error(f"Unable to create device for {host}") # Skip update on specific commands, or if device factory, # that performs an update was used for the device. diff --git a/kasa/discover.py b/kasa/discover.py index a1bc28a3..9d615398 100755 --- a/kasa/discover.py +++ b/kasa/discover.py @@ -82,13 +82,16 @@ Discovering a single device returns a kasa.Device object. from __future__ import annotations import asyncio +import base64 import binascii import ipaddress import logging +import secrets import socket +import struct from collections.abc import Awaitable from pprint import pformat as pf -from typing import Any, Callable, Dict, Optional, Type, cast +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast # When support for cpython older than 3.11 is dropped # async_timeout can be replaced with asyncio.timeout @@ -96,6 +99,7 @@ from async_timeout import timeout as asyncio_timeout from pydantic.v1 import BaseModel, ValidationError from kasa import Device +from kasa.aestransport import AesEncyptionSession, KeyPair from kasa.credentials import Credentials from kasa.device_factory import ( get_device_class_from_family, @@ -133,6 +137,46 @@ NEW_DISCOVERY_REDACTORS: dict[str, Callable[[Any], Any] | None] = { } +class _AesDiscoveryQuery: + keypair: KeyPair | None = None + + @classmethod + def generate_query(cls): + if not cls.keypair: + cls.keypair = KeyPair.create_key_pair(key_size=2048) + secret = secrets.token_bytes(4) + + key_payload = {"params": {"rsa_key": cls.keypair.get_public_pem().decode()}} + + key_payload_bytes = json_dumps(key_payload).encode() + # https://labs.withsecure.com/advisories/tp-link-ac1750-pwn2own-2019 + version = 2 # version of tdp + msg_type = 0 + op_code = 1 # probe + msg_size = len(key_payload_bytes) + flags = 17 + padding_byte = 0 # blank byte + device_serial = int.from_bytes(secret, "big") + initial_crc = 0x5A6B7C8D + + disco_header = struct.pack( + ">BBHHBBII", + version, + msg_type, + op_code, + msg_size, + flags, + padding_byte, + device_serial, + initial_crc, + ) + + query = bytearray(disco_header + key_payload_bytes) + crc = binascii.crc32(query).to_bytes(length=4, byteorder="big") + query[12:16] = crc + return query + + class _DiscoverProtocol(asyncio.DatagramProtocol): """Implementation of the discovery protocol handler. @@ -224,15 +268,21 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): _LOGGER.debug("[DISCOVERY] %s >> %s", self.target, Discover.DISCOVERY_QUERY) encrypted_req = XorEncryption.encrypt(req) sleep_between_packets = self.discovery_timeout / self.discovery_packets + + aes_discovery_query = _AesDiscoveryQuery.generate_query() for _ in range(self.discovery_packets): if self.target in self.seen_hosts: # Stop sending for discover_single break self.transport.sendto(encrypted_req[4:], self.target_1) # type: ignore self.transport.sendto(Discover.DISCOVERY_QUERY_2, self.target_2) # type: ignore + self.transport.sendto(aes_discovery_query, self.target_2) # type: ignore await asyncio.sleep(sleep_between_packets) def datagram_received(self, data, addr) -> None: """Handle discovery responses.""" + if TYPE_CHECKING: + assert _AesDiscoveryQuery.keypair + ip, port = addr # Prevent multiple entries due multiple broadcasts if ip in self.seen_hosts: @@ -395,7 +445,8 @@ class Discover: credentials: Credentials | None = None, username: str | None = None, password: str | None = None, - ) -> Device: + on_unsupported: OnUnsupportedCallable | None = None, + ) -> Device | None: """Discover a single device by the given IP address. It is generally preferred to avoid :func:`discover_single()` and @@ -465,7 +516,11 @@ class Discover: dev.host = host return dev elif ip in protocol.unsupported_device_exceptions: - raise protocol.unsupported_device_exceptions[ip] + if on_unsupported: + await on_unsupported(protocol.unsupported_device_exceptions[ip]) + return None + else: + raise protocol.unsupported_device_exceptions[ip] elif ip in protocol.invalid_device_exceptions: raise protocol.invalid_device_exceptions[ip] else: @@ -512,6 +567,25 @@ class Discover: device.update_from_discover_info(info) return device + @staticmethod + def _decrypt_discovery_data(discovery_result: DiscoveryResult) -> None: + if TYPE_CHECKING: + assert discovery_result.encrypt_info + assert _AesDiscoveryQuery.keypair + encryped_key = discovery_result.encrypt_info.key + encrypted_data = discovery_result.encrypt_info.data + + key_and_iv = _AesDiscoveryQuery.keypair.decrypt_discovery_key( + base64.b64decode(encryped_key.encode()) + ) + + key, iv = key_and_iv[:16], key_and_iv[16:] + + session = AesEncyptionSession(key, iv) + decrypted_data = session.decrypt(encrypted_data) + + discovery_result.decrypted_data = json_loads(decrypted_data) + @staticmethod def _get_device_instance( data: bytes, @@ -528,6 +602,8 @@ class Discover: ) from ex try: discovery_result = DiscoveryResult(**info["result"]) + if discovery_result.encrypt_info: + Discover._decrypt_discovery_data(discovery_result) except ValidationError as ex: if debug_enabled: data = ( @@ -547,9 +623,19 @@ class Discover: type_ = discovery_result.device_type try: + if not ( + encrypt_type := discovery_result.mgt_encrypt_schm.encrypt_type + ) and (encrypt_info := discovery_result.encrypt_info): + encrypt_type = encrypt_info.sym_schm + if not encrypt_type: + raise UnsupportedDeviceError( + f"Unsupported device {config.host} of type {type_} " + + "with no encryption type", + discovery_result=discovery_result.get_dict(), + ) config.connection_type = DeviceConnectionParameters.from_values( type_, - discovery_result.mgt_encrypt_schm.encrypt_type, + encrypt_type, discovery_result.mgt_encrypt_schm.lv, ) except KasaException as ex: @@ -593,21 +679,35 @@ class EncryptionScheme(BaseModel): """Base model for encryption scheme of discovery result.""" is_support_https: bool - encrypt_type: str - http_port: int + encrypt_type: Optional[str] # noqa: UP007 + http_port: Optional[int] = None # noqa: UP007 lv: Optional[int] = None # noqa: UP007 +class EncryptionInfo(BaseModel): + """Base model for encryption info of discovery result.""" + + sym_schm: str + key: str + data: str + + class DiscoveryResult(BaseModel): """Base model for discovery result.""" device_type: str device_model: str + device_name: Optional[str] # noqa: UP007 ip: str mac: str mgt_encrypt_schm: EncryptionScheme + encrypt_info: Optional[EncryptionInfo] = None # noqa: UP007 + encrypt_type: Optional[list[str]] = None # noqa: UP007 + decrypted_data: Optional[dict] = None # noqa: UP007 device_id: str + firmware_version: Optional[str] = None # noqa: UP007 + hardware_version: Optional[str] = None # noqa: UP007 hw_ver: Optional[str] = None # noqa: UP007 owner: Optional[str] = None # noqa: UP007 is_support_iot_cloud: Optional[bool] = None # noqa: UP007 diff --git a/kasa/tests/test_aestransport.py b/kasa/tests/test_aestransport.py index 53d83858..f1dbfb32 100644 --- a/kasa/tests/test_aestransport.py +++ b/kasa/tests/test_aestransport.py @@ -99,8 +99,8 @@ async def test_handshake_with_keys(mocker): assert transport._state is TransportState.HANDSHAKE_REQUIRED await transport.perform_handshake() - assert transport._key_pair.get_private_key() == test_keys["private"] - assert transport._key_pair.get_public_key() == test_keys["public"] + assert transport._key_pair.private_key_der_b64 == test_keys["private"] + assert transport._key_pair.public_key_der_b64 == test_keys["public"] @status_parameters diff --git a/kasa/tests/test_cli.py b/kasa/tests/test_cli.py index e439644b..553f93d3 100644 --- a/kasa/tests/test_cli.py +++ b/kasa/tests/test_cli.py @@ -2,6 +2,7 @@ import json import os import re from datetime import datetime +from unittest.mock import ANY import asyncclick as click import pytest @@ -17,7 +18,6 @@ from kasa import ( EmeterStatus, KasaException, Module, - UnsupportedDeviceError, ) from kasa.cli.device import ( alias, @@ -613,6 +613,7 @@ async def test_without_device_type(dev, mocker, runner): credentials=Credentials("foo", "bar"), timeout=5, discovery_timeout=7, + on_unsupported=ANY, ) @@ -735,7 +736,7 @@ async def test_host_unsupported(unsupported_device_info, runner): ) assert res.exit_code != 0 - assert isinstance(res.exception, UnsupportedDeviceError) + assert "== Unsupported device ==" in res.output @new_discovery diff --git a/kasa/tests/test_discovery.py b/kasa/tests/test_discovery.py index 15d4af9c..8163d4c1 100644 --- a/kasa/tests/test_discovery.py +++ b/kasa/tests/test_discovery.py @@ -2,6 +2,8 @@ # ruff: noqa: S106 import asyncio +import base64 +import json import logging import re import socket @@ -10,6 +12,8 @@ from unittest.mock import MagicMock import aiohttp import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 from async_timeout import timeout as asyncio_timeout +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding from kasa import ( Credentials, @@ -18,11 +22,17 @@ from kasa import ( Discover, KasaException, ) +from kasa.aestransport import AesEncyptionSession from kasa.deviceconfig import ( DeviceConfig, DeviceConnectionParameters, ) -from kasa.discover import DiscoveryResult, _DiscoverProtocol, json_dumps +from kasa.discover import ( + DiscoveryResult, + _AesDiscoveryQuery, + _DiscoverProtocol, + json_dumps, +) from kasa.exceptions import AuthenticationError, UnsupportedDeviceError from kasa.iot import IotDevice from kasa.xortransport import XorEncryption @@ -278,7 +288,7 @@ async def test_discover_send(mocker): assert proto.target_1 == ("255.255.255.255", 9999) transport = mocker.patch.object(proto, "transport") await proto.do_discover() - assert transport.sendto.call_count == proto.discovery_packets * 2 + assert transport.sendto.call_count == proto.discovery_packets * 3 async def test_discover_datagram_received(mocker, discovery_data): @@ -485,13 +495,14 @@ async def test_do_discover_drop_packets(mocker, port, do_not_reply_count): discovery_timeout=discovery_timeout, discovery_packets=5, ) - ft = FakeDatagramTransport(dp, port, do_not_reply_count) + expected_send = 1 if port == 9999 else 2 + ft = FakeDatagramTransport(dp, port, do_not_reply_count * expected_send) dp.connection_made(ft) await dp.wait_for_discovery_to_complete() await asyncio.sleep(0) - assert ft.send_count == do_not_reply_count + 1 + assert ft.send_count == do_not_reply_count * expected_send + expected_send assert dp.discover_task.done() assert dp.discover_task.cancelled() @@ -603,3 +614,36 @@ async def test_discovery_redaction(discovery_mock, caplog: pytest.LogCaptureFixt await Discover.discover() assert mac not in caplog.text assert "12:34:56:00:00:00" in caplog.text + + +async def test_discovery_decryption(): + """Test discovery decryption.""" + key = b"8\x89\x02\xfa\xf5Xs\x1c\xa1 H\x9a\x82\xc7\xd9\t" + iv = b"9=\xf8\x1bS\xcd0\xb5\x89i\xba\xfd^9\x9f\xfa" + key_iv = key + iv + + _AesDiscoveryQuery.generate_query() + keypair = _AesDiscoveryQuery.keypair + + padding = asymmetric_padding.OAEP( + mgf=asymmetric_padding.MGF1(algorithm=hashes.SHA1()), # noqa: S303 + algorithm=hashes.SHA1(), # noqa: S303 + label=None, + ) + encrypted_key_iv = keypair.public_key.encrypt(key_iv, padding) + encrypted_key_iv_b4 = base64.b64encode(encrypted_key_iv) + encryption_session = AesEncyptionSession(key_iv[:16], key_iv[16:]) + + data_dict = {"foo": 1, "bar": 2} + data = json.dumps(data_dict) + encypted_data = encryption_session.encrypt(data.encode()) + + encrypt_info = { + "data": encypted_data.decode(), + "key": encrypted_key_iv_b4.decode(), + "sym_schm": "AES", + } + info = {**UNSUPPORTED["result"], "encrypt_info": encrypt_info} + dr = DiscoveryResult(**info) + Discover._decrypt_discovery_data(dr) + assert dr.decrypted_data == data_dict