Add support for alternative discovery protocol (20002/udp) (#488)

This will broadcast the new discovery message on the new port and log any responses received as unsupported devices.
This commit is contained in:
sdb9696 2023-08-29 14:04:28 +01:00 committed by GitHub
parent 53021f07fe
commit 6055c29d74
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 219 additions and 25 deletions

View File

@ -103,6 +103,7 @@ def json_formatter_cb(result, **kwargs):
"--port",
envvar="KASA_PORT",
required=False,
type=int,
help="The port of the device to connect to.",
)
@click.option(
@ -138,7 +139,17 @@ def json_formatter_cb(result, **kwargs):
)
@click.version_option(package_name="python-kasa")
@click.pass_context
async def cli(ctx, host, port, alias, target, debug, type, json, discovery_timeout):
async def cli(
ctx,
host,
port,
alias,
target,
debug,
type,
json,
discovery_timeout,
):
"""A tool for controlling TP-Link smart home devices.""" # noqa
# no need to perform any checks if we are just displaying the help
if sys.argv[-1] == "--help":
@ -238,13 +249,29 @@ async def join(dev: SmartDevice, ssid, password, keytype):
@cli.command()
@click.option("--timeout", default=3, required=False)
@click.option(
"--show-unsupported",
envvar="KASA_SHOW_UNSUPPORTED",
required=False,
default=False,
is_flag=True,
help="Print out discovered unsupported devices",
)
@click.pass_context
async def discover(ctx, timeout):
async def discover(ctx, timeout, show_unsupported):
"""Discover devices in the network."""
target = ctx.parent.params["target"]
echo(f"Discovering devices on {target} for {timeout} seconds")
sem = asyncio.Semaphore()
discovered = dict()
unsupported = []
async def print_unsupported(data: Dict):
unsupported.append(data)
if show_unsupported:
echo(f"Found unsupported device (tapo/unknown encryption): {data}")
echo()
echo(f"Discovering devices on {target} for {timeout} seconds")
async def print_discovered(dev: SmartDevice):
await dev.update()
@ -255,9 +282,23 @@ async def discover(ctx, timeout):
echo()
await Discover.discover(
target=target, timeout=timeout, on_discovered=print_discovered
target=target,
timeout=timeout,
on_discovered=print_discovered,
on_unsupported=print_unsupported,
)
echo(f"Found {len(discovered)} devices")
if unsupported:
echo(
f"Found {len(unsupported)} unsupported devices"
+ (
""
if show_unsupported
else ", to show them use: kasa discover --show-unsupported"
)
)
return discovered

View File

@ -1,9 +1,15 @@
"""Discovery module for TP-Link Smart Home devices."""
import asyncio
import binascii
import logging
import socket
from typing import Awaitable, Callable, Dict, Optional, Type, cast
# When support for cpython older than 3.11 is dropped
# async_timeout can be replaced with asyncio.timeout
from async_timeout import timeout as asyncio_timeout
from kasa.exceptions import UnsupportedDeviceException
from kasa.json import dumps as json_dumps
from kasa.json import loads as json_loads
from kasa.protocol import TPLinkSmartHomeProtocol
@ -36,13 +42,22 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
target: str = "255.255.255.255",
discovery_packets: int = 3,
interface: Optional[str] = None,
on_unsupported: Optional[Callable[[Dict], Awaitable[None]]] = None,
port: Optional[int] = None,
discovered_event: Optional[asyncio.Event] = None,
):
self.transport = None
self.discovery_packets = discovery_packets
self.interface = interface
self.on_discovered = on_discovered
self.target = (target, Discover.DISCOVERY_PORT)
self.discovery_port = port or Discover.DISCOVERY_PORT
self.target = (target, self.discovery_port)
self.target_2 = (target, Discover.DISCOVERY_PORT_2)
self.discovered_devices = {}
self.unsupported_devices: Dict = {}
self.invalid_device_exceptions: Dict = {}
self.on_unsupported = on_unsupported
self.discovered_event = discovered_event
def connection_made(self, transport) -> None:
"""Set socket options for broadcasting."""
@ -69,23 +84,48 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
encrypted_req = TPLinkSmartHomeProtocol.encrypt(req)
for i in range(self.discovery_packets):
self.transport.sendto(encrypted_req[4:], self.target) # type: ignore
self.transport.sendto(Discover.DISCOVERY_QUERY_2, self.target_2) # type: ignore
def datagram_received(self, data, addr) -> None:
"""Handle discovery responses."""
ip, port = addr
if ip in self.discovered_devices:
if (
ip in self.discovered_devices
or ip in self.unsupported_devices
or ip in self.invalid_device_exceptions
):
return
info = json_loads(TPLinkSmartHomeProtocol.decrypt(data))
_LOGGER.debug("[DISCOVERY] %s << %s", ip, info)
if port == self.discovery_port:
info = json_loads(TPLinkSmartHomeProtocol.decrypt(data))
_LOGGER.debug("[DISCOVERY] %s << %s", ip, info)
elif port == Discover.DISCOVERY_PORT_2:
info = json_loads(data[16:])
self.unsupported_devices[ip] = info
if self.on_unsupported is not None:
asyncio.ensure_future(self.on_unsupported(info))
_LOGGER.debug("[DISCOVERY] Unsupported device found at %s << %s", ip, info)
if self.discovered_event is not None and "255" not in self.target[0].split(
"."
):
self.discovered_event.set()
return
try:
device_class = Discover._get_device_class(info)
except SmartDeviceException as ex:
_LOGGER.debug("Unable to find device type from %s: %s", info, ex)
_LOGGER.debug(
"[DISCOVERY] Unable to find device type from %s: %s", info, ex
)
self.invalid_device_exceptions[ip] = ex
if self.discovered_event is not None and "255" not in self.target[0].split(
"."
):
self.discovered_event.set()
return
device = device_class(ip)
device = device_class(ip, port=port)
device.update_from_discover_info(info)
self.discovered_devices[ip] = device
@ -93,6 +133,9 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
if self.on_discovered is not None:
asyncio.ensure_future(self.on_discovered(device))
if self.discovered_event is not None and "255" not in self.target[0].split("."):
self.discovered_event.set()
def error_received(self, ex):
"""Handle asyncio.Protocol errors."""
_LOGGER.error("Got error: %s", ex)
@ -142,6 +185,9 @@ class Discover:
"system": {"get_sysinfo": None},
}
DISCOVERY_PORT_2 = 20002
DISCOVERY_QUERY_2 = binascii.unhexlify("020000010000000000000000463cb5d3")
@staticmethod
async def discover(
*,
@ -150,6 +196,7 @@ class Discover:
timeout=5,
discovery_packets=3,
interface=None,
on_unsupported=None,
) -> DeviceDict:
"""Discover supported devices.
@ -177,6 +224,7 @@ class Discover:
on_discovered=on_discovered,
discovery_packets=discovery_packets,
interface=interface,
on_unsupported=on_unsupported,
),
local_addr=("0.0.0.0", 0),
)
@ -193,22 +241,47 @@ class Discover:
return protocol.discovered_devices
@staticmethod
async def discover_single(host: str, *, port: Optional[int] = None) -> SmartDevice:
async def discover_single(
host: str, *, port: Optional[int] = None, timeout=5
) -> SmartDevice:
"""Discover a single device by the given IP address.
:param host: Hostname of device to query
:rtype: SmartDevice
:return: Object for querying/controlling found device.
"""
protocol = TPLinkSmartHomeProtocol(host, port=port)
loop = asyncio.get_event_loop()
event = asyncio.Event()
transport, protocol = await loop.create_datagram_endpoint(
lambda: _DiscoverProtocol(target=host, port=port, discovered_event=event),
local_addr=("0.0.0.0", 0),
)
protocol = cast(_DiscoverProtocol, protocol)
info = await protocol.query(Discover.DISCOVERY_QUERY)
try:
_LOGGER.debug("Waiting a total of %s seconds for responses...", timeout)
device_class = Discover._get_device_class(info)
dev = device_class(host, port=port)
await dev.update()
async with asyncio_timeout(timeout):
await event.wait()
except asyncio.TimeoutError:
raise SmartDeviceException(
f"Timed out getting discovery response for {host}"
)
finally:
transport.close()
return dev
if host in protocol.discovered_devices:
dev = protocol.discovered_devices[host]
await dev.update()
return dev
elif host in protocol.unsupported_devices:
raise UnsupportedDeviceException(
f"Unsupported device {host}: {protocol.unsupported_devices[host]}"
)
elif host in protocol.invalid_device_exceptions:
raise protocol.invalid_device_exceptions[host]
else:
raise SmartDeviceException(f"Unable to get discovery response for {host}")
@staticmethod
def _get_device_class(info: dict) -> Type[SmartDevice]:

View File

@ -3,3 +3,7 @@
class SmartDeviceException(Exception):
"""Base exception for device errors."""
class UnsupportedDeviceException(SmartDeviceException):
"""Exception for trying to connect to unsupported devices."""

View File

@ -1,10 +1,12 @@
# type: ignore
import re
import sys
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
from kasa import DeviceType, Discover, SmartDevice, SmartDeviceException, protocol
from kasa.discover import _DiscoverProtocol
from kasa.discover import _DiscoverProtocol, json_dumps
from kasa.exceptions import UnsupportedDeviceException
from .conftest import bulb, dimmer, lightstrip, plug, strip
@ -55,11 +57,73 @@ async def test_type_unknown():
@pytest.mark.parametrize("custom_port", [123, None])
async def test_discover_single(discovery_data: dict, mocker, custom_port):
"""Make sure that discover_single returns an initialized SmartDevice instance."""
host = "127.0.0.1"
def mock_discover(self):
self.datagram_received(
protocol.TPLinkSmartHomeProtocol.encrypt(json_dumps(discovery_data))[4:],
(host, custom_port or 9999),
)
mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover)
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
x = await Discover.discover_single("127.0.0.1", port=custom_port)
x = await Discover.discover_single(host, port=custom_port)
assert issubclass(x.__class__, SmartDevice)
assert x._sys_info is not None
assert x.port == custom_port
assert x.port == custom_port or 9999
UNSUPPORTED = {
"result": {
"device_id": "xx",
"owner": "xx",
"device_type": "SMART.TAPOPLUG",
"device_model": "P110(EU)",
"ip": "127.0.0.1",
"mac": "48-22xxx",
"is_support_iot_cloud": True,
"obd_src": "tplink",
"factory_default": False,
"mgt_encrypt_schm": {
"is_support_https": False,
"encrypt_type": "AES",
"http_port": 80,
"lv": 2,
},
},
"error_code": 0,
}
async def test_discover_single_unsupported(mocker):
"""Make sure that discover_single handles unsupported devices correctly."""
host = "127.0.0.1"
def mock_discover(self):
if discovery_data:
data = (
b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8"
+ json_dumps(discovery_data).encode()
)
self.datagram_received(data, (host, 20002))
mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover)
# Test with a valid unsupported response
discovery_data = UNSUPPORTED
with pytest.raises(
UnsupportedDeviceException,
match=f"Unsupported device {host}: {re.escape(str(UNSUPPORTED))}",
):
await Discover.discover_single(host)
# Test with no response
discovery_data = None
with pytest.raises(
SmartDeviceException, match=f"Timed out getting discovery response for {host}"
):
await Discover.discover_single(host, timeout=0.001)
INVALIDS = [
@ -75,9 +139,17 @@ INVALIDS = [
@pytest.mark.parametrize("msg, data", INVALIDS)
async def test_discover_invalid_info(msg, data, mocker):
"""Make sure that invalid discovery information raises an exception."""
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=data)
host = "127.0.0.1"
def mock_discover(self):
self.datagram_received(
protocol.TPLinkSmartHomeProtocol.encrypt(json_dumps(data))[4:], (host, 9999)
)
mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover)
with pytest.raises(SmartDeviceException, match=msg):
await Discover.discover_single("127.0.0.1")
await Discover.discover_single(host)
async def test_discover_send(mocker):
@ -87,7 +159,7 @@ async def test_discover_send(mocker):
assert proto.target == ("255.255.255.255", 9999)
transport = mocker.patch.object(proto, "transport")
proto.do_discover()
assert transport.sendto.call_count == proto.discovery_packets
assert transport.sendto.call_count == proto.discovery_packets * 2
async def test_discover_datagram_received(mocker, discovery_data):
@ -98,10 +170,14 @@ async def test_discover_datagram_received(mocker, discovery_data):
mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "decrypt")
addr = "127.0.0.1"
proto.datagram_received("<placeholder data>", (addr, 1234))
proto.datagram_received("<placeholder data>", (addr, 9999))
addr2 = "127.0.0.2"
proto.datagram_received("<placeholder data>", (addr2, 20002))
# Check that device in discovered_devices is initialized correctly
assert len(proto.discovered_devices) == 1
# Check that unsupported device is 1
assert len(proto.unsupported_devices) == 1
dev = proto.discovered_devices[addr]
assert issubclass(dev.__class__, SmartDevice)
assert dev.host == addr
@ -115,5 +191,5 @@ async def test_discover_invalid_responses(msg, data, mocker):
mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "encrypt")
mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "decrypt")
proto.datagram_received(data, ("127.0.0.1", 1234))
proto.datagram_received(data, ("127.0.0.1", 9999))
assert len(proto.discovered_devices) == 0