from utils import decimal_year_to_date
from scripts.populate_beast_template import taxa_from_fasta
from snk_cli import validate_config

def normalize_bool_strings(value):
    if isinstance(value, dict):
        return {key: normalize_bool_strings(val) for key, val in value.items()}
    if isinstance(value, list):
        return [normalize_bool_strings(item) for item in value]
    if isinstance(value, str):
        lowered = value.strip().lower()
        if lowered == "true":
            return True
        if lowered == "false":
            return False
    return value

config = normalize_bool_strings(config)
validate_config(config, Path(workflow.basedir).parent / "snk.yaml")

if config["alignment"] is None:
    raise ValueError("Must specify alignment")

if type(config["alignment"]) != list:
    config["alignment"] = [config["alignment"]]

if type(config["clock"]) != list:
    config["clock"] = [config["clock"]]

if config["group"] is None:
    config["group"] = []
elif type(config["group"]) != list:
    config["group"] = [config["group"]]

allowed_clocks = ['strict', 'relaxed', 'flc-stem', 'flc-shared-stem', 'flc-clade', 'flc-shared-clade', 'flc-stem-and-clade', 'flc-shared-stem-and-clade']

if any(clock not in allowed_clocks for clock in config["clock"]):
    raise ValueError(f"Invalid clock specified. Allowed clocks are: {', '.join(allowed_clocks)}")

if any("flc" in clock for clock in config["clock"]) and not config["group"]:
    raise ValueError("Must specify at least one group when an FLC clock is selected")

SNAKE_DIR = Path(workflow.basedir)
TEMPLATE_DIR = SNAKE_DIR / "templates"
SCRIPT_DIR = SNAKE_DIR / "scripts"

OUT_DIR=Path(config["output"]["dir"])
if config["output"]["dated"]:
    # create timestamped output directory
    from datetime import datetime
    OUT_DIR = OUT_DIR / datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
CLOCK_DIR = OUT_DIR / "clocks"

# template 
if config["beast"]["template"]:
    beast_xml_template = config["beast"]["template"]
else:
    beast_xml_template = TEMPLATE_DIR / "beast_xml_template.jinja"

MLE: bool = config["marginal_likelihood"].get("estimate")
fit_clocks: bool = config["beast"]["fit_clocks"]
rate_gamma_prior_scale=config["rate_gamma_prior_scale"]
rate_gamma_prior_shape=config["rate_gamma_prior_shape"]
clocks = expand("{clock}_{rate_gamma_prior_shape}_{rate_gamma_prior_scale}", clock=config["clock"], rate_gamma_prior_shape=rate_gamma_prior_shape, rate_gamma_prior_scale=rate_gamma_prior_scale)
flc_clocks = [c for c in clocks if "flc" in c]
relaxed_clocks = [c for c in clocks if "relaxed" in c]
alignment_paths=config["alignment"]
date_delimiter=config["date_delimiter"]
date_index=config["date_index"]
has_multiple_partitions = len(alignment_paths) > 1

duplicates = range(1, config['beast']["duplicates"] + 1) if fit_clocks else []
mle_duplicates = range(1, config["marginal_likelihood"].get("duplicates") + 1) if MLE else []

ALL_LOG_FILES = expand(CLOCK_DIR / "{clock}" / "{clock}_{duplicate}" / "{clock}_{duplicate}.log", clock=clocks, duplicate=duplicates)
PER_CLOCK_LOG_FILES = lambda wildcards: [CLOCK_DIR / wildcards.clock / f"{wildcards.clock}_{duplicate}" / f"{wildcards.clock}_{duplicate}.log" for duplicate in duplicates]

TAXA = taxa_from_fasta(
    alignment_paths[0],
    date_delimiter=date_delimiter,
    date_index=date_index,
)

most_recent_sampling_date = decimal_year_to_date(max(TAXA, key=lambda taxa: taxa.date).date)

print(f"Running Episodic with {len(TAXA)} taxa")
print(f"Most recent sampling date: {most_recent_sampling_date}")

include: "rules/beast.smk"
include: "rules/config.smk"
include: "rules/report.smk"
include: "rules/tree.smk"




CLOCK_FILES = []

CLOCK_FILES.extend(
    [
        expand(CLOCK_DIR / "{clock}" / "{clock}_{duplicate}" / "{clock}_{duplicate}_trace_plots", clock=clocks, duplicate=duplicates),
        expand(CLOCK_DIR / "{clock}" / "{clock}-summary.csv", clock=clocks),
    ]
)
CLOCK_FILES.extend(
    [CLOCK_DIR / f"{clock}" / f"{clock}-violin.svg" for clock in flc_clocks],
)
CLOCK_FILES.extend(
    [CLOCK_DIR / f"{clock}" / f"{clock}-odds.csv" for clock in flc_clocks],
)
CLOCK_FILES.extend(
    [CLOCK_DIR / f"clocks_{rate_gamma_prior_shape}_{rate_gamma_prior_scale}-odds.csv"] if flc_clocks else [],
)
CLOCK_FILES.extend(
    expand(
        CLOCK_DIR / "clocks_{rate_gamma_prior_shape}_{rate_gamma_prior_scale}-{type}.svg", 
        type=["trace", "violin"],
        rate_gamma_prior_shape=rate_gamma_prior_shape,
        rate_gamma_prior_scale=rate_gamma_prior_scale
    ),
)

CLOCK_FILES.extend(
    expand(
        CLOCK_DIR / "{clock}" / "{clock}_{duplicate}" / "{clock}_{duplicate}.mcc.{heights}.{ext}",
        clock=clocks,
        duplicate=duplicates,
        heights=config["mcc_tree"].get("heights", "mean"),
        ext=["nwk", "svg"]
    ),
)
CLOCK_FILES.extend(
    expand(
        CLOCK_DIR / "{clock}" / "{clock}_{duplicate}" / "{clock}_{duplicate}.stem.rate_quantiles.{ext}", 
        clock=relaxed_clocks, 
        duplicate=duplicates,
        ext=["csv", "svg"]
    )
)

CLOCK_FILES.extend(
    expand(
        CLOCK_DIR / "{clock}" / "{clock}-partition_local_rate_posteriors.{ext}",
        clock=flc_clocks if has_multiple_partitions else [],
        ext=["csv", "svg"],
    )
)

CLOCK_FILES.extend(
    expand(
        CLOCK_DIR / "{clock}" / "{clock}-substitution_model_rates.{ext}",
        clock=flc_clocks,
        ext=["csv", "svg"],
    )
)



OUTPUT_FILES = [
    OUT_DIR / "config.yaml",
    OUT_DIR / "taxon_groups.tsv",
]

if fit_clocks:
    OUTPUT_FILES.extend(
        CLOCK_FILES
    )

if MLE:
    OUTPUT_FILES.append(
        OUT_DIR / "mle" / "mle.svg"
    )

rule all:
    input: 
        *OUTPUT_FILES
