mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-01-24 05:37:59 +00:00
Sleep between discovery packets (#656)
* Sleep between discovery packets * Add tests
This commit is contained in:
parent
6b0a72d5a7
commit
ee487ad837
@ -49,6 +49,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
|
|||||||
on_discovered: Optional[OnDiscoveredCallable] = None,
|
on_discovered: Optional[OnDiscoveredCallable] = None,
|
||||||
target: str = "255.255.255.255",
|
target: str = "255.255.255.255",
|
||||||
discovery_packets: int = 3,
|
discovery_packets: int = 3,
|
||||||
|
discovery_timeout: int = 5,
|
||||||
interface: Optional[str] = None,
|
interface: Optional[str] = None,
|
||||||
on_unsupported: Optional[
|
on_unsupported: Optional[
|
||||||
Callable[[UnsupportedDeviceException], Awaitable[None]]
|
Callable[[UnsupportedDeviceException], Awaitable[None]]
|
||||||
@ -65,7 +66,8 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
|
|||||||
|
|
||||||
self.port = port
|
self.port = port
|
||||||
self.discovery_port = port or Discover.DISCOVERY_PORT
|
self.discovery_port = port or Discover.DISCOVERY_PORT
|
||||||
self.target = (target, self.discovery_port)
|
self.target = target
|
||||||
|
self.target_1 = (target, self.discovery_port)
|
||||||
self.target_2 = (target, Discover.DISCOVERY_PORT_2)
|
self.target_2 = (target, Discover.DISCOVERY_PORT_2)
|
||||||
|
|
||||||
self.discovered_devices = {}
|
self.discovered_devices = {}
|
||||||
@ -75,7 +77,9 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
|
|||||||
self.discovered_event = discovered_event
|
self.discovered_event = discovered_event
|
||||||
self.credentials = credentials
|
self.credentials = credentials
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
|
self.discovery_timeout = discovery_timeout
|
||||||
self.seen_hosts: Set[str] = set()
|
self.seen_hosts: Set[str] = set()
|
||||||
|
self.discover_task: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
def connection_made(self, transport) -> None:
|
def connection_made(self, transport) -> None:
|
||||||
"""Set socket options for broadcasting."""
|
"""Set socket options for broadcasting."""
|
||||||
@ -93,16 +97,21 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
|
|||||||
socket.SOL_SOCKET, socket.SO_BINDTODEVICE, self.interface.encode()
|
socket.SOL_SOCKET, socket.SO_BINDTODEVICE, self.interface.encode()
|
||||||
)
|
)
|
||||||
|
|
||||||
self.do_discover()
|
self.discover_task = asyncio.create_task(self.do_discover())
|
||||||
|
|
||||||
def do_discover(self) -> None:
|
async def do_discover(self) -> None:
|
||||||
"""Send number of discovery datagrams."""
|
"""Send number of discovery datagrams."""
|
||||||
req = json_dumps(Discover.DISCOVERY_QUERY)
|
req = json_dumps(Discover.DISCOVERY_QUERY)
|
||||||
_LOGGER.debug("[DISCOVERY] %s >> %s", self.target, Discover.DISCOVERY_QUERY)
|
_LOGGER.debug("[DISCOVERY] %s >> %s", self.target, Discover.DISCOVERY_QUERY)
|
||||||
encrypted_req = TPLinkSmartHomeProtocol.encrypt(req)
|
encrypted_req = TPLinkSmartHomeProtocol.encrypt(req)
|
||||||
for _i in range(self.discovery_packets):
|
sleep_between_packets = self.discovery_timeout / self.discovery_packets
|
||||||
self.transport.sendto(encrypted_req[4:], self.target) # type: ignore
|
for i in range(self.discovery_packets):
|
||||||
|
if self.target in self.seen_hosts: # Stop sending for discover_single
|
||||||
|
break
|
||||||
|
self.transport.sendto(encrypted_req[4:], self.target_1) # type: ignore
|
||||||
self.transport.sendto(Discover.DISCOVERY_QUERY_2, self.target_2) # type: ignore
|
self.transport.sendto(Discover.DISCOVERY_QUERY_2, self.target_2) # type: ignore
|
||||||
|
if i < self.discovery_packets - 1:
|
||||||
|
await asyncio.sleep(sleep_between_packets)
|
||||||
|
|
||||||
def datagram_received(self, data, addr) -> None:
|
def datagram_received(self, data, addr) -> None:
|
||||||
"""Handle discovery responses."""
|
"""Handle discovery responses."""
|
||||||
@ -132,14 +141,12 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
|
|||||||
self.unsupported_device_exceptions[ip] = udex
|
self.unsupported_device_exceptions[ip] = udex
|
||||||
if self.on_unsupported is not None:
|
if self.on_unsupported is not None:
|
||||||
asyncio.ensure_future(self.on_unsupported(udex))
|
asyncio.ensure_future(self.on_unsupported(udex))
|
||||||
if self.discovered_event is not None:
|
self._handle_discovered_event()
|
||||||
self.discovered_event.set()
|
|
||||||
return
|
return
|
||||||
except SmartDeviceException as ex:
|
except SmartDeviceException as ex:
|
||||||
_LOGGER.debug(f"[DISCOVERY] Unable to find device type for {ip}: {ex}")
|
_LOGGER.debug(f"[DISCOVERY] Unable to find device type for {ip}: {ex}")
|
||||||
self.invalid_device_exceptions[ip] = ex
|
self.invalid_device_exceptions[ip] = ex
|
||||||
if self.discovered_event is not None:
|
self._handle_discovered_event()
|
||||||
self.discovered_event.set()
|
|
||||||
return
|
return
|
||||||
|
|
||||||
self.discovered_devices[ip] = device
|
self.discovered_devices[ip] = device
|
||||||
@ -147,15 +154,23 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
|
|||||||
if self.on_discovered is not None:
|
if self.on_discovered is not None:
|
||||||
asyncio.ensure_future(self.on_discovered(device))
|
asyncio.ensure_future(self.on_discovered(device))
|
||||||
|
|
||||||
|
self._handle_discovered_event()
|
||||||
|
|
||||||
|
def _handle_discovered_event(self):
|
||||||
|
"""If discovered_event is available set it and cancel discover_task."""
|
||||||
if self.discovered_event is not None:
|
if self.discovered_event is not None:
|
||||||
|
if self.discover_task:
|
||||||
|
self.discover_task.cancel()
|
||||||
self.discovered_event.set()
|
self.discovered_event.set()
|
||||||
|
|
||||||
def error_received(self, ex):
|
def error_received(self, ex):
|
||||||
"""Handle asyncio.Protocol errors."""
|
"""Handle asyncio.Protocol errors."""
|
||||||
_LOGGER.error("Got error: %s", ex)
|
_LOGGER.error("Got error: %s", ex)
|
||||||
|
|
||||||
def connection_lost(self, ex):
|
def connection_lost(self, ex): # pragma: no cover
|
||||||
"""NOP implementation of connection lost."""
|
"""Cancel the discover task if running."""
|
||||||
|
if self.discover_task:
|
||||||
|
self.discover_task.cancel()
|
||||||
|
|
||||||
|
|
||||||
class Discover:
|
class Discover:
|
||||||
@ -260,6 +275,7 @@ class Discover:
|
|||||||
on_unsupported=on_unsupported,
|
on_unsupported=on_unsupported,
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
discovery_timeout=discovery_timeout,
|
||||||
port=port,
|
port=port,
|
||||||
),
|
),
|
||||||
local_addr=("0.0.0.0", 0), # noqa: S104
|
local_addr=("0.0.0.0", 0), # noqa: S104
|
||||||
@ -334,6 +350,7 @@ class Discover:
|
|||||||
discovered_event=event,
|
discovered_event=event,
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
discovery_timeout=discovery_timeout,
|
||||||
),
|
),
|
||||||
local_addr=("0.0.0.0", 0), # noqa: S104
|
local_addr=("0.0.0.0", 0), # noqa: S104
|
||||||
)
|
)
|
||||||
|
@ -1,10 +1,13 @@
|
|||||||
# type: ignore
|
# type: ignore
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import socket
|
import socket
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
|
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
|
||||||
|
from async_timeout import timeout as asyncio_timeout
|
||||||
|
|
||||||
from kasa import (
|
from kasa import (
|
||||||
Credentials,
|
Credentials,
|
||||||
@ -12,6 +15,7 @@ from kasa import (
|
|||||||
Discover,
|
Discover,
|
||||||
SmartDevice,
|
SmartDevice,
|
||||||
SmartDeviceException,
|
SmartDeviceException,
|
||||||
|
TPLinkSmartHomeProtocol,
|
||||||
protocol,
|
protocol,
|
||||||
)
|
)
|
||||||
from kasa.deviceconfig import (
|
from kasa.deviceconfig import (
|
||||||
@ -198,9 +202,9 @@ async def test_discover_send(mocker):
|
|||||||
"""Test discovery parameters."""
|
"""Test discovery parameters."""
|
||||||
proto = _DiscoverProtocol()
|
proto = _DiscoverProtocol()
|
||||||
assert proto.discovery_packets == 3
|
assert proto.discovery_packets == 3
|
||||||
assert proto.target == ("255.255.255.255", 9999)
|
assert proto.target_1 == ("255.255.255.255", 9999)
|
||||||
transport = mocker.patch.object(proto, "transport")
|
transport = mocker.patch.object(proto, "transport")
|
||||||
proto.do_discover()
|
await proto.do_discover()
|
||||||
assert transport.sendto.call_count == proto.discovery_packets * 2
|
assert transport.sendto.call_count == proto.discovery_packets * 2
|
||||||
|
|
||||||
|
|
||||||
@ -341,3 +345,116 @@ async def test_discover_http_client(discovery_mock, mocker):
|
|||||||
assert x.protocol._transport._http_client.client != http_client
|
assert x.protocol._transport._http_client.client != http_client
|
||||||
x.config.http_client = http_client
|
x.config.http_client = http_client
|
||||||
assert x.protocol._transport._http_client.client == http_client
|
assert x.protocol._transport._http_client.client == http_client
|
||||||
|
|
||||||
|
|
||||||
|
LEGACY_DISCOVER_DATA = {
|
||||||
|
"system": {
|
||||||
|
"get_sysinfo": {
|
||||||
|
"alias": "#MASKED_NAME#",
|
||||||
|
"dev_name": "Smart Wi-Fi Plug",
|
||||||
|
"deviceId": "0000000000000000000000000000000000000000",
|
||||||
|
"err_code": 0,
|
||||||
|
"hwId": "00000000000000000000000000000000",
|
||||||
|
"hw_ver": "0.0",
|
||||||
|
"mac": "00:00:00:00:00:00",
|
||||||
|
"mic_type": "IOT.SMARTPLUGSWITCH",
|
||||||
|
"model": "HS100(UK)",
|
||||||
|
"sw_ver": "1.1.0 Build 201016 Rel.175121",
|
||||||
|
"updating": 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class FakeDatagramTransport(asyncio.DatagramTransport):
|
||||||
|
GHOST_PORT = 8888
|
||||||
|
|
||||||
|
def __init__(self, dp, port, do_not_reply_count, unsupported=False):
|
||||||
|
self.dp = dp
|
||||||
|
self.port = port
|
||||||
|
self.do_not_reply_count = do_not_reply_count
|
||||||
|
self.send_count = 0
|
||||||
|
if port == 9999:
|
||||||
|
self.datagram = TPLinkSmartHomeProtocol.encrypt(
|
||||||
|
json_dumps(LEGACY_DISCOVER_DATA)
|
||||||
|
)[4:]
|
||||||
|
elif port == 20002:
|
||||||
|
discovery_data = UNSUPPORTED if unsupported else AUTHENTICATION_DATA_KLAP
|
||||||
|
self.datagram = (
|
||||||
|
b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8"
|
||||||
|
+ json_dumps(discovery_data).encode()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.datagram = {"foo": "bar"}
|
||||||
|
|
||||||
|
def get_extra_info(self, name, default=None):
|
||||||
|
return MagicMock()
|
||||||
|
|
||||||
|
def sendto(self, data, addr=None):
|
||||||
|
ip, port = addr
|
||||||
|
if port == self.port or self.port == self.GHOST_PORT:
|
||||||
|
self.send_count += 1
|
||||||
|
if self.send_count > self.do_not_reply_count:
|
||||||
|
self.dp.datagram_received(self.datagram, (ip, self.port))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("port", [9999, 20002])
|
||||||
|
@pytest.mark.parametrize("do_not_reply_count", [0, 1, 2, 3, 4])
|
||||||
|
async def test_do_discover_drop_packets(mocker, port, do_not_reply_count):
|
||||||
|
"""Make sure that discover_single handles authenticating devices correctly."""
|
||||||
|
host = "127.0.0.1"
|
||||||
|
discovery_timeout = 1
|
||||||
|
|
||||||
|
event = asyncio.Event()
|
||||||
|
dp = _DiscoverProtocol(
|
||||||
|
target=host,
|
||||||
|
discovery_timeout=discovery_timeout,
|
||||||
|
discovery_packets=5,
|
||||||
|
discovered_event=event,
|
||||||
|
)
|
||||||
|
ft = FakeDatagramTransport(dp, port, do_not_reply_count)
|
||||||
|
dp.connection_made(ft)
|
||||||
|
|
||||||
|
timed_out = False
|
||||||
|
try:
|
||||||
|
async with asyncio_timeout(discovery_timeout):
|
||||||
|
await event.wait()
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
timed_out = True
|
||||||
|
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
assert ft.send_count == do_not_reply_count + 1
|
||||||
|
assert dp.discover_task.done()
|
||||||
|
assert timed_out is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"port, will_timeout",
|
||||||
|
[(FakeDatagramTransport.GHOST_PORT, True), (20002, False)],
|
||||||
|
ids=["unknownport", "unsupporteddevice"],
|
||||||
|
)
|
||||||
|
async def test_do_discover_invalid(mocker, port, will_timeout):
|
||||||
|
"""Make sure that discover_single handles authenticating devices correctly."""
|
||||||
|
host = "127.0.0.1"
|
||||||
|
discovery_timeout = 1
|
||||||
|
|
||||||
|
event = asyncio.Event()
|
||||||
|
dp = _DiscoverProtocol(
|
||||||
|
target=host,
|
||||||
|
discovery_timeout=discovery_timeout,
|
||||||
|
discovery_packets=5,
|
||||||
|
discovered_event=event,
|
||||||
|
)
|
||||||
|
ft = FakeDatagramTransport(dp, port, 0, unsupported=True)
|
||||||
|
dp.connection_made(ft)
|
||||||
|
|
||||||
|
timed_out = False
|
||||||
|
try:
|
||||||
|
async with asyncio_timeout(15):
|
||||||
|
await event.wait()
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
timed_out = True
|
||||||
|
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
assert dp.discover_task.done()
|
||||||
|
assert timed_out is will_timeout
|
||||||
|
Loading…
Reference in New Issue
Block a user