diff --git a/kasa/tests/conftest.py b/kasa/tests/conftest.py index 8c6e7634..0d47080f 100644 --- a/kasa/tests/conftest.py +++ b/kasa/tests/conftest.py @@ -140,14 +140,3 @@ def mock_datagram_endpoint(request): # noqa: PT004 side_effect=_create_datagram_endpoint, ): yield - - -# allow mocks to be awaited -# https://stackoverflow.com/questions/51394411/python-object-magicmock-cant-be-used-in-await-expression/51399767#51399767 - - -async def async_magic(): - pass - - -MagicMock.__await__ = lambda x: async_magic().__await__() diff --git a/kasa/tests/test_device.py b/kasa/tests/test_device.py index 0aee5b56..f67d37c2 100644 --- a/kasa/tests/test_device.py +++ b/kasa/tests/test_device.py @@ -7,7 +7,7 @@ import inspect import pkgutil import sys from contextlib import AbstractContextManager -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, patch import pytest @@ -85,7 +85,7 @@ async def test_create_device_with_timeout(): async def test_create_thin_wrapper(): """Make sure thin wrapper is created with the correct device type.""" - mock = Mock() + mock = AsyncMock() config = DeviceConfig( host="test_host", port_override=1234, @@ -281,7 +281,7 @@ async def test_device_type_aliases(): """Test that the device type aliases in Device work.""" def _mock_connect(config, *args, **kwargs): - mock = Mock() + mock = AsyncMock() mock.config = config return mock diff --git a/kasa/tests/test_feature.py b/kasa/tests/test_feature.py index 83b7c24c..938f9547 100644 --- a/kasa/tests/test_feature.py +++ b/kasa/tests/test_feature.py @@ -1,6 +1,6 @@ import logging import sys -from unittest.mock import patch +from unittest.mock import AsyncMock, patch import pytest from pytest_mock import MockerFixture @@ -94,7 +94,9 @@ def test_feature_value_callable(dev, dummy_feature: Feature): async def test_feature_setter(dev, mocker, dummy_feature: Feature): """Verify that *set_value* calls the defined method.""" - mock_set_dummy = mocker.patch.object(dummy_feature.device, "set_dummy", create=True) + mock_set_dummy = mocker.patch.object( + dummy_feature.device, "set_dummy", create=True, new_callable=AsyncMock + ) dummy_feature.attribute_setter = "set_dummy" await dummy_feature.set_value("dummy value") mock_set_dummy.assert_called_with("dummy value") @@ -118,7 +120,9 @@ async def test_feature_action(mocker): icon="mdi:dummy", type=Feature.Type.Action, ) - mock_call_action = mocker.patch.object(feat.device, "call_action", create=True) + mock_call_action = mocker.patch.object( + feat.device, "call_action", create=True, new_callable=AsyncMock + ) assert feat.value == "" await feat.set_value(1234) mock_call_action.assert_called() @@ -129,7 +133,9 @@ async def test_feature_choice_list(dummy_feature, caplog, mocker: MockerFixture) dummy_feature.type = Feature.Type.Choice dummy_feature.choices_getter = lambda: ["first", "second"] - mock_setter = mocker.patch.object(dummy_feature.device, "dummysetter", create=True) + mock_setter = mocker.patch.object( + dummy_feature.device, "dummysetter", create=True, new_callable=AsyncMock + ) await dummy_feature.set_value("first") mock_setter.assert_called_with("first") mock_setter.reset_mock() diff --git a/kasa/tests/test_protocol.py b/kasa/tests/test_protocol.py index afb953dd..9c15795f 100644 --- a/kasa/tests/test_protocol.py +++ b/kasa/tests/test_protocol.py @@ -9,6 +9,7 @@ import pkgutil import struct import sys from typing import cast +from unittest.mock import AsyncMock import pytest @@ -175,6 +176,7 @@ async def test_protocol_reconnect( writer = mocker.patch("asyncio.StreamWriter") mocker.patch.object(writer, "write", _fail_one_less_than_retry_count) mocker.patch.object(reader, "readexactly", _mock_read) + mocker.patch.object(writer, "drain", new_callable=AsyncMock) return reader, writer config = DeviceConfig("127.0.0.1") @@ -224,6 +226,7 @@ async def test_protocol_handles_cancellation_during_write( writer = mocker.patch("asyncio.StreamWriter") mocker.patch.object(writer, "write", _cancel_first_attempt) mocker.patch.object(reader, "readexactly", _mock_read) + mocker.patch.object(writer, "drain", new_callable=AsyncMock) return reader, writer config = DeviceConfig("127.0.0.1") @@ -275,6 +278,7 @@ async def test_protocol_handles_cancellation_during_connection( reader = mocker.patch("asyncio.StreamReader") writer = mocker.patch("asyncio.StreamWriter") mocker.patch.object(reader, "readexactly", _mock_read) + mocker.patch.object(writer, "drain", new_callable=AsyncMock) return reader, writer config = DeviceConfig("127.0.0.1") @@ -324,6 +328,7 @@ async def test_protocol_logging( reader = mocker.patch("asyncio.StreamReader") writer = mocker.patch("asyncio.StreamWriter") mocker.patch.object(reader, "readexactly", _mock_read) + mocker.patch.object(writer, "drain", new_callable=AsyncMock) return reader, writer config = DeviceConfig("127.0.0.1") @@ -373,6 +378,7 @@ async def test_protocol_custom_port( else: assert port == custom_port mocker.patch.object(reader, "readexactly", _mock_read) + mocker.patch.object(writer, "drain", new_callable=AsyncMock) return reader, writer config = DeviceConfig("127.0.0.1", port_override=custom_port)