Handle KeyboardInterrupts in the cli better (#1391)

Addresses an issue with how `asyncclick` deals with `KeyboardInterrupt`
errors. Instead of the `click.main` receiving `KeyboardInterrupt` it
receives `CancelledError` because it's a task running inside the loop.

Also ensures that discovery catches the `CancelledError` and closes the
http clients.
This commit is contained in:
Steven B. 2024-12-20 13:21:38 +00:00 committed by GitHub
parent fe88b52e19
commit 296af3192e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 18 additions and 1 deletions

View File

@ -2,12 +2,14 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import json import json
import re import re
import sys import sys
from collections.abc import Callable from collections.abc import Callable
from contextlib import contextmanager from contextlib import contextmanager
from functools import singledispatch, update_wrapper, wraps from functools import singledispatch, update_wrapper, wraps
from gettext import gettext
from typing import TYPE_CHECKING, Any, Final from typing import TYPE_CHECKING, Any, Final
import asyncclick as click import asyncclick as click
@ -238,4 +240,19 @@ def CatchAllExceptions(cls):
except Exception as exc: except Exception as exc:
_handle_exception(self._debug, exc) _handle_exception(self._debug, exc)
def __call__(self, *args, **kwargs):
"""Run the coroutine in the event loop and print any exceptions.
python click catches KeyboardInterrupt in main, raises Abort()
and does sys.exit. asyncclick doesn't properly handle a coroutine
receiving CancelledError on a KeyboardInterrupt, so we catch the
KeyboardInterrupt here once asyncio.run has re-raised it. This
avoids large stacktraces when a user presses Ctrl-C.
"""
try:
asyncio.run(self.main(*args, **kwargs))
except KeyboardInterrupt:
click.echo(gettext("\nAborted!"), file=sys.stderr)
sys.exit(1)
return _CommandCls return _CommandCls

View File

@ -498,7 +498,7 @@ class Discover:
try: try:
_LOGGER.debug("Waiting %s seconds for responses...", discovery_timeout) _LOGGER.debug("Waiting %s seconds for responses...", discovery_timeout)
await protocol.wait_for_discovery_to_complete() await protocol.wait_for_discovery_to_complete()
except KasaException as ex: except (KasaException, asyncio.CancelledError) as ex:
for device in protocol.discovered_devices.values(): for device in protocol.discovered_devices.values():
await device.protocol.close() await device.protocol.close()
raise ex raise ex