mirror of
https://github.com/python-kasa/python-kasa.git
synced 2024-12-22 19:23:34 +00:00
Re-add protocol_class parameter to connect (#551)
Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
parent
d3c2861e4a
commit
9728866afb
@ -7,6 +7,7 @@ from typing import Any, Dict, Optional, Type
|
|||||||
from .credentials import Credentials
|
from .credentials import Credentials
|
||||||
from .device_type import DeviceType
|
from .device_type import DeviceType
|
||||||
from .exceptions import UnsupportedDeviceException
|
from .exceptions import UnsupportedDeviceException
|
||||||
|
from .protocol import TPLinkProtocol
|
||||||
from .smartbulb import SmartBulb
|
from .smartbulb import SmartBulb
|
||||||
from .smartdevice import SmartDevice, SmartDeviceException
|
from .smartdevice import SmartDevice, SmartDeviceException
|
||||||
from .smartdimmer import SmartDimmer
|
from .smartdimmer import SmartDimmer
|
||||||
@ -32,6 +33,7 @@ async def connect(
|
|||||||
timeout=5,
|
timeout=5,
|
||||||
credentials: Optional[Credentials] = None,
|
credentials: Optional[Credentials] = None,
|
||||||
device_type: Optional[DeviceType] = None,
|
device_type: Optional[DeviceType] = None,
|
||||||
|
protocol_class: Optional[Type[TPLinkProtocol]] = None,
|
||||||
) -> "SmartDevice":
|
) -> "SmartDevice":
|
||||||
"""Connect to a single device by the given IP address.
|
"""Connect to a single device by the given IP address.
|
||||||
|
|
||||||
@ -50,6 +52,8 @@ async def connect(
|
|||||||
If not given, the device type is discovered by querying the device.
|
If not given, the device type is discovered by querying the device.
|
||||||
If the device type is already known, it is preferred to pass it
|
If the device type is already known, it is preferred to pass it
|
||||||
to avoid the extra query to the device to discover its type.
|
to avoid the extra query to the device to discover its type.
|
||||||
|
:param protocol_class: Optionally provide the protocol class
|
||||||
|
to use.
|
||||||
:rtype: SmartDevice
|
:rtype: SmartDevice
|
||||||
:return: Object for querying/controlling found device.
|
:return: Object for querying/controlling found device.
|
||||||
"""
|
"""
|
||||||
@ -62,6 +66,8 @@ async def connect(
|
|||||||
dev: SmartDevice = klass(
|
dev: SmartDevice = klass(
|
||||||
host=host, port=port, credentials=credentials, timeout=timeout
|
host=host, port=port, credentials=credentials, timeout=timeout
|
||||||
)
|
)
|
||||||
|
if protocol_class is not None:
|
||||||
|
dev.protocol = protocol_class(host, credentials=credentials)
|
||||||
await dev.update()
|
await dev.update()
|
||||||
if debug_enabled:
|
if debug_enabled:
|
||||||
end_time = time.perf_counter()
|
end_time = time.perf_counter()
|
||||||
@ -76,6 +82,8 @@ async def connect(
|
|||||||
unknown_dev = SmartDevice(
|
unknown_dev = SmartDevice(
|
||||||
host=host, port=port, credentials=credentials, timeout=timeout
|
host=host, port=port, credentials=credentials, timeout=timeout
|
||||||
)
|
)
|
||||||
|
if protocol_class is not None:
|
||||||
|
unknown_dev.protocol = protocol_class(host, credentials=credentials)
|
||||||
await unknown_dev.update()
|
await unknown_dev.update()
|
||||||
device_class = get_device_class_from_info(unknown_dev.internal_state)
|
device_class = get_device_class_from_info(unknown_dev.internal_state)
|
||||||
dev = device_class(host=host, port=port, credentials=credentials, timeout=timeout)
|
dev = device_class(host=host, port=port, credentials=credentials, timeout=timeout)
|
||||||
|
@ -14,6 +14,8 @@ from kasa import (
|
|||||||
SmartPlug,
|
SmartPlug,
|
||||||
)
|
)
|
||||||
from kasa.device_factory import connect
|
from kasa.device_factory import connect
|
||||||
|
from kasa.klapprotocol import TPLinkKlap
|
||||||
|
from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("custom_port", [123, None])
|
@pytest.mark.parametrize("custom_port", [123, None])
|
||||||
@ -72,3 +74,28 @@ async def test_connect_logs_connect_time(
|
|||||||
logging.getLogger("kasa").setLevel(logging.DEBUG)
|
logging.getLogger("kasa").setLevel(logging.DEBUG)
|
||||||
await connect(host)
|
await connect(host)
|
||||||
assert "seconds to connect" in caplog.text
|
assert "seconds to connect" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("device_type", [DeviceType.Plug, None])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("protocol_in", "protocol_result"),
|
||||||
|
(
|
||||||
|
(None, TPLinkSmartHomeProtocol),
|
||||||
|
(TPLinkKlap, TPLinkKlap),
|
||||||
|
(TPLinkSmartHomeProtocol, TPLinkSmartHomeProtocol),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
async def test_connect_pass_protocol(
|
||||||
|
discovery_data: dict,
|
||||||
|
mocker,
|
||||||
|
device_type: DeviceType,
|
||||||
|
protocol_in: Type[TPLinkProtocol],
|
||||||
|
protocol_result: Type[TPLinkProtocol],
|
||||||
|
):
|
||||||
|
"""Test that if the protocol is passed in it's gets set correctly."""
|
||||||
|
host = "127.0.0.1"
|
||||||
|
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
|
||||||
|
mocker.patch("kasa.TPLinkKlap.query", return_value=discovery_data)
|
||||||
|
|
||||||
|
dev = await connect(host, device_type=device_type, protocol_class=protocol_in)
|
||||||
|
assert isinstance(dev.protocol, protocol_result)
|
||||||
|
Loading…
Reference in New Issue
Block a user