Return raw discovery result in cli discover raw (#1342)

Add `on_discovered_raw` callback to Discover and adds a cli command `discover raw` which returns the raw json before serializing to a `DiscoveryResult` and attempting to create a device class.
This commit is contained in:
Steven B. 2024-12-10 22:42:14 +00:00 committed by GitHub
parent 464683e09b
commit bf8f0adabe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 158 additions and 23 deletions

View File

@ -14,9 +14,17 @@ from kasa import (
Discover, Discover,
UnsupportedDeviceError, UnsupportedDeviceError,
) )
from kasa.discover import ConnectAttempt, DiscoveryResult from kasa.discover import (
NEW_DISCOVERY_REDACTORS,
ConnectAttempt,
DiscoveredRaw,
DiscoveryResult,
)
from kasa.iot.iotdevice import _extract_sys_info from kasa.iot.iotdevice import _extract_sys_info
from kasa.protocols.iotprotocol import REDACTORS as IOT_REDACTORS
from kasa.protocols.protocol import redact_data
from ..json import dumps as json_dumps
from .common import echo, error from .common import echo, error
@ -64,7 +72,9 @@ async def detail(ctx):
await ctx.parent.invoke(state) await ctx.parent.invoke(state)
echo() echo()
discovered = await _discover(ctx, print_discovered, print_unsupported) discovered = await _discover(
ctx, print_discovered=print_discovered, print_unsupported=print_unsupported
)
if ctx.parent.parent.params["host"]: if ctx.parent.parent.params["host"]:
return discovered return discovered
@ -77,6 +87,33 @@ async def detail(ctx):
return discovered return discovered
@discover.command()
@click.option(
"--redact/--no-redact",
default=False,
is_flag=True,
type=bool,
help="Set flag to redact sensitive data from raw output.",
)
@click.pass_context
async def raw(ctx, redact: bool):
"""Return raw discovery data returned from devices."""
def print_raw(discovered: DiscoveredRaw):
if redact:
redactors = (
NEW_DISCOVERY_REDACTORS
if discovered["meta"]["port"] == Discover.DISCOVERY_PORT_2
else IOT_REDACTORS
)
discovered["discovery_response"] = redact_data(
discovered["discovery_response"], redactors
)
echo(json_dumps(discovered, indent=True))
return await _discover(ctx, print_raw=print_raw, do_echo=False)
@discover.command() @discover.command()
@click.pass_context @click.pass_context
async def list(ctx): async def list(ctx):
@ -102,10 +139,17 @@ async def list(ctx):
echo(f"{host:<15} UNSUPPORTED DEVICE") echo(f"{host:<15} UNSUPPORTED DEVICE")
echo(f"{'HOST':<15} {'DEVICE FAMILY':<20} {'ENCRYPT':<7} {'ALIAS'}") echo(f"{'HOST':<15} {'DEVICE FAMILY':<20} {'ENCRYPT':<7} {'ALIAS'}")
return await _discover(ctx, print_discovered, print_unsupported, do_echo=False) return await _discover(
ctx,
print_discovered=print_discovered,
print_unsupported=print_unsupported,
do_echo=False,
)
async def _discover(ctx, print_discovered, print_unsupported, *, do_echo=True): async def _discover(
ctx, *, print_discovered=None, print_unsupported=None, print_raw=None, do_echo=True
):
params = ctx.parent.parent.params params = ctx.parent.parent.params
target = params["target"] target = params["target"]
username = params["username"] username = params["username"]
@ -126,6 +170,7 @@ async def _discover(ctx, print_discovered, print_unsupported, *, do_echo=True):
timeout=timeout, timeout=timeout,
discovery_timeout=discovery_timeout, discovery_timeout=discovery_timeout,
on_unsupported=print_unsupported, on_unsupported=print_unsupported,
on_discovered_raw=print_raw,
) )
if do_echo: if do_echo:
echo(f"Discovering devices on {target} for {discovery_timeout} seconds") echo(f"Discovering devices on {target} for {discovery_timeout} seconds")
@ -137,6 +182,7 @@ async def _discover(ctx, print_discovered, print_unsupported, *, do_echo=True):
port=port, port=port,
timeout=timeout, timeout=timeout,
credentials=credentials, credentials=credentials,
on_discovered_raw=print_raw,
) )
for device in discovered_devices.values(): for device in discovered_devices.values():

View File

@ -99,6 +99,7 @@ from typing import (
Annotated, Annotated,
Any, Any,
NamedTuple, NamedTuple,
TypedDict,
cast, cast,
) )
@ -147,18 +148,35 @@ class ConnectAttempt(NamedTuple):
device: type device: type
class DiscoveredMeta(TypedDict):
"""Meta info about discovery response."""
ip: str
port: int
class DiscoveredRaw(TypedDict):
"""Try to connect attempt."""
meta: DiscoveredMeta
discovery_response: dict
OnDiscoveredCallable = Callable[[Device], Coroutine] OnDiscoveredCallable = Callable[[Device], Coroutine]
OnDiscoveredRawCallable = Callable[[DiscoveredRaw], None]
OnUnsupportedCallable = Callable[[UnsupportedDeviceError], Coroutine] OnUnsupportedCallable = Callable[[UnsupportedDeviceError], Coroutine]
OnConnectAttemptCallable = Callable[[ConnectAttempt, bool], None] OnConnectAttemptCallable = Callable[[ConnectAttempt, bool], None]
DeviceDict = dict[str, Device] DeviceDict = dict[str, Device]
NEW_DISCOVERY_REDACTORS: dict[str, Callable[[Any], Any] | None] = { NEW_DISCOVERY_REDACTORS: dict[str, Callable[[Any], Any] | None] = {
"device_id": lambda x: "REDACTED_" + x[9::], "device_id": lambda x: "REDACTED_" + x[9::],
"device_name": lambda x: "#MASKED_NAME#" if x else "",
"owner": lambda x: "REDACTED_" + x[9::], "owner": lambda x: "REDACTED_" + x[9::],
"mac": mask_mac, "mac": mask_mac,
"master_device_id": lambda x: "REDACTED_" + x[9::], "master_device_id": lambda x: "REDACTED_" + x[9::],
"group_id": lambda x: "REDACTED_" + x[9::], "group_id": lambda x: "REDACTED_" + x[9::],
"group_name": lambda x: "I01BU0tFRF9TU0lEIw==", "group_name": lambda x: "I01BU0tFRF9TU0lEIw==",
"encrypt_info": lambda x: {**x, "key": "", "data": ""},
} }
@ -216,6 +234,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
self, self,
*, *,
on_discovered: OnDiscoveredCallable | None = None, on_discovered: OnDiscoveredCallable | None = None,
on_discovered_raw: OnDiscoveredRawCallable | None = None,
target: str = "255.255.255.255", target: str = "255.255.255.255",
discovery_packets: int = 3, discovery_packets: int = 3,
discovery_timeout: int = 5, discovery_timeout: int = 5,
@ -240,6 +259,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
self.unsupported_device_exceptions: dict = {} self.unsupported_device_exceptions: dict = {}
self.invalid_device_exceptions: dict = {} self.invalid_device_exceptions: dict = {}
self.on_unsupported = on_unsupported self.on_unsupported = on_unsupported
self.on_discovered_raw = on_discovered_raw
self.credentials = credentials self.credentials = credentials
self.timeout = timeout self.timeout = timeout
self.discovery_timeout = discovery_timeout self.discovery_timeout = discovery_timeout
@ -329,12 +349,23 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
config.timeout = self.timeout config.timeout = self.timeout
try: try:
if port == self.discovery_port: if port == self.discovery_port:
device = Discover._get_device_instance_legacy(data, config) json_func = Discover._get_discovery_json_legacy
device_func = Discover._get_device_instance_legacy
elif port == Discover.DISCOVERY_PORT_2: elif port == Discover.DISCOVERY_PORT_2:
config.uses_http = True config.uses_http = True
device = Discover._get_device_instance(data, config) json_func = Discover._get_discovery_json
device_func = Discover._get_device_instance
else: else:
return return
info = json_func(data, ip)
if self.on_discovered_raw is not None:
self.on_discovered_raw(
{
"discovery_response": info,
"meta": {"ip": ip, "port": port},
}
)
device = device_func(info, config)
except UnsupportedDeviceError as udex: except UnsupportedDeviceError as udex:
_LOGGER.debug("Unsupported device found at %s << %s", ip, udex) _LOGGER.debug("Unsupported device found at %s << %s", ip, udex)
self.unsupported_device_exceptions[ip] = udex self.unsupported_device_exceptions[ip] = udex
@ -391,6 +422,7 @@ class Discover:
*, *,
target: str = "255.255.255.255", target: str = "255.255.255.255",
on_discovered: OnDiscoveredCallable | None = None, on_discovered: OnDiscoveredCallable | None = None,
on_discovered_raw: OnDiscoveredRawCallable | None = None,
discovery_timeout: int = 5, discovery_timeout: int = 5,
discovery_packets: int = 3, discovery_packets: int = 3,
interface: str | None = None, interface: str | None = None,
@ -421,6 +453,8 @@ class Discover:
:param target: The target address where to send the broadcast discovery :param target: The target address where to send the broadcast discovery
queries if multi-homing (e.g. 192.168.xxx.255). queries if multi-homing (e.g. 192.168.xxx.255).
:param on_discovered: coroutine to execute on discovery :param on_discovered: coroutine to execute on discovery
:param on_discovered_raw: Optional callback once discovered json is loaded
before any attempt to deserialize it and create devices
:param discovery_timeout: Seconds to wait for responses, defaults to 5 :param discovery_timeout: Seconds to wait for responses, defaults to 5
:param discovery_packets: Number of discovery packets to broadcast :param discovery_packets: Number of discovery packets to broadcast
:param interface: Bind to specific interface :param interface: Bind to specific interface
@ -443,6 +477,7 @@ class Discover:
discovery_packets=discovery_packets, discovery_packets=discovery_packets,
interface=interface, interface=interface,
on_unsupported=on_unsupported, on_unsupported=on_unsupported,
on_discovered_raw=on_discovered_raw,
credentials=credentials, credentials=credentials,
timeout=timeout, timeout=timeout,
discovery_timeout=discovery_timeout, discovery_timeout=discovery_timeout,
@ -476,6 +511,7 @@ class Discover:
credentials: Credentials | None = None, credentials: Credentials | None = None,
username: str | None = None, username: str | None = None,
password: str | None = None, password: str | None = None,
on_discovered_raw: OnDiscoveredRawCallable | None = None,
on_unsupported: OnUnsupportedCallable | None = None, on_unsupported: OnUnsupportedCallable | None = None,
) -> Device | None: ) -> Device | None:
"""Discover a single device by the given IP address. """Discover a single device by the given IP address.
@ -493,6 +529,9 @@ class Discover:
username and password are ignored if provided. username and password are ignored if provided.
:param username: Username for devices that require authentication :param username: Username for devices that require authentication
:param password: Password for devices that require authentication :param password: Password for devices that require authentication
:param on_discovered_raw: Optional callback once discovered json is loaded
before any attempt to deserialize it and create devices
:param on_unsupported: Optional callback when unsupported devices are discovered
:rtype: SmartDevice :rtype: SmartDevice
:return: Object for querying/controlling found device. :return: Object for querying/controlling found device.
""" """
@ -529,6 +568,7 @@ class Discover:
credentials=credentials, credentials=credentials,
timeout=timeout, timeout=timeout,
discovery_timeout=discovery_timeout, discovery_timeout=discovery_timeout,
on_discovered_raw=on_discovered_raw,
), ),
local_addr=("0.0.0.0", 0), # noqa: S104 local_addr=("0.0.0.0", 0), # noqa: S104
) )
@ -666,15 +706,19 @@ class Discover:
return get_device_class_from_sys_info(info) return get_device_class_from_sys_info(info)
@staticmethod @staticmethod
def _get_device_instance_legacy(data: bytes, config: DeviceConfig) -> IotDevice: def _get_discovery_json_legacy(data: bytes, ip: str) -> dict:
"""Get SmartDevice from legacy 9999 response.""" """Get discovery json from legacy 9999 response."""
try: try:
info = json_loads(XorEncryption.decrypt(data)) info = json_loads(XorEncryption.decrypt(data))
except Exception as ex: except Exception as ex:
raise KasaException( raise KasaException(
f"Unable to read response from device: {config.host}: {ex}" f"Unable to read response from device: {ip}: {ex}"
) from ex ) from ex
return info
@staticmethod
def _get_device_instance_legacy(info: dict, config: DeviceConfig) -> Device:
"""Get IotDevice from legacy 9999 response."""
if _LOGGER.isEnabledFor(logging.DEBUG): if _LOGGER.isEnabledFor(logging.DEBUG):
data = redact_data(info, IOT_REDACTORS) if Discover._redact_data else info data = redact_data(info, IOT_REDACTORS) if Discover._redact_data else info
_LOGGER.debug("[DISCOVERY] %s << %s", config.host, pf(data)) _LOGGER.debug("[DISCOVERY] %s << %s", config.host, pf(data))
@ -715,20 +759,25 @@ class Discover:
discovery_result.decrypted_data = json_loads(decrypted_data) discovery_result.decrypted_data = json_loads(decrypted_data)
@staticmethod
def _get_discovery_json(data: bytes, ip: str) -> dict:
"""Get discovery json from the new 20002 response."""
try:
info = json_loads(data[16:])
except Exception as ex:
_LOGGER.debug("Got invalid response from device %s: %s", ip, data)
raise KasaException(
f"Unable to read response from device: {ip}: {ex}"
) from ex
return info
@staticmethod @staticmethod
def _get_device_instance( def _get_device_instance(
data: bytes, info: dict,
config: DeviceConfig, config: DeviceConfig,
) -> Device: ) -> Device:
"""Get SmartDevice from the new 20002 response.""" """Get SmartDevice from the new 20002 response."""
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG) debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
try:
info = json_loads(data[16:])
except Exception as ex:
_LOGGER.debug("Got invalid response from device %s: %s", config.host, data)
raise KasaException(
f"Unable to read response from device: {config.host}: {ex}"
) from ex
try: try:
discovery_result = DiscoveryResult.from_dict(info["result"]) discovery_result = DiscoveryResult.from_dict(info["result"])
@ -757,7 +806,9 @@ class Discover:
Discover._decrypt_discovery_data(discovery_result) Discover._decrypt_discovery_data(discovery_result)
except Exception: except Exception:
_LOGGER.exception( _LOGGER.exception(
"Unable to decrypt discovery data %s: %s", config.host, data "Unable to decrypt discovery data %s: %s",
config.host,
redact_data(info, NEW_DISCOVERY_REDACTORS),
) )
type_ = discovery_result.device_type type_ = discovery_result.device_type

View File

@ -8,18 +8,24 @@ from typing import Any
try: try:
import orjson import orjson
def dumps(obj: Any, *, default: Callable | None = None) -> str: def dumps(
obj: Any, *, default: Callable | None = None, indent: bool = False
) -> str:
"""Dump JSON.""" """Dump JSON."""
return orjson.dumps(obj).decode() return orjson.dumps(
obj, option=orjson.OPT_INDENT_2 if indent else None
).decode()
loads = orjson.loads loads = orjson.loads
except ImportError: except ImportError:
import json import json
def dumps(obj: Any, *, default: Callable | None = None) -> str: def dumps(
obj: Any, *, default: Callable | None = None, indent: bool = False
) -> str:
"""Dump JSON.""" """Dump JSON."""
# Separators specified for consistency with orjson # Separators specified for consistency with orjson
return json.dumps(obj, separators=(",", ":")) return json.dumps(obj, separators=(",", ":"), indent=2 if indent else None)
loads = json.loads loads = json.loads

View File

@ -42,8 +42,9 @@ from kasa.cli.main import TYPES, _legacy_type_to_class, cli, cmd_command, raw_co
from kasa.cli.time import time from kasa.cli.time import time
from kasa.cli.usage import energy from kasa.cli.usage import energy
from kasa.cli.wifi import wifi from kasa.cli.wifi import wifi
from kasa.discover import Discover, DiscoveryResult from kasa.discover import Discover, DiscoveryResult, redact_data
from kasa.iot import IotDevice from kasa.iot import IotDevice
from kasa.json import dumps as json_dumps
from kasa.smart import SmartDevice from kasa.smart import SmartDevice
from kasa.smartcam import SmartCamDevice from kasa.smartcam import SmartCamDevice
@ -126,6 +127,36 @@ async def test_list_devices(discovery_mock, runner):
assert row in res.output assert row in res.output
async def test_discover_raw(discovery_mock, runner, mocker):
"""Test the discover raw command."""
redact_spy = mocker.patch(
"kasa.protocols.protocol.redact_data", side_effect=redact_data
)
res = await runner.invoke(
cli,
["--username", "foo", "--password", "bar", "discover", "raw"],
catch_exceptions=False,
)
assert res.exit_code == 0
expected = {
"discovery_response": discovery_mock.discovery_data,
"meta": {"ip": "127.0.0.123", "port": discovery_mock.discovery_port},
}
assert res.output == json_dumps(expected, indent=True) + "\n"
redact_spy.assert_not_called()
res = await runner.invoke(
cli,
["--username", "foo", "--password", "bar", "discover", "raw", "--redact"],
catch_exceptions=False,
)
assert res.exit_code == 0
redact_spy.assert_called()
@new_discovery @new_discovery
async def test_list_auth_failed(discovery_mock, mocker, runner): async def test_list_auth_failed(discovery_mock, mocker, runner):
"""Test that device update is called on main.""" """Test that device update is called on main."""
@ -731,6 +762,7 @@ async def test_without_device_type(dev, mocker, runner):
timeout=5, timeout=5,
discovery_timeout=7, discovery_timeout=7,
on_unsupported=ANY, on_unsupported=ANY,
on_discovered_raw=ANY,
) )