#!/usr/bin/env python3
# Vendored from https://github.com/bbugyi200/dotfiles via pyvendor on 2026-02-21
"""Find unused public Python function and class definitions."""

import argparse
import ast
import os
import re
import subprocess
import sys
import tomllib
from dataclasses import dataclass
from pathlib import Path


@dataclass
class _PragmaInfo:
    symbol_name: str
    ref_path: str  # relative to git root
    source_file: Path  # .py file containing the pragma
    pragma_line: int  # 1-based line number


def _get_git_root() -> Path | None:
    """Get the git repository root directory."""
    result = subprocess.run(
        ["git", "rev-parse", "--show-toplevel"],
        capture_output=True,
        text=True,
    )
    if result.returncode != 0:
        return None
    return Path(result.stdout.strip())


def _extract_pragmas(file_path: Path) -> dict[str, list[_PragmaInfo]]:
    """Extract pyvision pragmas from a Python file.

    Looks for ``# pyvision: <path>`` comments on consecutive lines directly
    above a top-level def/async def/class (or above its first decorator).
    Multiple pragma lines per symbol are supported.
    """
    try:
        with open(file_path, encoding="utf-8") as f:
            source = f.read()
        lines = source.splitlines()
        tree = ast.parse(source, filename=str(file_path))
    except (SyntaxError, UnicodeDecodeError, OSError):
        return {}

    pragma_re = re.compile(r"^#\s*pyvision:\s*(.+)$")
    pragmas: dict[str, list[_PragmaInfo]] = {}

    for node in tree.body:
        if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
            continue

        # The effective start line is the first decorator (if any), else the def/class line
        if node.decorator_list:
            start_line = node.decorator_list[0].lineno
        else:
            start_line = node.lineno

        # Walk upward from the line above, collecting all consecutive pragma lines
        comment_idx = start_line - 2
        if comment_idx < 0:
            continue

        symbol_pragmas: list[_PragmaInfo] = []
        while comment_idx >= 0:
            m = pragma_re.match(lines[comment_idx].strip())
            if not m:
                break
            symbol_pragmas.append(
                _PragmaInfo(
                    symbol_name=node.name,
                    ref_path=m.group(1).strip(),
                    source_file=file_path,
                    pragma_line=comment_idx + 1,  # 1-based line number
                )
            )
            comment_idx -= 1

        if symbol_pragmas:
            pragmas[node.name] = symbol_pragmas

    return pragmas


def _validate_pragmas(
    all_pragmas: dict[str, list[_PragmaInfo]],
    git_root: Path,
    python_files: list[Path],
) -> list[str]:
    """Validate all collected pragmas and return a list of error strings."""
    errors: list[str] = []

    for name, pragma_list in all_pragmas.items():
        # Use first pragma for symbol-level error prefix
        first = pragma_list[0]

        # Symbol-level check: pragma on private symbol
        if name.startswith("_"):
            prefix = f"Error: pyvision pragma in {first.source_file}:{first.pragma_line}:"
            errors.append(
                f"{prefix} pragma cannot be applied to private symbol '{name}'"
            )
            continue

        # Symbol-level check: stale pragma (symbol already imported by Python files)
        if _search_function_usage(name, python_files):
            prefix = f"Error: pyvision pragma in {first.source_file}:{first.pragma_line}:"
            errors.append(
                f"{prefix} symbol '{name}' is already imported by other Python files."
                " Remove this unnecessary pragma"
            )
            continue

        # Per-pragma checks
        for pragma in pragma_list:
            prefix = f"Error: pyvision pragma in {pragma.source_file}:{pragma.pragma_line}:"
            ref_file = git_root / pragma.ref_path

            # Referenced file doesn't exist
            if not ref_file.is_file():
                errors.append(
                    f"{prefix} referenced file '{pragma.ref_path}' does not exist"
                )
                continue

            # Referenced file inside src/
            try:
                ref_file.resolve().relative_to((git_root / "src").resolve())
                errors.append(
                    f"{prefix} referenced file '{pragma.ref_path}' is inside src/"
                )
                continue
            except ValueError:
                pass  # Not inside src/ — good

            # Check that referenced file actually contains the symbol name
            try:
                ref_content = ref_file.read_text(encoding="utf-8")
                if not re.search(rf"\b{re.escape(name)}\b", ref_content):
                    errors.append(
                        f"{prefix} referenced file '{pragma.ref_path}'"
                        f" does not contain a reference to symbol '{name}'"
                    )
            except OSError as e:
                errors.append(
                    f"{prefix} could not read referenced file '{pragma.ref_path}': {e}"
                )

    return errors


def _find_python_files(directory: Path, exclude_test_dirs: bool = True) -> list[Path]:
    """Find all .py files in directory, optionally excluding test directories."""
    exclude_patterns = (
        {"test", "tests", ".venv", "venv"} if exclude_test_dirs else set()
    )

    python_files = []
    for py_file in directory.rglob("*.py"):
        # Check if any parent directory is in exclude_patterns
        if any(part in exclude_patterns for part in py_file.parts):
            continue
        python_files.append(py_file)

    return python_files


def _find_pyproject_toml(directory: Path) -> Path | None:
    """Walk up from directory to find the nearest pyproject.toml."""
    current = directory.resolve()
    while True:
        candidate = current / "pyproject.toml"
        if candidate.is_file():
            return candidate
        parent = current.parent
        if parent == current:
            return None
        current = parent


def _extract_entrypoint_symbols(pyproject_path: Path) -> set[str]:
    """Extract function names from pyproject.toml entry points.

    Parses [project.scripts], [project.gui-scripts], and [project.entry-points.*]
    sections, extracting the function name from each "module.path:function_name" value.
    """
    with open(pyproject_path, "rb") as f:
        data = tomllib.load(f)

    symbols: set[str] = set()
    project = data.get("project", {})

    for section_key in ("scripts", "gui-scripts"):
        section = project.get(section_key, {})
        for value in section.values():
            if ":" in value:
                func_name = value.rsplit(":", 1)[1]
                symbols.add(func_name)

    for group in project.get("entry-points", {}).values():
        for value in group.values():
            if ":" in value:
                func_name = value.rsplit(":", 1)[1]
                symbols.add(func_name)

    return symbols


def _get_decorator_names(node: ast.AST) -> set[str]:
    """Extract the simple names of all decorators on a function/class node.

    For a plain ``@foo`` decorator the name is ``"foo"``.
    For a dotted ``@foo.bar`` or call ``@foo(...)`` the outermost name is used.
    """
    names: set[str] = set()
    for dec in getattr(node, "decorator_list", []):
        if isinstance(dec, ast.Name):
            names.add(dec.id)
        elif isinstance(dec, ast.Attribute):
            # Walk to the root Name of a dotted expression like foo.bar.baz
            inner = dec
            while isinstance(inner, ast.Attribute):
                inner = inner.value
            if isinstance(inner, ast.Name):
                names.add(inner.id)
        elif isinstance(dec, ast.Call):
            func = dec.func
            if isinstance(func, ast.Name):
                names.add(func.id)
            elif isinstance(func, ast.Attribute):
                inner = func
                while isinstance(inner, ast.Attribute):
                    inner = inner.value
                if isinstance(inner, ast.Name):
                    names.add(inner.id)
    return names


def _extract_public_functions(
    file_path: Path, exclude_decorators: set[str] | None = None
) -> list[str]:
    """Extract all public top-level function and class names from a Python file."""
    try:
        with open(file_path, encoding="utf-8") as f:
            tree = ast.parse(f.read(), filename=str(file_path))

        symbols = []
        # Only iterate over module-level nodes, not nested functions/methods
        for node in tree.body:
            if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
                # Only public symbols (not prefixed with _) and not named "main"
                if not node.name.startswith("_") and node.name != "main":
                    if exclude_decorators and _get_decorator_names(node) & exclude_decorators:
                        continue
                    symbols.append(node.name)

        return symbols
    except (SyntaxError, UnicodeDecodeError) as e:
        print(f"Warning: Could not parse {file_path}: {e}", file=sys.stderr)
        return []


def _extract_private_functions(
    file_path: Path, exclude_decorators: set[str] | None = None
) -> list[str]:
    """Extract all private top-level function and class names from a Python file."""
    try:
        with open(file_path, encoding="utf-8") as f:
            tree = ast.parse(f.read(), filename=str(file_path))

        symbols = []
        # Only iterate over module-level nodes, not nested functions/methods
        for node in tree.body:
            if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
                # Only private symbols (prefixed with _)
                if node.name.startswith("_"):
                    if exclude_decorators and _get_decorator_names(node) & exclude_decorators:
                        continue
                    symbols.append(node.name)

        return symbols
    except (SyntaxError, UnicodeDecodeError) as e:
        print(f"Warning: Could not parse {file_path}: {e}", file=sys.stderr)
        return []


def _search_function_usage(function_name: str, search_files: list[Path]) -> bool:
    r"""Search for usage of a function or class in the given files.

    Uses two patterns:
    1. ^\s+<symbol>,\s*(#.*)?$ - Symbol alone on line in multi-line import
    2. ^\s*from\s.*\simport.*\s<symbol> - Import statement
    """
    # Using word boundaries to ensure we match the exact symbol name
    patterns = [
        rf"^\s+{re.escape(function_name)},\s*(#.*)?$",
        rf"^\s*from\s.*\simport.*\s{re.escape(function_name)}\b",
    ]

    for pattern in patterns:
        for file_path in search_files:
            try:
                with open(file_path, encoding="utf-8") as f:
                    content = f.read()
                    if re.search(pattern, content, re.MULTILINE):
                        return True
            except (UnicodeDecodeError, OSError):
                continue

    return False


def _is_private_function_used_in_file(function_name: str, file_path: Path) -> bool:
    """Check if a private function is used within its own file.

    Looks for the function being called or referenced, excluding the definition line.
    """
    try:
        with open(file_path, encoding="utf-8") as f:
            content = f.read()

        # Parse the file to get the definition line number
        tree = ast.parse(content, filename=str(file_path))
        def_line = None
        for node in tree.body:
            if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
                if node.name == function_name:
                    def_line = node.lineno
                    break

        # Search for usage of the function (not in the definition line)
        lines = content.split("\n")
        for i, line in enumerate(lines, start=1):
            # Skip the definition line
            if i == def_line:
                continue

            # Look for function name being used (called, passed, etc.)
            # Use word boundary to avoid matching substrings
            if re.search(rf"\b{re.escape(function_name)}\b", line):
                return True

        return False
    except (UnicodeDecodeError, OSError, SyntaxError):
        return True  # Assume used if we can't parse


def main():
    parser = argparse.ArgumentParser(
        description="Find unused public Python function and class definitions"
    )
    parser.add_argument(
        "directory", type=Path, help="Directory to search for Python files"
    )
    parser.add_argument(
        "--exclude-file",
        action="append",
        default=[],
        help="Exclude a file from analysis (can be repeated)",
    )
    parser.add_argument(
        "--epic-symbol",
        action="append",
        default=[],
        help=(
            "Exclude a symbol tied to an open epic bead from unused-symbol analysis."
            " Format: <bead_id>(<symbol_name>). Can be repeated."
        ),
    )
    parser.add_argument(
        "--exclude-decorator",
        action="append",
        default=[],
        help=(
            "Exclude any function/class decorated with this decorator name"
            " from unused-symbol analysis. Can be repeated."
        ),
    )
    args = parser.parse_args()

    if not args.directory.is_dir():
        print(f"Error: {args.directory} is not a directory", file=sys.stderr)
        return 1

    # Find all Python files excluding test directories
    python_files = _find_python_files(args.directory)

    # Apply --exclude-file filters
    if args.exclude_file:
        excluded = {Path(p).resolve() for p in args.exclude_file}
        python_files = [f for f in python_files if f.resolve() not in excluded]

    if not python_files:
        print("No Python files found", file=sys.stderr)
        return 0

    # Build the set of excluded decorator names
    exclude_decorators = set(args.exclude_decorator) if args.exclude_decorator else None

    # Collect all public functions and classes from all files
    all_symbols = {}  # {symbol_name: [file_path, ...]}
    for py_file in python_files:
        symbols = _extract_public_functions(py_file, exclude_decorators)
        for symbol in symbols:
            if symbol not in all_symbols:
                all_symbols[symbol] = []
            all_symbols[symbol].append(py_file)

    # Exclude symbols referenced as entry points in pyproject.toml
    pyproject_path = _find_pyproject_toml(args.directory)
    if pyproject_path:
        entrypoint_symbols = _extract_entrypoint_symbols(pyproject_path)
        for sym in entrypoint_symbols:
            all_symbols.pop(sym, None)

    # Apply --epic-symbol filters
    epic_errors = []
    validated_epic_symbols = []
    for entry in args.epic_symbol:
        # 1. Parse format: <bead_id>(<symbol_name>)
        m = re.match(r"^(.+)\(([^)]+)\)$", entry)
        if not m:
            epic_errors.append(
                f"Error: Invalid --epic-symbol format '{entry}'."
                " Expected format: <bead_id>(<symbol_name>)"
            )
            continue
        bead_id, symbol_name = m.group(1), m.group(2)

        # 2. Reject private symbols
        if symbol_name.startswith("_"):
            epic_errors.append(
                f"Error: --epic-symbol '{entry}': symbol '{symbol_name}' is private."
                " Only public symbols can be excluded."
            )
            continue

        # 3. Check bead is open
        bd_cmd = os.environ.get("BD_COMMAND", "bd")
        result = subprocess.run(
            [bd_cmd, "show", bead_id],
            capture_output=True,
            text=True,
        )
        if result.returncode != 0:
            epic_errors.append(
                f"Error: --epic-symbol '{entry}': bead '{bead_id}' not found."
                " Remove this --epic-symbol entry."
            )
            continue
        if "CLOSED" in result.stdout:
            epic_errors.append(
                f"Error: --epic-symbol '{entry}': bead '{bead_id}' is closed."
                " Remove this stale --epic-symbol entry and clean up the symbol."
            )
            continue

        # 4. Check symbol exists
        if symbol_name not in all_symbols:
            epic_errors.append(
                f"Error: --epic-symbol '{entry}': symbol '{symbol_name}' not found"
                " as a public definition. Remove this --epic-symbol entry."
            )
            continue

        # 5. Check symbol actually needs ignoring (not already used)
        if _search_function_usage(symbol_name, python_files):
            epic_errors.append(
                f"Error: --epic-symbol '{entry}': symbol '{symbol_name}' is already"
                " properly used. Remove this unnecessary --epic-symbol entry."
            )
            continue

        validated_epic_symbols.append(symbol_name)

    if epic_errors:
        for err in epic_errors:
            print(err, file=sys.stderr)
        return 1

    for symbol in validated_epic_symbols:
        all_symbols.pop(symbol, None)

    # Collect and validate pyvision pragmas from all Python files
    all_pragmas: dict[str, list[_PragmaInfo]] = {}
    for py_file in python_files:
        all_pragmas.update(_extract_pragmas(py_file))

    if all_pragmas:
        git_root = _get_git_root()
        if git_root is None:
            print(
                "Error: pyvision pragmas found but not inside a git repository",
                file=sys.stderr,
            )
            return 1

        pragma_errors = _validate_pragmas(all_pragmas, git_root, python_files)
        if pragma_errors:
            for err in pragma_errors:
                print(err, file=sys.stderr)
            return 1

        # Remove pragma-covered symbols from the unused check
        for symbol_name in all_pragmas:
            all_symbols.pop(symbol_name, None)

    if not all_symbols:
        print("No public functions or classes found!", file=sys.stderr)
        return 0

    # Collect all private functions and classes from all files
    all_private_symbols = {}  # {symbol_name: [file_path, ...]}
    for py_file in python_files:
        private_symbols = _extract_private_functions(py_file, exclude_decorators)
        for symbol in private_symbols:
            if symbol not in all_private_symbols:
                all_private_symbols[symbol] = []
            all_private_symbols[symbol].append(py_file)

    # Check that private symbols do NOT match import patterns
    imported_private_symbols = []
    for symbol_name, def_files in all_private_symbols.items():
        if _search_function_usage(symbol_name, python_files):
            imported_private_symbols.append((symbol_name, def_files))

    if imported_private_symbols:
        print(
            "Error: Private functions/classes should not be imported. Make these public if they"
            " need to be imported by non-test files!:",
            file=sys.stderr,
        )
        for symbol_name, def_files in sorted(imported_private_symbols):
            for file_path in def_files:
                print(f"  {symbol_name} in {file_path}", file=sys.stderr)
        return 1

    # Check that private symbols ARE used in their own file
    unused_private_symbols = []
    for symbol_name, def_files in all_private_symbols.items():
        for file_path in def_files:
            if not _is_private_function_used_in_file(symbol_name, file_path):
                unused_private_symbols.append((symbol_name, file_path))

    if unused_private_symbols:
        print(
            "Error: Private functions/classes must be used in the file where they are defined:",
            file=sys.stderr,
        )
        for symbol_name, file_path in sorted(unused_private_symbols):
            print(f"  {symbol_name} in {file_path}", file=sys.stderr)
        return 1

    # Search for usage of each symbol
    unused_symbols = []
    for symbol_name, def_files in all_symbols.items():
        if not _search_function_usage(symbol_name, python_files):
            unused_symbols.append((symbol_name, def_files))

    # Report results
    if unused_symbols:
        print(
            "Unused public functions/classes. Make these private if they are used only within the"
            " file they are defined. If the functions/classes are completely unused, you should delete them:"
        )
        for symbol_name, def_files in sorted(unused_symbols):
            for file_path in def_files:
                print(f"  {symbol_name} in {file_path}")
        return 1

    print("All public/private classes/functions are used properly!")
    return 0


if __name__ == "__main__":
    sys.exit(main())
