mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-01-08 22:07:06 +00:00
Keep connection open and lock to prevent duplicate requests (#213)
* Keep connection open and lock to prevent duplicate requests * option to not update children * tweaks * typing * tweaks * run tests in the same event loop * memorize model * Update kasa/protocol.py Co-authored-by: Teemu R. <tpr@iki.fi> * Update kasa/protocol.py Co-authored-by: Teemu R. <tpr@iki.fi> * Update kasa/protocol.py Co-authored-by: Teemu R. <tpr@iki.fi> * Update kasa/protocol.py Co-authored-by: Teemu R. <tpr@iki.fi> * dry * tweaks * warn when the event loop gets switched out from under us * raise on unable to connect multiple times * fix patch target * tweaks * isrot * reconnect test * prune * fix mocking * fix mocking * fix test under python 3.7 * fix test under python 3.7 * less patching * isort * use mocker to patch * disable on old python since mocking doesnt work * avoid disconnect/reconnect cycles * isort * Fix hue validation * Fix latitude_i/longitude_i units Co-authored-by: Teemu R. <tpr@iki.fi>
This commit is contained in:
parent
f1b28e79b9
commit
e31cc6662c
@ -78,16 +78,17 @@ def cli(host, debug):
|
||||
),
|
||||
]
|
||||
|
||||
protocol = TPLinkSmartHomeProtocol()
|
||||
|
||||
successes = []
|
||||
|
||||
for test_call in items:
|
||||
|
||||
async def _run_query():
|
||||
protocol = TPLinkSmartHomeProtocol(host)
|
||||
return await protocol.query({test_call.module: {test_call.method: None}})
|
||||
|
||||
try:
|
||||
click.echo(f"Testing {test_call}..", nl=False)
|
||||
info = asyncio.run(
|
||||
protocol.query(host, {test_call.module: {test_call.method: None}})
|
||||
)
|
||||
info = asyncio.run(_run_query())
|
||||
resp = info[test_call.module]
|
||||
except Exception as ex:
|
||||
click.echo(click.style(f"FAIL {ex}", fg="red"))
|
||||
@ -107,8 +108,12 @@ def cli(host, debug):
|
||||
|
||||
final = default_to_regular(final)
|
||||
|
||||
async def _run_final_query():
|
||||
protocol = TPLinkSmartHomeProtocol(host)
|
||||
return await protocol.query(final_query)
|
||||
|
||||
try:
|
||||
final = asyncio.run(protocol.query(host, final_query))
|
||||
final = asyncio.run(_run_final_query())
|
||||
except Exception as ex:
|
||||
click.echo(
|
||||
click.style(
|
||||
|
@ -40,7 +40,6 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
|
||||
self.discovery_packets = discovery_packets
|
||||
self.interface = interface
|
||||
self.on_discovered = on_discovered
|
||||
self.protocol = TPLinkSmartHomeProtocol()
|
||||
self.target = (target, Discover.DISCOVERY_PORT)
|
||||
self.discovered_devices = {}
|
||||
|
||||
@ -61,7 +60,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
|
||||
"""Send number of discovery datagrams."""
|
||||
req = json.dumps(Discover.DISCOVERY_QUERY)
|
||||
_LOGGER.debug("[DISCOVERY] %s >> %s", self.target, Discover.DISCOVERY_QUERY)
|
||||
encrypted_req = self.protocol.encrypt(req)
|
||||
encrypted_req = TPLinkSmartHomeProtocol.encrypt(req)
|
||||
for i in range(self.discovery_packets):
|
||||
self.transport.sendto(encrypted_req[4:], self.target) # type: ignore
|
||||
|
||||
@ -71,7 +70,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
|
||||
if ip in self.discovered_devices:
|
||||
return
|
||||
|
||||
info = json.loads(self.protocol.decrypt(data))
|
||||
info = json.loads(TPLinkSmartHomeProtocol.decrypt(data))
|
||||
_LOGGER.debug("[DISCOVERY] %s << %s", ip, info)
|
||||
|
||||
device_class = Discover._get_device_class(info)
|
||||
@ -190,9 +189,9 @@ class Discover:
|
||||
:rtype: SmartDevice
|
||||
:return: Object for querying/controlling found device.
|
||||
"""
|
||||
protocol = TPLinkSmartHomeProtocol()
|
||||
protocol = TPLinkSmartHomeProtocol(host)
|
||||
|
||||
info = await protocol.query(host, Discover.DISCOVERY_QUERY)
|
||||
info = await protocol.query(Discover.DISCOVERY_QUERY)
|
||||
|
||||
device_class = Discover._get_device_class(info)
|
||||
dev = device_class(host)
|
||||
|
146
kasa/protocol.py
146
kasa/protocol.py
@ -10,11 +10,12 @@ which are licensed under the Apache License, Version 2.0
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
"""
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
import struct
|
||||
from pprint import pformat as pf
|
||||
from typing import Dict, Union
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from .exceptions import SmartDeviceException
|
||||
|
||||
@ -28,8 +29,26 @@ class TPLinkSmartHomeProtocol:
|
||||
DEFAULT_PORT = 9999
|
||||
DEFAULT_TIMEOUT = 5
|
||||
|
||||
@staticmethod
|
||||
async def query(host: str, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
||||
BLOCK_SIZE = 4
|
||||
|
||||
def __init__(self, host: str) -> None:
|
||||
"""Create a protocol object."""
|
||||
self.host = host
|
||||
self.reader: Optional[asyncio.StreamReader] = None
|
||||
self.writer: Optional[asyncio.StreamWriter] = None
|
||||
self.query_lock: Optional[asyncio.Lock] = None
|
||||
self.loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
|
||||
def _detect_event_loop_change(self) -> None:
|
||||
"""Check if this object has been reused betwen event loops."""
|
||||
loop = asyncio.get_running_loop()
|
||||
if not self.loop:
|
||||
self.loop = loop
|
||||
elif self.loop != loop:
|
||||
_LOGGER.warning("Detected protocol reuse between different event loop")
|
||||
self._reset()
|
||||
|
||||
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
||||
"""Request information from a TP-Link SmartHome Device.
|
||||
|
||||
:param str host: host name or ip address of the device
|
||||
@ -38,57 +57,106 @@ class TPLinkSmartHomeProtocol:
|
||||
:param retry_count: how many retries to do in case of failure
|
||||
:return: response dict
|
||||
"""
|
||||
self._detect_event_loop_change()
|
||||
|
||||
if not self.query_lock:
|
||||
self.query_lock = asyncio.Lock()
|
||||
|
||||
if isinstance(request, dict):
|
||||
request = json.dumps(request)
|
||||
assert isinstance(request, str)
|
||||
|
||||
timeout = TPLinkSmartHomeProtocol.DEFAULT_TIMEOUT
|
||||
writer = None
|
||||
|
||||
async with self.query_lock:
|
||||
return await self._query(request, retry_count, timeout)
|
||||
|
||||
async def _connect(self, timeout: int) -> bool:
|
||||
"""Try to connect or reconnect to the device."""
|
||||
if self.writer:
|
||||
return True
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
self.reader = self.writer = None
|
||||
task = asyncio.open_connection(
|
||||
self.host, TPLinkSmartHomeProtocol.DEFAULT_PORT
|
||||
)
|
||||
self.reader, self.writer = await asyncio.wait_for(task, timeout=timeout)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _execute_query(self, request: str) -> Dict:
|
||||
"""Execute a query on the device and wait for the response."""
|
||||
assert self.writer is not None
|
||||
assert self.reader is not None
|
||||
|
||||
_LOGGER.debug("> (%i) %s", len(request), request)
|
||||
self.writer.write(TPLinkSmartHomeProtocol.encrypt(request))
|
||||
await self.writer.drain()
|
||||
|
||||
packed_block_size = await self.reader.readexactly(self.BLOCK_SIZE)
|
||||
length = struct.unpack(">I", packed_block_size)[0]
|
||||
|
||||
buffer = await self.reader.readexactly(length)
|
||||
response = TPLinkSmartHomeProtocol.decrypt(buffer)
|
||||
json_payload = json.loads(response)
|
||||
_LOGGER.debug("< (%i) %s", len(response), pf(json_payload))
|
||||
return json_payload
|
||||
|
||||
async def close(self):
|
||||
"""Close the connection."""
|
||||
writer = self.writer
|
||||
self._reset()
|
||||
if writer:
|
||||
writer.close()
|
||||
with contextlib.suppress(Exception):
|
||||
await writer.wait_closed()
|
||||
|
||||
def _reset(self):
|
||||
"""Clear any varibles that should not survive between loops."""
|
||||
self.writer = None
|
||||
self.reader = None
|
||||
self.query_lock = None
|
||||
self.loop = None
|
||||
|
||||
async def _query(self, request: str, retry_count: int, timeout: int) -> Dict:
|
||||
"""Try to query a device."""
|
||||
for retry in range(retry_count + 1):
|
||||
try:
|
||||
task = asyncio.open_connection(
|
||||
host, TPLinkSmartHomeProtocol.DEFAULT_PORT
|
||||
)
|
||||
reader, writer = await asyncio.wait_for(task, timeout=timeout)
|
||||
_LOGGER.debug("> (%i) %s", len(request), request)
|
||||
writer.write(TPLinkSmartHomeProtocol.encrypt(request))
|
||||
await writer.drain()
|
||||
|
||||
buffer = bytes()
|
||||
# Some devices send responses with a length header of 0 and
|
||||
# terminate with a zero size chunk. Others send the length and
|
||||
# will hang if we attempt to read more data.
|
||||
length = -1
|
||||
while True:
|
||||
chunk = await reader.read(4096)
|
||||
if length == -1:
|
||||
length = struct.unpack(">I", chunk[0:4])[0]
|
||||
buffer += chunk
|
||||
if (length > 0 and len(buffer) >= length + 4) or not chunk:
|
||||
break
|
||||
|
||||
response = TPLinkSmartHomeProtocol.decrypt(buffer[4:])
|
||||
json_payload = json.loads(response)
|
||||
_LOGGER.debug("< (%i) %s", len(response), pf(json_payload))
|
||||
|
||||
return json_payload
|
||||
|
||||
except Exception as ex:
|
||||
if not await self._connect(timeout):
|
||||
await self.close()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up after %s retries", retry)
|
||||
raise SmartDeviceException(
|
||||
"Unable to query the device: %s" % ex
|
||||
f"Unable to connect to the device: {self.host}"
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
assert self.reader is not None
|
||||
assert self.writer is not None
|
||||
return await asyncio.wait_for(
|
||||
self._execute_query(request), timeout=timeout
|
||||
)
|
||||
except Exception as ex:
|
||||
await self.close()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up after %s retries", retry)
|
||||
raise SmartDeviceException(
|
||||
f"Unable to query the device: {ex}"
|
||||
) from ex
|
||||
|
||||
_LOGGER.debug("Unable to query the device, retrying: %s", ex)
|
||||
|
||||
finally:
|
||||
if writer:
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
|
||||
# make mypy happy, this should never be reached..
|
||||
await self.close()
|
||||
raise SmartDeviceException("Query reached somehow to unreachable")
|
||||
|
||||
def __del__(self):
|
||||
if self.writer and self.loop and self.loop.is_running():
|
||||
self.writer.close()
|
||||
self._reset()
|
||||
|
||||
@staticmethod
|
||||
def _xor_payload(unencrypted):
|
||||
key = TPLinkSmartHomeProtocol.INITIALIZATION_VECTOR
|
||||
|
@ -194,7 +194,7 @@ class SmartDevice:
|
||||
"""
|
||||
self.host = host
|
||||
|
||||
self.protocol = TPLinkSmartHomeProtocol()
|
||||
self.protocol = TPLinkSmartHomeProtocol(host)
|
||||
self.emeter_type = "emeter"
|
||||
_LOGGER.debug("Initializing %s of type %s", self.host, type(self))
|
||||
self._device_type = DeviceType.Unknown
|
||||
@ -234,7 +234,7 @@ class SmartDevice:
|
||||
request = self._create_request(target, cmd, arg, child_ids)
|
||||
|
||||
try:
|
||||
response = await self.protocol.query(host=self.host, request=request)
|
||||
response = await self.protocol.query(request=request)
|
||||
except Exception as ex:
|
||||
raise SmartDeviceException(f"Communication error on {target}:{cmd}") from ex
|
||||
|
||||
@ -272,7 +272,7 @@ class SmartDevice:
|
||||
"""Retrieve system information."""
|
||||
return await self._query_helper("system", "get_sysinfo")
|
||||
|
||||
async def update(self):
|
||||
async def update(self, update_children: bool = True):
|
||||
"""Query the device to update the data.
|
||||
|
||||
Needed for properties that are decorated with `requires_update`.
|
||||
@ -285,7 +285,7 @@ class SmartDevice:
|
||||
# See #105, #120, #161
|
||||
if self._last_update is None:
|
||||
_LOGGER.debug("Performing the initial update to obtain sysinfo")
|
||||
self._last_update = await self.protocol.query(self.host, req)
|
||||
self._last_update = await self.protocol.query(req)
|
||||
self._sys_info = self._last_update["system"]["get_sysinfo"]
|
||||
# If the device has no emeter, we are done for the initial update
|
||||
# Otherwise we will follow the regular code path to also query
|
||||
@ -299,7 +299,7 @@ class SmartDevice:
|
||||
)
|
||||
req.update(self._create_emeter_request())
|
||||
|
||||
self._last_update = await self.protocol.query(self.host, req)
|
||||
self._last_update = await self.protocol.query(req)
|
||||
self._sys_info = self._last_update["system"]["get_sysinfo"]
|
||||
|
||||
def update_from_discover_info(self, info):
|
||||
@ -383,8 +383,8 @@ class SmartDevice:
|
||||
loc["latitude"] = sys_info["latitude"]
|
||||
loc["longitude"] = sys_info["longitude"]
|
||||
elif "latitude_i" in sys_info and "longitude_i" in sys_info:
|
||||
loc["latitude"] = sys_info["latitude_i"]
|
||||
loc["longitude"] = sys_info["longitude_i"]
|
||||
loc["latitude"] = sys_info["latitude_i"] / 10000
|
||||
loc["longitude"] = sys_info["longitude_i"] / 10000
|
||||
else:
|
||||
_LOGGER.warning("Unsupported device location.")
|
||||
|
||||
|
@ -87,12 +87,12 @@ class SmartStrip(SmartDevice):
|
||||
"""Return if any of the outlets are on."""
|
||||
return any(plug.is_on for plug in self.children)
|
||||
|
||||
async def update(self):
|
||||
async def update(self, update_children: bool = True):
|
||||
"""Update some of the attributes.
|
||||
|
||||
Needed for methods that are decorated with `requires_update`.
|
||||
"""
|
||||
await super().update()
|
||||
await super().update(update_children)
|
||||
|
||||
# Initialize the child devices during the first update.
|
||||
if not self.children:
|
||||
@ -103,7 +103,7 @@ class SmartStrip(SmartDevice):
|
||||
SmartStripPlug(self.host, parent=self, child_id=child["id"])
|
||||
)
|
||||
|
||||
if self.has_emeter:
|
||||
if update_children and self.has_emeter:
|
||||
for plug in self.children:
|
||||
await plug.update()
|
||||
|
||||
@ -243,13 +243,13 @@ class SmartStripPlug(SmartPlug):
|
||||
self._sys_info = parent._sys_info
|
||||
self._device_type = DeviceType.StripSocket
|
||||
|
||||
async def update(self):
|
||||
async def update(self, update_children: bool = True):
|
||||
"""Query the device to update the data.
|
||||
|
||||
Needed for properties that are decorated with `requires_update`.
|
||||
"""
|
||||
self._last_update = await self.parent.protocol.query(
|
||||
self.host, self._create_emeter_request()
|
||||
self._create_emeter_request()
|
||||
)
|
||||
|
||||
def _create_request(
|
||||
|
@ -4,6 +4,7 @@ import json
|
||||
import os
|
||||
from os.path import basename
|
||||
from pathlib import Path, PurePath
|
||||
from typing import Dict
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest # type: ignore # see https://github.com/pytest-dev/pytest/issues/3342
|
||||
@ -39,6 +40,8 @@ WITH_EMETER = {"HS110", "HS300", "KP115", *BULBS}
|
||||
|
||||
ALL_DEVICES = BULBS.union(PLUGS).union(STRIPS).union(DIMMERS)
|
||||
|
||||
IP_MODEL_CACHE: Dict[str, str] = {}
|
||||
|
||||
|
||||
def filter_model(desc, filter):
|
||||
filtered = list()
|
||||
@ -137,23 +140,39 @@ def device_for_file(model):
|
||||
raise Exception("Unable to find type for %s", model)
|
||||
|
||||
|
||||
def get_device_for_file(file):
|
||||
async def _update_and_close(d):
|
||||
await d.update()
|
||||
await d.protocol.close()
|
||||
return d
|
||||
|
||||
|
||||
async def _discover_update_and_close(ip):
|
||||
d = await Discover.discover_single(ip)
|
||||
return await _update_and_close(d)
|
||||
|
||||
|
||||
async def get_device_for_file(file):
|
||||
# if the wanted file is not an absolute path, prepend the fixtures directory
|
||||
p = Path(file)
|
||||
if not p.is_absolute():
|
||||
p = Path(__file__).parent / "fixtures" / file
|
||||
|
||||
with open(p) as f:
|
||||
sysinfo = json.load(f)
|
||||
model = basename(file)
|
||||
p = device_for_file(model)(host="127.0.0.123")
|
||||
p.protocol = FakeTransportProtocol(sysinfo)
|
||||
asyncio.run(p.update())
|
||||
return p
|
||||
def load_file():
|
||||
with open(p) as f:
|
||||
return json.load(f)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
sysinfo = await loop.run_in_executor(None, load_file)
|
||||
|
||||
model = basename(file)
|
||||
d = device_for_file(model)(host="127.0.0.123")
|
||||
d.protocol = FakeTransportProtocol(sysinfo)
|
||||
await _update_and_close(d)
|
||||
return d
|
||||
|
||||
|
||||
@pytest.fixture(params=SUPPORTED_DEVICES, scope="session")
|
||||
def dev(request):
|
||||
@pytest.fixture(params=SUPPORTED_DEVICES)
|
||||
async def dev(request):
|
||||
"""Device fixture.
|
||||
|
||||
Provides a device (given --ip) or parametrized fixture for the supported devices.
|
||||
@ -163,14 +182,16 @@ def dev(request):
|
||||
|
||||
ip = request.config.getoption("--ip")
|
||||
if ip:
|
||||
d = asyncio.run(Discover.discover_single(ip))
|
||||
asyncio.run(d.update())
|
||||
if d.model in file:
|
||||
return d
|
||||
else:
|
||||
model = IP_MODEL_CACHE.get(ip)
|
||||
d = None
|
||||
if not model:
|
||||
d = await _discover_update_and_close(ip)
|
||||
IP_MODEL_CACHE[ip] = model = d.model
|
||||
if model not in file:
|
||||
pytest.skip(f"skipping file {file}")
|
||||
return d if d else await _discover_update_and_close(ip)
|
||||
|
||||
return get_device_for_file(file)
|
||||
return await get_device_for_file(file)
|
||||
|
||||
|
||||
@pytest.fixture(params=SUPPORTED_DEVICES, scope="session")
|
||||
|
@ -83,9 +83,19 @@ PLUG_SCHEMA = Schema(
|
||||
"icon_hash": str,
|
||||
"led_off": check_int_bool,
|
||||
"latitude": Any(All(float, Range(min=-90, max=90)), 0, None),
|
||||
"latitude_i": Any(All(float, Range(min=-90, max=90)), 0, None),
|
||||
"latitude_i": Any(
|
||||
All(int, Range(min=-900000, max=900000)),
|
||||
All(float, Range(min=-900000, max=900000)),
|
||||
0,
|
||||
None,
|
||||
),
|
||||
"longitude": Any(All(float, Range(min=-180, max=180)), 0, None),
|
||||
"longitude_i": Any(All(float, Range(min=-180, max=180)), 0, None),
|
||||
"longitude_i": Any(
|
||||
All(int, Range(min=-18000000, max=18000000)),
|
||||
All(float, Range(min=-18000000, max=18000000)),
|
||||
0,
|
||||
None,
|
||||
),
|
||||
"mac": check_mac,
|
||||
"model": str,
|
||||
"oemId": str,
|
||||
@ -117,17 +127,17 @@ LIGHT_STATE_SCHEMA = Schema(
|
||||
{
|
||||
"brightness": All(int, Range(min=0, max=100)),
|
||||
"color_temp": int,
|
||||
"hue": All(int, Range(min=0, max=255)),
|
||||
"hue": All(int, Range(min=0, max=360)),
|
||||
"mode": str,
|
||||
"on_off": check_int_bool,
|
||||
"saturation": All(int, Range(min=0, max=255)),
|
||||
"saturation": All(int, Range(min=0, max=100)),
|
||||
"dft_on_state": Optional(
|
||||
{
|
||||
"brightness": All(int, Range(min=0, max=100)),
|
||||
"color_temp": All(int, Range(min=0, max=9000)),
|
||||
"hue": All(int, Range(min=0, max=255)),
|
||||
"hue": All(int, Range(min=0, max=360)),
|
||||
"mode": str,
|
||||
"saturation": All(int, Range(min=0, max=255)),
|
||||
"saturation": All(int, Range(min=0, max=100)),
|
||||
}
|
||||
),
|
||||
"err_code": int,
|
||||
@ -276,6 +286,8 @@ TIME_MODULE = {
|
||||
class FakeTransportProtocol(TPLinkSmartHomeProtocol):
|
||||
def __init__(self, info):
|
||||
self.discovery_data = info
|
||||
self.writer = None
|
||||
self.reader = None
|
||||
proto = FakeTransportProtocol.baseproto
|
||||
|
||||
for target in info:
|
||||
@ -426,7 +438,7 @@ class FakeTransportProtocol(TPLinkSmartHomeProtocol):
|
||||
},
|
||||
}
|
||||
|
||||
async def query(self, host, request, port=9999):
|
||||
async def query(self, request, port=9999):
|
||||
proto = self.proto
|
||||
|
||||
# collect child ids from context
|
||||
|
@ -60,7 +60,7 @@ async def test_hsv(dev, turn_on):
|
||||
assert dev.is_color
|
||||
|
||||
hue, saturation, brightness = dev.hsv
|
||||
assert 0 <= hue <= 255
|
||||
assert 0 <= hue <= 360
|
||||
assert 0 <= saturation <= 100
|
||||
assert 0 <= brightness <= 100
|
||||
|
||||
|
@ -3,7 +3,7 @@ import sys
|
||||
|
||||
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
|
||||
|
||||
from kasa import DeviceType, Discover, SmartDevice, SmartDeviceException
|
||||
from kasa import DeviceType, Discover, SmartDevice, SmartDeviceException, protocol
|
||||
from kasa.discover import _DiscoverProtocol
|
||||
|
||||
from .conftest import bulb, dimmer, lightstrip, plug, pytestmark, strip
|
||||
@ -94,7 +94,8 @@ async def test_discover_datagram_received(mocker, discovery_data):
|
||||
"""Verify that datagram received fills discovered_devices."""
|
||||
proto = _DiscoverProtocol()
|
||||
mocker.patch("json.loads", return_value=discovery_data)
|
||||
mocker.patch.object(proto, "protocol")
|
||||
mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "encrypt")
|
||||
mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "decrypt")
|
||||
|
||||
addr = "127.0.0.1"
|
||||
proto.datagram_received("<placeholder data>", (addr, 1234))
|
||||
|
@ -1,4 +1,6 @@
|
||||
import json
|
||||
import struct
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
@ -21,11 +23,47 @@ async def test_protocol_retries(mocker, retry_count):
|
||||
|
||||
conn = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await TPLinkSmartHomeProtocol.query("127.0.0.1", {}, retry_count=retry_count)
|
||||
await TPLinkSmartHomeProtocol("127.0.0.1").query({}, retry_count=retry_count)
|
||||
|
||||
assert conn.call_count == retry_count + 1
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.version_info < (3, 8), reason="3.8 is first one with asyncmock")
|
||||
@pytest.mark.parametrize("retry_count", [1, 3, 5])
|
||||
async def test_protocol_reconnect(mocker, retry_count):
|
||||
remaining = retry_count
|
||||
encrypted = TPLinkSmartHomeProtocol.encrypt('{"great":"success"}')[
|
||||
TPLinkSmartHomeProtocol.BLOCK_SIZE :
|
||||
]
|
||||
|
||||
def _fail_one_less_than_retry_count(*_):
|
||||
nonlocal remaining
|
||||
remaining -= 1
|
||||
if remaining:
|
||||
raise Exception("Simulated write failure")
|
||||
|
||||
async def _mock_read(byte_count):
|
||||
nonlocal encrypted
|
||||
if byte_count == TPLinkSmartHomeProtocol.BLOCK_SIZE:
|
||||
return struct.pack(">I", len(encrypted))
|
||||
if byte_count == len(encrypted):
|
||||
return encrypted
|
||||
|
||||
raise ValueError(f"No mock for {byte_count}")
|
||||
|
||||
def aio_mock_writer(_, __):
|
||||
reader = mocker.patch("asyncio.StreamReader")
|
||||
writer = mocker.patch("asyncio.StreamWriter")
|
||||
mocker.patch.object(writer, "write", _fail_one_less_than_retry_count)
|
||||
mocker.patch.object(reader, "readexactly", _mock_read)
|
||||
return reader, writer
|
||||
|
||||
protocol = TPLinkSmartHomeProtocol("127.0.0.1")
|
||||
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
|
||||
response = await protocol.query({}, retry_count=retry_count)
|
||||
assert response == {"great": "success"}
|
||||
|
||||
|
||||
def test_encrypt():
|
||||
d = json.dumps({"foo": 1, "bar": 2})
|
||||
encrypted = TPLinkSmartHomeProtocol.encrypt(d)
|
||||
|
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
@ -8,7 +9,7 @@ from kasa.tests.conftest import get_device_for_file
|
||||
|
||||
def test_bulb_examples(mocker):
|
||||
"""Use KL130 (bulb with all features) to test the doctests."""
|
||||
p = get_device_for_file("KL130(US)_1.0.json")
|
||||
p = asyncio.run(get_device_for_file("KL130(US)_1.0.json"))
|
||||
mocker.patch("kasa.smartbulb.SmartBulb", return_value=p)
|
||||
mocker.patch("kasa.smartbulb.SmartBulb.update")
|
||||
res = xdoctest.doctest_module("kasa.smartbulb", "all")
|
||||
@ -17,7 +18,7 @@ def test_bulb_examples(mocker):
|
||||
|
||||
def test_smartdevice_examples(mocker):
|
||||
"""Use HS110 for emeter examples."""
|
||||
p = get_device_for_file("HS110(EU)_1.0_real.json")
|
||||
p = asyncio.run(get_device_for_file("HS110(EU)_1.0_real.json"))
|
||||
mocker.patch("kasa.smartdevice.SmartDevice", return_value=p)
|
||||
mocker.patch("kasa.smartdevice.SmartDevice.update")
|
||||
res = xdoctest.doctest_module("kasa.smartdevice", "all")
|
||||
@ -26,7 +27,7 @@ def test_smartdevice_examples(mocker):
|
||||
|
||||
def test_plug_examples(mocker):
|
||||
"""Test plug examples."""
|
||||
p = get_device_for_file("HS110(EU)_1.0_real.json")
|
||||
p = asyncio.run(get_device_for_file("HS110(EU)_1.0_real.json"))
|
||||
mocker.patch("kasa.smartplug.SmartPlug", return_value=p)
|
||||
mocker.patch("kasa.smartplug.SmartPlug.update")
|
||||
res = xdoctest.doctest_module("kasa.smartplug", "all")
|
||||
@ -35,7 +36,7 @@ def test_plug_examples(mocker):
|
||||
|
||||
def test_strip_examples(mocker):
|
||||
"""Test strip examples."""
|
||||
p = get_device_for_file("KP303(UK)_1.0.json")
|
||||
p = asyncio.run(get_device_for_file("KP303(UK)_1.0.json"))
|
||||
mocker.patch("kasa.smartstrip.SmartStrip", return_value=p)
|
||||
mocker.patch("kasa.smartstrip.SmartStrip.update")
|
||||
res = xdoctest.doctest_module("kasa.smartstrip", "all")
|
||||
@ -44,7 +45,7 @@ def test_strip_examples(mocker):
|
||||
|
||||
def test_dimmer_examples(mocker):
|
||||
"""Test dimmer examples."""
|
||||
p = get_device_for_file("HS220(US)_1.0_real.json")
|
||||
p = asyncio.run(get_device_for_file("HS220(US)_1.0_real.json"))
|
||||
mocker.patch("kasa.smartdimmer.SmartDimmer", return_value=p)
|
||||
mocker.patch("kasa.smartdimmer.SmartDimmer.update")
|
||||
res = xdoctest.doctest_module("kasa.smartdimmer", "all")
|
||||
@ -53,7 +54,7 @@ def test_dimmer_examples(mocker):
|
||||
|
||||
def test_lightstrip_examples(mocker):
|
||||
"""Test lightstrip examples."""
|
||||
p = get_device_for_file("KL430(US)_1.0.json")
|
||||
p = asyncio.run(get_device_for_file("KL430(US)_1.0.json"))
|
||||
mocker.patch("kasa.smartlightstrip.SmartLightStrip", return_value=p)
|
||||
mocker.patch("kasa.smartlightstrip.SmartLightStrip.update")
|
||||
res = xdoctest.doctest_module("kasa.smartlightstrip", "all")
|
||||
@ -65,7 +66,7 @@ def test_lightstrip_examples(mocker):
|
||||
)
|
||||
def test_discovery_examples(mocker):
|
||||
"""Test discovery examples."""
|
||||
p = get_device_for_file("KP303(UK)_1.0.json")
|
||||
p = asyncio.run(get_device_for_file("KP303(UK)_1.0.json"))
|
||||
|
||||
# This succeeds on python 3.8 but fails on 3.7
|
||||
# ValueError: a coroutine was expected, got [<DeviceType.Strip model KP303(UK) ...
|
||||
|
Loading…
Reference in New Issue
Block a user