import pandas as pd
import os
import sys
import warnings
import json
import re
warnings.filterwarnings("ignore")
import logging
from spatialsnake.workflow.function.get_sample import get_sample_paths, get_annotation, seg_filter_sample, get_stereoseq_input_spec_map
from spatialsnake.workflow.function.stereoseq_spec import parse_stereoseq_input_spec
L = logging.getLogger("spatialsnake_user")
L.setLevel(logging.INFO)
L.propagate = False
log_handler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter('%(asctime)s: %(levelname)s - %(message)s')
log_handler.setFormatter(formatter)
if not L.handlers:
    L.addHandler(log_handler)
runtype = ["visium","xenium","Merfish","visium_segment","stereoseq"]

annotation_list = config.get("annotation_list", "annotation.txt")
def parse_bool_flag(value):
    if isinstance(value, bool):
        return value
    if value is None:
        return False
    normalized = str(value).strip().lower()
    if normalized in {"true", "1", "yes", "y", "t"}:
        return True
    if normalized in {"false", "0", "no", "n", "f", "none", "null", ""}:
        return False
    return bool(value)

filter_list = parse_bool_flag(config.get("filter_list", False))
sample_list = config.get("sample_list", "sample.txt")

spatialsnake_path = config.get("spatialsnake_path", "")
results_folder = config.get("results_folder", "results")
cellchat_compare_output_dir = config.get("cellchat_compare_output_dir", os.path.join(results_folder, "compare_cellchat"))
cellchat_compare_sample_name1 = config.get("cellchat_compare_sample_name1", "")
cellchat_compare_sample_name2 = config.get("cellchat_compare_sample_name2", "")
cellchat_compare_pathways = config.get("cellchat_compare_pathways", "")
cellchat_compare_source_cells = config.get("cellchat_compare_source_cells", "")
cellchat_compare_target_cells = config.get("cellchat_compare_target_cells", "")
cellchat_compare_receiver_cells = config.get("cellchat_compare_receiver_cells", "")
cellchat_compare_bubble_angle = config.get("cellchat_compare_bubble_angle", 45)
cellchat_compare_bubble_remove_isolate = config.get("cellchat_compare_bubble_remove_isolate", True)
cellchat_compare_do_ranknet = config.get("cellchat_compare_do_ranknet", True)
cellchat_compare_do_role_heatmap = config.get("cellchat_compare_do_role_heatmap", True)
cellchat_compare_do_pathway_plots = config.get("cellchat_compare_do_pathway_plots", True)
cellchat_compare_do_compare_overview = config.get("cellchat_compare_do_compare_overview", True)
cellchat_compare_do_compare_bubble = config.get("cellchat_compare_do_compare_bubble", True)
cellchat_compare_do_single_bubble = config.get("cellchat_compare_do_single_bubble", True)
cellchat_compare_do_gene_expression = config.get("cellchat_compare_do_gene_expression", False)
cellchat_compare_gene_colors = config.get("cellchat_compare_gene_colors", "white,#FEC44F,#D95F0E")
cellchat_compare_gene_plot_type = config.get("cellchat_compare_gene_plot_type", "dot")
cellchat_compare_pair_lr_use = config.get("cellchat_compare_pair_lr_use", "")
cellchat_compare_save_merged = config.get("cellchat_compare_save_merged", True)

option = config.get("option","integrate")

channel = config.get("channel", "single_analysis")

data_fold = config.get("data_fold", "data")
run_type = config.get("run_type", "visium")
#################  xenium ###########
cells_boundaries = config.get("cells_boundaries", False)
nucleus_boundaries = config.get("nucleus_boundaries", False)
nucleus_labels = config.get("nucleus_labels", False)
morphology_mip = config.get("morphology_mip", False)

markers_algorithm = config.get("markers_algorithm", "wilcoxon")
image_slice = config.get("image_slice", False)

if image_slice == "True":
    coord = []
    x1 = config.get("x1")
    x2 = config.get("x2")
    y1 = config.get("y1")
    y2 = config.get("y2")
    coord.append(x1)
    coord.append(x2)
    coord.append(y1)
    coord.append(y2)
tsene = config.get("tsene", False)
MIN_DIST = config.get("MIN_DIST", 0.3)
SPREAD = config.get("SPREAD", 1)
variable = config.get("variable", False)
NEIGHBORS = config.get("NEIGHBORS", 10)
pcs = config.get("pcs",30)


RES = config.get("resolution", 0.5)
n_top_genes = config.get("n_top_genes", 1000)
batch_method = config.get("batch_method", None)
n_comps = config.get("n_comps", 50)
recluster_resolution = config.get("recluster_resolution", 0.8)
recluster_n_top_genes = config.get("recluster_n_top_genes", 2000)
recluster_neighbors = config.get("recluster_neighbors", 15)
recluster_n_pcs = config.get("recluster_n_pcs", 30)
recluster_marker_method = config.get("recluster_marker_method", "wilcoxon")
recluster_min_pct = config.get("recluster_min_pct", 0.1)
recluster_logfc_threshold = config.get("recluster_logfc_threshold", 0.25)

if channel == "single_analysis" and batch_method!=None:
    L.info("your argument harmony just allowed setting in compare_analysis channal")
    batch_method = None

spacies = config.get("spacies", "human")
cluster_algorithm = config.get("cluster_algorithm", "leiden")
anno_algorithm = config.get("anno_algorithm", "mannul")
compare_algorithm = config.get("compare_algorithm", "DEseq2")
cell_focus = config.get("cell_focus", "Colorectal")
device = config.get("device", "cpu")
runpipe = config.get("runpipe", "pysenic")
counts_data = config.get("counts_data", "hgnc_symbol")
iterations = config.get("iterations", 500)
threshold = config.get("threshold", 0.1)
threads = config.get("threads", 32)
pvalue = config.get("pvalue", 0.05)
sample_id = config.get("sample_type", "Normal")
microenvs_file_path = config.get("microenvs_file_path", "")
active_tf_path = config.get("active_tf_path", "")
degs_file_path = config.get("degs_file_path", "")
niche_col = config.get("niche_col", "")
is_singlecell = config.get("is_singlecell", False)
cpdb_method = config.get("cpdb_method", "statistical")
cpdb_de_method = config.get("cpdb_de_method", "wilcoxon")
n_clusters = config.get("n_clusters", 10)
output_name = config.get("output_name", "Normal")
mt_threshold = config.get("mt_threshold", 50.0)
significance = config.get("significance", 0.05)
max_cluster = config.get("max_cluster", 10)
condition_col = config.get("condition_col", "condition")
sample_col = config.get("sample_col", "sample")
celltype_col = config.get("celltype_col", "celltype")
liana_method = config.get("liana_method", "cellphonedb")
liana_resource_name = config.get("liana_resource_name", "consensus")
liana_expr_prop = config.get("liana_expr_prop", 0.1)
liana_min_cells = config.get("liana_min_cells", 5)
liana_use_raw = config.get("liana_use_raw", True)
cellcharter_col = config.get("cellcharter_col", "spatial_cluster")
cell_type1 = config.get("cell_type1", "")
cell_type2 = config.get("cell_type2", "")
gene_family = config.get("gene_family", "")
celltype = celltype_col
cellPhoneDB_input = config.get("cellPhoneDB_input", "")

geojson = "cell_segmentations.geojson"
image = "tissue_hires_image.png"
scale_factors = "scalefactors_json.json"
coor_file = "BeadLocationsForR.csv"
merscope_z_layers = config.get("merscope_z_layers", "None")
merscope_region_name = config.get("merscope_region_name", "None")
merscope_transcripts = config.get("merscope_transcripts", True)
merscope_cells_boundaries = config.get("merscope_cells_boundaries", True)
merscope_cells_table = config.get("merscope_cells_table", True)
merscope_mosaic_images = config.get("merscope_mosaic_images", True)

k_geom = config.get("k_geom", 15)
max_m = config.get("max_m", 1)
nbr_weight_decay = config.get("nbr_weight_decay", "scaled_gaussian")
lambda_list = config.get("lambda_list", [0.8])

## downsample for spatialdata
sketch = config.get("sketch",False)
sample_rate = config.get("sample_rate",1.0)

INPUT_FIlE=config.get("INPUT_FIlE")
barcode=config.get('clusters')
max_x=config.get('max_x')
min_x=config.get("min_x")
max_y=config.get("max_y")
min_y=config.get("min_y")

feather_input = config.get("feather_input", "")

if run_type != "xenium":
    image_type = config.get("image_type", "hires")
    shape_type = config.get("shape_type", "cell_boundaries")
else:
    image_type = config.get("image_type", "morphology_focus")
    shape_type = config.get("shape_type", "cell_circle")
vis_mode = config.get("vis_mode", "auto")

consistent_option = ["integrate","preprocess","clustering","annotation_help","annotation"]

def build_compare_sample_key(sample_id, group_id):
    return f"{group_id}::{sample_id}"

def normalize_run_type_name(value):
    return re.sub(r"[\s_-]+", "", str(value).strip().lower())

def normalize_optional_cellchat_spec(value):
    if value is None:
        return ""
    text = str(value).strip()
    if text.lower() in {"", "none", "null", "na", "nan"}:
        return ""
    return text

def prepare_cellchat_scale_specs(sample_ids, raw_specs, run_type):
    normalized_run_type = normalize_run_type_name(run_type)
    specs = [normalize_optional_cellchat_spec(spec) for spec in raw_specs]
    if len(specs) < len(sample_ids):
        specs.extend([""] * (len(sample_ids) - len(specs)))
    elif len(specs) > len(sample_ids):
        specs = specs[:len(sample_ids)]

    if normalized_run_type in {"visium", "visiumhd", "visiumsegment"}:
        missing_samples = [sample for sample, spec in zip(sample_ids, specs) if spec == ""]
        if missing_samples:
            sys.exit(
                "\nCellChat requires the third column of sample.txt to provide scalefactors_json "
                f"for run_type={run_type}. Missing samples: {', '.join(missing_samples)}"
            )
        return specs

    if normalized_run_type == "stereoseq":
        missing_samples = [sample for sample, spec in zip(sample_ids, specs) if spec == ""]
        if missing_samples:
            sys.exit(
                "\nCellChat requires the third column of sample.txt to provide Stereo-seq "
                f"bin_size or cellbin for run_type={run_type}. Missing samples: {', '.join(missing_samples)}"
            )
        for sample, spec in zip(sample_ids, specs):
            try:
                parse_stereoseq_input_spec(spec)
            except ValueError as exc:
                sys.exit(f"\nInvalid CellChat Stereo-seq spec for sample '{sample}': {exc}")
        return specs

    return []

def load_sample_input_dir_map(sample_list_file):
    sample_input_dir_map = {}
    if not os.path.isfile(sample_list_file):
        return sample_input_dir_map
    with open(sample_list_file) as sample_handle:
        next(sample_handle, None)
        for raw_line in sample_handle:
            raw_line = raw_line.strip()
            if not raw_line:
                continue
            parts = re.split(r'\s+', raw_line)
            if len(parts) < 2:
                continue
            sample_id = parts[0].strip()
            input_dir = parts[1].strip()
            sample_input_dir_map[sample_id] = input_dir
            if channel == "compare_analysis" and len(parts) >= 3:
                group_idx = 3 if run_type == "stereoseq" and len(parts) >= 4 else 2
                if len(parts) > group_idx:
                    sample_input_dir_map[build_compare_sample_key(sample_id, parts[group_idx].strip())] = input_dir
    return sample_input_dir_map

sample_input_dir_map = load_sample_input_dir_map(sample_list)
sample_stereoseq_input_spec_map = get_stereoseq_input_spec_map(sample_list, channel) if run_type == "stereoseq" else {}

if  option in consistent_option and anno_algorithm == "mannul":
    main_file = []
    samples = []
    bin_size = []
    group = []
    if channel == 'single_analysis':
        if run_type in runtype:
            samples, main_file = get_sample_paths(sample_list, run_type,channel,option,require_non_empty=True)
        elif run_type == "visium_HD":
            samples, main_file, bin_size = get_sample_paths(sample_list, run_type,channel,option,require_non_empty=True)
    else:
        if run_type == "stereoseq":
            samples, main_file, bin_size, group = get_sample_paths(sample_list, run_type,channel,option,require_non_empty=True)
        elif run_type in runtype:
            samples, main_file, group = get_sample_paths(sample_list, run_type,channel,option,require_non_empty=True)
        elif run_type == "visium_HD":
            samples, main_file, bin_size, group = get_sample_paths(sample_list,run_type,channel,option,require_non_empty=True)
    main_file = main_file[0]
elif option == "compare_stage" and runpipe != "cellchat":
    main_file = []
    samples = []
    bin_size = []
    group = []
    if channel == 'single_analysis':
        if run_type in runtype:
            samples, main_file = get_sample_paths(sample_list, run_type,channel,option,require_non_empty=True)
        elif run_type == "visium_HD":
            samples, main_file, bin_size = get_sample_paths(sample_list, run_type,channel,option,require_non_empty=True)
    else:
        if run_type == "stereoseq":
            samples, main_file, bin_size, group = get_sample_paths(sample_list, run_type,channel,option,require_non_empty=True)
        elif run_type in runtype:
            samples, main_file, group = get_sample_paths(sample_list, run_type,channel,option,require_non_empty=True)
        elif run_type == "visium_HD":
            samples, main_file, bin_size, group = get_sample_paths(sample_list,run_type,channel,option,require_non_empty=True)
    if len(main_file) > 0:
        main_file = main_file[0]
    unique_groups = pd.unique(group).tolist() if len(group) > 0 else []
    if len(unique_groups) > 0 and all(str(g).strip().lower().startswith("group") for g in unique_groups):
        L.info("compare_stage uses the third column of sample.txt directly for result naming. Please replace generic labels like Group1/Group2 with real biological group names.")
else:
    samples = []
    downstream_file = []
    reference=[]
    scale_factors_files = []
    cellchat_sample_names = []
    cellchat_scale_factors = []
    cellchat_output_name = None
    cellcharter_input = ""
    cellcharter_sample_id = sample_id
    banksy_input = ""
    banksy_sample_id = sample_id
    if option=="annotation":
        samples,downstream_file,reference = get_sample_paths(sample_list, run_type,channel,option,require_non_empty=False)
    else:
        samples,downstream_file,scale_factors_files= get_sample_paths(sample_list, run_type,channel,option,require_non_empty=False)
    if option == "compare_stage" and runpipe == "cellchat":
        if len(downstream_file) == 0:
            sys.exit("\ncompare_stage: 未提供 CellChat 比较所需的 rds 文件路径")
        bad = [p for p in downstream_file if (not os.path.isfile(p)) or (os.path.splitext(p)[1].lower() != ".rds")]
        if len(bad) > 0:
            L.info(f"以下路径不可用或非 .rds 文件: {bad}")
            sys.exit("\ncompare_stage: 请在 sample.txt 第二列提供有效的 CellChat .rds 路径")
    if option == "advance_analysis" and runpipe == "cellchat":
        scale_factors_files = prepare_cellchat_scale_specs(samples, scale_factors_files, run_type)
    if option == "advance_analysis" and runpipe == "cellchat" and channel == "compare_analysis":
        cellchat_sample_names = samples
        cellchat_scale_factors = scale_factors_files
        if len(downstream_file) > 0:
            derived_name = os.path.splitext(os.path.basename(downstream_file[0]))[0]
            cellchat_output_name = derived_name if derived_name else "concatenated_sdata"
        else:
            cellchat_output_name = "concatenated_sdata"
        samples = [cellchat_output_name]
    if option == "advance_analysis" and runpipe == "cellcharter":
        if len(downstream_file) == 0:
            sys.exit("\nadvance_analysis: 未提供 CellCharter 输入数据路径")
        cellcharter_input = downstream_file[0]
        if channel == "single_analysis" and len(samples) > 0:
            cellcharter_sample_id = samples[0]
        elif channel == "compare_analysis":
            derived_name = os.path.splitext(os.path.basename(downstream_file[0]))[0]
            cellcharter_sample_id = derived_name if derived_name else "concatenated_sdata"
    if option == "advance_analysis" and runpipe == "banksy":
        if len(downstream_file) == 0:
            sys.exit("\nadvance_analysis: 未提供 BANKSY 输入数据路径")
        banksy_input = downstream_file[0]
        if channel == "single_analysis" and len(samples) > 0:
            banksy_sample_id = samples[0]
        elif channel == "compare_analysis":
            derived_name = os.path.splitext(os.path.basename(downstream_file[0]))[0]
            banksy_sample_id = derived_name if derived_name else "concatenated_sdata"
    if option == "advance_analysis" and (runpipe == "cellPhoneDB" or runpipe == "pysenic" or runpipe == "liana"):
        if not cellPhoneDB_input:
            if len(downstream_file) == 0:
                sys.exit("\nadvance_analysis: 未提供 CellPhoneDB 输入数据路径")
            cellPhoneDB_input = downstream_file[0]
        if channel == "compare_analysis":
            cpdb_sample_id = "concatenated_sdata"
        elif len(samples) > 0:
            cpdb_sample_id = samples[0]
        else:
            cpdb_sample_id = sample_id
    print(samples,downstream_file,reference,scale_factors_files,"@@@@@@@@@@@")

def parameter_output(samples, option):
    outs = []
    if option == 'integrate':
        if channel == 'single_analysis':
            if run_type in runtype:
                outs += expand(os.path.join(results_folder, "{sample}",'integrate',"{sample}.zarr"), sample=samples)
            if run_type == "visium_HD":
                outs += expand(os.path.join(results_folder, "{sample}_{bin}um", 'integrate',"{sample}.zarr"), zip, sample=samples, bin=bin_size)
            if run_type == "slide_seq":
                outs += expand(os.path.join(results_folder, "{sample}",'integrate',"{sample}.h5ad"), sample=samples)
        if channel == "compare_analysis" and flag == 'merge':
            if run_type in runtype:
                outs += expand(os.path.join(results_folder, "{group}", "{sample}.zarr"), zip, sample=samples, group=group)
            if run_type == "visium_HD":
                outs += expand(os.path.join(results_folder, "{group}_{bin}um", "{sample}.zarr"), zip, sample=samples, bin=bin_size, group=group)
            if run_type == "slide_seq":
                outs += expand(os.path.join(results_folder, "{group}", "{sample}.h5ad"), zip, sample=samples, group=group)
        if channel == "compare_analysis" and flag == 'ALL':
            if run_type == "slide_seq":
                outs.append(os.path.join(results_folder, "merge_data", "integrate", "concatenated_sdata.h5ad"))
            else:
                outs.append(os.path.join(results_folder, "merge_data", "integrate", "concatenated_sdata.zarr"))
        return outs
    if option == "preprocess":
        if channel == 'single_analysis':
            if run_type in runtype:
                outs += expand(os.path.join(results_folder, "{sample}", 'preprocess', "filter_{sample}.zarr"), sample=samples)
            if run_type == "visium_HD":
                outs += expand(os.path.join(results_folder, "{sample}_{bin}um", 'preprocess', "filter_{sample}.zarr"), zip, sample=samples, bin=bin_size)
            if run_type == "slide_seq":
                outs += expand(os.path.join(results_folder, "{sample}", 'preprocess', "filter_{sample}.h5ad"), sample=samples)
        if channel == "compare_analysis":
            if run_type == "slide_seq":
                outs.append(os.path.join(results_folder, "merge_data", 'preprocess', "filter_concatenated_sdata.h5ad"))
            else:
                outs.append(os.path.join(results_folder, "merge_data", 'preprocess', "filter_concatenated_sdata.zarr"))
        return outs
    if option == "annotation_help":
        if channel == 'single_analysis':
          if run_type == "visium_HD":
            return(expand(os.path.join(results_folder, "{sample}_{bin}um", 'clustering', 'kegg_data.csv'), zip, sample=samples, bin=bin_size))
          else:
            return(expand(os.path.join(results_folder, "{sample}", 'clustering', 'kegg_data.csv'), sample=samples))
        if channel == 'compare_analysis':
            return(os.path.join(results_folder, "merge_data", 'clustering', 'kegg_data.csv'))
    if option == "clustering" or (option == "annotation" and anno_algorithm == "mannul"):
        if channel == 'single_analysis':
            if run_type in runtype:
                outs += expand(os.path.join(results_folder, "{sample}", option, "{sample}.zarr"), sample=samples)
            if run_type == "visium_HD":
                outs += expand(os.path.join(results_folder, "{sample}_{bin}um", option, "{sample}.zarr"), zip, sample=samples, bin=bin_size)
            if run_type == "slide_seq":
                outs += expand(os.path.join(results_folder, "{sample}", option, "{sample}.h5ad"), sample=samples)
        if channel == "compare_analysis":
            if run_type == "slide_seq":
                outs.append(os.path.join(results_folder, "merge_data", option, "concatenated_sdata.h5ad"))
            else:
                outs.append(os.path.join(results_folder, "merge_data", option, "concatenated_sdata.zarr"))
        return outs
    if option == "reclustering":
        outs += expand(os.path.join(results_folder, "{sample}", "reclustering", "{sample}.zarr"), sample=samples)
        return outs
    if anno_algorithm != "mannul" and option == "annotation":
        if anno_algorithm == "RCTD":
            outs += expand(os.path.join(results_folder, "{sample}", anno_algorithm, "{sample}.zarr"), sample=samples)
            return outs
        if channel == 'single_analysis':
            if run_type in runtype:
                outs += expand(os.path.join(results_folder, "{sample}", anno_algorithm, "{sample}.zarr"), sample=samples)
            if run_type == "visium_HD":
                outs += expand(os.path.join(results_folder, "{sample}", anno_algorithm, "{sample}.zarr"),sample=samples)
            if run_type == "slide_seq":
                outs += expand(os.path.join(results_folder, "{sample}", anno_algorithm, "{sample}.h5ad"), sample=samples)
        if channel == "compare_analysis":
            if run_type == "slide_seq":
                outs.append(os.path.join(results_folder, "merge_data", anno_algorithm, "concatenated_sdata.h5ad"))
            else:
                outs.append(os.path.join(results_folder, "merge_data", anno_algorithm, "concatenated_sdata.zarr"))
        return outs
    if option == "compare_stage":
        if runpipe == "cellchat":
            return(cellchat_compare_output_dir)
        return(os.path.join(results_folder, "merge_data", 'compare_analysis', 'positive', 'kegg_data.csv'))
    if option == "advance_analysis":
        if runpipe == "cellPhoneDB":
            if channel == 'single_analysis':
                outs += expand(os.path.join(results_folder, f"{cpdb_sample_id}", "cellPhoneDB_results", f"{cpdb_sample_id}_heatmap.png"))
            if channel == "compare_analysis":
                outs.append(os.path.join(results_folder, "merge_data", "cellPhoneDB_results", "concentrate_heatmap.png"))
        elif runpipe == "pysenic":
            outs.append(os.path.join(results_folder, "pysenic_results", f"{cpdb_sample_id}.aucell.loom"))
        elif runpipe == "liana":
            outs.append(os.path.join(results_folder,f"{cpdb_sample_id}","liana_output",f"{cpdb_sample_id}.zarr"))
        elif runpipe == "cellcharter":
            if channel == "compare_analysis":
                outs.append(os.path.join(results_folder, "merge_data", "cellcharter", f"{cellcharter_sample_id}_cellcharter.zarr"))
            else:
                outs.append(os.path.join(results_folder, f"{cellcharter_sample_id}", "cellcharter", f"{cellcharter_sample_id}_cellcharter.zarr"))
        elif runpipe == "banksy":
            if channel == "compare_analysis":
                outs.append(os.path.join(results_folder, "merge_data", "banksy", f"{banksy_sample_id}_banksy.zarr"))
            else:
                outs.append(os.path.join(results_folder, f"{banksy_sample_id}", "banksy", f"{banksy_sample_id}_banksy.zarr"))
        elif runpipe == "cellchat":
            cellchat_samples = samples
            if channel == "compare_analysis" and cellchat_output_name:
                cellchat_samples = [cellchat_output_name]
            outs += expand(os.path.join(results_folder, "{sample}", "cellchat", "{sample}_cellchat_network.png"), sample=cellchat_samples)
            outs += expand(os.path.join(results_folder, "{sample}", "cellchat", "{sample}_cellchat_network.pdf"), sample=cellchat_samples)
            outs += expand(os.path.join(results_folder, "{sample}", "cellchat", "{sample}_cellchat_stats.csv"), sample=cellchat_samples)
            outs += expand(os.path.join(results_folder, "{sample}", "cellchat", "{sample}_cellchat_lr.csv"), sample=cellchat_samples)
        return outs

def input_file(run_type):
    if run_type in runtype and run_type != "Merfish":
        return(os.path.join(data_fold, '{sample}', main_file))
    elif run_type == "Merfish":
        return(os.path.join(data_fold, "{sample}"))
    elif run_type == "visium_HD":
        return(os.path.join(data_fold, '{sample}', "binned_outputs", "square_{bin}um", main_file))
    elif run_type == "visium_segment":
        return os.path.join(data_fold, "{sample}", "segmented_outputs", main_file)

def get_output(run_type):
    if channel == 'single_analysis':
        if run_type == "visium_HD":
            return(directory(os.path.join(results_folder, "{sample}_{bin}um", "{sample}.zarr")))
        elif run_type in runtype:
            return(directory(os.path.join(results_folder, "{sample}", "{sample}.zarr")))
        elif run_type == "slide_seq":
            return(os.path.join(results_folder, "{sample}", "{sample}.h5ad"))
    elif channel == "compare_analysis":
        if run_type == "visium_HD":
            return(directory(os.path.join(results_folder, "{group}_{bin}um", "{sample}.zarr")))
        elif run_type in runtype:
            return(directory(os.path.join(results_folder, "{group}", "{sample}.zarr")))
        elif run_type == "slide_seq":
            return(os.path.join(results_folder, "{group}", "{sample}.h5ad"))  

def nomal_file(run_type, segs):
    if channel == 'single_analysis':
        if run_type == "visium_HD":
            return((os.path.join(results_folder, "{sample}_{bin}um", segs, "{sample}.zarr")))
        elif run_type in runtype:
            return(os.path.join(results_folder, "{sample}", segs, "{sample}.zarr"))
        elif run_type == "slide_seq":
            return(os.path.join(results_folder, "{sample}", segs, "{sample}.h5ad"))
    if channel == "compare_analysis":
        return((parameter_output(samples, segs)))

if not filter_list:
    min_genes = config.get('min_genes', 50)
    min_cells = config.get('min_cells', 50)
    L.info(f"{min_genes},{min_cells}")
else:
    L.info("!!!!! per-sample filtering is enabled, reading min_cells/min_genes/mt_threshold from sample.txt")
    filter_dict = seg_filter_sample(sample_list)


flag = 'ALL'

rule all:
    input:
        parameter_output(samples, option)

if option == "integrate":
    include: "rules/integrate.smk"
    if channel == "compare_analysis":
        flag = 'merge'
        include: "rules/merge.smk"
elif option == "preprocess":
    include: "rules/preprocess.smk"
elif option == "clustering":
    include: "rules/cluster.smk"
elif option == "reclustering":
    include: "rules/reclustering.smk"
elif option == "annotation_help":
    include: "rules/annotation_help.smk"
elif option == "annotation":
    if anno_algorithm == "cell2Location":
        input_spatial = downstream_file[0]
        input_singlecell = reference[0]
        include: "rules/cell2Location_run.smk"
    elif anno_algorithm == "mannul":
        anno_data = get_annotation(annotation_list, samples,channel,results_folder)
        include: "rules/mannul.smk"
    elif anno_algorithm == "reannotation":
        anno_data = get_annotation(annotation_list, samples,channel,results_folder)
        include: "rules/reannotation.smk"
    elif anno_algorithm == "RCTD":
        input_spatial = downstream_file[0]
        input_singlecell = reference[0]
        include: "rules/RCTD.smk"
elif option == "compare_stage":
    if runpipe == "cellchat":
        include: "rules/compare_LR.smk"
    else:
        if channel == "compare_analysis":
            annotation_input = os.path.join(results_folder, "merge_data", "annotation", "concatenated_sdata.h5ad" if run_type == "slide_seq" else "concatenated_sdata.zarr")
        include: "rules/compare_gene.smk"
elif option == "advance_analysis":
    if runpipe == "cellPhoneDB":
        include: "rules/cellPhoneDB.smk"
    elif runpipe == "pysenic":
        input_pysenic = downstream_file[0]
        print(input_pysenic)
        include: "rules/py_senic.smk"
    elif runpipe == "liana":
        liana_inputs = downstream_file[0]
        include: "rules/run_liana.smk"
    elif runpipe == "cellcharter":
        include: "rules/run_cellcharter.smk"
    elif runpipe == "banksy":
        include: "rules/run_banksy.smk"
    elif runpipe == "cellchat":
        input_spatial = downstream_file[0]
        include: "rules/run_cellchat.smk"
elif option == "all":
    include: "rules/integrate.smk"
    if channel == "compare_analysis":
        flag = 'merge'
        include: "rules/merge.smk"
    include: "rules/preprocess.smk"
    include: "rules/cluster.smk"
    include: "rules/annotation_help.smk"
