#!python
import argparse
import curses
import datetime as dt
import json
import os
import shutil
import subprocess
import sys
from dataclasses import dataclass
from pathlib import Path


HOME = Path.home()
CODEX_HOME = Path(os.environ.get("CODEX_HOME", HOME / ".codex")).expanduser()
CLAUDE_HOME = Path(os.environ.get("CLAUDE_CONFIG_DIR", HOME / ".claude")).expanduser()
VERSION = "0.1.0"


@dataclass
class Session:
    provider: str
    session_id: str
    title: str
    cwd: str
    updated: float
    count: int = 0


def parse_time(value):
    if value is None:
        return 0.0
    if isinstance(value, (int, float)):
        return float(value) / 1000 if value > 10_000_000_000 else float(value)
    if isinstance(value, str):
        text = value.replace("Z", "+00:00")
        try:
            return dt.datetime.fromisoformat(text).timestamp()
        except ValueError:
            return 0.0
    return 0.0


def trim(text, width=88):
    text = " ".join(str(text or "").split())
    if len(text) <= width:
        return text
    return text[: width - 1].rstrip() + "…"


def relpath(path):
    if not path:
        return ""
    try:
        return str(Path(path).expanduser()).replace(str(HOME), "~", 1)
    except Exception:
        return path


def age(ts):
    if not ts:
        return "unknown"
    seconds = max(0, int(dt.datetime.now().timestamp() - ts))
    units = [
        ("y", 365 * 24 * 3600),
        ("mo", 30 * 24 * 3600),
        ("d", 24 * 3600),
        ("h", 3600),
        ("m", 60),
    ]
    for suffix, size in units:
        if seconds >= size:
            return f"{seconds // size}{suffix} ago"
    return "now"


def iter_jsonl(path):
    try:
        with path.open("r", encoding="utf-8") as handle:
            for line in handle:
                line = line.strip()
                if not line:
                    continue
                try:
                    yield json.loads(line)
                except json.JSONDecodeError:
                    continue
    except OSError:
        return


def codex_sessions():
    history = {}
    for item in iter_jsonl(CODEX_HOME / "history.jsonl"):
        session_id = item.get("session_id")
        if not session_id:
            continue
        ts = parse_time(item.get("ts"))
        text = item.get("text") or ""
        previous = history.get(session_id)
        if previous is None or ts >= previous["updated"]:
            history[session_id] = {"updated": ts, "title": text}

    sessions = {}
    for path in (CODEX_HOME / "sessions").glob("**/*.jsonl"):
        session_id = None
        cwd = ""
        created = 0.0
        for item in iter_jsonl(path):
            if item.get("type") != "session_meta":
                continue
            payload = item.get("payload") or {}
            session_id = payload.get("id") or session_id
            cwd = payload.get("cwd") or cwd
            created = parse_time(payload.get("timestamp") or item.get("timestamp"))
            break
        if not session_id:
            stem = path.stem
            parts = stem.split("-")
            if len(parts) >= 7:
                session_id = "-".join(parts[-5:])
        if not session_id:
            continue
        stat_updated = path.stat().st_mtime
        hist = history.get(session_id, {})
        updated = max(parse_time(hist.get("updated")), created, stat_updated)
        title = hist.get("title") or f"Codex session {session_id[:8]}"
        sessions[session_id] = Session("codex", session_id, title, cwd, updated)

    for session_id, hist in history.items():
        if session_id in sessions:
            continue
        sessions[session_id] = Session(
            "codex",
            session_id,
            hist.get("title") or f"Codex session {session_id[:8]}",
            "",
            parse_time(hist.get("updated")),
        )
    return list(sessions.values())


def claude_sessions():
    sessions = {}
    history = {}
    for item in iter_jsonl(CLAUDE_HOME / "history.jsonl"):
        session_id = item.get("sessionId")
        if not session_id:
            continue
        ts = parse_time(item.get("timestamp"))
        previous = history.get(session_id)
        if previous is None or ts >= previous["updated"]:
            history[session_id] = {
                "updated": ts,
                "title": item.get("display") or "",
                "cwd": item.get("project") or "",
            }

    for index in (CLAUDE_HOME / "projects").glob("*/sessions-index.json"):
        try:
            data = json.loads(index.read_text(encoding="utf-8"))
        except (OSError, json.JSONDecodeError):
            continue
        for entry in data.get("entries") or []:
            session_id = entry.get("sessionId")
            if not session_id:
                continue
            title = entry.get("summary") or entry.get("firstPrompt") or f"Claude session {session_id[:8]}"
            updated = max(
                parse_time(entry.get("modified")),
                parse_time(entry.get("fileMtime")),
                parse_time(entry.get("created")),
            )
            hist = history.get(session_id, {})
            if hist and parse_time(hist.get("updated")) >= updated:
                title = hist.get("title") or title
                updated = parse_time(hist.get("updated"))
            sessions[session_id] = Session(
                "claude",
                session_id,
                title,
                hist.get("cwd") or entry.get("projectPath") or "",
                updated,
                int(entry.get("messageCount") or 0),
            )

    for path in (CLAUDE_HOME / "projects").glob("*/*.jsonl"):
        session_id = path.stem
        hist = history.get(session_id, {})
        existing = sessions.get(session_id)
        title = hist.get("title") or (existing.title if existing else "") or f"Claude session {session_id[:8]}"
        cwd = hist.get("cwd") or (existing.cwd if existing else "")
        updated = max(path.stat().st_mtime, parse_time(hist.get("updated")), existing.updated if existing else 0)
        if existing:
            existing.title = title
            existing.cwd = cwd
            existing.updated = updated
        else:
            sessions[session_id] = Session("claude", session_id, title, cwd, updated)

    return list(sessions.values())


def load_sessions(args):
    sessions = []
    if args.provider in ("all", "codex"):
        sessions.extend(codex_sessions())
    if args.provider in ("all", "claude"):
        sessions.extend(claude_sessions())

    sessions = filter_sessions(sessions, " ".join(args.query))

    sessions.sort(key=lambda s: s.updated, reverse=True)
    if args.limit and not args.last:
        sessions = sessions[: args.limit]
    return sessions


def filter_sessions(sessions, query):
    query = query.strip().lower()
    if not query:
        return list(sessions)
    words = query.split()
    return [
        s
        for s in sessions
        if all(word in searchable_text(s) for word in words)
    ]


def searchable_text(session):
    return " ".join([session.provider, session.session_id, session.title, session.cwd]).lower()


def format_row(index, session):
    provider = session.provider.ljust(6)
    when = age(session.updated).rjust(8)
    cwd = trim(relpath(session.cwd), 42).ljust(42)
    title = trim(session.title, 92)
    return f"{index:>3}  {provider}  {when}  {cwd}  {title}"


def provider_label(provider):
    return "Codex" if provider == "codex" else "Claude"


def safe_addstr(stdscr, y, x, text, attr=0):
    height, width = stdscr.getmaxyx()
    if y < 0 or y >= height or x >= width:
        return
    try:
        stdscr.addnstr(y, x, text, max(0, width - x - 1), attr)
    except curses.error:
        pass


def fill_line(stdscr, y, attr=0):
    height, width = stdscr.getmaxyx()
    if y < 0 or y >= height:
        return
    try:
        stdscr.addnstr(y, 0, " " * max(0, width - 1), max(0, width - 1), attr)
    except curses.error:
        pass


def set_cursor(visibility):
    try:
        curses.curs_set(visibility)
    except curses.error:
        pass


def init_colors():
    try:
        curses.start_color()
        curses.use_default_colors()
        curses.init_pair(1, curses.COLOR_CYAN, -1)
        curses.init_pair(2, curses.COLOR_MAGENTA, -1)
        curses.init_pair(3, curses.COLOR_WHITE, curses.COLOR_BLUE)
        curses.init_pair(4, curses.COLOR_BLACK, curses.COLOR_WHITE)
    except curses.error:
        pass


def color_pair(number):
    try:
        return curses.color_pair(number)
    except curses.error:
        return 0


def choose_with_curses(sessions):
    def run(stdscr):
        set_cursor(0)
        init_colors()
        selected = 0
        offset = 0
        query = ""
        editing = False

        while True:
            visible = filter_sessions(sessions, query)
            if selected >= len(visible):
                selected = max(0, len(visible) - 1)

            height, width = stdscr.getmaxyx()
            row_height = 2 if height >= 15 else 1
            list_top = 4
            list_height = max(1, height - list_top - 2)
            visible_rows = max(1, list_height // row_height)

            if selected < offset:
                offset = selected
            elif selected >= offset + visible_rows:
                offset = selected - visible_rows + 1

            stdscr.erase()
            safe_addstr(stdscr, 0, 2, "resume", curses.A_BOLD)
            safe_addstr(stdscr, 0, 9, "Codex + Claude sessions", curses.A_DIM)
            safe_addstr(
                stdscr,
                1,
                2,
                "Enter: resume selected   /: filter   up/down or j/k: move   q: quit",
                curses.A_DIM,
            )
            if query:
                safe_addstr(stdscr, 2, 2, f"filter: {query}", color_pair(1))
            else:
                safe_addstr(stdscr, 2, 2, f"{len(sessions)} sessions", curses.A_DIM)

            if not visible:
                safe_addstr(stdscr, list_top, 2, "No sessions match.", curses.A_BOLD)
            else:
                end = min(len(visible), offset + visible_rows)
                for row, session in enumerate(visible[offset:end]):
                    actual_index = offset + row
                    y = list_top + row * row_height
                    is_selected = actual_index == selected
                    attr = curses.A_REVERSE if is_selected else 0
                    muted = attr if is_selected else curses.A_DIM
                    provider_attr = attr | curses.A_BOLD
                    if not is_selected:
                        provider_attr |= color_pair(1 if session.provider == "codex" else 2)
                    else:
                        fill_line(stdscr, y, attr)
                        if row_height == 2:
                            fill_line(stdscr, y + 1, attr)

                    prefix = ">" if is_selected else " "
                    safe_addstr(stdscr, y, 2, prefix, attr | curses.A_BOLD)
                    safe_addstr(stdscr, y, 4, provider_label(session.provider).ljust(6), provider_attr)
                    safe_addstr(stdscr, y, 12, age(session.updated).rjust(8), muted)
                    safe_addstr(stdscr, y, 23, trim(relpath(session.cwd), max(16, width - 28)), attr)
                    if row_height == 2:
                        safe_addstr(stdscr, y + 1, 6, trim(session.title, max(16, width - 8)), muted)

            prompt = "filter> " if editing else ""
            if editing:
                set_cursor(1)
                safe_addstr(stdscr, height - 1, 2, prompt + query)
                stdscr.move(height - 1, min(width - 2, 2 + len(prompt) + len(query)))
            else:
                set_cursor(0)
                safe_addstr(stdscr, height - 1, 2, "Esc/q cancels", curses.A_DIM)

            stdscr.refresh()
            try:
                key = stdscr.get_wch()
            except KeyboardInterrupt:
                raise SystemExit(130)

            if editing:
                if key in ("\n", "\r"):
                    editing = False
                elif key == "\x1b":
                    editing = False
                elif key in ("\b", "\x7f"):
                    query = query[:-1]
                    selected = 0
                    offset = 0
                elif key == "\x15":
                    query = ""
                    selected = 0
                    offset = 0
                elif isinstance(key, str) and key.isprintable():
                    query += key
                    selected = 0
                    offset = 0
                continue

            if key in ("\n", "\r"):
                if visible:
                    chosen = visible[selected]
                    for index, session in enumerate(sessions):
                        if session is chosen:
                            return index
                continue
            if key in ("q", "Q", "\x1b"):
                raise SystemExit(130)
            if key == "/":
                editing = True
                continue
            if key in ("j", curses.KEY_DOWN):
                selected = min(selected + 1, max(0, len(visible) - 1))
            elif key in ("k", curses.KEY_UP):
                selected = max(0, selected - 1)
            elif key in (curses.KEY_NPAGE, " "):
                selected = min(selected + visible_rows, max(0, len(visible) - 1))
            elif key == curses.KEY_PPAGE:
                selected = max(0, selected - visible_rows)
            elif key in ("g", curses.KEY_HOME):
                selected = 0
            elif key in ("G", curses.KEY_END):
                selected = max(0, len(visible) - 1)

    try:
        return curses.wrapper(run)
    except curses.error as error:
        raise SystemExit(f"Could not start terminal picker: {error}")


def choose_with_fzf(sessions):
    lines = [format_row(index, session) for index, session in enumerate(sessions, 1)]
    try:
        result = subprocess.run(
            ["fzf", "--no-sort", "--ansi", "--height=80%", "--reverse", "--prompt=resume> "],
            input="\n".join(lines),
            text=True,
            stdout=subprocess.PIPE,
            check=False,
        )
    except OSError:
        return None
    if result.returncode != 0 or not result.stdout.strip():
        raise SystemExit(130)
    try:
        return int(result.stdout.split(None, 1)[0]) - 1
    except (ValueError, IndexError):
        return None


def choose_with_prompt(sessions):
    for index, session in enumerate(sessions, 1):
        print(format_row(index, session))
    print()
    try:
        choice = input("resume> ").strip()
    except EOFError:
        raise SystemExit(130)
    if not choice:
        raise SystemExit(130)
    try:
        selected = int(choice)
    except ValueError:
        matches = [
            index
            for index, session in enumerate(sessions)
            if choice.lower() in session.session_id.lower()
            or choice.lower() in session.title.lower()
        ]
        if len(matches) == 1:
            return matches[0]
        raise SystemExit(f"Could not resolve selection: {choice}")
    if selected < 1 or selected > len(sessions):
        raise SystemExit(f"Selection out of range: {selected}")
    return selected - 1


def command_for(session, fork=False):
    if session.provider == "codex":
        return ["codex", "fork" if fork else "resume", session.session_id]
    if session.provider == "claude":
        command = ["claude"]
        if fork:
            command.append("--fork-session")
        command.extend(["--resume", session.session_id])
        return command
    raise ValueError(f"Unknown provider: {session.provider}")


def main():
    parser = argparse.ArgumentParser(
        description="Merged picker for Codex and Claude sessions.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument("--version", action="version", version=f"%(prog)s {VERSION}")
    parser.add_argument("query", nargs="*", help="Filter words to match against provider, id, title, or cwd")
    parser.add_argument("-n", "--limit", type=int, default=50, help="Maximum sessions to show")
    parser.add_argument("--last", action="store_true", help="Resume the most recent merged session")
    parser.add_argument("--fork", action="store_true", help="Fork instead of resuming")
    parser.add_argument("--dry-run", action="store_true", help="Print the command instead of running it")
    parser.add_argument(
        "--provider",
        choices=["all", "codex", "claude"],
        default="all",
        help="Limit results to one provider",
    )
    parser.add_argument("--fzf", action="store_true", help="Use fzf instead of the built-in picker")
    parser.add_argument("--plain", action="store_true", help="Use the numbered prompt")
    parser.add_argument("--no-fzf", action="store_true", help=argparse.SUPPRESS)
    args = parser.parse_args()

    sessions = load_sessions(args)
    if not sessions:
        raise SystemExit("No matching Codex or Claude sessions found.")

    if args.last:
        selected = 0
    elif args.fzf and not args.no_fzf and shutil.which("fzf"):
        selected = choose_with_fzf(sessions)
        if selected is None:
            selected = choose_with_prompt(sessions)
    elif args.plain or not sys.stdin.isatty() or not sys.stdout.isatty():
        selected = choose_with_prompt(sessions)
    else:
        selected = choose_with_curses(sessions)

    session = sessions[selected]
    command = command_for(session, args.fork)
    if args.dry_run:
        print(" ".join(command))
        return
    os.execvp(command[0], command)


if __name__ == "__main__":
    main()
