Remove sync interface, add asyncio discovery (#14)

* do not update inside __repr__

* Convert discovery to asyncio

* Use asyncio.DatagramProtocol
* Cleanup parameters, no more positional arguments

Closes #7

* Remove sync interface

* This requires #13 to be merged. Closes #12.
* Converts cli to use asyncio.run() where needed.
* The children from smartstrips is being initialized during the first update call.

* Convert on and off commands to use asyncio.run

* conftest: do the initial update automatically for the device, cleans up tests a bit

* return subdevices alias for strip plugs, remove sync from docstrings

* Make tests pass using pytest-asyncio

* Simplify tests and use pytest-asyncio.
* Removed the emeter tests for child devices, as this information do not seem to exist (based on the dummy sysinfo data). Can be added again if needed.
* Remove sync from docstrings.

* Fix incorrect type hint

* Add type hints and some docstrings to discovery
This commit is contained in:
Teemu R 2020-01-12 22:44:19 +01:00 committed by GitHub
parent 3c68d295da
commit 524d28abbc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 386 additions and 341 deletions

View File

@ -108,11 +108,12 @@ def discover(ctx, timeout, discover_only, dump_raw):
"""Discover devices in the network."""
target = ctx.parent.params["target"]
click.echo("Discovering devices for %s seconds" % timeout)
found_devs = Discover.discover(
target=target, timeout=timeout, return_raw=dump_raw
).items()
found_devs = asyncio.run(
Discover.discover(target=target, timeout=timeout, return_raw=dump_raw)
)
if not discover_only:
for ip, dev in found_devs:
for ip, dev in found_devs.items():
asyncio.run(dev.update())
if dump_raw:
click.echo(dev)
continue
@ -144,7 +145,7 @@ def find_host_from_alias(alias, target="255.255.255.255", timeout=1, attempts=3)
@pass_dev
def sysinfo(dev):
"""Print out full system information."""
dev.sync.update()
asyncio.run(dev.update())
click.echo(click.style("== System info ==", bold=True))
click.echo(pf(dev.sys_info))
@ -154,7 +155,7 @@ def sysinfo(dev):
@click.pass_context
def state(ctx, dev: SmartDevice):
"""Print out device state and versions."""
dev.sync.update()
asyncio.run(dev.update())
click.echo(click.style(f"== {dev.alias} - {dev.model} ==", bold=True))
click.echo(
@ -165,7 +166,7 @@ def state(ctx, dev: SmartDevice):
)
if dev.is_strip:
for plug in dev.plugs: # type: ignore
plug.sync.update()
asyncio.run(plug.update())
is_on = plug.is_on
alias = plug.alias
click.echo(
@ -179,7 +180,7 @@ def state(ctx, dev: SmartDevice):
for k, v in dev.state_information.items():
click.echo(f"{k}: {v}")
click.echo(click.style("== Generic information ==", bold=True))
click.echo("Time: {}".format(dev.sync.get_time()))
click.echo("Time: {}".format(asyncio.run(dev.get_time())))
click.echo("Hardware: {}".format(dev.hw_info["hw_ver"]))
click.echo("Software: {}".format(dev.hw_info["sw_ver"]))
click.echo(f"MAC (rssi): {dev.mac} ({dev.rssi})")
@ -195,7 +196,7 @@ def alias(dev, new_alias):
"""Get or set the device alias."""
if new_alias is not None:
click.echo(f"Setting alias to {new_alias}")
dev.sync.set_alias(new_alias)
asyncio.run(dev.set_alias(new_alias))
click.echo(f"Alias: {dev.alias}")
@ -211,8 +212,8 @@ def raw_command(dev: SmartDevice, module, command, parameters):
if parameters is not None:
parameters = ast.literal_eval(parameters)
res = dev.sync._query_helper(module, command, parameters)
dev.sync.update()
res = asyncio.run(dev._query_helper(module, command, parameters))
asyncio.run(dev.update())
click.echo(res)
@ -224,24 +225,26 @@ def raw_command(dev: SmartDevice, module, command, parameters):
def emeter(dev, year, month, erase):
"""Query emeter for historical consumption."""
click.echo(click.style("== Emeter ==", bold=True))
dev.sync.update()
asyncio.run(dev.update())
if not dev.has_emeter:
click.echo("Device has no emeter")
return
if erase:
click.echo("Erasing emeter statistics..")
dev.sync.erase_emeter_stats()
asyncio.run(dev.erase_emeter_stats())
return
if year:
click.echo(f"== For year {year.year} ==")
emeter_status = dev.sync.get_emeter_monthly(year.year)
emeter_status = asyncio.run(dev.get_emeter_monthly(year.year))
elif month:
click.echo(f"== For month {month.month} of {month.year} ==")
emeter_status = dev.sync.get_emeter_daily(year=month.year, month=month.month)
emeter_status = asyncio.run(
dev.get_emeter_daily(year=month.year, month=month.month)
)
else:
emeter_status = dev.sync.get_emeter_realtime()
emeter_status = asyncio.run(dev.get_emeter_realtime())
click.echo("== Current State ==")
if isinstance(emeter_status, list):
@ -256,7 +259,7 @@ def emeter(dev, year, month, erase):
@pass_dev
def brightness(dev, brightness):
"""Get or set brightness."""
dev.sync.update()
asyncio.run(dev.update())
if not dev.is_dimmable:
click.echo("This device does not support brightness.")
return
@ -264,7 +267,7 @@ def brightness(dev, brightness):
click.echo("Brightness: %s" % dev.brightness)
else:
click.echo("Setting brightness to %s" % brightness)
dev.sync.set_brightness(brightness)
asyncio.run(dev.set_brightness(brightness))
@cli.command()
@ -286,7 +289,7 @@ def temperature(dev: SmartBulb, temperature):
)
else:
click.echo(f"Setting color temperature to {temperature}")
dev.sync.set_color_temp(temperature)
asyncio.run(dev.set_color_temp(temperature))
@cli.command()
@ -303,7 +306,7 @@ def hsv(dev, ctx, h, s, v):
raise click.BadArgumentUsage("Setting a color requires 3 values.", ctx)
else:
click.echo(f"Setting HSV: {h} {s} {v}")
dev.sync.set_hsv(h, s, v)
asyncio.run(dev.set_hsv(h, s, v))
@cli.command()
@ -313,7 +316,7 @@ def led(dev, state):
"""Get or set (Plug's) led state."""
if state is not None:
click.echo("Turning led to %s" % state)
dev.sync.set_led(state)
asyncio.run(dev.set_led(state))
else:
click.echo("LED state: %s" % dev.led)
@ -322,7 +325,7 @@ def led(dev, state):
@pass_dev
def time(dev):
"""Get the device time."""
click.echo(dev.sync.get_time())
click.echo(asyncio.run(dev.get_time()))
@cli.command()
@ -332,9 +335,9 @@ def on(plug, index):
"""Turn the device on."""
click.echo("Turning on..")
if index is None:
plug.turn_on()
asyncio.run(plug.turn_on())
else:
plug.turn_on(index=(index - 1))
asyncio.run(plug.turn_on(index=(index - 1)))
@cli.command()
@ -344,9 +347,9 @@ def off(plug, index):
"""Turn the device off."""
click.echo("Turning off..")
if index is None:
plug.turn_off()
asyncio.run(plug.turn_off())
else:
plug.turn_off(index=(index - 1))
asyncio.run(plug.turn_off(index=(index - 1)))
@cli.command()
@ -355,7 +358,7 @@ def off(plug, index):
def reboot(plug, delay):
"""Reboot the device."""
click.echo("Rebooting the device..")
plug.reboot(delay)
asyncio.run(plug.reboot(delay))
if __name__ == "__main__":

View File

@ -1,8 +1,9 @@
"""Discovery module for TP-Link Smart Home devices."""
import asyncio
import json
import logging
import socket
from typing import Dict, Optional, Type
from typing import Awaitable, Callable, Dict, Mapping, Type, Union, cast
from kasa.protocol import TPLinkSmartHomeProtocol
from kasa.smartbulb import SmartBulb
@ -13,6 +14,79 @@ from kasa.smartstrip import SmartStrip
_LOGGER = logging.getLogger(__name__)
OnDiscoveredCallable = Callable[[SmartDevice], Awaitable[None]]
class _DiscoverProtocol(asyncio.DatagramProtocol):
"""Implementation of the discovery protocol handler.
This is internal class, use :func:Discover.discover: instead.
"""
discovered_devices: Dict[str, SmartDevice]
discovered_devices_raw: Dict[str, Dict]
def __init__(
self,
*,
on_discovered: OnDiscoveredCallable = None,
target: str = "255.255.255.255",
timeout: int = 5,
discovery_packets: int = 3,
):
self.transport = None
self.tries = discovery_packets
self.timeout = timeout
self.on_discovered = on_discovered
self.protocol = TPLinkSmartHomeProtocol()
self.target = (target, Discover.DISCOVERY_PORT)
self.discovered_devices = {}
self.discovered_devices_raw = {}
def connection_made(self, transport) -> None:
"""Set socket options for broadcasting."""
self.transport = transport
sock = transport.get_extra_info("socket")
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.do_discover()
def do_discover(self) -> None:
"""Send number of discovery datagrams."""
req = json.dumps(Discover.DISCOVERY_QUERY)
_LOGGER.debug("[DISCOVERY] %s >> %s", self.target, Discover.DISCOVERY_QUERY)
encrypted_req = self.protocol.encrypt(req)
for i in range(self.tries):
self.transport.sendto(encrypted_req[4:], self.target) # type: ignore
def datagram_received(self, data, addr) -> None:
ip, port = addr
if ip in self.discovered_devices:
return
info = json.loads(self.protocol.decrypt(data))
_LOGGER.debug("[DISCOVERY] %s << %s", ip, info)
device_class = Discover._get_device_class(info)
device = device_class(ip)
self.discovered_devices[ip] = device
self.discovered_devices_raw[ip] = info
if device_class is not None:
if self.on_discovered is not None:
asyncio.ensure_future(self.on_discovered(device))
else:
_LOGGER.error("Received invalid response: %s", info)
def error_received(self, ex):
_LOGGER.error("Got error: %s", ex)
def connection_lost(self, ex):
pass
class Discover:
"""Discover TPLink Smart Home devices.
@ -28,6 +102,8 @@ class Discover:
The protocol uses UDP broadcast datagrams on port 9999 for discovery.
"""
DISCOVERY_PORT = 9999
DISCOVERY_QUERY = {
"system": {"get_sysinfo": None},
"emeter": {"get_realtime": None},
@ -37,75 +113,65 @@ class Discover:
}
@staticmethod
def discover(
protocol: TPLinkSmartHomeProtocol = None,
target: str = "255.255.255.255",
port: int = 9999,
timeout: int = 3,
async def discover(
*,
target="255.255.255.255",
on_discovered=None,
timeout=5,
discovery_packets=3,
return_raw=False,
) -> Dict[str, SmartDevice]:
"""Discover devices.
) -> Mapping[str, Union[SmartDevice, Dict]]:
"""Discover supported devices.
Sends discovery message to 255.255.255.255:9999 in order
to detect available supported devices in the local network,
and waits for given timeout for answers from devices.
:param protocol: Protocol implementation to use
If given, `on_discovered` coroutine will get passed with the SmartDevice as parameter.
The results of the discovery can be accessed either via `discovered_devices` (SmartDevice-derived) or
`discovered_devices_raw` (JSON objects).
:param target: The target broadcast address (e.g. 192.168.xxx.255).
:param timeout: How long to wait for responses, defaults to 3
:param port: port to send broadcast messages, defaults to 9999.
:rtype: dict
:return: Array of json objects {"ip", "port", "sys_info"}
:param on_discovered:
:param timeout: How long to wait for responses, defaults to 5
:param discovery_packets: Number of discovery packets are broadcasted.
:param return_raw: True to return JSON objects instead of Devices.
:return:
"""
if protocol is None:
protocol = TPLinkSmartHomeProtocol()
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.settimeout(timeout)
req = json.dumps(Discover.DISCOVERY_QUERY)
_LOGGER.debug("Sending discovery to %s:%s", target, port)
encrypted_req = protocol.encrypt(req)
for i in range(discovery_packets):
sock.sendto(encrypted_req[4:], (target, port))
devices = {}
_LOGGER.debug("Waiting %s seconds for responses...", timeout)
loop = asyncio.get_event_loop()
transport, protocol = await loop.create_datagram_endpoint(
lambda: _DiscoverProtocol(
target=target,
on_discovered=on_discovered,
timeout=timeout,
discovery_packets=discovery_packets,
),
local_addr=("0.0.0.0", 0),
)
protocol = cast(_DiscoverProtocol, protocol)
try:
while True:
data, addr = sock.recvfrom(4096)
ip, port = addr
info = json.loads(protocol.decrypt(data))
device_class = Discover._get_device_class(info)
if return_raw:
devices[ip] = info
elif device_class is not None:
devices[ip] = device_class(ip)
except socket.timeout:
_LOGGER.debug("Got socket timeout, which is okay.")
except Exception as ex:
_LOGGER.error("Got exception %s", ex, exc_info=True)
_LOGGER.debug("Found %s devices: %s", len(devices), devices)
return devices
_LOGGER.debug("Waiting %s seconds for responses...", timeout)
await asyncio.sleep(5)
finally:
transport.close()
_LOGGER.debug("Discovered %s devices", len(protocol.discovered_devices))
if return_raw:
return protocol.discovered_devices_raw
return protocol.discovered_devices
@staticmethod
async def discover_single(
host: str, protocol: TPLinkSmartHomeProtocol = None
) -> Optional[SmartDevice]:
async def discover_single(host: str) -> SmartDevice:
"""Discover a single device by the given IP address.
:param host: Hostname of device to query
:param protocol: Protocol implementation to use
:rtype: SmartDevice
:return: Object for querying/controlling found device.
"""
if protocol is None:
protocol = TPLinkSmartHomeProtocol()
protocol = TPLinkSmartHomeProtocol()
info = await protocol.query(host, Discover.DISCOVERY_QUERY)
@ -113,10 +179,10 @@ class Discover:
if device_class is not None:
return device_class(host)
return None
raise SmartDeviceException("Unable to discover device, received: %s" % info)
@staticmethod
def _get_device_class(info: dict) -> Optional[Type[SmartDevice]]:
def _get_device_class(info: dict) -> Type[SmartDevice]:
"""Find SmartDevice subclass for device described by passed data."""
if "system" in info and "get_sysinfo" in info["system"]:
sysinfo = info["system"]["get_sysinfo"]
@ -136,4 +202,17 @@ class Discover:
elif "smartbulb" in type_.lower():
return SmartBulb
return None
raise SmartDeviceException("Unknown device type: %s", type_)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
loop = asyncio.get_event_loop()
async def _on_device(dev):
await dev.update()
_LOGGER.info("Got device: %s", dev)
devices = loop.run_until_complete(Discover.discover(on_discovered=_on_device))
for ip, dev in devices.items():
print(f"[{ip}] {dev}")

View File

@ -27,46 +27,45 @@ class SmartBulb(SmartDevice):
Usage example when used as library:
```python
p = SmartBulb("192.168.1.105")
await p.update()
# print the devices alias
print(p.sync.alias)
print(p.alias)
# change state of bulb
p.sync.turn_on()
p.sync.turn_off()
await p.turn_on()
await p.turn_off()
# query and print current state of plug
print(p.sync.state_information())
print(p.state_information)
# check whether the bulb supports color changes
if p.sync.is_color():
if p.is_color:
# set the color to an HSV tuple
p.sync.set_hsv(180, 100, 100)
await p.set_hsv(180, 100, 100)
# get the current HSV value
print(p.sync.hsv())
print(p.hsv)
# check whether the bulb supports setting color temperature
if p.sync.is_variable_color_temp():
if p.is_variable_color_temp:
# set the color temperature in Kelvin
p.sync.set_color_temp(3000)
await p.set_color_temp(3000)
# get the current color temperature
print(p.sync.color_temp)
print(p.color_temp)
# check whether the bulb is dimmable
if p.is_dimmable:
# set the bulb to 50% brightness
p.sync.set_brightness(50)
await p.set_brightness(50)
# check the current brightness
print(p.brightness)
```
Omit the `sync` attribute to get coroutines.
Errors reported by the device are raised as SmartDeviceExceptions,
and should be handled by the user of the library.
"""

View File

@ -138,7 +138,6 @@ class SmartDevice:
self.cache = defaultdict(lambda: defaultdict(lambda: None)) # type: ignore
self._device_type = DeviceType.Unknown
self.ioloop = ioloop or asyncio.get_event_loop()
self.sync = SyncSmartDevice(self, ioloop=self.ioloop)
self._sys_info = None
def _result_from_cache(self, target, cmd) -> Optional[Dict]:
@ -646,14 +645,13 @@ class SmartDevice:
return False
def __repr__(self):
self.sync.update()
return "<{} model {} at {} ({}), is_on: {} - dev specific: {}>".format(
self.__class__.__name__,
self.model,
self.host,
self.alias,
self.is_on,
self.sync.state_information,
self.state_information,
)

View File

@ -22,18 +22,16 @@ class SmartPlug(SmartDevice):
p = SmartPlug("192.168.1.105")
# print the devices alias
print(p.sync.alias)
print(p.alias)
# change state of plug
p.sync.turn_on()
p.sync.turn_off()
await p.turn_on()
await p.turn_off()
# query and print current state of plug
print(p.sync.state_information)
print(p.state_information)
```
Omit the `sync` attribute to get coroutines.
Errors reported by the device are raised as SmartDeviceExceptions,
and should be handled by the user of the library.
"""
@ -92,6 +90,26 @@ class SmartPlug(SmartDevice):
else:
raise ValueError("Brightness value %s is not valid." % value)
def _get_child_info(self):
for plug in self.sys_info["children"]:
if plug["id"] == self.context:
return plug
raise SmartDeviceException("Unable to find children %s")
@property # type: ignore
@requires_update
def alias(self) -> str:
"""Return device name (alias).
:return: Device name aka alias.
:rtype: str
"""
if self.context:
info = self._get_child_info()
return info["alias"]
else:
return super().alias
@property # type: ignore
@requires_update
def is_dimmable(self):
@ -122,6 +140,10 @@ class SmartPlug(SmartDevice):
:return: True if device is on, False otherwise
"""
if self.context:
info = self._get_child_info()
return info["state"]
sys_info = self.sys_info
return bool(sys_info["relay_state"])
@ -171,10 +193,8 @@ class SmartPlug(SmartDevice):
"""
sys_info = self.sys_info
if self.context:
for plug in sys_info["children"]:
if plug["id"] == self.context:
on_time = plug["on_time"]
break
info = self._get_child_info()
on_time = info["on_time"]
else:
on_time = sys_info["on_time"]

View File

@ -22,22 +22,21 @@ class SmartStrip(SmartPlug):
p = SmartStrip("192.168.1.105")
# query the state of the strip
await p.update()
print(p.is_on)
# change state of all outlets
p.sync.turn_on()
p.sync.turn_off()
await p.turn_on()
await p.turn_off()
# individual outlets are accessible through plugs variable
for plug in p.plugs:
print(f"{p}: {p.is_on}")
# change state of a single outlet
p.plugs[0].sync.turn_on()
await p.plugs[0].turn_on()
```
Omit the `sync` attribute to get coroutines.
Errors reported by the device are raised as SmartDeviceExceptions,
and should be handled by the user of the library.
"""
@ -53,18 +52,6 @@ class SmartStrip(SmartPlug):
self.emeter_type = "emeter"
self._device_type = DeviceType.Strip
self.plugs: List[SmartPlug] = []
children = self.sync.get_sys_info()["children"]
self.num_children = len(children)
for child in children:
self.plugs.append(
SmartPlug(
host,
protocol,
context=child["id"],
cache_ttl=cache_ttl,
ioloop=ioloop,
)
)
@property # type: ignore
@requires_update
@ -82,6 +69,22 @@ class SmartStrip(SmartPlug):
Needed for methods that are decorated with `requires_update`.
"""
await super().update()
# Initialize the child devices during the first update.
if not self.plugs:
children = self.sys_info["children"]
self.num_children = len(children)
for child in children:
self.plugs.append(
SmartPlug(
self.host,
self.protocol,
context=child["id"],
cache_ttl=self.cache_ttl.total_seconds(),
ioloop=self.ioloop,
)
)
for plug in self.plugs:
await plug.update()

View File

@ -73,21 +73,27 @@ non_color_bulb = pytest.mark.parametrize(
turn_on = pytest.mark.parametrize("turn_on", [True, False])
def handle_turn_on(dev, turn_on):
async def handle_turn_on(dev, turn_on):
if turn_on:
dev.sync.turn_on()
await dev.turn_on()
else:
dev.sync.turn_off()
await dev.turn_off()
@pytest.fixture(params=SUPPORTED_DEVICES)
def dev(request):
"""Device fixture.
Provides a device (given --ip) or parametrized fixture for the supported devices.
The initial update is called automatically before returning the device.
"""
ioloop = get_ioloop()
file = request.param
ip = request.config.getoption("--ip")
if ip:
d = ioloop.run_until_complete(Discover.discover_single(ip))
ioloop.run_until_complete(d.update())
print(d.model)
if d.model in file:
return d
@ -109,6 +115,7 @@ def dev(request):
p = SmartPlug(**params, ioloop=ioloop)
else:
raise Exception("No tests for %s" % model)
ioloop.run_until_complete(p.update())
yield p

View File

@ -30,9 +30,9 @@ from .newfakes import (
)
@pytest.mark.asyncio
@plug
def test_plug_sysinfo(dev):
dev.sync.update()
async def test_plug_sysinfo(dev):
assert dev.sys_info is not None
PLUG_SCHEMA(dev.sys_info)
@ -42,9 +42,9 @@ def test_plug_sysinfo(dev):
assert dev.is_plug or dev.is_strip
@pytest.mark.asyncio
@bulb
def test_bulb_sysinfo(dev):
dev.sync.update()
async def test_bulb_sysinfo(dev):
assert dev.sys_info is not None
BULB_SCHEMA(dev.sys_info)
@ -54,83 +54,87 @@ def test_bulb_sysinfo(dev):
assert dev.is_bulb
def test_state_info(dev):
dev.sync.update()
assert isinstance(dev.sync.state_information, dict)
@pytest.mark.asyncio
async def test_state_info(dev):
assert isinstance(dev.state_information, dict)
def test_invalid_connection(dev):
@pytest.mark.asyncio
async def test_invalid_connection(dev):
with patch.object(FakeTransportProtocol, "query", side_effect=SmartDeviceException):
with pytest.raises(SmartDeviceException):
dev.sync.update()
await dev.update()
dev.is_on
def test_query_helper(dev):
@pytest.mark.asyncio
async def test_query_helper(dev):
with pytest.raises(SmartDeviceException):
dev.sync._query_helper("test", "testcmd", {})
await dev._query_helper("test", "testcmd", {})
# TODO check for unwrapping?
@pytest.mark.asyncio
@turn_on
def test_state(dev, turn_on):
handle_turn_on(dev, turn_on)
dev.sync.update()
async def test_state(dev, turn_on):
await handle_turn_on(dev, turn_on)
orig_state = dev.is_on
if orig_state:
dev.sync.turn_off()
await dev.turn_off()
assert not dev.is_on
assert dev.is_off
dev.sync.turn_on()
await dev.turn_on()
assert dev.is_on
assert not dev.is_off
else:
dev.sync.turn_on()
await dev.turn_on()
assert dev.is_on
assert not dev.is_off
dev.sync.turn_off()
await dev.turn_off()
assert not dev.is_on
assert dev.is_off
@pytest.mark.asyncio
@no_emeter
def test_no_emeter(dev):
dev.sync.update()
async def test_no_emeter(dev):
assert not dev.has_emeter
with pytest.raises(SmartDeviceException):
dev.sync.get_emeter_realtime()
await dev.get_emeter_realtime()
with pytest.raises(SmartDeviceException):
dev.sync.get_emeter_daily()
await dev.get_emeter_daily()
with pytest.raises(SmartDeviceException):
dev.sync.get_emeter_monthly()
await dev.get_emeter_monthly()
with pytest.raises(SmartDeviceException):
dev.sync.erase_emeter_stats()
await dev.erase_emeter_stats()
@pytest.mark.asyncio
@has_emeter
def test_get_emeter_realtime(dev):
dev.sync.update()
async def test_get_emeter_realtime(dev):
if dev.is_strip:
pytest.skip("Disabled for HS300 temporarily")
assert dev.has_emeter
current_emeter = dev.sync.get_emeter_realtime()
current_emeter = await dev.get_emeter_realtime()
CURRENT_CONSUMPTION_SCHEMA(current_emeter)
@pytest.mark.asyncio
@has_emeter
def test_get_emeter_daily(dev):
dev.sync.update()
async def test_get_emeter_daily(dev):
if dev.is_strip:
pytest.skip("Disabled for HS300 temporarily")
assert dev.has_emeter
assert dev.sync.get_emeter_daily(year=1900, month=1) == {}
assert await dev.get_emeter_daily(year=1900, month=1) == {}
d = dev.sync.get_emeter_daily()
d = await dev.get_emeter_daily()
assert len(d) > 0
k, v = d.popitem()
@ -138,22 +142,22 @@ def test_get_emeter_daily(dev):
assert isinstance(v, float)
# Test kwh (energy, energy_wh)
d = dev.sync.get_emeter_daily(kwh=False)
d = await dev.get_emeter_daily(kwh=False)
k2, v2 = d.popitem()
assert v * 1000 == v2
@pytest.mark.asyncio
@has_emeter
def test_get_emeter_monthly(dev):
dev.sync.update()
async def test_get_emeter_monthly(dev):
if dev.is_strip:
pytest.skip("Disabled for HS300 temporarily")
assert dev.has_emeter
assert dev.sync.get_emeter_monthly(year=1900) == {}
assert await dev.get_emeter_monthly(year=1900) == {}
d = dev.sync.get_emeter_monthly()
d = await dev.get_emeter_monthly()
assert len(d) > 0
k, v = d.popitem()
@ -161,20 +165,20 @@ def test_get_emeter_monthly(dev):
assert isinstance(v, float)
# Test kwh (energy, energy_wh)
d = dev.sync.get_emeter_monthly(kwh=False)
d = await dev.get_emeter_monthly(kwh=False)
k2, v2 = d.popitem()
assert v * 1000 == v2
@pytest.mark.asyncio
@has_emeter
def test_emeter_status(dev):
dev.sync.update()
async def test_emeter_status(dev):
if dev.is_strip:
pytest.skip("Disabled for HS300 temporarily")
assert dev.has_emeter
d = dev.sync.get_emeter_realtime()
d = await dev.get_emeter_realtime()
with pytest.raises(KeyError):
assert d["foo"]
@ -188,162 +192,165 @@ def test_emeter_status(dev):
assert d["total_wh"] == d["total"] * 1000
@pytest.mark.asyncio
@pytest.mark.skip("not clearing your stats..")
@has_emeter
def test_erase_emeter_stats(dev):
dev.sync.update()
async def test_erase_emeter_stats(dev):
assert dev.has_emeter
dev.sync.erase_emeter()
await dev.erase_emeter()
@pytest.mark.asyncio
@has_emeter
def test_current_consumption(dev):
dev.sync.update()
async def test_current_consumption(dev):
if dev.is_strip:
pytest.skip("Disabled for HS300 temporarily")
if dev.has_emeter:
x = dev.sync.current_consumption()
x = await dev.current_consumption()
assert isinstance(x, float)
assert x >= 0.0
else:
assert dev.sync.current_consumption() is None
assert await dev.current_consumption() is None
def test_alias(dev):
dev.sync.update()
@pytest.mark.asyncio
async def test_alias(dev):
test_alias = "TEST1234"
original = dev.sync.alias
original = dev.alias
assert isinstance(original, str)
dev.sync.set_alias(test_alias)
assert dev.sync.alias == test_alias
await dev.set_alias(test_alias)
assert dev.alias == test_alias
dev.sync.set_alias(original)
assert dev.sync.alias == original
await dev.set_alias(original)
assert dev.alias == original
@pytest.mark.asyncio
@plug
def test_led(dev):
dev.sync.update()
async def test_led(dev):
original = dev.led
dev.sync.set_led(False)
await dev.set_led(False)
assert not dev.led
dev.sync.set_led(True)
await dev.set_led(True)
assert dev.led
dev.sync.set_led(original)
await dev.set_led(original)
@pytest.mark.asyncio
@plug
def test_on_since(dev):
dev.sync.update()
async def test_on_since(dev):
assert isinstance(dev.on_since, datetime.datetime)
def test_icon(dev):
assert set(dev.sync.get_icon().keys()), {"icon", "hash"}
@pytest.mark.asyncio
async def test_icon(dev):
assert set((await dev.get_icon()).keys()), {"icon", "hash"}
def test_time(dev):
assert isinstance(dev.sync.get_time(), datetime.datetime)
@pytest.mark.asyncio
async def test_time(dev):
assert isinstance(await dev.get_time(), datetime.datetime)
# TODO check setting?
def test_timezone(dev):
TZ_SCHEMA(dev.sync.get_timezone())
@pytest.mark.asyncio
async def test_timezone(dev):
TZ_SCHEMA(await dev.get_timezone())
def test_hw_info(dev):
dev.sync.update()
@pytest.mark.asyncio
async def test_hw_info(dev):
PLUG_SCHEMA(dev.hw_info)
def test_location(dev):
dev.sync.update()
@pytest.mark.asyncio
async def test_location(dev):
PLUG_SCHEMA(dev.location)
def test_rssi(dev):
dev.sync.update()
@pytest.mark.asyncio
async def test_rssi(dev):
PLUG_SCHEMA({"rssi": dev.rssi}) # wrapping for vol
def test_mac(dev):
dev.sync.update()
@pytest.mark.asyncio
async def test_mac(dev):
PLUG_SCHEMA({"mac": dev.mac}) # wrapping for val
# TODO check setting?
@pytest.mark.asyncio
@non_variable_temp
def test_temperature_on_nonsupporting(dev):
dev.sync.update()
async def test_temperature_on_nonsupporting(dev):
assert dev.valid_temperature_range == (0, 0)
# TODO test when device does not support temperature range
with pytest.raises(SmartDeviceException):
dev.sync.set_color_temp(2700)
await dev.set_color_temp(2700)
with pytest.raises(SmartDeviceException):
print(dev.sync.color_temp)
print(dev.color_temp)
@pytest.mark.asyncio
@variable_temp
def test_out_of_range_temperature(dev):
dev.sync.update()
async def test_out_of_range_temperature(dev):
with pytest.raises(ValueError):
dev.sync.set_color_temp(1000)
await dev.set_color_temp(1000)
with pytest.raises(ValueError):
dev.sync.set_color_temp(10000)
await dev.set_color_temp(10000)
@pytest.mark.asyncio
@non_dimmable
def test_non_dimmable(dev):
dev.sync.update()
async def test_non_dimmable(dev):
assert not dev.is_dimmable
with pytest.raises(SmartDeviceException):
assert dev.brightness == 0
with pytest.raises(SmartDeviceException):
dev.sync.set_brightness(100)
await dev.set_brightness(100)
@pytest.mark.asyncio
@dimmable
@turn_on
def test_dimmable_brightness(dev, turn_on):
handle_turn_on(dev, turn_on)
dev.sync.update()
async def test_dimmable_brightness(dev, turn_on):
await handle_turn_on(dev, turn_on)
assert dev.is_dimmable
dev.sync.set_brightness(50)
await dev.set_brightness(50)
assert dev.brightness == 50
dev.sync.set_brightness(10)
await dev.set_brightness(10)
assert dev.brightness == 10
with pytest.raises(ValueError):
dev.sync.set_brightness("foo")
await dev.set_brightness("foo")
@pytest.mark.asyncio
@dimmable
def test_invalid_brightness(dev):
dev.sync.update()
async def test_invalid_brightness(dev):
assert dev.is_dimmable
with pytest.raises(ValueError):
dev.sync.set_brightness(110)
await dev.set_brightness(110)
with pytest.raises(ValueError):
dev.sync.set_brightness(-100)
await dev.set_brightness(-100)
@pytest.mark.asyncio
@color_bulb
@turn_on
def test_hsv(dev, turn_on):
handle_turn_on(dev, turn_on)
dev.sync.update()
async def test_hsv(dev, turn_on):
await handle_turn_on(dev, turn_on)
assert dev.is_color
hue, saturation, brightness = dev.hsv
@ -351,7 +358,7 @@ def test_hsv(dev, turn_on):
assert 0 <= saturation <= 100
assert 0 <= brightness <= 100
dev.sync.set_hsv(hue=1, saturation=1, value=1)
await dev.set_hsv(hue=1, saturation=1, value=1)
hue, saturation, brightness = dev.hsv
assert hue == 1
@ -359,105 +366,101 @@ def test_hsv(dev, turn_on):
assert brightness == 1
@pytest.mark.asyncio
@color_bulb
@turn_on
def test_invalid_hsv(dev, turn_on):
handle_turn_on(dev, turn_on)
dev.sync.update()
async def test_invalid_hsv(dev, turn_on):
await handle_turn_on(dev, turn_on)
assert dev.is_color
for invalid_hue in [-1, 361, 0.5]:
with pytest.raises(ValueError):
dev.sync.set_hsv(invalid_hue, 0, 0)
await dev.set_hsv(invalid_hue, 0, 0)
for invalid_saturation in [-1, 101, 0.5]:
with pytest.raises(ValueError):
dev.sync.set_hsv(0, invalid_saturation, 0)
await dev.set_hsv(0, invalid_saturation, 0)
for invalid_brightness in [-1, 101, 0.5]:
with pytest.raises(ValueError):
dev.sync.set_hsv(0, 0, invalid_brightness)
await dev.set_hsv(0, 0, invalid_brightness)
@pytest.mark.asyncio
@non_color_bulb
def test_hsv_on_non_color(dev):
dev.sync.update()
async def test_hsv_on_non_color(dev):
assert not dev.is_color
with pytest.raises(SmartDeviceException):
dev.sync.set_hsv(0, 0, 0)
await dev.set_hsv(0, 0, 0)
with pytest.raises(SmartDeviceException):
print(dev.hsv)
@pytest.mark.asyncio
@variable_temp
@turn_on
def test_try_set_colortemp(dev, turn_on):
dev.sync.update()
handle_turn_on(dev, turn_on)
dev.sync.set_color_temp(2700)
assert dev.sync.color_temp == 2700
async def test_try_set_colortemp(dev, turn_on):
await handle_turn_on(dev, turn_on)
await dev.set_color_temp(2700)
assert dev.color_temp == 2700
@pytest.mark.asyncio
@non_variable_temp
def test_non_variable_temp(dev):
async def test_non_variable_temp(dev):
with pytest.raises(SmartDeviceException):
dev.sync.update()
dev.sync.set_color_temp(2700)
await dev.set_color_temp(2700)
@pytest.mark.asyncio
@strip
@turn_on
def test_children_change_state(dev, turn_on):
dev.sync.update()
handle_turn_on(dev, turn_on)
async def test_children_change_state(dev, turn_on):
await handle_turn_on(dev, turn_on)
for plug in dev.plugs:
plug.sync.update()
orig_state = plug.is_on
if orig_state:
plug.turn_off()
plug.sync.update()
await plug.turn_off()
assert not plug.is_on
assert plug.is_off
plug.sync.turn_on()
plug.sync.update()
await plug.turn_on()
assert plug.is_on
assert not plug.is_off
else:
plug.sync.turn_on()
plug.sync.update()
await plug.turn_on()
assert plug.is_on
assert not plug.is_off
plug.sync.turn_off()
plug.sync.update()
await plug.turn_off()
assert not plug.is_on
assert plug.is_off
@pytest.mark.asyncio
@strip
def test_children_alias(dev):
async def test_children_alias(dev):
test_alias = "TEST1234"
for plug in dev.plugs:
plug.sync.update()
original = plug.alias
plug.sync.set_alias(alias=test_alias)
plug.sync.update()
await plug.set_alias(alias=test_alias)
assert plug.alias == test_alias
plug.sync.set_alias(alias=original)
plug.sync.update()
await plug.set_alias(alias=original)
assert plug.alias == original
@pytest.mark.asyncio
@strip
def test_children_on_since(dev):
async def test_children_on_since(dev):
for plug in dev.plugs:
plug.sync.update()
assert plug.on_since
@pytest.mark.asyncio
@pytest.mark.skip("this test will wear out your relays")
def test_all_binary_states(dev):
async def test_all_binary_states(dev):
# test every binary state
for state in range(2 ** dev.num_children):
# create binary state map
@ -466,9 +469,9 @@ def test_all_binary_states(dev):
state_map[plug_index] = bool((state >> plug_index) & 1)
if state_map[plug_index]:
dev.sync.turn_on(index=plug_index)
await dev.turn_on(index=plug_index)
else:
dev.sync.turn_off(index=plug_index)
await dev.turn_off(index=plug_index)
# check state map applied
for index, state in dev.is_on.items():
@ -479,9 +482,9 @@ def test_all_binary_states(dev):
# toggle state
if state_map[plug_index]:
dev.sync.turn_off(index=plug_index)
await dev.turn_off(index=plug_index)
else:
dev.sync.turn_on(index=plug_index)
await dev.turn_on(index=plug_index)
# only target outlet should have state changed
for index, state in dev.is_on.items():
@ -492,86 +495,19 @@ def test_all_binary_states(dev):
# reset state
if state_map[plug_index]:
dev.sync.turn_on(index=plug_index)
await dev.turn_on(index=plug_index)
else:
dev.sync.turn_off(index=plug_index)
await dev.turn_off(index=plug_index)
# original state map should be restored
for index, state in dev.is_on.items():
assert state == state_map[index]
@strip
def test_children_get_emeter_realtime(dev):
dev.sync.update()
assert dev.has_emeter
# test with index
for plug in dev.plugs:
plug.sync.update()
emeter = plug.sync.get_emeter_realtime()
CURRENT_CONSUMPTION_SCHEMA(emeter)
# test without index
# TODO test that sum matches the sum of individiaul plugs.
# for index, emeter in dev.sync.get_emeter_realtime().items():
# CURRENT_CONSUMPTION_SCHEMA(emeter)
@strip
def test_children_get_emeter_daily(dev):
dev.sync.update()
assert dev.has_emeter
# test individual emeters
for plug in dev.plugs:
plug.sync.update()
emeter = plug.sync.get_emeter_daily(year=1900, month=1)
assert emeter == {}
emeter = plug.sync.get_emeter_daily()
assert len(emeter) > 0
k, v = emeter.popitem()
assert isinstance(k, int)
assert isinstance(v, float)
# test sum of emeters
all_emeter = dev.sync.get_emeter_daily(year=1900, month=1)
k, v = all_emeter.popitem()
assert isinstance(k, int)
assert isinstance(v, float)
@strip
def test_children_get_emeter_monthly(dev):
dev.sync.update()
assert dev.has_emeter
# test individual emeters
for plug in dev.plugs:
plug.sync.update()
emeter = plug.sync.get_emeter_monthly(year=1900)
assert emeter == {}
emeter = plug.sync.get_emeter_monthly()
assert len(emeter) > 0
k, v = emeter.popitem()
assert isinstance(k, int)
assert isinstance(v, float)
# test sum of emeters
all_emeter = dev.sync.get_emeter_monthly(year=1900)
k, v = all_emeter.popitem()
assert isinstance(k, int)
assert isinstance(v, float)
# def test_cache(dev):
# from datetime import timedelta
# dev.sync.cache_ttl = timedelta(seconds=3)
# dev.cache_ttl = timedelta(seconds=3)
# with patch.object(
# FakeTransportProtocol, "query", wraps=dev.protocol.query
# ) as query_mock:
@ -590,7 +526,7 @@ def test_children_get_emeter_monthly(dev):
# def test_cache_invalidates(dev):
# from datetime import timedelta
# dev.sync.cache_ttl = timedelta(seconds=0)
# dev.cache_ttl = timedelta(seconds=0)
# with patch.object(
# FakeTransportProtocol, "query", wraps=dev.protocol.query

View File

@ -14,8 +14,8 @@ deps=
pytest-cov
voluptuous
typing
deprecation
flake8
pytest-asyncio
commands=
py.test --cov --cov-config=tox.ini kasa