Update SslAesTransport for legacy firmware versions

This commit is contained in:
Steven B 2024-12-10 14:05:30 +00:00
parent ed0481918c
commit 4a5bc20ee2
No known key found for this signature in database
GPG Key ID: 6D5B46B3679F2A43
8 changed files with 248 additions and 34 deletions

View File

@ -38,7 +38,7 @@ from kasa.feature import Feature
from kasa.interfaces.light import HSV, ColorTempRange, Light, LightState from kasa.interfaces.light import HSV, ColorTempRange, Light, LightState
from kasa.interfaces.thermostat import Thermostat, ThermostatState from kasa.interfaces.thermostat import Thermostat, ThermostatState
from kasa.module import Module from kasa.module import Module
from kasa.protocols import BaseProtocol, IotProtocol, SmartProtocol from kasa.protocols import BaseProtocol, IotProtocol, SmartCamProtocol, SmartProtocol
from kasa.protocols.iotprotocol import _deprecated_TPLinkSmartHomeProtocol # noqa: F401 from kasa.protocols.iotprotocol import _deprecated_TPLinkSmartHomeProtocol # noqa: F401
from kasa.smartcam.modules.camera import StreamResolution from kasa.smartcam.modules.camera import StreamResolution
from kasa.transports import BaseTransport from kasa.transports import BaseTransport
@ -52,6 +52,7 @@ __all__ = [
"BaseTransport", "BaseTransport",
"IotProtocol", "IotProtocol",
"SmartProtocol", "SmartProtocol",
"SmartCamProtocol",
"LightState", "LightState",
"TurnOnBehaviors", "TurnOnBehaviors",
"TurnOnBehavior", "TurnOnBehavior",

View File

@ -8,7 +8,7 @@ from typing import Any
from .device import Device from .device import Device
from .device_type import DeviceType from .device_type import DeviceType
from .deviceconfig import DeviceConfig from .deviceconfig import DeviceConfig, DeviceFamily
from .exceptions import KasaException, UnsupportedDeviceError from .exceptions import KasaException, UnsupportedDeviceError
from .iot import ( from .iot import (
IotBulb, IotBulb,
@ -180,19 +180,23 @@ def get_protocol(
config: DeviceConfig, config: DeviceConfig,
) -> BaseProtocol | None: ) -> BaseProtocol | None:
"""Return the protocol from the connection name.""" """Return the protocol from the connection name."""
protocol_name = config.connection_type.device_family.value.split(".")[0]
ctype = config.connection_type ctype = config.connection_type
protocol_name = ctype.device_family.value.split(".")[0]
if ctype.device_family is DeviceFamily.SmartIpCamera:
return SmartCamProtocol(transport=SslAesTransport(config=config))
if ctype.device_family is DeviceFamily.IotIpCamera:
return IotProtocol(transport=LinkieTransportV2(config=config))
if ctype.device_family is DeviceFamily.SmartTapoRobovac:
return SmartProtocol(transport=SslTransport(config=config))
protocol_transport_key = ( protocol_transport_key = (
protocol_name protocol_name
+ "." + "."
+ ctype.encryption_type.value + ctype.encryption_type.value
+ (".HTTPS" if ctype.https else "") + (".HTTPS" if ctype.https else "")
+ (
f".{ctype.login_version}"
if ctype.login_version and ctype.login_version > 1
else ""
)
) )
_LOGGER.debug("Finding transport for %s", protocol_transport_key) _LOGGER.debug("Finding transport for %s", protocol_transport_key)
@ -201,12 +205,11 @@ def get_protocol(
] = { ] = {
"IOT.XOR": (IotProtocol, XorTransport), "IOT.XOR": (IotProtocol, XorTransport),
"IOT.KLAP": (IotProtocol, KlapTransport), "IOT.KLAP": (IotProtocol, KlapTransport),
"IOT.XOR.HTTPS.2": (IotProtocol, LinkieTransportV2),
"SMART.AES": (SmartProtocol, AesTransport), "SMART.AES": (SmartProtocol, AesTransport),
"SMART.AES.2": (SmartProtocol, AesTransport), "SMART.KLAP": (SmartProtocol, KlapTransportV2),
"SMART.KLAP.2": (SmartProtocol, KlapTransportV2), # Still require a lookup for SslAesTransport as H200 has a type of
"SMART.AES.HTTPS.2": (SmartCamProtocol, SslAesTransport), # SMART.TAPOHUB.
"SMART.AES.HTTPS": (SmartProtocol, SslTransport), "SMART.AES.HTTPS": (SmartCamProtocol, SslAesTransport),
} }
if not (prot_tran_cls := supported_device_protocols.get(protocol_transport_key)): if not (prot_tran_cls := supported_device_protocols.get(protocol_transport_key)):
return None return None

View File

@ -775,12 +775,10 @@ class Discover:
): ):
encrypt_type = encrypt_info.sym_schm encrypt_type = encrypt_info.sym_schm
if ( if not (login_version := encrypt_schm.lv) and (
not (login_version := encrypt_schm.lv) et := discovery_result.encrypt_type
and (et := discovery_result.encrypt_type)
and et == ["3"]
): ):
login_version = 2 login_version = max([int(i) for i in et])
if not encrypt_type: if not encrypt_type:
raise UnsupportedDeviceError( raise UnsupportedDeviceError(

View File

@ -2,6 +2,7 @@
from .iotprotocol import IotProtocol from .iotprotocol import IotProtocol
from .protocol import BaseProtocol from .protocol import BaseProtocol
from .smartcamprotocol import SmartCamProtocol
from .smartprotocol import SmartErrorCode, SmartProtocol from .smartprotocol import SmartErrorCode, SmartProtocol
__all__ = [ __all__ = [
@ -9,4 +10,5 @@ __all__ = [
"IotProtocol", "IotProtocol",
"SmartErrorCode", "SmartErrorCode",
"SmartProtocol", "SmartProtocol",
"SmartCamProtocol",
] ]

View File

@ -19,7 +19,7 @@ from ..transports.sslaestransport import (
SMART_RETRYABLE_ERRORS, SMART_RETRYABLE_ERRORS,
SmartErrorCode, SmartErrorCode,
) )
from . import SmartProtocol from .smartprotocol import SmartProtocol
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)

View File

@ -4,6 +4,7 @@ from .aestransport import AesEncyptionSession, AesTransport
from .basetransport import BaseTransport from .basetransport import BaseTransport
from .klaptransport import KlapTransport, KlapTransportV2 from .klaptransport import KlapTransport, KlapTransportV2
from .linkietransport import LinkieTransportV2 from .linkietransport import LinkieTransportV2
from .sslaestransport import SslAesTransport
from .ssltransport import SslTransport from .ssltransport import SslTransport
from .xortransport import XorEncryption, XorTransport from .xortransport import XorEncryption, XorTransport
@ -11,6 +12,7 @@ __all__ = [
"AesTransport", "AesTransport",
"AesEncyptionSession", "AesEncyptionSession",
"SslTransport", "SslTransport",
"SslAesTransport",
"BaseTransport", "BaseTransport",
"KlapTransport", "KlapTransport",
"KlapTransportV2", "KlapTransportV2",

View File

@ -48,6 +48,10 @@ def _sha256_hash(payload: bytes) -> str:
return hashlib.sha256(payload).hexdigest().upper() # noqa: S324 return hashlib.sha256(payload).hexdigest().upper() # noqa: S324
def _sha1_hash(payload: bytes) -> str:
return hashlib.sha1(payload).hexdigest().upper() # noqa: S324
class TransportState(Enum): class TransportState(Enum):
"""Enum for AES state.""" """Enum for AES state."""
@ -107,11 +111,10 @@ class SslAesTransport(BaseTransport):
self._app_url = URL(f"https://{self._host_port}") self._app_url = URL(f"https://{self._host_port}")
self._token_url: URL | None = None self._token_url: URL | None = None
self._ssl_context: ssl.SSLContext | None = None self._ssl_context: ssl.SSLContext | None = None
ref = str(self._token_url) if self._token_url else str(self._app_url)
self._headers = { self._headers = {
**self.COMMON_HEADERS, **self.COMMON_HEADERS,
"Host": self._host_port, "Host": self._host,
"Referer": ref, "Referer": f"https://{self._host}",
} }
self._seq: int | None = None self._seq: int | None = None
self._pwd_hash: str | None = None self._pwd_hash: str | None = None
@ -125,6 +128,7 @@ class SslAesTransport(BaseTransport):
self._password = ch["pwd"] self._password = ch["pwd"]
self._username = ch["un"] self._username = ch["un"]
self._local_nonce: str | None = None self._local_nonce: str | None = None
self._send_secure = True
_LOGGER.debug("Created AES transport for %s", self._host) _LOGGER.debug("Created AES transport for %s", self._host)
@ -194,6 +198,10 @@ class SslAesTransport(BaseTransport):
else: else:
url = self._app_url url = self._app_url
_LOGGER.debug(
"Sending secure passthrough from %s",
self._host,
)
encrypted_payload = self._encryption_session.encrypt(request.encode()) # type: ignore encrypted_payload = self._encryption_session.encrypt(request.encode()) # type: ignore
passthrough_request = { passthrough_request = {
"method": "securePassthrough", "method": "securePassthrough",
@ -254,6 +262,37 @@ class SslAesTransport(BaseTransport):
) from ex ) from ex
return ret_val # type: ignore[return-value] return ret_val # type: ignore[return-value]
async def send_unencrypted(self, request: str) -> dict[str, Any]:
"""Send encrypted message as passthrough."""
if self._state is TransportState.ESTABLISHED and self._token_url:
url = self._token_url
else:
url = self._app_url
_LOGGER.debug(
"Sending unencrypted from %s",
self._host,
)
status_code, resp_dict = await self._http_client.post(
url,
json=request,
headers=self._headers,
ssl=await self._get_ssl_context(),
)
if status_code != 200:
raise KasaException(
f"{self._host} responded with an unexpected "
+ f"status code {status_code}"
)
self._handle_response_error_code(resp_dict, "Error sending message")
if TYPE_CHECKING:
resp_dict = cast(dict[str, Any], resp_dict)
return resp_dict
@staticmethod @staticmethod
def generate_confirm_hash( def generate_confirm_hash(
local_nonce: str, server_nonce: str, pwd_hash: str local_nonce: str, server_nonce: str, pwd_hash: str
@ -302,8 +341,52 @@ class SslAesTransport(BaseTransport):
async def perform_handshake(self) -> None: async def perform_handshake(self) -> None:
"""Perform the handshake.""" """Perform the handshake."""
local_nonce, server_nonce, pwd_hash = await self.perform_handshake1() result = await self.perform_handshake1()
await self.perform_handshake2(local_nonce, server_nonce, pwd_hash) if result:
local_nonce, server_nonce, pwd_hash = result
await self.perform_handshake2(local_nonce, server_nonce, pwd_hash)
async def try_perform_login(self) -> bool:
"""Perform the md5 login."""
_LOGGER.debug("Performing insecure login ...")
pwd_hash = _md5_hash(self._pwd_to_hash().encode())
username = self._username
body = {
"method": "login",
"params": {
"hashed": True,
"password": pwd_hash,
"username": username,
},
}
http_client = self._http_client
status_code, resp_dict = await http_client.post(
self._app_url,
json=body,
headers=self._headers,
ssl=await self._get_ssl_context(),
)
if status_code != 200:
raise KasaException(
f"{self._host} responded with an unexpected "
+ f"status code {status_code} to handshake2"
)
resp_dict = cast(dict, resp_dict)
if resp_dict.get("error_code") == 0 and (
stok := resp_dict.get("result", {}).get("stok")
):
_LOGGER.debug(
"Succesfully logged in to %s with less secure passthrough", self._host
)
self._send_secure = False
self._token_url = URL(f"{str(self._app_url)}/stok={stok}/ds")
self._pwd_hash = pwd_hash
return True
_LOGGER.debug("Unable to log in to %s with less secure login", self._host)
return False
async def perform_handshake2( async def perform_handshake2(
self, local_nonce: str, server_nonce: str, pwd_hash: str self, local_nonce: str, server_nonce: str, pwd_hash: str
@ -355,13 +438,42 @@ class SslAesTransport(BaseTransport):
self._state = TransportState.ESTABLISHED self._state = TransportState.ESTABLISHED
_LOGGER.debug("Handshake2 complete ...") _LOGGER.debug("Handshake2 complete ...")
async def perform_handshake1(self) -> tuple[str, str, str]: def _pwd_to_hash(self) -> str:
"""Return the password to hash."""
if self._credentials and self._credentials != Credentials():
return self._credentials.password
if self._username and self._password:
return self._password
return self._default_credentials.password
async def perform_handshake1(self) -> tuple[str, str, str] | None:
"""Perform the handshake1.""" """Perform the handshake1."""
resp_dict = None resp_dict = None
if self._username: if self._username:
local_nonce = secrets.token_bytes(8).hex().upper() local_nonce = secrets.token_bytes(8).hex().upper()
resp_dict = await self.try_send_handshake1(self._username, local_nonce) resp_dict = await self.try_send_handshake1(self._username, local_nonce)
if (
resp_dict
and (error_code := self._get_response_error(resp_dict))
is SmartErrorCode.SESSION_EXPIRED
and (
encrypt_type := resp_dict.get("result", {})
.get("data", {})
.get("encrypt_type")
)
and (encrypt_type != ["3"])
):
_LOGGER.debug(
"Received encrypt_type %s for %s, trying less secure login",
encrypt_type,
self._host,
)
if await self.try_perform_login():
return None
# Try the default username. If it fails raise the original error_code # Try the default username. If it fails raise the original error_code
if ( if (
not resp_dict not resp_dict
@ -369,6 +481,7 @@ class SslAesTransport(BaseTransport):
is not SmartErrorCode.INVALID_NONCE is not SmartErrorCode.INVALID_NONCE
or "nonce" not in resp_dict["result"].get("data", {}) or "nonce" not in resp_dict["result"].get("data", {})
): ):
_LOGGER.debug("Trying default credentials to %s", self._host)
local_nonce = secrets.token_bytes(8).hex().upper() local_nonce = secrets.token_bytes(8).hex().upper()
default_resp_dict = await self.try_send_handshake1( default_resp_dict = await self.try_send_handshake1(
self._default_credentials.username, local_nonce self._default_credentials.username, local_nonce
@ -378,7 +491,7 @@ class SslAesTransport(BaseTransport):
) is SmartErrorCode.INVALID_NONCE and "nonce" in default_resp_dict[ ) is SmartErrorCode.INVALID_NONCE and "nonce" in default_resp_dict[
"result" "result"
].get("data", {}): ].get("data", {}):
_LOGGER.debug("Connected to {self._host} with default username") _LOGGER.debug("Connected to %s with default username", self._host)
self._username = self._default_credentials.username self._username = self._default_credentials.username
error_code = default_error_code error_code = default_error_code
resp_dict = default_resp_dict resp_dict = default_resp_dict
@ -397,12 +510,8 @@ class SslAesTransport(BaseTransport):
server_nonce = resp_dict["result"]["data"]["nonce"] server_nonce = resp_dict["result"]["data"]["nonce"]
device_confirm = resp_dict["result"]["data"]["device_confirm"] device_confirm = resp_dict["result"]["data"]["device_confirm"]
if self._credentials and self._credentials != Credentials():
pwd_hash = _sha256_hash(self._credentials.password.encode()) pwd_hash = _sha256_hash(self._pwd_to_hash().encode())
elif self._username and self._password:
pwd_hash = _sha256_hash(self._password.encode())
else:
pwd_hash = _sha256_hash(self._default_credentials.password.encode())
expected_confirm_sha256 = self.generate_confirm_hash( expected_confirm_sha256 = self.generate_confirm_hash(
local_nonce, server_nonce, pwd_hash local_nonce, server_nonce, pwd_hash
@ -414,7 +523,9 @@ class SslAesTransport(BaseTransport):
if TYPE_CHECKING: if TYPE_CHECKING:
assert self._credentials assert self._credentials
assert self._credentials.password assert self._credentials.password
pwd_hash = _md5_hash(self._credentials.password.encode())
pwd_hash = _md5_hash(self._pwd_to_hash().encode())
expected_confirm_md5 = self.generate_confirm_hash( expected_confirm_md5 = self.generate_confirm_hash(
local_nonce, server_nonce, pwd_hash local_nonce, server_nonce, pwd_hash
) )
@ -422,8 +533,17 @@ class SslAesTransport(BaseTransport):
_LOGGER.debug("Credentials match") _LOGGER.debug("Credentials match")
return local_nonce, server_nonce, pwd_hash return local_nonce, server_nonce, pwd_hash
for val in {"admin", "tpadmin", "slprealtek"}:
for func in {_sha256_hash, _md5_hash, _sha1_hash, lambda x: x.decode()}:
pwd_hash = func(val.encode())
ec = self.generate_confirm_hash(local_nonce, server_nonce, pwd_hash)
if device_confirm == ec:
_LOGGER.debug("Credentials match with %s %s", val, func.__name__)
return local_nonce, server_nonce, pwd_hash
msg = f"Server response doesn't match our challenge on ip {self._host}" msg = f"Server response doesn't match our challenge on ip {self._host}"
_LOGGER.debug(msg) _LOGGER.debug(msg)
raise AuthenticationError(msg) raise AuthenticationError(msg)
async def try_send_handshake1(self, username: str, local_nonce: str) -> dict: async def try_send_handshake1(self, username: str, local_nonce: str) -> dict:
@ -462,7 +582,10 @@ class SslAesTransport(BaseTransport):
if self._state is TransportState.HANDSHAKE_REQUIRED: if self._state is TransportState.HANDSHAKE_REQUIRED:
await self.perform_handshake() await self.perform_handshake()
return await self.send_secure_passthrough(request) if self._send_secure:
return await self.send_secure_passthrough(request)
return await self.send_unencrypted(request)
async def close(self) -> None: async def close(self) -> None:
"""Close the http client and reset internal state.""" """Close the http client and reset internal state."""

View File

@ -13,9 +13,13 @@ import aiohttp
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
from kasa import ( from kasa import (
BaseProtocol,
Credentials, Credentials,
Discover, Discover,
IotProtocol,
KasaException, KasaException,
SmartCamProtocol,
SmartProtocol,
) )
from kasa.device_factory import ( from kasa.device_factory import (
Device, Device,
@ -33,6 +37,16 @@ from kasa.deviceconfig import (
DeviceFamily, DeviceFamily,
) )
from kasa.discover import DiscoveryResult from kasa.discover import DiscoveryResult
from kasa.transports import (
AesTransport,
BaseTransport,
KlapTransport,
KlapTransportV2,
LinkieTransportV2,
SslAesTransport,
SslTransport,
XorTransport,
)
from .conftest import DISCOVERY_MOCK_IP from .conftest import DISCOVERY_MOCK_IP
@ -203,3 +217,74 @@ async def test_device_class_from_unknown_family(caplog):
with caplog.at_level(logging.DEBUG): with caplog.at_level(logging.DEBUG):
assert get_device_class_from_family(dummy_name, https=False) == SmartDevice assert get_device_class_from_family(dummy_name, https=False) == SmartDevice
assert f"Unknown SMART device with {dummy_name}" in caplog.text assert f"Unknown SMART device with {dummy_name}" in caplog.text
# Aliases to make the test params more readable
CP = DeviceConnectionParameters
DF = DeviceFamily
ET = DeviceEncryptionType
@pytest.mark.parametrize(
("conn_params", "expected_protocol", "expected_transport"),
[
pytest.param(
CP(DF.SmartIpCamera, ET.Aes, https=True),
SmartCamProtocol,
SslAesTransport,
id="smartcam",
),
pytest.param(
CP(DF.SmartTapoHub, ET.Aes, https=True),
SmartCamProtocol,
SslAesTransport,
id="smartcam-hub",
),
pytest.param(
CP(DF.IotIpCamera, ET.Aes, https=True),
IotProtocol,
LinkieTransportV2,
id="kasacam",
),
pytest.param(
CP(DF.SmartTapoRobovac, ET.Aes, https=True),
SmartProtocol,
SslTransport,
id="robovac",
),
pytest.param(
CP(DF.IotSmartPlugSwitch, ET.Klap, https=False),
IotProtocol,
KlapTransport,
id="iot-klap",
),
pytest.param(
CP(DF.IotSmartPlugSwitch, ET.Xor, https=False),
IotProtocol,
XorTransport,
id="iot-xor",
),
pytest.param(
CP(DF.SmartTapoPlug, ET.Aes, https=False),
SmartProtocol,
AesTransport,
id="smart-aes",
),
pytest.param(
CP(DF.SmartTapoPlug, ET.Klap, https=False),
SmartProtocol,
KlapTransportV2,
id="smart-klap",
),
],
)
async def test_get_protocol(
conn_params: DeviceConnectionParameters,
expected_protocol: type[BaseProtocol],
expected_transport: type[BaseTransport],
):
"""Test get_protocol returns the right protocol."""
config = DeviceConfig("127.0.0.1", connection_type=conn_params)
protocol = get_protocol(config)
assert isinstance(protocol, expected_protocol)
assert isinstance(protocol._transport, expected_transport)