# Copyright (C) 2023 Luigi Pertoldi <gipert@pm.me>
#
# This program is free software: you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option) any
# later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

import os
from datetime import datetime
from pathlib import Path
import time

from snakemake.logging import logger
from dbetto import AttrsDict

from legendsimflow import aggregate, nersc, utils
from legendsimflow.aggregate import gen_list_of_all_hpges_valid_for_modeling


logger.info("Initializing workflow context")

SIMFLOW_CONTEXT = utils.init_simflow_context(config, workflow)
config = SIMFLOW_CONTEXT.config

# it takes a while to generate the list of good HPGes of all channels, let's do
# it at the start and cache it for later
SIMFLOW_CONTEXT |= {
    "modelable_hpges": gen_list_of_all_hpges_valid_for_modeling(
        config, write_to_file=config.paths.pars / "modelable_hpge_detectors.yaml"
    ),
}

logger.info("Done! Handing it over to Snakemake")


def dvs_ro(x):
    return nersc.dvs_ro(config, x)


wildcard_constraints:
    tier=r"\w+",
    simid=r"[-\w]+",
    jobid=r"\d+",
    runid=r"\w+(?:-\w+){3}",
    hpge_detector=r"[VPBC]\w{6}",
    hpge_voltage=r"\d+",


onstart:
    utils.setup_logdir_link(config)

    # make sure some packages are initialized before we begin to avoid race conditions
    # https://numba.readthedocs.io/en/stable/developer/caching.html#cache-sharing
    if not workflow.touch and "precompile_pkg" in config:
        logger.info(
            f"Pre-compiling Numba cache (NUMBA_CACHE_DIR={os.environ.get('NUMBA_CACHE_DIR')})"
        )
        for module in config.precompile_pkg:
            shell(f"python -c 'from {module} import *'")

    logger.info("Starting workflow")


include: "rules/aux.smk"


make_steps = config.get("make_steps", aggregate.STEPS_ORDERED)
make_steps = utils.sorted_by(make_steps, aggregate.STEPS_ORDERED)
config["make_steps"] = make_steps

for step in make_steps:

    include: f"rules/{step}.smk"


# strategy: select the latest requested tier (in steps_ordered),
# then make the build cumulative by including all tiers up to it.
def gen_target_all():
    if config.get("simlist", "*") not in ("all", "*"):
        return aggregate.process_simlist(config, make_steps=make_steps)

    inputs = []
    for step in make_steps:
        if step == "cvt":
            inputs.extend(rules.gen_all_tier_cvt.input)
        elif step == "evt":
            inputs.extend(rules.gen_all_tier_evt.input)
        elif step == "hit":
            inputs.extend(rules.gen_all_tier_hit.input)
        elif step == "opt":
            inputs.extend(rules.gen_all_tier_opt.input)
        elif step == "stp":
            inputs.extend(rules.gen_all_tier_stp.input)
        elif step in ("vtx", "par"):
            pass

    return inputs


rule all:
    """Default rule.

    Makes all output files of the highest tier specified in the `make_steps`
    configuration variable (holding a list of steps) and all the simids listed
    in the `simlist` configuration variable.
    """
    default_target: True
    input:
        gen_target_all(),
