#!/usr/bin/env python

"""Updates the schemas for both the dev environment and the staging/production
environment.

To prepare the schemas for release first check the changes with the --prod flag:

    $ ./tools/update-schemas --prod --diff

This will compare the schemas with all `$id` and `url` fields URLs removed which
in theory can be present anywhere in any schema. If the output looks as expected
you can write the production schemas:

    $ ./tools/update-schemas --prod --force

"""

from __future__ import annotations

import argparse
import difflib
import json
import os
import subprocess
import sys
from copy import deepcopy
from enum import Enum, auto
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, TypeVar

if TYPE_CHECKING:
    from fmu.datamodels._schema_base import SchemaBase

GREEN = "\033[32m"
RED = "\033[31m"
YELLOW = "\033[93m"
NC = "\033[0m"
BOLD = "\033[1m"
PASS = f"[{BOLD}{GREEN}✔{NC}]"
FAIL = f"[{BOLD}{RED}✖{NC}]"
INFO = f"[{BOLD}{YELLOW}+{NC}]"

T = TypeVar("T", Dict, List, object)


def _get_parser() -> argparse.ArgumentParser:
    """Construct parser object."""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-d",
        "--diff",
        action="store_true",
        help="Show a diff between the current schema and the new one in output.",
    )
    parser.add_argument(
        "-p",
        "--prod",
        action="store_true",
        help="Produce schemas with production URLs",
    )
    parser.add_argument(
        "-t",
        "--test",
        action="store_true",
        help="Run as normal, but don't write the file.",
    )
    parser.add_argument(
        "-f",
        "--force",
        action="store_true",
        help="Force the script to overwrite the current schema with the new schema.",
    )
    return parser


def _show_git_diff(output_filepath: Path) -> None:
    command = ["git", "diff", str(output_filepath)]
    print(INFO, f"running `{' '.join(command)}` ...")
    output = subprocess.run(command, capture_output=True, text=True)
    diff_str = "\n      ".join(output.stdout.split("\n"))
    print(f"      {diff_str}")  # To indent the first line too


def _show_py_diff(existing_schema: dict[str, Any], new_schema: dict[str, Any]) -> None:
    existing = json.dumps(existing_schema, indent=2, sort_keys=True)
    new = json.dumps(new_schema, indent=2, sort_keys=True)
    diff = difflib.unified_diff(
        existing.splitlines(),
        new.splitlines(),
        lineterm="",
        fromfile="existing schema",
        tofile="new schema",
    )
    diff_str = "\n      ".join(diff)
    print(f"      {diff_str}")


def _remove_schema_ids(schema: T) -> T:
    """Recursively remove all '$id' and 'url' fields from a schema."""
    if isinstance(schema, dict):
        return {
            key: _remove_schema_ids(value)
            for key, value in schema.items()
            if key not in ("$id", "url")
        }
    if isinstance(schema, list):
        return [_remove_schema_ids(item) for item in schema]
    return schema


def _schemas_are_equivalent(
    existing_schema: dict[str, Any],
    new_schema: dict[str, Any],
    is_release: bool,
    release_url: str,
) -> bool:
    """Checks that schemas are equivalent.

    If this is a release, it first removes all $id fields to ensure no other fields are
    being changed. It re-applies the root $id field (self-reference) to ensure that they
    do match.
    """
    if is_release:
        existing_schema = _remove_schema_ids(deepcopy(existing_schema))
        existing_schema["$id"] = release_url

        new_schema_id = new_schema["$id"]
        new_schema = _remove_schema_ids(deepcopy(new_schema))
        new_schema["$id"] = new_schema_id

    return existing_schema == new_schema


def _update_metadata_examples():
    """Updates the metadata examples to be in sync with the updated schemas."""

    print(INFO, "Updating metadata examples to be in sync with updated schemas...")

    print(INFO, "Running bash script /examples/update_examples.sh")
    result = subprocess.run(
        "examples/update_examples.sh",
        capture_output=True,
        text=True
    )

    if result.returncode > 0:
        print(
            FAIL,
            f"🚨 The bash script for updating metadata examples failed "
            f"with exit code: {result.returncode}\n"
            f"Error: {result.stderr}",
        )
        sys.exit(1)

    print(PASS, "Done updating metadata examples.")
    

class SchemaUpdateResult(Enum):
    """The result of updating a schema."""

    UNCHANGED = auto()
    UPDATED = auto()
    FAILED = auto()


class SchemaWriter:
    """Handles schema finding, updating, and writing."""

    def __init__(
        self,
        *,
        is_release: bool = False,
        show_diff: bool = False,
        is_test: bool = False,
        force_overwrite: bool = False,
    ) -> None:
        self.force_overwrite = force_overwrite
        self.is_test = is_test
        self.is_release = is_release
        self.show_diff = show_diff

    def write_schema(
        self,
        schema_base: SchemaBase,
    ) -> SchemaUpdateResult:
        """Write schema to file after some checking."""
        output_filepath = self._get_output_filepath(schema_base.PATH)
        self._ensure_output_path(output_filepath.parent)

        new_schema = schema_base.dump()

        try:
            existing_schema = self._load_existing_schema(output_filepath)
        except FileNotFoundError:
            return self._write_new_schema(output_filepath, new_schema, schema_base)
        except json.JSONDecodeError as e:
            if not self.force_overwrite:
                print(FAIL, f"Invalid JSON in existing schema: {e}")
                return SchemaUpdateResult.FAILED
            return self._write_new_schema(output_filepath, new_schema, schema_base)

        return self._handle_existing_schema(
            existing_schema,
            new_schema,
            schema_base,
            output_filepath,
        )

    def _get_output_filepath(self, schema_path: Path) -> Path:
        """Returns a Path with the appropriate output location, without the filename."""
        project_root = Path(__file__).parent.parent.resolve()  # absolute path of ../../
        return project_root / schema_path

    def _ensure_output_path(self, dir_path: Path) -> None:
        dir_exists = dir_path.exists()
        if dir_exists and not dir_path.is_dir():
            raise ValueError(f"Path exists but is not a directory: {dir_path}")

        if not self.is_test:
            dir_path.mkdir(parents=True, exist_ok=True)
        if not dir_exists:
            print(INFO, f"Created directory: {dir_path}")

    def _load_existing_schema(self, filepath: Path) -> dict[str, Any]:
        if not filepath.exists():
            raise FileNotFoundError(f"Schema not found: {filepath}")

        with open(filepath, encoding="utf-8") as f:
            return json.load(f)

    def _handle_existing_schema(
        self,
        existing_schema: dict[str, Any],
        new_schema: dict[str, Any],
        schema_base: SchemaBase,
        output_filepath: Path,
    ) -> SchemaUpdateResult:
        """Handles the case when a schema version we're writing to already exists.

        Typically this means we need a version bump. In rare cases it could mean we need
        to force an overwrite.
        """
        if not self.force_overwrite and not _schemas_are_equivalent(
            existing_schema, new_schema, self.is_release, schema_base.url()
        ):
            print(
                FAIL,
                f"🚨 {BOLD}{schema_base.FILENAME}{NC} version "
                f"{BOLD}{schema_base.VERSION}{NC} "
                "has changed. does it need a new version?",
            )
            if self.show_diff:
                _show_py_diff(existing_schema, new_schema)
            return SchemaUpdateResult.FAILED

        if (
            self.is_release
            and not self.force_overwrite
            and existing_schema["$id"] != new_schema["$id"]
        ):
            print(
                FAIL,
                f"🚨 {BOLD}{schema_base.FILENAME}{NC} version "
                f"{BOLD}{schema_base.VERSION}{NC}: mismatch in '$id'. "
                "you probably need to run `./tools/update-schema --prod --force`",
            )
            if self.show_diff:
                _show_py_diff(existing_schema, new_schema)
            return SchemaUpdateResult.FAILED

        if new_schema == existing_schema:
            print(
                PASS,
                f"{BOLD}{schema_base.FILENAME}{NC} version "
                f"{BOLD}{schema_base.VERSION}{NC} unchanged",
            )
            return SchemaUpdateResult.UNCHANGED
        return self._write_new_schema(output_filepath, new_schema, schema_base)

    def _write_new_schema(
        self,
        output_filepath: Path,
        schema: dict[str, Any],
        schema_base: SchemaBase,
    ) -> SchemaUpdateResult:
        """Writes the schema to disk."""
        if not self.is_test:
            with open(output_filepath, "w", encoding="utf-8") as f:
                f.write(json.dumps(schema, indent=2, sort_keys=True))

        print(
            PASS,
            f"{BOLD}{schema_base.FILENAME}{NC} version {BOLD}{schema_base.VERSION}{NC} "
            f"written to:\n      {output_filepath}",
        )

        if self.show_diff:
            _show_git_diff(output_filepath)
        return SchemaUpdateResult.UPDATED


def main() -> None:
    parser = _get_parser()
    args = parser.parse_args()

    if args.force:
        print(INFO, "Forcing schema overwrite")

    os.environ["DEV_SCHEMA"] = "" if args.prod else "1"
    # Ensures URLs will differ based on above
    import fmu.datamodels as models

    writer = SchemaWriter(
        is_release=args.prod,
        show_diff=args.diff,
        is_test=args.test,
        force_overwrite=args.force,
    )

    schema_update_results = []
    for schema in models.schemas:
        update_result = writer.write_schema(schema)
        schema_update_results.append(update_result)

    if SchemaUpdateResult.FAILED in schema_update_results:
        sys.exit(1)


if __name__ == "__main__":
    main()
