#!/usr/bin/env python
# [MISE] description="Extract CI info"
# [MISE] hide=true
# [MISE] tools={python="latest"}
import dataclasses
import glob
import json
import logging
import os
import re
import shlex
import subprocess
import tomllib
from dataclasses import is_dataclass
from datetime import datetime
from io import StringIO
from pprint import pformat
from typing import List, Set

LOG = logging.getLogger(__name__)


@dataclasses.dataclass(frozen=True, eq=True)
class Task:
    name: str
    cmd: str


class Output:
    def __init__(self):
        output = os.getenv("GITHUB_OUTPUT")
        if output:
            self._fobj = open(output, "a")
        else:
            self._fobj = StringIO()
        self._build_id = datetime.now().strftime("%Y%m%d%H%M")

    def write(self, mise_config):
        LOG.debug("Creating outputs")
        for name in dir(self):
            if name.startswith("_"):
                continue
            if name == "write":
                continue

            output_name = name.replace("_", "-")
            value = getattr(self, name)(mise_config)
            line = f"{output_name}={json.dumps(value, default=encode_dataclass)}\n"
            self._fobj.write(line)
            LOG.info(line.strip())

    def manifests_task(self, mise_config):
        try:
            LOG.debug("Repository has explicit manifests task: %r", mise_config["tasks"]["manifests"])
            return "manifests"
        except KeyError:
            return "k8s:manifests"

    def _task_dependencies(self, task_name) -> Set[Task]:
        task_info = mise("tasks", "info", task_name)
        if task_info["run"] or task_info["file"]:
            LOG.debug("Task %s is a leaf", task_name)
            return {Task(slugify(task_name), task_name)}
        deps = set()
        for dep in task_info["depends"]:
            if isinstance(dep, str):
                sub_deps = self._task_dependencies(dep)
                deps.update(sub_deps)
            else:
                LOG.debug("Adding task %r", dep)
                cmd = shlex.join(dep)
                deps.add(Task(slugify(cmd), cmd))
        return deps

    def ci_tasks(self, mise_config) -> List[Task]:
        return list(self._task_dependencies("ci"))

    def release_build_tasks(self, mise_config) -> List[Task]:
        return list(self._task_dependencies("release-build"))

    def _has_source_changes(self, mise_config):
        depth = int(git(["git", "rev-list", "--count", "--all"]))
        if depth == 1:
            LOG.warning("To correctly determine if source files has been changed, you need to fetch all git history.")
            LOG.warning("Consider setting 'fetch-depth: 0' in your checkout action if running in GitHub Actions.")
        base_sha = git(["git", "rev-list", "--tags", "--max-count=1"]).strip()
        if not base_sha:
            base_sha = git(["git", "hash-object", "-t", "tree", "/dev/null"]).strip() # SHA of empty tree
        LOG.debug("Last tag has sha %r", base_sha)
        head_sha = git(["git", "rev-parse", "HEAD"]).strip()
        LOG.debug("Head has sha %r", head_sha)
        changed_files = git(["git", "diff", "--name-only", base_sha, head_sha]).splitlines(keepends=False)
        LOG.debug("Changed files: %s", pformat(changed_files))
        sources = mise_config["tasks"]["release-build"].get("sources", [])
        sources.append(".github/workflows/*.y*ml")
        for src in sources:
            pattern = re.compile(glob.translate(src, recursive=True))
            for f in changed_files:
                if pattern.match(f):
                    LOG.debug("Found changed source file: %r", f)
                    return True
        LOG.debug("No changes to source files")
        return False

    def run_publish(self, mise_config):
        result = git(["git", "rev-parse", "--symbolic-full-name", "HEAD"])
        rev = ("".join(result)).strip()
        if "refs/heads/main" != rev:
            LOG.info(f"Not on main branch, rev is {rev!r}")
            return False
        if not self._has_source_changes(mise_config):
            LOG.info("No changes to sources since last tag, no need to publish")
            return False
        return True

    def artifacts(self, mise_config):
        build_config = mise_config["tasks"]["release-build"]
        return "\n".join(build_config.get("outputs", []))

    def tool_versions(self, mise_config):
        tools = mise("ls", "--local")
        output = {}
        for tool in tools.keys():
            versions = tools.get(tool)
            output[tool] = sorted([v["requested_version"] for v in versions], reverse=True)
        return output

    def _version(self, mise_config, clean=False):
        vars = mise_config.get("vars", {})
        major = vars.get("major_version", "0")
        if clean:
            return f"{major}.{self._build_id}"
        result = git(["git", "describe", "--always", "--dirty", "--exclude", "*"])
        local = ("".join(result)).strip()
        return f"{major}.{self._build_id}+{local}"

    def version(self, mise_config):
        return self._version(mise_config, clean=False)

    def clean_version(self, mise_config):
        return self._version(mise_config, clean=True)


def git(cmd):
    root = os.getenv("MISE_CONFIG_ROOT", ".")
    try:
        p = subprocess.run(cmd, capture_output=True, text=True, check=True, cwd=root)
        return p.stdout
    except subprocess.CalledProcessError as e:
        LOG.error("%s", str(e))
        LOG.debug("StdOut: %s", pformat(e.stdout))
        LOG.debug("StdErr: %s", pformat(e.stderr))


def mise(*args):
    cmd = ["mise"]
    cmd.extend(args)
    root = os.getenv("MISE_CONFIG_ROOT", ".")
    try:
        if "--json" not in cmd:
            cmd.append("--json")
        p = subprocess.run(cmd, capture_output=True, text=True, check=True, cwd=root)
        return json.loads(p.stdout)
    except subprocess.CalledProcessError as e:
        LOG.error("%s", str(e))
        LOG.debug("StdOut: %s", pformat(e.stdout))
        LOG.debug("StdErr: %s", pformat(e.stderr))


def encode_dataclass(value):
    if is_dataclass(value):
        return dataclasses.asdict(value)
    return value


def slugify(text):
    slug = text.lower()
    slug = re.sub(r'[^a-zæøå0-9]+', '-', slug).strip('-')
    slug = re.sub(r'[-]+', '-', slug)
    return slug


def load_mise():
    """Load mise.toml"""
    root = os.getenv("MISE_CONFIG_ROOT", ".")
    mise_path = os.path.join(root, "mise.toml")
    with open(mise_path, "rb") as fobj:
        return tomllib.load(fobj)


def setup_logging():
    logging.basicConfig(level=logging.DEBUG, format="::%(levelname)s::%(message)s")
    # GitHub wants lowercase levelnames for workflow commands mapping
    for level in logging.getLevelNamesMapping().values():
        logging.addLevelName(level, logging.getLevelName(level).lower())


def main():
    setup_logging()
    LOG.info("Extracting CI info")
    mise_config = load_mise()
    output = Output()
    output.write(mise_config)


if __name__ == "__main__":
    main()
