Source code for spacr.timelapse

import cv2, os, re, glob, random, btrack, sqlite3
import numpy as np
import pandas as pd
from collections import defaultdict
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
import matplotlib as mpl
from IPython.display import display
from IPython.display import Image as ipyimage
import trackpy as tp
from btrack import datasets as btrack_datasets
from skimage.measure import regionprops_table
from scipy.signal import find_peaks
from scipy.optimize import curve_fit, linear_sum_assignment
from xgboost import XGBClassifier
from sklearn.model_selection import train_test_split

from multiprocessing import Pool, cpu_count
import logging

try:
    from numpy import trapz
except ImportError:
    from scipy.integrate import trapz
    
import matplotlib.pyplot as plt

import logging
from spacr.utils import debug


def _npz_to_movie(arrays, filenames, save_path, fps=10):
    """
    Convert a list of numpy arrays to a movie file.

    Args:
        arrays (List[np.ndarray]): List of numpy arrays representing frames of the movie.
        filenames (List[str]): List of filenames corresponding to each frame.
        save_path (str): Path to save the movie file.
        fps (int, optional): Frames per second of the movie. Defaults to 10.

    Returns:
        None
    """
    # Define the codec and create VideoWriter object
    fourcc = cv2.VideoWriter_fourcc(*'XVID')
    if save_path.endswith('.mp4'):
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')

    # Initialize VideoWriter with the size of the first image
    height, width = arrays[0].shape[:2]
    out = cv2.VideoWriter(save_path, fourcc, fps, (width, height))

    for i, frame in enumerate(arrays):
        # Handle float32 images by scaling or normalizing
        if frame.dtype == np.float32:
            frame = np.clip(frame, 0, 1)
            frame = (frame * 255).astype(np.uint8)

        # Convert 16-bit image to 8-bit
        elif frame.dtype == np.uint16:
            frame = cv2.convertScaleAbs(frame, alpha=(255.0/65535.0))

        # Handling 1-channel (grayscale) or 2-channel images
        if frame.ndim == 2 or (frame.ndim == 3 and frame.shape[2] in [1, 2]):
            if frame.ndim == 2 or frame.shape[2] == 1:
                # Convert grayscale to RGB
                frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR)
            elif frame.shape[2] == 2:
                # Create an RGB image with the first channel as red, second as green, blue set to zero
                rgb_frame = np.zeros((height, width, 3), dtype=np.uint8)
                rgb_frame[..., 0] = frame[..., 0]  # Red channel
                rgb_frame[..., 1] = frame[..., 1]  # Green channel
                frame = rgb_frame

        # For 3-channel images, ensure it's in BGR format for OpenCV
        elif frame.shape[2] == 3:
            frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)

        # Add filenames as text on frames
        cv2.putText(frame, filenames[i], (10, height - 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA)

        out.write(frame)

    out.release()
    print(f"Movie saved to {save_path}")
    
def _scmovie(folder_paths):
        """
        Generate movies from a collection of PNG images in the given folder paths.

        Args:
            folder_paths (list): List of folder paths containing PNG images.

        Returns:
            None
        """
        folder_paths = list(set(folder_paths))
        for folder_path in folder_paths:
            movie_path = os.path.join(folder_path, 'movies')
            os.makedirs(movie_path, exist_ok=True)
            # Regular expression to parse the filename
            filename_regex = re.compile(r'(\w+)_(\w+)_(\w+)_(\d+)_(\d+).png')
            # Dictionary to hold lists of images by plate, well, field, and object number
            grouped_images = defaultdict(list)
            # Iterate over all PNG files in the folder
            for filename in os.listdir(folder_path):
                if filename.endswith('.png'):
                    match = filename_regex.match(filename)
                    if match:
                        plate, well, field, time, object_number = match.groups()
                        key = (plate, well, field, object_number)
                        grouped_images[key].append((int(time), os.path.join(folder_path, filename)))
            for key, images in grouped_images.items():
                # Sort images by time using sorted and lambda function for custom sort key
                images = sorted(images, key=lambda x: x[0])
                _, image_paths = zip(*images)
                # Determine the size to which all images should be padded
                max_height = max_width = 0
                for image_path in image_paths:
                    image = cv2.imread(image_path)
                    h, w, _ = image.shape
                    max_height, max_width = max(max_height, h), max(max_width, w)
                # Initialize VideoWriter
                plate, well, field, object_number = key
                output_filename = f"{plate}_{well}_{field}_{object_number}.mp4"
                output_path = os.path.join(movie_path, output_filename)
                fourcc = cv2.VideoWriter_fourcc(*'mp4v')
                video = cv2.VideoWriter(output_path, fourcc, 10, (max_width, max_height))
                # Process each image
                for image_path in image_paths:
                    image = cv2.imread(image_path)
                    h, w, _ = image.shape
                    padded_image = np.zeros((max_height, max_width, 3), dtype=np.uint8)
                    padded_image[:h, :w, :] = image
                    video.write(padded_image)
                video.release()
                
                
def _sort_key(file_path):
    """
    Returns a sort key for the given file path based on the pattern '(\d+)_([A-Z]\d+)_(\d+)_(\d+).npy'.
    The sort key is a tuple containing the plate, well, field, and time values extracted from the file path.
    If the file path does not match the pattern, a default sort key is returned to sort the file as "earliest" or "lowest".

    Args:
        file_path (str): The file path to extract the sort key from.

    Returns:
        tuple: The sort key tuple containing the plate, well, field, and time values.
    """
    match = re.search(r'(\d+)_([A-Z]\d+)_(\d+)_(\d+).npy', os.path.basename(file_path))
    if match:
        plate, well, field, time = match.groups()
        # Assuming plate, well, and field are to be returned as is and time converted to int for sorting
        return (plate, well, field, int(time))
    else:
        # Return a tuple that sorts this file as "earliest" or "lowest"
        return ('', '', '', 0)

def _masks_to_gif(masks, gif_folder, name, filenames, object_type):
    """
    Converts a sequence of masks into a GIF file.

    Args:
        masks (list): List of masks representing the sequence.
        gif_folder (str): Path to the folder where the GIF file will be saved.
        name (str): Name of the GIF file.
        filenames (list): List of filenames corresponding to each mask in the sequence.
        object_type (str): Type of object represented by the masks.

    Returns:
        None
    """

    from .io import _save_mask_timelapse_as_gif

    def _display_gif(path):
        with open(path, 'rb') as file:
            display(ipyimage(file.read()))

    highest_label = max(np.max(mask) for mask in masks)
    random_colors = np.random.rand(highest_label + 1, 4)
    random_colors[:, 3] = 1  # Full opacity
    random_colors[0] = [0, 0, 0, 1]  # Background color
    cmap = plt.cm.colors.ListedColormap(random_colors)
    norm = plt.cm.colors.Normalize(vmin=0, vmax=highest_label)

    save_path_gif = os.path.join(gif_folder, f'timelapse_masks_{object_type}_{name}.gif')
    _save_mask_timelapse_as_gif(masks, None, save_path_gif, cmap, norm, filenames)
    #_display_gif(save_path_gif)
    
def _timelapse_masks_to_gif(folder_path, mask_channels, object_types):
    """
    Converts a sequence of masks into a timelapse GIF file.

    Args:
        folder_path (str): The path to the folder containing the mask files.
        mask_channels (list): List of channel indices to extract masks from.
        object_types (list): List of object types corresponding to each mask channel.

    Returns:
        None
    """
    master_folder = os.path.dirname(folder_path)
    gif_folder = os.path.join(master_folder, 'movies', 'gif')
    os.makedirs(gif_folder, exist_ok=True)

    paths = glob.glob(os.path.join(folder_path, '*.npy'))
    paths.sort(key=_sort_key)

    organized_files = {}
    for file in paths:
        match = re.search(r'(\d+)_([A-Z]\d+)_(\d+)_\d+.npy', os.path.basename(file))
        if match:
            plate, well, field = match.groups()
            key = (plate, well, field)
            if key not in organized_files:
                organized_files[key] = []
            organized_files[key].append(file)

    for key, file_list in organized_files.items():
        # Generate the name for the GIF based on plate, well, field
        name = f'{key[0]}_{key[1]}_{key[2]}'
        save_path_gif = os.path.join(gif_folder, f'timelapse_masks_{name}.gif')

        for i, mask_channel in enumerate(mask_channels):
            object_type = object_types[i]
            # Initialize an empty list to store masks for the current object type
            mask_arrays = []

            for file in file_list:
                # Load only the current time series array
                array = np.load(file)
                # Append the specific channel mask to the mask_arrays list
                mask_arrays.append(array[:, :, mask_channel])

            # Convert mask_arrays list to a numpy array for processing
            mask_arrays_np = np.array(mask_arrays)
            # Generate filenames for each frame in the time series
            filenames = [os.path.basename(f) for f in file_list]
            # Create the GIF for the current time series and object type
            _masks_to_gif(mask_arrays_np, gif_folder, name, filenames, object_type)
            
def _relabel_masks_based_on_tracks(masks, tracks, mode='btrack'):
    """
    Relabels the masks based on the tracks DataFrame.

    Args:
        masks (ndarray): Input masks array with shape (num_frames, height, width).
        tracks (DataFrame): DataFrame containing track information.
        mode (str, optional): Mode for relabeling. Defaults to 'btrack'.

    Returns:
        ndarray: Relabeled masks array with the same shape and dtype as the input masks.
    """
    # Initialize an array to hold the relabeled masks with the same shape and dtype as the input masks
    relabeled_masks = np.zeros(masks.shape, dtype=masks.dtype)

    # Iterate through each frame
    for frame_number in range(masks.shape[0]):
        # Extract the mapping for the current frame from the tracks DataFrame
        frame_tracks = tracks[tracks['frame'] == frame_number]
        mapping = dict(zip(frame_tracks['original_label'], frame_tracks['track_id']))
        current_mask = masks[frame_number, :, :]

        # Apply the mapping to the current mask
        for original_label, new_label in mapping.items():
            # Where the current mask equals the original label, set it to the new label value
            relabeled_masks[frame_number][current_mask == original_label] = new_label

    return relabeled_masks

def _prepare_for_tracking(mask_array):
    frames = []
    for t, frame in enumerate(mask_array):
        props = regionprops_table(
            frame,
            properties=('label', 'centroid', 'area', 'bbox', 'eccentricity')
        )
        df = pd.DataFrame(props)
        df = df.rename(columns={
            'centroid-0': 'y',
            'centroid-1': 'x',
            'area':       'mass',
            'label':      'original_label'
        })
        df['frame'] = t
        frames.append(df[['frame','y','x','mass','original_label',
                          'bbox-0','bbox-1','bbox-2','bbox-3','eccentricity']])
    return pd.concat(frames, ignore_index=True)

def _track_by_iou(masks, iou_threshold=0.1):
    """
    Build a track table by linking masks frame→frame via IoU.
    Returns a DataFrame with columns [frame, original_label, track_id].
    """
    n_frames = masks.shape[0]
    # 1) initialize: every label in frame 0 starts its own track
    labels0 = np.unique(masks[0])[1:]
    next_track = 1
    track_map = {}  # (frame,label) -> track_id
    for L in labels0:
        track_map[(0, L)] = next_track
        next_track += 1

    # 2) iterate through frames
    for t in range(1, n_frames):
        prev, curr = masks[t-1], masks[t]
        matches = link_by_iou(prev, curr, iou_threshold=iou_threshold)
        used_curr = set()
        # a) assign matched labels to existing tracks
        for L_prev, L_curr in matches:
            tid = track_map[(t-1, L_prev)]
            track_map[(t, L_curr)] = tid
            used_curr.add(L_curr)
        # b) any label in curr not matched → new track
        for L in np.unique(curr)[1:]:
            if L not in used_curr:
                track_map[(t, L)] = next_track
                next_track += 1

    # 3) flatten into DataFrame
    records = []
    for (frame, label), tid in track_map.items():
        records.append({'frame': frame, 'original_label': label, 'track_id': tid})
    return pd.DataFrame(records)




def _find_optimal_search_range(features, initial_search_range=500, increment=10, max_attempts=49, memory=3):
    """
    Find the optimal search range for linking features.

    Args:
        features (list): List of features to be linked.
        initial_search_range (int, optional): Initial search range. Defaults to 500.
        increment (int, optional): Increment value for reducing the search range. Defaults to 10.
        max_attempts (int, optional): Maximum number of attempts to find the optimal search range. Defaults to 49.
        memory (int, optional): Memory parameter for linking features. Defaults to 3.

    Returns:
        int: The optimal search range for linking features.
    """
    optimal_search_range = initial_search_range
    for attempt in range(max_attempts):
        try:
            # Attempt to link features with the current search range
            tracks_df = tp.link(features, search_range=optimal_search_range, memory=memory)
            print(f"Success with search_range={optimal_search_range}")
            return optimal_search_range
        except Exception as e:
            #print(f"SubnetOversizeException with search_range={optimal_search_range}: {e}")
            optimal_search_range -= increment
            print(f'Retrying with displacement value: {optimal_search_range}', end='\r', flush=True)
    min_range = initial_search_range-(max_attempts*increment)
    if optimal_search_range <= min_range:
        print(f'timelapse_displacement={optimal_search_range} is too high. Lower timelapse_displacement or set to None for automatic thresholding.')
    return optimal_search_range

def _remove_objects_from_first_frame(masks, percentage=10):
        """
        Removes a specified percentage of objects from the first frame of a sequence of masks.

        Parameters:
        masks (ndarray): Sequence of masks representing the frames.
        percentage (int): Percentage of objects to remove from the first frame.

        Returns:
        ndarray: Sequence of masks with objects removed from the first frame.
        """
        first_frame = masks[0]
        unique_labels = np.unique(first_frame[first_frame != 0])
        num_labels_to_remove = max(1, int(len(unique_labels) * (percentage / 100)))
        labels_to_remove = random.sample(list(unique_labels), num_labels_to_remove)

        for label in labels_to_remove:
            masks[0][first_frame == label] = 0
        return masks

def _track_by_iou(masks, iou_threshold=0.1):
    """
    Build a track table by linking masks frame→frame via IoU.
    Returns a DataFrame with columns [frame, original_label, track_id].
    """
    n_frames = masks.shape[0]
    # 1) initialize: every label in frame 0 starts its own track
    labels0 = np.unique(masks[0])[1:]
    next_track = 1
    track_map = {}  # (frame,label) -> track_id
    for L in labels0:
        track_map[(0, L)] = next_track
        next_track += 1

    # 2) iterate through frames
    for t in range(1, n_frames):
        prev, curr = masks[t-1], masks[t]
        matches = link_by_iou(prev, curr, iou_threshold=iou_threshold)
        used_curr = set()
        # a) assign matched labels to existing tracks
        for L_prev, L_curr in matches:
            tid = track_map[(t-1, L_prev)]
            track_map[(t, L_curr)] = tid
            used_curr.add(L_curr)
        # b) any label in curr not matched → new track
        for L in np.unique(curr)[1:]:
            if L not in used_curr:
                track_map[(t, L)] = next_track
                next_track += 1

    # 3) flatten into DataFrame
    records = []
    for (frame, label), tid in track_map.items():
        records.append({'frame': frame, 'original_label': label, 'track_id': tid})
    return pd.DataFrame(records)

def _facilitate_trackin_with_adaptive_removal(masks, search_range=None, max_attempts=5, memory=3, min_mass=50, track_by_iou=False):
    """
    Facilitates object tracking with deterministic initial filtering and
    trackpy’s constant-velocity prediction.

    Args:
        masks (np.ndarray): integer‐labeled masks (frames × H × W).
        search_range (int|None): max displacement; if None, auto‐computed.
        max_attempts (int): how many times to retry with smaller search_range.
        memory (int): trackpy memory parameter.
        min_mass (float): drop any object in frame 0 with area < min_mass.

    Returns:
        masks, features_df, tracks_df

    Raises:
        RuntimeError if linking fails after max_attempts.
    """
    # 1) initial features & filter frame 0 by area
    features = _prepare_for_tracking(masks)
    f0 = features[features['frame'] == 0]
    valid = f0.loc[f0['mass'] >= min_mass, 'original_label'].unique()
    masks[0] = np.where(np.isin(masks[0], valid), masks[0], 0)

    # 2) recompute features on filtered masks
    features = _prepare_for_tracking(masks)

    # 3) default search_range = 2×sqrt(99th‑pct area)
    if search_range is None:
        a99 = f0['mass'].quantile(0.99)
        search_range = max(1, int(2 * np.sqrt(a99)))

    # 4) attempt linking, shrinking search_range on failure
    for attempt in range(1, max_attempts + 1):
        try:
            if track_by_iou:
                tracks_df = _track_by_iou(masks, iou_threshold=0.1)
            else:
                tracks_df = tp.link_df(features,search_range=search_range, memory=memory, predict=True)
                print(f"Linked on attempt {attempt} with search_range={search_range}")
            return masks, features, tracks_df

        except Exception as e:
            search_range = max(1, int(search_range * 0.8))
            print(f"Attempt {attempt} failed ({e}); reducing search_range to {search_range}")

    raise RuntimeError(
        f"Failed to track after {max_attempts} attempts; last search_range={search_range}"
    )

def _trackpy_track_cells(src, name, batch_filenames, object_type, masks, timelapse_displacement, timelapse_memory, timelapse_remove_transient, plot, save, mode, track_by_iou):
        """
        Track cells using the Trackpy library.

        Args:
            src (str): The source file path.
            name (str): The name of the track.
            batch_filenames (list): List of batch filenames.
            object_type (str): The type of object to track.
            masks (list): List of masks.
            timelapse_displacement (int): The displacement for timelapse tracking.
            timelapse_memory (int): The memory for timelapse tracking.
            timelapse_remove_transient (bool): Whether to remove transient objects in timelapse tracking.
            plot (bool): Whether to plot the tracks.
            save (bool): Whether to save the tracks.
            mode (str): The mode of tracking.

        Returns:
            list: The mask stack.

        """
        
        from .plot import _visualize_and_save_timelapse_stack_with_tracks
        from .utils import _masks_to_masks_stack
        
        print(f'Tracking objects with trackpy')

        if timelapse_displacement is None:
            features = _prepare_for_tracking(masks)
            timelapse_displacement = _find_optimal_search_range(features, initial_search_range=500, increment=10, max_attempts=49, memory=3)
            if timelapse_displacement is None:
                timelapse_displacement = 50

        masks, features, tracks_df = _facilitate_trackin_with_adaptive_removal(masks, search_range=timelapse_displacement, max_attempts=100, memory=timelapse_memory, track_by_iou=track_by_iou)

        tracks_df['particle'] += 1

        if timelapse_remove_transient:
            tracks_df_filter = tp.filter_stubs(tracks_df, len(masks))
        else:
            tracks_df_filter = tracks_df.copy()

        tracks_df_filter = tracks_df_filter.rename(columns={'particle': 'track_id'})
        print(f'Removed {len(tracks_df)-len(tracks_df_filter)} objects that were not present in all frames')
        masks = _relabel_masks_based_on_tracks(masks, tracks_df_filter)
        tracks_path = os.path.join(os.path.dirname(src), 'tracks')
        os.makedirs(tracks_path, exist_ok=True)
        tracks_df_filter.to_csv(os.path.join(tracks_path, f'trackpy_tracks_{object_type}_{name}.csv'), index=False)
        if plot or save:
            _visualize_and_save_timelapse_stack_with_tracks(masks, tracks_df_filter, save, src, name, plot, batch_filenames, object_type, mode)

        mask_stack = _masks_to_masks_stack(masks)
        return mask_stack

def _filter_short_tracks(df, min_length=5):
    """Filter out tracks that are shorter than min_length.

    Args:
        df (pandas.DataFrame): The input DataFrame containing track information.
        min_length (int, optional): The minimum length of tracks to keep. Defaults to 5.

    Returns:
        pandas.DataFrame: The filtered DataFrame with only tracks longer than min_length.
    """
    track_lengths = df.groupby('track_id').size()
    long_tracks = track_lengths[track_lengths >= min_length].index
    return df[df['track_id'].isin(long_tracks)]

@debug(enabled=True)
def _btrack_track_cells(src, name, batch_filenames, object_type, plot, save, masks_3D, mode, timelapse_remove_transient, radius=100, n_jobs=10, batch_list=None, optimizer_time_limit_s=120, optimizer_mip_gap=0.01, run_optimization=True, max_objects_for_optimization=20000):
    """
    Track cells using the btrack library.

    Args:
        src (str): The source file path.
        name (str): The name of the track (npz batch name).
        batch_filenames (list[str]): Filenames for frames in this batch.
        object_type (str): The type of object to track (cell, nucleus, pathogen).
        plot (bool): Whether to plot the tracks.
        save (bool): Whether to save plots.
        masks_3D (ndarray or list): 3D label array of masks with shape (T, Y, X),
            or list of 2D (Y, X) label arrays (one per frame).
        mode (str): The tracking mode (unused here but kept for API consistency).
        timelapse_remove_transient (bool): Whether to remove short tracks.
        radius (int or None, optional): Max search radius (pixels). If None,
            it is set automatically to image_width / 20. Defaults to 100.
        n_jobs (int, optional): Number of workers for object extraction. Defaults to 10.
        batch_list (list or None, optional): List of intensity images used by Cellpose.
            Currently not used by btrack (tracking is shape-based here), but kept
            for API compatibility and possible future use.
        optimizer_time_limit_s (float or None): Time limit for GLPK in seconds
            (used only when global optimisation is actually run).
        optimizer_mip_gap (float or None): Relative MIP gap for GLPK (0.01 = 1%).
        run_optimization (bool): If False, skip global optimisation entirely.
        max_objects_for_optimization (int or None): If not None, skip global
            optimisation when the number of objects exceeds this threshold.

    Returns:
        ndarray: The relabelled mask stack (same shape as masks_3D) where labels
        are track IDs.
    """
    import os
    import logging

    import numpy as np
    import pandas as pd
    import btrack
    from btrack import datasets as btrack_datasets
    from btrack.constants import BayesianUpdates

    from .plot import _visualize_and_save_timelapse_stack_with_tracks
    from .utils import _masks_to_masks_stack, _map_wells

    logger = logging.getLogger(__name__)

    logger.debug(
        "Entering _btrack_track_cells: name=%s, object_type=%s, mode=%s",
        name,
        object_type,
        mode,
    )
    logger.debug("src=%s", src)
    logger.debug(
        "timelapse_remove_transient=%s, radius=%s, n_jobs=%s, "
        "optimizer_time_limit_s=%s, optimizer_mip_gap=%s, run_optimization=%s, "
        "max_objects_for_optimization=%s",
        timelapse_remove_transient,
        radius,
        n_jobs,
        optimizer_time_limit_s,
        optimizer_mip_gap,
        run_optimization,
        max_objects_for_optimization,
    )
    logger.debug("masks_3D type: %s", type(masks_3D))

    # ------------------------------------------------------------------
    # Normalise masks_3D to a 3D ndarray (T, Y, X)
    # ------------------------------------------------------------------
    if isinstance(masks_3D, list):
        logger.debug("masks_3D is a list with length=%d", len(masks_3D))
        if len(masks_3D) == 0:
            raise ValueError("masks_3D is an empty list; nothing to track.")
        masks_3D = [np.asarray(m) for m in masks_3D]
        shapes = {m.shape for m in masks_3D}
        logger.debug("Unique mask shapes in list: %s", shapes)
        if len(shapes) != 1:
            raise ValueError(
                f"All masks must have the same shape; got shapes={shapes}."
            )
        masks_3D = np.stack(masks_3D, axis=0)
        logger.debug("Stacked masks_3D into ndarray with shape %s", masks_3D.shape)
    else:
        masks_3D = np.asarray(masks_3D)
        logger.debug("masks_3D array shape: %s", masks_3D.shape)

    if masks_3D.ndim != 3:
        raise ValueError(
            f"masks_3D must be 3D (T, Y, X); got shape {masks_3D.shape}"
        )

    n_frames, height, width = masks_3D.shape
    logger.debug(
        "Parsed geometry: n_frames=%d, height=%d, width=%d",
        n_frames,
        height,
        width,
    )

    # Auto radius if requested
    if radius is None:
        radius = max(1, width // 20)
        logger.debug(
            "radius was None; automatically set radius=%d (width/20)", radius
        )

    # ------------------------------------------------------------------
    # btrack configuration and feature definition
    # ------------------------------------------------------------------
    CONFIG_FILE = btrack_datasets.cell_config()
    
    # Shape-based features only (robust + what your config already expects)
    FEATURES = [
        "area",
        "major_axis_length",
        "minor_axis_length",
        "orientation",
        "solidity",
    ]
    TRACKING_UPDATES = ["motion", "visual"]

    # ------------------------------------------------------------------
    # Convert segmentation to btrack objects
    # ------------------------------------------------------------------
    logger.debug("Converting segmentation to btrack objects...")
    objects = btrack.utils.segmentation_to_objects(
        masks_3D,
        properties=tuple(FEATURES),
        num_workers=n_jobs,
    )
    n_objects = len(objects)
    logger.info("Extracted %d objects for tracking.", n_objects)

    # ------------------------------------------------------------------
    # Run the Bayesian tracker
    # ------------------------------------------------------------------
    with btrack.BayesianTracker() as tracker:
        tracker.configure(CONFIG_FILE)

        # Use APPROXIMATE updates for large datasets (recommended by btrack docs)
        tracker.update_method = BayesianUpdates.APPROXIMATE
        tracker.max_search_radius = radius

        # Features used by the visual model
        tracker.features = FEATURES

        # Append objects and define volume
        tracker.append(objects)
        tracker.volume = ((0, width), (0, height))
        logger.debug(
            "Tracker volume set to x=(0,%d), y=(0,%d); update_method=%s",
            width,
            height,
            tracker.update_method,
        )

        # Tracking
        logger.debug("Starting tracking...")
        try:
            tracker.track(tracking_updates=TRACKING_UPDATES)
        except TypeError:
            # Fallback for older btrack APIs
            logger.debug(
                "tracker.track(tracking_updates=...) not supported; "
                "falling back to tracker.tracking_updates + track(step_size=100)."
            )
            tracker.tracking_updates = [u.upper() for u in TRACKING_UPDATES]
            tracker.track(step_size=100)

        logger.info(
            "Tracking complete. Number of tracks before optimisation: %d",
            len(tracker.tracks),
        )

        # ------------------------------------------------------------------
        # Global optimisation (GLPK) – conditionally disabled for large problems
        # ------------------------------------------------------------------
        do_optimize = bool(run_optimization)

        if max_objects_for_optimization is not None and n_objects > max_objects_for_optimization:
            logger.warning(
                "Skipping btrack global optimisation: %d objects > "
                "max_objects_for_optimization=%d. Using pre-optimisation tracks.",
                n_objects,
                max_objects_for_optimization,
            )
            do_optimize = False

        if do_optimize and len(tracker.tracks) > 0:
            # Build GLPK options from user parameters
            glpk_options = {}
            if optimizer_time_limit_s is not None and optimizer_time_limit_s > 0:
                # GLPK tm_lim is in milliseconds
                glpk_options["tm_lim"] = int(optimizer_time_limit_s * 1000)
            if optimizer_mip_gap is not None and optimizer_mip_gap > 0:
                glpk_options["mip_gap"] = float(optimizer_mip_gap)

            try:
                if glpk_options:
                    logger.info(
                        "Running GLPK optimisation with options: %s", glpk_options
                    )
                    tracker.optimize(
                        backend="glpk",
                        options={"options": glpk_options},
                    )
                else:
                    logger.info("Running GLPK optimisation with default options.")
                    tracker.optimize(backend="glpk")

                logger.info(
                    "Optimisation complete. Number of tracks after optimisation: %d",
                    len(tracker.tracks),
                )
            except Exception as e:
                # If GLPK misbehaves, fall back to pre-optimisation tracks
                logger.warning(
                    "btrack global optimisation failed or stalled (%s). "
                    "Using pre-optimisation tracks instead.",
                    e,
                    exc_info=True,
                )

        # After this point, tracker.tracks always contains the tracks we will use
        tracks = tracker.tracks

    # ------------------------------------------------------------------
    # Convert tracks to DataFrame
    # ------------------------------------------------------------------
    track_data = []
    for track in tracks:
        for t, x, y, z in zip(track.t, track.x, track.y, track.z):
            track_data.append(
                {
                    "track_id": track.ID,
                    "frame": t,
                    "x": x,
                    "y": y,
                    "z": z,
                }
            )

    tracks_df = pd.DataFrame(track_data)
    logger.debug("tracks_df shape: %s", tracks_df.shape)

    # Optionally remove transient tracks (very short trajectories)
    if timelapse_remove_transient and not tracks_df.empty:
        logger.debug("Removing transient tracks with min_length=%d", n_frames)
        tracks_df = _filter_short_tracks(tracks_df, min_length=n_frames)
        logger.debug("tracks_df shape after filtering: %s", tracks_df.shape)

    # ------------------------------------------------------------------
    # Map track positions back to original labels
    # ------------------------------------------------------------------
    logger.debug("Preparing objects_df from masks_3D...")
    objects_df = _prepare_for_tracking(masks_3D)
    logger.debug("objects_df shape: %s", objects_df.shape)

    # Harmonise precision before merge
    tracks_df["x"] = tracks_df["x"].round(2)
    tracks_df["y"] = tracks_df["y"].round(2)
    objects_df["x"] = objects_df["x"].round(2)
    objects_df["y"] = objects_df["y"].round(2)

    logger.debug("Merging tracks_df and objects_df on ['frame', 'x', 'y']...")
    merged_df = pd.merge(
        tracks_df,
        objects_df,
        on=["frame", "x", "y"],
        how="inner",
    )
    logger.debug("merged_df shape: %s", merged_df.shape)

    final_df = merged_df[["track_id", "frame", "x", "y", "original_label"]]
    
    try:
        final_df['file_name'] = name
        final_df[['plateID', 'rowID', 'columnID', 'fieldID', 'prcf']] = (final_df['file_name'].apply(lambda fname: pd.Series(_map_wells(fname, timelapse=False))))
        final_df['wellID'] = final_df['file_name'].str.split('_').str[1]
        
    except IndexError:
        logger.warning("Failed to parse plate, well, field from name: %s", name)
    
    # ------------------------------------------------------------------
    # Relabel masks with track IDs
    # ------------------------------------------------------------------
    logger.debug("Relabelling masks based on tracks...")
    masks = _relabel_masks_based_on_tracks(masks_3D, final_df)

    # ------------------------------------------------------------------
    # Save track table
    # ------------------------------------------------------------------
    tracks_path = os.path.join(os.path.dirname(src), "tracks")
    os.makedirs(tracks_path, exist_ok=True)
    out_csv = os.path.join(tracks_path, f"btrack_tracks_{object_type}_{name}.csv")
    logger.debug("Saving track table to %s", out_csv)
    final_df.to_csv(out_csv, index=False)

    # ------------------------------------------------------------------
    # Optional visualisation
    # ------------------------------------------------------------------
    if plot or save:
        logger.debug("Generating visualisation (plot=%s, save=%s)...", plot, save)
        _visualize_and_save_timelapse_stack_with_tracks(
            masks,
            final_df,
            save,
            src,
            name,
            plot,
            batch_filenames,
            object_type,
            mode,
        )

    # Return in your standard mask stack format
    mask_stack = _masks_to_masks_stack(masks)
    logger.debug(
        "Finished _btrack_track_cells. mask_stack shape: %s",
        getattr(mask_stack, "shape", None),
    )
    return mask_stack


[docs] def exponential_decay(x, a, b, c): return a * np.exp(-b * x) + c
[docs] def preprocess_pathogen_data(pathogen_df): # Group by identifiers and count the number of parasites parasite_counts = pathogen_df.groupby(['plateID', 'rowID', 'column_name', 'fieldID', 'timeid', 'pathogen_cell_id']).size().reset_index(name='parasite_count') # Aggregate numerical columns and take the first of object columns agg_funcs = {col: 'mean' if np.issubdtype(pathogen_df[col].dtype, np.number) else 'first' for col in pathogen_df.columns if col not in ['plateID', 'rowID', 'column_name', 'fieldID', 'timeid', 'pathogen_cell_id', 'parasite_count']} pathogen_agg = pathogen_df.groupby(['plateID', 'rowID', 'column_name', 'fieldID', 'timeid', 'pathogen_cell_id']).agg(agg_funcs).reset_index() # Merge the counts back into the aggregated data pathogen_agg = pathogen_agg.merge(parasite_counts, on=['plateID', 'rowID', 'column_name', 'fieldID', 'timeid', 'pathogen_cell_id']) # Remove the object_label column as it corresponds to the pathogen ID not the cell ID if 'object_label' in pathogen_agg.columns: pathogen_agg.drop(columns=['object_label'], inplace=True) # Change the name of pathogen_cell_id to object_label pathogen_agg.rename(columns={'pathogen_cell_id': 'object_label'}, inplace=True) return pathogen_agg
[docs] def plot_data(measurement, group, ax, label, marker='o', linestyle='-'): ax.plot(group['time'], group['delta_' + measurement], marker=marker, linestyle=linestyle, label=label)
[docs] def infected_vs_noninfected(result_df, measurement): # Separate the merged dataframe into two groups based on pathogen_count infected_cells_df = result_df[result_df.groupby('plate_row_column_field_object')['parasite_count'].transform('max') > 0] uninfected_cells_df = result_df[result_df.groupby('plate_row_column_field_object')['parasite_count'].transform('max') == 0] # Plotting fig, axs = plt.subplots(2, 1, figsize=(12, 10), sharex=True) # Plot for cells that were infected at some time for group_id in infected_cells_df['plate_row_column_field_object'].unique(): group = infected_cells_df[infected_cells_df['plate_row_column_field_object'] == group_id] plot_data(measurement, group, axs[0], 'Infected', marker='x') # Plot for cells that were never infected for group_id in uninfected_cells_df['plate_row_column_field_object'].unique(): group = uninfected_cells_df[uninfected_cells_df['plate_row_column_field_object'] == group_id] plot_data(measurement, group, axs[1], 'Uninfected') # Set the titles and labels axs[0].set_title('Cells Infected at Some Time') axs[1].set_title('Cells Never Infected') for ax in axs: ax.set_xlabel('Time') ax.set_ylabel('Normalized Delta ' + measurement) all_timepoints = sorted(result_df['time'].unique()) ax.set_xticks(all_timepoints) ax.set_xticklabels(all_timepoints, rotation=45, ha="right") plt.tight_layout() plt.show()
[docs] def save_figure(fig, src, figure_number): source = os.path.dirname(src) results_fldr = os.path.join(source,'results') os.makedirs(results_fldr, exist_ok=True) fig_loc = os.path.join(results_fldr, f'figure_{figure_number}.pdf') fig.savefig(fig_loc) print(f'Saved figure:{fig_loc}')
[docs] def save_results_dataframe(df, src, results_name): source = os.path.dirname(src) results_fldr = os.path.join(source,'results') os.makedirs(results_fldr, exist_ok=True) csv_loc = os.path.join(results_fldr, f'{results_name}.csv') df.to_csv(csv_loc, index=True) print(f'Saved results:{csv_loc}')
[docs] def summarize_per_well(peak_details_df): # Step 1: Split the 'ID' column split_columns = peak_details_df['ID'].str.split('_', expand=True) peak_details_df[['plateID', 'rowID', 'columnID', 'fieldID', 'object_number']] = split_columns # Step 2: Create 'well_ID' by combining 'rowID' and 'columnID' peak_details_df['well_ID'] = peak_details_df['rowID'] + '_' + peak_details_df['columnID'] # Filter entries where 'amplitude' is not null filtered_df = peak_details_df[peak_details_df['amplitude'].notna()] # Preparation for Step 3: Identify numeric columns for averaging from the filtered dataframe numeric_cols = filtered_df.select_dtypes(include=['number']).columns # Step 3: Calculate summary statistics summary_df = filtered_df.groupby('well_ID').agg( peaks_per_well=('ID', 'size'), unique_IDs_with_amplitude=('ID', 'nunique'), # Count unique IDs per well with non-null amplitude **{col: (col, 'mean') for col in numeric_cols} # exclude 'amplitude' from averaging if it's numeric ).reset_index() # Step 3: Calculate summary statistics summary_df_2 = peak_details_df.groupby('well_ID').agg( cells_per_well=('object_number', 'nunique'), ).reset_index() summary_df['cells_per_well'] = summary_df_2['cells_per_well'] summary_df['peaks_per_cell'] = summary_df['peaks_per_well'] / summary_df['cells_per_well'] return summary_df
[docs] def summarize_per_well_inf_non_inf(peak_details_df): # Step 1: Split the 'ID' column split_columns = peak_details_df['ID'].str.split('_', expand=True) peak_details_df[['plateID', 'rowID', 'columnID', 'fieldID', 'object_number']] = split_columns # Step 2: Create 'well_ID' by combining 'rowID' and 'columnID' peak_details_df['well_ID'] = peak_details_df['rowID'] + '_' + peak_details_df['columnID'] # Assume 'pathogen_count' indicates infection if > 0 # Add an 'infected_status' column to classify cells peak_details_df['infected_status'] = peak_details_df['infected'].apply(lambda x: 'infected' if x > 0 else 'non_infected') # Preparation for Step 3: Identify numeric columns for averaging numeric_cols = peak_details_df.select_dtypes(include=['number']).columns # Step 3: Calculate summary statistics summary_df = peak_details_df.groupby(['well_ID', 'infected_status']).agg( cells_per_well=('object_number', 'nunique'), peaks_per_well=('ID', 'size'), **{col: (col, 'mean') for col in numeric_cols} ).reset_index() # Calculate peaks per cell summary_df['peaks_per_cell'] = summary_df['peaks_per_well'] / summary_df['cells_per_well'] return summary_df
[docs] def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intensity', size_filter='cell_area', fluctuation_threshold=0.25, num_lines=None, peak_height=0.01, pathogen=None, cytoplasm=None, remove_transient=True, verbose=False, transience_threshold=0.9): # Load data conn = sqlite3.connect(db_loc) # Load cell table cell_df = pd.read_sql(f"SELECT * FROM {'cell'}", conn) if pathogen: pathogen_df = pd.read_sql("SELECT * FROM pathogen", conn) pathogen_df['pathogen_cell_id'] = pathogen_df['pathogen_cell_id'].astype(float).astype('Int64') pathogen_df = preprocess_pathogen_data(pathogen_df) cell_df = cell_df.merge(pathogen_df, on=['plateID', 'rowID', 'column_name', 'fieldID', 'timeid', 'object_label'], how='left', suffixes=('', '_pathogen')) cell_df['parasite_count'] = cell_df['parasite_count'].fillna(0) print(f'After pathogen merge: {len(cell_df)} objects') # Optionally load cytoplasm table and merge if cytoplasm: cytoplasm_df = pd.read_sql(f"SELECT * FROM {'cytoplasm'}", conn) # Merge on specified columns cell_df = cell_df.merge(cytoplasm_df, on=['plateID', 'rowID', 'column_name', 'fieldID', 'timeid', 'object_label'], how='left', suffixes=('', '_cytoplasm')) print(f'After cytoplasm merge: {len(cell_df)} objects') conn.close() # Continue with your existing processing on cell_df now containing merged data... # Prepare DataFrame (use cell_df instead of df) prcf_components = cell_df['prcf'].str.split('_', expand=True) cell_df['plateID'] = prcf_components[0] cell_df['rowID'] = prcf_components[1] cell_df['columnID'] = prcf_components[2] cell_df['fieldID'] = prcf_components[3] cell_df['time'] = prcf_components[4].str.extract('t(\d+)').astype(int) cell_df['object_number'] = cell_df['object_label'] cell_df['plate_row_column_field_object'] = cell_df['plateID'].astype(str) + '_' + cell_df['rowID'].astype(str) + '_' + cell_df['columnID'].astype(str) + '_' + cell_df['fieldID'].astype(str) + '_' + cell_df['object_label'].astype(str) df = cell_df.copy() # Fit exponential decay model to all scaled fluorescence data try: params, _ = curve_fit(exponential_decay, df['time'], df[measurement], p0=[max(df[measurement]), 0.01, min(df[measurement])], maxfev=10000) df['corrected_' + measurement] = df[measurement] / exponential_decay(df['time'], *params) except RuntimeError as e: print(f"Curve fitting failed for the entire dataset with error: {e}") return if verbose: print(f'Analyzing: {len(df)} objects') # Normalizing corrected fluorescence for each cell corrected_dfs = [] peak_details_list = [] total_timepoints = df['time'].nunique() size_filter_removed = 0 transience_removed = 0 for unique_id, group in df.groupby('plate_row_column_field_object'): group = group.sort_values('time') if remove_transient: threshold = int(transience_threshold * total_timepoints) if verbose: print(f'Group length: {len(group)} Timelapse length: {total_timepoints}, threshold:{threshold}') if len(group) <= threshold: transience_removed += 1 if verbose: print(f'removed group {unique_id} due to transience') continue size_diff = group[size_filter].std() / group[size_filter].mean() if size_diff <= fluctuation_threshold: group['delta_' + measurement] = group['corrected_' + measurement].diff().fillna(0) corrected_dfs.append(group) # Detect peaks peaks, properties = find_peaks(group['delta_' + measurement], height=peak_height) # Set values < 0 to 0 group_filtered = group.copy() group_filtered['delta_' + measurement] = group['delta_' + measurement].clip(lower=0) above_zero_auc = trapz(y=group_filtered['delta_' + measurement], x=group_filtered['time']) auc = trapz(y=group['delta_' + measurement], x=group_filtered['time']) is_infected = (group['parasite_count'] > 0).any() if is_infected: is_infected = 1 else: is_infected = 0 if len(peaks) == 0: peak_details_list.append({ 'ID': unique_id, 'plateID': group['plateID'].iloc[0], 'rowID': group['rowID'].iloc[0], 'columnID': group['columnID'].iloc[0], 'fieldID': group['fieldID'].iloc[0], 'object_number': group['object_number'].iloc[0], 'time': np.nan, # The time of the peak 'amplitude': np.nan, 'delta': np.nan, 'AUC': auc, 'AUC_positive': above_zero_auc, 'AUC_peak': np.nan, 'infected': is_infected }) # Inside the for loop where peaks are detected for i, peak in enumerate(peaks): amplitude = properties['peak_heights'][i] peak_time = group['time'].iloc[peak] pathogen_count_at_peak = group['parasite_count'].iloc[peak] start_idx = max(peak - 1, 0) end_idx = min(peak + 1, len(group) - 1) # Using indices to slice for AUC calculation peak_segment_y = group['delta_' + measurement].iloc[start_idx:end_idx + 1] peak_segment_x = group['time'].iloc[start_idx:end_idx + 1] peak_auc = trapz(y=peak_segment_y, x=peak_segment_x) peak_details_list.append({ 'ID': unique_id, 'plateID': group['plateID'].iloc[0], 'rowID': group['rowID'].iloc[0], 'columnID': group['columnID'].iloc[0], 'fieldID': group['fieldID'].iloc[0], 'object_number': group['object_number'].iloc[0], 'time': peak_time, # The time of the peak 'amplitude': amplitude, 'delta': group['delta_' + measurement].iloc[peak], 'AUC': auc, 'AUC_positive': above_zero_auc, 'AUC_peak': peak_auc, 'infected': pathogen_count_at_peak }) else: size_filter_removed += 1 if verbose: print(f'Removed {size_filter_removed} objects due to size filter fluctuation') print(f'Removed {transience_removed} objects due to transience') if len(corrected_dfs) > 0: result_df = pd.concat(corrected_dfs) else: print("No suitable cells found for analysis") return peak_details_df = pd.DataFrame(peak_details_list) summary_df = summarize_per_well(peak_details_df) summary_df_inf_non_inf = summarize_per_well_inf_non_inf(peak_details_df) save_results_dataframe(df=peak_details_df, src=db_loc, results_name='peak_details') save_results_dataframe(df=result_df, src=db_loc, results_name='results') save_results_dataframe(df=summary_df, src=db_loc, results_name='well_results') save_results_dataframe(df=summary_df_inf_non_inf, src=db_loc, results_name='well_results_inf_non_inf') # Plotting fig, ax = plt.subplots(figsize=(10, 8)) sampled_groups = result_df['plate_row_column_field_object'].unique() if num_lines is not None and 0 < num_lines < len(sampled_groups): sampled_groups = np.random.choice(sampled_groups, size=num_lines, replace=False) for group_id in sampled_groups: group = result_df[result_df['plate_row_column_field_object'] == group_id] ax.plot(group['time'], group['delta_' + measurement], marker='o', linestyle='-') ax.set_xticks(sorted(df['time'].unique())) ax.set_xticklabels(sorted(df['time'].unique()), rotation=45, ha="right") ax.set_title(f'Normalized Delta of {measurement} Over Time (Corrected for Photobleaching)') ax.set_xlabel('Time') ax.set_ylabel('Normalized Delta ' + measurement) plt.tight_layout() plt.show() save_figure(fig, src=db_loc, figure_number=1) if pathogen: infected_vs_noninfected(result_df, measurement) save_figure(fig, src=db_loc, figure_number=2) # Identify cells with and without pathogens infected_cells = result_df[result_df.groupby('plate_row_column_field_object')['parasite_count'].transform('max') > 0]['plate_row_column_field_object'].unique() noninfected_cells = result_df[result_df.groupby('plate_row_column_field_object')['parasite_count'].transform('max') == 0]['plate_row_column_field_object'].unique() # Peaks in infected and noninfected cells infected_peaks = peak_details_df[peak_details_df['ID'].isin(infected_cells)] noninfected_peaks = peak_details_df[peak_details_df['ID'].isin(noninfected_cells)] # Calculate the average number of peaks per cell avg_inf_peaks_per_cell = len(infected_peaks) / len(infected_cells) if len(infected_cells) > 0 else 0 avg_non_inf_peaks_per_cell = len(noninfected_peaks) / len(noninfected_cells) if len(noninfected_cells) > 0 else 0 print(f'Average number of peaks per infected cell: {avg_inf_peaks_per_cell:.2f}') print(f'Average number of peaks per non-infected cell: {avg_non_inf_peaks_per_cell:.2f}') print(f'done') return result_df, peak_details_df, fig
def _generate_mask_random_cmap(mask): """ Generate a random colormap based on the unique labels in the given mask. Parameters ---------- mask : ndarray 2D label mask. Background must be 0, objects > 0. Returns ------- mpl.colors.ListedColormap Random colormap with a fixed black background (label 0). """ unique_labels = np.unique(mask) # Only count non-zero labels as objects num_objects = np.sum(unique_labels != 0) # +1 so index 0 is background random_colors = np.random.rand(num_objects + 1, 4) random_colors[:, 3] = 1.0 # full alpha # background = black, fully opaque random_colors[0, :] = [0.0, 0.0, 0.0, 1.0] return mpl.colors.ListedColormap(random_colors)
[docs] def create_results_figure(): """ Create a Figure with 3 subplots arranged as: - PCA (top-left) - XGBoost (top-right) - Histogram (bottom spanning both columns) Returns ------- fig : Figure ax_pca, ax_xgb, ax_hist : matplotlib.axes.Axes """ fig = Figure(figsize=(7, 6), dpi=100) gs = fig.add_gridspec(2, 2, height_ratios=[2, 1]) ax_pca = fig.add_subplot(gs[0, 0]) ax_xgb = fig.add_subplot(gs[0, 1]) ax_hist = fig.add_subplot(gs[1, :]) return fig, ax_pca, ax_xgb, ax_hist
def _make_intensity_motility_panel( all_df, infection_col, track_df, per_well_tracks, n_channels, motility_dir, pixels_per_um, seconds_per_frame, vel_unit, settings, label_tag, ): """ Make panels for infection and motility. Behaviour: - "mask_*" label_tag: classic panel with * per-channel mean intensity (infected vs uninfected) * (optional) pathogen-channel p75 intensity bar * (optional) pathogen/cytoplasm intensity ratio bar * all-tracks motility plot (absolute FOV) * motility origin plots (infected / uninfected) * optional small QC image (feature importance PNG) - "adjusted_*" label_tag: same as mask panel, plus method-specific QC subplots appended: * histogram (if strategy == "histogram") * PCA/UMAP/t-SNE embedding (if strategy == "pca"/"umap"/"tsne") * XGBoost: - probability separation histogram - feature-importance barplot """ import os import numpy as np import matplotlib.pyplot as plt import matplotlib.image as mpimg # used for the small QC PNG in mask panel if all_df.empty or track_df.empty or not per_well_tracks: print(f"[_make_intensity_motility_panel] No data for panel '{label_tag}', skipping.") return os.makedirs(motility_dir, exist_ok=True) key_cols = ["plateID", "wellID", "fieldID", "cellID"] # ------------------------------------------------------------------ # Panel type / strategy / QC payload availability # ------------------------------------------------------------------ label_lower = str(label_tag).lower() is_mask_panel = label_lower.startswith("mask") is_adjusted_panel = label_lower.startswith("adjusted") qc_strategy = str(settings.get("infection_intensity_strategy", "none")).lower() method_label = qc_strategy if qc_strategy else "none" panel_label = "mask" if is_mask_panel else ("adjusted" if is_adjusted_panel else label_tag) qc_graphs_enabled = bool(settings.get("infection_intensity_qc_graphs", True)) qc_panel_type = settings.get("infection_intensity_qc_panel_type", None) qc_panel_path = settings.get("infection_intensity_qc_panel_path", None) # Global QC payloads (built in QC helpers) hist_data = settings.get("infection_hist_data", None) pca_data = settings.get("infection_pca_data", None) xgb_data = settings.get("infection_xgb_importance", None) has_hist = hist_data is not None has_pca = pca_data is not None has_xgb = xgb_data is not None # Mask panel: small embedded QC PNG if available qc_panel_needed_mask = ( is_mask_panel and qc_graphs_enabled and isinstance(qc_panel_path, str) and qc_panel_path and os.path.exists(qc_panel_path) ) # Adjusted panel: method-specific QC axes qc_axes_count = 0 if is_adjusted_panel and qc_graphs_enabled: # Histogram: we *always* allocate one QC axis and can compute from df_well if qc_strategy == "histogram": qc_axes_count = 1 # PCA/UMAP/t-SNE: need pca_data for embedding elif qc_strategy in {"pca", "umap", "tsne"} and has_pca: qc_axes_count = 1 # XGBoost: probability separation + feature importance elif qc_strategy == "xgboost" and has_xgb: qc_axes_count = 2 # Motility axis limits: driven by motility_xlim / motility_ylim origin_xlim = settings.get("motility_xlim", settings.get("motility_origin_xlim")) origin_ylim = settings.get("motility_ylim", settings.get("motility_origin_ylim")) # Coordinate scaling if pixels_per_um is not None and pixels_per_um > 0: coord_scale = 1.0 / float(pixels_per_um) coord_label_x = "x (µm)" coord_label_y = "y (µm)" else: coord_scale = 1.0 coord_label_x = "x (pixels)" coord_label_y = "y (pixels)" pathogen_chan = settings.get("pathogen_channel", None) # ------------------------------------------------------------------ # Helpers for QC subplots (used in adjusted panel) # ------------------------------------------------------------------ def _plot_hist_qc(ax, source): """ Draw infected vs uninfected intensity histogram. `source` can be: - a dict payload (settings['infection_hist_data']) with keys: 'intensities_inf', 'intensities_uninf', 'bin_edges', 'thr_val' (optional), 'intensity_col' (optional) - or a DataFrame (df_well), in which case the histogram is computed on the fly using a reasonable intensity column and `infection_intensity_n_bins`. """ try: # Case 1: payload dict from settings if isinstance(source, dict): intens_inf = np.asarray(source["intensities_inf"], dtype=float) intens_uninf = np.asarray(source["intensities_uninf"], dtype=float) bin_edges = np.asarray(source["bin_edges"], dtype=float) thr_val = float(source.get("thr_val", np.nan)) intensity_col = source.get("intensity_col", "intensity") else: # Case 2: compute from per-well DataFrame df_vals = source # Decide which intensity column to use intensity_col = settings.get("infection_hist_intensity_col", None) if not intensity_col or intensity_col not in df_vals.columns: cand_cols = [] if pathogen_chan is not None: cand_cols.extend( [ f"cell_mean_intensity_ch{pathogen_chan}", f"cell_p75_intensity_ch{pathogen_chan}", f"pathogen_mean_intensity_ch{pathogen_chan}", ] ) # Fallback: any cell_mean_intensity_ch* cand_cols.extend( [c for c in df_vals.columns if c.startswith("cell_mean_intensity_ch")] ) for c in cand_cols: if c in df_vals.columns: intensity_col = c break if not intensity_col or intensity_col not in df_vals.columns: ax.set_visible(False) return # Collapse to one value per cell-track cell_level = ( df_vals[key_cols + [intensity_col, infection_col]] .groupby(key_cols, dropna=False) .agg({intensity_col: "mean", infection_col: "max"}) .reset_index() ) cell_level = cell_level.replace([np.inf, -np.inf], np.nan) cell_level = cell_level.dropna(subset=[intensity_col]) if cell_level.empty: ax.set_visible(False) return mask_inf = cell_level[infection_col].astype(bool) intens_inf = cell_level.loc[mask_inf, intensity_col].to_numpy() intens_uninf = cell_level.loc[~mask_inf, intensity_col].to_numpy() if intens_inf.size == 0 and intens_uninf.size == 0: ax.set_visible(False) return all_vals = np.concatenate( [arr for arr in (intens_inf, intens_uninf) if arr.size] ) n_bins = int(settings.get("infection_intensity_n_bins", 64) or 64) bin_edges = np.histogram_bin_edges(all_vals, bins=n_bins) thr_val = settings.get( "infection_hist_thr_val", settings.get("infection_intensity_threshold", np.nan), ) thr_val = float(thr_val) if thr_val is not None else np.nan # Now plot ax.hist( intens_uninf, bins=bin_edges, alpha=0.5, color="green", label="Uninfected", ) ax.hist( intens_inf, bins=bin_edges, alpha=0.5, color="red", label="Infected", ) if np.isfinite(thr_val): ax.axvline(thr_val, color="black", linestyle="--", linewidth=1) ax.set_xlabel(intensity_col) ax.set_ylabel("Count") ax.set_title("Pathogen-channel intensity\n(adjusted labels)") ax.legend(fontsize=7) except Exception as e: print(f"[_make_intensity_motility_panel] Histogram payload invalid: {e}") ax.set_visible(False) def _plot_pca_qc(ax, pdata): import numpy as np try: coords = np.asarray(pdata["coords"], dtype=float) labels = np.asarray(pdata["labels"], dtype=bool) except Exception as e: print(f"[_make_intensity_motility_panel] PCA/embedding payload invalid: {e}") ax.set_visible(False) return if coords.ndim != 2 or coords.shape[1] < 2: ax.set_visible(False) return # Method label stored by _infection_qc_pca_clustering: 'PCA', 'UMAP', or 't-SNE' method_label = str(pdata.get("method_label", "PCA")) x = coords[:, 0] y = coords[:, 1] ax.scatter( x[~labels], y[~labels], s=5, alpha=0.4, color="green", label="Uninfected", ) ax.scatter( x[labels], y[labels], s=5, alpha=0.4, color="red", label="Infected", ) # Generic axis labels that respect the embedding method ax.set_xlabel(f"{method_label} 1") ax.set_ylabel(f"{method_label} 2") # Title also reflects method ax.set_title(f"{method_label} of features\n(adjusted labels)") ax.legend(fontsize=7) def _plot_xgb_importance_qc(ax, xdata): try: feat_names = xdata["feature_names"] feat_vals = xdata["feature_importances"] except Exception as e: print(f"[_make_intensity_motility_panel] XGB importance payload invalid: {e}") ax.set_visible(False) return if not feat_names: ax.set_visible(False) return y_pos = np.arange(len(feat_names)) ax.barh(y_pos, feat_vals) ax.set_yticks(y_pos) ax.set_yticklabels(feat_names, fontsize=7) ax.invert_yaxis() ax.set_xlabel("Importance (gain)") ax.set_title("XGBoost feature importance") def _plot_xgb_prob_qc(ax, df_prob): """ Per-cell probability distribution by adjusted infection label. Uses settings['infection_xgb_proba_column'] if available, otherwise falls back through a few common column names. """ # Resolve probability column prob_col_candidates = [] cfg_col = settings.get("infection_xgb_proba_column", None) if isinstance(cfg_col, str) and cfg_col: prob_col_candidates.append(cfg_col) prob_col_candidates.extend( [ "infection_prob", "infection_xgb_proba", "xgb_prob", ] ) prob_col = None for c in prob_col_candidates: if c in df_prob.columns: prob_col = c break if prob_col is None: ax.set_visible(False) return # per-cell probabilities cell_probs = ( df_prob[key_cols + [prob_col, infection_col]] .groupby(key_cols, dropna=False) .agg({prob_col: "mean", infection_col: "max"}) .reset_index() ) cell_probs = cell_probs.replace([np.inf, -np.inf], np.nan) cell_probs = cell_probs.dropna(subset=[prob_col]) if cell_probs.empty: ax.set_visible(False) return mask_inf = cell_probs[infection_col].astype(bool) probs_inf = cell_probs.loc[mask_inf, prob_col].to_numpy() probs_uninf = cell_probs.loc[~mask_inf, prob_col].to_numpy() bins = np.linspace(0.0, 1.0, 21) if probs_uninf.size: ax.hist( probs_uninf, bins=bins, alpha=0.5, color="green", label="Uninfected", ) if probs_inf.size: ax.hist( probs_inf, bins=bins, alpha=0.5, color="red", label="Infected", ) ax.set_xlabel("XGBoost infection probability") ax.set_ylabel("Cells") ax.set_title("Probability separation (adjusted labels)") ax.legend(fontsize=7) def _plot_inf_uninf_bar(ax, df_vals, value_col, title, ylabel): """ Helper to plot infected vs uninfected distributions for the given column, using violin plots (with mean markers) instead of barplots. """ if value_col not in df_vals.columns: ax.set_visible(False) return # Collapse to one value per cell-track cell_level = ( df_vals[key_cols + [value_col, infection_col]] .groupby(key_cols, dropna=False) .agg({value_col: "mean", infection_col: "max"}) .reset_index() ) cell_level = cell_level.replace([np.inf, -np.inf], np.nan) cell_level = cell_level.dropna(subset=[value_col]) if cell_level.empty: ax.set_visible(False) return mask_inf = cell_level[infection_col].astype(bool) vals_inf = cell_level.loc[mask_inf, value_col].to_numpy() vals_uninf = cell_level.loc[~mask_inf, value_col].to_numpy() data = [] positions = [] colors = [] labels_xtick = [] pos = 0 if vals_inf.size: data.append(vals_inf) positions.append(pos) colors.append("red") labels_xtick.append("Inf") pos += 1 if vals_uninf.size: data.append(vals_uninf) positions.append(pos) colors.append("green") labels_xtick.append("Uninf") if not data: ax.set_visible(False) return # Violin plots vp = ax.violinplot( data, positions=positions, widths=0.6, showmeans=False, showmedians=False, showextrema=False, ) # Color each violin (Inf = red, Uninf = green) for body, color in zip(vp["bodies"], colors): body.set_facecolor(color) body.set_edgecolor("black") body.set_alpha(0.6) # Overlay means as black points means = [float(np.nanmean(d)) for d in data] ax.scatter(positions, means, color="black", s=10, zorder=3) ax.set_xticks(positions) ax.set_xticklabels(labels_xtick) # If all values are non-negative, anchor at 0 flat = np.concatenate(data) if np.nanmin(flat) >= 0: ymin, ymax = ax.get_ylim() ax.set_ylim(bottom=0, top=ymax) ax.set_title(title) ax.set_ylabel(ylabel) # ------------------------------------------------------------------ # One figure per (plateID, wellID) # ------------------------------------------------------------------ if not {"plateID", "wellID"}.issubset(all_df.columns): print( "[_make_intensity_motility_panel] Missing 'plateID'/'wellID' columns; " "cannot make per-well panels." ) return unique_wells = ( all_df[["plateID", "wellID"]] .dropna() .drop_duplicates() .to_records(index=False) ) for plate_id, well_id in unique_wells: # Subset data for this well df_well = all_df[ (all_df["plateID"] == plate_id) & (all_df["wellID"] == well_id) ] track_df_well = track_df[ (track_df["plateID"] == plate_id) & (track_df["wellID"] == well_id) ] # Collect tracks for this well from per_well_tracks well_tracks = [] for tracks in per_well_tracks.values(): for tr in tracks: if tr.get("plateID") == plate_id and tr.get("wellID") == well_id: well_tracks.append(tr) if df_well.empty or track_df_well.empty or not well_tracks: print( f"[_make_intensity_motility_panel] No data for plate={plate_id}, " f"well={well_id}; skipping." ) continue # Determine which channels are available *for this well* available_channels = [ ch for ch in range(n_channels) if f"cell_mean_intensity_ch{ch}" in df_well.columns ] if not available_channels: print( f"[_make_intensity_motility_panel] No cell_mean_intensity_ch* " f"columns for plate={plate_id}, well={well_id}; skipping." ) continue # Extra intensity plots for pathogen channel has_p75_path = False has_rel_int = False if pathogen_chan is not None: p75_col = f"cell_p75_intensity_ch{pathogen_chan}" if p75_col in df_well.columns: has_p75_path = True path_col = f"pathogen_mean_intensity_ch{pathogen_chan}" cyto_col = f"cytoplasm_mean_intensity_ch{pathogen_chan}" if path_col in df_well.columns and cyto_col in df_well.columns: has_rel_int = True extra_int_plots = (1 if has_p75_path else 0) + (1 if has_rel_int else 0) n_int_plots = len(available_channels) + extra_int_plots # +3 for: all-tracks motility, infected origin, uninfected origin # +1 for small QC PNG in mask panel, # +qc_axes_count for adjusted panel QC subplots n_cols = n_int_plots + 3 + (1 if qc_panel_needed_mask else 0) + qc_axes_count fig, axes = plt.subplots(1, n_cols, figsize=(4 * n_cols, 4)) if n_cols == 1: axes = np.array([axes]) else: axes = np.array(axes).ravel() axis_idx = 0 # ----- intensity violins per channel (per well) ----- for ch in available_channels: col_int = f"cell_mean_intensity_ch{ch}" ax = axes[axis_idx] axis_idx += 1 _plot_inf_uninf_bar( ax, df_well, value_col=col_int, title=f"Ch {ch} mean", ylabel="Mean cell intensity", ) # If this is the pathogen channel, append p75 and ratio plots if available if pathogen_chan is not None and ch == pathogen_chan: if has_p75_path: ax_p75 = axes[axis_idx] axis_idx += 1 p75_col = f"cell_p75_intensity_ch{pathogen_chan}" _plot_inf_uninf_bar( ax_p75, df_well, value_col=p75_col, title=f"Ch {pathogen_chan} p75", ylabel="Cell p75 intensity", ) if has_rel_int: ax_rel = axes[axis_idx] axis_idx += 1 path_col = f"pathogen_mean_intensity_ch{pathogen_chan}" cyto_col = f"cytoplasm_mean_intensity_ch{pathogen_chan}" df_ratio = df_well[ key_cols + [path_col, cyto_col, infection_col] ].copy() df_ratio["rel_intensity"] = df_ratio[path_col] / df_ratio[ cyto_col ].replace(0, np.nan) _plot_inf_uninf_bar( ax_rel, df_ratio, value_col="rel_intensity", title=f"Ch {pathogen_chan} pathogen/cytoplasm", ylabel="Intensity ratio", ) # ----- all-tracks FOV plot (absolute coordinates) ----- def _plot_all_tracks(ax): if not well_tracks: ax.set_visible(False) return xs_all = [] ys_all = [] n_inf_tr = 0 n_uninf_tr = 0 for tr in well_tracks: x_px = np.asarray(tr["x_px"], dtype=float) y_px = np.asarray(tr["y_px"], dtype=float) if x_px.size < 2: continue x = x_px * coord_scale y = y_px * coord_scale infected_tr = bool(tr.get("infected", False)) color = "red" if infected_tr else "green" ax.plot(x, y, color=color, alpha=0.15, linewidth=0.5) ax.scatter(x[-1], y[-1], color=color, s=5) xs_all.append(x) ys_all.append(y) if infected_tr: n_inf_tr += 1 else: n_uninf_tr += 1 if not xs_all: ax.set_visible(False) return xs_all = np.concatenate(xs_all) ys_all = np.concatenate(ys_all) ax.set_aspect("equal", "box") ax.set_xlabel(coord_label_x) ax.set_ylabel(coord_label_y) # auto limits from data x_margin = 0.05 * (xs_all.max() - xs_all.min() + 1e-9) y_margin = 0.05 * (ys_all.max() - ys_all.min() + 1e-9) ax.set_xlim(xs_all.min() - x_margin, xs_all.max() + x_margin) ax.set_ylim(ys_all.min() - y_margin, ys_all.max() + y_margin) mask_inf = track_df_well["infected"].astype(bool) v_inf = track_df_well.loc[mask_inf, "velocity"].to_numpy() v_uninf = track_df_well.loc[~mask_inf, "velocity"].to_numpy() mean_inf_v = float(np.nanmean(v_inf)) if v_inf.size else np.nan mean_uninf_v = float(np.nanmean(v_uninf)) if v_uninf.size else np.nan txt_lines = [] txt_lines.append(f"Infected ({mean_inf_v:.2f} {vel_unit})") txt_lines.append(f"Uninfected ({mean_uninf_v:.2f} {vel_unit})") if pixels_per_um is not None and pixels_per_um > 0: txt_lines.append(f"1 µm = {pixels_per_um:.2f} px") if seconds_per_frame is not None: txt_lines.append(f"1 frame = {seconds_per_frame:.0f} s") txt = "\n".join(txt_lines) ax.text( 0.98, 0.02, txt, transform=ax.transAxes, ha="right", va="bottom", fontsize=8, bbox=dict(boxstyle="round,pad=0.4", facecolor="white", alpha=0.8), ) ax_all = axes[axis_idx] axis_idx += 1 _plot_all_tracks(ax_all) # ----- motility origin plots (infected vs uninfected) for this well ----- def _plot_origin(ax, want_infected: bool): n_tr = 0 color = "red" if want_infected else "green" for tr in well_tracks: if bool(tr.get("infected", False)) != want_infected: continue x_px = np.asarray(tr["x_px"], dtype=float) y_px = np.asarray(tr["y_px"], dtype=float) if x_px.size < 2: continue x = (x_px - x_px[0]) * coord_scale y = (y_px - y_px[0]) * coord_scale ax.plot(x, y, color=color, alpha=0.15, linewidth=0.5) ax.scatter(x[-1], y[-1], color=color, s=5) n_tr += 1 ax.set_aspect("equal", "box") ax.set_xlabel(coord_label_x) ax.set_ylabel(coord_label_y) if origin_xlim is not None and len(origin_xlim) == 2: ax.set_xlim(origin_xlim) if origin_ylim is not None and len(origin_ylim) == 2: ax.set_ylim(origin_ylim) mask = track_df_well["infected"].astype(bool) if not want_infected: mask = ~mask v = track_df_well.loc[mask, "velocity"].to_numpy() mean_v = float(np.nanmean(v)) if v.size else np.nan label = "Infected" if want_infected else "Uninfected" ax.set_title(f"{label}\n(n={n_tr}, v={mean_v:.2f} {vel_unit})") # infected origin plot ax_inf = axes[axis_idx] axis_idx += 1 _plot_origin(ax_inf, True) # uninfected origin plot ax_uninf = axes[axis_idx] axis_idx += 1 _plot_origin(ax_uninf, False) # ----- optional small QC PNG (mask panel only) ----- if qc_panel_needed_mask and axis_idx < len(axes): ax_qc = axes[axis_idx] axis_idx += 1 try: img = mpimg.imread(qc_panel_path) ax_qc.imshow(img) ax_qc.axis("off") tmap = { "histogram": "Intensity histogram", "pca": "PCA/UMAP clustering", "xgboost": "XGBoost feature importance", } ttl = tmap.get(str(qc_panel_type).lower(), "Infection QC") ax_qc.set_title(ttl, fontsize=9) except Exception as e: print( f"[_make_intensity_motility_panel] Could not embed QC plot " f"from {qc_panel_path}: {e}" ) ax_qc.set_visible(False) # ----- adjusted panel QC subplots: method-specific ----- if is_adjusted_panel and qc_graphs_enabled: if qc_strategy == "histogram" and axis_idx < len(axes): ax_hist = axes[axis_idx] axis_idx += 1 # Prefer global payload if present; otherwise compute from per-well df src = hist_data if hist_data is not None else df_well _plot_hist_qc(ax_hist, src) elif qc_strategy in {"pca", "umap", "tsne"} and has_pca and axis_idx < len(axes): ax_pca = axes[axis_idx] axis_idx += 1 _plot_pca_qc(ax_pca, pca_data) elif qc_strategy == "xgboost" and has_xgb: if axis_idx < len(axes): ax_prob = axes[axis_idx] axis_idx += 1 _plot_xgb_prob_qc(ax_prob, df_well) if axis_idx < len(axes): ax_xgb = axes[axis_idx] axis_idx += 1 _plot_xgb_importance_qc(ax_xgb, xgb_data) # Plate/well tag for title & filename meta_tag = f"{plate_id}_{well_id}" fig.suptitle( f"Infection panel – {panel_label} labels – method={method_label}\n{meta_tag}", fontsize=10, ) fig.tight_layout(rect=[0, 0, 1, 0.90]) # Filenames: # mask/original: plate1_A03.pdf # adjusted: plate1_A03_xgboost_adjusted.pdf if is_adjusted_panel: out_name = f"{meta_tag}_{method_label}_adjusted.pdf" elif is_mask_panel: out_name = f"{meta_tag}.pdf" else: # fallback for any unexpected label_tag out_name = f"{meta_tag}_{label_tag}_{method_label}.pdf" out_path = os.path.join(motility_dir, out_name) fig.savefig(out_path) # PDF inferred from extension plt.close(fig) print( f"[summarise_tracks_from_merged] Saved per-well intensity+motility panel " f"({panel_label}, method={method_label}) for plate={plate_id}, well={well_id} " f"to {out_path}" ) def _infer_plate_well_meta_tag(df): """ Infer a compact 'plate_well' tag for filenames from a DataFrame that has plateID / wellID columns. Examples -------- plate1 + A02 -> 'plate1_A02' plate1 + many -> 'plate1_MULTI_WELLS' many + A02 -> 'MULTI_PLATES_A02' many + many -> 'MULTI_PLATES_MULTI_WELLS' """ plates = sorted(df["plateID"].dropna().unique()) if "plateID" in df.columns else [] wells = sorted(df["wellID"].dropna().unique()) if "wellID" in df.columns else [] if len(plates) == 1 and len(wells) == 1: return f"{plates[0]}_{wells[0]}" elif len(plates) == 1 and len(wells) > 1: return f"{plates[0]}_MULTI_WELLS" elif len(plates) > 1 and len(wells) == 1: return f"MULTI_PLATES_{wells[0]}" else: return "MULTI_PLATES_MULTI_WELLS" def _compute_cell_mean_intensity_per_channel( mask_stack, intensity_stack, channel_index, ): """ Compute per-frame, per-cell mean intensity for a given channel. Parameters ---------- mask_stack : ndarray Label image stack of shape (T, Y, X) for cells (track_id labels). intensity_stack : ndarray Intensity stack of shape (T, Y, X, C). channel_index : int Channel index in intensity_stack to use. Returns ------- DataFrame Columns: ['frame', 'track_id', f'cell_mean_intensity_ch{channel_index}'] """ import numpy as np import pandas as pd if intensity_stack is None: print( f"[cell_mean_intensity] channel {channel_index}: " "intensity_stack is None, skipping." ) return pd.DataFrame( columns=["frame", "track_id", f"cell_mean_intensity_ch{channel_index}"] ) if channel_index is None or channel_index < 0 or channel_index >= intensity_stack.shape[-1]: print( f"[cell_mean_intensity] channel {channel_index}: " "invalid channel index for intensity_stack, skipping." ) return pd.DataFrame( columns=["frame", "track_id", f"cell_mean_intensity_ch{channel_index}"] ) T = mask_stack.shape[0] dfs = [] col_name = f"cell_mean_intensity_ch{channel_index}" for frame in range(T): labels = mask_stack[frame] if not np.any(labels): continue intensity_image = intensity_stack[frame, :, :, channel_index] props_table = regionprops_table( labels, intensity_image=intensity_image, properties=("label", "mean_intensity"), ) frame_df = pd.DataFrame(props_table) frame_df = frame_df.rename( columns={ "label": "track_id", "mean_intensity": col_name, } ) frame_df["frame"] = frame dfs.append(frame_df) if not dfs: print( f"[cell_mean_intensity] channel {channel_index}: " f"no objects found in any of {T} frames." ) return pd.DataFrame(columns=["frame", "track_id", col_name]) out_df = pd.concat(dfs, ignore_index=True) n_rows = out_df.shape[0] n_frames_detected = out_df["frame"].nunique() n_objs = out_df["track_id"].nunique() print( f"[cell_mean_intensity] channel {channel_index}: " f"frames_with_objects={n_frames_detected}/{T}, " f"unique_track_id={n_objs}, rows={n_rows}" ) return out_df def _reorient_merged_array(arr, n_channels, max_extra_masks=3): """ Ensure merged array has shape (planes, H, W) with planes as the first axis. Handles both (planes, H, W) and (H, W, planes) layouts by detecting which axis likely corresponds to the small "planes" dimension (~n_channels + masks). """ import numpy as np if arr.ndim != 3: raise ValueError( f"_reorient_merged_array expected 3D array, got ndim={arr.ndim}" ) target_min = n_channels target_max = n_channels + max_extra_masks shape = arr.shape plane_axis = None for ax, dim in enumerate(shape): if target_min <= dim <= target_max: plane_axis = ax break if plane_axis is None: # Fallback: choose the smallest axis as planes plane_axis = int(np.argmin(shape)) if plane_axis != 0: arr = np.moveaxis(arr, plane_axis, 0) planes, H, W = arr.shape return arr, planes, H, W def _parse_merged_filename(fname): """ Parse a merged .npy filename of the form: plate_well_field_time.npy Returns a dict with: plateID, wellID, rowID, columnID, fieldID, timeID, prcf, prcft, filename """ base = os.path.splitext(os.path.basename(fname))[0] parts = base.split("_") plateID = parts[0] if len(parts) > 0 else "" wellID = parts[1] if len(parts) > 1 else "" fieldID = parts[2] if len(parts) > 2 else "1" time_str = parts[3] if len(parts) > 3 else "0" # Extract numeric time index, tolerate formats like "t000" digits = "".join(ch for ch in time_str if ch.isdigit()) timeID = int(digits) if digits else 0 rowID = wellID[0] if wellID else "" col_part = "".join(ch for ch in wellID[1:] if ch.isdigit()) columnID = int(col_part) if col_part else 0 prcf = f"{plateID}_{wellID}_{fieldID}" prcft = f"{prcf}_{timeID}" meta = dict( plateID=plateID, wellID=wellID, rowID=rowID, columnID=columnID, fieldID=fieldID, timeID=timeID, prcf=prcf, prcft=prcft, filename=os.path.basename(fname), ) return meta def _compute_parent_child_overlaps( parent_masks, child_masks, parent_label_col, child_label_col, ): """ For each frame, find which child labels overlap which parent labels. Returns columns: 'frame', parent_label_col, child_label_col """ T = parent_masks.shape[0] records = [] for frame in range(T): p = parent_masks[frame] c = child_masks[frame] m = (p > 0) & (c > 0) if not np.any(m): continue p_flat = p[m].ravel() c_flat = c[m].ravel() pairs = np.stack([p_flat, c_flat], axis=1) unique_pairs = np.unique(pairs, axis=0) for parent_label, child_label in unique_pairs: records.append( { "frame": frame, parent_label_col: int(parent_label), child_label_col: int(child_label), } ) if not records: return pd.DataFrame(columns=["frame", parent_label_col, child_label_col]) return pd.DataFrame.from_records(records) def _summarise_child_features_per_parent( overlaps_df, child_props_df, parent_label_col, child_label_col, count_col_name, ): """ Summarise child object features per parent object. - Counts distinct children -> count_col_name - Aggregates numeric child features per parent: * '*area*' -> sum * '*intensity*' -> mean * '*dist*'/'*distance*' -> min * everything else -> mean """ if overlaps_df.empty or child_props_df.empty: return pd.DataFrame(columns=["frame", parent_label_col, count_col_name]) df = overlaps_df.merge(child_props_df, on=["frame", child_label_col], how="left") if df.empty: return pd.DataFrame(columns=["frame", parent_label_col, count_col_name]) group_cols = ["frame", parent_label_col] counts = ( df.groupby(group_cols)[child_label_col] .nunique() .reset_index() .rename(columns={child_label_col: count_col_name}) ) numeric_cols = [ c for c in df.columns if c not in group_cols + [child_label_col] and np.issubdtype(df[c].dtype, np.number) ] if not numeric_cols: return counts def _agg_for_feature(col_name: str) -> str: name = col_name.lower() if "area" in name: return "sum" if "intensity" in name: return "mean" if "distance" in name or "dist" in name: return "min" return "mean" agg_dict = {c: _agg_for_feature(c) for c in numeric_cols} agg_df = df.groupby(group_cols).agg(agg_dict).reset_index() summary = agg_df.merge(counts, on=group_cols, how="left") return summary def _load_intensity_stack_from_merged( src, filenames, n_channels, height, width, dtype=np.float32, ): """ Load intensity channels from merged/*.npy into a (T, H, W, C) stack. Supports merged arrays stored either as (planes, H, W) or (H, W, planes). The first n_channels planes are intensities and any remaining planes are masks. """ import os import numpy as np merged_dir = os.path.join(src, "merged") T = len(filenames) if not os.path.isdir(merged_dir) or n_channels is None or n_channels <= 0: return np.zeros((T, height, width, 0), dtype=dtype) stack = np.zeros((T, height, width, n_channels), dtype=dtype) for t, fn in enumerate(filenames): base = os.path.splitext(os.path.basename(fn))[0] candidates = [ os.path.join(merged_dir, base + ".npy"), os.path.join(merged_dir, fn), os.path.join(merged_dir, fn + ".npy"), ] arr = None for path in candidates: if os.path.exists(path): arr = np.load(path) break if arr is None or arr.ndim != 3: continue # Standardise to (planes, H, W) try: arr, planes, H_img, W_img = _reorient_merged_array( arr, n_channels=n_channels ) except ValueError: continue if H_img != height or W_img != width: # Skip unexpected size print( f"[_load_intensity_stack_from_merged] Skipping {fn}: " f"reoriented size=({planes}, {H_img}, {W_img}), " f"expected H={height}, W={width}" ) continue use_planes = min(n_channels, planes) if use_planes <= 0: continue img = arr[:use_planes].transpose(1, 2, 0) # (H, W, C) C = img.shape[2] stack[t, :, :, :C] = img return stack def _load_masks_from_merged( src, filenames, n_channels, height, width, nucleus_chan=None, pathogen_chan=None, dtype=None, ): """ Load cell / nucleus / pathogen masks from merged/*.npy. Supports merged arrays stored either as (planes, H, W) or (H, W, planes). Layout per merged array after reorientation (planes, H, W): 0 .. n_channels-1 → intensity channels n_channels → cell_mask (always present) n_channels + 1 (optional) → nucleus_mask or pathogen_mask n_channels + 2 (optional) → pathogen_mask (when both nuc+pathogen exist) The exact interpretation of mask planes depends on whether `nucleus_chan` and/or `pathogen_chan` are None. """ import os import numpy as np if dtype is None: dtype = np.int32 merged_dir = os.path.join(src, "merged") T = len(filenames) cell_masks = np.zeros((T, height, width), dtype=dtype) nucleus_masks = np.zeros((T, height, width), dtype=dtype) pathogen_masks = np.zeros((T, height, width), dtype=dtype) if not os.path.isdir(merged_dir): return cell_masks, nucleus_masks, pathogen_masks for t, fn in enumerate(filenames): base = os.path.splitext(os.path.basename(fn))[0] candidates = [ os.path.join(merged_dir, base + ".npy"), os.path.join(merged_dir, fn), os.path.join(merged_dir, fn + ".npy"), ] arr = None for path in candidates: if os.path.exists(path): arr = np.load(path) break if arr is None or arr.ndim != 3: continue # Standardise to (planes, H, W) try: arr, planes, H_img, W_img = _reorient_merged_array( arr, n_channels=n_channels ) except ValueError: continue if H_img != height or W_img != width: print( f"[_load_masks_from_merged] Skipping {fn}: " f"reoriented size=({planes}, {H_img}, {W_img}), " f"expected H={height}, W={width}" ) continue if planes <= n_channels: # Only intensity planes, no masks continue n_masks = planes - n_channels # First mask plane is always cell cell_masks[t] = arr[n_channels].astype(dtype) # Second mask plane (if present) is nucleus OR pathogen depending on settings if n_masks >= 2: if nucleus_chan is not None and pathogen_chan is None: nucleus_masks[t] = arr[n_channels + 1].astype(dtype) elif nucleus_chan is None and pathogen_chan is not None: pathogen_masks[t] = arr[n_channels + 1].astype(dtype) elif nucleus_chan is not None and pathogen_chan is not None: # both requested → expect nucleus here nucleus_masks[t] = arr[n_channels + 1].astype(dtype) # Third mask plane (if present) is pathogen when both nuc+pathogen exist if n_masks >= 3 and pathogen_chan is not None: pathogen_masks[t] = arr[n_channels + 2].astype(dtype) return cell_masks, nucleus_masks, pathogen_masks def _compute_regionprops_stack( mask_stack, intensity_stack, channel_index, object_prefix, label_as_track_id=False, ): """ Compute regionprops over a (T, Y, X) label stack. Parameters ---------- mask_stack : ndarray Label image stack of shape (T, Y, X). intensity_stack : ndarray or None Intensity stack of shape (T, Y, X, C) or None. channel_index : int or None Channel index in intensity_stack to use for intensity props. object_prefix : str Prefix for column names ("cell", "nucleus", "pathogen", "cytoplasm"). label_as_track_id : bool If True, rename 'label' to 'track_id', otherwise to f"{object_prefix}_label". Returns ------- DataFrame One row per object per frame with prefixed column names. """ import numpy as np import pandas as pd T, H, W = mask_stack.shape use_intensity = ( intensity_stack is not None and channel_index is not None and 0 <= channel_index < intensity_stack.shape[-1] ) # Avoid properties that rely on normalized central moments geom_props = [ "label", "area", "bbox_area", "equivalent_diameter", "perimeter", "perimeter_crofton", "solidity", "centroid", ] intensity_props = [ "max_intensity", "mean_intensity", "min_intensity", ] props = geom_props + intensity_props if use_intensity else geom_props label_col_name = "track_id" if label_as_track_id else f"{object_prefix}_label" dfs = [] for frame in range(T): labels = mask_stack[frame] if not np.any(labels): continue if use_intensity: intensity_image = intensity_stack[frame, :, :, channel_index] props_table = regionprops_table( labels, intensity_image=intensity_image, properties=props, ) else: props_table = regionprops_table(labels, properties=props) frame_df = pd.DataFrame(props_table) frame_df = frame_df.rename(columns={"label": label_col_name}) frame_df["frame"] = frame feature_cols = [ c for c in frame_df.columns if c not in ("frame", label_col_name) ] frame_df = frame_df.rename( columns={c: f"{object_prefix}_{c}" for c in feature_cols} ) dfs.append(frame_df) if not dfs: print(f"[regionprops] {object_prefix}: no objects found in any of {T} frames.") return pd.DataFrame(columns=["frame", label_col_name]) out_df = pd.concat(dfs, ignore_index=True) n_rows = out_df.shape[0] n_frames_detected = out_df["frame"].nunique() n_objs = out_df[label_col_name].nunique() print( f"[regionprops] {object_prefix}: frames_with_objects=" f"{n_frames_detected}/{T}, unique_{label_col_name}={n_objs}, rows={n_rows}" ) return out_df def _process_merged_group(args): """ Worker: process one (plate, well, field) group of merged .npy files. Returns per-cell-per-frame DataFrame with: - metadata - cell features - aggregated nucleus / pathogen / cytoplasm features - per-channel cell mean intensities (cell_mean_intensity_ch{c}) """ import numpy as np import pandas as pd import os ( src, file_basenames, n_channels, cell_chan, nucleus_chan, pathogen_chan, ) = args if not file_basenames: print("[_process_merged_group] Empty file_basenames list.") return pd.DataFrame() merged_dir = os.path.join(src, "merged") # sort filenames by timeID metas = [] for bn in file_basenames: meta = _parse_merged_filename(bn) metas.append(meta) metas_sorted = sorted(metas, key=lambda m: m["timeID"]) sorted_basenames = [m["filename"] for m in metas_sorted] key = ( metas_sorted[0]["plateID"], metas_sorted[0]["wellID"], metas_sorted[0]["fieldID"], ) print(f"[_process_merged_group] Start group {key}, files={len(sorted_basenames)}") # infer size from first file (respecting orientation) first_path = os.path.join(merged_dir, sorted_basenames[0]) first_arr_raw = np.load(first_path) if first_arr_raw.ndim != 3: print( f"[_process_merged_group] First array for group {key} is not 3D, " "skipping." ) return pd.DataFrame() try: first_arr, planes, H, W = _reorient_merged_array( first_arr_raw, n_channels=n_channels ) except ValueError: print( f"[_process_merged_group] Group {key}: could not reorient first array " f"with shape={first_arr_raw.shape}, skipping." ) return pd.DataFrame() base_dtype = first_arr.dtype print( f"[_process_merged_group] Group {key}: first array original_shape=" f"{first_arr_raw.shape}, reoriented_shape=({planes}, {H}, {W}), " f"dtype={base_dtype}" ) # load stacks intensity_stack = _load_intensity_stack_from_merged( src=src, filenames=sorted_basenames, n_channels=n_channels, height=H, width=W, dtype=base_dtype, ) cell_masks, nucleus_masks, pathogen_masks = _load_masks_from_merged( src=src, filenames=sorted_basenames, n_channels=n_channels, height=H, width=W, nucleus_chan=nucleus_chan, pathogen_chan=pathogen_chan, dtype=np.int32, ) T = cell_masks.shape[0] if T == 0 or not np.any(cell_masks): print(f"[_process_merged_group] Group {key}: no cell masks found, skipping.") return pd.DataFrame() print( f"[_process_merged_group] Group {key}: frames={T}, " f"any_nucleus={np.any(nucleus_masks)}, any_pathogen={np.any(pathogen_masks)}" ) # cytoplasm = cell minus (nucleus union pathogen) has_nucleus = np.any(nucleus_masks) has_pathogen = np.any(pathogen_masks) cytoplasm_masks = None if has_nucleus or has_pathogen: cytoplasm_masks = cell_masks.copy() if has_nucleus: cytoplasm_masks[nucleus_masks > 0] = 0 if has_pathogen: cytoplasm_masks[pathogen_masks > 0] = 0 # regionprops for cell geometry (+ intensities in cell_chan) cell_props_df = _compute_regionprops_stack( mask_stack=cell_masks, intensity_stack=intensity_stack, channel_index=cell_chan, object_prefix="cell", label_as_track_id=True, ) nucleus_props_df = _compute_regionprops_stack( mask_stack=nucleus_masks, intensity_stack=intensity_stack, channel_index=nucleus_chan, object_prefix="nucleus", label_as_track_id=False, ) pathogen_props_df = _compute_regionprops_stack( mask_stack=pathogen_masks, intensity_stack=intensity_stack, channel_index=pathogen_chan, object_prefix="pathogen", label_as_track_id=False, ) cytoplasm_props_df = pd.DataFrame() if cytoplasm_masks is not None and np.any(cytoplasm_masks): cytoplasm_props_df = _compute_regionprops_stack( mask_stack=cytoplasm_masks, intensity_stack=intensity_stack, channel_index=cell_chan, # use same channel as cell by default object_prefix="cytoplasm", label_as_track_id=False, ) # --- per-channel intensity percentiles for each compartment --- percentile_dfs_cell = [] percentile_dfs_nucleus = [] percentile_dfs_pathogen = [] percentile_dfs_cytoplasm = [] for ch in range(n_channels): # cell: track_id labels df_p = _compute_intensity_percentiles_per_channel( mask_stack=cell_masks, intensity_stack=intensity_stack, channel_index=ch, object_prefix="cell", label_as_track_id=True, ) if not df_p.empty: percentile_dfs_cell.append(df_p) # nucleus if np.any(nucleus_masks): df_p_n = _compute_intensity_percentiles_per_channel( mask_stack=nucleus_masks, intensity_stack=intensity_stack, channel_index=ch, object_prefix="nucleus", label_as_track_id=False, ) if not df_p_n.empty: percentile_dfs_nucleus.append(df_p_n) # pathogen if np.any(pathogen_masks): df_p_pa = _compute_intensity_percentiles_per_channel( mask_stack=pathogen_masks, intensity_stack=intensity_stack, channel_index=ch, object_prefix="pathogen", label_as_track_id=False, ) if not df_p_pa.empty: percentile_dfs_pathogen.append(df_p_pa) # cytoplasm if cytoplasm_masks is not None and np.any(cytoplasm_masks): df_p_cy = _compute_intensity_percentiles_per_channel( mask_stack=cytoplasm_masks, intensity_stack=intensity_stack, channel_index=ch, object_prefix="cytoplasm", label_as_track_id=False, ) if not df_p_cy.empty: percentile_dfs_cytoplasm.append(df_p_cy) # merge percentile features into base props if percentile_dfs_cell: tmp = percentile_dfs_cell[0] for df_p in percentile_dfs_cell[1:]: tmp = tmp.merge(df_p, on=["frame", "track_id"], how="outer") cell_props_df = cell_props_df.merge( tmp, on=["frame", "track_id"], how="left" ) if np.any(nucleus_masks) and not nucleus_props_df.empty and percentile_dfs_nucleus: tmp = percentile_dfs_nucleus[0] for df_p in percentile_dfs_nucleus[1:]: tmp = tmp.merge(df_p, on=["frame", "nucleus_label"], how="outer") nucleus_props_df = nucleus_props_df.merge( tmp, on=["frame", "nucleus_label"], how="left" ) if np.any(pathogen_masks) and not pathogen_props_df.empty and percentile_dfs_pathogen: tmp = percentile_dfs_pathogen[0] for df_p in percentile_dfs_pathogen[1:]: tmp = tmp.merge(df_p, on=["frame", "pathogen_label"], how="outer") pathogen_props_df = pathogen_props_df.merge( tmp, on=["frame", "pathogen_label"], how="left" ) if ( cytoplasm_masks is not None and np.any(cytoplasm_masks) and not cytoplasm_props_df.empty and percentile_dfs_cytoplasm ): tmp = percentile_dfs_cytoplasm[0] for df_p in percentile_dfs_cytoplasm[1:]: tmp = tmp.merge(df_p, on=["frame", "cytoplasm_label"], how="outer") cytoplasm_props_df = cytoplasm_props_df.merge( tmp, on=["frame", "cytoplasm_label"], how="left" ) if cell_props_df.empty: print(f"[_process_merged_group] Group {key}: cell_props_df empty, skipping.") return pd.DataFrame() # --- per-channel cell mean intensities (one column per channel) --- per_channel_intensity_dfs = [] for ch in range(n_channels): df_ch = _compute_cell_mean_intensity_per_channel( mask_stack=cell_masks, intensity_stack=intensity_stack, channel_index=ch, ) if not df_ch.empty: per_channel_intensity_dfs.append(df_ch) cell_intensity_df = None if per_channel_intensity_dfs: cell_intensity_df = per_channel_intensity_dfs[0] for df_ch in per_channel_intensity_dfs[1:]: cell_intensity_df = cell_intensity_df.merge( df_ch, on=["frame", "track_id"], how="outer", ) added_cols = [ c for c in cell_intensity_df.columns if c.startswith("cell_mean_intensity_ch") ] print( f"[_process_merged_group] Group {key}: added per-channel cell " f"intensity columns: {added_cols}" ) # overlaps and summaries nucleus_summary = None if has_nucleus: overlaps_cn = _compute_parent_child_overlaps( parent_masks=cell_masks, child_masks=nucleus_masks, parent_label_col="track_id", child_label_col="nucleus_label", ) if not overlaps_cn.empty and not nucleus_props_df.empty: nucleus_summary = _summarise_child_features_per_parent( overlaps_df=overlaps_cn, child_props_df=nucleus_props_df, parent_label_col="track_id", child_label_col="nucleus_label", count_col_name="n_nuclei", ) print( f"[_process_merged_group] Group {key}: nucleus_summary rows=" f"{len(nucleus_summary)}" ) pathogen_summary = None if has_pathogen: overlaps_cp = _compute_parent_child_overlaps( parent_masks=cell_masks, child_masks=pathogen_masks, parent_label_col="track_id", child_label_col="pathogen_label", ) if not overlaps_cp.empty and not pathogen_props_df.empty: pathogen_summary = _summarise_child_features_per_parent( overlaps_df=overlaps_cp, child_props_df=pathogen_props_df, parent_label_col="track_id", child_label_col="pathogen_label", count_col_name="n_pathogens", ) print( f"[_process_merged_group] Group {key}: pathogen_summary rows=" f"{len(pathogen_summary)}" ) cytoplasm_summary = None if ( cytoplasm_masks is not None and np.any(cytoplasm_masks) and not cytoplasm_props_df.empty ): overlaps_cc = _compute_parent_child_overlaps( parent_masks=cell_masks, child_masks=cytoplasm_masks, parent_label_col="track_id", child_label_col="cytoplasm_label", ) if not overlaps_cc.empty: cytoplasm_summary = _summarise_child_features_per_parent( overlaps_df=overlaps_cc, child_props_df=cytoplasm_props_df, parent_label_col="track_id", child_label_col="cytoplasm_label", count_col_name="n_cytoplasm", ) print( f"[_process_merged_group] Group {key}: cytoplasm_summary rows=" f"{len(cytoplasm_summary)}" ) enriched_df = cell_props_df.copy() if nucleus_summary is not None and not nucleus_summary.empty: enriched_df = enriched_df.merge( nucleus_summary, on=["frame", "track_id"], how="left", ) if pathogen_summary is not None and not pathogen_summary.empty: enriched_df = enriched_df.merge( pathogen_summary, on=["frame", "track_id"], how="left", ) if cytoplasm_summary is not None and not cytoplasm_summary.empty: enriched_df = enriched_df.merge( cytoplasm_summary, on=["frame", "track_id"], how="left", ) if cell_intensity_df is not None: enriched_df = enriched_df.merge( cell_intensity_df, on=["frame", "track_id"], how="left", ) # attach metadata (plate, well, field, timeID, etc.) meta_records = [] for local_frame_idx, meta in enumerate(metas_sorted): rec = {"frame": local_frame_idx} rec.update(meta) meta_records.append(rec) meta_df = pd.DataFrame(meta_records) enriched_df = enriched_df.merge(meta_df, on="frame", how="left") enriched_df["cellID"] = enriched_df["track_id"] n_tracks = ( enriched_df[["plateID", "wellID", "fieldID", "cellID"]] .drop_duplicates() .shape[0] ) print( f"[_process_merged_group] Group {key}: enriched_df rows={len(enriched_df)}, " f"unique_tracks={n_tracks}" ) return enriched_df def _smooth_tracks_and_features(df, max_displacement=50.0, zscore_thresh=3.0): """ Smooth cell tracks and a small set of scalar features. - Fixes single-frame "teleport" glitches in centroid position. - Optionally drops tracks with impossible jumps. - Smooths a subset of scalar cell_* features using a z-score heuristic. """ import numpy as np import pandas as pd if df.empty: print("[_smooth_tracks_and_features] Input DataFrame is empty.") return df n_rows_before = len(df) n_tracks_before = df[["plateID", "wellID", "fieldID", "cellID"]].drop_duplicates().shape[0] df = df.sort_values( ["plateID", "wellID", "fieldID", "cellID", "frame"] ).reset_index(drop=True) y_col = "cell_centroid-0" x_col = "cell_centroid-1" if y_col not in df.columns or x_col not in df.columns: print("[_smooth_tracks_and_features] Centroid columns missing, nothing to smooth.") return df drop_indices = set() updates = {} # Only smooth scalar features with well-defined numeric dtype candidate_cols = [ "cell_area", "cell_bbox_area", "cell_equivalent_diameter", "cell_perimeter", "cell_perimeter_crofton", "cell_solidity", "cell_mean_intensity", "cell_max_intensity", "cell_min_intensity", ] cell_feature_cols = [c for c in candidate_cols if c in df.columns] # Ensure we are not writing floats into int columns (avoid FutureWarning) for col in [y_col, x_col] + cell_feature_cols: if col in df.columns and not np.issubdtype(df[col].dtype, np.floating): df[col] = df[col].astype(float) grouped = df.groupby(["plateID", "wellID", "fieldID", "cellID"], sort=False) n_tracks_processed = 0 n_tracks_dropped = 0 n_glitches_fixed = 0 for (plateID, wellID, fieldID, cellID), g in grouped: idx = g.index.to_numpy() if len(idx) < 2: continue n_tracks_processed += 1 y = g[y_col].to_numpy(dtype=float) x = g[x_col].to_numpy(dtype=float) n = len(idx) glitch_frames = set() # --- 1) detect and interpolate single-frame centroid glitches --- if n >= 3: for i_local in range(1, n - 1): y_prev, y_curr, y_next = y[i_local - 1], y[i_local], y[i_local + 1] x_prev, x_curr, x_next = x[i_local - 1], x[i_local], x[i_local + 1] d_prev = np.hypot(y_curr - y_prev, x_curr - x_prev) d_next = np.hypot(y_curr - y_next, x_curr - x_next) d_neigh = np.hypot(y_next - y_prev, x_next - x_prev) if ( d_prev > max_displacement and d_next > max_displacement and d_neigh <= max_displacement ): glitch_frames.add(i_local) # interpolate centroid + scalar features at glitch frames for i_local in glitch_frames: if i_local <= 0 or i_local >= n - 1: continue n_glitches_fixed += 1 y_new = 0.5 * (y[i_local - 1] + y[i_local + 1]) x_new = 0.5 * (x[i_local - 1] + x[i_local + 1]) y[i_local] = y_new x[i_local] = x_new for col in cell_feature_cols: s = g[col].to_numpy(dtype=float) if len(s) < 3: continue s_new = 0.5 * (s[i_local - 1] + s[i_local + 1]) updates.setdefault(col, {})[idx[i_local]] = s_new # --- 2) drop tracks with big jumps not explainable as glitches --- drop_track = False for i_local in range(1, n): d = np.hypot(y[i_local] - y[i_local - 1], x[i_local] - x[i_local - 1]) if ( d > max_displacement and i_local not in glitch_frames and (i_local - 1) not in glitch_frames ): drop_track = True break if drop_track: n_tracks_dropped += 1 drop_indices.update(idx.tolist()) continue # write back smoothed centroid for i_local, global_idx in enumerate(idx): if y[i_local] != g[y_col].iloc[i_local]: updates.setdefault(y_col, {})[global_idx] = y[i_local] if x[i_local] != g[x_col].iloc[i_local]: updates.setdefault(x_col, {})[global_idx] = x[i_local] # --- 3) z-score based smoothing of scalar features --- if len(idx) < 3 or not cell_feature_cols: continue for col in cell_feature_cols: s = g[col].to_numpy(dtype=float) if np.all(~np.isfinite(s)): continue mean = np.nanmean(s) std = np.nanstd(s) if not np.isfinite(std) or std == 0: continue z = (s - mean) / std for i_local in range(1, n - 1): if not np.isfinite(z[i_local]) or abs(z[i_local]) <= zscore_thresh: continue if ( abs(z[i_local - 1]) <= zscore_thresh / 2 and abs(z[i_local + 1]) <= zscore_thresh / 2 ): new_val = 0.5 * (s[i_local - 1] + s[i_local + 1]) updates.setdefault(col, {})[idx[i_local]] = new_val # apply all updates in one go for col, mapping in updates.items(): df.loc[list(mapping.keys()), col] = list(mapping.values()) if drop_indices: df = df.drop(index=list(drop_indices)).reset_index(drop=True) n_rows_after = len(df) n_tracks_after = df[["plateID", "wellID", "fieldID", "cellID"]].drop_duplicates().shape[0] print( "[_smooth_tracks_and_features] rows_before=" f"{n_rows_before}, rows_after={n_rows_after}, " f"tracks_before={n_tracks_before}, tracks_after={n_tracks_after}, " f"tracks_processed={n_tracks_processed}, tracks_dropped={n_tracks_dropped}, " f"glitches_fixed={n_glitches_fixed}" ) return df def _debug_plot_merged_planes(src, sample_filename, n_channels, nucleus_chan, pathogen_chan, out_dir): """ Debug-plot a single merged .npy file. The plot is saved as a PDF and contains: - one panel per raw intensity channel (normalized 2–98 percent) - one panel per mask plane (random colormap) - one panel showing merged intensity channels with all masks overlaid using a random colormap with alpha=0.6. """ import os import numpy as np import matplotlib.pyplot as plt import matplotlib as mpl # needed by _generate_mask_random_cmap if defined elsewhere merged_path = os.path.join(src, "merged", sample_filename) if not os.path.isfile(merged_path): print(f"[_debug_plot_merged_planes] File not found: {merged_path}") return arr = np.load(merged_path) original_shape = arr.shape # Re-orient to (planes, y, x) if arr.ndim == 3: # (Y, X, planes) -> (planes, Y, X) if arr.shape[-1] != n_channels and arr.shape[0] == n_channels: planes = arr else: planes = np.moveaxis(arr, -1, 0) elif arr.ndim == 4: # Take first timepoint; assume (T, Y, X, planes) or similar if arr.shape[-1] >= n_channels: planes = np.moveaxis(arr[0], -1, 0) else: # fallback: collapse time into planes planes = arr.reshape(-1, arr.shape[-2], arr.shape[-1]) else: # Fallback, try to interpret leading axis as planes planes = arr reoriented_shape = planes.shape print( f"[_debug_plot_merged_planes] Sample '{sample_filename}': " f"original_shape={original_shape}, reoriented_shape={reoriented_shape}" ) if planes.ndim != 3: print( f"[_debug_plot_merged_planes] Expected 3D array after reorientation, " f"got shape={planes.shape}; skipping." ) return n_planes = planes.shape[0] if n_planes < n_channels: n_channels = n_planes intensity_planes = planes[:n_channels].astype(float) mask_planes = planes[n_channels:] n_masks = mask_planes.shape[0] # Normalize intensity channels to 2–98 percentiles norm_intensity = [] for ch_idx in range(n_channels): p = intensity_planes[ch_idx].astype(float) lo = np.percentile(p, 2) hi = np.percentile(p, 98) if hi <= lo: p_norm = np.zeros_like(p, dtype=float) else: p_norm = np.clip((p - lo) / (hi - lo), 0.0, 1.0) norm_intensity.append(p_norm) norm_intensity = np.asarray(norm_intensity) if norm_intensity.size == 0: print("[_debug_plot_merged_planes] No intensity channels to plot; skipping.") return H, W = norm_intensity[0].shape # Build RGB merge of intensity channels (up to 3) merged_rgb = np.zeros((H, W, 3), dtype=float) if n_channels >= 1: merged_rgb[..., 0] = norm_intensity[0] # red if n_channels >= 2: merged_rgb[..., 1] = norm_intensity[1] # green if n_channels >= 3: merged_rgb[..., 2] = norm_intensity[2] # blue # Combined mask for overlay combined_mask = None if n_masks > 0: combined_mask = np.zeros((H, W), dtype=int) offset = 0 for m in mask_planes: m_int = m.astype(int) if m_int.max() <= 0: continue nonzero = m_int > 0 combined_mask[nonzero] = m_int[nonzero] + offset offset += int(m_int.max()) if offset == 0: combined_mask = None # Figure layout: channels + masks + merged overlay extra = 1 if combined_mask is not None else 0 n_cols = n_channels + n_masks + extra fig, axes = plt.subplots( 1, n_cols, figsize=(3 * n_cols, 3), dpi=150, squeeze=False, ) axes = axes[0] col_idx = 0 # Intensity channels for ch_idx in range(n_channels): ax = axes[col_idx] col_idx += 1 ax.imshow(norm_intensity[ch_idx], cmap="gray") ax.set_title(f"Ch {ch_idx} (2–98% norm)") ax.axis("off") # Individual mask planes with random cmap for m_idx in range(n_masks): ax = axes[col_idx] col_idx += 1 mask_plane = mask_planes[m_idx] try: random_cmap = _generate_mask_random_cmap(mask_plane) except NameError: # Fallback: create a simple random colormap here unique_labels = np.unique(mask_plane) unique_labels = unique_labels[unique_labels != 0] n_labels = len(unique_labels) rng = np.random.default_rng(seed=42) colors = np.ones((n_labels + 1, 4)) colors[1:, :3] = rng.random((n_labels, 3)) random_cmap = mpl.colors.ListedColormap(colors) ax.imshow(mask_plane, cmap=random_cmap, interpolation="nearest") ax.set_title(f"Mask {m_idx}") ax.axis("off") # Merged channels + combined masks if combined_mask is not None: ax = axes[col_idx] try: merged_cmap = _generate_mask_random_cmap(combined_mask) except NameError: unique_labels = np.unique(combined_mask) unique_labels = unique_labels[unique_labels != 0] n_labels = len(unique_labels) rng = np.random.default_rng(seed=123) colors = np.ones((n_labels + 1, 4)) colors[1:, :3] = rng.random((n_labels, 3)) merged_cmap = mpl.colors.ListedColormap(colors) ax.imshow(merged_rgb) ax.imshow(combined_mask, cmap=merged_cmap, alpha=0.6, interpolation="nearest") ax.set_title("Merged channels + masks") ax.axis("off") fig.tight_layout() base = os.path.splitext(sample_filename)[0] out_path = os.path.join(out_dir, f"merged_planes_{base}.pdf") fig.savefig(out_path, bbox_inches="tight") plt.close(fig) print( f"[_debug_plot_merged_planes] Saved merged plane debug figure to {out_path}" ) def _infection_qc_pca_clustering( all_df, settings, infection_col, pathogen_chan, motility_dir, ): """ Embedding-based infection intensity QC (PCA/UMAP/t-SNE). Steps ----- 1. Aggregate to per-cell features (cell_* columns). 2. Select morphology + pathogen-channel intensity features. 3. Optionally transform features to improve structure: - log1p on intensity-like features (if enabled) - up-weight pathogen-channel features (if configured) 4. Embed to 2D using PCA / UMAP / t-SNE (controlled by settings['infection_intensity_strategy']: 'pca', 'umap', or 'tsne'). For UMAP and t-SNE, an internal hyperparameter search is performed (if enabled) to maximize a separation score based on: - distance between the two clusters in the embedding - how well clusters separate GT infected vs GT uninfected. 5. KMeans clustering (2 clusters) in the embedded space. 6. Define "ground-truth" subsets based on pathogen-channel intensity: - uninfected_gt: lowest 25% of intensities in the UNINFECTED group - infected_gt : highest 25% of intensities in the INFECTED group 7. Assign clusters as "infected" vs "uninfected" based on which ground-truth class dominates each cluster. 8. Depending on infection_intensity_mode: - 'relabel': adjusted_infected = cluster assignment. - 'remove' : drop cells whose original label disagrees with the cluster assignment. 9. Map adjusted_infected back to all_df (frame level). Hyperparameter search --------------------- UMAP: - Enabled if settings.get('infection_pca_umap_search', True) is True. - Grid (overrideable via settings): settings['infection_pca_umap_n_neighbors_grid'] (default [5, 10, 15, 30]) settings['infection_pca_umap_min_dist_grid'] (default [0.0, 0.05, 0.1, 0.3]) - Score = centroid_distance * |GT-infected fraction difference between clusters|. t-SNE: - Enabled if settings.get('infection_pca_tsne_search', True) is True. - Grid (overrideable via settings): settings['infection_pca_tsne_perplexity_grid'] (default [15.0, 30.0, 45.0]) settings['infection_pca_tsne_learning_rate_grid'] (default [200.0, 500.0]) (perplexity is clamped to < (n_samples-1)/3.) - Same score definition as for UMAP. PCA “structure” helpers ----------------------- - Optional log1p transform on intensity-like features (if settings.get('infection_pca_log_intensity', True) is True). - Optional up-weighting of pathogen-channel features after standardization: settings['infection_pca_pathogen_weight'] (default 1.0) Side effects ------------ - settings['infection_pca_data'] with: { 'coords': coords (n_cells x 2), 'labels': adjusted_infected (bool), 'cluster_labels': cluster_ids (0 or 1), 'method_label': 'PCA' / 'UMAP' / 't-SNE', 'infected_cluster': int, 'uninfected_cluster': int, 'initial_infected_frac_infected_cluster': float (0-1), 'initial_infected_frac_uninfected_cluster': float (0-1), 'gt_sep_score': float, 'silhouette_score': float or None, 'centroid_distance': float, 'embedding_params': dict (e.g. {'n_neighbors': 15, 'min_dist': 0.1}), } - settings['infection_intensity_qc_panel_type'] = 'pca' - settings['infection_intensity_qc_panel_path'] = None """ import os import numpy as np import pandas as pd from sklearn.preprocessing import StandardScaler from sklearn.cluster import KMeans from sklearn.metrics import silhouette_score from sklearn.decomposition import PCA # Optional imports for alternative embeddings try: from sklearn.manifold import TSNE except Exception: # optional TSNE = None try: import umap # type: ignore except Exception: # optional umap = None # ------------------------------------------------------------------ # Helper: evaluate an embedding + clustering # ------------------------------------------------------------------ def _evaluate_embedding(coords, cluster_labels, y_orig, gt_uninf, gt_inf): """ Compute: - infected/uninfected cluster mapping using GT sets - GT separation score - silhouette score - centroid distance between the two clusters - original infected fractions in each cluster - overall score = centroid_distance * GT separation """ # GT-based fractional infection per cluster frac_inf_gt = [] for k in (0, 1): mask_k = cluster_labels == k n_inf_k = int(np.sum(mask_k & gt_inf)) n_uninf_k = int(np.sum(mask_k & gt_uninf)) tot_k = n_inf_k + n_uninf_k if tot_k == 0: frac_inf_gt.append(0.5) else: frac_inf_gt.append(n_inf_k / float(tot_k)) infected_cluster = int(np.argmax(frac_inf_gt)) uninfected_cluster = 1 - infected_cluster gt_sep_score = abs(frac_inf_gt[0] - frac_inf_gt[1]) # Centroid distance in embedding space centroids = [] for k in (0, 1): mask_k = cluster_labels == k if mask_k.any(): centroids.append(coords[mask_k].mean(axis=0)) else: centroids.append(np.zeros(coords.shape[1], dtype=float)) centroid_distance = float(np.linalg.norm(centroids[0] - centroids[1])) # Silhouette in embedding space sil = None if coords.shape[0] > 10 and len(np.unique(cluster_labels)) > 1: try: sil = float(silhouette_score(coords, cluster_labels)) except Exception: sil = None # Original infected fractions in each cluster mask_inf_cluster = cluster_labels == infected_cluster mask_uninf_cluster = cluster_labels == uninfected_cluster frac_inf_infected_cluster = ( float(y_orig[mask_inf_cluster].mean()) if mask_inf_cluster.any() else 0.0 ) frac_inf_uninfected_cluster = ( float(y_orig[mask_uninf_cluster].mean()) if mask_uninf_cluster.any() else 0.0 ) # Objective: distance * GT separation score = centroid_distance * gt_sep_score return { "score": score, "infected_cluster": infected_cluster, "uninfected_cluster": uninfected_cluster, "gt_sep_score": gt_sep_score, "silhouette_score": sil, "centroid_distance": centroid_distance, "frac_inf_infected_cluster": frac_inf_infected_cluster, "frac_inf_uninfected_cluster": frac_inf_uninfected_cluster, } # ------------------------------------------------------------------ # Helper: UMAP with hyperparameter search # ------------------------------------------------------------------ def _search_umap(X_scaled, y_orig, gt_uninf, gt_inf, settings_local): if umap is None: raise RuntimeError("umap-learn is not installed.") random_state = int(settings_local.get("infection_pca_random_state", 0)) do_search = bool(settings_local.get("infection_pca_umap_search", True)) # No search: single run with configured/default params if not do_search: n_neighbors = int(settings_local.get("infection_pca_umap_n_neighbors", 15)) min_dist = float(settings_local.get("infection_pca_umap_min_dist", 0.1)) reducer = umap.UMAP( n_components=2, random_state=random_state, n_neighbors=n_neighbors, min_dist=min_dist, ) coords = reducer.fit_transform(X_scaled) kmeans = KMeans( n_clusters=2, random_state=random_state, n_init="auto" ) cluster_labels = kmeans.fit_predict(coords) stats = _evaluate_embedding(coords, cluster_labels, y_orig, gt_uninf, gt_inf) return coords, cluster_labels, stats, {"n_neighbors": n_neighbors, "min_dist": min_dist} # With search: small grid over n_neighbors and min_dist nn_grid = settings_local.get( "infection_pca_umap_n_neighbors_grid", [5, 10, 15, 30] ) md_grid = settings_local.get( "infection_pca_umap_min_dist_grid", [0.0, 0.05, 0.1, 0.3] ) best = None for nn in nn_grid: for md in md_grid: try: reducer = umap.UMAP( n_components=2, random_state=random_state, n_neighbors=int(nn), min_dist=float(md), ) coords = reducer.fit_transform(X_scaled) kmeans = KMeans( n_clusters=2, random_state=random_state, n_init="auto" ) cluster_labels = kmeans.fit_predict(coords) stats = _evaluate_embedding( coords, cluster_labels, y_orig, gt_uninf, gt_inf ) if (best is None) or (stats["score"] > best["stats"]["score"]): best = { "coords": coords, "cluster_labels": cluster_labels, "stats": stats, "params": {"n_neighbors": int(nn), "min_dist": float(md)}, } except Exception as e: print( f"[infection_intensity_qc:PCA] UMAP trial failed for " f"n_neighbors={nn}, min_dist={md}: {e}" ) continue if best is None: raise RuntimeError("UMAP hyperparameter search failed for all trials.") return best["coords"], best["cluster_labels"], best["stats"], best["params"] # ------------------------------------------------------------------ # Helper: t-SNE with hyperparameter search # ------------------------------------------------------------------ def _search_tsne(X_scaled, y_orig, gt_uninf, gt_inf, settings_local): if TSNE is None: raise RuntimeError("sklearn.manifold.TSNE is not available.") random_state = int(settings_local.get("infection_pca_random_state", 0)) do_search = bool(settings_local.get("infection_pca_tsne_search", True)) n_samples = X_scaled.shape[0] max_perp = max(5.0, (n_samples - 1) / 3.0) # Utility: run one t-SNE def _run_tsne(perplexity, learning_rate): tsne = TSNE( n_components=2, random_state=random_state, init="pca", learning_rate=learning_rate, perplexity=perplexity, ) coords_ = tsne.fit_transform(X_scaled) kmeans_ = KMeans( n_clusters=2, random_state=random_state, n_init="auto" ) cluster_labels_ = kmeans_.fit_predict(coords_) stats_ = _evaluate_embedding( coords_, cluster_labels_, y_orig, gt_uninf, gt_inf ) return coords_, cluster_labels_, stats_ # No search: single run with configured/default params if not do_search: base_perp = float(settings_local.get("infection_pca_tsne_perplexity", 30.0)) perplexity = min(base_perp, max_perp) if perplexity <= 0: perplexity = max_perp coords, cluster_labels, stats = _run_tsne(perplexity, learning_rate="auto") return coords, cluster_labels, stats, {"perplexity": perplexity, "learning_rate": "auto"} # With search: grid over perplexity and learning_rate perp_grid = settings_local.get( "infection_pca_tsne_perplexity_grid", [15.0, 30.0, 45.0] ) lr_grid = settings_local.get( "infection_pca_tsne_learning_rate_grid", [200.0, 500.0] ) perp_candidates = [ float(p) for p in perp_grid if float(p) < max_perp and float(p) > 0 ] if not perp_candidates: perp_candidates = [min(30.0, max_perp)] best = None for perp in perp_candidates: for lr in lr_grid: try: coords, cluster_labels, stats = _run_tsne(perp, float(lr)) if (best is None) or (stats["score"] > best["stats"]["score"]): best = { "coords": coords, "cluster_labels": cluster_labels, "stats": stats, "params": {"perplexity": float(perp), "learning_rate": float(lr)}, } except Exception as e: print( f"[infection_intensity_qc:PCA] t-SNE trial failed for " f"perplexity={perp}, learning_rate={lr}: {e}" ) continue if best is None: raise RuntimeError("t-SNE hyperparameter search failed for all trials.") return best["coords"], best["cluster_labels"], best["stats"], best["params"] # ------------------------------------------------------------------ # Main body # ------------------------------------------------------------------ if all_df.empty: print("[infection_intensity_qc:PCA] all_df is empty; skipping embedding QC.") return all_df, infection_col if infection_col not in all_df.columns: print( f"[infection_intensity_qc:PCA] infection_col {infection_col!r} missing; " "skipping embedding QC." ) return all_df, infection_col mode = str(settings.get("infection_intensity_mode", "relabel")).lower() if mode not in {"relabel", "remove"}: print( f"[infection_intensity_qc:PCA] Unsupported mode={mode!r}; " "expected 'relabel' or 'remove'. Skipping embedding QC." ) return all_df, infection_col # ------------------------ # 🔴 Key change is here 🔴 # ------------------------ # Use infection_intensity_strategy to define embedding method strategy = str(settings.get("infection_intensity_strategy", "pca")).lower() if strategy in {"pca", "umap", "tsne"}: embed_method = strategy else: embed_method = "pca" # keep settings in sync so downstream code can use this if needed settings["infection_pca_method"] = embed_method if embed_method not in {"pca", "umap", "tsne"}: embed_method = "pca" key_cols = ["plateID", "wellID", "fieldID", "cellID"] for col in key_cols: if col not in all_df.columns: raise KeyError( f"[infection_intensity_qc:PCA] Required column {col!r} not in all_df." ) # Drop any existing adjusted_infected to avoid _x/_y columns on merge cols_to_drop = [ c for c in all_df.columns if c == "adjusted_infected" or c.startswith("adjusted_infected_") ] if cols_to_drop: all_df = all_df.drop(columns=cols_to_drop) # ------------------------------------------------------------------ # Build per-cell feature table # ------------------------------------------------------------------ numeric_cols = [ c for c in all_df.columns if c.startswith("cell_") and c not in {"cellID"} and pd.api.types.is_numeric_dtype(all_df[c]) ] if not numeric_cols: print( "[infection_intensity_qc:PCA] No numeric cell_* features found; " "skipping embedding QC." ) return all_df, infection_col cols_for_group = key_cols + numeric_cols + [infection_col] tmp = all_df[cols_for_group].copy() tmp.replace([np.inf, -np.inf], np.nan, inplace=True) group = tmp.groupby(key_cols, observed=True) cell_level = group[numeric_cols].median(numeric_only=True).reset_index() # any cell that was ever infected in the time series is treated as infected inf_any = group[infection_col].max().reset_index() cell_level = cell_level.merge(inf_any, on=key_cols, how="left", suffixes=("", "_y")) if infection_col not in cell_level.columns: for cand in (f"{infection_col}_y", f"{infection_col}_x"): if cand in cell_level.columns: cell_level[infection_col] = cell_level[cand] break if infection_col not in cell_level.columns: print( f"[infection_intensity_qc:PCA] Could not recover infection_col={infection_col!r} " "after aggregation; skipping embedding QC." ) return all_df, infection_col cell_level[infection_col] = cell_level[infection_col].fillna(0).astype(bool) # ------------------------------------------------------------------ # Decide pathogen-channel intensity column (needed for ground truth) # ------------------------------------------------------------------ intensity_col = None if pathogen_chan is not None: cand_int = [ f"cell_p95_intensity_ch{pathogen_chan}", f"cell_max_intensity_ch{pathogen_chan}", f"cell_mean_intensity_ch{pathogen_chan}", ] for c in cand_int: if c in cell_level.columns: intensity_col = c break if intensity_col is None: print( "[infection_intensity_qc:PCA] No pathogen-channel cell_* intensity column " "found; skipping embedding QC." ) return all_df, infection_col # ------------------------------------------------------------------ # Select morphology + pathogen-channel features # - morphology: cell_* columns without 'ch' (no per-channel intensity) # - pathogen: cell_* columns that mention ch{pathogen_chan} # ------------------------------------------------------------------ morph_cols = [ c for c in numeric_cols if c.startswith("cell_") and ("ch" not in c.lower()) ] path_cols = [ c for c in numeric_cols if c.startswith("cell_") and f"ch{pathogen_chan}" in c.lower() ] feature_cols = sorted(set(morph_cols + path_cols)) if intensity_col not in feature_cols and intensity_col in cell_level.columns: feature_cols.append(intensity_col) # Drop degenerate features clean_feature_cols = [] for c in feature_cols: s = cell_level[c] if s.notna().sum() < 10: continue if s.nunique(dropna=True) <= 1: continue clean_feature_cols.append(c) feature_cols = clean_feature_cols if not feature_cols: print( "[infection_intensity_qc:PCA] No usable morphology + pathogen features; " "skipping embedding QC." ) return all_df, infection_col # ------------------------------------------------------------------ # Prepare feature matrix + ground-truth subsets # ------------------------------------------------------------------ # Optional log1p transform on intensity-like features to sharpen structure log_intensity = bool(settings.get("infection_pca_log_intensity", True)) cell_for_X = cell_level.copy() if log_intensity: for c in feature_cols: cl = c.lower() if ("intensity" in cl) or ("p75" in cl) or ("p95" in cl) or ("max" in cl): vals = cell_for_X[c].to_numpy(dtype=float) finite = np.isfinite(vals) if finite.any() and np.nanmin(vals[finite]) >= 0: vals[finite] = np.log1p(vals[finite]) cell_for_X[c] = vals X = cell_for_X[feature_cols].to_numpy(dtype=float) y_orig = cell_level[infection_col].astype(bool).to_numpy() # Remove rows with all NaNs finite_counts = np.isfinite(X).sum(axis=1) mask_rows = finite_counts > 0 if not mask_rows.any(): print( "[infection_intensity_qc:PCA] No rows with finite features; " "skipping embedding QC." ) return all_df, infection_col X = X[mask_rows] cell_level = cell_level.loc[mask_rows].reset_index(drop=True) y_orig = y_orig[mask_rows] # Median imputation per feature for j in range(X.shape[1]): col = X[:, j] m = np.isfinite(col) if not m.any(): X[:, j] = 0.0 else: med = np.nanmedian(col[m]) col[~m] = med X[:, j] = col if X.shape[0] < 10: print( "[infection_intensity_qc:PCA] Fewer than 10 cells after filtering; " "skipping embedding QC." ) return all_df, infection_col # Optional subsampling for speed max_cells = int(settings.get("infection_pca_max_cells", 50000)) if X.shape[0] > max_cells: rng = np.random.default_rng(0) idx = rng.choice(np.arange(X.shape[0]), size=max_cells, replace=False) X = X[idx] cell_level = cell_level.iloc[idx].reset_index(drop=True) y_orig = y_orig[idx] # ------------------------------------------------------------------ # Build intensity-based ground-truth subsets # ------------------------------------------------------------------ intens = cell_level[intensity_col].to_numpy(dtype=float) mask_finite_int = np.isfinite(intens) intens = intens[mask_finite_int] y_int = y_orig[mask_finite_int] if intens.size < 40 or np.sum(y_int) < 10 or np.sum(~y_int) < 10: print( "[infection_intensity_qc:PCA] Not enough cells with finite intensity in " "both infected/uninfected for ground-truth definition; " "skipping embedding QC." ) return all_df, infection_col inf_vals = intens[y_int] uninf_vals = intens[~y_int] if inf_vals.size < 10 or uninf_vals.size < 10: print( "[infection_intensity_qc:PCA] Too few intensity values per class; " "skipping embedding QC." ) return all_df, infection_col thr_uninf = float(np.nanpercentile(uninf_vals, 25.0)) thr_inf = float(np.nanpercentile(inf_vals, 75.0)) # Boolean masks in the full (post-subsample) cell_level intens_full = cell_level[intensity_col].to_numpy(dtype=float) mask_finite_full = np.isfinite(intens_full) gt_uninf = mask_finite_full & (~y_orig) & (intens_full <= thr_uninf) gt_inf = mask_finite_full & (y_orig) & (intens_full >= thr_inf) n_gt_uninf = int(gt_uninf.sum()) n_gt_inf = int(gt_inf.sum()) print( "[infection_intensity_qc:PCA] Ground-truth sets: " f"uninfected_gt={n_gt_uninf}, infected_gt={n_gt_inf} " f"(thr_uninf={thr_uninf:.3f}, thr_inf={thr_inf:.3f})." ) if n_gt_uninf < 10 or n_gt_inf < 10: print( "[infection_intensity_qc:PCA] Very small ground-truth subsets; " "embedding QC may be unstable." ) # ------------------------------------------------------------------ # Embedding (PCA / UMAP / t-SNE) with optional hyperparameter search # ------------------------------------------------------------------ scaler = StandardScaler() X_scaled = scaler.fit_transform(X) # Optional: up-weight pathogen-channel features to emphasize infection signal path_weight = float(settings.get("infection_pca_pathogen_weight", 1.0)) if path_weight != 1.0 and path_cols: path_idx = [feature_cols.index(c) for c in feature_cols if c in path_cols] if path_idx: X_scaled[:, path_idx] *= path_weight random_state = int(settings.get("infection_pca_random_state", 0)) method_label = "PCA" embedding_params = {} coords = None cluster_labels = None eval_stats = None if embed_method == "umap" and umap is not None: coords, cluster_labels, eval_stats, embedding_params = _search_umap( X_scaled, y_orig, gt_uninf, gt_inf, settings ) method_label = "UMAP" print( "[infection_intensity_qc:PCA] UMAP best params: " f"{embedding_params}, score={eval_stats['score']:.4f}" ) elif embed_method == "tsne" and TSNE is not None: coords, cluster_labels, eval_stats, embedding_params = _search_tsne( X_scaled, y_orig, gt_uninf, gt_inf, settings ) method_label = "t-SNE" print( "[infection_intensity_qc:PCA] t-SNE best params: " f"{embedding_params}, score={eval_stats['score']:.4f}" ) else: # PCA (no hyperparameter search, but benefits from log/weighting above) if embed_method in {"umap", "tsne"}: print( f"[infection_intensity_qc:PCA] Requested method={embed_method!r} " "not available; falling back to PCA." ) pca = PCA( n_components=2, random_state=random_state, ) coords = pca.fit_transform(X_scaled) kmeans = KMeans( n_clusters=2, random_state=random_state, n_init="auto", ) cluster_labels = kmeans.fit_predict(coords) eval_stats = _evaluate_embedding(coords, cluster_labels, y_orig, gt_uninf, gt_inf) method_label = "PCA" embedding_params = {} # Unpack evaluation stats infected_cluster = int(eval_stats["infected_cluster"]) uninfected_cluster = int(eval_stats["uninfected_cluster"]) gt_sep_score = float(eval_stats["gt_sep_score"]) sil_score = eval_stats["silhouette_score"] centroid_distance = float(eval_stats["centroid_distance"]) frac_inf_infected_cluster = float(eval_stats["frac_inf_infected_cluster"]) frac_inf_uninfected_cluster = float(eval_stats["frac_inf_uninfected_cluster"]) min_gt_sep = float(settings.get("infection_pca_min_gt_separation", 0.2)) min_sil = float(settings.get("infection_pca_min_silhouette", 0.05)) if gt_sep_score < min_gt_sep or (sil_score is not None and sil_score < min_sil): print( "[infection_intensity_qc:PCA] WARNING: weak cluster structure " f"(gt_sep_score={gt_sep_score:.3f}, silhouette={sil_score}). " "To improve separation you can try:\n" " - Tightening infection ground-truth thresholds (e.g. more extreme percentiles)\n" " - Reducing noise features, especially non-morphology/non-pathogen\n" " - Adjusting UMAP/t-SNE grids to favor more local structure\n" " - Increasing infection_pca_pathogen_weight to emphasize pathogen features." ) print( "[infection_intensity_qc:PCA] Cluster infected fractions (original labels): " f"infected_cluster={frac_inf_infected_cluster:.3f}, " f"uninfected_cluster={frac_inf_uninfected_cluster:.3f}, " f"centroid_distance={centroid_distance:.3f}, gt_sep={gt_sep_score:.3f}." ) # ------------------------------------------------------------------ # Build cluster-based infection call # ------------------------------------------------------------------ cluster_infected = (cluster_labels == infected_cluster) removed_ids = set() if mode == "relabel": # labels follow cluster adjusted = cluster_infected.astype(bool) n_changed = int((adjusted != y_orig).sum()) print( "[infection_intensity_qc:PCA] Relabel mode: adjusted infection labels for " f"{n_changed} cells based on {method_label} clusters." ) cell_level["adjusted_infected"] = adjusted.astype(bool) else: # mode == "remove" consistent = cluster_infected == y_orig to_remove = ~consistent if to_remove.any(): removed = cell_level.loc[to_remove, key_cols] removed_ids = { (r["plateID"], r["wellID"], r["fieldID"], r["cellID"]) for _, r in removed.iterrows() } cell_level = cell_level.loc[consistent].copy() cluster_infected = cluster_infected[consistent] y_orig = y_orig[consistent] coords = coords[consistent] cluster_labels = cluster_labels[consistent] print( "[infection_intensity_qc:PCA] Remove mode: removed " f"{len(removed_ids)} cells with cluster vs label disagreement." ) cell_level["adjusted_infected"] = y_orig.astype(bool) # ------------------------------------------------------------------ # Map adjusted infection back to all_df (frame level) # ------------------------------------------------------------------ # Ensure key dtypes match for col in key_cols: all_df[col] = all_df[col].astype(cell_level[col].dtype) all_df = all_df.merge( cell_level[key_cols + ["adjusted_infected"]], on=key_cols, how="left", validate="m:1", ) if removed_ids: mask_drop = all_df.apply( lambda r: (r["plateID"], r["wellID"], r["fieldID"], r["cellID"]) in removed_ids, axis=1, ) all_df = all_df.loc[~mask_drop].reset_index(drop=True) # Any rows that did not get an adjusted label inherit the original mask_missing = all_df["adjusted_infected"].isna() if mask_missing.any(): all_df.loc[mask_missing, "adjusted_infected"] = ( all_df.loc[mask_missing, infection_col].astype(bool) ) all_df["adjusted_infected"] = all_df["adjusted_infected"].astype(bool) infection_col = "adjusted_infected" # ------------------------------------------------------------------ # Store payload for combined panel # ------------------------------------------------------------------ settings["infection_pca_data"] = { "coords": coords, "labels": cell_level["adjusted_infected"].astype(bool).to_numpy(), "cluster_labels": cluster_labels, "method_label": method_label, # <- used for axis titles / panel labels "infected_cluster": int(infected_cluster), "uninfected_cluster": int(uninfected_cluster), "initial_infected_frac_infected_cluster": frac_inf_infected_cluster, "initial_infected_frac_uninfected_cluster": frac_inf_uninfected_cluster, "gt_sep_score": gt_sep_score, "silhouette_score": sil_score, "centroid_distance": centroid_distance, "embedding_params": embedding_params, } settings["infection_intensity_qc_panel_type"] = "pca" settings["infection_intensity_qc_panel_path"] = None # ------------------------------------------------------------------ # Optional debug plot # ------------------------------------------------------------------ try: if motility_dir is not None: import matplotlib.pyplot as plt os.makedirs(motility_dir, exist_ok=True) fig, ax = plt.subplots(figsize=(4, 4)) # Masks for remaining cells mask_uninf_cluster_plot = cluster_labels == uninfected_cluster mask_inf_cluster_plot = cluster_labels == infected_cluster # Plot clusters with transparency and filled markers ax.scatter( coords[mask_uninf_cluster_plot, 0], coords[mask_uninf_cluster_plot, 1], s=2, alpha=0.6, color="green", label=( f"Uninfected cluster " f"({frac_inf_uninfected_cluster*100:.1f}% infected at start)" ), ) ax.scatter( coords[mask_inf_cluster_plot, 0], coords[mask_inf_cluster_plot, 1], s=2, alpha=0.6, color="red", label=( f"Infected cluster " f"({frac_inf_infected_cluster*100:.1f}% infected at start)" ), ) # Axis titles and main title reflect method ax.set_xlabel(f"{method_label} 1") ax.set_ylabel(f"{method_label} 2") title = f"{method_label} infection QC" if embedding_params: param_str = ", ".join( f"{k}={v}" for k, v in embedding_params.items() ) title += f"\n{param_str}" if sil_score is not None: title += f"\nGT-sep={gt_sep_score:.2f}, sil={sil_score:.2f}" else: title += f"\nGT-sep={gt_sep_score:.2f}" ax.set_title(title) ax.legend(fontsize=7, loc="best") out_png = os.path.join( motility_dir, f"infection_{embed_method}_qc_embedding.png" ) fig.savefig(out_png, dpi=150, bbox_inches="tight") plt.close(fig) except Exception as e: print(f"[infection_intensity_qc:PCA] Failed to save embedding QC plot: {e}") return all_df, infection_col def _apply_infection_intensity_qc( all_df, settings, infection_col, pathogen_chan, motility_dir, ): """ Dispatch to different infection QC strategies based on settings['infection_intensity_strategy'] and settings['infection_intensity_qc_scope']. Strategies ---------- 'histogram' / 'hist' / 'histagram' 1D intensity histogram thresholding 'pca' / 'umap' / 'tsne' PCA/UMAP/TSNE + clustering (via _infection_qc_pca_clustering) 'xgboost' / 'xgb' Supervised XGBoost classifier on extreme intensities Scope ----- settings['infection_intensity_qc_scope']: 'combined' (default) Run QC once on all_df (old behaviour). 'plate' Run QC separately per plateID. 'well' Run QC separately per (plateID, wellID). 'none' / 'off' Skip QC entirely and return original labels. If settings['infection_intensity_qc'] is False or pathogen_chan is None, this function is a no-op and returns the input as-is. Returns ------- all_df : DataFrame Frame-level measurements with possibly updated 'infection_col'. infection_col : str Name of the column in all_df that encodes the (possibly adjusted) infection status. """ import os import numpy as np import pandas as pd # Reset QC payloads by default; strategy helpers will overwrite if used settings["infection_hist_data"] = None settings["infection_pca_data"] = None settings["infection_xgb_importance"] = None settings["infection_intensity_qc_panel_type"] = None settings["infection_intensity_qc_panel_path"] = None # If QC is disabled or there is no pathogen channel, do nothing. infection_intensity_qc = bool(settings.get("infection_intensity_qc", False)) if (not infection_intensity_qc) or (pathogen_chan is None): print("[infection_intensity_qc] QC disabled or no pathogen channel; skipping.") return all_df, infection_col # Make sure output directory exists for plots os.makedirs(motility_dir, exist_ok=True) strategy = str(settings.get("infection_intensity_strategy", "histogram")).lower() # Strategy → QC helper if strategy in {"hist", "histogram", "histagram"}: qc_func = _infection_qc_histogram elif strategy in {"xgboost", "xgb"}: qc_func = _infection_qc_xgboost elif strategy in {"pca", "umap", "tsne"}: qc_func = _infection_qc_pca_clustering else: print( "[infection_intensity_qc] Unknown strategy " f"{strategy!r}; falling back to 'histogram'." ) qc_func = _infection_qc_histogram # Scope: combined (default), per-plate, per-well, or none scope = str(settings.get("infection_intensity_qc_scope", "combined") or "combined").lower() if scope in {"none", "off"}: # Explicit request to skip QC return all_df, infection_col if scope in {"combined", "global", "all"}: # --- old behaviour: single global run --- local_settings = dict(settings) # shallow copy for QC helper df_qc, inf_col_out = qc_func( all_df=all_df, settings=local_settings, infection_col=infection_col, pathogen_chan=pathogen_chan, motility_dir=motility_dir, ) # propagate QC payloads back settings["infection_hist_data"] = local_settings.get("infection_hist_data") settings["infection_pca_data"] = local_settings.get("infection_pca_data") settings["infection_xgb_importance"] = local_settings.get("infection_xgb_importance") settings["infection_intensity_qc_panel_type"] = local_settings.get( "infection_intensity_qc_panel_type" ) settings["infection_intensity_qc_panel_path"] = local_settings.get( "infection_intensity_qc_panel_path" ) # normalise adjusted_infected if present if "adjusted_infected" in df_qc.columns: if df_qc["adjusted_infected"].isna().any(): df_qc["adjusted_infected"] = df_qc["adjusted_infected"].fillna( df_qc[infection_col] ) try: df_qc["adjusted_infected"] = df_qc["adjusted_infected"].astype(int) except Exception: df_qc["adjusted_infected"] = df_qc["adjusted_infected"].astype(bool) inf_col_out = "adjusted_infected" return df_qc, inf_col_out # Grouped scopes if scope in {"plate", "per_plate", "plateid"}: group_cols = ["plateID"] elif scope in {"well", "per_well"}: group_cols = ["plateID", "wellID"] else: print( f"[_apply_infection_intensity_qc] Unknown scope={scope!r}; " "using 'combined' behaviour." ) local_settings = dict(settings) df_qc, inf_col_out = qc_func( all_df=all_df, settings=local_settings, infection_col=infection_col, pathogen_chan=pathogen_chan, motility_dir=motility_dir, ) settings["infection_hist_data"] = local_settings.get("infection_hist_data") settings["infection_pca_data"] = local_settings.get("infection_pca_data") settings["infection_xgb_importance"] = local_settings.get("infection_xgb_importance") settings["infection_intensity_qc_panel_type"] = local_settings.get( "infection_intensity_qc_panel_type" ) settings["infection_intensity_qc_panel_path"] = local_settings.get( "infection_intensity_qc_panel_path" ) if "adjusted_infected" in df_qc.columns: if df_qc["adjusted_infected"].isna().any(): df_qc["adjusted_infected"] = df_qc["adjusted_infected"].fillna( df_qc[infection_col] ) try: df_qc["adjusted_infected"] = df_qc["adjusted_infected"].astype(int) except Exception: df_qc["adjusted_infected"] = df_qc["adjusted_infected"].astype(bool) inf_col_out = "adjusted_infected" return df_qc, inf_col_out # If requested group columns are missing, fall back to combined if not set(group_cols).issubset(all_df.columns): print( f"[_apply_infection_intensity_qc] Requested scope={scope!r} but " f"missing grouping columns {group_cols}; falling back to combined QC." ) local_settings = dict(settings) df_qc, inf_col_out = qc_func( all_df=all_df, settings=local_settings, infection_col=infection_col, pathogen_chan=pathogen_chan, motility_dir=motility_dir, ) settings["infection_hist_data"] = local_settings.get("infection_hist_data") settings["infection_pca_data"] = local_settings.get("infection_pca_data") settings["infection_xgb_importance"] = local_settings.get("infection_xgb_importance") settings["infection_intensity_qc_panel_type"] = local_settings.get( "infection_intensity_qc_panel_type" ) settings["infection_intensity_qc_panel_path"] = local_settings.get( "infection_intensity_qc_panel_path" ) if "adjusted_infected" in df_qc.columns: if df_qc["adjusted_infected"].isna().any(): df_qc["adjusted_infected"] = df_qc["adjusted_infected"].fillna( df_qc[infection_col] ) try: df_qc["adjusted_infected"] = df_qc["adjusted_infected"].astype(int) except Exception: df_qc["adjusted_infected"] = df_qc["adjusted_infected"].astype(bool) inf_col_out = "adjusted_infected" return df_qc, inf_col_out # ------------------------------------------------------------------ # Group-wise QC: per-plate or per-well # ------------------------------------------------------------------ parts = [] any_adjusted = False first_payload_settings = None for g_key, df_group in all_df.groupby(group_cols, sort=False): if df_group.empty: continue local_settings = dict(settings) df_group_qc, inf_col_group = qc_func( all_df=df_group, settings=local_settings, infection_col=infection_col, pathogen_chan=pathogen_chan, motility_dir=motility_dir, ) if "adjusted_infected" in df_group_qc.columns and df_group_qc["adjusted_infected"].notna().any(): any_adjusted = True if first_payload_settings is None: first_payload_settings = local_settings parts.append(df_group_qc) if not parts: # nothing processed → return original return all_df, infection_col all_df_qc = pd.concat(parts, axis=0, ignore_index=True) # Use QC payloads (histogram/PCA/XGB) from the first processed group if first_payload_settings is not None: settings["infection_hist_data"] = first_payload_settings.get("infection_hist_data") settings["infection_pca_data"] = first_payload_settings.get("infection_pca_data") settings["infection_xgb_importance"] = first_payload_settings.get("infection_xgb_importance") settings["infection_intensity_qc_panel_type"] = first_payload_settings.get( "infection_intensity_qc_panel_type" ) settings["infection_intensity_qc_panel_path"] = first_payload_settings.get( "infection_intensity_qc_panel_path" ) if any_adjusted and "adjusted_infected" in all_df_qc.columns: if all_df_qc["adjusted_infected"].isna().any(): all_df_qc["adjusted_infected"] = all_df_qc["adjusted_infected"].fillna( all_df_qc[infection_col] ) try: all_df_qc["adjusted_infected"] = all_df_qc["adjusted_infected"].astype(int) except Exception: all_df_qc["adjusted_infected"] = all_df_qc["adjusted_infected"].astype(bool) infection_col_out = "adjusted_infected" else: infection_col_out = infection_col return all_df_qc, infection_col_out def _compute_velocities_and_well_summary( all_df, settings, infection_col, pixels_per_um, seconds_per_frame, ): """ Compute per-track velocities, straightness and per-well motility summary. Returns ------- track_df : DataFrame per_well_tracks : dict[(plateID, wellID) -> list of track dicts] well_summary_df : DataFrame vel_unit : str """ import numpy as np import pandas as pd y_col = "cell_centroid-0" x_col = "cell_centroid-1" track_df = pd.DataFrame() well_summary_df = pd.DataFrame() per_well_tracks = {} vel_unit = "px/frame" if y_col not in all_df.columns or x_col not in all_df.columns: print( "[summarise_tracks_from_merged] Centroid columns missing; " "motility summary and plots will not be generated." ) return track_df, per_well_tracks, well_summary_df, vel_unit gtracks = all_df.groupby(["plateID", "wellID", "fieldID", "cellID"]) track_records = [] for (plateID, wellID, fieldID, cellID), g in gtracks: g = g.sort_values("frame") x_px = g[x_col].to_numpy(dtype=float) y_px = g[y_col].to_numpy(dtype=float) if len(x_px) < 2: continue dx = np.diff(x_px) dy = np.diff(y_px) d = np.hypot(dx, dy) if d.size == 0 or not np.isfinite(d).any(): continue v_px = float(np.nanmean(d)) path_length = float(np.nansum(d)) net_dx = float(x_px[-1] - x_px[0]) net_dy = float(y_px[-1] - y_px[0]) net_disp = float(np.hypot(net_dx, net_dy)) if path_length > 0 and np.isfinite(net_disp): straightness = net_disp / path_length else: straightness = np.nan infected_track = bool(g[infection_col].any()) track_records.append( { "plateID": plateID, "wellID": wellID, "fieldID": fieldID, "cellID": cellID, "infected": infected_track, "v_px_per_frame": v_px, "straightness": straightness, } ) key_well = (plateID, wellID) per_well_tracks.setdefault(key_well, []).append( { "plateID": plateID, "wellID": wellID, "fieldID": fieldID, "cellID": cellID, "infected": infected_track, "x_px": x_px, "y_px": y_px, "v_px_per_frame": v_px, "straightness": straightness, } ) if not track_records: print( "[summarise_tracks_from_merged] No tracks with >=2 frames; " "skipping motility summary and plots." ) return track_df, per_well_tracks, well_summary_df, vel_unit track_df = pd.DataFrame(track_records) use_physical_units = ( pixels_per_um is not None and seconds_per_frame is not None ) if use_physical_units: pixels_per_um = float(pixels_per_um) seconds_per_frame = float(seconds_per_frame) factor = (1.0 / pixels_per_um) * (60.0 / seconds_per_frame) vel_unit = "µm/min" else: factor = 1.0 vel_unit = "px/frame" track_df["velocity"] = track_df["v_px_per_frame"] * factor track_df["velocity_unit"] = vel_unit # Straightness-based artifact detection / filtering if "straightness" in track_df.columns: straightness_threshold = float( settings.get("straightness_threshold", 0.95) ) straightness_filter = bool(settings.get("straightness_filter", False)) n_tracks_before = track_df.shape[0] n_high = int((track_df["straightness"] >= straightness_threshold).sum()) print( "[summarise_tracks_from_merged] Straightness metric: " f"{n_high} of {n_tracks_before} tracks have straightness " f">= {straightness_threshold:.2f} " "(net displacement / path length)." ) if straightness_filter and n_high > 0: drop_mask = track_df["straightness"] >= straightness_threshold dropped = track_df.loc[ drop_mask, ["plateID", "wellID", "fieldID", "cellID"] ].copy() drop_keys = set( zip( dropped["plateID"], dropped["wellID"], dropped["fieldID"], dropped["cellID"], ) ) track_df = track_df.loc[~drop_mask].reset_index(drop=True) print( "[summarise_tracks_from_merged] Straightness filter " f"removed {n_high} overly straight tracks " f"(threshold={straightness_threshold:.2f})." ) # Filter per_well_tracks accordingly for well_key, track_list in list(per_well_tracks.items()): filtered_list = [ tr for tr in track_list if ( tr["plateID"], tr["wellID"], tr["fieldID"], tr["cellID"], ) not in drop_keys ] if filtered_list: per_well_tracks[well_key] = filtered_list else: del per_well_tracks[well_key] if track_df.empty: print( "[summarise_tracks_from_merged] No tracks left after " "straightness filtering; skipping motility summary and plots." ) return track_df, per_well_tracks, well_summary_df, vel_unit well_records = [] for (plateID, wellID), g in track_df.groupby(["plateID", "wellID"]): n_tracks_well = len(g) n_inf_well = int(g["infected"].sum()) n_uninf_well = n_tracks_well - n_inf_well mean_all = float(g["velocity"].mean()) if n_tracks_well > 0 else np.nan mean_inf = ( float(g.loc[g["infected"], "velocity"].mean()) if n_inf_well > 0 else np.nan ) mean_uninf = ( float(g.loc[~g["infected"], "velocity"].mean()) if n_uninf_well > 0 else np.nan ) well_records.append( dict( plateID=plateID, wellID=wellID, n_tracks=n_tracks_well, n_infected_tracks=n_inf_well, n_uninfected_tracks=n_uninf_well, mean_velocity_all=mean_all, mean_velocity_infected=mean_inf, mean_velocity_uninfected=mean_uninf, velocity_unit=vel_unit, ) ) if well_records: well_summary_df = pd.DataFrame(well_records) print( "[summarise_tracks_from_merged] Computed per-track velocities " f"in units: {vel_unit}" ) return track_df, per_well_tracks, well_summary_df, vel_unit def _save_measurements_and_well_summary( all_df, well_summary_df, src, db_table_name, ): """ Save per-frame measurements and well-level motility summary to SQLite. Returns (measurements_dir, db_path). """ import os import sqlite3 measurements_dir = os.path.join(src, "measurements") os.makedirs(measurements_dir, exist_ok=True) db_path = os.path.join(measurements_dir, "measurements.db") with sqlite3.connect(db_path) as conn: all_df.to_sql(db_table_name, conn, if_exists="replace", index=False) print( f"[summarise_tracks_from_merged] Saved measurements to " f"{db_path} (table='{db_table_name}')" ) if not well_summary_df.empty: well_table_name = db_table_name + "_well_motility" well_summary_df.to_sql( well_table_name, conn, if_exists="replace", index=False, ) print( f"[summarise_tracks_from_merged] Saved well-level motility " f"summary to {db_path} (table='{well_table_name}')" ) else: print( "[summarise_tracks_from_merged] No well-level motility " "summary table was created." ) return measurements_dir, db_path def _feature_velocity_correlations(all_df, track_df, measurements_dir): """ Correlate per-track velocity with median per-track features (all / infected / uninfected). Saves CSV to measurements_dir/velocity_feature_correlations.csv """ import numpy as np import os import pandas as pd if track_df.empty: return try: group_cols = ["plateID", "wellID", "fieldID", "cellID"] numeric_cols = all_df.select_dtypes(include=[np.number]).columns.tolist() for col_rm in ("frame", "timeID", "cellID"): if col_rm in numeric_cols: numeric_cols.remove(col_rm) if not numeric_cols: print( "[summarise_tracks_from_merged] No numeric feature columns " "available for correlation analysis." ) return agg_features = ( all_df[group_cols + numeric_cols] .groupby(group_cols, dropna=False) .median() .reset_index() ) track_features = track_df.merge(agg_features, on=group_cols, how="left") exclude_cols = set( group_cols + ["infected", "v_px_per_frame", "velocity", "velocity_unit"] ) candidate_cols = [ c for c in track_features.columns if c not in exclude_cols and np.issubdtype(track_features[c].dtype, np.number) ] if not candidate_cols: print( "[summarise_tracks_from_merged] No numeric feature columns " "available for correlation analysis." ) return def _corr_subset(mask, label): sub = track_features.loc[mask, candidate_cols + ["velocity"]].copy() sub = sub[np.isfinite(sub["velocity"])] if sub.shape[0] < 5: print( "[summarise_tracks_from_merged] " f"Not enough tracks for correlation ({label})." ) return None corr_series = sub.corr(method="pearson")["velocity"].drop("velocity") corr_df = ( corr_series.rename("pearson_r") .to_frame() .reset_index() .rename(columns={"index": "feature"}) ) corr_df["n_tracks"] = sub.shape[0] corr_df["group"] = label return corr_df mask_all = np.isfinite(track_features["velocity"]) results = [] res_all = _corr_subset(mask_all, "all") if res_all is not None: results.append(res_all) mask_inf = track_features["infected"].astype(bool) & mask_all res_inf = _corr_subset(mask_inf, "infected") if res_inf is not None: results.append(res_inf) mask_uninf = (~track_features["infected"].astype(bool)) & mask_all res_uninf = _corr_subset(mask_uninf, "uninfected") if res_uninf is not None: results.append(res_uninf) if not results: return corr_all = pd.concat(results, ignore_index=True) corr_all["abs_pearson_r"] = corr_all["pearson_r"].abs() corr_all = corr_all.sort_values( ["group", "abs_pearson_r"], ascending=[True, False] ) corr_out = os.path.join(measurements_dir, "velocity_feature_correlations.csv") corr_all.to_csv(corr_out, index=False) print( "[summarise_tracks_from_merged] Saved velocity–feature " f"correlations to {corr_out}" ) except Exception as e: print( "[summarise_tracks_from_merged] Feature–velocity correlation " f"analysis failed with error: {e}" ) def _make_intensity_sanity_plots(all_df, infection_col, n_channels, motility_dir): """ Per-channel intensity sanity-check plots (infected vs uninfected). """ import numpy as np import os import matplotlib.pyplot as plt if all_df.empty: return keys = ["plateID", "wellID", "fieldID", "cellID"] os.makedirs(motility_dir, exist_ok=True) for ch in range(n_channels): col_int = f"cell_mean_intensity_ch{ch}" if col_int not in all_df.columns: continue cell_level_int = ( all_df[keys + [col_int, infection_col]] .groupby(keys, dropna=False) .agg( { col_int: "mean", infection_col: "max", } ) .reset_index() ) cell_level_int = cell_level_int.replace([np.inf, -np.inf], np.nan) cell_level_int = cell_level_int.dropna(subset=[col_int]) if cell_level_int.empty: print( f"[summarise_tracks_from_merged] No data for intensity " f"channel {ch}, skipping sanity plot." ) continue mask_inf = cell_level_int[infection_col].astype(bool) vals_inf = cell_level_int.loc[mask_inf, col_int].to_numpy() vals_uninf = cell_level_int.loc[~mask_inf, col_int].to_numpy() mean_inf = float(np.nanmean(vals_inf)) if vals_inf.size else np.nan std_inf = ( float(np.nanstd(vals_inf, ddof=1)) if vals_inf.size > 1 else np.nan ) mean_uninf = float(np.nanmean(vals_uninf)) if vals_uninf.size else np.nan std_uninf = ( float(np.nanstd(vals_uninf, ddof=1)) if vals_uninf.size > 1 else np.nan ) x_pos = np.arange(2) heights = [mean_inf, mean_uninf] errors = [std_inf, std_uninf] fig_ch, ax_ch = plt.subplots(figsize=(4, 4)) ax_ch.bar( x_pos, heights, yerr=errors, capsize=5, color=["red", "green"], alpha=0.7, ) ax_ch.set_xticks(x_pos) ax_ch.set_xticklabels(["Infected", "Uninfected"]) ax_ch.set_ylabel(f"Mean cell intensity (channel {ch})") ax_ch.set_title(f"Intensity vs infection – channel {ch}") ax_ch.set_ylim(bottom=0) plt.tight_layout() out_ch = os.path.join( motility_dir, f"intensity_channel{ch}_infected_vs_uninfected.png" ) fig_ch.savefig(out_ch, dpi=300) plt.close(fig_ch) print( f"[summarise_tracks_from_merged] Saved intensity sanity plot " f"for channel {ch} to {out_ch}" ) def _make_motility_plots( track_df, per_well_tracks, well_summary_df, motility_dir, pixels_per_um, seconds_per_frame, vel_unit, settings, ): """ Motility plots (combined + per-well) with compact text box. Axis control via settings: - motility_xlim / motility_ylim: applied to absolute-coordinate plots - motility_origin_xlim / motility_origin_ylim: applied to origin plots """ import numpy as np import os import matplotlib.pyplot as plt from matplotlib import patches if track_df.empty or not per_well_tracks: print( "[summarise_tracks_from_merged] No per-track velocities available; " "motility plots were not generated." ) return def _fmt_vel(val): return "n/a" if not np.isfinite(val) else f"{val:.2f}" def _apply_axis_limits(ax, xlim, ylim): if xlim is not None and len(xlim) == 2: ax.set_xlim(float(xlim[0]), float(xlim[1])) if ylim is not None and len(ylim) == 2: ax.set_ylim(float(ylim[0]), float(ylim[1])) abs_xlim = settings.get("motility_xlim", None) abs_ylim = settings.get("motility_ylim", None) origin_xlim = settings.get("motility_origin_xlim", None) origin_ylim = settings.get("motility_origin_ylim", None) if pixels_per_um is not None: unit_line1 = f"1 µm = {float(pixels_per_um):.2f} px" coord_label_x = "x (µm)" coord_label_y = "y (µm)" coord_scale = 1.0 / float(pixels_per_um) else: unit_line1 = "1 µm = ? px" coord_label_x = "x (pixels)" coord_label_y = "y (pixels)" coord_scale = 1.0 if seconds_per_frame is not None: unit_line2 = f"1 frame = {float(seconds_per_frame):g} s" else: unit_line2 = "1 frame = ? s" box_x0 = 0.64 box_y0 = 0.69 box_width = 0.30 box_height = 0.23 text_x = box_x0 + 0.02 y_top = box_y0 + box_height - 0.03 line_spacing = 0.07 fontsize_main = 8 fontsize_units = 7 os.makedirs(motility_dir, exist_ok=True) # Combined plot over all wells fig_all, ax_all = plt.subplots(figsize=(6, 6)) for tracks in per_well_tracks.values(): for tr in tracks: x = tr["x_px"] * coord_scale y = tr["y_px"] * coord_scale infected_track = tr["infected"] color = "red" if infected_track else "green" ax_all.plot(x, y, color=color, alpha=0.2, linewidth=0.5) ax_all.scatter(x[-1], y[-1], color=color, s=5) vel_all = track_df["velocity"].to_numpy() vel_inf = track_df.loc[track_df["infected"], "velocity"].to_numpy() vel_uninf = track_df.loc[~track_df["infected"], "velocity"].to_numpy() mean_vel_all = float(np.nanmean(vel_all)) if vel_all.size else np.nan mean_vel_inf = float(np.nanmean(vel_inf)) if vel_inf.size else np.nan mean_vel_uninf = float(np.nanmean(vel_uninf)) if vel_uninf.size else np.nan print( "[summarise_tracks_from_merged] Velocity stats " f"({vel_unit}): all={mean_vel_all:.3f} " f"(n={vel_all.size} tracks with >=2 frames), " f"infected={mean_vel_inf:.3f} (n={vel_inf.size}), " f"uninfected={mean_vel_uninf:.3f} (n={vel_uninf.size})" ) ax_all.set_aspect("equal", "box") ax_all.set_xlabel(coord_label_x) ax_all.set_ylabel(coord_label_y) _apply_axis_limits(ax_all, abs_xlim, abs_ylim) bbox_all = patches.FancyBboxPatch( (box_x0, box_y0), box_width, box_height, transform=ax_all.transAxes, facecolor="white", edgecolor="black", boxstyle="round,pad=0.02", alpha=0.8, ) ax_all.add_patch(bbox_all) ax_all.text( text_x, y_top, f"Infected ({_fmt_vel(mean_vel_inf)} {vel_unit})", color="red", transform=ax_all.transAxes, fontsize=fontsize_main, va="top", ) ax_all.text( text_x, y_top - line_spacing, f"Uninfected ({_fmt_vel(mean_vel_uninf)} {vel_unit})", color="green", transform=ax_all.transAxes, fontsize=fontsize_main, va="top", ) ax_all.text( text_x, y_top - 2 * line_spacing, unit_line1, color="black", transform=ax_all.transAxes, fontsize=fontsize_units, va="top", ) ax_all.text( text_x, y_top - 3 * line_spacing, unit_line2, color="black", transform=ax_all.transAxes, fontsize=fontsize_units, va="top", ) plt.tight_layout() out_png_all = os.path.join(motility_dir, "motility_all_tracks.png") fig_all.savefig(out_png_all, dpi=300) plt.close(fig_all) print( f"[summarise_tracks_from_merged] Saved combined motility plot to " f"{out_png_all}" ) # Per-well plots well_summary_map = {} if not well_summary_df.empty: for _, row in well_summary_df.iterrows(): well_summary_map[(row["plateID"], row["wellID"])] = row for (plateID, wellID), tracks in per_well_tracks.items(): fig_w, ax_w = plt.subplots(figsize=(6, 6)) has_infected = False has_uninfected = False for tr in tracks: x = tr["x_px"] * coord_scale y = tr["y_px"] * coord_scale infected_track = tr["infected"] color = "red" if infected_track else "green" if infected_track: has_infected = True else: has_uninfected = True ax_w.plot(x, y, color=color, alpha=0.2, linewidth=0.5) ax_w.scatter(x[-1], y[-1], color=color, s=5) ax_w.set_aspect("equal", "box") ax_w.set_xlabel(coord_label_x) ax_w.set_ylabel(coord_label_y) _apply_axis_limits(ax_w, abs_xlim, abs_ylim) mean_inf_w = np.nan mean_uninf_w = np.nan summary_row = well_summary_map.get((plateID, wellID)) if summary_row is not None: mean_inf_w = summary_row["mean_velocity_infected"] mean_uninf_w = summary_row["mean_velocity_uninfected"] bbox_w = patches.FancyBboxPatch( (box_x0, box_y0), box_width, box_height, transform=ax_w.transAxes, facecolor="white", edgecolor="black", boxstyle="round,pad=0.02", alpha=0.8, ) ax_w.add_patch(bbox_w) ax_w.text( text_x, y_top, f"Infected ({_fmt_vel(mean_inf_w)} {vel_unit})", color="red", transform=ax_w.transAxes, fontsize=fontsize_main, va="top", ) ax_w.text( text_x, y_top - line_spacing, f"Uninfected ({_fmt_vel(mean_uninf_w)} {vel_unit})", color="green", transform=ax_w.transAxes, fontsize=fontsize_main, va="top", ) ax_w.text( text_x, y_top - 2 * line_spacing, unit_line1, color="black", transform=ax_w.transAxes, fontsize=fontsize_units, va="top", ) ax_w.text( text_x, y_top - 3 * line_spacing, unit_line2, color="black", transform=ax_w.transAxes, fontsize=fontsize_units, va="top", ) plt.tight_layout() out_well = os.path.join( motility_dir, f"motility_{plateID}_{wellID}_all_tracks.png" ) fig_w.savefig(out_well, dpi=300) plt.close(fig_w) print( f"[summarise_tracks_from_merged] Saved per-well motility plot " f"to {out_well}" ) # infected-only, re-centred to (0,0) if has_infected: fig_inf, ax_inf = plt.subplots(figsize=(6, 6)) for tr in tracks: if not tr["infected"]: continue x = (tr["x_px"] - tr["x_px"][0]) * coord_scale y = (tr["y_px"] - tr["y_px"][0]) * coord_scale ax_inf.plot(x, y, color="red", alpha=0.2, linewidth=0.5) ax_inf.scatter(x[-1], y[-1], color="red", s=5) ax_inf.set_aspect("equal", "box") ax_inf.set_xlabel(coord_label_x) ax_inf.set_ylabel(coord_label_y) _apply_axis_limits(ax_inf, origin_xlim, origin_ylim) plt.tight_layout() out_inf = os.path.join( motility_dir, f"motility_{plateID}_{wellID}_infected_origin.png" ) fig_inf.savefig(out_inf, dpi=300) plt.close(fig_inf) print( f"[summarise_tracks_from_merged] Saved per-well infected " f"origin plot to {out_inf}" ) # uninfected-only, re-centred to (0,0) if has_uninfected: fig_uninf, ax_uninf = plt.subplots(figsize=(6, 6)) for tr in tracks: if tr["infected"]: continue x = (tr["x_px"] - tr["x_px"][0]) * coord_scale y = (tr["y_px"] - tr["y_px"][0]) * coord_scale ax_uninf.plot(x, y, color="green", alpha=0.2, linewidth=0.5) ax_uninf.scatter(x[-1], y[-1], color="green", s=5) ax_uninf.set_aspect("equal", "box") ax_uninf.set_xlabel(coord_label_x) ax_uninf.set_ylabel(coord_label_y) _apply_axis_limits(ax_uninf, origin_xlim, origin_ylim) plt.tight_layout() out_uninf = os.path.join( motility_dir, f"motility_{plateID}_{wellID}_uninfected_origin.png" ) fig_uninf.savefig(out_uninf, dpi=300) plt.close(fig_uninf) print( f"[summarise_tracks_from_merged] Saved per-well uninfected " f"origin plot to {out_uninf}" ) def _select_infection_feature_columns(all_df, pathogen_chan): """ Select numeric feature columns for infection QC: - numeric columns - drop obvious IDs / motility metrics - drop centroid (coordinate) features - keep intensity features only for the pathogen channel - drop near-constant or almost-empty columns at cell level """ import numpy as np numeric_cols = all_df.select_dtypes(include=[np.number]).columns.tolist() exclude = { "frame", "timeID", "cellID", "n_pathogens", "v_px_per_frame", "velocity", "straightness", } # drop any debug / temporary numeric cols if present exclude |= {c for c in numeric_cols if c.endswith("_idx")} # drop centroid features (absolute coordinates) exclude |= {c for c in numeric_cols if "centroid" in c.lower()} # exclude intensity columns for non-pathogen channels if pathogen_chan is not None: for c in numeric_cols: if "intensity_ch" in c: try: digits = "".join(ch for ch in c.split("ch")[-1] if ch.isdigit()) if digits != "": ch_idx = int(digits) if ch_idx != pathogen_chan: exclude.add(c) except Exception: # if parsing fails, keep column pass feature_cols = [c for c in numeric_cols if c not in exclude] if not feature_cols: return [] key_cols = ["plateID", "wellID", "fieldID", "cellID"] agg_cols = [c for c in feature_cols if c in all_df.columns] if not agg_cols: return [] # Build per-cell table to filter out useless columns cell_level = ( all_df[key_cols + agg_cols] .groupby(key_cols, dropna=False) .median() .reset_index() ) filtered = [] for c in agg_cols: arr = cell_level[c].to_numpy(dtype=float) finite = np.isfinite(arr) if finite.sum() < 10: continue if np.nanstd(arr[finite]) < 1e-6: continue filtered.append(c) return filtered def _compute_intensity_percentiles_per_channel( mask_stack, intensity_stack, channel_index, object_prefix, percentiles=(1, 5, 10, 25, 75, 95, 99), label_as_track_id=False, ): """ Compute per-frame, per-object intensity percentiles for a given channel. Parameters ---------- mask_stack : ndarray Label image stack of shape (T, Y, X). intensity_stack : ndarray Intensity stack of shape (T, Y, X, C). channel_index : int Channel index in intensity_stack. object_prefix : str Prefix for column names ("cell", "nucleus", "pathogen", "cytoplasm"). percentiles : tuple of int Percentiles to compute (0–100). label_as_track_id : bool If True, rename 'label' -> 'track_id'; otherwise 'label' -> f"{object_prefix}_label". Returns ------- DataFrame Columns: ['frame', label_col, f'{object_prefix}_pXX_intensity_ch{channel_index}', ...] """ import numpy as np import pandas as pd if intensity_stack is None: return pd.DataFrame( columns=["frame", "track_id" if label_as_track_id else f"{object_prefix}_label"] ) if channel_index is None or channel_index < 0 or channel_index >= intensity_stack.shape[-1]: return pd.DataFrame( columns=["frame", "track_id" if label_as_track_id else f"{object_prefix}_label"] ) T = mask_stack.shape[0] dfs = [] label_col_name = "track_id" if label_as_track_id else f"{object_prefix}_label" perc = np.array(percentiles, dtype=float) for frame in range(T): labels = mask_stack[frame] if not np.any(labels): continue intensity_image = intensity_stack[frame, :, :, channel_index] # unique labels > 0 obj_labels = np.unique(labels) obj_labels = obj_labels[obj_labels > 0] if obj_labels.size == 0: continue records = [] for lab in obj_labels: mask = labels == lab vals = intensity_image[mask] vals = vals[np.isfinite(vals)] if vals.size == 0: continue pvals = np.percentile(vals, perc) rec = {"frame": frame, label_col_name: int(lab)} for p, v in zip(perc, pvals): col_name = f"{object_prefix}_p{int(p):02d}_intensity_ch{channel_index}" rec[col_name] = float(v) records.append(rec) if records: dfs.append(pd.DataFrame.from_records(records)) if not dfs: return pd.DataFrame(columns=["frame", label_col_name]) out_df = pd.concat(dfs, ignore_index=True) return out_df def _make_adjusted_qc_panel( all_df, infection_col, motility_dir, settings, label_tag, ): """ Build a QC results panel for adjusted labels using 3 subplots: - top-left: PCA (adjusted_infected) - top-right: XGBoost feature importance - bottom: pathogen-channel intensity histogram Uses payloads stored in `settings` by the QC functions: settings["infection_hist_data"] settings["infection_pca_data"] settings["infection_xgb_importance"] """ import os import numpy as np import matplotlib.pyplot as plt os.makedirs(motility_dir, exist_ok=True) meta_tag = _infer_plate_well_meta_tag(all_df) # Create figure with desired layout fig, ax_pca, ax_xgb, ax_hist = create_results_figure() # ------------------------------------------------------------------ # Histogram # ------------------------------------------------------------------ hist_data = settings.get("infection_hist_data") or {} vals_inf = np.asarray(hist_data.get("intensities_inf", []), dtype=float) vals_uninf = np.asarray(hist_data.get("intensities_uninf", []), dtype=float) bin_edges = np.asarray(hist_data.get("bin_edges", []), dtype=float) thr_val = hist_data.get("thr_val", None) pathogen_chan = hist_data.get("pathogen_chan", None) do_log = bool(hist_data.get("log_transform", False)) if vals_inf.size + vals_uninf.size > 0 and bin_edges.size > 0: ax_hist.hist( vals_uninf, bins=bin_edges, alpha=0.5, color="green", label="Uninfected", ) ax_hist.hist( vals_inf, bins=bin_edges, alpha=0.5, color="red", label="Infected", ) if thr_val is not None: ax_hist.axvline( thr_val, linestyle="--", linewidth=2, color="black", label=f"thr={thr_val:.2f}", ) if pathogen_chan is not None: if do_log: ax_hist.set_xlabel(f"log10 intensity (channel {pathogen_chan})") else: ax_hist.set_xlabel(f"Intensity (channel {pathogen_chan})") else: ax_hist.set_xlabel("Intensity") ax_hist.set_ylabel("Cell count") ax_hist.set_title("Pathogen-channel intensity histogram") ax_hist.legend(loc="best") else: ax_hist.text( 0.5, 0.5, "No histogram data", ha="center", va="center", transform=ax_hist.transAxes, ) ax_hist.axis("off") # ------------------------------------------------------------------ # PCA # ------------------------------------------------------------------ pca_data = settings.get("infection_pca_data") or {} coords = pca_data.get("coords", None) labels = pca_data.get("labels", None) method_label = pca_data.get("method_label", "PCA") if coords is not None and labels is not None: coords = np.asarray(coords, dtype=float) labels = np.asarray(labels, dtype=bool) if coords.ndim == 2 and coords.shape[0] == labels.shape[0] and coords.shape[1] >= 2: x = coords[:, 0] y = coords[:, 1] ax_pca.scatter( x[~labels], y[~labels], s=8, c="green", alpha=0.5, label="Uninfected", ) ax_pca.scatter( x[labels], y[labels], s=8, c="red", alpha=0.5, label="Infected", ) ax_pca.set_xlabel("component 1") ax_pca.set_ylabel("component 2") ax_pca.set_title(f"{method_label} embedding") ax_pca.legend(loc="best") else: ax_pca.text( 0.5, 0.5, "No PCA data", ha="center", va="center", transform=ax_pca.transAxes, ) ax_pca.axis("off") else: ax_pca.text( 0.5, 0.5, "No PCA data", ha="center", va="center", transform=ax_pca.transAxes, ) ax_pca.axis("off") # ------------------------------------------------------------------ # XGBoost feature importance # ------------------------------------------------------------------ xgb_data = settings.get("infection_xgb_importance") or {} feat_names = xgb_data.get("feature_names") or [] feat_vals = xgb_data.get("feature_importances") or [] if feat_names and feat_vals and len(feat_names) == len(feat_vals): feat_names = list(feat_names) feat_vals = np.asarray(feat_vals, dtype=float) y_pos = np.arange(len(feat_names)) ax_xgb.barh(y_pos, feat_vals) ax_xgb.set_yticks(y_pos) ax_xgb.set_yticklabels(feat_names) ax_xgb.invert_yaxis() ax_xgb.set_xlabel("Importance (gain)") ax_xgb.set_title("XGBoost feature importance") else: ax_xgb.text( 0.5, 0.5, "No XGBoost importance data", ha="center", va="center", transform=ax_xgb.transAxes, ) ax_xgb.axis("off") fig.suptitle( f"Infection QC panel – {label_tag} labels\n{meta_tag}", fontsize=10, ) fig.tight_layout(rect=[0, 0, 1, 0.94]) out_name = f"infection_qc_panel_{label_tag}_{meta_tag}.png" out_path = os.path.join(motility_dir, out_name) fig.savefig(out_path, dpi=300) plt.close(fig) print( f"[summarise_tracks_from_merged] Saved infection QC results panel " f"({label_tag}) to {out_path}" ) def _load_measurements_from_db(db_path, db_table_name): """ Load per-cell measurements from an existing SQLite database. Parameters ---------- db_path : str Path to the SQLite database file (measurements.db). db_table_name : str Name of the table that stores per-cell measurements. Returns ------- pandas.DataFrame DataFrame with the measurements, or an empty DataFrame if the database/table is missing or unreadable. """ import os import sqlite3 import pandas as pd if not os.path.isfile(db_path): return pd.DataFrame() conn = sqlite3.connect(db_path) try: query = f"SELECT * FROM {db_table_name}" df = pd.read_sql_query(query, conn) except Exception as e: print( "[summarise_tracks_from_merged] Could not load existing measurements " f"from {db_path} (table='{db_table_name}'): {e}" ) df = pd.DataFrame() finally: conn.close() return df def _infection_qc_histogram( all_df, settings, infection_col, pathogen_chan, motility_dir, ): import matplotlib.pyplot as plt import os import numpy as np # Prefer 95th percentile of pathogen channel; fall back to mean if needed cand_cols = [ f"cell_p95_intensity_ch{pathogen_chan}", f"cell_mean_intensity_ch{pathogen_chan}", ] intensity_col = None for c in cand_cols: if c in all_df.columns: intensity_col = c break # Initialize payload slot settings["infection_hist_data"] = None if intensity_col is None: print( f"[infection_intensity_qc] None of {cand_cols} found; " f"skipping intensity-based relabelling." ) settings["infection_intensity_qc_panel_type"] = "histogram" settings["infection_intensity_qc_panel_path"] = None return all_df, infection_col if intensity_col not in all_df.columns: print( f"[infection_intensity_qc] Column {intensity_col!r} not found; " f"skipping intensity-based relabelling." ) settings["infection_intensity_qc_panel_type"] = "histogram" settings["infection_intensity_qc_panel_path"] = None return all_df, infection_col # --- IMPORTANT: drop any existing adjusted_infected when reusing DB --- # This prevents merge from creating adjusted_infected_x / adjusted_infected_y # and guarantees we recompute labels fresh each run. cols_to_drop = [ c for c in all_df.columns if c == "adjusted_infected" or c.startswith("adjusted_infected_") ] if cols_to_drop: all_df = all_df.drop(columns=cols_to_drop) key_cols = ["plateID", "wellID", "fieldID", "cellID"] cell_level = ( all_df[key_cols + [intensity_col, infection_col]] .groupby(key_cols, dropna=False) .agg({intensity_col: "mean", infection_col: "max"}) .reset_index() ) cell_level = cell_level.replace([np.inf, -np.inf], np.nan) cell_level = cell_level.dropna(subset=[intensity_col]) if len(cell_level) < 20 or cell_level[intensity_col].nunique() < 2: print( "[infection_intensity_qc] Too few cells or no intensity variation; " "skipping intensity-based relabelling." ) settings["infection_intensity_qc_panel_type"] = "histogram" settings["infection_intensity_qc_panel_path"] = None return all_df, infection_col intensities = cell_level[intensity_col].to_numpy(dtype=float) mask_labels = cell_level[infection_col].to_numpy(dtype=bool) # Optional log-transform to help separate populations do_log = bool(settings.get("infection_intensity_log", False)) if do_log: eps = np.nanmax([np.nanmin(intensities[intensities > 0]) * 0.5, 1e-6]) intensities = np.log10(intensities + eps) n_bins = int(settings.get("infection_intensity_n_bins", 64)) n_bins = max(10, min(n_bins, 256)) counts_all, bin_edges = np.histogram(intensities, bins=n_bins) counts_inf, _ = np.histogram(intensities[mask_labels], bins=bin_edges) denom = np.maximum(counts_all, 1) frac_inf = counts_inf.astype(float) / denom.astype(float) # Target fraction of infected in a bin target_frac = float(settings.get("infection_intensity_frac_infected", 0.7)) target_frac = max(0.5, min(target_frac, 0.95)) # Fallback percentile (now default 25th) hist_pct = float(settings.get("infection_hist_percentile", 25.0)) hist_pct = max(0.0, min(hist_pct, 100.0)) # First bin (low→high) where infected ≥ target_frac thresh_idx = None for i, frac in enumerate(frac_inf): if frac >= target_frac: thresh_idx = i break if thresh_idx is None: # fallback: hist_pct percentile of all cells (after optional log) thr_val = float(np.nanpercentile(intensities, hist_pct)) print( "[infection_intensity_qc] Could not find bin with infected ≥ " f"{target_frac:.2f}; using {hist_pct:.1f}th percentile of all cells " f"({thr_val:.2f}) as threshold." ) else: thr_val = float(bin_edges[thresh_idx]) print( "[infection_intensity_qc] Automatic intensity threshold at first bin " f"where infected ≥ {target_frac:.2f}: {thr_val:.2f} (bin {thresh_idx})" ) cell_level["intensity_positive"] = cell_level[intensity_col] >= thr_val mode = str(settings.get("infection_intensity_mode", "relabel")).lower() if mode not in {"relabel", "remove"}: mode = "relabel" removed_ids = None if mode == "relabel": cell_level["adjusted_infected"] = cell_level["intensity_positive"].astype(bool) n_changed = int( ( cell_level["adjusted_infected"] != cell_level[infection_col].astype(bool) ).sum() ) print( "[infection_intensity_qc] Adjusted infection labels for " f"{n_changed} cells (mode=relabel)." ) else: consistent = ( cell_level[infection_col].astype(bool) == cell_level["intensity_positive"].astype(bool) ) removed = cell_level.loc[ ~consistent, ["plateID", "wellID", "fieldID", "cellID"] ] removed_ids = set( zip( removed["plateID"], removed["wellID"], removed["fieldID"], removed["cellID"], ) ) cell_level = cell_level.loc[consistent].copy() cell_level["adjusted_infected"] = cell_level["intensity_positive"].astype(bool) print( "[infection_intensity_qc] Removed " f"{len(removed_ids)} cells with conflicting mask vs intensity labels " "(mode=remove)." ) # merge back all_df = all_df.merge( cell_level[key_cols + ["adjusted_infected"]], on=key_cols, how="left", ) if removed_ids: mask_keep = ~all_df.apply( lambda r: (r["plateID"], r["wellID"], r["fieldID"], r["cellID"]) in removed_ids, axis=1, ) all_df = all_df.loc[mask_keep].reset_index(drop=True) # Now adjusted_infected definitely exists and comes from histogram QC; # fill any NaNs (cells not in cell_level) from the original infection_col. all_df["adjusted_infected"] = all_df["adjusted_infected"].fillna( all_df[infection_col] ) infection_col = "adjusted_infected" # Decide whether to make / save QC graph make_graphs = bool(settings.get("infection_intensity_qc_graphs", True)) meta_tag = _infer_plate_well_meta_tag(all_df) hist_path = None # Prepare payload (even if we don't generate PNG) vals_inf = intensities[mask_labels] vals_uninf = intensities[~mask_labels] hist_payload = { "intensities_inf": vals_inf, "intensities_uninf": vals_uninf, "bin_edges": bin_edges, "thr_val": thr_val, "pathogen_chan": pathogen_chan, "log_transform": do_log, "intensity_col": intensity_col, } settings["infection_hist_data"] = hist_payload if make_graphs: # Plot histogram os.makedirs(motility_dir, exist_ok=True) fig_h, ax_h = plt.subplots(figsize=(6, 4)) ax_h.hist( vals_uninf, bins=bin_edges, alpha=0.5, color="green", label="Uninfected (mask-based)", ) ax_h.hist( vals_inf, bins=bin_edges, alpha=0.5, color="red", label="Infected (mask-based)", ) ax_h.axvline( thr_val, linestyle="--", linewidth=2, color="black", label=f"Threshold = {thr_val:.1f}", ) if do_log: ax_h.set_xlabel(f"log10 intensity metric (channel {pathogen_chan})") else: ax_h.set_xlabel(f"Intensity metric (channel {pathogen_chan})") ax_h.set_ylabel("Cell count") ax_h.set_title( f"Pathogen-channel intensity histogram (thr={thr_val:.1f}, " f"{hist_pct:.1f}th pct fallback)" ) ax_h.legend(loc="best") hist_filename = f"infection_intensity_histogram_{meta_tag}.png" hist_path = os.path.join(motility_dir, hist_filename) fig_h.tight_layout() fig_h.savefig(hist_path, dpi=200) plt.close(fig_h) print(f"[infection_intensity_qc] Saved histogram to: {hist_path}") else: print( "[infection_intensity_qc] infection_intensity_qc_graphs=False; " "skipping histogram plot." ) # Let the panel know what QC plot to embed (if present) settings["infection_intensity_qc_panel_type"] = "histogram" settings["infection_intensity_qc_panel_path"] = hist_path return all_df, infection_col def _infection_qc_xgboost(all_df, settings, infection_col, pathogen_chan, motility_dir): """ Use an XGBoost classifier to refine infection calling based on per-object features. Key behaviour ------------- - Training: * per-cell (or per-object) medians across frames * {tracked_object}_* morphology + {tracked_object}_pathogen-channel intensity features * training labels from pathogen-channel intensity extremes: - bottom 25% of UNINFECTED → strong negatives - top 25% of INFECTED → strong positives * training data curated per well: - wells with both classes in the extreme set: - if both classes have >= infection_xgb_min_cells_per_class examples: → balanced sampling per class within that well - else (small wells with both classes): → keep all extreme examples from that well - wells with only one class in the extreme set are skipped - if no well has both classes in the extreme set → skip XGBoost QC - Prediction: * get P(infected) for all cells * "mode": - 'relabel': start from original labels, override only when model is confident; ambiguous are still dropped if requested. - 'remove' : drop strong label–model disagreements AND ambiguous if requested. - Ambiguous band removal: * if infection_xgb_drop_ambiguous is True (default), drop cells with proba in [infection_xgb_ambiguous_low, infection_xgb_ambiguous_high] (defaults: 0.25 and 0.75). Additionally, this function stores three QC payloads in `settings`: settings["infection_hist_data"] : dict for histogram panel settings["infection_pca_data"] : dict for PCA panel settings["infection_xgb_importance"] : dict for feature-importance panel Returns ------- all_df, infection_col='adjusted_infected' """ import os import re import numpy as np import pandas as pd import matplotlib.pyplot as plt try: import xgboost as xgb except ImportError: print("[_infection_qc_xgboost] XGBoost not installed; using histogram QC.") return _infection_qc_histogram( all_df=all_df, settings=settings, infection_col=infection_col, pathogen_chan=pathogen_chan, motility_dir=motility_dir, ) # init payload slots settings["infection_hist_data"] = None settings["infection_pca_data"] = None settings["infection_xgb_importance"] = None # IMPORTANT: drop any existing adjusted_* / infection_prob* from DB reuse cols_to_drop = [ c for c in all_df.columns if c == "adjusted_infected" or c.startswith("adjusted_infected_") or c == "infection_prob" or c.startswith("infection_prob_") ] if cols_to_drop: all_df = all_df.drop(columns=cols_to_drop) orig_infection_col = infection_col if pathogen_chan is None: print("[_infection_qc_xgboost] pathogen_chan is None; using histogram QC.") return _infection_qc_histogram( all_df=all_df, settings=settings, infection_col=infection_col, pathogen_chan=pathogen_chan, motility_dir=motility_dir, ) # Ensure n_pathogens exists and has 0 instead of NaN if "n_pathogens" in all_df.columns: all_df["n_pathogens"] = all_df["n_pathogens"].fillna(0) # Recover infection column if missing if orig_infection_col not in all_df.columns: infect_like = [c for c in all_df.columns if "infect" in c.lower()] if infect_like: new_col = infect_like[0] print( f"[_infection_qc_xgboost] Column {orig_infection_col!r} not found; " f"using {new_col!r} instead." ) orig_infection_col = new_col elif "n_pathogens" in all_df.columns: orig_infection_col = "_infected_from_n_pathogens" all_df[orig_infection_col] = (all_df["n_pathogens"] > 0).astype(int) print( "[_infection_qc_xgboost] Column 'infected' not found; created " f"{orig_infection_col!r} from 'n_pathogens > 0'." ) else: print( "[_infection_qc_xgboost] No infection label and no 'n_pathogens'; " "using histogram QC instead." ) return _infection_qc_histogram( all_df=all_df, settings=settings, infection_col=infection_col, pathogen_chan=pathogen_chan, motility_dir=motility_dir, ) key_cols = ["plateID", "wellID", "fieldID", "cellID"] for col in key_cols: if col not in all_df.columns: raise KeyError(f"[_infection_qc_xgboost] Required column {col!r} not in all_df.") # ------------------------------------------------------------------ # Aggregate to per-object level (median across frames) # ------------------------------------------------------------------ agg_cols = [ c for c in all_df.columns if c not in (key_cols + ["frame", "timeID", orig_infection_col]) ] group = all_df.groupby(key_cols, observed=True) cell_level = group[agg_cols].median(numeric_only=True).reset_index() infection_any = ( all_df.groupby(key_cols, observed=True)[orig_infection_col] .max() .reset_index() ) cell_level = cell_level.merge( infection_any, on=key_cols, how="left", suffixes=("", "_y") ) if orig_infection_col not in cell_level.columns: for cand in (f"{orig_infection_col}_y", f"{orig_infection_col}_x"): if cand in cell_level.columns: cell_level[orig_infection_col] = cell_level[cand] break if orig_infection_col not in cell_level.columns: raise KeyError( f"[_infection_qc_xgboost] Infection column {orig_infection_col!r} " "missing from per-cell table after aggregation/merge." ) cell_level[orig_infection_col] = ( cell_level[orig_infection_col] .fillna(0) .astype(bool) ) if "n_pathogens" in cell_level.columns: cell_level["n_pathogens"] = cell_level["n_pathogens"].fillna(0) # ------------------------------------------------------------------ # Decide which object type's features to use (tracked_object) # ------------------------------------------------------------------ tracked_object = str(settings.get("tracked_object", "cell")).strip().lower() if tracked_object not in {"cell", "nucleus", "pathogen"}: print( f"[_infection_qc_xgboost] Unknown tracked_object={tracked_object!r}; " "falling back to 'cell'." ) tracked_object = "cell" obj_prefix = f"{tracked_object}_" # ------------------------------------------------------------------ # Decide pathogen-channel intensity column for this tracked_object # ------------------------------------------------------------------ intensity_candidates = [ f"{obj_prefix}p95_intensity_ch{pathogen_chan}", f"{obj_prefix}max_intensity_ch{pathogen_chan}", f"{obj_prefix}mean_intensity_ch{pathogen_chan}", ] intensity_col = None for c in intensity_candidates: if c in cell_level.columns: intensity_col = c break if intensity_col is None: print( "[_infection_qc_xgboost] No pathogen-channel intensity column found for " f"tracked_object={tracked_object!r} " f"(tried: {intensity_candidates}); using histogram QC." ) return _infection_qc_histogram( all_df=all_df, settings=settings, infection_col=infection_col, pathogen_chan=pathogen_chan, motility_dir=motility_dir, ) # ------------------------------------------------------------------ # Build feature set: {tracked_object}_* only, excluding centroids, # non-pathogen channels, degenerate features # ------------------------------------------------------------------ numeric_cols = cell_level.select_dtypes(include=[np.number, "bool"]).columns.tolist() feature_cols = [] pattern_obj = re.compile(rf"^{re.escape(obj_prefix)}") pattern_ch = re.compile(r"ch(\d+)\b") for c in numeric_cols: if c == orig_infection_col: continue if c in {"frame", "timeID"}: continue if "centroid" in c.lower(): continue if not pattern_obj.match(c): continue m = pattern_ch.search(c) if m: ch_idx = int(m.group(1)) if ch_idx != int(pathogen_chan): continue feature_cols.append(c) else: feature_cols.append(c) if intensity_col in cell_level.columns and intensity_col not in feature_cols: feature_cols.append(intensity_col) # Drop degenerate features clean_feature_cols = [] for c in feature_cols: s = cell_level[c] if s.notna().sum() < 10: continue if s.nunique(dropna=True) <= 1: continue clean_feature_cols.append(c) feature_cols = clean_feature_cols if not feature_cols: print( "[_infection_qc_xgboost] No usable " f"{tracked_object}_* feature columns; using histogram QC." ) return _infection_qc_histogram( all_df=all_df, settings=settings, infection_col=infection_col, pathogen_chan=pathogen_chan, motility_dir=motility_dir, ) # ------------------------------------------------------------------ # Global sanity check: do we even have enough infected/uninfected cells? # ------------------------------------------------------------------ infected_cells = cell_level[cell_level[orig_infection_col]] uninfected_cells = cell_level[~cell_level[orig_infection_col]] if len(infected_cells) < 10 or len(uninfected_cells) < 10: print( "[_infection_qc_xgboost] Too few infected or uninfected cells overall; " "using histogram QC." ) return _infection_qc_histogram( all_df=all_df, settings=settings, infection_col=infection_col, pathogen_chan=pathogen_chan, motility_dir=motility_dir, ) # ------------------------------------------------------------------ # Define confident training sets using intensity quartiles # ------------------------------------------------------------------ inf_int = infected_cells[intensity_col].to_numpy(dtype=float) uninf_int = uninfected_cells[intensity_col].to_numpy(dtype=float) inf_int = inf_int[np.isfinite(inf_int)] uninf_int = uninf_int[np.isfinite(uninf_int)] if inf_int.size == 0 or uninf_int.size == 0: print( "[_infection_qc_xgboost] No finite intensities for infected/uninfected; " "using histogram QC." ) return _infection_qc_histogram( all_df=all_df, settings=settings, infection_col=infection_col, pathogen_chan=pathogen_chan, motility_dir=motility_dir, ) high_thr_inf = np.nanpercentile(inf_int, 75.0) low_thr_uninf = np.nanpercentile(uninf_int, 25.0) hi_inf = infected_cells[infected_cells[intensity_col] >= high_thr_inf].copy() lo_uninf = uninfected_cells[uninfected_cells[intensity_col] <= low_thr_uninf].copy() if hi_inf.empty or lo_uninf.empty: print( "[_infection_qc_xgboost] Could not define confident high/low quartiles; " "using histogram QC." ) return _infection_qc_histogram( all_df=all_df, settings=settings, infection_col=infection_col, pathogen_chan=pathogen_chan, motility_dir=motility_dir, ) print( "[_infection_qc_xgboost] Extreme-intensity candidates: " f"infected={len(hi_inf)}, uninfected={len(lo_uninf)} " f"(tracked_object={tracked_object}, intensity_col={intensity_col}, " f"low_thr_uninf={low_thr_uninf:.3f}, high_thr_inf={high_thr_inf:.3f})" ) # ------------------------------------------------------------------ # Curate XGBoost training data per well # ------------------------------------------------------------------ hi_inf["xgb_label"] = 1 lo_uninf["xgb_label"] = 0 train_candidates = pd.concat([hi_inf, lo_uninf], axis=0) min_per_class = int(settings.get("infection_xgb_min_cells_per_class", 10)) rng_seed = settings.get("infection_xgb_random_state", 42) try: rng = np.random.default_rng(int(rng_seed)) except Exception: rng = np.random.default_rng(42) train_idx_list = [] y_train_list = [] wells_used = set() wells_single_class = [] grouped = train_candidates.groupby(["plateID", "wellID"], observed=True) for (plate_id, well_id), df_w in grouped: pos_df = df_w[df_w["xgb_label"] == 1] neg_df = df_w[df_w["xgb_label"] == 0] n_pos = len(pos_df) n_neg = len(neg_df) if n_pos == 0 or n_neg == 0: # wells with only one class → skip for training wells_single_class.append((plate_id, well_id)) continue pos_idx = pos_df.index.to_numpy() neg_idx = neg_df.index.to_numpy() if n_pos >= min_per_class and n_neg >= min_per_class: # wells with enough data per class → balanced sampling within well n_per_class = min(n_pos, n_neg) if n_pos > n_per_class: pos_sel = rng.choice(pos_idx, size=n_per_class, replace=False) else: pos_sel = pos_idx if n_neg > n_per_class: neg_sel = rng.choice(neg_idx, size=n_per_class, replace=False) else: neg_sel = neg_idx else: # small wells with both classes → keep all extreme examples pos_sel = pos_idx neg_sel = neg_idx if pos_sel.size == 0 or neg_sel.size == 0: wells_single_class.append((plate_id, well_id)) continue train_idx_list.extend(pos_sel.tolist()) y_train_list.extend([1] * pos_sel.size) train_idx_list.extend(neg_sel.tolist()) y_train_list.extend([0] * neg_sel.size) wells_used.add((plate_id, well_id)) if not wells_used: print( "[_infection_qc_xgboost] No wells with both infected and uninfected " "extreme-intensity examples; skipping XGBoost QC." ) return _infection_qc_histogram( all_df=all_df, settings=settings, infection_col=infection_col, pathogen_chan=pathogen_chan, motility_dir=motility_dir, ) train_idx = np.array(train_idx_list, dtype=int) y_train = np.array(y_train_list, dtype=int) n_pos_train = int((y_train == 1).sum()) n_neg_train = int((y_train == 0).sum()) print( "[_infection_qc_xgboost] Training set (after per-well curation): " f"wells used={len(wells_used)}, positives={n_pos_train}, " f"negatives={n_neg_train}, min_per_class={min_per_class}." ) if wells_single_class: print( "[_infection_qc_xgboost] Wells skipped due to single class in extreme set: " + ", ".join([f"{p}_{w}" for (p, w) in wells_single_class]) ) # ------------------------------------------------------------------ # Build feature matrix + median imputation # ------------------------------------------------------------------ X_all = cell_level[feature_cols].to_numpy(dtype=float) for j in range(X_all.shape[1]): col = X_all[:, j] mask = np.isfinite(col) if not mask.any(): X_all[:, j] = 0.0 else: med = np.nanmedian(col[mask]) col[~mask] = med X_all[:, j] = col X_train = X_all[train_idx] # ------------------------------------------------------------------ # Remove highly correlated features # ------------------------------------------------------------------ if X_train.shape[1] > 1: corr = np.corrcoef(X_train, rowvar=False) corr_thr = float(settings.get("infection_xgb_corr_threshold", 0.95)) keep = np.ones(corr.shape[0], dtype=bool) for i in range(corr.shape[0]): if not keep[i]: continue for j in range(i + 1, corr.shape[0]): if keep[j] and abs(corr[i, j]) >= corr_thr: keep[j] = False if not keep.any(): print( "[_infection_qc_xgboost] All features flagged as highly correlated; " "keeping original feature set." ) else: removed = [f for f, k in zip(feature_cols, keep) if not k] if removed: print( "[_infection_qc_xgboost] Removing highly correlated features " f"(>|{corr_thr:.2f}|): " + ", ".join(removed) ) feature_cols = [f for f, k in zip(feature_cols, keep) if k] X_all = X_all[:, keep] X_train = X_train[:, keep] used_feature_cols = feature_cols if X_train.shape[1] == 0: print( "[_infection_qc_xgboost] No usable feature columns after correlation " "filtering; using histogram QC." ) return _infection_qc_histogram( all_df=all_df, settings=settings, infection_col=infection_col, pathogen_chan=pathogen_chan, motility_dir=motility_dir, ) print( f"[_infection_qc_xgboost] Using {len(used_feature_cols)} " f"{tracked_object}_* features:" ) print(" " + ", ".join(used_feature_cols)) # ------------------------------------------------------------------ # Train XGBoost # ------------------------------------------------------------------ dtrain = xgb.DMatrix(X_train, label=y_train, feature_names=used_feature_cols) params = { "objective": "binary:logistic", "eval_metric": "logloss", "max_depth": int(settings.get("infection_xgb_max_depth", 3)), "eta": float(settings.get("infection_xgb_learning_rate", 0.1)), "subsample": float(settings.get("infection_xgb_subsample", 0.8)), "colsample_bytree": float(settings.get("infection_xgb_colsample_bytree", 0.8)), "lambda": float(settings.get("infection_xgb_reg_lambda", 1.0)), "alpha": 0.0, "verbosity": 0, "nthread": int(settings.get("infection_xgb_n_jobs", -1)), } num_round = int(settings.get("infection_xgb_n_estimators", 200)) bst = xgb.train(params, dtrain, num_boost_round=num_round) # ------------------------------------------------------------------ # Predict all cells # ------------------------------------------------------------------ dall = xgb.DMatrix(X_all, feature_names=used_feature_cols) probs = bst.predict(dall) prob_thr = float(settings.get("infection_xgb_proba_threshold", 0.5)) margin = float(settings.get("infection_xgb_margin", 0.0)) margin = max(0.0, min(margin, 0.49)) orig_arr = cell_level[orig_infection_col].astype(int).to_numpy() pred_arr = (probs >= prob_thr).astype(int) mode = str(settings.get("infection_intensity_mode", "relabel")).lower() if mode not in {"relabel", "remove"}: mode = "relabel" removed_ids = set() ambiguous_ids = set() if mode == "relabel": adjusted = orig_arr.copy() hi_conf = probs >= (prob_thr + margin) lo_conf = probs <= (prob_thr - margin) adjusted[hi_conf] = 1 adjusted[lo_conf] = 0 n_changed = int((adjusted != orig_arr).sum()) print( "[_infection_qc_xgboost] Relabel mode: adjusted infection labels for " f"{n_changed} cells (prob_thr={prob_thr:.2f}, margin={margin:.2f})." ) cell_level["adjusted_infected"] = adjusted cell_level["infection_prob"] = probs else: if margin > 0: ambig = np.abs(probs - prob_thr) < margin else: ambig = np.zeros_like(probs, dtype=bool) disagree = pred_arr != orig_arr to_remove = disagree & ~ambig removed = cell_level.loc[to_remove, key_cols] removed_ids = { (r["plateID"], r["wellID"], r["fieldID"], r["cellID"]) for _, r in removed.iterrows() } cell_level = cell_level.loc[~to_remove].copy() kept_idx = (~to_remove).nonzero()[0] adjusted = orig_arr[kept_idx].copy() pred_kept = pred_arr[kept_idx] ambig_kept = ambig[kept_idx] adjusted[~ambig_kept] = pred_kept[~ambig_kept] cell_level["adjusted_infected"] = adjusted cell_level["infection_prob"] = probs[~to_remove] print( "[_infection_qc_xgboost] Remove mode: removed " f"{len(removed_ids)} cells with strong model vs label disagreement " f"(prob_thr={prob_thr:.2f}, margin={margin:.2f})." ) # ------------------------------------------------------------------ # Drop ambiguous band (probability in [low, high]) # ------------------------------------------------------------------ drop_amb = bool(settings.get("infection_xgb_drop_ambiguous", True)) amb_low = float(settings.get("infection_xgb_ambiguous_low", 0.25)) amb_high = float(settings.get("infection_xgb_ambiguous_high", 0.75)) amb_low = max(0.0, min(amb_low, 1.0)) amb_high = max(0.0, min(amb_high, 1.0)) if amb_low > amb_high: amb_low, amb_high = amb_high, amb_low if drop_amb and "infection_prob" in cell_level.columns: amb_mask = ( (cell_level["infection_prob"] >= amb_low) & (cell_level["infection_prob"] <= amb_high) ) if amb_mask.any(): amb = cell_level.loc[amb_mask, key_cols] ambiguous_ids = { (r["plateID"], r["wellID"], r["fieldID"], r["cellID"]) for _, r in amb.iterrows() } cell_level = cell_level.loc[~amb_mask].copy() print( "[_infection_qc_xgboost] Dropped " f"{len(ambiguous_ids)} cells with ambiguous XGBoost probability " f"in [{amb_low:.2f}, {amb_high:.2f}]." ) # ------------------------------------------------------------------ # Map adjusted calls back to all_df # ------------------------------------------------------------------ for col in key_cols: all_df[col] = all_df[col].astype(cell_level[col].dtype) # (any stale adjusted_infected/infection_prob already dropped above) all_df = all_df.merge( cell_level[key_cols + ["adjusted_infected", "infection_prob"]], on=key_cols, how="left", validate="m:1", ) # Combine removed_sets (disagreement + ambiguous) ids_to_remove = set() if removed_ids: ids_to_remove |= removed_ids if ambiguous_ids: ids_to_remove |= ambiguous_ids if ids_to_remove: mask_drop = all_df.apply( lambda r: (r["plateID"], r["wellID"], r["fieldID"], r["cellID"]) in ids_to_remove, axis=1, ) all_df = all_df.loc[~mask_drop].reset_index(drop=True) mask_missing = all_df["adjusted_infected"].isna() if mask_missing.any(): all_df.loc[mask_missing, "adjusted_infected"] = ( all_df.loc[mask_missing, orig_infection_col].astype(int) ) all_df["adjusted_infected"] = all_df["adjusted_infected"].astype(int) infection_col = "adjusted_infected" try: n_inf = int(all_df[infection_col].sum()) n_uninf = int((1 - all_df[infection_col]).sum()) print( f"[_infection_qc_xgboost] Final infection counts (frame-level): " f"infected={n_inf}, uninfected={n_uninf}" ) except Exception: pass # ------------------------------------------------------------------ # Prepare QC payloads for the combined adjusted panel # - histogram of intensity (adjusted labels) # - PCA embedding of used features (adjusted labels) # - XGBoost feature importances (gain) # ------------------------------------------------------------------ try: # --- histogram payload --- if intensity_col in cell_level.columns: intens = cell_level[intensity_col].to_numpy(dtype=float) labels_adj = cell_level["adjusted_infected"].astype(bool).to_numpy() mask_fin = np.isfinite(intens) intens = intens[mask_fin] labels_adj = labels_adj[mask_fin] vals_inf = intens[labels_adj] vals_uninf = intens[~labels_adj] if intens.size >= 10: n_bins = int(settings.get("infection_intensity_n_bins", 64)) n_bins = max(10, min(n_bins, 256)) _, bin_edges = np.histogram(intens, bins=n_bins) # Use midpoint between training thresholds as a visual threshold thr_val = float(0.5 * (low_thr_uninf + high_thr_inf)) hist_payload = { "intensities_inf": vals_inf, "intensities_uninf": vals_uninf, "bin_edges": bin_edges, "thr_val": thr_val, "pathogen_chan": pathogen_chan, "log_transform": False, "intensity_col": intensity_col, "tracked_object": tracked_object, } settings["infection_hist_data"] = hist_payload # --- PCA payload --- from sklearn.decomposition import PCA from sklearn.preprocessing import StandardScaler if used_feature_cols: X_panel = cell_level[used_feature_cols].to_numpy(dtype=float) for j in range(X_panel.shape[1]): col = X_panel[:, j] m = np.isfinite(col) if not m.any(): X_panel[:, j] = 0.0 else: med = np.nanmedian(col[m]) col[~m] = med X_panel[:, j] = col scaler = StandardScaler() X_scaled_panel = scaler.fit_transform(X_panel) pca = PCA( n_components=2, random_state=int(settings.get("infection_pca_random_state", 0)), ) coords = pca.fit_transform(X_scaled_panel) labels_adj_panel = cell_level["adjusted_infected"].astype(bool).to_numpy() pca_payload = { "coords": coords, "labels": labels_adj_panel, "method_label": "PCA", "tracked_object": tracked_object, } settings["infection_pca_data"] = pca_payload except Exception as e: print(f"[_infection_qc_xgboost] Could not compute histogram/PCA payloads: {e}") # ------------------------------------------------------------------ # Feature importance payload (no PNG; panels draw from this) # ------------------------------------------------------------------ try: importance_dict = bst.get_score(importance_type="gain") or {} feat_names = used_feature_cols feat_vals = [importance_dict.get(f, 0.0) for f in feat_names] # sort and truncate sorted_pairs = sorted( zip(feat_names, feat_vals), key=lambda x: x[1], reverse=True, ) feat_names = [p[0] for p in sorted_pairs] feat_vals = [p[1] for p in sorted_pairs] top_k = int(settings.get("infection_xgb_top_features", 20)) feat_names = feat_names[:top_k] feat_vals = feat_vals[:top_k] settings["infection_xgb_importance"] = { "feature_names": feat_names, "feature_importances": feat_vals, "tracked_object": tracked_object, } if feat_names: print("[_infection_qc_xgboost] Top XGBoost features (gain):") for name, val in zip(feat_names, feat_vals): print(f" {name}: {val:.4g}") except Exception as e: print(f"[_infection_qc_xgboost] Could not compute feature importances: {e}") # Mark QC type for panels; no embedded PNG (mask panel draws nothing) settings["infection_intensity_qc_panel_type"] = "xgboost" settings["infection_intensity_qc_panel_path"] = None return all_df, infection_col
[docs] def automated_motility_assay(settings): """ End-to-end: 1. Read merged/*.npy (plate_well_field_time.npy) 2. Build intensity + cell/nucleus/pathogen masks, derive cytoplasm 3. Per cell & frame: metadata + cell regionprops 4. Aggregate child (nucleus/pathogen/cytoplasm) features per cell 5. Concatenate across all merged files 6. Clean impossible jumps + measurement glitches 7. Save per-cell measurements to SQLite DB (measurements/measurements.db, table=db_table_name) **This table is always the original, pre-QC measurements.** 8. Compute per-track velocities (after smoothing) 9. Save a well-level motility summary table in the same DB 10. Generate *panel* plots combining intensity + motility: - original (mask-based) infection labels - adjusted infection labels (if QC modifies labels) 11. Optional infection intensity QC based on pathogen channel. New-relevant settings (all optional): # Infection QC / strategy 'infection_intensity_qc': True/False 'infection_intensity_strategy': one of {'xgboost', 'histogram', 'pca', 'umap', 'tsne'} 'infection_intensity_mode': {'relabel', 'remove'} # existing # XGBoost ambiguous-band filtering (track-level) 'infection_xgb_drop_ambiguous': True/False (default True) 'infection_xgb_ambiguous_low': 0.25 (default) 'infection_xgb_ambiguous_high': 0.75 (default) 'infection_xgb_proba_column': 'name_of_proba_col' # optional override # Histogram strategy 'infection_hist_percentile': 25 # used inside _apply_infection_intensity_qc # Panel toggles 'make_mask_panel': True/False (default True) 'make_adjusted_panel': True/False (default True) # Plot ranges (unchanged) - 'motility_xlim', 'motility_ylim' - 'motility_origin_xlim', 'motility_origin_ylim' # Measurements reuse 'reuse_existing_measurements': True/False (default True) """ import matplotlib.pyplot as plt # noqa: F401 (used in helpers) from matplotlib import patches # noqa: F401 (used in helpers) import numpy as np import pandas as pd import os from multiprocessing import Pool, cpu_count import sqlite3 from .settings import get_automated_motility_assay_default_settings settings = get_automated_motility_assay_default_settings(settings) src = settings["src"] db_table_name = settings["db_table_name"] n_jobs = settings["n_jobs"] max_displacement = settings["max_displacement"] zscore_thresh = settings["zscore_thresh"] # ------------------------------------------------------------------ # Optional reuse of existing measurements from SQLite # → this table is treated as the ORIGINAL, pre-QC dataset. # ------------------------------------------------------------------ reuse_existing = settings.get("reuse_existing_measurements", True) measurements_dir = os.path.join(src, "measurements") os.makedirs(measurements_dir, exist_ok=True) db_path = os.path.join(measurements_dir, "measurements.db") all_df = None loaded_from_db = False if reuse_existing and os.path.exists(db_path): try: print( f"[summarise_tracks_from_merged] Attempting to reuse existing " f"measurements from {db_path} (table='{db_table_name}')." ) with sqlite3.connect(db_path) as conn: # If the table does not exist, this will raise and fall back to recompute all_df = pd.read_sql_query(f"SELECT * FROM {db_table_name}", conn) if ( all_df is not None and not all_df.empty and {"plateID", "wellID", "fieldID", "cellID", "frame"}.issubset( all_df.columns ) ): n_frames_db = all_df["frame"].nunique() n_tracks_db = ( all_df[["plateID", "wellID", "fieldID", "cellID"]] .drop_duplicates() .shape[0] ) print( "[summarise_tracks_from_merged] Loaded ORIGINAL measurements " f"from DB: shape={all_df.shape}, frames={n_frames_db}, " f"tracks={n_tracks_db}. Skipping regionprops/intensity " "computation from merged .npy files." ) loaded_from_db = True else: print( "[summarise_tracks_from_merged] Loaded table is empty or missing " "required columns; recomputing from merged .npy files." ) all_df = None except Exception as e: print( "[summarise_tracks_from_merged] Failed to reuse existing measurements " f"({e}); recomputing from merged .npy files." ) all_df = None # ------------------------------------------------------------------ # Read merged files & basic metadata (if not reusing DB) # ------------------------------------------------------------------ merged_dir = os.path.join(src, "merged") if not os.path.isdir(merged_dir): raise FileNotFoundError(f"No merged directory at: {merged_dir}") all_files = [f for f in os.listdir(merged_dir) if f.endswith(".npy")] if not all_files: raise FileNotFoundError(f"No .npy files found in {merged_dir}") print( f"[summarise_tracks_from_merged] Found {len(all_files)} merged .npy files " f"in {merged_dir}" ) # group by (plateID, wellID, fieldID) groups = {} for fname in all_files: meta = _parse_merged_filename(fname) key = (meta["plateID"], meta["wellID"], meta["fieldID"]) groups.setdefault(key, []).append(fname) print( "[summarise_tracks_from_merged] Number of (plate, well, field) groups: " f"{len(groups)}" ) cell_chan = settings.get("cell_channel", None) nucleus_chan = settings.get("nucleus_channel", None) pathogen_chan = settings.get("pathogen_channel", None) channels_list = settings.get("channels", []) pixels_per_um = settings.get("pixels_per_um", None) seconds_per_frame = settings.get("seconds_per_frame", None) n_channels = len(channels_list) if isinstance(channels_list, (list, tuple)) else None if n_channels is None or n_channels <= 0: raise ValueError( "settings['channels'] must be a non-empty list of channels used " "in merged arrays." ) print( f"[summarise_tracks_from_merged] Channels={channels_list}, " f"cell_chan={cell_chan}, nucleus_chan={nucleus_chan}, " f"pathogen_chan={pathogen_chan}" ) motility_dir = os.path.join(src, "motility_plots") os.makedirs(motility_dir, exist_ok=True) # Debug-plot one sample merged array with channel/mask labels sample_filename = sorted(all_files)[0] print( "[summarise_tracks_from_merged] Debug plotting planes for sample file: " f"{sample_filename}" ) _debug_plot_merged_planes( src=src, sample_filename=sample_filename, n_channels=n_channels, nucleus_chan=nucleus_chan, pathogen_chan=pathogen_chan, out_dir=motility_dir, ) # ------------------------------------------------------------------ # Build measurements if not reusing from DB # ------------------------------------------------------------------ if not loaded_from_db: worker_args = [] for key, file_basenames in groups.items(): worker_args.append( (src, file_basenames, n_channels, cell_chan, nucleus_chan, pathogen_chan) ) if n_jobs is None: n_jobs = max(cpu_count() - 1, 1) print(f"[summarise_tracks_from_merged] Using n_jobs={n_jobs}") if n_jobs == 1: dfs = [_process_merged_group(args) for args in worker_args] else: with Pool(processes=n_jobs) as pool: dfs = pool.map(_process_merged_group, worker_args) all_df = ( pd.concat([df for df in dfs if not df.empty], ignore_index=True) if dfs else pd.DataFrame() ) if all_df.empty: raise RuntimeError("No measurements were produced from merged .npy files.") print( "[summarise_tracks_from_merged] Combined raw measurements: " f"shape={all_df.shape}, frames={all_df['frame'].nunique()}" ) n_tracks_raw = ( all_df[["plateID", "wellID", "fieldID", "cellID"]] .drop_duplicates() .shape[0] ) print( "[summarise_tracks_from_merged] Unique tracks before smoothing: " f"{n_tracks_raw}" ) # Clean tracks all_df = _smooth_tracks_and_features( all_df, max_displacement=max_displacement, zscore_thresh=zscore_thresh, ) n_tracks_smoothed = ( all_df[["plateID", "wellID", "fieldID", "cellID"]] .drop_duplicates() .shape[0] ) print( "[summarise_tracks_from_merged] After smoothing: " f"shape={all_df.shape}, frames={all_df['frame'].nunique()}, " f"tracks={n_tracks_smoothed}" ) else: # Already loaded smoothed measurements from DB (treated as ORIGINAL) n_frames_db = all_df["frame"].nunique() n_tracks_db = ( all_df[["plateID", "wellID", "fieldID", "cellID"]] .drop_duplicates() .shape[0] ) print( "[summarise_tracks_from_merged] Reusing ORIGINAL smoothed measurements " f"from DB: shape={all_df.shape}, frames={n_frames_db}, " f"tracks={n_tracks_db}" ) # ------------------------------------------------------------------ # Infection status per track (mask-based) # - This is part of the ORIGINAL dataset. # ------------------------------------------------------------------ if "infected" in all_df.columns: # Reuse existing infection labels (DB-reused or previous run), # but ensure no NaNs and correct dtype. all_df["infected"] = all_df["infected"].fillna(False).astype(bool) elif "n_pathogens" in all_df.columns: tmp = all_df[["plateID", "wellID", "fieldID", "cellID", "n_pathogens"]].copy() tmp["n_pathogens"] = tmp["n_pathogens"].fillna(0) infected = ( tmp.groupby(["plateID", "wellID", "fieldID", "cellID"])["n_pathogens"] .max() .gt(0) ) infected = infected.reset_index() infected = infected.rename(columns={"n_pathogens": "infected"}) infected["infected"] = infected["infected"].astype(bool) all_df = all_df.merge( infected[["plateID", "wellID", "fieldID", "cellID", "infected"]], on=["plateID", "wellID", "fieldID", "cellID"], how="left", ) all_df["infected"] = all_df["infected"].fillna(False).astype(bool) else: all_df["infected"] = False n_infected_tracks = ( all_df[all_df["infected"]][["plateID", "wellID", "fieldID", "cellID"]] .drop_duplicates() .shape[0] ) n_uninfected_tracks = ( all_df[~all_df["infected"]][["plateID", "wellID", "fieldID", "cellID"]] .drop_duplicates() .shape[0] ) print( "[summarise_tracks_from_merged] Tracks (mask-based): " f"infected={n_infected_tracks}, uninfected={n_uninfected_tracks}" ) # ------------------------------------------------------------------ # SNAPSHOT: ORIGINAL measurements (pre-QC) to be stored in SQLite. # This copy is never overridden by adjusted/QC'd data. # ------------------------------------------------------------------ all_df_original = all_df.copy(deep=True) # ------------------------------------------------------------------ # Optional infection-intensity QC (may create 'adjusted_infected') # - This operates on all_df only (not on all_df_original). # ------------------------------------------------------------------ infection_col = "infected" all_df, infection_col = _apply_infection_intensity_qc( all_df=all_df, settings=settings, infection_col=infection_col, motility_dir=motility_dir, pathogen_chan=pathogen_chan, ) # ------------------------------------------------------------------ # XGBoost ambiguous-band filtering (track-level) # ------------------------------------------------------------------ if ( settings.get("infection_intensity_qc", False) and str(settings.get("infection_intensity_strategy", "")).lower() == "xgboost" and settings.get("infection_xgb_drop_ambiguous", True) ): low = settings.get("infection_xgb_ambiguous_low", 0.25) high = settings.get("infection_xgb_ambiguous_high", 0.75) # Try to locate a probability column created by the QC step xgb_proba_col = settings.get("infection_xgb_proba_column", None) if xgb_proba_col is None: cand_cols = [ c for c in all_df.columns if "xgb" in c.lower() and ( "proba" in c.lower() or "prob" in c.lower() or "score" in c.lower() ) ] if not cand_cols: cand_cols = [ c for c in all_df.columns if "infection" in c.lower() and ( "proba" in c.lower() or "prob" in c.lower() or "score" in c.lower() ) ] if cand_cols: xgb_proba_col = cand_cols[0] if xgb_proba_col and xgb_proba_col in all_df.columns: track_keys = ["plateID", "wellID", "fieldID", "cellID"] track_scores = ( all_df[track_keys + [xgb_proba_col]] .groupby(track_keys)[xgb_proba_col] .mean() .reset_index() ) ambiguous = track_scores[ (track_scores[xgb_proba_col] > low) & (track_scores[xgb_proba_col] < high) ][track_keys] if not ambiguous.empty: before = all_df.shape[0] all_df = all_df.merge( ambiguous.assign(_ambiguous_flag=1), on=track_keys, how="left", ) all_df = all_df[all_df["_ambiguous_flag"].isna()].drop( columns=["_ambiguous_flag"] ) after = all_df.shape[0] print( "[summarise_tracks_from_merged] Dropped " f"{before - after} rows from {ambiguous.shape[0]} ambiguous " f"XGBoost tracks ({low} < proba < {high})." ) else: print( "[summarise_tracks_from_merged] WARNING: " "infection_xgb_drop_ambiguous is True, but no XGBoost " "probability/score column was found. Skipping ambiguous-track " "filtering." ) # ------------------------------------------------------------------ # Save ADJUSTED frame-level measurements to CSV ONLY. # This NEVER overwrites the SQLite original table. # ------------------------------------------------------------------ try: qc_strategy = str(settings.get("infection_intensity_strategy", "none")).lower() if qc_strategy in {"", "none", "null"}: adjusted_basename = f"{db_table_name}_adjusted.csv" else: adjusted_basename = f"{db_table_name}_adjusted_{qc_strategy}.csv" adjusted_csv_path = os.path.join(measurements_dir, adjusted_basename) all_df.to_csv(adjusted_csv_path, index=False) print( "[summarise_tracks_from_merged] Saved ADJUSTED frame-level measurements " f"to CSV: {adjusted_csv_path}" ) except Exception as e: print( f"[summarise_tracks_from_merged] WARNING: failed to save adjusted CSV " f"({e})" ) # ------------------------------------------------------------------ # Compute per-track velocities + per-well summary # ------------------------------------------------------------------ ( track_df_mask, per_well_tracks_mask, well_summary_mask, vel_unit_mask, ) = _compute_velocities_and_well_summary( all_df=all_df, settings=settings, infection_col="infected", pixels_per_um=pixels_per_um, seconds_per_frame=seconds_per_frame, ) ( track_df, per_well_tracks, well_summary_df, vel_unit, ) = _compute_velocities_and_well_summary( all_df=all_df, settings=settings, infection_col=infection_col, pixels_per_um=pixels_per_um, seconds_per_frame=seconds_per_frame, ) # ------------------------------------------------------------------ # Save to DB: # - all_df_original (pre-QC snapshot) is written to db_table_name # - well_summary_df is written as usual by _save_measurements_and_well_summary # → adjusted labels NEVER touch the canonical measurements table. # ------------------------------------------------------------------ measurements_dir, db_path = _save_measurements_and_well_summary( all_df=all_df_original, well_summary_df=well_summary_df, src=src, db_table_name=db_table_name, ) # Feature–velocity correlation analysis (final labels, adjusted view) _feature_velocity_correlations(all_df, track_df, measurements_dir) # ------------------------------------------------------------------ # Combined intensity + motility panels # ------------------------------------------------------------------ qc_strategy = str(settings.get("infection_intensity_strategy", "none")).lower() # Intensity + motility panel for mask-based labels if settings.get("make_mask_panel", True): _make_intensity_motility_panel( all_df=all_df, infection_col="infected", track_df=track_df_mask, per_well_tracks=per_well_tracks_mask, n_channels=n_channels, motility_dir=motility_dir, pixels_per_um=pixels_per_um, seconds_per_frame=seconds_per_frame, vel_unit=vel_unit_mask, settings=settings, # encode both label type and QC strategy in the tag label_tag=f"mask_{qc_strategy}", ) # Intensity + motility panel for adjusted labels (if distinct) if ( settings.get("make_adjusted_panel", True) and infection_col in all_df.columns and infection_col != "infected" ): _make_intensity_motility_panel( all_df=all_df, infection_col=infection_col, track_df=track_df, per_well_tracks=per_well_tracks, n_channels=n_channels, motility_dir=motility_dir, pixels_per_um=pixels_per_um, seconds_per_frame=seconds_per_frame, vel_unit=vel_unit, settings=settings, label_tag=f"adjusted_{qc_strategy}", ) return all_df