"""Auth bridge generado por spa-cli — corre Lambda Authorizers como middleware FastAPI.

Lee `openapi.json` (mismo dir), indexa `(method, path) -> [scheme_name]`,
y para cada request invoca el `lambda_handler` real del authorizer.

NO EDITAR A MANO. Regenerado por `spa project build --build-mode container`.
"""
from __future__ import annotations

import importlib
import json
import os
import re
import uuid
from datetime import datetime
from pathlib import Path
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple

from fastapi import Request
from fastapi.responses import JSONResponse


_OPENAPI_PATH = Path(__file__).parent / "openapi.json"
_AUTHORIZER_REGISTRY: Dict[str, Dict[str, Any]] = json.loads(
    (Path(__file__).parent / "auth_bridge.config.json").read_text(encoding="utf-8")
)


class _MockContext:
    def __init__(self, function_name: str = "spa-authorizer"):
        self.function_name = function_name
        self.memory_limit_in_mb = 128
        self.invoked_function_arn = (
            f"arn:aws:lambda:us-east-1:123456789012:function:{function_name}"
        )
        self.aws_request_id = str(uuid.uuid4())


def _path_to_regex(openapi_path: str) -> re.Pattern:
    pattern = re.sub(r"\{[^/]+?\}", r"[^/]+", openapi_path)
    return re.compile(rf"^{pattern}/?$")


def _build_security_index(openapi: Dict[str, Any]) -> List[Tuple[re.Pattern, str, List[str]]]:
    paths = openapi.get("paths", {}) or {}
    index: List[Tuple[re.Pattern, str, List[str]]] = []
    for path, methods in paths.items():
        if not isinstance(methods, dict):
            continue
        path_regex = _path_to_regex(path)
        for method, op in methods.items():
            if not isinstance(op, dict):
                continue
            sec = op.get("security") or openapi.get("security") or []
            scheme_names: List[str] = []
            for entry in sec:
                if isinstance(entry, dict):
                    scheme_names.extend(entry.keys())
            if scheme_names:
                index.append((path_regex, method.upper(), scheme_names))
    return index


def _resolve_handler(scheme_name: str, openapi: Dict[str, Any]) -> Optional[Callable]:
    schemes = (
        openapi.get("components", {}).get("securitySchemes")
        or openapi.get("securityDefinitions")
        or {}
    )
    scheme_cfg = schemes.get(scheme_name)
    if not scheme_cfg:
        return None

    key = scheme_name.replace("_authorizer", "")
    cfg = _AUTHORIZER_REGISTRY.get(key, {})

    candidates: List[Tuple[str, str]] = []
    if cfg.get("module"):
        candidates.append((cfg["module"], cfg.get("handler") or "lambda_handler"))
    candidates.append((f"src.authorizers.{key}.handler", "lambda_handler"))
    candidates.append(
        (f"build.infra.components.authorizers.{key}.lambda_function", "lambda_handler")
    )
    candidates.append(
        (f"infra.components.authorizers.{key}.lambda_function", "lambda_handler")
    )

    for module_path, attr in candidates:
        try:
            mod = importlib.import_module(module_path)
            fn = getattr(mod, attr, None)
            if callable(fn):
                return fn
        except ImportError:
            continue
    return None


def _extract_token(scheme_cfg: Dict[str, Any], request: Request) -> Optional[str]:
    where = (scheme_cfg.get("in") or "header").lower()
    name = scheme_cfg.get("name") or "Authorization"
    if where == "header":
        return request.headers.get(name)
    if where == "query":
        return request.query_params.get(name)
    if where == "cookie":
        return request.cookies.get(name)
    apigw = scheme_cfg.get("x-amazon-apigateway-authorizer", {})
    identity = apigw.get("identitySource", "")
    if identity.startswith("method.request.header."):
        return request.headers.get(identity.split(".")[-1])
    if identity.startswith("method.request.querystring."):
        return request.query_params.get(identity.split(".")[-1])
    return None


def _build_apigw_event(request: Request, raw_token: str, path: str) -> Dict[str, Any]:
    method_arn = (
        f"arn:aws:execute-api:us-east-1:123456789012:apiid/$default/"
        f"{request.method.upper()}{path}"
    )
    headers = dict(request.headers)
    return {
        "type": "TOKEN",
        "authorizationToken": raw_token,
        "methodArn": method_arn,
        "headers": headers,
        "httpMethod": request.method.upper(),
        "path": path,
        "requestContext": {
            "accountId": "123456789012",
            "apiId": "apiid",
            "httpMethod": request.method.upper(),
            "path": path,
            "stage": "$default",
            "requestId": str(uuid.uuid4()),
            "time": datetime.utcnow().strftime("%d/%b/%Y:%H:%M:%S +0000"),
        },
    }


def _is_allow(policy: Dict[str, Any]) -> bool:
    doc = policy.get("policyDocument") or {}
    statements = doc.get("Statement") or []
    return any((s.get("Effect") == "Allow") for s in statements)


def build_security_middleware() -> Callable[[Request, Callable[[Request], Awaitable[Any]]], Awaitable[Any]]:
    if os.getenv("AUTH_DISABLED", "").lower() in ("1", "true", "yes"):
        async def passthrough(request: Request, call_next):
            return await call_next(request)
        return passthrough

    if not _OPENAPI_PATH.exists():
        async def noop(request: Request, call_next):
            return await call_next(request)
        return noop

    openapi = json.loads(_OPENAPI_PATH.read_text(encoding="utf-8"))
    security_index = _build_security_index(openapi)
    schemes_def = (
        openapi.get("components", {}).get("securitySchemes")
        or openapi.get("securityDefinitions")
        or {}
    )

    handler_cache: Dict[str, Optional[Callable]] = {}

    def _get_handler(name: str) -> Optional[Callable]:
        if name not in handler_cache:
            handler_cache[name] = _resolve_handler(name, openapi)
        return handler_cache[name]

    async def middleware(request: Request, call_next):
        path = request.url.path
        env_prefix = "/" + os.getenv("ENVIRONMENT", "dev").lower()
        match_path = path[len(env_prefix):] if path.startswith(env_prefix + "/") else path

        matched_schemes: List[str] = []
        for path_regex, method, scheme_names in security_index:
            if method != request.method.upper():
                continue
            if path_regex.match(match_path):
                matched_schemes = scheme_names
                break

        if not matched_schemes:
            return await call_next(request)

        for scheme_name in matched_schemes:
            scheme_cfg = schemes_def.get(scheme_name)
            if not scheme_cfg:
                continue

            # api_key scheme — validate against API_KEY env var (API Gateway native, no Lambda authorizer)
            if scheme_name == "api_key":
                expected_key = os.getenv("API_KEY", "")
                if not expected_key:
                    continue  # no API_KEY configured → skip validation
                actual_key = request.headers.get(scheme_cfg.get("name", "x-api-key"), "")
                if actual_key != expected_key:
                    return JSONResponse(status_code=403, content={"message": "Invalid API key"})
                continue

            handler = _get_handler(scheme_name)
            if handler is None:
                return JSONResponse(
                    status_code=403,
                    content={"message": f"Forbidden"},
                )
            token = _extract_token(scheme_cfg, request)
            if not token:
                return JSONResponse(status_code=401, content={"message": "Unauthorized"})

            event = _build_apigw_event(request, token, match_path)
            try:
                policy = handler(event, _MockContext())
            except Exception as exc:
                msg = str(exc)
                if msg == "Unauthorized":
                    return JSONResponse(status_code=401, content={"message": "Unauthorized"})
                if msg == "MalformedToken":
                    return JSONResponse(status_code=422, content={"message": "MalformedToken"})
                return JSONResponse(status_code=403, content={"message": "Forbidden"})

            if not isinstance(policy, dict) or not _is_allow(policy):
                return JSONResponse(status_code=403, content={"message": "Forbidden"})

            request.state.authorizer = {
                "principalId": policy.get("principalId"),
                "context": policy.get("context", {}),
                "scheme": scheme_name,
            }
            break

        return await call_next(request)

    return middleware
