import os

# Relative path to the directory where you store the data of this workflow.
data_root = "data"

configfile: os.path.join("workflow", "lumpings.yml")

# TODO:
# - source in config file required? Probably yes, because of all the file that are loaded within python.


def get_d(wildcards):
  return config[wildcards.lumping]["kernel similarity"]

def get_g(wildcards):
  return config[wildcards.lumping]["feature kernel"]

def get_top_file(wildcards):
  configfile: os.path.join(data_root, wildcards.system, "input", "config.yml")
  return os.path.join(config["source"], config["topology file"])

def get_xtc_file(wildcards):
  configfile: os.path.join(data_root, wildcards.system, "input", "config.yml")
  return os.path.join(config["source"], config["xtc file"])

def get_contact_index_file(wildcards):
  configfile: os.path.join(data_root, wildcards.system, "input", "config.yml")
  if "contact index file" in config:
    return os.path.join(config["source"], config["contact index file"])
  else:
    return "none"

def get_xtc_stride(wildcards):
  configfile: os.path.join(data_root, wildcards.system, "input", "config.yml")
  if "xtc stride" in config:
    return int(config["xtc stride"])
  else:
    return 1

def get_mem_mb(wildcards):
  configfile: os.path.join(data_root, wildcards.system, "input", "config.yml")
  return os.path.getsize(os.path.join(config["source"], config["xtc file"])) / 1e6 + 500

rule gen_Z:
  input:
    os.path.join("data", "{system}", "input", "config.yml"),
  output:
    os.path.join("data", "{system}", "results", "{lumping}", "Z.npy")
  params:
    d=get_d,
    g=get_g
  conda:
    "mpp"
  cache: True
  shell:
    "python -m MPP.run {input} {params.d} {params.g} -Z {output}"

rule plot:
  input:
    config=os.path.join(data_root, "{system}", "input", "config.yml"),
    z=os.path.join("data", "{system}", "results", "{lumping}", "Z.npy"),
  output:
    os.path.join("data", "{system}", "results", "{lumping}", "{plot}.{ext}"),
  wildcard_constraints:
    ext="pdf|png|txt"
  params:
    d=get_d,
    g=get_g
  conda:
    "mpp"
  shell:
    "python -m MPP.run {input.config} {params.d} {params.g} -Z {input.z} -p {wildcards.plot} -o {output}"

rule rmsd_CA:
  input:
    config=os.path.join("data", "{system}", "input", "config.yml"),
    z=os.path.join("data", "{system}", "results", "{lumping}", "Z.npy"),
  output:
    npy=os.path.join("data", "{system}", "results", "{lumping}", "rmsd_CA.npy"),
    ndx=os.path.join("data", "{system}", "results", "{lumping}", "rmsd_CA_mean_frames.ndx")
  params:
    d=get_d,
    g=get_g
  resources:
    mem_mb=get_mem_mb
  conda:
    "mpp"
  shell:
    "python -m MPP.run {input.config} {params.d} {params.g} -Z {input.z} --rmsd {output.npy}"

rule rmsd_feature:
  input:
    config=os.path.join("data", "{system}", "input", "config.yml"),
    z=os.path.join("data", "{system}", "results", "{lumping}", "Z.npy")
  output:
    npy=os.path.join("data", "{system}", "results", "{lumping}", "rmsd_feature.npy"),
    ndx=os.path.join("data", "{system}", "results", "{lumping}", "rmsd_feature_mean_frames.ndx")
  params:
    d=get_d,
    g=get_g
  resources:
    mem_mb=get_mem_mb
  conda:
    "mpp"
    # "mpp.yml"
  shell:
    "python -m MPP.run {input.config} {params.d} {params.g} -Z {input.z} --rmsd {output.npy} --rmsd-feature feature"

checkpoint draw_random:
  input:
    config=os.path.join("data", "{system}", "input", "config.yml"),
    z=os.path.join("data", "{system}", "results", "{lumping}", "Z.npy")
  output:
    macrostates=directory(os.path.join("data", "{system}", "results", "{lumping}", "random_frames"))
  params:
    d=get_d,
    g=get_g,
    n=20
  conda:
    "mpp"
    # "mpp.yml"
  shell:
    "python -m MPP.run {input.config} {params.d} {params.g} -Z {input.z} -r {params.n} -o {output}"

def aggregate_input(wildcards):
    checkpoint_output = checkpoints.draw_random.get(**wildcards).output[0]
    return expand(os.path.join("data", "{system}", "results", "{lumping}", "random_pngs", "{macrostate}.png"),
           system=wildcards.system,
           lumping=wildcards.lumping,
           macrostate=glob_wildcards(os.path.join(checkpoint_output, "{macrostate}.ndx")).macrostate)

def aggregate_input_pdb(wildcards):
    checkpoint_output = checkpoints.draw_random.get(**wildcards).output[0]
    return expand(os.path.join("data", "{system}", "results", "{lumping}", "random_pdbs", "{macrostate}.pdb"),
           system=wildcards.system,
           lumping=wildcards.lumping,
           macrostate=glob_wildcards(os.path.join(checkpoint_output, "{macrostate}.ndx")).macrostate)

rule aggregate:
  input:
    aggregate_input
  output:
    os.path.join("data", "{system}", "results", "{lumping}", "macrostates.txt")
  shell:
    "echo {input} > {output}"

rule aggregate_pdb:
  input:
    aggregate_input_pdb
  output:
    os.path.join("data", "{system}", "results", "{lumping}", "macrostates_pdb.txt")
  shell:
    "echo {input} > {output}"


rule get_least_moving_residues:
  input:
    config=os.path.join("data", "{system}", "input", "config.yml"),
    z=os.path.join("data", "{system}", "results", "{lumping}", "Z.npy")
  output:
    os.path.join("data", "{system}", "results", "{lumping}", "least_moving_residues.ndx")
  params:
    d=get_d,
    g=get_g,
    ndx_file=get_contact_index_file,
  conda:
    "mpp"
  shell:
    "python -m MPP.run {input.config} {params.d} {params.g} -Z {input.z} -o {output} --get-least-moving-residues {params.ndx_file}"

# TODO:
# Mind the xtc stride / --skip not working; rewrite of .ndx required
rule get_random_pdb_frames:
  input:
    os.path.join("data", "{system}", "results", "{lumping}", "random_frames", "{macrostate}.ndx")
  output:
    os.path.join("data", "{system}", "results", "{lumping}", "random_pdbs", "{macrostate}.pdb")
  params:
    top=get_top_file,
    xtc=get_xtc_file,
  shell:
    """mkdir -p random_pdbs
gmx trjconv -s {params.top} -f {params.xtc} -o {output} -fr {input} <<eof
0
eof"""

rule get_mean_frames_pdb_CA:
  input:
    os.path.join("data", "{system}", "results", "{lumping}", "rmsd_CA_mean_frames.ndx")
  output:
    os.path.join("data", "{system}", "results", "{lumping}", "mean_frames_CA.pdb")
  params:
    top=get_top_file,
    xtc=get_xtc_file,
  shell:
    """mkdir -p random_pdbs
gmx trjconv -s {params.top} -f {params.xtc} -o {output} -fr {input} <<eof
0
eof"""

rule get_mean_frames_pdb_feature:
  input:
    os.path.join("data", "{system}", "results", "{lumping}", "rmsd_feature_mean_frames.ndx")
  output:
    os.path.join("data", "{system}", "results", "{lumping}", "mean_frames_feature.pdb")
  params:
    top=get_top_file,
    xtc=get_xtc_file,
  shell:
    """mkdir -p random_pdbs
gmx trjconv -s {params.top} -f {params.xtc} -o {output} -fr {input} <<eof
0
eof"""

rule render_pdb_files:
  input:
    pdb=os.path.join("data", "{system}", "results", "{lumping}", "random_pdbs", "{macrostate}.pdb"),
    ndx=os.path.join("data", "{system}", "results", "{lumping}", "least_moving_residues.ndx")
  output:
    os.path.join("data", "{system}", "results", "{lumping}", "random_pngs", "{macrostate}.png")
  conda:
    "pymol"
  params:
    ndx_file=get_contact_index_file,
  shell:
    "pymol -cq " + os.path.join("workflow", "render_random_structures.py") + " -- {input.pdb} {output} {input.ndx}"


checkpoint draw_cluster_in_states:
  input:
    config=os.path.join("data", "{system}", "input", "config.yml"),
    frames=os.path.join("data", "{system}", "results", "{lumping}", "mean_frames_{rmsd_feature}.pdb")
  output:
    directory(os.path.join("data", "{system}", "results", "{lumping}", "mean_frames_{rmsd_feature}_{clusters}"))
  conda:
    "pymol"
  shell:
    os.path.join("workflow", "create_cluster_pml.py") + " {output} {input.config} {input.frames} {wildcards.clusters}"

def aggregate_input_draw_cluster(wildcards):
    checkpoint_output = checkpoints.draw_cluster_in_states.get(**wildcards).output[0]
    return expand(os.path.join("data", "{system}", "results", "{lumping}", "mean_frames_{rmsd_feature}", "{macrostate}.png"),
           system=wildcards.system,
           lumping=wildcards.lumping,
           macrostate=glob_wildcards(os.path.join(checkpoint_output, "{macrostate}.png")).macrostate)

rule aggregate_draw_cluster:
  input:
    aggregate_input_draw_cluster
  output:
    os.path.join("data", "{system}", "results", "{lumping}", "rmsd_mean_{rmsd_feature}_{clusters}.txt")
  shell:
    "echo {input} > {output}"


rule plot_all:
  input:
    os.path.join("data", "{system}", "results", "{lumping}", "sankey.pdf"),
    os.path.join("data", "{system}", "results", "{lumping}", "dendrogram.pdf"),
    os.path.join("data", "{system}", "results", "{lumping}", "ck_test.pdf"),
    os.path.join("data", "{system}", "results", "{lumping}", "timescales.pdf"),
    os.path.join("data", "{system}", "results", "{lumping}", "contacts.pdf"),
    os.path.join("data", "{system}", "results", "{lumping}", "macrotraj.pdf"),
    os.path.join("data", "{system}", "results", "{lumping}", "rmsd.pdf"),
    os.path.join("data", "{system}", "results", "{lumping}", "delta_rmsd.pdf"),
    os.path.join("data", "{system}", "results", "{lumping}", "sankey.png"),
    os.path.join("data", "{system}", "results", "{lumping}", "dendrogram.png"),
    os.path.join("data", "{system}", "results", "{lumping}", "ck_test.png"),
    os.path.join("data", "{system}", "results", "{lumping}", "timescales.png"),
    os.path.join("data", "{system}", "results", "{lumping}", "contacts.png"),
    os.path.join("data", "{system}", "results", "{lumping}", "macrotraj.png"),
    os.path.join("data", "{system}", "results", "{lumping}", "rmsd.png"),
    os.path.join("data", "{system}", "results", "{lumping}", "delta_rmsd.png"),
  output:
    os.path.join("data", "{system}", "results", "{lumping}", "plotted"),
  shell:
    "date > {output}"
