mirror of
				https://github.com/python-kasa/python-kasa.git
				synced 2025-11-04 06:32:07 +00:00 
			
		
		
		
	Add ssltransport for robovacs (#943)
This PR implements a clear-text, token-based transport protocol seen on RV30 Plus (#937). - Client sends `{"username": "email@example.com", "password": md5(password)}` and gets back a token in the response - Rest of the communications are done with POST at `/app?token=<token>` --------- Co-authored-by: Steven B. <51370195+sdb9696@users.noreply.github.com>
This commit is contained in:
		@@ -427,25 +427,25 @@ COMPONENT_REQUESTS = {
 | 
			
		||||
    "overheat_protection": [],
 | 
			
		||||
    # Vacuum components
 | 
			
		||||
    "clean": [
 | 
			
		||||
        SmartRequest.get_raw_request("get_clean_records"),
 | 
			
		||||
        SmartRequest.get_raw_request("get_vac_state"),
 | 
			
		||||
        SmartRequest.get_raw_request("getCleanRecords"),
 | 
			
		||||
        SmartRequest.get_raw_request("getVacStatus"),
 | 
			
		||||
    ],
 | 
			
		||||
    "battery": [SmartRequest.get_raw_request("get_battery_info")],
 | 
			
		||||
    "consumables": [SmartRequest.get_raw_request("get_consumables_info")],
 | 
			
		||||
    "battery": [SmartRequest.get_raw_request("getBatteryInfo")],
 | 
			
		||||
    "consumables": [SmartRequest.get_raw_request("getConsumablesInfo")],
 | 
			
		||||
    "direction_control": [],
 | 
			
		||||
    "button_and_led": [],
 | 
			
		||||
    "speaker": [
 | 
			
		||||
        SmartRequest.get_raw_request("get_support_voice_language"),
 | 
			
		||||
        SmartRequest.get_raw_request("get_current_voice_language"),
 | 
			
		||||
        SmartRequest.get_raw_request("getSupportVoiceLanguage"),
 | 
			
		||||
        SmartRequest.get_raw_request("getCurrentVoiceLanguage"),
 | 
			
		||||
    ],
 | 
			
		||||
    "map": [
 | 
			
		||||
        SmartRequest.get_raw_request("get_map_info"),
 | 
			
		||||
        SmartRequest.get_raw_request("get_map_data"),
 | 
			
		||||
        SmartRequest.get_raw_request("getMapInfo"),
 | 
			
		||||
        SmartRequest.get_raw_request("getMapData"),
 | 
			
		||||
    ],
 | 
			
		||||
    "auto_change_map": [SmartRequest.get_raw_request("get_auto_change_map")],
 | 
			
		||||
    "dust_bucket": [SmartRequest.get_raw_request("get_auto_dust_collection")],
 | 
			
		||||
    "mop": [SmartRequest.get_raw_request("get_mop_state")],
 | 
			
		||||
    "do_not_disturb": [SmartRequest.get_raw_request("get_do_not_disturb")],
 | 
			
		||||
    "auto_change_map": [SmartRequest.get_raw_request("getAutoChangeMap")],
 | 
			
		||||
    "dust_bucket": [SmartRequest.get_raw_request("getAutoDustCollection")],
 | 
			
		||||
    "mop": [SmartRequest.get_raw_request("getMopState")],
 | 
			
		||||
    "do_not_disturb": [SmartRequest.get_raw_request("getDoNotDisturb")],
 | 
			
		||||
    "charge_pose_clean": [],
 | 
			
		||||
    "continue_breakpoint_sweep": [],
 | 
			
		||||
    "goto_point": [],
 | 
			
		||||
 
 | 
			
		||||
@@ -308,6 +308,7 @@ async def cli(
 | 
			
		||||
        if type == "camera":
 | 
			
		||||
            encrypt_type = "AES"
 | 
			
		||||
            https = True
 | 
			
		||||
            login_version = 2
 | 
			
		||||
            device_family = "SMART.IPCAMERA"
 | 
			
		||||
 | 
			
		||||
        from kasa.device import Device
 | 
			
		||||
 
 | 
			
		||||
@@ -32,6 +32,7 @@ from .transports import (
 | 
			
		||||
    BaseTransport,
 | 
			
		||||
    KlapTransport,
 | 
			
		||||
    KlapTransportV2,
 | 
			
		||||
    SslTransport,
 | 
			
		||||
    XorTransport,
 | 
			
		||||
)
 | 
			
		||||
from .transports.sslaestransport import SslAesTransport
 | 
			
		||||
@@ -155,6 +156,7 @@ def get_device_class_from_family(
 | 
			
		||||
        "SMART.KASAHUB": SmartDevice,
 | 
			
		||||
        "SMART.KASASWITCH": SmartDevice,
 | 
			
		||||
        "SMART.IPCAMERA.HTTPS": SmartCamDevice,
 | 
			
		||||
        "SMART.TAPOROBOVAC": SmartDevice,
 | 
			
		||||
        "IOT.SMARTPLUGSWITCH": IotPlug,
 | 
			
		||||
        "IOT.SMARTBULB": IotBulb,
 | 
			
		||||
    }
 | 
			
		||||
@@ -176,20 +178,30 @@ def get_protocol(
 | 
			
		||||
    """Return the protocol from the connection name."""
 | 
			
		||||
    protocol_name = config.connection_type.device_family.value.split(".")[0]
 | 
			
		||||
    ctype = config.connection_type
 | 
			
		||||
 | 
			
		||||
    protocol_transport_key = (
 | 
			
		||||
        protocol_name
 | 
			
		||||
        + "."
 | 
			
		||||
        + ctype.encryption_type.value
 | 
			
		||||
        + (".HTTPS" if ctype.https else "")
 | 
			
		||||
        + (
 | 
			
		||||
            f".{ctype.login_version}"
 | 
			
		||||
            if ctype.login_version and ctype.login_version > 1
 | 
			
		||||
            else ""
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    _LOGGER.debug("Finding transport for %s", protocol_transport_key)
 | 
			
		||||
    supported_device_protocols: dict[
 | 
			
		||||
        str, tuple[type[BaseProtocol], type[BaseTransport]]
 | 
			
		||||
    ] = {
 | 
			
		||||
        "IOT.XOR": (IotProtocol, XorTransport),
 | 
			
		||||
        "IOT.KLAP": (IotProtocol, KlapTransport),
 | 
			
		||||
        "SMART.AES": (SmartProtocol, AesTransport),
 | 
			
		||||
        "SMART.KLAP": (SmartProtocol, KlapTransportV2),
 | 
			
		||||
        "SMART.AES.HTTPS": (SmartCamProtocol, SslAesTransport),
 | 
			
		||||
        "SMART.AES.2": (SmartProtocol, AesTransport),
 | 
			
		||||
        "SMART.KLAP.2": (SmartProtocol, KlapTransportV2),
 | 
			
		||||
        "SMART.AES.HTTPS.2": (SmartCamProtocol, SslAesTransport),
 | 
			
		||||
        "SMART.AES.HTTPS": (SmartProtocol, SslTransport),
 | 
			
		||||
    }
 | 
			
		||||
    if not (prot_tran_cls := supported_device_protocols.get(protocol_transport_key)):
 | 
			
		||||
        return None
 | 
			
		||||
 
 | 
			
		||||
@@ -21,6 +21,7 @@ class DeviceType(Enum):
 | 
			
		||||
    Hub = "hub"
 | 
			
		||||
    Fan = "fan"
 | 
			
		||||
    Thermostat = "thermostat"
 | 
			
		||||
    Vacuum = "vacuum"
 | 
			
		||||
    Unknown = "unknown"
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
 
 | 
			
		||||
@@ -77,6 +77,7 @@ class DeviceFamily(Enum):
 | 
			
		||||
    SmartTapoHub = "SMART.TAPOHUB"
 | 
			
		||||
    SmartKasaHub = "SMART.KASAHUB"
 | 
			
		||||
    SmartIpCamera = "SMART.IPCAMERA"
 | 
			
		||||
    SmartTapoRobovac = "SMART.TAPOROBOVAC"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _DeviceConfigBaseMixin(DataClassJSONMixin):
 | 
			
		||||
 
 | 
			
		||||
@@ -598,10 +598,12 @@ class Discover:
 | 
			
		||||
            for encrypt in Device.EncryptionType
 | 
			
		||||
            for device_family in main_device_families
 | 
			
		||||
            for https in (True, False)
 | 
			
		||||
            for login_version in (None, 2)
 | 
			
		||||
            if (
 | 
			
		||||
                conn_params := DeviceConnectionParameters(
 | 
			
		||||
                    device_family=device_family,
 | 
			
		||||
                    encryption_type=encrypt,
 | 
			
		||||
                    login_version=login_version,
 | 
			
		||||
                    https=https,
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
@@ -768,6 +770,13 @@ class Discover:
 | 
			
		||||
            ):
 | 
			
		||||
                encrypt_type = encrypt_info.sym_schm
 | 
			
		||||
 | 
			
		||||
            if (
 | 
			
		||||
                not (login_version := encrypt_schm.lv)
 | 
			
		||||
                and (et := discovery_result.encrypt_type)
 | 
			
		||||
                and et == ["3"]
 | 
			
		||||
            ):
 | 
			
		||||
                login_version = 2
 | 
			
		||||
 | 
			
		||||
            if not encrypt_type:
 | 
			
		||||
                raise UnsupportedDeviceError(
 | 
			
		||||
                    f"Unsupported device {config.host} of type {type_} "
 | 
			
		||||
@@ -778,7 +787,7 @@ class Discover:
 | 
			
		||||
            config.connection_type = DeviceConnectionParameters.from_values(
 | 
			
		||||
                type_,
 | 
			
		||||
                encrypt_type,
 | 
			
		||||
                encrypt_schm.lv,
 | 
			
		||||
                login_version,
 | 
			
		||||
                encrypt_schm.is_support_https,
 | 
			
		||||
            )
 | 
			
		||||
        except KasaException as ex:
 | 
			
		||||
 
 | 
			
		||||
@@ -802,6 +802,8 @@ class SmartDevice(Device):
 | 
			
		||||
            return DeviceType.Sensor
 | 
			
		||||
        if "ENERGY" in device_type:
 | 
			
		||||
            return DeviceType.Thermostat
 | 
			
		||||
        if "ROBOVAC" in device_type:
 | 
			
		||||
            return DeviceType.Vacuum
 | 
			
		||||
        _LOGGER.warning("Unknown device type, falling back to plug")
 | 
			
		||||
        return DeviceType.Plug
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -3,11 +3,13 @@
 | 
			
		||||
from .aestransport import AesEncyptionSession, AesTransport
 | 
			
		||||
from .basetransport import BaseTransport
 | 
			
		||||
from .klaptransport import KlapTransport, KlapTransportV2
 | 
			
		||||
from .ssltransport import SslTransport
 | 
			
		||||
from .xortransport import XorEncryption, XorTransport
 | 
			
		||||
 | 
			
		||||
__all__ = [
 | 
			
		||||
    "AesTransport",
 | 
			
		||||
    "AesEncyptionSession",
 | 
			
		||||
    "SslTransport",
 | 
			
		||||
    "BaseTransport",
 | 
			
		||||
    "KlapTransport",
 | 
			
		||||
    "KlapTransportV2",
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										233
									
								
								kasa/transports/ssltransport.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										233
									
								
								kasa/transports/ssltransport.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,233 @@
 | 
			
		||||
"""Implementation of the clear-text passthrough ssl transport.
 | 
			
		||||
 | 
			
		||||
This transport does not encrypt the passthrough payloads at all, but requires a login.
 | 
			
		||||
This has been seen on some devices (like robovacs).
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
 | 
			
		||||
import asyncio
 | 
			
		||||
import base64
 | 
			
		||||
import hashlib
 | 
			
		||||
import logging
 | 
			
		||||
import time
 | 
			
		||||
from enum import Enum, auto
 | 
			
		||||
from typing import TYPE_CHECKING, Any, cast
 | 
			
		||||
 | 
			
		||||
from yarl import URL
 | 
			
		||||
 | 
			
		||||
from kasa.credentials import DEFAULT_CREDENTIALS, Credentials, get_default_credentials
 | 
			
		||||
from kasa.deviceconfig import DeviceConfig
 | 
			
		||||
from kasa.exceptions import (
 | 
			
		||||
    SMART_AUTHENTICATION_ERRORS,
 | 
			
		||||
    SMART_RETRYABLE_ERRORS,
 | 
			
		||||
    AuthenticationError,
 | 
			
		||||
    DeviceError,
 | 
			
		||||
    KasaException,
 | 
			
		||||
    SmartErrorCode,
 | 
			
		||||
    _RetryableError,
 | 
			
		||||
)
 | 
			
		||||
from kasa.httpclient import HttpClient
 | 
			
		||||
from kasa.json import dumps as json_dumps
 | 
			
		||||
from kasa.json import loads as json_loads
 | 
			
		||||
from kasa.transports import BaseTransport
 | 
			
		||||
 | 
			
		||||
_LOGGER = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
ONE_DAY_SECONDS = 86400
 | 
			
		||||
SESSION_EXPIRE_BUFFER_SECONDS = 60 * 20
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _md5_hash(payload: bytes) -> str:
 | 
			
		||||
    return hashlib.md5(payload).hexdigest().upper()  # noqa: S324
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TransportState(Enum):
 | 
			
		||||
    """Enum for transport state."""
 | 
			
		||||
 | 
			
		||||
    LOGIN_REQUIRED = auto()  # Login needed
 | 
			
		||||
    ESTABLISHED = auto()  # Ready to send requests
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SslTransport(BaseTransport):
 | 
			
		||||
    """Implementation of the cleartext transport protocol.
 | 
			
		||||
 | 
			
		||||
    This transport uses HTTPS without any further payload encryption.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    DEFAULT_PORT: int = 4433
 | 
			
		||||
    COMMON_HEADERS = {
 | 
			
		||||
        "Content-Type": "application/json",
 | 
			
		||||
    }
 | 
			
		||||
    BACKOFF_SECONDS_AFTER_LOGIN_ERROR = 1
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        *,
 | 
			
		||||
        config: DeviceConfig,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        super().__init__(config=config)
 | 
			
		||||
 | 
			
		||||
        if (
 | 
			
		||||
            not self._credentials or self._credentials.username is None
 | 
			
		||||
        ) and not self._credentials_hash:
 | 
			
		||||
            self._credentials = Credentials()
 | 
			
		||||
 | 
			
		||||
        if self._credentials:
 | 
			
		||||
            self._login_params = self._get_login_params(self._credentials)
 | 
			
		||||
        else:
 | 
			
		||||
            self._login_params = json_loads(
 | 
			
		||||
                base64.b64decode(self._credentials_hash.encode()).decode()  # type: ignore[union-attr]
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        self._default_credentials: Credentials | None = None
 | 
			
		||||
        self._http_client: HttpClient = HttpClient(config)
 | 
			
		||||
 | 
			
		||||
        self._state = TransportState.LOGIN_REQUIRED
 | 
			
		||||
        self._session_expire_at: float | None = None
 | 
			
		||||
 | 
			
		||||
        self._app_url = URL(f"https://{self._host}:{self._port}/app")
 | 
			
		||||
 | 
			
		||||
        _LOGGER.debug("Created ssltransport for %s", self._host)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def default_port(self) -> int:
 | 
			
		||||
        """Default port for the transport."""
 | 
			
		||||
        return self.DEFAULT_PORT
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def credentials_hash(self) -> str:
 | 
			
		||||
        """The hashed credentials used by the transport."""
 | 
			
		||||
        return base64.b64encode(json_dumps(self._login_params).encode()).decode()
 | 
			
		||||
 | 
			
		||||
    def _get_login_params(self, credentials: Credentials) -> dict[str, str]:
 | 
			
		||||
        """Get the login parameters based on the login_version."""
 | 
			
		||||
        un, pw = self.hash_credentials(credentials)
 | 
			
		||||
        return {"password": pw, "username": un}
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def hash_credentials(credentials: Credentials) -> tuple[str, str]:
 | 
			
		||||
        """Hash the credentials."""
 | 
			
		||||
        un = credentials.username
 | 
			
		||||
        pw = _md5_hash(credentials.password.encode())
 | 
			
		||||
        return un, pw
 | 
			
		||||
 | 
			
		||||
    async def _handle_response_error_code(self, resp_dict: Any, msg: str) -> None:
 | 
			
		||||
        """Handle response errors to request reauth etc."""
 | 
			
		||||
        error_code = SmartErrorCode(resp_dict.get("error_code"))  # type: ignore[arg-type]
 | 
			
		||||
        if error_code == SmartErrorCode.SUCCESS:
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        msg = f"{msg}: {self._host}: {error_code.name}({error_code.value})"
 | 
			
		||||
 | 
			
		||||
        if error_code in SMART_RETRYABLE_ERRORS:
 | 
			
		||||
            raise _RetryableError(msg, error_code=error_code)
 | 
			
		||||
 | 
			
		||||
        if error_code in SMART_AUTHENTICATION_ERRORS:
 | 
			
		||||
            await self.reset()
 | 
			
		||||
            raise AuthenticationError(msg, error_code=error_code)
 | 
			
		||||
 | 
			
		||||
        raise DeviceError(msg, error_code=error_code)
 | 
			
		||||
 | 
			
		||||
    async def send_request(self, request: str) -> dict[str, Any]:
 | 
			
		||||
        """Send request."""
 | 
			
		||||
        url = self._app_url
 | 
			
		||||
 | 
			
		||||
        _LOGGER.debug("Sending %s to %s", request, url)
 | 
			
		||||
 | 
			
		||||
        status_code, resp_dict = await self._http_client.post(
 | 
			
		||||
            url,
 | 
			
		||||
            json=request,
 | 
			
		||||
            headers=self.COMMON_HEADERS,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if status_code != 200:
 | 
			
		||||
            raise KasaException(
 | 
			
		||||
                f"{self._host} responded with an unexpected "
 | 
			
		||||
                + f"status code {status_code}"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        _LOGGER.debug("Response with %s: %r", status_code, resp_dict)
 | 
			
		||||
 | 
			
		||||
        await self._handle_response_error_code(resp_dict, "Error sending request")
 | 
			
		||||
 | 
			
		||||
        if TYPE_CHECKING:
 | 
			
		||||
            resp_dict = cast(dict[str, Any], resp_dict)
 | 
			
		||||
 | 
			
		||||
        return resp_dict
 | 
			
		||||
 | 
			
		||||
    async def perform_login(self) -> None:
 | 
			
		||||
        """Login to the device."""
 | 
			
		||||
        try:
 | 
			
		||||
            await self.try_login(self._login_params)
 | 
			
		||||
        except AuthenticationError as aex:
 | 
			
		||||
            try:
 | 
			
		||||
                if aex.error_code is not SmartErrorCode.LOGIN_ERROR:
 | 
			
		||||
                    raise aex
 | 
			
		||||
 | 
			
		||||
                _LOGGER.debug("Login failed, going to try default credentials")
 | 
			
		||||
                if self._default_credentials is None:
 | 
			
		||||
                    self._default_credentials = get_default_credentials(
 | 
			
		||||
                        DEFAULT_CREDENTIALS["TAPO"]
 | 
			
		||||
                    )
 | 
			
		||||
                    await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_LOGIN_ERROR)
 | 
			
		||||
 | 
			
		||||
                await self.try_login(self._get_login_params(self._default_credentials))
 | 
			
		||||
                _LOGGER.debug(
 | 
			
		||||
                    "%s: logged in with default credentials",
 | 
			
		||||
                    self._host,
 | 
			
		||||
                )
 | 
			
		||||
            except AuthenticationError:
 | 
			
		||||
                raise
 | 
			
		||||
            except Exception as ex:
 | 
			
		||||
                raise KasaException(
 | 
			
		||||
                    "Unable to login and trying default "
 | 
			
		||||
                    + f"login raised another exception: {ex}",
 | 
			
		||||
                    ex,
 | 
			
		||||
                ) from ex
 | 
			
		||||
 | 
			
		||||
    async def try_login(self, login_params: dict[str, Any]) -> None:
 | 
			
		||||
        """Try to login with supplied login_params."""
 | 
			
		||||
        login_request = {
 | 
			
		||||
            "method": "login",
 | 
			
		||||
            "params": login_params,
 | 
			
		||||
        }
 | 
			
		||||
        request = json_dumps(login_request)
 | 
			
		||||
        _LOGGER.debug("Going to send login request")
 | 
			
		||||
 | 
			
		||||
        resp_dict = await self.send_request(request)
 | 
			
		||||
        await self._handle_response_error_code(resp_dict, "Error logging in")
 | 
			
		||||
 | 
			
		||||
        login_token = resp_dict["result"]["token"]
 | 
			
		||||
        self._app_url = self._app_url.with_query(f"token={login_token}")
 | 
			
		||||
        self._state = TransportState.ESTABLISHED
 | 
			
		||||
        self._session_expire_at = (
 | 
			
		||||
            time.time() + ONE_DAY_SECONDS - SESSION_EXPIRE_BUFFER_SECONDS
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _session_expired(self) -> bool:
 | 
			
		||||
        """Return true if session has expired."""
 | 
			
		||||
        return (
 | 
			
		||||
            self._session_expire_at is None
 | 
			
		||||
            or self._session_expire_at - time.time() <= 0
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    async def send(self, request: str) -> dict[str, Any]:
 | 
			
		||||
        """Send the request."""
 | 
			
		||||
        _LOGGER.info("Going to send %s", request)
 | 
			
		||||
        if self._state is not TransportState.ESTABLISHED or self._session_expired():
 | 
			
		||||
            _LOGGER.debug("Transport not established or session expired, logging in")
 | 
			
		||||
            await self.perform_login()
 | 
			
		||||
 | 
			
		||||
        return await self.send_request(request)
 | 
			
		||||
 | 
			
		||||
    async def close(self) -> None:
 | 
			
		||||
        """Close the http client and reset internal state."""
 | 
			
		||||
        await self.reset()
 | 
			
		||||
        await self._http_client.close()
 | 
			
		||||
 | 
			
		||||
    async def reset(self) -> None:
 | 
			
		||||
        """Reset internal login state."""
 | 
			
		||||
        self._state = TransportState.LOGIN_REQUIRED
 | 
			
		||||
        self._app_url = URL(f"https://{self._host}:{self._port}/app")
 | 
			
		||||
@@ -692,6 +692,8 @@ async def test_credentials(discovery_mock, mocker, runner):
 | 
			
		||||
            dr.device_type,
 | 
			
		||||
            "--encrypt-type",
 | 
			
		||||
            dr.mgt_encrypt_schm.encrypt_type,
 | 
			
		||||
            "--login-version",
 | 
			
		||||
            dr.mgt_encrypt_schm.lv or 1,
 | 
			
		||||
        ],
 | 
			
		||||
    )
 | 
			
		||||
    assert res.exit_code == 0
 | 
			
		||||
 
 | 
			
		||||
@@ -47,7 +47,10 @@ def _get_connection_type_device_class(discovery_info):
 | 
			
		||||
        dr = DiscoveryResult.from_dict(discovery_info["result"])
 | 
			
		||||
 | 
			
		||||
        connection_type = DeviceConnectionParameters.from_values(
 | 
			
		||||
            dr.device_type, dr.mgt_encrypt_schm.encrypt_type
 | 
			
		||||
            dr.device_type,
 | 
			
		||||
            dr.mgt_encrypt_schm.encrypt_type,
 | 
			
		||||
            dr.mgt_encrypt_schm.lv,
 | 
			
		||||
            dr.mgt_encrypt_schm.is_support_https,
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        connection_type = DeviceConnectionParameters.from_values(
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										374
									
								
								tests/transports/test_ssltransport.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										374
									
								
								tests/transports/test_ssltransport.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,374 @@
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
 | 
			
		||||
import logging
 | 
			
		||||
from base64 import b64encode
 | 
			
		||||
from contextlib import nullcontext as does_not_raise
 | 
			
		||||
from typing import Any
 | 
			
		||||
 | 
			
		||||
import aiohttp
 | 
			
		||||
import pytest
 | 
			
		||||
from yarl import URL
 | 
			
		||||
 | 
			
		||||
from kasa.credentials import DEFAULT_CREDENTIALS, Credentials, get_default_credentials
 | 
			
		||||
from kasa.deviceconfig import DeviceConfig
 | 
			
		||||
from kasa.exceptions import (
 | 
			
		||||
    AuthenticationError,
 | 
			
		||||
    DeviceError,
 | 
			
		||||
    KasaException,
 | 
			
		||||
    SmartErrorCode,
 | 
			
		||||
    _RetryableError,
 | 
			
		||||
)
 | 
			
		||||
from kasa.httpclient import HttpClient
 | 
			
		||||
from kasa.json import dumps as json_dumps
 | 
			
		||||
from kasa.json import loads as json_loads
 | 
			
		||||
from kasa.transports import SslTransport
 | 
			
		||||
from kasa.transports.ssltransport import TransportState, _md5_hash
 | 
			
		||||
 | 
			
		||||
# Transport tests are not designed for real devices
 | 
			
		||||
pytestmark = [pytest.mark.requires_dummy]
 | 
			
		||||
 | 
			
		||||
MOCK_PWD = "correct_pwd"  # noqa: S105
 | 
			
		||||
MOCK_USER = "mock@example.com"
 | 
			
		||||
MOCK_BAD_USER_OR_PWD = "foobar"  # noqa: S105
 | 
			
		||||
MOCK_TOKEN = "abcdefghijklmnopqrstuvwxyz1234)("  # noqa: S105
 | 
			
		||||
 | 
			
		||||
DEFAULT_CREDS = get_default_credentials(DEFAULT_CREDENTIALS["TAPO"])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_LOGGER = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize(
 | 
			
		||||
    (
 | 
			
		||||
        "status_code",
 | 
			
		||||
        "error_code",
 | 
			
		||||
        "username",
 | 
			
		||||
        "password",
 | 
			
		||||
        "expectation",
 | 
			
		||||
    ),
 | 
			
		||||
    [
 | 
			
		||||
        pytest.param(
 | 
			
		||||
            200,
 | 
			
		||||
            SmartErrorCode.SUCCESS,
 | 
			
		||||
            MOCK_USER,
 | 
			
		||||
            MOCK_PWD,
 | 
			
		||||
            does_not_raise(),
 | 
			
		||||
            id="success",
 | 
			
		||||
        ),
 | 
			
		||||
        pytest.param(
 | 
			
		||||
            200,
 | 
			
		||||
            SmartErrorCode.UNSPECIFIC_ERROR,
 | 
			
		||||
            MOCK_USER,
 | 
			
		||||
            MOCK_PWD,
 | 
			
		||||
            pytest.raises(_RetryableError),
 | 
			
		||||
            id="test retry",
 | 
			
		||||
        ),
 | 
			
		||||
        pytest.param(
 | 
			
		||||
            200,
 | 
			
		||||
            SmartErrorCode.DEVICE_BLOCKED,
 | 
			
		||||
            MOCK_USER,
 | 
			
		||||
            MOCK_PWD,
 | 
			
		||||
            pytest.raises(DeviceError),
 | 
			
		||||
            id="test regular error",
 | 
			
		||||
        ),
 | 
			
		||||
        pytest.param(
 | 
			
		||||
            400,
 | 
			
		||||
            SmartErrorCode.INTERNAL_UNKNOWN_ERROR,
 | 
			
		||||
            MOCK_USER,
 | 
			
		||||
            MOCK_PWD,
 | 
			
		||||
            pytest.raises(KasaException),
 | 
			
		||||
            id="400 error",
 | 
			
		||||
        ),
 | 
			
		||||
        pytest.param(
 | 
			
		||||
            200,
 | 
			
		||||
            SmartErrorCode.LOGIN_ERROR,
 | 
			
		||||
            MOCK_BAD_USER_OR_PWD,
 | 
			
		||||
            MOCK_PWD,
 | 
			
		||||
            pytest.raises(AuthenticationError),
 | 
			
		||||
            id="bad-username",
 | 
			
		||||
        ),
 | 
			
		||||
        pytest.param(
 | 
			
		||||
            200,
 | 
			
		||||
            [SmartErrorCode.LOGIN_ERROR, SmartErrorCode.SUCCESS],
 | 
			
		||||
            MOCK_BAD_USER_OR_PWD,
 | 
			
		||||
            "",
 | 
			
		||||
            does_not_raise(),
 | 
			
		||||
            id="working-fallback",
 | 
			
		||||
        ),
 | 
			
		||||
        pytest.param(
 | 
			
		||||
            200,
 | 
			
		||||
            [SmartErrorCode.LOGIN_ERROR, SmartErrorCode.LOGIN_ERROR],
 | 
			
		||||
            MOCK_BAD_USER_OR_PWD,
 | 
			
		||||
            "",
 | 
			
		||||
            pytest.raises(AuthenticationError),
 | 
			
		||||
            id="fallback-fail",
 | 
			
		||||
        ),
 | 
			
		||||
        pytest.param(
 | 
			
		||||
            200,
 | 
			
		||||
            SmartErrorCode.LOGIN_ERROR,
 | 
			
		||||
            MOCK_USER,
 | 
			
		||||
            MOCK_BAD_USER_OR_PWD,
 | 
			
		||||
            pytest.raises(AuthenticationError),
 | 
			
		||||
            id="bad-password",
 | 
			
		||||
        ),
 | 
			
		||||
        pytest.param(
 | 
			
		||||
            200,
 | 
			
		||||
            SmartErrorCode.TRANSPORT_UNKNOWN_CREDENTIALS_ERROR,
 | 
			
		||||
            MOCK_USER,
 | 
			
		||||
            MOCK_PWD,
 | 
			
		||||
            pytest.raises(AuthenticationError),
 | 
			
		||||
            id="auth-error != login_error",
 | 
			
		||||
        ),
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
async def test_login(
 | 
			
		||||
    mocker,
 | 
			
		||||
    status_code,
 | 
			
		||||
    error_code,
 | 
			
		||||
    username,
 | 
			
		||||
    password,
 | 
			
		||||
    expectation,
 | 
			
		||||
):
 | 
			
		||||
    host = "127.0.0.1"
 | 
			
		||||
    mock_ssl_aes_device = MockSslDevice(
 | 
			
		||||
        host,
 | 
			
		||||
        status_code=status_code,
 | 
			
		||||
        send_error_code=error_code,
 | 
			
		||||
    )
 | 
			
		||||
    mocker.patch.object(
 | 
			
		||||
        aiohttp.ClientSession, "post", side_effect=mock_ssl_aes_device.post
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    transport = SslTransport(
 | 
			
		||||
        config=DeviceConfig(host, credentials=Credentials(username, password))
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    assert transport._state is TransportState.LOGIN_REQUIRED
 | 
			
		||||
    with expectation:
 | 
			
		||||
        await transport.perform_login()
 | 
			
		||||
        assert transport._state is TransportState.ESTABLISHED
 | 
			
		||||
 | 
			
		||||
    await transport.close()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def test_credentials_hash(mocker):
 | 
			
		||||
    host = "127.0.0.1"
 | 
			
		||||
    mock_ssl_aes_device = MockSslDevice(host)
 | 
			
		||||
    mocker.patch.object(
 | 
			
		||||
        aiohttp.ClientSession, "post", side_effect=mock_ssl_aes_device.post
 | 
			
		||||
    )
 | 
			
		||||
    creds = Credentials(MOCK_USER, MOCK_PWD)
 | 
			
		||||
 | 
			
		||||
    data = {"password": _md5_hash(MOCK_PWD.encode()), "username": MOCK_USER}
 | 
			
		||||
 | 
			
		||||
    creds_hash = b64encode(json_dumps(data).encode()).decode()
 | 
			
		||||
 | 
			
		||||
    # Test with credentials input
 | 
			
		||||
    transport = SslTransport(config=DeviceConfig(host, credentials=creds))
 | 
			
		||||
    assert transport.credentials_hash == creds_hash
 | 
			
		||||
 | 
			
		||||
    # Test with credentials_hash input
 | 
			
		||||
    transport = SslTransport(config=DeviceConfig(host, credentials_hash=creds_hash))
 | 
			
		||||
    assert transport.credentials_hash == creds_hash
 | 
			
		||||
 | 
			
		||||
    await transport.close()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def test_send(mocker):
 | 
			
		||||
    host = "127.0.0.1"
 | 
			
		||||
    mock_ssl_aes_device = MockSslDevice(host, send_error_code=SmartErrorCode.SUCCESS)
 | 
			
		||||
    mocker.patch.object(
 | 
			
		||||
        aiohttp.ClientSession, "post", side_effect=mock_ssl_aes_device.post
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    transport = SslTransport(
 | 
			
		||||
        config=DeviceConfig(host, credentials=Credentials(MOCK_USER, MOCK_PWD))
 | 
			
		||||
    )
 | 
			
		||||
    try_login_spy = mocker.spy(transport, "try_login")
 | 
			
		||||
    request = {
 | 
			
		||||
        "method": "get_device_info",
 | 
			
		||||
        "params": None,
 | 
			
		||||
    }
 | 
			
		||||
    assert transport._state is TransportState.LOGIN_REQUIRED
 | 
			
		||||
 | 
			
		||||
    res = await transport.send(json_dumps(request))
 | 
			
		||||
    assert "result" in res
 | 
			
		||||
    try_login_spy.assert_called_once()
 | 
			
		||||
    assert transport._state is TransportState.ESTABLISHED
 | 
			
		||||
 | 
			
		||||
    # Second request does not
 | 
			
		||||
    res = await transport.send(json_dumps(request))
 | 
			
		||||
    try_login_spy.assert_called_once()
 | 
			
		||||
 | 
			
		||||
    await transport.close()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def test_no_credentials(mocker):
 | 
			
		||||
    """Test transport without credentials."""
 | 
			
		||||
    host = "127.0.0.1"
 | 
			
		||||
    mock_ssl_aes_device = MockSslDevice(
 | 
			
		||||
        host, send_error_code=SmartErrorCode.LOGIN_ERROR
 | 
			
		||||
    )
 | 
			
		||||
    mocker.patch.object(
 | 
			
		||||
        aiohttp.ClientSession, "post", side_effect=mock_ssl_aes_device.post
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    transport = SslTransport(config=DeviceConfig(host))
 | 
			
		||||
    try_login_spy = mocker.spy(transport, "try_login")
 | 
			
		||||
 | 
			
		||||
    with pytest.raises(AuthenticationError):
 | 
			
		||||
        await transport.send('{"method": "dummy"}')
 | 
			
		||||
 | 
			
		||||
    # We get called twice
 | 
			
		||||
    assert try_login_spy.call_count == 2
 | 
			
		||||
 | 
			
		||||
    await transport.close()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def test_reset(mocker):
 | 
			
		||||
    """Test that transport state adjusts correctly for reset."""
 | 
			
		||||
    host = "127.0.0.1"
 | 
			
		||||
    mock_ssl_aes_device = MockSslDevice(host, send_error_code=SmartErrorCode.SUCCESS)
 | 
			
		||||
    mocker.patch.object(
 | 
			
		||||
        aiohttp.ClientSession, "post", side_effect=mock_ssl_aes_device.post
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    transport = SslTransport(
 | 
			
		||||
        config=DeviceConfig(host, credentials=Credentials(MOCK_USER, MOCK_PWD))
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    assert transport._state is TransportState.LOGIN_REQUIRED
 | 
			
		||||
    assert str(transport._app_url) == "https://127.0.0.1:4433/app"
 | 
			
		||||
 | 
			
		||||
    await transport.perform_login()
 | 
			
		||||
    assert transport._state is TransportState.ESTABLISHED
 | 
			
		||||
    assert str(transport._app_url).startswith("https://127.0.0.1:4433/app?token=")
 | 
			
		||||
 | 
			
		||||
    await transport.close()
 | 
			
		||||
    assert transport._state is TransportState.LOGIN_REQUIRED
 | 
			
		||||
    assert str(transport._app_url) == "https://127.0.0.1:4433/app"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def test_port_override():
 | 
			
		||||
    """Test that port override sets the app_url."""
 | 
			
		||||
    host = "127.0.0.1"
 | 
			
		||||
    port_override = 12345
 | 
			
		||||
    config = DeviceConfig(
 | 
			
		||||
        host, credentials=Credentials("foo", "bar"), port_override=port_override
 | 
			
		||||
    )
 | 
			
		||||
    transport = SslTransport(config=config)
 | 
			
		||||
 | 
			
		||||
    assert str(transport._app_url) == f"https://127.0.0.1:{port_override}/app"
 | 
			
		||||
 | 
			
		||||
    await transport.close()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MockSslDevice:
 | 
			
		||||
    """Based on MockAesSslDevice."""
 | 
			
		||||
 | 
			
		||||
    class _mock_response:
 | 
			
		||||
        def __init__(self, status, request: dict):
 | 
			
		||||
            self.status = status
 | 
			
		||||
            self._json = request
 | 
			
		||||
 | 
			
		||||
        async def __aenter__(self):
 | 
			
		||||
            return self
 | 
			
		||||
 | 
			
		||||
        async def __aexit__(self, exc_t, exc_v, exc_tb):
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
        async def read(self):
 | 
			
		||||
            if isinstance(self._json, dict):
 | 
			
		||||
                return json_dumps(self._json).encode()
 | 
			
		||||
            return self._json
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        host,
 | 
			
		||||
        *,
 | 
			
		||||
        status_code=200,
 | 
			
		||||
        send_error_code=SmartErrorCode.INTERNAL_UNKNOWN_ERROR,
 | 
			
		||||
    ):
 | 
			
		||||
        self.host = host
 | 
			
		||||
        self.http_client = HttpClient(DeviceConfig(self.host))
 | 
			
		||||
 | 
			
		||||
        self._state = TransportState.LOGIN_REQUIRED
 | 
			
		||||
 | 
			
		||||
        # test behaviour attributes
 | 
			
		||||
        self.status_code = status_code
 | 
			
		||||
        self.send_error_code = send_error_code
 | 
			
		||||
 | 
			
		||||
    async def post(self, url: URL, params=None, json=None, data=None, *_, **__):
 | 
			
		||||
        if data:
 | 
			
		||||
            json = json_loads(data)
 | 
			
		||||
        _LOGGER.debug("Request %s: %s", url, json)
 | 
			
		||||
        res = self._post(url, json)
 | 
			
		||||
        _LOGGER.debug("Response %s, data: %s", res, await res.read())
 | 
			
		||||
        return res
 | 
			
		||||
 | 
			
		||||
    def _post(self, url: URL, json: dict[str, Any]):
 | 
			
		||||
        method = json["method"]
 | 
			
		||||
 | 
			
		||||
        if method == "login":
 | 
			
		||||
            if self._state is TransportState.LOGIN_REQUIRED:
 | 
			
		||||
                assert json.get("token") is None
 | 
			
		||||
                assert url == URL(f"https://{self.host}:4433/app")
 | 
			
		||||
                return self._return_login_response(url, json)
 | 
			
		||||
            else:
 | 
			
		||||
                _LOGGER.warning("Received login although already logged in")
 | 
			
		||||
                pytest.fail("non-handled re-login logic")
 | 
			
		||||
 | 
			
		||||
        assert url == URL(f"https://{self.host}:4433/app?token={MOCK_TOKEN}")
 | 
			
		||||
        return self._return_send_response(url, json)
 | 
			
		||||
 | 
			
		||||
    def _return_login_response(self, url: URL, request: dict[str, Any]):
 | 
			
		||||
        request_username = request["params"].get("username")
 | 
			
		||||
        request_password = request["params"].get("password")
 | 
			
		||||
 | 
			
		||||
        # Handle multiple error codes
 | 
			
		||||
        if isinstance(self.send_error_code, list):
 | 
			
		||||
            error_code = self.send_error_code.pop(0)
 | 
			
		||||
        else:
 | 
			
		||||
            error_code = self.send_error_code
 | 
			
		||||
 | 
			
		||||
        _LOGGER.debug("Using error code %s", error_code)
 | 
			
		||||
 | 
			
		||||
        def _return_login_error():
 | 
			
		||||
            resp = {
 | 
			
		||||
                "error_code": error_code.value,
 | 
			
		||||
                "result": {"unknown": "payload"},
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            _LOGGER.debug("Returning login error with status %s", self.status_code)
 | 
			
		||||
            return self._mock_response(self.status_code, resp)
 | 
			
		||||
 | 
			
		||||
        if error_code is not SmartErrorCode.SUCCESS:
 | 
			
		||||
            # Bad username
 | 
			
		||||
            if request_username == MOCK_BAD_USER_OR_PWD:
 | 
			
		||||
                return _return_login_error()
 | 
			
		||||
 | 
			
		||||
            # Bad password
 | 
			
		||||
            if request_password == _md5_hash(MOCK_BAD_USER_OR_PWD.encode()):
 | 
			
		||||
                return _return_login_error()
 | 
			
		||||
 | 
			
		||||
            # Empty password
 | 
			
		||||
            if request_password == _md5_hash(b""):
 | 
			
		||||
                return _return_login_error()
 | 
			
		||||
 | 
			
		||||
        self._state = TransportState.ESTABLISHED
 | 
			
		||||
        resp = {
 | 
			
		||||
            "error_code": error_code.value,
 | 
			
		||||
            "result": {
 | 
			
		||||
                "token": MOCK_TOKEN,
 | 
			
		||||
            },
 | 
			
		||||
        }
 | 
			
		||||
        _LOGGER.debug("Returning login success with status %s", self.status_code)
 | 
			
		||||
        return self._mock_response(self.status_code, resp)
 | 
			
		||||
 | 
			
		||||
    def _return_send_response(self, url: URL, json: dict[str, Any]):
 | 
			
		||||
        method = json["method"]
 | 
			
		||||
        result = {
 | 
			
		||||
            "result": {method: {"dummy": "response"}},
 | 
			
		||||
            "error_code": self.send_error_code.value,
 | 
			
		||||
        }
 | 
			
		||||
        return self._mock_response(self.status_code, result)
 | 
			
		||||
		Reference in New Issue
	
	Block a user