mirror of
https://github.com/python-kasa/python-kasa.git
synced 2024-12-22 11:13:34 +00:00
Add ssltransport for robovacs (#943)
This PR implements a clear-text, token-based transport protocol seen on RV30 Plus (#937). - Client sends `{"username": "email@example.com", "password": md5(password)}` and gets back a token in the response - Rest of the communications are done with POST at `/app?token=<token>` --------- Co-authored-by: Steven B. <51370195+sdb9696@users.noreply.github.com>
This commit is contained in:
parent
9a52056522
commit
9966c6094a
@ -427,25 +427,25 @@ COMPONENT_REQUESTS = {
|
|||||||
"overheat_protection": [],
|
"overheat_protection": [],
|
||||||
# Vacuum components
|
# Vacuum components
|
||||||
"clean": [
|
"clean": [
|
||||||
SmartRequest.get_raw_request("get_clean_records"),
|
SmartRequest.get_raw_request("getCleanRecords"),
|
||||||
SmartRequest.get_raw_request("get_vac_state"),
|
SmartRequest.get_raw_request("getVacStatus"),
|
||||||
],
|
],
|
||||||
"battery": [SmartRequest.get_raw_request("get_battery_info")],
|
"battery": [SmartRequest.get_raw_request("getBatteryInfo")],
|
||||||
"consumables": [SmartRequest.get_raw_request("get_consumables_info")],
|
"consumables": [SmartRequest.get_raw_request("getConsumablesInfo")],
|
||||||
"direction_control": [],
|
"direction_control": [],
|
||||||
"button_and_led": [],
|
"button_and_led": [],
|
||||||
"speaker": [
|
"speaker": [
|
||||||
SmartRequest.get_raw_request("get_support_voice_language"),
|
SmartRequest.get_raw_request("getSupportVoiceLanguage"),
|
||||||
SmartRequest.get_raw_request("get_current_voice_language"),
|
SmartRequest.get_raw_request("getCurrentVoiceLanguage"),
|
||||||
],
|
],
|
||||||
"map": [
|
"map": [
|
||||||
SmartRequest.get_raw_request("get_map_info"),
|
SmartRequest.get_raw_request("getMapInfo"),
|
||||||
SmartRequest.get_raw_request("get_map_data"),
|
SmartRequest.get_raw_request("getMapData"),
|
||||||
],
|
],
|
||||||
"auto_change_map": [SmartRequest.get_raw_request("get_auto_change_map")],
|
"auto_change_map": [SmartRequest.get_raw_request("getAutoChangeMap")],
|
||||||
"dust_bucket": [SmartRequest.get_raw_request("get_auto_dust_collection")],
|
"dust_bucket": [SmartRequest.get_raw_request("getAutoDustCollection")],
|
||||||
"mop": [SmartRequest.get_raw_request("get_mop_state")],
|
"mop": [SmartRequest.get_raw_request("getMopState")],
|
||||||
"do_not_disturb": [SmartRequest.get_raw_request("get_do_not_disturb")],
|
"do_not_disturb": [SmartRequest.get_raw_request("getDoNotDisturb")],
|
||||||
"charge_pose_clean": [],
|
"charge_pose_clean": [],
|
||||||
"continue_breakpoint_sweep": [],
|
"continue_breakpoint_sweep": [],
|
||||||
"goto_point": [],
|
"goto_point": [],
|
||||||
|
@ -308,6 +308,7 @@ async def cli(
|
|||||||
if type == "camera":
|
if type == "camera":
|
||||||
encrypt_type = "AES"
|
encrypt_type = "AES"
|
||||||
https = True
|
https = True
|
||||||
|
login_version = 2
|
||||||
device_family = "SMART.IPCAMERA"
|
device_family = "SMART.IPCAMERA"
|
||||||
|
|
||||||
from kasa.device import Device
|
from kasa.device import Device
|
||||||
|
@ -32,6 +32,7 @@ from .transports import (
|
|||||||
BaseTransport,
|
BaseTransport,
|
||||||
KlapTransport,
|
KlapTransport,
|
||||||
KlapTransportV2,
|
KlapTransportV2,
|
||||||
|
SslTransport,
|
||||||
XorTransport,
|
XorTransport,
|
||||||
)
|
)
|
||||||
from .transports.sslaestransport import SslAesTransport
|
from .transports.sslaestransport import SslAesTransport
|
||||||
@ -155,6 +156,7 @@ def get_device_class_from_family(
|
|||||||
"SMART.KASAHUB": SmartDevice,
|
"SMART.KASAHUB": SmartDevice,
|
||||||
"SMART.KASASWITCH": SmartDevice,
|
"SMART.KASASWITCH": SmartDevice,
|
||||||
"SMART.IPCAMERA.HTTPS": SmartCamDevice,
|
"SMART.IPCAMERA.HTTPS": SmartCamDevice,
|
||||||
|
"SMART.TAPOROBOVAC": SmartDevice,
|
||||||
"IOT.SMARTPLUGSWITCH": IotPlug,
|
"IOT.SMARTPLUGSWITCH": IotPlug,
|
||||||
"IOT.SMARTBULB": IotBulb,
|
"IOT.SMARTBULB": IotBulb,
|
||||||
}
|
}
|
||||||
@ -176,20 +178,30 @@ def get_protocol(
|
|||||||
"""Return the protocol from the connection name."""
|
"""Return the protocol from the connection name."""
|
||||||
protocol_name = config.connection_type.device_family.value.split(".")[0]
|
protocol_name = config.connection_type.device_family.value.split(".")[0]
|
||||||
ctype = config.connection_type
|
ctype = config.connection_type
|
||||||
|
|
||||||
protocol_transport_key = (
|
protocol_transport_key = (
|
||||||
protocol_name
|
protocol_name
|
||||||
+ "."
|
+ "."
|
||||||
+ ctype.encryption_type.value
|
+ ctype.encryption_type.value
|
||||||
+ (".HTTPS" if ctype.https else "")
|
+ (".HTTPS" if ctype.https else "")
|
||||||
|
+ (
|
||||||
|
f".{ctype.login_version}"
|
||||||
|
if ctype.login_version and ctype.login_version > 1
|
||||||
|
else ""
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_LOGGER.debug("Finding transport for %s", protocol_transport_key)
|
||||||
supported_device_protocols: dict[
|
supported_device_protocols: dict[
|
||||||
str, tuple[type[BaseProtocol], type[BaseTransport]]
|
str, tuple[type[BaseProtocol], type[BaseTransport]]
|
||||||
] = {
|
] = {
|
||||||
"IOT.XOR": (IotProtocol, XorTransport),
|
"IOT.XOR": (IotProtocol, XorTransport),
|
||||||
"IOT.KLAP": (IotProtocol, KlapTransport),
|
"IOT.KLAP": (IotProtocol, KlapTransport),
|
||||||
"SMART.AES": (SmartProtocol, AesTransport),
|
"SMART.AES": (SmartProtocol, AesTransport),
|
||||||
"SMART.KLAP": (SmartProtocol, KlapTransportV2),
|
"SMART.AES.2": (SmartProtocol, AesTransport),
|
||||||
"SMART.AES.HTTPS": (SmartCamProtocol, SslAesTransport),
|
"SMART.KLAP.2": (SmartProtocol, KlapTransportV2),
|
||||||
|
"SMART.AES.HTTPS.2": (SmartCamProtocol, SslAesTransport),
|
||||||
|
"SMART.AES.HTTPS": (SmartProtocol, SslTransport),
|
||||||
}
|
}
|
||||||
if not (prot_tran_cls := supported_device_protocols.get(protocol_transport_key)):
|
if not (prot_tran_cls := supported_device_protocols.get(protocol_transport_key)):
|
||||||
return None
|
return None
|
||||||
|
@ -21,6 +21,7 @@ class DeviceType(Enum):
|
|||||||
Hub = "hub"
|
Hub = "hub"
|
||||||
Fan = "fan"
|
Fan = "fan"
|
||||||
Thermostat = "thermostat"
|
Thermostat = "thermostat"
|
||||||
|
Vacuum = "vacuum"
|
||||||
Unknown = "unknown"
|
Unknown = "unknown"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -77,6 +77,7 @@ class DeviceFamily(Enum):
|
|||||||
SmartTapoHub = "SMART.TAPOHUB"
|
SmartTapoHub = "SMART.TAPOHUB"
|
||||||
SmartKasaHub = "SMART.KASAHUB"
|
SmartKasaHub = "SMART.KASAHUB"
|
||||||
SmartIpCamera = "SMART.IPCAMERA"
|
SmartIpCamera = "SMART.IPCAMERA"
|
||||||
|
SmartTapoRobovac = "SMART.TAPOROBOVAC"
|
||||||
|
|
||||||
|
|
||||||
class _DeviceConfigBaseMixin(DataClassJSONMixin):
|
class _DeviceConfigBaseMixin(DataClassJSONMixin):
|
||||||
|
@ -598,10 +598,12 @@ class Discover:
|
|||||||
for encrypt in Device.EncryptionType
|
for encrypt in Device.EncryptionType
|
||||||
for device_family in main_device_families
|
for device_family in main_device_families
|
||||||
for https in (True, False)
|
for https in (True, False)
|
||||||
|
for login_version in (None, 2)
|
||||||
if (
|
if (
|
||||||
conn_params := DeviceConnectionParameters(
|
conn_params := DeviceConnectionParameters(
|
||||||
device_family=device_family,
|
device_family=device_family,
|
||||||
encryption_type=encrypt,
|
encryption_type=encrypt,
|
||||||
|
login_version=login_version,
|
||||||
https=https,
|
https=https,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -768,6 +770,13 @@ class Discover:
|
|||||||
):
|
):
|
||||||
encrypt_type = encrypt_info.sym_schm
|
encrypt_type = encrypt_info.sym_schm
|
||||||
|
|
||||||
|
if (
|
||||||
|
not (login_version := encrypt_schm.lv)
|
||||||
|
and (et := discovery_result.encrypt_type)
|
||||||
|
and et == ["3"]
|
||||||
|
):
|
||||||
|
login_version = 2
|
||||||
|
|
||||||
if not encrypt_type:
|
if not encrypt_type:
|
||||||
raise UnsupportedDeviceError(
|
raise UnsupportedDeviceError(
|
||||||
f"Unsupported device {config.host} of type {type_} "
|
f"Unsupported device {config.host} of type {type_} "
|
||||||
@ -778,7 +787,7 @@ class Discover:
|
|||||||
config.connection_type = DeviceConnectionParameters.from_values(
|
config.connection_type = DeviceConnectionParameters.from_values(
|
||||||
type_,
|
type_,
|
||||||
encrypt_type,
|
encrypt_type,
|
||||||
encrypt_schm.lv,
|
login_version,
|
||||||
encrypt_schm.is_support_https,
|
encrypt_schm.is_support_https,
|
||||||
)
|
)
|
||||||
except KasaException as ex:
|
except KasaException as ex:
|
||||||
|
@ -802,6 +802,8 @@ class SmartDevice(Device):
|
|||||||
return DeviceType.Sensor
|
return DeviceType.Sensor
|
||||||
if "ENERGY" in device_type:
|
if "ENERGY" in device_type:
|
||||||
return DeviceType.Thermostat
|
return DeviceType.Thermostat
|
||||||
|
if "ROBOVAC" in device_type:
|
||||||
|
return DeviceType.Vacuum
|
||||||
_LOGGER.warning("Unknown device type, falling back to plug")
|
_LOGGER.warning("Unknown device type, falling back to plug")
|
||||||
return DeviceType.Plug
|
return DeviceType.Plug
|
||||||
|
|
||||||
|
@ -3,11 +3,13 @@
|
|||||||
from .aestransport import AesEncyptionSession, AesTransport
|
from .aestransport import AesEncyptionSession, AesTransport
|
||||||
from .basetransport import BaseTransport
|
from .basetransport import BaseTransport
|
||||||
from .klaptransport import KlapTransport, KlapTransportV2
|
from .klaptransport import KlapTransport, KlapTransportV2
|
||||||
|
from .ssltransport import SslTransport
|
||||||
from .xortransport import XorEncryption, XorTransport
|
from .xortransport import XorEncryption, XorTransport
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AesTransport",
|
"AesTransport",
|
||||||
"AesEncyptionSession",
|
"AesEncyptionSession",
|
||||||
|
"SslTransport",
|
||||||
"BaseTransport",
|
"BaseTransport",
|
||||||
"KlapTransport",
|
"KlapTransport",
|
||||||
"KlapTransportV2",
|
"KlapTransportV2",
|
||||||
|
233
kasa/transports/ssltransport.py
Normal file
233
kasa/transports/ssltransport.py
Normal file
@ -0,0 +1,233 @@
|
|||||||
|
"""Implementation of the clear-text passthrough ssl transport.
|
||||||
|
|
||||||
|
This transport does not encrypt the passthrough payloads at all, but requires a login.
|
||||||
|
This has been seen on some devices (like robovacs).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from enum import Enum, auto
|
||||||
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
|
|
||||||
|
from yarl import URL
|
||||||
|
|
||||||
|
from kasa.credentials import DEFAULT_CREDENTIALS, Credentials, get_default_credentials
|
||||||
|
from kasa.deviceconfig import DeviceConfig
|
||||||
|
from kasa.exceptions import (
|
||||||
|
SMART_AUTHENTICATION_ERRORS,
|
||||||
|
SMART_RETRYABLE_ERRORS,
|
||||||
|
AuthenticationError,
|
||||||
|
DeviceError,
|
||||||
|
KasaException,
|
||||||
|
SmartErrorCode,
|
||||||
|
_RetryableError,
|
||||||
|
)
|
||||||
|
from kasa.httpclient import HttpClient
|
||||||
|
from kasa.json import dumps as json_dumps
|
||||||
|
from kasa.json import loads as json_loads
|
||||||
|
from kasa.transports import BaseTransport
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
ONE_DAY_SECONDS = 86400
|
||||||
|
SESSION_EXPIRE_BUFFER_SECONDS = 60 * 20
|
||||||
|
|
||||||
|
|
||||||
|
def _md5_hash(payload: bytes) -> str:
|
||||||
|
return hashlib.md5(payload).hexdigest().upper() # noqa: S324
|
||||||
|
|
||||||
|
|
||||||
|
class TransportState(Enum):
|
||||||
|
"""Enum for transport state."""
|
||||||
|
|
||||||
|
LOGIN_REQUIRED = auto() # Login needed
|
||||||
|
ESTABLISHED = auto() # Ready to send requests
|
||||||
|
|
||||||
|
|
||||||
|
class SslTransport(BaseTransport):
|
||||||
|
"""Implementation of the cleartext transport protocol.
|
||||||
|
|
||||||
|
This transport uses HTTPS without any further payload encryption.
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_PORT: int = 4433
|
||||||
|
COMMON_HEADERS = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
BACKOFF_SECONDS_AFTER_LOGIN_ERROR = 1
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
config: DeviceConfig,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(config=config)
|
||||||
|
|
||||||
|
if (
|
||||||
|
not self._credentials or self._credentials.username is None
|
||||||
|
) and not self._credentials_hash:
|
||||||
|
self._credentials = Credentials()
|
||||||
|
|
||||||
|
if self._credentials:
|
||||||
|
self._login_params = self._get_login_params(self._credentials)
|
||||||
|
else:
|
||||||
|
self._login_params = json_loads(
|
||||||
|
base64.b64decode(self._credentials_hash.encode()).decode() # type: ignore[union-attr]
|
||||||
|
)
|
||||||
|
|
||||||
|
self._default_credentials: Credentials | None = None
|
||||||
|
self._http_client: HttpClient = HttpClient(config)
|
||||||
|
|
||||||
|
self._state = TransportState.LOGIN_REQUIRED
|
||||||
|
self._session_expire_at: float | None = None
|
||||||
|
|
||||||
|
self._app_url = URL(f"https://{self._host}:{self._port}/app")
|
||||||
|
|
||||||
|
_LOGGER.debug("Created ssltransport for %s", self._host)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_port(self) -> int:
|
||||||
|
"""Default port for the transport."""
|
||||||
|
return self.DEFAULT_PORT
|
||||||
|
|
||||||
|
@property
|
||||||
|
def credentials_hash(self) -> str:
|
||||||
|
"""The hashed credentials used by the transport."""
|
||||||
|
return base64.b64encode(json_dumps(self._login_params).encode()).decode()
|
||||||
|
|
||||||
|
def _get_login_params(self, credentials: Credentials) -> dict[str, str]:
|
||||||
|
"""Get the login parameters based on the login_version."""
|
||||||
|
un, pw = self.hash_credentials(credentials)
|
||||||
|
return {"password": pw, "username": un}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def hash_credentials(credentials: Credentials) -> tuple[str, str]:
|
||||||
|
"""Hash the credentials."""
|
||||||
|
un = credentials.username
|
||||||
|
pw = _md5_hash(credentials.password.encode())
|
||||||
|
return un, pw
|
||||||
|
|
||||||
|
async def _handle_response_error_code(self, resp_dict: Any, msg: str) -> None:
|
||||||
|
"""Handle response errors to request reauth etc."""
|
||||||
|
error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
|
||||||
|
if error_code == SmartErrorCode.SUCCESS:
|
||||||
|
return
|
||||||
|
|
||||||
|
msg = f"{msg}: {self._host}: {error_code.name}({error_code.value})"
|
||||||
|
|
||||||
|
if error_code in SMART_RETRYABLE_ERRORS:
|
||||||
|
raise _RetryableError(msg, error_code=error_code)
|
||||||
|
|
||||||
|
if error_code in SMART_AUTHENTICATION_ERRORS:
|
||||||
|
await self.reset()
|
||||||
|
raise AuthenticationError(msg, error_code=error_code)
|
||||||
|
|
||||||
|
raise DeviceError(msg, error_code=error_code)
|
||||||
|
|
||||||
|
async def send_request(self, request: str) -> dict[str, Any]:
|
||||||
|
"""Send request."""
|
||||||
|
url = self._app_url
|
||||||
|
|
||||||
|
_LOGGER.debug("Sending %s to %s", request, url)
|
||||||
|
|
||||||
|
status_code, resp_dict = await self._http_client.post(
|
||||||
|
url,
|
||||||
|
json=request,
|
||||||
|
headers=self.COMMON_HEADERS,
|
||||||
|
)
|
||||||
|
|
||||||
|
if status_code != 200:
|
||||||
|
raise KasaException(
|
||||||
|
f"{self._host} responded with an unexpected "
|
||||||
|
+ f"status code {status_code}"
|
||||||
|
)
|
||||||
|
|
||||||
|
_LOGGER.debug("Response with %s: %r", status_code, resp_dict)
|
||||||
|
|
||||||
|
await self._handle_response_error_code(resp_dict, "Error sending request")
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
resp_dict = cast(dict[str, Any], resp_dict)
|
||||||
|
|
||||||
|
return resp_dict
|
||||||
|
|
||||||
|
async def perform_login(self) -> None:
|
||||||
|
"""Login to the device."""
|
||||||
|
try:
|
||||||
|
await self.try_login(self._login_params)
|
||||||
|
except AuthenticationError as aex:
|
||||||
|
try:
|
||||||
|
if aex.error_code is not SmartErrorCode.LOGIN_ERROR:
|
||||||
|
raise aex
|
||||||
|
|
||||||
|
_LOGGER.debug("Login failed, going to try default credentials")
|
||||||
|
if self._default_credentials is None:
|
||||||
|
self._default_credentials = get_default_credentials(
|
||||||
|
DEFAULT_CREDENTIALS["TAPO"]
|
||||||
|
)
|
||||||
|
await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_LOGIN_ERROR)
|
||||||
|
|
||||||
|
await self.try_login(self._get_login_params(self._default_credentials))
|
||||||
|
_LOGGER.debug(
|
||||||
|
"%s: logged in with default credentials",
|
||||||
|
self._host,
|
||||||
|
)
|
||||||
|
except AuthenticationError:
|
||||||
|
raise
|
||||||
|
except Exception as ex:
|
||||||
|
raise KasaException(
|
||||||
|
"Unable to login and trying default "
|
||||||
|
+ f"login raised another exception: {ex}",
|
||||||
|
ex,
|
||||||
|
) from ex
|
||||||
|
|
||||||
|
async def try_login(self, login_params: dict[str, Any]) -> None:
|
||||||
|
"""Try to login with supplied login_params."""
|
||||||
|
login_request = {
|
||||||
|
"method": "login",
|
||||||
|
"params": login_params,
|
||||||
|
}
|
||||||
|
request = json_dumps(login_request)
|
||||||
|
_LOGGER.debug("Going to send login request")
|
||||||
|
|
||||||
|
resp_dict = await self.send_request(request)
|
||||||
|
await self._handle_response_error_code(resp_dict, "Error logging in")
|
||||||
|
|
||||||
|
login_token = resp_dict["result"]["token"]
|
||||||
|
self._app_url = self._app_url.with_query(f"token={login_token}")
|
||||||
|
self._state = TransportState.ESTABLISHED
|
||||||
|
self._session_expire_at = (
|
||||||
|
time.time() + ONE_DAY_SECONDS - SESSION_EXPIRE_BUFFER_SECONDS
|
||||||
|
)
|
||||||
|
|
||||||
|
def _session_expired(self) -> bool:
|
||||||
|
"""Return true if session has expired."""
|
||||||
|
return (
|
||||||
|
self._session_expire_at is None
|
||||||
|
or self._session_expire_at - time.time() <= 0
|
||||||
|
)
|
||||||
|
|
||||||
|
async def send(self, request: str) -> dict[str, Any]:
|
||||||
|
"""Send the request."""
|
||||||
|
_LOGGER.info("Going to send %s", request)
|
||||||
|
if self._state is not TransportState.ESTABLISHED or self._session_expired():
|
||||||
|
_LOGGER.debug("Transport not established or session expired, logging in")
|
||||||
|
await self.perform_login()
|
||||||
|
|
||||||
|
return await self.send_request(request)
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close the http client and reset internal state."""
|
||||||
|
await self.reset()
|
||||||
|
await self._http_client.close()
|
||||||
|
|
||||||
|
async def reset(self) -> None:
|
||||||
|
"""Reset internal login state."""
|
||||||
|
self._state = TransportState.LOGIN_REQUIRED
|
||||||
|
self._app_url = URL(f"https://{self._host}:{self._port}/app")
|
@ -692,6 +692,8 @@ async def test_credentials(discovery_mock, mocker, runner):
|
|||||||
dr.device_type,
|
dr.device_type,
|
||||||
"--encrypt-type",
|
"--encrypt-type",
|
||||||
dr.mgt_encrypt_schm.encrypt_type,
|
dr.mgt_encrypt_schm.encrypt_type,
|
||||||
|
"--login-version",
|
||||||
|
dr.mgt_encrypt_schm.lv or 1,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
assert res.exit_code == 0
|
assert res.exit_code == 0
|
||||||
|
@ -47,7 +47,10 @@ def _get_connection_type_device_class(discovery_info):
|
|||||||
dr = DiscoveryResult.from_dict(discovery_info["result"])
|
dr = DiscoveryResult.from_dict(discovery_info["result"])
|
||||||
|
|
||||||
connection_type = DeviceConnectionParameters.from_values(
|
connection_type = DeviceConnectionParameters.from_values(
|
||||||
dr.device_type, dr.mgt_encrypt_schm.encrypt_type
|
dr.device_type,
|
||||||
|
dr.mgt_encrypt_schm.encrypt_type,
|
||||||
|
dr.mgt_encrypt_schm.lv,
|
||||||
|
dr.mgt_encrypt_schm.is_support_https,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
connection_type = DeviceConnectionParameters.from_values(
|
connection_type = DeviceConnectionParameters.from_values(
|
||||||
|
374
tests/transports/test_ssltransport.py
Normal file
374
tests/transports/test_ssltransport.py
Normal file
@ -0,0 +1,374 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from base64 import b64encode
|
||||||
|
from contextlib import nullcontext as does_not_raise
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import pytest
|
||||||
|
from yarl import URL
|
||||||
|
|
||||||
|
from kasa.credentials import DEFAULT_CREDENTIALS, Credentials, get_default_credentials
|
||||||
|
from kasa.deviceconfig import DeviceConfig
|
||||||
|
from kasa.exceptions import (
|
||||||
|
AuthenticationError,
|
||||||
|
DeviceError,
|
||||||
|
KasaException,
|
||||||
|
SmartErrorCode,
|
||||||
|
_RetryableError,
|
||||||
|
)
|
||||||
|
from kasa.httpclient import HttpClient
|
||||||
|
from kasa.json import dumps as json_dumps
|
||||||
|
from kasa.json import loads as json_loads
|
||||||
|
from kasa.transports import SslTransport
|
||||||
|
from kasa.transports.ssltransport import TransportState, _md5_hash
|
||||||
|
|
||||||
|
# Transport tests are not designed for real devices
|
||||||
|
pytestmark = [pytest.mark.requires_dummy]
|
||||||
|
|
||||||
|
MOCK_PWD = "correct_pwd" # noqa: S105
|
||||||
|
MOCK_USER = "mock@example.com"
|
||||||
|
MOCK_BAD_USER_OR_PWD = "foobar" # noqa: S105
|
||||||
|
MOCK_TOKEN = "abcdefghijklmnopqrstuvwxyz1234)(" # noqa: S105
|
||||||
|
|
||||||
|
DEFAULT_CREDS = get_default_credentials(DEFAULT_CREDENTIALS["TAPO"])
|
||||||
|
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
(
|
||||||
|
"status_code",
|
||||||
|
"error_code",
|
||||||
|
"username",
|
||||||
|
"password",
|
||||||
|
"expectation",
|
||||||
|
),
|
||||||
|
[
|
||||||
|
pytest.param(
|
||||||
|
200,
|
||||||
|
SmartErrorCode.SUCCESS,
|
||||||
|
MOCK_USER,
|
||||||
|
MOCK_PWD,
|
||||||
|
does_not_raise(),
|
||||||
|
id="success",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
200,
|
||||||
|
SmartErrorCode.UNSPECIFIC_ERROR,
|
||||||
|
MOCK_USER,
|
||||||
|
MOCK_PWD,
|
||||||
|
pytest.raises(_RetryableError),
|
||||||
|
id="test retry",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
200,
|
||||||
|
SmartErrorCode.DEVICE_BLOCKED,
|
||||||
|
MOCK_USER,
|
||||||
|
MOCK_PWD,
|
||||||
|
pytest.raises(DeviceError),
|
||||||
|
id="test regular error",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
400,
|
||||||
|
SmartErrorCode.INTERNAL_UNKNOWN_ERROR,
|
||||||
|
MOCK_USER,
|
||||||
|
MOCK_PWD,
|
||||||
|
pytest.raises(KasaException),
|
||||||
|
id="400 error",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
200,
|
||||||
|
SmartErrorCode.LOGIN_ERROR,
|
||||||
|
MOCK_BAD_USER_OR_PWD,
|
||||||
|
MOCK_PWD,
|
||||||
|
pytest.raises(AuthenticationError),
|
||||||
|
id="bad-username",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
200,
|
||||||
|
[SmartErrorCode.LOGIN_ERROR, SmartErrorCode.SUCCESS],
|
||||||
|
MOCK_BAD_USER_OR_PWD,
|
||||||
|
"",
|
||||||
|
does_not_raise(),
|
||||||
|
id="working-fallback",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
200,
|
||||||
|
[SmartErrorCode.LOGIN_ERROR, SmartErrorCode.LOGIN_ERROR],
|
||||||
|
MOCK_BAD_USER_OR_PWD,
|
||||||
|
"",
|
||||||
|
pytest.raises(AuthenticationError),
|
||||||
|
id="fallback-fail",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
200,
|
||||||
|
SmartErrorCode.LOGIN_ERROR,
|
||||||
|
MOCK_USER,
|
||||||
|
MOCK_BAD_USER_OR_PWD,
|
||||||
|
pytest.raises(AuthenticationError),
|
||||||
|
id="bad-password",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
200,
|
||||||
|
SmartErrorCode.TRANSPORT_UNKNOWN_CREDENTIALS_ERROR,
|
||||||
|
MOCK_USER,
|
||||||
|
MOCK_PWD,
|
||||||
|
pytest.raises(AuthenticationError),
|
||||||
|
id="auth-error != login_error",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_login(
|
||||||
|
mocker,
|
||||||
|
status_code,
|
||||||
|
error_code,
|
||||||
|
username,
|
||||||
|
password,
|
||||||
|
expectation,
|
||||||
|
):
|
||||||
|
host = "127.0.0.1"
|
||||||
|
mock_ssl_aes_device = MockSslDevice(
|
||||||
|
host,
|
||||||
|
status_code=status_code,
|
||||||
|
send_error_code=error_code,
|
||||||
|
)
|
||||||
|
mocker.patch.object(
|
||||||
|
aiohttp.ClientSession, "post", side_effect=mock_ssl_aes_device.post
|
||||||
|
)
|
||||||
|
|
||||||
|
transport = SslTransport(
|
||||||
|
config=DeviceConfig(host, credentials=Credentials(username, password))
|
||||||
|
)
|
||||||
|
|
||||||
|
assert transport._state is TransportState.LOGIN_REQUIRED
|
||||||
|
with expectation:
|
||||||
|
await transport.perform_login()
|
||||||
|
assert transport._state is TransportState.ESTABLISHED
|
||||||
|
|
||||||
|
await transport.close()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_credentials_hash(mocker):
|
||||||
|
host = "127.0.0.1"
|
||||||
|
mock_ssl_aes_device = MockSslDevice(host)
|
||||||
|
mocker.patch.object(
|
||||||
|
aiohttp.ClientSession, "post", side_effect=mock_ssl_aes_device.post
|
||||||
|
)
|
||||||
|
creds = Credentials(MOCK_USER, MOCK_PWD)
|
||||||
|
|
||||||
|
data = {"password": _md5_hash(MOCK_PWD.encode()), "username": MOCK_USER}
|
||||||
|
|
||||||
|
creds_hash = b64encode(json_dumps(data).encode()).decode()
|
||||||
|
|
||||||
|
# Test with credentials input
|
||||||
|
transport = SslTransport(config=DeviceConfig(host, credentials=creds))
|
||||||
|
assert transport.credentials_hash == creds_hash
|
||||||
|
|
||||||
|
# Test with credentials_hash input
|
||||||
|
transport = SslTransport(config=DeviceConfig(host, credentials_hash=creds_hash))
|
||||||
|
assert transport.credentials_hash == creds_hash
|
||||||
|
|
||||||
|
await transport.close()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_send(mocker):
|
||||||
|
host = "127.0.0.1"
|
||||||
|
mock_ssl_aes_device = MockSslDevice(host, send_error_code=SmartErrorCode.SUCCESS)
|
||||||
|
mocker.patch.object(
|
||||||
|
aiohttp.ClientSession, "post", side_effect=mock_ssl_aes_device.post
|
||||||
|
)
|
||||||
|
|
||||||
|
transport = SslTransport(
|
||||||
|
config=DeviceConfig(host, credentials=Credentials(MOCK_USER, MOCK_PWD))
|
||||||
|
)
|
||||||
|
try_login_spy = mocker.spy(transport, "try_login")
|
||||||
|
request = {
|
||||||
|
"method": "get_device_info",
|
||||||
|
"params": None,
|
||||||
|
}
|
||||||
|
assert transport._state is TransportState.LOGIN_REQUIRED
|
||||||
|
|
||||||
|
res = await transport.send(json_dumps(request))
|
||||||
|
assert "result" in res
|
||||||
|
try_login_spy.assert_called_once()
|
||||||
|
assert transport._state is TransportState.ESTABLISHED
|
||||||
|
|
||||||
|
# Second request does not
|
||||||
|
res = await transport.send(json_dumps(request))
|
||||||
|
try_login_spy.assert_called_once()
|
||||||
|
|
||||||
|
await transport.close()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_no_credentials(mocker):
|
||||||
|
"""Test transport without credentials."""
|
||||||
|
host = "127.0.0.1"
|
||||||
|
mock_ssl_aes_device = MockSslDevice(
|
||||||
|
host, send_error_code=SmartErrorCode.LOGIN_ERROR
|
||||||
|
)
|
||||||
|
mocker.patch.object(
|
||||||
|
aiohttp.ClientSession, "post", side_effect=mock_ssl_aes_device.post
|
||||||
|
)
|
||||||
|
|
||||||
|
transport = SslTransport(config=DeviceConfig(host))
|
||||||
|
try_login_spy = mocker.spy(transport, "try_login")
|
||||||
|
|
||||||
|
with pytest.raises(AuthenticationError):
|
||||||
|
await transport.send('{"method": "dummy"}')
|
||||||
|
|
||||||
|
# We get called twice
|
||||||
|
assert try_login_spy.call_count == 2
|
||||||
|
|
||||||
|
await transport.close()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_reset(mocker):
|
||||||
|
"""Test that transport state adjusts correctly for reset."""
|
||||||
|
host = "127.0.0.1"
|
||||||
|
mock_ssl_aes_device = MockSslDevice(host, send_error_code=SmartErrorCode.SUCCESS)
|
||||||
|
mocker.patch.object(
|
||||||
|
aiohttp.ClientSession, "post", side_effect=mock_ssl_aes_device.post
|
||||||
|
)
|
||||||
|
|
||||||
|
transport = SslTransport(
|
||||||
|
config=DeviceConfig(host, credentials=Credentials(MOCK_USER, MOCK_PWD))
|
||||||
|
)
|
||||||
|
|
||||||
|
assert transport._state is TransportState.LOGIN_REQUIRED
|
||||||
|
assert str(transport._app_url) == "https://127.0.0.1:4433/app"
|
||||||
|
|
||||||
|
await transport.perform_login()
|
||||||
|
assert transport._state is TransportState.ESTABLISHED
|
||||||
|
assert str(transport._app_url).startswith("https://127.0.0.1:4433/app?token=")
|
||||||
|
|
||||||
|
await transport.close()
|
||||||
|
assert transport._state is TransportState.LOGIN_REQUIRED
|
||||||
|
assert str(transport._app_url) == "https://127.0.0.1:4433/app"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_port_override():
|
||||||
|
"""Test that port override sets the app_url."""
|
||||||
|
host = "127.0.0.1"
|
||||||
|
port_override = 12345
|
||||||
|
config = DeviceConfig(
|
||||||
|
host, credentials=Credentials("foo", "bar"), port_override=port_override
|
||||||
|
)
|
||||||
|
transport = SslTransport(config=config)
|
||||||
|
|
||||||
|
assert str(transport._app_url) == f"https://127.0.0.1:{port_override}/app"
|
||||||
|
|
||||||
|
await transport.close()
|
||||||
|
|
||||||
|
|
||||||
|
class MockSslDevice:
|
||||||
|
"""Based on MockAesSslDevice."""
|
||||||
|
|
||||||
|
class _mock_response:
|
||||||
|
def __init__(self, status, request: dict):
|
||||||
|
self.status = status
|
||||||
|
self._json = request
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_t, exc_v, exc_tb):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def read(self):
|
||||||
|
if isinstance(self._json, dict):
|
||||||
|
return json_dumps(self._json).encode()
|
||||||
|
return self._json
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
host,
|
||||||
|
*,
|
||||||
|
status_code=200,
|
||||||
|
send_error_code=SmartErrorCode.INTERNAL_UNKNOWN_ERROR,
|
||||||
|
):
|
||||||
|
self.host = host
|
||||||
|
self.http_client = HttpClient(DeviceConfig(self.host))
|
||||||
|
|
||||||
|
self._state = TransportState.LOGIN_REQUIRED
|
||||||
|
|
||||||
|
# test behaviour attributes
|
||||||
|
self.status_code = status_code
|
||||||
|
self.send_error_code = send_error_code
|
||||||
|
|
||||||
|
async def post(self, url: URL, params=None, json=None, data=None, *_, **__):
|
||||||
|
if data:
|
||||||
|
json = json_loads(data)
|
||||||
|
_LOGGER.debug("Request %s: %s", url, json)
|
||||||
|
res = self._post(url, json)
|
||||||
|
_LOGGER.debug("Response %s, data: %s", res, await res.read())
|
||||||
|
return res
|
||||||
|
|
||||||
|
def _post(self, url: URL, json: dict[str, Any]):
|
||||||
|
method = json["method"]
|
||||||
|
|
||||||
|
if method == "login":
|
||||||
|
if self._state is TransportState.LOGIN_REQUIRED:
|
||||||
|
assert json.get("token") is None
|
||||||
|
assert url == URL(f"https://{self.host}:4433/app")
|
||||||
|
return self._return_login_response(url, json)
|
||||||
|
else:
|
||||||
|
_LOGGER.warning("Received login although already logged in")
|
||||||
|
pytest.fail("non-handled re-login logic")
|
||||||
|
|
||||||
|
assert url == URL(f"https://{self.host}:4433/app?token={MOCK_TOKEN}")
|
||||||
|
return self._return_send_response(url, json)
|
||||||
|
|
||||||
|
def _return_login_response(self, url: URL, request: dict[str, Any]):
|
||||||
|
request_username = request["params"].get("username")
|
||||||
|
request_password = request["params"].get("password")
|
||||||
|
|
||||||
|
# Handle multiple error codes
|
||||||
|
if isinstance(self.send_error_code, list):
|
||||||
|
error_code = self.send_error_code.pop(0)
|
||||||
|
else:
|
||||||
|
error_code = self.send_error_code
|
||||||
|
|
||||||
|
_LOGGER.debug("Using error code %s", error_code)
|
||||||
|
|
||||||
|
def _return_login_error():
|
||||||
|
resp = {
|
||||||
|
"error_code": error_code.value,
|
||||||
|
"result": {"unknown": "payload"},
|
||||||
|
}
|
||||||
|
|
||||||
|
_LOGGER.debug("Returning login error with status %s", self.status_code)
|
||||||
|
return self._mock_response(self.status_code, resp)
|
||||||
|
|
||||||
|
if error_code is not SmartErrorCode.SUCCESS:
|
||||||
|
# Bad username
|
||||||
|
if request_username == MOCK_BAD_USER_OR_PWD:
|
||||||
|
return _return_login_error()
|
||||||
|
|
||||||
|
# Bad password
|
||||||
|
if request_password == _md5_hash(MOCK_BAD_USER_OR_PWD.encode()):
|
||||||
|
return _return_login_error()
|
||||||
|
|
||||||
|
# Empty password
|
||||||
|
if request_password == _md5_hash(b""):
|
||||||
|
return _return_login_error()
|
||||||
|
|
||||||
|
self._state = TransportState.ESTABLISHED
|
||||||
|
resp = {
|
||||||
|
"error_code": error_code.value,
|
||||||
|
"result": {
|
||||||
|
"token": MOCK_TOKEN,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_LOGGER.debug("Returning login success with status %s", self.status_code)
|
||||||
|
return self._mock_response(self.status_code, resp)
|
||||||
|
|
||||||
|
def _return_send_response(self, url: URL, json: dict[str, Any]):
|
||||||
|
method = json["method"]
|
||||||
|
result = {
|
||||||
|
"result": {method: {"dummy": "response"}},
|
||||||
|
"error_code": self.send_error_code.value,
|
||||||
|
}
|
||||||
|
return self._mock_response(self.status_code, result)
|
Loading…
Reference in New Issue
Block a user