#!/usr/bin/env python3

import os
import shlex
import subprocess
import sys


def get_current_branch():
    """Get the name of the current branch."""
    result = subprocess.run(['git', 'symbolic-ref', '--short', 'HEAD'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    if result.returncode != 0:
        print(result.stderr.decode('utf-8'), file=sys.stderr)
        sys.exit(result.returncode)
    return result.stdout.decode('utf-8').strip()

def get_changed_files():
    """Get a list of files that have been changed in the commit."""
    result = subprocess.run(['git', 'diff', '--cached', '--name-only'], stdout=subprocess.PIPE)
    changed_files = result.stdout.decode('utf-8').strip().split('\n')
    return changed_files

def run_command(command, print_stdout: bool = False):
    """Run a shell command."""
    result = subprocess.run(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    if result.returncode != 0:
        print(f"non zero exit ({result.returncode}) for cmd: {command}")
        print(result.stderr.decode('utf-8'))
        if print_stdout:
            print(result.stdout.decode('utf-8'))
        sys.exit(result.returncode)
    return result.stdout.decode('utf-8')


def format_file_args(files: list[str]) -> str:
    return " ".join(shlex.quote(path) for path in files)


def run_ruff_formatting(files: list[str]):
    """Run Ruff formatter and add changes to the commit."""
    print("Running ruff format...")
    run_command(f"uv run --extra dev ruff format {format_file_args(files)}")
    run_command("git add .")

def run_ruff_check(files: list[str]):
    """Run Ruff linting with autofix and add changes to the commit."""
    print("Running ruff check (with fixes)...")
    run_command(f"uv run --extra dev ruff check {format_file_args(files)} --fix")
    run_command("git add .")

def run_ty_check(files: list[str]):
    """Run Ty type checking."""
    print("Running ty check...")
    run_command(f"uv run --extra dev ty check {format_file_args(files)}", True)

def ensure_not_main():
    protected_branch = "main"
    current_branch = get_current_branch()
    if current_branch == protected_branch:
        # Get the commit message from the staged commit
        result = subprocess.run(['git', 'log', '-1', '--pretty=%B', 'HEAD'],
                              stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        commit_msg = result.stdout.decode('utf-8').strip()

        # Allow commits that start with "Bump version"
        if not commit_msg.startswith("Bump version"):
            print(f"Commits directly to {protected_branch} not allowed")
            sys.exit(1)

def main():
    ensure_not_main()

    changed_files = get_changed_files()

    python_files = [file for file in changed_files if file.endswith(".py") and os.path.exists(file)]
    if python_files:
        run_ruff_formatting(python_files)
        run_ruff_check(python_files)
        run_ty_check(python_files)


    # Continue with the commit
    sys.exit(0)

if __name__ == '__main__':
    main()
