Cleanup discovery & add tests (#212)

* Cleanup discovery & add tests

* discovered_devices_raw is not anymore available, as that can be accessed directly from the device objects
* test most of the discovery code paths
* some minor cleanups to test handling
* update discovery docs

* Move category check to be after the definitions

* skip a couple of tests requiring asyncmock not available on py37

* Remove return_raw usage from cli.discover
This commit is contained in:
Teemu R 2021-09-24 17:18:11 +02:00 committed by GitHub
parent bdb07a749c
commit acb221b1e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 114 additions and 72 deletions

View File

@ -143,13 +143,11 @@ async def discover(ctx, timeout, discover_only, dump_raw):
"""Discover devices in the network."""
target = ctx.parent.params["target"]
click.echo(f"Discovering devices on {target} for {timeout} seconds")
found_devs = await Discover.discover(
target=target, timeout=timeout, return_raw=dump_raw
)
found_devs = await Discover.discover(target=target, timeout=timeout)
if not discover_only:
for ip, dev in found_devs.items():
if dump_raw:
click.echo(dev)
click.echo(dev.sys_info)
continue
ctx.obj = dev
await ctx.invoke(state)

View File

@ -3,7 +3,7 @@ import asyncio
import json
import logging
import socket
from typing import Awaitable, Callable, Dict, Mapping, Optional, Type, Union, cast
from typing import Awaitable, Callable, Dict, Optional, Type, cast
from kasa.protocol import TPLinkSmartHomeProtocol
from kasa.smartbulb import SmartBulb
@ -17,6 +17,7 @@ _LOGGER = logging.getLogger(__name__)
OnDiscoveredCallable = Callable[[SmartDevice], Awaitable[None]]
DeviceDict = Dict[str, SmartDevice]
class _DiscoverProtocol(asyncio.DatagramProtocol):
@ -25,8 +26,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
This is internal class, use :func:`Discover.discover`: instead.
"""
discovered_devices: Dict[str, SmartDevice]
discovered_devices_raw: Dict[str, Dict]
discovered_devices: DeviceDict
def __init__(
self,
@ -43,7 +43,6 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
self.protocol = TPLinkSmartHomeProtocol()
self.target = (target, Discover.DISCOVERY_PORT)
self.discovered_devices = {}
self.discovered_devices_raw = {}
def connection_made(self, transport) -> None:
"""Set socket options for broadcasting."""
@ -80,13 +79,9 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
device.update_from_discover_info(info)
self.discovered_devices[ip] = device
self.discovered_devices_raw[ip] = info
if device_class is not None:
if self.on_discovered is not None:
asyncio.ensure_future(self.on_discovered(device))
else:
_LOGGER.error("Received invalid response: %s", info)
if self.on_discovered is not None:
asyncio.ensure_future(self.on_discovered(device))
def error_received(self, ex):
"""Handle asyncio.Protocol errors."""
@ -144,9 +139,8 @@ class Discover:
on_discovered=None,
timeout=5,
discovery_packets=3,
return_raw=False,
interface=None,
) -> Mapping[str, Union[SmartDevice, Dict]]:
) -> DeviceDict:
"""Discover supported devices.
Sends discovery message to 255.255.255.255:9999 in order
@ -154,17 +148,17 @@ class Discover:
and waits for given timeout for answers from devices.
If you have multiple interfaces, you can use target parameter to specify the network for discovery.
If given, `on_discovered` coroutine will get passed with the :class:`SmartDevice`-derived object as parameter.
If given, `on_discovered` coroutine will get awaited with a :class:`SmartDevice`-derived object as parameter.
The results of the discovery are returned either as a list of :class:`SmartDevice`-derived objects
or as raw response dictionaries objects (if `return_raw` is True).
The results of the discovery are returned as a dict of :class:`SmartDevice`-derived objects keyed with IP addresses.
The devices are already initialized and all but emeter-related properties can be accessed directly.
:param target: The target address where to send the broadcast discovery queries if multi-homing (e.g. 192.168.xxx.255).
:param on_discovered: coroutine to execute on discovery
:param timeout: How long to wait for responses, defaults to 5
:param discovery_packets: Number of discovery packets are broadcasted.
:param return_raw: True to return JSON objects instead of Devices.
:return:
:param discovery_packets: Number of discovery packets to broadcast
:param interface: Bind to specific interface
:return: dictionary with discovered devices
"""
loop = asyncio.get_event_loop()
transport, protocol = await loop.create_datagram_endpoint(
@ -186,9 +180,6 @@ class Discover:
_LOGGER.debug("Discovered %s devices", len(protocol.discovered_devices))
if return_raw:
return protocol.discovered_devices_raw
return protocol.discovered_devices
@staticmethod
@ -204,12 +195,10 @@ class Discover:
info = await protocol.query(host, Discover.DISCOVERY_QUERY)
device_class = Discover._get_device_class(info)
if device_class is not None:
dev = device_class(host)
await dev.update()
return dev
dev = device_class(host)
await dev.update()
raise SmartDeviceException("Unable to discover device, received: %s" % info)
return dev
@staticmethod
def _get_device_class(info: dict) -> Type[SmartDevice]:
@ -237,17 +226,4 @@ class Discover:
return SmartBulb
raise SmartDeviceException("Unknown device type: %s", type_)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
loop = asyncio.get_event_loop()
async def _on_device(dev):
await dev.update()
_LOGGER.info("Got device: %s", dev)
devices = loop.run_until_complete(Discover.discover(on_discovered=_on_device))
for ip, dev in devices.items():
print(f"[{ip}] {dev}")
raise SmartDeviceException("Unknown device type: %s" % type_)

View File

@ -53,8 +53,6 @@ def filter_model(desc, filter):
def parametrize(desc, devices, ids=None):
# if ids is None:
# ids = ["on", "off"]
return pytest.mark.parametrize(
"dev", filter_model(desc, devices), indirect=True, ids=ids
)
@ -63,32 +61,11 @@ def parametrize(desc, devices, ids=None):
has_emeter = parametrize("has emeter", WITH_EMETER)
no_emeter = parametrize("no emeter", ALL_DEVICES - WITH_EMETER)
def name_for_filename(x):
from os.path import basename
return basename(x)
bulb = parametrize("bulbs", BULBS, ids=name_for_filename)
plug = parametrize("plugs", PLUGS, ids=name_for_filename)
strip = parametrize("strips", STRIPS, ids=name_for_filename)
dimmer = parametrize("dimmers", DIMMERS, ids=name_for_filename)
lightstrip = parametrize("lightstrips", LIGHT_STRIPS, ids=name_for_filename)
# This ensures that every single file inside fixtures/ is being placed in some category
categorized_fixtures = set(
dimmer.args[1] + strip.args[1] + plug.args[1] + bulb.args[1] + lightstrip.args[1]
)
diff = set(SUPPORTED_DEVICES) - set(categorized_fixtures)
if diff:
for file in diff:
print(
"No category for file %s, add to the corresponding set (BULBS, PLUGS, ..)"
% file
)
raise Exception("Missing category for %s" % diff)
bulb = parametrize("bulbs", BULBS, ids=basename)
plug = parametrize("plugs", PLUGS, ids=basename)
strip = parametrize("strips", STRIPS, ids=basename)
dimmer = parametrize("dimmers", DIMMERS, ids=basename)
lightstrip = parametrize("lightstrips", LIGHT_STRIPS, ids=basename)
# bulb types
dimmable = parametrize("dimmable", DIMMABLE)
@ -98,6 +75,28 @@ non_variable_temp = parametrize("non-variable color temp", BULBS - VARIABLE_TEMP
color_bulb = parametrize("color bulbs", COLOR_BULBS)
non_color_bulb = parametrize("non-color bulbs", BULBS - COLOR_BULBS)
def check_categories():
"""Check that every fixture file is categorized."""
categorized_fixtures = set(
dimmer.args[1]
+ strip.args[1]
+ plug.args[1]
+ bulb.args[1]
+ lightstrip.args[1]
)
diff = set(SUPPORTED_DEVICES) - set(categorized_fixtures)
if diff:
for file in diff:
print(
"No category for file %s, add to the corresponding set (BULBS, PLUGS, ..)"
% file
)
raise Exception("Missing category for %s" % diff)
check_categories()
# Parametrize tests to run with device both on and off
turn_on = pytest.mark.parametrize("turn_on", [True, False])
@ -174,6 +173,18 @@ def dev(request):
return get_device_for_file(file)
@pytest.fixture(params=SUPPORTED_DEVICES, scope="session")
def discovery_data(request):
"""Return raw discovery file contents as JSON. Used for discovery tests."""
file = request.param
p = Path(file)
if not p.is_absolute():
p = Path(__file__).parent / "fixtures" / file
with open(p) as f:
return json.load(f)
def pytest_addoption(parser):
parser.addoption(
"--ip", action="store", default=None, help="run against device on given ip"

View File

@ -1,7 +1,10 @@
# type: ignore
import sys
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
from kasa import DeviceType, Discover, SmartDevice, SmartDeviceException
from kasa.discover import _DiscoverProtocol
from .conftest import bulb, dimmer, lightstrip, plug, pytestmark, strip
@ -47,3 +50,57 @@ async def test_type_unknown():
invalid_info = {"system": {"get_sysinfo": {"type": "nosuchtype"}}}
with pytest.raises(SmartDeviceException):
Discover._get_device_class(invalid_info)
@pytest.mark.skipif(sys.version_info < (3, 8), reason="3.8 is first one with asyncmock")
async def test_discover_single(discovery_data: dict, mocker):
"""Make sure that discover_single returns an initialized SmartDevice instance."""
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
x = await Discover.discover_single("127.0.0.1")
assert issubclass(x.__class__, SmartDevice)
assert x._sys_info is not None
INVALIDS = [
("No 'system' or 'get_sysinfo' in response", {"no": "data"}),
(
"Unable to find the device type field",
{"system": {"get_sysinfo": {"missing_type": 1}}},
),
("Unknown device type: foo", {"system": {"get_sysinfo": {"type": "foo"}}}),
]
@pytest.mark.skipif(sys.version_info < (3, 8), reason="3.8 is first one with asyncmock")
@pytest.mark.parametrize("msg, data", INVALIDS)
async def test_discover_invalid_info(msg, data, mocker):
"""Make sure that invalid discovery information raises an exception."""
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=data)
with pytest.raises(SmartDeviceException, match=msg):
await Discover.discover_single("127.0.0.1")
async def test_discover_send(mocker):
"""Test discovery parameters."""
proto = _DiscoverProtocol()
assert proto.discovery_packets == 3
assert proto.target == ("255.255.255.255", 9999)
sendto = mocker.patch.object(proto, "transport")
proto.do_discover()
assert sendto.sendto.call_count == proto.discovery_packets
async def test_discover_datagram_received(mocker, discovery_data):
"""Verify that datagram received fills discovered_devices."""
proto = _DiscoverProtocol()
mocker.patch("json.loads", return_value=discovery_data)
mocker.patch.object(proto, "protocol")
addr = "127.0.0.1"
proto.datagram_received("<placeholder data>", (addr, 1234))
# Check that device in discovered_devices is initialized correctly
assert len(proto.discovered_devices) == 1
dev = proto.discovered_devices[addr]
assert issubclass(dev.__class__, SmartDevice)
assert dev.host == addr