Various test code cleanups (#725)

* Separate fake protocols for iot and smart

* Move control_child impl into its own method

* Organize schemas into correct places

* Add test_childdevice

* Add missing return for _handle_control_child
This commit is contained in:
Teemu R 2024-01-29 20:26:39 +01:00 committed by GitHub
parent 1e26434205
commit 9e6896a08f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 332 additions and 304 deletions

View File

@ -27,7 +27,8 @@ from kasa.protocol import BaseTransport
from kasa.tapo import TapoBulb, TapoPlug from kasa.tapo import TapoBulb, TapoPlug
from kasa.xortransport import XorEncryption from kasa.xortransport import XorEncryption
from .newfakes import FakeSmartProtocol, FakeTransportProtocol from .fakeprotocol_iot import FakeIotProtocol
from .fakeprotocol_smart import FakeSmartProtocol
SUPPORTED_IOT_DEVICES = [ SUPPORTED_IOT_DEVICES = [
(device, "IOT") (device, "IOT")
@ -410,7 +411,7 @@ async def get_device_for_file(file, protocol):
if protocol == "SMART": if protocol == "SMART":
d.protocol = FakeSmartProtocol(sysinfo) d.protocol = FakeSmartProtocol(sysinfo)
else: else:
d.protocol = FakeTransportProtocol(sysinfo) d.protocol = FakeIotProtocol(sysinfo)
await _update_and_close(d) await _update_and_close(d)
return d return d
@ -521,7 +522,7 @@ def discovery_mock(all_fixture_data, mocker):
if "component_nego" in dm.query_data: if "component_nego" in dm.query_data:
proto = FakeSmartProtocol(dm.query_data) proto = FakeSmartProtocol(dm.query_data)
else: else:
proto = FakeTransportProtocol(dm.query_data) proto = FakeIotProtocol(dm.query_data)
async def _query(request, retry_count: int = 3): async def _query(request, retry_count: int = 3):
return await proto.query(request) return await proto.query(request)

View File

@ -1,185 +1,13 @@
import base64
import copy import copy
import logging import logging
import re
import warnings
from json import loads as json_loads
from voluptuous import (
REMOVE_EXTRA,
All,
Any,
Coerce, # type: ignore
Invalid,
Optional,
Range,
Schema,
)
from ..credentials import Credentials
from ..deviceconfig import DeviceConfig from ..deviceconfig import DeviceConfig
from ..exceptions import SmartDeviceException
from ..iotprotocol import IotProtocol from ..iotprotocol import IotProtocol
from ..protocol import BaseTransport
from ..smartprotocol import SmartProtocol
from ..xortransport import XorTransport from ..xortransport import XorTransport
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
def check_int_bool(x):
if x != 0 and x != 1:
raise Invalid(x)
return x
def check_mac(x):
if re.match("[0-9a-f]{2}([-:])[0-9a-f]{2}(\\1[0-9a-f]{2}){4}$", x.lower()):
return x
raise Invalid(x)
def check_mode(x):
if x in ["schedule", "none", "count_down"]:
return x
raise Invalid(f"invalid mode {x}")
def lb_dev_state(x):
if x in ["normal"]:
return x
raise Invalid(f"Invalid dev_state {x}")
TZ_SCHEMA = Schema(
{"zone_str": str, "dst_offset": int, "index": All(int, Range(min=0)), "tz_str": str}
)
CURRENT_CONSUMPTION_SCHEMA = Schema(
Any(
{
"voltage": Any(All(float, Range(min=0, max=300)), None),
"power": Any(Coerce(float, Range(min=0)), None),
"total": Any(Coerce(float, Range(min=0)), None),
"current": Any(All(float, Range(min=0)), None),
"voltage_mv": Any(
All(float, Range(min=0, max=300000)), int, None
), # TODO can this be int?
"power_mw": Any(Coerce(float, Range(min=0)), None),
"total_wh": Any(Coerce(float, Range(min=0)), None),
"current_ma": Any(
All(float, Range(min=0)), int, None
), # TODO can this be int?
"slot_id": Any(Coerce(int, Range(min=0)), None),
},
None,
)
)
# these schemas should go to the mainlib as
# they can be useful when adding support for new features/devices
# as well as to check that faked devices are operating properly.
PLUG_SCHEMA = Schema(
{
"active_mode": check_mode,
"alias": str,
"dev_name": str,
"deviceId": str,
"feature": str,
"fwId": str,
"hwId": str,
"hw_ver": str,
"icon_hash": str,
"led_off": check_int_bool,
"latitude": Any(All(float, Range(min=-90, max=90)), 0, None),
"latitude_i": Any(
All(int, Range(min=-900000, max=900000)),
All(float, Range(min=-900000, max=900000)),
0,
None,
),
"longitude": Any(All(float, Range(min=-180, max=180)), 0, None),
"longitude_i": Any(
All(int, Range(min=-18000000, max=18000000)),
All(float, Range(min=-18000000, max=18000000)),
0,
None,
),
"mac": check_mac,
"model": str,
"oemId": str,
"on_time": int,
"relay_state": int,
"rssi": Any(int, None), # rssi can also be positive, see #54
"sw_ver": str,
"type": str,
"mic_type": str,
"updating": check_int_bool,
# these are available on hs220
"brightness": int,
"preferred_state": [
{"brightness": All(int, Range(min=0, max=100)), "index": int}
],
"next_action": {"type": int},
"child_num": Optional(Any(None, int)), # TODO fix hs300 checks
"children": Optional(list), # TODO fix hs300
# TODO some tplink simulator entries contain invalid (mic_mac, _i variants for lat/lon)
# Therefore we add REMOVE_EXTRA..
# "INVALIDmac": Optional,
# "INVALIDlatitude": Optional,
# "INVALIDlongitude": Optional,
},
extra=REMOVE_EXTRA,
)
LIGHT_STATE_SCHEMA = Schema(
{
"brightness": All(int, Range(min=0, max=100)),
"color_temp": int,
"hue": All(int, Range(min=0, max=360)),
"mode": str,
"on_off": check_int_bool,
"saturation": All(int, Range(min=0, max=100)),
"dft_on_state": Optional(
{
"brightness": All(int, Range(min=0, max=100)),
"color_temp": All(int, Range(min=0, max=9000)),
"hue": All(int, Range(min=0, max=360)),
"mode": str,
"saturation": All(int, Range(min=0, max=100)),
}
),
"err_code": int,
}
)
BULB_SCHEMA = PLUG_SCHEMA.extend(
{
"ctrl_protocols": Optional(dict),
"description": Optional(str), # TODO: LBxxx similar to dev_name
"dev_state": lb_dev_state,
"disco_ver": str,
"heapsize": int,
"is_color": check_int_bool,
"is_dimmable": check_int_bool,
"is_factory": bool,
"is_variable_color_temp": check_int_bool,
"light_state": LIGHT_STATE_SCHEMA,
"preferred_state": [
{
"brightness": All(int, Range(min=0, max=100)),
"color_temp": int,
"hue": All(int, Range(min=0, max=360)),
"index": int,
"saturation": All(int, Range(min=0, max=100)),
}
],
}
)
def get_realtime(obj, x, *args): def get_realtime(obj, x, *args):
return { return {
"current": 0.268587, "current": 0.268587,
@ -294,123 +122,7 @@ TIME_MODULE = {
} }
class FakeSmartProtocol(SmartProtocol): class FakeIotProtocol(IotProtocol):
def __init__(self, info):
super().__init__(
transport=FakeSmartTransport(info),
)
async def query(self, request, retry_count: int = 3):
"""Implement query here so can still patch SmartProtocol.query."""
resp_dict = await self._query(request, retry_count)
return resp_dict
class FakeSmartTransport(BaseTransport):
def __init__(self, info):
super().__init__(
config=DeviceConfig(
"127.0.0.123",
credentials=Credentials(
username="dummy_user",
password="dummy_password", # noqa: S106
),
),
)
self.info = info
self.components = {
comp["id"]: comp["ver_code"]
for comp in self.info["component_nego"]["component_list"]
}
@property
def default_port(self):
"""Default port for the transport."""
return 80
@property
def credentials_hash(self):
"""The hashed credentials used by the transport."""
return self._credentials.username + self._credentials.password + "hash"
FIXTURE_MISSING_MAP = {
"get_wireless_scan_info": ("wireless", {"ap_list": [], "wep_supported": False}),
}
async def send(self, request: str):
request_dict = json_loads(request)
method = request_dict["method"]
params = request_dict["params"]
if method == "multipleRequest":
responses = []
for request in params["requests"]:
response = self._send_request(request) # type: ignore[arg-type]
response["method"] = request["method"] # type: ignore[index]
responses.append(response)
return {"result": {"responses": responses}, "error_code": 0}
else:
return self._send_request(request_dict)
def _send_request(self, request_dict: dict):
method = request_dict["method"]
params = request_dict["params"]
info = self.info
if method == "control_child":
device_id = params.get("device_id")
request_data = params.get("requestData")
child_method = request_data.get("method")
child_params = request_data.get("params")
children = info["get_child_device_list"]["child_device_list"]
for child in children:
if child["device_id"] == device_id:
info = child
break
# We only support get & set device info for now.
if child_method == "get_device_info":
return {"result": info, "error_code": 0}
elif child_method == "set_device_info":
info.update(child_params)
return {"error_code": 0}
raise NotImplementedError(
"Method %s not implemented for children" % child_method
)
if method == "component_nego" or method[:4] == "get_":
if method in info:
return {"result": info[method], "error_code": 0}
elif (
missing_result := self.FIXTURE_MISSING_MAP.get(method)
) and missing_result[0] in self.components:
warnings.warn(
UserWarning(
f"Fixture missing expected method {method}, try to regenerate"
),
stacklevel=1,
)
return {"result": missing_result[1], "error_code": 0}
else:
raise SmartDeviceException(f"Fixture doesn't support {method}")
elif method == "set_qs_info":
return {"error_code": 0}
elif method[:4] == "set_":
target_method = f"get_{method[4:]}"
info[target_method].update(params)
return {"error_code": 0}
async def close(self) -> None:
pass
async def reset(self) -> None:
pass
class FakeTransportProtocol(IotProtocol):
def __init__(self, info): def __init__(self, info):
super().__init__( super().__init__(
transport=XorTransport( transport=XorTransport(
@ -420,7 +132,7 @@ class FakeTransportProtocol(IotProtocol):
self.discovery_data = info self.discovery_data = info
self.writer = None self.writer = None
self.reader = None self.reader = None
proto = copy.deepcopy(FakeTransportProtocol.baseproto) proto = copy.deepcopy(FakeIotProtocol.baseproto)
for target in info: for target in info:
# print("target %s" % target) # print("target %s" % target)

View File

@ -0,0 +1,125 @@
import warnings
from json import loads as json_loads
from kasa import Credentials, DeviceConfig, SmartDeviceException, SmartProtocol
from kasa.protocol import BaseTransport
class FakeSmartProtocol(SmartProtocol):
def __init__(self, info):
super().__init__(
transport=FakeSmartTransport(info),
)
async def query(self, request, retry_count: int = 3):
"""Implement query here so can still patch SmartProtocol.query."""
resp_dict = await self._query(request, retry_count)
return resp_dict
class FakeSmartTransport(BaseTransport):
def __init__(self, info):
super().__init__(
config=DeviceConfig(
"127.0.0.123",
credentials=Credentials(
username="dummy_user",
password="dummy_password", # noqa: S106
),
),
)
self.info = info
self.components = {
comp["id"]: comp["ver_code"]
for comp in self.info["component_nego"]["component_list"]
}
@property
def default_port(self):
"""Default port for the transport."""
return 80
@property
def credentials_hash(self):
"""The hashed credentials used by the transport."""
return self._credentials.username + self._credentials.password + "hash"
FIXTURE_MISSING_MAP = {
"get_wireless_scan_info": ("wireless", {"ap_list": [], "wep_supported": False}),
}
async def send(self, request: str):
request_dict = json_loads(request)
method = request_dict["method"]
params = request_dict["params"]
if method == "multipleRequest":
responses = []
for request in params["requests"]:
response = self._send_request(request) # type: ignore[arg-type]
response["method"] = request["method"] # type: ignore[index]
responses.append(response)
return {"result": {"responses": responses}, "error_code": 0}
else:
return self._send_request(request_dict)
def _handle_control_child(self, params: dict):
"""Handle control_child command."""
device_id = params.get("device_id")
request_data = params.get("requestData", {})
child_method = request_data.get("method")
child_params = request_data.get("params")
info = self.info
children = info["get_child_device_list"]["child_device_list"]
for child in children:
if child["device_id"] == device_id:
info = child
break
# We only support get & set device info for now.
if child_method == "get_device_info":
return {"result": info, "error_code": 0}
elif child_method == "set_device_info":
info.update(child_params)
return {"error_code": 0}
raise NotImplementedError(
"Method %s not implemented for children" % child_method
)
def _send_request(self, request_dict: dict):
method = request_dict["method"]
params = request_dict["params"]
info = self.info
if method == "control_child":
return self._handle_control_child(params)
elif method == "component_nego" or method[:4] == "get_":
if method in info:
return {"result": info[method], "error_code": 0}
elif (
missing_result := self.FIXTURE_MISSING_MAP.get(method)
) and missing_result[0] in self.components:
warnings.warn(
UserWarning(
f"Fixture missing expected method {method}, try to regenerate"
),
stacklevel=1,
)
return {"result": missing_result[1], "error_code": 0}
else:
raise SmartDeviceException(f"Fixture doesn't support {method}")
elif method == "set_qs_info":
return {"error_code": 0}
elif method[:4] == "set_":
target_method = f"get_{method[4:]}"
info[target_method].update(params)
return {"error_code": 0}
async def close(self) -> None:
pass
async def reset(self) -> None:
pass

View File

@ -1,4 +1,15 @@
import pytest import pytest
from voluptuous import (
REMOVE_EXTRA,
All,
Any,
Boolean,
Coerce, # type: ignore
Invalid,
Optional,
Range,
Schema,
)
from kasa import DeviceType, SmartBulb, SmartBulbPreset, SmartDeviceException from kasa import DeviceType, SmartBulb, SmartBulbPreset, SmartDeviceException
@ -16,13 +27,13 @@ from .conftest import (
variable_temp, variable_temp,
variable_temp_iot, variable_temp_iot,
) )
from .newfakes import BULB_SCHEMA, LIGHT_STATE_SCHEMA from .test_smartdevice import SYSINFO_SCHEMA
@bulb @bulb
async def test_bulb_sysinfo(dev: SmartBulb): async def test_bulb_sysinfo(dev: SmartBulb):
assert dev.sys_info is not None assert dev.sys_info is not None
BULB_SCHEMA(dev.sys_info) SYSINFO_SCHEMA_BULB(dev.sys_info)
assert dev.model is not None assert dev.model is not None
@ -316,3 +327,49 @@ async def test_modify_preset_payloads(dev: SmartBulb, preset, payload, mocker):
query_helper = mocker.patch("kasa.SmartBulb._query_helper") query_helper = mocker.patch("kasa.SmartBulb._query_helper")
await dev.save_preset(preset) await dev.save_preset(preset)
query_helper.assert_called_with(dev.LIGHT_SERVICE, "set_preferred_state", payload) query_helper.assert_called_with(dev.LIGHT_SERVICE, "set_preferred_state", payload)
LIGHT_STATE_SCHEMA = Schema(
{
"brightness": All(int, Range(min=0, max=100)),
"color_temp": int,
"hue": All(int, Range(min=0, max=360)),
"mode": str,
"on_off": Boolean,
"saturation": All(int, Range(min=0, max=100)),
"dft_on_state": Optional(
{
"brightness": All(int, Range(min=0, max=100)),
"color_temp": All(int, Range(min=0, max=9000)),
"hue": All(int, Range(min=0, max=360)),
"mode": str,
"saturation": All(int, Range(min=0, max=100)),
}
),
"err_code": int,
}
)
SYSINFO_SCHEMA_BULB = SYSINFO_SCHEMA.extend(
{
"ctrl_protocols": Optional(dict),
"description": Optional(str), # Seen on LBxxx, similar to dev_name
"dev_state": str,
"disco_ver": str,
"heapsize": int,
"is_color": Boolean,
"is_dimmable": Boolean,
"is_factory": Boolean,
"is_variable_color_temp": Boolean,
"light_state": LIGHT_STATE_SCHEMA,
"preferred_state": [
{
"brightness": All(int, Range(min=0, max=100)),
"color_temp": int,
"hue": All(int, Range(min=0, max=360)),
"index": int,
"saturation": All(int, Range(min=0, max=100)),
}
],
}
)

View File

@ -0,0 +1,31 @@
from kasa.smartprotocol import _ChildProtocolWrapper
from kasa.tapo import ChildDevice
from .conftest import strip_smart
@strip_smart
def test_childdevice_init(dev, dummy_protocol, mocker):
"""Test that child devices get initialized and use protocol wrapper."""
assert len(dev.children) > 0
assert dev.is_strip
first = dev.children[0]
assert isinstance(first.protocol, _ChildProtocolWrapper)
assert first._info["category"] == "plug.powerstrip.sub-plug"
assert "position" in first._info
@strip_smart
async def test_childdevice_update(dev, dummy_protocol, mocker):
"""Test that parent update updates children."""
assert len(dev.children) > 0
first = dev.children[0]
child_update = mocker.patch.object(first, "update")
await dev.update()
child_update.assert_called()
assert dev._last_update != first._last_update
assert dev._last_update["child_info"]["child_device_list"][0] == first._last_update

View File

@ -2,12 +2,38 @@ import datetime
from unittest.mock import Mock from unittest.mock import Mock
import pytest import pytest
from voluptuous import (
REMOVE_EXTRA,
All,
Any,
Coerce, # type: ignore
Invalid,
Optional,
Range,
Schema,
)
from kasa import EmeterStatus, SmartDeviceException from kasa import EmeterStatus, SmartDeviceException
from kasa.modules.emeter import Emeter from kasa.modules.emeter import Emeter
from .conftest import has_emeter, has_emeter_iot, no_emeter from .conftest import has_emeter, has_emeter_iot, no_emeter
from .newfakes import CURRENT_CONSUMPTION_SCHEMA
CURRENT_CONSUMPTION_SCHEMA = Schema(
Any(
{
"voltage": Any(All(float, Range(min=0, max=300)), None),
"power": Any(Coerce(float, Range(min=0)), None),
"total": Any(Coerce(float, Range(min=0)), None),
"current": Any(All(float, Range(min=0)), None),
"voltage_mv": Any(All(float, Range(min=0, max=300000)), int, None),
"power_mw": Any(Coerce(float, Range(min=0)), None),
"total_wh": Any(Coerce(float, Range(min=0)), None),
"current_ma": Any(All(float, Range(min=0)), int, None),
"slot_id": Any(Coerce(int, Range(min=0)), None),
},
None,
)
)
@no_emeter @no_emeter

View File

@ -1,13 +1,17 @@
from kasa import DeviceType from kasa import DeviceType
from .conftest import plug, plug_smart from .conftest import plug, plug_smart
from .newfakes import PLUG_SCHEMA from .test_smartdevice import SYSINFO_SCHEMA
# these schemas should go to the mainlib as
# they can be useful when adding support for new features/devices
# as well as to check that faked devices are operating properly.
@plug @plug
async def test_plug_sysinfo(dev): async def test_plug_sysinfo(dev):
assert dev.sys_info is not None assert dev.sys_info is not None
PLUG_SCHEMA(dev.sys_info) SYSINFO_SCHEMA(dev.sys_info)
assert dev.model is not None assert dev.model is not None

View File

@ -1,14 +1,26 @@
import inspect import inspect
import re
from datetime import datetime from datetime import datetime
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
from voluptuous import (
REMOVE_EXTRA,
All,
Any,
Boolean,
In,
Invalid,
Optional,
Range,
Schema,
)
import kasa import kasa
from kasa import Credentials, DeviceConfig, SmartDevice, SmartDeviceException from kasa import Credentials, DeviceConfig, SmartDevice, SmartDeviceException
from .conftest import device_iot, handle_turn_on, has_emeter_iot, no_emeter_iot, turn_on from .conftest import device_iot, handle_turn_on, has_emeter_iot, no_emeter_iot, turn_on
from .newfakes import PLUG_SCHEMA, TZ_SCHEMA, FakeTransportProtocol from .fakeprotocol_iot import FakeIotProtocol
# List of all SmartXXX classes including the SmartDevice base class # List of all SmartXXX classes including the SmartDevice base class
smart_device_classes = [ smart_device_classes = [
@ -30,7 +42,7 @@ async def test_state_info(dev):
@device_iot @device_iot
async def test_invalid_connection(dev): async def test_invalid_connection(dev):
with patch.object( with patch.object(
FakeTransportProtocol, "query", side_effect=SmartDeviceException FakeIotProtocol, "query", side_effect=SmartDeviceException
), pytest.raises(SmartDeviceException): ), pytest.raises(SmartDeviceException):
await dev.update() await dev.update()
@ -133,22 +145,22 @@ async def test_timezone(dev):
@device_iot @device_iot
async def test_hw_info(dev): async def test_hw_info(dev):
PLUG_SCHEMA(dev.hw_info) SYSINFO_SCHEMA(dev.hw_info)
@device_iot @device_iot
async def test_location(dev): async def test_location(dev):
PLUG_SCHEMA(dev.location) SYSINFO_SCHEMA(dev.location)
@device_iot @device_iot
async def test_rssi(dev): async def test_rssi(dev):
PLUG_SCHEMA({"rssi": dev.rssi}) # wrapping for vol SYSINFO_SCHEMA({"rssi": dev.rssi}) # wrapping for vol
@device_iot @device_iot
async def test_mac(dev): async def test_mac(dev):
PLUG_SCHEMA({"mac": dev.mac}) # wrapping for val SYSINFO_SCHEMA({"mac": dev.mac}) # wrapping for val
@device_iot @device_iot
@ -263,3 +275,63 @@ async def test_modules_not_supported(dev: SmartDevice):
await dev.update() await dev.update()
for module in dev.modules.values(): for module in dev.modules.values():
assert module.is_supported is not None assert module.is_supported is not None
def check_mac(x):
if re.match("[0-9a-f]{2}([-:])[0-9a-f]{2}(\\1[0-9a-f]{2}){4}$", x.lower()):
return x
raise Invalid(x)
TZ_SCHEMA = Schema(
{"zone_str": str, "dst_offset": int, "index": All(int, Range(min=0)), "tz_str": str}
)
SYSINFO_SCHEMA = Schema(
{
"active_mode": In(["schedule", "none", "count_down"]),
"alias": str,
"dev_name": str,
"deviceId": str,
"feature": str,
"fwId": str,
"hwId": str,
"hw_ver": str,
"icon_hash": str,
"led_off": Boolean,
"latitude": Any(All(float, Range(min=-90, max=90)), 0, None),
"latitude_i": Any(
All(int, Range(min=-900000, max=900000)),
All(float, Range(min=-900000, max=900000)),
0,
None,
),
"longitude": Any(All(float, Range(min=-180, max=180)), 0, None),
"longitude_i": Any(
All(int, Range(min=-18000000, max=18000000)),
All(float, Range(min=-18000000, max=18000000)),
0,
None,
),
"mac": check_mac,
"model": str,
"oemId": str,
"on_time": int,
"relay_state": int,
"rssi": Any(int, None), # rssi can also be positive, see #54
"sw_ver": str,
"type": str,
"mic_type": str,
"updating": Boolean,
# these are available on hs220
"brightness": int,
"preferred_state": [
{"brightness": All(int, Range(min=0, max=100)), "index": int}
],
"next_action": {"type": int},
"child_num": Optional(Any(None, int)),
"children": Optional(list),
},
extra=REMOVE_EXTRA,
)