mirror of
				https://github.com/python-kasa/python-kasa.git
				synced 2025-11-03 22:22:06 +00:00 
			
		
		
		
	Re-add protocol_class parameter to connect (#551)
Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
		@@ -7,6 +7,7 @@ from typing import Any, Dict, Optional, Type
 | 
			
		||||
from .credentials import Credentials
 | 
			
		||||
from .device_type import DeviceType
 | 
			
		||||
from .exceptions import UnsupportedDeviceException
 | 
			
		||||
from .protocol import TPLinkProtocol
 | 
			
		||||
from .smartbulb import SmartBulb
 | 
			
		||||
from .smartdevice import SmartDevice, SmartDeviceException
 | 
			
		||||
from .smartdimmer import SmartDimmer
 | 
			
		||||
@@ -32,6 +33,7 @@ async def connect(
 | 
			
		||||
    timeout=5,
 | 
			
		||||
    credentials: Optional[Credentials] = None,
 | 
			
		||||
    device_type: Optional[DeviceType] = None,
 | 
			
		||||
    protocol_class: Optional[Type[TPLinkProtocol]] = None,
 | 
			
		||||
) -> "SmartDevice":
 | 
			
		||||
    """Connect to a single device by the given IP address.
 | 
			
		||||
 | 
			
		||||
@@ -50,6 +52,8 @@ async def connect(
 | 
			
		||||
        If not given, the device type is discovered by querying the device.
 | 
			
		||||
        If the device type is already known, it is preferred to pass it
 | 
			
		||||
        to avoid the extra query to the device to discover its type.
 | 
			
		||||
    :param protocol_class: Optionally provide the protocol class
 | 
			
		||||
            to use.
 | 
			
		||||
    :rtype: SmartDevice
 | 
			
		||||
    :return: Object for querying/controlling found device.
 | 
			
		||||
    """
 | 
			
		||||
@@ -62,6 +66,8 @@ async def connect(
 | 
			
		||||
        dev: SmartDevice = klass(
 | 
			
		||||
            host=host, port=port, credentials=credentials, timeout=timeout
 | 
			
		||||
        )
 | 
			
		||||
        if protocol_class is not None:
 | 
			
		||||
            dev.protocol = protocol_class(host, credentials=credentials)
 | 
			
		||||
        await dev.update()
 | 
			
		||||
        if debug_enabled:
 | 
			
		||||
            end_time = time.perf_counter()
 | 
			
		||||
@@ -76,6 +82,8 @@ async def connect(
 | 
			
		||||
    unknown_dev = SmartDevice(
 | 
			
		||||
        host=host, port=port, credentials=credentials, timeout=timeout
 | 
			
		||||
    )
 | 
			
		||||
    if protocol_class is not None:
 | 
			
		||||
        unknown_dev.protocol = protocol_class(host, credentials=credentials)
 | 
			
		||||
    await unknown_dev.update()
 | 
			
		||||
    device_class = get_device_class_from_info(unknown_dev.internal_state)
 | 
			
		||||
    dev = device_class(host=host, port=port, credentials=credentials, timeout=timeout)
 | 
			
		||||
 
 | 
			
		||||
@@ -14,6 +14,8 @@ from kasa import (
 | 
			
		||||
    SmartPlug,
 | 
			
		||||
)
 | 
			
		||||
from kasa.device_factory import connect
 | 
			
		||||
from kasa.klapprotocol import TPLinkKlap
 | 
			
		||||
from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize("custom_port", [123, None])
 | 
			
		||||
@@ -72,3 +74,28 @@ async def test_connect_logs_connect_time(
 | 
			
		||||
    logging.getLogger("kasa").setLevel(logging.DEBUG)
 | 
			
		||||
    await connect(host)
 | 
			
		||||
    assert "seconds to connect" in caplog.text
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize("device_type", [DeviceType.Plug, None])
 | 
			
		||||
@pytest.mark.parametrize(
 | 
			
		||||
    ("protocol_in", "protocol_result"),
 | 
			
		||||
    (
 | 
			
		||||
        (None, TPLinkSmartHomeProtocol),
 | 
			
		||||
        (TPLinkKlap, TPLinkKlap),
 | 
			
		||||
        (TPLinkSmartHomeProtocol, TPLinkSmartHomeProtocol),
 | 
			
		||||
    ),
 | 
			
		||||
)
 | 
			
		||||
async def test_connect_pass_protocol(
 | 
			
		||||
    discovery_data: dict,
 | 
			
		||||
    mocker,
 | 
			
		||||
    device_type: DeviceType,
 | 
			
		||||
    protocol_in: Type[TPLinkProtocol],
 | 
			
		||||
    protocol_result: Type[TPLinkProtocol],
 | 
			
		||||
):
 | 
			
		||||
    """Test that if the protocol is passed in it's gets set correctly."""
 | 
			
		||||
    host = "127.0.0.1"
 | 
			
		||||
    mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
 | 
			
		||||
    mocker.patch("kasa.TPLinkKlap.query", return_value=discovery_data)
 | 
			
		||||
 | 
			
		||||
    dev = await connect(host, device_type=device_type, protocol_class=protocol_in)
 | 
			
		||||
    assert isinstance(dev.protocol, protocol_result)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user