import snakebids
from snakebids import bids, generate_inputs, get_wildcard_constraints
from appdirs import AppDirs

configfile: 'config/snakebids.yml'

inputs = generate_inputs(
    bids_dir=config['bids_dir'],
    pybids_inputs=config['pybids_inputs'],
    pybids_config=["bids", "derivatives"],
    derivatives=config["derivatives"],
    participant_label=config["participant_label"],
    exclude_participant_label=config["exclude_participant_label"],
)

# this adds constraints to the bids naming
wildcard_constraints:
    **get_wildcard_constraints(config["pybids_inputs"]),

# ----
wildcard_constraints:
    desc="[a-zA-Z0-9]+",
    space="[a-zA-Z0-9]+",
    hemi="[a-zA-Z0-9]+",
    surfname="[a-zA-Z0-9]+",
    modality="[a-zA-Z0-9]+",
    density="[a-zA-Z0-9]+",
    atlas="[a-zA-Z0-9]+",
    autotop="[a-zA-Z0-9]+",
    template="[a-zA-Z0-9]+",

def get_download_dir():
    if "AUTOAFIDS_CACHE_DIR" in os.environ.keys():
        download_dir = os.environ["AUTOAFIDS_CACHE_DIR"]
    else:
        # create local download dir if it doesn't exist
        dirs = AppDirs("autoafids", "jclauneurolab")
        download_dir = dirs.user_cache_dir
    return download_dir

download_dir = get_download_dir()
# ---- begin preproc profile pathways --------------------------------------------

processing_steps = {
    "slow": ["freesurfer_preproc", "skull_strip", "norm_im"],
    "medium": ["n4_bias_correction", "skull_strip", "norm_im", "resample_im"],
    "fast": ["n4_bias_correction", "norm_im", "resample_im"],
    "superfast": ["norm_im", "resample_im"]
}

profile_settings = processing_steps[config["profile"]]

# Normalization scheme 
norm_methods = {
0: "zscore", #best for multi-contrast data
1: "minmax", #best for data from the same center
2: "percentile", #best for CT imaging [notavail]
3: "wm", #triggers FS WM tissue normalization then z-scores output (only for T1w)[notavail]
4: "other" #user must provide a python funtion (custom)[notavail]
}

# Chosen norm method
chosen_norm_method =  norm_methods[int(config['norm'])]

# Chosen resampling resolution
resolution = float(config['res'])

# Chosen template 
templates = {
    "T1w": config['templatet1w'], 
    "T2w": config['templatet2w']
}

chosen_template =  templates[config['template']]

print(f'..... {config["profile"]} preprocessing configration .....')
print(f' the folowing squence will be triggered {profile_settings}')


registration_bids = partial(
    bids,
    root=str(Path(config["output_dir"]) / "registration"),
    datatype="anat",
    space="native",
    **inputs["t1w"].wildcards,
)

synthstrip_bids = partial(
    bids,
    root=str(Path(config["output_dir"]) / "synthstrip"),
    datatype="anat",
    desc="synthstrip",
    **inputs["t1w"].wildcards,
)

im_normed = partial(
    bids,
    root=str(Path(config["output_dir"]) / "normalize"),
    datatype="anat",
    desc=chosen_norm_method,
    **inputs["t1w"].wildcards,
)

im_resamp = partial(
    bids,
    root=str(Path(config["output_dir"]) / "resample"),
    datatype="anat",
    desc=chosen_norm_method,
    res=config["res"],
    **inputs["t1w"].wildcards,
)

if "n4_bias_correction" in profile_settings:

    rule n4_bias_correction:
        input:
            im=bids(
            root=str(Path(config["bids_dir"])),
            datatype="anat",
            suffix="T1w.nii.gz",
            **inputs["t1w"].wildcards
            ),
        output:
            corrected_im=bids(
                root=str(Path(config["output_dir"]) / "n4biascorr"),
                datatype="anat",
                desc="n4corrected",
                suffix="T1w.nii.gz",
                **inputs["t1w"].wildcards
            ),
        script:
            "scripts/n4_bias_corr.py"


rule norm_im:
    input:
        im_raw= lambda wildcards: synthstrip_bids(suffix="T1w.nii.gz") if "skull_strip" in profile_settings else
            bids(root=str(Path(config["output_dir"]) / "n4biascorr"),
                 datatype="anat",
                 desc="n4corrected",
                 suffix="T1w.nii.gz",
                 **inputs["t1w"].wildcards) if "n4_bias_correction" in profile_settings else
            bids(root=str(Path(config["bids_dir"])),
            datatype="anat",
            suffix="T1w.nii.gz",
            **inputs["t1w"].wildcards)
    output:
        im_norm=im_normed(suffix="T1w.nii.gz"),
    params:
        norm_method=chosen_norm_method
    script:
        "scripts/norm_schemes.py"

if "resample_im" in profile_settings:    

    rule resample_im:
        input:
            im_normed=rules.norm_im.output.im_norm,
        params:
            res=resolution
        output:
            resam_im=im_resamp(suffix="T1w.nii.gz"),
        script:
            "scripts/resample_img.py"


rule mni2sub:
    input:
        im_resamp = rules.n4_bias_correction.output.corrected_im if "n4_bias_correction" in profile_settings else
            bids(root=str(Path(config["bids_dir"])),
            datatype="anat",
            suffix="T1w.nii.gz",
            **inputs["t1w"].wildcards)
    output: 
        xfm_slicer = registration_bids(suffix="slicer.mat"),
        xfm_ras = registration_bids(suffix="xfm.txt"),
        out_im = registration_bids(suffix="MNI.nii.gz")
    params:
        moving = str(Path(workflow.basedir).parent / chosen_template)
    script: 
        "scripts/regis_script.py"

rule mni2subfids:
    input:
        xfm_new = rules.mni2sub.output.xfm_ras
    params:
        fcsv=str(Path(workflow.basedir).parent / config["fcsv"]),
        fcsv_mni = str(Path(workflow.basedir).parent / config["fcsv_mni"])
    output:
        fcsv_new=registration_bids(desc="MNI", suffix="afids.fcsv"),
    script:
        "scripts/tform_script.py"

rule preprocessing_result:
    input:
        samp = inputs["t1w"].expand(rules.resample_im.output.resam_im),
        mni = inputs["t1w"].expand(rules.mni2subfids.output.fcsv_new),

include: "rules/cnn.smk"

rule all:
    input:
        models=inputs["t1w"].expand(
            rules.gen_fcsv.output.fcsv,
        ),
    default_target: True