Cleanup discovery & add tests (#212)

* Cleanup discovery & add tests

* discovered_devices_raw is not anymore available, as that can be accessed directly from the device objects
* test most of the discovery code paths
* some minor cleanups to test handling
* update discovery docs

* Move category check to be after the definitions

* skip a couple of tests requiring asyncmock not available on py37

* Remove return_raw usage from cli.discover
This commit is contained in:
Teemu R 2021-09-24 17:18:11 +02:00 committed by GitHub
parent bdb07a749c
commit acb221b1e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 114 additions and 72 deletions

View File

@ -143,13 +143,11 @@ async def discover(ctx, timeout, discover_only, dump_raw):
"""Discover devices in the network.""" """Discover devices in the network."""
target = ctx.parent.params["target"] target = ctx.parent.params["target"]
click.echo(f"Discovering devices on {target} for {timeout} seconds") click.echo(f"Discovering devices on {target} for {timeout} seconds")
found_devs = await Discover.discover( found_devs = await Discover.discover(target=target, timeout=timeout)
target=target, timeout=timeout, return_raw=dump_raw
)
if not discover_only: if not discover_only:
for ip, dev in found_devs.items(): for ip, dev in found_devs.items():
if dump_raw: if dump_raw:
click.echo(dev) click.echo(dev.sys_info)
continue continue
ctx.obj = dev ctx.obj = dev
await ctx.invoke(state) await ctx.invoke(state)

View File

@ -3,7 +3,7 @@ import asyncio
import json import json
import logging import logging
import socket import socket
from typing import Awaitable, Callable, Dict, Mapping, Optional, Type, Union, cast from typing import Awaitable, Callable, Dict, Optional, Type, cast
from kasa.protocol import TPLinkSmartHomeProtocol from kasa.protocol import TPLinkSmartHomeProtocol
from kasa.smartbulb import SmartBulb from kasa.smartbulb import SmartBulb
@ -17,6 +17,7 @@ _LOGGER = logging.getLogger(__name__)
OnDiscoveredCallable = Callable[[SmartDevice], Awaitable[None]] OnDiscoveredCallable = Callable[[SmartDevice], Awaitable[None]]
DeviceDict = Dict[str, SmartDevice]
class _DiscoverProtocol(asyncio.DatagramProtocol): class _DiscoverProtocol(asyncio.DatagramProtocol):
@ -25,8 +26,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
This is internal class, use :func:`Discover.discover`: instead. This is internal class, use :func:`Discover.discover`: instead.
""" """
discovered_devices: Dict[str, SmartDevice] discovered_devices: DeviceDict
discovered_devices_raw: Dict[str, Dict]
def __init__( def __init__(
self, self,
@ -43,7 +43,6 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
self.protocol = TPLinkSmartHomeProtocol() self.protocol = TPLinkSmartHomeProtocol()
self.target = (target, Discover.DISCOVERY_PORT) self.target = (target, Discover.DISCOVERY_PORT)
self.discovered_devices = {} self.discovered_devices = {}
self.discovered_devices_raw = {}
def connection_made(self, transport) -> None: def connection_made(self, transport) -> None:
"""Set socket options for broadcasting.""" """Set socket options for broadcasting."""
@ -80,13 +79,9 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
device.update_from_discover_info(info) device.update_from_discover_info(info)
self.discovered_devices[ip] = device self.discovered_devices[ip] = device
self.discovered_devices_raw[ip] = info
if device_class is not None: if self.on_discovered is not None:
if self.on_discovered is not None: asyncio.ensure_future(self.on_discovered(device))
asyncio.ensure_future(self.on_discovered(device))
else:
_LOGGER.error("Received invalid response: %s", info)
def error_received(self, ex): def error_received(self, ex):
"""Handle asyncio.Protocol errors.""" """Handle asyncio.Protocol errors."""
@ -144,9 +139,8 @@ class Discover:
on_discovered=None, on_discovered=None,
timeout=5, timeout=5,
discovery_packets=3, discovery_packets=3,
return_raw=False,
interface=None, interface=None,
) -> Mapping[str, Union[SmartDevice, Dict]]: ) -> DeviceDict:
"""Discover supported devices. """Discover supported devices.
Sends discovery message to 255.255.255.255:9999 in order Sends discovery message to 255.255.255.255:9999 in order
@ -154,17 +148,17 @@ class Discover:
and waits for given timeout for answers from devices. and waits for given timeout for answers from devices.
If you have multiple interfaces, you can use target parameter to specify the network for discovery. If you have multiple interfaces, you can use target parameter to specify the network for discovery.
If given, `on_discovered` coroutine will get passed with the :class:`SmartDevice`-derived object as parameter. If given, `on_discovered` coroutine will get awaited with a :class:`SmartDevice`-derived object as parameter.
The results of the discovery are returned either as a list of :class:`SmartDevice`-derived objects The results of the discovery are returned as a dict of :class:`SmartDevice`-derived objects keyed with IP addresses.
or as raw response dictionaries objects (if `return_raw` is True). The devices are already initialized and all but emeter-related properties can be accessed directly.
:param target: The target address where to send the broadcast discovery queries if multi-homing (e.g. 192.168.xxx.255). :param target: The target address where to send the broadcast discovery queries if multi-homing (e.g. 192.168.xxx.255).
:param on_discovered: coroutine to execute on discovery :param on_discovered: coroutine to execute on discovery
:param timeout: How long to wait for responses, defaults to 5 :param timeout: How long to wait for responses, defaults to 5
:param discovery_packets: Number of discovery packets are broadcasted. :param discovery_packets: Number of discovery packets to broadcast
:param return_raw: True to return JSON objects instead of Devices. :param interface: Bind to specific interface
:return: :return: dictionary with discovered devices
""" """
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
transport, protocol = await loop.create_datagram_endpoint( transport, protocol = await loop.create_datagram_endpoint(
@ -186,9 +180,6 @@ class Discover:
_LOGGER.debug("Discovered %s devices", len(protocol.discovered_devices)) _LOGGER.debug("Discovered %s devices", len(protocol.discovered_devices))
if return_raw:
return protocol.discovered_devices_raw
return protocol.discovered_devices return protocol.discovered_devices
@staticmethod @staticmethod
@ -204,12 +195,10 @@ class Discover:
info = await protocol.query(host, Discover.DISCOVERY_QUERY) info = await protocol.query(host, Discover.DISCOVERY_QUERY)
device_class = Discover._get_device_class(info) device_class = Discover._get_device_class(info)
if device_class is not None: dev = device_class(host)
dev = device_class(host) await dev.update()
await dev.update()
return dev
raise SmartDeviceException("Unable to discover device, received: %s" % info) return dev
@staticmethod @staticmethod
def _get_device_class(info: dict) -> Type[SmartDevice]: def _get_device_class(info: dict) -> Type[SmartDevice]:
@ -237,17 +226,4 @@ class Discover:
return SmartBulb return SmartBulb
raise SmartDeviceException("Unknown device type: %s", type_) raise SmartDeviceException("Unknown device type: %s" % type_)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
loop = asyncio.get_event_loop()
async def _on_device(dev):
await dev.update()
_LOGGER.info("Got device: %s", dev)
devices = loop.run_until_complete(Discover.discover(on_discovered=_on_device))
for ip, dev in devices.items():
print(f"[{ip}] {dev}")

View File

@ -53,8 +53,6 @@ def filter_model(desc, filter):
def parametrize(desc, devices, ids=None): def parametrize(desc, devices, ids=None):
# if ids is None:
# ids = ["on", "off"]
return pytest.mark.parametrize( return pytest.mark.parametrize(
"dev", filter_model(desc, devices), indirect=True, ids=ids "dev", filter_model(desc, devices), indirect=True, ids=ids
) )
@ -63,32 +61,11 @@ def parametrize(desc, devices, ids=None):
has_emeter = parametrize("has emeter", WITH_EMETER) has_emeter = parametrize("has emeter", WITH_EMETER)
no_emeter = parametrize("no emeter", ALL_DEVICES - WITH_EMETER) no_emeter = parametrize("no emeter", ALL_DEVICES - WITH_EMETER)
bulb = parametrize("bulbs", BULBS, ids=basename)
def name_for_filename(x): plug = parametrize("plugs", PLUGS, ids=basename)
from os.path import basename strip = parametrize("strips", STRIPS, ids=basename)
dimmer = parametrize("dimmers", DIMMERS, ids=basename)
return basename(x) lightstrip = parametrize("lightstrips", LIGHT_STRIPS, ids=basename)
bulb = parametrize("bulbs", BULBS, ids=name_for_filename)
plug = parametrize("plugs", PLUGS, ids=name_for_filename)
strip = parametrize("strips", STRIPS, ids=name_for_filename)
dimmer = parametrize("dimmers", DIMMERS, ids=name_for_filename)
lightstrip = parametrize("lightstrips", LIGHT_STRIPS, ids=name_for_filename)
# This ensures that every single file inside fixtures/ is being placed in some category
categorized_fixtures = set(
dimmer.args[1] + strip.args[1] + plug.args[1] + bulb.args[1] + lightstrip.args[1]
)
diff = set(SUPPORTED_DEVICES) - set(categorized_fixtures)
if diff:
for file in diff:
print(
"No category for file %s, add to the corresponding set (BULBS, PLUGS, ..)"
% file
)
raise Exception("Missing category for %s" % diff)
# bulb types # bulb types
dimmable = parametrize("dimmable", DIMMABLE) dimmable = parametrize("dimmable", DIMMABLE)
@ -98,6 +75,28 @@ non_variable_temp = parametrize("non-variable color temp", BULBS - VARIABLE_TEMP
color_bulb = parametrize("color bulbs", COLOR_BULBS) color_bulb = parametrize("color bulbs", COLOR_BULBS)
non_color_bulb = parametrize("non-color bulbs", BULBS - COLOR_BULBS) non_color_bulb = parametrize("non-color bulbs", BULBS - COLOR_BULBS)
def check_categories():
"""Check that every fixture file is categorized."""
categorized_fixtures = set(
dimmer.args[1]
+ strip.args[1]
+ plug.args[1]
+ bulb.args[1]
+ lightstrip.args[1]
)
diff = set(SUPPORTED_DEVICES) - set(categorized_fixtures)
if diff:
for file in diff:
print(
"No category for file %s, add to the corresponding set (BULBS, PLUGS, ..)"
% file
)
raise Exception("Missing category for %s" % diff)
check_categories()
# Parametrize tests to run with device both on and off # Parametrize tests to run with device both on and off
turn_on = pytest.mark.parametrize("turn_on", [True, False]) turn_on = pytest.mark.parametrize("turn_on", [True, False])
@ -174,6 +173,18 @@ def dev(request):
return get_device_for_file(file) return get_device_for_file(file)
@pytest.fixture(params=SUPPORTED_DEVICES, scope="session")
def discovery_data(request):
"""Return raw discovery file contents as JSON. Used for discovery tests."""
file = request.param
p = Path(file)
if not p.is_absolute():
p = Path(__file__).parent / "fixtures" / file
with open(p) as f:
return json.load(f)
def pytest_addoption(parser): def pytest_addoption(parser):
parser.addoption( parser.addoption(
"--ip", action="store", default=None, help="run against device on given ip" "--ip", action="store", default=None, help="run against device on given ip"

View File

@ -1,7 +1,10 @@
# type: ignore # type: ignore
import sys
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
from kasa import DeviceType, Discover, SmartDevice, SmartDeviceException from kasa import DeviceType, Discover, SmartDevice, SmartDeviceException
from kasa.discover import _DiscoverProtocol
from .conftest import bulb, dimmer, lightstrip, plug, pytestmark, strip from .conftest import bulb, dimmer, lightstrip, plug, pytestmark, strip
@ -47,3 +50,57 @@ async def test_type_unknown():
invalid_info = {"system": {"get_sysinfo": {"type": "nosuchtype"}}} invalid_info = {"system": {"get_sysinfo": {"type": "nosuchtype"}}}
with pytest.raises(SmartDeviceException): with pytest.raises(SmartDeviceException):
Discover._get_device_class(invalid_info) Discover._get_device_class(invalid_info)
@pytest.mark.skipif(sys.version_info < (3, 8), reason="3.8 is first one with asyncmock")
async def test_discover_single(discovery_data: dict, mocker):
"""Make sure that discover_single returns an initialized SmartDevice instance."""
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
x = await Discover.discover_single("127.0.0.1")
assert issubclass(x.__class__, SmartDevice)
assert x._sys_info is not None
INVALIDS = [
("No 'system' or 'get_sysinfo' in response", {"no": "data"}),
(
"Unable to find the device type field",
{"system": {"get_sysinfo": {"missing_type": 1}}},
),
("Unknown device type: foo", {"system": {"get_sysinfo": {"type": "foo"}}}),
]
@pytest.mark.skipif(sys.version_info < (3, 8), reason="3.8 is first one with asyncmock")
@pytest.mark.parametrize("msg, data", INVALIDS)
async def test_discover_invalid_info(msg, data, mocker):
"""Make sure that invalid discovery information raises an exception."""
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=data)
with pytest.raises(SmartDeviceException, match=msg):
await Discover.discover_single("127.0.0.1")
async def test_discover_send(mocker):
"""Test discovery parameters."""
proto = _DiscoverProtocol()
assert proto.discovery_packets == 3
assert proto.target == ("255.255.255.255", 9999)
sendto = mocker.patch.object(proto, "transport")
proto.do_discover()
assert sendto.sendto.call_count == proto.discovery_packets
async def test_discover_datagram_received(mocker, discovery_data):
"""Verify that datagram received fills discovered_devices."""
proto = _DiscoverProtocol()
mocker.patch("json.loads", return_value=discovery_data)
mocker.patch.object(proto, "protocol")
addr = "127.0.0.1"
proto.datagram_received("<placeholder data>", (addr, 1234))
# Check that device in discovered_devices is initialized correctly
assert len(proto.discovered_devices) == 1
dev = proto.discovered_devices[addr]
assert issubclass(dev.__class__, SmartDevice)
assert dev.host == addr