import os
import datetime  # Add this import
from pathlib import Path
import shutil
from micaflow.scripts.util_bids_pathing import check_paths
import json

# Default parameters
SUBJECT = config.get("subject", "")
SESSION = config.get("session", "")
OUT_DIR = config.get("output", "")
THREADS = workflow.cores

# Dynamic thread allocation logic:
# - If cores > 2: Reserve 1 core for light jobs, use (cores-1) for heavy jobs
# - If cores <= 2: Use all cores for heavy jobs (no room for parallel light jobs)
if THREADS > 2:
    HEAVY_THREADS = THREADS - 1
    LIGHT_THREADS = 1
else:
    HEAVY_THREADS = THREADS
    LIGHT_THREADS = 1

DATA_DIRECTORY = config.get("data_directory", "")
FLAIR_FILE = config.get("flair_file", "")
T1W_FILE = config.get("t1w_file", "")
DWI_FILE = config.get("dwi_file", "")
BVAL_FILE = config.get("bval_file", "")
BVEC_FILE = config.get("bvec_file", "")
INVERSE_DWI_FILE = config.get("inverse_dwi_file", "")
INVERSE_BVAL_FILE = config.get("inverse_bval_file", "")
INVERSE_BVEC_FILE = config.get("inverse_bvec_file", "")
RM_CEREBELLUM = config.get("rm_cerebellum", False)
KEEP_TEMP = config.get("keep_temp", False)  # New parameter to control temp directory cleanup
PED = config.get("PED", "")
SHELL_DIMENSION = config.get("shell_dimension", 3)
EXTRACT_BRAIN = config.get("extract_brain", False)
GPU = config.get("gpu", "")

# FIX: Define CPU_FLAG globally for SynthSeg and other rules
# If GPU is False/Empty, use "--cpu" flag. If GPU is enabled, pass empty string.
CPU_FLAG = "--cpu" if not GPU else ""

# Registration Flags
LINEAR = config.get("linear", False)
NONLINEAR = config.get("nonlinear", False)

# logic: if neither, default to NONLINEAR.
if not LINEAR and not NONLINEAR:
    NONLINEAR = True

# Define MNI target spaces list
MNI_SPACES = []
if NONLINEAR:
    MNI_SPACES.append("MNI152")
if LINEAR:
    MNI_SPACES.append("MNI152linear")

# Add Linear Registration Flag
LINEAR = config.get("linear", False)

# Define MNI space naming based on linear flag
MNI_SPACE = "MNI152linear" if LINEAR else "MNI152"

# Add this to the top section where variables are defined
USE_SYNTH_B0 = INVERSE_DWI_FILE == "" or INVERSE_DWI_FILE is None

# Atlas paths
ATLAS_DIR = os.path.join(workflow.basedir, "atlas")
ATLAS = os.path.join(ATLAS_DIR, "mni_icbm152_t1_tal_nlin_sym_09a.nii")
ATLAS_MASK = os.path.join(ATLAS_DIR, "mni_icbm152_t1_tal_nlin_sym_09a_mask.nii")
ATLAS_SEG = os.path.join(ATLAS_DIR, "mni_icbm152_t1_tal_nlin_sym_09a_seg.nii")

SCRIPT_DIR = os.path.join(os.path.dirname(workflow.basedir), "scripts")

# Define temp directory
TEMP_DIR = f"{OUT_DIR}/{SUBJECT}/{SESSION}/temp"
os.makedirs(TEMP_DIR, exist_ok=True)

# Cache file path for validated parameters
_PARAMS_CACHE = os.path.join(OUT_DIR, SUBJECT, SESSION, "micaflow_parameters.json")

def _get_validated_params():
    """
    Run check_paths only once and cache results to micaflow_parameters.json.
    On subsequent Snakefile parses (e.g., for run: blocks), read from cache.
    """
    global DATA_DIRECTORY, OUT_DIR, SUBJECT, SESSION, GPU
    global FLAIR_FILE, T1W_FILE, DWI_FILE, BVAL_FILE, BVEC_FILE
    global INVERSE_DWI_FILE, INVERSE_BVAL_FILE, INVERSE_BVEC_FILE
    global THREADS, RUN_FLAIR, USE_SYNTH_B0
    global RM_CEREBELLUM, KEEP_TEMP, PED, SHELL_DIMENSION, EXTRACT_BRAIN
    global LINEAR, NONLINEAR
    
    # If cache exists, read from it
    if os.path.exists(_PARAMS_CACHE):
        try:
            with open(_PARAMS_CACHE, "r") as f:
                cached = json.load(f)
            
            # FIX: Invalidate cache if input file paths have changed in the current command line
            cache_valid = True
            if T1W_FILE and T1W_FILE != cached.get("T1W_FILE"):
                print(f"[INFO] T1w input changed. Invalidating cache.")
                cache_valid = False
            if DWI_FILE and DWI_FILE != cached.get("DWI_FILE"):
                print(f"[INFO] DWI input changed. Invalidating cache.")
                cache_valid = False
            if FLAIR_FILE and FLAIR_FILE != cached.get("FLAIR_FILE"):
                print(f"[INFO] FLAIR input changed. Invalidating cache.")
                cache_valid = False
                
            if cache_valid:
                return (
                    cached["DATA_DIRECTORY"], cached["OUT_DIR"], cached["SUBJECT"],
                    cached["SESSION"], cached["RUN_DWI"], cached["GPU"],
                    cached["FLAIR_FILE"], cached["T1W_FILE"], cached["DWI_FILE"],
                    cached["BVAL_FILE"], cached["BVEC_FILE"], cached["INVERSE_DWI_FILE"],
                    cached["INVERSE_BVAL_FILE"], cached["INVERSE_BVEC_FILE"],
                    cached["THREADS"], cached["RUN_FLAIR"], cached["RM_CEREBELLUM"],
                    cached["KEEP_TEMP"], cached["PED"], cached["SHELL_DIMENSION"],
                    cached["EXTRACT_BRAIN"], cached["USE_SYNTH_B0"], 
                    cached.get("LINEAR", False), cached.get("NONLINEAR", True)
                )
        except (json.JSONDecodeError, KeyError):
            pass  # Cache invalid
    
    # Run check_paths (first time only)
    print(f"[INFO] Total cores: {THREADS}")
    print(f"[INFO] Heavy job threads: {THREADS - 1 if THREADS > 2 else THREADS}")
    print(f"[INFO] Light job threads: 1")
    print("GPU: ", GPU)
    print("ATLAS_DIR: ", ATLAS_DIR)
    print(f"Temporary directory: {TEMP_DIR}")
    
    # FIX: Ensure all file paths are strings (handle None -> "") before passing to check_paths
    def safe_str(val):
        return "" if val is None else str(val)

    result = check_paths(
        safe_str(DATA_DIRECTORY), safe_str(OUT_DIR), safe_str(SUBJECT), safe_str(SESSION), safe_str(GPU),
        safe_str(FLAIR_FILE), safe_str(T1W_FILE), safe_str(DWI_FILE), safe_str(BVAL_FILE), safe_str(BVEC_FILE),
        safe_str(INVERSE_DWI_FILE), safe_str(INVERSE_BVAL_FILE), safe_str(INVERSE_BVEC_FILE), THREADS
    )
    
    (data_dir, out_dir, subject, session, run_dwi, gpu,
     flair_file, t1w_file, dwi_file, bval_file, bvec_file,
     inv_dwi, inv_bval, inv_bvec, threads, run_flair) = result
    
    # Recalculate USE_SYNTH_B0 based on validated values
    use_synth_b0 = inv_dwi == "" or inv_dwi is None
    
    # Save all parameters to cache
    cache_data = {
        "DATA_DIRECTORY": data_dir,
        "OUT_DIR": out_dir,
        "SUBJECT": subject,
        "SESSION": session,
        "RUN_DWI": run_dwi,
        "GPU": gpu,
        "FLAIR_FILE": flair_file,
        "T1W_FILE": t1w_file,
        "DWI_FILE": dwi_file,
        "BVAL_FILE": bval_file,
        "BVEC_FILE": bvec_file,
        "INVERSE_DWI_FILE": inv_dwi,
        "INVERSE_BVAL_FILE": inv_bval,
        "INVERSE_BVEC_FILE": inv_bvec,
        "THREADS": threads,
        "RUN_FLAIR": run_flair,
        "RM_CEREBELLUM": RM_CEREBELLUM,
        "KEEP_TEMP": KEEP_TEMP,
        "PED": PED,
        "SHELL_DIMENSION": SHELL_DIMENSION,
        "EXTRACT_BRAIN": EXTRACT_BRAIN,
        "USE_SYNTH_B0": use_synth_b0,
        "LINEAR": LINEAR,
        "NONLINEAR": NONLINEAR
    }
    
    # Ensure directory exists and save
    os.makedirs(os.path.dirname(_PARAMS_CACHE), exist_ok=True)
    with open(_PARAMS_CACHE, "w") as f:
        json.dump(cache_data, f, indent=4)
    
    print("flair_file: ", flair_file)
    print("run_flair: ", run_flair)
    print("keep_temp: ", KEEP_TEMP)
    
    return (data_dir, out_dir, subject, session, run_dwi, gpu,
            flair_file, t1w_file, dwi_file, bval_file, bvec_file,
            inv_dwi, inv_bval, inv_bvec, threads, run_flair,
            RM_CEREBELLUM, KEEP_TEMP, PED, SHELL_DIMENSION, EXTRACT_BRAIN, use_synth_b0, 
            LINEAR, NONLINEAR)

# Get validated parameters (from cache or by running check_paths)
(DATA_DIRECTORY, OUT_DIR, SUBJECT, SESSION, RUN_DWI, GPU,
 FLAIR_FILE, T1W_FILE, DWI_FILE, BVAL_FILE, BVEC_FILE,
 INVERSE_DWI_FILE, INVERSE_BVAL_FILE, INVERSE_BVEC_FILE,
 THREADS, RUN_FLAIR, RM_CEREBELLUM, KEEP_TEMP, PED,
 SHELL_DIMENSION, EXTRACT_BRAIN, USE_SYNTH_B0, LINEAR, NONLINEAR) = _get_validated_params()

# Re-establish lists because _get_validated only returns simple types
MNI_SPACES = []
if NONLINEAR: MNI_SPACES.append("MNI152")
if LINEAR: MNI_SPACES.append("MNI152linear")

print("Target MNI Spaces:", MNI_SPACES)

def get_final_output():
    outputs = []
    # Native space outputs (Invariant)
    outputs.append(os.path.join(OUT_DIR, SUBJECT, SESSION, "anat", f"{SUBJECT}_{SESSION}_T1w-space_T1w.nii.gz"))
    outputs.append(os.path.join(OUT_DIR, SUBJECT, SESSION, "anat", f"{SUBJECT}_{SESSION}_T1w-space_T1w_normalized.nii.gz"))
    
    # Iterate over required MNI spaces
    for space in MNI_SPACES:
        # Texture
        outputs.append(f"{OUT_DIR}/{SUBJECT}/{SESSION}/textures/{SUBJECT}_{SESSION}_{space}-space_textures-T1w_gradient-magnitude.nii.gz")
        outputs.append(f"{OUT_DIR}/{SUBJECT}/{SESSION}/textures/{SUBJECT}_{SESSION}_{space}-space_textures-T1w_relative-intensity.nii.gz")
        # Metrics
        outputs.append(f"{OUT_DIR}/{SUBJECT}/{SESSION}/metrics/{SUBJECT}_{SESSION}_T1w-to-{space}_DICE.csv")
        # Anat
        outputs.append(os.path.join(OUT_DIR, SUBJECT, SESSION, "anat", f"{SUBJECT}_{SESSION}_{space}-space_T1w_normalized.nii.gz"))
        outputs.append(os.path.join(OUT_DIR, SUBJECT, SESSION, "anat", f"{SUBJECT}_{SESSION}_{space}-space_T1w.nii.gz"))

        if RUN_FLAIR:
            outputs.append(f"{OUT_DIR}/{SUBJECT}/{SESSION}/textures/{SUBJECT}_{SESSION}_{space}-space_textures-FLAIR_gradient-magnitude.nii.gz")
            outputs.append(f"{OUT_DIR}/{SUBJECT}/{SESSION}/textures/{SUBJECT}_{SESSION}_{space}-space_textures-FLAIR_relative-intensity.nii.gz")
            outputs.append(os.path.join(OUT_DIR, SUBJECT, SESSION, "anat", f"{SUBJECT}_{SESSION}_{space}-space_FLAIR_normalized.nii.gz"))
            outputs.append(os.path.join(OUT_DIR, SUBJECT, SESSION, "anat", f"{SUBJECT}_{SESSION}_{space}-space_FLAIR.nii.gz"))

        if RUN_DWI:
            outputs.append(f"{OUT_DIR}/{SUBJECT}/{SESSION}/dwi/{SUBJECT}_{SESSION}_{space}-space_FA.nii.gz")
            outputs.append(f"{OUT_DIR}/{SUBJECT}/{SESSION}/dwi/{SUBJECT}_{SESSION}_{space}-space_MD.nii.gz")

        if EXTRACT_BRAIN:
            outputs.append(f"{OUT_DIR}/{SUBJECT}/{SESSION}/brain-extracted/{SUBJECT}_{SESSION}_{space}-space_T1w.nii.gz")
            outputs.append(f"{OUT_DIR}/{SUBJECT}/{SESSION}/brain-extracted/{SUBJECT}_{SESSION}_{space}-space_T1w_normalized.nii.gz")
            if RUN_FLAIR:
                outputs.append(f"{OUT_DIR}/{SUBJECT}/{SESSION}/brain-extracted/{SUBJECT}_{SESSION}_{space}-space_FLAIR.nii.gz")
                outputs.append(f"{OUT_DIR}/{SUBJECT}/{SESSION}/brain-extracted/{SUBJECT}_{SESSION}_{space}-space_FLAIR_normalized.nii.gz")
            
    # Native space extras
    if EXTRACT_BRAIN:
        outputs.append(f"{OUT_DIR}/{SUBJECT}/{SESSION}/brain-extracted/{SUBJECT}_{SESSION}_T1w-space_T1w_normalized.nii.gz")
        if RUN_FLAIR:
            outputs.append(f"{OUT_DIR}/{SUBJECT}/{SESSION}/brain-extracted/{SUBJECT}_{SESSION}_T1w-space_FLAIR_normalized.nii.gz")
        if RUN_DWI:
            # ... dwi native brain ...
            pass # (Keep existing logic)

    # ... existing native space rules for DWI ...
    if RUN_DWI:
        outputs.extend([
            f"{OUT_DIR}/{SUBJECT}/{SESSION}/dwi/{SUBJECT}_{SESSION}_T1w-space_FA.nii.gz",
            f"{OUT_DIR}/{SUBJECT}/{SESSION}/dwi/{SUBJECT}_{SESSION}_T1w-space_MD.nii.gz",
            f"{OUT_DIR}/{SUBJECT}/{SESSION}/metrics/{SUBJECT}_{SESSION}_DWI-to-T1w-space_DICE.csv",
            f"{OUT_DIR}/{SUBJECT}/{SESSION}/dwi/{SUBJECT}_{SESSION}_T1w-space_FA_normalized.nii.gz",
            f"{OUT_DIR}/{SUBJECT}/{SESSION}/dwi/{SUBJECT}_{SESSION}_T1w-space_MD_normalized.nii.gz"
        ])
            
    if RUN_FLAIR:
        outputs.append(f"{OUT_DIR}/{SUBJECT}/{SESSION}/metrics/{SUBJECT}_{SESSION}_FLAIR-to-T1w-space_DICE.csv")
        outputs.append(os.path.join(OUT_DIR, SUBJECT, SESSION, "anat", f"{SUBJECT}_{SESSION}_T1w-space_FLAIR.nii.gz"))
        outputs.append(os.path.join(OUT_DIR, SUBJECT, SESSION, "anat", f"{SUBJECT}_{SESSION}_T1w-space_FLAIR_normalized.nii.gz"))

    return outputs


rule all:
    input: get_final_output()
    shell:
        """
        if [ "{KEEP_TEMP}" = "False" ] || [ "{KEEP_TEMP}" = "false" ]; then
            echo "[INFO] Cleaning up temporary directory: {TEMP_DIR}"
            rm -rf {TEMP_DIR}
        fi
        if [ "{LINEAR}" = "True" ] || [ "{LINEAR}" = "true" ]; then
            echo "[INFO] Linear registration was used throughout the pipeline."
            echo "[INFO] Removing temporary warp fields as they are not generated in linear mode."
            find {OUT_DIR}/{SUBJECT}/{SESSION}/xfm/ -name "*T1w_to-MNI152linear_fwdfield.nii.gz" -type f -delete
            find {OUT_DIR}/{SUBJECT}/{SESSION}/xfm/ -name "*T1w_to-MNI152linear_bakfield.nii.gz" -type f -delete
        else
            echo "[INFO] Non-linear registration was used throughout the pipeline."
        fi
        """

# Define synthseg_t1w first since other rules depend on it
rule synthseg_t1w:
    input:
        image = T1W_FILE
    output:
        seg = f"{OUT_DIR}/{SUBJECT}/{SESSION}/xfm/{SUBJECT}_{SESSION}_synthseg_T1w.nii.gz"
    threads: HEAVY_THREADS
    shell:
        "micaflow synthseg --i {input.image} --o {output.seg} --parc --robust --threads {threads} {CPU_FLAG}"

# Now define the FLAIR-specific synthseg rule if needed
if RUN_FLAIR:
    rule synthseg_flair:
        input:
            image = FLAIR_FILE,
        output:
            seg = f"{OUT_DIR}/{SUBJECT}/{SESSION}/xfm/{SUBJECT}_{SESSION}_synthseg_FLAIR.nii.gz"
        threads: HEAVY_THREADS
        shell:
            "micaflow synthseg --i {input.image} --o {output.seg} --parc --robust --threads {threads} {CPU_FLAG}"

rule skull_strip_t1w:
    input:
        image = T1W_FILE,
        seg = rules.synthseg_t1w.output.seg
    output:
        brain = f"{TEMP_DIR}/{SUBJECT}_{SESSION}_brain-extracted_T1w.nii.gz",
        mask = f"{OUT_DIR}/{SUBJECT}/{SESSION}/anat/{SUBJECT}_{SESSION}_brain-extracted_T1w-space_mask.nii.gz"
    threads: LIGHT_THREADS
    params:
        parcellation = lambda wildcards, input: f"--parcellation {input.seg}",
        rm_cerebellum = "--remove-cerebellum" if RM_CEREBELLUM else ""
    shell:
        """
        micaflow bet \
            --input {input.image} \
            --output {output.brain} \
            --output-mask {output.mask} \
            {params.parcellation} \
            {params.rm_cerebellum} 
        """

rule bias_field_correction:
    input:
        image = T1W_FILE,
        mask = rules.skull_strip_t1w.output.mask
    output:
        corrected = f"{OUT_DIR}/{SUBJECT}/{SESSION}/anat/{SUBJECT}_{SESSION}_T1w-space_T1w.nii.gz"
    threads: LIGHT_THREADS
    shell:
        "micaflow bias_correction -i {input.image} -o {output.corrected} -m {input.mask}"

# Place these rules in a conditional block to only run when FLAIR is available
if RUN_FLAIR:
    rule skull_strip_flair:
        input:
            image = FLAIR_FILE,
            seg = rules.synthseg_flair.output.seg
        output:
            brain = f"{TEMP_DIR}/{SUBJECT}_{SESSION}_brain-extracted_FLAIR.nii.gz",
            mask = f"{OUT_DIR}/{SUBJECT}/{SESSION}/anat/{SUBJECT}_{SESSION}_brain-extracted_FLAIR_mask.nii.gz"
        threads: LIGHT_THREADS
        params:
            parcellation = lambda wildcards, input: f"--parcellation {input.seg}",
            rm_cerebellum = "--remove-cerebellum" if RM_CEREBELLUM else ""
        shell:
            """
            micaflow bet \
                --input {input.image} \
                --output {output.brain} \
                --output-mask {output.mask} \
                {params.parcellation} \
                {params.rm_cerebellum} 
            """

    rule bias_field_correction_flair:
        input:
            image = FLAIR_FILE,
            mask = rules.skull_strip_flair.output.mask
        output:
            corrected = f"{OUT_DIR}/{SUBJECT}/{SESSION}/anat/{SUBJECT}_{SESSION}_FLAIR-space_FLAIR.nii.gz"
        threads: LIGHT_THREADS
        shell:
            "micaflow bias_correction -i {input.image} -o {output.corrected} -m {input.mask}"

    rule registration_t1w:
        input:
            fixed_seg = rules.synthseg_t1w.output.seg,
            moving_seg = rules.synthseg_flair.output.seg,
            anatomical_fixed = rules.bias_field_correction.output.corrected,
            anatomical_moving = rules.bias_field_correction_flair.output.corrected
        output:
            warped = f"{OUT_DIR}/{SUBJECT}/{SESSION}/xfm/{SUBJECT}_{SESSION}_T1w-space_FLAIR.nii.gz",
            fwd_field = f"{OUT_DIR}/{SUBJECT}/{SESSION}/xfm/{SUBJECT}_{SESSION}_from-FLAIR_to-T1w_fwdfield.nii.gz",
            bak_field = f"{OUT_DIR}/{SUBJECT}/{SESSION}/xfm/{SUBJECT}_{SESSION}_from-FLAIR_to-T1w_bakfield.nii.gz",
            fwd_affine = f"{OUT_DIR}/{SUBJECT}/{SESSION}/xfm/{SUBJECT}_{SESSION}_from-FLAIR_to-T1w_fwdaffine.mat",
            fwd_field_secondary = f"{OUT_DIR}/{SUBJECT}/{SESSION}/xfm/{SUBJECT}_{SESSION}_from-FLAIR_to-T1w_fwdfield-secondary.nii.gz",
            bak_field_secondary = f"{OUT_DIR}/{SUBJECT}/{SESSION}/xfm/{SUBJECT}_{SESSION}_from-FLAIR_to-T1w_bakfield-secondary.nii.gz",
            output_segmentation = f"{OUT_DIR}/{SUBJECT}/{SESSION}/xfm/{SUBJECT}_{SESSION}_from-FLAIR-to-T1w_seg.nii.gz"
        threads: HEAVY_THREADS
        shell:
            """
            micaflow coregister \
                --fixed-file {input.anatomical_fixed} \
                --moving-file {input.anatomical_moving} \
                --fixed-segmentation {input.fixed_seg} \
                --moving-segmentation {input.moving_seg} \
                --output {output.warped} \
                --warp-file {output.fwd_field} \
                --affine-file {output.fwd_affine} \
                --rev-warp-file {output.bak_field} \
                --threads {threads} \
                --output-segmentation {output.output_segmentation} \
                --secondary-warp-file {output.fwd_field_secondary} \
                --secondary-rev-warp-file {output.bak_field_secondary}
            """

    rule apply_warp_flair_to_t1w:
        input:
            moving = rules.bias_field_correction_flair.output.corrected,
            warp = rules.registration_t1w.output.fwd_field,
            affine = rules.registration_t1w.output.fwd_affine,
            reference = rules.bias_field_correction.output.corrected
        output:
            warped = f"{OUT_DIR}/{SUBJECT}/{SESSION}/anat/{SUBJECT}_{SESSION}_T1w-space_FLAIR.nii.gz"
        threads: LIGHT_THREADS
        shell:
            """
            micaflow apply_warp \
                --moving {input.moving} \
                --reference {input.reference} \
                --affine {input.affine} \
                --warp {input.warp} \
                --output {output.warped}
            """

    rule calculate_metrics_FLAIR:
        input:
            image = rules.registration_t1w.output.output_segmentation,
            atlas = rules.synthseg_t1w.output.seg
        output:
            metrics = f"{OUT_DIR}/{SUBJECT}/{SESSION}/metrics/{SUBJECT}_{SESSION}_FLAIR-to-T1w-space_DICE.csv"
        threads: LIGHT_THREADS
        shell:
            """
            micaflow calculate_dice \
                --input {input.image} \
                --reference {input.atlas} \
                --output {output.metrics}
            """

# Generalized MNI registration rule for both linear and nonlinear
rule registration_mni152:
    input:
        image = rules.bias_field_correction.output.corrected,
        fixed = ATLAS,
        image_segmentation = rules.synthseg_t1w.output.seg,
        fixed_segmentation = ATLAS_SEG
    output:
        warped = f"{OUT_DIR}/{SUBJECT}/{SESSION}/xfm/{SUBJECT}_{SESSION}_{{mni_space}}-space_T1w.nii.gz",
        fwd_field = f"{OUT_DIR}/{SUBJECT}/{SESSION}/xfm/{SUBJECT}_{SESSION}_from-T1w_to-{{mni_space}}_fwdfield.nii.gz",
        bak_field = f"{OUT_DIR}/{SUBJECT}/{SESSION}/xfm/{SUBJECT}_{SESSION}_from-T1w_to-{{mni_space}}_bakfield.nii.gz",
        fwd_affine = f"{OUT_DIR}/{SUBJECT}/{SESSION}/xfm/{SUBJECT}_{SESSION}_from-T1w_to-{{mni_space}}_fwdaffine.mat",
        output_segmentation = f"{OUT_DIR}/{SUBJECT}/{SESSION}/xfm/{SUBJECT}_{SESSION}_from-T1w_to-{{mni_space}}_seg.nii.gz"
    threads: HEAVY_THREADS
    wildcard_constraints:
        mni_space = "MNI152|MNI152linear"
    params:
        linear_flag = lambda wildcards: "--linear-only" if "linear" in wildcards.mni_space else "",
        disable_robust = "--disable-robust"
    shell:
        """
        micaflow coregister \
            --fixed-file {input.fixed} \
            --moving-file {input.image} \
            --fixed-segmentation {input.fixed_segmentation} \
            --moving-segmentation {input.image_segmentation} \
            --output {output.warped} \
            --affine-file {output.fwd_affine} \
            --warp-file {output.fwd_field} \
            --rev-warp-file {output.bak_field} \
            --output-segmentation {output.output_segmentation} \
            --threads {threads} \
            {params.disable_robust} \
            {params.linear_flag}
        
        # If linear mode, create placeholder files to satisfy outputs
        if [[ "{wildcards.mni_space}" == *"linear"* ]]; then
            touch "{output.fwd_field}"
            touch "{output.bak_field}"
        fi
        """

rule apply_warp_t1w_to_mni:
    input:
        moving = rules.bias_field_correction.output.corrected,
        affine = rules.registration_mni152.output.fwd_affine,
        warp = rules.registration_mni152.output.fwd_field,
        reference = ATLAS
    output:
        warped = f"{OUT_DIR}/{SUBJECT}/{SESSION}/anat/{SUBJECT}_{SESSION}_{{mni_space}}-space_T1w.nii.gz"
    threads: LIGHT_THREADS
    wildcard_constraints:
        mni_space = "MNI152|MNI152linear"
    run:
        if "linear" in wildcards.mni_space:
            shell(f"micaflow apply_warp --moving {input.moving} --reference {input.reference} "
                  f"--affine {input.affine} --output {output.warped}")
        else:
            shell(f"micaflow apply_warp --moving {input.moving} --reference {input.reference} "
                  f"--affine {input.affine} --warp {input.warp} --output {output.warped}")

if RUN_FLAIR:
    rule apply_warp_flair_to_mni:
        input:
            moving = rules.bias_field_correction_flair.output.corrected,
            affine_mni = rules.registration_mni152.output.fwd_affine,
            warp_mni = rules.registration_mni152.output.fwd_field,
            affine_flair = rules.registration_t1w.output.fwd_affine,
            warp_flair = rules.registration_t1w.output.fwd_field,
            secondary_warp_flair = rules.registration_t1w.output.fwd_field_secondary,
            reference = ATLAS
        output:
            warped = f"{OUT_DIR}/{SUBJECT}/{SESSION}/anat/{SUBJECT}_{SESSION}_{{mni_space}}-space_FLAIR.nii.gz"
        threads: LIGHT_THREADS
        wildcard_constraints:
            mni_space = "MNI152|MNI152linear"
        run:
            if "linear" in wildcards.mni_space:
                 shell(f"micaflow apply_warp --moving {input.moving} --reference {input.reference} "
                       f"--transforms {input.affine_mni} {input.secondary_warp_flair} "
                       f"{input.warp_flair} {input.affine_flair} --output {output.warped}")
            else:
                 shell(f"micaflow apply_warp --moving {input.moving} --reference {input.reference} "
                       f"--transforms {input.warp_mni} {input.affine_mni} {input.secondary_warp_flair} "
                       f"{input.warp_flair} {input.affine_flair} --output {output.warped}")

rule run_texture_native:
    input:
        image = f"{OUT_DIR}/{SUBJECT}/{SESSION}/anat/{SUBJECT}_{SESSION}_T1w-space_{{modality}}.nii.gz",
        mask = f"{OUT_DIR}/{SUBJECT}/{SESSION}/anat/{SUBJECT}_{SESSION}_brain-extracted_T1w-space_mask.nii.gz"
    output:
        gradient = f"{OUT_DIR}/{SUBJECT}/{SESSION}/textures/{SUBJECT}_{SESSION}_T1w-space_textures-{{modality}}_gradient-magnitude.nii.gz",
        intensity = f"{OUT_DIR}/{SUBJECT}/{SESSION}/textures/{SUBJECT}_{SESSION}_T1w-space_textures-{{modality}}_relative-intensity.nii.gz"
    threads: LIGHT_THREADS
    shell:
        """
        micaflow texture_generation \
            --input {input.image} \
            --mask {input.mask} \
            --output {OUT_DIR}/{SUBJECT}/{SESSION}/textures/{SUBJECT}_{SESSION}_T1w-space_textures-{wildcards.modality}
        """

rule warp_texture_to_mni:
    input:
        moving = f"{OUT_DIR}/{SUBJECT}/{SESSION}/textures/{SUBJECT}_{SESSION}_T1w-space_textures-{{modality}}_{{metric}}.nii.gz",
        affine = rules.registration_mni152.output.fwd_affine,
        warp = rules.registration_mni152.output.fwd_field,
        reference = ATLAS
    output:
        warped = f"{OUT_DIR}/{SUBJECT}/{SESSION}/textures/{SUBJECT}_{SESSION}_{{mni_space}}-space_textures-{{modality}}_{{metric}}.nii.gz"
    wildcard_constraints:
        metric = "gradient-magnitude|relative-intensity",
        mni_space = "MNI152|MNI152linear"
    threads: LIGHT_THREADS
    run:
        if "linear" in wildcards.mni_space:
            shell(f"micaflow apply_warp --moving {input.moving} --reference {input.reference} "
                  f"--affine {input.affine} --output {output.warped}")
        else:
            shell(f"micaflow apply_warp --moving {input.moving} --reference {input.reference} "
                  f"--affine {input.affine} --warp {input.warp} --output {output.warped}")

rule calculate_metrics_T1w:
    input:
        image = rules.registration_mni152.output.output_segmentation,
        atlas = ATLAS_SEG
    output:
        metrics = f"{OUT_DIR}/{SUBJECT}/{SESSION}/metrics/{SUBJECT}_{SESSION}_T1w-to-{{mni_space}}_DICE.csv"
    wildcard_constraints:
        mni_space = "MNI152|MNI152linear"
    threads: LIGHT_THREADS
    shell:
        """
        micaflow calculate_dice \
            --input {input.image} \
            --reference {input.atlas} \
            --output {output.metrics}
        """


if RUN_DWI:
    rule normalize_dwi_metrics:
        input:
            image = lambda wildcards: f"{OUT_DIR}/{SUBJECT}/{SESSION}/dwi/{SUBJECT}_{SESSION}_T1w-space_{wildcards.metric}.nii.gz"
        output:
            normalized = f"{OUT_DIR}/{SUBJECT}/{SESSION}/dwi/{SUBJECT}_{SESSION}_T1w-space_{{metric}}_normalized.nii.gz"
        threads: LIGHT_THREADS
        shell:
            """
            micaflow normalize \
                --input {input.image} \
                --output {output.normalized} \
                --lower-percentile 1.0 \
                --upper-percentile 99.0 \
                --min-value 0 \
                --max-value 100
            """
    rule dwi_denoise:
        input:
            moving = DWI_FILE,
            bval = BVAL_FILE,
            bvec = BVEC_FILE
        output:
            denoised = f"{TEMP_DIR}/{SUBJECT}_{SESSION}_denoised_DWI.nii.gz"
        threads: HEAVY_THREADS
        shell:
            """
            micaflow denoise \
                --input {input.moving} \
                --bval {input.bval} \
                --bvec {input.bvec} \
                --output {output.denoised} \
                --gibbs \
                --threads {threads}
            """
    rule dwi_b0_extraction:
        input:
            denoised = rules.dwi_denoise.output.denoised,
            bval = BVAL_FILE,
            bvec = BVEC_FILE
        output:
            b0 = f"{TEMP_DIR}/{SUBJECT}_{SESSION}_b0_DWI.nii.gz",
            output_bvec = f"{TEMP_DIR}/{SUBJECT}_{SESSION}_b0_DWI.bvec",
            output_bval = f"{TEMP_DIR}/{SUBJECT}_{SESSION}_b0_DWI.bval",
            output_dwi = f"{TEMP_DIR}/{SUBJECT}_{SESSION}_denoised_DWI_nob0.nii.gz",
            b0_bval = f"{TEMP_DIR}/{SUBJECT}_{SESSION}_b0_only.bval",
            b0_bvec = f"{TEMP_DIR}/{SUBJECT}_{SESSION}_b0_only.bvec"
        threads: LIGHT_THREADS
        shell:
            """
            micaflow extract_b0 \
                --input {input.denoised} \
                --bval {input.bval} \
                --bvec {input.bvec} \
                --output-bval {output.output_bval} \
                --output-bvec {output.output_bvec} \
                --output {output.b0} \
                --output-dwi {output.output_dwi} \
                --shell-dimension {SHELL_DIMENSION} \
                --b0-bval {output.b0_bval} \
                --b0-bvec {output.b0_bvec}
            """

    rule dwi_motion_correction:
        input:
            denoised = rules.dwi_b0_extraction.output.output_dwi,
            bvec = rules.dwi_b0_extraction.output.output_bvec,
            b0 = rules.dwi_b0_extraction.output.b0,
            bval = rules.dwi_b0_extraction.output.output_bval
        output:
            corrected = f"{TEMP_DIR}/{SUBJECT}_{SESSION}_denoised_motioncorrected_DWI.nii.gz",
            corrected_bvec = f"{TEMP_DIR}/{SUBJECT}_{SESSION}_denoised_motioncorrected_DWI.bvec"
        threads: HEAVY_THREADS
        shell:
            """
            micaflow motion_correction \
                --denoised {input.denoised} \
                --input-bvecs {input.bvec} \
                --output-bvecs {output.corrected_bvec} \
                --output {output.corrected} \
                --b0 {input.b0} \
                --shell-dimension {SHELL_DIMENSION} \
                --threads {threads} \
                --input-bvals {input.bval} \
                --temp-dir {TEMP_DIR}
            """

    if USE_SYNTH_B0:
        rule dwi_bias_correction:
            input:
                image = rules.dwi_motion_correction.output.corrected,
                b0 = rules.dwi_b0_extraction.output.b0,
            output:
                corrected = f"{TEMP_DIR}/{SUBJECT}_{SESSION}_denoised_bias-corrected_DWI.nii.gz",
                b0_corrected = f"{TEMP_DIR}/{SUBJECT}_{SESSION}_biascorrected-b0.nii.gz"
            threads: LIGHT_THREADS
            shell:
                """
                micaflow bias_correction \
                    --input {input.image} \
                    --b0 {input.b0} \
                    --b0-output {output.b0_corrected} \
                    --output {output.corrected} \
                    --shell-dimension {SHELL_DIMENSION} \
                    --threads {threads}
                """
        rule b0_synthseg:
            input:
                image = rules.dwi_bias_correction.output.b0_corrected,
            output:
                seg = f"{TEMP_DIR}/{SUBJECT}_{SESSION}_synthseg_b0_NoSDC.nii.gz"
            threads: LIGHT_THREADS
            shell:
                """
                micaflow synthseg \
                    --i {input.image} \
                    --o {output.seg} \
                    --parc \
                    --robust \
                    --threads {threads} \
                    {CPU_FLAG}
                """
        rule b0_synth_registration:
            input:
                moving = rules.dwi_bias_correction.output.b0_corrected,
                fixed = rules.bias_field_correction.output.corrected,
                fixed_seg = rules.synthseg_t1w.output.seg,
                moving_seg = rules.b0_synthseg.output.seg
            output:
                warped = f"{TEMP_DIR}/{SUBJECT}_{SESSION}_registered-b0_NoSDC.nii.gz",
                fwd_field = f"{TEMP_DIR}/{SUBJECT}_{SESSION}_b0-to-t1_warp_NoSDC.nii.gz",
                fwd_affine = f"{TEMP_DIR}/{SUBJECT}_{SESSION}_b0-to-t1_affine_NoSDC.mat",
                fwd_field_secondary = f"{TEMP_DIR}/{SUBJECT}_{SESSION}_b0-to-t1_warp-secondary_NoSDC.nii.gz",
                output_segmentation = f"{TEMP_DIR}/{SUBJECT}_{SESSION}_b0-to-t1_synthseg_NoSDC.nii.gz"
            threads: HEAVY_THREADS
            shell:
                """
                micaflow coregister \
                    --fixed-file {input.fixed} \
                    --moving-file {input.moving} \
                    --fixed-segmentation {input.fixed_seg} \
                    --moving-segmentation {input.moving_seg} \
                    --output {output.warped} \
                    --warp-file {output.fwd_field} \
                    --affine-file {output.fwd_affine} \
                    --threads {threads} \
                    --output-segmentation {output.output_segmentation} \
                    --secondary-warp-file {output.fwd_field_secondary} \
                """
        
        rule dwi_create_synthetic_b0:
            input:
                t1 = f"{OUT_DIR}/{SUBJECT}/{SESSION}/anat/{SUBJECT}_{SESSION}_T1w-space_T1w_normalized.nii.gz",
                b0 = rules.b0_synth_registration.output.warped,
                dwi = rules.dwi_bias_correction.output.corrected,
                warp = rules.b0_synth_registration.output.fwd_field,
                secondary_warp = rules.b0_synth_registration.output.fwd_field_secondary,
                affine = rules.b0_synth_registration.output.fwd_affine,
            output:
                corrected_DWI = f"{TEMP_DIR}/{SUBJECT}_{SESSION}_SDCcorrected-DWI.nii.gz",
                intermediate = f"{TEMP_DIR}/{SUBJECT}_{SESSION}_synthetic-b0-t1space.nii.gz",
                warp = f"{TEMP_DIR}/{SUBJECT}_{SESSION}_SDC_warp_t1space.nii.gz",
                corrected_b0 = f"{TEMP_DIR}/{SUBJECT}_{SESSION}_SDCcorrected-b0.nii.gz"
            threads: HEAVY_THREADS
            params:
                cpu_flag = "--cpu" if GPU == "--cpu" else ""
            shell:
                """
                micaflow synth_b0 \
                    --t1 {input.t1} \
                    --b0 {input.b0} \
                    --dwi {input.dwi} \
                    --output {output.corrected_DWI} \
                    --intermediate {output.intermediate} \
                    {params.cpu_flag} \
                    --warp {output.warp} \
                    --temp-dir {TEMP_DIR} \
                    --phase-encoding {PED} \
                    --corrected-b0 {output.corrected_b0} \
                    --shell-dimension {SHELL_DIMENSION} \
                    --threads {threads} \
                    --b0-to-T1-affine {input.affine} \
                    --b0-to-T1-warp {input.warp} \
                    --b0-to-T1-warp-secondary {input.secondary_warp}
                """
        
        rule synthseg_dwi:
            input:
                # This input needs to be adjusted to work with either workflow
                image = rules.dwi_create_synthetic_b0.output.corrected_b0,
            output:
                seg = f"{OUT_DIR}/{SUBJECT}/{SESSION}/xfm/{SUBJECT}_{SESSION}_synthseg_DWI.nii.gz"
            threads: HEAVY_THREADS
            shell:
                """
                micaflow synthseg \
                    --i {input.image} \
                    --o {output.seg} \
                    --parc \
                    --robust \
                    --threads {threads} \
                    {CPU_FLAG}
                """
        
        # The rest of the rules remain unchanged

        rule dwi_skull_strip:
            input:
                image = rules.dwi_create_synthetic_b0.output.corrected_b0,
                seg = rules.synthseg_dwi.output.seg
            output:
                image = f"{TEMP_DIR}/{SUBJECT}_{SESSION}_brain-extracted_DWI.nii.gz",
                mask = f"{OUT_DIR}/{SUBJECT}/{SESSION}/dwi/{SUBJECT}_{SESSION}_brain-extracted_DWI_mask.nii.gz"
            params:
                rm_cerebellum = "--remove-cerebellum" if RM_CEREBELLUM else ""
            threads: LIGHT_THREADS
            shell:
                """
                micaflow bet \
                    --input {input.image} \
                    --output {output.image} \
                    --output-mask {output.mask} \
                    --parcellation {input.seg} \
                    {params.rm_cerebellum} 
                """

        rule dwi_registration:
            input:
                moving_seg = rules.synthseg_dwi.output.seg,
                fixed_seg = rules.synthseg_t1w.output.seg,
                moving = rules.dwi_create_synthetic_b0.output.corrected_b0,
                fixed = rules.bias_field_correction.output.corrected
            output:
                warped = f"{OUT_DIR}/{SUBJECT}/{SESSION}/xfm/{SUBJECT}_{SESSION}_T1w-space_DWI.nii.gz",
                fwd_field = f"{OUT_DIR}/{SUBJECT}/{SESSION}/xfm/{SUBJECT}_{SESSION}_from-DWI_to-T1w_fwdfield.nii.gz",
                fwd_field_secondary = f"{OUT_DIR}/{SUBJECT}/{SESSION}/xfm/{SUBJECT}_{SESSION}_from-DWI_to-T1w_fwdfield-secondary.nii.gz",
                fwd_affine = f"{OUT_DIR}/{SUBJECT}/{SESSION}/xfm/{SUBJECT}_{SESSION}_from-DWI_to-T1w_fwdaffine.mat",
                rev_field = f"{OUT_DIR}/{SUBJECT}/{SESSION}/xfm/{SUBJECT}_{SESSION}_from-DWI_to-T1w_revfield.nii.gz",
                rev_field_secondary = f"{OUT_DIR}/{SUBJECT}/{SESSION}/xfm/{SUBJECT}_{SESSION}_from-DWI_to-T1w_revfield-secondary.nii.gz",
                output_segmentation = f"{OUT_DIR}/{SUBJECT}/{SESSION}/xfm/{SUBJECT}_{SESSION}_from-DWI_to-T1w_seg.nii.gz"
            threads: HEAVY_THREADS
            shell:
                """
                micaflow coregister \
                    --fixed-file {input.fixed} \
                    --moving-file {input.moving} \
                    --fixed-segmentation {input.fixed_seg} \
                    --moving-segmentation {input.moving_seg} \
                    --output {output.warped} \
                    --warp-file {output.fwd_field} \
                    --affine-file {output.fwd_affine} \
                    --rev-warp-file {output.rev_field} \
                    --threads {threads} \
                    --output-segmentation {output.output_segmentation} \
                    --secondary-warp-file {output.fwd_field_secondary} \
                    --secondary-rev-warp-file {output.rev_field_secondary}
                """
        
        rule dwi_compute_fa_md:
            input:
                image = rules.dwi_create_synthetic_b0.output.corrected_DWI,
                mask = rules.dwi_skull_strip.output.mask,
                bval = rules.dwi_b0_extraction.output.output_bval,
                bvec = rules.dwi_motion_correction.output.corrected_bvec,
                b0_volume = rules.dwi_create_synthetic_b0.output.corrected_b0,
                b0_bvec = rules.dwi_b0_extraction.output.b0_bvec,
                b0_bval = rules.dwi_b0_extraction.output.b0_bval
            output:
                fa = f"{OUT_DIR}/{SUBJECT}/{SESSION}/dwi/{SUBJECT}_{SESSION}_DWI-space_FA.nii.gz",
                md = f"{OUT_DIR}/{SUBJECT}/{SESSION}/dwi/{SUBJECT}_{SESSION}_DWI-space_MD.nii.gz"
            threads: LIGHT_THREADS
            shell:
                """
                micaflow compute_fa_md \
                    --input {input.image} \
                    --mask {input.mask} \
                    --bval {input.bval} \
                    --bvec {input.bvec} \
                    --output-fa {output.fa} \
                    --output-md {output.md} \
                    --b0-volume {input.b0_volume} \
                    --b0-bvec {input.b0_bvec} \
                    --b0-bval {input.b0_bval} \
                """
        rule dwi_fa_md_registration:
            input:
                fa = rules.dwi_compute_fa_md.output.fa,
                md = rules.dwi_compute_fa_md.output.md,
                reference = lambda wildcards: f"{OUT_DIR}/{SUBJECT}/{SESSION}/anat/{SUBJECT}_{SESSION}_T1w-space_T1w.nii.gz",
                affine = rules.dwi_registration.output.fwd_affine,
                warp = rules.dwi_registration.output.fwd_field,
                secondary_warp = rules.dwi_registration.output.fwd_field_secondary
            output:
                fa_reg = f"{OUT_DIR}/{SUBJECT}/{SESSION}/dwi/{SUBJECT}_{SESSION}_T1w-space_FA.nii.gz",
                md_reg = f"{OUT_DIR}/{SUBJECT}/{SESSION}/dwi/{SUBJECT}_{SESSION}_T1w-space_MD.nii.gz"
            threads: LIGHT_THREADS
            wildcard_constraints:
                modality="T1w"
            run:
                # Process FA map
                shell(f"micaflow apply_warp --moving {input.fa} --reference {input.reference} "
                      f"--affine {input.affine} --warp {input.warp} --output {output.fa_reg} "
                      f"--secondary-warp {input.secondary_warp}")
                
                # Process MD map
                shell(f"micaflow apply_warp --moving {input.md} --reference {input.reference} "
                      f"--affine {input.affine} --warp {input.warp} --output {output.md_reg} "
                      f"--secondary-warp {input.secondary_warp}")

        rule calculate_metrics_DWI:
            input:
                image = rules.dwi_registration.output.output_segmentation,
                atlas = rules.synthseg_t1w.output.seg
            output:
                metrics = f"{OUT_DIR}/{SUBJECT}/{SESSION}/metrics/{SUBJECT}_{SESSION}_DWI-to-T1w-space_DICE.csv"
            threads: LIGHT_THREADS
            shell:
                """
                micaflow calculate_dice \
                    --input {input.image} \
                    --reference {input.atlas} \
                    --output {output.metrics}
                """

    else:
        rule dwi_b0_extraction_reversePE:
            input:
                image = INVERSE_DWI_FILE,
                bval = INVERSE_BVAL_FILE,
                bvec = INVERSE_BVEC_FILE
            output:
                b0 = f"{TEMP_DIR}/{SUBJECT}_{SESSION}_b0-inverse_DWI.nii.gz",
                output_bvec = f"{TEMP_DIR}/{SUBJECT}_{SESSION}_b0-inverse_DWI.bvec",
                output_bval = f"{TEMP_DIR}/{SUBJECT}_{SESSION}_b0-inverse_DWI.bval",
                output_dwi = f"{TEMP_DIR}/{SUBJECT}_{SESSION}_denoised_DWI_nob0-inverse.nii.gz"
            threads: LIGHT_THREADS
            shell:
                """
                micaflow extract_b0 \
                    --input {input.image} \
                    --bvals {input.bval} \
                    --bvecs {input.bvec} \
                    --output-bvals {output.output_bval} \
                    --output-bvecs {output.output_bvec} \
                    --output {output.b0} \
                    --output-dwi {output.output_dwi} \
                    --shell-dimension {SHELL_DIMENSION}
                """

        rule dwi_topup:
            input:
                b0 = rules.dwi_b0_extraction.output.b0,
                b0_inverse = rules.dwi_b0_extraction_reversePE.output.b0,
            output:
                warp = f"{OUT_DIR}/{SUBJECT}/{SESSION}/xfm/{SUBJECT}_{SESSION}_SDC-warp.nii.gz",
                corrected = f"{TEMP_DIR}/{SUBJECT}_{SESSION}_corrected-b0_DWI.nii.gz"
            threads: LIGHT_THREADS
            shell:
                """
                micaflow SDC \
                    --input {input.b0} \
                    --reverse-image {input.b0_inverse} \
                    --output {output.corrected} \
                    --output-warp {output.warp} \
                    --phase-encoding {PED}
                """

        rule dwi_apply_topup:
            input:
                motion_corr = rules.dwi_motion_correction.output.corrected,
                warp = rules.dwi_topup.output.warp,
                affine = DWI_FILE
            output:
                corrected = f"{TEMP_DIR}/{SUBJECT}_{SESSION}_SDC-DWI.nii.gz"
            threads: LIGHT_THREADS
            shell:
                """
                micaflow apply_SDC \
                    --input {input.motion_corr} \
                    --warp {input.warp} \
                    --affine {input.affine} \
                    --output {output.corrected} \
                    --phase-encoding {PED} 
                """
    
        rule synthseg_dwi:
            input:
                # This input needs to be adjusted to work with either workflow
                image = rules.dwi_topup.output.corrected,
            output:
                seg = f"{OUT_DIR}/{SUBJECT}/{SESSION}/xfm/{SUBJECT}_{SESSION}_synthseg_DWI.nii.gz"
            threads: HEAVY_THREADS
            shell:
                """
                micaflow synthseg \
                    --i {input.image} \
                    --o {output.seg} \
                    --parc \
                    --robust \
                    --threads {threads} \
                    {CPU_FLAG}
                """
        
        # The rest of the rules remain unchanged

        rule dwi_skull_strip:
            input:
                image = rules.dwi_topup.output.corrected,
                seg = rules.synthseg_dwi.output.seg
            output:
                image = f"{TEMP_DIR}/{SUBJECT}_{SESSION}_brain-extracted_DWI.nii.gz",
                mask = f"{OUT_DIR}/{SUBJECT}/{SESSION}/anat/{SUBJECT}_{SESSION}_brain-extracted_DWI_mask.nii.gz"
            params:
                rm_cerebellum = "--remove-cerebellum" if RM_CEREBELLUM else ""
            threads: LIGHT_THREADS
            shell:
                """
                micaflow bet \
                    --input {input.image} \
                    --output {output.image} \
                    --output-mask {output.mask} \
                    --parcellation {input.seg} \
                    {params.rm_cerebellum} 
                """

        rule dwi_bias_correction:
            input:
                image = rules.dwi_apply_topup.output.corrected,
                mask = rules.dwi_skull_strip.output.mask,
                b0 = rules.dwi_topup.output.corrected
            output:
                corrected = f"{TEMP_DIR}/{SUBJECT}_{SESSION}_denoised_bias-corrected_DWI.nii.gz",
                b0_corrected = f"{TEMP_DIR}/{SUBJECT}_{SESSION}_biascorrected-b0.nii.gz"
            threads: LIGHT_THREADS
            shell:
                """
                micaflow bias_correction \
                    --input {input.image} \
                    --b0 {input.b0} \
                    --b0-output {output.b0_corrected} \
                    --output {output.corrected} \
                    --mask {input.mask} \
                    --shell-dimension {SHELL_DIMENSION} \
                    --threads {threads}
                """

        rule dwi_registration:
            input:
                moving_seg = rules.synthseg_dwi.output.seg,
                fixed_seg = rules.synthseg_t1w.output.seg,
                moving = rules.dwi_bias_correction.output.b0_corrected,
                fixed = rules.bias_field_correction.output.corrected
            output:
                warped = f"{OUT_DIR}/{SUBJECT}/{SESSION}/xfm/{SUBJECT}_{SESSION}_T1w-space_DWI.nii.gz",
                fwd_field = f"{OUT_DIR}/{SUBJECT}/{SESSION}/xfm/{SUBJECT}_{SESSION}_from-DWI_to-T1w_fwdfield.nii.gz",
                fwd_field_secondary = f"{OUT_DIR}/{SUBJECT}/{SESSION}/xfm/{SUBJECT}_{SESSION}_from-DWI_to-T1w_fwdfield-secondary.nii.gz",
                fwd_affine = f"{OUT_DIR}/{SUBJECT}/{SESSION}/xfm/{SUBJECT}_{SESSION}_from-DWI_to-T1w_fwdaffine.mat",
                rev_field = f"{OUT_DIR}/{SUBJECT}/{SESSION}/xfm/{SUBJECT}_{SESSION}_from-DWI_to-T1w_revfield.nii.gz",
                rev_field_secondary = f"{OUT_DIR}/{SUBJECT}/{SESSION}/xfm/{SUBJECT}_{SESSION}_from-DWI_to-T1w_revfield-secondary.nii.gz",
                output_segmentation = f"{OUT_DIR}/{SUBJECT}/{SESSION}/xfm/{SUBJECT}_{SESSION}_from-DWI_to-T1w_seg.nii.gz"
            threads: HEAVY_THREADS
            shell:
                """
                micaflow coregister \
                    --fixed-file {input.fixed} \
                    --moving-file {input.moving} \
                    --fixed-segmentation {input.fixed_seg} \
                    --moving-segmentation {input.moving_seg} \
                    --output {output.warped} \
                    --warp-file {output.fwd_field} \
                    --affine-file {output.fwd_affine} \
                    --rev-warp-file {output.rev_field} \
                    --threads {threads} \
                    --output-segmentation {output.output_segmentation} \
                    --secondary-warp-file {output.fwd_field_secondary} \
                    --secondary-rev-warp-file {output.rev_field_secondary}
                """
        
        rule dwi_compute_fa_md:
            input:
                image = rules.dwi_bias_correction.output.corrected,
                mask = rules.dwi_skull_strip.output.mask,
                bval = rules.dwi_b0_extraction.output.output_bval,
                bvec = rules.dwi_motion_correction.output.corrected_bvec,
                b0_volume = rules.dwi_bias_correction.output.b0_corrected,
                b0_bvec = rules.dwi_b0_extraction.output.b0_bvec,
                b0_bval = rules.dwi_b0_extraction.output.b0_bval
            output:
                fa = f"{OUT_DIR}/{SUBJECT}/{SESSION}/dwi/{SUBJECT}_{SESSION}_DWI-space_FA.nii.gz",
                md = f"{OUT_DIR}/{SUBJECT}/{SESSION}/dwi/{SUBJECT}_{SESSION}_DWI-space_MD.nii.gz"
            threads: LIGHT_THREADS
            shell:
                """
                micaflow compute_fa_md \
                    --input {input.image} \
                    --mask {input.mask} \
                    --bval {input.bval} \
                    --bvec {input.bvec} \
                    --output-fa {output.fa} \
                    --output-md {output.md} \
                    --b0-volume {input.b0_volume} \
                    --b0-bvec {input.b0_bvec} \
                    --b0-bval {input.b0_bval} \
                """
        rule dwi_fa_md_registration:
            input:
                fa = rules.dwi_compute_fa_md.output.fa,
                md = rules.dwi_compute_fa_md.output.md,
                reference = lambda wildcards: f"{OUT_DIR}/{SUBJECT}/{SESSION}/anat/{SUBJECT}_{SESSION}_T1w-space_T1w.nii.gz",
                affine = rules.dwi_registration.output.fwd_affine,
                warp = rules.dwi_registration.output.fwd_field,
                secondary_warp = rules.dwi_registration.output.fwd_field_secondary
            output:
                fa_reg = f"{OUT_DIR}/{SUBJECT}/{SESSION}/dwi/{SUBJECT}_{SESSION}_T1w-space_FA.nii.gz",
                md_reg = f"{OUT_DIR}/{SUBJECT}/{SESSION}/dwi/{SUBJECT}_{SESSION}_T1w-space_MD.nii.gz"
            threads: LIGHT_THREADS
            wildcard_constraints:
                modality="T1w"
            run:
                # Process FA map
                shell(f"micaflow apply_warp --moving {input.fa} --reference {input.reference} "
                      f"--affine {input.affine} --warp {input.warp} --output {output.fa_reg} "
                      f"--secondary-warp {input.secondary_warp}")
                
                # Process MD map
                shell(f"micaflow apply_warp --moving {input.md} --reference {input.reference} "
                      f"--affine {input.affine} --warp {input.warp} --output {output.md_reg} "
                      f"--secondary-warp {input.secondary_warp}")

        rule calculate_metrics_DWI:
            input:
                image = rules.dwi_registration.output.output_segmentation,
                atlas = rules.synthseg_t1w.output.seg
            output:
                metrics = f"{OUT_DIR}/{SUBJECT}/{SESSION}/metrics/{SUBJECT}_{SESSION}_DWI-to-T1w-space_DICE.csv"
            threads: LIGHT_THREADS
            shell:
                """
                micaflow calculate_dice \
                    --input {input.image} \
                    --reference {input.atlas} \
                    --output {output.metrics}
                """

rule transform_mask_to_mni:
    input:
        mask = rules.skull_strip_t1w.output.mask,
        affine = rules.registration_mni152.output.fwd_affine,
        warp = rules.registration_mni152.output.fwd_field,
        # secondary_warp = rules.registration_mni152.output.fwd_field_secondary,
        reference = ATLAS
    output:
        mni_mask = f"{OUT_DIR}/{SUBJECT}/{SESSION}/anat/{SUBJECT}_{SESSION}_{{mni_space}}-space_brain_mask.nii.gz"
    threads: LIGHT_THREADS
    run:
        if "linear" in wildcards.mni_space:
             shell(f"micaflow apply_warp --moving {input.mask} --reference {input.reference} "
                   f"--affine {input.affine} --output {output.mni_mask} --interpolation nearestNeighbor")
        else:
             shell(f"micaflow apply_warp --moving {input.mask} --reference {input.reference} "
                   f"--affine {input.affine} --warp {input.warp} --output {output.mni_mask} --interpolation nearestNeighbor")

# Add new rule for DWI metrics to MNI space
if RUN_DWI:
    rule apply_warp_dwi_to_mni:
        input:
            moving = lambda wildcards: f"{OUT_DIR}/{SUBJECT}/{SESSION}/dwi/{SUBJECT}_{SESSION}_DWI-space_{wildcards.metric}.nii.gz",
            affine_mni = rules.registration_mni152.output.fwd_affine,
            warp_mni = rules.registration_mni152.output.fwd_field,
            affine_dwi = rules.dwi_registration.output.fwd_affine,
            warp_dwi = rules.dwi_registration.output.fwd_field,
            secondary_warp_dwi = rules.dwi_registration.output.fwd_field_secondary,
            reference = ATLAS
        output:
            warped = f"{OUT_DIR}/{SUBJECT}/{SESSION}/dwi/{SUBJECT}_{SESSION}_{{mni_space}}-space_{{metric}}.nii.gz"
        wildcard_constraints:
            metric="FA|MD",
            mni_space = "MNI152|MNI152linear"
        threads: LIGHT_THREADS
        run:
            if "linear" in wildcards.mni_space:
                shell(f"micaflow apply_warp --moving {input.moving} --reference {input.reference} "
                      f"--transforms {input.affine_mni} {input.secondary_warp_dwi} "
                      f"{input.warp_dwi} {input.affine_dwi} --output {output.warped}")
            else:
                shell(f"micaflow apply_warp --moving {input.moving} --reference {input.reference} "
                      f"--transforms {input.warp_mni} {input.affine_mni} {input.secondary_warp_dwi} "
                      f"{input.warp_dwi} {input.affine_dwi} --output {output.warped}")

if EXTRACT_BRAIN:
    rule skullstripping_MNI152_BE:
        input:
            image = lambda wildcards: (
                # Logic for T1w Space DWI inputs
                rules.dwi_fa_md_registration.output.fa_reg if wildcards.modality == "FA" and wildcards.mni_space == "T1w" else
                rules.dwi_fa_md_registration.output.md_reg if wildcards.modality == "MD" and wildcards.mni_space == "T1w" else
                
                # Logic for MNI Space DWI inputs
                rules.apply_warp_dwi_to_mni.output.warped.format(mni_space=wildcards.mni_space, metric=wildcards.modality) 
                if wildcards.modality in ["FA", "MD"] and "MNI" in wildcards.mni_space else

                # Logic for Anatomical inputs (T1w/FLAIR) - works for both MNI and T1w space via string formatting
                f"{OUT_DIR}/{SUBJECT}/{SESSION}/anat/{SUBJECT}_{SESSION}_{wildcards.mni_space}-space_{wildcards.modality}.nii.gz"
            ),
            mask = lambda wildcards: (
                rules.skull_strip_t1w.output.mask if wildcards.mni_space == "T1w" else
                rules.transform_mask_to_mni.output.mni_mask.format(mni_space=wildcards.mni_space)
            )
        output:
            brain = f"{OUT_DIR}/{SUBJECT}/{SESSION}/brain-extracted/{SUBJECT}_{SESSION}_{{mni_space}}-space_{{modality}}.nii.gz"
        wildcard_constraints:
            modality = "FA|MD|T1w|FLAIR",
            mni_space = "MNI152|MNI152linear|T1w" 
        threads: LIGHT_THREADS
        run:
             # ... creation and checks ...
             os.makedirs(os.path.dirname(output.brain), exist_ok=True)
             if wildcards.modality == "FLAIR" and not RUN_FLAIR:
                 shell("touch {output.brain}")
             elif wildcards.modality in ["FA", "MD"] and not RUN_DWI:
                 shell("touch {output.brain}")
             else:
                 shell(f"micaflow bet --input {input.image} --output {output.brain} --input-mask {input.mask}")

    rule normalize_brain_extracted:
        input:
            image = f"{OUT_DIR}/{SUBJECT}/{SESSION}/brain-extracted/{SUBJECT}_{SESSION}_{{space}}-space_{{modality}}.nii.gz"
        output:
            normalized = f"{OUT_DIR}/{SUBJECT}/{SESSION}/brain-extracted/{SUBJECT}_{SESSION}_{{space}}-space_{{modality}}_normalized.nii.gz"
        wildcard_constraints:
            space = "T1w|MNI152|MNI152linear",
            modality = "FA|MD|T1w|FLAIR"
        threads: LIGHT_THREADS
        shell:
            """
            micaflow normalize \
                --input {input.image} \
                --output {output.normalized} \
                --lower-percentile 1.0 \
                --upper-percentile 99.0 \
                --min-value 0 \
                --max-value 100
            """

rule normalize_anatomical:
    input:
        image = lambda wildcards: f"{OUT_DIR}/{SUBJECT}/{SESSION}/anat/{SUBJECT}_{SESSION}_{wildcards.space}_{{modality}}.nii.gz"
    output:
        normalized = f"{OUT_DIR}/{SUBJECT}/{SESSION}/anat/{SUBJECT}_{SESSION}_{{space}}_{{modality}}_normalized.nii.gz"
    wildcard_constraints:
        space = "T1w-space|MNI152-space|MNI152linear-space"
    threads: LIGHT_THREADS
    shell:
        """
        micaflow normalize \
            --input {input.image} \
            --output {output.normalized} \
            --lower-percentile 1.0 \
            --upper-percentile 99.0 \
            --min-value 0 \
            --max-value 100
        """
