From 048c84d72cc7163c778d080132a3faed6959b7a3 Mon Sep 17 00:00:00 2001 From: "Steven B." <51370195+sdb9696@users.noreply.github.com> Date: Tue, 22 Oct 2024 18:09:35 +0100 Subject: [PATCH] Add https parameter to device class factory (#1184) `SMART.TAPOHUB` resolves to different device classes based on the https value --- kasa/cli/discover.py | 4 ++-- kasa/device_factory.py | 18 ++++++++++------ kasa/discover.py | 36 +++++++++++++++++++++++-------- kasa/exceptions.py | 1 + kasa/tests/discovery_fixtures.py | 28 ++++++++++++++++++++++-- kasa/tests/test_cli.py | 1 - kasa/tests/test_device_factory.py | 2 +- kasa/tests/test_discovery.py | 6 ++++-- 8 files changed, 73 insertions(+), 23 deletions(-) diff --git a/kasa/cli/discover.py b/kasa/cli/discover.py index deb28b4d..7989dbb1 100644 --- a/kasa/cli/discover.py +++ b/kasa/cli/discover.py @@ -98,8 +98,8 @@ async def list(ctx): echo(f"{infostr} {dev.alias}") async def print_unsupported(unsupported_exception: UnsupportedDeviceError): - if res := unsupported_exception.discovery_result: - echo(f"{res.get('ip'):<15} UNSUPPORTED DEVICE") + if host := unsupported_exception.host: + echo(f"{host:<15} UNSUPPORTED DEVICE") echo(f"{'HOST':<15} {'DEVICE FAMILY':<20} {'ENCRYPT':<7} {'ALIAS'}") return await _discover(ctx, print_discovered, print_unsupported, do_echo=False) diff --git a/kasa/device_factory.py b/kasa/device_factory.py index 01b2c8e7..53ae1eff 100755 --- a/kasa/device_factory.py +++ b/kasa/device_factory.py @@ -67,7 +67,8 @@ async def connect(*, host: str | None = None, config: DeviceConfig) -> Device: if (protocol := get_protocol(config=config)) is None: raise UnsupportedDeviceError( f"Unsupported device for {config.host}: " - + f"{config.connection_type.device_family.value}" + + f"{config.connection_type.device_family.value}", + host=config.host, ) try: @@ -110,7 +111,7 @@ async def _connect(config: DeviceConfig, protocol: BaseProtocol) -> Device: _perf_log(True, "update") return device elif device_class := get_device_class_from_family( - config.connection_type.device_family.value + config.connection_type.device_family.value, https=config.connection_type.https ): device = device_class(host=config.host, protocol=protocol) await device.update() @@ -119,7 +120,8 @@ async def _connect(config: DeviceConfig, protocol: BaseProtocol) -> Device: else: raise UnsupportedDeviceError( f"Unsupported device for {config.host}: " - + f"{config.connection_type.device_family.value}" + + f"{config.connection_type.device_family.value}", + host=config.host, ) @@ -164,7 +166,9 @@ def get_device_class_from_sys_info(sysinfo: dict[str, Any]) -> type[IotDevice]: return TYPE_TO_CLASS[_get_device_type_from_sys_info(sysinfo)] -def get_device_class_from_family(device_type: str) -> type[Device] | None: +def get_device_class_from_family( + device_type: str, *, https: bool +) -> type[Device] | None: """Return the device class from the type name.""" supported_device_types: dict[str, type[Device]] = { "SMART.TAPOPLUG": SmartDevice, @@ -172,14 +176,16 @@ def get_device_class_from_family(device_type: str) -> type[Device] | None: "SMART.TAPOSWITCH": SmartDevice, "SMART.KASAPLUG": SmartDevice, "SMART.TAPOHUB": SmartDevice, + "SMART.TAPOHUB.HTTPS": SmartCamera, "SMART.KASAHUB": SmartDevice, "SMART.KASASWITCH": SmartDevice, - "SMART.IPCAMERA": SmartCamera, + "SMART.IPCAMERA.HTTPS": SmartCamera, "IOT.SMARTPLUGSWITCH": IotPlug, "IOT.SMARTBULB": IotBulb, } + lookup_key = f"{device_type}{'.HTTPS' if https else ''}" if ( - cls := supported_device_types.get(device_type) + cls := supported_device_types.get(lookup_key) ) is None and device_type.startswith("SMART."): _LOGGER.warning("Unknown SMART device with %s, using SmartDevice", device_type) cls = SmartDevice diff --git a/kasa/discover.py b/kasa/discover.py index e7a3946c..5df094bb 100755 --- a/kasa/discover.py +++ b/kasa/discover.py @@ -573,7 +573,11 @@ class Discover: ) ) and (protocol := get_protocol(config)) - and (device_class := get_device_class_from_family(device_family.value)) + and ( + device_class := get_device_class_from_family( + device_family.value, https=https + ) + ) } for protocol, config in candidates.values(): try: @@ -591,7 +595,10 @@ class Discover: """Find SmartDevice subclass for device described by passed data.""" if "result" in info: discovery_result = DiscoveryResult(**info["result"]) - dev_class = get_device_class_from_family(discovery_result.device_type) + https = discovery_result.mgt_encrypt_schm.is_support_https + dev_class = get_device_class_from_family( + discovery_result.device_type, https=https + ) if not dev_class: raise UnsupportedDeviceError( "Unknown device type: %s" % discovery_result.device_type, @@ -662,7 +669,9 @@ class Discover: ) from ex try: discovery_result = DiscoveryResult(**info["result"]) - if discovery_result.encrypt_info: + if ( + encrypt_info := discovery_result.encrypt_info + ) and encrypt_info.sym_schm == "AES": Discover._decrypt_discovery_data(discovery_result) except ValidationError as ex: if debug_enabled: @@ -677,21 +686,23 @@ class Discover: pf(data), ) raise UnsupportedDeviceError( - f"Unable to parse discovery from device: {config.host}: {ex}" + f"Unable to parse discovery from device: {config.host}: {ex}", + host=config.host, ) from ex type_ = discovery_result.device_type - + encrypt_schm = discovery_result.mgt_encrypt_schm try: - if not ( - encrypt_type := discovery_result.mgt_encrypt_schm.encrypt_type - ) and (encrypt_info := discovery_result.encrypt_info): + if not (encrypt_type := encrypt_schm.encrypt_type) and ( + encrypt_info := discovery_result.encrypt_info + ): encrypt_type = encrypt_info.sym_schm if not encrypt_type: raise UnsupportedDeviceError( f"Unsupported device {config.host} of type {type_} " + "with no encryption type", discovery_result=discovery_result.get_dict(), + host=config.host, ) config.connection_type = DeviceConnectionParameters.from_values( type_, @@ -704,12 +715,18 @@ class Discover: f"Unsupported device {config.host} of type {type_} " + f"with encrypt_type {discovery_result.mgt_encrypt_schm.encrypt_type}", discovery_result=discovery_result.get_dict(), + host=config.host, ) from ex - if (device_class := get_device_class_from_family(type_)) is None: + if ( + device_class := get_device_class_from_family( + type_, https=encrypt_schm.is_support_https + ) + ) is None: _LOGGER.warning("Got unsupported device type: %s", type_) raise UnsupportedDeviceError( f"Unsupported device {config.host} of type {type_}: {info}", discovery_result=discovery_result.get_dict(), + host=config.host, ) if (protocol := get_protocol(config)) is None: _LOGGER.warning( @@ -719,6 +736,7 @@ class Discover: f"Unsupported encryption scheme {config.host} of " + f"type {config.connection_type.to_dict()}: {info}", discovery_result=discovery_result.get_dict(), + host=config.host, ) if debug_enabled: diff --git a/kasa/exceptions.py b/kasa/exceptions.py index 3f7f301b..e32e9fd1 100644 --- a/kasa/exceptions.py +++ b/kasa/exceptions.py @@ -31,6 +31,7 @@ class UnsupportedDeviceError(KasaException): def __init__(self, *args: Any, **kwargs: Any) -> None: self.discovery_result = kwargs.get("discovery_result") + self.host = kwargs.get("host") super().__init__(*args) diff --git a/kasa/tests/discovery_fixtures.py b/kasa/tests/discovery_fixtures.py index d56f1187..ccad1510 100644 --- a/kasa/tests/discovery_fixtures.py +++ b/kasa/tests/discovery_fixtures.py @@ -15,8 +15,10 @@ from .fixtureinfo import FixtureInfo, filter_fixtures, idgenerator DISCOVERY_MOCK_IP = "127.0.0.123" -def _make_unsupported(device_family, encrypt_type): - return { +def _make_unsupported(device_family, encrypt_type, *, omit_keys=None): + if omit_keys is None: + omit_keys = {"encrypt_info": None} + result = { "result": { "device_id": "xx", "owner": "xx", @@ -33,9 +35,17 @@ def _make_unsupported(device_family, encrypt_type): "http_port": 80, "lv": 2, }, + "encrypt_info": {"data": "", "key": "", "sym_schm": encrypt_type}, }, "error_code": 0, } + for key, val in omit_keys.items(): + if val is None: + result["result"].pop(key) + else: + result["result"][key].pop(val) + + return result UNSUPPORTED_DEVICES = { @@ -43,6 +53,16 @@ UNSUPPORTED_DEVICES = { "wrong_encryption_iot": _make_unsupported("IOT.SMARTPLUGSWITCH", "AES"), "wrong_encryption_smart": _make_unsupported("SMART.TAPOBULB", "IOT"), "unknown_encryption": _make_unsupported("IOT.SMARTPLUGSWITCH", "FOO"), + "missing_encrypt_type": _make_unsupported( + "SMART.TAPOBULB", + "FOO", + omit_keys={"mgt_encrypt_schm": "encrypt_type", "encrypt_info": None}, + ), + "unable_to_parse": _make_unsupported( + "SMART.TAPOBULB", + "FOO", + omit_keys={"mgt_encrypt_schm": None}, + ), } @@ -90,6 +110,7 @@ def create_discovery_mock(ip: str, fixture_data: dict): query_data: dict device_type: str encrypt_type: str + https: bool login_version: int | None = None port_override: int | None = None @@ -110,6 +131,7 @@ def create_discovery_mock(ip: str, fixture_data: dict): "encrypt_type" ] login_version = fixture_data["discovery_result"]["mgt_encrypt_schm"].get("lv") + https = fixture_data["discovery_result"]["mgt_encrypt_schm"]["is_support_https"] dm = _DiscoveryMock( ip, 80, @@ -118,6 +140,7 @@ def create_discovery_mock(ip: str, fixture_data: dict): fixture_data, device_type, encrypt_type, + https, login_version, ) else: @@ -134,6 +157,7 @@ def create_discovery_mock(ip: str, fixture_data: dict): fixture_data, device_type, encrypt_type, + False, login_version, ) diff --git a/kasa/tests/test_cli.py b/kasa/tests/test_cli.py index e1861a29..bd93d430 100644 --- a/kasa/tests/test_cli.py +++ b/kasa/tests/test_cli.py @@ -764,7 +764,6 @@ async def test_discover_unsupported(unsupported_device_info, runner): ) assert res.exit_code == 0 assert "== Unsupported device ==" in res.output - assert "== Discovery Result ==" in res.output async def test_host_unsupported(unsupported_device_info, runner): diff --git a/kasa/tests/test_device_factory.py b/kasa/tests/test_device_factory.py index 7940f1e5..35031cd0 100644 --- a/kasa/tests/test_device_factory.py +++ b/kasa/tests/test_device_factory.py @@ -189,5 +189,5 @@ async def test_device_class_from_unknown_family(caplog): """Verify that unknown SMART devices yield a warning and fallback to SmartDevice.""" dummy_name = "SMART.foo" with caplog.at_level(logging.WARNING): - assert get_device_class_from_family(dummy_name) == SmartDevice + assert get_device_class_from_family(dummy_name, https=False) == SmartDevice assert f"Unknown SMART device with {dummy_name}" in caplog.text diff --git a/kasa/tests/test_discovery.py b/kasa/tests/test_discovery.py index d6e0a0db..ff21b610 100644 --- a/kasa/tests/test_discovery.py +++ b/kasa/tests/test_discovery.py @@ -658,12 +658,14 @@ async def test_discovery_decryption(): async def test_discover_try_connect_all(discovery_mock, mocker): """Test that device update is called on main.""" if "result" in discovery_mock.discovery_data: - dev_class = get_device_class_from_family(discovery_mock.device_type) + dev_class = get_device_class_from_family( + discovery_mock.device_type, https=discovery_mock.https + ) cparams = DeviceConnectionParameters.from_values( discovery_mock.device_type, discovery_mock.encrypt_type, discovery_mock.login_version, - False, + discovery_mock.https, ) protocol = get_protocol( DeviceConfig(discovery_mock.ip, connection_type=cparams)