Add ssltransport for robovacs (#943)

This PR implements a clear-text, token-based transport protocol seen on
RV30 Plus (#937).

- Client sends `{"username": "email@example.com", "password":
md5(password)}` and gets back a token in the response
- Rest of the communications are done with POST at `/app?token=<token>`

---------

Co-authored-by: Steven B. <51370195+sdb9696@users.noreply.github.com>
This commit is contained in:
Teemu R. 2024-12-01 18:06:48 +01:00 committed by GitHub
parent 9a52056522
commit 9966c6094a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 656 additions and 16 deletions

View File

@ -427,25 +427,25 @@ COMPONENT_REQUESTS = {
"overheat_protection": [], "overheat_protection": [],
# Vacuum components # Vacuum components
"clean": [ "clean": [
SmartRequest.get_raw_request("get_clean_records"), SmartRequest.get_raw_request("getCleanRecords"),
SmartRequest.get_raw_request("get_vac_state"), SmartRequest.get_raw_request("getVacStatus"),
], ],
"battery": [SmartRequest.get_raw_request("get_battery_info")], "battery": [SmartRequest.get_raw_request("getBatteryInfo")],
"consumables": [SmartRequest.get_raw_request("get_consumables_info")], "consumables": [SmartRequest.get_raw_request("getConsumablesInfo")],
"direction_control": [], "direction_control": [],
"button_and_led": [], "button_and_led": [],
"speaker": [ "speaker": [
SmartRequest.get_raw_request("get_support_voice_language"), SmartRequest.get_raw_request("getSupportVoiceLanguage"),
SmartRequest.get_raw_request("get_current_voice_language"), SmartRequest.get_raw_request("getCurrentVoiceLanguage"),
], ],
"map": [ "map": [
SmartRequest.get_raw_request("get_map_info"), SmartRequest.get_raw_request("getMapInfo"),
SmartRequest.get_raw_request("get_map_data"), SmartRequest.get_raw_request("getMapData"),
], ],
"auto_change_map": [SmartRequest.get_raw_request("get_auto_change_map")], "auto_change_map": [SmartRequest.get_raw_request("getAutoChangeMap")],
"dust_bucket": [SmartRequest.get_raw_request("get_auto_dust_collection")], "dust_bucket": [SmartRequest.get_raw_request("getAutoDustCollection")],
"mop": [SmartRequest.get_raw_request("get_mop_state")], "mop": [SmartRequest.get_raw_request("getMopState")],
"do_not_disturb": [SmartRequest.get_raw_request("get_do_not_disturb")], "do_not_disturb": [SmartRequest.get_raw_request("getDoNotDisturb")],
"charge_pose_clean": [], "charge_pose_clean": [],
"continue_breakpoint_sweep": [], "continue_breakpoint_sweep": [],
"goto_point": [], "goto_point": [],

View File

@ -308,6 +308,7 @@ async def cli(
if type == "camera": if type == "camera":
encrypt_type = "AES" encrypt_type = "AES"
https = True https = True
login_version = 2
device_family = "SMART.IPCAMERA" device_family = "SMART.IPCAMERA"
from kasa.device import Device from kasa.device import Device

View File

@ -32,6 +32,7 @@ from .transports import (
BaseTransport, BaseTransport,
KlapTransport, KlapTransport,
KlapTransportV2, KlapTransportV2,
SslTransport,
XorTransport, XorTransport,
) )
from .transports.sslaestransport import SslAesTransport from .transports.sslaestransport import SslAesTransport
@ -155,6 +156,7 @@ def get_device_class_from_family(
"SMART.KASAHUB": SmartDevice, "SMART.KASAHUB": SmartDevice,
"SMART.KASASWITCH": SmartDevice, "SMART.KASASWITCH": SmartDevice,
"SMART.IPCAMERA.HTTPS": SmartCamDevice, "SMART.IPCAMERA.HTTPS": SmartCamDevice,
"SMART.TAPOROBOVAC": SmartDevice,
"IOT.SMARTPLUGSWITCH": IotPlug, "IOT.SMARTPLUGSWITCH": IotPlug,
"IOT.SMARTBULB": IotBulb, "IOT.SMARTBULB": IotBulb,
} }
@ -176,20 +178,30 @@ def get_protocol(
"""Return the protocol from the connection name.""" """Return the protocol from the connection name."""
protocol_name = config.connection_type.device_family.value.split(".")[0] protocol_name = config.connection_type.device_family.value.split(".")[0]
ctype = config.connection_type ctype = config.connection_type
protocol_transport_key = ( protocol_transport_key = (
protocol_name protocol_name
+ "." + "."
+ ctype.encryption_type.value + ctype.encryption_type.value
+ (".HTTPS" if ctype.https else "") + (".HTTPS" if ctype.https else "")
+ (
f".{ctype.login_version}"
if ctype.login_version and ctype.login_version > 1
else ""
) )
)
_LOGGER.debug("Finding transport for %s", protocol_transport_key)
supported_device_protocols: dict[ supported_device_protocols: dict[
str, tuple[type[BaseProtocol], type[BaseTransport]] str, tuple[type[BaseProtocol], type[BaseTransport]]
] = { ] = {
"IOT.XOR": (IotProtocol, XorTransport), "IOT.XOR": (IotProtocol, XorTransport),
"IOT.KLAP": (IotProtocol, KlapTransport), "IOT.KLAP": (IotProtocol, KlapTransport),
"SMART.AES": (SmartProtocol, AesTransport), "SMART.AES": (SmartProtocol, AesTransport),
"SMART.KLAP": (SmartProtocol, KlapTransportV2), "SMART.AES.2": (SmartProtocol, AesTransport),
"SMART.AES.HTTPS": (SmartCamProtocol, SslAesTransport), "SMART.KLAP.2": (SmartProtocol, KlapTransportV2),
"SMART.AES.HTTPS.2": (SmartCamProtocol, SslAesTransport),
"SMART.AES.HTTPS": (SmartProtocol, SslTransport),
} }
if not (prot_tran_cls := supported_device_protocols.get(protocol_transport_key)): if not (prot_tran_cls := supported_device_protocols.get(protocol_transport_key)):
return None return None

View File

@ -21,6 +21,7 @@ class DeviceType(Enum):
Hub = "hub" Hub = "hub"
Fan = "fan" Fan = "fan"
Thermostat = "thermostat" Thermostat = "thermostat"
Vacuum = "vacuum"
Unknown = "unknown" Unknown = "unknown"
@staticmethod @staticmethod

View File

@ -77,6 +77,7 @@ class DeviceFamily(Enum):
SmartTapoHub = "SMART.TAPOHUB" SmartTapoHub = "SMART.TAPOHUB"
SmartKasaHub = "SMART.KASAHUB" SmartKasaHub = "SMART.KASAHUB"
SmartIpCamera = "SMART.IPCAMERA" SmartIpCamera = "SMART.IPCAMERA"
SmartTapoRobovac = "SMART.TAPOROBOVAC"
class _DeviceConfigBaseMixin(DataClassJSONMixin): class _DeviceConfigBaseMixin(DataClassJSONMixin):

View File

@ -598,10 +598,12 @@ class Discover:
for encrypt in Device.EncryptionType for encrypt in Device.EncryptionType
for device_family in main_device_families for device_family in main_device_families
for https in (True, False) for https in (True, False)
for login_version in (None, 2)
if ( if (
conn_params := DeviceConnectionParameters( conn_params := DeviceConnectionParameters(
device_family=device_family, device_family=device_family,
encryption_type=encrypt, encryption_type=encrypt,
login_version=login_version,
https=https, https=https,
) )
) )
@ -768,6 +770,13 @@ class Discover:
): ):
encrypt_type = encrypt_info.sym_schm encrypt_type = encrypt_info.sym_schm
if (
not (login_version := encrypt_schm.lv)
and (et := discovery_result.encrypt_type)
and et == ["3"]
):
login_version = 2
if not encrypt_type: if not encrypt_type:
raise UnsupportedDeviceError( raise UnsupportedDeviceError(
f"Unsupported device {config.host} of type {type_} " f"Unsupported device {config.host} of type {type_} "
@ -778,7 +787,7 @@ class Discover:
config.connection_type = DeviceConnectionParameters.from_values( config.connection_type = DeviceConnectionParameters.from_values(
type_, type_,
encrypt_type, encrypt_type,
encrypt_schm.lv, login_version,
encrypt_schm.is_support_https, encrypt_schm.is_support_https,
) )
except KasaException as ex: except KasaException as ex:

View File

@ -802,6 +802,8 @@ class SmartDevice(Device):
return DeviceType.Sensor return DeviceType.Sensor
if "ENERGY" in device_type: if "ENERGY" in device_type:
return DeviceType.Thermostat return DeviceType.Thermostat
if "ROBOVAC" in device_type:
return DeviceType.Vacuum
_LOGGER.warning("Unknown device type, falling back to plug") _LOGGER.warning("Unknown device type, falling back to plug")
return DeviceType.Plug return DeviceType.Plug

View File

@ -3,11 +3,13 @@
from .aestransport import AesEncyptionSession, AesTransport from .aestransport import AesEncyptionSession, AesTransport
from .basetransport import BaseTransport from .basetransport import BaseTransport
from .klaptransport import KlapTransport, KlapTransportV2 from .klaptransport import KlapTransport, KlapTransportV2
from .ssltransport import SslTransport
from .xortransport import XorEncryption, XorTransport from .xortransport import XorEncryption, XorTransport
__all__ = [ __all__ = [
"AesTransport", "AesTransport",
"AesEncyptionSession", "AesEncyptionSession",
"SslTransport",
"BaseTransport", "BaseTransport",
"KlapTransport", "KlapTransport",
"KlapTransportV2", "KlapTransportV2",

View File

@ -0,0 +1,233 @@
"""Implementation of the clear-text passthrough ssl transport.
This transport does not encrypt the passthrough payloads at all, but requires a login.
This has been seen on some devices (like robovacs).
"""
from __future__ import annotations
import asyncio
import base64
import hashlib
import logging
import time
from enum import Enum, auto
from typing import TYPE_CHECKING, Any, cast
from yarl import URL
from kasa.credentials import DEFAULT_CREDENTIALS, Credentials, get_default_credentials
from kasa.deviceconfig import DeviceConfig
from kasa.exceptions import (
SMART_AUTHENTICATION_ERRORS,
SMART_RETRYABLE_ERRORS,
AuthenticationError,
DeviceError,
KasaException,
SmartErrorCode,
_RetryableError,
)
from kasa.httpclient import HttpClient
from kasa.json import dumps as json_dumps
from kasa.json import loads as json_loads
from kasa.transports import BaseTransport
_LOGGER = logging.getLogger(__name__)
ONE_DAY_SECONDS = 86400
SESSION_EXPIRE_BUFFER_SECONDS = 60 * 20
def _md5_hash(payload: bytes) -> str:
return hashlib.md5(payload).hexdigest().upper() # noqa: S324
class TransportState(Enum):
"""Enum for transport state."""
LOGIN_REQUIRED = auto() # Login needed
ESTABLISHED = auto() # Ready to send requests
class SslTransport(BaseTransport):
"""Implementation of the cleartext transport protocol.
This transport uses HTTPS without any further payload encryption.
"""
DEFAULT_PORT: int = 4433
COMMON_HEADERS = {
"Content-Type": "application/json",
}
BACKOFF_SECONDS_AFTER_LOGIN_ERROR = 1
def __init__(
self,
*,
config: DeviceConfig,
) -> None:
super().__init__(config=config)
if (
not self._credentials or self._credentials.username is None
) and not self._credentials_hash:
self._credentials = Credentials()
if self._credentials:
self._login_params = self._get_login_params(self._credentials)
else:
self._login_params = json_loads(
base64.b64decode(self._credentials_hash.encode()).decode() # type: ignore[union-attr]
)
self._default_credentials: Credentials | None = None
self._http_client: HttpClient = HttpClient(config)
self._state = TransportState.LOGIN_REQUIRED
self._session_expire_at: float | None = None
self._app_url = URL(f"https://{self._host}:{self._port}/app")
_LOGGER.debug("Created ssltransport for %s", self._host)
@property
def default_port(self) -> int:
"""Default port for the transport."""
return self.DEFAULT_PORT
@property
def credentials_hash(self) -> str:
"""The hashed credentials used by the transport."""
return base64.b64encode(json_dumps(self._login_params).encode()).decode()
def _get_login_params(self, credentials: Credentials) -> dict[str, str]:
"""Get the login parameters based on the login_version."""
un, pw = self.hash_credentials(credentials)
return {"password": pw, "username": un}
@staticmethod
def hash_credentials(credentials: Credentials) -> tuple[str, str]:
"""Hash the credentials."""
un = credentials.username
pw = _md5_hash(credentials.password.encode())
return un, pw
async def _handle_response_error_code(self, resp_dict: Any, msg: str) -> None:
"""Handle response errors to request reauth etc."""
error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
if error_code == SmartErrorCode.SUCCESS:
return
msg = f"{msg}: {self._host}: {error_code.name}({error_code.value})"
if error_code in SMART_RETRYABLE_ERRORS:
raise _RetryableError(msg, error_code=error_code)
if error_code in SMART_AUTHENTICATION_ERRORS:
await self.reset()
raise AuthenticationError(msg, error_code=error_code)
raise DeviceError(msg, error_code=error_code)
async def send_request(self, request: str) -> dict[str, Any]:
"""Send request."""
url = self._app_url
_LOGGER.debug("Sending %s to %s", request, url)
status_code, resp_dict = await self._http_client.post(
url,
json=request,
headers=self.COMMON_HEADERS,
)
if status_code != 200:
raise KasaException(
f"{self._host} responded with an unexpected "
+ f"status code {status_code}"
)
_LOGGER.debug("Response with %s: %r", status_code, resp_dict)
await self._handle_response_error_code(resp_dict, "Error sending request")
if TYPE_CHECKING:
resp_dict = cast(dict[str, Any], resp_dict)
return resp_dict
async def perform_login(self) -> None:
"""Login to the device."""
try:
await self.try_login(self._login_params)
except AuthenticationError as aex:
try:
if aex.error_code is not SmartErrorCode.LOGIN_ERROR:
raise aex
_LOGGER.debug("Login failed, going to try default credentials")
if self._default_credentials is None:
self._default_credentials = get_default_credentials(
DEFAULT_CREDENTIALS["TAPO"]
)
await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_LOGIN_ERROR)
await self.try_login(self._get_login_params(self._default_credentials))
_LOGGER.debug(
"%s: logged in with default credentials",
self._host,
)
except AuthenticationError:
raise
except Exception as ex:
raise KasaException(
"Unable to login and trying default "
+ f"login raised another exception: {ex}",
ex,
) from ex
async def try_login(self, login_params: dict[str, Any]) -> None:
"""Try to login with supplied login_params."""
login_request = {
"method": "login",
"params": login_params,
}
request = json_dumps(login_request)
_LOGGER.debug("Going to send login request")
resp_dict = await self.send_request(request)
await self._handle_response_error_code(resp_dict, "Error logging in")
login_token = resp_dict["result"]["token"]
self._app_url = self._app_url.with_query(f"token={login_token}")
self._state = TransportState.ESTABLISHED
self._session_expire_at = (
time.time() + ONE_DAY_SECONDS - SESSION_EXPIRE_BUFFER_SECONDS
)
def _session_expired(self) -> bool:
"""Return true if session has expired."""
return (
self._session_expire_at is None
or self._session_expire_at - time.time() <= 0
)
async def send(self, request: str) -> dict[str, Any]:
"""Send the request."""
_LOGGER.info("Going to send %s", request)
if self._state is not TransportState.ESTABLISHED or self._session_expired():
_LOGGER.debug("Transport not established or session expired, logging in")
await self.perform_login()
return await self.send_request(request)
async def close(self) -> None:
"""Close the http client and reset internal state."""
await self.reset()
await self._http_client.close()
async def reset(self) -> None:
"""Reset internal login state."""
self._state = TransportState.LOGIN_REQUIRED
self._app_url = URL(f"https://{self._host}:{self._port}/app")

View File

@ -692,6 +692,8 @@ async def test_credentials(discovery_mock, mocker, runner):
dr.device_type, dr.device_type,
"--encrypt-type", "--encrypt-type",
dr.mgt_encrypt_schm.encrypt_type, dr.mgt_encrypt_schm.encrypt_type,
"--login-version",
dr.mgt_encrypt_schm.lv or 1,
], ],
) )
assert res.exit_code == 0 assert res.exit_code == 0

View File

@ -47,7 +47,10 @@ def _get_connection_type_device_class(discovery_info):
dr = DiscoveryResult.from_dict(discovery_info["result"]) dr = DiscoveryResult.from_dict(discovery_info["result"])
connection_type = DeviceConnectionParameters.from_values( connection_type = DeviceConnectionParameters.from_values(
dr.device_type, dr.mgt_encrypt_schm.encrypt_type dr.device_type,
dr.mgt_encrypt_schm.encrypt_type,
dr.mgt_encrypt_schm.lv,
dr.mgt_encrypt_schm.is_support_https,
) )
else: else:
connection_type = DeviceConnectionParameters.from_values( connection_type = DeviceConnectionParameters.from_values(

View File

@ -0,0 +1,374 @@
from __future__ import annotations
import logging
from base64 import b64encode
from contextlib import nullcontext as does_not_raise
from typing import Any
import aiohttp
import pytest
from yarl import URL
from kasa.credentials import DEFAULT_CREDENTIALS, Credentials, get_default_credentials
from kasa.deviceconfig import DeviceConfig
from kasa.exceptions import (
AuthenticationError,
DeviceError,
KasaException,
SmartErrorCode,
_RetryableError,
)
from kasa.httpclient import HttpClient
from kasa.json import dumps as json_dumps
from kasa.json import loads as json_loads
from kasa.transports import SslTransport
from kasa.transports.ssltransport import TransportState, _md5_hash
# Transport tests are not designed for real devices
pytestmark = [pytest.mark.requires_dummy]
MOCK_PWD = "correct_pwd" # noqa: S105
MOCK_USER = "mock@example.com"
MOCK_BAD_USER_OR_PWD = "foobar" # noqa: S105
MOCK_TOKEN = "abcdefghijklmnopqrstuvwxyz1234)(" # noqa: S105
DEFAULT_CREDS = get_default_credentials(DEFAULT_CREDENTIALS["TAPO"])
_LOGGER = logging.getLogger(__name__)
@pytest.mark.parametrize(
(
"status_code",
"error_code",
"username",
"password",
"expectation",
),
[
pytest.param(
200,
SmartErrorCode.SUCCESS,
MOCK_USER,
MOCK_PWD,
does_not_raise(),
id="success",
),
pytest.param(
200,
SmartErrorCode.UNSPECIFIC_ERROR,
MOCK_USER,
MOCK_PWD,
pytest.raises(_RetryableError),
id="test retry",
),
pytest.param(
200,
SmartErrorCode.DEVICE_BLOCKED,
MOCK_USER,
MOCK_PWD,
pytest.raises(DeviceError),
id="test regular error",
),
pytest.param(
400,
SmartErrorCode.INTERNAL_UNKNOWN_ERROR,
MOCK_USER,
MOCK_PWD,
pytest.raises(KasaException),
id="400 error",
),
pytest.param(
200,
SmartErrorCode.LOGIN_ERROR,
MOCK_BAD_USER_OR_PWD,
MOCK_PWD,
pytest.raises(AuthenticationError),
id="bad-username",
),
pytest.param(
200,
[SmartErrorCode.LOGIN_ERROR, SmartErrorCode.SUCCESS],
MOCK_BAD_USER_OR_PWD,
"",
does_not_raise(),
id="working-fallback",
),
pytest.param(
200,
[SmartErrorCode.LOGIN_ERROR, SmartErrorCode.LOGIN_ERROR],
MOCK_BAD_USER_OR_PWD,
"",
pytest.raises(AuthenticationError),
id="fallback-fail",
),
pytest.param(
200,
SmartErrorCode.LOGIN_ERROR,
MOCK_USER,
MOCK_BAD_USER_OR_PWD,
pytest.raises(AuthenticationError),
id="bad-password",
),
pytest.param(
200,
SmartErrorCode.TRANSPORT_UNKNOWN_CREDENTIALS_ERROR,
MOCK_USER,
MOCK_PWD,
pytest.raises(AuthenticationError),
id="auth-error != login_error",
),
],
)
async def test_login(
mocker,
status_code,
error_code,
username,
password,
expectation,
):
host = "127.0.0.1"
mock_ssl_aes_device = MockSslDevice(
host,
status_code=status_code,
send_error_code=error_code,
)
mocker.patch.object(
aiohttp.ClientSession, "post", side_effect=mock_ssl_aes_device.post
)
transport = SslTransport(
config=DeviceConfig(host, credentials=Credentials(username, password))
)
assert transport._state is TransportState.LOGIN_REQUIRED
with expectation:
await transport.perform_login()
assert transport._state is TransportState.ESTABLISHED
await transport.close()
async def test_credentials_hash(mocker):
host = "127.0.0.1"
mock_ssl_aes_device = MockSslDevice(host)
mocker.patch.object(
aiohttp.ClientSession, "post", side_effect=mock_ssl_aes_device.post
)
creds = Credentials(MOCK_USER, MOCK_PWD)
data = {"password": _md5_hash(MOCK_PWD.encode()), "username": MOCK_USER}
creds_hash = b64encode(json_dumps(data).encode()).decode()
# Test with credentials input
transport = SslTransport(config=DeviceConfig(host, credentials=creds))
assert transport.credentials_hash == creds_hash
# Test with credentials_hash input
transport = SslTransport(config=DeviceConfig(host, credentials_hash=creds_hash))
assert transport.credentials_hash == creds_hash
await transport.close()
async def test_send(mocker):
host = "127.0.0.1"
mock_ssl_aes_device = MockSslDevice(host, send_error_code=SmartErrorCode.SUCCESS)
mocker.patch.object(
aiohttp.ClientSession, "post", side_effect=mock_ssl_aes_device.post
)
transport = SslTransport(
config=DeviceConfig(host, credentials=Credentials(MOCK_USER, MOCK_PWD))
)
try_login_spy = mocker.spy(transport, "try_login")
request = {
"method": "get_device_info",
"params": None,
}
assert transport._state is TransportState.LOGIN_REQUIRED
res = await transport.send(json_dumps(request))
assert "result" in res
try_login_spy.assert_called_once()
assert transport._state is TransportState.ESTABLISHED
# Second request does not
res = await transport.send(json_dumps(request))
try_login_spy.assert_called_once()
await transport.close()
async def test_no_credentials(mocker):
"""Test transport without credentials."""
host = "127.0.0.1"
mock_ssl_aes_device = MockSslDevice(
host, send_error_code=SmartErrorCode.LOGIN_ERROR
)
mocker.patch.object(
aiohttp.ClientSession, "post", side_effect=mock_ssl_aes_device.post
)
transport = SslTransport(config=DeviceConfig(host))
try_login_spy = mocker.spy(transport, "try_login")
with pytest.raises(AuthenticationError):
await transport.send('{"method": "dummy"}')
# We get called twice
assert try_login_spy.call_count == 2
await transport.close()
async def test_reset(mocker):
"""Test that transport state adjusts correctly for reset."""
host = "127.0.0.1"
mock_ssl_aes_device = MockSslDevice(host, send_error_code=SmartErrorCode.SUCCESS)
mocker.patch.object(
aiohttp.ClientSession, "post", side_effect=mock_ssl_aes_device.post
)
transport = SslTransport(
config=DeviceConfig(host, credentials=Credentials(MOCK_USER, MOCK_PWD))
)
assert transport._state is TransportState.LOGIN_REQUIRED
assert str(transport._app_url) == "https://127.0.0.1:4433/app"
await transport.perform_login()
assert transport._state is TransportState.ESTABLISHED
assert str(transport._app_url).startswith("https://127.0.0.1:4433/app?token=")
await transport.close()
assert transport._state is TransportState.LOGIN_REQUIRED
assert str(transport._app_url) == "https://127.0.0.1:4433/app"
async def test_port_override():
"""Test that port override sets the app_url."""
host = "127.0.0.1"
port_override = 12345
config = DeviceConfig(
host, credentials=Credentials("foo", "bar"), port_override=port_override
)
transport = SslTransport(config=config)
assert str(transport._app_url) == f"https://127.0.0.1:{port_override}/app"
await transport.close()
class MockSslDevice:
"""Based on MockAesSslDevice."""
class _mock_response:
def __init__(self, status, request: dict):
self.status = status
self._json = request
async def __aenter__(self):
return self
async def __aexit__(self, exc_t, exc_v, exc_tb):
pass
async def read(self):
if isinstance(self._json, dict):
return json_dumps(self._json).encode()
return self._json
def __init__(
self,
host,
*,
status_code=200,
send_error_code=SmartErrorCode.INTERNAL_UNKNOWN_ERROR,
):
self.host = host
self.http_client = HttpClient(DeviceConfig(self.host))
self._state = TransportState.LOGIN_REQUIRED
# test behaviour attributes
self.status_code = status_code
self.send_error_code = send_error_code
async def post(self, url: URL, params=None, json=None, data=None, *_, **__):
if data:
json = json_loads(data)
_LOGGER.debug("Request %s: %s", url, json)
res = self._post(url, json)
_LOGGER.debug("Response %s, data: %s", res, await res.read())
return res
def _post(self, url: URL, json: dict[str, Any]):
method = json["method"]
if method == "login":
if self._state is TransportState.LOGIN_REQUIRED:
assert json.get("token") is None
assert url == URL(f"https://{self.host}:4433/app")
return self._return_login_response(url, json)
else:
_LOGGER.warning("Received login although already logged in")
pytest.fail("non-handled re-login logic")
assert url == URL(f"https://{self.host}:4433/app?token={MOCK_TOKEN}")
return self._return_send_response(url, json)
def _return_login_response(self, url: URL, request: dict[str, Any]):
request_username = request["params"].get("username")
request_password = request["params"].get("password")
# Handle multiple error codes
if isinstance(self.send_error_code, list):
error_code = self.send_error_code.pop(0)
else:
error_code = self.send_error_code
_LOGGER.debug("Using error code %s", error_code)
def _return_login_error():
resp = {
"error_code": error_code.value,
"result": {"unknown": "payload"},
}
_LOGGER.debug("Returning login error with status %s", self.status_code)
return self._mock_response(self.status_code, resp)
if error_code is not SmartErrorCode.SUCCESS:
# Bad username
if request_username == MOCK_BAD_USER_OR_PWD:
return _return_login_error()
# Bad password
if request_password == _md5_hash(MOCK_BAD_USER_OR_PWD.encode()):
return _return_login_error()
# Empty password
if request_password == _md5_hash(b""):
return _return_login_error()
self._state = TransportState.ESTABLISHED
resp = {
"error_code": error_code.value,
"result": {
"token": MOCK_TOKEN,
},
}
_LOGGER.debug("Returning login success with status %s", self.status_code)
return self._mock_response(self.status_code, resp)
def _return_send_response(self, url: URL, json: dict[str, Any]):
method = json["method"]
result = {
"result": {method: {"dummy": "response"}},
"error_code": self.send_error_code.value,
}
return self._mock_response(self.status_code, result)