diff --git a/kasa/cli.py b/kasa/cli.py index 47f53d18..cc782432 100755 --- a/kasa/cli.py +++ b/kasa/cli.py @@ -103,6 +103,7 @@ def json_formatter_cb(result, **kwargs): "--port", envvar="KASA_PORT", required=False, + type=int, help="The port of the device to connect to.", ) @click.option( @@ -138,7 +139,17 @@ def json_formatter_cb(result, **kwargs): ) @click.version_option(package_name="python-kasa") @click.pass_context -async def cli(ctx, host, port, alias, target, debug, type, json, discovery_timeout): +async def cli( + ctx, + host, + port, + alias, + target, + debug, + type, + json, + discovery_timeout, +): """A tool for controlling TP-Link smart home devices.""" # noqa # no need to perform any checks if we are just displaying the help if sys.argv[-1] == "--help": @@ -238,13 +249,29 @@ async def join(dev: SmartDevice, ssid, password, keytype): @cli.command() @click.option("--timeout", default=3, required=False) +@click.option( + "--show-unsupported", + envvar="KASA_SHOW_UNSUPPORTED", + required=False, + default=False, + is_flag=True, + help="Print out discovered unsupported devices", +) @click.pass_context -async def discover(ctx, timeout): +async def discover(ctx, timeout, show_unsupported): """Discover devices in the network.""" target = ctx.parent.params["target"] - echo(f"Discovering devices on {target} for {timeout} seconds") sem = asyncio.Semaphore() discovered = dict() + unsupported = [] + + async def print_unsupported(data: Dict): + unsupported.append(data) + if show_unsupported: + echo(f"Found unsupported device (tapo/unknown encryption): {data}") + echo() + + echo(f"Discovering devices on {target} for {timeout} seconds") async def print_discovered(dev: SmartDevice): await dev.update() @@ -255,9 +282,23 @@ async def discover(ctx, timeout): echo() await Discover.discover( - target=target, timeout=timeout, on_discovered=print_discovered + target=target, + timeout=timeout, + on_discovered=print_discovered, + on_unsupported=print_unsupported, ) + echo(f"Found {len(discovered)} devices") + if unsupported: + echo( + f"Found {len(unsupported)} unsupported devices" + + ( + "" + if show_unsupported + else ", to show them use: kasa discover --show-unsupported" + ) + ) + return discovered diff --git a/kasa/discover.py b/kasa/discover.py index f7b5fbbf..5a78d193 100755 --- a/kasa/discover.py +++ b/kasa/discover.py @@ -1,9 +1,15 @@ """Discovery module for TP-Link Smart Home devices.""" import asyncio +import binascii import logging import socket from typing import Awaitable, Callable, Dict, Optional, Type, cast +# When support for cpython older than 3.11 is dropped +# async_timeout can be replaced with asyncio.timeout +from async_timeout import timeout as asyncio_timeout + +from kasa.exceptions import UnsupportedDeviceException from kasa.json import dumps as json_dumps from kasa.json import loads as json_loads from kasa.protocol import TPLinkSmartHomeProtocol @@ -36,13 +42,22 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): target: str = "255.255.255.255", discovery_packets: int = 3, interface: Optional[str] = None, + on_unsupported: Optional[Callable[[Dict], Awaitable[None]]] = None, + port: Optional[int] = None, + discovered_event: Optional[asyncio.Event] = None, ): self.transport = None self.discovery_packets = discovery_packets self.interface = interface self.on_discovered = on_discovered - self.target = (target, Discover.DISCOVERY_PORT) + self.discovery_port = port or Discover.DISCOVERY_PORT + self.target = (target, self.discovery_port) + self.target_2 = (target, Discover.DISCOVERY_PORT_2) self.discovered_devices = {} + self.unsupported_devices: Dict = {} + self.invalid_device_exceptions: Dict = {} + self.on_unsupported = on_unsupported + self.discovered_event = discovered_event def connection_made(self, transport) -> None: """Set socket options for broadcasting.""" @@ -69,23 +84,48 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): encrypted_req = TPLinkSmartHomeProtocol.encrypt(req) for i in range(self.discovery_packets): self.transport.sendto(encrypted_req[4:], self.target) # type: ignore + self.transport.sendto(Discover.DISCOVERY_QUERY_2, self.target_2) # type: ignore def datagram_received(self, data, addr) -> None: """Handle discovery responses.""" ip, port = addr - if ip in self.discovered_devices: + if ( + ip in self.discovered_devices + or ip in self.unsupported_devices + or ip in self.invalid_device_exceptions + ): return - info = json_loads(TPLinkSmartHomeProtocol.decrypt(data)) - _LOGGER.debug("[DISCOVERY] %s << %s", ip, info) + if port == self.discovery_port: + info = json_loads(TPLinkSmartHomeProtocol.decrypt(data)) + _LOGGER.debug("[DISCOVERY] %s << %s", ip, info) + + elif port == Discover.DISCOVERY_PORT_2: + info = json_loads(data[16:]) + self.unsupported_devices[ip] = info + if self.on_unsupported is not None: + asyncio.ensure_future(self.on_unsupported(info)) + _LOGGER.debug("[DISCOVERY] Unsupported device found at %s << %s", ip, info) + if self.discovered_event is not None and "255" not in self.target[0].split( + "." + ): + self.discovered_event.set() + return try: device_class = Discover._get_device_class(info) except SmartDeviceException as ex: - _LOGGER.debug("Unable to find device type from %s: %s", info, ex) + _LOGGER.debug( + "[DISCOVERY] Unable to find device type from %s: %s", info, ex + ) + self.invalid_device_exceptions[ip] = ex + if self.discovered_event is not None and "255" not in self.target[0].split( + "." + ): + self.discovered_event.set() return - device = device_class(ip) + device = device_class(ip, port=port) device.update_from_discover_info(info) self.discovered_devices[ip] = device @@ -93,6 +133,9 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): if self.on_discovered is not None: asyncio.ensure_future(self.on_discovered(device)) + if self.discovered_event is not None and "255" not in self.target[0].split("."): + self.discovered_event.set() + def error_received(self, ex): """Handle asyncio.Protocol errors.""" _LOGGER.error("Got error: %s", ex) @@ -142,6 +185,9 @@ class Discover: "system": {"get_sysinfo": None}, } + DISCOVERY_PORT_2 = 20002 + DISCOVERY_QUERY_2 = binascii.unhexlify("020000010000000000000000463cb5d3") + @staticmethod async def discover( *, @@ -150,6 +196,7 @@ class Discover: timeout=5, discovery_packets=3, interface=None, + on_unsupported=None, ) -> DeviceDict: """Discover supported devices. @@ -177,6 +224,7 @@ class Discover: on_discovered=on_discovered, discovery_packets=discovery_packets, interface=interface, + on_unsupported=on_unsupported, ), local_addr=("0.0.0.0", 0), ) @@ -193,22 +241,47 @@ class Discover: return protocol.discovered_devices @staticmethod - async def discover_single(host: str, *, port: Optional[int] = None) -> SmartDevice: + async def discover_single( + host: str, *, port: Optional[int] = None, timeout=5 + ) -> SmartDevice: """Discover a single device by the given IP address. :param host: Hostname of device to query :rtype: SmartDevice :return: Object for querying/controlling found device. """ - protocol = TPLinkSmartHomeProtocol(host, port=port) + loop = asyncio.get_event_loop() + event = asyncio.Event() + transport, protocol = await loop.create_datagram_endpoint( + lambda: _DiscoverProtocol(target=host, port=port, discovered_event=event), + local_addr=("0.0.0.0", 0), + ) + protocol = cast(_DiscoverProtocol, protocol) - info = await protocol.query(Discover.DISCOVERY_QUERY) + try: + _LOGGER.debug("Waiting a total of %s seconds for responses...", timeout) - device_class = Discover._get_device_class(info) - dev = device_class(host, port=port) - await dev.update() + async with asyncio_timeout(timeout): + await event.wait() + except asyncio.TimeoutError: + raise SmartDeviceException( + f"Timed out getting discovery response for {host}" + ) + finally: + transport.close() - return dev + if host in protocol.discovered_devices: + dev = protocol.discovered_devices[host] + await dev.update() + return dev + elif host in protocol.unsupported_devices: + raise UnsupportedDeviceException( + f"Unsupported device {host}: {protocol.unsupported_devices[host]}" + ) + elif host in protocol.invalid_device_exceptions: + raise protocol.invalid_device_exceptions[host] + else: + raise SmartDeviceException(f"Unable to get discovery response for {host}") @staticmethod def _get_device_class(info: dict) -> Type[SmartDevice]: diff --git a/kasa/exceptions.py b/kasa/exceptions.py index 90d36c9a..0d2ff826 100644 --- a/kasa/exceptions.py +++ b/kasa/exceptions.py @@ -3,3 +3,7 @@ class SmartDeviceException(Exception): """Base exception for device errors.""" + + +class UnsupportedDeviceException(SmartDeviceException): + """Exception for trying to connect to unsupported devices.""" diff --git a/kasa/tests/test_discovery.py b/kasa/tests/test_discovery.py index bbdaf8a8..41578a2c 100644 --- a/kasa/tests/test_discovery.py +++ b/kasa/tests/test_discovery.py @@ -1,10 +1,12 @@ # type: ignore +import re import sys import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 from kasa import DeviceType, Discover, SmartDevice, SmartDeviceException, protocol -from kasa.discover import _DiscoverProtocol +from kasa.discover import _DiscoverProtocol, json_dumps +from kasa.exceptions import UnsupportedDeviceException from .conftest import bulb, dimmer, lightstrip, plug, strip @@ -55,11 +57,73 @@ async def test_type_unknown(): @pytest.mark.parametrize("custom_port", [123, None]) async def test_discover_single(discovery_data: dict, mocker, custom_port): """Make sure that discover_single returns an initialized SmartDevice instance.""" + host = "127.0.0.1" + + def mock_discover(self): + self.datagram_received( + protocol.TPLinkSmartHomeProtocol.encrypt(json_dumps(discovery_data))[4:], + (host, custom_port or 9999), + ) + + mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover) mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data) - x = await Discover.discover_single("127.0.0.1", port=custom_port) + + x = await Discover.discover_single(host, port=custom_port) assert issubclass(x.__class__, SmartDevice) assert x._sys_info is not None - assert x.port == custom_port + assert x.port == custom_port or 9999 + + +UNSUPPORTED = { + "result": { + "device_id": "xx", + "owner": "xx", + "device_type": "SMART.TAPOPLUG", + "device_model": "P110(EU)", + "ip": "127.0.0.1", + "mac": "48-22xxx", + "is_support_iot_cloud": True, + "obd_src": "tplink", + "factory_default": False, + "mgt_encrypt_schm": { + "is_support_https": False, + "encrypt_type": "AES", + "http_port": 80, + "lv": 2, + }, + }, + "error_code": 0, +} + + +async def test_discover_single_unsupported(mocker): + """Make sure that discover_single handles unsupported devices correctly.""" + host = "127.0.0.1" + + def mock_discover(self): + if discovery_data: + data = ( + b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8" + + json_dumps(discovery_data).encode() + ) + self.datagram_received(data, (host, 20002)) + + mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover) + + # Test with a valid unsupported response + discovery_data = UNSUPPORTED + with pytest.raises( + UnsupportedDeviceException, + match=f"Unsupported device {host}: {re.escape(str(UNSUPPORTED))}", + ): + await Discover.discover_single(host) + + # Test with no response + discovery_data = None + with pytest.raises( + SmartDeviceException, match=f"Timed out getting discovery response for {host}" + ): + await Discover.discover_single(host, timeout=0.001) INVALIDS = [ @@ -75,9 +139,17 @@ INVALIDS = [ @pytest.mark.parametrize("msg, data", INVALIDS) async def test_discover_invalid_info(msg, data, mocker): """Make sure that invalid discovery information raises an exception.""" - mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=data) + host = "127.0.0.1" + + def mock_discover(self): + self.datagram_received( + protocol.TPLinkSmartHomeProtocol.encrypt(json_dumps(data))[4:], (host, 9999) + ) + + mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover) + with pytest.raises(SmartDeviceException, match=msg): - await Discover.discover_single("127.0.0.1") + await Discover.discover_single(host) async def test_discover_send(mocker): @@ -87,7 +159,7 @@ async def test_discover_send(mocker): assert proto.target == ("255.255.255.255", 9999) transport = mocker.patch.object(proto, "transport") proto.do_discover() - assert transport.sendto.call_count == proto.discovery_packets + assert transport.sendto.call_count == proto.discovery_packets * 2 async def test_discover_datagram_received(mocker, discovery_data): @@ -98,10 +170,14 @@ async def test_discover_datagram_received(mocker, discovery_data): mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "decrypt") addr = "127.0.0.1" - proto.datagram_received("", (addr, 1234)) + proto.datagram_received("", (addr, 9999)) + addr2 = "127.0.0.2" + proto.datagram_received("", (addr2, 20002)) # Check that device in discovered_devices is initialized correctly assert len(proto.discovered_devices) == 1 + # Check that unsupported device is 1 + assert len(proto.unsupported_devices) == 1 dev = proto.discovered_devices[addr] assert issubclass(dev.__class__, SmartDevice) assert dev.host == addr @@ -115,5 +191,5 @@ async def test_discover_invalid_responses(msg, data, mocker): mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "encrypt") mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "decrypt") - proto.datagram_received(data, ("127.0.0.1", 1234)) + proto.datagram_received(data, ("127.0.0.1", 9999)) assert len(proto.discovered_devices) == 0