"""Implementation of the clear-text passthrough ssl transport.

This transport does not encrypt the passthrough payloads at all, but requires a login.
This has been seen on some devices (like robovacs).
"""

from __future__ import annotations

import asyncio
import base64
import hashlib
import logging
import time
from enum import Enum, auto
from typing import TYPE_CHECKING, Any, cast

from yarl import URL

from kasa.credentials import DEFAULT_CREDENTIALS, Credentials, get_default_credentials
from kasa.deviceconfig import DeviceConfig
from kasa.exceptions import (
    SMART_AUTHENTICATION_ERRORS,
    SMART_RETRYABLE_ERRORS,
    AuthenticationError,
    DeviceError,
    KasaException,
    SmartErrorCode,
    _RetryableError,
)
from kasa.httpclient import HttpClient
from kasa.json import dumps as json_dumps
from kasa.json import loads as json_loads
from kasa.transports import BaseTransport

_LOGGER = logging.getLogger(__name__)


ONE_DAY_SECONDS = 86400
SESSION_EXPIRE_BUFFER_SECONDS = 60 * 20


def _md5_hash(payload: bytes) -> str:
    return hashlib.md5(payload).hexdigest().upper()  # noqa: S324


class TransportState(Enum):
    """Enum for transport state."""

    LOGIN_REQUIRED = auto()  # Login needed
    ESTABLISHED = auto()  # Ready to send requests


class SslTransport(BaseTransport):
    """Implementation of the cleartext transport protocol.

    This transport uses HTTPS without any further payload encryption.
    """

    DEFAULT_PORT: int = 4433
    COMMON_HEADERS = {
        "Content-Type": "application/json",
    }
    BACKOFF_SECONDS_AFTER_LOGIN_ERROR = 1

    def __init__(
        self,
        *,
        config: DeviceConfig,
    ) -> None:
        super().__init__(config=config)

        if (
            not self._credentials or self._credentials.username is None
        ) and not self._credentials_hash:
            self._credentials = Credentials()

        if self._credentials:
            self._login_params = self._get_login_params(self._credentials)
        else:
            self._login_params = json_loads(
                base64.b64decode(self._credentials_hash.encode()).decode()  # type: ignore[union-attr]
            )

        self._default_credentials: Credentials | None = None
        self._http_client: HttpClient = HttpClient(config)

        self._state = TransportState.LOGIN_REQUIRED
        self._session_expire_at: float | None = None

        self._app_url = URL(f"https://{self._host}:{self._port}/app")

        _LOGGER.debug("Created ssltransport for %s", self._host)

    @property
    def default_port(self) -> int:
        """Default port for the transport."""
        return self.DEFAULT_PORT

    @property
    def credentials_hash(self) -> str:
        """The hashed credentials used by the transport."""
        return base64.b64encode(json_dumps(self._login_params).encode()).decode()

    def _get_login_params(self, credentials: Credentials) -> dict[str, str]:
        """Get the login parameters based on the login_version."""
        un, pw = self.hash_credentials(credentials)
        return {"password": pw, "username": un}

    @staticmethod
    def hash_credentials(credentials: Credentials) -> tuple[str, str]:
        """Hash the credentials."""
        un = credentials.username
        pw = _md5_hash(credentials.password.encode())
        return un, pw

    async def _handle_response_error_code(self, resp_dict: Any, msg: str) -> None:
        """Handle response errors to request reauth etc."""
        error_code = SmartErrorCode(resp_dict.get("error_code"))  # type: ignore[arg-type]
        if error_code == SmartErrorCode.SUCCESS:
            return

        msg = f"{msg}: {self._host}: {error_code.name}({error_code.value})"

        if error_code in SMART_RETRYABLE_ERRORS:
            raise _RetryableError(msg, error_code=error_code)

        if error_code in SMART_AUTHENTICATION_ERRORS:
            await self.reset()
            raise AuthenticationError(msg, error_code=error_code)

        raise DeviceError(msg, error_code=error_code)

    async def send_request(self, request: str) -> dict[str, Any]:
        """Send request."""
        url = self._app_url

        _LOGGER.debug("Sending %s to %s", request, url)

        status_code, resp_dict = await self._http_client.post(
            url,
            json=request,
            headers=self.COMMON_HEADERS,
        )

        if status_code != 200:
            raise KasaException(
                f"{self._host} responded with an unexpected "
                + f"status code {status_code}"
            )

        _LOGGER.debug("Response with %s: %r", status_code, resp_dict)

        await self._handle_response_error_code(resp_dict, "Error sending request")

        if TYPE_CHECKING:
            resp_dict = cast(dict[str, Any], resp_dict)

        return resp_dict

    async def perform_login(self) -> None:
        """Login to the device."""
        try:
            await self.try_login(self._login_params)
        except AuthenticationError as aex:
            try:
                if aex.error_code is not SmartErrorCode.LOGIN_ERROR:
                    raise aex

                _LOGGER.debug("Login failed, going to try default credentials")
                if self._default_credentials is None:
                    self._default_credentials = get_default_credentials(
                        DEFAULT_CREDENTIALS["TAPO"]
                    )
                    await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_LOGIN_ERROR)

                await self.try_login(self._get_login_params(self._default_credentials))
                _LOGGER.debug(
                    "%s: logged in with default credentials",
                    self._host,
                )
            except AuthenticationError:
                raise
            except Exception as ex:
                raise KasaException(
                    "Unable to login and trying default "
                    + f"login raised another exception: {ex}",
                    ex,
                ) from ex

    async def try_login(self, login_params: dict[str, Any]) -> None:
        """Try to login with supplied login_params."""
        login_request = {
            "method": "login",
            "params": login_params,
        }
        request = json_dumps(login_request)
        _LOGGER.debug("Going to send login request")

        resp_dict = await self.send_request(request)
        await self._handle_response_error_code(resp_dict, "Error logging in")

        login_token = resp_dict["result"]["token"]
        self._app_url = self._app_url.with_query(f"token={login_token}")
        self._state = TransportState.ESTABLISHED
        self._session_expire_at = (
            time.time() + ONE_DAY_SECONDS - SESSION_EXPIRE_BUFFER_SECONDS
        )

    def _session_expired(self) -> bool:
        """Return true if session has expired."""
        return (
            self._session_expire_at is None
            or self._session_expire_at - time.time() <= 0
        )

    async def send(self, request: str) -> dict[str, Any]:
        """Send the request."""
        _LOGGER.info("Going to send %s", request)
        if self._state is not TransportState.ESTABLISHED or self._session_expired():
            _LOGGER.debug("Transport not established or session expired, logging in")
            await self.perform_login()

        return await self.send_request(request)

    async def close(self) -> None:
        """Close the http client and reset internal state."""
        await self.reset()
        await self._http_client.close()

    async def reset(self) -> None:
        """Reset internal login state."""
        self._state = TransportState.LOGIN_REQUIRED
        self._app_url = URL(f"https://{self._host}:{self._port}/app")