mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-10-12 10:28:01 +00:00
Remove sync interface, add asyncio discovery (#14)
* do not update inside __repr__ * Convert discovery to asyncio * Use asyncio.DatagramProtocol * Cleanup parameters, no more positional arguments Closes #7 * Remove sync interface * This requires #13 to be merged. Closes #12. * Converts cli to use asyncio.run() where needed. * The children from smartstrips is being initialized during the first update call. * Convert on and off commands to use asyncio.run * conftest: do the initial update automatically for the device, cleans up tests a bit * return subdevices alias for strip plugs, remove sync from docstrings * Make tests pass using pytest-asyncio * Simplify tests and use pytest-asyncio. * Removed the emeter tests for child devices, as this information do not seem to exist (based on the dummy sysinfo data). Can be added again if needed. * Remove sync from docstrings. * Fix incorrect type hint * Add type hints and some docstrings to discovery
This commit is contained in:
187
kasa/discover.py
187
kasa/discover.py
@@ -1,8 +1,9 @@
|
||||
"""Discovery module for TP-Link Smart Home devices."""
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import socket
|
||||
from typing import Dict, Optional, Type
|
||||
from typing import Awaitable, Callable, Dict, Mapping, Type, Union, cast
|
||||
|
||||
from kasa.protocol import TPLinkSmartHomeProtocol
|
||||
from kasa.smartbulb import SmartBulb
|
||||
@@ -13,6 +14,79 @@ from kasa.smartstrip import SmartStrip
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
OnDiscoveredCallable = Callable[[SmartDevice], Awaitable[None]]
|
||||
|
||||
|
||||
class _DiscoverProtocol(asyncio.DatagramProtocol):
|
||||
"""Implementation of the discovery protocol handler.
|
||||
|
||||
This is internal class, use :func:Discover.discover: instead.
|
||||
"""
|
||||
|
||||
discovered_devices: Dict[str, SmartDevice]
|
||||
discovered_devices_raw: Dict[str, Dict]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
on_discovered: OnDiscoveredCallable = None,
|
||||
target: str = "255.255.255.255",
|
||||
timeout: int = 5,
|
||||
discovery_packets: int = 3,
|
||||
):
|
||||
self.transport = None
|
||||
self.tries = discovery_packets
|
||||
self.timeout = timeout
|
||||
self.on_discovered = on_discovered
|
||||
self.protocol = TPLinkSmartHomeProtocol()
|
||||
self.target = (target, Discover.DISCOVERY_PORT)
|
||||
self.discovered_devices = {}
|
||||
self.discovered_devices_raw = {}
|
||||
|
||||
def connection_made(self, transport) -> None:
|
||||
"""Set socket options for broadcasting."""
|
||||
self.transport = transport
|
||||
sock = transport.get_extra_info("socket")
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
|
||||
self.do_discover()
|
||||
|
||||
def do_discover(self) -> None:
|
||||
"""Send number of discovery datagrams."""
|
||||
req = json.dumps(Discover.DISCOVERY_QUERY)
|
||||
_LOGGER.debug("[DISCOVERY] %s >> %s", self.target, Discover.DISCOVERY_QUERY)
|
||||
encrypted_req = self.protocol.encrypt(req)
|
||||
for i in range(self.tries):
|
||||
self.transport.sendto(encrypted_req[4:], self.target) # type: ignore
|
||||
|
||||
def datagram_received(self, data, addr) -> None:
|
||||
ip, port = addr
|
||||
if ip in self.discovered_devices:
|
||||
return
|
||||
|
||||
info = json.loads(self.protocol.decrypt(data))
|
||||
_LOGGER.debug("[DISCOVERY] %s << %s", ip, info)
|
||||
|
||||
device_class = Discover._get_device_class(info)
|
||||
device = device_class(ip)
|
||||
|
||||
self.discovered_devices[ip] = device
|
||||
self.discovered_devices_raw[ip] = info
|
||||
|
||||
if device_class is not None:
|
||||
if self.on_discovered is not None:
|
||||
asyncio.ensure_future(self.on_discovered(device))
|
||||
else:
|
||||
_LOGGER.error("Received invalid response: %s", info)
|
||||
|
||||
def error_received(self, ex):
|
||||
_LOGGER.error("Got error: %s", ex)
|
||||
|
||||
def connection_lost(self, ex):
|
||||
pass
|
||||
|
||||
|
||||
class Discover:
|
||||
"""Discover TPLink Smart Home devices.
|
||||
|
||||
@@ -28,6 +102,8 @@ class Discover:
|
||||
The protocol uses UDP broadcast datagrams on port 9999 for discovery.
|
||||
"""
|
||||
|
||||
DISCOVERY_PORT = 9999
|
||||
|
||||
DISCOVERY_QUERY = {
|
||||
"system": {"get_sysinfo": None},
|
||||
"emeter": {"get_realtime": None},
|
||||
@@ -37,75 +113,65 @@ class Discover:
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def discover(
|
||||
protocol: TPLinkSmartHomeProtocol = None,
|
||||
target: str = "255.255.255.255",
|
||||
port: int = 9999,
|
||||
timeout: int = 3,
|
||||
async def discover(
|
||||
*,
|
||||
target="255.255.255.255",
|
||||
on_discovered=None,
|
||||
timeout=5,
|
||||
discovery_packets=3,
|
||||
return_raw=False,
|
||||
) -> Dict[str, SmartDevice]:
|
||||
"""Discover devices.
|
||||
) -> Mapping[str, Union[SmartDevice, Dict]]:
|
||||
"""Discover supported devices.
|
||||
|
||||
Sends discovery message to 255.255.255.255:9999 in order
|
||||
to detect available supported devices in the local network,
|
||||
and waits for given timeout for answers from devices.
|
||||
|
||||
:param protocol: Protocol implementation to use
|
||||
If given, `on_discovered` coroutine will get passed with the SmartDevice as parameter.
|
||||
The results of the discovery can be accessed either via `discovered_devices` (SmartDevice-derived) or
|
||||
`discovered_devices_raw` (JSON objects).
|
||||
|
||||
:param target: The target broadcast address (e.g. 192.168.xxx.255).
|
||||
:param timeout: How long to wait for responses, defaults to 3
|
||||
:param port: port to send broadcast messages, defaults to 9999.
|
||||
:rtype: dict
|
||||
:return: Array of json objects {"ip", "port", "sys_info"}
|
||||
:param on_discovered:
|
||||
:param timeout: How long to wait for responses, defaults to 5
|
||||
:param discovery_packets: Number of discovery packets are broadcasted.
|
||||
:param return_raw: True to return JSON objects instead of Devices.
|
||||
:return:
|
||||
"""
|
||||
if protocol is None:
|
||||
protocol = TPLinkSmartHomeProtocol()
|
||||
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
sock.settimeout(timeout)
|
||||
|
||||
req = json.dumps(Discover.DISCOVERY_QUERY)
|
||||
_LOGGER.debug("Sending discovery to %s:%s", target, port)
|
||||
|
||||
encrypted_req = protocol.encrypt(req)
|
||||
for i in range(discovery_packets):
|
||||
sock.sendto(encrypted_req[4:], (target, port))
|
||||
|
||||
devices = {}
|
||||
_LOGGER.debug("Waiting %s seconds for responses...", timeout)
|
||||
loop = asyncio.get_event_loop()
|
||||
transport, protocol = await loop.create_datagram_endpoint(
|
||||
lambda: _DiscoverProtocol(
|
||||
target=target,
|
||||
on_discovered=on_discovered,
|
||||
timeout=timeout,
|
||||
discovery_packets=discovery_packets,
|
||||
),
|
||||
local_addr=("0.0.0.0", 0),
|
||||
)
|
||||
protocol = cast(_DiscoverProtocol, protocol)
|
||||
|
||||
try:
|
||||
while True:
|
||||
data, addr = sock.recvfrom(4096)
|
||||
ip, port = addr
|
||||
info = json.loads(protocol.decrypt(data))
|
||||
device_class = Discover._get_device_class(info)
|
||||
if return_raw:
|
||||
devices[ip] = info
|
||||
elif device_class is not None:
|
||||
devices[ip] = device_class(ip)
|
||||
except socket.timeout:
|
||||
_LOGGER.debug("Got socket timeout, which is okay.")
|
||||
except Exception as ex:
|
||||
_LOGGER.error("Got exception %s", ex, exc_info=True)
|
||||
_LOGGER.debug("Found %s devices: %s", len(devices), devices)
|
||||
return devices
|
||||
_LOGGER.debug("Waiting %s seconds for responses...", timeout)
|
||||
await asyncio.sleep(5)
|
||||
finally:
|
||||
transport.close()
|
||||
|
||||
_LOGGER.debug("Discovered %s devices", len(protocol.discovered_devices))
|
||||
|
||||
if return_raw:
|
||||
return protocol.discovered_devices_raw
|
||||
|
||||
return protocol.discovered_devices
|
||||
|
||||
@staticmethod
|
||||
async def discover_single(
|
||||
host: str, protocol: TPLinkSmartHomeProtocol = None
|
||||
) -> Optional[SmartDevice]:
|
||||
async def discover_single(host: str) -> SmartDevice:
|
||||
"""Discover a single device by the given IP address.
|
||||
|
||||
:param host: Hostname of device to query
|
||||
:param protocol: Protocol implementation to use
|
||||
:rtype: SmartDevice
|
||||
:return: Object for querying/controlling found device.
|
||||
"""
|
||||
if protocol is None:
|
||||
protocol = TPLinkSmartHomeProtocol()
|
||||
protocol = TPLinkSmartHomeProtocol()
|
||||
|
||||
info = await protocol.query(host, Discover.DISCOVERY_QUERY)
|
||||
|
||||
@@ -113,10 +179,10 @@ class Discover:
|
||||
if device_class is not None:
|
||||
return device_class(host)
|
||||
|
||||
return None
|
||||
raise SmartDeviceException("Unable to discover device, received: %s" % info)
|
||||
|
||||
@staticmethod
|
||||
def _get_device_class(info: dict) -> Optional[Type[SmartDevice]]:
|
||||
def _get_device_class(info: dict) -> Type[SmartDevice]:
|
||||
"""Find SmartDevice subclass for device described by passed data."""
|
||||
if "system" in info and "get_sysinfo" in info["system"]:
|
||||
sysinfo = info["system"]["get_sysinfo"]
|
||||
@@ -136,4 +202,17 @@ class Discover:
|
||||
elif "smartbulb" in type_.lower():
|
||||
return SmartBulb
|
||||
|
||||
return None
|
||||
raise SmartDeviceException("Unknown device type: %s", type_)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
async def _on_device(dev):
|
||||
await dev.update()
|
||||
_LOGGER.info("Got device: %s", dev)
|
||||
|
||||
devices = loop.run_until_complete(Discover.discover(on_discovered=_on_device))
|
||||
for ip, dev in devices.items():
|
||||
print(f"[{ip}] {dev}")
|
||||
|
Reference in New Issue
Block a user