#!/usr/bin/env python3
"""CDP multiplexer — per-connection fingerprint seeds for stealth Chromium.

Spawns a separate Chrome process per unique fingerprint seed, routing CDP
connections through a single port. Each seed gets its own browser identity.

Usage:
    cloakserve                                      # default, backward compat
    cloakserve --port=9222                           # custom port

Client:
    browser = pw.chromium.connect_over_cdp("http://host:9222?fingerprint=12345")
    browser = pw.chromium.connect_over_cdp(
        "http://host:9222?fingerprint=12345&timezone=America/New_York&locale=en-US"
    )
"""

from __future__ import annotations

import asyncio
import json
import logging
import os
import random
import shutil
import socket
import subprocess
import sys
import time
from dataclasses import dataclass
from urllib.parse import parse_qs

from pathlib import Path

import aiohttp
import websockets
from aiohttp import web

from cloakbrowser.browser import build_args, maybe_resolve_geoip, _resolve_webrtc_args, _normalize_socks_string_url
from cloakbrowser.download import ensure_binary

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)s %(message)s",
    datefmt="%H:%M:%S",
)
logger = logging.getLogger("cloakserve")

# Args for running Chrome directly (outside Playwright).
# Playwright normally adds its own version of these.
BASE_CHROME_ARGS = [
    "--no-first-run",
    "--no-default-browser-check",
    "--disable-dev-shm-usage",
    "--disable-extensions",
    "--disable-popup-blocking",
    "--disable-background-networking",
    "--metrics-recording-only",
    "--ignore-gpu-blocklist",
]

BASE_CDP_PORT = 5100


# ---------------------------------------------------------------------------
# ChromeProcess — one running Chrome instance
# ---------------------------------------------------------------------------

@dataclass
class ChromeProcess:
    seed: str
    process: subprocess.Popen
    cdp_port: int
    user_data_dir: str
    timezone: str | None = None
    locale: str | None = None
    proxy: str | None = None


# ---------------------------------------------------------------------------
# ChromePool — manages multiple Chrome processes keyed by seed
# ---------------------------------------------------------------------------

class ChromePool:
    def __init__(
        self,
        binary: str,
        global_args: list[str],
        headless: bool,
        data_dir: str = "/tmp/cloakserve",
        default_seed: str | None = None,
        default_locale: str | None = None,
        default_timezone: str | None = None,
    ):
        self._binary = binary
        self._global_args = global_args
        self._headless = headless
        self._data_dir = data_dir
        self._default_seed = default_seed
        self._default_locale = default_locale
        self._default_timezone = default_timezone
        self._processes: dict[str, ChromeProcess] = {}
        self._default: ChromeProcess | None = None
        self._locks: dict[str, asyncio.Lock] = {}
        self._next_port = BASE_CDP_PORT
        # Connection refcounting for status reporting
        self._connections: dict[str, int] = {}

    def _get_lock(self, seed: str) -> asyncio.Lock:
        if seed not in self._locks:
            self._locks[seed] = asyncio.Lock()
        return self._locks[seed]

    def _allocate_port(self) -> int:
        """Find a free port starting from _next_port."""
        for _ in range(100):
            port = self._next_port
            self._next_port += 1
            try:
                with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
                    s.bind(("127.0.0.1", port))
                return port
            except OSError:
                continue
        raise RuntimeError("No free ports available for Chrome CDP")

    def connect(self, seed_key: str) -> None:
        """Increment connection refcount for a seed."""
        self._connections[seed_key] = self._connections.get(seed_key, 0) + 1

    def disconnect(self, seed_key: str) -> None:
        """Decrement connection refcount for a seed."""
        count = self._connections.get(seed_key, 0) - 1
        if count <= 0:
            self._connections.pop(seed_key, None)
        else:
            self._connections[seed_key] = count

    async def get_or_launch(
        self,
        seed: str | None,
        extra_args: list[str] | None = None,
        timezone: str | None = None,
        locale: str | None = None,
        proxy: str | None = None,
        geoip: bool = False,
    ) -> ChromeProcess:
        """Get existing or launch new Chrome process for a seed."""
        # Apply CLI defaults when query params don't provide values
        if seed is None and self._default_seed:
            seed = self._default_seed
        if locale is None:
            locale = self._default_locale
        if timezone is None:
            timezone = self._default_timezone

        # No seed = default shared process
        if seed is None:
            seed_key = "__default__"
            actual_seed = str(random.randint(10000, 99999))
        else:
            seed_key = seed
            actual_seed = seed

        lock = self._get_lock(seed_key)
        async with lock:
            # Check if already running (including default fast-path)
            if seed_key in self._processes:
                proc = self._processes[seed_key]
                if proc.process.poll() is None:
                    if any([extra_args, timezone, locale, proxy, geoip]):
                        logger.warning(
                            "Seed %s already running (port %d, tz=%s, locale=%s, proxy=%s) — "
                            "ignoring new params (first-launch wins)",
                            seed_key, proc.cdp_port,
                            proc.timezone, proc.locale, proc.proxy,
                        )
                    return proc
                # Dead — clean up
                await self._cleanup_process(seed_key)

            # Resolve geoip if requested
            exit_ip = None
            if geoip and proxy:
                timezone, locale, exit_ip = maybe_resolve_geoip(True, proxy, timezone, locale)

            # Build Chrome args via shared logic
            fp_extra = [f"--fingerprint={actual_seed}"]
            if extra_args:
                fp_extra.extend(extra_args)
            if proxy:
                fp_extra.append(f"--proxy-server={_normalize_socks_string_url(proxy)}")

            # WebRTC IP spoofing: resolve auto, inject geoip exit IP
            fp_extra = _resolve_webrtc_args(fp_extra, proxy)
            if exit_ip and not any(a.startswith("--fingerprint-webrtc-ip") for a in (fp_extra or [])):
                fp_extra = list(fp_extra or [])
                fp_extra.append(f"--fingerprint-webrtc-ip={exit_ip}")

            chrome_args = build_args(
                stealth_args=True,
                extra_args=fp_extra,
                timezone=timezone,
                locale=locale,
                headless=self._headless,
            )

            # Allocate port and user data dir
            port = self._allocate_port()
            user_data_dir = os.path.join(self._data_dir, seed_key)
            os.makedirs(user_data_dir, exist_ok=True)

            full_args = (
                [self._binary]
                + BASE_CHROME_ARGS
                + chrome_args
                + self._global_args
                + [
                    f"--remote-debugging-port={port}",
                    "--remote-debugging-address=127.0.0.1",
                    f"--user-data-dir={user_data_dir}",
                ]
            )

            logger.info("Launching Chrome (seed=%s, port=%d)", actual_seed, port)
            process = subprocess.Popen(
                full_args,
                stdout=subprocess.DEVNULL,
            )

            # Wait for CDP to be ready
            if not await self._wait_for_cdp(port):
                process.kill()
                await asyncio.to_thread(process.wait, timeout=5)
                await asyncio.to_thread(shutil.rmtree, user_data_dir, True)
                raise web.HTTPBadGateway(
                    text=json.dumps({"error": "Chrome failed to start"}),
                    content_type="application/json",
                )

            cp = ChromeProcess(
                seed=actual_seed,
                process=process,
                cdp_port=port,
                user_data_dir=user_data_dir,
                timezone=timezone,
                locale=locale,
                proxy=proxy,
            )
            self._processes[seed_key] = cp

            if seed is None:
                self._default = cp

            logger.info("Chrome ready (seed=%s, port=%d, pid=%d)", actual_seed, port, process.pid)
            return cp

    async def _cleanup_process(self, key: str) -> None:
        """Terminate a Chrome process and clean up."""
        proc = self._processes.pop(key, None)
        if not proc:
            return
        if proc.process.poll() is None:
            proc.process.terminate()
            try:
                await asyncio.to_thread(proc.process.wait, timeout=5)
            except subprocess.TimeoutExpired:
                proc.process.kill()
        # Clean up user data dir (can be slow for large profiles)
        await asyncio.to_thread(shutil.rmtree, proc.user_data_dir, True)
        if self._default is proc:
            self._default = None
        self._locks.pop(key, None)
        self._connections.pop(key, None)

    async def shutdown(self) -> None:
        """Terminate all Chrome processes."""
        for key in list(self._processes.keys()):
            await self._cleanup_process(key)
        logger.info("All Chrome processes terminated")

    @staticmethod
    async def _wait_for_cdp(port: int, timeout: float = 10.0) -> bool:
        """Poll Chrome's /json/version until ready."""
        deadline = time.monotonic() + timeout
        delay = 0.1
        session = aiohttp.ClientSession(
            timeout=aiohttp.ClientTimeout(total=1)
        )
        try:
            while time.monotonic() < deadline:
                try:
                    async with session.get(
                        f"http://127.0.0.1:{port}/json/version"
                    ) as resp:
                        if resp.status == 200:
                            return True
                except Exception:
                    pass
                await asyncio.sleep(delay)
                delay = min(delay * 2, 1.0)
            return False
        finally:
            await session.close()


# ---------------------------------------------------------------------------
# Query param parsing
# ---------------------------------------------------------------------------

# Params that need special handling (not simple --fingerprint-{name}= mapping)
SPECIAL_PARAMS = {"fingerprint", "proxy", "geoip", "locale", "timezone"}


def parse_connection_params(query_string: str) -> dict:
    """Parse query params into connection config."""
    qs = parse_qs(query_string, keep_blank_values=False)

    result: dict = {
        "seed": None,
        "timezone": None,
        "locale": None,
        "proxy": None,
        "geoip": False,
        "extra_args": [],
    }

    for key, values in qs.items():
        val = values[0]
        if key == "fingerprint":
            result["seed"] = val
        elif key == "timezone":
            result["timezone"] = val
        elif key == "locale":
            result["locale"] = val
        elif key == "proxy":
            result["proxy"] = val
        elif key == "geoip":
            result["geoip"] = val.lower() in ("true", "1", "yes")
        elif key not in SPECIAL_PARAMS:
            # Generic fingerprint param: map to --fingerprint-{key}={val}
            result["extra_args"].append(f"--fingerprint-{key}={val}")

    return result


# ---------------------------------------------------------------------------
# HTTP handlers
# ---------------------------------------------------------------------------

def _ws_scheme(request: web.Request) -> str:
    """Return 'wss' if client connected via HTTPS (e.g. TLS-terminating proxy), else 'ws'."""
    proto = request.headers.get("X-Forwarded-Proto", request.scheme)
    return "wss" if proto == "https" else "ws"


async def handle_root(request: web.Request) -> web.Response:
    """Health check / process status."""
    pool: ChromePool = request.app["pool"]
    processes = {}
    for key, proc in pool._processes.items():
        if proc.process.poll() is None:
            processes[key] = {
                "pid": proc.process.pid,
                "port": proc.cdp_port,
                "seed": proc.seed,
                "connections": pool._connections.get(key, 0),
                "timezone": proc.timezone,
                "locale": proc.locale,
                "proxy": proc.proxy,
            }
    return web.json_response({
        "status": "ok",
        "active": len(processes),
        "processes": processes,
    })


async def handle_json_version(request: web.Request) -> web.Response:
    """Proxy /json/version with optional per-seed routing."""
    pool: ChromePool = request.app["pool"]
    params = parse_connection_params(request.query_string)

    cp = await pool.get_or_launch(
        seed=params["seed"],
        extra_args=params["extra_args"] or None,
        timezone=params["timezone"],
        locale=params["locale"],
        proxy=params["proxy"],
        geoip=params["geoip"],
    )

    try:
        async with aiohttp.ClientSession() as session:
            async with session.get(
                f"http://127.0.0.1:{cp.cdp_port}/json/version",
                timeout=aiohttp.ClientTimeout(total=5),
            ) as resp:
                data = await resp.json()
    except Exception as exc:
        logger.error("Failed to reach Chrome CDP (port %d): %s", cp.cdp_port, exc)
        return web.json_response({"error": "CDP endpoint unreachable"}, status=502)

    # Rewrite webSocketDebuggerUrl to route through our multiplexer
    host = request.headers.get("Host", f"localhost:{request.app['port']}")
    seed_key = params["seed"]
    if seed_key:
        ws_path = f"fingerprint/{seed_key}/devtools/browser"
    else:
        ws_path = "devtools/browser"

    # Extract the browser GUID from Chrome's original URL
    orig_ws = data.get("webSocketDebuggerUrl", "")
    guid = orig_ws.rsplit("/", 1)[-1] if "/devtools/" in orig_ws else ""

    scheme = _ws_scheme(request)
    data["webSocketDebuggerUrl"] = f"{scheme}://{host}/{ws_path}/{guid}"
    return web.json_response(data)


async def handle_json_list(request: web.Request) -> web.Response:
    """Proxy /json/list with per-seed routing. Rewrites all entries."""
    pool: ChromePool = request.app["pool"]
    params = parse_connection_params(request.query_string)

    cp = await pool.get_or_launch(
        seed=params["seed"],
        extra_args=params["extra_args"] or None,
        timezone=params["timezone"],
        locale=params["locale"],
        proxy=params["proxy"],
        geoip=params["geoip"],
    )

    try:
        async with aiohttp.ClientSession() as session:
            async with session.get(
                f"http://127.0.0.1:{cp.cdp_port}/json/list",
                timeout=aiohttp.ClientTimeout(total=5),
            ) as resp:
                data = await resp.json()
    except Exception as exc:
        logger.error("Failed to reach Chrome CDP (port %d): %s", cp.cdp_port, exc)
        return web.json_response({"error": "CDP endpoint unreachable"}, status=502)

    host = request.headers.get("Host", f"localhost:{request.app['port']}")
    scheme = _ws_scheme(request)
    seed_key = params["seed"]

    for entry in data:
        if "webSocketDebuggerUrl" in entry:
            ws_tail = entry["webSocketDebuggerUrl"].split("/devtools/")[-1]
            if seed_key:
                entry["webSocketDebuggerUrl"] = (
                    f"{scheme}://{host}/fingerprint/{seed_key}/devtools/{ws_tail}"
                )
            else:
                entry["webSocketDebuggerUrl"] = f"{scheme}://{host}/devtools/{ws_tail}"

    return web.json_response(data)


# ---------------------------------------------------------------------------
# WebSocket proxy
# ---------------------------------------------------------------------------

async def proxy_cdp_websocket(
    client_ws: web.WebSocketResponse,
    target_url: str,
    label: str,
) -> None:
    """Bidirectional WebSocket proxy between client and Chrome CDP."""
    try:
        async with websockets.connect(
            target_url, max_size=None, ping_interval=None, ping_timeout=None,
        ) as cdp_ws:
            logger.info("%s: connected to %s", label, target_url)

            async def client_to_cdp():
                try:
                    async for msg in client_ws:
                        if msg.type == aiohttp.WSMsgType.TEXT:
                            await cdp_ws.send(msg.data)
                        elif msg.type == aiohttp.WSMsgType.BINARY:
                            await cdp_ws.send(msg.data)
                        elif msg.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING, aiohttp.WSMsgType.CLOSED):
                            break
                except Exception as exc:
                    logger.debug("%s [c->cdp]: %s", label, exc)

            async def cdp_to_client():
                try:
                    async for msg in cdp_ws:
                        if isinstance(msg, str):
                            await client_ws.send_str(msg)
                        else:
                            await client_ws.send_bytes(msg)
                except Exception as exc:
                    logger.debug("%s [cdp->c]: %s", label, exc)

            c2d = asyncio.create_task(client_to_cdp(), name="c2d")
            d2c = asyncio.create_task(cdp_to_client(), name="d2c")
            done, pending = await asyncio.wait(
                [c2d, d2c], return_when=asyncio.FIRST_COMPLETED,
            )
            for task in pending:
                task.cancel()
            logger.info("%s: disconnected", label)

    except Exception as exc:
        logger.error("%s error: %s", label, exc)


async def handle_ws_default(request: web.Request) -> web.WebSocketResponse:
    """WebSocket proxy for default (no-seed) Chrome: /devtools/{type}/{guid}"""
    pool: ChromePool = request.app["pool"]
    path = request.match_info.get("path", "")

    cp = await pool.get_or_launch(seed=None)

    ws = web.WebSocketResponse()
    await ws.prepare(request)

    pool.connect("__default__")
    try:
        target_url = f"ws://127.0.0.1:{cp.cdp_port}/devtools/{path}"
        await proxy_cdp_websocket(ws, target_url, f"CDP default [{path}]")
    finally:
        pool.disconnect("__default__")
    return ws


async def handle_ws_seed(request: web.Request) -> web.WebSocketResponse:
    """WebSocket proxy for seed-specific Chrome: /fingerprint/{seed}/devtools/{type}/{guid}"""
    pool: ChromePool = request.app["pool"]
    seed = request.match_info["seed"]
    path = request.match_info.get("path", "")

    cp = await pool.get_or_launch(seed=seed)

    ws = web.WebSocketResponse()
    await ws.prepare(request)

    pool.connect(seed)
    try:
        target_url = f"ws://127.0.0.1:{cp.cdp_port}/devtools/{path}"
        await proxy_cdp_websocket(ws, target_url, f"CDP seed={seed} [{path}]")
    finally:
        pool.disconnect(seed)
    return ws


async def on_shutdown(app: web.Application) -> None:
    await app["pool"].shutdown()


# ---------------------------------------------------------------------------
# CLI arg parsing
# ---------------------------------------------------------------------------

def _default_data_dir() -> str:
    """Smart default: Docker → /tmp/cloakserve, bare metal → ~/.cloakbrowser/cloakserve."""
    if os.path.exists("/.dockerenv"):
        return "/tmp/cloakserve"
    return str(Path.home() / ".cloakbrowser" / "cloakserve")


def parse_cli_args(argv: list[str]) -> tuple[dict, list[str]]:
    """Parse cloakserve-specific args, return (config, passthrough_args).

    --fingerprint, --fingerprint-locale, and --fingerprint-timezone are
    extracted into config defaults so they route through build_args()
    (e.g. locale needs both --lang and --fingerprint-locale).
    Query-string params override these defaults per-connection.
    """
    config: dict = {
        "port": 9222,
        "headless": True,
        "data_dir": None,
        "default_seed": None,
        "default_locale": None,
        "default_timezone": None,
    }
    passthrough = []
    # Flags consumed by cloakserve (not passed to Chrome)
    consumed_prefixes = (
        "--port=",
        "--data-dir=",
        "--remote-debugging-port=",
        "--remote-debugging-address=",
    )

    for arg in argv:
        if arg.startswith("--port="):
            config["port"] = int(arg.split("=", 1)[1])
        elif arg.startswith("--data-dir="):
            config["data_dir"] = arg.split("=", 1)[1]
        elif arg == "--headless=false" or arg == "--headless=False":
            config["headless"] = False
            passthrough.append(arg)
        elif arg.startswith(consumed_prefixes):
            pass  # Strip these silently
        # Route through build_args() so companion flags are set correctly
        elif arg.startswith("--fingerprint-locale="):
            config["default_locale"] = arg.split("=", 1)[1]
        elif arg.startswith("--fingerprint-timezone="):
            config["default_timezone"] = arg.split("=", 1)[1]
        elif arg.startswith("--fingerprint="):
            config["default_seed"] = arg.split("=", 1)[1]
        else:
            passthrough.append(arg)

    if config["data_dir"] is None:
        config["data_dir"] = _default_data_dir()

    return config, passthrough


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

def main() -> None:
    binary = ensure_binary()
    config, global_args = parse_cli_args(sys.argv[1:])

    pool = ChromePool(
        binary=binary,
        global_args=global_args,
        headless=config["headless"],
        data_dir=config["data_dir"],
        default_seed=config["default_seed"],
        default_locale=config["default_locale"],
        default_timezone=config["default_timezone"],
    )

    app = web.Application()
    app["pool"] = pool
    app["port"] = config["port"]

    # Routes
    app.router.add_get("/", handle_root)
    app.router.add_get("/json/version", handle_json_version)
    app.router.add_get("/json/version/", handle_json_version)
    app.router.add_get("/json/list", handle_json_list)
    app.router.add_get("/json/list/", handle_json_list)
    app.router.add_get("/json", handle_json_list)
    app.router.add_get("/json/", handle_json_list)

    # WebSocket routes — seed-specific (must be before default to match first)
    app.router.add_get("/fingerprint/{seed}/devtools/{path:.+}", handle_ws_seed)
    # WebSocket routes — default (no seed)
    app.router.add_get("/devtools/{path:.+}", handle_ws_default)

    app.on_shutdown.append(on_shutdown)

    port = config["port"]
    logger.info("CloakBrowser CDP multiplexer starting on port %d", port)
    logger.info(
        "Connect: playwright.chromium.connect_over_cdp("
        "\"http://localhost:%d?fingerprint=<seed>\")",
        port,
    )

    web.run_app(app, host="0.0.0.0", port=port, print=None)


if __name__ == "__main__":
    main()
