From fe072657b492353525aa5a09c9ffd679eea8ca0c Mon Sep 17 00:00:00 2001 From: "Steven B." <51370195+sdb9696@users.noreply.github.com> Date: Tue, 17 Dec 2024 07:39:17 +0000 Subject: [PATCH] Simplify get_protocol to prevent clashes with smartcam and robovac (#1377) --- kasa/device_factory.py | 34 +++++++++------ kasa/discover.py | 10 ++--- tests/test_device_factory.py | 85 ++++++++++++++++++++++++++++++++++++ 3 files changed, 111 insertions(+), 18 deletions(-) diff --git a/kasa/device_factory.py b/kasa/device_factory.py index a1015570..99654a0c 100644 --- a/kasa/device_factory.py +++ b/kasa/device_factory.py @@ -8,7 +8,7 @@ from typing import Any from .device import Device from .device_type import DeviceType -from .deviceconfig import DeviceConfig +from .deviceconfig import DeviceConfig, DeviceFamily from .exceptions import KasaException, UnsupportedDeviceError from .iot import ( IotBulb, @@ -179,20 +179,29 @@ def get_device_class_from_family( def get_protocol( config: DeviceConfig, ) -> BaseProtocol | None: - """Return the protocol from the connection name.""" - protocol_name = config.connection_type.device_family.value.split(".")[0] + """Return the protocol from the connection name. + + For cameras and vacuums the device family is a simple mapping to + the protocol/transport. For other device types the transport varies + based on the discovery information. + """ ctype = config.connection_type + protocol_name = ctype.device_family.value.split(".")[0] + + if ctype.device_family is DeviceFamily.SmartIpCamera: + return SmartCamProtocol(transport=SslAesTransport(config=config)) + + if ctype.device_family is DeviceFamily.IotIpCamera: + return IotProtocol(transport=LinkieTransportV2(config=config)) + + if ctype.device_family is DeviceFamily.SmartTapoRobovac: + return SmartProtocol(transport=SslTransport(config=config)) protocol_transport_key = ( protocol_name + "." + ctype.encryption_type.value + (".HTTPS" if ctype.https else "") - + ( - f".{ctype.login_version}" - if ctype.login_version and ctype.login_version > 1 - else "" - ) ) _LOGGER.debug("Finding transport for %s", protocol_transport_key) @@ -201,12 +210,11 @@ def get_protocol( ] = { "IOT.XOR": (IotProtocol, XorTransport), "IOT.KLAP": (IotProtocol, KlapTransport), - "IOT.XOR.HTTPS.2": (IotProtocol, LinkieTransportV2), "SMART.AES": (SmartProtocol, AesTransport), - "SMART.AES.2": (SmartProtocol, AesTransport), - "SMART.KLAP.2": (SmartProtocol, KlapTransportV2), - "SMART.AES.HTTPS.2": (SmartCamProtocol, SslAesTransport), - "SMART.AES.HTTPS": (SmartProtocol, SslTransport), + "SMART.KLAP": (SmartProtocol, KlapTransportV2), + # H200 is device family SMART.TAPOHUB and uses SmartCamProtocol so use + # https to distuingish from SmartProtocol devices + "SMART.AES.HTTPS": (SmartCamProtocol, SslAesTransport), } if not (prot_tran_cls := supported_device_protocols.get(protocol_transport_key)): return None diff --git a/kasa/discover.py b/kasa/discover.py index 2bd98815..77ef80be 100755 --- a/kasa/discover.py +++ b/kasa/discover.py @@ -847,12 +847,12 @@ class Discover: ): encrypt_type = encrypt_info.sym_schm - if ( - not (login_version := encrypt_schm.lv) - and (et := discovery_result.encrypt_type) - and et == ["3"] + if not (login_version := encrypt_schm.lv) and ( + et := discovery_result.encrypt_type ): - login_version = 2 + # Known encrypt types are ["1","2"] and ["3"] + # Reuse the login_version attribute to pass the max to transport + login_version = max([int(i) for i in et]) if not encrypt_type: raise UnsupportedDeviceError( diff --git a/tests/test_device_factory.py b/tests/test_device_factory.py index ed73b3a3..66e24324 100644 --- a/tests/test_device_factory.py +++ b/tests/test_device_factory.py @@ -13,9 +13,13 @@ import aiohttp import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 from kasa import ( + BaseProtocol, Credentials, Discover, + IotProtocol, KasaException, + SmartCamProtocol, + SmartProtocol, ) from kasa.device_factory import ( Device, @@ -33,6 +37,16 @@ from kasa.deviceconfig import ( DeviceFamily, ) from kasa.discover import DiscoveryResult +from kasa.transports import ( + AesTransport, + BaseTransport, + KlapTransport, + KlapTransportV2, + LinkieTransportV2, + SslAesTransport, + SslTransport, + XorTransport, +) from .conftest import DISCOVERY_MOCK_IP @@ -203,3 +217,74 @@ async def test_device_class_from_unknown_family(caplog): with caplog.at_level(logging.DEBUG): assert get_device_class_from_family(dummy_name, https=False) == SmartDevice assert f"Unknown SMART device with {dummy_name}" in caplog.text + + +# Aliases to make the test params more readable +CP = DeviceConnectionParameters +DF = DeviceFamily +ET = DeviceEncryptionType + + +@pytest.mark.parametrize( + ("conn_params", "expected_protocol", "expected_transport"), + [ + pytest.param( + CP(DF.SmartIpCamera, ET.Aes, https=True), + SmartCamProtocol, + SslAesTransport, + id="smartcam", + ), + pytest.param( + CP(DF.SmartTapoHub, ET.Aes, https=True), + SmartCamProtocol, + SslAesTransport, + id="smartcam-hub", + ), + pytest.param( + CP(DF.IotIpCamera, ET.Aes, https=True), + IotProtocol, + LinkieTransportV2, + id="kasacam", + ), + pytest.param( + CP(DF.SmartTapoRobovac, ET.Aes, https=True), + SmartProtocol, + SslTransport, + id="robovac", + ), + pytest.param( + CP(DF.IotSmartPlugSwitch, ET.Klap, https=False), + IotProtocol, + KlapTransport, + id="iot-klap", + ), + pytest.param( + CP(DF.IotSmartPlugSwitch, ET.Xor, https=False), + IotProtocol, + XorTransport, + id="iot-xor", + ), + pytest.param( + CP(DF.SmartTapoPlug, ET.Aes, https=False), + SmartProtocol, + AesTransport, + id="smart-aes", + ), + pytest.param( + CP(DF.SmartTapoPlug, ET.Klap, https=False), + SmartProtocol, + KlapTransportV2, + id="smart-klap", + ), + ], +) +async def test_get_protocol( + conn_params: DeviceConnectionParameters, + expected_protocol: type[BaseProtocol], + expected_transport: type[BaseTransport], +): + """Test get_protocol returns the right protocol.""" + config = DeviceConfig("127.0.0.1", connection_type=conn_params) + protocol = get_protocol(config) + assert isinstance(protocol, expected_protocol) + assert isinstance(protocol._transport, expected_transport)