import logging import pytest import pytest_mock from ..exceptions import ( SMART_RETRYABLE_ERRORS, DeviceError, KasaException, SmartErrorCode, ) from ..smartprotocol import SmartProtocol, _ChildProtocolWrapper from .fakeprotocol_smart import FakeSmartTransport DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}} DUMMY_MULTIPLE_QUERY = { "foobar": {"foo": "bar", "bar": "foo"}, "barfoo": {"foo": "bar", "bar": "foo"}, } ERRORS = [e for e in SmartErrorCode if e != 0] async def test_smart_queries(dummy_protocol, mocker: pytest_mock.MockerFixture): mock_response = {"result": {"great": "success"}, "error_code": 0} mocker.patch.object(dummy_protocol._transport, "send", return_value=mock_response) # test sending a method name as a string resp = await dummy_protocol.query("foobar") assert "foobar" in resp assert resp["foobar"] == mock_response["result"] # test sending a method name as a dict resp = await dummy_protocol.query(DUMMY_QUERY) assert "foobar" in resp assert resp["foobar"] == mock_response["result"] @pytest.mark.parametrize("error_code", ERRORS, ids=lambda e: e.name) async def test_smart_device_errors(dummy_protocol, mocker, error_code): mock_response = {"result": {"great": "success"}, "error_code": error_code.value} send_mock = mocker.patch.object( dummy_protocol._transport, "send", return_value=mock_response ) with pytest.raises(KasaException): await dummy_protocol.query(DUMMY_QUERY, retry_count=2) expected_calls = 3 if error_code in SMART_RETRYABLE_ERRORS else 1 assert send_mock.call_count == expected_calls @pytest.mark.parametrize("error_code", [-13333, 13333]) async def test_smart_device_unknown_errors( dummy_protocol, mocker, error_code, caplog: pytest.LogCaptureFixture ): """Test handling of unknown error codes.""" mock_response = {"result": {"great": "success"}, "error_code": error_code} send_mock = mocker.patch.object( dummy_protocol._transport, "send", return_value=mock_response ) with pytest.raises(KasaException): res = await dummy_protocol.query(DUMMY_QUERY) assert res is SmartErrorCode.INTERNAL_UNKNOWN_ERROR send_mock.assert_called_once() assert f"Received unknown error code: {error_code}" in caplog.text @pytest.mark.parametrize("error_code", ERRORS, ids=lambda e: e.name) async def test_smart_device_errors_in_multiple_request( dummy_protocol, mocker, error_code ): mock_request = { "foobar1": {"foo": "bar", "bar": "foo"}, "foobar2": {"foo": "bar", "bar": "foo"}, "foobar3": {"foo": "bar", "bar": "foo"}, } mock_response = { "result": { "responses": [ {"method": "foobar1", "result": {"great": "success"}, "error_code": 0}, { "method": "foobar2", "result": {"great": "success"}, "error_code": error_code.value, }, {"method": "foobar3", "result": {"great": "success"}, "error_code": 0}, ] }, "error_code": 0, } send_mock = mocker.patch.object( dummy_protocol._transport, "send", return_value=mock_response ) resp_dict = await dummy_protocol.query(mock_request, retry_count=2) assert resp_dict["foobar2"] == error_code assert send_mock.call_count == 1 assert len(resp_dict) == len(mock_request) @pytest.mark.parametrize("request_size", [1, 3, 5, 10]) @pytest.mark.parametrize("batch_size", [1, 2, 3, 4, 5]) async def test_smart_device_multiple_request( dummy_protocol, mocker, request_size, batch_size ): requests = {} mock_response = { "result": {"responses": []}, "error_code": 0, } for i in range(request_size): method = f"get_method_{i}" requests[method] = {"foo": "bar", "bar": "foo"} mock_response["result"]["responses"].append( {"method": method, "result": {"great": "success"}, "error_code": 0} ) send_mock = mocker.patch.object( dummy_protocol._transport, "send", return_value=mock_response ) dummy_protocol._multi_request_batch_size = batch_size await dummy_protocol.query(requests, retry_count=0) expected_count = int(request_size / batch_size) + (request_size % batch_size > 0) assert send_mock.call_count == expected_count async def test_smart_device_multiple_request_json_decode_failure( dummy_protocol, mocker ): """Test the logic to disable multiple requests on JSON_DECODE_FAIL_ERROR.""" requests = {} mock_responses = [] mock_json_error = { "result": {"responses": []}, "error_code": SmartErrorCode.JSON_DECODE_FAIL_ERROR.value, } for i in range(10): method = f"get_method_{i}" requests[method] = {"foo": "bar", "bar": "foo"} mock_responses.append( {"method": method, "result": {"great": "success"}, "error_code": 0} ) send_mock = mocker.patch.object( dummy_protocol._transport, "send", side_effect=[mock_json_error, *mock_responses], ) dummy_protocol._multi_request_batch_size = 5 assert dummy_protocol._multi_request_batch_size == 5 await dummy_protocol.query(requests, retry_count=1) assert dummy_protocol._multi_request_batch_size == 1 # Call count should be the first error + number of requests assert send_mock.call_count == len(requests) + 1 async def test_smart_device_multiple_request_json_decode_failure_twice( dummy_protocol, mocker ): """Test the logic to disable multiple requests on JSON_DECODE_FAIL_ERROR.""" requests = {} mock_json_error = { "result": {"responses": []}, "error_code": SmartErrorCode.JSON_DECODE_FAIL_ERROR.value, } for i in range(10): method = f"get_method_{i}" requests[method] = {"foo": "bar", "bar": "foo"} send_mock = mocker.patch.object( dummy_protocol._transport, "send", side_effect=[mock_json_error, KasaException], ) dummy_protocol._multi_request_batch_size = 5 with pytest.raises(KasaException): await dummy_protocol.query(requests, retry_count=1) assert dummy_protocol._multi_request_batch_size == 1 assert send_mock.call_count == 2 async def test_smart_device_multiple_request_non_json_decode_failure( dummy_protocol, mocker ): """Test the logic to disable multiple requests on JSON_DECODE_FAIL_ERROR. Ensure other exception types behave as expected. """ requests = {} mock_json_error = { "result": {"responses": []}, "error_code": SmartErrorCode.UNKNOWN_METHOD_ERROR.value, } for i in range(10): method = f"get_method_{i}" requests[method] = {"foo": "bar", "bar": "foo"} send_mock = mocker.patch.object( dummy_protocol._transport, "send", side_effect=[mock_json_error, KasaException], ) dummy_protocol._multi_request_batch_size = 5 with pytest.raises(DeviceError): await dummy_protocol.query(requests, retry_count=1) assert dummy_protocol._multi_request_batch_size == 5 assert send_mock.call_count == 1 async def test_childdevicewrapper_unwrapping(dummy_protocol, mocker): """Test that responseData gets unwrapped correctly.""" wrapped_protocol = _ChildProtocolWrapper("dummyid", dummy_protocol) mock_response = {"error_code": 0, "result": {"responseData": {"error_code": 0}}} mocker.patch.object(wrapped_protocol._transport, "send", return_value=mock_response) res = await wrapped_protocol.query(DUMMY_QUERY) assert res == {"foobar": None} async def test_childdevicewrapper_unwrapping_with_payload(dummy_protocol, mocker): wrapped_protocol = _ChildProtocolWrapper("dummyid", dummy_protocol) mock_response = { "error_code": 0, "result": {"responseData": {"error_code": 0, "result": {"bar": "bar"}}}, } mocker.patch.object(wrapped_protocol._transport, "send", return_value=mock_response) res = await wrapped_protocol.query(DUMMY_QUERY) assert res == {"foobar": {"bar": "bar"}} async def test_childdevicewrapper_error(dummy_protocol, mocker): """Test that errors inside the responseData payload cause an exception.""" wrapped_protocol = _ChildProtocolWrapper("dummyid", dummy_protocol) mock_response = {"error_code": 0, "result": {"responseData": {"error_code": -1001}}} mocker.patch.object(wrapped_protocol._transport, "send", return_value=mock_response) with pytest.raises(KasaException): await wrapped_protocol.query(DUMMY_QUERY) async def test_childdevicewrapper_unwrapping_multiplerequest(dummy_protocol, mocker): """Test that unwrapping multiplerequest works correctly.""" mock_response = { "error_code": 0, "result": { "responseData": { "result": { "responses": [ { "error_code": 0, "method": "get_device_info", "result": {"foo": "bar"}, }, { "error_code": 0, "method": "second_command", "result": {"bar": "foo"}, }, ] } } }, } wrapped_protocol = _ChildProtocolWrapper("dummyid", dummy_protocol) mocker.patch.object(wrapped_protocol._transport, "send", return_value=mock_response) resp = await wrapped_protocol.query(DUMMY_QUERY) assert resp == {"get_device_info": {"foo": "bar"}, "second_command": {"bar": "foo"}} async def test_childdevicewrapper_multiplerequest_error(dummy_protocol, mocker): """Test that errors inside multipleRequest response of responseData raise an exception.""" mock_response = { "error_code": 0, "result": { "responseData": { "result": { "responses": [ { "error_code": 0, "method": "get_device_info", "result": {"foo": "bar"}, }, {"error_code": -1001, "method": "invalid_command"}, ] } } }, } wrapped_protocol = _ChildProtocolWrapper("dummyid", dummy_protocol) mocker.patch.object(wrapped_protocol._transport, "send", return_value=mock_response) res = await wrapped_protocol.query(DUMMY_QUERY) assert res["get_device_info"] == {"foo": "bar"} assert res["invalid_command"] == SmartErrorCode(-1001) @pytest.mark.parametrize("list_sum", [5, 10, 30]) @pytest.mark.parametrize("batch_size", [1, 2, 3, 50]) async def test_smart_protocol_lists_single_request(mocker, list_sum, batch_size): child_device_list = [{"foo": i} for i in range(list_sum)] response = { "get_child_device_list": { "child_device_list": child_device_list, "start_index": 0, "sum": list_sum, } } request = {"get_child_device_list": None} ft = FakeSmartTransport( response, "foobar", list_return_size=batch_size, component_nego_not_included=True, ) protocol = SmartProtocol(transport=ft) query_spy = mocker.spy(protocol, "_execute_query") resp = await protocol.query(request) expected_count = int(list_sum / batch_size) + (1 if list_sum % batch_size else 0) assert query_spy.call_count == expected_count assert resp == response @pytest.mark.parametrize("list_sum", [5, 10, 30]) @pytest.mark.parametrize("batch_size", [1, 2, 3, 50]) async def test_smart_protocol_lists_multiple_request(mocker, list_sum, batch_size): child_list = [{"foo": i} for i in range(list_sum)] response = { "get_child_device_list": { "child_device_list": child_list, "start_index": 0, "sum": list_sum, }, "get_child_device_component_list": { "child_component_list": child_list, "start_index": 0, "sum": list_sum, }, } request = {"get_child_device_list": None, "get_child_device_component_list": None} ft = FakeSmartTransport( response, "foobar", list_return_size=batch_size, component_nego_not_included=True, ) protocol = SmartProtocol(transport=ft) query_spy = mocker.spy(protocol, "_execute_query") resp = await protocol.query(request) expected_count = 1 + 2 * ( int(list_sum / batch_size) + (0 if list_sum % batch_size else -1) ) assert query_spy.call_count == expected_count assert resp == response async def test_incomplete_list(mocker, caplog): """Test for handling incomplete lists returned from queries.""" info = { "get_preset_rules": { "start_index": 0, "states": [ { "brightness": 50, }, { "brightness": 100, }, ], "sum": 7, } } caplog.set_level(logging.ERROR) transport = FakeSmartTransport( info, "dummy-name", component_nego_not_included=True, warn_fixture_missing_methods=False, ) protocol = SmartProtocol(transport=transport) resp = await protocol.query({"get_preset_rules": None}) assert resp assert resp["get_preset_rules"]["sum"] == 2 # FakeTransport fixes sum assert caplog.text == "" # Test behaviour without FakeTranport fix transport = FakeSmartTransport( info, "dummy-name", component_nego_not_included=True, warn_fixture_missing_methods=False, fix_incomplete_fixture_lists=False, ) protocol = SmartProtocol(transport=transport) resp = await protocol.query({"get_preset_rules": None}) assert resp["get_preset_rules"]["sum"] == 7 assert ( "Device 127.0.0.123 returned empty results list for method get_preset_rules" in caplog.text )