From 209391c42212372f78d1c98afa47506bf148800e Mon Sep 17 00:00:00 2001 From: sdb9696 <51370195+sdb9696@users.noreply.github.com> Date: Tue, 19 Dec 2023 12:50:33 +0000 Subject: [PATCH] Improve CLI Discovery output (#583) - Show discovery results for unsupported devices and devices that fail to authenticate. - Rename `--show-unsupported` to `--verbose`. - Remove separate `--timeout` parameter from cli discovery so it's not confused with `--timeout` now added to cli command. - Add tests. --- kasa/cli.py | 102 ++++++++++++++++++++-------- kasa/discover.py | 22 +++--- kasa/exceptions.py | 4 ++ kasa/tests/conftest.py | 90 +++++++++++++++++++++---- kasa/tests/newfakes.py | 7 ++ kasa/tests/test_cli.py | 126 +++++++++++++++++++++++++++++++++-- kasa/tests/test_discovery.py | 2 +- 7 files changed, 298 insertions(+), 55 deletions(-) diff --git a/kasa/cli.py b/kasa/cli.py index 3557bf4e..600494df 100755 --- a/kasa/cli.py +++ b/kasa/cli.py @@ -18,8 +18,15 @@ from kasa import ( SmartBulb, SmartDevice, SmartStrip, + UnsupportedDeviceException, ) from kasa.device_factory import DEVICE_TYPE_TO_CLASS +from kasa.discover import DiscoveryResult + +try: + from pydantic.v1 import ValidationError +except ImportError: + from pydantic import ValidationError try: from rich import print as _do_echo @@ -241,7 +248,7 @@ async def cli( if host is None: echo("No host name given, trying discovery..") - return await ctx.invoke(discover, timeout=discovery_timeout) + return await ctx.invoke(discover) if type is not None: device_type = DeviceType.from_value(type) @@ -300,21 +307,21 @@ async def join(dev: SmartDevice, ssid, password, keytype): @cli.command() -@click.option("--timeout", default=3, required=False) @click.option( - "--show-unsupported", - envvar="KASA_SHOW_UNSUPPORTED", + "--verbose", + envvar="KASA_VERBOSE", required=False, default=False, is_flag=True, - help="Print out discovered unsupported devices", + help="Be more verbose on output", ) @click.pass_context -async def discover(ctx, timeout, show_unsupported): +async def discover(ctx, verbose): """Discover devices in the network.""" target = ctx.parent.params["target"] username = ctx.parent.params["username"] password = ctx.parent.params["password"] + timeout = ctx.parent.params["discovery_timeout"] credentials = Credentials(username, password) @@ -323,24 +330,37 @@ async def discover(ctx, timeout, show_unsupported): unsupported = [] auth_failed = [] - async def print_unsupported(data: str): - unsupported.append(data) - if show_unsupported: - echo(f"Found unsupported device (tapo/unknown encryption): {data}") - echo() + async def print_unsupported(unsupported_exception: UnsupportedDeviceException): + unsupported.append(unsupported_exception) + async with sem: + if unsupported_exception.discovery_result: + echo("== Unsupported device ==") + _echo_discovery_info(unsupported_exception.discovery_result) + echo() + else: + echo("== Unsupported device ==") + echo(f"\t{unsupported_exception}") + echo() echo(f"Discovering devices on {target} for {timeout} seconds") async def print_discovered(dev: SmartDevice): - try: - await dev.update() - async with sem: + async with sem: + try: + await dev.update() + except AuthenticationException: + auth_failed.append(dev._discovery_info) + echo("== Authentication failed for device ==") + _echo_discovery_info(dev._discovery_info) + echo() + else: discovered[dev.host] = dev.internal_state ctx.obj = dev await ctx.invoke(state) - echo() - except AuthenticationException as aex: - auth_failed.append(str(aex)) + if verbose: + echo() + _echo_discovery_info(dev._discovery_info) + echo() await Discover.discover( target=target, @@ -352,22 +372,50 @@ async def discover(ctx, timeout, show_unsupported): echo(f"Found {len(discovered)} devices") if unsupported: - echo( - f"Found {len(unsupported)} unsupported devices" - + ( - "" - if show_unsupported - else ", to show them use: kasa discover --show-unsupported" - ) - ) + echo(f"Found {len(unsupported)} unsupported devices") if auth_failed: echo(f"Found {len(auth_failed)} devices that failed to authenticate") - for fail in auth_failed: - echo(fail) return discovered +def _echo_dictionary(discovery_info: dict): + echo("\t[bold]== Discovery information ==[/bold]") + for key, value in discovery_info.items(): + key_name = " ".join(x.capitalize() or "_" for x in key.split("_")) + key_name_and_spaces = "{:<15}".format(key_name + ":") + echo(f"\t{key_name_and_spaces}{value}") + + +def _echo_discovery_info(discovery_info): + if "system" in discovery_info and "get_sysinfo" in discovery_info["system"]: + _echo_dictionary(discovery_info["system"]["get_sysinfo"]) + return + + try: + dr = DiscoveryResult(**discovery_info) + except ValidationError: + _echo_dictionary(discovery_info) + return + + echo("\t[bold]== Discovery Result ==[/bold]") + echo(f"\tDevice Type: {dr.device_type}") + echo(f"\tDevice Model: {dr.device_model}") + echo(f"\tIP: {dr.ip}") + echo(f"\tMAC: {dr.mac}") + echo(f"\tDevice Id (hash): {dr.device_id}") + echo(f"\tOwner (hash): {dr.owner}") + echo(f"\tHW Ver: {dr.hw_ver}") + echo(f"\tIs Support IOT Cloud: {dr.is_support_iot_cloud})") + echo(f"\tOBD Src: {dr.obd_src}") + echo(f"\tFactory Default: {dr.factory_default}") + echo("\t\t== Encryption Scheme ==") + echo(f"\t\tEncrypt Type: {dr.mgt_encrypt_schm.encrypt_type}") + echo(f"\t\tIs Support HTTPS: {dr.mgt_encrypt_schm.is_support_https}") + echo(f"\t\tHTTP Port: {dr.mgt_encrypt_schm.http_port}") + echo(f"\t\tLV (Login Level): {dr.mgt_encrypt_schm.lv}") + + async def find_host_from_alias(alias, target="255.255.255.255", timeout=1, attempts=3): """Discover a device identified by its alias.""" for _attempt in range(1, attempts): diff --git a/kasa/discover.py b/kasa/discover.py index 2038369b..4ec3775e 100755 --- a/kasa/discover.py +++ b/kasa/discover.py @@ -50,7 +50,9 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): target: str = "255.255.255.255", discovery_packets: int = 3, interface: Optional[str] = None, - on_unsupported: Optional[Callable[[str], Awaitable[None]]] = None, + on_unsupported: Optional[ + Callable[[UnsupportedDeviceException], Awaitable[None]] + ] = None, port: Optional[int] = None, discovered_event: Optional[asyncio.Event] = None, credentials: Optional[Credentials] = None, @@ -64,7 +66,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): self.target = (target, self.discovery_port) self.target_2 = (target, Discover.DISCOVERY_PORT_2) self.discovered_devices = {} - self.unsupported_devices: Dict = {} + self.unsupported_device_exceptions: Dict = {} self.invalid_device_exceptions: Dict = {} self.on_unsupported = on_unsupported self.discovered_event = discovered_event @@ -119,9 +121,9 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): return except UnsupportedDeviceException as udex: _LOGGER.debug("Unsupported device found at %s << %s", ip, udex) - self.unsupported_devices[ip] = str(udex) + self.unsupported_device_exceptions[ip] = udex if self.on_unsupported is not None: - asyncio.ensure_future(self.on_unsupported(str(udex))) + asyncio.ensure_future(self.on_unsupported(udex)) if self.discovered_event is not None: self.discovered_event.set() return @@ -336,10 +338,8 @@ class Discover: if update_parent_devices and dev.has_children: await dev.update() return dev - elif ip in protocol.unsupported_devices: - raise UnsupportedDeviceException( - f"Unsupported device {host}: {protocol.unsupported_devices[ip]}" - ) + elif ip in protocol.unsupported_device_exceptions: + raise protocol.unsupported_device_exceptions[ip] elif ip in protocol.invalid_device_exceptions: raise protocol.invalid_device_exceptions[ip] else: @@ -397,7 +397,8 @@ class Discover: if (device_class := get_device_class_from_type_name(type_)) is None: _LOGGER.warning("Got unsupported device type: %s", type_) raise UnsupportedDeviceException( - f"Unsupported device {ip} of type {type_}: {info}" + f"Unsupported device {ip} of type {type_}: {info}", + discovery_result=discovery_result.get_dict(), ) if ( protocol := get_protocol_from_connection_name( @@ -406,7 +407,8 @@ class Discover: ) is None: _LOGGER.warning("Got unsupported device type: %s", encrypt_type_) raise UnsupportedDeviceException( - f"Unsupported encryption scheme {ip} of type {encrypt_type_}: {info}" + f"Unsupported encryption scheme {ip} of type {encrypt_type_}: {info}", + discovery_result=discovery_result.get_dict(), ) _LOGGER.debug("[DISCOVERY] %s << %s", ip, info) diff --git a/kasa/exceptions.py b/kasa/exceptions.py index 22b3c1ac..e83c9237 100644 --- a/kasa/exceptions.py +++ b/kasa/exceptions.py @@ -9,6 +9,10 @@ class SmartDeviceException(Exception): class UnsupportedDeviceException(SmartDeviceException): """Exception for trying to connect to unsupported devices.""" + def __init__(self, *args, discovery_result=None): + self.discovery_result = discovery_result + super().__init__(args) + class AuthenticationException(SmartDeviceException): """Base exception for device authentication errors.""" diff --git a/kasa/tests/conftest.py b/kasa/tests/conftest.py index f84083e2..095971de 100644 --- a/kasa/tests/conftest.py +++ b/kasa/tests/conftest.py @@ -129,6 +129,37 @@ ALL_DEVICES = ALL_DEVICES_IOT.union(ALL_DEVICES_SMART) IP_MODEL_CACHE: Dict[str, str] = {} +def _make_unsupported(device_family, encrypt_type): + return { + "result": { + "device_id": "xx", + "owner": "xx", + "device_type": device_family, + "device_model": "P110(EU)", + "ip": "127.0.0.1", + "mac": "48-22xxx", + "is_support_iot_cloud": True, + "obd_src": "tplink", + "factory_default": False, + "mgt_encrypt_schm": { + "is_support_https": False, + "encrypt_type": encrypt_type, + "http_port": 80, + "lv": 2, + }, + }, + "error_code": 0, + } + + +UNSUPPORTED_DEVICES = { + "unknown_device_family": _make_unsupported("SMART.TAPOXMASTREE", "AES"), + "wrong_encryption_iot": _make_unsupported("IOT.SMARTPLUGSWITCH", "AES"), + "wrong_encryption_smart": _make_unsupported("SMART.TAPOBULB", "IOT"), + "unknown_encryption": _make_unsupported("IOT.SMARTPLUGSWITCH", "FOO"), +} + + def idgenerator(paramtuple): try: return basename(paramtuple[0]) + ( @@ -242,7 +273,7 @@ def filter_fixtures(desc, root_filter): def parametrize_discovery(desc, root_key): filtered_fixtures = filter_fixtures(desc, root_key) return pytest.mark.parametrize( - "discovery_data", + "all_fixture_data", filtered_fixtures.values(), indirect=True, ids=filtered_fixtures.keys(), @@ -360,7 +391,7 @@ async def get_device_for_file(file, protocol): return d -@pytest.fixture(params=SUPPORTED_DEVICES) +@pytest.fixture(params=SUPPORTED_DEVICES, ids=idgenerator) async def dev(request): """Device fixture. @@ -386,23 +417,27 @@ async def dev(request): @pytest.fixture -def discovery_mock(discovery_data, mocker): +def discovery_mock(all_fixture_data, mocker): @dataclass class _DiscoveryMock: ip: str default_port: int discovery_data: dict + query_data: dict port_override: Optional[int] = None - if "result" in discovery_data: + if "discovery_result" in all_fixture_data: + discovery_data = {"result": all_fixture_data["discovery_result"]} datagram = ( b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8" + json_dumps(discovery_data).encode() ) - dm = _DiscoveryMock("127.0.0.123", 20002, discovery_data) + dm = _DiscoveryMock("127.0.0.123", 20002, discovery_data, all_fixture_data) else: + sys_info = all_fixture_data["system"]["get_sysinfo"] + discovery_data = {"system": {"get_sysinfo": sys_info}} datagram = TPLinkSmartHomeProtocol.encrypt(json_dumps(discovery_data))[4:] - dm = _DiscoveryMock("127.0.0.123", 9999, discovery_data) + dm = _DiscoveryMock("127.0.0.123", 9999, discovery_data, all_fixture_data) def mock_discover(self): port = ( @@ -420,17 +455,29 @@ def discovery_mock(discovery_data, mocker): "socket.getaddrinfo", side_effect=lambda *_, **__: [(None, None, None, None, (dm.ip, 0))], ) + + if "component_nego" in dm.query_data: + proto = FakeSmartProtocol(dm.query_data) + else: + proto = FakeTransportProtocol(dm.query_data) + + async def _query(request, retry_count: int = 3): + return await proto.query(request) + + mocker.patch("kasa.IotProtocol.query", side_effect=_query) + mocker.patch("kasa.SmartProtocol.query", side_effect=_query) + mocker.patch("kasa.TPLinkSmartHomeProtocol.query", side_effect=_query) + yield dm -@pytest.fixture(params=FIXTURE_DATA.values(), ids=FIXTURE_DATA.keys(), scope="session") -def discovery_data(request): +@pytest.fixture +def discovery_data(all_fixture_data): """Return raw discovery file contents as JSON. Used for discovery tests.""" - fixture_data = request.param - if "discovery_result" in fixture_data: - return {"result": fixture_data["discovery_result"]} + if "discovery_result" in all_fixture_data: + return {"result": all_fixture_data["discovery_result"]} else: - return {"system": {"get_sysinfo": fixture_data["system"]["get_sysinfo"]}} + return {"system": {"get_sysinfo": all_fixture_data["system"]["get_sysinfo"]}} @pytest.fixture(params=FIXTURE_DATA.values(), ids=FIXTURE_DATA.keys(), scope="session") @@ -440,6 +487,25 @@ def all_fixture_data(request): return fixture_data +@pytest.fixture(params=UNSUPPORTED_DEVICES.values(), ids=UNSUPPORTED_DEVICES.keys()) +def unsupported_device_info(request, mocker): + """Return unsupported devices for cli and discovery tests.""" + discovery_data = request.param + host = "127.0.0.1" + + def mock_discover(self): + if discovery_data: + data = ( + b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8" + + json_dumps(discovery_data).encode() + ) + self.datagram_received(data, (host, 20002)) + + mocker.patch("kasa.discover._DiscoverProtocol.do_discover", mock_discover) + + yield discovery_data + + def pytest_addoption(parser): parser.addoption( "--ip", action="store", default=None, help="run against device on given ip" diff --git a/kasa/tests/newfakes.py b/kasa/tests/newfakes.py index 05064c11..284f4e2b 100644 --- a/kasa/tests/newfakes.py +++ b/kasa/tests/newfakes.py @@ -291,6 +291,13 @@ class FakeSmartProtocol(SmartProtocol): def __init__(self, info): super().__init__("127.0.0.123", transport=FakeSmartTransport(info)) + async def query(self, request, retry_count: int = 3): + """Implement query here so can still patch SmartProtocol.query.""" + resp_dict = await self._query(request, retry_count) + if "result" in resp_dict: + return resp_dict["result"] + return {} + class FakeSmartTransport(BaseTransport): def __init__(self, info): diff --git a/kasa/tests/test_cli.py b/kasa/tests/test_cli.py index 55e3977a..5add2b58 100644 --- a/kasa/tests/test_cli.py +++ b/kasa/tests/test_cli.py @@ -4,14 +4,12 @@ import asyncclick as click import pytest from asyncclick.testing import CliRunner -from kasa import SmartDevice, TPLinkSmartHomeProtocol +from kasa import AuthenticationException, SmartDevice, UnsupportedDeviceException from kasa.cli import alias, brightness, cli, emeter, raw_command, state, sysinfo, toggle from kasa.device_factory import DEVICE_TYPE_TO_CLASS from kasa.discover import Discover -from kasa.smartprotocol import SmartProtocol from .conftest import device_iot, handle_turn_on, new_discovery, turn_on -from .newfakes import FakeSmartProtocol, FakeTransportProtocol @device_iot @@ -22,7 +20,6 @@ async def test_sysinfo(dev): assert dev.alias in res.output -@device_iot @turn_on async def test_state(dev, turn_on): await handle_turn_on(dev, turn_on) @@ -36,7 +33,6 @@ async def test_state(dev, turn_on): assert "Device state: False" in res.output -@device_iot @turn_on async def test_toggle(dev, turn_on, mocker): await handle_turn_on(dev, turn_on) @@ -226,3 +222,123 @@ async def test_duplicate_target_device(): ) assert res.exit_code == 2 assert "Error: Use either --alias or --host, not both." in res.output + + +async def test_discover(discovery_mock, mocker): + """Test discovery output.""" + runner = CliRunner() + res = await runner.invoke( + cli, + [ + "--discovery-timeout", + 0, + "--username", + "foo", + "--password", + "bar", + "discover", + "--verbose", + ], + ) + assert res.exit_code == 0 + + +async def test_discover_unsupported(unsupported_device_info): + """Test discovery output.""" + runner = CliRunner() + res = await runner.invoke( + cli, + [ + "--discovery-timeout", + 0, + "--username", + "foo", + "--password", + "bar", + "discover", + "--verbose", + ], + ) + assert res.exit_code == 0 + assert "== Unsupported device ==" in res.output + assert "== Discovery Result ==" in res.output + + +async def test_host_unsupported(unsupported_device_info): + """Test discovery output.""" + runner = CliRunner() + host = "127.0.0.1" + + res = await runner.invoke( + cli, + [ + "--host", + host, + "--username", + "foo", + "--password", + "bar", + ], + ) + + assert res.exit_code != 0 + assert isinstance(res.exception, UnsupportedDeviceException) + + +@new_discovery +async def test_discover_auth_failed(discovery_mock, mocker): + """Test discovery output.""" + runner = CliRunner() + host = "127.0.0.1" + discovery_mock.ip = host + device_class = Discover._get_device_class(discovery_mock.discovery_data) + mocker.patch.object( + device_class, + "update", + side_effect=AuthenticationException("Failed to authenticate"), + ) + res = await runner.invoke( + cli, + [ + "--discovery-timeout", + 0, + "--username", + "foo", + "--password", + "bar", + "discover", + "--verbose", + ], + ) + + assert res.exit_code == 0 + assert "== Authentication failed for device ==" in res.output + assert "== Discovery Result ==" in res.output + + +@new_discovery +async def test_host_auth_failed(discovery_mock, mocker): + """Test discovery output.""" + runner = CliRunner() + host = "127.0.0.1" + discovery_mock.ip = host + device_class = Discover._get_device_class(discovery_mock.discovery_data) + mocker.patch.object( + device_class, + "update", + side_effect=AuthenticationException("Failed to authenticate"), + ) + res = await runner.invoke( + cli, + [ + "--host", + host, + "--username", + "foo", + "--password", + "bar", + ], + ) + + assert res.exit_code != 0 + assert isinstance(res.exception, AuthenticationException) diff --git a/kasa/tests/test_discovery.py b/kasa/tests/test_discovery.py index 72555c7e..18798ab9 100644 --- a/kasa/tests/test_discovery.py +++ b/kasa/tests/test_discovery.py @@ -202,7 +202,7 @@ async def test_discover_datagram_received(mocker, discovery_data): # Check that device in discovered_devices is initialized correctly assert len(proto.discovered_devices) == 1 # Check that unsupported device is 1 - assert len(proto.unsupported_devices) == 1 + assert len(proto.unsupported_device_exceptions) == 1 dev = proto.discovered_devices[addr] assert issubclass(dev.__class__, SmartDevice) assert dev.host == addr