Add support for alternative discovery protocol (20002/udp) (#488)

This will broadcast the new discovery message on the new port and log any responses received as unsupported devices.
This commit is contained in:
sdb9696 2023-08-29 14:04:28 +01:00 committed by GitHub
parent 53021f07fe
commit 6055c29d74
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 219 additions and 25 deletions

View File

@ -103,6 +103,7 @@ def json_formatter_cb(result, **kwargs):
"--port", "--port",
envvar="KASA_PORT", envvar="KASA_PORT",
required=False, required=False,
type=int,
help="The port of the device to connect to.", help="The port of the device to connect to.",
) )
@click.option( @click.option(
@ -138,7 +139,17 @@ def json_formatter_cb(result, **kwargs):
) )
@click.version_option(package_name="python-kasa") @click.version_option(package_name="python-kasa")
@click.pass_context @click.pass_context
async def cli(ctx, host, port, alias, target, debug, type, json, discovery_timeout): async def cli(
ctx,
host,
port,
alias,
target,
debug,
type,
json,
discovery_timeout,
):
"""A tool for controlling TP-Link smart home devices.""" # noqa """A tool for controlling TP-Link smart home devices.""" # noqa
# no need to perform any checks if we are just displaying the help # no need to perform any checks if we are just displaying the help
if sys.argv[-1] == "--help": if sys.argv[-1] == "--help":
@ -238,13 +249,29 @@ async def join(dev: SmartDevice, ssid, password, keytype):
@cli.command() @cli.command()
@click.option("--timeout", default=3, required=False) @click.option("--timeout", default=3, required=False)
@click.option(
"--show-unsupported",
envvar="KASA_SHOW_UNSUPPORTED",
required=False,
default=False,
is_flag=True,
help="Print out discovered unsupported devices",
)
@click.pass_context @click.pass_context
async def discover(ctx, timeout): async def discover(ctx, timeout, show_unsupported):
"""Discover devices in the network.""" """Discover devices in the network."""
target = ctx.parent.params["target"] target = ctx.parent.params["target"]
echo(f"Discovering devices on {target} for {timeout} seconds")
sem = asyncio.Semaphore() sem = asyncio.Semaphore()
discovered = dict() discovered = dict()
unsupported = []
async def print_unsupported(data: Dict):
unsupported.append(data)
if show_unsupported:
echo(f"Found unsupported device (tapo/unknown encryption): {data}")
echo()
echo(f"Discovering devices on {target} for {timeout} seconds")
async def print_discovered(dev: SmartDevice): async def print_discovered(dev: SmartDevice):
await dev.update() await dev.update()
@ -255,7 +282,21 @@ async def discover(ctx, timeout):
echo() echo()
await Discover.discover( await Discover.discover(
target=target, timeout=timeout, on_discovered=print_discovered target=target,
timeout=timeout,
on_discovered=print_discovered,
on_unsupported=print_unsupported,
)
echo(f"Found {len(discovered)} devices")
if unsupported:
echo(
f"Found {len(unsupported)} unsupported devices"
+ (
""
if show_unsupported
else ", to show them use: kasa discover --show-unsupported"
)
) )
return discovered return discovered

View File

@ -1,9 +1,15 @@
"""Discovery module for TP-Link Smart Home devices.""" """Discovery module for TP-Link Smart Home devices."""
import asyncio import asyncio
import binascii
import logging import logging
import socket import socket
from typing import Awaitable, Callable, Dict, Optional, Type, cast from typing import Awaitable, Callable, Dict, Optional, Type, cast
# When support for cpython older than 3.11 is dropped
# async_timeout can be replaced with asyncio.timeout
from async_timeout import timeout as asyncio_timeout
from kasa.exceptions import UnsupportedDeviceException
from kasa.json import dumps as json_dumps from kasa.json import dumps as json_dumps
from kasa.json import loads as json_loads from kasa.json import loads as json_loads
from kasa.protocol import TPLinkSmartHomeProtocol from kasa.protocol import TPLinkSmartHomeProtocol
@ -36,13 +42,22 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
target: str = "255.255.255.255", target: str = "255.255.255.255",
discovery_packets: int = 3, discovery_packets: int = 3,
interface: Optional[str] = None, interface: Optional[str] = None,
on_unsupported: Optional[Callable[[Dict], Awaitable[None]]] = None,
port: Optional[int] = None,
discovered_event: Optional[asyncio.Event] = None,
): ):
self.transport = None self.transport = None
self.discovery_packets = discovery_packets self.discovery_packets = discovery_packets
self.interface = interface self.interface = interface
self.on_discovered = on_discovered self.on_discovered = on_discovered
self.target = (target, Discover.DISCOVERY_PORT) self.discovery_port = port or Discover.DISCOVERY_PORT
self.target = (target, self.discovery_port)
self.target_2 = (target, Discover.DISCOVERY_PORT_2)
self.discovered_devices = {} self.discovered_devices = {}
self.unsupported_devices: Dict = {}
self.invalid_device_exceptions: Dict = {}
self.on_unsupported = on_unsupported
self.discovered_event = discovered_event
def connection_made(self, transport) -> None: def connection_made(self, transport) -> None:
"""Set socket options for broadcasting.""" """Set socket options for broadcasting."""
@ -69,23 +84,48 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
encrypted_req = TPLinkSmartHomeProtocol.encrypt(req) encrypted_req = TPLinkSmartHomeProtocol.encrypt(req)
for i in range(self.discovery_packets): for i in range(self.discovery_packets):
self.transport.sendto(encrypted_req[4:], self.target) # type: ignore self.transport.sendto(encrypted_req[4:], self.target) # type: ignore
self.transport.sendto(Discover.DISCOVERY_QUERY_2, self.target_2) # type: ignore
def datagram_received(self, data, addr) -> None: def datagram_received(self, data, addr) -> None:
"""Handle discovery responses.""" """Handle discovery responses."""
ip, port = addr ip, port = addr
if ip in self.discovered_devices: if (
ip in self.discovered_devices
or ip in self.unsupported_devices
or ip in self.invalid_device_exceptions
):
return return
if port == self.discovery_port:
info = json_loads(TPLinkSmartHomeProtocol.decrypt(data)) info = json_loads(TPLinkSmartHomeProtocol.decrypt(data))
_LOGGER.debug("[DISCOVERY] %s << %s", ip, info) _LOGGER.debug("[DISCOVERY] %s << %s", ip, info)
elif port == Discover.DISCOVERY_PORT_2:
info = json_loads(data[16:])
self.unsupported_devices[ip] = info
if self.on_unsupported is not None:
asyncio.ensure_future(self.on_unsupported(info))
_LOGGER.debug("[DISCOVERY] Unsupported device found at %s << %s", ip, info)
if self.discovered_event is not None and "255" not in self.target[0].split(
"."
):
self.discovered_event.set()
return
try: try:
device_class = Discover._get_device_class(info) device_class = Discover._get_device_class(info)
except SmartDeviceException as ex: except SmartDeviceException as ex:
_LOGGER.debug("Unable to find device type from %s: %s", info, ex) _LOGGER.debug(
"[DISCOVERY] Unable to find device type from %s: %s", info, ex
)
self.invalid_device_exceptions[ip] = ex
if self.discovered_event is not None and "255" not in self.target[0].split(
"."
):
self.discovered_event.set()
return return
device = device_class(ip) device = device_class(ip, port=port)
device.update_from_discover_info(info) device.update_from_discover_info(info)
self.discovered_devices[ip] = device self.discovered_devices[ip] = device
@ -93,6 +133,9 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
if self.on_discovered is not None: if self.on_discovered is not None:
asyncio.ensure_future(self.on_discovered(device)) asyncio.ensure_future(self.on_discovered(device))
if self.discovered_event is not None and "255" not in self.target[0].split("."):
self.discovered_event.set()
def error_received(self, ex): def error_received(self, ex):
"""Handle asyncio.Protocol errors.""" """Handle asyncio.Protocol errors."""
_LOGGER.error("Got error: %s", ex) _LOGGER.error("Got error: %s", ex)
@ -142,6 +185,9 @@ class Discover:
"system": {"get_sysinfo": None}, "system": {"get_sysinfo": None},
} }
DISCOVERY_PORT_2 = 20002
DISCOVERY_QUERY_2 = binascii.unhexlify("020000010000000000000000463cb5d3")
@staticmethod @staticmethod
async def discover( async def discover(
*, *,
@ -150,6 +196,7 @@ class Discover:
timeout=5, timeout=5,
discovery_packets=3, discovery_packets=3,
interface=None, interface=None,
on_unsupported=None,
) -> DeviceDict: ) -> DeviceDict:
"""Discover supported devices. """Discover supported devices.
@ -177,6 +224,7 @@ class Discover:
on_discovered=on_discovered, on_discovered=on_discovered,
discovery_packets=discovery_packets, discovery_packets=discovery_packets,
interface=interface, interface=interface,
on_unsupported=on_unsupported,
), ),
local_addr=("0.0.0.0", 0), local_addr=("0.0.0.0", 0),
) )
@ -193,22 +241,47 @@ class Discover:
return protocol.discovered_devices return protocol.discovered_devices
@staticmethod @staticmethod
async def discover_single(host: str, *, port: Optional[int] = None) -> SmartDevice: async def discover_single(
host: str, *, port: Optional[int] = None, timeout=5
) -> SmartDevice:
"""Discover a single device by the given IP address. """Discover a single device by the given IP address.
:param host: Hostname of device to query :param host: Hostname of device to query
:rtype: SmartDevice :rtype: SmartDevice
:return: Object for querying/controlling found device. :return: Object for querying/controlling found device.
""" """
protocol = TPLinkSmartHomeProtocol(host, port=port) loop = asyncio.get_event_loop()
event = asyncio.Event()
transport, protocol = await loop.create_datagram_endpoint(
lambda: _DiscoverProtocol(target=host, port=port, discovered_event=event),
local_addr=("0.0.0.0", 0),
)
protocol = cast(_DiscoverProtocol, protocol)
info = await protocol.query(Discover.DISCOVERY_QUERY) try:
_LOGGER.debug("Waiting a total of %s seconds for responses...", timeout)
device_class = Discover._get_device_class(info) async with asyncio_timeout(timeout):
dev = device_class(host, port=port) await event.wait()
except asyncio.TimeoutError:
raise SmartDeviceException(
f"Timed out getting discovery response for {host}"
)
finally:
transport.close()
if host in protocol.discovered_devices:
dev = protocol.discovered_devices[host]
await dev.update() await dev.update()
return dev return dev
elif host in protocol.unsupported_devices:
raise UnsupportedDeviceException(
f"Unsupported device {host}: {protocol.unsupported_devices[host]}"
)
elif host in protocol.invalid_device_exceptions:
raise protocol.invalid_device_exceptions[host]
else:
raise SmartDeviceException(f"Unable to get discovery response for {host}")
@staticmethod @staticmethod
def _get_device_class(info: dict) -> Type[SmartDevice]: def _get_device_class(info: dict) -> Type[SmartDevice]:

View File

@ -3,3 +3,7 @@
class SmartDeviceException(Exception): class SmartDeviceException(Exception):
"""Base exception for device errors.""" """Base exception for device errors."""
class UnsupportedDeviceException(SmartDeviceException):
"""Exception for trying to connect to unsupported devices."""

View File

@ -1,10 +1,12 @@
# type: ignore # type: ignore
import re
import sys import sys
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
from kasa import DeviceType, Discover, SmartDevice, SmartDeviceException, protocol from kasa import DeviceType, Discover, SmartDevice, SmartDeviceException, protocol
from kasa.discover import _DiscoverProtocol from kasa.discover import _DiscoverProtocol, json_dumps
from kasa.exceptions import UnsupportedDeviceException
from .conftest import bulb, dimmer, lightstrip, plug, strip from .conftest import bulb, dimmer, lightstrip, plug, strip
@ -55,11 +57,73 @@ async def test_type_unknown():
@pytest.mark.parametrize("custom_port", [123, None]) @pytest.mark.parametrize("custom_port", [123, None])
async def test_discover_single(discovery_data: dict, mocker, custom_port): async def test_discover_single(discovery_data: dict, mocker, custom_port):
"""Make sure that discover_single returns an initialized SmartDevice instance.""" """Make sure that discover_single returns an initialized SmartDevice instance."""
host = "127.0.0.1"
def mock_discover(self):
self.datagram_received(
protocol.TPLinkSmartHomeProtocol.encrypt(json_dumps(discovery_data))[4:],
(host, custom_port or 9999),
)
mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover)
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data) mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
x = await Discover.discover_single("127.0.0.1", port=custom_port)
x = await Discover.discover_single(host, port=custom_port)
assert issubclass(x.__class__, SmartDevice) assert issubclass(x.__class__, SmartDevice)
assert x._sys_info is not None assert x._sys_info is not None
assert x.port == custom_port assert x.port == custom_port or 9999
UNSUPPORTED = {
"result": {
"device_id": "xx",
"owner": "xx",
"device_type": "SMART.TAPOPLUG",
"device_model": "P110(EU)",
"ip": "127.0.0.1",
"mac": "48-22xxx",
"is_support_iot_cloud": True,
"obd_src": "tplink",
"factory_default": False,
"mgt_encrypt_schm": {
"is_support_https": False,
"encrypt_type": "AES",
"http_port": 80,
"lv": 2,
},
},
"error_code": 0,
}
async def test_discover_single_unsupported(mocker):
"""Make sure that discover_single handles unsupported devices correctly."""
host = "127.0.0.1"
def mock_discover(self):
if discovery_data:
data = (
b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8"
+ json_dumps(discovery_data).encode()
)
self.datagram_received(data, (host, 20002))
mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover)
# Test with a valid unsupported response
discovery_data = UNSUPPORTED
with pytest.raises(
UnsupportedDeviceException,
match=f"Unsupported device {host}: {re.escape(str(UNSUPPORTED))}",
):
await Discover.discover_single(host)
# Test with no response
discovery_data = None
with pytest.raises(
SmartDeviceException, match=f"Timed out getting discovery response for {host}"
):
await Discover.discover_single(host, timeout=0.001)
INVALIDS = [ INVALIDS = [
@ -75,9 +139,17 @@ INVALIDS = [
@pytest.mark.parametrize("msg, data", INVALIDS) @pytest.mark.parametrize("msg, data", INVALIDS)
async def test_discover_invalid_info(msg, data, mocker): async def test_discover_invalid_info(msg, data, mocker):
"""Make sure that invalid discovery information raises an exception.""" """Make sure that invalid discovery information raises an exception."""
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=data) host = "127.0.0.1"
def mock_discover(self):
self.datagram_received(
protocol.TPLinkSmartHomeProtocol.encrypt(json_dumps(data))[4:], (host, 9999)
)
mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover)
with pytest.raises(SmartDeviceException, match=msg): with pytest.raises(SmartDeviceException, match=msg):
await Discover.discover_single("127.0.0.1") await Discover.discover_single(host)
async def test_discover_send(mocker): async def test_discover_send(mocker):
@ -87,7 +159,7 @@ async def test_discover_send(mocker):
assert proto.target == ("255.255.255.255", 9999) assert proto.target == ("255.255.255.255", 9999)
transport = mocker.patch.object(proto, "transport") transport = mocker.patch.object(proto, "transport")
proto.do_discover() proto.do_discover()
assert transport.sendto.call_count == proto.discovery_packets assert transport.sendto.call_count == proto.discovery_packets * 2
async def test_discover_datagram_received(mocker, discovery_data): async def test_discover_datagram_received(mocker, discovery_data):
@ -98,10 +170,14 @@ async def test_discover_datagram_received(mocker, discovery_data):
mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "decrypt") mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "decrypt")
addr = "127.0.0.1" addr = "127.0.0.1"
proto.datagram_received("<placeholder data>", (addr, 1234)) proto.datagram_received("<placeholder data>", (addr, 9999))
addr2 = "127.0.0.2"
proto.datagram_received("<placeholder data>", (addr2, 20002))
# Check that device in discovered_devices is initialized correctly # Check that device in discovered_devices is initialized correctly
assert len(proto.discovered_devices) == 1 assert len(proto.discovered_devices) == 1
# Check that unsupported device is 1
assert len(proto.unsupported_devices) == 1
dev = proto.discovered_devices[addr] dev = proto.discovered_devices[addr]
assert issubclass(dev.__class__, SmartDevice) assert issubclass(dev.__class__, SmartDevice)
assert dev.host == addr assert dev.host == addr
@ -115,5 +191,5 @@ async def test_discover_invalid_responses(msg, data, mocker):
mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "encrypt") mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "encrypt")
mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "decrypt") mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "decrypt")
proto.datagram_received(data, ("127.0.0.1", 1234)) proto.datagram_received(data, ("127.0.0.1", 9999))
assert len(proto.discovered_devices) == 0 assert len(proto.discovered_devices) == 0