diff --git a/devtools/dump_devinfo.py b/devtools/dump_devinfo.py index cee7a7bf..a0fff0e5 100644 --- a/devtools/dump_devinfo.py +++ b/devtools/dump_devinfo.py @@ -300,7 +300,9 @@ async def cli( connection_type = DeviceConnectionParameters.from_values( dr.device_type, dr.mgt_encrypt_schm.encrypt_type, - dr.mgt_encrypt_schm.lv, + login_version=dr.mgt_encrypt_schm.lv, + https=dr.mgt_encrypt_schm.is_support_https, + http_port=dr.mgt_encrypt_schm.http_port, ) dc = DeviceConfig( host=host, diff --git a/kasa/device_factory.py b/kasa/device_factory.py index 25792f2c..83661038 100644 --- a/kasa/device_factory.py +++ b/kasa/device_factory.py @@ -205,11 +205,11 @@ def get_protocol(config: DeviceConfig, *, strict: bool = False) -> BaseProtocol return IotProtocol(transport=LinkieTransportV2(config=config)) # Older FW used a different transport - if ctype.device_family is DeviceFamily.SmartTapoRobovac: - if strict and ctype.encryption_type is not DeviceEncryptionType.Aes: - return None - if ctype.encryption_type is DeviceEncryptionType.Aes: - return SmartProtocol(transport=SslTransport(config=config)) + if ( + ctype.device_family is DeviceFamily.SmartTapoRobovac + and ctype.encryption_type is DeviceEncryptionType.Aes + ): + return SmartProtocol(transport=SslTransport(config=config)) protocol_transport_key = ( protocol_name diff --git a/kasa/deviceconfig.py b/kasa/deviceconfig.py index c5d5b1d5..b6325570 100644 --- a/kasa/deviceconfig.py +++ b/kasa/deviceconfig.py @@ -20,7 +20,7 @@ None {'host': '127.0.0.3', 'timeout': 5, 'credentials': {'username': 'user@example.com', \ 'password': 'great_password'}, 'connection_type'\ : {'device_family': 'SMART.TAPOBULB', 'encryption_type': 'KLAP', 'login_version': 2, \ -'https': False}} +'https': False, 'http_port': 80}} >>> later_device = await Device.connect(config=Device.Config.from_dict(config_dict)) >>> print(later_device.alias) # Alias is available as connect() calls update() @@ -98,13 +98,16 @@ class DeviceConnectionParameters(_DeviceConfigBaseMixin): encryption_type: DeviceEncryptionType login_version: int | None = None https: bool = False + http_port: int | None = None @staticmethod def from_values( device_family: str, encryption_type: str, + *, login_version: int | None = None, https: bool | None = None, + http_port: int | None = None, ) -> DeviceConnectionParameters: """Return connection parameters from string values.""" try: @@ -115,6 +118,7 @@ class DeviceConnectionParameters(_DeviceConfigBaseMixin): DeviceEncryptionType(encryption_type), login_version, https, + http_port=http_port, ) except (ValueError, TypeError) as ex: raise KasaException( diff --git a/kasa/discover.py b/kasa/discover.py index abcd7d5f..bfdb5fa8 100755 --- a/kasa/discover.py +++ b/kasa/discover.py @@ -637,10 +637,10 @@ class Discover: Device.Family.IotIpCamera, } candidates: dict[ - tuple[type[BaseProtocol], type[BaseTransport], type[Device]], + tuple[type[BaseProtocol], type[BaseTransport], type[Device], bool], tuple[BaseProtocol, DeviceConfig], ] = { - (type(protocol), type(protocol._transport), device_class): ( + (type(protocol), type(protocol._transport), device_class, https): ( protocol, config, ) @@ -870,8 +870,9 @@ class Discover: config.connection_type = DeviceConnectionParameters.from_values( type_, encrypt_type, - login_version, - encrypt_schm.is_support_https, + login_version=login_version, + https=encrypt_schm.is_support_https, + http_port=encrypt_schm.http_port, ) except KasaException as ex: raise UnsupportedDeviceError( diff --git a/kasa/transports/aestransport.py b/kasa/transports/aestransport.py index 3466ca98..45b963fe 100644 --- a/kasa/transports/aestransport.py +++ b/kasa/transports/aestransport.py @@ -120,6 +120,8 @@ class AesTransport(BaseTransport): @property def default_port(self) -> int: """Default port for the transport.""" + if port := self._config.connection_type.http_port: + return port return self.DEFAULT_PORT @property diff --git a/kasa/transports/klaptransport.py b/kasa/transports/klaptransport.py index 0be7cfc1..8253e0ae 100644 --- a/kasa/transports/klaptransport.py +++ b/kasa/transports/klaptransport.py @@ -93,6 +93,8 @@ class KlapTransport(BaseTransport): """ DEFAULT_PORT: int = 80 + DEFAULT_HTTPS_PORT: int = 4433 + SESSION_COOKIE_NAME = "TP_SESSIONID" TIMEOUT_COOKIE_NAME = "TIMEOUT" # Copy & paste from sslaestransport @@ -144,6 +146,13 @@ class KlapTransport(BaseTransport): @property def default_port(self) -> int: """Default port for the transport.""" + config = self._config + if port := config.connection_type.http_port: + return port + + if config.connection_type.https: + return self.DEFAULT_HTTPS_PORT + return self.DEFAULT_PORT @property diff --git a/kasa/transports/linkietransport.py b/kasa/transports/linkietransport.py index 779d182e..b817373c 100644 --- a/kasa/transports/linkietransport.py +++ b/kasa/transports/linkietransport.py @@ -55,6 +55,8 @@ class LinkieTransportV2(BaseTransport): @property def default_port(self) -> int: """Default port for the transport.""" + if port := self._config.connection_type.http_port: + return port return self.DEFAULT_PORT @property diff --git a/kasa/transports/sslaestransport.py b/kasa/transports/sslaestransport.py index eb67eda8..eeb29809 100644 --- a/kasa/transports/sslaestransport.py +++ b/kasa/transports/sslaestransport.py @@ -133,6 +133,8 @@ class SslAesTransport(BaseTransport): @property def default_port(self) -> int: """Default port for the transport.""" + if port := self._config.connection_type.http_port: + return port return self.DEFAULT_PORT @staticmethod diff --git a/kasa/transports/ssltransport.py b/kasa/transports/ssltransport.py index 4471dccb..e4fef9a3 100644 --- a/kasa/transports/ssltransport.py +++ b/kasa/transports/ssltransport.py @@ -94,6 +94,8 @@ class SslTransport(BaseTransport): @property def default_port(self) -> int: """Default port for the transport.""" + if port := self._config.connection_type.http_port: + return port return self.DEFAULT_PORT @property diff --git a/tests/discovery_fixtures.py b/tests/discovery_fixtures.py index eb843f1a..2db79e91 100644 --- a/tests/discovery_fixtures.py +++ b/tests/discovery_fixtures.py @@ -159,6 +159,7 @@ def create_discovery_mock(ip: str, fixture_data: dict): https: bool login_version: int | None = None port_override: int | None = None + http_port: int | None = None @property def model(self) -> str: @@ -194,9 +195,15 @@ def create_discovery_mock(ip: str, fixture_data: dict): ): login_version = max([int(i) for i in et]) https = discovery_result["mgt_encrypt_schm"]["is_support_https"] + http_port = discovery_result["mgt_encrypt_schm"].get("http_port") + if not http_port: # noqa: SIM108 + # Not all discovery responses set the http port, i.e. smartcam. + default_port = 443 if https else 80 + else: + default_port = http_port dm = _DiscoveryMock( ip, - 80, + default_port, 20002, discovery_data, fixture_data, @@ -204,6 +211,7 @@ def create_discovery_mock(ip: str, fixture_data: dict): encrypt_type, https, login_version, + http_port=http_port, ) else: sys_info = fixture_data["system"]["get_sysinfo"] diff --git a/tests/test_device_factory.py b/tests/test_device_factory.py index c21c8fe9..d6bdaedf 100644 --- a/tests/test_device_factory.py +++ b/tests/test_device_factory.py @@ -63,8 +63,9 @@ def _get_connection_type_device_class(discovery_info): connection_type = DeviceConnectionParameters.from_values( dr.device_type, dr.mgt_encrypt_schm.encrypt_type, - dr.mgt_encrypt_schm.lv, - dr.mgt_encrypt_schm.is_support_https, + login_version=dr.mgt_encrypt_schm.lv, + https=dr.mgt_encrypt_schm.is_support_https, + http_port=dr.mgt_encrypt_schm.http_port, ) else: connection_type = DeviceConnectionParameters.from_values( diff --git a/tests/test_discovery.py b/tests/test_discovery.py index fbbed879..96c9e9c6 100644 --- a/tests/test_discovery.py +++ b/tests/test_discovery.py @@ -157,14 +157,15 @@ async def test_discover_single(discovery_mock, custom_port, mocker): ) # Make sure discovery does not call update() assert update_mock.call_count == 0 - if discovery_mock.default_port == 80: + if discovery_mock.default_port != 9999: assert x.alias is None ct = DeviceConnectionParameters.from_values( discovery_mock.device_type, discovery_mock.encrypt_type, - discovery_mock.login_version, - discovery_mock.https, + login_version=discovery_mock.login_version, + https=discovery_mock.https, + http_port=discovery_mock.http_port, ) config = DeviceConfig( host=host, @@ -425,9 +426,9 @@ async def test_discover_single_http_client(discovery_mock, mocker): x: Device = await Discover.discover_single(host) - assert x.config.uses_http == (discovery_mock.default_port == 80) + assert x.config.uses_http == (discovery_mock.default_port != 9999) - if discovery_mock.default_port == 80: + if discovery_mock.default_port != 9999: assert x.protocol._transport._http_client.client != http_client x.config.http_client = http_client assert x.protocol._transport._http_client.client == http_client @@ -442,9 +443,9 @@ async def test_discover_http_client(discovery_mock, mocker): devices = await Discover.discover(discovery_timeout=0) x: Device = devices[host] - assert x.config.uses_http == (discovery_mock.default_port == 80) + assert x.config.uses_http == (discovery_mock.default_port != 9999) - if discovery_mock.default_port == 80: + if discovery_mock.default_port != 9999: assert x.protocol._transport._http_client.client != http_client x.config.http_client = http_client assert x.protocol._transport._http_client.client == http_client @@ -674,8 +675,9 @@ async def test_discover_try_connect_all(discovery_mock, mocker): cparams = DeviceConnectionParameters.from_values( discovery_mock.device_type, discovery_mock.encrypt_type, - discovery_mock.login_version, - discovery_mock.https, + login_version=discovery_mock.login_version, + https=discovery_mock.https, + http_port=discovery_mock.http_port, ) protocol = get_protocol( DeviceConfig(discovery_mock.ip, connection_type=cparams) @@ -687,10 +689,13 @@ async def test_discover_try_connect_all(discovery_mock, mocker): protocol_class = IotProtocol transport_class = XorTransport + default_port = discovery_mock.default_port + async def _query(self, *args, **kwargs): if ( self.__class__ is protocol_class and self._transport.__class__ is transport_class + and self._transport._port == default_port ): return discovery_mock.query_data raise KasaException("Unable to execute query") @@ -699,6 +704,7 @@ async def test_discover_try_connect_all(discovery_mock, mocker): if ( self.protocol.__class__ is protocol_class and self.protocol._transport.__class__ is transport_class + and self.protocol._transport._port == default_port ): return