Update to use http port from discovery if present

This commit is contained in:
Steven B 2025-01-21 13:49:23 +00:00
parent e163f5f61e
commit 305e732238
No known key found for this signature in database
GPG Key ID: 6D5B46B3679F2A43
12 changed files with 62 additions and 23 deletions

View File

@ -300,7 +300,9 @@ async def cli(
connection_type = DeviceConnectionParameters.from_values( connection_type = DeviceConnectionParameters.from_values(
dr.device_type, dr.device_type,
dr.mgt_encrypt_schm.encrypt_type, dr.mgt_encrypt_schm.encrypt_type,
dr.mgt_encrypt_schm.lv, login_version=dr.mgt_encrypt_schm.lv,
https=dr.mgt_encrypt_schm.is_support_https,
http_port=dr.mgt_encrypt_schm.http_port,
) )
dc = DeviceConfig( dc = DeviceConfig(
host=host, host=host,

View File

@ -205,10 +205,10 @@ def get_protocol(config: DeviceConfig, *, strict: bool = False) -> BaseProtocol
return IotProtocol(transport=LinkieTransportV2(config=config)) return IotProtocol(transport=LinkieTransportV2(config=config))
# Older FW used a different transport # Older FW used a different transport
if ctype.device_family is DeviceFamily.SmartTapoRobovac: if (
if strict and ctype.encryption_type is not DeviceEncryptionType.Aes: ctype.device_family is DeviceFamily.SmartTapoRobovac
return None and ctype.encryption_type is DeviceEncryptionType.Aes
if ctype.encryption_type is DeviceEncryptionType.Aes: ):
return SmartProtocol(transport=SslTransport(config=config)) return SmartProtocol(transport=SslTransport(config=config))
protocol_transport_key = ( protocol_transport_key = (

View File

@ -20,7 +20,7 @@ None
{'host': '127.0.0.3', 'timeout': 5, 'credentials': {'username': 'user@example.com', \ {'host': '127.0.0.3', 'timeout': 5, 'credentials': {'username': 'user@example.com', \
'password': 'great_password'}, 'connection_type'\ 'password': 'great_password'}, 'connection_type'\
: {'device_family': 'SMART.TAPOBULB', 'encryption_type': 'KLAP', 'login_version': 2, \ : {'device_family': 'SMART.TAPOBULB', 'encryption_type': 'KLAP', 'login_version': 2, \
'https': False}} 'https': False, 'http_port': 80}}
>>> later_device = await Device.connect(config=Device.Config.from_dict(config_dict)) >>> later_device = await Device.connect(config=Device.Config.from_dict(config_dict))
>>> print(later_device.alias) # Alias is available as connect() calls update() >>> print(later_device.alias) # Alias is available as connect() calls update()
@ -98,13 +98,16 @@ class DeviceConnectionParameters(_DeviceConfigBaseMixin):
encryption_type: DeviceEncryptionType encryption_type: DeviceEncryptionType
login_version: int | None = None login_version: int | None = None
https: bool = False https: bool = False
http_port: int | None = None
@staticmethod @staticmethod
def from_values( def from_values(
device_family: str, device_family: str,
encryption_type: str, encryption_type: str,
*,
login_version: int | None = None, login_version: int | None = None,
https: bool | None = None, https: bool | None = None,
http_port: int | None = None,
) -> DeviceConnectionParameters: ) -> DeviceConnectionParameters:
"""Return connection parameters from string values.""" """Return connection parameters from string values."""
try: try:
@ -115,6 +118,7 @@ class DeviceConnectionParameters(_DeviceConfigBaseMixin):
DeviceEncryptionType(encryption_type), DeviceEncryptionType(encryption_type),
login_version, login_version,
https, https,
http_port=http_port,
) )
except (ValueError, TypeError) as ex: except (ValueError, TypeError) as ex:
raise KasaException( raise KasaException(

View File

@ -637,10 +637,10 @@ class Discover:
Device.Family.IotIpCamera, Device.Family.IotIpCamera,
} }
candidates: dict[ candidates: dict[
tuple[type[BaseProtocol], type[BaseTransport], type[Device]], tuple[type[BaseProtocol], type[BaseTransport], type[Device], bool],
tuple[BaseProtocol, DeviceConfig], tuple[BaseProtocol, DeviceConfig],
] = { ] = {
(type(protocol), type(protocol._transport), device_class): ( (type(protocol), type(protocol._transport), device_class, https): (
protocol, protocol,
config, config,
) )
@ -870,8 +870,9 @@ class Discover:
config.connection_type = DeviceConnectionParameters.from_values( config.connection_type = DeviceConnectionParameters.from_values(
type_, type_,
encrypt_type, encrypt_type,
login_version, login_version=login_version,
encrypt_schm.is_support_https, https=encrypt_schm.is_support_https,
http_port=encrypt_schm.http_port,
) )
except KasaException as ex: except KasaException as ex:
raise UnsupportedDeviceError( raise UnsupportedDeviceError(

View File

@ -120,6 +120,8 @@ class AesTransport(BaseTransport):
@property @property
def default_port(self) -> int: def default_port(self) -> int:
"""Default port for the transport.""" """Default port for the transport."""
if port := self._config.connection_type.http_port:
return port
return self.DEFAULT_PORT return self.DEFAULT_PORT
@property @property

View File

@ -93,6 +93,8 @@ class KlapTransport(BaseTransport):
""" """
DEFAULT_PORT: int = 80 DEFAULT_PORT: int = 80
DEFAULT_HTTPS_PORT: int = 4433
SESSION_COOKIE_NAME = "TP_SESSIONID" SESSION_COOKIE_NAME = "TP_SESSIONID"
TIMEOUT_COOKIE_NAME = "TIMEOUT" TIMEOUT_COOKIE_NAME = "TIMEOUT"
# Copy & paste from sslaestransport # Copy & paste from sslaestransport
@ -144,6 +146,13 @@ class KlapTransport(BaseTransport):
@property @property
def default_port(self) -> int: def default_port(self) -> int:
"""Default port for the transport.""" """Default port for the transport."""
config = self._config
if port := config.connection_type.http_port:
return port
if config.connection_type.https:
return self.DEFAULT_HTTPS_PORT
return self.DEFAULT_PORT return self.DEFAULT_PORT
@property @property

View File

@ -55,6 +55,8 @@ class LinkieTransportV2(BaseTransport):
@property @property
def default_port(self) -> int: def default_port(self) -> int:
"""Default port for the transport.""" """Default port for the transport."""
if port := self._config.connection_type.http_port:
return port
return self.DEFAULT_PORT return self.DEFAULT_PORT
@property @property

View File

@ -133,6 +133,8 @@ class SslAesTransport(BaseTransport):
@property @property
def default_port(self) -> int: def default_port(self) -> int:
"""Default port for the transport.""" """Default port for the transport."""
if port := self._config.connection_type.http_port:
return port
return self.DEFAULT_PORT return self.DEFAULT_PORT
@staticmethod @staticmethod

View File

@ -94,6 +94,8 @@ class SslTransport(BaseTransport):
@property @property
def default_port(self) -> int: def default_port(self) -> int:
"""Default port for the transport.""" """Default port for the transport."""
if port := self._config.connection_type.http_port:
return port
return self.DEFAULT_PORT return self.DEFAULT_PORT
@property @property

View File

@ -159,6 +159,7 @@ def create_discovery_mock(ip: str, fixture_data: dict):
https: bool https: bool
login_version: int | None = None login_version: int | None = None
port_override: int | None = None port_override: int | None = None
http_port: int | None = None
@property @property
def model(self) -> str: def model(self) -> str:
@ -194,9 +195,15 @@ def create_discovery_mock(ip: str, fixture_data: dict):
): ):
login_version = max([int(i) for i in et]) login_version = max([int(i) for i in et])
https = discovery_result["mgt_encrypt_schm"]["is_support_https"] https = discovery_result["mgt_encrypt_schm"]["is_support_https"]
http_port = discovery_result["mgt_encrypt_schm"].get("http_port")
if not http_port: # noqa: SIM108
# Not all discovery responses set the http port, i.e. smartcam.
default_port = 443 if https else 80
else:
default_port = http_port
dm = _DiscoveryMock( dm = _DiscoveryMock(
ip, ip,
80, default_port,
20002, 20002,
discovery_data, discovery_data,
fixture_data, fixture_data,
@ -204,6 +211,7 @@ def create_discovery_mock(ip: str, fixture_data: dict):
encrypt_type, encrypt_type,
https, https,
login_version, login_version,
http_port=http_port,
) )
else: else:
sys_info = fixture_data["system"]["get_sysinfo"] sys_info = fixture_data["system"]["get_sysinfo"]

View File

@ -63,8 +63,9 @@ def _get_connection_type_device_class(discovery_info):
connection_type = DeviceConnectionParameters.from_values( connection_type = DeviceConnectionParameters.from_values(
dr.device_type, dr.device_type,
dr.mgt_encrypt_schm.encrypt_type, dr.mgt_encrypt_schm.encrypt_type,
dr.mgt_encrypt_schm.lv, login_version=dr.mgt_encrypt_schm.lv,
dr.mgt_encrypt_schm.is_support_https, https=dr.mgt_encrypt_schm.is_support_https,
http_port=dr.mgt_encrypt_schm.http_port,
) )
else: else:
connection_type = DeviceConnectionParameters.from_values( connection_type = DeviceConnectionParameters.from_values(

View File

@ -157,14 +157,15 @@ async def test_discover_single(discovery_mock, custom_port, mocker):
) )
# Make sure discovery does not call update() # Make sure discovery does not call update()
assert update_mock.call_count == 0 assert update_mock.call_count == 0
if discovery_mock.default_port == 80: if discovery_mock.default_port != 9999:
assert x.alias is None assert x.alias is None
ct = DeviceConnectionParameters.from_values( ct = DeviceConnectionParameters.from_values(
discovery_mock.device_type, discovery_mock.device_type,
discovery_mock.encrypt_type, discovery_mock.encrypt_type,
discovery_mock.login_version, login_version=discovery_mock.login_version,
discovery_mock.https, https=discovery_mock.https,
http_port=discovery_mock.http_port,
) )
config = DeviceConfig( config = DeviceConfig(
host=host, host=host,
@ -425,9 +426,9 @@ async def test_discover_single_http_client(discovery_mock, mocker):
x: Device = await Discover.discover_single(host) x: Device = await Discover.discover_single(host)
assert x.config.uses_http == (discovery_mock.default_port == 80) assert x.config.uses_http == (discovery_mock.default_port != 9999)
if discovery_mock.default_port == 80: if discovery_mock.default_port != 9999:
assert x.protocol._transport._http_client.client != http_client assert x.protocol._transport._http_client.client != http_client
x.config.http_client = http_client x.config.http_client = http_client
assert x.protocol._transport._http_client.client == http_client assert x.protocol._transport._http_client.client == http_client
@ -442,9 +443,9 @@ async def test_discover_http_client(discovery_mock, mocker):
devices = await Discover.discover(discovery_timeout=0) devices = await Discover.discover(discovery_timeout=0)
x: Device = devices[host] x: Device = devices[host]
assert x.config.uses_http == (discovery_mock.default_port == 80) assert x.config.uses_http == (discovery_mock.default_port != 9999)
if discovery_mock.default_port == 80: if discovery_mock.default_port != 9999:
assert x.protocol._transport._http_client.client != http_client assert x.protocol._transport._http_client.client != http_client
x.config.http_client = http_client x.config.http_client = http_client
assert x.protocol._transport._http_client.client == http_client assert x.protocol._transport._http_client.client == http_client
@ -674,8 +675,9 @@ async def test_discover_try_connect_all(discovery_mock, mocker):
cparams = DeviceConnectionParameters.from_values( cparams = DeviceConnectionParameters.from_values(
discovery_mock.device_type, discovery_mock.device_type,
discovery_mock.encrypt_type, discovery_mock.encrypt_type,
discovery_mock.login_version, login_version=discovery_mock.login_version,
discovery_mock.https, https=discovery_mock.https,
http_port=discovery_mock.http_port,
) )
protocol = get_protocol( protocol = get_protocol(
DeviceConfig(discovery_mock.ip, connection_type=cparams) DeviceConfig(discovery_mock.ip, connection_type=cparams)
@ -687,10 +689,13 @@ async def test_discover_try_connect_all(discovery_mock, mocker):
protocol_class = IotProtocol protocol_class = IotProtocol
transport_class = XorTransport transport_class = XorTransport
default_port = discovery_mock.default_port
async def _query(self, *args, **kwargs): async def _query(self, *args, **kwargs):
if ( if (
self.__class__ is protocol_class self.__class__ is protocol_class
and self._transport.__class__ is transport_class and self._transport.__class__ is transport_class
and self._transport._port == default_port
): ):
return discovery_mock.query_data return discovery_mock.query_data
raise KasaException("Unable to execute query") raise KasaException("Unable to execute query")
@ -699,6 +704,7 @@ async def test_discover_try_connect_all(discovery_mock, mocker):
if ( if (
self.protocol.__class__ is protocol_class self.protocol.__class__ is protocol_class
and self.protocol._transport.__class__ is transport_class and self.protocol._transport.__class__ is transport_class
and self.protocol._transport._port == default_port
): ):
return return