From 215b8d4e4f02a20918a8472c28666a93b4bd9fcd Mon Sep 17 00:00:00 2001 From: Steven B <51370195+sdb9696@users.noreply.github.com> Date: Mon, 5 Feb 2024 17:53:09 +0000 Subject: [PATCH] Fix discovery cli to print devices not printed during discovery timeout (#670) * Fix discovery cli to print devices not printed during discovery * Fix tests * Fix print exceptions not being propagated * Fix tests * Reduce test discover_send time * Simplify wait logic * Add tests * Remove sleep loop and make auth failed a list --- kasa/cli.py | 7 ++- kasa/discover.py | 61 ++++++++++++------- kasa/tests/conftest.py | 4 +- kasa/tests/test_cli.py | 22 ++++++- kasa/tests/test_discovery.py | 111 +++++++++++++++++++++++------------ 5 files changed, 142 insertions(+), 63 deletions(-) diff --git a/kasa/cli.py b/kasa/cli.py index 74c32e4e..53c68adb 100755 --- a/kasa/cli.py +++ b/kasa/cli.py @@ -444,12 +444,12 @@ async def discover(ctx): _echo_discovery_info(dev._discovery_info) echo() else: - discovered[dev.host] = dev.internal_state ctx.parent.obj = dev await ctx.parent.invoke(state) + discovered[dev.host] = dev.internal_state echo() - await Discover.discover( + discovered_devices = await Discover.discover( target=target, discovery_timeout=discovery_timeout, on_discovered=print_discovered, @@ -459,6 +459,9 @@ async def discover(ctx): credentials=credentials, ) + for device in discovered_devices.values(): + await device.protocol.close() + echo(f"Found {len(discovered)} devices") if unsupported: echo(f"Found {len(unsupported)} unsupported devices") diff --git a/kasa/discover.py b/kasa/discover.py index 858109e2..f9ce6e0a 100755 --- a/kasa/discover.py +++ b/kasa/discover.py @@ -4,7 +4,7 @@ import binascii import ipaddress import logging import socket -from typing import Awaitable, Callable, Dict, Optional, Set, Type, cast +from typing import Awaitable, Callable, Dict, List, Optional, Set, Type, cast # When support for cpython older than 3.11 is dropped # async_timeout can be replaced with asyncio.timeout @@ -46,6 +46,8 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): This is internal class, use :func:`Discover.discover`: instead. """ + DISCOVERY_START_TIMEOUT = 1 + discovered_devices: DeviceDict def __init__( @@ -60,7 +62,6 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): Callable[[UnsupportedDeviceException], Awaitable[None]] ] = None, port: Optional[int] = None, - discovered_event: Optional[asyncio.Event] = None, credentials: Optional[Credentials] = None, timeout: Optional[int] = None, ) -> None: @@ -79,12 +80,32 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): self.unsupported_device_exceptions: Dict = {} self.invalid_device_exceptions: Dict = {} self.on_unsupported = on_unsupported - self.discovered_event = discovered_event self.credentials = credentials self.timeout = timeout self.discovery_timeout = discovery_timeout self.seen_hosts: Set[str] = set() self.discover_task: Optional[asyncio.Task] = None + self.callback_tasks: List[asyncio.Task] = [] + self.target_discovered: bool = False + self._started_event = asyncio.Event() + + def _run_callback_task(self, coro): + task = asyncio.create_task(coro) + self.callback_tasks.append(task) + + async def wait_for_discovery_to_complete(self): + """Wait for the discovery task to complete.""" + # Give some time for connection_made event to be received + async with asyncio_timeout(self.DISCOVERY_START_TIMEOUT): + await self._started_event.wait() + try: + await self.discover_task + except asyncio.CancelledError: + # if target_discovered then cancel was called internally + if not self.target_discovered: + raise + # Wait for any pending callbacks to complete + await asyncio.gather(*self.callback_tasks) def connection_made(self, transport) -> None: """Set socket options for broadcasting.""" @@ -103,6 +124,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): ) self.discover_task = asyncio.create_task(self.do_discover()) + self._started_event.set() async def do_discover(self) -> None: """Send number of discovery datagrams.""" @@ -110,13 +132,12 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): _LOGGER.debug("[DISCOVERY] %s >> %s", self.target, Discover.DISCOVERY_QUERY) encrypted_req = XorEncryption.encrypt(req) sleep_between_packets = self.discovery_timeout / self.discovery_packets - for i in range(self.discovery_packets): + for _ in range(self.discovery_packets): if self.target in self.seen_hosts: # Stop sending for discover_single break self.transport.sendto(encrypted_req[4:], self.target_1) # type: ignore self.transport.sendto(Discover.DISCOVERY_QUERY_2, self.target_2) # type: ignore - if i < self.discovery_packets - 1: - await asyncio.sleep(sleep_between_packets) + await asyncio.sleep(sleep_between_packets) def datagram_received(self, data, addr) -> None: """Handle discovery responses.""" @@ -145,7 +166,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): _LOGGER.debug("Unsupported device found at %s << %s", ip, udex) self.unsupported_device_exceptions[ip] = udex if self.on_unsupported is not None: - asyncio.ensure_future(self.on_unsupported(udex)) + self._run_callback_task(self.on_unsupported(udex)) self._handle_discovered_event() return except SmartDeviceException as ex: @@ -157,16 +178,16 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): self.discovered_devices[ip] = device if self.on_discovered is not None: - asyncio.ensure_future(self.on_discovered(device)) + self._run_callback_task(self.on_discovered(device)) self._handle_discovered_event() def _handle_discovered_event(self): - """If discovered_event is available set it and cancel discover_task.""" - if self.discovered_event is not None: + """If target is in seen_hosts cancel discover_task.""" + if self.target in self.seen_hosts: + self.target_discovered = True if self.discover_task: self.discover_task.cancel() - self.discovered_event.set() def error_received(self, ex): """Handle asyncio.Protocol errors.""" @@ -289,7 +310,11 @@ class Discover: try: _LOGGER.debug("Waiting %s seconds for responses...", discovery_timeout) - await asyncio.sleep(discovery_timeout) + await protocol.wait_for_discovery_to_complete() + except SmartDeviceException as ex: + for device in protocol.discovered_devices.values(): + await device.protocol.close() + raise ex finally: transport.close() @@ -322,7 +347,6 @@ class Discover: :return: Object for querying/controlling found device. """ loop = asyncio.get_event_loop() - event = asyncio.Event() try: ipaddress.ip_address(host) @@ -352,7 +376,6 @@ class Discover: lambda: _DiscoverProtocol( target=ip, port=port, - discovered_event=event, credentials=credentials, timeout=timeout, discovery_timeout=discovery_timeout, @@ -365,13 +388,7 @@ class Discover: _LOGGER.debug( "Waiting a total of %s seconds for responses...", discovery_timeout ) - - async with asyncio_timeout(discovery_timeout): - await event.wait() - except asyncio.TimeoutError as ex: - raise TimeoutException( - f"Timed out getting discovery response for {host}" - ) from ex + await protocol.wait_for_discovery_to_complete() finally: transport.close() @@ -384,7 +401,7 @@ class Discover: elif ip in protocol.invalid_device_exceptions: raise protocol.invalid_device_exceptions[ip] else: - raise SmartDeviceException(f"Unable to get discovery response for {host}") + raise TimeoutException(f"Timed out getting discovery response for {host}") @staticmethod def _get_device_class(info: dict) -> Type[Device]: diff --git a/kasa/tests/conftest.py b/kasa/tests/conftest.py index b6e9135c..b5b711d9 100644 --- a/kasa/tests/conftest.py +++ b/kasa/tests/conftest.py @@ -508,7 +508,7 @@ def discovery_mock(all_fixture_data, mocker): login_version, ) - def mock_discover(self): + async def mock_discover(self): port = ( dm.port_override if dm.port_override and dm.discovery_port != 20002 @@ -561,7 +561,7 @@ def unsupported_device_info(request, mocker): discovery_data = request.param host = "127.0.0.1" - def mock_discover(self): + async def mock_discover(self): if discovery_data: data = ( b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8" diff --git a/kasa/tests/test_cli.py b/kasa/tests/test_cli.py index 58370d74..2aa07382 100644 --- a/kasa/tests/test_cli.py +++ b/kasa/tests/test_cli.py @@ -290,7 +290,7 @@ async def test_brightness(dev): @device_iot async def test_json_output(dev: Device, mocker): """Test that the json output produces correct output.""" - mocker.patch("kasa.Discover.discover", return_value=[dev]) + mocker.patch("kasa.Discover.discover", return_value={"127.0.0.1": dev}) runner = CliRunner() res = await runner.invoke(cli, ["--json", "state"], obj=dev) assert res.exit_code == 0 @@ -415,6 +415,26 @@ async def test_discover(discovery_mock, mocker): assert res.exit_code == 0 +async def test_discover_host(discovery_mock, mocker): + """Test discovery output.""" + runner = CliRunner() + res = await runner.invoke( + cli, + [ + "--discovery-timeout", + 0, + "--host", + "127.0.0.123", + "--username", + "foo", + "--password", + "bar", + "--verbose", + ], + ) + assert res.exit_code == 0 + + async def test_discover_unsupported(unsupported_device_info): """Test discovery output.""" runner = CliRunner() diff --git a/kasa/tests/test_discovery.py b/kasa/tests/test_discovery.py index e0a7fdd4..8ce5ca6e 100644 --- a/kasa/tests/test_discovery.py +++ b/kasa/tests/test_discovery.py @@ -191,7 +191,7 @@ async def test_discover_invalid_info(msg, data, mocker): """Make sure that invalid discovery information raises an exception.""" host = "127.0.0.1" - def mock_discover(self): + async def mock_discover(self): self.datagram_received( XorEncryption.encrypt(json_dumps(data))[4:], (host, 9999) ) @@ -204,7 +204,8 @@ async def test_discover_invalid_info(msg, data, mocker): async def test_discover_send(mocker): """Test discovery parameters.""" - proto = _DiscoverProtocol() + discovery_timeout = 0 + proto = _DiscoverProtocol(discovery_timeout=discovery_timeout) assert proto.discovery_packets == 3 assert proto.target_1 == ("255.255.255.255", 9999) transport = mocker.patch.object(proto, "transport") @@ -299,22 +300,25 @@ async def test_discover_single_authentication(discovery_mock, mocker): @new_discovery async def test_device_update_from_new_discovery_info(discovery_data): - device = IotDevice("127.0.0.7") + """Make sure that new discovery devices update from discovery info correctly.""" + device_class = Discover._get_device_class(discovery_data) + device = device_class("127.0.0.1") discover_info = DiscoveryResult(**discovery_data["result"]) discover_dump = discover_info.get_dict() - discover_dump["alias"] = "foobar" - discover_dump["model"] = discover_dump["device_model"] + model, _, _ = discover_dump["device_model"].partition("(") + discover_dump["model"] = model device.update_from_discover_info(discover_dump) - assert device.alias == "foobar" assert device.mac == discover_dump["mac"].replace("-", ":") - assert device.model == discover_dump["device_model"] + assert device.model == model - with pytest.raises( - SmartDeviceException, - match=re.escape("You need to await update() to access the data"), - ): - assert device.supported_modules + # TODO implement requires_update for SmartDevice + if isinstance(device, IotDevice): + with pytest.raises( + SmartDeviceException, + match=re.escape("You need to await update() to access the data"), + ): + assert device.supported_modules async def test_discover_single_http_client(discovery_mock, mocker): @@ -335,7 +339,7 @@ async def test_discover_single_http_client(discovery_mock, mocker): async def test_discover_http_client(discovery_mock, mocker): - """Make sure that discover_single returns an initialized SmartDevice instance.""" + """Make sure that discover returns an initialized SmartDevice instance.""" host = "127.0.0.1" discovery_mock.ip = host @@ -403,31 +407,24 @@ class FakeDatagramTransport(asyncio.DatagramTransport): @pytest.mark.parametrize("port", [9999, 20002]) @pytest.mark.parametrize("do_not_reply_count", [0, 1, 2, 3, 4]) async def test_do_discover_drop_packets(mocker, port, do_not_reply_count): - """Make sure that discover_single handles authenticating devices correctly.""" + """Make sure that _DiscoverProtocol handles authenticating devices correctly.""" host = "127.0.0.1" - discovery_timeout = 1 + discovery_timeout = 0 - event = asyncio.Event() dp = _DiscoverProtocol( target=host, discovery_timeout=discovery_timeout, discovery_packets=5, - discovered_event=event, ) ft = FakeDatagramTransport(dp, port, do_not_reply_count) dp.connection_made(ft) - timed_out = False - try: - async with asyncio_timeout(discovery_timeout): - await event.wait() - except asyncio.TimeoutError: - timed_out = True + await dp.wait_for_discovery_to_complete() await asyncio.sleep(0) assert ft.send_count == do_not_reply_count + 1 assert dp.discover_task.done() - assert timed_out is False + assert dp.discover_task.cancelled() @pytest.mark.parametrize( @@ -436,27 +433,69 @@ async def test_do_discover_drop_packets(mocker, port, do_not_reply_count): ids=["unknownport", "unsupporteddevice"], ) async def test_do_discover_invalid(mocker, port, will_timeout): - """Make sure that discover_single handles authenticating devices correctly.""" + """Make sure that _DiscoverProtocol handles invalid devices correctly.""" host = "127.0.0.1" - discovery_timeout = 1 + discovery_timeout = 0 - event = asyncio.Event() dp = _DiscoverProtocol( target=host, discovery_timeout=discovery_timeout, discovery_packets=5, - discovered_event=event, ) ft = FakeDatagramTransport(dp, port, 0, unsupported=True) dp.connection_made(ft) - timed_out = False - try: - async with asyncio_timeout(15): - await event.wait() - except asyncio.TimeoutError: - timed_out = True - + await dp.wait_for_discovery_to_complete() await asyncio.sleep(0) assert dp.discover_task.done() - assert timed_out is will_timeout + assert dp.discover_task.cancelled() != will_timeout + + +async def test_discover_propogates_task_exceptions(discovery_mock): + """Make sure that discover propogates callback exceptions.""" + discovery_timeout = 0 + + async def on_discovered(dev): + raise SmartDeviceException("Dummy exception") + + with pytest.raises(SmartDeviceException): + await Discover.discover( + discovery_timeout=discovery_timeout, on_discovered=on_discovered + ) + + +async def test_do_discover_no_connection(mocker): + """Make sure that if the datagram connection doesnt start a TimeoutError is raised.""" + host = "127.0.0.1" + discovery_timeout = 0 + mocker.patch.object(_DiscoverProtocol, "DISCOVERY_START_TIMEOUT", 0) + dp = _DiscoverProtocol( + target=host, + discovery_timeout=discovery_timeout, + discovery_packets=5, + ) + # Normally tests would simulate connection as per below + # ft = FakeDatagramTransport(dp, port, 0, unsupported=True) + # dp.connection_made(ft) + + with pytest.raises(asyncio.TimeoutError): + await dp.wait_for_discovery_to_complete() + + +async def test_do_discover_external_cancel(mocker): + """Make sure that a cancel other than when target is discovered propogates.""" + host = "127.0.0.1" + discovery_timeout = 1 + + dp = _DiscoverProtocol( + target=host, + discovery_timeout=discovery_timeout, + discovery_packets=1, + ) + # Normally tests would simulate connection as per below + ft = FakeDatagramTransport(dp, 9999, 1, unsupported=True) + dp.connection_made(ft) + + with pytest.raises(asyncio.TimeoutError): + async with asyncio_timeout(0): + await dp.wait_for_discovery_to_complete()