Add https parameter to device class factory (#1184)

`SMART.TAPOHUB` resolves to different device classes based on the https value
This commit is contained in:
Steven B. 2024-10-22 18:09:35 +01:00 committed by GitHub
parent 3c865b5fb6
commit 048c84d72c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 73 additions and 23 deletions

View File

@ -98,8 +98,8 @@ async def list(ctx):
echo(f"{infostr} {dev.alias}") echo(f"{infostr} {dev.alias}")
async def print_unsupported(unsupported_exception: UnsupportedDeviceError): async def print_unsupported(unsupported_exception: UnsupportedDeviceError):
if res := unsupported_exception.discovery_result: if host := unsupported_exception.host:
echo(f"{res.get('ip'):<15} UNSUPPORTED DEVICE") echo(f"{host:<15} UNSUPPORTED DEVICE")
echo(f"{'HOST':<15} {'DEVICE FAMILY':<20} {'ENCRYPT':<7} {'ALIAS'}") echo(f"{'HOST':<15} {'DEVICE FAMILY':<20} {'ENCRYPT':<7} {'ALIAS'}")
return await _discover(ctx, print_discovered, print_unsupported, do_echo=False) return await _discover(ctx, print_discovered, print_unsupported, do_echo=False)

View File

@ -67,7 +67,8 @@ async def connect(*, host: str | None = None, config: DeviceConfig) -> Device:
if (protocol := get_protocol(config=config)) is None: if (protocol := get_protocol(config=config)) is None:
raise UnsupportedDeviceError( raise UnsupportedDeviceError(
f"Unsupported device for {config.host}: " f"Unsupported device for {config.host}: "
+ f"{config.connection_type.device_family.value}" + f"{config.connection_type.device_family.value}",
host=config.host,
) )
try: try:
@ -110,7 +111,7 @@ async def _connect(config: DeviceConfig, protocol: BaseProtocol) -> Device:
_perf_log(True, "update") _perf_log(True, "update")
return device return device
elif device_class := get_device_class_from_family( elif device_class := get_device_class_from_family(
config.connection_type.device_family.value config.connection_type.device_family.value, https=config.connection_type.https
): ):
device = device_class(host=config.host, protocol=protocol) device = device_class(host=config.host, protocol=protocol)
await device.update() await device.update()
@ -119,7 +120,8 @@ async def _connect(config: DeviceConfig, protocol: BaseProtocol) -> Device:
else: else:
raise UnsupportedDeviceError( raise UnsupportedDeviceError(
f"Unsupported device for {config.host}: " f"Unsupported device for {config.host}: "
+ f"{config.connection_type.device_family.value}" + f"{config.connection_type.device_family.value}",
host=config.host,
) )
@ -164,7 +166,9 @@ def get_device_class_from_sys_info(sysinfo: dict[str, Any]) -> type[IotDevice]:
return TYPE_TO_CLASS[_get_device_type_from_sys_info(sysinfo)] return TYPE_TO_CLASS[_get_device_type_from_sys_info(sysinfo)]
def get_device_class_from_family(device_type: str) -> type[Device] | None: def get_device_class_from_family(
device_type: str, *, https: bool
) -> type[Device] | None:
"""Return the device class from the type name.""" """Return the device class from the type name."""
supported_device_types: dict[str, type[Device]] = { supported_device_types: dict[str, type[Device]] = {
"SMART.TAPOPLUG": SmartDevice, "SMART.TAPOPLUG": SmartDevice,
@ -172,14 +176,16 @@ def get_device_class_from_family(device_type: str) -> type[Device] | None:
"SMART.TAPOSWITCH": SmartDevice, "SMART.TAPOSWITCH": SmartDevice,
"SMART.KASAPLUG": SmartDevice, "SMART.KASAPLUG": SmartDevice,
"SMART.TAPOHUB": SmartDevice, "SMART.TAPOHUB": SmartDevice,
"SMART.TAPOHUB.HTTPS": SmartCamera,
"SMART.KASAHUB": SmartDevice, "SMART.KASAHUB": SmartDevice,
"SMART.KASASWITCH": SmartDevice, "SMART.KASASWITCH": SmartDevice,
"SMART.IPCAMERA": SmartCamera, "SMART.IPCAMERA.HTTPS": SmartCamera,
"IOT.SMARTPLUGSWITCH": IotPlug, "IOT.SMARTPLUGSWITCH": IotPlug,
"IOT.SMARTBULB": IotBulb, "IOT.SMARTBULB": IotBulb,
} }
lookup_key = f"{device_type}{'.HTTPS' if https else ''}"
if ( if (
cls := supported_device_types.get(device_type) cls := supported_device_types.get(lookup_key)
) is None and device_type.startswith("SMART."): ) is None and device_type.startswith("SMART."):
_LOGGER.warning("Unknown SMART device with %s, using SmartDevice", device_type) _LOGGER.warning("Unknown SMART device with %s, using SmartDevice", device_type)
cls = SmartDevice cls = SmartDevice

View File

@ -573,7 +573,11 @@ class Discover:
) )
) )
and (protocol := get_protocol(config)) and (protocol := get_protocol(config))
and (device_class := get_device_class_from_family(device_family.value)) and (
device_class := get_device_class_from_family(
device_family.value, https=https
)
)
} }
for protocol, config in candidates.values(): for protocol, config in candidates.values():
try: try:
@ -591,7 +595,10 @@ class Discover:
"""Find SmartDevice subclass for device described by passed data.""" """Find SmartDevice subclass for device described by passed data."""
if "result" in info: if "result" in info:
discovery_result = DiscoveryResult(**info["result"]) discovery_result = DiscoveryResult(**info["result"])
dev_class = get_device_class_from_family(discovery_result.device_type) https = discovery_result.mgt_encrypt_schm.is_support_https
dev_class = get_device_class_from_family(
discovery_result.device_type, https=https
)
if not dev_class: if not dev_class:
raise UnsupportedDeviceError( raise UnsupportedDeviceError(
"Unknown device type: %s" % discovery_result.device_type, "Unknown device type: %s" % discovery_result.device_type,
@ -662,7 +669,9 @@ class Discover:
) from ex ) from ex
try: try:
discovery_result = DiscoveryResult(**info["result"]) discovery_result = DiscoveryResult(**info["result"])
if discovery_result.encrypt_info: if (
encrypt_info := discovery_result.encrypt_info
) and encrypt_info.sym_schm == "AES":
Discover._decrypt_discovery_data(discovery_result) Discover._decrypt_discovery_data(discovery_result)
except ValidationError as ex: except ValidationError as ex:
if debug_enabled: if debug_enabled:
@ -677,21 +686,23 @@ class Discover:
pf(data), pf(data),
) )
raise UnsupportedDeviceError( raise UnsupportedDeviceError(
f"Unable to parse discovery from device: {config.host}: {ex}" f"Unable to parse discovery from device: {config.host}: {ex}",
host=config.host,
) from ex ) from ex
type_ = discovery_result.device_type type_ = discovery_result.device_type
encrypt_schm = discovery_result.mgt_encrypt_schm
try: try:
if not ( if not (encrypt_type := encrypt_schm.encrypt_type) and (
encrypt_type := discovery_result.mgt_encrypt_schm.encrypt_type encrypt_info := discovery_result.encrypt_info
) and (encrypt_info := discovery_result.encrypt_info): ):
encrypt_type = encrypt_info.sym_schm encrypt_type = encrypt_info.sym_schm
if not encrypt_type: if not encrypt_type:
raise UnsupportedDeviceError( raise UnsupportedDeviceError(
f"Unsupported device {config.host} of type {type_} " f"Unsupported device {config.host} of type {type_} "
+ "with no encryption type", + "with no encryption type",
discovery_result=discovery_result.get_dict(), discovery_result=discovery_result.get_dict(),
host=config.host,
) )
config.connection_type = DeviceConnectionParameters.from_values( config.connection_type = DeviceConnectionParameters.from_values(
type_, type_,
@ -704,12 +715,18 @@ class Discover:
f"Unsupported device {config.host} of type {type_} " f"Unsupported device {config.host} of type {type_} "
+ f"with encrypt_type {discovery_result.mgt_encrypt_schm.encrypt_type}", + f"with encrypt_type {discovery_result.mgt_encrypt_schm.encrypt_type}",
discovery_result=discovery_result.get_dict(), discovery_result=discovery_result.get_dict(),
host=config.host,
) from ex ) from ex
if (device_class := get_device_class_from_family(type_)) is None: if (
device_class := get_device_class_from_family(
type_, https=encrypt_schm.is_support_https
)
) is None:
_LOGGER.warning("Got unsupported device type: %s", type_) _LOGGER.warning("Got unsupported device type: %s", type_)
raise UnsupportedDeviceError( raise UnsupportedDeviceError(
f"Unsupported device {config.host} of type {type_}: {info}", f"Unsupported device {config.host} of type {type_}: {info}",
discovery_result=discovery_result.get_dict(), discovery_result=discovery_result.get_dict(),
host=config.host,
) )
if (protocol := get_protocol(config)) is None: if (protocol := get_protocol(config)) is None:
_LOGGER.warning( _LOGGER.warning(
@ -719,6 +736,7 @@ class Discover:
f"Unsupported encryption scheme {config.host} of " f"Unsupported encryption scheme {config.host} of "
+ f"type {config.connection_type.to_dict()}: {info}", + f"type {config.connection_type.to_dict()}: {info}",
discovery_result=discovery_result.get_dict(), discovery_result=discovery_result.get_dict(),
host=config.host,
) )
if debug_enabled: if debug_enabled:

View File

@ -31,6 +31,7 @@ class UnsupportedDeviceError(KasaException):
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
self.discovery_result = kwargs.get("discovery_result") self.discovery_result = kwargs.get("discovery_result")
self.host = kwargs.get("host")
super().__init__(*args) super().__init__(*args)

View File

@ -15,8 +15,10 @@ from .fixtureinfo import FixtureInfo, filter_fixtures, idgenerator
DISCOVERY_MOCK_IP = "127.0.0.123" DISCOVERY_MOCK_IP = "127.0.0.123"
def _make_unsupported(device_family, encrypt_type): def _make_unsupported(device_family, encrypt_type, *, omit_keys=None):
return { if omit_keys is None:
omit_keys = {"encrypt_info": None}
result = {
"result": { "result": {
"device_id": "xx", "device_id": "xx",
"owner": "xx", "owner": "xx",
@ -33,9 +35,17 @@ def _make_unsupported(device_family, encrypt_type):
"http_port": 80, "http_port": 80,
"lv": 2, "lv": 2,
}, },
"encrypt_info": {"data": "", "key": "", "sym_schm": encrypt_type},
}, },
"error_code": 0, "error_code": 0,
} }
for key, val in omit_keys.items():
if val is None:
result["result"].pop(key)
else:
result["result"][key].pop(val)
return result
UNSUPPORTED_DEVICES = { UNSUPPORTED_DEVICES = {
@ -43,6 +53,16 @@ UNSUPPORTED_DEVICES = {
"wrong_encryption_iot": _make_unsupported("IOT.SMARTPLUGSWITCH", "AES"), "wrong_encryption_iot": _make_unsupported("IOT.SMARTPLUGSWITCH", "AES"),
"wrong_encryption_smart": _make_unsupported("SMART.TAPOBULB", "IOT"), "wrong_encryption_smart": _make_unsupported("SMART.TAPOBULB", "IOT"),
"unknown_encryption": _make_unsupported("IOT.SMARTPLUGSWITCH", "FOO"), "unknown_encryption": _make_unsupported("IOT.SMARTPLUGSWITCH", "FOO"),
"missing_encrypt_type": _make_unsupported(
"SMART.TAPOBULB",
"FOO",
omit_keys={"mgt_encrypt_schm": "encrypt_type", "encrypt_info": None},
),
"unable_to_parse": _make_unsupported(
"SMART.TAPOBULB",
"FOO",
omit_keys={"mgt_encrypt_schm": None},
),
} }
@ -90,6 +110,7 @@ def create_discovery_mock(ip: str, fixture_data: dict):
query_data: dict query_data: dict
device_type: str device_type: str
encrypt_type: str encrypt_type: str
https: bool
login_version: int | None = None login_version: int | None = None
port_override: int | None = None port_override: int | None = None
@ -110,6 +131,7 @@ def create_discovery_mock(ip: str, fixture_data: dict):
"encrypt_type" "encrypt_type"
] ]
login_version = fixture_data["discovery_result"]["mgt_encrypt_schm"].get("lv") login_version = fixture_data["discovery_result"]["mgt_encrypt_schm"].get("lv")
https = fixture_data["discovery_result"]["mgt_encrypt_schm"]["is_support_https"]
dm = _DiscoveryMock( dm = _DiscoveryMock(
ip, ip,
80, 80,
@ -118,6 +140,7 @@ def create_discovery_mock(ip: str, fixture_data: dict):
fixture_data, fixture_data,
device_type, device_type,
encrypt_type, encrypt_type,
https,
login_version, login_version,
) )
else: else:
@ -134,6 +157,7 @@ def create_discovery_mock(ip: str, fixture_data: dict):
fixture_data, fixture_data,
device_type, device_type,
encrypt_type, encrypt_type,
False,
login_version, login_version,
) )

View File

@ -764,7 +764,6 @@ async def test_discover_unsupported(unsupported_device_info, runner):
) )
assert res.exit_code == 0 assert res.exit_code == 0
assert "== Unsupported device ==" in res.output assert "== Unsupported device ==" in res.output
assert "== Discovery Result ==" in res.output
async def test_host_unsupported(unsupported_device_info, runner): async def test_host_unsupported(unsupported_device_info, runner):

View File

@ -189,5 +189,5 @@ async def test_device_class_from_unknown_family(caplog):
"""Verify that unknown SMART devices yield a warning and fallback to SmartDevice.""" """Verify that unknown SMART devices yield a warning and fallback to SmartDevice."""
dummy_name = "SMART.foo" dummy_name = "SMART.foo"
with caplog.at_level(logging.WARNING): with caplog.at_level(logging.WARNING):
assert get_device_class_from_family(dummy_name) == SmartDevice assert get_device_class_from_family(dummy_name, https=False) == SmartDevice
assert f"Unknown SMART device with {dummy_name}" in caplog.text assert f"Unknown SMART device with {dummy_name}" in caplog.text

View File

@ -658,12 +658,14 @@ async def test_discovery_decryption():
async def test_discover_try_connect_all(discovery_mock, mocker): async def test_discover_try_connect_all(discovery_mock, mocker):
"""Test that device update is called on main.""" """Test that device update is called on main."""
if "result" in discovery_mock.discovery_data: if "result" in discovery_mock.discovery_data:
dev_class = get_device_class_from_family(discovery_mock.device_type) dev_class = get_device_class_from_family(
discovery_mock.device_type, https=discovery_mock.https
)
cparams = DeviceConnectionParameters.from_values( cparams = DeviceConnectionParameters.from_values(
discovery_mock.device_type, discovery_mock.device_type,
discovery_mock.encrypt_type, discovery_mock.encrypt_type,
discovery_mock.login_version, discovery_mock.login_version,
False, discovery_mock.https,
) )
protocol = get_protocol( protocol = get_protocol(
DeviceConfig(discovery_mock.ip, connection_type=cparams) DeviceConfig(discovery_mock.ip, connection_type=cparams)