mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-01-08 22:07:06 +00:00
Cleanup discovery & add tests (#212)
* Cleanup discovery & add tests * discovered_devices_raw is not anymore available, as that can be accessed directly from the device objects * test most of the discovery code paths * some minor cleanups to test handling * update discovery docs * Move category check to be after the definitions * skip a couple of tests requiring asyncmock not available on py37 * Remove return_raw usage from cli.discover
This commit is contained in:
parent
bdb07a749c
commit
acb221b1e0
@ -143,13 +143,11 @@ async def discover(ctx, timeout, discover_only, dump_raw):
|
|||||||
"""Discover devices in the network."""
|
"""Discover devices in the network."""
|
||||||
target = ctx.parent.params["target"]
|
target = ctx.parent.params["target"]
|
||||||
click.echo(f"Discovering devices on {target} for {timeout} seconds")
|
click.echo(f"Discovering devices on {target} for {timeout} seconds")
|
||||||
found_devs = await Discover.discover(
|
found_devs = await Discover.discover(target=target, timeout=timeout)
|
||||||
target=target, timeout=timeout, return_raw=dump_raw
|
|
||||||
)
|
|
||||||
if not discover_only:
|
if not discover_only:
|
||||||
for ip, dev in found_devs.items():
|
for ip, dev in found_devs.items():
|
||||||
if dump_raw:
|
if dump_raw:
|
||||||
click.echo(dev)
|
click.echo(dev.sys_info)
|
||||||
continue
|
continue
|
||||||
ctx.obj = dev
|
ctx.obj = dev
|
||||||
await ctx.invoke(state)
|
await ctx.invoke(state)
|
||||||
|
@ -3,7 +3,7 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import socket
|
import socket
|
||||||
from typing import Awaitable, Callable, Dict, Mapping, Optional, Type, Union, cast
|
from typing import Awaitable, Callable, Dict, Optional, Type, cast
|
||||||
|
|
||||||
from kasa.protocol import TPLinkSmartHomeProtocol
|
from kasa.protocol import TPLinkSmartHomeProtocol
|
||||||
from kasa.smartbulb import SmartBulb
|
from kasa.smartbulb import SmartBulb
|
||||||
@ -17,6 +17,7 @@ _LOGGER = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
OnDiscoveredCallable = Callable[[SmartDevice], Awaitable[None]]
|
OnDiscoveredCallable = Callable[[SmartDevice], Awaitable[None]]
|
||||||
|
DeviceDict = Dict[str, SmartDevice]
|
||||||
|
|
||||||
|
|
||||||
class _DiscoverProtocol(asyncio.DatagramProtocol):
|
class _DiscoverProtocol(asyncio.DatagramProtocol):
|
||||||
@ -25,8 +26,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
|
|||||||
This is internal class, use :func:`Discover.discover`: instead.
|
This is internal class, use :func:`Discover.discover`: instead.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
discovered_devices: Dict[str, SmartDevice]
|
discovered_devices: DeviceDict
|
||||||
discovered_devices_raw: Dict[str, Dict]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -43,7 +43,6 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
|
|||||||
self.protocol = TPLinkSmartHomeProtocol()
|
self.protocol = TPLinkSmartHomeProtocol()
|
||||||
self.target = (target, Discover.DISCOVERY_PORT)
|
self.target = (target, Discover.DISCOVERY_PORT)
|
||||||
self.discovered_devices = {}
|
self.discovered_devices = {}
|
||||||
self.discovered_devices_raw = {}
|
|
||||||
|
|
||||||
def connection_made(self, transport) -> None:
|
def connection_made(self, transport) -> None:
|
||||||
"""Set socket options for broadcasting."""
|
"""Set socket options for broadcasting."""
|
||||||
@ -80,13 +79,9 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
|
|||||||
device.update_from_discover_info(info)
|
device.update_from_discover_info(info)
|
||||||
|
|
||||||
self.discovered_devices[ip] = device
|
self.discovered_devices[ip] = device
|
||||||
self.discovered_devices_raw[ip] = info
|
|
||||||
|
|
||||||
if device_class is not None:
|
if self.on_discovered is not None:
|
||||||
if self.on_discovered is not None:
|
asyncio.ensure_future(self.on_discovered(device))
|
||||||
asyncio.ensure_future(self.on_discovered(device))
|
|
||||||
else:
|
|
||||||
_LOGGER.error("Received invalid response: %s", info)
|
|
||||||
|
|
||||||
def error_received(self, ex):
|
def error_received(self, ex):
|
||||||
"""Handle asyncio.Protocol errors."""
|
"""Handle asyncio.Protocol errors."""
|
||||||
@ -144,9 +139,8 @@ class Discover:
|
|||||||
on_discovered=None,
|
on_discovered=None,
|
||||||
timeout=5,
|
timeout=5,
|
||||||
discovery_packets=3,
|
discovery_packets=3,
|
||||||
return_raw=False,
|
|
||||||
interface=None,
|
interface=None,
|
||||||
) -> Mapping[str, Union[SmartDevice, Dict]]:
|
) -> DeviceDict:
|
||||||
"""Discover supported devices.
|
"""Discover supported devices.
|
||||||
|
|
||||||
Sends discovery message to 255.255.255.255:9999 in order
|
Sends discovery message to 255.255.255.255:9999 in order
|
||||||
@ -154,17 +148,17 @@ class Discover:
|
|||||||
and waits for given timeout for answers from devices.
|
and waits for given timeout for answers from devices.
|
||||||
If you have multiple interfaces, you can use target parameter to specify the network for discovery.
|
If you have multiple interfaces, you can use target parameter to specify the network for discovery.
|
||||||
|
|
||||||
If given, `on_discovered` coroutine will get passed with the :class:`SmartDevice`-derived object as parameter.
|
If given, `on_discovered` coroutine will get awaited with a :class:`SmartDevice`-derived object as parameter.
|
||||||
|
|
||||||
The results of the discovery are returned either as a list of :class:`SmartDevice`-derived objects
|
The results of the discovery are returned as a dict of :class:`SmartDevice`-derived objects keyed with IP addresses.
|
||||||
or as raw response dictionaries objects (if `return_raw` is True).
|
The devices are already initialized and all but emeter-related properties can be accessed directly.
|
||||||
|
|
||||||
:param target: The target address where to send the broadcast discovery queries if multi-homing (e.g. 192.168.xxx.255).
|
:param target: The target address where to send the broadcast discovery queries if multi-homing (e.g. 192.168.xxx.255).
|
||||||
:param on_discovered: coroutine to execute on discovery
|
:param on_discovered: coroutine to execute on discovery
|
||||||
:param timeout: How long to wait for responses, defaults to 5
|
:param timeout: How long to wait for responses, defaults to 5
|
||||||
:param discovery_packets: Number of discovery packets are broadcasted.
|
:param discovery_packets: Number of discovery packets to broadcast
|
||||||
:param return_raw: True to return JSON objects instead of Devices.
|
:param interface: Bind to specific interface
|
||||||
:return:
|
:return: dictionary with discovered devices
|
||||||
"""
|
"""
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
transport, protocol = await loop.create_datagram_endpoint(
|
transport, protocol = await loop.create_datagram_endpoint(
|
||||||
@ -186,9 +180,6 @@ class Discover:
|
|||||||
|
|
||||||
_LOGGER.debug("Discovered %s devices", len(protocol.discovered_devices))
|
_LOGGER.debug("Discovered %s devices", len(protocol.discovered_devices))
|
||||||
|
|
||||||
if return_raw:
|
|
||||||
return protocol.discovered_devices_raw
|
|
||||||
|
|
||||||
return protocol.discovered_devices
|
return protocol.discovered_devices
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -204,12 +195,10 @@ class Discover:
|
|||||||
info = await protocol.query(host, Discover.DISCOVERY_QUERY)
|
info = await protocol.query(host, Discover.DISCOVERY_QUERY)
|
||||||
|
|
||||||
device_class = Discover._get_device_class(info)
|
device_class = Discover._get_device_class(info)
|
||||||
if device_class is not None:
|
dev = device_class(host)
|
||||||
dev = device_class(host)
|
await dev.update()
|
||||||
await dev.update()
|
|
||||||
return dev
|
|
||||||
|
|
||||||
raise SmartDeviceException("Unable to discover device, received: %s" % info)
|
return dev
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_device_class(info: dict) -> Type[SmartDevice]:
|
def _get_device_class(info: dict) -> Type[SmartDevice]:
|
||||||
@ -237,17 +226,4 @@ class Discover:
|
|||||||
|
|
||||||
return SmartBulb
|
return SmartBulb
|
||||||
|
|
||||||
raise SmartDeviceException("Unknown device type: %s", type_)
|
raise SmartDeviceException("Unknown device type: %s" % type_)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
|
|
||||||
async def _on_device(dev):
|
|
||||||
await dev.update()
|
|
||||||
_LOGGER.info("Got device: %s", dev)
|
|
||||||
|
|
||||||
devices = loop.run_until_complete(Discover.discover(on_discovered=_on_device))
|
|
||||||
for ip, dev in devices.items():
|
|
||||||
print(f"[{ip}] {dev}")
|
|
||||||
|
@ -53,8 +53,6 @@ def filter_model(desc, filter):
|
|||||||
|
|
||||||
|
|
||||||
def parametrize(desc, devices, ids=None):
|
def parametrize(desc, devices, ids=None):
|
||||||
# if ids is None:
|
|
||||||
# ids = ["on", "off"]
|
|
||||||
return pytest.mark.parametrize(
|
return pytest.mark.parametrize(
|
||||||
"dev", filter_model(desc, devices), indirect=True, ids=ids
|
"dev", filter_model(desc, devices), indirect=True, ids=ids
|
||||||
)
|
)
|
||||||
@ -63,32 +61,11 @@ def parametrize(desc, devices, ids=None):
|
|||||||
has_emeter = parametrize("has emeter", WITH_EMETER)
|
has_emeter = parametrize("has emeter", WITH_EMETER)
|
||||||
no_emeter = parametrize("no emeter", ALL_DEVICES - WITH_EMETER)
|
no_emeter = parametrize("no emeter", ALL_DEVICES - WITH_EMETER)
|
||||||
|
|
||||||
|
bulb = parametrize("bulbs", BULBS, ids=basename)
|
||||||
def name_for_filename(x):
|
plug = parametrize("plugs", PLUGS, ids=basename)
|
||||||
from os.path import basename
|
strip = parametrize("strips", STRIPS, ids=basename)
|
||||||
|
dimmer = parametrize("dimmers", DIMMERS, ids=basename)
|
||||||
return basename(x)
|
lightstrip = parametrize("lightstrips", LIGHT_STRIPS, ids=basename)
|
||||||
|
|
||||||
|
|
||||||
bulb = parametrize("bulbs", BULBS, ids=name_for_filename)
|
|
||||||
plug = parametrize("plugs", PLUGS, ids=name_for_filename)
|
|
||||||
strip = parametrize("strips", STRIPS, ids=name_for_filename)
|
|
||||||
dimmer = parametrize("dimmers", DIMMERS, ids=name_for_filename)
|
|
||||||
lightstrip = parametrize("lightstrips", LIGHT_STRIPS, ids=name_for_filename)
|
|
||||||
|
|
||||||
# This ensures that every single file inside fixtures/ is being placed in some category
|
|
||||||
categorized_fixtures = set(
|
|
||||||
dimmer.args[1] + strip.args[1] + plug.args[1] + bulb.args[1] + lightstrip.args[1]
|
|
||||||
)
|
|
||||||
diff = set(SUPPORTED_DEVICES) - set(categorized_fixtures)
|
|
||||||
if diff:
|
|
||||||
for file in diff:
|
|
||||||
print(
|
|
||||||
"No category for file %s, add to the corresponding set (BULBS, PLUGS, ..)"
|
|
||||||
% file
|
|
||||||
)
|
|
||||||
raise Exception("Missing category for %s" % diff)
|
|
||||||
|
|
||||||
|
|
||||||
# bulb types
|
# bulb types
|
||||||
dimmable = parametrize("dimmable", DIMMABLE)
|
dimmable = parametrize("dimmable", DIMMABLE)
|
||||||
@ -98,6 +75,28 @@ non_variable_temp = parametrize("non-variable color temp", BULBS - VARIABLE_TEMP
|
|||||||
color_bulb = parametrize("color bulbs", COLOR_BULBS)
|
color_bulb = parametrize("color bulbs", COLOR_BULBS)
|
||||||
non_color_bulb = parametrize("non-color bulbs", BULBS - COLOR_BULBS)
|
non_color_bulb = parametrize("non-color bulbs", BULBS - COLOR_BULBS)
|
||||||
|
|
||||||
|
|
||||||
|
def check_categories():
|
||||||
|
"""Check that every fixture file is categorized."""
|
||||||
|
categorized_fixtures = set(
|
||||||
|
dimmer.args[1]
|
||||||
|
+ strip.args[1]
|
||||||
|
+ plug.args[1]
|
||||||
|
+ bulb.args[1]
|
||||||
|
+ lightstrip.args[1]
|
||||||
|
)
|
||||||
|
diff = set(SUPPORTED_DEVICES) - set(categorized_fixtures)
|
||||||
|
if diff:
|
||||||
|
for file in diff:
|
||||||
|
print(
|
||||||
|
"No category for file %s, add to the corresponding set (BULBS, PLUGS, ..)"
|
||||||
|
% file
|
||||||
|
)
|
||||||
|
raise Exception("Missing category for %s" % diff)
|
||||||
|
|
||||||
|
|
||||||
|
check_categories()
|
||||||
|
|
||||||
# Parametrize tests to run with device both on and off
|
# Parametrize tests to run with device both on and off
|
||||||
turn_on = pytest.mark.parametrize("turn_on", [True, False])
|
turn_on = pytest.mark.parametrize("turn_on", [True, False])
|
||||||
|
|
||||||
@ -174,6 +173,18 @@ def dev(request):
|
|||||||
return get_device_for_file(file)
|
return get_device_for_file(file)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(params=SUPPORTED_DEVICES, scope="session")
|
||||||
|
def discovery_data(request):
|
||||||
|
"""Return raw discovery file contents as JSON. Used for discovery tests."""
|
||||||
|
file = request.param
|
||||||
|
p = Path(file)
|
||||||
|
if not p.is_absolute():
|
||||||
|
p = Path(__file__).parent / "fixtures" / file
|
||||||
|
|
||||||
|
with open(p) as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
def pytest_addoption(parser):
|
def pytest_addoption(parser):
|
||||||
parser.addoption(
|
parser.addoption(
|
||||||
"--ip", action="store", default=None, help="run against device on given ip"
|
"--ip", action="store", default=None, help="run against device on given ip"
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
# type: ignore
|
# type: ignore
|
||||||
|
import sys
|
||||||
|
|
||||||
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
|
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
|
||||||
|
|
||||||
from kasa import DeviceType, Discover, SmartDevice, SmartDeviceException
|
from kasa import DeviceType, Discover, SmartDevice, SmartDeviceException
|
||||||
|
from kasa.discover import _DiscoverProtocol
|
||||||
|
|
||||||
from .conftest import bulb, dimmer, lightstrip, plug, pytestmark, strip
|
from .conftest import bulb, dimmer, lightstrip, plug, pytestmark, strip
|
||||||
|
|
||||||
@ -47,3 +50,57 @@ async def test_type_unknown():
|
|||||||
invalid_info = {"system": {"get_sysinfo": {"type": "nosuchtype"}}}
|
invalid_info = {"system": {"get_sysinfo": {"type": "nosuchtype"}}}
|
||||||
with pytest.raises(SmartDeviceException):
|
with pytest.raises(SmartDeviceException):
|
||||||
Discover._get_device_class(invalid_info)
|
Discover._get_device_class(invalid_info)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(sys.version_info < (3, 8), reason="3.8 is first one with asyncmock")
|
||||||
|
async def test_discover_single(discovery_data: dict, mocker):
|
||||||
|
"""Make sure that discover_single returns an initialized SmartDevice instance."""
|
||||||
|
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
|
||||||
|
x = await Discover.discover_single("127.0.0.1")
|
||||||
|
assert issubclass(x.__class__, SmartDevice)
|
||||||
|
assert x._sys_info is not None
|
||||||
|
|
||||||
|
|
||||||
|
INVALIDS = [
|
||||||
|
("No 'system' or 'get_sysinfo' in response", {"no": "data"}),
|
||||||
|
(
|
||||||
|
"Unable to find the device type field",
|
||||||
|
{"system": {"get_sysinfo": {"missing_type": 1}}},
|
||||||
|
),
|
||||||
|
("Unknown device type: foo", {"system": {"get_sysinfo": {"type": "foo"}}}),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(sys.version_info < (3, 8), reason="3.8 is first one with asyncmock")
|
||||||
|
@pytest.mark.parametrize("msg, data", INVALIDS)
|
||||||
|
async def test_discover_invalid_info(msg, data, mocker):
|
||||||
|
"""Make sure that invalid discovery information raises an exception."""
|
||||||
|
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=data)
|
||||||
|
with pytest.raises(SmartDeviceException, match=msg):
|
||||||
|
await Discover.discover_single("127.0.0.1")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_discover_send(mocker):
|
||||||
|
"""Test discovery parameters."""
|
||||||
|
proto = _DiscoverProtocol()
|
||||||
|
assert proto.discovery_packets == 3
|
||||||
|
assert proto.target == ("255.255.255.255", 9999)
|
||||||
|
sendto = mocker.patch.object(proto, "transport")
|
||||||
|
proto.do_discover()
|
||||||
|
assert sendto.sendto.call_count == proto.discovery_packets
|
||||||
|
|
||||||
|
|
||||||
|
async def test_discover_datagram_received(mocker, discovery_data):
|
||||||
|
"""Verify that datagram received fills discovered_devices."""
|
||||||
|
proto = _DiscoverProtocol()
|
||||||
|
mocker.patch("json.loads", return_value=discovery_data)
|
||||||
|
mocker.patch.object(proto, "protocol")
|
||||||
|
|
||||||
|
addr = "127.0.0.1"
|
||||||
|
proto.datagram_received("<placeholder data>", (addr, 1234))
|
||||||
|
|
||||||
|
# Check that device in discovered_devices is initialized correctly
|
||||||
|
assert len(proto.discovered_devices) == 1
|
||||||
|
dev = proto.discovered_devices[addr]
|
||||||
|
assert issubclass(dev.__class__, SmartDevice)
|
||||||
|
assert dev.host == addr
|
||||||
|
Loading…
Reference in New Issue
Block a user