From 4a5bc20ee2660ee7e179f96b0c54e060401d57cc Mon Sep 17 00:00:00 2001 From: Steven B <51370195+sdb9696@users.noreply.github.com> Date: Tue, 10 Dec 2024 14:05:30 +0000 Subject: [PATCH] Update SslAesTransport for legacy firmware versions --- kasa/__init__.py | 3 +- kasa/device_factory.py | 27 ++--- kasa/discover.py | 8 +- kasa/protocols/__init__.py | 2 + kasa/protocols/smartcamprotocol.py | 2 +- kasa/transports/__init__.py | 2 + kasa/transports/sslaestransport.py | 153 ++++++++++++++++++++++++++--- tests/test_device_factory.py | 85 ++++++++++++++++ 8 files changed, 248 insertions(+), 34 deletions(-) diff --git a/kasa/__init__.py b/kasa/__init__.py index ee52eb3a..b8871f99 100755 --- a/kasa/__init__.py +++ b/kasa/__init__.py @@ -38,7 +38,7 @@ from kasa.feature import Feature from kasa.interfaces.light import HSV, ColorTempRange, Light, LightState from kasa.interfaces.thermostat import Thermostat, ThermostatState from kasa.module import Module -from kasa.protocols import BaseProtocol, IotProtocol, SmartProtocol +from kasa.protocols import BaseProtocol, IotProtocol, SmartCamProtocol, SmartProtocol from kasa.protocols.iotprotocol import _deprecated_TPLinkSmartHomeProtocol # noqa: F401 from kasa.smartcam.modules.camera import StreamResolution from kasa.transports import BaseTransport @@ -52,6 +52,7 @@ __all__ = [ "BaseTransport", "IotProtocol", "SmartProtocol", + "SmartCamProtocol", "LightState", "TurnOnBehaviors", "TurnOnBehavior", diff --git a/kasa/device_factory.py b/kasa/device_factory.py index a1015570..99218c81 100644 --- a/kasa/device_factory.py +++ b/kasa/device_factory.py @@ -8,7 +8,7 @@ from typing import Any from .device import Device from .device_type import DeviceType -from .deviceconfig import DeviceConfig +from .deviceconfig import DeviceConfig, DeviceFamily from .exceptions import KasaException, UnsupportedDeviceError from .iot import ( IotBulb, @@ -180,19 +180,23 @@ def get_protocol( config: DeviceConfig, ) -> BaseProtocol | None: """Return the protocol from the connection name.""" - protocol_name = config.connection_type.device_family.value.split(".")[0] ctype = config.connection_type + protocol_name = ctype.device_family.value.split(".")[0] + + if ctype.device_family is DeviceFamily.SmartIpCamera: + return SmartCamProtocol(transport=SslAesTransport(config=config)) + + if ctype.device_family is DeviceFamily.IotIpCamera: + return IotProtocol(transport=LinkieTransportV2(config=config)) + + if ctype.device_family is DeviceFamily.SmartTapoRobovac: + return SmartProtocol(transport=SslTransport(config=config)) protocol_transport_key = ( protocol_name + "." + ctype.encryption_type.value + (".HTTPS" if ctype.https else "") - + ( - f".{ctype.login_version}" - if ctype.login_version and ctype.login_version > 1 - else "" - ) ) _LOGGER.debug("Finding transport for %s", protocol_transport_key) @@ -201,12 +205,11 @@ def get_protocol( ] = { "IOT.XOR": (IotProtocol, XorTransport), "IOT.KLAP": (IotProtocol, KlapTransport), - "IOT.XOR.HTTPS.2": (IotProtocol, LinkieTransportV2), "SMART.AES": (SmartProtocol, AesTransport), - "SMART.AES.2": (SmartProtocol, AesTransport), - "SMART.KLAP.2": (SmartProtocol, KlapTransportV2), - "SMART.AES.HTTPS.2": (SmartCamProtocol, SslAesTransport), - "SMART.AES.HTTPS": (SmartProtocol, SslTransport), + "SMART.KLAP": (SmartProtocol, KlapTransportV2), + # Still require a lookup for SslAesTransport as H200 has a type of + # SMART.TAPOHUB. + "SMART.AES.HTTPS": (SmartCamProtocol, SslAesTransport), } if not (prot_tran_cls := supported_device_protocols.get(protocol_transport_key)): return None diff --git a/kasa/discover.py b/kasa/discover.py index 9cb0808d..5e8388a0 100755 --- a/kasa/discover.py +++ b/kasa/discover.py @@ -775,12 +775,10 @@ class Discover: ): encrypt_type = encrypt_info.sym_schm - if ( - not (login_version := encrypt_schm.lv) - and (et := discovery_result.encrypt_type) - and et == ["3"] + if not (login_version := encrypt_schm.lv) and ( + et := discovery_result.encrypt_type ): - login_version = 2 + login_version = max([int(i) for i in et]) if not encrypt_type: raise UnsupportedDeviceError( diff --git a/kasa/protocols/__init__.py b/kasa/protocols/__init__.py index 44130d7f..b994d732 100644 --- a/kasa/protocols/__init__.py +++ b/kasa/protocols/__init__.py @@ -2,6 +2,7 @@ from .iotprotocol import IotProtocol from .protocol import BaseProtocol +from .smartcamprotocol import SmartCamProtocol from .smartprotocol import SmartErrorCode, SmartProtocol __all__ = [ @@ -9,4 +10,5 @@ __all__ = [ "IotProtocol", "SmartErrorCode", "SmartProtocol", + "SmartCamProtocol", ] diff --git a/kasa/protocols/smartcamprotocol.py b/kasa/protocols/smartcamprotocol.py index 12caa207..324f8056 100644 --- a/kasa/protocols/smartcamprotocol.py +++ b/kasa/protocols/smartcamprotocol.py @@ -19,7 +19,7 @@ from ..transports.sslaestransport import ( SMART_RETRYABLE_ERRORS, SmartErrorCode, ) -from . import SmartProtocol +from .smartprotocol import SmartProtocol _LOGGER = logging.getLogger(__name__) diff --git a/kasa/transports/__init__.py b/kasa/transports/__init__.py index 602d0cca..192b4156 100644 --- a/kasa/transports/__init__.py +++ b/kasa/transports/__init__.py @@ -4,6 +4,7 @@ from .aestransport import AesEncyptionSession, AesTransport from .basetransport import BaseTransport from .klaptransport import KlapTransport, KlapTransportV2 from .linkietransport import LinkieTransportV2 +from .sslaestransport import SslAesTransport from .ssltransport import SslTransport from .xortransport import XorEncryption, XorTransport @@ -11,6 +12,7 @@ __all__ = [ "AesTransport", "AesEncyptionSession", "SslTransport", + "SslAesTransport", "BaseTransport", "KlapTransport", "KlapTransportV2", diff --git a/kasa/transports/sslaestransport.py b/kasa/transports/sslaestransport.py index 2061d293..677c0447 100644 --- a/kasa/transports/sslaestransport.py +++ b/kasa/transports/sslaestransport.py @@ -48,6 +48,10 @@ def _sha256_hash(payload: bytes) -> str: return hashlib.sha256(payload).hexdigest().upper() # noqa: S324 +def _sha1_hash(payload: bytes) -> str: + return hashlib.sha1(payload).hexdigest().upper() # noqa: S324 + + class TransportState(Enum): """Enum for AES state.""" @@ -107,11 +111,10 @@ class SslAesTransport(BaseTransport): self._app_url = URL(f"https://{self._host_port}") self._token_url: URL | None = None self._ssl_context: ssl.SSLContext | None = None - ref = str(self._token_url) if self._token_url else str(self._app_url) self._headers = { **self.COMMON_HEADERS, - "Host": self._host_port, - "Referer": ref, + "Host": self._host, + "Referer": f"https://{self._host}", } self._seq: int | None = None self._pwd_hash: str | None = None @@ -125,6 +128,7 @@ class SslAesTransport(BaseTransport): self._password = ch["pwd"] self._username = ch["un"] self._local_nonce: str | None = None + self._send_secure = True _LOGGER.debug("Created AES transport for %s", self._host) @@ -194,6 +198,10 @@ class SslAesTransport(BaseTransport): else: url = self._app_url + _LOGGER.debug( + "Sending secure passthrough from %s", + self._host, + ) encrypted_payload = self._encryption_session.encrypt(request.encode()) # type: ignore passthrough_request = { "method": "securePassthrough", @@ -254,6 +262,37 @@ class SslAesTransport(BaseTransport): ) from ex return ret_val # type: ignore[return-value] + async def send_unencrypted(self, request: str) -> dict[str, Any]: + """Send encrypted message as passthrough.""" + if self._state is TransportState.ESTABLISHED and self._token_url: + url = self._token_url + else: + url = self._app_url + + _LOGGER.debug( + "Sending unencrypted from %s", + self._host, + ) + + status_code, resp_dict = await self._http_client.post( + url, + json=request, + headers=self._headers, + ssl=await self._get_ssl_context(), + ) + + if status_code != 200: + raise KasaException( + f"{self._host} responded with an unexpected " + + f"status code {status_code}" + ) + + self._handle_response_error_code(resp_dict, "Error sending message") + + if TYPE_CHECKING: + resp_dict = cast(dict[str, Any], resp_dict) + return resp_dict + @staticmethod def generate_confirm_hash( local_nonce: str, server_nonce: str, pwd_hash: str @@ -302,8 +341,52 @@ class SslAesTransport(BaseTransport): async def perform_handshake(self) -> None: """Perform the handshake.""" - local_nonce, server_nonce, pwd_hash = await self.perform_handshake1() - await self.perform_handshake2(local_nonce, server_nonce, pwd_hash) + result = await self.perform_handshake1() + if result: + local_nonce, server_nonce, pwd_hash = result + await self.perform_handshake2(local_nonce, server_nonce, pwd_hash) + + async def try_perform_login(self) -> bool: + """Perform the md5 login.""" + _LOGGER.debug("Performing insecure login ...") + + pwd_hash = _md5_hash(self._pwd_to_hash().encode()) + username = self._username + body = { + "method": "login", + "params": { + "hashed": True, + "password": pwd_hash, + "username": username, + }, + } + + http_client = self._http_client + status_code, resp_dict = await http_client.post( + self._app_url, + json=body, + headers=self._headers, + ssl=await self._get_ssl_context(), + ) + if status_code != 200: + raise KasaException( + f"{self._host} responded with an unexpected " + + f"status code {status_code} to handshake2" + ) + resp_dict = cast(dict, resp_dict) + if resp_dict.get("error_code") == 0 and ( + stok := resp_dict.get("result", {}).get("stok") + ): + _LOGGER.debug( + "Succesfully logged in to %s with less secure passthrough", self._host + ) + self._send_secure = False + self._token_url = URL(f"{str(self._app_url)}/stok={stok}/ds") + self._pwd_hash = pwd_hash + return True + + _LOGGER.debug("Unable to log in to %s with less secure login", self._host) + return False async def perform_handshake2( self, local_nonce: str, server_nonce: str, pwd_hash: str @@ -355,13 +438,42 @@ class SslAesTransport(BaseTransport): self._state = TransportState.ESTABLISHED _LOGGER.debug("Handshake2 complete ...") - async def perform_handshake1(self) -> tuple[str, str, str]: + def _pwd_to_hash(self) -> str: + """Return the password to hash.""" + if self._credentials and self._credentials != Credentials(): + return self._credentials.password + + if self._username and self._password: + return self._password + + return self._default_credentials.password + + async def perform_handshake1(self) -> tuple[str, str, str] | None: """Perform the handshake1.""" resp_dict = None if self._username: local_nonce = secrets.token_bytes(8).hex().upper() resp_dict = await self.try_send_handshake1(self._username, local_nonce) + if ( + resp_dict + and (error_code := self._get_response_error(resp_dict)) + is SmartErrorCode.SESSION_EXPIRED + and ( + encrypt_type := resp_dict.get("result", {}) + .get("data", {}) + .get("encrypt_type") + ) + and (encrypt_type != ["3"]) + ): + _LOGGER.debug( + "Received encrypt_type %s for %s, trying less secure login", + encrypt_type, + self._host, + ) + if await self.try_perform_login(): + return None + # Try the default username. If it fails raise the original error_code if ( not resp_dict @@ -369,6 +481,7 @@ class SslAesTransport(BaseTransport): is not SmartErrorCode.INVALID_NONCE or "nonce" not in resp_dict["result"].get("data", {}) ): + _LOGGER.debug("Trying default credentials to %s", self._host) local_nonce = secrets.token_bytes(8).hex().upper() default_resp_dict = await self.try_send_handshake1( self._default_credentials.username, local_nonce @@ -378,7 +491,7 @@ class SslAesTransport(BaseTransport): ) is SmartErrorCode.INVALID_NONCE and "nonce" in default_resp_dict[ "result" ].get("data", {}): - _LOGGER.debug("Connected to {self._host} with default username") + _LOGGER.debug("Connected to %s with default username", self._host) self._username = self._default_credentials.username error_code = default_error_code resp_dict = default_resp_dict @@ -397,12 +510,8 @@ class SslAesTransport(BaseTransport): server_nonce = resp_dict["result"]["data"]["nonce"] device_confirm = resp_dict["result"]["data"]["device_confirm"] - if self._credentials and self._credentials != Credentials(): - pwd_hash = _sha256_hash(self._credentials.password.encode()) - elif self._username and self._password: - pwd_hash = _sha256_hash(self._password.encode()) - else: - pwd_hash = _sha256_hash(self._default_credentials.password.encode()) + + pwd_hash = _sha256_hash(self._pwd_to_hash().encode()) expected_confirm_sha256 = self.generate_confirm_hash( local_nonce, server_nonce, pwd_hash @@ -414,7 +523,9 @@ class SslAesTransport(BaseTransport): if TYPE_CHECKING: assert self._credentials assert self._credentials.password - pwd_hash = _md5_hash(self._credentials.password.encode()) + + pwd_hash = _md5_hash(self._pwd_to_hash().encode()) + expected_confirm_md5 = self.generate_confirm_hash( local_nonce, server_nonce, pwd_hash ) @@ -422,8 +533,17 @@ class SslAesTransport(BaseTransport): _LOGGER.debug("Credentials match") return local_nonce, server_nonce, pwd_hash + for val in {"admin", "tpadmin", "slprealtek"}: + for func in {_sha256_hash, _md5_hash, _sha1_hash, lambda x: x.decode()}: + pwd_hash = func(val.encode()) + ec = self.generate_confirm_hash(local_nonce, server_nonce, pwd_hash) + if device_confirm == ec: + _LOGGER.debug("Credentials match with %s %s", val, func.__name__) + return local_nonce, server_nonce, pwd_hash + msg = f"Server response doesn't match our challenge on ip {self._host}" _LOGGER.debug(msg) + raise AuthenticationError(msg) async def try_send_handshake1(self, username: str, local_nonce: str) -> dict: @@ -462,7 +582,10 @@ class SslAesTransport(BaseTransport): if self._state is TransportState.HANDSHAKE_REQUIRED: await self.perform_handshake() - return await self.send_secure_passthrough(request) + if self._send_secure: + return await self.send_secure_passthrough(request) + + return await self.send_unencrypted(request) async def close(self) -> None: """Close the http client and reset internal state.""" diff --git a/tests/test_device_factory.py b/tests/test_device_factory.py index ed73b3a3..66e24324 100644 --- a/tests/test_device_factory.py +++ b/tests/test_device_factory.py @@ -13,9 +13,13 @@ import aiohttp import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 from kasa import ( + BaseProtocol, Credentials, Discover, + IotProtocol, KasaException, + SmartCamProtocol, + SmartProtocol, ) from kasa.device_factory import ( Device, @@ -33,6 +37,16 @@ from kasa.deviceconfig import ( DeviceFamily, ) from kasa.discover import DiscoveryResult +from kasa.transports import ( + AesTransport, + BaseTransport, + KlapTransport, + KlapTransportV2, + LinkieTransportV2, + SslAesTransport, + SslTransport, + XorTransport, +) from .conftest import DISCOVERY_MOCK_IP @@ -203,3 +217,74 @@ async def test_device_class_from_unknown_family(caplog): with caplog.at_level(logging.DEBUG): assert get_device_class_from_family(dummy_name, https=False) == SmartDevice assert f"Unknown SMART device with {dummy_name}" in caplog.text + + +# Aliases to make the test params more readable +CP = DeviceConnectionParameters +DF = DeviceFamily +ET = DeviceEncryptionType + + +@pytest.mark.parametrize( + ("conn_params", "expected_protocol", "expected_transport"), + [ + pytest.param( + CP(DF.SmartIpCamera, ET.Aes, https=True), + SmartCamProtocol, + SslAesTransport, + id="smartcam", + ), + pytest.param( + CP(DF.SmartTapoHub, ET.Aes, https=True), + SmartCamProtocol, + SslAesTransport, + id="smartcam-hub", + ), + pytest.param( + CP(DF.IotIpCamera, ET.Aes, https=True), + IotProtocol, + LinkieTransportV2, + id="kasacam", + ), + pytest.param( + CP(DF.SmartTapoRobovac, ET.Aes, https=True), + SmartProtocol, + SslTransport, + id="robovac", + ), + pytest.param( + CP(DF.IotSmartPlugSwitch, ET.Klap, https=False), + IotProtocol, + KlapTransport, + id="iot-klap", + ), + pytest.param( + CP(DF.IotSmartPlugSwitch, ET.Xor, https=False), + IotProtocol, + XorTransport, + id="iot-xor", + ), + pytest.param( + CP(DF.SmartTapoPlug, ET.Aes, https=False), + SmartProtocol, + AesTransport, + id="smart-aes", + ), + pytest.param( + CP(DF.SmartTapoPlug, ET.Klap, https=False), + SmartProtocol, + KlapTransportV2, + id="smart-klap", + ), + ], +) +async def test_get_protocol( + conn_params: DeviceConnectionParameters, + expected_protocol: type[BaseProtocol], + expected_transport: type[BaseTransport], +): + """Test get_protocol returns the right protocol.""" + config = DeviceConfig("127.0.0.1", connection_type=conn_params) + protocol = get_protocol(config) + assert isinstance(protocol, expected_protocol) + assert isinstance(protocol._transport, expected_transport)