Remove sync interface, add asyncio discovery (#14)

* do not update inside __repr__

* Convert discovery to asyncio

* Use asyncio.DatagramProtocol
* Cleanup parameters, no more positional arguments

Closes #7

* Remove sync interface

* This requires #13 to be merged. Closes #12.
* Converts cli to use asyncio.run() where needed.
* The children from smartstrips is being initialized during the first update call.

* Convert on and off commands to use asyncio.run

* conftest: do the initial update automatically for the device, cleans up tests a bit

* return subdevices alias for strip plugs, remove sync from docstrings

* Make tests pass using pytest-asyncio

* Simplify tests and use pytest-asyncio.
* Removed the emeter tests for child devices, as this information do not seem to exist (based on the dummy sysinfo data). Can be added again if needed.
* Remove sync from docstrings.

* Fix incorrect type hint

* Add type hints and some docstrings to discovery
This commit is contained in:
Teemu R
2020-01-12 22:44:19 +01:00
committed by GitHub
parent 3c68d295da
commit 524d28abbc
9 changed files with 386 additions and 341 deletions

View File

@@ -1,8 +1,9 @@
"""Discovery module for TP-Link Smart Home devices."""
import asyncio
import json
import logging
import socket
from typing import Dict, Optional, Type
from typing import Awaitable, Callable, Dict, Mapping, Type, Union, cast
from kasa.protocol import TPLinkSmartHomeProtocol
from kasa.smartbulb import SmartBulb
@@ -13,6 +14,79 @@ from kasa.smartstrip import SmartStrip
_LOGGER = logging.getLogger(__name__)
OnDiscoveredCallable = Callable[[SmartDevice], Awaitable[None]]
class _DiscoverProtocol(asyncio.DatagramProtocol):
"""Implementation of the discovery protocol handler.
This is internal class, use :func:Discover.discover: instead.
"""
discovered_devices: Dict[str, SmartDevice]
discovered_devices_raw: Dict[str, Dict]
def __init__(
self,
*,
on_discovered: OnDiscoveredCallable = None,
target: str = "255.255.255.255",
timeout: int = 5,
discovery_packets: int = 3,
):
self.transport = None
self.tries = discovery_packets
self.timeout = timeout
self.on_discovered = on_discovered
self.protocol = TPLinkSmartHomeProtocol()
self.target = (target, Discover.DISCOVERY_PORT)
self.discovered_devices = {}
self.discovered_devices_raw = {}
def connection_made(self, transport) -> None:
"""Set socket options for broadcasting."""
self.transport = transport
sock = transport.get_extra_info("socket")
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.do_discover()
def do_discover(self) -> None:
"""Send number of discovery datagrams."""
req = json.dumps(Discover.DISCOVERY_QUERY)
_LOGGER.debug("[DISCOVERY] %s >> %s", self.target, Discover.DISCOVERY_QUERY)
encrypted_req = self.protocol.encrypt(req)
for i in range(self.tries):
self.transport.sendto(encrypted_req[4:], self.target) # type: ignore
def datagram_received(self, data, addr) -> None:
ip, port = addr
if ip in self.discovered_devices:
return
info = json.loads(self.protocol.decrypt(data))
_LOGGER.debug("[DISCOVERY] %s << %s", ip, info)
device_class = Discover._get_device_class(info)
device = device_class(ip)
self.discovered_devices[ip] = device
self.discovered_devices_raw[ip] = info
if device_class is not None:
if self.on_discovered is not None:
asyncio.ensure_future(self.on_discovered(device))
else:
_LOGGER.error("Received invalid response: %s", info)
def error_received(self, ex):
_LOGGER.error("Got error: %s", ex)
def connection_lost(self, ex):
pass
class Discover:
"""Discover TPLink Smart Home devices.
@@ -28,6 +102,8 @@ class Discover:
The protocol uses UDP broadcast datagrams on port 9999 for discovery.
"""
DISCOVERY_PORT = 9999
DISCOVERY_QUERY = {
"system": {"get_sysinfo": None},
"emeter": {"get_realtime": None},
@@ -37,75 +113,65 @@ class Discover:
}
@staticmethod
def discover(
protocol: TPLinkSmartHomeProtocol = None,
target: str = "255.255.255.255",
port: int = 9999,
timeout: int = 3,
async def discover(
*,
target="255.255.255.255",
on_discovered=None,
timeout=5,
discovery_packets=3,
return_raw=False,
) -> Dict[str, SmartDevice]:
"""Discover devices.
) -> Mapping[str, Union[SmartDevice, Dict]]:
"""Discover supported devices.
Sends discovery message to 255.255.255.255:9999 in order
to detect available supported devices in the local network,
and waits for given timeout for answers from devices.
:param protocol: Protocol implementation to use
If given, `on_discovered` coroutine will get passed with the SmartDevice as parameter.
The results of the discovery can be accessed either via `discovered_devices` (SmartDevice-derived) or
`discovered_devices_raw` (JSON objects).
:param target: The target broadcast address (e.g. 192.168.xxx.255).
:param timeout: How long to wait for responses, defaults to 3
:param port: port to send broadcast messages, defaults to 9999.
:rtype: dict
:return: Array of json objects {"ip", "port", "sys_info"}
:param on_discovered:
:param timeout: How long to wait for responses, defaults to 5
:param discovery_packets: Number of discovery packets are broadcasted.
:param return_raw: True to return JSON objects instead of Devices.
:return:
"""
if protocol is None:
protocol = TPLinkSmartHomeProtocol()
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.settimeout(timeout)
req = json.dumps(Discover.DISCOVERY_QUERY)
_LOGGER.debug("Sending discovery to %s:%s", target, port)
encrypted_req = protocol.encrypt(req)
for i in range(discovery_packets):
sock.sendto(encrypted_req[4:], (target, port))
devices = {}
_LOGGER.debug("Waiting %s seconds for responses...", timeout)
loop = asyncio.get_event_loop()
transport, protocol = await loop.create_datagram_endpoint(
lambda: _DiscoverProtocol(
target=target,
on_discovered=on_discovered,
timeout=timeout,
discovery_packets=discovery_packets,
),
local_addr=("0.0.0.0", 0),
)
protocol = cast(_DiscoverProtocol, protocol)
try:
while True:
data, addr = sock.recvfrom(4096)
ip, port = addr
info = json.loads(protocol.decrypt(data))
device_class = Discover._get_device_class(info)
if return_raw:
devices[ip] = info
elif device_class is not None:
devices[ip] = device_class(ip)
except socket.timeout:
_LOGGER.debug("Got socket timeout, which is okay.")
except Exception as ex:
_LOGGER.error("Got exception %s", ex, exc_info=True)
_LOGGER.debug("Found %s devices: %s", len(devices), devices)
return devices
_LOGGER.debug("Waiting %s seconds for responses...", timeout)
await asyncio.sleep(5)
finally:
transport.close()
_LOGGER.debug("Discovered %s devices", len(protocol.discovered_devices))
if return_raw:
return protocol.discovered_devices_raw
return protocol.discovered_devices
@staticmethod
async def discover_single(
host: str, protocol: TPLinkSmartHomeProtocol = None
) -> Optional[SmartDevice]:
async def discover_single(host: str) -> SmartDevice:
"""Discover a single device by the given IP address.
:param host: Hostname of device to query
:param protocol: Protocol implementation to use
:rtype: SmartDevice
:return: Object for querying/controlling found device.
"""
if protocol is None:
protocol = TPLinkSmartHomeProtocol()
protocol = TPLinkSmartHomeProtocol()
info = await protocol.query(host, Discover.DISCOVERY_QUERY)
@@ -113,10 +179,10 @@ class Discover:
if device_class is not None:
return device_class(host)
return None
raise SmartDeviceException("Unable to discover device, received: %s" % info)
@staticmethod
def _get_device_class(info: dict) -> Optional[Type[SmartDevice]]:
def _get_device_class(info: dict) -> Type[SmartDevice]:
"""Find SmartDevice subclass for device described by passed data."""
if "system" in info and "get_sysinfo" in info["system"]:
sysinfo = info["system"]["get_sysinfo"]
@@ -136,4 +202,17 @@ class Discover:
elif "smartbulb" in type_.lower():
return SmartBulb
return None
raise SmartDeviceException("Unknown device type: %s", type_)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
loop = asyncio.get_event_loop()
async def _on_device(dev):
await dev.update()
_LOGGER.info("Got device: %s", dev)
devices = loop.run_until_complete(Discover.discover(on_discovered=_on_device))
for ip, dev in devices.items():
print(f"[{ip}] {dev}")