Sleep between discovery packets (#656)

* Sleep between discovery packets

* Add tests
This commit is contained in:
Steven B
2024-01-22 17:25:23 +00:00
committed by GitHub
parent 6b0a72d5a7
commit ee487ad837
2 changed files with 147 additions and 13 deletions

View File

@@ -49,6 +49,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
on_discovered: Optional[OnDiscoveredCallable] = None,
target: str = "255.255.255.255",
discovery_packets: int = 3,
discovery_timeout: int = 5,
interface: Optional[str] = None,
on_unsupported: Optional[
Callable[[UnsupportedDeviceException], Awaitable[None]]
@@ -65,7 +66,8 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
self.port = port
self.discovery_port = port or Discover.DISCOVERY_PORT
self.target = (target, self.discovery_port)
self.target = target
self.target_1 = (target, self.discovery_port)
self.target_2 = (target, Discover.DISCOVERY_PORT_2)
self.discovered_devices = {}
@@ -75,7 +77,9 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
self.discovered_event = discovered_event
self.credentials = credentials
self.timeout = timeout
self.discovery_timeout = discovery_timeout
self.seen_hosts: Set[str] = set()
self.discover_task: Optional[asyncio.Task] = None
def connection_made(self, transport) -> None:
"""Set socket options for broadcasting."""
@@ -93,16 +97,21 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
socket.SOL_SOCKET, socket.SO_BINDTODEVICE, self.interface.encode()
)
self.do_discover()
self.discover_task = asyncio.create_task(self.do_discover())
def do_discover(self) -> None:
async def do_discover(self) -> None:
"""Send number of discovery datagrams."""
req = json_dumps(Discover.DISCOVERY_QUERY)
_LOGGER.debug("[DISCOVERY] %s >> %s", self.target, Discover.DISCOVERY_QUERY)
encrypted_req = TPLinkSmartHomeProtocol.encrypt(req)
for _i in range(self.discovery_packets):
self.transport.sendto(encrypted_req[4:], self.target) # type: ignore
sleep_between_packets = self.discovery_timeout / self.discovery_packets
for i in range(self.discovery_packets):
if self.target in self.seen_hosts: # Stop sending for discover_single
break
self.transport.sendto(encrypted_req[4:], self.target_1) # type: ignore
self.transport.sendto(Discover.DISCOVERY_QUERY_2, self.target_2) # type: ignore
if i < self.discovery_packets - 1:
await asyncio.sleep(sleep_between_packets)
def datagram_received(self, data, addr) -> None:
"""Handle discovery responses."""
@@ -132,14 +141,12 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
self.unsupported_device_exceptions[ip] = udex
if self.on_unsupported is not None:
asyncio.ensure_future(self.on_unsupported(udex))
if self.discovered_event is not None:
self.discovered_event.set()
self._handle_discovered_event()
return
except SmartDeviceException as ex:
_LOGGER.debug(f"[DISCOVERY] Unable to find device type for {ip}: {ex}")
self.invalid_device_exceptions[ip] = ex
if self.discovered_event is not None:
self.discovered_event.set()
self._handle_discovered_event()
return
self.discovered_devices[ip] = device
@@ -147,15 +154,23 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
if self.on_discovered is not None:
asyncio.ensure_future(self.on_discovered(device))
self._handle_discovered_event()
def _handle_discovered_event(self):
"""If discovered_event is available set it and cancel discover_task."""
if self.discovered_event is not None:
if self.discover_task:
self.discover_task.cancel()
self.discovered_event.set()
def error_received(self, ex):
"""Handle asyncio.Protocol errors."""
_LOGGER.error("Got error: %s", ex)
def connection_lost(self, ex):
"""NOP implementation of connection lost."""
def connection_lost(self, ex): # pragma: no cover
"""Cancel the discover task if running."""
if self.discover_task:
self.discover_task.cancel()
class Discover:
@@ -260,6 +275,7 @@ class Discover:
on_unsupported=on_unsupported,
credentials=credentials,
timeout=timeout,
discovery_timeout=discovery_timeout,
port=port,
),
local_addr=("0.0.0.0", 0), # noqa: S104
@@ -334,6 +350,7 @@ class Discover:
discovered_event=event,
credentials=credentials,
timeout=timeout,
discovery_timeout=discovery_timeout,
),
local_addr=("0.0.0.0", 0), # noqa: S104
)