#!/usr/bin/env python3

import argparse
import sys
import os
from logging import basicConfig, getLogger, info, warning, critical
import importlib
import importlib.metadata
import importlib.util
import pathlib
try:
    import rich
    from rich.logging import RichHandler
    from rich_argparse import RichHelpFormatter
    from rich.progress import track
    import jinja2
    import feyngraph as fg
    try:
        import tomllib as toml
    except ImportError:
        import tomli as toml
except ImportError:
    print("Unable to import all dependencies, please install FeynGraph with the optional 'cli' feature, e.g. 'pip install feyngraph[cli]'.")
    sys.exit()

VERSION = importlib.metadata.version('feyngraph')

console = rich.console.Console()

if __name__ == "__main__":
    ###########################################################################
    # Basic Setup
    ###########################################################################
    basicConfig(
        level="NOTSET",
        format="%(message)s",
        datefmt="[%X]",
        handlers=[
            RichHandler(show_path=False, rich_tracebacks=True),
        ],
    )

    cli_parser = argparse.ArgumentParser(prog="FeynGraph", description="A modern Feynman diagram generation toolkit", formatter_class=RichHelpFormatter)
    cli_parser.add_argument("config", help="Configuration file for the Diagram generation in TOML format")
    cli_parser.add_argument("-t", "--template", help="Jinja2 template to render the output with", default = None)
    cli_parser.add_argument("-o", "--outfile", help="Name of the generated output file", default = "diagrams.out")
    cli_parser.add_argument("-j", "--jobs", help = "Number of worker threads to use", type=int, default = None)
    cli_parser.add_argument("--loglevel", choices = ["DEBUG", "INFO", "WARN", "ERROR"], default = "INFO")
    cli_parser.add_argument("--version", action = "version", version = VERSION)

    cli_args = cli_parser.parse_args()

    getLogger().setLevel(cli_args.loglevel)

    info(f"FeynGraph version {VERSION} running in executable mode")

    ###########################################################################
    # Config File Checks
    ###########################################################################
    with open(cli_args.config, "rb") as f:
        conf = toml.load(f)
    if cli_args.template is not None:
        conf["template"] = cli_args.template
    if not "template" in conf:
        info(f"Config file {cli_args.config} does not contain property 'template', defaulting to 'json.jinja'.")
        conf["template"] = "json.jinja"
    if not "process" in conf:
        critical(f"Config file {cli_args.config} does not contain a '[process]' section.")
        sys.exit()
    if not "in" in conf["process"] or not "out" in conf["process"]:
        critical(f"Config file {cli_args.config} does not contain the required 'in' and 'out' properties in the '[process]' section.")
        sys.exit()
    if not "loops" in conf["process"]:
        info(f"Config file {cli_args.config} does not contain property 'loops' in the '[process]' section, defaulting to 'loops = 0'.")
        conf["process"]["loops"] = 0
    if not "model" in conf["process"]:
        info(f"Config file {cli_args.config} does not contain property 'model' in the '[process]' section, defaulting to the Standard Model in Feynman gauge.")
        conf["process"]["model"] = None
    if not "momenta" in conf["process"]:
        info(f"Config file {cli_args.config} does not contain property 'momenta' in the '[process]' section, defaulting to 'p<i>' for external momenta and 'l<i>' for loop momenta.")
        conf["process"]["momenta"] = None
    else:
        n_legs = len(conf["process"]["in"]) + len(conf["process"]["out"])
        n_loops = conf["process"]["loops"]
        if len(conf["process"]["momenta"]) != n_legs + n_loops:
            critical(f"The given process with {n_legs} legs and {n_loops} loops has {n_legs + n_loops} associated momenta, but {len(conf["process"]["momenta"])} labels were specified. Please speficy exactly {n_legs + n_loops} labels.")
            sys.exit()
    if "drawing" in conf:
        if not "format" in conf["drawing"]:
            conf["drawing"]["format"] = "tikz"
        elif conf["drawing"]["format"] != "tikz" and conf["drawing"]["format"] != "svg":
            critical(f"Unknown drawing format '{conf["drawing"]["format"]}', the supported formats are 'tikz' and 'svg'.")
            sys.exit()
        if not "outdir" in conf["drawing"]:
            conf["drawing"]["outdir"] = "feyngraph_drawings"
        if not "filename" in conf["drawing"]:
            conf["drawing"]["filename"] = "diagram{i}"
    if cli_args.jobs is not None:
        fg.set_threads(cli_args.jobs)

    ###########################################################################
    # Load and Prepare Model
    ###########################################################################
    if conf["process"]["model"] is not None:
        if os.path.isfile(conf["process"]["model"]):
            info(f"Loading QGRAF model file {conf['process']['model']}...")
            model = fg.Model.from_qgraf(conf["process"]["model"])
        elif os.path.isdir(conf["process"]["model"]):
            info(f"Loading UFO model {conf['process']['model']}...")
            model = fg.Model.from_ufo(conf["process"]["model"])
        else:
            critical(f"Unable to import model {conf['process']['model']} because it is not a valid QGRAF or UFO model.")
            sys.exit()
        info(f"Successfully imported model {conf['process']['model']} with {len(model.particles())} particles and {len(model.vertices())} vertices.")
    else:
        info(f"Loading Standard Model in Feynman Gauge...")
        model = fg.Model()
        info(f"Successfully imported Standard Model with {len(model.particles())} particles and {len(model.vertices())} vertices.")

    ###########################################################################
    # Setup Diagram Filters
    ###########################################################################
    if "filter" in conf:
        selector = fg.DiagramSelector()
        for key, value in conf["filter"].items():
            match key:
                case "onshell":
                    if conf["filter"]["onshell"]:
                        selector.select_on_shell()
                case "self_loops":
                    selector.select_self_loops(conf["filter"]["self_loops"])
                case "opi_components":
                   selector.select_opi_components(conf["filter"]["self_loops"])
                case "coupling_orders":
                   for coupling, power in conf["filter"]["coupling_orders"].items():
                      selector.add_coupling_power(coupling, power)
                case "custom_function":
                    if not "custom_code" in conf["filter"]:
                        critical(f"Custom filter function '{conf['filter']['custom_function']}' requested in config, but no code to define it was given.")
                        sys.exit()
                    exec(conf["filter"]["custom_code"], globals())
                    selector.add_custom_function(globals()[conf["filter"]["custom_function"]])
                case "custom_code":
                    if not "custom_function" in conf["filter"]:
                        warning("The config contains custom filter code, but no custom filter function was defined.")
                    continue
                case _:
                   warning(f"Unknown filter keyword '{key}', ignoring.")
    else:
        selector = None

    ###########################################################################
    # Generate Diagrams
    ###########################################################################
    n_legs = len(conf["process"]["in"]) + len(conf["process"]["in"])
    info(f"Generating Feynman diagrams with {n_legs} legs and {conf['process']['loops']} loops.")
    with console.status(f"Generating diagrams...", spinner="bouncingBall"):
        diag_gen = fg.DiagramGenerator(conf["process"]["in"], conf["process"]["out"], conf["process"]["loops"], model = model, selector = selector)
        if conf["process"]["momenta"] is not None:
            diag_gen.set_momentum_labels(conf["process"]["momenta"])
        diagrams = diag_gen.generate()
    info(f"Successfully generated {len(diagrams)} Feynman diagrams.")

    ###########################################################################
    # Render Output Template
    ###########################################################################
    info(f"Rendering output template '{conf['template']}'.")
    if os.path.isfile(conf["template"]):
        with open(conf["template"], "r") as f:
            template = jinja2.Template(f.read())
    else:
        fg_prefix = pathlib.Path(importlib.util.find_spec("feyngraph").origin).parent.parent.parent.parent.parent.absolute() # type: ignore[reportOptionalMemberAccess]
        template_dir = os.path.join(fg_prefix, 'share', 'FeynGraph', 'templates')
        if os.path.isfile(os.path.join(template_dir, conf["template"])):
            with open(os.path.join(template_dir, conf["template"]), "r") as f:
                template = jinja2.Template(f.read())
        else:
            critical(f"Unable to read template '{conf['template']}'")
            sys.exit()
    with open(cli_args.outfile, "w") as f:
        f.write(template.render({"diagrams": diagrams}))
    info(f"Successfully wrote {len(diagrams)} diagrams to '{cli_args.outfile}'.")

     ###########################################################################
     # Draw Diagrams
     ###########################################################################
    if "drawing" in conf:
        info(f"Drawing {len(diagrams)} Feynman diagrams.")
        if not os.path.isdir(conf["drawing"]["outdir"]):
            os.mkdir(conf["drawing"]["outdir"])
        for i, d in track(enumerate(d for d in diagrams), transient = True):
            filename = conf["drawing"]["filename"].format(i=i)
            match conf["drawing"]["format"]:
                case "svg":
                    d.draw_svg(os.path.join(conf["drawing"]["outdir"], filename + ".svg"))
                case "tikz":
                    d.draw_tikz(os.path.join(conf["drawing"]["outdir"], filename + ".tizk"))
        info(f"Successfully wrote {len(diagrams)} diagrams in '{conf['drawing']['format']}' format to '{conf['drawing']['outdir']}'.")
