mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-08-09 20:24:02 +00:00
Sleep between discovery packets (#656)
* Sleep between discovery packets * Add tests
This commit is contained in:
@@ -49,6 +49,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
|
||||
on_discovered: Optional[OnDiscoveredCallable] = None,
|
||||
target: str = "255.255.255.255",
|
||||
discovery_packets: int = 3,
|
||||
discovery_timeout: int = 5,
|
||||
interface: Optional[str] = None,
|
||||
on_unsupported: Optional[
|
||||
Callable[[UnsupportedDeviceException], Awaitable[None]]
|
||||
@@ -65,7 +66,8 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
|
||||
|
||||
self.port = port
|
||||
self.discovery_port = port or Discover.DISCOVERY_PORT
|
||||
self.target = (target, self.discovery_port)
|
||||
self.target = target
|
||||
self.target_1 = (target, self.discovery_port)
|
||||
self.target_2 = (target, Discover.DISCOVERY_PORT_2)
|
||||
|
||||
self.discovered_devices = {}
|
||||
@@ -75,7 +77,9 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
|
||||
self.discovered_event = discovered_event
|
||||
self.credentials = credentials
|
||||
self.timeout = timeout
|
||||
self.discovery_timeout = discovery_timeout
|
||||
self.seen_hosts: Set[str] = set()
|
||||
self.discover_task: Optional[asyncio.Task] = None
|
||||
|
||||
def connection_made(self, transport) -> None:
|
||||
"""Set socket options for broadcasting."""
|
||||
@@ -93,16 +97,21 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
|
||||
socket.SOL_SOCKET, socket.SO_BINDTODEVICE, self.interface.encode()
|
||||
)
|
||||
|
||||
self.do_discover()
|
||||
self.discover_task = asyncio.create_task(self.do_discover())
|
||||
|
||||
def do_discover(self) -> None:
|
||||
async def do_discover(self) -> None:
|
||||
"""Send number of discovery datagrams."""
|
||||
req = json_dumps(Discover.DISCOVERY_QUERY)
|
||||
_LOGGER.debug("[DISCOVERY] %s >> %s", self.target, Discover.DISCOVERY_QUERY)
|
||||
encrypted_req = TPLinkSmartHomeProtocol.encrypt(req)
|
||||
for _i in range(self.discovery_packets):
|
||||
self.transport.sendto(encrypted_req[4:], self.target) # type: ignore
|
||||
sleep_between_packets = self.discovery_timeout / self.discovery_packets
|
||||
for i in range(self.discovery_packets):
|
||||
if self.target in self.seen_hosts: # Stop sending for discover_single
|
||||
break
|
||||
self.transport.sendto(encrypted_req[4:], self.target_1) # type: ignore
|
||||
self.transport.sendto(Discover.DISCOVERY_QUERY_2, self.target_2) # type: ignore
|
||||
if i < self.discovery_packets - 1:
|
||||
await asyncio.sleep(sleep_between_packets)
|
||||
|
||||
def datagram_received(self, data, addr) -> None:
|
||||
"""Handle discovery responses."""
|
||||
@@ -132,14 +141,12 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
|
||||
self.unsupported_device_exceptions[ip] = udex
|
||||
if self.on_unsupported is not None:
|
||||
asyncio.ensure_future(self.on_unsupported(udex))
|
||||
if self.discovered_event is not None:
|
||||
self.discovered_event.set()
|
||||
self._handle_discovered_event()
|
||||
return
|
||||
except SmartDeviceException as ex:
|
||||
_LOGGER.debug(f"[DISCOVERY] Unable to find device type for {ip}: {ex}")
|
||||
self.invalid_device_exceptions[ip] = ex
|
||||
if self.discovered_event is not None:
|
||||
self.discovered_event.set()
|
||||
self._handle_discovered_event()
|
||||
return
|
||||
|
||||
self.discovered_devices[ip] = device
|
||||
@@ -147,15 +154,23 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
|
||||
if self.on_discovered is not None:
|
||||
asyncio.ensure_future(self.on_discovered(device))
|
||||
|
||||
self._handle_discovered_event()
|
||||
|
||||
def _handle_discovered_event(self):
|
||||
"""If discovered_event is available set it and cancel discover_task."""
|
||||
if self.discovered_event is not None:
|
||||
if self.discover_task:
|
||||
self.discover_task.cancel()
|
||||
self.discovered_event.set()
|
||||
|
||||
def error_received(self, ex):
|
||||
"""Handle asyncio.Protocol errors."""
|
||||
_LOGGER.error("Got error: %s", ex)
|
||||
|
||||
def connection_lost(self, ex):
|
||||
"""NOP implementation of connection lost."""
|
||||
def connection_lost(self, ex): # pragma: no cover
|
||||
"""Cancel the discover task if running."""
|
||||
if self.discover_task:
|
||||
self.discover_task.cancel()
|
||||
|
||||
|
||||
class Discover:
|
||||
@@ -260,6 +275,7 @@ class Discover:
|
||||
on_unsupported=on_unsupported,
|
||||
credentials=credentials,
|
||||
timeout=timeout,
|
||||
discovery_timeout=discovery_timeout,
|
||||
port=port,
|
||||
),
|
||||
local_addr=("0.0.0.0", 0), # noqa: S104
|
||||
@@ -334,6 +350,7 @@ class Discover:
|
||||
discovered_event=event,
|
||||
credentials=credentials,
|
||||
timeout=timeout,
|
||||
discovery_timeout=discovery_timeout,
|
||||
),
|
||||
local_addr=("0.0.0.0", 0), # noqa: S104
|
||||
)
|
||||
|
Reference in New Issue
Block a user