Simplify get_protocol to prevent clashes with smartcam and robovac (#1377)
Some checks are pending
CI / Perform linting checks (3.13) (push) Waiting to run
CI / Python ${{ matrix.python-version}} on ${{ matrix.os }}${{ fromJSON('[" (extras)", ""]')[matrix.extras == ''] }} (false, macos-latest, 3.11) (push) Blocked by required conditions
CI / Python ${{ matrix.python-version}} on ${{ matrix.os }}${{ fromJSON('[" (extras)", ""]')[matrix.extras == ''] }} (false, macos-latest, 3.12) (push) Blocked by required conditions
CI / Python ${{ matrix.python-version}} on ${{ matrix.os }}${{ fromJSON('[" (extras)", ""]')[matrix.extras == ''] }} (false, macos-latest, 3.13) (push) Blocked by required conditions
CI / Python ${{ matrix.python-version}} on ${{ matrix.os }}${{ fromJSON('[" (extras)", ""]')[matrix.extras == ''] }} (false, ubuntu-latest, 3.11) (push) Blocked by required conditions
CI / Python ${{ matrix.python-version}} on ${{ matrix.os }}${{ fromJSON('[" (extras)", ""]')[matrix.extras == ''] }} (false, ubuntu-latest, 3.12) (push) Blocked by required conditions
CI / Python ${{ matrix.python-version}} on ${{ matrix.os }}${{ fromJSON('[" (extras)", ""]')[matrix.extras == ''] }} (false, ubuntu-latest, 3.13) (push) Blocked by required conditions
CI / Python ${{ matrix.python-version}} on ${{ matrix.os }}${{ fromJSON('[" (extras)", ""]')[matrix.extras == ''] }} (false, windows-latest, 3.11) (push) Blocked by required conditions
CI / Python ${{ matrix.python-version}} on ${{ matrix.os }}${{ fromJSON('[" (extras)", ""]')[matrix.extras == ''] }} (false, windows-latest, 3.12) (push) Blocked by required conditions
CI / Python ${{ matrix.python-version}} on ${{ matrix.os }}${{ fromJSON('[" (extras)", ""]')[matrix.extras == ''] }} (false, windows-latest, 3.13) (push) Blocked by required conditions
CI / Python ${{ matrix.python-version}} on ${{ matrix.os }}${{ fromJSON('[" (extras)", ""]')[matrix.extras == ''] }} (true, ubuntu-latest, 3.11) (push) Blocked by required conditions
CI / Python ${{ matrix.python-version}} on ${{ matrix.os }}${{ fromJSON('[" (extras)", ""]')[matrix.extras == ''] }} (true, ubuntu-latest, 3.12) (push) Blocked by required conditions
CI / Python ${{ matrix.python-version}} on ${{ matrix.os }}${{ fromJSON('[" (extras)", ""]')[matrix.extras == ''] }} (true, ubuntu-latest, 3.13) (push) Blocked by required conditions
CodeQL checks / Analyze (python) (push) Waiting to run

This commit is contained in:
Steven B. 2024-12-17 07:39:17 +00:00 committed by GitHub
parent 5918e4daa7
commit fe072657b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 111 additions and 18 deletions

View File

@ -8,7 +8,7 @@ from typing import Any
from .device import Device from .device import Device
from .device_type import DeviceType from .device_type import DeviceType
from .deviceconfig import DeviceConfig from .deviceconfig import DeviceConfig, DeviceFamily
from .exceptions import KasaException, UnsupportedDeviceError from .exceptions import KasaException, UnsupportedDeviceError
from .iot import ( from .iot import (
IotBulb, IotBulb,
@ -179,20 +179,29 @@ def get_device_class_from_family(
def get_protocol( def get_protocol(
config: DeviceConfig, config: DeviceConfig,
) -> BaseProtocol | None: ) -> BaseProtocol | None:
"""Return the protocol from the connection name.""" """Return the protocol from the connection name.
protocol_name = config.connection_type.device_family.value.split(".")[0]
For cameras and vacuums the device family is a simple mapping to
the protocol/transport. For other device types the transport varies
based on the discovery information.
"""
ctype = config.connection_type ctype = config.connection_type
protocol_name = ctype.device_family.value.split(".")[0]
if ctype.device_family is DeviceFamily.SmartIpCamera:
return SmartCamProtocol(transport=SslAesTransport(config=config))
if ctype.device_family is DeviceFamily.IotIpCamera:
return IotProtocol(transport=LinkieTransportV2(config=config))
if ctype.device_family is DeviceFamily.SmartTapoRobovac:
return SmartProtocol(transport=SslTransport(config=config))
protocol_transport_key = ( protocol_transport_key = (
protocol_name protocol_name
+ "." + "."
+ ctype.encryption_type.value + ctype.encryption_type.value
+ (".HTTPS" if ctype.https else "") + (".HTTPS" if ctype.https else "")
+ (
f".{ctype.login_version}"
if ctype.login_version and ctype.login_version > 1
else ""
)
) )
_LOGGER.debug("Finding transport for %s", protocol_transport_key) _LOGGER.debug("Finding transport for %s", protocol_transport_key)
@ -201,12 +210,11 @@ def get_protocol(
] = { ] = {
"IOT.XOR": (IotProtocol, XorTransport), "IOT.XOR": (IotProtocol, XorTransport),
"IOT.KLAP": (IotProtocol, KlapTransport), "IOT.KLAP": (IotProtocol, KlapTransport),
"IOT.XOR.HTTPS.2": (IotProtocol, LinkieTransportV2),
"SMART.AES": (SmartProtocol, AesTransport), "SMART.AES": (SmartProtocol, AesTransport),
"SMART.AES.2": (SmartProtocol, AesTransport), "SMART.KLAP": (SmartProtocol, KlapTransportV2),
"SMART.KLAP.2": (SmartProtocol, KlapTransportV2), # H200 is device family SMART.TAPOHUB and uses SmartCamProtocol so use
"SMART.AES.HTTPS.2": (SmartCamProtocol, SslAesTransport), # https to distuingish from SmartProtocol devices
"SMART.AES.HTTPS": (SmartProtocol, SslTransport), "SMART.AES.HTTPS": (SmartCamProtocol, SslAesTransport),
} }
if not (prot_tran_cls := supported_device_protocols.get(protocol_transport_key)): if not (prot_tran_cls := supported_device_protocols.get(protocol_transport_key)):
return None return None

View File

@ -847,12 +847,12 @@ class Discover:
): ):
encrypt_type = encrypt_info.sym_schm encrypt_type = encrypt_info.sym_schm
if ( if not (login_version := encrypt_schm.lv) and (
not (login_version := encrypt_schm.lv) et := discovery_result.encrypt_type
and (et := discovery_result.encrypt_type)
and et == ["3"]
): ):
login_version = 2 # Known encrypt types are ["1","2"] and ["3"]
# Reuse the login_version attribute to pass the max to transport
login_version = max([int(i) for i in et])
if not encrypt_type: if not encrypt_type:
raise UnsupportedDeviceError( raise UnsupportedDeviceError(

View File

@ -13,9 +13,13 @@ import aiohttp
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
from kasa import ( from kasa import (
BaseProtocol,
Credentials, Credentials,
Discover, Discover,
IotProtocol,
KasaException, KasaException,
SmartCamProtocol,
SmartProtocol,
) )
from kasa.device_factory import ( from kasa.device_factory import (
Device, Device,
@ -33,6 +37,16 @@ from kasa.deviceconfig import (
DeviceFamily, DeviceFamily,
) )
from kasa.discover import DiscoveryResult from kasa.discover import DiscoveryResult
from kasa.transports import (
AesTransport,
BaseTransport,
KlapTransport,
KlapTransportV2,
LinkieTransportV2,
SslAesTransport,
SslTransport,
XorTransport,
)
from .conftest import DISCOVERY_MOCK_IP from .conftest import DISCOVERY_MOCK_IP
@ -203,3 +217,74 @@ async def test_device_class_from_unknown_family(caplog):
with caplog.at_level(logging.DEBUG): with caplog.at_level(logging.DEBUG):
assert get_device_class_from_family(dummy_name, https=False) == SmartDevice assert get_device_class_from_family(dummy_name, https=False) == SmartDevice
assert f"Unknown SMART device with {dummy_name}" in caplog.text assert f"Unknown SMART device with {dummy_name}" in caplog.text
# Aliases to make the test params more readable
CP = DeviceConnectionParameters
DF = DeviceFamily
ET = DeviceEncryptionType
@pytest.mark.parametrize(
("conn_params", "expected_protocol", "expected_transport"),
[
pytest.param(
CP(DF.SmartIpCamera, ET.Aes, https=True),
SmartCamProtocol,
SslAesTransport,
id="smartcam",
),
pytest.param(
CP(DF.SmartTapoHub, ET.Aes, https=True),
SmartCamProtocol,
SslAesTransport,
id="smartcam-hub",
),
pytest.param(
CP(DF.IotIpCamera, ET.Aes, https=True),
IotProtocol,
LinkieTransportV2,
id="kasacam",
),
pytest.param(
CP(DF.SmartTapoRobovac, ET.Aes, https=True),
SmartProtocol,
SslTransport,
id="robovac",
),
pytest.param(
CP(DF.IotSmartPlugSwitch, ET.Klap, https=False),
IotProtocol,
KlapTransport,
id="iot-klap",
),
pytest.param(
CP(DF.IotSmartPlugSwitch, ET.Xor, https=False),
IotProtocol,
XorTransport,
id="iot-xor",
),
pytest.param(
CP(DF.SmartTapoPlug, ET.Aes, https=False),
SmartProtocol,
AesTransport,
id="smart-aes",
),
pytest.param(
CP(DF.SmartTapoPlug, ET.Klap, https=False),
SmartProtocol,
KlapTransportV2,
id="smart-klap",
),
],
)
async def test_get_protocol(
conn_params: DeviceConnectionParameters,
expected_protocol: type[BaseProtocol],
expected_transport: type[BaseTransport],
):
"""Test get_protocol returns the right protocol."""
config = DeviceConfig("127.0.0.1", connection_type=conn_params)
protocol = get_protocol(config)
assert isinstance(protocol, expected_protocol)
assert isinstance(protocol._transport, expected_transport)