mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-01-24 05:37:59 +00:00
Update SslAesTransport for legacy firmware versions
This commit is contained in:
parent
ed0481918c
commit
4a5bc20ee2
@ -38,7 +38,7 @@ from kasa.feature import Feature
|
||||
from kasa.interfaces.light import HSV, ColorTempRange, Light, LightState
|
||||
from kasa.interfaces.thermostat import Thermostat, ThermostatState
|
||||
from kasa.module import Module
|
||||
from kasa.protocols import BaseProtocol, IotProtocol, SmartProtocol
|
||||
from kasa.protocols import BaseProtocol, IotProtocol, SmartCamProtocol, SmartProtocol
|
||||
from kasa.protocols.iotprotocol import _deprecated_TPLinkSmartHomeProtocol # noqa: F401
|
||||
from kasa.smartcam.modules.camera import StreamResolution
|
||||
from kasa.transports import BaseTransport
|
||||
@ -52,6 +52,7 @@ __all__ = [
|
||||
"BaseTransport",
|
||||
"IotProtocol",
|
||||
"SmartProtocol",
|
||||
"SmartCamProtocol",
|
||||
"LightState",
|
||||
"TurnOnBehaviors",
|
||||
"TurnOnBehavior",
|
||||
|
@ -8,7 +8,7 @@ from typing import Any
|
||||
|
||||
from .device import Device
|
||||
from .device_type import DeviceType
|
||||
from .deviceconfig import DeviceConfig
|
||||
from .deviceconfig import DeviceConfig, DeviceFamily
|
||||
from .exceptions import KasaException, UnsupportedDeviceError
|
||||
from .iot import (
|
||||
IotBulb,
|
||||
@ -180,19 +180,23 @@ def get_protocol(
|
||||
config: DeviceConfig,
|
||||
) -> BaseProtocol | None:
|
||||
"""Return the protocol from the connection name."""
|
||||
protocol_name = config.connection_type.device_family.value.split(".")[0]
|
||||
ctype = config.connection_type
|
||||
protocol_name = ctype.device_family.value.split(".")[0]
|
||||
|
||||
if ctype.device_family is DeviceFamily.SmartIpCamera:
|
||||
return SmartCamProtocol(transport=SslAesTransport(config=config))
|
||||
|
||||
if ctype.device_family is DeviceFamily.IotIpCamera:
|
||||
return IotProtocol(transport=LinkieTransportV2(config=config))
|
||||
|
||||
if ctype.device_family is DeviceFamily.SmartTapoRobovac:
|
||||
return SmartProtocol(transport=SslTransport(config=config))
|
||||
|
||||
protocol_transport_key = (
|
||||
protocol_name
|
||||
+ "."
|
||||
+ ctype.encryption_type.value
|
||||
+ (".HTTPS" if ctype.https else "")
|
||||
+ (
|
||||
f".{ctype.login_version}"
|
||||
if ctype.login_version and ctype.login_version > 1
|
||||
else ""
|
||||
)
|
||||
)
|
||||
|
||||
_LOGGER.debug("Finding transport for %s", protocol_transport_key)
|
||||
@ -201,12 +205,11 @@ def get_protocol(
|
||||
] = {
|
||||
"IOT.XOR": (IotProtocol, XorTransport),
|
||||
"IOT.KLAP": (IotProtocol, KlapTransport),
|
||||
"IOT.XOR.HTTPS.2": (IotProtocol, LinkieTransportV2),
|
||||
"SMART.AES": (SmartProtocol, AesTransport),
|
||||
"SMART.AES.2": (SmartProtocol, AesTransport),
|
||||
"SMART.KLAP.2": (SmartProtocol, KlapTransportV2),
|
||||
"SMART.AES.HTTPS.2": (SmartCamProtocol, SslAesTransport),
|
||||
"SMART.AES.HTTPS": (SmartProtocol, SslTransport),
|
||||
"SMART.KLAP": (SmartProtocol, KlapTransportV2),
|
||||
# Still require a lookup for SslAesTransport as H200 has a type of
|
||||
# SMART.TAPOHUB.
|
||||
"SMART.AES.HTTPS": (SmartCamProtocol, SslAesTransport),
|
||||
}
|
||||
if not (prot_tran_cls := supported_device_protocols.get(protocol_transport_key)):
|
||||
return None
|
||||
|
@ -775,12 +775,10 @@ class Discover:
|
||||
):
|
||||
encrypt_type = encrypt_info.sym_schm
|
||||
|
||||
if (
|
||||
not (login_version := encrypt_schm.lv)
|
||||
and (et := discovery_result.encrypt_type)
|
||||
and et == ["3"]
|
||||
if not (login_version := encrypt_schm.lv) and (
|
||||
et := discovery_result.encrypt_type
|
||||
):
|
||||
login_version = 2
|
||||
login_version = max([int(i) for i in et])
|
||||
|
||||
if not encrypt_type:
|
||||
raise UnsupportedDeviceError(
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
from .iotprotocol import IotProtocol
|
||||
from .protocol import BaseProtocol
|
||||
from .smartcamprotocol import SmartCamProtocol
|
||||
from .smartprotocol import SmartErrorCode, SmartProtocol
|
||||
|
||||
__all__ = [
|
||||
@ -9,4 +10,5 @@ __all__ = [
|
||||
"IotProtocol",
|
||||
"SmartErrorCode",
|
||||
"SmartProtocol",
|
||||
"SmartCamProtocol",
|
||||
]
|
||||
|
@ -19,7 +19,7 @@ from ..transports.sslaestransport import (
|
||||
SMART_RETRYABLE_ERRORS,
|
||||
SmartErrorCode,
|
||||
)
|
||||
from . import SmartProtocol
|
||||
from .smartprotocol import SmartProtocol
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -4,6 +4,7 @@ from .aestransport import AesEncyptionSession, AesTransport
|
||||
from .basetransport import BaseTransport
|
||||
from .klaptransport import KlapTransport, KlapTransportV2
|
||||
from .linkietransport import LinkieTransportV2
|
||||
from .sslaestransport import SslAesTransport
|
||||
from .ssltransport import SslTransport
|
||||
from .xortransport import XorEncryption, XorTransport
|
||||
|
||||
@ -11,6 +12,7 @@ __all__ = [
|
||||
"AesTransport",
|
||||
"AesEncyptionSession",
|
||||
"SslTransport",
|
||||
"SslAesTransport",
|
||||
"BaseTransport",
|
||||
"KlapTransport",
|
||||
"KlapTransportV2",
|
||||
|
@ -48,6 +48,10 @@ def _sha256_hash(payload: bytes) -> str:
|
||||
return hashlib.sha256(payload).hexdigest().upper() # noqa: S324
|
||||
|
||||
|
||||
def _sha1_hash(payload: bytes) -> str:
|
||||
return hashlib.sha1(payload).hexdigest().upper() # noqa: S324
|
||||
|
||||
|
||||
class TransportState(Enum):
|
||||
"""Enum for AES state."""
|
||||
|
||||
@ -107,11 +111,10 @@ class SslAesTransport(BaseTransport):
|
||||
self._app_url = URL(f"https://{self._host_port}")
|
||||
self._token_url: URL | None = None
|
||||
self._ssl_context: ssl.SSLContext | None = None
|
||||
ref = str(self._token_url) if self._token_url else str(self._app_url)
|
||||
self._headers = {
|
||||
**self.COMMON_HEADERS,
|
||||
"Host": self._host_port,
|
||||
"Referer": ref,
|
||||
"Host": self._host,
|
||||
"Referer": f"https://{self._host}",
|
||||
}
|
||||
self._seq: int | None = None
|
||||
self._pwd_hash: str | None = None
|
||||
@ -125,6 +128,7 @@ class SslAesTransport(BaseTransport):
|
||||
self._password = ch["pwd"]
|
||||
self._username = ch["un"]
|
||||
self._local_nonce: str | None = None
|
||||
self._send_secure = True
|
||||
|
||||
_LOGGER.debug("Created AES transport for %s", self._host)
|
||||
|
||||
@ -194,6 +198,10 @@ class SslAesTransport(BaseTransport):
|
||||
else:
|
||||
url = self._app_url
|
||||
|
||||
_LOGGER.debug(
|
||||
"Sending secure passthrough from %s",
|
||||
self._host,
|
||||
)
|
||||
encrypted_payload = self._encryption_session.encrypt(request.encode()) # type: ignore
|
||||
passthrough_request = {
|
||||
"method": "securePassthrough",
|
||||
@ -254,6 +262,37 @@ class SslAesTransport(BaseTransport):
|
||||
) from ex
|
||||
return ret_val # type: ignore[return-value]
|
||||
|
||||
async def send_unencrypted(self, request: str) -> dict[str, Any]:
|
||||
"""Send encrypted message as passthrough."""
|
||||
if self._state is TransportState.ESTABLISHED and self._token_url:
|
||||
url = self._token_url
|
||||
else:
|
||||
url = self._app_url
|
||||
|
||||
_LOGGER.debug(
|
||||
"Sending unencrypted from %s",
|
||||
self._host,
|
||||
)
|
||||
|
||||
status_code, resp_dict = await self._http_client.post(
|
||||
url,
|
||||
json=request,
|
||||
headers=self._headers,
|
||||
ssl=await self._get_ssl_context(),
|
||||
)
|
||||
|
||||
if status_code != 200:
|
||||
raise KasaException(
|
||||
f"{self._host} responded with an unexpected "
|
||||
+ f"status code {status_code}"
|
||||
)
|
||||
|
||||
self._handle_response_error_code(resp_dict, "Error sending message")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
resp_dict = cast(dict[str, Any], resp_dict)
|
||||
return resp_dict
|
||||
|
||||
@staticmethod
|
||||
def generate_confirm_hash(
|
||||
local_nonce: str, server_nonce: str, pwd_hash: str
|
||||
@ -302,8 +341,52 @@ class SslAesTransport(BaseTransport):
|
||||
|
||||
async def perform_handshake(self) -> None:
|
||||
"""Perform the handshake."""
|
||||
local_nonce, server_nonce, pwd_hash = await self.perform_handshake1()
|
||||
await self.perform_handshake2(local_nonce, server_nonce, pwd_hash)
|
||||
result = await self.perform_handshake1()
|
||||
if result:
|
||||
local_nonce, server_nonce, pwd_hash = result
|
||||
await self.perform_handshake2(local_nonce, server_nonce, pwd_hash)
|
||||
|
||||
async def try_perform_login(self) -> bool:
|
||||
"""Perform the md5 login."""
|
||||
_LOGGER.debug("Performing insecure login ...")
|
||||
|
||||
pwd_hash = _md5_hash(self._pwd_to_hash().encode())
|
||||
username = self._username
|
||||
body = {
|
||||
"method": "login",
|
||||
"params": {
|
||||
"hashed": True,
|
||||
"password": pwd_hash,
|
||||
"username": username,
|
||||
},
|
||||
}
|
||||
|
||||
http_client = self._http_client
|
||||
status_code, resp_dict = await http_client.post(
|
||||
self._app_url,
|
||||
json=body,
|
||||
headers=self._headers,
|
||||
ssl=await self._get_ssl_context(),
|
||||
)
|
||||
if status_code != 200:
|
||||
raise KasaException(
|
||||
f"{self._host} responded with an unexpected "
|
||||
+ f"status code {status_code} to handshake2"
|
||||
)
|
||||
resp_dict = cast(dict, resp_dict)
|
||||
if resp_dict.get("error_code") == 0 and (
|
||||
stok := resp_dict.get("result", {}).get("stok")
|
||||
):
|
||||
_LOGGER.debug(
|
||||
"Succesfully logged in to %s with less secure passthrough", self._host
|
||||
)
|
||||
self._send_secure = False
|
||||
self._token_url = URL(f"{str(self._app_url)}/stok={stok}/ds")
|
||||
self._pwd_hash = pwd_hash
|
||||
return True
|
||||
|
||||
_LOGGER.debug("Unable to log in to %s with less secure login", self._host)
|
||||
return False
|
||||
|
||||
async def perform_handshake2(
|
||||
self, local_nonce: str, server_nonce: str, pwd_hash: str
|
||||
@ -355,13 +438,42 @@ class SslAesTransport(BaseTransport):
|
||||
self._state = TransportState.ESTABLISHED
|
||||
_LOGGER.debug("Handshake2 complete ...")
|
||||
|
||||
async def perform_handshake1(self) -> tuple[str, str, str]:
|
||||
def _pwd_to_hash(self) -> str:
|
||||
"""Return the password to hash."""
|
||||
if self._credentials and self._credentials != Credentials():
|
||||
return self._credentials.password
|
||||
|
||||
if self._username and self._password:
|
||||
return self._password
|
||||
|
||||
return self._default_credentials.password
|
||||
|
||||
async def perform_handshake1(self) -> tuple[str, str, str] | None:
|
||||
"""Perform the handshake1."""
|
||||
resp_dict = None
|
||||
if self._username:
|
||||
local_nonce = secrets.token_bytes(8).hex().upper()
|
||||
resp_dict = await self.try_send_handshake1(self._username, local_nonce)
|
||||
|
||||
if (
|
||||
resp_dict
|
||||
and (error_code := self._get_response_error(resp_dict))
|
||||
is SmartErrorCode.SESSION_EXPIRED
|
||||
and (
|
||||
encrypt_type := resp_dict.get("result", {})
|
||||
.get("data", {})
|
||||
.get("encrypt_type")
|
||||
)
|
||||
and (encrypt_type != ["3"])
|
||||
):
|
||||
_LOGGER.debug(
|
||||
"Received encrypt_type %s for %s, trying less secure login",
|
||||
encrypt_type,
|
||||
self._host,
|
||||
)
|
||||
if await self.try_perform_login():
|
||||
return None
|
||||
|
||||
# Try the default username. If it fails raise the original error_code
|
||||
if (
|
||||
not resp_dict
|
||||
@ -369,6 +481,7 @@ class SslAesTransport(BaseTransport):
|
||||
is not SmartErrorCode.INVALID_NONCE
|
||||
or "nonce" not in resp_dict["result"].get("data", {})
|
||||
):
|
||||
_LOGGER.debug("Trying default credentials to %s", self._host)
|
||||
local_nonce = secrets.token_bytes(8).hex().upper()
|
||||
default_resp_dict = await self.try_send_handshake1(
|
||||
self._default_credentials.username, local_nonce
|
||||
@ -378,7 +491,7 @@ class SslAesTransport(BaseTransport):
|
||||
) is SmartErrorCode.INVALID_NONCE and "nonce" in default_resp_dict[
|
||||
"result"
|
||||
].get("data", {}):
|
||||
_LOGGER.debug("Connected to {self._host} with default username")
|
||||
_LOGGER.debug("Connected to %s with default username", self._host)
|
||||
self._username = self._default_credentials.username
|
||||
error_code = default_error_code
|
||||
resp_dict = default_resp_dict
|
||||
@ -397,12 +510,8 @@ class SslAesTransport(BaseTransport):
|
||||
|
||||
server_nonce = resp_dict["result"]["data"]["nonce"]
|
||||
device_confirm = resp_dict["result"]["data"]["device_confirm"]
|
||||
if self._credentials and self._credentials != Credentials():
|
||||
pwd_hash = _sha256_hash(self._credentials.password.encode())
|
||||
elif self._username and self._password:
|
||||
pwd_hash = _sha256_hash(self._password.encode())
|
||||
else:
|
||||
pwd_hash = _sha256_hash(self._default_credentials.password.encode())
|
||||
|
||||
pwd_hash = _sha256_hash(self._pwd_to_hash().encode())
|
||||
|
||||
expected_confirm_sha256 = self.generate_confirm_hash(
|
||||
local_nonce, server_nonce, pwd_hash
|
||||
@ -414,7 +523,9 @@ class SslAesTransport(BaseTransport):
|
||||
if TYPE_CHECKING:
|
||||
assert self._credentials
|
||||
assert self._credentials.password
|
||||
pwd_hash = _md5_hash(self._credentials.password.encode())
|
||||
|
||||
pwd_hash = _md5_hash(self._pwd_to_hash().encode())
|
||||
|
||||
expected_confirm_md5 = self.generate_confirm_hash(
|
||||
local_nonce, server_nonce, pwd_hash
|
||||
)
|
||||
@ -422,8 +533,17 @@ class SslAesTransport(BaseTransport):
|
||||
_LOGGER.debug("Credentials match")
|
||||
return local_nonce, server_nonce, pwd_hash
|
||||
|
||||
for val in {"admin", "tpadmin", "slprealtek"}:
|
||||
for func in {_sha256_hash, _md5_hash, _sha1_hash, lambda x: x.decode()}:
|
||||
pwd_hash = func(val.encode())
|
||||
ec = self.generate_confirm_hash(local_nonce, server_nonce, pwd_hash)
|
||||
if device_confirm == ec:
|
||||
_LOGGER.debug("Credentials match with %s %s", val, func.__name__)
|
||||
return local_nonce, server_nonce, pwd_hash
|
||||
|
||||
msg = f"Server response doesn't match our challenge on ip {self._host}"
|
||||
_LOGGER.debug(msg)
|
||||
|
||||
raise AuthenticationError(msg)
|
||||
|
||||
async def try_send_handshake1(self, username: str, local_nonce: str) -> dict:
|
||||
@ -462,7 +582,10 @@ class SslAesTransport(BaseTransport):
|
||||
if self._state is TransportState.HANDSHAKE_REQUIRED:
|
||||
await self.perform_handshake()
|
||||
|
||||
return await self.send_secure_passthrough(request)
|
||||
if self._send_secure:
|
||||
return await self.send_secure_passthrough(request)
|
||||
|
||||
return await self.send_unencrypted(request)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the http client and reset internal state."""
|
||||
|
@ -13,9 +13,13 @@ import aiohttp
|
||||
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
|
||||
|
||||
from kasa import (
|
||||
BaseProtocol,
|
||||
Credentials,
|
||||
Discover,
|
||||
IotProtocol,
|
||||
KasaException,
|
||||
SmartCamProtocol,
|
||||
SmartProtocol,
|
||||
)
|
||||
from kasa.device_factory import (
|
||||
Device,
|
||||
@ -33,6 +37,16 @@ from kasa.deviceconfig import (
|
||||
DeviceFamily,
|
||||
)
|
||||
from kasa.discover import DiscoveryResult
|
||||
from kasa.transports import (
|
||||
AesTransport,
|
||||
BaseTransport,
|
||||
KlapTransport,
|
||||
KlapTransportV2,
|
||||
LinkieTransportV2,
|
||||
SslAesTransport,
|
||||
SslTransport,
|
||||
XorTransport,
|
||||
)
|
||||
|
||||
from .conftest import DISCOVERY_MOCK_IP
|
||||
|
||||
@ -203,3 +217,74 @@ async def test_device_class_from_unknown_family(caplog):
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
assert get_device_class_from_family(dummy_name, https=False) == SmartDevice
|
||||
assert f"Unknown SMART device with {dummy_name}" in caplog.text
|
||||
|
||||
|
||||
# Aliases to make the test params more readable
|
||||
CP = DeviceConnectionParameters
|
||||
DF = DeviceFamily
|
||||
ET = DeviceEncryptionType
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("conn_params", "expected_protocol", "expected_transport"),
|
||||
[
|
||||
pytest.param(
|
||||
CP(DF.SmartIpCamera, ET.Aes, https=True),
|
||||
SmartCamProtocol,
|
||||
SslAesTransport,
|
||||
id="smartcam",
|
||||
),
|
||||
pytest.param(
|
||||
CP(DF.SmartTapoHub, ET.Aes, https=True),
|
||||
SmartCamProtocol,
|
||||
SslAesTransport,
|
||||
id="smartcam-hub",
|
||||
),
|
||||
pytest.param(
|
||||
CP(DF.IotIpCamera, ET.Aes, https=True),
|
||||
IotProtocol,
|
||||
LinkieTransportV2,
|
||||
id="kasacam",
|
||||
),
|
||||
pytest.param(
|
||||
CP(DF.SmartTapoRobovac, ET.Aes, https=True),
|
||||
SmartProtocol,
|
||||
SslTransport,
|
||||
id="robovac",
|
||||
),
|
||||
pytest.param(
|
||||
CP(DF.IotSmartPlugSwitch, ET.Klap, https=False),
|
||||
IotProtocol,
|
||||
KlapTransport,
|
||||
id="iot-klap",
|
||||
),
|
||||
pytest.param(
|
||||
CP(DF.IotSmartPlugSwitch, ET.Xor, https=False),
|
||||
IotProtocol,
|
||||
XorTransport,
|
||||
id="iot-xor",
|
||||
),
|
||||
pytest.param(
|
||||
CP(DF.SmartTapoPlug, ET.Aes, https=False),
|
||||
SmartProtocol,
|
||||
AesTransport,
|
||||
id="smart-aes",
|
||||
),
|
||||
pytest.param(
|
||||
CP(DF.SmartTapoPlug, ET.Klap, https=False),
|
||||
SmartProtocol,
|
||||
KlapTransportV2,
|
||||
id="smart-klap",
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_get_protocol(
|
||||
conn_params: DeviceConnectionParameters,
|
||||
expected_protocol: type[BaseProtocol],
|
||||
expected_transport: type[BaseTransport],
|
||||
):
|
||||
"""Test get_protocol returns the right protocol."""
|
||||
config = DeviceConfig("127.0.0.1", connection_type=conn_params)
|
||||
protocol = get_protocol(config)
|
||||
assert isinstance(protocol, expected_protocol)
|
||||
assert isinstance(protocol._transport, expected_transport)
|
||||
|
Loading…
Reference in New Issue
Block a user