Update test framework to support smartcam device discovery. (#1477)

Update test framework to support `smartcam` device discovery:
- Add `SMARTCAM` to the default `discovery_mock` filter
- Make connection parameter derivation a self contained static method in `Discover`
- Introduce a queue to the `discovery_mock` to ensure the discovery callbacks
  complete in the same order that they started.
- Patch `Discover._decrypt_discovery_data` in `discovery_mock`
  so it doesn't error trying to decrypt empty fixture data
This commit is contained in:
Steven B. 2025-01-23 11:26:55 +00:00 committed by GitHub
parent 5e57f8bd6c
commit 988eb96bd1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 107 additions and 62 deletions

View File

@ -799,6 +799,47 @@ class Discover:
) from ex
return info
@staticmethod
def _get_connection_parameters(
discovery_result: DiscoveryResult,
) -> DeviceConnectionParameters:
"""Get connection parameters from the discovery result."""
type_ = discovery_result.device_type
if (encrypt_schm := discovery_result.mgt_encrypt_schm) is None:
raise UnsupportedDeviceError(
f"Unsupported device {discovery_result.ip} of type {type_} "
"with no mgt_encrypt_schm",
discovery_result=discovery_result.to_dict(),
host=discovery_result.ip,
)
if not (encrypt_type := encrypt_schm.encrypt_type) and (
encrypt_info := discovery_result.encrypt_info
):
encrypt_type = encrypt_info.sym_schm
if not (login_version := encrypt_schm.lv) and (
et := discovery_result.encrypt_type
):
# Known encrypt types are ["1","2"] and ["3"]
# Reuse the login_version attribute to pass the max to transport
login_version = max([int(i) for i in et])
if not encrypt_type:
raise UnsupportedDeviceError(
f"Unsupported device {discovery_result.ip} of type {type_} "
+ "with no encryption type",
discovery_result=discovery_result.to_dict(),
host=discovery_result.ip,
)
return DeviceConnectionParameters.from_values(
type_,
encrypt_type,
login_version=login_version,
https=encrypt_schm.is_support_https,
http_port=encrypt_schm.http_port,
)
@staticmethod
def _get_device_instance(
info: dict,
@ -838,55 +879,22 @@ class Discover:
config.host,
redact_data(info, NEW_DISCOVERY_REDACTORS),
)
type_ = discovery_result.device_type
if (encrypt_schm := discovery_result.mgt_encrypt_schm) is None:
raise UnsupportedDeviceError(
f"Unsupported device {config.host} of type {type_} "
"with no mgt_encrypt_schm",
discovery_result=discovery_result.to_dict(),
host=config.host,
)
try:
if not (encrypt_type := encrypt_schm.encrypt_type) and (
encrypt_info := discovery_result.encrypt_info
):
encrypt_type = encrypt_info.sym_schm
if not (login_version := encrypt_schm.lv) and (
et := discovery_result.encrypt_type
):
# Known encrypt types are ["1","2"] and ["3"]
# Reuse the login_version attribute to pass the max to transport
login_version = max([int(i) for i in et])
if not encrypt_type:
raise UnsupportedDeviceError(
f"Unsupported device {config.host} of type {type_} "
+ "with no encryption type",
discovery_result=discovery_result.to_dict(),
host=config.host,
)
config.connection_type = DeviceConnectionParameters.from_values(
type_,
encrypt_type,
login_version=login_version,
https=encrypt_schm.is_support_https,
http_port=encrypt_schm.http_port,
)
conn_params = Discover._get_connection_parameters(discovery_result)
config.connection_type = conn_params
except KasaException as ex:
if isinstance(ex, UnsupportedDeviceError):
raise
raise UnsupportedDeviceError(
f"Unsupported device {config.host} of type {type_} "
+ f"with encrypt_type {encrypt_schm.encrypt_type}",
+ f"with encrypt_scheme {discovery_result.mgt_encrypt_schm}",
discovery_result=discovery_result.to_dict(),
host=config.host,
) from ex
if (
device_class := get_device_class_from_family(
type_, https=encrypt_schm.is_support_https
)
device_class := get_device_class_from_family(type_, https=conn_params.https)
) is None:
_LOGGER.debug("Got unsupported device type: %s", type_)
raise UnsupportedDeviceError(

View File

@ -1,6 +1,8 @@
from __future__ import annotations
import asyncio
import copy
from collections.abc import Coroutine
from dataclasses import dataclass
from json import dumps as json_dumps
from typing import Any, TypedDict
@ -34,7 +36,7 @@ UNSUPPORTED_HOMEWIFISYSTEM = {
"group_id": "REDACTED_07d902da02fa9beab8a64",
"group_name": "I01BU0tFRF9TU0lEIw==", # '#MASKED_SSID#'
"hardware_version": "3.0",
"ip": "192.168.1.192",
"ip": "127.0.0.1",
"mac": "24:2F:D0:00:00:00",
"master_device_id": "REDACTED_51f72a752213a6c45203530",
"need_account_digest": True,
@ -134,7 +136,9 @@ smart_discovery = parametrize_discovery("smart discovery", protocol_filter={"SMA
@pytest.fixture(
params=filter_fixtures("discoverable", protocol_filter={"SMART", "IOT"}),
params=filter_fixtures(
"discoverable", protocol_filter={"SMART", "SMARTCAM", "IOT"}
),
ids=idgenerator,
)
async def discovery_mock(request, mocker):
@ -251,12 +255,46 @@ def patch_discovery(fixture_infos: dict[str, FixtureInfo], mocker):
first_ip = list(fixture_infos.keys())[0]
first_host = None
# Mock _run_callback_task so the tasks complete in the order they started.
# Otherwise test output is non-deterministic which affects readme examples.
callback_queue: asyncio.Queue = asyncio.Queue()
exception_queue: asyncio.Queue = asyncio.Queue()
async def process_callback_queue(finished_event: asyncio.Event) -> None:
while (finished_event.is_set() is False) or callback_queue.qsize():
coro = await callback_queue.get()
try:
await coro
except Exception as ex:
await exception_queue.put(ex)
else:
await exception_queue.put(None)
callback_queue.task_done()
async def wait_for_coro():
await callback_queue.join()
if ex := exception_queue.get_nowait():
raise ex
def _run_callback_task(self, coro: Coroutine) -> None:
callback_queue.put_nowait(coro)
task = asyncio.create_task(wait_for_coro())
self.callback_tasks.append(task)
mocker.patch(
"kasa.discover._DiscoverProtocol._run_callback_task", _run_callback_task
)
# do_discover_mock
async def mock_discover(self):
"""Call datagram_received for all mock fixtures.
Handles test cases modifying the ip and hostname of the first fixture
for discover_single testing.
"""
finished_event = asyncio.Event()
asyncio.create_task(process_callback_queue(finished_event))
for ip, dm in discovery_mocks.items():
first_ip = list(discovery_mocks.values())[0].ip
fixture_info = fixture_infos[ip]
@ -283,10 +321,18 @@ def patch_discovery(fixture_infos: dict[str, FixtureInfo], mocker):
dm._datagram,
(dm.ip, port),
)
# Setting this event will stop the processing of callbacks
finished_event.set()
mocker.patch("kasa.discover._DiscoverProtocol.do_discover", mock_discover)
# query_mock
async def _query(self, request, retry_count: int = 3):
return await protos[self._host].query(request)
mocker.patch("kasa.IotProtocol.query", _query)
mocker.patch("kasa.SmartProtocol.query", _query)
def _getaddrinfo(host, *_, **__):
nonlocal first_host, first_ip
first_host = host # Store the hostname used by discover single
@ -295,20 +341,21 @@ def patch_discovery(fixture_infos: dict[str, FixtureInfo], mocker):
].ip # ip could have been overridden in test
return [(None, None, None, None, (first_ip, 0))]
mocker.patch("kasa.IotProtocol.query", _query)
mocker.patch("kasa.SmartProtocol.query", _query)
mocker.patch("kasa.discover._DiscoverProtocol.do_discover", mock_discover)
mocker.patch(
"socket.getaddrinfo",
# side_effect=lambda *_, **__: [(None, None, None, None, (first_ip, 0))],
side_effect=_getaddrinfo,
)
mocker.patch("socket.getaddrinfo", side_effect=_getaddrinfo)
# Mock decrypt so it doesn't error with unencryptable empty data in the
# fixtures. The discovery result will already contain the decrypted data
# deserialized from the fixture
mocker.patch("kasa.discover.Discover._decrypt_discovery_data")
# Only return the first discovery mock to be used for testing discover single
return discovery_mocks[first_ip]
@pytest.fixture(
params=filter_fixtures("discoverable", protocol_filter={"SMART", "IOT"}),
params=filter_fixtures(
"discoverable", protocol_filter={"SMART", "SMARTCAM", "IOT"}
),
ids=idgenerator,
)
def discovery_data(request, mocker):

View File

@ -60,13 +60,7 @@ def _get_connection_type_device_class(discovery_info):
device_class = Discover._get_device_class(discovery_info)
dr = DiscoveryResult.from_dict(discovery_info["result"])
connection_type = DeviceConnectionParameters.from_values(
dr.device_type,
dr.mgt_encrypt_schm.encrypt_type,
login_version=dr.mgt_encrypt_schm.lv,
https=dr.mgt_encrypt_schm.is_support_https,
http_port=dr.mgt_encrypt_schm.http_port,
)
connection_type = Discover._get_connection_parameters(dr)
else:
connection_type = DeviceConnectionParameters.from_values(
DeviceFamily.IotSmartPlugSwitch.value, DeviceEncryptionType.Xor.value
@ -118,11 +112,7 @@ async def test_connect_custom_port(discovery_mock, mocker, custom_port):
connection_type=ctype,
credentials=Credentials("dummy_user", "dummy_password"),
)
default_port = (
DiscoveryResult.from_dict(discovery_data["result"]).mgt_encrypt_schm.http_port
if "result" in discovery_data
else 9999
)
default_port = discovery_mock.default_port
ctype, _ = _get_connection_type_device_class(discovery_data)