Return raw discovery result in cli discover raw (#1342)

Add `on_discovered_raw` callback to Discover and adds a cli command `discover raw` which returns the raw json before serializing to a `DiscoveryResult` and attempting to create a device class.
This commit is contained in:
Steven B. 2024-12-10 22:42:14 +00:00 committed by GitHub
parent 464683e09b
commit bf8f0adabe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 158 additions and 23 deletions

View File

@ -14,9 +14,17 @@ from kasa import (
Discover,
UnsupportedDeviceError,
)
from kasa.discover import ConnectAttempt, DiscoveryResult
from kasa.discover import (
NEW_DISCOVERY_REDACTORS,
ConnectAttempt,
DiscoveredRaw,
DiscoveryResult,
)
from kasa.iot.iotdevice import _extract_sys_info
from kasa.protocols.iotprotocol import REDACTORS as IOT_REDACTORS
from kasa.protocols.protocol import redact_data
from ..json import dumps as json_dumps
from .common import echo, error
@ -64,7 +72,9 @@ async def detail(ctx):
await ctx.parent.invoke(state)
echo()
discovered = await _discover(ctx, print_discovered, print_unsupported)
discovered = await _discover(
ctx, print_discovered=print_discovered, print_unsupported=print_unsupported
)
if ctx.parent.parent.params["host"]:
return discovered
@ -77,6 +87,33 @@ async def detail(ctx):
return discovered
@discover.command()
@click.option(
"--redact/--no-redact",
default=False,
is_flag=True,
type=bool,
help="Set flag to redact sensitive data from raw output.",
)
@click.pass_context
async def raw(ctx, redact: bool):
"""Return raw discovery data returned from devices."""
def print_raw(discovered: DiscoveredRaw):
if redact:
redactors = (
NEW_DISCOVERY_REDACTORS
if discovered["meta"]["port"] == Discover.DISCOVERY_PORT_2
else IOT_REDACTORS
)
discovered["discovery_response"] = redact_data(
discovered["discovery_response"], redactors
)
echo(json_dumps(discovered, indent=True))
return await _discover(ctx, print_raw=print_raw, do_echo=False)
@discover.command()
@click.pass_context
async def list(ctx):
@ -102,10 +139,17 @@ async def list(ctx):
echo(f"{host:<15} UNSUPPORTED DEVICE")
echo(f"{'HOST':<15} {'DEVICE FAMILY':<20} {'ENCRYPT':<7} {'ALIAS'}")
return await _discover(ctx, print_discovered, print_unsupported, do_echo=False)
return await _discover(
ctx,
print_discovered=print_discovered,
print_unsupported=print_unsupported,
do_echo=False,
)
async def _discover(ctx, print_discovered, print_unsupported, *, do_echo=True):
async def _discover(
ctx, *, print_discovered=None, print_unsupported=None, print_raw=None, do_echo=True
):
params = ctx.parent.parent.params
target = params["target"]
username = params["username"]
@ -126,6 +170,7 @@ async def _discover(ctx, print_discovered, print_unsupported, *, do_echo=True):
timeout=timeout,
discovery_timeout=discovery_timeout,
on_unsupported=print_unsupported,
on_discovered_raw=print_raw,
)
if do_echo:
echo(f"Discovering devices on {target} for {discovery_timeout} seconds")
@ -137,6 +182,7 @@ async def _discover(ctx, print_discovered, print_unsupported, *, do_echo=True):
port=port,
timeout=timeout,
credentials=credentials,
on_discovered_raw=print_raw,
)
for device in discovered_devices.values():

View File

@ -99,6 +99,7 @@ from typing import (
Annotated,
Any,
NamedTuple,
TypedDict,
cast,
)
@ -147,18 +148,35 @@ class ConnectAttempt(NamedTuple):
device: type
class DiscoveredMeta(TypedDict):
"""Meta info about discovery response."""
ip: str
port: int
class DiscoveredRaw(TypedDict):
"""Try to connect attempt."""
meta: DiscoveredMeta
discovery_response: dict
OnDiscoveredCallable = Callable[[Device], Coroutine]
OnDiscoveredRawCallable = Callable[[DiscoveredRaw], None]
OnUnsupportedCallable = Callable[[UnsupportedDeviceError], Coroutine]
OnConnectAttemptCallable = Callable[[ConnectAttempt, bool], None]
DeviceDict = dict[str, Device]
NEW_DISCOVERY_REDACTORS: dict[str, Callable[[Any], Any] | None] = {
"device_id": lambda x: "REDACTED_" + x[9::],
"device_name": lambda x: "#MASKED_NAME#" if x else "",
"owner": lambda x: "REDACTED_" + x[9::],
"mac": mask_mac,
"master_device_id": lambda x: "REDACTED_" + x[9::],
"group_id": lambda x: "REDACTED_" + x[9::],
"group_name": lambda x: "I01BU0tFRF9TU0lEIw==",
"encrypt_info": lambda x: {**x, "key": "", "data": ""},
}
@ -216,6 +234,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
self,
*,
on_discovered: OnDiscoveredCallable | None = None,
on_discovered_raw: OnDiscoveredRawCallable | None = None,
target: str = "255.255.255.255",
discovery_packets: int = 3,
discovery_timeout: int = 5,
@ -240,6 +259,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
self.unsupported_device_exceptions: dict = {}
self.invalid_device_exceptions: dict = {}
self.on_unsupported = on_unsupported
self.on_discovered_raw = on_discovered_raw
self.credentials = credentials
self.timeout = timeout
self.discovery_timeout = discovery_timeout
@ -329,12 +349,23 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
config.timeout = self.timeout
try:
if port == self.discovery_port:
device = Discover._get_device_instance_legacy(data, config)
json_func = Discover._get_discovery_json_legacy
device_func = Discover._get_device_instance_legacy
elif port == Discover.DISCOVERY_PORT_2:
config.uses_http = True
device = Discover._get_device_instance(data, config)
json_func = Discover._get_discovery_json
device_func = Discover._get_device_instance
else:
return
info = json_func(data, ip)
if self.on_discovered_raw is not None:
self.on_discovered_raw(
{
"discovery_response": info,
"meta": {"ip": ip, "port": port},
}
)
device = device_func(info, config)
except UnsupportedDeviceError as udex:
_LOGGER.debug("Unsupported device found at %s << %s", ip, udex)
self.unsupported_device_exceptions[ip] = udex
@ -391,6 +422,7 @@ class Discover:
*,
target: str = "255.255.255.255",
on_discovered: OnDiscoveredCallable | None = None,
on_discovered_raw: OnDiscoveredRawCallable | None = None,
discovery_timeout: int = 5,
discovery_packets: int = 3,
interface: str | None = None,
@ -421,6 +453,8 @@ class Discover:
:param target: The target address where to send the broadcast discovery
queries if multi-homing (e.g. 192.168.xxx.255).
:param on_discovered: coroutine to execute on discovery
:param on_discovered_raw: Optional callback once discovered json is loaded
before any attempt to deserialize it and create devices
:param discovery_timeout: Seconds to wait for responses, defaults to 5
:param discovery_packets: Number of discovery packets to broadcast
:param interface: Bind to specific interface
@ -443,6 +477,7 @@ class Discover:
discovery_packets=discovery_packets,
interface=interface,
on_unsupported=on_unsupported,
on_discovered_raw=on_discovered_raw,
credentials=credentials,
timeout=timeout,
discovery_timeout=discovery_timeout,
@ -476,6 +511,7 @@ class Discover:
credentials: Credentials | None = None,
username: str | None = None,
password: str | None = None,
on_discovered_raw: OnDiscoveredRawCallable | None = None,
on_unsupported: OnUnsupportedCallable | None = None,
) -> Device | None:
"""Discover a single device by the given IP address.
@ -493,6 +529,9 @@ class Discover:
username and password are ignored if provided.
:param username: Username for devices that require authentication
:param password: Password for devices that require authentication
:param on_discovered_raw: Optional callback once discovered json is loaded
before any attempt to deserialize it and create devices
:param on_unsupported: Optional callback when unsupported devices are discovered
:rtype: SmartDevice
:return: Object for querying/controlling found device.
"""
@ -529,6 +568,7 @@ class Discover:
credentials=credentials,
timeout=timeout,
discovery_timeout=discovery_timeout,
on_discovered_raw=on_discovered_raw,
),
local_addr=("0.0.0.0", 0), # noqa: S104
)
@ -666,15 +706,19 @@ class Discover:
return get_device_class_from_sys_info(info)
@staticmethod
def _get_device_instance_legacy(data: bytes, config: DeviceConfig) -> IotDevice:
"""Get SmartDevice from legacy 9999 response."""
def _get_discovery_json_legacy(data: bytes, ip: str) -> dict:
"""Get discovery json from legacy 9999 response."""
try:
info = json_loads(XorEncryption.decrypt(data))
except Exception as ex:
raise KasaException(
f"Unable to read response from device: {config.host}: {ex}"
f"Unable to read response from device: {ip}: {ex}"
) from ex
return info
@staticmethod
def _get_device_instance_legacy(info: dict, config: DeviceConfig) -> Device:
"""Get IotDevice from legacy 9999 response."""
if _LOGGER.isEnabledFor(logging.DEBUG):
data = redact_data(info, IOT_REDACTORS) if Discover._redact_data else info
_LOGGER.debug("[DISCOVERY] %s << %s", config.host, pf(data))
@ -715,20 +759,25 @@ class Discover:
discovery_result.decrypted_data = json_loads(decrypted_data)
@staticmethod
def _get_discovery_json(data: bytes, ip: str) -> dict:
"""Get discovery json from the new 20002 response."""
try:
info = json_loads(data[16:])
except Exception as ex:
_LOGGER.debug("Got invalid response from device %s: %s", ip, data)
raise KasaException(
f"Unable to read response from device: {ip}: {ex}"
) from ex
return info
@staticmethod
def _get_device_instance(
data: bytes,
info: dict,
config: DeviceConfig,
) -> Device:
"""Get SmartDevice from the new 20002 response."""
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
try:
info = json_loads(data[16:])
except Exception as ex:
_LOGGER.debug("Got invalid response from device %s: %s", config.host, data)
raise KasaException(
f"Unable to read response from device: {config.host}: {ex}"
) from ex
try:
discovery_result = DiscoveryResult.from_dict(info["result"])
@ -757,7 +806,9 @@ class Discover:
Discover._decrypt_discovery_data(discovery_result)
except Exception:
_LOGGER.exception(
"Unable to decrypt discovery data %s: %s", config.host, data
"Unable to decrypt discovery data %s: %s",
config.host,
redact_data(info, NEW_DISCOVERY_REDACTORS),
)
type_ = discovery_result.device_type

View File

@ -8,18 +8,24 @@ from typing import Any
try:
import orjson
def dumps(obj: Any, *, default: Callable | None = None) -> str:
def dumps(
obj: Any, *, default: Callable | None = None, indent: bool = False
) -> str:
"""Dump JSON."""
return orjson.dumps(obj).decode()
return orjson.dumps(
obj, option=orjson.OPT_INDENT_2 if indent else None
).decode()
loads = orjson.loads
except ImportError:
import json
def dumps(obj: Any, *, default: Callable | None = None) -> str:
def dumps(
obj: Any, *, default: Callable | None = None, indent: bool = False
) -> str:
"""Dump JSON."""
# Separators specified for consistency with orjson
return json.dumps(obj, separators=(",", ":"))
return json.dumps(obj, separators=(",", ":"), indent=2 if indent else None)
loads = json.loads

View File

@ -42,8 +42,9 @@ from kasa.cli.main import TYPES, _legacy_type_to_class, cli, cmd_command, raw_co
from kasa.cli.time import time
from kasa.cli.usage import energy
from kasa.cli.wifi import wifi
from kasa.discover import Discover, DiscoveryResult
from kasa.discover import Discover, DiscoveryResult, redact_data
from kasa.iot import IotDevice
from kasa.json import dumps as json_dumps
from kasa.smart import SmartDevice
from kasa.smartcam import SmartCamDevice
@ -126,6 +127,36 @@ async def test_list_devices(discovery_mock, runner):
assert row in res.output
async def test_discover_raw(discovery_mock, runner, mocker):
"""Test the discover raw command."""
redact_spy = mocker.patch(
"kasa.protocols.protocol.redact_data", side_effect=redact_data
)
res = await runner.invoke(
cli,
["--username", "foo", "--password", "bar", "discover", "raw"],
catch_exceptions=False,
)
assert res.exit_code == 0
expected = {
"discovery_response": discovery_mock.discovery_data,
"meta": {"ip": "127.0.0.123", "port": discovery_mock.discovery_port},
}
assert res.output == json_dumps(expected, indent=True) + "\n"
redact_spy.assert_not_called()
res = await runner.invoke(
cli,
["--username", "foo", "--password", "bar", "discover", "raw", "--redact"],
catch_exceptions=False,
)
assert res.exit_code == 0
redact_spy.assert_called()
@new_discovery
async def test_list_auth_failed(discovery_mock, mocker, runner):
"""Test that device update is called on main."""
@ -731,6 +762,7 @@ async def test_without_device_type(dev, mocker, runner):
timeout=5,
discovery_timeout=7,
on_unsupported=ANY,
on_discovered_raw=ANY,
)