From e31cc6662c8b3da672732773d27140faa58122aa Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 24 Sep 2021 16:25:43 -0500 Subject: [PATCH] Keep connection open and lock to prevent duplicate requests (#213) * Keep connection open and lock to prevent duplicate requests * option to not update children * tweaks * typing * tweaks * run tests in the same event loop * memorize model * Update kasa/protocol.py Co-authored-by: Teemu R. * Update kasa/protocol.py Co-authored-by: Teemu R. * Update kasa/protocol.py Co-authored-by: Teemu R. * Update kasa/protocol.py Co-authored-by: Teemu R. * dry * tweaks * warn when the event loop gets switched out from under us * raise on unable to connect multiple times * fix patch target * tweaks * isrot * reconnect test * prune * fix mocking * fix mocking * fix test under python 3.7 * fix test under python 3.7 * less patching * isort * use mocker to patch * disable on old python since mocking doesnt work * avoid disconnect/reconnect cycles * isort * Fix hue validation * Fix latitude_i/longitude_i units Co-authored-by: Teemu R. --- devtools/dump_devinfo.py | 17 ++-- kasa/discover.py | 9 +- kasa/protocol.py | 146 +++++++++++++++++++++-------- kasa/smartdevice.py | 14 +-- kasa/smartstrip.py | 10 +- kasa/tests/conftest.py | 53 +++++++---- kasa/tests/newfakes.py | 26 +++-- kasa/tests/test_bulb.py | 2 +- kasa/tests/test_discovery.py | 5 +- kasa/tests/test_protocol.py | 40 +++++++- kasa/tests/test_readme_examples.py | 15 +-- 11 files changed, 241 insertions(+), 96 deletions(-) diff --git a/devtools/dump_devinfo.py b/devtools/dump_devinfo.py index 9d30f967..1108e7fb 100644 --- a/devtools/dump_devinfo.py +++ b/devtools/dump_devinfo.py @@ -78,16 +78,17 @@ def cli(host, debug): ), ] - protocol = TPLinkSmartHomeProtocol() - successes = [] for test_call in items: + + async def _run_query(): + protocol = TPLinkSmartHomeProtocol(host) + return await protocol.query({test_call.module: {test_call.method: None}}) + try: click.echo(f"Testing {test_call}..", nl=False) - info = asyncio.run( - protocol.query(host, {test_call.module: {test_call.method: None}}) - ) + info = asyncio.run(_run_query()) resp = info[test_call.module] except Exception as ex: click.echo(click.style(f"FAIL {ex}", fg="red")) @@ -107,8 +108,12 @@ def cli(host, debug): final = default_to_regular(final) + async def _run_final_query(): + protocol = TPLinkSmartHomeProtocol(host) + return await protocol.query(final_query) + try: - final = asyncio.run(protocol.query(host, final_query)) + final = asyncio.run(_run_final_query()) except Exception as ex: click.echo( click.style( diff --git a/kasa/discover.py b/kasa/discover.py index f452c54a..a408c2de 100755 --- a/kasa/discover.py +++ b/kasa/discover.py @@ -40,7 +40,6 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): self.discovery_packets = discovery_packets self.interface = interface self.on_discovered = on_discovered - self.protocol = TPLinkSmartHomeProtocol() self.target = (target, Discover.DISCOVERY_PORT) self.discovered_devices = {} @@ -61,7 +60,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): """Send number of discovery datagrams.""" req = json.dumps(Discover.DISCOVERY_QUERY) _LOGGER.debug("[DISCOVERY] %s >> %s", self.target, Discover.DISCOVERY_QUERY) - encrypted_req = self.protocol.encrypt(req) + encrypted_req = TPLinkSmartHomeProtocol.encrypt(req) for i in range(self.discovery_packets): self.transport.sendto(encrypted_req[4:], self.target) # type: ignore @@ -71,7 +70,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): if ip in self.discovered_devices: return - info = json.loads(self.protocol.decrypt(data)) + info = json.loads(TPLinkSmartHomeProtocol.decrypt(data)) _LOGGER.debug("[DISCOVERY] %s << %s", ip, info) device_class = Discover._get_device_class(info) @@ -190,9 +189,9 @@ class Discover: :rtype: SmartDevice :return: Object for querying/controlling found device. """ - protocol = TPLinkSmartHomeProtocol() + protocol = TPLinkSmartHomeProtocol(host) - info = await protocol.query(host, Discover.DISCOVERY_QUERY) + info = await protocol.query(Discover.DISCOVERY_QUERY) device_class = Discover._get_device_class(info) dev = device_class(host) diff --git a/kasa/protocol.py b/kasa/protocol.py index bbf13b99..b54029c6 100755 --- a/kasa/protocol.py +++ b/kasa/protocol.py @@ -10,11 +10,12 @@ which are licensed under the Apache License, Version 2.0 http://www.apache.org/licenses/LICENSE-2.0 """ import asyncio +import contextlib import json import logging import struct from pprint import pformat as pf -from typing import Dict, Union +from typing import Dict, Optional, Union from .exceptions import SmartDeviceException @@ -28,8 +29,26 @@ class TPLinkSmartHomeProtocol: DEFAULT_PORT = 9999 DEFAULT_TIMEOUT = 5 - @staticmethod - async def query(host: str, request: Union[str, Dict], retry_count: int = 3) -> Dict: + BLOCK_SIZE = 4 + + def __init__(self, host: str) -> None: + """Create a protocol object.""" + self.host = host + self.reader: Optional[asyncio.StreamReader] = None + self.writer: Optional[asyncio.StreamWriter] = None + self.query_lock: Optional[asyncio.Lock] = None + self.loop: Optional[asyncio.AbstractEventLoop] = None + + def _detect_event_loop_change(self) -> None: + """Check if this object has been reused betwen event loops.""" + loop = asyncio.get_running_loop() + if not self.loop: + self.loop = loop + elif self.loop != loop: + _LOGGER.warning("Detected protocol reuse between different event loop") + self._reset() + + async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: """Request information from a TP-Link SmartHome Device. :param str host: host name or ip address of the device @@ -38,57 +57,106 @@ class TPLinkSmartHomeProtocol: :param retry_count: how many retries to do in case of failure :return: response dict """ + self._detect_event_loop_change() + + if not self.query_lock: + self.query_lock = asyncio.Lock() + if isinstance(request, dict): request = json.dumps(request) + assert isinstance(request, str) timeout = TPLinkSmartHomeProtocol.DEFAULT_TIMEOUT - writer = None + + async with self.query_lock: + return await self._query(request, retry_count, timeout) + + async def _connect(self, timeout: int) -> bool: + """Try to connect or reconnect to the device.""" + if self.writer: + return True + + with contextlib.suppress(Exception): + self.reader = self.writer = None + task = asyncio.open_connection( + self.host, TPLinkSmartHomeProtocol.DEFAULT_PORT + ) + self.reader, self.writer = await asyncio.wait_for(task, timeout=timeout) + return True + + return False + + async def _execute_query(self, request: str) -> Dict: + """Execute a query on the device and wait for the response.""" + assert self.writer is not None + assert self.reader is not None + + _LOGGER.debug("> (%i) %s", len(request), request) + self.writer.write(TPLinkSmartHomeProtocol.encrypt(request)) + await self.writer.drain() + + packed_block_size = await self.reader.readexactly(self.BLOCK_SIZE) + length = struct.unpack(">I", packed_block_size)[0] + + buffer = await self.reader.readexactly(length) + response = TPLinkSmartHomeProtocol.decrypt(buffer) + json_payload = json.loads(response) + _LOGGER.debug("< (%i) %s", len(response), pf(json_payload)) + return json_payload + + async def close(self): + """Close the connection.""" + writer = self.writer + self._reset() + if writer: + writer.close() + with contextlib.suppress(Exception): + await writer.wait_closed() + + def _reset(self): + """Clear any varibles that should not survive between loops.""" + self.writer = None + self.reader = None + self.query_lock = None + self.loop = None + + async def _query(self, request: str, retry_count: int, timeout: int) -> Dict: + """Try to query a device.""" for retry in range(retry_count + 1): - try: - task = asyncio.open_connection( - host, TPLinkSmartHomeProtocol.DEFAULT_PORT - ) - reader, writer = await asyncio.wait_for(task, timeout=timeout) - _LOGGER.debug("> (%i) %s", len(request), request) - writer.write(TPLinkSmartHomeProtocol.encrypt(request)) - await writer.drain() - - buffer = bytes() - # Some devices send responses with a length header of 0 and - # terminate with a zero size chunk. Others send the length and - # will hang if we attempt to read more data. - length = -1 - while True: - chunk = await reader.read(4096) - if length == -1: - length = struct.unpack(">I", chunk[0:4])[0] - buffer += chunk - if (length > 0 and len(buffer) >= length + 4) or not chunk: - break - - response = TPLinkSmartHomeProtocol.decrypt(buffer[4:]) - json_payload = json.loads(response) - _LOGGER.debug("< (%i) %s", len(response), pf(json_payload)) - - return json_payload - - except Exception as ex: + if not await self._connect(timeout): + await self.close() if retry >= retry_count: _LOGGER.debug("Giving up after %s retries", retry) raise SmartDeviceException( - "Unable to query the device: %s" % ex + f"Unable to connect to the device: {self.host}" + ) + continue + + try: + assert self.reader is not None + assert self.writer is not None + return await asyncio.wait_for( + self._execute_query(request), timeout=timeout + ) + except Exception as ex: + await self.close() + if retry >= retry_count: + _LOGGER.debug("Giving up after %s retries", retry) + raise SmartDeviceException( + f"Unable to query the device: {ex}" ) from ex _LOGGER.debug("Unable to query the device, retrying: %s", ex) - finally: - if writer: - writer.close() - await writer.wait_closed() - # make mypy happy, this should never be reached.. + await self.close() raise SmartDeviceException("Query reached somehow to unreachable") + def __del__(self): + if self.writer and self.loop and self.loop.is_running(): + self.writer.close() + self._reset() + @staticmethod def _xor_payload(unencrypted): key = TPLinkSmartHomeProtocol.INITIALIZATION_VECTOR diff --git a/kasa/smartdevice.py b/kasa/smartdevice.py index 11c7d1c9..fabf26b3 100755 --- a/kasa/smartdevice.py +++ b/kasa/smartdevice.py @@ -194,7 +194,7 @@ class SmartDevice: """ self.host = host - self.protocol = TPLinkSmartHomeProtocol() + self.protocol = TPLinkSmartHomeProtocol(host) self.emeter_type = "emeter" _LOGGER.debug("Initializing %s of type %s", self.host, type(self)) self._device_type = DeviceType.Unknown @@ -234,7 +234,7 @@ class SmartDevice: request = self._create_request(target, cmd, arg, child_ids) try: - response = await self.protocol.query(host=self.host, request=request) + response = await self.protocol.query(request=request) except Exception as ex: raise SmartDeviceException(f"Communication error on {target}:{cmd}") from ex @@ -272,7 +272,7 @@ class SmartDevice: """Retrieve system information.""" return await self._query_helper("system", "get_sysinfo") - async def update(self): + async def update(self, update_children: bool = True): """Query the device to update the data. Needed for properties that are decorated with `requires_update`. @@ -285,7 +285,7 @@ class SmartDevice: # See #105, #120, #161 if self._last_update is None: _LOGGER.debug("Performing the initial update to obtain sysinfo") - self._last_update = await self.protocol.query(self.host, req) + self._last_update = await self.protocol.query(req) self._sys_info = self._last_update["system"]["get_sysinfo"] # If the device has no emeter, we are done for the initial update # Otherwise we will follow the regular code path to also query @@ -299,7 +299,7 @@ class SmartDevice: ) req.update(self._create_emeter_request()) - self._last_update = await self.protocol.query(self.host, req) + self._last_update = await self.protocol.query(req) self._sys_info = self._last_update["system"]["get_sysinfo"] def update_from_discover_info(self, info): @@ -383,8 +383,8 @@ class SmartDevice: loc["latitude"] = sys_info["latitude"] loc["longitude"] = sys_info["longitude"] elif "latitude_i" in sys_info and "longitude_i" in sys_info: - loc["latitude"] = sys_info["latitude_i"] - loc["longitude"] = sys_info["longitude_i"] + loc["latitude"] = sys_info["latitude_i"] / 10000 + loc["longitude"] = sys_info["longitude_i"] / 10000 else: _LOGGER.warning("Unsupported device location.") diff --git a/kasa/smartstrip.py b/kasa/smartstrip.py index c1235920..71373a7a 100755 --- a/kasa/smartstrip.py +++ b/kasa/smartstrip.py @@ -87,12 +87,12 @@ class SmartStrip(SmartDevice): """Return if any of the outlets are on.""" return any(plug.is_on for plug in self.children) - async def update(self): + async def update(self, update_children: bool = True): """Update some of the attributes. Needed for methods that are decorated with `requires_update`. """ - await super().update() + await super().update(update_children) # Initialize the child devices during the first update. if not self.children: @@ -103,7 +103,7 @@ class SmartStrip(SmartDevice): SmartStripPlug(self.host, parent=self, child_id=child["id"]) ) - if self.has_emeter: + if update_children and self.has_emeter: for plug in self.children: await plug.update() @@ -243,13 +243,13 @@ class SmartStripPlug(SmartPlug): self._sys_info = parent._sys_info self._device_type = DeviceType.StripSocket - async def update(self): + async def update(self, update_children: bool = True): """Query the device to update the data. Needed for properties that are decorated with `requires_update`. """ self._last_update = await self.parent.protocol.query( - self.host, self._create_emeter_request() + self._create_emeter_request() ) def _create_request( diff --git a/kasa/tests/conftest.py b/kasa/tests/conftest.py index df253e5d..a7ab5d13 100644 --- a/kasa/tests/conftest.py +++ b/kasa/tests/conftest.py @@ -4,6 +4,7 @@ import json import os from os.path import basename from pathlib import Path, PurePath +from typing import Dict from unittest.mock import MagicMock import pytest # type: ignore # see https://github.com/pytest-dev/pytest/issues/3342 @@ -39,6 +40,8 @@ WITH_EMETER = {"HS110", "HS300", "KP115", *BULBS} ALL_DEVICES = BULBS.union(PLUGS).union(STRIPS).union(DIMMERS) +IP_MODEL_CACHE: Dict[str, str] = {} + def filter_model(desc, filter): filtered = list() @@ -137,23 +140,39 @@ def device_for_file(model): raise Exception("Unable to find type for %s", model) -def get_device_for_file(file): +async def _update_and_close(d): + await d.update() + await d.protocol.close() + return d + + +async def _discover_update_and_close(ip): + d = await Discover.discover_single(ip) + return await _update_and_close(d) + + +async def get_device_for_file(file): # if the wanted file is not an absolute path, prepend the fixtures directory p = Path(file) if not p.is_absolute(): p = Path(__file__).parent / "fixtures" / file - with open(p) as f: - sysinfo = json.load(f) - model = basename(file) - p = device_for_file(model)(host="127.0.0.123") - p.protocol = FakeTransportProtocol(sysinfo) - asyncio.run(p.update()) - return p + def load_file(): + with open(p) as f: + return json.load(f) + + loop = asyncio.get_running_loop() + sysinfo = await loop.run_in_executor(None, load_file) + + model = basename(file) + d = device_for_file(model)(host="127.0.0.123") + d.protocol = FakeTransportProtocol(sysinfo) + await _update_and_close(d) + return d -@pytest.fixture(params=SUPPORTED_DEVICES, scope="session") -def dev(request): +@pytest.fixture(params=SUPPORTED_DEVICES) +async def dev(request): """Device fixture. Provides a device (given --ip) or parametrized fixture for the supported devices. @@ -163,14 +182,16 @@ def dev(request): ip = request.config.getoption("--ip") if ip: - d = asyncio.run(Discover.discover_single(ip)) - asyncio.run(d.update()) - if d.model in file: - return d - else: + model = IP_MODEL_CACHE.get(ip) + d = None + if not model: + d = await _discover_update_and_close(ip) + IP_MODEL_CACHE[ip] = model = d.model + if model not in file: pytest.skip(f"skipping file {file}") + return d if d else await _discover_update_and_close(ip) - return get_device_for_file(file) + return await get_device_for_file(file) @pytest.fixture(params=SUPPORTED_DEVICES, scope="session") diff --git a/kasa/tests/newfakes.py b/kasa/tests/newfakes.py index a37bb414..a4764b66 100644 --- a/kasa/tests/newfakes.py +++ b/kasa/tests/newfakes.py @@ -83,9 +83,19 @@ PLUG_SCHEMA = Schema( "icon_hash": str, "led_off": check_int_bool, "latitude": Any(All(float, Range(min=-90, max=90)), 0, None), - "latitude_i": Any(All(float, Range(min=-90, max=90)), 0, None), + "latitude_i": Any( + All(int, Range(min=-900000, max=900000)), + All(float, Range(min=-900000, max=900000)), + 0, + None, + ), "longitude": Any(All(float, Range(min=-180, max=180)), 0, None), - "longitude_i": Any(All(float, Range(min=-180, max=180)), 0, None), + "longitude_i": Any( + All(int, Range(min=-18000000, max=18000000)), + All(float, Range(min=-18000000, max=18000000)), + 0, + None, + ), "mac": check_mac, "model": str, "oemId": str, @@ -117,17 +127,17 @@ LIGHT_STATE_SCHEMA = Schema( { "brightness": All(int, Range(min=0, max=100)), "color_temp": int, - "hue": All(int, Range(min=0, max=255)), + "hue": All(int, Range(min=0, max=360)), "mode": str, "on_off": check_int_bool, - "saturation": All(int, Range(min=0, max=255)), + "saturation": All(int, Range(min=0, max=100)), "dft_on_state": Optional( { "brightness": All(int, Range(min=0, max=100)), "color_temp": All(int, Range(min=0, max=9000)), - "hue": All(int, Range(min=0, max=255)), + "hue": All(int, Range(min=0, max=360)), "mode": str, - "saturation": All(int, Range(min=0, max=255)), + "saturation": All(int, Range(min=0, max=100)), } ), "err_code": int, @@ -276,6 +286,8 @@ TIME_MODULE = { class FakeTransportProtocol(TPLinkSmartHomeProtocol): def __init__(self, info): self.discovery_data = info + self.writer = None + self.reader = None proto = FakeTransportProtocol.baseproto for target in info: @@ -426,7 +438,7 @@ class FakeTransportProtocol(TPLinkSmartHomeProtocol): }, } - async def query(self, host, request, port=9999): + async def query(self, request, port=9999): proto = self.proto # collect child ids from context diff --git a/kasa/tests/test_bulb.py b/kasa/tests/test_bulb.py index 28fcd4cb..ea8a28cb 100644 --- a/kasa/tests/test_bulb.py +++ b/kasa/tests/test_bulb.py @@ -60,7 +60,7 @@ async def test_hsv(dev, turn_on): assert dev.is_color hue, saturation, brightness = dev.hsv - assert 0 <= hue <= 255 + assert 0 <= hue <= 360 assert 0 <= saturation <= 100 assert 0 <= brightness <= 100 diff --git a/kasa/tests/test_discovery.py b/kasa/tests/test_discovery.py index 13ba3809..c933cb12 100644 --- a/kasa/tests/test_discovery.py +++ b/kasa/tests/test_discovery.py @@ -3,7 +3,7 @@ import sys import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 -from kasa import DeviceType, Discover, SmartDevice, SmartDeviceException +from kasa import DeviceType, Discover, SmartDevice, SmartDeviceException, protocol from kasa.discover import _DiscoverProtocol from .conftest import bulb, dimmer, lightstrip, plug, pytestmark, strip @@ -94,7 +94,8 @@ async def test_discover_datagram_received(mocker, discovery_data): """Verify that datagram received fills discovered_devices.""" proto = _DiscoverProtocol() mocker.patch("json.loads", return_value=discovery_data) - mocker.patch.object(proto, "protocol") + mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "encrypt") + mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "decrypt") addr = "127.0.0.1" proto.datagram_received("", (addr, 1234)) diff --git a/kasa/tests/test_protocol.py b/kasa/tests/test_protocol.py index 51c01d49..bc0da183 100644 --- a/kasa/tests/test_protocol.py +++ b/kasa/tests/test_protocol.py @@ -1,4 +1,6 @@ import json +import struct +import sys import pytest @@ -21,11 +23,47 @@ async def test_protocol_retries(mocker, retry_count): conn = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) with pytest.raises(SmartDeviceException): - await TPLinkSmartHomeProtocol.query("127.0.0.1", {}, retry_count=retry_count) + await TPLinkSmartHomeProtocol("127.0.0.1").query({}, retry_count=retry_count) assert conn.call_count == retry_count + 1 +@pytest.mark.skipif(sys.version_info < (3, 8), reason="3.8 is first one with asyncmock") +@pytest.mark.parametrize("retry_count", [1, 3, 5]) +async def test_protocol_reconnect(mocker, retry_count): + remaining = retry_count + encrypted = TPLinkSmartHomeProtocol.encrypt('{"great":"success"}')[ + TPLinkSmartHomeProtocol.BLOCK_SIZE : + ] + + def _fail_one_less_than_retry_count(*_): + nonlocal remaining + remaining -= 1 + if remaining: + raise Exception("Simulated write failure") + + async def _mock_read(byte_count): + nonlocal encrypted + if byte_count == TPLinkSmartHomeProtocol.BLOCK_SIZE: + return struct.pack(">I", len(encrypted)) + if byte_count == len(encrypted): + return encrypted + + raise ValueError(f"No mock for {byte_count}") + + def aio_mock_writer(_, __): + reader = mocker.patch("asyncio.StreamReader") + writer = mocker.patch("asyncio.StreamWriter") + mocker.patch.object(writer, "write", _fail_one_less_than_retry_count) + mocker.patch.object(reader, "readexactly", _mock_read) + return reader, writer + + protocol = TPLinkSmartHomeProtocol("127.0.0.1") + mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) + response = await protocol.query({}, retry_count=retry_count) + assert response == {"great": "success"} + + def test_encrypt(): d = json.dumps({"foo": 1, "bar": 2}) encrypted = TPLinkSmartHomeProtocol.encrypt(d) diff --git a/kasa/tests/test_readme_examples.py b/kasa/tests/test_readme_examples.py index 27455dd8..a64c824c 100644 --- a/kasa/tests/test_readme_examples.py +++ b/kasa/tests/test_readme_examples.py @@ -1,3 +1,4 @@ +import asyncio import sys import pytest @@ -8,7 +9,7 @@ from kasa.tests.conftest import get_device_for_file def test_bulb_examples(mocker): """Use KL130 (bulb with all features) to test the doctests.""" - p = get_device_for_file("KL130(US)_1.0.json") + p = asyncio.run(get_device_for_file("KL130(US)_1.0.json")) mocker.patch("kasa.smartbulb.SmartBulb", return_value=p) mocker.patch("kasa.smartbulb.SmartBulb.update") res = xdoctest.doctest_module("kasa.smartbulb", "all") @@ -17,7 +18,7 @@ def test_bulb_examples(mocker): def test_smartdevice_examples(mocker): """Use HS110 for emeter examples.""" - p = get_device_for_file("HS110(EU)_1.0_real.json") + p = asyncio.run(get_device_for_file("HS110(EU)_1.0_real.json")) mocker.patch("kasa.smartdevice.SmartDevice", return_value=p) mocker.patch("kasa.smartdevice.SmartDevice.update") res = xdoctest.doctest_module("kasa.smartdevice", "all") @@ -26,7 +27,7 @@ def test_smartdevice_examples(mocker): def test_plug_examples(mocker): """Test plug examples.""" - p = get_device_for_file("HS110(EU)_1.0_real.json") + p = asyncio.run(get_device_for_file("HS110(EU)_1.0_real.json")) mocker.patch("kasa.smartplug.SmartPlug", return_value=p) mocker.patch("kasa.smartplug.SmartPlug.update") res = xdoctest.doctest_module("kasa.smartplug", "all") @@ -35,7 +36,7 @@ def test_plug_examples(mocker): def test_strip_examples(mocker): """Test strip examples.""" - p = get_device_for_file("KP303(UK)_1.0.json") + p = asyncio.run(get_device_for_file("KP303(UK)_1.0.json")) mocker.patch("kasa.smartstrip.SmartStrip", return_value=p) mocker.patch("kasa.smartstrip.SmartStrip.update") res = xdoctest.doctest_module("kasa.smartstrip", "all") @@ -44,7 +45,7 @@ def test_strip_examples(mocker): def test_dimmer_examples(mocker): """Test dimmer examples.""" - p = get_device_for_file("HS220(US)_1.0_real.json") + p = asyncio.run(get_device_for_file("HS220(US)_1.0_real.json")) mocker.patch("kasa.smartdimmer.SmartDimmer", return_value=p) mocker.patch("kasa.smartdimmer.SmartDimmer.update") res = xdoctest.doctest_module("kasa.smartdimmer", "all") @@ -53,7 +54,7 @@ def test_dimmer_examples(mocker): def test_lightstrip_examples(mocker): """Test lightstrip examples.""" - p = get_device_for_file("KL430(US)_1.0.json") + p = asyncio.run(get_device_for_file("KL430(US)_1.0.json")) mocker.patch("kasa.smartlightstrip.SmartLightStrip", return_value=p) mocker.patch("kasa.smartlightstrip.SmartLightStrip.update") res = xdoctest.doctest_module("kasa.smartlightstrip", "all") @@ -65,7 +66,7 @@ def test_lightstrip_examples(mocker): ) def test_discovery_examples(mocker): """Test discovery examples.""" - p = get_device_for_file("KP303(UK)_1.0.json") + p = asyncio.run(get_device_for_file("KP303(UK)_1.0.json")) # This succeeds on python 3.8 but fails on 3.7 # ValueError: a coroutine was expected, got [