Fix discovery cli to print devices not printed during discovery timeout (#670)

* Fix discovery cli to print devices not printed during discovery

* Fix tests

* Fix print exceptions not being propagated

* Fix tests

* Reduce test discover_send time

* Simplify wait logic

* Add tests

* Remove sleep loop and make auth failed a list
This commit is contained in:
Steven B 2024-02-05 17:53:09 +00:00 committed by GitHub
parent 0d119e63d0
commit 215b8d4e4f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 142 additions and 63 deletions

View File

@ -444,12 +444,12 @@ async def discover(ctx):
_echo_discovery_info(dev._discovery_info) _echo_discovery_info(dev._discovery_info)
echo() echo()
else: else:
discovered[dev.host] = dev.internal_state
ctx.parent.obj = dev ctx.parent.obj = dev
await ctx.parent.invoke(state) await ctx.parent.invoke(state)
discovered[dev.host] = dev.internal_state
echo() echo()
await Discover.discover( discovered_devices = await Discover.discover(
target=target, target=target,
discovery_timeout=discovery_timeout, discovery_timeout=discovery_timeout,
on_discovered=print_discovered, on_discovered=print_discovered,
@ -459,6 +459,9 @@ async def discover(ctx):
credentials=credentials, credentials=credentials,
) )
for device in discovered_devices.values():
await device.protocol.close()
echo(f"Found {len(discovered)} devices") echo(f"Found {len(discovered)} devices")
if unsupported: if unsupported:
echo(f"Found {len(unsupported)} unsupported devices") echo(f"Found {len(unsupported)} unsupported devices")

View File

@ -4,7 +4,7 @@ import binascii
import ipaddress import ipaddress
import logging import logging
import socket import socket
from typing import Awaitable, Callable, Dict, Optional, Set, Type, cast from typing import Awaitable, Callable, Dict, List, Optional, Set, Type, cast
# When support for cpython older than 3.11 is dropped # When support for cpython older than 3.11 is dropped
# async_timeout can be replaced with asyncio.timeout # async_timeout can be replaced with asyncio.timeout
@ -46,6 +46,8 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
This is internal class, use :func:`Discover.discover`: instead. This is internal class, use :func:`Discover.discover`: instead.
""" """
DISCOVERY_START_TIMEOUT = 1
discovered_devices: DeviceDict discovered_devices: DeviceDict
def __init__( def __init__(
@ -60,7 +62,6 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
Callable[[UnsupportedDeviceException], Awaitable[None]] Callable[[UnsupportedDeviceException], Awaitable[None]]
] = None, ] = None,
port: Optional[int] = None, port: Optional[int] = None,
discovered_event: Optional[asyncio.Event] = None,
credentials: Optional[Credentials] = None, credentials: Optional[Credentials] = None,
timeout: Optional[int] = None, timeout: Optional[int] = None,
) -> None: ) -> None:
@ -79,12 +80,32 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
self.unsupported_device_exceptions: Dict = {} self.unsupported_device_exceptions: Dict = {}
self.invalid_device_exceptions: Dict = {} self.invalid_device_exceptions: Dict = {}
self.on_unsupported = on_unsupported self.on_unsupported = on_unsupported
self.discovered_event = discovered_event
self.credentials = credentials self.credentials = credentials
self.timeout = timeout self.timeout = timeout
self.discovery_timeout = discovery_timeout self.discovery_timeout = discovery_timeout
self.seen_hosts: Set[str] = set() self.seen_hosts: Set[str] = set()
self.discover_task: Optional[asyncio.Task] = None self.discover_task: Optional[asyncio.Task] = None
self.callback_tasks: List[asyncio.Task] = []
self.target_discovered: bool = False
self._started_event = asyncio.Event()
def _run_callback_task(self, coro):
task = asyncio.create_task(coro)
self.callback_tasks.append(task)
async def wait_for_discovery_to_complete(self):
"""Wait for the discovery task to complete."""
# Give some time for connection_made event to be received
async with asyncio_timeout(self.DISCOVERY_START_TIMEOUT):
await self._started_event.wait()
try:
await self.discover_task
except asyncio.CancelledError:
# if target_discovered then cancel was called internally
if not self.target_discovered:
raise
# Wait for any pending callbacks to complete
await asyncio.gather(*self.callback_tasks)
def connection_made(self, transport) -> None: def connection_made(self, transport) -> None:
"""Set socket options for broadcasting.""" """Set socket options for broadcasting."""
@ -103,6 +124,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
) )
self.discover_task = asyncio.create_task(self.do_discover()) self.discover_task = asyncio.create_task(self.do_discover())
self._started_event.set()
async def do_discover(self) -> None: async def do_discover(self) -> None:
"""Send number of discovery datagrams.""" """Send number of discovery datagrams."""
@ -110,13 +132,12 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
_LOGGER.debug("[DISCOVERY] %s >> %s", self.target, Discover.DISCOVERY_QUERY) _LOGGER.debug("[DISCOVERY] %s >> %s", self.target, Discover.DISCOVERY_QUERY)
encrypted_req = XorEncryption.encrypt(req) encrypted_req = XorEncryption.encrypt(req)
sleep_between_packets = self.discovery_timeout / self.discovery_packets sleep_between_packets = self.discovery_timeout / self.discovery_packets
for i in range(self.discovery_packets): for _ in range(self.discovery_packets):
if self.target in self.seen_hosts: # Stop sending for discover_single if self.target in self.seen_hosts: # Stop sending for discover_single
break break
self.transport.sendto(encrypted_req[4:], self.target_1) # type: ignore self.transport.sendto(encrypted_req[4:], self.target_1) # type: ignore
self.transport.sendto(Discover.DISCOVERY_QUERY_2, self.target_2) # type: ignore self.transport.sendto(Discover.DISCOVERY_QUERY_2, self.target_2) # type: ignore
if i < self.discovery_packets - 1: await asyncio.sleep(sleep_between_packets)
await asyncio.sleep(sleep_between_packets)
def datagram_received(self, data, addr) -> None: def datagram_received(self, data, addr) -> None:
"""Handle discovery responses.""" """Handle discovery responses."""
@ -145,7 +166,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
_LOGGER.debug("Unsupported device found at %s << %s", ip, udex) _LOGGER.debug("Unsupported device found at %s << %s", ip, udex)
self.unsupported_device_exceptions[ip] = udex self.unsupported_device_exceptions[ip] = udex
if self.on_unsupported is not None: if self.on_unsupported is not None:
asyncio.ensure_future(self.on_unsupported(udex)) self._run_callback_task(self.on_unsupported(udex))
self._handle_discovered_event() self._handle_discovered_event()
return return
except SmartDeviceException as ex: except SmartDeviceException as ex:
@ -157,16 +178,16 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
self.discovered_devices[ip] = device self.discovered_devices[ip] = device
if self.on_discovered is not None: if self.on_discovered is not None:
asyncio.ensure_future(self.on_discovered(device)) self._run_callback_task(self.on_discovered(device))
self._handle_discovered_event() self._handle_discovered_event()
def _handle_discovered_event(self): def _handle_discovered_event(self):
"""If discovered_event is available set it and cancel discover_task.""" """If target is in seen_hosts cancel discover_task."""
if self.discovered_event is not None: if self.target in self.seen_hosts:
self.target_discovered = True
if self.discover_task: if self.discover_task:
self.discover_task.cancel() self.discover_task.cancel()
self.discovered_event.set()
def error_received(self, ex): def error_received(self, ex):
"""Handle asyncio.Protocol errors.""" """Handle asyncio.Protocol errors."""
@ -289,7 +310,11 @@ class Discover:
try: try:
_LOGGER.debug("Waiting %s seconds for responses...", discovery_timeout) _LOGGER.debug("Waiting %s seconds for responses...", discovery_timeout)
await asyncio.sleep(discovery_timeout) await protocol.wait_for_discovery_to_complete()
except SmartDeviceException as ex:
for device in protocol.discovered_devices.values():
await device.protocol.close()
raise ex
finally: finally:
transport.close() transport.close()
@ -322,7 +347,6 @@ class Discover:
:return: Object for querying/controlling found device. :return: Object for querying/controlling found device.
""" """
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
event = asyncio.Event()
try: try:
ipaddress.ip_address(host) ipaddress.ip_address(host)
@ -352,7 +376,6 @@ class Discover:
lambda: _DiscoverProtocol( lambda: _DiscoverProtocol(
target=ip, target=ip,
port=port, port=port,
discovered_event=event,
credentials=credentials, credentials=credentials,
timeout=timeout, timeout=timeout,
discovery_timeout=discovery_timeout, discovery_timeout=discovery_timeout,
@ -365,13 +388,7 @@ class Discover:
_LOGGER.debug( _LOGGER.debug(
"Waiting a total of %s seconds for responses...", discovery_timeout "Waiting a total of %s seconds for responses...", discovery_timeout
) )
await protocol.wait_for_discovery_to_complete()
async with asyncio_timeout(discovery_timeout):
await event.wait()
except asyncio.TimeoutError as ex:
raise TimeoutException(
f"Timed out getting discovery response for {host}"
) from ex
finally: finally:
transport.close() transport.close()
@ -384,7 +401,7 @@ class Discover:
elif ip in protocol.invalid_device_exceptions: elif ip in protocol.invalid_device_exceptions:
raise protocol.invalid_device_exceptions[ip] raise protocol.invalid_device_exceptions[ip]
else: else:
raise SmartDeviceException(f"Unable to get discovery response for {host}") raise TimeoutException(f"Timed out getting discovery response for {host}")
@staticmethod @staticmethod
def _get_device_class(info: dict) -> Type[Device]: def _get_device_class(info: dict) -> Type[Device]:

View File

@ -508,7 +508,7 @@ def discovery_mock(all_fixture_data, mocker):
login_version, login_version,
) )
def mock_discover(self): async def mock_discover(self):
port = ( port = (
dm.port_override dm.port_override
if dm.port_override and dm.discovery_port != 20002 if dm.port_override and dm.discovery_port != 20002
@ -561,7 +561,7 @@ def unsupported_device_info(request, mocker):
discovery_data = request.param discovery_data = request.param
host = "127.0.0.1" host = "127.0.0.1"
def mock_discover(self): async def mock_discover(self):
if discovery_data: if discovery_data:
data = ( data = (
b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8" b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8"

View File

@ -290,7 +290,7 @@ async def test_brightness(dev):
@device_iot @device_iot
async def test_json_output(dev: Device, mocker): async def test_json_output(dev: Device, mocker):
"""Test that the json output produces correct output.""" """Test that the json output produces correct output."""
mocker.patch("kasa.Discover.discover", return_value=[dev]) mocker.patch("kasa.Discover.discover", return_value={"127.0.0.1": dev})
runner = CliRunner() runner = CliRunner()
res = await runner.invoke(cli, ["--json", "state"], obj=dev) res = await runner.invoke(cli, ["--json", "state"], obj=dev)
assert res.exit_code == 0 assert res.exit_code == 0
@ -415,6 +415,26 @@ async def test_discover(discovery_mock, mocker):
assert res.exit_code == 0 assert res.exit_code == 0
async def test_discover_host(discovery_mock, mocker):
"""Test discovery output."""
runner = CliRunner()
res = await runner.invoke(
cli,
[
"--discovery-timeout",
0,
"--host",
"127.0.0.123",
"--username",
"foo",
"--password",
"bar",
"--verbose",
],
)
assert res.exit_code == 0
async def test_discover_unsupported(unsupported_device_info): async def test_discover_unsupported(unsupported_device_info):
"""Test discovery output.""" """Test discovery output."""
runner = CliRunner() runner = CliRunner()

View File

@ -191,7 +191,7 @@ async def test_discover_invalid_info(msg, data, mocker):
"""Make sure that invalid discovery information raises an exception.""" """Make sure that invalid discovery information raises an exception."""
host = "127.0.0.1" host = "127.0.0.1"
def mock_discover(self): async def mock_discover(self):
self.datagram_received( self.datagram_received(
XorEncryption.encrypt(json_dumps(data))[4:], (host, 9999) XorEncryption.encrypt(json_dumps(data))[4:], (host, 9999)
) )
@ -204,7 +204,8 @@ async def test_discover_invalid_info(msg, data, mocker):
async def test_discover_send(mocker): async def test_discover_send(mocker):
"""Test discovery parameters.""" """Test discovery parameters."""
proto = _DiscoverProtocol() discovery_timeout = 0
proto = _DiscoverProtocol(discovery_timeout=discovery_timeout)
assert proto.discovery_packets == 3 assert proto.discovery_packets == 3
assert proto.target_1 == ("255.255.255.255", 9999) assert proto.target_1 == ("255.255.255.255", 9999)
transport = mocker.patch.object(proto, "transport") transport = mocker.patch.object(proto, "transport")
@ -299,22 +300,25 @@ async def test_discover_single_authentication(discovery_mock, mocker):
@new_discovery @new_discovery
async def test_device_update_from_new_discovery_info(discovery_data): async def test_device_update_from_new_discovery_info(discovery_data):
device = IotDevice("127.0.0.7") """Make sure that new discovery devices update from discovery info correctly."""
device_class = Discover._get_device_class(discovery_data)
device = device_class("127.0.0.1")
discover_info = DiscoveryResult(**discovery_data["result"]) discover_info = DiscoveryResult(**discovery_data["result"])
discover_dump = discover_info.get_dict() discover_dump = discover_info.get_dict()
discover_dump["alias"] = "foobar" model, _, _ = discover_dump["device_model"].partition("(")
discover_dump["model"] = discover_dump["device_model"] discover_dump["model"] = model
device.update_from_discover_info(discover_dump) device.update_from_discover_info(discover_dump)
assert device.alias == "foobar"
assert device.mac == discover_dump["mac"].replace("-", ":") assert device.mac == discover_dump["mac"].replace("-", ":")
assert device.model == discover_dump["device_model"] assert device.model == model
with pytest.raises( # TODO implement requires_update for SmartDevice
SmartDeviceException, if isinstance(device, IotDevice):
match=re.escape("You need to await update() to access the data"), with pytest.raises(
): SmartDeviceException,
assert device.supported_modules match=re.escape("You need to await update() to access the data"),
):
assert device.supported_modules
async def test_discover_single_http_client(discovery_mock, mocker): async def test_discover_single_http_client(discovery_mock, mocker):
@ -335,7 +339,7 @@ async def test_discover_single_http_client(discovery_mock, mocker):
async def test_discover_http_client(discovery_mock, mocker): async def test_discover_http_client(discovery_mock, mocker):
"""Make sure that discover_single returns an initialized SmartDevice instance.""" """Make sure that discover returns an initialized SmartDevice instance."""
host = "127.0.0.1" host = "127.0.0.1"
discovery_mock.ip = host discovery_mock.ip = host
@ -403,31 +407,24 @@ class FakeDatagramTransport(asyncio.DatagramTransport):
@pytest.mark.parametrize("port", [9999, 20002]) @pytest.mark.parametrize("port", [9999, 20002])
@pytest.mark.parametrize("do_not_reply_count", [0, 1, 2, 3, 4]) @pytest.mark.parametrize("do_not_reply_count", [0, 1, 2, 3, 4])
async def test_do_discover_drop_packets(mocker, port, do_not_reply_count): async def test_do_discover_drop_packets(mocker, port, do_not_reply_count):
"""Make sure that discover_single handles authenticating devices correctly.""" """Make sure that _DiscoverProtocol handles authenticating devices correctly."""
host = "127.0.0.1" host = "127.0.0.1"
discovery_timeout = 1 discovery_timeout = 0
event = asyncio.Event()
dp = _DiscoverProtocol( dp = _DiscoverProtocol(
target=host, target=host,
discovery_timeout=discovery_timeout, discovery_timeout=discovery_timeout,
discovery_packets=5, discovery_packets=5,
discovered_event=event,
) )
ft = FakeDatagramTransport(dp, port, do_not_reply_count) ft = FakeDatagramTransport(dp, port, do_not_reply_count)
dp.connection_made(ft) dp.connection_made(ft)
timed_out = False await dp.wait_for_discovery_to_complete()
try:
async with asyncio_timeout(discovery_timeout):
await event.wait()
except asyncio.TimeoutError:
timed_out = True
await asyncio.sleep(0) await asyncio.sleep(0)
assert ft.send_count == do_not_reply_count + 1 assert ft.send_count == do_not_reply_count + 1
assert dp.discover_task.done() assert dp.discover_task.done()
assert timed_out is False assert dp.discover_task.cancelled()
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -436,27 +433,69 @@ async def test_do_discover_drop_packets(mocker, port, do_not_reply_count):
ids=["unknownport", "unsupporteddevice"], ids=["unknownport", "unsupporteddevice"],
) )
async def test_do_discover_invalid(mocker, port, will_timeout): async def test_do_discover_invalid(mocker, port, will_timeout):
"""Make sure that discover_single handles authenticating devices correctly.""" """Make sure that _DiscoverProtocol handles invalid devices correctly."""
host = "127.0.0.1" host = "127.0.0.1"
discovery_timeout = 1 discovery_timeout = 0
event = asyncio.Event()
dp = _DiscoverProtocol( dp = _DiscoverProtocol(
target=host, target=host,
discovery_timeout=discovery_timeout, discovery_timeout=discovery_timeout,
discovery_packets=5, discovery_packets=5,
discovered_event=event,
) )
ft = FakeDatagramTransport(dp, port, 0, unsupported=True) ft = FakeDatagramTransport(dp, port, 0, unsupported=True)
dp.connection_made(ft) dp.connection_made(ft)
timed_out = False await dp.wait_for_discovery_to_complete()
try:
async with asyncio_timeout(15):
await event.wait()
except asyncio.TimeoutError:
timed_out = True
await asyncio.sleep(0) await asyncio.sleep(0)
assert dp.discover_task.done() assert dp.discover_task.done()
assert timed_out is will_timeout assert dp.discover_task.cancelled() != will_timeout
async def test_discover_propogates_task_exceptions(discovery_mock):
"""Make sure that discover propogates callback exceptions."""
discovery_timeout = 0
async def on_discovered(dev):
raise SmartDeviceException("Dummy exception")
with pytest.raises(SmartDeviceException):
await Discover.discover(
discovery_timeout=discovery_timeout, on_discovered=on_discovered
)
async def test_do_discover_no_connection(mocker):
"""Make sure that if the datagram connection doesnt start a TimeoutError is raised."""
host = "127.0.0.1"
discovery_timeout = 0
mocker.patch.object(_DiscoverProtocol, "DISCOVERY_START_TIMEOUT", 0)
dp = _DiscoverProtocol(
target=host,
discovery_timeout=discovery_timeout,
discovery_packets=5,
)
# Normally tests would simulate connection as per below
# ft = FakeDatagramTransport(dp, port, 0, unsupported=True)
# dp.connection_made(ft)
with pytest.raises(asyncio.TimeoutError):
await dp.wait_for_discovery_to_complete()
async def test_do_discover_external_cancel(mocker):
"""Make sure that a cancel other than when target is discovered propogates."""
host = "127.0.0.1"
discovery_timeout = 1
dp = _DiscoverProtocol(
target=host,
discovery_timeout=discovery_timeout,
discovery_packets=1,
)
# Normally tests would simulate connection as per below
ft = FakeDatagramTransport(dp, 9999, 1, unsupported=True)
dp.connection_made(ft)
with pytest.raises(asyncio.TimeoutError):
async with asyncio_timeout(0):
await dp.wait_for_discovery_to_complete()