From 88b7951feeb80f40cded04c3c0d3c42b35a121d1 Mon Sep 17 00:00:00 2001 From: "Steven B." <51370195+sdb9696@users.noreply.github.com> Date: Fri, 25 Oct 2024 19:43:37 +0100 Subject: [PATCH] Allow passing an aiohttp client session during discover try_connect_all (#1198) --- kasa/discover.py | 6 ++++++ kasa/tests/test_discovery.py | 6 +++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/kasa/discover.py b/kasa/discover.py index 5df094bb..ade6a54a 100755 --- a/kasa/discover.py +++ b/kasa/discover.py @@ -93,6 +93,8 @@ from collections.abc import Awaitable from pprint import pformat as pf from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast +from aiohttp import ClientSession + # When support for cpython older than 3.11 is dropped # async_timeout can be replaced with asyncio.timeout from async_timeout import timeout as asyncio_timeout @@ -533,6 +535,7 @@ class Discover: port: int | None = None, timeout: int | None = None, credentials: Credentials | None = None, + http_client: ClientSession | None = None, ) -> Device | None: """Try to connect directly to a device with all possible parameters. @@ -544,6 +547,7 @@ class Discover: :param port: Optionally set a different port for legacy devices using port 9999 :param timeout: Timeout in seconds device for devices queries :param credentials: Credentials for devices that require authentication. + :param http_client: Optional client session for devices that use http. username and password are ignored if provided. """ from .device_factory import _connect @@ -570,6 +574,8 @@ class Discover: timeout=timeout, port_override=port, credentials=credentials, + http_client=http_client, + uses_http=encrypt is not Device.EncryptionType.Xor, ) ) and (protocol := get_protocol(config)) diff --git a/kasa/tests/test_discovery.py b/kasa/tests/test_discovery.py index ff21b610..a31ef836 100644 --- a/kasa/tests/test_discovery.py +++ b/kasa/tests/test_discovery.py @@ -697,9 +697,13 @@ async def test_discover_try_connect_all(discovery_mock, mocker): mocker.patch("kasa.SmartProtocol.query", new=_query) mocker.patch.object(dev_class, "update", new=_update) - dev = await Discover.try_connect_all(discovery_mock.ip) + session = aiohttp.ClientSession() + dev = await Discover.try_connect_all(discovery_mock.ip, http_client=session) assert dev assert isinstance(dev, dev_class) assert isinstance(dev.protocol, protocol_class) assert isinstance(dev.protocol._transport, transport_class) + assert dev.config.uses_http is (transport_class != XorTransport) + if transport_class != XorTransport: + assert dev.protocol._transport._http_client.client == session