from __future__ import annotations

import copy
from dataclasses import dataclass
from json import dumps as json_dumps
from typing import Any, TypedDict

import pytest

from kasa.transports.xortransport import XorEncryption

from .fakeprotocol_iot import FakeIotProtocol
from .fakeprotocol_smart import FakeSmartProtocol, FakeSmartTransport
from .fakeprotocol_smartcam import FakeSmartCamProtocol
from .fixtureinfo import FixtureInfo, filter_fixtures, idgenerator

DISCOVERY_MOCK_IP = "127.0.0.123"


class DiscoveryResponse(TypedDict):
    result: dict[str, Any]
    error_code: int


UNSUPPORTED_HOMEWIFISYSTEM = {
    "error_code": 0,
    "result": {
        "channel_2g": "10",
        "channel_5g": "44",
        "device_id": "REDACTED_51f72a752213a6c45203530",
        "device_model": "X20",
        "device_type": "HOMEWIFISYSTEM",
        "factory_default": False,
        "group_id": "REDACTED_07d902da02fa9beab8a64",
        "group_name": "I01BU0tFRF9TU0lEIw==",  # '#MASKED_SSID#'
        "hardware_version": "3.0",
        "ip": "192.168.1.192",
        "mac": "24:2F:D0:00:00:00",
        "master_device_id": "REDACTED_51f72a752213a6c45203530",
        "need_account_digest": True,
        "owner": "REDACTED_341c020d7e8bda184e56a90",
        "role": "master",
        "tmp_port": [20001],
    },
}


def _make_unsupported(
    device_family,
    encrypt_type,
    *,
    https: bool = False,
    omit_keys: dict[str, Any] | None = None,
) -> DiscoveryResponse:
    if omit_keys is None:
        omit_keys = {"encrypt_info": None}
    result: DiscoveryResponse = {
        "result": {
            "device_id": "xx",
            "owner": "xx",
            "device_type": device_family,
            "device_model": "P110(EU)",
            "ip": "127.0.0.1",
            "mac": "48-22xxx",
            "is_support_iot_cloud": True,
            "obd_src": "tplink",
            "factory_default": False,
            "mgt_encrypt_schm": {
                "is_support_https": https,
                "encrypt_type": encrypt_type,
                "http_port": 80,
                "lv": 2,
            },
            "encrypt_info": {"data": "", "key": "", "sym_schm": encrypt_type},
        },
        "error_code": 0,
    }
    for key, val in omit_keys.items():
        if val is None:
            result["result"].pop(key)
        else:
            result["result"][key].pop(val)

    return result


UNSUPPORTED_DEVICES = {
    "unknown_device_family": _make_unsupported("SMART.TAPOXMASTREE", "AES"),
    "unknown_iot_device_family": _make_unsupported("IOT.IOTXMASTREE", "AES"),
    "wrong_encryption_iot": _make_unsupported("IOT.SMARTPLUGSWITCH", "AES"),
    "wrong_encryption_smart": _make_unsupported("SMART.TAPOBULB", "IOT"),
    "unknown_encryption": _make_unsupported("IOT.SMARTPLUGSWITCH", "FOO"),
    "missing_encrypt_type": _make_unsupported(
        "SMART.TAPOBULB",
        "FOO",
        omit_keys={"mgt_encrypt_schm": "encrypt_type", "encrypt_info": None},
    ),
    "unable_to_parse": _make_unsupported(
        "SMART.TAPOBULB",
        "FOO",
        omit_keys={"device_id": None},
    ),
    "invalidinstance": _make_unsupported(
        "IOT.SMARTPLUGSWITCH",
        "KLAP",
        https=True,
    ),
    "homewifi": UNSUPPORTED_HOMEWIFISYSTEM,
}


def parametrize_discovery(
    desc, *, data_root_filter=None, protocol_filter=None, model_filter=None
):
    filtered_fixtures = filter_fixtures(
        desc,
        data_root_filter=data_root_filter,
        protocol_filter=protocol_filter,
        model_filter=model_filter,
    )
    return pytest.mark.parametrize(
        "discovery_mock",
        filtered_fixtures,
        indirect=True,
        ids=idgenerator,
    )


new_discovery = parametrize_discovery(
    "new discovery", data_root_filter="discovery_result"
)

smart_discovery = parametrize_discovery("smart discovery", protocol_filter={"SMART"})


@pytest.fixture(
    params=filter_fixtures("discoverable", protocol_filter={"SMART", "IOT"}),
    ids=idgenerator,
)
async def discovery_mock(request, mocker):
    """Mock discovery and patch protocol queries to use Fake protocols."""
    fi: FixtureInfo = request.param
    fixture_info = FixtureInfo(fi.name, fi.protocol, copy.deepcopy(fi.data))
    return patch_discovery({DISCOVERY_MOCK_IP: fixture_info}, mocker)


def create_discovery_mock(ip: str, fixture_data: dict):
    """Mock discovery and patch protocol queries to use Fake protocols."""

    @dataclass
    class _DiscoveryMock:
        ip: str
        default_port: int
        discovery_port: int
        discovery_data: dict
        query_data: dict
        device_type: str
        encrypt_type: str
        https: bool
        login_version: int | None = None
        port_override: int | None = None

        @property
        def model(self) -> str:
            dd = self.discovery_data
            model_region = (
                dd["result"]["device_model"]
                if self.discovery_port == 20002
                else dd["system"]["get_sysinfo"]["model"]
            )
            model, _, _ = model_region.partition("(")
            return model

        @property
        def _datagram(self) -> bytes:
            if self.default_port == 9999:
                return XorEncryption.encrypt(json_dumps(self.discovery_data))[4:]
            else:
                return (
                    b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8"
                    + json_dumps(self.discovery_data).encode()
                )

    if "discovery_result" in fixture_data:
        discovery_data = fixture_data["discovery_result"].copy()
        discovery_result = fixture_data["discovery_result"]["result"]
        device_type = discovery_result["device_type"]
        encrypt_type = discovery_result["mgt_encrypt_schm"].get(
            "encrypt_type", discovery_result.get("encrypt_info", {}).get("sym_schm")
        )

        if not (login_version := discovery_result["mgt_encrypt_schm"].get("lv")) and (
            et := discovery_result.get("encrypt_type")
        ):
            login_version = max([int(i) for i in et])
        https = discovery_result["mgt_encrypt_schm"]["is_support_https"]
        dm = _DiscoveryMock(
            ip,
            80,
            20002,
            discovery_data,
            fixture_data,
            device_type,
            encrypt_type,
            https,
            login_version,
        )
    else:
        sys_info = fixture_data["system"]["get_sysinfo"]
        discovery_data = {"system": {"get_sysinfo": sys_info.copy()}}
        device_type = sys_info.get("mic_type") or sys_info.get("type")
        encrypt_type = "XOR"
        login_version = None
        dm = _DiscoveryMock(
            ip,
            9999,
            9999,
            discovery_data,
            fixture_data,
            device_type,
            encrypt_type,
            False,
            login_version,
        )

    return dm


def patch_discovery(fixture_infos: dict[str, FixtureInfo], mocker):
    """Mock discovery and patch protocol queries to use Fake protocols."""
    discovery_mocks = {
        ip: create_discovery_mock(ip, fixture_info.data)
        for ip, fixture_info in fixture_infos.items()
    }
    protos = {
        ip: FakeSmartProtocol(fixture_info.data, fixture_info.name)
        if fixture_info.protocol in {"SMART", "SMART.CHILD"}
        else FakeSmartCamProtocol(fixture_info.data, fixture_info.name)
        if fixture_info.protocol in {"SMARTCAM", "SMARTCAM.CHILD"}
        else FakeIotProtocol(fixture_info.data, fixture_info.name)
        for ip, fixture_info in fixture_infos.items()
    }
    first_ip = list(fixture_infos.keys())[0]
    first_host = None

    async def mock_discover(self):
        """Call datagram_received for all mock fixtures.

        Handles test cases modifying the ip and hostname of the first fixture
        for discover_single testing.
        """
        for ip, dm in discovery_mocks.items():
            first_ip = list(discovery_mocks.values())[0].ip
            fixture_info = fixture_infos[ip]
            # Ip of first fixture could have been modified by a test
            if dm.ip == first_ip:
                # hostname could have been used
                host = first_host if first_host else first_ip
            else:
                host = dm.ip
            # update the protos for any host testing or the test overriding the first ip
            protos[host] = (
                FakeSmartProtocol(fixture_info.data, fixture_info.name)
                if fixture_info.protocol in {"SMART", "SMART.CHILD"}
                else FakeSmartCamProtocol(fixture_info.data, fixture_info.name)
                if fixture_info.protocol in {"SMARTCAM", "SMARTCAM.CHILD"}
                else FakeIotProtocol(fixture_info.data, fixture_info.name)
            )
            port = (
                dm.port_override
                if dm.port_override and dm.discovery_port != 20002
                else dm.discovery_port
            )
            self.datagram_received(
                dm._datagram,
                (dm.ip, port),
            )

    async def _query(self, request, retry_count: int = 3):
        return await protos[self._host].query(request)

    def _getaddrinfo(host, *_, **__):
        nonlocal first_host, first_ip
        first_host = host  # Store the hostname used by discover single
        first_ip = list(discovery_mocks.values())[
            0
        ].ip  # ip could have been overridden in test
        return [(None, None, None, None, (first_ip, 0))]

    mocker.patch("kasa.IotProtocol.query", _query)
    mocker.patch("kasa.SmartProtocol.query", _query)
    mocker.patch("kasa.discover._DiscoverProtocol.do_discover", mock_discover)
    mocker.patch(
        "socket.getaddrinfo",
        # side_effect=lambda *_, **__: [(None, None, None, None, (first_ip, 0))],
        side_effect=_getaddrinfo,
    )
    # Only return the first discovery mock to be used for testing discover single
    return discovery_mocks[first_ip]


@pytest.fixture(
    params=filter_fixtures("discoverable", protocol_filter={"SMART", "IOT"}),
    ids=idgenerator,
)
def discovery_data(request, mocker):
    """Return raw discovery file contents as JSON. Used for discovery tests."""
    fixture_info = request.param
    fixture_data = copy.deepcopy(fixture_info.data)
    # Add missing queries to fixture data
    if "component_nego" in fixture_data:
        components = {
            comp["id"]: int(comp["ver_code"])
            for comp in fixture_data["component_nego"]["component_list"]
        }
        for k, v in FakeSmartTransport.FIXTURE_MISSING_MAP.items():
            # Value is a tuple of component,reponse
            if k not in fixture_data and v[0] in components:
                fixture_data[k] = v[1]
    mocker.patch("kasa.IotProtocol.query", return_value=fixture_data)
    mocker.patch("kasa.SmartProtocol.query", return_value=fixture_data)
    if "discovery_result" in fixture_data:
        return fixture_data["discovery_result"].copy()
    else:
        return {"system": {"get_sysinfo": fixture_data["system"]["get_sysinfo"]}}


@pytest.fixture(
    params=UNSUPPORTED_DEVICES.values(), ids=list(UNSUPPORTED_DEVICES.keys())
)
def unsupported_device_info(request, mocker):
    """Return unsupported devices for cli and discovery tests."""
    discovery_data = request.param
    host = "127.0.0.1"

    async def mock_discover(self):
        if discovery_data:
            data = (
                b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8"
                + json_dumps(discovery_data).encode()
            )
            self.datagram_received(data, (host, 20002))

    mocker.patch("kasa.discover._DiscoverProtocol.do_discover", mock_discover)

    return discovery_data