Update test framework to support smartcam device discovery. (#1477)

Update test framework to support `smartcam` device discovery:
- Add `SMARTCAM` to the default `discovery_mock` filter
- Make connection parameter derivation a self contained static method in `Discover`
- Introduce a queue to the `discovery_mock` to ensure the discovery callbacks
  complete in the same order that they started.
- Patch `Discover._decrypt_discovery_data` in `discovery_mock`
  so it doesn't error trying to decrypt empty fixture data
This commit is contained in:
Steven B. 2025-01-23 11:26:55 +00:00 committed by GitHub
parent 5e57f8bd6c
commit 988eb96bd1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 107 additions and 62 deletions

View File

@ -799,6 +799,47 @@ class Discover:
) from ex ) from ex
return info return info
@staticmethod
def _get_connection_parameters(
discovery_result: DiscoveryResult,
) -> DeviceConnectionParameters:
"""Get connection parameters from the discovery result."""
type_ = discovery_result.device_type
if (encrypt_schm := discovery_result.mgt_encrypt_schm) is None:
raise UnsupportedDeviceError(
f"Unsupported device {discovery_result.ip} of type {type_} "
"with no mgt_encrypt_schm",
discovery_result=discovery_result.to_dict(),
host=discovery_result.ip,
)
if not (encrypt_type := encrypt_schm.encrypt_type) and (
encrypt_info := discovery_result.encrypt_info
):
encrypt_type = encrypt_info.sym_schm
if not (login_version := encrypt_schm.lv) and (
et := discovery_result.encrypt_type
):
# Known encrypt types are ["1","2"] and ["3"]
# Reuse the login_version attribute to pass the max to transport
login_version = max([int(i) for i in et])
if not encrypt_type:
raise UnsupportedDeviceError(
f"Unsupported device {discovery_result.ip} of type {type_} "
+ "with no encryption type",
discovery_result=discovery_result.to_dict(),
host=discovery_result.ip,
)
return DeviceConnectionParameters.from_values(
type_,
encrypt_type,
login_version=login_version,
https=encrypt_schm.is_support_https,
http_port=encrypt_schm.http_port,
)
@staticmethod @staticmethod
def _get_device_instance( def _get_device_instance(
info: dict, info: dict,
@ -838,55 +879,22 @@ class Discover:
config.host, config.host,
redact_data(info, NEW_DISCOVERY_REDACTORS), redact_data(info, NEW_DISCOVERY_REDACTORS),
) )
type_ = discovery_result.device_type type_ = discovery_result.device_type
if (encrypt_schm := discovery_result.mgt_encrypt_schm) is None:
raise UnsupportedDeviceError(
f"Unsupported device {config.host} of type {type_} "
"with no mgt_encrypt_schm",
discovery_result=discovery_result.to_dict(),
host=config.host,
)
try: try:
if not (encrypt_type := encrypt_schm.encrypt_type) and ( conn_params = Discover._get_connection_parameters(discovery_result)
encrypt_info := discovery_result.encrypt_info config.connection_type = conn_params
):
encrypt_type = encrypt_info.sym_schm
if not (login_version := encrypt_schm.lv) and (
et := discovery_result.encrypt_type
):
# Known encrypt types are ["1","2"] and ["3"]
# Reuse the login_version attribute to pass the max to transport
login_version = max([int(i) for i in et])
if not encrypt_type:
raise UnsupportedDeviceError(
f"Unsupported device {config.host} of type {type_} "
+ "with no encryption type",
discovery_result=discovery_result.to_dict(),
host=config.host,
)
config.connection_type = DeviceConnectionParameters.from_values(
type_,
encrypt_type,
login_version=login_version,
https=encrypt_schm.is_support_https,
http_port=encrypt_schm.http_port,
)
except KasaException as ex: except KasaException as ex:
if isinstance(ex, UnsupportedDeviceError):
raise
raise UnsupportedDeviceError( raise UnsupportedDeviceError(
f"Unsupported device {config.host} of type {type_} " f"Unsupported device {config.host} of type {type_} "
+ f"with encrypt_type {encrypt_schm.encrypt_type}", + f"with encrypt_scheme {discovery_result.mgt_encrypt_schm}",
discovery_result=discovery_result.to_dict(), discovery_result=discovery_result.to_dict(),
host=config.host, host=config.host,
) from ex ) from ex
if ( if (
device_class := get_device_class_from_family( device_class := get_device_class_from_family(type_, https=conn_params.https)
type_, https=encrypt_schm.is_support_https
)
) is None: ) is None:
_LOGGER.debug("Got unsupported device type: %s", type_) _LOGGER.debug("Got unsupported device type: %s", type_)
raise UnsupportedDeviceError( raise UnsupportedDeviceError(

View File

@ -1,6 +1,8 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import copy import copy
from collections.abc import Coroutine
from dataclasses import dataclass from dataclasses import dataclass
from json import dumps as json_dumps from json import dumps as json_dumps
from typing import Any, TypedDict from typing import Any, TypedDict
@ -34,7 +36,7 @@ UNSUPPORTED_HOMEWIFISYSTEM = {
"group_id": "REDACTED_07d902da02fa9beab8a64", "group_id": "REDACTED_07d902da02fa9beab8a64",
"group_name": "I01BU0tFRF9TU0lEIw==", # '#MASKED_SSID#' "group_name": "I01BU0tFRF9TU0lEIw==", # '#MASKED_SSID#'
"hardware_version": "3.0", "hardware_version": "3.0",
"ip": "192.168.1.192", "ip": "127.0.0.1",
"mac": "24:2F:D0:00:00:00", "mac": "24:2F:D0:00:00:00",
"master_device_id": "REDACTED_51f72a752213a6c45203530", "master_device_id": "REDACTED_51f72a752213a6c45203530",
"need_account_digest": True, "need_account_digest": True,
@ -134,7 +136,9 @@ smart_discovery = parametrize_discovery("smart discovery", protocol_filter={"SMA
@pytest.fixture( @pytest.fixture(
params=filter_fixtures("discoverable", protocol_filter={"SMART", "IOT"}), params=filter_fixtures(
"discoverable", protocol_filter={"SMART", "SMARTCAM", "IOT"}
),
ids=idgenerator, ids=idgenerator,
) )
async def discovery_mock(request, mocker): async def discovery_mock(request, mocker):
@ -251,12 +255,46 @@ def patch_discovery(fixture_infos: dict[str, FixtureInfo], mocker):
first_ip = list(fixture_infos.keys())[0] first_ip = list(fixture_infos.keys())[0]
first_host = None first_host = None
# Mock _run_callback_task so the tasks complete in the order they started.
# Otherwise test output is non-deterministic which affects readme examples.
callback_queue: asyncio.Queue = asyncio.Queue()
exception_queue: asyncio.Queue = asyncio.Queue()
async def process_callback_queue(finished_event: asyncio.Event) -> None:
while (finished_event.is_set() is False) or callback_queue.qsize():
coro = await callback_queue.get()
try:
await coro
except Exception as ex:
await exception_queue.put(ex)
else:
await exception_queue.put(None)
callback_queue.task_done()
async def wait_for_coro():
await callback_queue.join()
if ex := exception_queue.get_nowait():
raise ex
def _run_callback_task(self, coro: Coroutine) -> None:
callback_queue.put_nowait(coro)
task = asyncio.create_task(wait_for_coro())
self.callback_tasks.append(task)
mocker.patch(
"kasa.discover._DiscoverProtocol._run_callback_task", _run_callback_task
)
# do_discover_mock
async def mock_discover(self): async def mock_discover(self):
"""Call datagram_received for all mock fixtures. """Call datagram_received for all mock fixtures.
Handles test cases modifying the ip and hostname of the first fixture Handles test cases modifying the ip and hostname of the first fixture
for discover_single testing. for discover_single testing.
""" """
finished_event = asyncio.Event()
asyncio.create_task(process_callback_queue(finished_event))
for ip, dm in discovery_mocks.items(): for ip, dm in discovery_mocks.items():
first_ip = list(discovery_mocks.values())[0].ip first_ip = list(discovery_mocks.values())[0].ip
fixture_info = fixture_infos[ip] fixture_info = fixture_infos[ip]
@ -283,10 +321,18 @@ def patch_discovery(fixture_infos: dict[str, FixtureInfo], mocker):
dm._datagram, dm._datagram,
(dm.ip, port), (dm.ip, port),
) )
# Setting this event will stop the processing of callbacks
finished_event.set()
mocker.patch("kasa.discover._DiscoverProtocol.do_discover", mock_discover)
# query_mock
async def _query(self, request, retry_count: int = 3): async def _query(self, request, retry_count: int = 3):
return await protos[self._host].query(request) return await protos[self._host].query(request)
mocker.patch("kasa.IotProtocol.query", _query)
mocker.patch("kasa.SmartProtocol.query", _query)
def _getaddrinfo(host, *_, **__): def _getaddrinfo(host, *_, **__):
nonlocal first_host, first_ip nonlocal first_host, first_ip
first_host = host # Store the hostname used by discover single first_host = host # Store the hostname used by discover single
@ -295,20 +341,21 @@ def patch_discovery(fixture_infos: dict[str, FixtureInfo], mocker):
].ip # ip could have been overridden in test ].ip # ip could have been overridden in test
return [(None, None, None, None, (first_ip, 0))] return [(None, None, None, None, (first_ip, 0))]
mocker.patch("kasa.IotProtocol.query", _query) mocker.patch("socket.getaddrinfo", side_effect=_getaddrinfo)
mocker.patch("kasa.SmartProtocol.query", _query)
mocker.patch("kasa.discover._DiscoverProtocol.do_discover", mock_discover) # Mock decrypt so it doesn't error with unencryptable empty data in the
mocker.patch( # fixtures. The discovery result will already contain the decrypted data
"socket.getaddrinfo", # deserialized from the fixture
# side_effect=lambda *_, **__: [(None, None, None, None, (first_ip, 0))], mocker.patch("kasa.discover.Discover._decrypt_discovery_data")
side_effect=_getaddrinfo,
)
# Only return the first discovery mock to be used for testing discover single # Only return the first discovery mock to be used for testing discover single
return discovery_mocks[first_ip] return discovery_mocks[first_ip]
@pytest.fixture( @pytest.fixture(
params=filter_fixtures("discoverable", protocol_filter={"SMART", "IOT"}), params=filter_fixtures(
"discoverable", protocol_filter={"SMART", "SMARTCAM", "IOT"}
),
ids=idgenerator, ids=idgenerator,
) )
def discovery_data(request, mocker): def discovery_data(request, mocker):

View File

@ -60,13 +60,7 @@ def _get_connection_type_device_class(discovery_info):
device_class = Discover._get_device_class(discovery_info) device_class = Discover._get_device_class(discovery_info)
dr = DiscoveryResult.from_dict(discovery_info["result"]) dr = DiscoveryResult.from_dict(discovery_info["result"])
connection_type = DeviceConnectionParameters.from_values( connection_type = Discover._get_connection_parameters(dr)
dr.device_type,
dr.mgt_encrypt_schm.encrypt_type,
login_version=dr.mgt_encrypt_schm.lv,
https=dr.mgt_encrypt_schm.is_support_https,
http_port=dr.mgt_encrypt_schm.http_port,
)
else: else:
connection_type = DeviceConnectionParameters.from_values( connection_type = DeviceConnectionParameters.from_values(
DeviceFamily.IotSmartPlugSwitch.value, DeviceEncryptionType.Xor.value DeviceFamily.IotSmartPlugSwitch.value, DeviceEncryptionType.Xor.value
@ -118,11 +112,7 @@ async def test_connect_custom_port(discovery_mock, mocker, custom_port):
connection_type=ctype, connection_type=ctype,
credentials=Credentials("dummy_user", "dummy_password"), credentials=Credentials("dummy_user", "dummy_password"),
) )
default_port = ( default_port = discovery_mock.default_port
DiscoveryResult.from_dict(discovery_data["result"]).mgt_encrypt_schm.http_port
if "result" in discovery_data
else 9999
)
ctype, _ = _get_connection_type_device_class(discovery_data) ctype, _ = _get_connection_type_device_class(discovery_data)