mirror of
https://github.com/python-kasa/python-kasa.git
synced 2024-12-23 03:33:35 +00:00
Keep connection open and lock to prevent duplicate requests (#213)
* Keep connection open and lock to prevent duplicate requests * option to not update children * tweaks * typing * tweaks * run tests in the same event loop * memorize model * Update kasa/protocol.py Co-authored-by: Teemu R. <tpr@iki.fi> * Update kasa/protocol.py Co-authored-by: Teemu R. <tpr@iki.fi> * Update kasa/protocol.py Co-authored-by: Teemu R. <tpr@iki.fi> * Update kasa/protocol.py Co-authored-by: Teemu R. <tpr@iki.fi> * dry * tweaks * warn when the event loop gets switched out from under us * raise on unable to connect multiple times * fix patch target * tweaks * isrot * reconnect test * prune * fix mocking * fix mocking * fix test under python 3.7 * fix test under python 3.7 * less patching * isort * use mocker to patch * disable on old python since mocking doesnt work * avoid disconnect/reconnect cycles * isort * Fix hue validation * Fix latitude_i/longitude_i units Co-authored-by: Teemu R. <tpr@iki.fi>
This commit is contained in:
parent
f1b28e79b9
commit
e31cc6662c
@ -78,16 +78,17 @@ def cli(host, debug):
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
protocol = TPLinkSmartHomeProtocol()
|
|
||||||
|
|
||||||
successes = []
|
successes = []
|
||||||
|
|
||||||
for test_call in items:
|
for test_call in items:
|
||||||
|
|
||||||
|
async def _run_query():
|
||||||
|
protocol = TPLinkSmartHomeProtocol(host)
|
||||||
|
return await protocol.query({test_call.module: {test_call.method: None}})
|
||||||
|
|
||||||
try:
|
try:
|
||||||
click.echo(f"Testing {test_call}..", nl=False)
|
click.echo(f"Testing {test_call}..", nl=False)
|
||||||
info = asyncio.run(
|
info = asyncio.run(_run_query())
|
||||||
protocol.query(host, {test_call.module: {test_call.method: None}})
|
|
||||||
)
|
|
||||||
resp = info[test_call.module]
|
resp = info[test_call.module]
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
click.echo(click.style(f"FAIL {ex}", fg="red"))
|
click.echo(click.style(f"FAIL {ex}", fg="red"))
|
||||||
@ -107,8 +108,12 @@ def cli(host, debug):
|
|||||||
|
|
||||||
final = default_to_regular(final)
|
final = default_to_regular(final)
|
||||||
|
|
||||||
|
async def _run_final_query():
|
||||||
|
protocol = TPLinkSmartHomeProtocol(host)
|
||||||
|
return await protocol.query(final_query)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
final = asyncio.run(protocol.query(host, final_query))
|
final = asyncio.run(_run_final_query())
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
click.echo(
|
click.echo(
|
||||||
click.style(
|
click.style(
|
||||||
|
@ -40,7 +40,6 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
|
|||||||
self.discovery_packets = discovery_packets
|
self.discovery_packets = discovery_packets
|
||||||
self.interface = interface
|
self.interface = interface
|
||||||
self.on_discovered = on_discovered
|
self.on_discovered = on_discovered
|
||||||
self.protocol = TPLinkSmartHomeProtocol()
|
|
||||||
self.target = (target, Discover.DISCOVERY_PORT)
|
self.target = (target, Discover.DISCOVERY_PORT)
|
||||||
self.discovered_devices = {}
|
self.discovered_devices = {}
|
||||||
|
|
||||||
@ -61,7 +60,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
|
|||||||
"""Send number of discovery datagrams."""
|
"""Send number of discovery datagrams."""
|
||||||
req = json.dumps(Discover.DISCOVERY_QUERY)
|
req = json.dumps(Discover.DISCOVERY_QUERY)
|
||||||
_LOGGER.debug("[DISCOVERY] %s >> %s", self.target, Discover.DISCOVERY_QUERY)
|
_LOGGER.debug("[DISCOVERY] %s >> %s", self.target, Discover.DISCOVERY_QUERY)
|
||||||
encrypted_req = self.protocol.encrypt(req)
|
encrypted_req = TPLinkSmartHomeProtocol.encrypt(req)
|
||||||
for i in range(self.discovery_packets):
|
for i in range(self.discovery_packets):
|
||||||
self.transport.sendto(encrypted_req[4:], self.target) # type: ignore
|
self.transport.sendto(encrypted_req[4:], self.target) # type: ignore
|
||||||
|
|
||||||
@ -71,7 +70,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
|
|||||||
if ip in self.discovered_devices:
|
if ip in self.discovered_devices:
|
||||||
return
|
return
|
||||||
|
|
||||||
info = json.loads(self.protocol.decrypt(data))
|
info = json.loads(TPLinkSmartHomeProtocol.decrypt(data))
|
||||||
_LOGGER.debug("[DISCOVERY] %s << %s", ip, info)
|
_LOGGER.debug("[DISCOVERY] %s << %s", ip, info)
|
||||||
|
|
||||||
device_class = Discover._get_device_class(info)
|
device_class = Discover._get_device_class(info)
|
||||||
@ -190,9 +189,9 @@ class Discover:
|
|||||||
:rtype: SmartDevice
|
:rtype: SmartDevice
|
||||||
:return: Object for querying/controlling found device.
|
:return: Object for querying/controlling found device.
|
||||||
"""
|
"""
|
||||||
protocol = TPLinkSmartHomeProtocol()
|
protocol = TPLinkSmartHomeProtocol(host)
|
||||||
|
|
||||||
info = await protocol.query(host, Discover.DISCOVERY_QUERY)
|
info = await protocol.query(Discover.DISCOVERY_QUERY)
|
||||||
|
|
||||||
device_class = Discover._get_device_class(info)
|
device_class = Discover._get_device_class(info)
|
||||||
dev = device_class(host)
|
dev = device_class(host)
|
||||||
|
130
kasa/protocol.py
130
kasa/protocol.py
@ -10,11 +10,12 @@ which are licensed under the Apache License, Version 2.0
|
|||||||
http://www.apache.org/licenses/LICENSE-2.0
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import contextlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import struct
|
import struct
|
||||||
from pprint import pformat as pf
|
from pprint import pformat as pf
|
||||||
from typing import Dict, Union
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
from .exceptions import SmartDeviceException
|
from .exceptions import SmartDeviceException
|
||||||
|
|
||||||
@ -28,8 +29,26 @@ class TPLinkSmartHomeProtocol:
|
|||||||
DEFAULT_PORT = 9999
|
DEFAULT_PORT = 9999
|
||||||
DEFAULT_TIMEOUT = 5
|
DEFAULT_TIMEOUT = 5
|
||||||
|
|
||||||
@staticmethod
|
BLOCK_SIZE = 4
|
||||||
async def query(host: str, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
|
||||||
|
def __init__(self, host: str) -> None:
|
||||||
|
"""Create a protocol object."""
|
||||||
|
self.host = host
|
||||||
|
self.reader: Optional[asyncio.StreamReader] = None
|
||||||
|
self.writer: Optional[asyncio.StreamWriter] = None
|
||||||
|
self.query_lock: Optional[asyncio.Lock] = None
|
||||||
|
self.loop: Optional[asyncio.AbstractEventLoop] = None
|
||||||
|
|
||||||
|
def _detect_event_loop_change(self) -> None:
|
||||||
|
"""Check if this object has been reused betwen event loops."""
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
if not self.loop:
|
||||||
|
self.loop = loop
|
||||||
|
elif self.loop != loop:
|
||||||
|
_LOGGER.warning("Detected protocol reuse between different event loop")
|
||||||
|
self._reset()
|
||||||
|
|
||||||
|
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
||||||
"""Request information from a TP-Link SmartHome Device.
|
"""Request information from a TP-Link SmartHome Device.
|
||||||
|
|
||||||
:param str host: host name or ip address of the device
|
:param str host: host name or ip address of the device
|
||||||
@ -38,57 +57,106 @@ class TPLinkSmartHomeProtocol:
|
|||||||
:param retry_count: how many retries to do in case of failure
|
:param retry_count: how many retries to do in case of failure
|
||||||
:return: response dict
|
:return: response dict
|
||||||
"""
|
"""
|
||||||
|
self._detect_event_loop_change()
|
||||||
|
|
||||||
|
if not self.query_lock:
|
||||||
|
self.query_lock = asyncio.Lock()
|
||||||
|
|
||||||
if isinstance(request, dict):
|
if isinstance(request, dict):
|
||||||
request = json.dumps(request)
|
request = json.dumps(request)
|
||||||
|
assert isinstance(request, str)
|
||||||
|
|
||||||
timeout = TPLinkSmartHomeProtocol.DEFAULT_TIMEOUT
|
timeout = TPLinkSmartHomeProtocol.DEFAULT_TIMEOUT
|
||||||
writer = None
|
|
||||||
for retry in range(retry_count + 1):
|
async with self.query_lock:
|
||||||
try:
|
return await self._query(request, retry_count, timeout)
|
||||||
|
|
||||||
|
async def _connect(self, timeout: int) -> bool:
|
||||||
|
"""Try to connect or reconnect to the device."""
|
||||||
|
if self.writer:
|
||||||
|
return True
|
||||||
|
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
self.reader = self.writer = None
|
||||||
task = asyncio.open_connection(
|
task = asyncio.open_connection(
|
||||||
host, TPLinkSmartHomeProtocol.DEFAULT_PORT
|
self.host, TPLinkSmartHomeProtocol.DEFAULT_PORT
|
||||||
)
|
)
|
||||||
reader, writer = await asyncio.wait_for(task, timeout=timeout)
|
self.reader, self.writer = await asyncio.wait_for(task, timeout=timeout)
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _execute_query(self, request: str) -> Dict:
|
||||||
|
"""Execute a query on the device and wait for the response."""
|
||||||
|
assert self.writer is not None
|
||||||
|
assert self.reader is not None
|
||||||
|
|
||||||
_LOGGER.debug("> (%i) %s", len(request), request)
|
_LOGGER.debug("> (%i) %s", len(request), request)
|
||||||
writer.write(TPLinkSmartHomeProtocol.encrypt(request))
|
self.writer.write(TPLinkSmartHomeProtocol.encrypt(request))
|
||||||
await writer.drain()
|
await self.writer.drain()
|
||||||
|
|
||||||
buffer = bytes()
|
packed_block_size = await self.reader.readexactly(self.BLOCK_SIZE)
|
||||||
# Some devices send responses with a length header of 0 and
|
length = struct.unpack(">I", packed_block_size)[0]
|
||||||
# terminate with a zero size chunk. Others send the length and
|
|
||||||
# will hang if we attempt to read more data.
|
|
||||||
length = -1
|
|
||||||
while True:
|
|
||||||
chunk = await reader.read(4096)
|
|
||||||
if length == -1:
|
|
||||||
length = struct.unpack(">I", chunk[0:4])[0]
|
|
||||||
buffer += chunk
|
|
||||||
if (length > 0 and len(buffer) >= length + 4) or not chunk:
|
|
||||||
break
|
|
||||||
|
|
||||||
response = TPLinkSmartHomeProtocol.decrypt(buffer[4:])
|
buffer = await self.reader.readexactly(length)
|
||||||
|
response = TPLinkSmartHomeProtocol.decrypt(buffer)
|
||||||
json_payload = json.loads(response)
|
json_payload = json.loads(response)
|
||||||
_LOGGER.debug("< (%i) %s", len(response), pf(json_payload))
|
_LOGGER.debug("< (%i) %s", len(response), pf(json_payload))
|
||||||
|
|
||||||
return json_payload
|
return json_payload
|
||||||
|
|
||||||
except Exception as ex:
|
async def close(self):
|
||||||
|
"""Close the connection."""
|
||||||
|
writer = self.writer
|
||||||
|
self._reset()
|
||||||
|
if writer:
|
||||||
|
writer.close()
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
await writer.wait_closed()
|
||||||
|
|
||||||
|
def _reset(self):
|
||||||
|
"""Clear any varibles that should not survive between loops."""
|
||||||
|
self.writer = None
|
||||||
|
self.reader = None
|
||||||
|
self.query_lock = None
|
||||||
|
self.loop = None
|
||||||
|
|
||||||
|
async def _query(self, request: str, retry_count: int, timeout: int) -> Dict:
|
||||||
|
"""Try to query a device."""
|
||||||
|
for retry in range(retry_count + 1):
|
||||||
|
if not await self._connect(timeout):
|
||||||
|
await self.close()
|
||||||
if retry >= retry_count:
|
if retry >= retry_count:
|
||||||
_LOGGER.debug("Giving up after %s retries", retry)
|
_LOGGER.debug("Giving up after %s retries", retry)
|
||||||
raise SmartDeviceException(
|
raise SmartDeviceException(
|
||||||
"Unable to query the device: %s" % ex
|
f"Unable to connect to the device: {self.host}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
assert self.reader is not None
|
||||||
|
assert self.writer is not None
|
||||||
|
return await asyncio.wait_for(
|
||||||
|
self._execute_query(request), timeout=timeout
|
||||||
|
)
|
||||||
|
except Exception as ex:
|
||||||
|
await self.close()
|
||||||
|
if retry >= retry_count:
|
||||||
|
_LOGGER.debug("Giving up after %s retries", retry)
|
||||||
|
raise SmartDeviceException(
|
||||||
|
f"Unable to query the device: {ex}"
|
||||||
) from ex
|
) from ex
|
||||||
|
|
||||||
_LOGGER.debug("Unable to query the device, retrying: %s", ex)
|
_LOGGER.debug("Unable to query the device, retrying: %s", ex)
|
||||||
|
|
||||||
finally:
|
|
||||||
if writer:
|
|
||||||
writer.close()
|
|
||||||
await writer.wait_closed()
|
|
||||||
|
|
||||||
# make mypy happy, this should never be reached..
|
# make mypy happy, this should never be reached..
|
||||||
|
await self.close()
|
||||||
raise SmartDeviceException("Query reached somehow to unreachable")
|
raise SmartDeviceException("Query reached somehow to unreachable")
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
if self.writer and self.loop and self.loop.is_running():
|
||||||
|
self.writer.close()
|
||||||
|
self._reset()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _xor_payload(unencrypted):
|
def _xor_payload(unencrypted):
|
||||||
key = TPLinkSmartHomeProtocol.INITIALIZATION_VECTOR
|
key = TPLinkSmartHomeProtocol.INITIALIZATION_VECTOR
|
||||||
|
@ -194,7 +194,7 @@ class SmartDevice:
|
|||||||
"""
|
"""
|
||||||
self.host = host
|
self.host = host
|
||||||
|
|
||||||
self.protocol = TPLinkSmartHomeProtocol()
|
self.protocol = TPLinkSmartHomeProtocol(host)
|
||||||
self.emeter_type = "emeter"
|
self.emeter_type = "emeter"
|
||||||
_LOGGER.debug("Initializing %s of type %s", self.host, type(self))
|
_LOGGER.debug("Initializing %s of type %s", self.host, type(self))
|
||||||
self._device_type = DeviceType.Unknown
|
self._device_type = DeviceType.Unknown
|
||||||
@ -234,7 +234,7 @@ class SmartDevice:
|
|||||||
request = self._create_request(target, cmd, arg, child_ids)
|
request = self._create_request(target, cmd, arg, child_ids)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await self.protocol.query(host=self.host, request=request)
|
response = await self.protocol.query(request=request)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise SmartDeviceException(f"Communication error on {target}:{cmd}") from ex
|
raise SmartDeviceException(f"Communication error on {target}:{cmd}") from ex
|
||||||
|
|
||||||
@ -272,7 +272,7 @@ class SmartDevice:
|
|||||||
"""Retrieve system information."""
|
"""Retrieve system information."""
|
||||||
return await self._query_helper("system", "get_sysinfo")
|
return await self._query_helper("system", "get_sysinfo")
|
||||||
|
|
||||||
async def update(self):
|
async def update(self, update_children: bool = True):
|
||||||
"""Query the device to update the data.
|
"""Query the device to update the data.
|
||||||
|
|
||||||
Needed for properties that are decorated with `requires_update`.
|
Needed for properties that are decorated with `requires_update`.
|
||||||
@ -285,7 +285,7 @@ class SmartDevice:
|
|||||||
# See #105, #120, #161
|
# See #105, #120, #161
|
||||||
if self._last_update is None:
|
if self._last_update is None:
|
||||||
_LOGGER.debug("Performing the initial update to obtain sysinfo")
|
_LOGGER.debug("Performing the initial update to obtain sysinfo")
|
||||||
self._last_update = await self.protocol.query(self.host, req)
|
self._last_update = await self.protocol.query(req)
|
||||||
self._sys_info = self._last_update["system"]["get_sysinfo"]
|
self._sys_info = self._last_update["system"]["get_sysinfo"]
|
||||||
# If the device has no emeter, we are done for the initial update
|
# If the device has no emeter, we are done for the initial update
|
||||||
# Otherwise we will follow the regular code path to also query
|
# Otherwise we will follow the regular code path to also query
|
||||||
@ -299,7 +299,7 @@ class SmartDevice:
|
|||||||
)
|
)
|
||||||
req.update(self._create_emeter_request())
|
req.update(self._create_emeter_request())
|
||||||
|
|
||||||
self._last_update = await self.protocol.query(self.host, req)
|
self._last_update = await self.protocol.query(req)
|
||||||
self._sys_info = self._last_update["system"]["get_sysinfo"]
|
self._sys_info = self._last_update["system"]["get_sysinfo"]
|
||||||
|
|
||||||
def update_from_discover_info(self, info):
|
def update_from_discover_info(self, info):
|
||||||
@ -383,8 +383,8 @@ class SmartDevice:
|
|||||||
loc["latitude"] = sys_info["latitude"]
|
loc["latitude"] = sys_info["latitude"]
|
||||||
loc["longitude"] = sys_info["longitude"]
|
loc["longitude"] = sys_info["longitude"]
|
||||||
elif "latitude_i" in sys_info and "longitude_i" in sys_info:
|
elif "latitude_i" in sys_info and "longitude_i" in sys_info:
|
||||||
loc["latitude"] = sys_info["latitude_i"]
|
loc["latitude"] = sys_info["latitude_i"] / 10000
|
||||||
loc["longitude"] = sys_info["longitude_i"]
|
loc["longitude"] = sys_info["longitude_i"] / 10000
|
||||||
else:
|
else:
|
||||||
_LOGGER.warning("Unsupported device location.")
|
_LOGGER.warning("Unsupported device location.")
|
||||||
|
|
||||||
|
@ -87,12 +87,12 @@ class SmartStrip(SmartDevice):
|
|||||||
"""Return if any of the outlets are on."""
|
"""Return if any of the outlets are on."""
|
||||||
return any(plug.is_on for plug in self.children)
|
return any(plug.is_on for plug in self.children)
|
||||||
|
|
||||||
async def update(self):
|
async def update(self, update_children: bool = True):
|
||||||
"""Update some of the attributes.
|
"""Update some of the attributes.
|
||||||
|
|
||||||
Needed for methods that are decorated with `requires_update`.
|
Needed for methods that are decorated with `requires_update`.
|
||||||
"""
|
"""
|
||||||
await super().update()
|
await super().update(update_children)
|
||||||
|
|
||||||
# Initialize the child devices during the first update.
|
# Initialize the child devices during the first update.
|
||||||
if not self.children:
|
if not self.children:
|
||||||
@ -103,7 +103,7 @@ class SmartStrip(SmartDevice):
|
|||||||
SmartStripPlug(self.host, parent=self, child_id=child["id"])
|
SmartStripPlug(self.host, parent=self, child_id=child["id"])
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.has_emeter:
|
if update_children and self.has_emeter:
|
||||||
for plug in self.children:
|
for plug in self.children:
|
||||||
await plug.update()
|
await plug.update()
|
||||||
|
|
||||||
@ -243,13 +243,13 @@ class SmartStripPlug(SmartPlug):
|
|||||||
self._sys_info = parent._sys_info
|
self._sys_info = parent._sys_info
|
||||||
self._device_type = DeviceType.StripSocket
|
self._device_type = DeviceType.StripSocket
|
||||||
|
|
||||||
async def update(self):
|
async def update(self, update_children: bool = True):
|
||||||
"""Query the device to update the data.
|
"""Query the device to update the data.
|
||||||
|
|
||||||
Needed for properties that are decorated with `requires_update`.
|
Needed for properties that are decorated with `requires_update`.
|
||||||
"""
|
"""
|
||||||
self._last_update = await self.parent.protocol.query(
|
self._last_update = await self.parent.protocol.query(
|
||||||
self.host, self._create_emeter_request()
|
self._create_emeter_request()
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_request(
|
def _create_request(
|
||||||
|
@ -4,6 +4,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
from os.path import basename
|
from os.path import basename
|
||||||
from pathlib import Path, PurePath
|
from pathlib import Path, PurePath
|
||||||
|
from typing import Dict
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest # type: ignore # see https://github.com/pytest-dev/pytest/issues/3342
|
import pytest # type: ignore # see https://github.com/pytest-dev/pytest/issues/3342
|
||||||
@ -39,6 +40,8 @@ WITH_EMETER = {"HS110", "HS300", "KP115", *BULBS}
|
|||||||
|
|
||||||
ALL_DEVICES = BULBS.union(PLUGS).union(STRIPS).union(DIMMERS)
|
ALL_DEVICES = BULBS.union(PLUGS).union(STRIPS).union(DIMMERS)
|
||||||
|
|
||||||
|
IP_MODEL_CACHE: Dict[str, str] = {}
|
||||||
|
|
||||||
|
|
||||||
def filter_model(desc, filter):
|
def filter_model(desc, filter):
|
||||||
filtered = list()
|
filtered = list()
|
||||||
@ -137,23 +140,39 @@ def device_for_file(model):
|
|||||||
raise Exception("Unable to find type for %s", model)
|
raise Exception("Unable to find type for %s", model)
|
||||||
|
|
||||||
|
|
||||||
def get_device_for_file(file):
|
async def _update_and_close(d):
|
||||||
|
await d.update()
|
||||||
|
await d.protocol.close()
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
async def _discover_update_and_close(ip):
|
||||||
|
d = await Discover.discover_single(ip)
|
||||||
|
return await _update_and_close(d)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_device_for_file(file):
|
||||||
# if the wanted file is not an absolute path, prepend the fixtures directory
|
# if the wanted file is not an absolute path, prepend the fixtures directory
|
||||||
p = Path(file)
|
p = Path(file)
|
||||||
if not p.is_absolute():
|
if not p.is_absolute():
|
||||||
p = Path(__file__).parent / "fixtures" / file
|
p = Path(__file__).parent / "fixtures" / file
|
||||||
|
|
||||||
|
def load_file():
|
||||||
with open(p) as f:
|
with open(p) as f:
|
||||||
sysinfo = json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
sysinfo = await loop.run_in_executor(None, load_file)
|
||||||
|
|
||||||
model = basename(file)
|
model = basename(file)
|
||||||
p = device_for_file(model)(host="127.0.0.123")
|
d = device_for_file(model)(host="127.0.0.123")
|
||||||
p.protocol = FakeTransportProtocol(sysinfo)
|
d.protocol = FakeTransportProtocol(sysinfo)
|
||||||
asyncio.run(p.update())
|
await _update_and_close(d)
|
||||||
return p
|
return d
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(params=SUPPORTED_DEVICES, scope="session")
|
@pytest.fixture(params=SUPPORTED_DEVICES)
|
||||||
def dev(request):
|
async def dev(request):
|
||||||
"""Device fixture.
|
"""Device fixture.
|
||||||
|
|
||||||
Provides a device (given --ip) or parametrized fixture for the supported devices.
|
Provides a device (given --ip) or parametrized fixture for the supported devices.
|
||||||
@ -163,14 +182,16 @@ def dev(request):
|
|||||||
|
|
||||||
ip = request.config.getoption("--ip")
|
ip = request.config.getoption("--ip")
|
||||||
if ip:
|
if ip:
|
||||||
d = asyncio.run(Discover.discover_single(ip))
|
model = IP_MODEL_CACHE.get(ip)
|
||||||
asyncio.run(d.update())
|
d = None
|
||||||
if d.model in file:
|
if not model:
|
||||||
return d
|
d = await _discover_update_and_close(ip)
|
||||||
else:
|
IP_MODEL_CACHE[ip] = model = d.model
|
||||||
|
if model not in file:
|
||||||
pytest.skip(f"skipping file {file}")
|
pytest.skip(f"skipping file {file}")
|
||||||
|
return d if d else await _discover_update_and_close(ip)
|
||||||
|
|
||||||
return get_device_for_file(file)
|
return await get_device_for_file(file)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(params=SUPPORTED_DEVICES, scope="session")
|
@pytest.fixture(params=SUPPORTED_DEVICES, scope="session")
|
||||||
|
@ -83,9 +83,19 @@ PLUG_SCHEMA = Schema(
|
|||||||
"icon_hash": str,
|
"icon_hash": str,
|
||||||
"led_off": check_int_bool,
|
"led_off": check_int_bool,
|
||||||
"latitude": Any(All(float, Range(min=-90, max=90)), 0, None),
|
"latitude": Any(All(float, Range(min=-90, max=90)), 0, None),
|
||||||
"latitude_i": Any(All(float, Range(min=-90, max=90)), 0, None),
|
"latitude_i": Any(
|
||||||
|
All(int, Range(min=-900000, max=900000)),
|
||||||
|
All(float, Range(min=-900000, max=900000)),
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
),
|
||||||
"longitude": Any(All(float, Range(min=-180, max=180)), 0, None),
|
"longitude": Any(All(float, Range(min=-180, max=180)), 0, None),
|
||||||
"longitude_i": Any(All(float, Range(min=-180, max=180)), 0, None),
|
"longitude_i": Any(
|
||||||
|
All(int, Range(min=-18000000, max=18000000)),
|
||||||
|
All(float, Range(min=-18000000, max=18000000)),
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
),
|
||||||
"mac": check_mac,
|
"mac": check_mac,
|
||||||
"model": str,
|
"model": str,
|
||||||
"oemId": str,
|
"oemId": str,
|
||||||
@ -117,17 +127,17 @@ LIGHT_STATE_SCHEMA = Schema(
|
|||||||
{
|
{
|
||||||
"brightness": All(int, Range(min=0, max=100)),
|
"brightness": All(int, Range(min=0, max=100)),
|
||||||
"color_temp": int,
|
"color_temp": int,
|
||||||
"hue": All(int, Range(min=0, max=255)),
|
"hue": All(int, Range(min=0, max=360)),
|
||||||
"mode": str,
|
"mode": str,
|
||||||
"on_off": check_int_bool,
|
"on_off": check_int_bool,
|
||||||
"saturation": All(int, Range(min=0, max=255)),
|
"saturation": All(int, Range(min=0, max=100)),
|
||||||
"dft_on_state": Optional(
|
"dft_on_state": Optional(
|
||||||
{
|
{
|
||||||
"brightness": All(int, Range(min=0, max=100)),
|
"brightness": All(int, Range(min=0, max=100)),
|
||||||
"color_temp": All(int, Range(min=0, max=9000)),
|
"color_temp": All(int, Range(min=0, max=9000)),
|
||||||
"hue": All(int, Range(min=0, max=255)),
|
"hue": All(int, Range(min=0, max=360)),
|
||||||
"mode": str,
|
"mode": str,
|
||||||
"saturation": All(int, Range(min=0, max=255)),
|
"saturation": All(int, Range(min=0, max=100)),
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
"err_code": int,
|
"err_code": int,
|
||||||
@ -276,6 +286,8 @@ TIME_MODULE = {
|
|||||||
class FakeTransportProtocol(TPLinkSmartHomeProtocol):
|
class FakeTransportProtocol(TPLinkSmartHomeProtocol):
|
||||||
def __init__(self, info):
|
def __init__(self, info):
|
||||||
self.discovery_data = info
|
self.discovery_data = info
|
||||||
|
self.writer = None
|
||||||
|
self.reader = None
|
||||||
proto = FakeTransportProtocol.baseproto
|
proto = FakeTransportProtocol.baseproto
|
||||||
|
|
||||||
for target in info:
|
for target in info:
|
||||||
@ -426,7 +438,7 @@ class FakeTransportProtocol(TPLinkSmartHomeProtocol):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
async def query(self, host, request, port=9999):
|
async def query(self, request, port=9999):
|
||||||
proto = self.proto
|
proto = self.proto
|
||||||
|
|
||||||
# collect child ids from context
|
# collect child ids from context
|
||||||
|
@ -60,7 +60,7 @@ async def test_hsv(dev, turn_on):
|
|||||||
assert dev.is_color
|
assert dev.is_color
|
||||||
|
|
||||||
hue, saturation, brightness = dev.hsv
|
hue, saturation, brightness = dev.hsv
|
||||||
assert 0 <= hue <= 255
|
assert 0 <= hue <= 360
|
||||||
assert 0 <= saturation <= 100
|
assert 0 <= saturation <= 100
|
||||||
assert 0 <= brightness <= 100
|
assert 0 <= brightness <= 100
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ import sys
|
|||||||
|
|
||||||
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
|
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
|
||||||
|
|
||||||
from kasa import DeviceType, Discover, SmartDevice, SmartDeviceException
|
from kasa import DeviceType, Discover, SmartDevice, SmartDeviceException, protocol
|
||||||
from kasa.discover import _DiscoverProtocol
|
from kasa.discover import _DiscoverProtocol
|
||||||
|
|
||||||
from .conftest import bulb, dimmer, lightstrip, plug, pytestmark, strip
|
from .conftest import bulb, dimmer, lightstrip, plug, pytestmark, strip
|
||||||
@ -94,7 +94,8 @@ async def test_discover_datagram_received(mocker, discovery_data):
|
|||||||
"""Verify that datagram received fills discovered_devices."""
|
"""Verify that datagram received fills discovered_devices."""
|
||||||
proto = _DiscoverProtocol()
|
proto = _DiscoverProtocol()
|
||||||
mocker.patch("json.loads", return_value=discovery_data)
|
mocker.patch("json.loads", return_value=discovery_data)
|
||||||
mocker.patch.object(proto, "protocol")
|
mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "encrypt")
|
||||||
|
mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "decrypt")
|
||||||
|
|
||||||
addr = "127.0.0.1"
|
addr = "127.0.0.1"
|
||||||
proto.datagram_received("<placeholder data>", (addr, 1234))
|
proto.datagram_received("<placeholder data>", (addr, 1234))
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
|
import struct
|
||||||
|
import sys
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -21,11 +23,47 @@ async def test_protocol_retries(mocker, retry_count):
|
|||||||
|
|
||||||
conn = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
|
conn = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
|
||||||
with pytest.raises(SmartDeviceException):
|
with pytest.raises(SmartDeviceException):
|
||||||
await TPLinkSmartHomeProtocol.query("127.0.0.1", {}, retry_count=retry_count)
|
await TPLinkSmartHomeProtocol("127.0.0.1").query({}, retry_count=retry_count)
|
||||||
|
|
||||||
assert conn.call_count == retry_count + 1
|
assert conn.call_count == retry_count + 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(sys.version_info < (3, 8), reason="3.8 is first one with asyncmock")
|
||||||
|
@pytest.mark.parametrize("retry_count", [1, 3, 5])
|
||||||
|
async def test_protocol_reconnect(mocker, retry_count):
|
||||||
|
remaining = retry_count
|
||||||
|
encrypted = TPLinkSmartHomeProtocol.encrypt('{"great":"success"}')[
|
||||||
|
TPLinkSmartHomeProtocol.BLOCK_SIZE :
|
||||||
|
]
|
||||||
|
|
||||||
|
def _fail_one_less_than_retry_count(*_):
|
||||||
|
nonlocal remaining
|
||||||
|
remaining -= 1
|
||||||
|
if remaining:
|
||||||
|
raise Exception("Simulated write failure")
|
||||||
|
|
||||||
|
async def _mock_read(byte_count):
|
||||||
|
nonlocal encrypted
|
||||||
|
if byte_count == TPLinkSmartHomeProtocol.BLOCK_SIZE:
|
||||||
|
return struct.pack(">I", len(encrypted))
|
||||||
|
if byte_count == len(encrypted):
|
||||||
|
return encrypted
|
||||||
|
|
||||||
|
raise ValueError(f"No mock for {byte_count}")
|
||||||
|
|
||||||
|
def aio_mock_writer(_, __):
|
||||||
|
reader = mocker.patch("asyncio.StreamReader")
|
||||||
|
writer = mocker.patch("asyncio.StreamWriter")
|
||||||
|
mocker.patch.object(writer, "write", _fail_one_less_than_retry_count)
|
||||||
|
mocker.patch.object(reader, "readexactly", _mock_read)
|
||||||
|
return reader, writer
|
||||||
|
|
||||||
|
protocol = TPLinkSmartHomeProtocol("127.0.0.1")
|
||||||
|
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
|
||||||
|
response = await protocol.query({}, retry_count=retry_count)
|
||||||
|
assert response == {"great": "success"}
|
||||||
|
|
||||||
|
|
||||||
def test_encrypt():
|
def test_encrypt():
|
||||||
d = json.dumps({"foo": 1, "bar": 2})
|
d = json.dumps({"foo": 1, "bar": 2})
|
||||||
encrypted = TPLinkSmartHomeProtocol.encrypt(d)
|
encrypted = TPLinkSmartHomeProtocol.encrypt(d)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -8,7 +9,7 @@ from kasa.tests.conftest import get_device_for_file
|
|||||||
|
|
||||||
def test_bulb_examples(mocker):
|
def test_bulb_examples(mocker):
|
||||||
"""Use KL130 (bulb with all features) to test the doctests."""
|
"""Use KL130 (bulb with all features) to test the doctests."""
|
||||||
p = get_device_for_file("KL130(US)_1.0.json")
|
p = asyncio.run(get_device_for_file("KL130(US)_1.0.json"))
|
||||||
mocker.patch("kasa.smartbulb.SmartBulb", return_value=p)
|
mocker.patch("kasa.smartbulb.SmartBulb", return_value=p)
|
||||||
mocker.patch("kasa.smartbulb.SmartBulb.update")
|
mocker.patch("kasa.smartbulb.SmartBulb.update")
|
||||||
res = xdoctest.doctest_module("kasa.smartbulb", "all")
|
res = xdoctest.doctest_module("kasa.smartbulb", "all")
|
||||||
@ -17,7 +18,7 @@ def test_bulb_examples(mocker):
|
|||||||
|
|
||||||
def test_smartdevice_examples(mocker):
|
def test_smartdevice_examples(mocker):
|
||||||
"""Use HS110 for emeter examples."""
|
"""Use HS110 for emeter examples."""
|
||||||
p = get_device_for_file("HS110(EU)_1.0_real.json")
|
p = asyncio.run(get_device_for_file("HS110(EU)_1.0_real.json"))
|
||||||
mocker.patch("kasa.smartdevice.SmartDevice", return_value=p)
|
mocker.patch("kasa.smartdevice.SmartDevice", return_value=p)
|
||||||
mocker.patch("kasa.smartdevice.SmartDevice.update")
|
mocker.patch("kasa.smartdevice.SmartDevice.update")
|
||||||
res = xdoctest.doctest_module("kasa.smartdevice", "all")
|
res = xdoctest.doctest_module("kasa.smartdevice", "all")
|
||||||
@ -26,7 +27,7 @@ def test_smartdevice_examples(mocker):
|
|||||||
|
|
||||||
def test_plug_examples(mocker):
|
def test_plug_examples(mocker):
|
||||||
"""Test plug examples."""
|
"""Test plug examples."""
|
||||||
p = get_device_for_file("HS110(EU)_1.0_real.json")
|
p = asyncio.run(get_device_for_file("HS110(EU)_1.0_real.json"))
|
||||||
mocker.patch("kasa.smartplug.SmartPlug", return_value=p)
|
mocker.patch("kasa.smartplug.SmartPlug", return_value=p)
|
||||||
mocker.patch("kasa.smartplug.SmartPlug.update")
|
mocker.patch("kasa.smartplug.SmartPlug.update")
|
||||||
res = xdoctest.doctest_module("kasa.smartplug", "all")
|
res = xdoctest.doctest_module("kasa.smartplug", "all")
|
||||||
@ -35,7 +36,7 @@ def test_plug_examples(mocker):
|
|||||||
|
|
||||||
def test_strip_examples(mocker):
|
def test_strip_examples(mocker):
|
||||||
"""Test strip examples."""
|
"""Test strip examples."""
|
||||||
p = get_device_for_file("KP303(UK)_1.0.json")
|
p = asyncio.run(get_device_for_file("KP303(UK)_1.0.json"))
|
||||||
mocker.patch("kasa.smartstrip.SmartStrip", return_value=p)
|
mocker.patch("kasa.smartstrip.SmartStrip", return_value=p)
|
||||||
mocker.patch("kasa.smartstrip.SmartStrip.update")
|
mocker.patch("kasa.smartstrip.SmartStrip.update")
|
||||||
res = xdoctest.doctest_module("kasa.smartstrip", "all")
|
res = xdoctest.doctest_module("kasa.smartstrip", "all")
|
||||||
@ -44,7 +45,7 @@ def test_strip_examples(mocker):
|
|||||||
|
|
||||||
def test_dimmer_examples(mocker):
|
def test_dimmer_examples(mocker):
|
||||||
"""Test dimmer examples."""
|
"""Test dimmer examples."""
|
||||||
p = get_device_for_file("HS220(US)_1.0_real.json")
|
p = asyncio.run(get_device_for_file("HS220(US)_1.0_real.json"))
|
||||||
mocker.patch("kasa.smartdimmer.SmartDimmer", return_value=p)
|
mocker.patch("kasa.smartdimmer.SmartDimmer", return_value=p)
|
||||||
mocker.patch("kasa.smartdimmer.SmartDimmer.update")
|
mocker.patch("kasa.smartdimmer.SmartDimmer.update")
|
||||||
res = xdoctest.doctest_module("kasa.smartdimmer", "all")
|
res = xdoctest.doctest_module("kasa.smartdimmer", "all")
|
||||||
@ -53,7 +54,7 @@ def test_dimmer_examples(mocker):
|
|||||||
|
|
||||||
def test_lightstrip_examples(mocker):
|
def test_lightstrip_examples(mocker):
|
||||||
"""Test lightstrip examples."""
|
"""Test lightstrip examples."""
|
||||||
p = get_device_for_file("KL430(US)_1.0.json")
|
p = asyncio.run(get_device_for_file("KL430(US)_1.0.json"))
|
||||||
mocker.patch("kasa.smartlightstrip.SmartLightStrip", return_value=p)
|
mocker.patch("kasa.smartlightstrip.SmartLightStrip", return_value=p)
|
||||||
mocker.patch("kasa.smartlightstrip.SmartLightStrip.update")
|
mocker.patch("kasa.smartlightstrip.SmartLightStrip.update")
|
||||||
res = xdoctest.doctest_module("kasa.smartlightstrip", "all")
|
res = xdoctest.doctest_module("kasa.smartlightstrip", "all")
|
||||||
@ -65,7 +66,7 @@ def test_lightstrip_examples(mocker):
|
|||||||
)
|
)
|
||||||
def test_discovery_examples(mocker):
|
def test_discovery_examples(mocker):
|
||||||
"""Test discovery examples."""
|
"""Test discovery examples."""
|
||||||
p = get_device_for_file("KP303(UK)_1.0.json")
|
p = asyncio.run(get_device_for_file("KP303(UK)_1.0.json"))
|
||||||
|
|
||||||
# This succeeds on python 3.8 but fails on 3.7
|
# This succeeds on python 3.8 but fails on 3.7
|
||||||
# ValueError: a coroutine was expected, got [<DeviceType.Strip model KP303(UK) ...
|
# ValueError: a coroutine was expected, got [<DeviceType.Strip model KP303(UK) ...
|
||||||
|
Loading…
Reference in New Issue
Block a user