diff --git a/kasa/discover.py b/kasa/discover.py index 36d6f277..a943ddd4 100755 --- a/kasa/discover.py +++ b/kasa/discover.py @@ -799,6 +799,47 @@ class Discover: ) from ex return info + @staticmethod + def _get_connection_parameters( + discovery_result: DiscoveryResult, + ) -> DeviceConnectionParameters: + """Get connection parameters from the discovery result.""" + type_ = discovery_result.device_type + if (encrypt_schm := discovery_result.mgt_encrypt_schm) is None: + raise UnsupportedDeviceError( + f"Unsupported device {discovery_result.ip} of type {type_} " + "with no mgt_encrypt_schm", + discovery_result=discovery_result.to_dict(), + host=discovery_result.ip, + ) + + if not (encrypt_type := encrypt_schm.encrypt_type) and ( + encrypt_info := discovery_result.encrypt_info + ): + encrypt_type = encrypt_info.sym_schm + + if not (login_version := encrypt_schm.lv) and ( + et := discovery_result.encrypt_type + ): + # Known encrypt types are ["1","2"] and ["3"] + # Reuse the login_version attribute to pass the max to transport + login_version = max([int(i) for i in et]) + + if not encrypt_type: + raise UnsupportedDeviceError( + f"Unsupported device {discovery_result.ip} of type {type_} " + + "with no encryption type", + discovery_result=discovery_result.to_dict(), + host=discovery_result.ip, + ) + return DeviceConnectionParameters.from_values( + type_, + encrypt_type, + login_version=login_version, + https=encrypt_schm.is_support_https, + http_port=encrypt_schm.http_port, + ) + @staticmethod def _get_device_instance( info: dict, @@ -838,55 +879,22 @@ class Discover: config.host, redact_data(info, NEW_DISCOVERY_REDACTORS), ) - type_ = discovery_result.device_type - if (encrypt_schm := discovery_result.mgt_encrypt_schm) is None: - raise UnsupportedDeviceError( - f"Unsupported device {config.host} of type {type_} " - "with no mgt_encrypt_schm", - discovery_result=discovery_result.to_dict(), - host=config.host, - ) - try: - if not (encrypt_type := encrypt_schm.encrypt_type) and ( - encrypt_info := discovery_result.encrypt_info - ): - encrypt_type = encrypt_info.sym_schm - - if not (login_version := encrypt_schm.lv) and ( - et := discovery_result.encrypt_type - ): - # Known encrypt types are ["1","2"] and ["3"] - # Reuse the login_version attribute to pass the max to transport - login_version = max([int(i) for i in et]) - - if not encrypt_type: - raise UnsupportedDeviceError( - f"Unsupported device {config.host} of type {type_} " - + "with no encryption type", - discovery_result=discovery_result.to_dict(), - host=config.host, - ) - config.connection_type = DeviceConnectionParameters.from_values( - type_, - encrypt_type, - login_version=login_version, - https=encrypt_schm.is_support_https, - http_port=encrypt_schm.http_port, - ) + conn_params = Discover._get_connection_parameters(discovery_result) + config.connection_type = conn_params except KasaException as ex: + if isinstance(ex, UnsupportedDeviceError): + raise raise UnsupportedDeviceError( f"Unsupported device {config.host} of type {type_} " - + f"with encrypt_type {encrypt_schm.encrypt_type}", + + f"with encrypt_scheme {discovery_result.mgt_encrypt_schm}", discovery_result=discovery_result.to_dict(), host=config.host, ) from ex if ( - device_class := get_device_class_from_family( - type_, https=encrypt_schm.is_support_https - ) + device_class := get_device_class_from_family(type_, https=conn_params.https) ) is None: _LOGGER.debug("Got unsupported device type: %s", type_) raise UnsupportedDeviceError( diff --git a/tests/discovery_fixtures.py b/tests/discovery_fixtures.py index 2db79e91..3cf726f4 100644 --- a/tests/discovery_fixtures.py +++ b/tests/discovery_fixtures.py @@ -1,6 +1,8 @@ from __future__ import annotations +import asyncio import copy +from collections.abc import Coroutine from dataclasses import dataclass from json import dumps as json_dumps from typing import Any, TypedDict @@ -34,7 +36,7 @@ UNSUPPORTED_HOMEWIFISYSTEM = { "group_id": "REDACTED_07d902da02fa9beab8a64", "group_name": "I01BU0tFRF9TU0lEIw==", # '#MASKED_SSID#' "hardware_version": "3.0", - "ip": "192.168.1.192", + "ip": "127.0.0.1", "mac": "24:2F:D0:00:00:00", "master_device_id": "REDACTED_51f72a752213a6c45203530", "need_account_digest": True, @@ -134,7 +136,9 @@ smart_discovery = parametrize_discovery("smart discovery", protocol_filter={"SMA @pytest.fixture( - params=filter_fixtures("discoverable", protocol_filter={"SMART", "IOT"}), + params=filter_fixtures( + "discoverable", protocol_filter={"SMART", "SMARTCAM", "IOT"} + ), ids=idgenerator, ) async def discovery_mock(request, mocker): @@ -251,12 +255,46 @@ def patch_discovery(fixture_infos: dict[str, FixtureInfo], mocker): first_ip = list(fixture_infos.keys())[0] first_host = None + # Mock _run_callback_task so the tasks complete in the order they started. + # Otherwise test output is non-deterministic which affects readme examples. + callback_queue: asyncio.Queue = asyncio.Queue() + exception_queue: asyncio.Queue = asyncio.Queue() + + async def process_callback_queue(finished_event: asyncio.Event) -> None: + while (finished_event.is_set() is False) or callback_queue.qsize(): + coro = await callback_queue.get() + try: + await coro + except Exception as ex: + await exception_queue.put(ex) + else: + await exception_queue.put(None) + callback_queue.task_done() + + async def wait_for_coro(): + await callback_queue.join() + if ex := exception_queue.get_nowait(): + raise ex + + def _run_callback_task(self, coro: Coroutine) -> None: + callback_queue.put_nowait(coro) + task = asyncio.create_task(wait_for_coro()) + self.callback_tasks.append(task) + + mocker.patch( + "kasa.discover._DiscoverProtocol._run_callback_task", _run_callback_task + ) + + # do_discover_mock async def mock_discover(self): """Call datagram_received for all mock fixtures. Handles test cases modifying the ip and hostname of the first fixture for discover_single testing. """ + finished_event = asyncio.Event() + asyncio.create_task(process_callback_queue(finished_event)) + for ip, dm in discovery_mocks.items(): first_ip = list(discovery_mocks.values())[0].ip fixture_info = fixture_infos[ip] @@ -283,10 +321,18 @@ def patch_discovery(fixture_infos: dict[str, FixtureInfo], mocker): dm._datagram, (dm.ip, port), ) + # Setting this event will stop the processing of callbacks + finished_event.set() + mocker.patch("kasa.discover._DiscoverProtocol.do_discover", mock_discover) + + # query_mock async def _query(self, request, retry_count: int = 3): return await protos[self._host].query(request) + mocker.patch("kasa.IotProtocol.query", _query) + mocker.patch("kasa.SmartProtocol.query", _query) + def _getaddrinfo(host, *_, **__): nonlocal first_host, first_ip first_host = host # Store the hostname used by discover single @@ -295,20 +341,21 @@ def patch_discovery(fixture_infos: dict[str, FixtureInfo], mocker): ].ip # ip could have been overridden in test return [(None, None, None, None, (first_ip, 0))] - mocker.patch("kasa.IotProtocol.query", _query) - mocker.patch("kasa.SmartProtocol.query", _query) - mocker.patch("kasa.discover._DiscoverProtocol.do_discover", mock_discover) - mocker.patch( - "socket.getaddrinfo", - # side_effect=lambda *_, **__: [(None, None, None, None, (first_ip, 0))], - side_effect=_getaddrinfo, - ) + mocker.patch("socket.getaddrinfo", side_effect=_getaddrinfo) + + # Mock decrypt so it doesn't error with unencryptable empty data in the + # fixtures. The discovery result will already contain the decrypted data + # deserialized from the fixture + mocker.patch("kasa.discover.Discover._decrypt_discovery_data") + # Only return the first discovery mock to be used for testing discover single return discovery_mocks[first_ip] @pytest.fixture( - params=filter_fixtures("discoverable", protocol_filter={"SMART", "IOT"}), + params=filter_fixtures( + "discoverable", protocol_filter={"SMART", "SMARTCAM", "IOT"} + ), ids=idgenerator, ) def discovery_data(request, mocker): diff --git a/tests/test_device_factory.py b/tests/test_device_factory.py index 539609c3..19ccfb73 100644 --- a/tests/test_device_factory.py +++ b/tests/test_device_factory.py @@ -60,13 +60,7 @@ def _get_connection_type_device_class(discovery_info): device_class = Discover._get_device_class(discovery_info) dr = DiscoveryResult.from_dict(discovery_info["result"]) - connection_type = DeviceConnectionParameters.from_values( - dr.device_type, - dr.mgt_encrypt_schm.encrypt_type, - login_version=dr.mgt_encrypt_schm.lv, - https=dr.mgt_encrypt_schm.is_support_https, - http_port=dr.mgt_encrypt_schm.http_port, - ) + connection_type = Discover._get_connection_parameters(dr) else: connection_type = DeviceConnectionParameters.from_values( DeviceFamily.IotSmartPlugSwitch.value, DeviceEncryptionType.Xor.value @@ -118,11 +112,7 @@ async def test_connect_custom_port(discovery_mock, mocker, custom_port): connection_type=ctype, credentials=Credentials("dummy_user", "dummy_password"), ) - default_port = ( - DiscoveryResult.from_dict(discovery_data["result"]).mgt_encrypt_schm.http_port - if "result" in discovery_data - else 9999 - ) + default_port = discovery_mock.default_port ctype, _ = _get_connection_type_device_class(discovery_data)