Keep connection open and lock to prevent duplicate requests (#213)

* Keep connection open and lock to prevent duplicate requests

* option to not update children

* tweaks

* typing

* tweaks

* run tests in the same event loop

* memorize model

* Update kasa/protocol.py

Co-authored-by: Teemu R. <tpr@iki.fi>

* Update kasa/protocol.py

Co-authored-by: Teemu R. <tpr@iki.fi>

* Update kasa/protocol.py

Co-authored-by: Teemu R. <tpr@iki.fi>

* Update kasa/protocol.py

Co-authored-by: Teemu R. <tpr@iki.fi>

* dry

* tweaks

* warn when the event loop gets switched out from under us

* raise on unable to connect multiple times

* fix patch target

* tweaks

* isrot

* reconnect test

* prune

* fix mocking

* fix mocking

* fix test under python 3.7

* fix test under python 3.7

* less patching

* isort

* use mocker to patch

* disable on old python since mocking doesnt work

* avoid disconnect/reconnect cycles

* isort

* Fix hue validation

* Fix latitude_i/longitude_i units

Co-authored-by: Teemu R. <tpr@iki.fi>
This commit is contained in:
J. Nick Koston 2021-09-24 16:25:43 -05:00 committed by GitHub
parent f1b28e79b9
commit e31cc6662c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 241 additions and 96 deletions

View File

@ -78,16 +78,17 @@ def cli(host, debug):
),
]
protocol = TPLinkSmartHomeProtocol()
successes = []
for test_call in items:
async def _run_query():
protocol = TPLinkSmartHomeProtocol(host)
return await protocol.query({test_call.module: {test_call.method: None}})
try:
click.echo(f"Testing {test_call}..", nl=False)
info = asyncio.run(
protocol.query(host, {test_call.module: {test_call.method: None}})
)
info = asyncio.run(_run_query())
resp = info[test_call.module]
except Exception as ex:
click.echo(click.style(f"FAIL {ex}", fg="red"))
@ -107,8 +108,12 @@ def cli(host, debug):
final = default_to_regular(final)
async def _run_final_query():
protocol = TPLinkSmartHomeProtocol(host)
return await protocol.query(final_query)
try:
final = asyncio.run(protocol.query(host, final_query))
final = asyncio.run(_run_final_query())
except Exception as ex:
click.echo(
click.style(

View File

@ -40,7 +40,6 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
self.discovery_packets = discovery_packets
self.interface = interface
self.on_discovered = on_discovered
self.protocol = TPLinkSmartHomeProtocol()
self.target = (target, Discover.DISCOVERY_PORT)
self.discovered_devices = {}
@ -61,7 +60,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
"""Send number of discovery datagrams."""
req = json.dumps(Discover.DISCOVERY_QUERY)
_LOGGER.debug("[DISCOVERY] %s >> %s", self.target, Discover.DISCOVERY_QUERY)
encrypted_req = self.protocol.encrypt(req)
encrypted_req = TPLinkSmartHomeProtocol.encrypt(req)
for i in range(self.discovery_packets):
self.transport.sendto(encrypted_req[4:], self.target) # type: ignore
@ -71,7 +70,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
if ip in self.discovered_devices:
return
info = json.loads(self.protocol.decrypt(data))
info = json.loads(TPLinkSmartHomeProtocol.decrypt(data))
_LOGGER.debug("[DISCOVERY] %s << %s", ip, info)
device_class = Discover._get_device_class(info)
@ -190,9 +189,9 @@ class Discover:
:rtype: SmartDevice
:return: Object for querying/controlling found device.
"""
protocol = TPLinkSmartHomeProtocol()
protocol = TPLinkSmartHomeProtocol(host)
info = await protocol.query(host, Discover.DISCOVERY_QUERY)
info = await protocol.query(Discover.DISCOVERY_QUERY)
device_class = Discover._get_device_class(info)
dev = device_class(host)

View File

@ -10,11 +10,12 @@ which are licensed under the Apache License, Version 2.0
http://www.apache.org/licenses/LICENSE-2.0
"""
import asyncio
import contextlib
import json
import logging
import struct
from pprint import pformat as pf
from typing import Dict, Union
from typing import Dict, Optional, Union
from .exceptions import SmartDeviceException
@ -28,8 +29,26 @@ class TPLinkSmartHomeProtocol:
DEFAULT_PORT = 9999
DEFAULT_TIMEOUT = 5
@staticmethod
async def query(host: str, request: Union[str, Dict], retry_count: int = 3) -> Dict:
BLOCK_SIZE = 4
def __init__(self, host: str) -> None:
"""Create a protocol object."""
self.host = host
self.reader: Optional[asyncio.StreamReader] = None
self.writer: Optional[asyncio.StreamWriter] = None
self.query_lock: Optional[asyncio.Lock] = None
self.loop: Optional[asyncio.AbstractEventLoop] = None
def _detect_event_loop_change(self) -> None:
"""Check if this object has been reused betwen event loops."""
loop = asyncio.get_running_loop()
if not self.loop:
self.loop = loop
elif self.loop != loop:
_LOGGER.warning("Detected protocol reuse between different event loop")
self._reset()
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
"""Request information from a TP-Link SmartHome Device.
:param str host: host name or ip address of the device
@ -38,57 +57,106 @@ class TPLinkSmartHomeProtocol:
:param retry_count: how many retries to do in case of failure
:return: response dict
"""
self._detect_event_loop_change()
if not self.query_lock:
self.query_lock = asyncio.Lock()
if isinstance(request, dict):
request = json.dumps(request)
assert isinstance(request, str)
timeout = TPLinkSmartHomeProtocol.DEFAULT_TIMEOUT
writer = None
async with self.query_lock:
return await self._query(request, retry_count, timeout)
async def _connect(self, timeout: int) -> bool:
"""Try to connect or reconnect to the device."""
if self.writer:
return True
with contextlib.suppress(Exception):
self.reader = self.writer = None
task = asyncio.open_connection(
self.host, TPLinkSmartHomeProtocol.DEFAULT_PORT
)
self.reader, self.writer = await asyncio.wait_for(task, timeout=timeout)
return True
return False
async def _execute_query(self, request: str) -> Dict:
"""Execute a query on the device and wait for the response."""
assert self.writer is not None
assert self.reader is not None
_LOGGER.debug("> (%i) %s", len(request), request)
self.writer.write(TPLinkSmartHomeProtocol.encrypt(request))
await self.writer.drain()
packed_block_size = await self.reader.readexactly(self.BLOCK_SIZE)
length = struct.unpack(">I", packed_block_size)[0]
buffer = await self.reader.readexactly(length)
response = TPLinkSmartHomeProtocol.decrypt(buffer)
json_payload = json.loads(response)
_LOGGER.debug("< (%i) %s", len(response), pf(json_payload))
return json_payload
async def close(self):
"""Close the connection."""
writer = self.writer
self._reset()
if writer:
writer.close()
with contextlib.suppress(Exception):
await writer.wait_closed()
def _reset(self):
"""Clear any varibles that should not survive between loops."""
self.writer = None
self.reader = None
self.query_lock = None
self.loop = None
async def _query(self, request: str, retry_count: int, timeout: int) -> Dict:
"""Try to query a device."""
for retry in range(retry_count + 1):
try:
task = asyncio.open_connection(
host, TPLinkSmartHomeProtocol.DEFAULT_PORT
)
reader, writer = await asyncio.wait_for(task, timeout=timeout)
_LOGGER.debug("> (%i) %s", len(request), request)
writer.write(TPLinkSmartHomeProtocol.encrypt(request))
await writer.drain()
buffer = bytes()
# Some devices send responses with a length header of 0 and
# terminate with a zero size chunk. Others send the length and
# will hang if we attempt to read more data.
length = -1
while True:
chunk = await reader.read(4096)
if length == -1:
length = struct.unpack(">I", chunk[0:4])[0]
buffer += chunk
if (length > 0 and len(buffer) >= length + 4) or not chunk:
break
response = TPLinkSmartHomeProtocol.decrypt(buffer[4:])
json_payload = json.loads(response)
_LOGGER.debug("< (%i) %s", len(response), pf(json_payload))
return json_payload
except Exception as ex:
if not await self._connect(timeout):
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up after %s retries", retry)
raise SmartDeviceException(
"Unable to query the device: %s" % ex
f"Unable to connect to the device: {self.host}"
)
continue
try:
assert self.reader is not None
assert self.writer is not None
return await asyncio.wait_for(
self._execute_query(request), timeout=timeout
)
except Exception as ex:
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up after %s retries", retry)
raise SmartDeviceException(
f"Unable to query the device: {ex}"
) from ex
_LOGGER.debug("Unable to query the device, retrying: %s", ex)
finally:
if writer:
writer.close()
await writer.wait_closed()
# make mypy happy, this should never be reached..
await self.close()
raise SmartDeviceException("Query reached somehow to unreachable")
def __del__(self):
if self.writer and self.loop and self.loop.is_running():
self.writer.close()
self._reset()
@staticmethod
def _xor_payload(unencrypted):
key = TPLinkSmartHomeProtocol.INITIALIZATION_VECTOR

View File

@ -194,7 +194,7 @@ class SmartDevice:
"""
self.host = host
self.protocol = TPLinkSmartHomeProtocol()
self.protocol = TPLinkSmartHomeProtocol(host)
self.emeter_type = "emeter"
_LOGGER.debug("Initializing %s of type %s", self.host, type(self))
self._device_type = DeviceType.Unknown
@ -234,7 +234,7 @@ class SmartDevice:
request = self._create_request(target, cmd, arg, child_ids)
try:
response = await self.protocol.query(host=self.host, request=request)
response = await self.protocol.query(request=request)
except Exception as ex:
raise SmartDeviceException(f"Communication error on {target}:{cmd}") from ex
@ -272,7 +272,7 @@ class SmartDevice:
"""Retrieve system information."""
return await self._query_helper("system", "get_sysinfo")
async def update(self):
async def update(self, update_children: bool = True):
"""Query the device to update the data.
Needed for properties that are decorated with `requires_update`.
@ -285,7 +285,7 @@ class SmartDevice:
# See #105, #120, #161
if self._last_update is None:
_LOGGER.debug("Performing the initial update to obtain sysinfo")
self._last_update = await self.protocol.query(self.host, req)
self._last_update = await self.protocol.query(req)
self._sys_info = self._last_update["system"]["get_sysinfo"]
# If the device has no emeter, we are done for the initial update
# Otherwise we will follow the regular code path to also query
@ -299,7 +299,7 @@ class SmartDevice:
)
req.update(self._create_emeter_request())
self._last_update = await self.protocol.query(self.host, req)
self._last_update = await self.protocol.query(req)
self._sys_info = self._last_update["system"]["get_sysinfo"]
def update_from_discover_info(self, info):
@ -383,8 +383,8 @@ class SmartDevice:
loc["latitude"] = sys_info["latitude"]
loc["longitude"] = sys_info["longitude"]
elif "latitude_i" in sys_info and "longitude_i" in sys_info:
loc["latitude"] = sys_info["latitude_i"]
loc["longitude"] = sys_info["longitude_i"]
loc["latitude"] = sys_info["latitude_i"] / 10000
loc["longitude"] = sys_info["longitude_i"] / 10000
else:
_LOGGER.warning("Unsupported device location.")

View File

@ -87,12 +87,12 @@ class SmartStrip(SmartDevice):
"""Return if any of the outlets are on."""
return any(plug.is_on for plug in self.children)
async def update(self):
async def update(self, update_children: bool = True):
"""Update some of the attributes.
Needed for methods that are decorated with `requires_update`.
"""
await super().update()
await super().update(update_children)
# Initialize the child devices during the first update.
if not self.children:
@ -103,7 +103,7 @@ class SmartStrip(SmartDevice):
SmartStripPlug(self.host, parent=self, child_id=child["id"])
)
if self.has_emeter:
if update_children and self.has_emeter:
for plug in self.children:
await plug.update()
@ -243,13 +243,13 @@ class SmartStripPlug(SmartPlug):
self._sys_info = parent._sys_info
self._device_type = DeviceType.StripSocket
async def update(self):
async def update(self, update_children: bool = True):
"""Query the device to update the data.
Needed for properties that are decorated with `requires_update`.
"""
self._last_update = await self.parent.protocol.query(
self.host, self._create_emeter_request()
self._create_emeter_request()
)
def _create_request(

View File

@ -4,6 +4,7 @@ import json
import os
from os.path import basename
from pathlib import Path, PurePath
from typing import Dict
from unittest.mock import MagicMock
import pytest # type: ignore # see https://github.com/pytest-dev/pytest/issues/3342
@ -39,6 +40,8 @@ WITH_EMETER = {"HS110", "HS300", "KP115", *BULBS}
ALL_DEVICES = BULBS.union(PLUGS).union(STRIPS).union(DIMMERS)
IP_MODEL_CACHE: Dict[str, str] = {}
def filter_model(desc, filter):
filtered = list()
@ -137,23 +140,39 @@ def device_for_file(model):
raise Exception("Unable to find type for %s", model)
def get_device_for_file(file):
async def _update_and_close(d):
await d.update()
await d.protocol.close()
return d
async def _discover_update_and_close(ip):
d = await Discover.discover_single(ip)
return await _update_and_close(d)
async def get_device_for_file(file):
# if the wanted file is not an absolute path, prepend the fixtures directory
p = Path(file)
if not p.is_absolute():
p = Path(__file__).parent / "fixtures" / file
with open(p) as f:
sysinfo = json.load(f)
model = basename(file)
p = device_for_file(model)(host="127.0.0.123")
p.protocol = FakeTransportProtocol(sysinfo)
asyncio.run(p.update())
return p
def load_file():
with open(p) as f:
return json.load(f)
loop = asyncio.get_running_loop()
sysinfo = await loop.run_in_executor(None, load_file)
model = basename(file)
d = device_for_file(model)(host="127.0.0.123")
d.protocol = FakeTransportProtocol(sysinfo)
await _update_and_close(d)
return d
@pytest.fixture(params=SUPPORTED_DEVICES, scope="session")
def dev(request):
@pytest.fixture(params=SUPPORTED_DEVICES)
async def dev(request):
"""Device fixture.
Provides a device (given --ip) or parametrized fixture for the supported devices.
@ -163,14 +182,16 @@ def dev(request):
ip = request.config.getoption("--ip")
if ip:
d = asyncio.run(Discover.discover_single(ip))
asyncio.run(d.update())
if d.model in file:
return d
else:
model = IP_MODEL_CACHE.get(ip)
d = None
if not model:
d = await _discover_update_and_close(ip)
IP_MODEL_CACHE[ip] = model = d.model
if model not in file:
pytest.skip(f"skipping file {file}")
return d if d else await _discover_update_and_close(ip)
return get_device_for_file(file)
return await get_device_for_file(file)
@pytest.fixture(params=SUPPORTED_DEVICES, scope="session")

View File

@ -83,9 +83,19 @@ PLUG_SCHEMA = Schema(
"icon_hash": str,
"led_off": check_int_bool,
"latitude": Any(All(float, Range(min=-90, max=90)), 0, None),
"latitude_i": Any(All(float, Range(min=-90, max=90)), 0, None),
"latitude_i": Any(
All(int, Range(min=-900000, max=900000)),
All(float, Range(min=-900000, max=900000)),
0,
None,
),
"longitude": Any(All(float, Range(min=-180, max=180)), 0, None),
"longitude_i": Any(All(float, Range(min=-180, max=180)), 0, None),
"longitude_i": Any(
All(int, Range(min=-18000000, max=18000000)),
All(float, Range(min=-18000000, max=18000000)),
0,
None,
),
"mac": check_mac,
"model": str,
"oemId": str,
@ -117,17 +127,17 @@ LIGHT_STATE_SCHEMA = Schema(
{
"brightness": All(int, Range(min=0, max=100)),
"color_temp": int,
"hue": All(int, Range(min=0, max=255)),
"hue": All(int, Range(min=0, max=360)),
"mode": str,
"on_off": check_int_bool,
"saturation": All(int, Range(min=0, max=255)),
"saturation": All(int, Range(min=0, max=100)),
"dft_on_state": Optional(
{
"brightness": All(int, Range(min=0, max=100)),
"color_temp": All(int, Range(min=0, max=9000)),
"hue": All(int, Range(min=0, max=255)),
"hue": All(int, Range(min=0, max=360)),
"mode": str,
"saturation": All(int, Range(min=0, max=255)),
"saturation": All(int, Range(min=0, max=100)),
}
),
"err_code": int,
@ -276,6 +286,8 @@ TIME_MODULE = {
class FakeTransportProtocol(TPLinkSmartHomeProtocol):
def __init__(self, info):
self.discovery_data = info
self.writer = None
self.reader = None
proto = FakeTransportProtocol.baseproto
for target in info:
@ -426,7 +438,7 @@ class FakeTransportProtocol(TPLinkSmartHomeProtocol):
},
}
async def query(self, host, request, port=9999):
async def query(self, request, port=9999):
proto = self.proto
# collect child ids from context

View File

@ -60,7 +60,7 @@ async def test_hsv(dev, turn_on):
assert dev.is_color
hue, saturation, brightness = dev.hsv
assert 0 <= hue <= 255
assert 0 <= hue <= 360
assert 0 <= saturation <= 100
assert 0 <= brightness <= 100

View File

@ -3,7 +3,7 @@ import sys
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
from kasa import DeviceType, Discover, SmartDevice, SmartDeviceException
from kasa import DeviceType, Discover, SmartDevice, SmartDeviceException, protocol
from kasa.discover import _DiscoverProtocol
from .conftest import bulb, dimmer, lightstrip, plug, pytestmark, strip
@ -94,7 +94,8 @@ async def test_discover_datagram_received(mocker, discovery_data):
"""Verify that datagram received fills discovered_devices."""
proto = _DiscoverProtocol()
mocker.patch("json.loads", return_value=discovery_data)
mocker.patch.object(proto, "protocol")
mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "encrypt")
mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "decrypt")
addr = "127.0.0.1"
proto.datagram_received("<placeholder data>", (addr, 1234))

View File

@ -1,4 +1,6 @@
import json
import struct
import sys
import pytest
@ -21,11 +23,47 @@ async def test_protocol_retries(mocker, retry_count):
conn = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
with pytest.raises(SmartDeviceException):
await TPLinkSmartHomeProtocol.query("127.0.0.1", {}, retry_count=retry_count)
await TPLinkSmartHomeProtocol("127.0.0.1").query({}, retry_count=retry_count)
assert conn.call_count == retry_count + 1
@pytest.mark.skipif(sys.version_info < (3, 8), reason="3.8 is first one with asyncmock")
@pytest.mark.parametrize("retry_count", [1, 3, 5])
async def test_protocol_reconnect(mocker, retry_count):
remaining = retry_count
encrypted = TPLinkSmartHomeProtocol.encrypt('{"great":"success"}')[
TPLinkSmartHomeProtocol.BLOCK_SIZE :
]
def _fail_one_less_than_retry_count(*_):
nonlocal remaining
remaining -= 1
if remaining:
raise Exception("Simulated write failure")
async def _mock_read(byte_count):
nonlocal encrypted
if byte_count == TPLinkSmartHomeProtocol.BLOCK_SIZE:
return struct.pack(">I", len(encrypted))
if byte_count == len(encrypted):
return encrypted
raise ValueError(f"No mock for {byte_count}")
def aio_mock_writer(_, __):
reader = mocker.patch("asyncio.StreamReader")
writer = mocker.patch("asyncio.StreamWriter")
mocker.patch.object(writer, "write", _fail_one_less_than_retry_count)
mocker.patch.object(reader, "readexactly", _mock_read)
return reader, writer
protocol = TPLinkSmartHomeProtocol("127.0.0.1")
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
response = await protocol.query({}, retry_count=retry_count)
assert response == {"great": "success"}
def test_encrypt():
d = json.dumps({"foo": 1, "bar": 2})
encrypted = TPLinkSmartHomeProtocol.encrypt(d)

View File

@ -1,3 +1,4 @@
import asyncio
import sys
import pytest
@ -8,7 +9,7 @@ from kasa.tests.conftest import get_device_for_file
def test_bulb_examples(mocker):
"""Use KL130 (bulb with all features) to test the doctests."""
p = get_device_for_file("KL130(US)_1.0.json")
p = asyncio.run(get_device_for_file("KL130(US)_1.0.json"))
mocker.patch("kasa.smartbulb.SmartBulb", return_value=p)
mocker.patch("kasa.smartbulb.SmartBulb.update")
res = xdoctest.doctest_module("kasa.smartbulb", "all")
@ -17,7 +18,7 @@ def test_bulb_examples(mocker):
def test_smartdevice_examples(mocker):
"""Use HS110 for emeter examples."""
p = get_device_for_file("HS110(EU)_1.0_real.json")
p = asyncio.run(get_device_for_file("HS110(EU)_1.0_real.json"))
mocker.patch("kasa.smartdevice.SmartDevice", return_value=p)
mocker.patch("kasa.smartdevice.SmartDevice.update")
res = xdoctest.doctest_module("kasa.smartdevice", "all")
@ -26,7 +27,7 @@ def test_smartdevice_examples(mocker):
def test_plug_examples(mocker):
"""Test plug examples."""
p = get_device_for_file("HS110(EU)_1.0_real.json")
p = asyncio.run(get_device_for_file("HS110(EU)_1.0_real.json"))
mocker.patch("kasa.smartplug.SmartPlug", return_value=p)
mocker.patch("kasa.smartplug.SmartPlug.update")
res = xdoctest.doctest_module("kasa.smartplug", "all")
@ -35,7 +36,7 @@ def test_plug_examples(mocker):
def test_strip_examples(mocker):
"""Test strip examples."""
p = get_device_for_file("KP303(UK)_1.0.json")
p = asyncio.run(get_device_for_file("KP303(UK)_1.0.json"))
mocker.patch("kasa.smartstrip.SmartStrip", return_value=p)
mocker.patch("kasa.smartstrip.SmartStrip.update")
res = xdoctest.doctest_module("kasa.smartstrip", "all")
@ -44,7 +45,7 @@ def test_strip_examples(mocker):
def test_dimmer_examples(mocker):
"""Test dimmer examples."""
p = get_device_for_file("HS220(US)_1.0_real.json")
p = asyncio.run(get_device_for_file("HS220(US)_1.0_real.json"))
mocker.patch("kasa.smartdimmer.SmartDimmer", return_value=p)
mocker.patch("kasa.smartdimmer.SmartDimmer.update")
res = xdoctest.doctest_module("kasa.smartdimmer", "all")
@ -53,7 +54,7 @@ def test_dimmer_examples(mocker):
def test_lightstrip_examples(mocker):
"""Test lightstrip examples."""
p = get_device_for_file("KL430(US)_1.0.json")
p = asyncio.run(get_device_for_file("KL430(US)_1.0.json"))
mocker.patch("kasa.smartlightstrip.SmartLightStrip", return_value=p)
mocker.patch("kasa.smartlightstrip.SmartLightStrip.update")
res = xdoctest.doctest_module("kasa.smartlightstrip", "all")
@ -65,7 +66,7 @@ def test_lightstrip_examples(mocker):
)
def test_discovery_examples(mocker):
"""Test discovery examples."""
p = get_device_for_file("KP303(UK)_1.0.json")
p = asyncio.run(get_device_for_file("KP303(UK)_1.0.json"))
# This succeeds on python 3.8 but fails on 3.7
# ValueError: a coroutine was expected, got [<DeviceType.Strip model KP303(UK) ...