Keep connection open and lock to prevent duplicate requests (#213)

* Keep connection open and lock to prevent duplicate requests

* option to not update children

* tweaks

* typing

* tweaks

* run tests in the same event loop

* memorize model

* Update kasa/protocol.py

Co-authored-by: Teemu R. <tpr@iki.fi>

* Update kasa/protocol.py

Co-authored-by: Teemu R. <tpr@iki.fi>

* Update kasa/protocol.py

Co-authored-by: Teemu R. <tpr@iki.fi>

* Update kasa/protocol.py

Co-authored-by: Teemu R. <tpr@iki.fi>

* dry

* tweaks

* warn when the event loop gets switched out from under us

* raise on unable to connect multiple times

* fix patch target

* tweaks

* isrot

* reconnect test

* prune

* fix mocking

* fix mocking

* fix test under python 3.7

* fix test under python 3.7

* less patching

* isort

* use mocker to patch

* disable on old python since mocking doesnt work

* avoid disconnect/reconnect cycles

* isort

* Fix hue validation

* Fix latitude_i/longitude_i units

Co-authored-by: Teemu R. <tpr@iki.fi>
This commit is contained in:
J. Nick Koston 2021-09-24 16:25:43 -05:00 committed by GitHub
parent f1b28e79b9
commit e31cc6662c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 241 additions and 96 deletions

View File

@ -78,16 +78,17 @@ def cli(host, debug):
), ),
] ]
protocol = TPLinkSmartHomeProtocol()
successes = [] successes = []
for test_call in items: for test_call in items:
async def _run_query():
protocol = TPLinkSmartHomeProtocol(host)
return await protocol.query({test_call.module: {test_call.method: None}})
try: try:
click.echo(f"Testing {test_call}..", nl=False) click.echo(f"Testing {test_call}..", nl=False)
info = asyncio.run( info = asyncio.run(_run_query())
protocol.query(host, {test_call.module: {test_call.method: None}})
)
resp = info[test_call.module] resp = info[test_call.module]
except Exception as ex: except Exception as ex:
click.echo(click.style(f"FAIL {ex}", fg="red")) click.echo(click.style(f"FAIL {ex}", fg="red"))
@ -107,8 +108,12 @@ def cli(host, debug):
final = default_to_regular(final) final = default_to_regular(final)
async def _run_final_query():
protocol = TPLinkSmartHomeProtocol(host)
return await protocol.query(final_query)
try: try:
final = asyncio.run(protocol.query(host, final_query)) final = asyncio.run(_run_final_query())
except Exception as ex: except Exception as ex:
click.echo( click.echo(
click.style( click.style(

View File

@ -40,7 +40,6 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
self.discovery_packets = discovery_packets self.discovery_packets = discovery_packets
self.interface = interface self.interface = interface
self.on_discovered = on_discovered self.on_discovered = on_discovered
self.protocol = TPLinkSmartHomeProtocol()
self.target = (target, Discover.DISCOVERY_PORT) self.target = (target, Discover.DISCOVERY_PORT)
self.discovered_devices = {} self.discovered_devices = {}
@ -61,7 +60,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
"""Send number of discovery datagrams.""" """Send number of discovery datagrams."""
req = json.dumps(Discover.DISCOVERY_QUERY) req = json.dumps(Discover.DISCOVERY_QUERY)
_LOGGER.debug("[DISCOVERY] %s >> %s", self.target, Discover.DISCOVERY_QUERY) _LOGGER.debug("[DISCOVERY] %s >> %s", self.target, Discover.DISCOVERY_QUERY)
encrypted_req = self.protocol.encrypt(req) encrypted_req = TPLinkSmartHomeProtocol.encrypt(req)
for i in range(self.discovery_packets): for i in range(self.discovery_packets):
self.transport.sendto(encrypted_req[4:], self.target) # type: ignore self.transport.sendto(encrypted_req[4:], self.target) # type: ignore
@ -71,7 +70,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
if ip in self.discovered_devices: if ip in self.discovered_devices:
return return
info = json.loads(self.protocol.decrypt(data)) info = json.loads(TPLinkSmartHomeProtocol.decrypt(data))
_LOGGER.debug("[DISCOVERY] %s << %s", ip, info) _LOGGER.debug("[DISCOVERY] %s << %s", ip, info)
device_class = Discover._get_device_class(info) device_class = Discover._get_device_class(info)
@ -190,9 +189,9 @@ class Discover:
:rtype: SmartDevice :rtype: SmartDevice
:return: Object for querying/controlling found device. :return: Object for querying/controlling found device.
""" """
protocol = TPLinkSmartHomeProtocol() protocol = TPLinkSmartHomeProtocol(host)
info = await protocol.query(host, Discover.DISCOVERY_QUERY) info = await protocol.query(Discover.DISCOVERY_QUERY)
device_class = Discover._get_device_class(info) device_class = Discover._get_device_class(info)
dev = device_class(host) dev = device_class(host)

View File

@ -10,11 +10,12 @@ which are licensed under the Apache License, Version 2.0
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
""" """
import asyncio import asyncio
import contextlib
import json import json
import logging import logging
import struct import struct
from pprint import pformat as pf from pprint import pformat as pf
from typing import Dict, Union from typing import Dict, Optional, Union
from .exceptions import SmartDeviceException from .exceptions import SmartDeviceException
@ -28,8 +29,26 @@ class TPLinkSmartHomeProtocol:
DEFAULT_PORT = 9999 DEFAULT_PORT = 9999
DEFAULT_TIMEOUT = 5 DEFAULT_TIMEOUT = 5
@staticmethod BLOCK_SIZE = 4
async def query(host: str, request: Union[str, Dict], retry_count: int = 3) -> Dict:
def __init__(self, host: str) -> None:
"""Create a protocol object."""
self.host = host
self.reader: Optional[asyncio.StreamReader] = None
self.writer: Optional[asyncio.StreamWriter] = None
self.query_lock: Optional[asyncio.Lock] = None
self.loop: Optional[asyncio.AbstractEventLoop] = None
def _detect_event_loop_change(self) -> None:
"""Check if this object has been reused betwen event loops."""
loop = asyncio.get_running_loop()
if not self.loop:
self.loop = loop
elif self.loop != loop:
_LOGGER.warning("Detected protocol reuse between different event loop")
self._reset()
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
"""Request information from a TP-Link SmartHome Device. """Request information from a TP-Link SmartHome Device.
:param str host: host name or ip address of the device :param str host: host name or ip address of the device
@ -38,57 +57,106 @@ class TPLinkSmartHomeProtocol:
:param retry_count: how many retries to do in case of failure :param retry_count: how many retries to do in case of failure
:return: response dict :return: response dict
""" """
self._detect_event_loop_change()
if not self.query_lock:
self.query_lock = asyncio.Lock()
if isinstance(request, dict): if isinstance(request, dict):
request = json.dumps(request) request = json.dumps(request)
assert isinstance(request, str)
timeout = TPLinkSmartHomeProtocol.DEFAULT_TIMEOUT timeout = TPLinkSmartHomeProtocol.DEFAULT_TIMEOUT
writer = None
async with self.query_lock:
return await self._query(request, retry_count, timeout)
async def _connect(self, timeout: int) -> bool:
"""Try to connect or reconnect to the device."""
if self.writer:
return True
with contextlib.suppress(Exception):
self.reader = self.writer = None
task = asyncio.open_connection(
self.host, TPLinkSmartHomeProtocol.DEFAULT_PORT
)
self.reader, self.writer = await asyncio.wait_for(task, timeout=timeout)
return True
return False
async def _execute_query(self, request: str) -> Dict:
"""Execute a query on the device and wait for the response."""
assert self.writer is not None
assert self.reader is not None
_LOGGER.debug("> (%i) %s", len(request), request)
self.writer.write(TPLinkSmartHomeProtocol.encrypt(request))
await self.writer.drain()
packed_block_size = await self.reader.readexactly(self.BLOCK_SIZE)
length = struct.unpack(">I", packed_block_size)[0]
buffer = await self.reader.readexactly(length)
response = TPLinkSmartHomeProtocol.decrypt(buffer)
json_payload = json.loads(response)
_LOGGER.debug("< (%i) %s", len(response), pf(json_payload))
return json_payload
async def close(self):
"""Close the connection."""
writer = self.writer
self._reset()
if writer:
writer.close()
with contextlib.suppress(Exception):
await writer.wait_closed()
def _reset(self):
"""Clear any varibles that should not survive between loops."""
self.writer = None
self.reader = None
self.query_lock = None
self.loop = None
async def _query(self, request: str, retry_count: int, timeout: int) -> Dict:
"""Try to query a device."""
for retry in range(retry_count + 1): for retry in range(retry_count + 1):
try: if not await self._connect(timeout):
task = asyncio.open_connection( await self.close()
host, TPLinkSmartHomeProtocol.DEFAULT_PORT
)
reader, writer = await asyncio.wait_for(task, timeout=timeout)
_LOGGER.debug("> (%i) %s", len(request), request)
writer.write(TPLinkSmartHomeProtocol.encrypt(request))
await writer.drain()
buffer = bytes()
# Some devices send responses with a length header of 0 and
# terminate with a zero size chunk. Others send the length and
# will hang if we attempt to read more data.
length = -1
while True:
chunk = await reader.read(4096)
if length == -1:
length = struct.unpack(">I", chunk[0:4])[0]
buffer += chunk
if (length > 0 and len(buffer) >= length + 4) or not chunk:
break
response = TPLinkSmartHomeProtocol.decrypt(buffer[4:])
json_payload = json.loads(response)
_LOGGER.debug("< (%i) %s", len(response), pf(json_payload))
return json_payload
except Exception as ex:
if retry >= retry_count: if retry >= retry_count:
_LOGGER.debug("Giving up after %s retries", retry) _LOGGER.debug("Giving up after %s retries", retry)
raise SmartDeviceException( raise SmartDeviceException(
"Unable to query the device: %s" % ex f"Unable to connect to the device: {self.host}"
)
continue
try:
assert self.reader is not None
assert self.writer is not None
return await asyncio.wait_for(
self._execute_query(request), timeout=timeout
)
except Exception as ex:
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up after %s retries", retry)
raise SmartDeviceException(
f"Unable to query the device: {ex}"
) from ex ) from ex
_LOGGER.debug("Unable to query the device, retrying: %s", ex) _LOGGER.debug("Unable to query the device, retrying: %s", ex)
finally:
if writer:
writer.close()
await writer.wait_closed()
# make mypy happy, this should never be reached.. # make mypy happy, this should never be reached..
await self.close()
raise SmartDeviceException("Query reached somehow to unreachable") raise SmartDeviceException("Query reached somehow to unreachable")
def __del__(self):
if self.writer and self.loop and self.loop.is_running():
self.writer.close()
self._reset()
@staticmethod @staticmethod
def _xor_payload(unencrypted): def _xor_payload(unencrypted):
key = TPLinkSmartHomeProtocol.INITIALIZATION_VECTOR key = TPLinkSmartHomeProtocol.INITIALIZATION_VECTOR

View File

@ -194,7 +194,7 @@ class SmartDevice:
""" """
self.host = host self.host = host
self.protocol = TPLinkSmartHomeProtocol() self.protocol = TPLinkSmartHomeProtocol(host)
self.emeter_type = "emeter" self.emeter_type = "emeter"
_LOGGER.debug("Initializing %s of type %s", self.host, type(self)) _LOGGER.debug("Initializing %s of type %s", self.host, type(self))
self._device_type = DeviceType.Unknown self._device_type = DeviceType.Unknown
@ -234,7 +234,7 @@ class SmartDevice:
request = self._create_request(target, cmd, arg, child_ids) request = self._create_request(target, cmd, arg, child_ids)
try: try:
response = await self.protocol.query(host=self.host, request=request) response = await self.protocol.query(request=request)
except Exception as ex: except Exception as ex:
raise SmartDeviceException(f"Communication error on {target}:{cmd}") from ex raise SmartDeviceException(f"Communication error on {target}:{cmd}") from ex
@ -272,7 +272,7 @@ class SmartDevice:
"""Retrieve system information.""" """Retrieve system information."""
return await self._query_helper("system", "get_sysinfo") return await self._query_helper("system", "get_sysinfo")
async def update(self): async def update(self, update_children: bool = True):
"""Query the device to update the data. """Query the device to update the data.
Needed for properties that are decorated with `requires_update`. Needed for properties that are decorated with `requires_update`.
@ -285,7 +285,7 @@ class SmartDevice:
# See #105, #120, #161 # See #105, #120, #161
if self._last_update is None: if self._last_update is None:
_LOGGER.debug("Performing the initial update to obtain sysinfo") _LOGGER.debug("Performing the initial update to obtain sysinfo")
self._last_update = await self.protocol.query(self.host, req) self._last_update = await self.protocol.query(req)
self._sys_info = self._last_update["system"]["get_sysinfo"] self._sys_info = self._last_update["system"]["get_sysinfo"]
# If the device has no emeter, we are done for the initial update # If the device has no emeter, we are done for the initial update
# Otherwise we will follow the regular code path to also query # Otherwise we will follow the regular code path to also query
@ -299,7 +299,7 @@ class SmartDevice:
) )
req.update(self._create_emeter_request()) req.update(self._create_emeter_request())
self._last_update = await self.protocol.query(self.host, req) self._last_update = await self.protocol.query(req)
self._sys_info = self._last_update["system"]["get_sysinfo"] self._sys_info = self._last_update["system"]["get_sysinfo"]
def update_from_discover_info(self, info): def update_from_discover_info(self, info):
@ -383,8 +383,8 @@ class SmartDevice:
loc["latitude"] = sys_info["latitude"] loc["latitude"] = sys_info["latitude"]
loc["longitude"] = sys_info["longitude"] loc["longitude"] = sys_info["longitude"]
elif "latitude_i" in sys_info and "longitude_i" in sys_info: elif "latitude_i" in sys_info and "longitude_i" in sys_info:
loc["latitude"] = sys_info["latitude_i"] loc["latitude"] = sys_info["latitude_i"] / 10000
loc["longitude"] = sys_info["longitude_i"] loc["longitude"] = sys_info["longitude_i"] / 10000
else: else:
_LOGGER.warning("Unsupported device location.") _LOGGER.warning("Unsupported device location.")

View File

@ -87,12 +87,12 @@ class SmartStrip(SmartDevice):
"""Return if any of the outlets are on.""" """Return if any of the outlets are on."""
return any(plug.is_on for plug in self.children) return any(plug.is_on for plug in self.children)
async def update(self): async def update(self, update_children: bool = True):
"""Update some of the attributes. """Update some of the attributes.
Needed for methods that are decorated with `requires_update`. Needed for methods that are decorated with `requires_update`.
""" """
await super().update() await super().update(update_children)
# Initialize the child devices during the first update. # Initialize the child devices during the first update.
if not self.children: if not self.children:
@ -103,7 +103,7 @@ class SmartStrip(SmartDevice):
SmartStripPlug(self.host, parent=self, child_id=child["id"]) SmartStripPlug(self.host, parent=self, child_id=child["id"])
) )
if self.has_emeter: if update_children and self.has_emeter:
for plug in self.children: for plug in self.children:
await plug.update() await plug.update()
@ -243,13 +243,13 @@ class SmartStripPlug(SmartPlug):
self._sys_info = parent._sys_info self._sys_info = parent._sys_info
self._device_type = DeviceType.StripSocket self._device_type = DeviceType.StripSocket
async def update(self): async def update(self, update_children: bool = True):
"""Query the device to update the data. """Query the device to update the data.
Needed for properties that are decorated with `requires_update`. Needed for properties that are decorated with `requires_update`.
""" """
self._last_update = await self.parent.protocol.query( self._last_update = await self.parent.protocol.query(
self.host, self._create_emeter_request() self._create_emeter_request()
) )
def _create_request( def _create_request(

View File

@ -4,6 +4,7 @@ import json
import os import os
from os.path import basename from os.path import basename
from pathlib import Path, PurePath from pathlib import Path, PurePath
from typing import Dict
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest # type: ignore # see https://github.com/pytest-dev/pytest/issues/3342 import pytest # type: ignore # see https://github.com/pytest-dev/pytest/issues/3342
@ -39,6 +40,8 @@ WITH_EMETER = {"HS110", "HS300", "KP115", *BULBS}
ALL_DEVICES = BULBS.union(PLUGS).union(STRIPS).union(DIMMERS) ALL_DEVICES = BULBS.union(PLUGS).union(STRIPS).union(DIMMERS)
IP_MODEL_CACHE: Dict[str, str] = {}
def filter_model(desc, filter): def filter_model(desc, filter):
filtered = list() filtered = list()
@ -137,23 +140,39 @@ def device_for_file(model):
raise Exception("Unable to find type for %s", model) raise Exception("Unable to find type for %s", model)
def get_device_for_file(file): async def _update_and_close(d):
await d.update()
await d.protocol.close()
return d
async def _discover_update_and_close(ip):
d = await Discover.discover_single(ip)
return await _update_and_close(d)
async def get_device_for_file(file):
# if the wanted file is not an absolute path, prepend the fixtures directory # if the wanted file is not an absolute path, prepend the fixtures directory
p = Path(file) p = Path(file)
if not p.is_absolute(): if not p.is_absolute():
p = Path(__file__).parent / "fixtures" / file p = Path(__file__).parent / "fixtures" / file
with open(p) as f: def load_file():
sysinfo = json.load(f) with open(p) as f:
model = basename(file) return json.load(f)
p = device_for_file(model)(host="127.0.0.123")
p.protocol = FakeTransportProtocol(sysinfo) loop = asyncio.get_running_loop()
asyncio.run(p.update()) sysinfo = await loop.run_in_executor(None, load_file)
return p
model = basename(file)
d = device_for_file(model)(host="127.0.0.123")
d.protocol = FakeTransportProtocol(sysinfo)
await _update_and_close(d)
return d
@pytest.fixture(params=SUPPORTED_DEVICES, scope="session") @pytest.fixture(params=SUPPORTED_DEVICES)
def dev(request): async def dev(request):
"""Device fixture. """Device fixture.
Provides a device (given --ip) or parametrized fixture for the supported devices. Provides a device (given --ip) or parametrized fixture for the supported devices.
@ -163,14 +182,16 @@ def dev(request):
ip = request.config.getoption("--ip") ip = request.config.getoption("--ip")
if ip: if ip:
d = asyncio.run(Discover.discover_single(ip)) model = IP_MODEL_CACHE.get(ip)
asyncio.run(d.update()) d = None
if d.model in file: if not model:
return d d = await _discover_update_and_close(ip)
else: IP_MODEL_CACHE[ip] = model = d.model
if model not in file:
pytest.skip(f"skipping file {file}") pytest.skip(f"skipping file {file}")
return d if d else await _discover_update_and_close(ip)
return get_device_for_file(file) return await get_device_for_file(file)
@pytest.fixture(params=SUPPORTED_DEVICES, scope="session") @pytest.fixture(params=SUPPORTED_DEVICES, scope="session")

View File

@ -83,9 +83,19 @@ PLUG_SCHEMA = Schema(
"icon_hash": str, "icon_hash": str,
"led_off": check_int_bool, "led_off": check_int_bool,
"latitude": Any(All(float, Range(min=-90, max=90)), 0, None), "latitude": Any(All(float, Range(min=-90, max=90)), 0, None),
"latitude_i": Any(All(float, Range(min=-90, max=90)), 0, None), "latitude_i": Any(
All(int, Range(min=-900000, max=900000)),
All(float, Range(min=-900000, max=900000)),
0,
None,
),
"longitude": Any(All(float, Range(min=-180, max=180)), 0, None), "longitude": Any(All(float, Range(min=-180, max=180)), 0, None),
"longitude_i": Any(All(float, Range(min=-180, max=180)), 0, None), "longitude_i": Any(
All(int, Range(min=-18000000, max=18000000)),
All(float, Range(min=-18000000, max=18000000)),
0,
None,
),
"mac": check_mac, "mac": check_mac,
"model": str, "model": str,
"oemId": str, "oemId": str,
@ -117,17 +127,17 @@ LIGHT_STATE_SCHEMA = Schema(
{ {
"brightness": All(int, Range(min=0, max=100)), "brightness": All(int, Range(min=0, max=100)),
"color_temp": int, "color_temp": int,
"hue": All(int, Range(min=0, max=255)), "hue": All(int, Range(min=0, max=360)),
"mode": str, "mode": str,
"on_off": check_int_bool, "on_off": check_int_bool,
"saturation": All(int, Range(min=0, max=255)), "saturation": All(int, Range(min=0, max=100)),
"dft_on_state": Optional( "dft_on_state": Optional(
{ {
"brightness": All(int, Range(min=0, max=100)), "brightness": All(int, Range(min=0, max=100)),
"color_temp": All(int, Range(min=0, max=9000)), "color_temp": All(int, Range(min=0, max=9000)),
"hue": All(int, Range(min=0, max=255)), "hue": All(int, Range(min=0, max=360)),
"mode": str, "mode": str,
"saturation": All(int, Range(min=0, max=255)), "saturation": All(int, Range(min=0, max=100)),
} }
), ),
"err_code": int, "err_code": int,
@ -276,6 +286,8 @@ TIME_MODULE = {
class FakeTransportProtocol(TPLinkSmartHomeProtocol): class FakeTransportProtocol(TPLinkSmartHomeProtocol):
def __init__(self, info): def __init__(self, info):
self.discovery_data = info self.discovery_data = info
self.writer = None
self.reader = None
proto = FakeTransportProtocol.baseproto proto = FakeTransportProtocol.baseproto
for target in info: for target in info:
@ -426,7 +438,7 @@ class FakeTransportProtocol(TPLinkSmartHomeProtocol):
}, },
} }
async def query(self, host, request, port=9999): async def query(self, request, port=9999):
proto = self.proto proto = self.proto
# collect child ids from context # collect child ids from context

View File

@ -60,7 +60,7 @@ async def test_hsv(dev, turn_on):
assert dev.is_color assert dev.is_color
hue, saturation, brightness = dev.hsv hue, saturation, brightness = dev.hsv
assert 0 <= hue <= 255 assert 0 <= hue <= 360
assert 0 <= saturation <= 100 assert 0 <= saturation <= 100
assert 0 <= brightness <= 100 assert 0 <= brightness <= 100

View File

@ -3,7 +3,7 @@ import sys
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
from kasa import DeviceType, Discover, SmartDevice, SmartDeviceException from kasa import DeviceType, Discover, SmartDevice, SmartDeviceException, protocol
from kasa.discover import _DiscoverProtocol from kasa.discover import _DiscoverProtocol
from .conftest import bulb, dimmer, lightstrip, plug, pytestmark, strip from .conftest import bulb, dimmer, lightstrip, plug, pytestmark, strip
@ -94,7 +94,8 @@ async def test_discover_datagram_received(mocker, discovery_data):
"""Verify that datagram received fills discovered_devices.""" """Verify that datagram received fills discovered_devices."""
proto = _DiscoverProtocol() proto = _DiscoverProtocol()
mocker.patch("json.loads", return_value=discovery_data) mocker.patch("json.loads", return_value=discovery_data)
mocker.patch.object(proto, "protocol") mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "encrypt")
mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "decrypt")
addr = "127.0.0.1" addr = "127.0.0.1"
proto.datagram_received("<placeholder data>", (addr, 1234)) proto.datagram_received("<placeholder data>", (addr, 1234))

View File

@ -1,4 +1,6 @@
import json import json
import struct
import sys
import pytest import pytest
@ -21,11 +23,47 @@ async def test_protocol_retries(mocker, retry_count):
conn = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) conn = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
with pytest.raises(SmartDeviceException): with pytest.raises(SmartDeviceException):
await TPLinkSmartHomeProtocol.query("127.0.0.1", {}, retry_count=retry_count) await TPLinkSmartHomeProtocol("127.0.0.1").query({}, retry_count=retry_count)
assert conn.call_count == retry_count + 1 assert conn.call_count == retry_count + 1
@pytest.mark.skipif(sys.version_info < (3, 8), reason="3.8 is first one with asyncmock")
@pytest.mark.parametrize("retry_count", [1, 3, 5])
async def test_protocol_reconnect(mocker, retry_count):
remaining = retry_count
encrypted = TPLinkSmartHomeProtocol.encrypt('{"great":"success"}')[
TPLinkSmartHomeProtocol.BLOCK_SIZE :
]
def _fail_one_less_than_retry_count(*_):
nonlocal remaining
remaining -= 1
if remaining:
raise Exception("Simulated write failure")
async def _mock_read(byte_count):
nonlocal encrypted
if byte_count == TPLinkSmartHomeProtocol.BLOCK_SIZE:
return struct.pack(">I", len(encrypted))
if byte_count == len(encrypted):
return encrypted
raise ValueError(f"No mock for {byte_count}")
def aio_mock_writer(_, __):
reader = mocker.patch("asyncio.StreamReader")
writer = mocker.patch("asyncio.StreamWriter")
mocker.patch.object(writer, "write", _fail_one_less_than_retry_count)
mocker.patch.object(reader, "readexactly", _mock_read)
return reader, writer
protocol = TPLinkSmartHomeProtocol("127.0.0.1")
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
response = await protocol.query({}, retry_count=retry_count)
assert response == {"great": "success"}
def test_encrypt(): def test_encrypt():
d = json.dumps({"foo": 1, "bar": 2}) d = json.dumps({"foo": 1, "bar": 2})
encrypted = TPLinkSmartHomeProtocol.encrypt(d) encrypted = TPLinkSmartHomeProtocol.encrypt(d)

View File

@ -1,3 +1,4 @@
import asyncio
import sys import sys
import pytest import pytest
@ -8,7 +9,7 @@ from kasa.tests.conftest import get_device_for_file
def test_bulb_examples(mocker): def test_bulb_examples(mocker):
"""Use KL130 (bulb with all features) to test the doctests.""" """Use KL130 (bulb with all features) to test the doctests."""
p = get_device_for_file("KL130(US)_1.0.json") p = asyncio.run(get_device_for_file("KL130(US)_1.0.json"))
mocker.patch("kasa.smartbulb.SmartBulb", return_value=p) mocker.patch("kasa.smartbulb.SmartBulb", return_value=p)
mocker.patch("kasa.smartbulb.SmartBulb.update") mocker.patch("kasa.smartbulb.SmartBulb.update")
res = xdoctest.doctest_module("kasa.smartbulb", "all") res = xdoctest.doctest_module("kasa.smartbulb", "all")
@ -17,7 +18,7 @@ def test_bulb_examples(mocker):
def test_smartdevice_examples(mocker): def test_smartdevice_examples(mocker):
"""Use HS110 for emeter examples.""" """Use HS110 for emeter examples."""
p = get_device_for_file("HS110(EU)_1.0_real.json") p = asyncio.run(get_device_for_file("HS110(EU)_1.0_real.json"))
mocker.patch("kasa.smartdevice.SmartDevice", return_value=p) mocker.patch("kasa.smartdevice.SmartDevice", return_value=p)
mocker.patch("kasa.smartdevice.SmartDevice.update") mocker.patch("kasa.smartdevice.SmartDevice.update")
res = xdoctest.doctest_module("kasa.smartdevice", "all") res = xdoctest.doctest_module("kasa.smartdevice", "all")
@ -26,7 +27,7 @@ def test_smartdevice_examples(mocker):
def test_plug_examples(mocker): def test_plug_examples(mocker):
"""Test plug examples.""" """Test plug examples."""
p = get_device_for_file("HS110(EU)_1.0_real.json") p = asyncio.run(get_device_for_file("HS110(EU)_1.0_real.json"))
mocker.patch("kasa.smartplug.SmartPlug", return_value=p) mocker.patch("kasa.smartplug.SmartPlug", return_value=p)
mocker.patch("kasa.smartplug.SmartPlug.update") mocker.patch("kasa.smartplug.SmartPlug.update")
res = xdoctest.doctest_module("kasa.smartplug", "all") res = xdoctest.doctest_module("kasa.smartplug", "all")
@ -35,7 +36,7 @@ def test_plug_examples(mocker):
def test_strip_examples(mocker): def test_strip_examples(mocker):
"""Test strip examples.""" """Test strip examples."""
p = get_device_for_file("KP303(UK)_1.0.json") p = asyncio.run(get_device_for_file("KP303(UK)_1.0.json"))
mocker.patch("kasa.smartstrip.SmartStrip", return_value=p) mocker.patch("kasa.smartstrip.SmartStrip", return_value=p)
mocker.patch("kasa.smartstrip.SmartStrip.update") mocker.patch("kasa.smartstrip.SmartStrip.update")
res = xdoctest.doctest_module("kasa.smartstrip", "all") res = xdoctest.doctest_module("kasa.smartstrip", "all")
@ -44,7 +45,7 @@ def test_strip_examples(mocker):
def test_dimmer_examples(mocker): def test_dimmer_examples(mocker):
"""Test dimmer examples.""" """Test dimmer examples."""
p = get_device_for_file("HS220(US)_1.0_real.json") p = asyncio.run(get_device_for_file("HS220(US)_1.0_real.json"))
mocker.patch("kasa.smartdimmer.SmartDimmer", return_value=p) mocker.patch("kasa.smartdimmer.SmartDimmer", return_value=p)
mocker.patch("kasa.smartdimmer.SmartDimmer.update") mocker.patch("kasa.smartdimmer.SmartDimmer.update")
res = xdoctest.doctest_module("kasa.smartdimmer", "all") res = xdoctest.doctest_module("kasa.smartdimmer", "all")
@ -53,7 +54,7 @@ def test_dimmer_examples(mocker):
def test_lightstrip_examples(mocker): def test_lightstrip_examples(mocker):
"""Test lightstrip examples.""" """Test lightstrip examples."""
p = get_device_for_file("KL430(US)_1.0.json") p = asyncio.run(get_device_for_file("KL430(US)_1.0.json"))
mocker.patch("kasa.smartlightstrip.SmartLightStrip", return_value=p) mocker.patch("kasa.smartlightstrip.SmartLightStrip", return_value=p)
mocker.patch("kasa.smartlightstrip.SmartLightStrip.update") mocker.patch("kasa.smartlightstrip.SmartLightStrip.update")
res = xdoctest.doctest_module("kasa.smartlightstrip", "all") res = xdoctest.doctest_module("kasa.smartlightstrip", "all")
@ -65,7 +66,7 @@ def test_lightstrip_examples(mocker):
) )
def test_discovery_examples(mocker): def test_discovery_examples(mocker):
"""Test discovery examples.""" """Test discovery examples."""
p = get_device_for_file("KP303(UK)_1.0.json") p = asyncio.run(get_device_for_file("KP303(UK)_1.0.json"))
# This succeeds on python 3.8 but fails on 3.7 # This succeeds on python 3.8 but fails on 3.7
# ValueError: a coroutine was expected, got [<DeviceType.Strip model KP303(UK) ... # ValueError: a coroutine was expected, got [<DeviceType.Strip model KP303(UK) ...