Allow passing an aiohttp client session during discover try_connect_all (#1198)

This commit is contained in:
Steven B. 2024-10-25 19:43:37 +01:00 committed by GitHub
parent 7eb8d45b6e
commit 88b7951fee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 11 additions and 1 deletions

View File

@ -93,6 +93,8 @@ from collections.abc import Awaitable
from pprint import pformat as pf from pprint import pformat as pf
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast
from aiohttp import ClientSession
# When support for cpython older than 3.11 is dropped # When support for cpython older than 3.11 is dropped
# async_timeout can be replaced with asyncio.timeout # async_timeout can be replaced with asyncio.timeout
from async_timeout import timeout as asyncio_timeout from async_timeout import timeout as asyncio_timeout
@ -533,6 +535,7 @@ class Discover:
port: int | None = None, port: int | None = None,
timeout: int | None = None, timeout: int | None = None,
credentials: Credentials | None = None, credentials: Credentials | None = None,
http_client: ClientSession | None = None,
) -> Device | None: ) -> Device | None:
"""Try to connect directly to a device with all possible parameters. """Try to connect directly to a device with all possible parameters.
@ -544,6 +547,7 @@ class Discover:
:param port: Optionally set a different port for legacy devices using port 9999 :param port: Optionally set a different port for legacy devices using port 9999
:param timeout: Timeout in seconds device for devices queries :param timeout: Timeout in seconds device for devices queries
:param credentials: Credentials for devices that require authentication. :param credentials: Credentials for devices that require authentication.
:param http_client: Optional client session for devices that use http.
username and password are ignored if provided. username and password are ignored if provided.
""" """
from .device_factory import _connect from .device_factory import _connect
@ -570,6 +574,8 @@ class Discover:
timeout=timeout, timeout=timeout,
port_override=port, port_override=port,
credentials=credentials, credentials=credentials,
http_client=http_client,
uses_http=encrypt is not Device.EncryptionType.Xor,
) )
) )
and (protocol := get_protocol(config)) and (protocol := get_protocol(config))

View File

@ -697,9 +697,13 @@ async def test_discover_try_connect_all(discovery_mock, mocker):
mocker.patch("kasa.SmartProtocol.query", new=_query) mocker.patch("kasa.SmartProtocol.query", new=_query)
mocker.patch.object(dev_class, "update", new=_update) mocker.patch.object(dev_class, "update", new=_update)
dev = await Discover.try_connect_all(discovery_mock.ip) session = aiohttp.ClientSession()
dev = await Discover.try_connect_all(discovery_mock.ip, http_client=session)
assert dev assert dev
assert isinstance(dev, dev_class) assert isinstance(dev, dev_class)
assert isinstance(dev.protocol, protocol_class) assert isinstance(dev.protocol, protocol_class)
assert isinstance(dev.protocol._transport, transport_class) assert isinstance(dev.protocol._transport, transport_class)
assert dev.config.uses_http is (transport_class != XorTransport)
if transport_class != XorTransport:
assert dev.protocol._transport._http_client.client == session