mirror of
				https://github.com/python-kasa/python-kasa.git
				synced 2025-11-04 06:32:07 +00:00 
			
		
		
		
	Cleanup discovery & add tests (#212)
* Cleanup discovery & add tests * discovered_devices_raw is not anymore available, as that can be accessed directly from the device objects * test most of the discovery code paths * some minor cleanups to test handling * update discovery docs * Move category check to be after the definitions * skip a couple of tests requiring asyncmock not available on py37 * Remove return_raw usage from cli.discover
This commit is contained in:
		@@ -143,13 +143,11 @@ async def discover(ctx, timeout, discover_only, dump_raw):
 | 
			
		||||
    """Discover devices in the network."""
 | 
			
		||||
    target = ctx.parent.params["target"]
 | 
			
		||||
    click.echo(f"Discovering devices on {target} for {timeout} seconds")
 | 
			
		||||
    found_devs = await Discover.discover(
 | 
			
		||||
        target=target, timeout=timeout, return_raw=dump_raw
 | 
			
		||||
    )
 | 
			
		||||
    found_devs = await Discover.discover(target=target, timeout=timeout)
 | 
			
		||||
    if not discover_only:
 | 
			
		||||
        for ip, dev in found_devs.items():
 | 
			
		||||
            if dump_raw:
 | 
			
		||||
                click.echo(dev)
 | 
			
		||||
                click.echo(dev.sys_info)
 | 
			
		||||
                continue
 | 
			
		||||
            ctx.obj = dev
 | 
			
		||||
            await ctx.invoke(state)
 | 
			
		||||
 
 | 
			
		||||
@@ -3,7 +3,7 @@ import asyncio
 | 
			
		||||
import json
 | 
			
		||||
import logging
 | 
			
		||||
import socket
 | 
			
		||||
from typing import Awaitable, Callable, Dict, Mapping, Optional, Type, Union, cast
 | 
			
		||||
from typing import Awaitable, Callable, Dict, Optional, Type, cast
 | 
			
		||||
 | 
			
		||||
from kasa.protocol import TPLinkSmartHomeProtocol
 | 
			
		||||
from kasa.smartbulb import SmartBulb
 | 
			
		||||
@@ -17,6 +17,7 @@ _LOGGER = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
OnDiscoveredCallable = Callable[[SmartDevice], Awaitable[None]]
 | 
			
		||||
DeviceDict = Dict[str, SmartDevice]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _DiscoverProtocol(asyncio.DatagramProtocol):
 | 
			
		||||
@@ -25,8 +26,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
 | 
			
		||||
    This is internal class, use :func:`Discover.discover`: instead.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    discovered_devices: Dict[str, SmartDevice]
 | 
			
		||||
    discovered_devices_raw: Dict[str, Dict]
 | 
			
		||||
    discovered_devices: DeviceDict
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
@@ -43,7 +43,6 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
 | 
			
		||||
        self.protocol = TPLinkSmartHomeProtocol()
 | 
			
		||||
        self.target = (target, Discover.DISCOVERY_PORT)
 | 
			
		||||
        self.discovered_devices = {}
 | 
			
		||||
        self.discovered_devices_raw = {}
 | 
			
		||||
 | 
			
		||||
    def connection_made(self, transport) -> None:
 | 
			
		||||
        """Set socket options for broadcasting."""
 | 
			
		||||
@@ -80,13 +79,9 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
 | 
			
		||||
        device.update_from_discover_info(info)
 | 
			
		||||
 | 
			
		||||
        self.discovered_devices[ip] = device
 | 
			
		||||
        self.discovered_devices_raw[ip] = info
 | 
			
		||||
 | 
			
		||||
        if device_class is not None:
 | 
			
		||||
            if self.on_discovered is not None:
 | 
			
		||||
                asyncio.ensure_future(self.on_discovered(device))
 | 
			
		||||
        else:
 | 
			
		||||
            _LOGGER.error("Received invalid response: %s", info)
 | 
			
		||||
        if self.on_discovered is not None:
 | 
			
		||||
            asyncio.ensure_future(self.on_discovered(device))
 | 
			
		||||
 | 
			
		||||
    def error_received(self, ex):
 | 
			
		||||
        """Handle asyncio.Protocol errors."""
 | 
			
		||||
@@ -144,9 +139,8 @@ class Discover:
 | 
			
		||||
        on_discovered=None,
 | 
			
		||||
        timeout=5,
 | 
			
		||||
        discovery_packets=3,
 | 
			
		||||
        return_raw=False,
 | 
			
		||||
        interface=None,
 | 
			
		||||
    ) -> Mapping[str, Union[SmartDevice, Dict]]:
 | 
			
		||||
    ) -> DeviceDict:
 | 
			
		||||
        """Discover supported devices.
 | 
			
		||||
 | 
			
		||||
        Sends discovery message to 255.255.255.255:9999 in order
 | 
			
		||||
@@ -154,17 +148,17 @@ class Discover:
 | 
			
		||||
        and waits for given timeout for answers from devices.
 | 
			
		||||
        If you have multiple interfaces, you can use target parameter to specify the network for discovery.
 | 
			
		||||
 | 
			
		||||
        If given, `on_discovered` coroutine will get passed with the :class:`SmartDevice`-derived object as parameter.
 | 
			
		||||
        If given, `on_discovered` coroutine will get awaited with a :class:`SmartDevice`-derived object as parameter.
 | 
			
		||||
 | 
			
		||||
        The results of the discovery are returned either as a list of :class:`SmartDevice`-derived objects
 | 
			
		||||
        or as raw response dictionaries objects (if `return_raw` is True).
 | 
			
		||||
        The results of the discovery are returned as a dict of :class:`SmartDevice`-derived objects keyed with IP addresses.
 | 
			
		||||
        The devices are already initialized and all but emeter-related properties can be accessed directly.
 | 
			
		||||
 | 
			
		||||
        :param target: The target address where to send the broadcast discovery queries if multi-homing (e.g. 192.168.xxx.255).
 | 
			
		||||
        :param on_discovered: coroutine to execute on discovery
 | 
			
		||||
        :param timeout: How long to wait for responses, defaults to 5
 | 
			
		||||
        :param discovery_packets: Number of discovery packets are broadcasted.
 | 
			
		||||
        :param return_raw: True to return JSON objects instead of Devices.
 | 
			
		||||
        :return:
 | 
			
		||||
        :param discovery_packets: Number of discovery packets to broadcast
 | 
			
		||||
        :param interface: Bind to specific interface
 | 
			
		||||
        :return: dictionary with discovered devices
 | 
			
		||||
        """
 | 
			
		||||
        loop = asyncio.get_event_loop()
 | 
			
		||||
        transport, protocol = await loop.create_datagram_endpoint(
 | 
			
		||||
@@ -186,9 +180,6 @@ class Discover:
 | 
			
		||||
 | 
			
		||||
        _LOGGER.debug("Discovered %s devices", len(protocol.discovered_devices))
 | 
			
		||||
 | 
			
		||||
        if return_raw:
 | 
			
		||||
            return protocol.discovered_devices_raw
 | 
			
		||||
 | 
			
		||||
        return protocol.discovered_devices
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
@@ -204,12 +195,10 @@ class Discover:
 | 
			
		||||
        info = await protocol.query(host, Discover.DISCOVERY_QUERY)
 | 
			
		||||
 | 
			
		||||
        device_class = Discover._get_device_class(info)
 | 
			
		||||
        if device_class is not None:
 | 
			
		||||
            dev = device_class(host)
 | 
			
		||||
            await dev.update()
 | 
			
		||||
            return dev
 | 
			
		||||
        dev = device_class(host)
 | 
			
		||||
        await dev.update()
 | 
			
		||||
 | 
			
		||||
        raise SmartDeviceException("Unable to discover device, received: %s" % info)
 | 
			
		||||
        return dev
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _get_device_class(info: dict) -> Type[SmartDevice]:
 | 
			
		||||
@@ -237,17 +226,4 @@ class Discover:
 | 
			
		||||
 | 
			
		||||
            return SmartBulb
 | 
			
		||||
 | 
			
		||||
        raise SmartDeviceException("Unknown device type: %s", type_)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    logging.basicConfig(level=logging.INFO)
 | 
			
		||||
    loop = asyncio.get_event_loop()
 | 
			
		||||
 | 
			
		||||
    async def _on_device(dev):
 | 
			
		||||
        await dev.update()
 | 
			
		||||
        _LOGGER.info("Got device: %s", dev)
 | 
			
		||||
 | 
			
		||||
    devices = loop.run_until_complete(Discover.discover(on_discovered=_on_device))
 | 
			
		||||
    for ip, dev in devices.items():
 | 
			
		||||
        print(f"[{ip}] {dev}")
 | 
			
		||||
        raise SmartDeviceException("Unknown device type: %s" % type_)
 | 
			
		||||
 
 | 
			
		||||
@@ -53,8 +53,6 @@ def filter_model(desc, filter):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def parametrize(desc, devices, ids=None):
 | 
			
		||||
    # if ids is None:
 | 
			
		||||
    #    ids = ["on", "off"]
 | 
			
		||||
    return pytest.mark.parametrize(
 | 
			
		||||
        "dev", filter_model(desc, devices), indirect=True, ids=ids
 | 
			
		||||
    )
 | 
			
		||||
@@ -63,32 +61,11 @@ def parametrize(desc, devices, ids=None):
 | 
			
		||||
has_emeter = parametrize("has emeter", WITH_EMETER)
 | 
			
		||||
no_emeter = parametrize("no emeter", ALL_DEVICES - WITH_EMETER)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def name_for_filename(x):
 | 
			
		||||
    from os.path import basename
 | 
			
		||||
 | 
			
		||||
    return basename(x)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
bulb = parametrize("bulbs", BULBS, ids=name_for_filename)
 | 
			
		||||
plug = parametrize("plugs", PLUGS, ids=name_for_filename)
 | 
			
		||||
strip = parametrize("strips", STRIPS, ids=name_for_filename)
 | 
			
		||||
dimmer = parametrize("dimmers", DIMMERS, ids=name_for_filename)
 | 
			
		||||
lightstrip = parametrize("lightstrips", LIGHT_STRIPS, ids=name_for_filename)
 | 
			
		||||
 | 
			
		||||
# This ensures that every single file inside fixtures/ is being placed in some category
 | 
			
		||||
categorized_fixtures = set(
 | 
			
		||||
    dimmer.args[1] + strip.args[1] + plug.args[1] + bulb.args[1] + lightstrip.args[1]
 | 
			
		||||
)
 | 
			
		||||
diff = set(SUPPORTED_DEVICES) - set(categorized_fixtures)
 | 
			
		||||
if diff:
 | 
			
		||||
    for file in diff:
 | 
			
		||||
        print(
 | 
			
		||||
            "No category for file %s, add to the corresponding set (BULBS, PLUGS, ..)"
 | 
			
		||||
            % file
 | 
			
		||||
        )
 | 
			
		||||
    raise Exception("Missing category for %s" % diff)
 | 
			
		||||
 | 
			
		||||
bulb = parametrize("bulbs", BULBS, ids=basename)
 | 
			
		||||
plug = parametrize("plugs", PLUGS, ids=basename)
 | 
			
		||||
strip = parametrize("strips", STRIPS, ids=basename)
 | 
			
		||||
dimmer = parametrize("dimmers", DIMMERS, ids=basename)
 | 
			
		||||
lightstrip = parametrize("lightstrips", LIGHT_STRIPS, ids=basename)
 | 
			
		||||
 | 
			
		||||
# bulb types
 | 
			
		||||
dimmable = parametrize("dimmable", DIMMABLE)
 | 
			
		||||
@@ -98,6 +75,28 @@ non_variable_temp = parametrize("non-variable color temp", BULBS - VARIABLE_TEMP
 | 
			
		||||
color_bulb = parametrize("color bulbs", COLOR_BULBS)
 | 
			
		||||
non_color_bulb = parametrize("non-color bulbs", BULBS - COLOR_BULBS)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def check_categories():
 | 
			
		||||
    """Check that every fixture file is categorized."""
 | 
			
		||||
    categorized_fixtures = set(
 | 
			
		||||
        dimmer.args[1]
 | 
			
		||||
        + strip.args[1]
 | 
			
		||||
        + plug.args[1]
 | 
			
		||||
        + bulb.args[1]
 | 
			
		||||
        + lightstrip.args[1]
 | 
			
		||||
    )
 | 
			
		||||
    diff = set(SUPPORTED_DEVICES) - set(categorized_fixtures)
 | 
			
		||||
    if diff:
 | 
			
		||||
        for file in diff:
 | 
			
		||||
            print(
 | 
			
		||||
                "No category for file %s, add to the corresponding set (BULBS, PLUGS, ..)"
 | 
			
		||||
                % file
 | 
			
		||||
            )
 | 
			
		||||
        raise Exception("Missing category for %s" % diff)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
check_categories()
 | 
			
		||||
 | 
			
		||||
# Parametrize tests to run with device both on and off
 | 
			
		||||
turn_on = pytest.mark.parametrize("turn_on", [True, False])
 | 
			
		||||
 | 
			
		||||
@@ -174,6 +173,18 @@ def dev(request):
 | 
			
		||||
    return get_device_for_file(file)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture(params=SUPPORTED_DEVICES, scope="session")
 | 
			
		||||
def discovery_data(request):
 | 
			
		||||
    """Return raw discovery file contents as JSON. Used for discovery tests."""
 | 
			
		||||
    file = request.param
 | 
			
		||||
    p = Path(file)
 | 
			
		||||
    if not p.is_absolute():
 | 
			
		||||
        p = Path(__file__).parent / "fixtures" / file
 | 
			
		||||
 | 
			
		||||
    with open(p) as f:
 | 
			
		||||
        return json.load(f)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def pytest_addoption(parser):
 | 
			
		||||
    parser.addoption(
 | 
			
		||||
        "--ip", action="store", default=None, help="run against device on given ip"
 | 
			
		||||
 
 | 
			
		||||
@@ -1,7 +1,10 @@
 | 
			
		||||
# type: ignore
 | 
			
		||||
import sys
 | 
			
		||||
 | 
			
		||||
import pytest  # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
 | 
			
		||||
 | 
			
		||||
from kasa import DeviceType, Discover, SmartDevice, SmartDeviceException
 | 
			
		||||
from kasa.discover import _DiscoverProtocol
 | 
			
		||||
 | 
			
		||||
from .conftest import bulb, dimmer, lightstrip, plug, pytestmark, strip
 | 
			
		||||
 | 
			
		||||
@@ -47,3 +50,57 @@ async def test_type_unknown():
 | 
			
		||||
    invalid_info = {"system": {"get_sysinfo": {"type": "nosuchtype"}}}
 | 
			
		||||
    with pytest.raises(SmartDeviceException):
 | 
			
		||||
        Discover._get_device_class(invalid_info)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.skipif(sys.version_info < (3, 8), reason="3.8 is first one with asyncmock")
 | 
			
		||||
async def test_discover_single(discovery_data: dict, mocker):
 | 
			
		||||
    """Make sure that discover_single returns an initialized SmartDevice instance."""
 | 
			
		||||
    mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
 | 
			
		||||
    x = await Discover.discover_single("127.0.0.1")
 | 
			
		||||
    assert issubclass(x.__class__, SmartDevice)
 | 
			
		||||
    assert x._sys_info is not None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
INVALIDS = [
 | 
			
		||||
    ("No 'system' or 'get_sysinfo' in response", {"no": "data"}),
 | 
			
		||||
    (
 | 
			
		||||
        "Unable to find the device type field",
 | 
			
		||||
        {"system": {"get_sysinfo": {"missing_type": 1}}},
 | 
			
		||||
    ),
 | 
			
		||||
    ("Unknown device type: foo", {"system": {"get_sysinfo": {"type": "foo"}}}),
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.skipif(sys.version_info < (3, 8), reason="3.8 is first one with asyncmock")
 | 
			
		||||
@pytest.mark.parametrize("msg, data", INVALIDS)
 | 
			
		||||
async def test_discover_invalid_info(msg, data, mocker):
 | 
			
		||||
    """Make sure that invalid discovery information raises an exception."""
 | 
			
		||||
    mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=data)
 | 
			
		||||
    with pytest.raises(SmartDeviceException, match=msg):
 | 
			
		||||
        await Discover.discover_single("127.0.0.1")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def test_discover_send(mocker):
 | 
			
		||||
    """Test discovery parameters."""
 | 
			
		||||
    proto = _DiscoverProtocol()
 | 
			
		||||
    assert proto.discovery_packets == 3
 | 
			
		||||
    assert proto.target == ("255.255.255.255", 9999)
 | 
			
		||||
    sendto = mocker.patch.object(proto, "transport")
 | 
			
		||||
    proto.do_discover()
 | 
			
		||||
    assert sendto.sendto.call_count == proto.discovery_packets
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def test_discover_datagram_received(mocker, discovery_data):
 | 
			
		||||
    """Verify that datagram received fills discovered_devices."""
 | 
			
		||||
    proto = _DiscoverProtocol()
 | 
			
		||||
    mocker.patch("json.loads", return_value=discovery_data)
 | 
			
		||||
    mocker.patch.object(proto, "protocol")
 | 
			
		||||
 | 
			
		||||
    addr = "127.0.0.1"
 | 
			
		||||
    proto.datagram_received("<placeholder data>", (addr, 1234))
 | 
			
		||||
 | 
			
		||||
    # Check that device in discovered_devices is initialized correctly
 | 
			
		||||
    assert len(proto.discovered_devices) == 1
 | 
			
		||||
    dev = proto.discovered_devices[addr]
 | 
			
		||||
    assert issubclass(dev.__class__, SmartDevice)
 | 
			
		||||
    assert dev.host == addr
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user