From ee487ad837f51cea7803009dede3e91ffb9cf54f Mon Sep 17 00:00:00 2001 From: Steven B <51370195+sdb9696@users.noreply.github.com> Date: Mon, 22 Jan 2024 17:25:23 +0000 Subject: [PATCH] Sleep between discovery packets (#656) * Sleep between discovery packets * Add tests --- kasa/discover.py | 39 +++++++---- kasa/tests/test_discovery.py | 121 ++++++++++++++++++++++++++++++++++- 2 files changed, 147 insertions(+), 13 deletions(-) diff --git a/kasa/discover.py b/kasa/discover.py index fca578a3..8b58d4bd 100755 --- a/kasa/discover.py +++ b/kasa/discover.py @@ -49,6 +49,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): on_discovered: Optional[OnDiscoveredCallable] = None, target: str = "255.255.255.255", discovery_packets: int = 3, + discovery_timeout: int = 5, interface: Optional[str] = None, on_unsupported: Optional[ Callable[[UnsupportedDeviceException], Awaitable[None]] @@ -65,7 +66,8 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): self.port = port self.discovery_port = port or Discover.DISCOVERY_PORT - self.target = (target, self.discovery_port) + self.target = target + self.target_1 = (target, self.discovery_port) self.target_2 = (target, Discover.DISCOVERY_PORT_2) self.discovered_devices = {} @@ -75,7 +77,9 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): self.discovered_event = discovered_event self.credentials = credentials self.timeout = timeout + self.discovery_timeout = discovery_timeout self.seen_hosts: Set[str] = set() + self.discover_task: Optional[asyncio.Task] = None def connection_made(self, transport) -> None: """Set socket options for broadcasting.""" @@ -93,16 +97,21 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): socket.SOL_SOCKET, socket.SO_BINDTODEVICE, self.interface.encode() ) - self.do_discover() + self.discover_task = asyncio.create_task(self.do_discover()) - def do_discover(self) -> None: + async def do_discover(self) -> None: """Send number of discovery datagrams.""" req = json_dumps(Discover.DISCOVERY_QUERY) _LOGGER.debug("[DISCOVERY] %s >> %s", self.target, Discover.DISCOVERY_QUERY) encrypted_req = TPLinkSmartHomeProtocol.encrypt(req) - for _i in range(self.discovery_packets): - self.transport.sendto(encrypted_req[4:], self.target) # type: ignore + sleep_between_packets = self.discovery_timeout / self.discovery_packets + for i in range(self.discovery_packets): + if self.target in self.seen_hosts: # Stop sending for discover_single + break + self.transport.sendto(encrypted_req[4:], self.target_1) # type: ignore self.transport.sendto(Discover.DISCOVERY_QUERY_2, self.target_2) # type: ignore + if i < self.discovery_packets - 1: + await asyncio.sleep(sleep_between_packets) def datagram_received(self, data, addr) -> None: """Handle discovery responses.""" @@ -132,14 +141,12 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): self.unsupported_device_exceptions[ip] = udex if self.on_unsupported is not None: asyncio.ensure_future(self.on_unsupported(udex)) - if self.discovered_event is not None: - self.discovered_event.set() + self._handle_discovered_event() return except SmartDeviceException as ex: _LOGGER.debug(f"[DISCOVERY] Unable to find device type for {ip}: {ex}") self.invalid_device_exceptions[ip] = ex - if self.discovered_event is not None: - self.discovered_event.set() + self._handle_discovered_event() return self.discovered_devices[ip] = device @@ -147,15 +154,23 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): if self.on_discovered is not None: asyncio.ensure_future(self.on_discovered(device)) + self._handle_discovered_event() + + def _handle_discovered_event(self): + """If discovered_event is available set it and cancel discover_task.""" if self.discovered_event is not None: + if self.discover_task: + self.discover_task.cancel() self.discovered_event.set() def error_received(self, ex): """Handle asyncio.Protocol errors.""" _LOGGER.error("Got error: %s", ex) - def connection_lost(self, ex): - """NOP implementation of connection lost.""" + def connection_lost(self, ex): # pragma: no cover + """Cancel the discover task if running.""" + if self.discover_task: + self.discover_task.cancel() class Discover: @@ -260,6 +275,7 @@ class Discover: on_unsupported=on_unsupported, credentials=credentials, timeout=timeout, + discovery_timeout=discovery_timeout, port=port, ), local_addr=("0.0.0.0", 0), # noqa: S104 @@ -334,6 +350,7 @@ class Discover: discovered_event=event, credentials=credentials, timeout=timeout, + discovery_timeout=discovery_timeout, ), local_addr=("0.0.0.0", 0), # noqa: S104 ) diff --git a/kasa/tests/test_discovery.py b/kasa/tests/test_discovery.py index 071a6503..2916e60a 100644 --- a/kasa/tests/test_discovery.py +++ b/kasa/tests/test_discovery.py @@ -1,10 +1,13 @@ # type: ignore +import asyncio import logging import re import socket +from unittest.mock import MagicMock import aiohttp import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 +from async_timeout import timeout as asyncio_timeout from kasa import ( Credentials, @@ -12,6 +15,7 @@ from kasa import ( Discover, SmartDevice, SmartDeviceException, + TPLinkSmartHomeProtocol, protocol, ) from kasa.deviceconfig import ( @@ -198,9 +202,9 @@ async def test_discover_send(mocker): """Test discovery parameters.""" proto = _DiscoverProtocol() assert proto.discovery_packets == 3 - assert proto.target == ("255.255.255.255", 9999) + assert proto.target_1 == ("255.255.255.255", 9999) transport = mocker.patch.object(proto, "transport") - proto.do_discover() + await proto.do_discover() assert transport.sendto.call_count == proto.discovery_packets * 2 @@ -341,3 +345,116 @@ async def test_discover_http_client(discovery_mock, mocker): assert x.protocol._transport._http_client.client != http_client x.config.http_client = http_client assert x.protocol._transport._http_client.client == http_client + + +LEGACY_DISCOVER_DATA = { + "system": { + "get_sysinfo": { + "alias": "#MASKED_NAME#", + "dev_name": "Smart Wi-Fi Plug", + "deviceId": "0000000000000000000000000000000000000000", + "err_code": 0, + "hwId": "00000000000000000000000000000000", + "hw_ver": "0.0", + "mac": "00:00:00:00:00:00", + "mic_type": "IOT.SMARTPLUGSWITCH", + "model": "HS100(UK)", + "sw_ver": "1.1.0 Build 201016 Rel.175121", + "updating": 0, + } + } +} + + +class FakeDatagramTransport(asyncio.DatagramTransport): + GHOST_PORT = 8888 + + def __init__(self, dp, port, do_not_reply_count, unsupported=False): + self.dp = dp + self.port = port + self.do_not_reply_count = do_not_reply_count + self.send_count = 0 + if port == 9999: + self.datagram = TPLinkSmartHomeProtocol.encrypt( + json_dumps(LEGACY_DISCOVER_DATA) + )[4:] + elif port == 20002: + discovery_data = UNSUPPORTED if unsupported else AUTHENTICATION_DATA_KLAP + self.datagram = ( + b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8" + + json_dumps(discovery_data).encode() + ) + else: + self.datagram = {"foo": "bar"} + + def get_extra_info(self, name, default=None): + return MagicMock() + + def sendto(self, data, addr=None): + ip, port = addr + if port == self.port or self.port == self.GHOST_PORT: + self.send_count += 1 + if self.send_count > self.do_not_reply_count: + self.dp.datagram_received(self.datagram, (ip, self.port)) + + +@pytest.mark.parametrize("port", [9999, 20002]) +@pytest.mark.parametrize("do_not_reply_count", [0, 1, 2, 3, 4]) +async def test_do_discover_drop_packets(mocker, port, do_not_reply_count): + """Make sure that discover_single handles authenticating devices correctly.""" + host = "127.0.0.1" + discovery_timeout = 1 + + event = asyncio.Event() + dp = _DiscoverProtocol( + target=host, + discovery_timeout=discovery_timeout, + discovery_packets=5, + discovered_event=event, + ) + ft = FakeDatagramTransport(dp, port, do_not_reply_count) + dp.connection_made(ft) + + timed_out = False + try: + async with asyncio_timeout(discovery_timeout): + await event.wait() + except asyncio.TimeoutError: + timed_out = True + + await asyncio.sleep(0) + assert ft.send_count == do_not_reply_count + 1 + assert dp.discover_task.done() + assert timed_out is False + + +@pytest.mark.parametrize( + "port, will_timeout", + [(FakeDatagramTransport.GHOST_PORT, True), (20002, False)], + ids=["unknownport", "unsupporteddevice"], +) +async def test_do_discover_invalid(mocker, port, will_timeout): + """Make sure that discover_single handles authenticating devices correctly.""" + host = "127.0.0.1" + discovery_timeout = 1 + + event = asyncio.Event() + dp = _DiscoverProtocol( + target=host, + discovery_timeout=discovery_timeout, + discovery_packets=5, + discovered_event=event, + ) + ft = FakeDatagramTransport(dp, port, 0, unsupported=True) + dp.connection_made(ft) + + timed_out = False + try: + async with asyncio_timeout(15): + await event.wait() + except asyncio.TimeoutError: + timed_out = True + + await asyncio.sleep(0) + assert dp.discover_task.done() + assert timed_out is will_timeout