mirror of
https://github.com/python-kasa/python-kasa.git
synced 2024-12-23 11:43:34 +00:00
fcb604e435
* Transport tests under tests/transports/ * Protocol tests under tests/protocols/ * IOT tests under iot/ * Plus some minor cleanups, most code changes are related to splitting up smart & iot tests
451 lines
15 KiB
Python
451 lines
15 KiB
Python
import logging
|
|
|
|
import pytest
|
|
import pytest_mock
|
|
|
|
from kasa.exceptions import (
|
|
SMART_RETRYABLE_ERRORS,
|
|
DeviceError,
|
|
KasaException,
|
|
SmartErrorCode,
|
|
)
|
|
from kasa.protocols.smartprotocol import SmartProtocol, _ChildProtocolWrapper
|
|
from kasa.smart import SmartDevice
|
|
|
|
from ..conftest import device_smart
|
|
from ..fakeprotocol_smart import FakeSmartTransport
|
|
|
|
DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}}
|
|
DUMMY_MULTIPLE_QUERY = {
|
|
"foobar": {"foo": "bar", "bar": "foo"},
|
|
"barfoo": {"foo": "bar", "bar": "foo"},
|
|
}
|
|
ERRORS = [e for e in SmartErrorCode if e != 0]
|
|
|
|
|
|
async def test_smart_queries(dummy_protocol, mocker: pytest_mock.MockerFixture):
|
|
mock_response = {"result": {"great": "success"}, "error_code": 0}
|
|
|
|
mocker.patch.object(dummy_protocol._transport, "send", return_value=mock_response)
|
|
# test sending a method name as a string
|
|
resp = await dummy_protocol.query("foobar")
|
|
assert "foobar" in resp
|
|
assert resp["foobar"] == mock_response["result"]
|
|
|
|
# test sending a method name as a dict
|
|
resp = await dummy_protocol.query(DUMMY_QUERY)
|
|
assert "foobar" in resp
|
|
assert resp["foobar"] == mock_response["result"]
|
|
|
|
|
|
@pytest.mark.parametrize("error_code", ERRORS, ids=lambda e: e.name)
|
|
async def test_smart_device_errors(dummy_protocol, mocker, error_code):
|
|
mock_response = {"result": {"great": "success"}, "error_code": error_code.value}
|
|
|
|
send_mock = mocker.patch.object(
|
|
dummy_protocol._transport, "send", return_value=mock_response
|
|
)
|
|
|
|
with pytest.raises(KasaException):
|
|
await dummy_protocol.query(DUMMY_QUERY, retry_count=2)
|
|
|
|
expected_calls = 3 if error_code in SMART_RETRYABLE_ERRORS else 1
|
|
assert send_mock.call_count == expected_calls
|
|
|
|
|
|
@pytest.mark.parametrize("error_code", [-13333, 13333])
|
|
@pytest.mark.xdist_group(name="caplog")
|
|
async def test_smart_device_unknown_errors(
|
|
dummy_protocol, mocker, error_code, caplog: pytest.LogCaptureFixture
|
|
):
|
|
"""Test handling of unknown error codes."""
|
|
mock_response = {"result": {"great": "success"}, "error_code": error_code}
|
|
|
|
send_mock = mocker.patch.object(
|
|
dummy_protocol._transport, "send", return_value=mock_response
|
|
)
|
|
|
|
with pytest.raises(KasaException): # noqa: PT012
|
|
res = await dummy_protocol.query(DUMMY_QUERY)
|
|
assert res is SmartErrorCode.INTERNAL_UNKNOWN_ERROR
|
|
|
|
send_mock.assert_called_once()
|
|
assert f"received unknown error code: {error_code}" in caplog.text
|
|
|
|
|
|
@pytest.mark.parametrize("error_code", ERRORS, ids=lambda e: e.name)
|
|
async def test_smart_device_errors_in_multiple_request(
|
|
dummy_protocol, mocker, error_code
|
|
):
|
|
mock_request = {
|
|
"foobar1": {"foo": "bar", "bar": "foo"},
|
|
"foobar2": {"foo": "bar", "bar": "foo"},
|
|
"foobar3": {"foo": "bar", "bar": "foo"},
|
|
}
|
|
mock_response = {
|
|
"result": {
|
|
"responses": [
|
|
{"method": "foobar1", "result": {"great": "success"}, "error_code": 0},
|
|
{
|
|
"method": "foobar2",
|
|
"result": {"great": "success"},
|
|
"error_code": error_code.value,
|
|
},
|
|
{"method": "foobar3", "result": {"great": "success"}, "error_code": 0},
|
|
]
|
|
},
|
|
"error_code": 0,
|
|
}
|
|
|
|
send_mock = mocker.patch.object(
|
|
dummy_protocol._transport, "send", return_value=mock_response
|
|
)
|
|
|
|
resp_dict = await dummy_protocol.query(mock_request, retry_count=2)
|
|
assert resp_dict["foobar2"] == error_code
|
|
assert send_mock.call_count == 1
|
|
assert len(resp_dict) == len(mock_request)
|
|
|
|
|
|
@pytest.mark.parametrize("request_size", [1, 3, 5, 10])
|
|
@pytest.mark.parametrize("batch_size", [1, 2, 3, 4, 5])
|
|
async def test_smart_device_multiple_request(
|
|
dummy_protocol, mocker, request_size, batch_size
|
|
):
|
|
requests = {}
|
|
mock_response = {
|
|
"result": {"responses": []},
|
|
"error_code": 0,
|
|
}
|
|
for i in range(request_size):
|
|
method = f"get_method_{i}"
|
|
requests[method] = {"foo": "bar", "bar": "foo"}
|
|
mock_response["result"]["responses"].append(
|
|
{"method": method, "result": {"great": "success"}, "error_code": 0}
|
|
)
|
|
|
|
send_mock = mocker.patch.object(
|
|
dummy_protocol._transport, "send", return_value=mock_response
|
|
)
|
|
dummy_protocol._multi_request_batch_size = batch_size
|
|
|
|
await dummy_protocol.query(requests, retry_count=0)
|
|
expected_count = int(request_size / batch_size) + (request_size % batch_size > 0)
|
|
assert send_mock.call_count == expected_count
|
|
|
|
|
|
async def test_smart_device_multiple_request_json_decode_failure(
|
|
dummy_protocol, mocker
|
|
):
|
|
"""Test the logic to disable multiple requests on JSON_DECODE_FAIL_ERROR."""
|
|
requests = {}
|
|
mock_responses = []
|
|
|
|
mock_json_error = {
|
|
"result": {"responses": []},
|
|
"error_code": SmartErrorCode.JSON_DECODE_FAIL_ERROR.value,
|
|
}
|
|
for i in range(10):
|
|
method = f"get_method_{i}"
|
|
requests[method] = {"foo": "bar", "bar": "foo"}
|
|
mock_responses.append(
|
|
{"method": method, "result": {"great": "success"}, "error_code": 0}
|
|
)
|
|
|
|
send_mock = mocker.patch.object(
|
|
dummy_protocol._transport,
|
|
"send",
|
|
side_effect=[mock_json_error, *mock_responses],
|
|
)
|
|
dummy_protocol._multi_request_batch_size = 5
|
|
assert dummy_protocol._multi_request_batch_size == 5
|
|
await dummy_protocol.query(requests, retry_count=1)
|
|
assert dummy_protocol._multi_request_batch_size == 1
|
|
# Call count should be the first error + number of requests
|
|
assert send_mock.call_count == len(requests) + 1
|
|
|
|
|
|
async def test_smart_device_multiple_request_json_decode_failure_twice(
|
|
dummy_protocol, mocker
|
|
):
|
|
"""Test the logic to disable multiple requests on JSON_DECODE_FAIL_ERROR."""
|
|
requests = {}
|
|
|
|
mock_json_error = {
|
|
"result": {"responses": []},
|
|
"error_code": SmartErrorCode.JSON_DECODE_FAIL_ERROR.value,
|
|
}
|
|
for i in range(10):
|
|
method = f"get_method_{i}"
|
|
requests[method] = {"foo": "bar", "bar": "foo"}
|
|
|
|
send_mock = mocker.patch.object(
|
|
dummy_protocol._transport,
|
|
"send",
|
|
side_effect=[mock_json_error, KasaException],
|
|
)
|
|
dummy_protocol._multi_request_batch_size = 5
|
|
with pytest.raises(KasaException):
|
|
await dummy_protocol.query(requests, retry_count=1)
|
|
assert dummy_protocol._multi_request_batch_size == 1
|
|
|
|
assert send_mock.call_count == 2
|
|
|
|
|
|
async def test_smart_device_multiple_request_non_json_decode_failure(
|
|
dummy_protocol, mocker
|
|
):
|
|
"""Test the logic to disable multiple requests on JSON_DECODE_FAIL_ERROR.
|
|
|
|
Ensure other exception types behave as expected.
|
|
"""
|
|
requests = {}
|
|
|
|
mock_json_error = {
|
|
"result": {"responses": []},
|
|
"error_code": SmartErrorCode.UNKNOWN_METHOD_ERROR.value,
|
|
}
|
|
for i in range(10):
|
|
method = f"get_method_{i}"
|
|
requests[method] = {"foo": "bar", "bar": "foo"}
|
|
|
|
send_mock = mocker.patch.object(
|
|
dummy_protocol._transport,
|
|
"send",
|
|
side_effect=[mock_json_error, KasaException],
|
|
)
|
|
dummy_protocol._multi_request_batch_size = 5
|
|
with pytest.raises(DeviceError):
|
|
await dummy_protocol.query(requests, retry_count=1)
|
|
assert dummy_protocol._multi_request_batch_size == 5
|
|
|
|
assert send_mock.call_count == 1
|
|
|
|
|
|
async def test_childdevicewrapper_unwrapping(dummy_protocol, mocker):
|
|
"""Test that responseData gets unwrapped correctly."""
|
|
wrapped_protocol = _ChildProtocolWrapper("dummyid", dummy_protocol)
|
|
mock_response = {"error_code": 0, "result": {"responseData": {"error_code": 0}}}
|
|
|
|
mocker.patch.object(wrapped_protocol._transport, "send", return_value=mock_response)
|
|
res = await wrapped_protocol.query(DUMMY_QUERY)
|
|
assert res == {"foobar": None}
|
|
|
|
|
|
async def test_childdevicewrapper_unwrapping_with_payload(dummy_protocol, mocker):
|
|
wrapped_protocol = _ChildProtocolWrapper("dummyid", dummy_protocol)
|
|
mock_response = {
|
|
"error_code": 0,
|
|
"result": {"responseData": {"error_code": 0, "result": {"bar": "bar"}}},
|
|
}
|
|
mocker.patch.object(wrapped_protocol._transport, "send", return_value=mock_response)
|
|
res = await wrapped_protocol.query(DUMMY_QUERY)
|
|
assert res == {"foobar": {"bar": "bar"}}
|
|
|
|
|
|
async def test_childdevicewrapper_error(dummy_protocol, mocker):
|
|
"""Test that errors inside the responseData payload cause an exception."""
|
|
wrapped_protocol = _ChildProtocolWrapper("dummyid", dummy_protocol)
|
|
mock_response = {"error_code": 0, "result": {"responseData": {"error_code": -1001}}}
|
|
|
|
mocker.patch.object(wrapped_protocol._transport, "send", return_value=mock_response)
|
|
with pytest.raises(KasaException):
|
|
await wrapped_protocol.query(DUMMY_QUERY)
|
|
|
|
|
|
async def test_childdevicewrapper_unwrapping_multiplerequest(dummy_protocol, mocker):
|
|
"""Test that unwrapping multiplerequest works correctly."""
|
|
mock_response = {
|
|
"error_code": 0,
|
|
"result": {
|
|
"responseData": {
|
|
"result": {
|
|
"responses": [
|
|
{
|
|
"error_code": 0,
|
|
"method": "get_device_info",
|
|
"result": {"foo": "bar"},
|
|
},
|
|
{
|
|
"error_code": 0,
|
|
"method": "second_command",
|
|
"result": {"bar": "foo"},
|
|
},
|
|
]
|
|
}
|
|
}
|
|
},
|
|
}
|
|
wrapped_protocol = _ChildProtocolWrapper("dummyid", dummy_protocol)
|
|
mocker.patch.object(wrapped_protocol._transport, "send", return_value=mock_response)
|
|
resp = await wrapped_protocol.query(DUMMY_QUERY)
|
|
assert resp == {"get_device_info": {"foo": "bar"}, "second_command": {"bar": "foo"}}
|
|
|
|
|
|
async def test_childdevicewrapper_multiplerequest_error(dummy_protocol, mocker):
|
|
"""Test that errors inside multipleRequest response of responseData raise an exception."""
|
|
mock_response = {
|
|
"error_code": 0,
|
|
"result": {
|
|
"responseData": {
|
|
"result": {
|
|
"responses": [
|
|
{
|
|
"error_code": 0,
|
|
"method": "get_device_info",
|
|
"result": {"foo": "bar"},
|
|
},
|
|
{"error_code": -1001, "method": "invalid_command"},
|
|
]
|
|
}
|
|
}
|
|
},
|
|
}
|
|
wrapped_protocol = _ChildProtocolWrapper("dummyid", dummy_protocol)
|
|
mocker.patch.object(wrapped_protocol._transport, "send", return_value=mock_response)
|
|
res = await wrapped_protocol.query(DUMMY_QUERY)
|
|
assert res["get_device_info"] == {"foo": "bar"}
|
|
assert res["invalid_command"] == SmartErrorCode(-1001)
|
|
|
|
|
|
@pytest.mark.parametrize("list_sum", [5, 10, 30])
|
|
@pytest.mark.parametrize("batch_size", [1, 2, 3, 50])
|
|
async def test_smart_protocol_lists_single_request(mocker, list_sum, batch_size):
|
|
child_device_list = [{"foo": i} for i in range(list_sum)]
|
|
response = {
|
|
"get_child_device_list": {
|
|
"child_device_list": child_device_list,
|
|
"start_index": 0,
|
|
"sum": list_sum,
|
|
}
|
|
}
|
|
request = {"get_child_device_list": None}
|
|
|
|
ft = FakeSmartTransport(
|
|
response,
|
|
"foobar",
|
|
list_return_size=batch_size,
|
|
component_nego_not_included=True,
|
|
get_child_fixtures=False,
|
|
)
|
|
protocol = SmartProtocol(transport=ft)
|
|
query_spy = mocker.spy(protocol, "_execute_query")
|
|
resp = await protocol.query(request)
|
|
expected_count = int(list_sum / batch_size) + (1 if list_sum % batch_size else 0)
|
|
assert query_spy.call_count == expected_count
|
|
assert resp == response
|
|
|
|
|
|
@pytest.mark.parametrize("list_sum", [5, 10, 30])
|
|
@pytest.mark.parametrize("batch_size", [1, 2, 3, 50])
|
|
async def test_smart_protocol_lists_multiple_request(mocker, list_sum, batch_size):
|
|
child_list = [{"foo": i} for i in range(list_sum)]
|
|
response = {
|
|
"get_child_device_list": {
|
|
"child_device_list": child_list,
|
|
"start_index": 0,
|
|
"sum": list_sum,
|
|
},
|
|
"get_child_device_component_list": {
|
|
"child_component_list": child_list,
|
|
"start_index": 0,
|
|
"sum": list_sum,
|
|
},
|
|
}
|
|
request = {"get_child_device_list": None, "get_child_device_component_list": None}
|
|
|
|
ft = FakeSmartTransport(
|
|
response,
|
|
"foobar",
|
|
list_return_size=batch_size,
|
|
component_nego_not_included=True,
|
|
get_child_fixtures=False,
|
|
)
|
|
protocol = SmartProtocol(transport=ft)
|
|
query_spy = mocker.spy(protocol, "_execute_query")
|
|
resp = await protocol.query(request)
|
|
expected_count = 1 + 2 * (
|
|
int(list_sum / batch_size) + (0 if list_sum % batch_size else -1)
|
|
)
|
|
assert query_spy.call_count == expected_count
|
|
assert resp == response
|
|
|
|
|
|
async def test_incomplete_list(mocker, caplog):
|
|
"""Test for handling incomplete lists returned from queries."""
|
|
info = {
|
|
"get_preset_rules": {
|
|
"start_index": 0,
|
|
"states": [
|
|
{
|
|
"brightness": 50,
|
|
},
|
|
{
|
|
"brightness": 100,
|
|
},
|
|
],
|
|
"sum": 7,
|
|
}
|
|
}
|
|
caplog.set_level(logging.ERROR)
|
|
transport = FakeSmartTransport(
|
|
info,
|
|
"dummy-name",
|
|
component_nego_not_included=True,
|
|
warn_fixture_missing_methods=False,
|
|
)
|
|
protocol = SmartProtocol(transport=transport)
|
|
resp = await protocol.query({"get_preset_rules": None})
|
|
assert resp
|
|
assert resp["get_preset_rules"]["sum"] == 2 # FakeTransport fixes sum
|
|
assert caplog.text == ""
|
|
|
|
# Test behaviour without FakeTranport fix
|
|
transport = FakeSmartTransport(
|
|
info,
|
|
"dummy-name",
|
|
component_nego_not_included=True,
|
|
warn_fixture_missing_methods=False,
|
|
fix_incomplete_fixture_lists=False,
|
|
)
|
|
protocol = SmartProtocol(transport=transport)
|
|
resp = await protocol.query({"get_preset_rules": None})
|
|
assert resp["get_preset_rules"]["sum"] == 7
|
|
assert (
|
|
"Device 127.0.0.123 returned empty results list for method get_preset_rules"
|
|
in caplog.text
|
|
)
|
|
|
|
|
|
@device_smart
|
|
@pytest.mark.xdist_group(name="caplog")
|
|
async def test_smart_queries_redaction(
|
|
dev: SmartDevice, caplog: pytest.LogCaptureFixture
|
|
):
|
|
"""Test query sensitive info redaction."""
|
|
if isinstance(dev.protocol._transport, FakeSmartTransport):
|
|
device_id = "123456789ABCDEF"
|
|
dev.protocol._transport.info["get_device_info"]["device_id"] = device_id
|
|
else: # real device
|
|
device_id = dev.device_id
|
|
|
|
# Info no message logging
|
|
caplog.set_level(logging.INFO)
|
|
await dev.update()
|
|
assert device_id not in caplog.text
|
|
|
|
caplog.set_level(logging.DEBUG)
|
|
|
|
# Debug no redaction
|
|
caplog.clear()
|
|
dev.protocol._redact_data = False
|
|
await dev.update()
|
|
assert device_id in caplog.text
|
|
|
|
# Debug redaction
|
|
caplog.clear()
|
|
dev.protocol._redact_data = True
|
|
await dev.update()
|
|
assert device_id not in caplog.text
|
|
assert "REDACTED_" + device_id[9::] in caplog.text
|