Update SslAesTransport for legacy firmware versions

This commit is contained in:
Steven B 2024-12-10 14:05:30 +00:00
parent ed0481918c
commit 4a5bc20ee2
No known key found for this signature in database
GPG Key ID: 6D5B46B3679F2A43
8 changed files with 248 additions and 34 deletions

View File

@ -38,7 +38,7 @@ from kasa.feature import Feature
from kasa.interfaces.light import HSV, ColorTempRange, Light, LightState
from kasa.interfaces.thermostat import Thermostat, ThermostatState
from kasa.module import Module
from kasa.protocols import BaseProtocol, IotProtocol, SmartProtocol
from kasa.protocols import BaseProtocol, IotProtocol, SmartCamProtocol, SmartProtocol
from kasa.protocols.iotprotocol import _deprecated_TPLinkSmartHomeProtocol # noqa: F401
from kasa.smartcam.modules.camera import StreamResolution
from kasa.transports import BaseTransport
@ -52,6 +52,7 @@ __all__ = [
"BaseTransport",
"IotProtocol",
"SmartProtocol",
"SmartCamProtocol",
"LightState",
"TurnOnBehaviors",
"TurnOnBehavior",

View File

@ -8,7 +8,7 @@ from typing import Any
from .device import Device
from .device_type import DeviceType
from .deviceconfig import DeviceConfig
from .deviceconfig import DeviceConfig, DeviceFamily
from .exceptions import KasaException, UnsupportedDeviceError
from .iot import (
IotBulb,
@ -180,19 +180,23 @@ def get_protocol(
config: DeviceConfig,
) -> BaseProtocol | None:
"""Return the protocol from the connection name."""
protocol_name = config.connection_type.device_family.value.split(".")[0]
ctype = config.connection_type
protocol_name = ctype.device_family.value.split(".")[0]
if ctype.device_family is DeviceFamily.SmartIpCamera:
return SmartCamProtocol(transport=SslAesTransport(config=config))
if ctype.device_family is DeviceFamily.IotIpCamera:
return IotProtocol(transport=LinkieTransportV2(config=config))
if ctype.device_family is DeviceFamily.SmartTapoRobovac:
return SmartProtocol(transport=SslTransport(config=config))
protocol_transport_key = (
protocol_name
+ "."
+ ctype.encryption_type.value
+ (".HTTPS" if ctype.https else "")
+ (
f".{ctype.login_version}"
if ctype.login_version and ctype.login_version > 1
else ""
)
)
_LOGGER.debug("Finding transport for %s", protocol_transport_key)
@ -201,12 +205,11 @@ def get_protocol(
] = {
"IOT.XOR": (IotProtocol, XorTransport),
"IOT.KLAP": (IotProtocol, KlapTransport),
"IOT.XOR.HTTPS.2": (IotProtocol, LinkieTransportV2),
"SMART.AES": (SmartProtocol, AesTransport),
"SMART.AES.2": (SmartProtocol, AesTransport),
"SMART.KLAP.2": (SmartProtocol, KlapTransportV2),
"SMART.AES.HTTPS.2": (SmartCamProtocol, SslAesTransport),
"SMART.AES.HTTPS": (SmartProtocol, SslTransport),
"SMART.KLAP": (SmartProtocol, KlapTransportV2),
# Still require a lookup for SslAesTransport as H200 has a type of
# SMART.TAPOHUB.
"SMART.AES.HTTPS": (SmartCamProtocol, SslAesTransport),
}
if not (prot_tran_cls := supported_device_protocols.get(protocol_transport_key)):
return None

View File

@ -775,12 +775,10 @@ class Discover:
):
encrypt_type = encrypt_info.sym_schm
if (
not (login_version := encrypt_schm.lv)
and (et := discovery_result.encrypt_type)
and et == ["3"]
if not (login_version := encrypt_schm.lv) and (
et := discovery_result.encrypt_type
):
login_version = 2
login_version = max([int(i) for i in et])
if not encrypt_type:
raise UnsupportedDeviceError(

View File

@ -2,6 +2,7 @@
from .iotprotocol import IotProtocol
from .protocol import BaseProtocol
from .smartcamprotocol import SmartCamProtocol
from .smartprotocol import SmartErrorCode, SmartProtocol
__all__ = [
@ -9,4 +10,5 @@ __all__ = [
"IotProtocol",
"SmartErrorCode",
"SmartProtocol",
"SmartCamProtocol",
]

View File

@ -19,7 +19,7 @@ from ..transports.sslaestransport import (
SMART_RETRYABLE_ERRORS,
SmartErrorCode,
)
from . import SmartProtocol
from .smartprotocol import SmartProtocol
_LOGGER = logging.getLogger(__name__)

View File

@ -4,6 +4,7 @@ from .aestransport import AesEncyptionSession, AesTransport
from .basetransport import BaseTransport
from .klaptransport import KlapTransport, KlapTransportV2
from .linkietransport import LinkieTransportV2
from .sslaestransport import SslAesTransport
from .ssltransport import SslTransport
from .xortransport import XorEncryption, XorTransport
@ -11,6 +12,7 @@ __all__ = [
"AesTransport",
"AesEncyptionSession",
"SslTransport",
"SslAesTransport",
"BaseTransport",
"KlapTransport",
"KlapTransportV2",

View File

@ -48,6 +48,10 @@ def _sha256_hash(payload: bytes) -> str:
return hashlib.sha256(payload).hexdigest().upper() # noqa: S324
def _sha1_hash(payload: bytes) -> str:
return hashlib.sha1(payload).hexdigest().upper() # noqa: S324
class TransportState(Enum):
"""Enum for AES state."""
@ -107,11 +111,10 @@ class SslAesTransport(BaseTransport):
self._app_url = URL(f"https://{self._host_port}")
self._token_url: URL | None = None
self._ssl_context: ssl.SSLContext | None = None
ref = str(self._token_url) if self._token_url else str(self._app_url)
self._headers = {
**self.COMMON_HEADERS,
"Host": self._host_port,
"Referer": ref,
"Host": self._host,
"Referer": f"https://{self._host}",
}
self._seq: int | None = None
self._pwd_hash: str | None = None
@ -125,6 +128,7 @@ class SslAesTransport(BaseTransport):
self._password = ch["pwd"]
self._username = ch["un"]
self._local_nonce: str | None = None
self._send_secure = True
_LOGGER.debug("Created AES transport for %s", self._host)
@ -194,6 +198,10 @@ class SslAesTransport(BaseTransport):
else:
url = self._app_url
_LOGGER.debug(
"Sending secure passthrough from %s",
self._host,
)
encrypted_payload = self._encryption_session.encrypt(request.encode()) # type: ignore
passthrough_request = {
"method": "securePassthrough",
@ -254,6 +262,37 @@ class SslAesTransport(BaseTransport):
) from ex
return ret_val # type: ignore[return-value]
async def send_unencrypted(self, request: str) -> dict[str, Any]:
"""Send encrypted message as passthrough."""
if self._state is TransportState.ESTABLISHED and self._token_url:
url = self._token_url
else:
url = self._app_url
_LOGGER.debug(
"Sending unencrypted from %s",
self._host,
)
status_code, resp_dict = await self._http_client.post(
url,
json=request,
headers=self._headers,
ssl=await self._get_ssl_context(),
)
if status_code != 200:
raise KasaException(
f"{self._host} responded with an unexpected "
+ f"status code {status_code}"
)
self._handle_response_error_code(resp_dict, "Error sending message")
if TYPE_CHECKING:
resp_dict = cast(dict[str, Any], resp_dict)
return resp_dict
@staticmethod
def generate_confirm_hash(
local_nonce: str, server_nonce: str, pwd_hash: str
@ -302,8 +341,52 @@ class SslAesTransport(BaseTransport):
async def perform_handshake(self) -> None:
"""Perform the handshake."""
local_nonce, server_nonce, pwd_hash = await self.perform_handshake1()
await self.perform_handshake2(local_nonce, server_nonce, pwd_hash)
result = await self.perform_handshake1()
if result:
local_nonce, server_nonce, pwd_hash = result
await self.perform_handshake2(local_nonce, server_nonce, pwd_hash)
async def try_perform_login(self) -> bool:
"""Perform the md5 login."""
_LOGGER.debug("Performing insecure login ...")
pwd_hash = _md5_hash(self._pwd_to_hash().encode())
username = self._username
body = {
"method": "login",
"params": {
"hashed": True,
"password": pwd_hash,
"username": username,
},
}
http_client = self._http_client
status_code, resp_dict = await http_client.post(
self._app_url,
json=body,
headers=self._headers,
ssl=await self._get_ssl_context(),
)
if status_code != 200:
raise KasaException(
f"{self._host} responded with an unexpected "
+ f"status code {status_code} to handshake2"
)
resp_dict = cast(dict, resp_dict)
if resp_dict.get("error_code") == 0 and (
stok := resp_dict.get("result", {}).get("stok")
):
_LOGGER.debug(
"Succesfully logged in to %s with less secure passthrough", self._host
)
self._send_secure = False
self._token_url = URL(f"{str(self._app_url)}/stok={stok}/ds")
self._pwd_hash = pwd_hash
return True
_LOGGER.debug("Unable to log in to %s with less secure login", self._host)
return False
async def perform_handshake2(
self, local_nonce: str, server_nonce: str, pwd_hash: str
@ -355,13 +438,42 @@ class SslAesTransport(BaseTransport):
self._state = TransportState.ESTABLISHED
_LOGGER.debug("Handshake2 complete ...")
async def perform_handshake1(self) -> tuple[str, str, str]:
def _pwd_to_hash(self) -> str:
"""Return the password to hash."""
if self._credentials and self._credentials != Credentials():
return self._credentials.password
if self._username and self._password:
return self._password
return self._default_credentials.password
async def perform_handshake1(self) -> tuple[str, str, str] | None:
"""Perform the handshake1."""
resp_dict = None
if self._username:
local_nonce = secrets.token_bytes(8).hex().upper()
resp_dict = await self.try_send_handshake1(self._username, local_nonce)
if (
resp_dict
and (error_code := self._get_response_error(resp_dict))
is SmartErrorCode.SESSION_EXPIRED
and (
encrypt_type := resp_dict.get("result", {})
.get("data", {})
.get("encrypt_type")
)
and (encrypt_type != ["3"])
):
_LOGGER.debug(
"Received encrypt_type %s for %s, trying less secure login",
encrypt_type,
self._host,
)
if await self.try_perform_login():
return None
# Try the default username. If it fails raise the original error_code
if (
not resp_dict
@ -369,6 +481,7 @@ class SslAesTransport(BaseTransport):
is not SmartErrorCode.INVALID_NONCE
or "nonce" not in resp_dict["result"].get("data", {})
):
_LOGGER.debug("Trying default credentials to %s", self._host)
local_nonce = secrets.token_bytes(8).hex().upper()
default_resp_dict = await self.try_send_handshake1(
self._default_credentials.username, local_nonce
@ -378,7 +491,7 @@ class SslAesTransport(BaseTransport):
) is SmartErrorCode.INVALID_NONCE and "nonce" in default_resp_dict[
"result"
].get("data", {}):
_LOGGER.debug("Connected to {self._host} with default username")
_LOGGER.debug("Connected to %s with default username", self._host)
self._username = self._default_credentials.username
error_code = default_error_code
resp_dict = default_resp_dict
@ -397,12 +510,8 @@ class SslAesTransport(BaseTransport):
server_nonce = resp_dict["result"]["data"]["nonce"]
device_confirm = resp_dict["result"]["data"]["device_confirm"]
if self._credentials and self._credentials != Credentials():
pwd_hash = _sha256_hash(self._credentials.password.encode())
elif self._username and self._password:
pwd_hash = _sha256_hash(self._password.encode())
else:
pwd_hash = _sha256_hash(self._default_credentials.password.encode())
pwd_hash = _sha256_hash(self._pwd_to_hash().encode())
expected_confirm_sha256 = self.generate_confirm_hash(
local_nonce, server_nonce, pwd_hash
@ -414,7 +523,9 @@ class SslAesTransport(BaseTransport):
if TYPE_CHECKING:
assert self._credentials
assert self._credentials.password
pwd_hash = _md5_hash(self._credentials.password.encode())
pwd_hash = _md5_hash(self._pwd_to_hash().encode())
expected_confirm_md5 = self.generate_confirm_hash(
local_nonce, server_nonce, pwd_hash
)
@ -422,8 +533,17 @@ class SslAesTransport(BaseTransport):
_LOGGER.debug("Credentials match")
return local_nonce, server_nonce, pwd_hash
for val in {"admin", "tpadmin", "slprealtek"}:
for func in {_sha256_hash, _md5_hash, _sha1_hash, lambda x: x.decode()}:
pwd_hash = func(val.encode())
ec = self.generate_confirm_hash(local_nonce, server_nonce, pwd_hash)
if device_confirm == ec:
_LOGGER.debug("Credentials match with %s %s", val, func.__name__)
return local_nonce, server_nonce, pwd_hash
msg = f"Server response doesn't match our challenge on ip {self._host}"
_LOGGER.debug(msg)
raise AuthenticationError(msg)
async def try_send_handshake1(self, username: str, local_nonce: str) -> dict:
@ -462,7 +582,10 @@ class SslAesTransport(BaseTransport):
if self._state is TransportState.HANDSHAKE_REQUIRED:
await self.perform_handshake()
return await self.send_secure_passthrough(request)
if self._send_secure:
return await self.send_secure_passthrough(request)
return await self.send_unencrypted(request)
async def close(self) -> None:
"""Close the http client and reset internal state."""

View File

@ -13,9 +13,13 @@ import aiohttp
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
from kasa import (
BaseProtocol,
Credentials,
Discover,
IotProtocol,
KasaException,
SmartCamProtocol,
SmartProtocol,
)
from kasa.device_factory import (
Device,
@ -33,6 +37,16 @@ from kasa.deviceconfig import (
DeviceFamily,
)
from kasa.discover import DiscoveryResult
from kasa.transports import (
AesTransport,
BaseTransport,
KlapTransport,
KlapTransportV2,
LinkieTransportV2,
SslAesTransport,
SslTransport,
XorTransport,
)
from .conftest import DISCOVERY_MOCK_IP
@ -203,3 +217,74 @@ async def test_device_class_from_unknown_family(caplog):
with caplog.at_level(logging.DEBUG):
assert get_device_class_from_family(dummy_name, https=False) == SmartDevice
assert f"Unknown SMART device with {dummy_name}" in caplog.text
# Aliases to make the test params more readable
CP = DeviceConnectionParameters
DF = DeviceFamily
ET = DeviceEncryptionType
@pytest.mark.parametrize(
("conn_params", "expected_protocol", "expected_transport"),
[
pytest.param(
CP(DF.SmartIpCamera, ET.Aes, https=True),
SmartCamProtocol,
SslAesTransport,
id="smartcam",
),
pytest.param(
CP(DF.SmartTapoHub, ET.Aes, https=True),
SmartCamProtocol,
SslAesTransport,
id="smartcam-hub",
),
pytest.param(
CP(DF.IotIpCamera, ET.Aes, https=True),
IotProtocol,
LinkieTransportV2,
id="kasacam",
),
pytest.param(
CP(DF.SmartTapoRobovac, ET.Aes, https=True),
SmartProtocol,
SslTransport,
id="robovac",
),
pytest.param(
CP(DF.IotSmartPlugSwitch, ET.Klap, https=False),
IotProtocol,
KlapTransport,
id="iot-klap",
),
pytest.param(
CP(DF.IotSmartPlugSwitch, ET.Xor, https=False),
IotProtocol,
XorTransport,
id="iot-xor",
),
pytest.param(
CP(DF.SmartTapoPlug, ET.Aes, https=False),
SmartProtocol,
AesTransport,
id="smart-aes",
),
pytest.param(
CP(DF.SmartTapoPlug, ET.Klap, https=False),
SmartProtocol,
KlapTransportV2,
id="smart-klap",
),
],
)
async def test_get_protocol(
conn_params: DeviceConnectionParameters,
expected_protocol: type[BaseProtocol],
expected_transport: type[BaseTransport],
):
"""Test get_protocol returns the right protocol."""
config = DeviceConfig("127.0.0.1", connection_type=conn_params)
protocol = get_protocol(config)
assert isinstance(protocol, expected_protocol)
assert isinstance(protocol._transport, expected_transport)