#!/usr/bin/env python3
"""Post-process forensic recovery output from `recover-btrfs`.

Pipeline:
  1. Read carved-candidates.jsonl line-by-line.
  2. Parse each line as JSON; reject malformed.
  3. Classify by shape: history.jsonl entry, session-jsonl turn, or noise.
  4. Filter noise (JSON from non-Claude sources -- browser caches, etc.).
  5. Deduplicate against the live ClaudeTimeline DB (data/claude.db) when
     present, so we only surface lines NOT already known.
  6. Reconstruct into proper format:
       - recovered-history.jsonl      (history.jsonl-shaped entries)
       - sessions/<uuid>.jsonl        (per-session reconstructed turn logs)
       - rejected.jsonl               (everything dropped, with reason)
       - stats.json                   (counts + breakdown)

Usage:
  uv run scripts/process-recovery [INPUT] [OUT_DIR]

Defaults:
  INPUT   = newest /var/tmp/btrfs-recover/carved-*/claude-candidates.jsonl
  OUT_DIR = INPUT's parent / "processed"

Read-only against the live ClaudeTimeline DB. Writes only under OUT_DIR.
"""

from __future__ import annotations

import argparse
import glob
import hashlib
import json
import os
import re
import sqlite3
import sys
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any

# ── Constants ──────────────────────────────────────────────────────────────

UUID_RE = re.compile(
    r"^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$"
)

# Shapes that count as "ours". We're conservative: must look like Claude data
# specifically, not just any JSON that happens to mention these fields.
HISTORY_REQUIRED = {"display", "timestamp", "sessionId", "project"}
SESSION_TYPES = {"user", "assistant", "system", "tool_result", "summary"}

# Heuristic: project paths from real Claude history start with /home/, /Users/,
# /root/, or /tmp/ (rare). Reject obviously unrelated paths.
PROJECT_PATH_HINT = re.compile(r"^(/home/|/Users/|/root/|/tmp/)")

DEFAULT_DB = Path("/home/m/Projects/ClaudeTimeline/data/claude.db")


# ── Data containers ────────────────────────────────────────────────────────


@dataclass
class Stats:
    total_lines: int = 0
    json_parse_failed: int = 0
    history_shape: int = 0
    session_shape: int = 0
    rejected_noise: int = 0
    duplicates_history: int = 0
    duplicates_session: int = 0
    new_history: int = 0
    new_session: int = 0
    sessions_seen: set[str] = field(default_factory=set)

    def as_dict(self) -> dict[str, Any]:
        d = self.__dict__.copy()
        d["sessions_seen"] = len(self.sessions_seen)
        return d


@dataclass
class Rejection:
    line_no: int
    reason: str
    snippet: str  # first 200 chars

    def as_dict(self) -> dict[str, str | int]:
        return {"line": self.line_no, "reason": self.reason, "snippet": self.snippet}


# ── Classification ─────────────────────────────────────────────────────────


def classify(obj: Any) -> str:
    """Return 'history', 'session', or 'noise'.

    history: top-level dict matching history.jsonl shape (display + timestamp
        + sessionId + project).
    session: top-level dict matching per-session turn shape (type in
        SESSION_TYPES + sessionId UUID).
    noise: anything else.
    """
    if not isinstance(obj, dict):
        return "noise"

    # history shape
    if HISTORY_REQUIRED.issubset(obj.keys()):
        sid = obj.get("sessionId")
        proj = obj.get("project")
        if (
            isinstance(sid, str)
            and UUID_RE.match(sid)
            and isinstance(proj, str)
            and PROJECT_PATH_HINT.match(proj)
        ):
            return "history"

    # session-turn shape
    t = obj.get("type")
    sid = obj.get("sessionId")
    if (
        isinstance(t, str)
        and t in SESSION_TYPES
        and isinstance(sid, str)
        and UUID_RE.match(sid)
    ):
        return "session"

    return "noise"


# ── Dedup against live DB ──────────────────────────────────────────────────


def load_known_messages(db_path: Path) -> set[tuple[str, str]]:
    """Return set of (session_id, sha256_of_display) for every message in DB.

    Used to detect duplicates from carved history-shape lines.
    """
    if not db_path.exists():
        print(f"[dedup] DB not found at {db_path}; skipping dedup", file=sys.stderr)
        return set()
    out: set[tuple[str, str]] = set()
    with sqlite3.connect(db_path) as conn:
        # Set per-connection FK enforcement for consistency with the app.
        # This path is read-only, but keeping the invariant uniform avoids
        # foot-guns if someone copy-pastes this connection setup.
        conn.execute("PRAGMA foreign_keys = ON")
        for sid, display in conn.execute(
            "SELECT session_id, display FROM messages"
        ):
            if sid and display:
                h = hashlib.sha256(display.encode("utf-8")).hexdigest()
                out.add((sid, h))
    print(f"[dedup] loaded {len(out):,} known message hashes", file=sys.stderr)
    return out


def history_key(obj: dict) -> tuple[str, str]:
    """Stable (session_id, sha256_of_display) key for dedup."""
    sid = obj.get("sessionId", "")
    display = obj.get("display", "")
    h = hashlib.sha256(display.encode("utf-8")).hexdigest()
    return (sid, h)


# ── Main pipeline ──────────────────────────────────────────────────────────


def newest_carved_file() -> Path | None:
    """Find the most recent recover-btrfs output."""
    matches = sorted(
        glob.glob("/var/tmp/btrfs-recover/carved-*/claude-candidates.jsonl")
    )
    return Path(matches[-1]) if matches else None


def process(input_path: Path, out_dir: Path, db_path: Path) -> Stats:
    out_dir.mkdir(parents=True, exist_ok=True)
    sessions_dir = out_dir / "sessions"
    sessions_dir.mkdir(exist_ok=True)

    history_out = out_dir / "recovered-history.jsonl"
    rejected_out = out_dir / "rejected.jsonl"
    stats_out = out_dir / "stats.json"

    known = load_known_messages(db_path)
    seen_history_keys: set[tuple[str, str]] = set()
    session_buckets: dict[str, list[dict]] = defaultdict(list)

    stats = Stats()
    rejections: list[Rejection] = []

    with input_path.open("rb") as fin:
        with history_out.open("w", encoding="utf-8") as fhist:
            for line_no, raw in enumerate(fin, start=1):
                stats.total_lines += 1
                # Each carved line is decoded as utf-8 with replacement so
                # garbled bytes don't crash the parser.
                try:
                    text = raw.decode("utf-8", errors="replace").rstrip("\n")
                except Exception:
                    stats.json_parse_failed += 1
                    rejections.append(Rejection(line_no, "decode-failed", repr(raw[:200])))
                    continue

                # Carved lines may not start with `{` (truncated context). Try
                # to find a JSON object inside.
                obj = try_parse_json(text)
                if obj is None:
                    stats.json_parse_failed += 1
                    rejections.append(Rejection(line_no, "json-parse-failed", text[:200]))
                    continue

                kind = classify(obj)
                if kind == "noise":
                    stats.rejected_noise += 1
                    rejections.append(Rejection(line_no, "noise-not-claude-shape", text[:200]))
                    continue

                if kind == "history":
                    stats.history_shape += 1
                    key = history_key(obj)
                    if key in known:
                        stats.duplicates_history += 1
                        continue
                    if key in seen_history_keys:
                        # Same hit appearing multiple times in carved file
                        stats.duplicates_history += 1
                        continue
                    seen_history_keys.add(key)
                    stats.new_history += 1
                    fhist.write(json.dumps(obj, ensure_ascii=False) + "\n")
                    sid = obj.get("sessionId")
                    if sid:
                        stats.sessions_seen.add(sid)

                elif kind == "session":
                    stats.session_shape += 1
                    sid = obj.get("sessionId", "")
                    # Session-shape dedup is harder: we don't have full session
                    # turn logs in the DB. Best-effort: bucket by session_id +
                    # timestamp, dedupe later when writing per-session files.
                    session_buckets[sid].append(obj)

    # Write per-session reconstructed files
    for sid, turns in session_buckets.items():
        # Dedupe within bucket by (timestamp, type, hash-of-content)
        seen: set[tuple[str, str, str]] = set()
        unique: list[dict] = []
        for t in turns:
            ts = str(t.get("timestamp", ""))
            ty = str(t.get("type", ""))
            content = json.dumps(t.get("message", t), sort_keys=True, ensure_ascii=False)
            h = hashlib.sha256(content.encode("utf-8")).hexdigest()[:16]
            key = (ts, ty, h)
            if key not in seen:
                seen.add(key)
                unique.append(t)

        # Sort by timestamp if present
        unique.sort(key=lambda x: str(x.get("timestamp", "")))
        stats.new_session += len(unique)
        stats.sessions_seen.add(sid)

        out_path = sessions_dir / f"{sid}.jsonl"
        with out_path.open("w", encoding="utf-8") as f:
            for t in unique:
                f.write(json.dumps(t, ensure_ascii=False) + "\n")

    # Write rejections (capped to first 5000 to avoid runaway file size)
    with rejected_out.open("w", encoding="utf-8") as f:
        for r in rejections[:5000]:
            f.write(json.dumps(r.as_dict(), ensure_ascii=False) + "\n")

    # Write stats
    stats_out.write_text(json.dumps(stats.as_dict(), indent=2, default=str))

    return stats


def try_parse_json(text: str) -> Any:
    """Try to parse text as JSON, with a fallback for truncated lines.

    Carved lines may have garbage prepended (when the regex match landed
    mid-record). Strategy: find the first `{`, try parsing from there; if
    that fails too, give up.
    """
    text = text.strip()
    if not text:
        return None
    try:
        return json.loads(text)
    except json.JSONDecodeError:
        pass
    # Try from first `{`
    i = text.find("{")
    if i > 0:
        try:
            return json.loads(text[i:])
        except json.JSONDecodeError:
            pass
    return None


# ── CLI ────────────────────────────────────────────────────────────────────


def main() -> int:
    ap = argparse.ArgumentParser(description=__doc__)
    ap.add_argument(
        "input",
        nargs="?",
        type=Path,
        help="Carved candidates file (default: newest /var/tmp/btrfs-recover/carved-*/claude-candidates.jsonl)",
    )
    ap.add_argument(
        "out_dir",
        nargs="?",
        type=Path,
        help="Output dir (default: <input parent>/processed)",
    )
    ap.add_argument(
        "--db",
        type=Path,
        default=DEFAULT_DB,
        help=f"ClaudeTimeline DB for dedup (default: {DEFAULT_DB})",
    )
    args = ap.parse_args()

    input_path = args.input or newest_carved_file()
    if input_path is None or not input_path.exists():
        print(
            "ERROR: no carved-candidates.jsonl found. "
            "Run scripts/recover-btrfs first or pass an explicit path.",
            file=sys.stderr,
        )
        return 1

    out_dir = args.out_dir or (input_path.parent / "processed")

    print(f"[process] input:   {input_path}", file=sys.stderr)
    print(f"[process] out dir: {out_dir}", file=sys.stderr)
    print(f"[process] DB:      {args.db}", file=sys.stderr)

    stats = process(input_path, out_dir, args.db)

    print("\n=== summary ===", file=sys.stderr)
    print(json.dumps(stats.as_dict(), indent=2, default=str), file=sys.stderr)
    print(f"\nOutput: {out_dir}", file=sys.stderr)
    print(f"  recovered-history.jsonl  -- {stats.new_history} new entries", file=sys.stderr)
    print(f"  sessions/<uuid>.jsonl    -- {len(stats.sessions_seen)} sessions touched", file=sys.stderr)
    print(f"  rejected.jsonl           -- first 5000 dropped lines + reason", file=sys.stderr)
    print(f"  stats.json               -- full counts", file=sys.stderr)

    return 0


if __name__ == "__main__":
    sys.exit(main())
