"""Module for the IOT legacy IOT KASA protocol."""

from __future__ import annotations

import asyncio
import logging
from collections.abc import Callable
from pprint import pformat as pf
from typing import TYPE_CHECKING, Any

from ..deviceconfig import DeviceConfig
from ..exceptions import (
    AuthenticationError,
    KasaException,
    TimeoutError,
    _ConnectionError,
    _RetryableError,
)
from ..json import dumps as json_dumps
from ..transports import XorEncryption, XorTransport
from .protocol import BaseProtocol, mask_mac, redact_data

if TYPE_CHECKING:
    from ..transports import BaseTransport

_LOGGER = logging.getLogger(__name__)


def _mask_children(children: list[dict[str, Any]]) -> list[dict[str, Any]]:
    def mask_child(child: dict[str, Any], index: int) -> dict[str, Any]:
        result = {
            **child,
            "id": f"SCRUBBED_CHILD_DEVICE_ID_{index+1}",
        }
        # Will leave empty aliases as blank
        if child.get("alias"):
            result["alias"] = f"#MASKED_NAME# {index + 1}"
        return result

    return [mask_child(child, index) for index, child in enumerate(children)]


REDACTORS: dict[str, Callable[[Any], Any] | None] = {
    "latitude": lambda x: 0,
    "longitude": lambda x: 0,
    "latitude_i": lambda x: 0,
    "longitude_i": lambda x: 0,
    "deviceId": lambda x: "REDACTED_" + x[9::],
    "children": _mask_children,
    "alias": lambda x: "#MASKED_NAME#" if x else "",
    "mac": mask_mac,
    "mic_mac": mask_mac,
    "ssid": lambda x: "#MASKED_SSID#" if x else "",
    "oemId": lambda x: "REDACTED_" + x[9::],
    "username": lambda _: "user@example.com",  # cnCloud
    "hwId": lambda x: "REDACTED_" + x[9::],
}


class IotProtocol(BaseProtocol):
    """Class for the legacy TPLink IOT KASA Protocol."""

    BACKOFF_SECONDS_AFTER_TIMEOUT = 1

    def __init__(
        self,
        *,
        transport: BaseTransport,
    ) -> None:
        """Create a protocol object."""
        super().__init__(transport=transport)

        self._query_lock = asyncio.Lock()
        self._redact_data = True

    async def query(self, request: str | dict, retry_count: int = 3) -> dict:
        """Query the device retrying for retry_count on failure."""
        if isinstance(request, dict):
            request = json_dumps(request)
            assert isinstance(request, str)  # noqa: S101

        async with self._query_lock:
            return await self._query(request, retry_count)

    async def _query(self, request: str, retry_count: int = 3) -> dict:
        for retry in range(retry_count + 1):
            try:
                return await self._execute_query(request, retry)
            except _ConnectionError as sdex:
                if retry >= retry_count:
                    _LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
                    raise sdex
                continue
            except AuthenticationError as auex:
                await self._transport.reset()
                _LOGGER.debug(
                    "Unable to authenticate with %s, not retrying", self._host
                )
                raise auex
            except _RetryableError as ex:
                await self._transport.reset()
                if retry >= retry_count:
                    _LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
                    raise ex
                continue
            except TimeoutError as ex:
                await self._transport.reset()
                if retry >= retry_count:
                    _LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
                    raise ex
                await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT)
                continue
            except KasaException as ex:
                await self._transport.reset()
                _LOGGER.debug(
                    "Unable to query the device: %s, not retrying: %s",
                    self._host,
                    ex,
                )
                raise ex

        # make mypy happy, this should never be reached..
        raise KasaException("Query reached somehow to unreachable")

    async def _execute_query(self, request: str, retry_count: int) -> dict:
        debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)

        if debug_enabled:
            _LOGGER.debug(
                "%s >> %s",
                self._host,
                request,
            )
        resp = await self._transport.send(request)

        if debug_enabled:
            data = redact_data(resp, REDACTORS) if self._redact_data else resp
            _LOGGER.debug(
                "%s << %s",
                self._host,
                pf(data),
            )
        return resp

    async def close(self) -> None:
        """Close the underlying transport."""
        await self._transport.close()


class _deprecated_TPLinkSmartHomeProtocol(IotProtocol):
    def __init__(
        self,
        host: str | None = None,
        *,
        port: int | None = None,
        timeout: int | None = None,
        transport: BaseTransport | None = None,
    ) -> None:
        """Create a protocol object."""
        if not host and not transport:
            raise KasaException("host or transport must be supplied")
        if not transport:
            config = DeviceConfig(
                host=host,  # type: ignore[arg-type]
                port_override=port,
                timeout=timeout or XorTransport.DEFAULT_TIMEOUT,
            )
            transport = XorTransport(config=config)
        super().__init__(transport=transport)

    @staticmethod
    def encrypt(request: str) -> bytes:
        """Encrypt a request for a TP-Link Smart Home Device.

        :param request: plaintext request data
        :return: ciphertext to be send over wire, in bytes
        """
        return XorEncryption.encrypt(request)

    @staticmethod
    def decrypt(ciphertext: bytes) -> str:
        """Decrypt a response of a TP-Link Smart Home Device.

        :param ciphertext: encrypted response data
        :return: plaintext response
        """
        return XorEncryption.decrypt(ciphertext)