From 9728866afb5539e5d5e6b5371d012ad50658eaa7 Mon Sep 17 00:00:00 2001 From: sdb9696 <51370195+sdb9696@users.noreply.github.com> Date: Tue, 28 Nov 2023 19:13:15 +0000 Subject: [PATCH] Re-add protocol_class parameter to connect (#551) Co-authored-by: J. Nick Koston --- kasa/device_factory.py | 8 ++++++++ kasa/tests/test_device_factory.py | 27 +++++++++++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/kasa/device_factory.py b/kasa/device_factory.py index c3ed4de3..049969fb 100755 --- a/kasa/device_factory.py +++ b/kasa/device_factory.py @@ -7,6 +7,7 @@ from typing import Any, Dict, Optional, Type from .credentials import Credentials from .device_type import DeviceType from .exceptions import UnsupportedDeviceException +from .protocol import TPLinkProtocol from .smartbulb import SmartBulb from .smartdevice import SmartDevice, SmartDeviceException from .smartdimmer import SmartDimmer @@ -32,6 +33,7 @@ async def connect( timeout=5, credentials: Optional[Credentials] = None, device_type: Optional[DeviceType] = None, + protocol_class: Optional[Type[TPLinkProtocol]] = None, ) -> "SmartDevice": """Connect to a single device by the given IP address. @@ -50,6 +52,8 @@ async def connect( If not given, the device type is discovered by querying the device. If the device type is already known, it is preferred to pass it to avoid the extra query to the device to discover its type. + :param protocol_class: Optionally provide the protocol class + to use. :rtype: SmartDevice :return: Object for querying/controlling found device. """ @@ -62,6 +66,8 @@ async def connect( dev: SmartDevice = klass( host=host, port=port, credentials=credentials, timeout=timeout ) + if protocol_class is not None: + dev.protocol = protocol_class(host, credentials=credentials) await dev.update() if debug_enabled: end_time = time.perf_counter() @@ -76,6 +82,8 @@ async def connect( unknown_dev = SmartDevice( host=host, port=port, credentials=credentials, timeout=timeout ) + if protocol_class is not None: + unknown_dev.protocol = protocol_class(host, credentials=credentials) await unknown_dev.update() device_class = get_device_class_from_info(unknown_dev.internal_state) dev = device_class(host=host, port=port, credentials=credentials, timeout=timeout) diff --git a/kasa/tests/test_device_factory.py b/kasa/tests/test_device_factory.py index 3a08857a..aca38e19 100644 --- a/kasa/tests/test_device_factory.py +++ b/kasa/tests/test_device_factory.py @@ -14,6 +14,8 @@ from kasa import ( SmartPlug, ) from kasa.device_factory import connect +from kasa.klapprotocol import TPLinkKlap +from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol @pytest.mark.parametrize("custom_port", [123, None]) @@ -72,3 +74,28 @@ async def test_connect_logs_connect_time( logging.getLogger("kasa").setLevel(logging.DEBUG) await connect(host) assert "seconds to connect" in caplog.text + + +@pytest.mark.parametrize("device_type", [DeviceType.Plug, None]) +@pytest.mark.parametrize( + ("protocol_in", "protocol_result"), + ( + (None, TPLinkSmartHomeProtocol), + (TPLinkKlap, TPLinkKlap), + (TPLinkSmartHomeProtocol, TPLinkSmartHomeProtocol), + ), +) +async def test_connect_pass_protocol( + discovery_data: dict, + mocker, + device_type: DeviceType, + protocol_in: Type[TPLinkProtocol], + protocol_result: Type[TPLinkProtocol], +): + """Test that if the protocol is passed in it's gets set correctly.""" + host = "127.0.0.1" + mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data) + mocker.patch("kasa.TPLinkKlap.query", return_value=discovery_data) + + dev = await connect(host, device_type=device_type, protocol_class=protocol_in) + assert isinstance(dev.protocol, protocol_result)