diff --git a/kasa/cli/discover.py b/kasa/cli/discover.py index f8967066..5e676a1d 100644 --- a/kasa/cli/discover.py +++ b/kasa/cli/discover.py @@ -14,9 +14,17 @@ from kasa import ( Discover, UnsupportedDeviceError, ) -from kasa.discover import ConnectAttempt, DiscoveryResult +from kasa.discover import ( + NEW_DISCOVERY_REDACTORS, + ConnectAttempt, + DiscoveredRaw, + DiscoveryResult, +) from kasa.iot.iotdevice import _extract_sys_info +from kasa.protocols.iotprotocol import REDACTORS as IOT_REDACTORS +from kasa.protocols.protocol import redact_data +from ..json import dumps as json_dumps from .common import echo, error @@ -64,7 +72,9 @@ async def detail(ctx): await ctx.parent.invoke(state) echo() - discovered = await _discover(ctx, print_discovered, print_unsupported) + discovered = await _discover( + ctx, print_discovered=print_discovered, print_unsupported=print_unsupported + ) if ctx.parent.parent.params["host"]: return discovered @@ -77,6 +87,33 @@ async def detail(ctx): return discovered +@discover.command() +@click.option( + "--redact/--no-redact", + default=False, + is_flag=True, + type=bool, + help="Set flag to redact sensitive data from raw output.", +) +@click.pass_context +async def raw(ctx, redact: bool): + """Return raw discovery data returned from devices.""" + + def print_raw(discovered: DiscoveredRaw): + if redact: + redactors = ( + NEW_DISCOVERY_REDACTORS + if discovered["meta"]["port"] == Discover.DISCOVERY_PORT_2 + else IOT_REDACTORS + ) + discovered["discovery_response"] = redact_data( + discovered["discovery_response"], redactors + ) + echo(json_dumps(discovered, indent=True)) + + return await _discover(ctx, print_raw=print_raw, do_echo=False) + + @discover.command() @click.pass_context async def list(ctx): @@ -102,10 +139,17 @@ async def list(ctx): echo(f"{host:<15} UNSUPPORTED DEVICE") echo(f"{'HOST':<15} {'DEVICE FAMILY':<20} {'ENCRYPT':<7} {'ALIAS'}") - return await _discover(ctx, print_discovered, print_unsupported, do_echo=False) + return await _discover( + ctx, + print_discovered=print_discovered, + print_unsupported=print_unsupported, + do_echo=False, + ) -async def _discover(ctx, print_discovered, print_unsupported, *, do_echo=True): +async def _discover( + ctx, *, print_discovered=None, print_unsupported=None, print_raw=None, do_echo=True +): params = ctx.parent.parent.params target = params["target"] username = params["username"] @@ -126,6 +170,7 @@ async def _discover(ctx, print_discovered, print_unsupported, *, do_echo=True): timeout=timeout, discovery_timeout=discovery_timeout, on_unsupported=print_unsupported, + on_discovered_raw=print_raw, ) if do_echo: echo(f"Discovering devices on {target} for {discovery_timeout} seconds") @@ -137,6 +182,7 @@ async def _discover(ctx, print_discovered, print_unsupported, *, do_echo=True): port=port, timeout=timeout, credentials=credentials, + on_discovered_raw=print_raw, ) for device in discovered_devices.values(): diff --git a/kasa/discover.py b/kasa/discover.py index 9cb0808d..d88fcc09 100755 --- a/kasa/discover.py +++ b/kasa/discover.py @@ -99,6 +99,7 @@ from typing import ( Annotated, Any, NamedTuple, + TypedDict, cast, ) @@ -147,18 +148,35 @@ class ConnectAttempt(NamedTuple): device: type +class DiscoveredMeta(TypedDict): + """Meta info about discovery response.""" + + ip: str + port: int + + +class DiscoveredRaw(TypedDict): + """Try to connect attempt.""" + + meta: DiscoveredMeta + discovery_response: dict + + OnDiscoveredCallable = Callable[[Device], Coroutine] +OnDiscoveredRawCallable = Callable[[DiscoveredRaw], None] OnUnsupportedCallable = Callable[[UnsupportedDeviceError], Coroutine] OnConnectAttemptCallable = Callable[[ConnectAttempt, bool], None] DeviceDict = dict[str, Device] NEW_DISCOVERY_REDACTORS: dict[str, Callable[[Any], Any] | None] = { "device_id": lambda x: "REDACTED_" + x[9::], + "device_name": lambda x: "#MASKED_NAME#" if x else "", "owner": lambda x: "REDACTED_" + x[9::], "mac": mask_mac, "master_device_id": lambda x: "REDACTED_" + x[9::], "group_id": lambda x: "REDACTED_" + x[9::], "group_name": lambda x: "I01BU0tFRF9TU0lEIw==", + "encrypt_info": lambda x: {**x, "key": "", "data": ""}, } @@ -216,6 +234,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): self, *, on_discovered: OnDiscoveredCallable | None = None, + on_discovered_raw: OnDiscoveredRawCallable | None = None, target: str = "255.255.255.255", discovery_packets: int = 3, discovery_timeout: int = 5, @@ -240,6 +259,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): self.unsupported_device_exceptions: dict = {} self.invalid_device_exceptions: dict = {} self.on_unsupported = on_unsupported + self.on_discovered_raw = on_discovered_raw self.credentials = credentials self.timeout = timeout self.discovery_timeout = discovery_timeout @@ -329,12 +349,23 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): config.timeout = self.timeout try: if port == self.discovery_port: - device = Discover._get_device_instance_legacy(data, config) + json_func = Discover._get_discovery_json_legacy + device_func = Discover._get_device_instance_legacy elif port == Discover.DISCOVERY_PORT_2: config.uses_http = True - device = Discover._get_device_instance(data, config) + json_func = Discover._get_discovery_json + device_func = Discover._get_device_instance else: return + info = json_func(data, ip) + if self.on_discovered_raw is not None: + self.on_discovered_raw( + { + "discovery_response": info, + "meta": {"ip": ip, "port": port}, + } + ) + device = device_func(info, config) except UnsupportedDeviceError as udex: _LOGGER.debug("Unsupported device found at %s << %s", ip, udex) self.unsupported_device_exceptions[ip] = udex @@ -391,6 +422,7 @@ class Discover: *, target: str = "255.255.255.255", on_discovered: OnDiscoveredCallable | None = None, + on_discovered_raw: OnDiscoveredRawCallable | None = None, discovery_timeout: int = 5, discovery_packets: int = 3, interface: str | None = None, @@ -421,6 +453,8 @@ class Discover: :param target: The target address where to send the broadcast discovery queries if multi-homing (e.g. 192.168.xxx.255). :param on_discovered: coroutine to execute on discovery + :param on_discovered_raw: Optional callback once discovered json is loaded + before any attempt to deserialize it and create devices :param discovery_timeout: Seconds to wait for responses, defaults to 5 :param discovery_packets: Number of discovery packets to broadcast :param interface: Bind to specific interface @@ -443,6 +477,7 @@ class Discover: discovery_packets=discovery_packets, interface=interface, on_unsupported=on_unsupported, + on_discovered_raw=on_discovered_raw, credentials=credentials, timeout=timeout, discovery_timeout=discovery_timeout, @@ -476,6 +511,7 @@ class Discover: credentials: Credentials | None = None, username: str | None = None, password: str | None = None, + on_discovered_raw: OnDiscoveredRawCallable | None = None, on_unsupported: OnUnsupportedCallable | None = None, ) -> Device | None: """Discover a single device by the given IP address. @@ -493,6 +529,9 @@ class Discover: username and password are ignored if provided. :param username: Username for devices that require authentication :param password: Password for devices that require authentication + :param on_discovered_raw: Optional callback once discovered json is loaded + before any attempt to deserialize it and create devices + :param on_unsupported: Optional callback when unsupported devices are discovered :rtype: SmartDevice :return: Object for querying/controlling found device. """ @@ -529,6 +568,7 @@ class Discover: credentials=credentials, timeout=timeout, discovery_timeout=discovery_timeout, + on_discovered_raw=on_discovered_raw, ), local_addr=("0.0.0.0", 0), # noqa: S104 ) @@ -666,15 +706,19 @@ class Discover: return get_device_class_from_sys_info(info) @staticmethod - def _get_device_instance_legacy(data: bytes, config: DeviceConfig) -> IotDevice: - """Get SmartDevice from legacy 9999 response.""" + def _get_discovery_json_legacy(data: bytes, ip: str) -> dict: + """Get discovery json from legacy 9999 response.""" try: info = json_loads(XorEncryption.decrypt(data)) except Exception as ex: raise KasaException( - f"Unable to read response from device: {config.host}: {ex}" + f"Unable to read response from device: {ip}: {ex}" ) from ex + return info + @staticmethod + def _get_device_instance_legacy(info: dict, config: DeviceConfig) -> Device: + """Get IotDevice from legacy 9999 response.""" if _LOGGER.isEnabledFor(logging.DEBUG): data = redact_data(info, IOT_REDACTORS) if Discover._redact_data else info _LOGGER.debug("[DISCOVERY] %s << %s", config.host, pf(data)) @@ -715,20 +759,25 @@ class Discover: discovery_result.decrypted_data = json_loads(decrypted_data) + @staticmethod + def _get_discovery_json(data: bytes, ip: str) -> dict: + """Get discovery json from the new 20002 response.""" + try: + info = json_loads(data[16:]) + except Exception as ex: + _LOGGER.debug("Got invalid response from device %s: %s", ip, data) + raise KasaException( + f"Unable to read response from device: {ip}: {ex}" + ) from ex + return info + @staticmethod def _get_device_instance( - data: bytes, + info: dict, config: DeviceConfig, ) -> Device: """Get SmartDevice from the new 20002 response.""" debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG) - try: - info = json_loads(data[16:]) - except Exception as ex: - _LOGGER.debug("Got invalid response from device %s: %s", config.host, data) - raise KasaException( - f"Unable to read response from device: {config.host}: {ex}" - ) from ex try: discovery_result = DiscoveryResult.from_dict(info["result"]) @@ -757,7 +806,9 @@ class Discover: Discover._decrypt_discovery_data(discovery_result) except Exception: _LOGGER.exception( - "Unable to decrypt discovery data %s: %s", config.host, data + "Unable to decrypt discovery data %s: %s", + config.host, + redact_data(info, NEW_DISCOVERY_REDACTORS), ) type_ = discovery_result.device_type diff --git a/kasa/json.py b/kasa/json.py index 21c6fa00..8a0eab7b 100755 --- a/kasa/json.py +++ b/kasa/json.py @@ -8,18 +8,24 @@ from typing import Any try: import orjson - def dumps(obj: Any, *, default: Callable | None = None) -> str: + def dumps( + obj: Any, *, default: Callable | None = None, indent: bool = False + ) -> str: """Dump JSON.""" - return orjson.dumps(obj).decode() + return orjson.dumps( + obj, option=orjson.OPT_INDENT_2 if indent else None + ).decode() loads = orjson.loads except ImportError: import json - def dumps(obj: Any, *, default: Callable | None = None) -> str: + def dumps( + obj: Any, *, default: Callable | None = None, indent: bool = False + ) -> str: """Dump JSON.""" # Separators specified for consistency with orjson - return json.dumps(obj, separators=(",", ":")) + return json.dumps(obj, separators=(",", ":"), indent=2 if indent else None) loads = json.loads diff --git a/tests/test_cli.py b/tests/test_cli.py index d1fc330c..4391b998 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -42,8 +42,9 @@ from kasa.cli.main import TYPES, _legacy_type_to_class, cli, cmd_command, raw_co from kasa.cli.time import time from kasa.cli.usage import energy from kasa.cli.wifi import wifi -from kasa.discover import Discover, DiscoveryResult +from kasa.discover import Discover, DiscoveryResult, redact_data from kasa.iot import IotDevice +from kasa.json import dumps as json_dumps from kasa.smart import SmartDevice from kasa.smartcam import SmartCamDevice @@ -126,6 +127,36 @@ async def test_list_devices(discovery_mock, runner): assert row in res.output +async def test_discover_raw(discovery_mock, runner, mocker): + """Test the discover raw command.""" + redact_spy = mocker.patch( + "kasa.protocols.protocol.redact_data", side_effect=redact_data + ) + res = await runner.invoke( + cli, + ["--username", "foo", "--password", "bar", "discover", "raw"], + catch_exceptions=False, + ) + assert res.exit_code == 0 + + expected = { + "discovery_response": discovery_mock.discovery_data, + "meta": {"ip": "127.0.0.123", "port": discovery_mock.discovery_port}, + } + assert res.output == json_dumps(expected, indent=True) + "\n" + + redact_spy.assert_not_called() + + res = await runner.invoke( + cli, + ["--username", "foo", "--password", "bar", "discover", "raw", "--redact"], + catch_exceptions=False, + ) + assert res.exit_code == 0 + + redact_spy.assert_called() + + @new_discovery async def test_list_auth_failed(discovery_mock, mocker, runner): """Test that device update is called on main.""" @@ -731,6 +762,7 @@ async def test_without_device_type(dev, mocker, runner): timeout=5, discovery_timeout=7, on_unsupported=ANY, + on_discovered_raw=ANY, )