From 9966c6094ae281229224c811f8bcba34f7b9d9dd Mon Sep 17 00:00:00 2001 From: "Teemu R." Date: Sun, 1 Dec 2024 18:06:48 +0100 Subject: [PATCH] Add ssltransport for robovacs (#943) This PR implements a clear-text, token-based transport protocol seen on RV30 Plus (#937). - Client sends `{"username": "email@example.com", "password": md5(password)}` and gets back a token in the response - Rest of the communications are done with POST at `/app?token=` --------- Co-authored-by: Steven B. <51370195+sdb9696@users.noreply.github.com> --- devtools/helpers/smartrequests.py | 24 +- kasa/cli/main.py | 1 + kasa/device_factory.py | 16 +- kasa/device_type.py | 1 + kasa/deviceconfig.py | 1 + kasa/discover.py | 11 +- kasa/smart/smartdevice.py | 2 + kasa/transports/__init__.py | 2 + kasa/transports/ssltransport.py | 233 ++++++++++++++++ tests/test_cli.py | 2 + tests/test_device_factory.py | 5 +- tests/transports/test_ssltransport.py | 374 ++++++++++++++++++++++++++ 12 files changed, 656 insertions(+), 16 deletions(-) create mode 100644 kasa/transports/ssltransport.py create mode 100644 tests/transports/test_ssltransport.py diff --git a/devtools/helpers/smartrequests.py b/devtools/helpers/smartrequests.py index 20b1300e..6ab53937 100644 --- a/devtools/helpers/smartrequests.py +++ b/devtools/helpers/smartrequests.py @@ -427,25 +427,25 @@ COMPONENT_REQUESTS = { "overheat_protection": [], # Vacuum components "clean": [ - SmartRequest.get_raw_request("get_clean_records"), - SmartRequest.get_raw_request("get_vac_state"), + SmartRequest.get_raw_request("getCleanRecords"), + SmartRequest.get_raw_request("getVacStatus"), ], - "battery": [SmartRequest.get_raw_request("get_battery_info")], - "consumables": [SmartRequest.get_raw_request("get_consumables_info")], + "battery": [SmartRequest.get_raw_request("getBatteryInfo")], + "consumables": [SmartRequest.get_raw_request("getConsumablesInfo")], "direction_control": [], "button_and_led": [], "speaker": [ - SmartRequest.get_raw_request("get_support_voice_language"), - SmartRequest.get_raw_request("get_current_voice_language"), + SmartRequest.get_raw_request("getSupportVoiceLanguage"), + SmartRequest.get_raw_request("getCurrentVoiceLanguage"), ], "map": [ - SmartRequest.get_raw_request("get_map_info"), - SmartRequest.get_raw_request("get_map_data"), + SmartRequest.get_raw_request("getMapInfo"), + SmartRequest.get_raw_request("getMapData"), ], - "auto_change_map": [SmartRequest.get_raw_request("get_auto_change_map")], - "dust_bucket": [SmartRequest.get_raw_request("get_auto_dust_collection")], - "mop": [SmartRequest.get_raw_request("get_mop_state")], - "do_not_disturb": [SmartRequest.get_raw_request("get_do_not_disturb")], + "auto_change_map": [SmartRequest.get_raw_request("getAutoChangeMap")], + "dust_bucket": [SmartRequest.get_raw_request("getAutoDustCollection")], + "mop": [SmartRequest.get_raw_request("getMopState")], + "do_not_disturb": [SmartRequest.get_raw_request("getDoNotDisturb")], "charge_pose_clean": [], "continue_breakpoint_sweep": [], "goto_point": [], diff --git a/kasa/cli/main.py b/kasa/cli/main.py index d0efc73f..fbcdf391 100755 --- a/kasa/cli/main.py +++ b/kasa/cli/main.py @@ -308,6 +308,7 @@ async def cli( if type == "camera": encrypt_type = "AES" https = True + login_version = 2 device_family = "SMART.IPCAMERA" from kasa.device import Device diff --git a/kasa/device_factory.py b/kasa/device_factory.py index d7ba5b53..be3c6ca0 100755 --- a/kasa/device_factory.py +++ b/kasa/device_factory.py @@ -32,6 +32,7 @@ from .transports import ( BaseTransport, KlapTransport, KlapTransportV2, + SslTransport, XorTransport, ) from .transports.sslaestransport import SslAesTransport @@ -155,6 +156,7 @@ def get_device_class_from_family( "SMART.KASAHUB": SmartDevice, "SMART.KASASWITCH": SmartDevice, "SMART.IPCAMERA.HTTPS": SmartCamDevice, + "SMART.TAPOROBOVAC": SmartDevice, "IOT.SMARTPLUGSWITCH": IotPlug, "IOT.SMARTBULB": IotBulb, } @@ -176,20 +178,30 @@ def get_protocol( """Return the protocol from the connection name.""" protocol_name = config.connection_type.device_family.value.split(".")[0] ctype = config.connection_type + protocol_transport_key = ( protocol_name + "." + ctype.encryption_type.value + (".HTTPS" if ctype.https else "") + + ( + f".{ctype.login_version}" + if ctype.login_version and ctype.login_version > 1 + else "" + ) ) + + _LOGGER.debug("Finding transport for %s", protocol_transport_key) supported_device_protocols: dict[ str, tuple[type[BaseProtocol], type[BaseTransport]] ] = { "IOT.XOR": (IotProtocol, XorTransport), "IOT.KLAP": (IotProtocol, KlapTransport), "SMART.AES": (SmartProtocol, AesTransport), - "SMART.KLAP": (SmartProtocol, KlapTransportV2), - "SMART.AES.HTTPS": (SmartCamProtocol, SslAesTransport), + "SMART.AES.2": (SmartProtocol, AesTransport), + "SMART.KLAP.2": (SmartProtocol, KlapTransportV2), + "SMART.AES.HTTPS.2": (SmartCamProtocol, SslAesTransport), + "SMART.AES.HTTPS": (SmartProtocol, SslTransport), } if not (prot_tran_cls := supported_device_protocols.get(protocol_transport_key)): return None diff --git a/kasa/device_type.py b/kasa/device_type.py index b690f1f1..7fe485d3 100755 --- a/kasa/device_type.py +++ b/kasa/device_type.py @@ -21,6 +21,7 @@ class DeviceType(Enum): Hub = "hub" Fan = "fan" Thermostat = "thermostat" + Vacuum = "vacuum" Unknown = "unknown" @staticmethod diff --git a/kasa/deviceconfig.py b/kasa/deviceconfig.py index 1156cf25..6f9176f5 100644 --- a/kasa/deviceconfig.py +++ b/kasa/deviceconfig.py @@ -77,6 +77,7 @@ class DeviceFamily(Enum): SmartTapoHub = "SMART.TAPOHUB" SmartKasaHub = "SMART.KASAHUB" SmartIpCamera = "SMART.IPCAMERA" + SmartTapoRobovac = "SMART.TAPOROBOVAC" class _DeviceConfigBaseMixin(DataClassJSONMixin): diff --git a/kasa/discover.py b/kasa/discover.py index f89999f4..771c3f5c 100755 --- a/kasa/discover.py +++ b/kasa/discover.py @@ -598,10 +598,12 @@ class Discover: for encrypt in Device.EncryptionType for device_family in main_device_families for https in (True, False) + for login_version in (None, 2) if ( conn_params := DeviceConnectionParameters( device_family=device_family, encryption_type=encrypt, + login_version=login_version, https=https, ) ) @@ -768,6 +770,13 @@ class Discover: ): encrypt_type = encrypt_info.sym_schm + if ( + not (login_version := encrypt_schm.lv) + and (et := discovery_result.encrypt_type) + and et == ["3"] + ): + login_version = 2 + if not encrypt_type: raise UnsupportedDeviceError( f"Unsupported device {config.host} of type {type_} " @@ -778,7 +787,7 @@ class Discover: config.connection_type = DeviceConnectionParameters.from_values( type_, encrypt_type, - encrypt_schm.lv, + login_version, encrypt_schm.is_support_https, ) except KasaException as ex: diff --git a/kasa/smart/smartdevice.py b/kasa/smart/smartdevice.py index 0989842a..adb4829d 100644 --- a/kasa/smart/smartdevice.py +++ b/kasa/smart/smartdevice.py @@ -802,6 +802,8 @@ class SmartDevice(Device): return DeviceType.Sensor if "ENERGY" in device_type: return DeviceType.Thermostat + if "ROBOVAC" in device_type: + return DeviceType.Vacuum _LOGGER.warning("Unknown device type, falling back to plug") return DeviceType.Plug diff --git a/kasa/transports/__init__.py b/kasa/transports/__init__.py index 8ccdae65..3438aab7 100644 --- a/kasa/transports/__init__.py +++ b/kasa/transports/__init__.py @@ -3,11 +3,13 @@ from .aestransport import AesEncyptionSession, AesTransport from .basetransport import BaseTransport from .klaptransport import KlapTransport, KlapTransportV2 +from .ssltransport import SslTransport from .xortransport import XorEncryption, XorTransport __all__ = [ "AesTransport", "AesEncyptionSession", + "SslTransport", "BaseTransport", "KlapTransport", "KlapTransportV2", diff --git a/kasa/transports/ssltransport.py b/kasa/transports/ssltransport.py new file mode 100644 index 00000000..5ffc935f --- /dev/null +++ b/kasa/transports/ssltransport.py @@ -0,0 +1,233 @@ +"""Implementation of the clear-text passthrough ssl transport. + +This transport does not encrypt the passthrough payloads at all, but requires a login. +This has been seen on some devices (like robovacs). +""" + +from __future__ import annotations + +import asyncio +import base64 +import hashlib +import logging +import time +from enum import Enum, auto +from typing import TYPE_CHECKING, Any, cast + +from yarl import URL + +from kasa.credentials import DEFAULT_CREDENTIALS, Credentials, get_default_credentials +from kasa.deviceconfig import DeviceConfig +from kasa.exceptions import ( + SMART_AUTHENTICATION_ERRORS, + SMART_RETRYABLE_ERRORS, + AuthenticationError, + DeviceError, + KasaException, + SmartErrorCode, + _RetryableError, +) +from kasa.httpclient import HttpClient +from kasa.json import dumps as json_dumps +from kasa.json import loads as json_loads +from kasa.transports import BaseTransport + +_LOGGER = logging.getLogger(__name__) + + +ONE_DAY_SECONDS = 86400 +SESSION_EXPIRE_BUFFER_SECONDS = 60 * 20 + + +def _md5_hash(payload: bytes) -> str: + return hashlib.md5(payload).hexdigest().upper() # noqa: S324 + + +class TransportState(Enum): + """Enum for transport state.""" + + LOGIN_REQUIRED = auto() # Login needed + ESTABLISHED = auto() # Ready to send requests + + +class SslTransport(BaseTransport): + """Implementation of the cleartext transport protocol. + + This transport uses HTTPS without any further payload encryption. + """ + + DEFAULT_PORT: int = 4433 + COMMON_HEADERS = { + "Content-Type": "application/json", + } + BACKOFF_SECONDS_AFTER_LOGIN_ERROR = 1 + + def __init__( + self, + *, + config: DeviceConfig, + ) -> None: + super().__init__(config=config) + + if ( + not self._credentials or self._credentials.username is None + ) and not self._credentials_hash: + self._credentials = Credentials() + + if self._credentials: + self._login_params = self._get_login_params(self._credentials) + else: + self._login_params = json_loads( + base64.b64decode(self._credentials_hash.encode()).decode() # type: ignore[union-attr] + ) + + self._default_credentials: Credentials | None = None + self._http_client: HttpClient = HttpClient(config) + + self._state = TransportState.LOGIN_REQUIRED + self._session_expire_at: float | None = None + + self._app_url = URL(f"https://{self._host}:{self._port}/app") + + _LOGGER.debug("Created ssltransport for %s", self._host) + + @property + def default_port(self) -> int: + """Default port for the transport.""" + return self.DEFAULT_PORT + + @property + def credentials_hash(self) -> str: + """The hashed credentials used by the transport.""" + return base64.b64encode(json_dumps(self._login_params).encode()).decode() + + def _get_login_params(self, credentials: Credentials) -> dict[str, str]: + """Get the login parameters based on the login_version.""" + un, pw = self.hash_credentials(credentials) + return {"password": pw, "username": un} + + @staticmethod + def hash_credentials(credentials: Credentials) -> tuple[str, str]: + """Hash the credentials.""" + un = credentials.username + pw = _md5_hash(credentials.password.encode()) + return un, pw + + async def _handle_response_error_code(self, resp_dict: Any, msg: str) -> None: + """Handle response errors to request reauth etc.""" + error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type] + if error_code == SmartErrorCode.SUCCESS: + return + + msg = f"{msg}: {self._host}: {error_code.name}({error_code.value})" + + if error_code in SMART_RETRYABLE_ERRORS: + raise _RetryableError(msg, error_code=error_code) + + if error_code in SMART_AUTHENTICATION_ERRORS: + await self.reset() + raise AuthenticationError(msg, error_code=error_code) + + raise DeviceError(msg, error_code=error_code) + + async def send_request(self, request: str) -> dict[str, Any]: + """Send request.""" + url = self._app_url + + _LOGGER.debug("Sending %s to %s", request, url) + + status_code, resp_dict = await self._http_client.post( + url, + json=request, + headers=self.COMMON_HEADERS, + ) + + if status_code != 200: + raise KasaException( + f"{self._host} responded with an unexpected " + + f"status code {status_code}" + ) + + _LOGGER.debug("Response with %s: %r", status_code, resp_dict) + + await self._handle_response_error_code(resp_dict, "Error sending request") + + if TYPE_CHECKING: + resp_dict = cast(dict[str, Any], resp_dict) + + return resp_dict + + async def perform_login(self) -> None: + """Login to the device.""" + try: + await self.try_login(self._login_params) + except AuthenticationError as aex: + try: + if aex.error_code is not SmartErrorCode.LOGIN_ERROR: + raise aex + + _LOGGER.debug("Login failed, going to try default credentials") + if self._default_credentials is None: + self._default_credentials = get_default_credentials( + DEFAULT_CREDENTIALS["TAPO"] + ) + await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_LOGIN_ERROR) + + await self.try_login(self._get_login_params(self._default_credentials)) + _LOGGER.debug( + "%s: logged in with default credentials", + self._host, + ) + except AuthenticationError: + raise + except Exception as ex: + raise KasaException( + "Unable to login and trying default " + + f"login raised another exception: {ex}", + ex, + ) from ex + + async def try_login(self, login_params: dict[str, Any]) -> None: + """Try to login with supplied login_params.""" + login_request = { + "method": "login", + "params": login_params, + } + request = json_dumps(login_request) + _LOGGER.debug("Going to send login request") + + resp_dict = await self.send_request(request) + await self._handle_response_error_code(resp_dict, "Error logging in") + + login_token = resp_dict["result"]["token"] + self._app_url = self._app_url.with_query(f"token={login_token}") + self._state = TransportState.ESTABLISHED + self._session_expire_at = ( + time.time() + ONE_DAY_SECONDS - SESSION_EXPIRE_BUFFER_SECONDS + ) + + def _session_expired(self) -> bool: + """Return true if session has expired.""" + return ( + self._session_expire_at is None + or self._session_expire_at - time.time() <= 0 + ) + + async def send(self, request: str) -> dict[str, Any]: + """Send the request.""" + _LOGGER.info("Going to send %s", request) + if self._state is not TransportState.ESTABLISHED or self._session_expired(): + _LOGGER.debug("Transport not established or session expired, logging in") + await self.perform_login() + + return await self.send_request(request) + + async def close(self) -> None: + """Close the http client and reset internal state.""" + await self.reset() + await self._http_client.close() + + async def reset(self) -> None: + """Reset internal login state.""" + self._state = TransportState.LOGIN_REQUIRED + self._app_url = URL(f"https://{self._host}:{self._port}/app") diff --git a/tests/test_cli.py b/tests/test_cli.py index bb707bb6..d1fc330c 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -692,6 +692,8 @@ async def test_credentials(discovery_mock, mocker, runner): dr.device_type, "--encrypt-type", dr.mgt_encrypt_schm.encrypt_type, + "--login-version", + dr.mgt_encrypt_schm.lv or 1, ], ) assert res.exit_code == 0 diff --git a/tests/test_device_factory.py b/tests/test_device_factory.py index 86003744..ed73b3a3 100644 --- a/tests/test_device_factory.py +++ b/tests/test_device_factory.py @@ -47,7 +47,10 @@ def _get_connection_type_device_class(discovery_info): dr = DiscoveryResult.from_dict(discovery_info["result"]) connection_type = DeviceConnectionParameters.from_values( - dr.device_type, dr.mgt_encrypt_schm.encrypt_type + dr.device_type, + dr.mgt_encrypt_schm.encrypt_type, + dr.mgt_encrypt_schm.lv, + dr.mgt_encrypt_schm.is_support_https, ) else: connection_type = DeviceConnectionParameters.from_values( diff --git a/tests/transports/test_ssltransport.py b/tests/transports/test_ssltransport.py new file mode 100644 index 00000000..37b79725 --- /dev/null +++ b/tests/transports/test_ssltransport.py @@ -0,0 +1,374 @@ +from __future__ import annotations + +import logging +from base64 import b64encode +from contextlib import nullcontext as does_not_raise +from typing import Any + +import aiohttp +import pytest +from yarl import URL + +from kasa.credentials import DEFAULT_CREDENTIALS, Credentials, get_default_credentials +from kasa.deviceconfig import DeviceConfig +from kasa.exceptions import ( + AuthenticationError, + DeviceError, + KasaException, + SmartErrorCode, + _RetryableError, +) +from kasa.httpclient import HttpClient +from kasa.json import dumps as json_dumps +from kasa.json import loads as json_loads +from kasa.transports import SslTransport +from kasa.transports.ssltransport import TransportState, _md5_hash + +# Transport tests are not designed for real devices +pytestmark = [pytest.mark.requires_dummy] + +MOCK_PWD = "correct_pwd" # noqa: S105 +MOCK_USER = "mock@example.com" +MOCK_BAD_USER_OR_PWD = "foobar" # noqa: S105 +MOCK_TOKEN = "abcdefghijklmnopqrstuvwxyz1234)(" # noqa: S105 + +DEFAULT_CREDS = get_default_credentials(DEFAULT_CREDENTIALS["TAPO"]) + + +_LOGGER = logging.getLogger(__name__) + + +@pytest.mark.parametrize( + ( + "status_code", + "error_code", + "username", + "password", + "expectation", + ), + [ + pytest.param( + 200, + SmartErrorCode.SUCCESS, + MOCK_USER, + MOCK_PWD, + does_not_raise(), + id="success", + ), + pytest.param( + 200, + SmartErrorCode.UNSPECIFIC_ERROR, + MOCK_USER, + MOCK_PWD, + pytest.raises(_RetryableError), + id="test retry", + ), + pytest.param( + 200, + SmartErrorCode.DEVICE_BLOCKED, + MOCK_USER, + MOCK_PWD, + pytest.raises(DeviceError), + id="test regular error", + ), + pytest.param( + 400, + SmartErrorCode.INTERNAL_UNKNOWN_ERROR, + MOCK_USER, + MOCK_PWD, + pytest.raises(KasaException), + id="400 error", + ), + pytest.param( + 200, + SmartErrorCode.LOGIN_ERROR, + MOCK_BAD_USER_OR_PWD, + MOCK_PWD, + pytest.raises(AuthenticationError), + id="bad-username", + ), + pytest.param( + 200, + [SmartErrorCode.LOGIN_ERROR, SmartErrorCode.SUCCESS], + MOCK_BAD_USER_OR_PWD, + "", + does_not_raise(), + id="working-fallback", + ), + pytest.param( + 200, + [SmartErrorCode.LOGIN_ERROR, SmartErrorCode.LOGIN_ERROR], + MOCK_BAD_USER_OR_PWD, + "", + pytest.raises(AuthenticationError), + id="fallback-fail", + ), + pytest.param( + 200, + SmartErrorCode.LOGIN_ERROR, + MOCK_USER, + MOCK_BAD_USER_OR_PWD, + pytest.raises(AuthenticationError), + id="bad-password", + ), + pytest.param( + 200, + SmartErrorCode.TRANSPORT_UNKNOWN_CREDENTIALS_ERROR, + MOCK_USER, + MOCK_PWD, + pytest.raises(AuthenticationError), + id="auth-error != login_error", + ), + ], +) +async def test_login( + mocker, + status_code, + error_code, + username, + password, + expectation, +): + host = "127.0.0.1" + mock_ssl_aes_device = MockSslDevice( + host, + status_code=status_code, + send_error_code=error_code, + ) + mocker.patch.object( + aiohttp.ClientSession, "post", side_effect=mock_ssl_aes_device.post + ) + + transport = SslTransport( + config=DeviceConfig(host, credentials=Credentials(username, password)) + ) + + assert transport._state is TransportState.LOGIN_REQUIRED + with expectation: + await transport.perform_login() + assert transport._state is TransportState.ESTABLISHED + + await transport.close() + + +async def test_credentials_hash(mocker): + host = "127.0.0.1" + mock_ssl_aes_device = MockSslDevice(host) + mocker.patch.object( + aiohttp.ClientSession, "post", side_effect=mock_ssl_aes_device.post + ) + creds = Credentials(MOCK_USER, MOCK_PWD) + + data = {"password": _md5_hash(MOCK_PWD.encode()), "username": MOCK_USER} + + creds_hash = b64encode(json_dumps(data).encode()).decode() + + # Test with credentials input + transport = SslTransport(config=DeviceConfig(host, credentials=creds)) + assert transport.credentials_hash == creds_hash + + # Test with credentials_hash input + transport = SslTransport(config=DeviceConfig(host, credentials_hash=creds_hash)) + assert transport.credentials_hash == creds_hash + + await transport.close() + + +async def test_send(mocker): + host = "127.0.0.1" + mock_ssl_aes_device = MockSslDevice(host, send_error_code=SmartErrorCode.SUCCESS) + mocker.patch.object( + aiohttp.ClientSession, "post", side_effect=mock_ssl_aes_device.post + ) + + transport = SslTransport( + config=DeviceConfig(host, credentials=Credentials(MOCK_USER, MOCK_PWD)) + ) + try_login_spy = mocker.spy(transport, "try_login") + request = { + "method": "get_device_info", + "params": None, + } + assert transport._state is TransportState.LOGIN_REQUIRED + + res = await transport.send(json_dumps(request)) + assert "result" in res + try_login_spy.assert_called_once() + assert transport._state is TransportState.ESTABLISHED + + # Second request does not + res = await transport.send(json_dumps(request)) + try_login_spy.assert_called_once() + + await transport.close() + + +async def test_no_credentials(mocker): + """Test transport without credentials.""" + host = "127.0.0.1" + mock_ssl_aes_device = MockSslDevice( + host, send_error_code=SmartErrorCode.LOGIN_ERROR + ) + mocker.patch.object( + aiohttp.ClientSession, "post", side_effect=mock_ssl_aes_device.post + ) + + transport = SslTransport(config=DeviceConfig(host)) + try_login_spy = mocker.spy(transport, "try_login") + + with pytest.raises(AuthenticationError): + await transport.send('{"method": "dummy"}') + + # We get called twice + assert try_login_spy.call_count == 2 + + await transport.close() + + +async def test_reset(mocker): + """Test that transport state adjusts correctly for reset.""" + host = "127.0.0.1" + mock_ssl_aes_device = MockSslDevice(host, send_error_code=SmartErrorCode.SUCCESS) + mocker.patch.object( + aiohttp.ClientSession, "post", side_effect=mock_ssl_aes_device.post + ) + + transport = SslTransport( + config=DeviceConfig(host, credentials=Credentials(MOCK_USER, MOCK_PWD)) + ) + + assert transport._state is TransportState.LOGIN_REQUIRED + assert str(transport._app_url) == "https://127.0.0.1:4433/app" + + await transport.perform_login() + assert transport._state is TransportState.ESTABLISHED + assert str(transport._app_url).startswith("https://127.0.0.1:4433/app?token=") + + await transport.close() + assert transport._state is TransportState.LOGIN_REQUIRED + assert str(transport._app_url) == "https://127.0.0.1:4433/app" + + +async def test_port_override(): + """Test that port override sets the app_url.""" + host = "127.0.0.1" + port_override = 12345 + config = DeviceConfig( + host, credentials=Credentials("foo", "bar"), port_override=port_override + ) + transport = SslTransport(config=config) + + assert str(transport._app_url) == f"https://127.0.0.1:{port_override}/app" + + await transport.close() + + +class MockSslDevice: + """Based on MockAesSslDevice.""" + + class _mock_response: + def __init__(self, status, request: dict): + self.status = status + self._json = request + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_t, exc_v, exc_tb): + pass + + async def read(self): + if isinstance(self._json, dict): + return json_dumps(self._json).encode() + return self._json + + def __init__( + self, + host, + *, + status_code=200, + send_error_code=SmartErrorCode.INTERNAL_UNKNOWN_ERROR, + ): + self.host = host + self.http_client = HttpClient(DeviceConfig(self.host)) + + self._state = TransportState.LOGIN_REQUIRED + + # test behaviour attributes + self.status_code = status_code + self.send_error_code = send_error_code + + async def post(self, url: URL, params=None, json=None, data=None, *_, **__): + if data: + json = json_loads(data) + _LOGGER.debug("Request %s: %s", url, json) + res = self._post(url, json) + _LOGGER.debug("Response %s, data: %s", res, await res.read()) + return res + + def _post(self, url: URL, json: dict[str, Any]): + method = json["method"] + + if method == "login": + if self._state is TransportState.LOGIN_REQUIRED: + assert json.get("token") is None + assert url == URL(f"https://{self.host}:4433/app") + return self._return_login_response(url, json) + else: + _LOGGER.warning("Received login although already logged in") + pytest.fail("non-handled re-login logic") + + assert url == URL(f"https://{self.host}:4433/app?token={MOCK_TOKEN}") + return self._return_send_response(url, json) + + def _return_login_response(self, url: URL, request: dict[str, Any]): + request_username = request["params"].get("username") + request_password = request["params"].get("password") + + # Handle multiple error codes + if isinstance(self.send_error_code, list): + error_code = self.send_error_code.pop(0) + else: + error_code = self.send_error_code + + _LOGGER.debug("Using error code %s", error_code) + + def _return_login_error(): + resp = { + "error_code": error_code.value, + "result": {"unknown": "payload"}, + } + + _LOGGER.debug("Returning login error with status %s", self.status_code) + return self._mock_response(self.status_code, resp) + + if error_code is not SmartErrorCode.SUCCESS: + # Bad username + if request_username == MOCK_BAD_USER_OR_PWD: + return _return_login_error() + + # Bad password + if request_password == _md5_hash(MOCK_BAD_USER_OR_PWD.encode()): + return _return_login_error() + + # Empty password + if request_password == _md5_hash(b""): + return _return_login_error() + + self._state = TransportState.ESTABLISHED + resp = { + "error_code": error_code.value, + "result": { + "token": MOCK_TOKEN, + }, + } + _LOGGER.debug("Returning login success with status %s", self.status_code) + return self._mock_response(self.status_code, resp) + + def _return_send_response(self, url: URL, json: dict[str, Any]): + method = json["method"] + result = { + "result": {method: {"dummy": "response"}}, + "error_code": self.send_error_code.value, + } + return self._mock_response(self.status_code, result)