mirror of
https://github.com/python-kasa/python-kasa.git
synced 2024-12-22 19:23:34 +00:00
Fix discovery cli to print devices not printed during discovery timeout (#670)
* Fix discovery cli to print devices not printed during discovery * Fix tests * Fix print exceptions not being propagated * Fix tests * Reduce test discover_send time * Simplify wait logic * Add tests * Remove sleep loop and make auth failed a list
This commit is contained in:
parent
0d119e63d0
commit
215b8d4e4f
@ -444,12 +444,12 @@ async def discover(ctx):
|
||||
_echo_discovery_info(dev._discovery_info)
|
||||
echo()
|
||||
else:
|
||||
discovered[dev.host] = dev.internal_state
|
||||
ctx.parent.obj = dev
|
||||
await ctx.parent.invoke(state)
|
||||
discovered[dev.host] = dev.internal_state
|
||||
echo()
|
||||
|
||||
await Discover.discover(
|
||||
discovered_devices = await Discover.discover(
|
||||
target=target,
|
||||
discovery_timeout=discovery_timeout,
|
||||
on_discovered=print_discovered,
|
||||
@ -459,6 +459,9 @@ async def discover(ctx):
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
for device in discovered_devices.values():
|
||||
await device.protocol.close()
|
||||
|
||||
echo(f"Found {len(discovered)} devices")
|
||||
if unsupported:
|
||||
echo(f"Found {len(unsupported)} unsupported devices")
|
||||
|
@ -4,7 +4,7 @@ import binascii
|
||||
import ipaddress
|
||||
import logging
|
||||
import socket
|
||||
from typing import Awaitable, Callable, Dict, Optional, Set, Type, cast
|
||||
from typing import Awaitable, Callable, Dict, List, Optional, Set, Type, cast
|
||||
|
||||
# When support for cpython older than 3.11 is dropped
|
||||
# async_timeout can be replaced with asyncio.timeout
|
||||
@ -46,6 +46,8 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
|
||||
This is internal class, use :func:`Discover.discover`: instead.
|
||||
"""
|
||||
|
||||
DISCOVERY_START_TIMEOUT = 1
|
||||
|
||||
discovered_devices: DeviceDict
|
||||
|
||||
def __init__(
|
||||
@ -60,7 +62,6 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
|
||||
Callable[[UnsupportedDeviceException], Awaitable[None]]
|
||||
] = None,
|
||||
port: Optional[int] = None,
|
||||
discovered_event: Optional[asyncio.Event] = None,
|
||||
credentials: Optional[Credentials] = None,
|
||||
timeout: Optional[int] = None,
|
||||
) -> None:
|
||||
@ -79,12 +80,32 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
|
||||
self.unsupported_device_exceptions: Dict = {}
|
||||
self.invalid_device_exceptions: Dict = {}
|
||||
self.on_unsupported = on_unsupported
|
||||
self.discovered_event = discovered_event
|
||||
self.credentials = credentials
|
||||
self.timeout = timeout
|
||||
self.discovery_timeout = discovery_timeout
|
||||
self.seen_hosts: Set[str] = set()
|
||||
self.discover_task: Optional[asyncio.Task] = None
|
||||
self.callback_tasks: List[asyncio.Task] = []
|
||||
self.target_discovered: bool = False
|
||||
self._started_event = asyncio.Event()
|
||||
|
||||
def _run_callback_task(self, coro):
|
||||
task = asyncio.create_task(coro)
|
||||
self.callback_tasks.append(task)
|
||||
|
||||
async def wait_for_discovery_to_complete(self):
|
||||
"""Wait for the discovery task to complete."""
|
||||
# Give some time for connection_made event to be received
|
||||
async with asyncio_timeout(self.DISCOVERY_START_TIMEOUT):
|
||||
await self._started_event.wait()
|
||||
try:
|
||||
await self.discover_task
|
||||
except asyncio.CancelledError:
|
||||
# if target_discovered then cancel was called internally
|
||||
if not self.target_discovered:
|
||||
raise
|
||||
# Wait for any pending callbacks to complete
|
||||
await asyncio.gather(*self.callback_tasks)
|
||||
|
||||
def connection_made(self, transport) -> None:
|
||||
"""Set socket options for broadcasting."""
|
||||
@ -103,6 +124,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
|
||||
)
|
||||
|
||||
self.discover_task = asyncio.create_task(self.do_discover())
|
||||
self._started_event.set()
|
||||
|
||||
async def do_discover(self) -> None:
|
||||
"""Send number of discovery datagrams."""
|
||||
@ -110,12 +132,11 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
|
||||
_LOGGER.debug("[DISCOVERY] %s >> %s", self.target, Discover.DISCOVERY_QUERY)
|
||||
encrypted_req = XorEncryption.encrypt(req)
|
||||
sleep_between_packets = self.discovery_timeout / self.discovery_packets
|
||||
for i in range(self.discovery_packets):
|
||||
for _ in range(self.discovery_packets):
|
||||
if self.target in self.seen_hosts: # Stop sending for discover_single
|
||||
break
|
||||
self.transport.sendto(encrypted_req[4:], self.target_1) # type: ignore
|
||||
self.transport.sendto(Discover.DISCOVERY_QUERY_2, self.target_2) # type: ignore
|
||||
if i < self.discovery_packets - 1:
|
||||
await asyncio.sleep(sleep_between_packets)
|
||||
|
||||
def datagram_received(self, data, addr) -> None:
|
||||
@ -145,7 +166,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
|
||||
_LOGGER.debug("Unsupported device found at %s << %s", ip, udex)
|
||||
self.unsupported_device_exceptions[ip] = udex
|
||||
if self.on_unsupported is not None:
|
||||
asyncio.ensure_future(self.on_unsupported(udex))
|
||||
self._run_callback_task(self.on_unsupported(udex))
|
||||
self._handle_discovered_event()
|
||||
return
|
||||
except SmartDeviceException as ex:
|
||||
@ -157,16 +178,16 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
|
||||
self.discovered_devices[ip] = device
|
||||
|
||||
if self.on_discovered is not None:
|
||||
asyncio.ensure_future(self.on_discovered(device))
|
||||
self._run_callback_task(self.on_discovered(device))
|
||||
|
||||
self._handle_discovered_event()
|
||||
|
||||
def _handle_discovered_event(self):
|
||||
"""If discovered_event is available set it and cancel discover_task."""
|
||||
if self.discovered_event is not None:
|
||||
"""If target is in seen_hosts cancel discover_task."""
|
||||
if self.target in self.seen_hosts:
|
||||
self.target_discovered = True
|
||||
if self.discover_task:
|
||||
self.discover_task.cancel()
|
||||
self.discovered_event.set()
|
||||
|
||||
def error_received(self, ex):
|
||||
"""Handle asyncio.Protocol errors."""
|
||||
@ -289,7 +310,11 @@ class Discover:
|
||||
|
||||
try:
|
||||
_LOGGER.debug("Waiting %s seconds for responses...", discovery_timeout)
|
||||
await asyncio.sleep(discovery_timeout)
|
||||
await protocol.wait_for_discovery_to_complete()
|
||||
except SmartDeviceException as ex:
|
||||
for device in protocol.discovered_devices.values():
|
||||
await device.protocol.close()
|
||||
raise ex
|
||||
finally:
|
||||
transport.close()
|
||||
|
||||
@ -322,7 +347,6 @@ class Discover:
|
||||
:return: Object for querying/controlling found device.
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
event = asyncio.Event()
|
||||
|
||||
try:
|
||||
ipaddress.ip_address(host)
|
||||
@ -352,7 +376,6 @@ class Discover:
|
||||
lambda: _DiscoverProtocol(
|
||||
target=ip,
|
||||
port=port,
|
||||
discovered_event=event,
|
||||
credentials=credentials,
|
||||
timeout=timeout,
|
||||
discovery_timeout=discovery_timeout,
|
||||
@ -365,13 +388,7 @@ class Discover:
|
||||
_LOGGER.debug(
|
||||
"Waiting a total of %s seconds for responses...", discovery_timeout
|
||||
)
|
||||
|
||||
async with asyncio_timeout(discovery_timeout):
|
||||
await event.wait()
|
||||
except asyncio.TimeoutError as ex:
|
||||
raise TimeoutException(
|
||||
f"Timed out getting discovery response for {host}"
|
||||
) from ex
|
||||
await protocol.wait_for_discovery_to_complete()
|
||||
finally:
|
||||
transport.close()
|
||||
|
||||
@ -384,7 +401,7 @@ class Discover:
|
||||
elif ip in protocol.invalid_device_exceptions:
|
||||
raise protocol.invalid_device_exceptions[ip]
|
||||
else:
|
||||
raise SmartDeviceException(f"Unable to get discovery response for {host}")
|
||||
raise TimeoutException(f"Timed out getting discovery response for {host}")
|
||||
|
||||
@staticmethod
|
||||
def _get_device_class(info: dict) -> Type[Device]:
|
||||
|
@ -508,7 +508,7 @@ def discovery_mock(all_fixture_data, mocker):
|
||||
login_version,
|
||||
)
|
||||
|
||||
def mock_discover(self):
|
||||
async def mock_discover(self):
|
||||
port = (
|
||||
dm.port_override
|
||||
if dm.port_override and dm.discovery_port != 20002
|
||||
@ -561,7 +561,7 @@ def unsupported_device_info(request, mocker):
|
||||
discovery_data = request.param
|
||||
host = "127.0.0.1"
|
||||
|
||||
def mock_discover(self):
|
||||
async def mock_discover(self):
|
||||
if discovery_data:
|
||||
data = (
|
||||
b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8"
|
||||
|
@ -290,7 +290,7 @@ async def test_brightness(dev):
|
||||
@device_iot
|
||||
async def test_json_output(dev: Device, mocker):
|
||||
"""Test that the json output produces correct output."""
|
||||
mocker.patch("kasa.Discover.discover", return_value=[dev])
|
||||
mocker.patch("kasa.Discover.discover", return_value={"127.0.0.1": dev})
|
||||
runner = CliRunner()
|
||||
res = await runner.invoke(cli, ["--json", "state"], obj=dev)
|
||||
assert res.exit_code == 0
|
||||
@ -415,6 +415,26 @@ async def test_discover(discovery_mock, mocker):
|
||||
assert res.exit_code == 0
|
||||
|
||||
|
||||
async def test_discover_host(discovery_mock, mocker):
|
||||
"""Test discovery output."""
|
||||
runner = CliRunner()
|
||||
res = await runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"--discovery-timeout",
|
||||
0,
|
||||
"--host",
|
||||
"127.0.0.123",
|
||||
"--username",
|
||||
"foo",
|
||||
"--password",
|
||||
"bar",
|
||||
"--verbose",
|
||||
],
|
||||
)
|
||||
assert res.exit_code == 0
|
||||
|
||||
|
||||
async def test_discover_unsupported(unsupported_device_info):
|
||||
"""Test discovery output."""
|
||||
runner = CliRunner()
|
||||
|
@ -191,7 +191,7 @@ async def test_discover_invalid_info(msg, data, mocker):
|
||||
"""Make sure that invalid discovery information raises an exception."""
|
||||
host = "127.0.0.1"
|
||||
|
||||
def mock_discover(self):
|
||||
async def mock_discover(self):
|
||||
self.datagram_received(
|
||||
XorEncryption.encrypt(json_dumps(data))[4:], (host, 9999)
|
||||
)
|
||||
@ -204,7 +204,8 @@ async def test_discover_invalid_info(msg, data, mocker):
|
||||
|
||||
async def test_discover_send(mocker):
|
||||
"""Test discovery parameters."""
|
||||
proto = _DiscoverProtocol()
|
||||
discovery_timeout = 0
|
||||
proto = _DiscoverProtocol(discovery_timeout=discovery_timeout)
|
||||
assert proto.discovery_packets == 3
|
||||
assert proto.target_1 == ("255.255.255.255", 9999)
|
||||
transport = mocker.patch.object(proto, "transport")
|
||||
@ -299,17 +300,20 @@ async def test_discover_single_authentication(discovery_mock, mocker):
|
||||
|
||||
@new_discovery
|
||||
async def test_device_update_from_new_discovery_info(discovery_data):
|
||||
device = IotDevice("127.0.0.7")
|
||||
"""Make sure that new discovery devices update from discovery info correctly."""
|
||||
device_class = Discover._get_device_class(discovery_data)
|
||||
device = device_class("127.0.0.1")
|
||||
discover_info = DiscoveryResult(**discovery_data["result"])
|
||||
discover_dump = discover_info.get_dict()
|
||||
discover_dump["alias"] = "foobar"
|
||||
discover_dump["model"] = discover_dump["device_model"]
|
||||
model, _, _ = discover_dump["device_model"].partition("(")
|
||||
discover_dump["model"] = model
|
||||
device.update_from_discover_info(discover_dump)
|
||||
|
||||
assert device.alias == "foobar"
|
||||
assert device.mac == discover_dump["mac"].replace("-", ":")
|
||||
assert device.model == discover_dump["device_model"]
|
||||
assert device.model == model
|
||||
|
||||
# TODO implement requires_update for SmartDevice
|
||||
if isinstance(device, IotDevice):
|
||||
with pytest.raises(
|
||||
SmartDeviceException,
|
||||
match=re.escape("You need to await update() to access the data"),
|
||||
@ -335,7 +339,7 @@ async def test_discover_single_http_client(discovery_mock, mocker):
|
||||
|
||||
|
||||
async def test_discover_http_client(discovery_mock, mocker):
|
||||
"""Make sure that discover_single returns an initialized SmartDevice instance."""
|
||||
"""Make sure that discover returns an initialized SmartDevice instance."""
|
||||
host = "127.0.0.1"
|
||||
discovery_mock.ip = host
|
||||
|
||||
@ -403,31 +407,24 @@ class FakeDatagramTransport(asyncio.DatagramTransport):
|
||||
@pytest.mark.parametrize("port", [9999, 20002])
|
||||
@pytest.mark.parametrize("do_not_reply_count", [0, 1, 2, 3, 4])
|
||||
async def test_do_discover_drop_packets(mocker, port, do_not_reply_count):
|
||||
"""Make sure that discover_single handles authenticating devices correctly."""
|
||||
"""Make sure that _DiscoverProtocol handles authenticating devices correctly."""
|
||||
host = "127.0.0.1"
|
||||
discovery_timeout = 1
|
||||
discovery_timeout = 0
|
||||
|
||||
event = asyncio.Event()
|
||||
dp = _DiscoverProtocol(
|
||||
target=host,
|
||||
discovery_timeout=discovery_timeout,
|
||||
discovery_packets=5,
|
||||
discovered_event=event,
|
||||
)
|
||||
ft = FakeDatagramTransport(dp, port, do_not_reply_count)
|
||||
dp.connection_made(ft)
|
||||
|
||||
timed_out = False
|
||||
try:
|
||||
async with asyncio_timeout(discovery_timeout):
|
||||
await event.wait()
|
||||
except asyncio.TimeoutError:
|
||||
timed_out = True
|
||||
await dp.wait_for_discovery_to_complete()
|
||||
|
||||
await asyncio.sleep(0)
|
||||
assert ft.send_count == do_not_reply_count + 1
|
||||
assert dp.discover_task.done()
|
||||
assert timed_out is False
|
||||
assert dp.discover_task.cancelled()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -436,27 +433,69 @@ async def test_do_discover_drop_packets(mocker, port, do_not_reply_count):
|
||||
ids=["unknownport", "unsupporteddevice"],
|
||||
)
|
||||
async def test_do_discover_invalid(mocker, port, will_timeout):
|
||||
"""Make sure that discover_single handles authenticating devices correctly."""
|
||||
"""Make sure that _DiscoverProtocol handles invalid devices correctly."""
|
||||
host = "127.0.0.1"
|
||||
discovery_timeout = 1
|
||||
discovery_timeout = 0
|
||||
|
||||
event = asyncio.Event()
|
||||
dp = _DiscoverProtocol(
|
||||
target=host,
|
||||
discovery_timeout=discovery_timeout,
|
||||
discovery_packets=5,
|
||||
discovered_event=event,
|
||||
)
|
||||
ft = FakeDatagramTransport(dp, port, 0, unsupported=True)
|
||||
dp.connection_made(ft)
|
||||
|
||||
timed_out = False
|
||||
try:
|
||||
async with asyncio_timeout(15):
|
||||
await event.wait()
|
||||
except asyncio.TimeoutError:
|
||||
timed_out = True
|
||||
|
||||
await dp.wait_for_discovery_to_complete()
|
||||
await asyncio.sleep(0)
|
||||
assert dp.discover_task.done()
|
||||
assert timed_out is will_timeout
|
||||
assert dp.discover_task.cancelled() != will_timeout
|
||||
|
||||
|
||||
async def test_discover_propogates_task_exceptions(discovery_mock):
|
||||
"""Make sure that discover propogates callback exceptions."""
|
||||
discovery_timeout = 0
|
||||
|
||||
async def on_discovered(dev):
|
||||
raise SmartDeviceException("Dummy exception")
|
||||
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await Discover.discover(
|
||||
discovery_timeout=discovery_timeout, on_discovered=on_discovered
|
||||
)
|
||||
|
||||
|
||||
async def test_do_discover_no_connection(mocker):
|
||||
"""Make sure that if the datagram connection doesnt start a TimeoutError is raised."""
|
||||
host = "127.0.0.1"
|
||||
discovery_timeout = 0
|
||||
mocker.patch.object(_DiscoverProtocol, "DISCOVERY_START_TIMEOUT", 0)
|
||||
dp = _DiscoverProtocol(
|
||||
target=host,
|
||||
discovery_timeout=discovery_timeout,
|
||||
discovery_packets=5,
|
||||
)
|
||||
# Normally tests would simulate connection as per below
|
||||
# ft = FakeDatagramTransport(dp, port, 0, unsupported=True)
|
||||
# dp.connection_made(ft)
|
||||
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await dp.wait_for_discovery_to_complete()
|
||||
|
||||
|
||||
async def test_do_discover_external_cancel(mocker):
|
||||
"""Make sure that a cancel other than when target is discovered propogates."""
|
||||
host = "127.0.0.1"
|
||||
discovery_timeout = 1
|
||||
|
||||
dp = _DiscoverProtocol(
|
||||
target=host,
|
||||
discovery_timeout=discovery_timeout,
|
||||
discovery_packets=1,
|
||||
)
|
||||
# Normally tests would simulate connection as per below
|
||||
ft = FakeDatagramTransport(dp, 9999, 1, unsupported=True)
|
||||
dp.connection_made(ft)
|
||||
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
async with asyncio_timeout(0):
|
||||
await dp.wait_for_discovery_to_complete()
|
||||
|
Loading…
Reference in New Issue
Block a user