"""Module for SmartCamProtocol."""

from __future__ import annotations

import logging
from dataclasses import dataclass
from pprint import pformat as pf
from typing import Any, cast

from ..exceptions import (
    AuthenticationError,
    DeviceError,
    KasaException,
    _RetryableError,
)
from ..json import dumps as json_dumps
from ..transports.sslaestransport import (
    SMART_AUTHENTICATION_ERRORS,
    SMART_RETRYABLE_ERRORS,
    SmartErrorCode,
)
from .smartprotocol import SmartProtocol

_LOGGER = logging.getLogger(__name__)

# List of getMethodNames that should be sent as {"method":"do"}
# https://md.depau.eu/s/r1Ys_oWoP#Modules
GET_METHODS_AS_DO = {
    "getSdCardFormatStatus",
    "getConnectionType",
    "getUserID",
    "getP2PSharePassword",
    "getAESEncryptKey",
    "getFirmwareAFResult",
    "getWhitelampStatus",
}


@dataclass
class SingleRequest:
    """Class for returning single request details from helper functions."""

    method_type: str
    method_name: str
    param_name: str
    request: dict[str, Any]


class SmartCamProtocol(SmartProtocol):
    """Class for SmartCam Protocol."""

    def _get_list_request(
        self, method: str, params: dict | None, start_index: int
    ) -> dict:
        # All smartcam requests have params
        params = cast(dict, params)
        module_name = next(iter(params))
        return {method: {module_name: {"start_index": start_index}}}

    def _handle_response_error_code(
        self, resp_dict: dict, method: str, raise_on_error: bool = True
    ) -> None:
        error_code_raw = resp_dict.get("error_code")
        try:
            error_code = SmartErrorCode.from_int(error_code_raw)
        except ValueError:
            _LOGGER.warning(
                "Device %s received unknown error code: %s", self._host, error_code_raw
            )
            error_code = SmartErrorCode.INTERNAL_UNKNOWN_ERROR

        if error_code is SmartErrorCode.SUCCESS:
            return

        if not raise_on_error:
            resp_dict["result"] = error_code
            return

        msg = (
            f"Error querying device: {self._host}: "
            + f"{error_code.name}({error_code.value})"
            + f" for method: {method}"
        )
        if error_code in SMART_RETRYABLE_ERRORS:
            raise _RetryableError(msg, error_code=error_code)
        if error_code in SMART_AUTHENTICATION_ERRORS:
            raise AuthenticationError(msg, error_code=error_code)
        raise DeviceError(msg, error_code=error_code)

    async def close(self) -> None:
        """Close the underlying transport."""
        await self._transport.close()

    @staticmethod
    def _get_smart_camera_single_request(
        request: dict[str, dict[str, Any]],
    ) -> SingleRequest:
        method = next(iter(request))
        if method == "multipleRequest":
            method_type = "multi"
            params = request["multipleRequest"]
            req = {"method": "multipleRequest", "params": params}
            return SingleRequest("multi", "multipleRequest", "", req)

        param = next(iter(request[method]))
        method_type = method
        req = {
            "method": method,
            param: request[method][param],
        }
        return SingleRequest(method_type, method, param, req)

    @staticmethod
    def _make_snake_name(name: str) -> str:
        """Convert camel or pascal case to snake name."""
        sn = "".join(["_" + i.lower() if i.isupper() else i for i in name]).lstrip("_")
        return sn

    @staticmethod
    def _make_smart_camera_single_request(
        request: str,
    ) -> SingleRequest:
        """Make a single request given a method name and no params.

        If method like getSomeThing then module will be some_thing.
        """
        method = request
        method_type = request[:3]
        snake_name = SmartCamProtocol._make_snake_name(request)
        param = snake_name[4:]
        if (
            (short_method := method[:3])
            and short_method in {"get", "set"}
            and method not in GET_METHODS_AS_DO
        ):
            method_type = short_method
            param = snake_name[4:]
        else:
            method_type = "do"
            param = snake_name
        req = {"method": method_type, param: {}}
        return SingleRequest(method_type, method, param, req)

    async def _execute_query(
        self, request: str | dict, *, retry_count: int, iterate_list_pages: bool = True
    ) -> dict:
        debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
        if isinstance(request, dict):
            method = next(iter(request))
            if len(request) == 1 and method in {"get", "set", "do", "multipleRequest"}:
                single_request = self._get_smart_camera_single_request(request)
            else:
                return await self._execute_multiple_query(
                    request, retry_count, iterate_list_pages
                )
        else:
            single_request = self._make_smart_camera_single_request(request)

        smart_request = json_dumps(single_request.request)
        if debug_enabled:
            _LOGGER.debug(
                "%s >> %s",
                self._host,
                pf(smart_request),
            )
        response_data = await self._transport.send(smart_request)

        if debug_enabled:
            _LOGGER.debug(
                "%s << %s",
                self._host,
                pf(response_data),
            )

        if "error_code" in response_data:
            # H200 does not return an error code
            self._handle_response_error_code(response_data, single_request.method_name)
        # Requests that are invalid and raise PROTOCOL_FORMAT_ERROR when sent
        # as a multipleRequest will return {} when sent as a single request.
        if single_request.method_type == "get" and (
            not (section := next(iter(response_data))) or response_data[section] == {}
        ):
            raise DeviceError(
                f"No results for get request {single_request.method_name}"
            )

        # TODO need to update handle response lists

        if single_request.method_type == "do":
            return {single_request.method_name: response_data}
        if single_request.method_type == "set":
            return {}
        if single_request.method_type == "multi":
            return {single_request.method_name: response_data["result"]}
        return {
            single_request.method_name: {
                single_request.param_name: response_data[single_request.param_name]
            }
        }


class _ChildCameraProtocolWrapper(SmartProtocol):
    """Protocol wrapper for controlling child devices.

    This is an internal class used to communicate with child devices,
    and should not be used directly.

    This class overrides query() method of the protocol to modify all
    outgoing queries to use ``controlChild`` command, and unwraps the
    device responses before returning to the caller.
    """

    def __init__(self, device_id: str, base_protocol: SmartProtocol) -> None:
        self._device_id = device_id
        self._protocol = base_protocol
        self._transport = base_protocol._transport

    async def query(self, request: str | dict, retry_count: int = 3) -> dict:
        """Wrap request inside controlChild envelope."""
        return await self._query(request, retry_count)

    async def _query(self, request: str | dict, retry_count: int = 3) -> dict:
        """Wrap request inside controlChild envelope."""
        if not isinstance(request, dict):
            raise KasaException("Child requests must be dictionaries.")
        requests = []
        methods = []
        for key, val in request.items():
            request = {
                "method": "controlChild",
                "params": {
                    "childControl": {
                        "device_id": self._device_id,
                        "request_data": {"method": key, "params": val},
                    }
                },
            }
            methods.append(key)
            requests.append(request)

        multipleRequest = {"multipleRequest": {"requests": requests}}

        response = await self._protocol.query(multipleRequest, retry_count)

        responses = response["multipleRequest"]["responses"]
        response_dict = {}

        # Raise errors for single calls
        raise_on_error = len(requests) == 1

        for index_id, response in enumerate(responses):
            response_data = response["result"]["response_data"]
            method = methods[index_id]
            self._handle_response_error_code(
                response_data, method, raise_on_error=raise_on_error
            )
            response_dict[method] = response_data.get("result")

        return response_dict

    async def close(self) -> None:
        """Do nothing as the parent owns the protocol."""