diff --git a/kasa/discover.py b/kasa/discover.py index 5df094bb..ade6a54a 100755 --- a/kasa/discover.py +++ b/kasa/discover.py @@ -93,6 +93,8 @@ from collections.abc import Awaitable from pprint import pformat as pf from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast +from aiohttp import ClientSession + # When support for cpython older than 3.11 is dropped # async_timeout can be replaced with asyncio.timeout from async_timeout import timeout as asyncio_timeout @@ -533,6 +535,7 @@ class Discover: port: int | None = None, timeout: int | None = None, credentials: Credentials | None = None, + http_client: ClientSession | None = None, ) -> Device | None: """Try to connect directly to a device with all possible parameters. @@ -544,6 +547,7 @@ class Discover: :param port: Optionally set a different port for legacy devices using port 9999 :param timeout: Timeout in seconds device for devices queries :param credentials: Credentials for devices that require authentication. + :param http_client: Optional client session for devices that use http. username and password are ignored if provided. """ from .device_factory import _connect @@ -570,6 +574,8 @@ class Discover: timeout=timeout, port_override=port, credentials=credentials, + http_client=http_client, + uses_http=encrypt is not Device.EncryptionType.Xor, ) ) and (protocol := get_protocol(config)) diff --git a/kasa/tests/test_discovery.py b/kasa/tests/test_discovery.py index ff21b610..a31ef836 100644 --- a/kasa/tests/test_discovery.py +++ b/kasa/tests/test_discovery.py @@ -697,9 +697,13 @@ async def test_discover_try_connect_all(discovery_mock, mocker): mocker.patch("kasa.SmartProtocol.query", new=_query) mocker.patch.object(dev_class, "update", new=_update) - dev = await Discover.try_connect_all(discovery_mock.ip) + session = aiohttp.ClientSession() + dev = await Discover.try_connect_all(discovery_mock.ip, http_client=session) assert dev assert isinstance(dev, dev_class) assert isinstance(dev.protocol, protocol_class) assert isinstance(dev.protocol._transport, transport_class) + assert dev.config.uses_http is (transport_class != XorTransport) + if transport_class != XorTransport: + assert dev.protocol._transport._http_client.client == session