mirror of
https://github.com/python-kasa/python-kasa.git
synced 2024-12-22 19:23:34 +00:00
Add support for alternative discovery protocol (20002/udp) (#488)
This will broadcast the new discovery message on the new port and log any responses received as unsupported devices.
This commit is contained in:
parent
53021f07fe
commit
6055c29d74
49
kasa/cli.py
49
kasa/cli.py
@ -103,6 +103,7 @@ def json_formatter_cb(result, **kwargs):
|
||||
"--port",
|
||||
envvar="KASA_PORT",
|
||||
required=False,
|
||||
type=int,
|
||||
help="The port of the device to connect to.",
|
||||
)
|
||||
@click.option(
|
||||
@ -138,7 +139,17 @@ def json_formatter_cb(result, **kwargs):
|
||||
)
|
||||
@click.version_option(package_name="python-kasa")
|
||||
@click.pass_context
|
||||
async def cli(ctx, host, port, alias, target, debug, type, json, discovery_timeout):
|
||||
async def cli(
|
||||
ctx,
|
||||
host,
|
||||
port,
|
||||
alias,
|
||||
target,
|
||||
debug,
|
||||
type,
|
||||
json,
|
||||
discovery_timeout,
|
||||
):
|
||||
"""A tool for controlling TP-Link smart home devices.""" # noqa
|
||||
# no need to perform any checks if we are just displaying the help
|
||||
if sys.argv[-1] == "--help":
|
||||
@ -238,13 +249,29 @@ async def join(dev: SmartDevice, ssid, password, keytype):
|
||||
|
||||
@cli.command()
|
||||
@click.option("--timeout", default=3, required=False)
|
||||
@click.option(
|
||||
"--show-unsupported",
|
||||
envvar="KASA_SHOW_UNSUPPORTED",
|
||||
required=False,
|
||||
default=False,
|
||||
is_flag=True,
|
||||
help="Print out discovered unsupported devices",
|
||||
)
|
||||
@click.pass_context
|
||||
async def discover(ctx, timeout):
|
||||
async def discover(ctx, timeout, show_unsupported):
|
||||
"""Discover devices in the network."""
|
||||
target = ctx.parent.params["target"]
|
||||
echo(f"Discovering devices on {target} for {timeout} seconds")
|
||||
sem = asyncio.Semaphore()
|
||||
discovered = dict()
|
||||
unsupported = []
|
||||
|
||||
async def print_unsupported(data: Dict):
|
||||
unsupported.append(data)
|
||||
if show_unsupported:
|
||||
echo(f"Found unsupported device (tapo/unknown encryption): {data}")
|
||||
echo()
|
||||
|
||||
echo(f"Discovering devices on {target} for {timeout} seconds")
|
||||
|
||||
async def print_discovered(dev: SmartDevice):
|
||||
await dev.update()
|
||||
@ -255,7 +282,21 @@ async def discover(ctx, timeout):
|
||||
echo()
|
||||
|
||||
await Discover.discover(
|
||||
target=target, timeout=timeout, on_discovered=print_discovered
|
||||
target=target,
|
||||
timeout=timeout,
|
||||
on_discovered=print_discovered,
|
||||
on_unsupported=print_unsupported,
|
||||
)
|
||||
|
||||
echo(f"Found {len(discovered)} devices")
|
||||
if unsupported:
|
||||
echo(
|
||||
f"Found {len(unsupported)} unsupported devices"
|
||||
+ (
|
||||
""
|
||||
if show_unsupported
|
||||
else ", to show them use: kasa discover --show-unsupported"
|
||||
)
|
||||
)
|
||||
|
||||
return discovered
|
||||
|
@ -1,9 +1,15 @@
|
||||
"""Discovery module for TP-Link Smart Home devices."""
|
||||
import asyncio
|
||||
import binascii
|
||||
import logging
|
||||
import socket
|
||||
from typing import Awaitable, Callable, Dict, Optional, Type, cast
|
||||
|
||||
# When support for cpython older than 3.11 is dropped
|
||||
# async_timeout can be replaced with asyncio.timeout
|
||||
from async_timeout import timeout as asyncio_timeout
|
||||
|
||||
from kasa.exceptions import UnsupportedDeviceException
|
||||
from kasa.json import dumps as json_dumps
|
||||
from kasa.json import loads as json_loads
|
||||
from kasa.protocol import TPLinkSmartHomeProtocol
|
||||
@ -36,13 +42,22 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
|
||||
target: str = "255.255.255.255",
|
||||
discovery_packets: int = 3,
|
||||
interface: Optional[str] = None,
|
||||
on_unsupported: Optional[Callable[[Dict], Awaitable[None]]] = None,
|
||||
port: Optional[int] = None,
|
||||
discovered_event: Optional[asyncio.Event] = None,
|
||||
):
|
||||
self.transport = None
|
||||
self.discovery_packets = discovery_packets
|
||||
self.interface = interface
|
||||
self.on_discovered = on_discovered
|
||||
self.target = (target, Discover.DISCOVERY_PORT)
|
||||
self.discovery_port = port or Discover.DISCOVERY_PORT
|
||||
self.target = (target, self.discovery_port)
|
||||
self.target_2 = (target, Discover.DISCOVERY_PORT_2)
|
||||
self.discovered_devices = {}
|
||||
self.unsupported_devices: Dict = {}
|
||||
self.invalid_device_exceptions: Dict = {}
|
||||
self.on_unsupported = on_unsupported
|
||||
self.discovered_event = discovered_event
|
||||
|
||||
def connection_made(self, transport) -> None:
|
||||
"""Set socket options for broadcasting."""
|
||||
@ -69,23 +84,48 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
|
||||
encrypted_req = TPLinkSmartHomeProtocol.encrypt(req)
|
||||
for i in range(self.discovery_packets):
|
||||
self.transport.sendto(encrypted_req[4:], self.target) # type: ignore
|
||||
self.transport.sendto(Discover.DISCOVERY_QUERY_2, self.target_2) # type: ignore
|
||||
|
||||
def datagram_received(self, data, addr) -> None:
|
||||
"""Handle discovery responses."""
|
||||
ip, port = addr
|
||||
if ip in self.discovered_devices:
|
||||
if (
|
||||
ip in self.discovered_devices
|
||||
or ip in self.unsupported_devices
|
||||
or ip in self.invalid_device_exceptions
|
||||
):
|
||||
return
|
||||
|
||||
if port == self.discovery_port:
|
||||
info = json_loads(TPLinkSmartHomeProtocol.decrypt(data))
|
||||
_LOGGER.debug("[DISCOVERY] %s << %s", ip, info)
|
||||
|
||||
elif port == Discover.DISCOVERY_PORT_2:
|
||||
info = json_loads(data[16:])
|
||||
self.unsupported_devices[ip] = info
|
||||
if self.on_unsupported is not None:
|
||||
asyncio.ensure_future(self.on_unsupported(info))
|
||||
_LOGGER.debug("[DISCOVERY] Unsupported device found at %s << %s", ip, info)
|
||||
if self.discovered_event is not None and "255" not in self.target[0].split(
|
||||
"."
|
||||
):
|
||||
self.discovered_event.set()
|
||||
return
|
||||
|
||||
try:
|
||||
device_class = Discover._get_device_class(info)
|
||||
except SmartDeviceException as ex:
|
||||
_LOGGER.debug("Unable to find device type from %s: %s", info, ex)
|
||||
_LOGGER.debug(
|
||||
"[DISCOVERY] Unable to find device type from %s: %s", info, ex
|
||||
)
|
||||
self.invalid_device_exceptions[ip] = ex
|
||||
if self.discovered_event is not None and "255" not in self.target[0].split(
|
||||
"."
|
||||
):
|
||||
self.discovered_event.set()
|
||||
return
|
||||
|
||||
device = device_class(ip)
|
||||
device = device_class(ip, port=port)
|
||||
device.update_from_discover_info(info)
|
||||
|
||||
self.discovered_devices[ip] = device
|
||||
@ -93,6 +133,9 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
|
||||
if self.on_discovered is not None:
|
||||
asyncio.ensure_future(self.on_discovered(device))
|
||||
|
||||
if self.discovered_event is not None and "255" not in self.target[0].split("."):
|
||||
self.discovered_event.set()
|
||||
|
||||
def error_received(self, ex):
|
||||
"""Handle asyncio.Protocol errors."""
|
||||
_LOGGER.error("Got error: %s", ex)
|
||||
@ -142,6 +185,9 @@ class Discover:
|
||||
"system": {"get_sysinfo": None},
|
||||
}
|
||||
|
||||
DISCOVERY_PORT_2 = 20002
|
||||
DISCOVERY_QUERY_2 = binascii.unhexlify("020000010000000000000000463cb5d3")
|
||||
|
||||
@staticmethod
|
||||
async def discover(
|
||||
*,
|
||||
@ -150,6 +196,7 @@ class Discover:
|
||||
timeout=5,
|
||||
discovery_packets=3,
|
||||
interface=None,
|
||||
on_unsupported=None,
|
||||
) -> DeviceDict:
|
||||
"""Discover supported devices.
|
||||
|
||||
@ -177,6 +224,7 @@ class Discover:
|
||||
on_discovered=on_discovered,
|
||||
discovery_packets=discovery_packets,
|
||||
interface=interface,
|
||||
on_unsupported=on_unsupported,
|
||||
),
|
||||
local_addr=("0.0.0.0", 0),
|
||||
)
|
||||
@ -193,22 +241,47 @@ class Discover:
|
||||
return protocol.discovered_devices
|
||||
|
||||
@staticmethod
|
||||
async def discover_single(host: str, *, port: Optional[int] = None) -> SmartDevice:
|
||||
async def discover_single(
|
||||
host: str, *, port: Optional[int] = None, timeout=5
|
||||
) -> SmartDevice:
|
||||
"""Discover a single device by the given IP address.
|
||||
|
||||
:param host: Hostname of device to query
|
||||
:rtype: SmartDevice
|
||||
:return: Object for querying/controlling found device.
|
||||
"""
|
||||
protocol = TPLinkSmartHomeProtocol(host, port=port)
|
||||
loop = asyncio.get_event_loop()
|
||||
event = asyncio.Event()
|
||||
transport, protocol = await loop.create_datagram_endpoint(
|
||||
lambda: _DiscoverProtocol(target=host, port=port, discovered_event=event),
|
||||
local_addr=("0.0.0.0", 0),
|
||||
)
|
||||
protocol = cast(_DiscoverProtocol, protocol)
|
||||
|
||||
info = await protocol.query(Discover.DISCOVERY_QUERY)
|
||||
try:
|
||||
_LOGGER.debug("Waiting a total of %s seconds for responses...", timeout)
|
||||
|
||||
device_class = Discover._get_device_class(info)
|
||||
dev = device_class(host, port=port)
|
||||
async with asyncio_timeout(timeout):
|
||||
await event.wait()
|
||||
except asyncio.TimeoutError:
|
||||
raise SmartDeviceException(
|
||||
f"Timed out getting discovery response for {host}"
|
||||
)
|
||||
finally:
|
||||
transport.close()
|
||||
|
||||
if host in protocol.discovered_devices:
|
||||
dev = protocol.discovered_devices[host]
|
||||
await dev.update()
|
||||
|
||||
return dev
|
||||
elif host in protocol.unsupported_devices:
|
||||
raise UnsupportedDeviceException(
|
||||
f"Unsupported device {host}: {protocol.unsupported_devices[host]}"
|
||||
)
|
||||
elif host in protocol.invalid_device_exceptions:
|
||||
raise protocol.invalid_device_exceptions[host]
|
||||
else:
|
||||
raise SmartDeviceException(f"Unable to get discovery response for {host}")
|
||||
|
||||
@staticmethod
|
||||
def _get_device_class(info: dict) -> Type[SmartDevice]:
|
||||
|
@ -3,3 +3,7 @@
|
||||
|
||||
class SmartDeviceException(Exception):
|
||||
"""Base exception for device errors."""
|
||||
|
||||
|
||||
class UnsupportedDeviceException(SmartDeviceException):
|
||||
"""Exception for trying to connect to unsupported devices."""
|
||||
|
@ -1,10 +1,12 @@
|
||||
# type: ignore
|
||||
import re
|
||||
import sys
|
||||
|
||||
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
|
||||
|
||||
from kasa import DeviceType, Discover, SmartDevice, SmartDeviceException, protocol
|
||||
from kasa.discover import _DiscoverProtocol
|
||||
from kasa.discover import _DiscoverProtocol, json_dumps
|
||||
from kasa.exceptions import UnsupportedDeviceException
|
||||
|
||||
from .conftest import bulb, dimmer, lightstrip, plug, strip
|
||||
|
||||
@ -55,11 +57,73 @@ async def test_type_unknown():
|
||||
@pytest.mark.parametrize("custom_port", [123, None])
|
||||
async def test_discover_single(discovery_data: dict, mocker, custom_port):
|
||||
"""Make sure that discover_single returns an initialized SmartDevice instance."""
|
||||
host = "127.0.0.1"
|
||||
|
||||
def mock_discover(self):
|
||||
self.datagram_received(
|
||||
protocol.TPLinkSmartHomeProtocol.encrypt(json_dumps(discovery_data))[4:],
|
||||
(host, custom_port or 9999),
|
||||
)
|
||||
|
||||
mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover)
|
||||
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
|
||||
x = await Discover.discover_single("127.0.0.1", port=custom_port)
|
||||
|
||||
x = await Discover.discover_single(host, port=custom_port)
|
||||
assert issubclass(x.__class__, SmartDevice)
|
||||
assert x._sys_info is not None
|
||||
assert x.port == custom_port
|
||||
assert x.port == custom_port or 9999
|
||||
|
||||
|
||||
UNSUPPORTED = {
|
||||
"result": {
|
||||
"device_id": "xx",
|
||||
"owner": "xx",
|
||||
"device_type": "SMART.TAPOPLUG",
|
||||
"device_model": "P110(EU)",
|
||||
"ip": "127.0.0.1",
|
||||
"mac": "48-22xxx",
|
||||
"is_support_iot_cloud": True,
|
||||
"obd_src": "tplink",
|
||||
"factory_default": False,
|
||||
"mgt_encrypt_schm": {
|
||||
"is_support_https": False,
|
||||
"encrypt_type": "AES",
|
||||
"http_port": 80,
|
||||
"lv": 2,
|
||||
},
|
||||
},
|
||||
"error_code": 0,
|
||||
}
|
||||
|
||||
|
||||
async def test_discover_single_unsupported(mocker):
|
||||
"""Make sure that discover_single handles unsupported devices correctly."""
|
||||
host = "127.0.0.1"
|
||||
|
||||
def mock_discover(self):
|
||||
if discovery_data:
|
||||
data = (
|
||||
b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8"
|
||||
+ json_dumps(discovery_data).encode()
|
||||
)
|
||||
self.datagram_received(data, (host, 20002))
|
||||
|
||||
mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover)
|
||||
|
||||
# Test with a valid unsupported response
|
||||
discovery_data = UNSUPPORTED
|
||||
with pytest.raises(
|
||||
UnsupportedDeviceException,
|
||||
match=f"Unsupported device {host}: {re.escape(str(UNSUPPORTED))}",
|
||||
):
|
||||
await Discover.discover_single(host)
|
||||
|
||||
# Test with no response
|
||||
discovery_data = None
|
||||
with pytest.raises(
|
||||
SmartDeviceException, match=f"Timed out getting discovery response for {host}"
|
||||
):
|
||||
await Discover.discover_single(host, timeout=0.001)
|
||||
|
||||
|
||||
INVALIDS = [
|
||||
@ -75,9 +139,17 @@ INVALIDS = [
|
||||
@pytest.mark.parametrize("msg, data", INVALIDS)
|
||||
async def test_discover_invalid_info(msg, data, mocker):
|
||||
"""Make sure that invalid discovery information raises an exception."""
|
||||
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=data)
|
||||
host = "127.0.0.1"
|
||||
|
||||
def mock_discover(self):
|
||||
self.datagram_received(
|
||||
protocol.TPLinkSmartHomeProtocol.encrypt(json_dumps(data))[4:], (host, 9999)
|
||||
)
|
||||
|
||||
mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover)
|
||||
|
||||
with pytest.raises(SmartDeviceException, match=msg):
|
||||
await Discover.discover_single("127.0.0.1")
|
||||
await Discover.discover_single(host)
|
||||
|
||||
|
||||
async def test_discover_send(mocker):
|
||||
@ -87,7 +159,7 @@ async def test_discover_send(mocker):
|
||||
assert proto.target == ("255.255.255.255", 9999)
|
||||
transport = mocker.patch.object(proto, "transport")
|
||||
proto.do_discover()
|
||||
assert transport.sendto.call_count == proto.discovery_packets
|
||||
assert transport.sendto.call_count == proto.discovery_packets * 2
|
||||
|
||||
|
||||
async def test_discover_datagram_received(mocker, discovery_data):
|
||||
@ -98,10 +170,14 @@ async def test_discover_datagram_received(mocker, discovery_data):
|
||||
mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "decrypt")
|
||||
|
||||
addr = "127.0.0.1"
|
||||
proto.datagram_received("<placeholder data>", (addr, 1234))
|
||||
proto.datagram_received("<placeholder data>", (addr, 9999))
|
||||
addr2 = "127.0.0.2"
|
||||
proto.datagram_received("<placeholder data>", (addr2, 20002))
|
||||
|
||||
# Check that device in discovered_devices is initialized correctly
|
||||
assert len(proto.discovered_devices) == 1
|
||||
# Check that unsupported device is 1
|
||||
assert len(proto.unsupported_devices) == 1
|
||||
dev = proto.discovered_devices[addr]
|
||||
assert issubclass(dev.__class__, SmartDevice)
|
||||
assert dev.host == addr
|
||||
@ -115,5 +191,5 @@ async def test_discover_invalid_responses(msg, data, mocker):
|
||||
mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "encrypt")
|
||||
mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "decrypt")
|
||||
|
||||
proto.datagram_received(data, ("127.0.0.1", 1234))
|
||||
proto.datagram_received(data, ("127.0.0.1", 9999))
|
||||
assert len(proto.discovered_devices) == 0
|
||||
|
Loading…
Reference in New Issue
Block a user