Source code for spacr.object


import os, gc, torch, time
import numpy as np
import pandas as pd
from multiprocessing import Pool, cpu_count
import matplotlib.pyplot as plt
from IPython.display import display
import warnings

from functools import partial
from skimage.segmentation import watershed
from skimage.measure import label as sk_label, regionprops
from scipy.ndimage import distance_transform_edt
from skimage.filters import (threshold_otsu,threshold_local,frangi,sato,meijering,gaussian,difference_of_gaussians,apply_hysteresis_threshold)
from skimage.feature import blob_log, blob_dog, peak_local_max
from skimage.morphology import (remove_small_objects,remove_small_holes,binary_opening,binary_closing,binary_dilation,binary_erosion,disk,skeletonize,white_tophat)
from skimage.exposure import equalize_adapthist, rescale_intensity
from skimage.restoration import rolling_ball

warnings.filterwarnings("ignore", message="3D stack used, but stitch_threshold=0 and do_3D=False, so masks are made per plane only")

[docs] def generate_cellpose_masks(src, settings, object_type): from .utils import _masks_to_masks_stack, _filter_cp_masks, _get_cellpose_channels, _choose_model, all_elements_match, prepare_batch_for_segmentation from .io import _create_database, _save_object_counts_to_database, _check_masks, _get_avg_object_size from .timelapse import _npz_to_movie, _btrack_track_cells, _trackpy_track_cells from .plot import plot_cellpose4_output from .settings import set_default_settings_preprocess_generate_masks, _get_object_settings from .spacr_cellpose import parse_cellpose4_output gc.collect() if not torch.cuda.is_available(): print(f'Torch CUDA is not available, using CPU') settings['src'] = src settings = set_default_settings_preprocess_generate_masks(settings) if settings['verbose']: settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value']) settings_df['setting_value'] = settings_df['setting_value'].apply(str) display(settings_df) figuresize=10 timelapse = settings['timelapse'] if timelapse: timelapse_displacement = settings['timelapse_displacement'] timelapse_frame_limits = settings['timelapse_frame_limits'] timelapse_memory = settings['timelapse_memory'] timelapse_remove_transient = settings['timelapse_remove_transient'] timelapse_mode = settings['timelapse_mode'] timelapse_objects = settings['timelapse_objects'] batch_size = settings['batch_size'] cellprob_threshold = settings[f'{object_type}_CP_prob'] flow_threshold = settings[f'{object_type}_FT'] object_settings = _get_object_settings(object_type, settings) model_name = object_settings['model_name'] cellpose_channels = _get_cellpose_channels(src, settings['nucleus_channel'], settings['pathogen_channel'], settings['cell_channel']) if settings['verbose']: print(cellpose_channels) if object_type not in cellpose_channels: raise ValueError(f"Error: No channels were specified for object_type '{object_type}'. Check your settings.") channels = cellpose_channels[object_type] device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if object_type == 'pathogen' and not settings['pathogen_model'] is None: model_name = settings['pathogen_model'] model = _choose_model(model_name, device, object_type=object_type, restore_type=None, object_settings=object_settings) #chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [2,0] if model_name == 'cyto' else [2, 0] if model_name == 'cyto3' else [2, 0] paths = [os.path.join(src, file) for file in os.listdir(src) if file.endswith('.npz')] count_loc = os.path.dirname(src)+'/measurements/measurements.db' os.makedirs(os.path.dirname(src)+'/measurements', exist_ok=True) _create_database(count_loc) average_sizes = [] average_count = [] time_ls = [] for file_index, path in enumerate(paths): name = os.path.basename(path) name, ext = os.path.splitext(name) output_folder = os.path.join(os.path.dirname(path), object_type+'_mask_stack') os.makedirs(output_folder, exist_ok=True) overall_average_size = 0 with np.load(path) as data: stack = data['data'] filenames = data['filenames'] for i, filename in enumerate(filenames): output_path = os.path.join(output_folder, filename) if os.path.exists(output_path): print(f"File {filename} already exists in the output folder. Skipping...") continue if settings['timelapse']: trackable_objects = ['cell','nucleus','pathogen'] if not all_elements_match(settings['timelapse_objects'], trackable_objects): print(f'timelapse_objects {settings["timelapse_objects"]} must be a subset of {trackable_objects}') return if len(stack) != batch_size: print(f'Changed batch_size:{batch_size} to {len(stack)}, data length:{len(stack)}') settings['timelapse_batch_size'] = len(stack) batch_size = len(stack) if isinstance(timelapse_frame_limits, list): if len(timelapse_frame_limits) >= 2: stack = stack[timelapse_frame_limits[0]: timelapse_frame_limits[1], :, :, :].astype(stack.dtype) filenames = filenames[timelapse_frame_limits[0]: timelapse_frame_limits[1]] batch_size = len(stack) print(f'Cut batch at indecies: {timelapse_frame_limits}, New batch_size: {batch_size} ') for i in range(0, stack.shape[0], batch_size): mask_stack = [] if stack.shape[3] == 1: batch = stack[i: i+batch_size, :, :, [0,0]].astype(stack.dtype) else: batch = stack[i: i+batch_size, :, :, channels].astype(stack.dtype) batch_filenames = filenames[i: i+batch_size].tolist() if not settings['plot']: batch, batch_filenames = _check_masks(batch, batch_filenames, output_folder) if batch.size == 0: continue batch = prepare_batch_for_segmentation(batch) batch_list = [batch[i] for i in range(batch.shape[0])] if timelapse: movie_path = os.path.join(os.path.dirname(src), 'movies') os.makedirs(movie_path, exist_ok=True) save_path = os.path.join(movie_path, f'timelapse_{object_type}_{name}.mp4') _npz_to_movie(batch, batch_filenames, save_path, fps=2) output = model.eval(x=batch_list, batch_size=batch_size, normalize=False, channel_axis=-1, channels=channels, diameter=object_settings['diameter'], flow_threshold=flow_threshold, cellprob_threshold=cellprob_threshold, rescale=None, resample=object_settings['resample']) masks, flows, _, _, _ = parse_cellpose4_output(output) if timelapse: if settings['plot']: plot_cellpose4_output(batch_list, masks, flows, cmap='inferno', figuresize=figuresize, nr=1, print_object_number=True) _save_object_counts_to_database(masks, object_type, batch_filenames, count_loc, added_string='_timelapse') if object_type in timelapse_objects: if timelapse_mode == 'btrack': if not timelapse_displacement is None: radius = timelapse_displacement else: radius = 100 n_jobs = os.cpu_count()-2 if n_jobs < 1: n_jobs = 1 mask_stack = _btrack_track_cells(src=src, name=name, batch_filenames=batch_filenames, object_type=object_type, plot=settings['plot'], save=settings['save'], masks_3D=masks, mode=timelapse_mode, timelapse_remove_transient=timelapse_remove_transient, radius=radius, n_jobs=n_jobs, batch_list=None, optimizer_time_limit_s=120, optimizer_mip_gap=0.01, run_optimization=True, max_objects_for_optimization=20000) if timelapse_mode == 'trackpy' or timelapse_mode == 'iou': if timelapse_mode == 'iou': track_by_iou = True else: track_by_iou = False mask_stack = _trackpy_track_cells(src=src, name=name, batch_filenames=batch_filenames, object_type=object_type, masks=masks, timelapse_displacement=timelapse_displacement, timelapse_memory=timelapse_memory, timelapse_remove_transient=timelapse_remove_transient, plot=settings['plot'], save=settings['save'], mode=timelapse_mode, track_by_iou=track_by_iou) else: mask_stack = _masks_to_masks_stack(masks) else: _save_object_counts_to_database(masks, object_type, batch_filenames, count_loc, added_string='_before_filtration') if object_settings['merge'] and not settings['filter']: mask_stack = _filter_cp_masks(masks=masks, flows=flows, filter_size=False, filter_intensity=False, minimum_size=object_settings['minimum_size'], maximum_size=object_settings['maximum_size'], remove_border_objects=False, merge=object_settings['merge'], batch=batch, plot=settings['plot'], figuresize=figuresize) if settings['filter']: mask_stack = _filter_cp_masks(masks=masks, flows=flows, filter_size=object_settings['filter_size'], filter_intensity=object_settings['filter_intensity'], minimum_size=object_settings['minimum_size'], maximum_size=object_settings['maximum_size'], remove_border_objects=object_settings['remove_border_objects'], merge=object_settings['merge'], batch=batch, plot=settings['plot'], figuresize=figuresize) _save_object_counts_to_database(mask_stack, object_type, batch_filenames, count_loc, added_string='_after_filtration') else: mask_stack = _masks_to_masks_stack(masks) if timelapse and settings.get("motility_analysis", False): from .timelapse import automated_motility_assay _ = automated_motility_assay(settings) if not np.any(mask_stack): avg_num_objects_per_image, average_obj_size = 0, 0 else: avg_num_objects_per_image, average_obj_size = _get_avg_object_size(mask_stack) average_count.append(avg_num_objects_per_image) average_sizes.append(average_obj_size) overall_average_size = np.mean(average_sizes) if len(average_sizes) > 0 else 0 overall_average_count = np.mean(average_count) if len(average_count) > 0 else 0 print(f'Found {overall_average_count} {object_type}/FOV. average size: {overall_average_size:.3f} px2') if not timelapse: if settings['plot']: plot_cellpose4_output(batch_list, masks, flows, cmap='inferno', figuresize=figuresize, nr=batch_size) if settings['save']: for mask_index, mask in enumerate(mask_stack): output_filename = os.path.join(output_folder, batch_filenames[mask_index]) mask = mask.astype(np.uint16) np.save(output_filename, mask) mask_stack = [] batch_filenames = [] gc.collect() torch.cuda.empty_cache() return
def _segment_single_image(img, settings): """ Segment a single 2-D image. This is the unit of work for multiprocessing. Must be a top-level function (picklable). Parameters ---------- img : np.ndarray 2-D float32 image. settings : dict Classical segmentation settings (pickle-safe). Returns ------- np.ndarray 2-D int32 label array. """ morphology = settings['organelle_morphology'] method = settings['organelle_method'] if morphology == 'spots': return _segment_spots(img, method, settings) elif morphology == 'network': return _segment_network(img, method, settings) elif morphology == 'irregular': return _segment_irregular(img, method, settings) else: raise ValueError(f"Unknown morphology: {morphology}") def _segment_spots(img, method, settings): """ Segment punctate / spot-like organelles. Strategies ---------- 'otsu' : top-hat -> Otsu -> watershed 'adaptive' : top-hat -> adaptive threshold -> watershed 'log' : Laplacian-of-Gaussian blob detection -> marker-based watershed """ tophat_radius = settings['organelle_tophat_radius'] use_watershed = settings['organelle_watershed_spots'] if method == 'log': return _spots_log(img, settings, use_watershed) # --- Pre-filter: white top-hat enhances bright spots on dark bg --- filtered = white_tophat(img, disk(tophat_radius)) # --- Threshold --- if method == 'otsu': thresh_val = threshold_otsu(filtered) binary = filtered > thresh_val elif method == 'adaptive': block = settings['organelle_adaptive_block_size'] offset = settings['organelle_adaptive_offset'] local_thresh = threshold_local(filtered, block_size=block, offset=offset) binary = filtered > local_thresh else: raise ValueError(f"Unsupported spot method: {method}") # --- Morphological cleanup --- binary = binary_opening(binary, disk(1)) binary = remove_small_objects(binary, min_size=settings['organelle_min_size']) # --- Watershed to split touching spots --- if use_watershed: labeled = _watershed_split(binary, filtered) else: labeled = sk_label(binary) return labeled def _segment_network(img, method, settings): """ Segment filamentous / reticular organelles. Strategies ---------- 'otsu' : Gaussian smooth -> Otsu -> morphological cleanup 'adaptive' : Gaussian smooth -> adaptive threshold -> cleanup 'ridge' : Ridge filter (Frangi / Sato / Meijering) -> threshold -> label """ if method == 'ridge': return _network_ridge(img, settings) smooth = gaussian(img, sigma=1) if method == 'otsu': thresh_val = threshold_otsu(smooth) binary = smooth > thresh_val elif method == 'adaptive': block = settings['organelle_adaptive_block_size'] offset = settings['organelle_adaptive_offset'] local_thresh = threshold_local(smooth, block_size=block, offset=offset) binary = smooth > local_thresh else: raise ValueError(f"Unsupported network method: {method}") morph_r = max(settings['organelle_morph_radius'] // 2, 1) binary = binary_closing(binary, disk(morph_r)) binary = remove_small_objects(binary, min_size=settings['organelle_min_size']) if settings['organelle_skeletonize']: skeleton = skeletonize(binary) skeleton = binary_dilation(skeleton, disk(1)) return sk_label(skeleton) return sk_label(binary)
[docs] def generate_organelle_masks(src, settings, object_type): """ Generate organelle masks using multiple segmentation strategies. Supports four morphology modes: - 'spots': punctate structures (lipid droplets, vesicles, peroxisomes) - 'network': filamentous/reticular structures (mitochondria, microtubules, ER tubules) - 'irregular': irregular-shaped organelles (Golgi, ER cisternae, lysosomes) - 'ring': hollow / ring-shaped structures (endosomes, autophagosomes, late lysosomes) Each mode can use different backends: - 'cellpose': deep-learning segmentation via Cellpose - 'stardist': star-convex polygon instance segmentation (spots only) - 'otsu': global Otsu thresholding with morphological cleanup - 'adaptive': local adaptive thresholding - 'log': Laplacian of Gaussian blob detection (spots, ring) - 'dog': Difference of Gaussians blob detection (spots, ring) - 'ridge': ridge/tubeness filter (network only) - 'hysteresis': dual-threshold hysteresis (network only) - 'unet': user-provided U-Net semantic segmentation (network only) Parameters ---------- src : str Path to the mask source directory containing .npz stacks. settings : dict Configuration dictionary. Organelle-specific keys (all prefixed with 'organelle_') are documented in _set_organelle_defaults. object_type : str Should be 'organelle' (or a custom name used for folder naming). Returns ------- None Masks are saved as .npy files in ``{src}/{object_type}_mask_stack/``. """ from .io import _create_database, _save_object_counts_to_database, _check_masks, _get_avg_object_size from .utils import _masks_to_masks_stack, _filter_cp_masks, prepare_batch_for_segmentation from .settings import _set_organelle_defaults from.plot import plot_organelle_output gc.collect() settings = _set_organelle_defaults(settings) morphology = settings['organelle_morphology'] method = settings['organelle_method'] organelle_channel = settings['organelle_channel'] _validate_organelle_settings(morphology, method) n_jobs = settings.get('n_jobs', 1) if n_jobs < 1: n_jobs = 1 if settings['verbose']: import pandas as pd from IPython.display import display organ_keys = {k: v for k, v in settings.items() if k.startswith('organelle_')} df = pd.DataFrame(list(organ_keys.items()), columns=['setting_key', 'setting_value']) df['setting_value'] = df['setting_value'].apply(str) display(df) paths = [os.path.join(src, f) for f in os.listdir(src) if f.endswith('.npz')] if not paths: print(f'No .npz files found in {src}') return count_loc = os.path.join(os.path.dirname(src), 'measurements', 'measurements.db') os.makedirs(os.path.dirname(count_loc), exist_ok=True) _create_database(count_loc) batch_size = settings['batch_size'] average_sizes = [] average_counts = [] time_ls = [] # ------------------------------------------------------------------ # # Load deep-learning model once (if needed) # ------------------------------------------------------------------ # dl_model = None is_dl_method = method in ('cellpose', 'stardist', 'unet') if method == 'cellpose': from .utils import _choose_model device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") dl_model = _choose_model( settings['organelle_model_name'], device, object_type=object_type, restore_type=None, object_settings=_build_object_settings(settings), ) elif method == 'stardist': dl_model = _load_stardist_model(settings) elif method == 'unet': dl_model = _load_unet_model(settings) # ------------------------------------------------------------------ # # Build a serialisable settings subset for worker processes # ------------------------------------------------------------------ # classical_settings = _extract_classical_settings(settings) # ------------------------------------------------------------------ # # Optionally load cell masks for per-cell masking # ------------------------------------------------------------------ # cell_mask_folder = None if settings.get('organelle_mask_within_cells', False): candidate = os.path.join(os.path.dirname(src), 'cell_mask_stack') if os.path.exists(candidate): cell_mask_folder = candidate print(f'Per-cell masking enabled, using cell masks from {candidate}') else: print(f'Warning: organelle_mask_within_cells=True but no cell_mask_stack found at {candidate}') # ------------------------------------------------------------------ # # Main loop over .npz stacks # ------------------------------------------------------------------ # for file_index, path in enumerate(paths): name = os.path.splitext(os.path.basename(path))[0] output_folder = os.path.join(os.path.dirname(path), f'{object_type}_mask_stack') os.makedirs(output_folder, exist_ok=True) with np.load(path) as data: stack = data['data'] filenames = data['filenames'] # Skip already-processed files existing = set(os.listdir(output_folder)) todo_indices = [i for i, fn in enumerate(filenames) if fn not in existing] if not todo_indices: print(f'All files in {name} already processed. Skipping.') continue for i in range(0, stack.shape[0], batch_size): start = time.time() batch = stack[i: i + batch_size] batch_filenames = filenames[i: i + batch_size].tolist() # ---------------------------------------------------------- # # Extract the organelle channel # ---------------------------------------------------------- # if organelle_channel is not None: if batch.ndim == 4: img_batch = batch[:, :, :, organelle_channel].astype(np.float32) else: img_batch = batch.astype(np.float32) else: if batch.ndim == 4: img_batch = batch[:, :, :, 0].astype(np.float32) else: img_batch = batch.astype(np.float32) # ---------------------------------------------------------- # # Per-cell masking: zero out pixels outside cells # ---------------------------------------------------------- # if cell_mask_folder is not None: img_batch = _apply_cell_mask(img_batch, batch_filenames, cell_mask_folder) # ---------------------------------------------------------- # # Preprocessing: rolling ball and/or CLAHE # ---------------------------------------------------------- # img_batch = _preprocess_batch(img_batch, settings) # ---------------------------------------------------------- # # Segment # ---------------------------------------------------------- # if method == 'cellpose': masks = _segment_cellpose( batch, batch_filenames, dl_model, settings, object_type, output_folder, ) elif method == 'stardist': masks = _segment_stardist(img_batch, dl_model, settings) elif method == 'unet': masks = _segment_unet(img_batch, dl_model, settings) else: # CPU-bound classical methods — parallelise masks = _segment_classical_parallel( img_batch, classical_settings, n_jobs=n_jobs, ) if masks is None or len(masks) == 0: continue # ---------------------------------------------------------- # # Post-process: size filter, border removal # ---------------------------------------------------------- # mask_stack = _postprocess_masks( masks, min_size=settings['organelle_min_size'], max_size=settings['organelle_max_size'], remove_border=settings['organelle_remove_border'], ) _save_object_counts_to_database( mask_stack, object_type, batch_filenames, count_loc, added_string='', ) # Stats if not np.any(mask_stack): avg_count, avg_size = 0, 0 else: avg_count, avg_size = _get_avg_object_size(mask_stack) average_counts.append(avg_count) average_sizes.append(avg_size) overall_avg_count = np.mean(average_counts) overall_avg_size = np.mean(average_sizes) stop = time.time() duration = stop - start time_ls.append(duration) print( f'Found {overall_avg_count:.1f} {object_type}/FOV, ' f'average size: {overall_avg_size:.1f} px2 ' f'[batch {file_index+1}/{len(paths)}, {duration:.1f}s, ' f'n_jobs={n_jobs if not is_dl_method else "GPU"}]' ) # ---------------------------------------------------------- # # Plot (if enabled) # ---------------------------------------------------------- # if settings.get('plot', False): plot_organelle_output( img_batch[: len(mask_stack)], mask_stack, settings, cmap='inferno', figuresize=10, nr=min(settings.get('examples_to_plot', 1), len(mask_stack)), print_object_number=True, ) # ---------------------------------------------------------- # # Save # ---------------------------------------------------------- # if settings['save']: for mask_idx, mask in enumerate(mask_stack): out_path = os.path.join(output_folder, batch_filenames[mask_idx]) np.save(out_path, mask.astype(np.uint16)) mask_stack = [] batch_filenames = [] gc.collect() torch.cuda.empty_cache() return
def _validate_organelle_settings(morphology, method): """Raise early on invalid morphology / method combinations.""" valid_morphologies = ('spots', 'network', 'irregular', 'ring') if morphology not in valid_morphologies: raise ValueError( f"organelle_morphology must be one of {valid_morphologies}, got '{morphology}'" ) method_map = { 'spots': ('otsu', 'adaptive', 'log', 'dog', 'cellpose', 'stardist'), 'network': ('otsu', 'adaptive', 'ridge', 'hysteresis', 'cellpose', 'unet'), 'irregular': ('otsu', 'adaptive', 'cellpose'), 'ring': ('otsu', 'adaptive', 'dog', 'log', 'cellpose'), } valid_methods = method_map[morphology] if method not in valid_methods: raise ValueError( f"For morphology='{morphology}', method must be one of {valid_methods}, got '{method}'" ) def _build_object_settings(settings): """Build an object_settings dict expected by _choose_model / cellpose eval.""" return { 'model_name': settings['organelle_model_name'], 'diameter': settings['organelle_diameter'], 'minimum_size': settings['organelle_min_size'], 'maximum_size': settings['organelle_max_size'], 'resample': settings['organelle_resample'], 'filter_size': False, 'filter_intensity': False, 'remove_border_objects': settings['organelle_remove_border'], 'merge': False, } def _extract_classical_settings(settings): """ Extract only the plain-data keys needed by classical segmentation workers. This dict is safe to pickle for multiprocessing. """ keys = [ 'organelle_morphology', 'organelle_method', 'organelle_min_size', 'organelle_max_size', # Spots 'organelle_tophat_radius', 'organelle_watershed_spots', 'organelle_log_min_sigma', 'organelle_log_max_sigma', 'organelle_log_num_sigma', 'organelle_log_threshold', 'organelle_dog_sigma_low', 'organelle_dog_sigma_high', # Network 'organelle_ridge_sigmas', 'organelle_ridge_filter', 'organelle_skeletonize', 'organelle_network_threshold', 'organelle_hysteresis_low', 'organelle_hysteresis_high', # Irregular 'organelle_adaptive_block_size', 'organelle_adaptive_offset', 'organelle_morph_radius', 'organelle_fill_holes', # Ring 'organelle_ring_sigma_inner', 'organelle_ring_sigma_outer', 'organelle_ring_min_prominence', 'organelle_ring_fill_method', ] return {k: settings[k] for k in keys if k in settings} # ====================================================================== # # Preprocessing # ====================================================================== # def _preprocess_batch(img_batch, settings): """ Apply optional preprocessing to the entire image batch. Operations (applied in order if enabled): 1. Rolling ball background subtraction 2. CLAHE (Contrast Limited Adaptive Histogram Equalization) Parameters ---------- img_batch : np.ndarray Shape (N, H, W) float32. settings : dict Returns ------- np.ndarray Preprocessed batch, same shape. """ do_rolling_ball = settings.get('organelle_rolling_ball', False) do_clahe = settings.get('organelle_clahe', False) if not do_rolling_ball and not do_clahe: return img_batch out = img_batch.copy() for idx in range(out.shape[0]): img = out[idx] if do_rolling_ball: radius = settings.get('organelle_rolling_ball_radius', 50) bg = rolling_ball(img, radius=radius) img = img - bg img = np.clip(img, 0, None) if do_clahe: clip_limit = settings.get('organelle_clahe_clip_limit', 0.01) pmin, pmax = np.percentile(img, (0.5, 99.5)) if pmax - pmin > 0: img_norm = np.clip((img - pmin) / (pmax - pmin), 0, 1) else: img_norm = np.zeros_like(img) img = equalize_adapthist(img_norm, clip_limit=clip_limit).astype(np.float32) out[idx] = img return out def _apply_cell_mask(img_batch, batch_filenames, cell_mask_folder): """ Zero out pixels outside cell boundaries for per-cell organelle detection. Parameters ---------- img_batch : np.ndarray Shape (N, H, W) float32. batch_filenames : list of str cell_mask_folder : str Path to cell_mask_stack directory. Returns ------- np.ndarray Masked batch. """ out = img_batch.copy() for idx, fn in enumerate(batch_filenames): cell_mask_path = os.path.join(cell_mask_folder, fn) if os.path.exists(cell_mask_path): cell_mask = np.load(cell_mask_path) out[idx][cell_mask == 0] = 0 else: cell_mask_path_npy = cell_mask_path if cell_mask_path.endswith('.npy') else cell_mask_path + '.npy' if os.path.exists(cell_mask_path_npy): cell_mask = np.load(cell_mask_path_npy) out[idx][cell_mask == 0] = 0 return out # ====================================================================== # # Deep-learning model loaders # ====================================================================== # def _load_stardist_model(settings): """Load a Stardist model (pretrained or custom).""" try: from stardist.models import StarDist2D except ImportError: raise ImportError( "Stardist is required for method='stardist'. " "Install with: pip install stardist" ) model_name = settings.get('organelle_stardist_model', '2D_versatile_fluo') if os.path.isdir(model_name): model = StarDist2D(None, name=os.path.basename(model_name), basedir=os.path.dirname(model_name)) else: model = StarDist2D.from_pretrained(model_name) return model def _load_unet_model(settings): """Load a user-provided U-Net model from a .pt / .pth file.""" model_path = settings.get('organelle_unet_model_path') if model_path is None or not os.path.exists(model_path): raise ValueError( f"organelle_unet_model_path must point to a valid .pt/.pth file, " f"got '{model_path}'" ) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = torch.load(model_path, map_location=device, weights_only=False) model.eval() return model # ====================================================================== # # Cellpose segmentation # ====================================================================== # def _segment_cellpose(batch, batch_filenames, model, settings, object_type, output_folder): """Run Cellpose on a batch and return a list of 2-D label arrays.""" from .utils import prepare_batch_for_segmentation from .io import _check_masks from .spacr_cellpose import parse_cellpose4_output organelle_ch = settings['organelle_channel'] if organelle_ch is None: organelle_ch = 0 if batch.ndim == 4: ch0 = batch[:, :, :, organelle_ch: organelle_ch + 1] nuc_ch = settings.get('nucleus_channel') if nuc_ch is not None and nuc_ch < batch.shape[3]: ch1 = batch[:, :, :, nuc_ch: nuc_ch + 1] else: ch1 = ch0 cp_batch = np.concatenate([ch0, ch1], axis=-1).astype(batch.dtype) else: cp_batch = np.stack([batch, batch], axis=-1).astype(batch.dtype) if not settings.get('plot', False): cp_batch, batch_filenames = _check_masks(cp_batch, batch_filenames, output_folder) if cp_batch.size == 0: return None cp_batch = prepare_batch_for_segmentation(cp_batch) batch_list = [cp_batch[j] for j in range(cp_batch.shape[0])] output = model.eval( x=batch_list, batch_size=settings['batch_size'], normalize=False, channel_axis=-1, channels=[0, 1], diameter=settings['organelle_diameter'], flow_threshold=settings['organelle_FT'], cellprob_threshold=settings['organelle_CP_prob'], rescale=None, resample=settings['organelle_resample'], ) masks, flows, _, _, _ = parse_cellpose4_output(output) return masks # ====================================================================== # # Stardist segmentation (GPU — not parallelised) # ====================================================================== # def _segment_stardist(img_batch, model, settings): """ Run Stardist on a batch of single-channel images. Best for dense, convex, spot-like organelles (lipid droplets, peroxisomes). """ prob_thresh = settings.get('organelle_stardist_prob', 0.5) nms_thresh = settings.get('organelle_stardist_nms', 0.3) masks = [] for idx in range(img_batch.shape[0]): img = img_batch[idx] pmin, pmax = np.percentile(img, (1, 99.8)) if pmax - pmin > 0: img_norm = np.clip((img - pmin) / (pmax - pmin), 0, 1).astype(np.float32) else: img_norm = np.zeros_like(img, dtype=np.float32) labels, _ = model.predict_instances( img_norm, prob_thresh=prob_thresh, nms_thresh=nms_thresh, ) masks.append(labels) return masks # ====================================================================== # # U-Net semantic segmentation (GPU — not parallelised) # ====================================================================== # def _segment_unet(img_batch, model, settings): """ Run a user-provided U-Net for semantic segmentation of networks. Expects a model that takes (B, 1, H, W) and outputs (B, 1, H, W) logits. """ device = next(model.parameters()).device threshold = settings.get('organelle_unet_threshold', 0.5) do_skeleton = settings.get('organelle_skeletonize', False) masks = [] with torch.no_grad(): for idx in range(img_batch.shape[0]): img = img_batch[idx] mean, std = img.mean(), img.std() if std > 0: img_norm = (img - mean) / std else: img_norm = np.zeros_like(img) tensor = torch.from_numpy(img_norm[None, None]).float().to(device) pred = model(tensor) if pred.shape[1] > 1: pred = pred[:, 0:1, :, :] pred = pred.sigmoid().cpu().numpy()[0, 0] binary = pred > threshold binary = remove_small_objects(binary, min_size=settings['organelle_min_size']) if do_skeleton: skeleton = skeletonize(binary) skeleton = binary_dilation(skeleton, disk(1)) masks.append(sk_label(skeleton)) else: masks.append(sk_label(binary)) return masks # ====================================================================== # # Classical segmentation — parallel dispatcher # ====================================================================== # def _segment_classical_parallel(img_batch, classical_settings, n_jobs=1): """ Segment a batch of images using classical methods, optionally in parallel. Parameters ---------- img_batch : np.ndarray Shape (N, H, W) float32 single-channel images. classical_settings : dict Pickle-safe subset of settings for classical segmentation. n_jobs : int Number of worker processes. 1 = sequential (no multiprocessing overhead). Returns ------- list of np.ndarray List of 2-D integer label arrays, one per image. """ n_images = img_batch.shape[0] if n_jobs == 1 or n_images == 1: return [_segment_single_image(img_batch[idx], classical_settings) for idx in range(n_images)] effective_jobs = min(n_jobs, n_images, cpu_count()) worker_fn = partial(_segment_single_image, settings=classical_settings) image_list = [img_batch[idx] for idx in range(n_images)] with Pool(processes=effective_jobs) as pool: masks = pool.map(worker_fn, image_list) return masks def _segment_single_image(img, settings): """ Segment a single 2-D image. Top-level function for multiprocessing. """ morphology = settings['organelle_morphology'] method = settings['organelle_method'] if morphology == 'spots': return _segment_spots(img, method, settings) elif morphology == 'network': return _segment_network(img, method, settings) elif morphology == 'irregular': return _segment_irregular(img, method, settings) elif morphology == 'ring': return _segment_ring(img, method, settings) else: raise ValueError(f"Unknown morphology: {morphology}") # ====================================================================== # # SPOTS segmentation # ====================================================================== # def _segment_spots(img, method, settings): """ Segment punctate / spot-like organelles. Strategies ---------- 'otsu' : top-hat -> Otsu -> watershed 'adaptive' : top-hat -> adaptive threshold -> watershed 'log' : Laplacian-of-Gaussian blob detection -> marker-based watershed 'dog' : Difference-of-Gaussians blob detection -> marker-based watershed """ tophat_radius = settings['organelle_tophat_radius'] use_watershed = settings['organelle_watershed_spots'] if method == 'log': return _spots_log(img, settings, use_watershed) elif method == 'dog': return _spots_dog(img, settings, use_watershed) # --- Pre-filter: white top-hat enhances bright spots on dark bg --- filtered = white_tophat(img, disk(tophat_radius)) # --- Threshold --- if method == 'otsu': thresh_val = threshold_otsu(filtered) binary = filtered > thresh_val elif method == 'adaptive': block = settings['organelle_adaptive_block_size'] offset = settings['organelle_adaptive_offset'] local_thresh = threshold_local(filtered, block_size=block, offset=offset) binary = filtered > local_thresh else: raise ValueError(f"Unsupported spot method: {method}") # --- Morphological cleanup --- binary = binary_opening(binary, disk(1)) binary = remove_small_objects(binary, min_size=settings['organelle_min_size']) # --- Watershed to split touching spots --- if use_watershed: labeled = _watershed_split(binary, filtered) else: labeled = sk_label(binary) return labeled def _spots_log(img, settings, use_watershed): """LoG blob detection -> marker-seeded watershed.""" min_s = settings['organelle_log_min_sigma'] max_s = settings['organelle_log_max_sigma'] num_s = settings['organelle_log_num_sigma'] thresh = settings['organelle_log_threshold'] img_norm = _normalize_01(img) blobs = blob_log(img_norm, min_sigma=min_s, max_sigma=max_s, num_sigma=num_s, threshold=thresh) if len(blobs) == 0: return np.zeros(img.shape, dtype=np.int32) return _blobs_to_labels(blobs, img_norm, use_watershed) def _spots_dog(img, settings, use_watershed): """ Difference-of-Gaussians blob detection -> marker-seeded watershed. Faster than LoG, nearly equivalent accuracy for fluorescence microscopy. """ sigma_low = settings.get('organelle_dog_sigma_low', 1.0) sigma_high = settings.get('organelle_dog_sigma_high', 3.0) thresh = settings['organelle_log_threshold'] img_norm = _normalize_01(img) blobs = blob_dog(img_norm, min_sigma=sigma_low, max_sigma=sigma_high, threshold=thresh) if len(blobs) == 0: return np.zeros(img.shape, dtype=np.int32) return _blobs_to_labels(blobs, img_norm, use_watershed) def _blobs_to_labels(blobs, img_norm, use_watershed): """ Convert blob coordinates (y, x, sigma) to a label image. Shared by LoG and DoG methods. """ shape = img_norm.shape markers = np.zeros(shape, dtype=np.int32) for i, (y, x, sigma) in enumerate(blobs, start=1): y, x = int(round(y)), int(round(x)) if 0 <= y < shape[0] and 0 <= x < shape[1]: markers[y, x] = i if not use_watershed: labeled = np.zeros(shape, dtype=np.int32) for i, (y, x, sigma) in enumerate(blobs, start=1): rr, cc = _circle_coords(int(round(y)), int(round(x)), max(int(round(sigma * np.sqrt(2))), 1), shape) labeled[rr, cc] = i return labeled smooth = gaussian(img_norm, sigma=1) labeled = watershed(-smooth, markers, mask=(smooth > np.percentile(smooth, 20))) return labeled def _circle_coords(cy, cx, radius, shape): """Return (row, col) arrays for a filled circle clipped to shape.""" yy, xx = np.ogrid[-radius:radius + 1, -radius:radius + 1] circle = yy ** 2 + xx ** 2 <= radius ** 2 rows = np.clip(cy + np.where(circle)[0] - radius, 0, shape[0] - 1) cols = np.clip(cx + np.where(circle)[1] - radius, 0, shape[1] - 1) return rows, cols # ====================================================================== # # NETWORK segmentation # ====================================================================== # def _segment_network(img, method, settings): """ Segment filamentous / reticular organelles. Strategies ---------- 'otsu' : Gaussian smooth -> Otsu -> morphological cleanup 'adaptive' : Gaussian smooth -> adaptive threshold -> cleanup 'ridge' : Ridge filter (Frangi / Sato / Meijering) -> threshold -> label 'hysteresis' : Gaussian smooth -> dual-threshold hysteresis -> cleanup """ if method == 'ridge': return _network_ridge(img, settings) elif method == 'hysteresis': return _network_hysteresis(img, settings) smooth = gaussian(img, sigma=1) if method == 'otsu': thresh_val = threshold_otsu(smooth) binary = smooth > thresh_val elif method == 'adaptive': block = settings['organelle_adaptive_block_size'] offset = settings['organelle_adaptive_offset'] local_thresh = threshold_local(smooth, block_size=block, offset=offset) binary = smooth > local_thresh else: raise ValueError(f"Unsupported network method: {method}") morph_r = max(settings['organelle_morph_radius'] // 2, 1) binary = binary_closing(binary, disk(morph_r)) binary = remove_small_objects(binary, min_size=settings['organelle_min_size']) if settings['organelle_skeletonize']: skeleton = skeletonize(binary) skeleton = binary_dilation(skeleton, disk(1)) return sk_label(skeleton) return sk_label(binary) def _network_ridge(img, settings): """Apply a ridge (tubeness) filter then threshold.""" sigmas = settings['organelle_ridge_sigmas'] filter_name = settings['organelle_ridge_filter'] thresh_method = settings['organelle_network_threshold'] img_norm = _normalize_01(img) ridge_filters = { 'frangi': frangi, 'sato': sato, 'meijering': meijering, } if filter_name not in ridge_filters: raise ValueError( f"organelle_ridge_filter must be one of {list(ridge_filters.keys())}, " f"got '{filter_name}'" ) enhanced = ridge_filters[filter_name](img_norm, sigmas=sigmas, black_ridges=False) if thresh_method == 'otsu': t = threshold_otsu(enhanced) binary = enhanced > t elif thresh_method == 'adaptive': block = settings['organelle_adaptive_block_size'] offset = settings['organelle_adaptive_offset'] local_t = threshold_local(enhanced, block_size=block, offset=offset) binary = enhanced > local_t else: t = threshold_otsu(enhanced) binary = enhanced > t binary = binary_closing(binary, disk(1)) binary = remove_small_objects(binary, min_size=settings['organelle_min_size']) if settings['organelle_skeletonize']: skeleton = skeletonize(binary) skeleton = binary_dilation(skeleton, disk(1)) return sk_label(skeleton) return sk_label(binary) def _network_hysteresis(img, settings): """ Dual-threshold hysteresis segmentation for networks. Uses a high threshold to seed confident regions and a low threshold to extend into dimmer connected filaments. Much better than single- threshold approaches for variable-contrast mitochondria and microtubules. Settings -------- organelle_hysteresis_low : float Low threshold. If <1.0, interpreted as a percentile of the image. organelle_hysteresis_high : float High threshold. If <1.0, interpreted as a percentile of the image. """ low = settings['organelle_hysteresis_low'] high = settings['organelle_hysteresis_high'] smooth = gaussian(img, sigma=1) # Interpret values <1.0 as percentiles if low < 1.0: low = np.percentile(smooth, low * 100) if high < 1.0: high = np.percentile(smooth, high * 100) binary = apply_hysteresis_threshold(smooth, low, high) morph_r = max(settings['organelle_morph_radius'] // 2, 1) binary = binary_closing(binary, disk(morph_r)) binary = remove_small_objects(binary, min_size=settings['organelle_min_size']) if settings['organelle_skeletonize']: skeleton = skeletonize(binary) skeleton = binary_dilation(skeleton, disk(1)) return sk_label(skeleton) return sk_label(binary) # ====================================================================== # # IRREGULAR segmentation # ====================================================================== # def _segment_irregular(img, method, settings): """ Segment irregularly shaped organelles (Golgi, ER cisternae, lysosomes). Strategies ---------- 'otsu' : Gaussian smooth -> Otsu -> morph open/close -> fill holes -> label 'adaptive' : Gaussian smooth -> adaptive threshold -> morph cleanup -> label """ morph_r = settings['organelle_morph_radius'] fill_area = settings['organelle_fill_holes'] smooth = gaussian(img, sigma=max(morph_r / 2, 1)) if method == 'otsu': thresh_val = threshold_otsu(smooth) binary = smooth > thresh_val elif method == 'adaptive': block = settings['organelle_adaptive_block_size'] offset = settings['organelle_adaptive_offset'] local_thresh = threshold_local(smooth, block_size=block, offset=offset) binary = smooth > local_thresh else: raise ValueError(f"Unsupported irregular method: {method}") selem = disk(morph_r) binary = binary_closing(binary, selem) binary = binary_opening(binary, selem) if fill_area > 0: binary = remove_small_holes(binary, area_threshold=fill_area) binary = remove_small_objects(binary, min_size=settings['organelle_min_size']) labeled = _watershed_split(binary, smooth) return labeled # ====================================================================== # # RING segmentation # ====================================================================== # def _segment_ring(img, method, settings): """ Segment hollow / ring-shaped organelles (endosomes, autophagosomes, late endosomes/lysosomes). Strategy -------- 1. Enhance ring edges using Difference of Gaussians. 2. Threshold edges using the specified method. 3. Fill enclosed rings to produce solid instance masks. 4. Filter out objects that lack ring morphology (no bright-dim boundary-interior contrast). Settings -------- organelle_ring_sigma_inner : float Inner sigma for DoG edge enhancement. Default: 1.0. organelle_ring_sigma_outer : float Outer sigma for DoG edge enhancement. Default: 3.0. organelle_ring_min_prominence : float Minimum edge-vs-interior intensity contrast (normalised). Objects below this are removed. Default: 0.1. organelle_ring_fill_method : str 'flood' (fill enclosed holes) or 'convex' (convex hull per component). Default: 'flood'. """ sigma_inner = settings.get('organelle_ring_sigma_inner', 1.0) sigma_outer = settings.get('organelle_ring_sigma_outer', 3.0) min_prominence = settings.get('organelle_ring_min_prominence', 0.1) fill_method = settings.get('organelle_ring_fill_method', 'flood') # Step 1: Enhance ring structures using DoG (edge enhancement) img_norm = _normalize_01(img) enhanced = np.abs(difference_of_gaussians(img_norm, sigma_inner, sigma_outer)) # Step 2: Threshold the enhanced image if method == 'otsu': thresh_val = threshold_otsu(enhanced) binary_edges = enhanced > thresh_val elif method == 'adaptive': block = settings['organelle_adaptive_block_size'] offset = settings['organelle_adaptive_offset'] local_thresh = threshold_local(enhanced, block_size=block, offset=offset) binary_edges = enhanced > local_thresh elif method == 'log': blobs = blob_log(img_norm, min_sigma=settings['organelle_log_min_sigma'], max_sigma=settings['organelle_log_max_sigma'], num_sigma=settings['organelle_log_num_sigma'], threshold=settings['organelle_log_threshold']) if len(blobs) == 0: return np.zeros(img.shape, dtype=np.int32) thresh_val = threshold_otsu(enhanced) binary_edges = enhanced > thresh_val elif method == 'dog': thresh_val = threshold_otsu(enhanced) binary_edges = enhanced > thresh_val else: raise ValueError(f"Unsupported ring method: {method}") # Cleanup edges binary_edges = binary_closing(binary_edges, disk(1)) binary_edges = remove_small_objects(binary_edges, min_size=max(settings['organelle_min_size'] // 4, 3)) # Step 3: Fill rings to get solid objects if fill_method == 'flood': filled = _fill_rings_flood(binary_edges) elif fill_method == 'convex': filled = _fill_rings_convex(binary_edges) else: filled = _fill_rings_flood(binary_edges) # Step 4: Remove objects that lack ring morphology labeled = sk_label(filled) labeled = _filter_non_rings(labeled, binary_edges, img_norm, min_prominence) return labeled def _fill_rings_flood(binary_edges): """ Fill ring interiors using flood-fill logic. Inverts the edge mask, labels connected components, and identifies interior regions (not touching the image border) as filled objects. """ inverted = ~binary_edges labeled_bg = sk_label(inverted) border_labels = set() border_labels.update(labeled_bg[0, :].ravel()) border_labels.update(labeled_bg[-1, :].ravel()) border_labels.update(labeled_bg[:, 0].ravel()) border_labels.update(labeled_bg[:, -1].ravel()) filled = binary_edges.copy() for region in regionprops(labeled_bg): if region.label not in border_labels: filled[labeled_bg == region.label] = True return filled def _fill_rings_convex(binary_edges): """ Fill rings using the convex hull of each connected edge component. """ from skimage.morphology import convex_hull_image labeled_edges = sk_label(binary_edges) filled = np.zeros_like(binary_edges) for region in regionprops(labeled_edges): minr, minc, maxr, maxc = region.bbox component = labeled_edges[minr:maxr, minc:maxc] == region.label hull = convex_hull_image(component) filled[minr:maxr, minc:maxc] |= hull return filled def _filter_non_rings(labeled, binary_edges, img_norm, min_prominence): """ Remove objects that lack ring morphology. A ring has a bright boundary with a dimmer interior. We measure this as the absolute difference between the mean edge intensity and the mean interior intensity, normalised by the object's mean intensity. """ props = regionprops(labeled, intensity_image=img_norm) output = labeled.copy() for prop in props: mask = labeled == prop.label edge_mask = mask & binary_edges interior_mask = mask & ~binary_edges if np.sum(edge_mask) == 0 or np.sum(interior_mask) == 0: edge_ratio = np.sum(edge_mask) / max(np.sum(mask), 1) if edge_ratio < 0.3: output[mask] = 0 continue mean_edge = img_norm[edge_mask].mean() mean_interior = img_norm[interior_mask].mean() object_mean = img_norm[mask].mean() if object_mean > 0: prominence = abs(mean_edge - mean_interior) / object_mean else: prominence = 0 if prominence < min_prominence: output[mask] = 0 return sk_label(output > 0) # ====================================================================== # # Shared helpers # ====================================================================== # def _normalize_01(img): """Percentile-based normalisation to [0, 1].""" img_norm = img.astype(np.float64) pmin, pmax = np.percentile(img_norm, (1, 99)) if pmax - pmin > 0: img_norm = np.clip((img_norm - pmin) / (pmax - pmin), 0, 1) else: img_norm = np.zeros_like(img_norm) return img_norm def _watershed_split(binary, intensity): """ Marker-controlled watershed to separate touching objects in a binary mask. Uses distance transform peaks as markers. """ distance = distance_transform_edt(binary) coords = peak_local_max(distance, min_distance=5, labels=binary) if len(coords) == 0: return sk_label(binary) markers = np.zeros(binary.shape, dtype=np.int32) for i, (r, c) in enumerate(coords, start=1): markers[r, c] = i labeled = watershed(-distance, markers, mask=binary) return labeled def _postprocess_masks(masks, min_size=10, max_size=None, remove_border=False): """ Apply size filtering and optional border removal to a list of label masks. Returns ------- list of np.ndarray Cleaned label arrays. """ processed = [] for mask in masks: mask = mask.copy() if remove_border: border_labels = set() border_labels.update(mask[0, :].ravel()) border_labels.update(mask[-1, :].ravel()) border_labels.update(mask[:, 0].ravel()) border_labels.update(mask[:, -1].ravel()) border_labels.discard(0) for lbl in border_labels: mask[mask == lbl] = 0 if min_size > 0 or max_size is not None: props = regionprops(mask) for prop in props: if prop.area < min_size: mask[mask == prop.label] = 0 elif max_size is not None and prop.area > max_size: mask[mask == prop.label] = 0 mask = sk_label(mask > 0) processed.append(mask) return processed