Keep connection open and lock to prevent duplicate requests (#213)

* Keep connection open and lock to prevent duplicate requests

* option to not update children

* tweaks

* typing

* tweaks

* run tests in the same event loop

* memorize model

* Update kasa/protocol.py

Co-authored-by: Teemu R. <tpr@iki.fi>

* Update kasa/protocol.py

Co-authored-by: Teemu R. <tpr@iki.fi>

* Update kasa/protocol.py

Co-authored-by: Teemu R. <tpr@iki.fi>

* Update kasa/protocol.py

Co-authored-by: Teemu R. <tpr@iki.fi>

* dry

* tweaks

* warn when the event loop gets switched out from under us

* raise on unable to connect multiple times

* fix patch target

* tweaks

* isrot

* reconnect test

* prune

* fix mocking

* fix mocking

* fix test under python 3.7

* fix test under python 3.7

* less patching

* isort

* use mocker to patch

* disable on old python since mocking doesnt work

* avoid disconnect/reconnect cycles

* isort

* Fix hue validation

* Fix latitude_i/longitude_i units

Co-authored-by: Teemu R. <tpr@iki.fi>
This commit is contained in:
J. Nick Koston
2021-09-24 16:25:43 -05:00
committed by GitHub
parent f1b28e79b9
commit e31cc6662c
11 changed files with 241 additions and 96 deletions

View File

@@ -4,6 +4,7 @@ import json
import os
from os.path import basename
from pathlib import Path, PurePath
from typing import Dict
from unittest.mock import MagicMock
import pytest # type: ignore # see https://github.com/pytest-dev/pytest/issues/3342
@@ -39,6 +40,8 @@ WITH_EMETER = {"HS110", "HS300", "KP115", *BULBS}
ALL_DEVICES = BULBS.union(PLUGS).union(STRIPS).union(DIMMERS)
IP_MODEL_CACHE: Dict[str, str] = {}
def filter_model(desc, filter):
filtered = list()
@@ -137,23 +140,39 @@ def device_for_file(model):
raise Exception("Unable to find type for %s", model)
def get_device_for_file(file):
async def _update_and_close(d):
await d.update()
await d.protocol.close()
return d
async def _discover_update_and_close(ip):
d = await Discover.discover_single(ip)
return await _update_and_close(d)
async def get_device_for_file(file):
# if the wanted file is not an absolute path, prepend the fixtures directory
p = Path(file)
if not p.is_absolute():
p = Path(__file__).parent / "fixtures" / file
with open(p) as f:
sysinfo = json.load(f)
model = basename(file)
p = device_for_file(model)(host="127.0.0.123")
p.protocol = FakeTransportProtocol(sysinfo)
asyncio.run(p.update())
return p
def load_file():
with open(p) as f:
return json.load(f)
loop = asyncio.get_running_loop()
sysinfo = await loop.run_in_executor(None, load_file)
model = basename(file)
d = device_for_file(model)(host="127.0.0.123")
d.protocol = FakeTransportProtocol(sysinfo)
await _update_and_close(d)
return d
@pytest.fixture(params=SUPPORTED_DEVICES, scope="session")
def dev(request):
@pytest.fixture(params=SUPPORTED_DEVICES)
async def dev(request):
"""Device fixture.
Provides a device (given --ip) or parametrized fixture for the supported devices.
@@ -163,14 +182,16 @@ def dev(request):
ip = request.config.getoption("--ip")
if ip:
d = asyncio.run(Discover.discover_single(ip))
asyncio.run(d.update())
if d.model in file:
return d
else:
model = IP_MODEL_CACHE.get(ip)
d = None
if not model:
d = await _discover_update_and_close(ip)
IP_MODEL_CACHE[ip] = model = d.model
if model not in file:
pytest.skip(f"skipping file {file}")
return d if d else await _discover_update_and_close(ip)
return get_device_for_file(file)
return await get_device_for_file(file)
@pytest.fixture(params=SUPPORTED_DEVICES, scope="session")

View File

@@ -83,9 +83,19 @@ PLUG_SCHEMA = Schema(
"icon_hash": str,
"led_off": check_int_bool,
"latitude": Any(All(float, Range(min=-90, max=90)), 0, None),
"latitude_i": Any(All(float, Range(min=-90, max=90)), 0, None),
"latitude_i": Any(
All(int, Range(min=-900000, max=900000)),
All(float, Range(min=-900000, max=900000)),
0,
None,
),
"longitude": Any(All(float, Range(min=-180, max=180)), 0, None),
"longitude_i": Any(All(float, Range(min=-180, max=180)), 0, None),
"longitude_i": Any(
All(int, Range(min=-18000000, max=18000000)),
All(float, Range(min=-18000000, max=18000000)),
0,
None,
),
"mac": check_mac,
"model": str,
"oemId": str,
@@ -117,17 +127,17 @@ LIGHT_STATE_SCHEMA = Schema(
{
"brightness": All(int, Range(min=0, max=100)),
"color_temp": int,
"hue": All(int, Range(min=0, max=255)),
"hue": All(int, Range(min=0, max=360)),
"mode": str,
"on_off": check_int_bool,
"saturation": All(int, Range(min=0, max=255)),
"saturation": All(int, Range(min=0, max=100)),
"dft_on_state": Optional(
{
"brightness": All(int, Range(min=0, max=100)),
"color_temp": All(int, Range(min=0, max=9000)),
"hue": All(int, Range(min=0, max=255)),
"hue": All(int, Range(min=0, max=360)),
"mode": str,
"saturation": All(int, Range(min=0, max=255)),
"saturation": All(int, Range(min=0, max=100)),
}
),
"err_code": int,
@@ -276,6 +286,8 @@ TIME_MODULE = {
class FakeTransportProtocol(TPLinkSmartHomeProtocol):
def __init__(self, info):
self.discovery_data = info
self.writer = None
self.reader = None
proto = FakeTransportProtocol.baseproto
for target in info:
@@ -426,7 +438,7 @@ class FakeTransportProtocol(TPLinkSmartHomeProtocol):
},
}
async def query(self, host, request, port=9999):
async def query(self, request, port=9999):
proto = self.proto
# collect child ids from context

View File

@@ -60,7 +60,7 @@ async def test_hsv(dev, turn_on):
assert dev.is_color
hue, saturation, brightness = dev.hsv
assert 0 <= hue <= 255
assert 0 <= hue <= 360
assert 0 <= saturation <= 100
assert 0 <= brightness <= 100

View File

@@ -3,7 +3,7 @@ import sys
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
from kasa import DeviceType, Discover, SmartDevice, SmartDeviceException
from kasa import DeviceType, Discover, SmartDevice, SmartDeviceException, protocol
from kasa.discover import _DiscoverProtocol
from .conftest import bulb, dimmer, lightstrip, plug, pytestmark, strip
@@ -94,7 +94,8 @@ async def test_discover_datagram_received(mocker, discovery_data):
"""Verify that datagram received fills discovered_devices."""
proto = _DiscoverProtocol()
mocker.patch("json.loads", return_value=discovery_data)
mocker.patch.object(proto, "protocol")
mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "encrypt")
mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "decrypt")
addr = "127.0.0.1"
proto.datagram_received("<placeholder data>", (addr, 1234))

View File

@@ -1,4 +1,6 @@
import json
import struct
import sys
import pytest
@@ -21,11 +23,47 @@ async def test_protocol_retries(mocker, retry_count):
conn = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
with pytest.raises(SmartDeviceException):
await TPLinkSmartHomeProtocol.query("127.0.0.1", {}, retry_count=retry_count)
await TPLinkSmartHomeProtocol("127.0.0.1").query({}, retry_count=retry_count)
assert conn.call_count == retry_count + 1
@pytest.mark.skipif(sys.version_info < (3, 8), reason="3.8 is first one with asyncmock")
@pytest.mark.parametrize("retry_count", [1, 3, 5])
async def test_protocol_reconnect(mocker, retry_count):
remaining = retry_count
encrypted = TPLinkSmartHomeProtocol.encrypt('{"great":"success"}')[
TPLinkSmartHomeProtocol.BLOCK_SIZE :
]
def _fail_one_less_than_retry_count(*_):
nonlocal remaining
remaining -= 1
if remaining:
raise Exception("Simulated write failure")
async def _mock_read(byte_count):
nonlocal encrypted
if byte_count == TPLinkSmartHomeProtocol.BLOCK_SIZE:
return struct.pack(">I", len(encrypted))
if byte_count == len(encrypted):
return encrypted
raise ValueError(f"No mock for {byte_count}")
def aio_mock_writer(_, __):
reader = mocker.patch("asyncio.StreamReader")
writer = mocker.patch("asyncio.StreamWriter")
mocker.patch.object(writer, "write", _fail_one_less_than_retry_count)
mocker.patch.object(reader, "readexactly", _mock_read)
return reader, writer
protocol = TPLinkSmartHomeProtocol("127.0.0.1")
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
response = await protocol.query({}, retry_count=retry_count)
assert response == {"great": "success"}
def test_encrypt():
d = json.dumps({"foo": 1, "bar": 2})
encrypted = TPLinkSmartHomeProtocol.encrypt(d)

View File

@@ -1,3 +1,4 @@
import asyncio
import sys
import pytest
@@ -8,7 +9,7 @@ from kasa.tests.conftest import get_device_for_file
def test_bulb_examples(mocker):
"""Use KL130 (bulb with all features) to test the doctests."""
p = get_device_for_file("KL130(US)_1.0.json")
p = asyncio.run(get_device_for_file("KL130(US)_1.0.json"))
mocker.patch("kasa.smartbulb.SmartBulb", return_value=p)
mocker.patch("kasa.smartbulb.SmartBulb.update")
res = xdoctest.doctest_module("kasa.smartbulb", "all")
@@ -17,7 +18,7 @@ def test_bulb_examples(mocker):
def test_smartdevice_examples(mocker):
"""Use HS110 for emeter examples."""
p = get_device_for_file("HS110(EU)_1.0_real.json")
p = asyncio.run(get_device_for_file("HS110(EU)_1.0_real.json"))
mocker.patch("kasa.smartdevice.SmartDevice", return_value=p)
mocker.patch("kasa.smartdevice.SmartDevice.update")
res = xdoctest.doctest_module("kasa.smartdevice", "all")
@@ -26,7 +27,7 @@ def test_smartdevice_examples(mocker):
def test_plug_examples(mocker):
"""Test plug examples."""
p = get_device_for_file("HS110(EU)_1.0_real.json")
p = asyncio.run(get_device_for_file("HS110(EU)_1.0_real.json"))
mocker.patch("kasa.smartplug.SmartPlug", return_value=p)
mocker.patch("kasa.smartplug.SmartPlug.update")
res = xdoctest.doctest_module("kasa.smartplug", "all")
@@ -35,7 +36,7 @@ def test_plug_examples(mocker):
def test_strip_examples(mocker):
"""Test strip examples."""
p = get_device_for_file("KP303(UK)_1.0.json")
p = asyncio.run(get_device_for_file("KP303(UK)_1.0.json"))
mocker.patch("kasa.smartstrip.SmartStrip", return_value=p)
mocker.patch("kasa.smartstrip.SmartStrip.update")
res = xdoctest.doctest_module("kasa.smartstrip", "all")
@@ -44,7 +45,7 @@ def test_strip_examples(mocker):
def test_dimmer_examples(mocker):
"""Test dimmer examples."""
p = get_device_for_file("HS220(US)_1.0_real.json")
p = asyncio.run(get_device_for_file("HS220(US)_1.0_real.json"))
mocker.patch("kasa.smartdimmer.SmartDimmer", return_value=p)
mocker.patch("kasa.smartdimmer.SmartDimmer.update")
res = xdoctest.doctest_module("kasa.smartdimmer", "all")
@@ -53,7 +54,7 @@ def test_dimmer_examples(mocker):
def test_lightstrip_examples(mocker):
"""Test lightstrip examples."""
p = get_device_for_file("KL430(US)_1.0.json")
p = asyncio.run(get_device_for_file("KL430(US)_1.0.json"))
mocker.patch("kasa.smartlightstrip.SmartLightStrip", return_value=p)
mocker.patch("kasa.smartlightstrip.SmartLightStrip.update")
res = xdoctest.doctest_module("kasa.smartlightstrip", "all")
@@ -65,7 +66,7 @@ def test_lightstrip_examples(mocker):
)
def test_discovery_examples(mocker):
"""Test discovery examples."""
p = get_device_for_file("KP303(UK)_1.0.json")
p = asyncio.run(get_device_for_file("KP303(UK)_1.0.json"))
# This succeeds on python 3.8 but fails on 3.7
# ValueError: a coroutine was expected, got [<DeviceType.Strip model KP303(UK) ...