import os,re, random, cv2, glob, time, math, torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import scipy.ndimage as ndi
import seaborn as sns
import scipy.stats as stats
import statsmodels.api as sm
import imageio.v2 as imageio
from IPython.display import display
from skimage.segmentation import find_boundaries
from skimage import measure
from skimage.measure import find_contours, label, regionprops
from ipywidgets import IntSlider, interact
from IPython.display import Image as ipyimage
from .logger import log_function_call
[docs]
def plot_image_mask_overlay(file, channels, cell_channel, nucleus_channel, pathogen_channel, figuresize=10, normalize=True, thickness=3, save_pdf=True):
"""Plot image and mask overlays."""
def _plot_merged_plot(image, outlines, outline_colors, figuresize, thickness):
"""Plot the merged plot with overlay, image channels, and masks."""
def _normalize_image(image, percentiles=(2, 98)):
"""Normalize the image to the given percentiles."""
v_min, v_max = np.percentile(image, percentiles)
image_normalized = np.clip((image - v_min) / (v_max - v_min), 0, 1)
return image_normalized
def _generate_contours(mask):
"""Generate contours for the given mask using OpenCV."""
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
return contours
def _apply_contours(image, mask, color, thickness):
"""Apply the contours to the RGB image for each unique label."""
unique_labels = np.unique(mask)
for label in unique_labels:
if label == 0:
continue # Skip background
label_mask = np.where(mask == label, 1, 0).astype(np.uint8)
contours = _generate_contours(label_mask)
for contour in contours:
cv2.drawContours(image, [contour], -1, mpl.colors.to_rgb(color), thickness)
return image
num_channels = image.shape[-1]
fig, ax = plt.subplots(1, num_channels + 1, figsize=(4 * figuresize, figuresize))
# Plot each channel with its corresponding outlines
for v in range(num_channels):
channel_image = image[..., v]
channel_image_normalized = _normalize_image(channel_image)
channel_image_rgb = np.dstack((channel_image_normalized, channel_image_normalized, channel_image_normalized))
for outline, color in zip(outlines, outline_colors):
channel_image_rgb = _apply_contours(channel_image_rgb, outline, color, thickness)
ax[v].imshow(channel_image_rgb)
ax[v].set_title(f'Image - Channel {v}')
# Plot the combined RGB image with all outlines
rgb_image = np.zeros((*image.shape[:2], 3), dtype=float)
rgb_channels = min(3, num_channels)
for i in range(rgb_channels):
channel_image = image[..., i]
channel_image_normalized = _normalize_image(channel_image)
rgb_image[..., i] = channel_image_normalized
for outline, color in zip(outlines, outline_colors):
rgb_image = _apply_contours(rgb_image, outline, color, thickness)
ax[-1].imshow(rgb_image)
ax[-1].set_title('Combined RGB Image')
plt.tight_layout()
# Save the figure as a PDF
if save_pdf:
pdf_dir = os.path.join(os.path.dirname(os.path.dirname(file)), 'results', 'overlay')
os.makedirs(pdf_dir, exist_ok=True)
pdf_path = os.path.join(pdf_dir, os.path.basename(file).replace('.npy', '.pdf'))
fig.savefig(pdf_path, format='pdf')
plt.show()
return fig
stack = np.load(file)
# Convert to float for normalization and ensure correct handling of both 8-bit and 16-bit arrays
if stack.dtype == np.uint16:
stack = stack.astype(np.float32)
elif stack.dtype == np.uint8:
stack = stack.astype(np.float32)
image = stack[..., channels]
outlines = []
outline_colors = []
if pathogen_channel is not None:
pathogen_mask_dim = -1 # last dimension
outlines.append(np.take(stack, pathogen_mask_dim, axis=2))
outline_colors.append('blue')
if nucleus_channel is not None:
nucleus_mask_dim = -2 if pathogen_channel is not None else -1
outlines.append(np.take(stack, nucleus_mask_dim, axis=2))
outline_colors.append('green')
if cell_channel is not None:
if nucleus_channel is not None and pathogen_channel is not None:
cell_mask_dim = -3
elif nucleus_channel is not None or pathogen_channel is not None:
cell_mask_dim = -2
else:
cell_mask_dim = -1
outlines.append(np.take(stack, cell_mask_dim, axis=2))
outline_colors.append('red')
fig = _plot_merged_plot(image=image, outlines=outlines, outline_colors=outline_colors, figuresize=figuresize, thickness=thickness)
return
[docs]
def plot_masks(batch, masks, flows, cmap='inferno', figuresize=10, nr=1, file_type='.npz', print_object_number=True):
"""
Plot the masks and flows for a given batch of images.
Args:
batch (numpy.ndarray): The batch of images.
masks (list or numpy.ndarray): The masks corresponding to the images.
flows (list or numpy.ndarray): The flows corresponding to the images.
cmap (str, optional): The colormap to use for displaying the images. Defaults to 'inferno'.
figuresize (int, optional): The size of the figure. Defaults to 20.
nr (int, optional): The maximum number of images to plot. Defaults to 1.
file_type (str, optional): The file type of the flows. Defaults to '.npz'.
print_object_number (bool, optional): Whether to print the object number on the mask. Defaults to True.
Returns:
None
"""
if len(batch.shape) == 3:
batch = np.expand_dims(batch, axis=0)
if not isinstance(masks, list):
masks = [masks]
if not isinstance(flows, list):
flows = [flows]
else:
flows = flows[0]
if file_type == 'png':
flows = [f[0] for f in flows] # assuming this is what you want to do when file_type is 'png'
font = figuresize/2
index = 0
for image, mask, flow in zip(batch, masks, flows):
unique_labels = np.unique(mask)
num_objects = len(unique_labels[unique_labels != 0])
random_colors = np.random.rand(num_objects+1, 4)
random_colors[:, 3] = 1
random_colors[0, :] = [0, 0, 0, 1]
random_cmap = mpl.colors.ListedColormap(random_colors)
if index < nr:
index += 1
chans = image.shape[-1]
fig, ax = plt.subplots(1, image.shape[-1] + 2, figsize=(4 * figuresize, figuresize))
for v in range(0, image.shape[-1]):
ax[v].imshow(image[..., v], cmap=cmap) #_imshow
ax[v].set_title('Image - Channel'+str(v))
ax[chans].imshow(mask, cmap=random_cmap) #_imshow
ax[chans].set_title('Mask')
if print_object_number:
unique_objects = np.unique(mask)[1:]
for obj in unique_objects:
cy, cx = ndi.center_of_mass(mask == obj)
ax[chans].text(cx, cy, str(obj), color='white', fontsize=font, ha='center', va='center')
ax[chans+1].imshow(flow, cmap='viridis') #_imshow
ax[chans+1].set_title('Flow')
plt.show()
return
def _plot_4D_arrays(src, figuresize=10, cmap='inferno', nr_npz=1, nr=1):
"""
Plot 4D arrays from .npz files.
Args:
src (str): The directory path where the .npz files are located.
figuresize (int, optional): The size of the figure. Defaults to 10.
cmap (str, optional): The colormap to use for image visualization. Defaults to 'inferno'.
nr_npz (int, optional): The number of .npz files to plot. Defaults to 1.
nr (int, optional): The number of images to plot from each .npz file. Defaults to 1.
"""
paths = [os.path.join(src, file) for file in os.listdir(src) if file.endswith('.npz')]
paths = random.sample(paths, min(nr_npz, len(paths)))
for path in paths:
with np.load(path) as data:
stack = data['data']
num_images = stack.shape[0]
num_channels = stack.shape[3]
for i in range(min(nr, num_images)):
img = stack[i]
# Create subplots
if num_channels == 1:
fig, axs = plt.subplots(1, 1, figsize=(figuresize, figuresize))
axs = [axs] # Make axs a list to use axs[c] later
else:
fig, axs = plt.subplots(1, num_channels, figsize=(num_channels * figuresize, figuresize))
for c in range(num_channels):
axs[c].imshow(img[:, :, c], cmap=cmap) #_imshow
axs[c].set_title(f'Channel {c}', size=24)
axs[c].axis('off')
fig.tight_layout()
plt.show()
return
[docs]
def generate_mask_random_cmap(mask):
"""
Generate a random colormap based on the unique labels in the given mask.
Parameters:
mask (numpy.ndarray): The input mask array.
Returns:
matplotlib.colors.ListedColormap: The random colormap.
"""
unique_labels = np.unique(mask)
num_objects = len(unique_labels[unique_labels != 0])
random_colors = np.random.rand(num_objects+1, 4)
random_colors[:, 3] = 1
random_colors[0, :] = [0, 0, 0, 1]
random_cmap = mpl.colors.ListedColormap(random_colors)
return random_cmap
[docs]
def random_cmap(num_objects=100):
"""
Generate a random colormap.
Parameters:
num_objects (int): The number of objects to generate colors for. Default is 100.
Returns:
random_cmap (matplotlib.colors.ListedColormap): A random colormap.
"""
random_colors = np.random.rand(num_objects+1, 4)
random_colors[:, 3] = 1
random_colors[0, :] = [0, 0, 0, 1]
random_cmap = mpl.colors.ListedColormap(random_colors)
return random_cmap
def _generate_mask_random_cmap(mask):
"""
Generate a random colormap based on the unique labels in the given mask.
Parameters:
mask (ndarray): The mask array containing unique labels.
Returns:
ListedColormap: A random colormap generated based on the unique labels in the mask.
"""
unique_labels = np.unique(mask)
num_objects = len(unique_labels[unique_labels != 0])
random_colors = np.random.rand(num_objects+1, 4)
random_colors[:, 3] = 1
random_colors[0, :] = [0, 0, 0, 1]
random_cmap = mpl.colors.ListedColormap(random_colors)
return random_cmap
def _get_colours_merged(outline_color):
"""
Get the merged outline colors based on the specified outline color format.
Parameters:
outline_color (str): The outline color format. Can be one of 'rgb', 'bgr', 'gbr', or 'rbg'.
Returns:
list: A list of merged outline colors based on the specified format.
"""
if outline_color == 'rgb':
outline_colors = [[1, 0, 0], [0, 1, 0], [0, 0, 1]] # rgb
elif outline_color == 'bgr':
outline_colors = [[0, 0, 1], [0, 1, 0], [1, 0, 0]] # bgr
elif outline_color == 'gbr':
outline_colors = [[0, 1, 0], [0, 0, 1], [1, 0, 0]] # gbr
elif outline_color == 'rbg':
outline_colors = [[1, 0, 0], [0, 0, 1], [0, 1, 0]] # rbg
else:
outline_colors = [[1, 0, 0], [0, 0, 1], [0, 1, 0]] # rbg
return outline_colors
[docs]
def plot_images_and_arrays(folders, lower_percentile=1, upper_percentile=99, threshold=1000, extensions=['.npy', '.tif', '.tiff', '.png'], overlay=False, max_nr=None, randomize=True):
"""
Plot images and arrays from the given folders.
Args:
folders (list): A list of folder paths containing the images and arrays.
lower_percentile (int, optional): The lower percentile for image normalization. Defaults to 1.
upper_percentile (int, optional): The upper percentile for image normalization. Defaults to 99.
threshold (int, optional): The threshold for determining whether to display an image as a mask or normalize it. Defaults to 1000.
extensions (list, optional): A list of file extensions to consider. Defaults to ['.npy', '.tif', '.tiff', '.png'].
overlay (bool, optional): If True, overlay the outlines of the objects on the image. Defaults to False.
"""
def normalize_image(image, lower=1, upper=99):
p2, p98 = np.percentile(image, (lower, upper))
return np.clip((image - p2) / (p98 - p2), 0, 1)
def find_files(folders, extensions=['.npy', '.tif', '.tiff', '.png']):
file_dict = {}
for folder in folders:
for root, _, files in os.walk(folder):
for file in files:
if any(file.endswith(ext) for ext in extensions):
file_name_wo_ext = os.path.splitext(file)[0]
file_path = os.path.join(root, file)
if file_name_wo_ext not in file_dict:
file_dict[file_name_wo_ext] = {}
file_dict[file_name_wo_ext][folder] = file_path
# Filter out files that don't have paths in all folders
filtered_dict = {k: v for k, v in file_dict.items() if len(v) == len(folders)}
return filtered_dict
def plot_from_file_dict(file_dict, threshold=1000, lower_percentile=1, upper_percentile=99, overlay=False, save=False):
"""
Plot images and arrays from the given file dictionary.
Args:
file_dict (dict): A dictionary containing the file paths for each image or array.
threshold (int, optional): The threshold for determining whether to display an image as a mask or normalize it. Defaults to 1000.
lower_percentile (int, optional): The lower percentile for image normalization. Defaults to 1.
upper_percentile (int, optional): The upper percentile for image normalization. Defaults to 99.
overlay (bool, optional): If True, overlay the outlines of the objects on the image. Defaults to False.
"""
for filename, folder_paths in file_dict.items():
image_data = None
mask_data = None
for folder, path in folder_paths.items():
if path.endswith('.npy'):
data = np.load(path)
elif path.endswith('.tif') or path.endswith('.tiff'):
data = imageio.imread(path)
else:
continue
unique_values = np.unique(data)
if len(unique_values) > threshold:
image_data = normalize_image(data, lower_percentile, upper_percentile)
else:
mask_data = data
if image_data is not None and mask_data is not None:
fig, axes = plt.subplots(1, 2, figsize=(15, 7))
# Display the mask with random colormap
cmap = random_cmap(num_objects=len(np.unique(mask_data)))
axes[0].imshow(mask_data, cmap=cmap)
axes[0].set_title(f"{filename} - Mask")
axes[0].axis('off')
# Display the normalized image
axes[1].imshow(image_data, cmap='gray')
if overlay:
labeled_mask = label(mask_data)
for region in regionprops(labeled_mask):
if region.image.shape[0] >= 2 and region.image.shape[1] >= 2:
contours = find_contours(region.image, 0.75)
for contour in contours:
# Adjust contour coordinates relative to the full image
contour[:, 0] += region.bbox[0]
contour[:, 1] += region.bbox[1]
axes[1].plot(contour[:, 1], contour[:, 0], linewidth=2, color='magenta')
axes[1].set_title(f"{filename} - Normalized Image")
axes[1].axis('off')
plt.tight_layout()
plt.show()
if save:
save_path = os.path.join(folder,f"{filename}.png")
plt.savefig(save_path)
if overlay:
print(f'Overlay will only work on the first two folders in the list')
file_dict = find_files(folders, extensions)
items = list(file_dict.items())
if randomize:
random.shuffle(items)
if isinstance(max_nr, (int, float)):
items = items[:int(max_nr)]
file_dict = dict(items)
plot_from_file_dict(file_dict, threshold, lower_percentile, upper_percentile, overlay, save=False)
return
def _filter_objects_in_plot(stack, cell_mask_dim, nucleus_mask_dim, pathogen_mask_dim, mask_dims, filter_min_max, include_multinucleated, include_multiinfected):
"""
Filters objects in a plot based on various criteria.
Args:
stack (numpy.ndarray): The input stack of masks.
cell_mask_dim (int): The dimension index of the cell mask.
nucleus_mask_dim (int): The dimension index of the nucleus mask.
pathogen_mask_dim (int): The dimension index of the pathogen mask.
mask_dims (list): A list of dimension indices for additional masks.
filter_min_max (list): A list of minimum and maximum area values for each mask.
include_multinucleated (bool): Whether to include multinucleated cells.
include_multiinfected (bool): Whether to include multiinfected cells.
Returns:
numpy.ndarray: The filtered stack of masks.
"""
from .utils import _remove_outside_objects, _remove_multiobject_cells
stack = _remove_outside_objects(stack, cell_mask_dim, nucleus_mask_dim, pathogen_mask_dim)
for i, mask_dim in enumerate(mask_dims):
if not filter_min_max is None:
min_max = filter_min_max[i]
else:
min_max = [0, 100000000]
mask = np.take(stack, mask_dim, axis=2)
props = measure.regionprops_table(mask, properties=['label', 'area'])
#props = measure.regionprops_table(mask, intensity_image=intensity_image, properties=['label', 'area', 'mean_intensity'])
avg_size_before = np.mean(props['area'])
total_count_before = len(props['label'])
if not filter_min_max is None:
valid_labels = props['label'][np.logical_and(props['area'] > min_max[0], props['area'] < min_max[1])]
stack[:, :, mask_dim] = np.isin(mask, valid_labels) * mask
props_after = measure.regionprops_table(stack[:, :, mask_dim], properties=['label', 'area'])
avg_size_after = np.mean(props_after['area'])
total_count_after = len(props_after['label'])
if mask_dim == cell_mask_dim:
if include_multinucleated is False and nucleus_mask_dim is not None:
stack = _remove_multiobject_cells(stack, mask_dim, cell_mask_dim, nucleus_mask_dim, pathogen_mask_dim, object_dim=pathogen_mask_dim)
if include_multiinfected is False and cell_mask_dim is not None and pathogen_mask_dim is not None:
stack = _remove_multiobject_cells(stack, mask_dim, cell_mask_dim, nucleus_mask_dim, pathogen_mask_dim, object_dim=nucleus_mask_dim)
cell_area_before = avg_size_before
cell_count_before = total_count_before
cell_area_after = avg_size_after
cell_count_after = total_count_after
if mask_dim == nucleus_mask_dim:
nucleus_area_before = avg_size_before
nucleus_count_before = total_count_before
nucleus_area_after = avg_size_after
nucleus_count_after = total_count_after
if mask_dim == pathogen_mask_dim:
pathogen_area_before = avg_size_before
pathogen_count_before = total_count_before
pathogen_area_after = avg_size_after
pathogen_count_after = total_count_after
if cell_mask_dim is not None:
print(f'removed {cell_count_before-cell_count_after} cells, cell size from {cell_area_before} to {cell_area_after}')
if nucleus_mask_dim is not None:
print(f'removed {nucleus_count_before-nucleus_count_after} nucleus, nucleus size from {nucleus_area_before} to {nucleus_area_after}')
if pathogen_mask_dim is not None:
print(f'removed {pathogen_count_before-pathogen_count_after} pathogens, pathogen size from {pathogen_area_before} to {pathogen_area_after}')
return stack
[docs]
def plot_arrays(src, figuresize=10, cmap='inferno', nr=1, normalize=True, q1=1, q2=99):
"""
Plot randomly selected arrays from a given directory.
Parameters:
- src (str): The directory path containing the arrays.
- figuresize (int): The size of the figure (default: 50).
- cmap (str): The colormap to use for displaying the arrays (default: 'inferno').
- nr (int): The number of arrays to plot (default: 1).
- normalize (bool): Whether to normalize the arrays (default: True).
- q1 (int): The lower percentile for normalization (default: 1).
- q2 (int): The upper percentile for normalization (default: 99).
Returns:
None
"""
from .utils import normalize_to_dtype
mask_cmap = random_cmap()
paths = []
for file in os.listdir(src):
if file.endswith('.npy'):
path = os.path.join(src, file)
paths.append(path)
paths = random.sample(paths, nr)
for path in paths:
print(f'Image path:{path}')
img = np.load(path)
if normalize:
img = normalize_to_dtype(array=img, p1=q1, p2=q2)
dim = img.shape
if len(img.shape)>2:
array_nr = img.shape[2]
fig, axs = plt.subplots(1, array_nr,figsize=(figuresize,figuresize))
for channel in range(array_nr):
i = np.take(img, [channel], axis=2)
axs[channel].imshow(i, cmap=plt.get_cmap(cmap)) #_imshow
axs[channel].set_title('Channel '+str(channel),size=24)
axs[channel].axis('off')
else:
fig, ax = plt.subplots(1, 1,figsize=(figuresize,figuresize))
ax.imshow(img, cmap=plt.get_cmap(cmap)) #_imshow
ax.set_title('Channel 0',size=24)
ax.axis('off')
fig.tight_layout()
plt.show()
return
def _normalize_and_outline(image, remove_background, normalize, normalization_percentiles, overlay, overlay_chans, mask_dims, outline_colors, outline_thickness):
"""
Normalize and outline an image.
Args:
image (ndarray): The input image.
remove_background (bool): Flag indicating whether to remove the background.
backgrounds (list): List of background values for each channel.
normalize (bool): Flag indicating whether to normalize the image.
normalization_percentiles (list): List of percentiles for normalization.
overlay (bool): Flag indicating whether to overlay outlines onto the image.
overlay_chans (list): List of channel indices to overlay.
mask_dims (list): List of dimensions to use for masking.
outline_colors (list): List of colors for the outlines.
outline_thickness (int): Thickness of the outlines.
Returns:
tuple: A tuple containing the overlayed image, the original image, and a list of outlines.
"""
from .utils import normalize_to_dtype, _outline_and_overlay, _gen_rgb_image
if remove_background:
backgrounds = np.percentile(image, 1, axis=(0, 1))
backgrounds = backgrounds[:, np.newaxis, np.newaxis]
mask = np.zeros_like(image, dtype=bool)
for chan_index in range(image.shape[-1]):
if chan_index not in mask_dims:
mask[:, :, chan_index] = image[:, :, chan_index] < backgrounds[chan_index]
image[mask] = 0
if normalize:
image = normalize_to_dtype(array=image, p1=normalization_percentiles[0], p2=normalization_percentiles[1])
else:
image = normalize_to_dtype(array=image, p1=0, p2=100)
rgb_image = _gen_rgb_image(image, channels=overlay_chans)
if overlay:
overlayed_image, outlines, image = _outline_and_overlay(image, rgb_image, mask_dims, outline_colors, outline_thickness)
return overlayed_image, image, outlines
else:
# Remove mask_dims from image
channels_to_keep = [i for i in range(image.shape[-1]) if i not in mask_dims]
image = np.take(image, channels_to_keep, axis=-1)
return [], image, []
def _plot_merged_plot(overlay, image, stack, mask_dims, figuresize, overlayed_image, outlines, cmap, outline_colors, print_object_number):
"""
Plot the merged plot with overlay, image channels, and masks.
Args:
overlay (bool): Flag indicating whether to overlay the image with outlines.
image (ndarray): Input image array.
stack (ndarray): Stack of masks.
mask_dims (list): List of mask dimensions.
figuresize (float): Size of the figure.
overlayed_image (ndarray): Overlayed image array.
outlines (list): List of outlines.
cmap (str): Colormap for the masks.
outline_colors (list): List of outline colors.
print_object_number (bool): Flag indicating whether to print object numbers on the masks.
Returns:
fig (Figure): The generated matplotlib figure.
"""
if overlay:
fig, ax = plt.subplots(1, image.shape[-1] + len(mask_dims) + 1, figsize=(4 * figuresize, figuresize))
ax[0].imshow(overlayed_image) #_imshow
ax[0].set_title('Overlayed Image')
ax_index = 1
else:
fig, ax = plt.subplots(1, image.shape[-1] + len(mask_dims), figsize=(4 * figuresize, figuresize))
ax_index = 0
# Normalize and plot each channel with outlines
for v in range(0, image.shape[-1]):
channel_image = image[..., v]
channel_image_normalized = channel_image.astype(float)
channel_image_normalized -= channel_image_normalized.min()
channel_image_normalized /= channel_image_normalized.max()
channel_image_rgb = np.dstack((channel_image_normalized, channel_image_normalized, channel_image_normalized))
# Apply the outlines onto the RGB image
for outline, color in zip(outlines, outline_colors):
for j in np.unique(outline)[1:]:
channel_image_rgb[outline == j] = mpl.colors.to_rgb(color)
ax[v + ax_index].imshow(channel_image_rgb)
ax[v + ax_index].set_title('Image - Channel'+str(v))
for i, mask_dim in enumerate(mask_dims):
mask = np.take(stack, mask_dim, axis=2)
random_cmap = _generate_mask_random_cmap(mask)
ax[i + image.shape[-1] + ax_index].imshow(mask, cmap=random_cmap)
ax[i + image.shape[-1] + ax_index].set_title('Mask '+ str(i))
if print_object_number:
unique_objects = np.unique(mask)[1:]
for obj in unique_objects:
cy, cx = ndi.center_of_mass(mask == obj)
ax[i + image.shape[-1] + ax_index].text(cx, cy, str(obj), color='white', fontsize=8, ha='center', va='center')
plt.tight_layout()
plt.show()
return fig
[docs]
def plot_merged(src, settings):
"""
Plot the merged images after applying various filters and modifications.
Args:
src (path): Path to folder with images.
settings (dict): The settings for the plot.
Returns:
None
"""
from .utils import _remove_noninfected
font = settings['figuresize']/2
outline_colors = _get_colours_merged(settings['outline_color'])
index = 0
mask_dims = [settings['cell_mask_dim'], settings['nucleus_mask_dim'], settings['pathogen_mask_dim']]
mask_dims = [element for element in mask_dims if element is not None]
if settings['verbose']:
display(settings)
if settings['pathogen_mask_dim'] is None:
settings['include_multiinfected'] = True
for file in os.listdir(src):
path = os.path.join(src, file)
stack = np.load(path)
print(f'Loaded: {path}')
if not settings['include_noninfected']:
if settings['pathogen_mask_dim'] is not None and settings['cell_mask_dim'] is not None:
stack = _remove_noninfected(stack, settings['cell_mask_dim'], settings['nucleus_mask_dim'], settings['pathogen_mask_dim'])
if settings['include_multiinfected'] is not True or settings['include_multinucleated'] is not True or settings['filter_min_max'] is not None:
stack = _filter_objects_in_plot(stack, settings['cell_mask_dim'], settings['nucleus_mask_dim'], settings['pathogen_mask_dim'], mask_dims, settings['filter_min_max'], settings['include_multinucleated'], settings['include_multiinfected'])
overlayed_image, image, outlines = _normalize_and_outline(image=stack,
remove_background=settings['remove_background'],
normalize=settings['normalize'],
normalization_percentiles=settings['normalization_percentiles'],
overlay=settings['overlay'],
overlay_chans=settings['overlay_chans'],
mask_dims=mask_dims,
outline_colors=outline_colors,
outline_thickness=settings['outline_thickness'])
if index < settings['nr']:
index += 1
fig = _plot_merged_plot(overlay=settings['overlay'],
image=image,
stack=stack,
mask_dims=mask_dims,
figuresize=settings['figuresize'],
overlayed_image=overlayed_image,
outlines=outlines,
cmap=settings['cmap'],
outline_colors=outline_colors,
print_object_number=settings['print_object_number'])
else:
return fig
def _plot_images_on_grid(image_files, channel_indices, um_per_pixel, scale_bar_length_um=5, fontsize=8, show_filename=True, channel_names=None, plot=False):
"""
Plots a grid of images with optional scale bar and channel names.
Args:
image_files (list): List of image file paths.
channel_indices (list): List of channel indices to select from the images.
um_per_pixel (float): Micrometers per pixel.
scale_bar_length_um (float, optional): Length of the scale bar in micrometers. Defaults to 5.
fontsize (int, optional): Font size for the image titles. Defaults to 8.
show_filename (bool, optional): Whether to show the image file names as titles. Defaults to True.
channel_names (list, optional): List of channel names. Defaults to None.
plot (bool, optional): Whether to display the plot. Defaults to False.
Returns:
matplotlib.figure.Figure: The generated figure object.
"""
print(f'scale bar represents {scale_bar_length_um} um')
nr_of_images = len(image_files)
cols = int(np.ceil(np.sqrt(nr_of_images)))
rows = np.ceil(nr_of_images / cols)
fig, axes = plt.subplots(int(rows), int(cols), figsize=(20, 20), facecolor='black')
fig.patch.set_facecolor('black')
axes = axes.flatten()
# Calculate the scale bar length in pixels
scale_bar_length_px = int(scale_bar_length_um / um_per_pixel) # Convert to pixels
channel_colors = ['red','green','blue']
for i, image_file in enumerate(image_files):
img_array = cv2.imread(image_file, cv2.IMREAD_UNCHANGED)
if img_array.ndim == 3 and img_array.shape[2] >= 3:
img_array = cv2.cvtColor(img_array, cv2.COLOR_BGR2RGB)
# Handle different channel selections
if channel_indices is not None:
if len(channel_indices) == 1: # Single channel (grayscale)
img_array = img_array[:, :, channel_indices[0]]
cmap = 'gray'
elif len(channel_indices) == 2: # Dual channels
img_array = np.mean(img_array[:, :, channel_indices], axis=2)
cmap = 'gray'
else: # RGB or more channels
img_array = img_array[:, :, channel_indices]
cmap = None
else:
cmap = None if img_array.ndim == 3 else 'gray'
# Normalize based on dtype
if img_array.dtype == np.uint16:
img_array = img_array.astype(np.float32) / 65535.0
elif img_array.dtype == np.uint8:
img_array = img_array.astype(np.float32) / 255.0
ax = axes[i]
ax.imshow(img_array, cmap=cmap)
ax.axis('off')
if show_filename:
ax.set_title(os.path.basename(image_file), color='white', fontsize=fontsize, pad=20)
# Add scale bar
ax.plot([10, 10 + scale_bar_length_px], [img_array.shape[0] - 10] * 2, lw=2, color='white')
# Add channel names at the top if specified
initial_offset = 0.02 # Starting offset from the left side of the figure
increment = 0.05 # Fixed increment for each subsequent channel name, adjust based on figure width
if channel_names:
current_offset = initial_offset
for i, channel_name in enumerate(channel_names):
color = channel_colors[i] if i < len(channel_colors) else 'white'
fig.text(current_offset, 0.99, channel_name, color=color, fontsize=fontsize,
verticalalignment='top', horizontalalignment='left',
bbox=dict(facecolor='black', edgecolor='none', pad=3))
current_offset += increment
for j in range(i + 1, len(axes)):
axes[j].axis('off')
plt.tight_layout(pad=3)
if plot:
plt.show()
return fig
def _save_scimg_plot(src, nr_imgs=16, channel_indices=[0,1,2], um_per_pixel=0.1, scale_bar_length_um=10, standardize=True, fontsize=8, show_filename=True, channel_names=None, dpi=300, plot=False, i=1, all_folders=1):
"""
Save and visualize single-cell images.
Args:
src (str): The source directory path.
nr_imgs (int, optional): The number of images to visualize. Defaults to 16.
channel_indices (list, optional): List of channel indices to visualize. Defaults to [0,1,2].
um_per_pixel (float, optional): Micrometers per pixel. Defaults to 0.1.
scale_bar_length_um (float, optional): Length of the scale bar in micrometers. Defaults to 10.
standardize (bool, optional): Whether to standardize the image sizes. Defaults to True.
fontsize (int, optional): Font size for the filename. Defaults to 8.
show_filename (bool, optional): Whether to show the filename on the image. Defaults to True.
channel_names (list, optional): List of channel names. Defaults to None.
dpi (int, optional): Dots per inch for the saved image. Defaults to 300.
plot (bool, optional): Whether to plot the images. Defaults to False.
Returns:
None
"""
from .io import _save_figure
def _visualize_scimgs(src, channel_indices=None, um_per_pixel=0.1, scale_bar_length_um=10, show_filename=True, standardize=True, nr_imgs=None, fontsize=8, channel_names=None, plot=False):
"""
Visualize single-cell images.
Args:
src (str): The source directory path.
channel_indices (list, optional): List of channel indices to visualize. Defaults to None.
um_per_pixel (float, optional): Micrometers per pixel. Defaults to 0.1.
scale_bar_length_um (float, optional): Length of the scale bar in micrometers. Defaults to 10.
show_filename (bool, optional): Whether to show the filename on the image. Defaults to True.
standardize (bool, optional): Whether to standardize the image sizes. Defaults to True.
nr_imgs (int, optional): The number of images to visualize. Defaults to None.
fontsize (int, optional): Font size for the filename. Defaults to 8.
channel_names (list, optional): List of channel names. Defaults to None.
plot (bool, optional): Whether to plot the images. Defaults to False.
Returns:
matplotlib.figure.Figure: The figure object containing the plotted images.
"""
from .utils import _find_similar_sized_images
def _generate_filelist(src):
"""
Generate a list of image files in the specified directory.
Args:
src (str): The source directory path.
Returns:
list: A list of image file paths.
"""
files = glob.glob(os.path.join(src, '*'))
image_files = [f for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff', '.gif'))]
return image_files
def _random_sample(file_list, nr_imgs=None):
"""
Randomly selects a subset of files from the given file list.
Args:
file_list (list): A list of file names.
nr_imgs (int, optional): The number of files to select. If None, all files are selected. Defaults to None.
Returns:
list: A list of randomly selected file names.
"""
if nr_imgs is not None and nr_imgs < len(file_list):
random.seed(42)
file_list = random.sample(file_list, nr_imgs)
return file_list
image_files = _generate_filelist(src)
if standardize:
image_files = _find_similar_sized_images(image_files)
if nr_imgs is not None:
image_files = _random_sample(image_files, nr_imgs)
fig = _plot_images_on_grid(image_files, channel_indices, um_per_pixel, scale_bar_length_um, fontsize, show_filename, channel_names, plot)
return fig
fig = _visualize_scimgs(src, channel_indices, um_per_pixel, scale_bar_length_um, show_filename, standardize, nr_imgs, fontsize, channel_names, plot)
_save_figure(fig, src, text='all_channels')
for channel in channel_indices:
channel_indices=[channel]
fig = _visualize_scimgs(src, channel_indices, um_per_pixel, scale_bar_length_um, show_filename, standardize, nr_imgs, fontsize, channel_names=None, plot=plot)
_save_figure(fig, src, text=f'channel_{channel}')
return
def _plot_cropped_arrays(stack, filename, figuresize=10, cmap='inferno', threshold=500):
"""
Plot cropped arrays.
Args:
stack (ndarray): The array to be plotted.
figuresize (int, optional): The size of the figure. Defaults to 20.
cmap (str, optional): The colormap to be used. Defaults to 'inferno'.
threshold (int, optional): The threshold for the number of unique intensity values. Defaults to 1000.
Returns:
None
"""
#start = time.time()
dim = stack.shape
def plot_single_array(array, ax, title, chosen_cmap):
unique_values = np.unique(array)
num_unique_values = len(unique_values)
if num_unique_values <= threshold:
chosen_cmap = _generate_mask_random_cmap(array)
title = f'{title}, {num_unique_values} (obj.)'
ax.imshow(array, cmap=chosen_cmap)
ax.set_title(title, size=18)
ax.axis('off')
if len(dim) == 2:
fig, ax = plt.subplots(1, 1, figsize=(figuresize, figuresize))
plot_single_array(stack, ax, 'Channel one', plt.get_cmap(cmap))
fig.tight_layout()
plt.show()
elif len(dim) > 2:
num_channels = dim[2]
fig, axs = plt.subplots(1, num_channels, figsize=(figuresize, figuresize))
for channel in range(num_channels):
plot_single_array(stack[:, :, channel], axs[channel], f'C. {channel}', plt.get_cmap(cmap))
fig.tight_layout()
print(f'{filename}')
return fig
def _visualize_and_save_timelapse_stack_with_tracks(masks, tracks_df, save, src, name, plot, filenames, object_type, mode='btrack', interactive=False):
"""
Visualizes and saves a timelapse stack with tracks.
Args:
masks (list): List of binary masks representing each frame of the timelapse stack.
tracks_df (pandas.DataFrame): DataFrame containing track information.
save (bool): Flag indicating whether to save the timelapse stack.
src (str): Source file path.
name (str): Name of the timelapse stack.
plot (bool): Flag indicating whether to plot the timelapse stack.
filenames (list): List of filenames corresponding to each frame of the timelapse stack.
object_type (str): Type of object being tracked.
mode (str, optional): Tracking mode. Defaults to 'btrack'.
interactive (bool, optional): Flag indicating whether to display the timelapse stack interactively. Defaults to False.
"""
from .io import _save_mask_timelapse_as_gif
highest_label = max(np.max(mask) for mask in masks)
# Generate random colors for each label, including the background
random_colors = np.random.rand(highest_label + 1, 4)
random_colors[:, 3] = 1 # Full opacity
random_colors[0] = [0, 0, 0, 1] # Background color
cmap = plt.cm.colors.ListedColormap(random_colors)
# Ensure the normalization range covers all labels
norm = plt.cm.colors.Normalize(vmin=0, vmax=highest_label)
# Function to plot a frame and overlay tracks
def _view_frame_with_tracks(frame=0):
"""
Display the frame with tracks overlaid.
Parameters:
frame (int): The frame number to display.
Returns:
None
"""
fig, ax = plt.subplots(figsize=(50, 50))
current_mask = masks[frame]
ax.imshow(current_mask, cmap=cmap, norm=norm) # Apply both colormap and normalization
ax.set_title(f'Frame: {frame}')
# Directly annotate each object with its label number from the mask
for label_value in np.unique(current_mask):
if label_value == 0: continue # Skip background
y, x = np.mean(np.where(current_mask == label_value), axis=1)
ax.text(x, y, str(label_value), color='white', fontsize=24, ha='center', va='center')
# Overlay tracks
for track in tracks_df['track_id'].unique():
_track = tracks_df[tracks_df['track_id'] == track]
ax.plot(_track['x'], _track['y'], '-k', linewidth=1)
ax.axis('off')
plt.show()
if plot:
if interactive:
interact(_view_frame_with_tracks, frame=IntSlider(min=0, max=len(masks)-1, step=1, value=0))
if save:
# Save as gif
gif_path = os.path.join(os.path.dirname(src), 'movies', 'gif')
os.makedirs(gif_path, exist_ok=True)
save_path_gif = os.path.join(gif_path, f'timelapse_masks_{object_type}_{name}.gif')
_save_mask_timelapse_as_gif(masks, tracks_df, save_path_gif, cmap, norm, filenames)
if plot:
if not interactive:
_display_gif(save_path_gif)
def _display_gif(path):
"""
Display a GIF image from the given path.
Parameters:
path (str): The path to the GIF image file.
Returns:
None
"""
with open(path, 'rb') as file:
display(ipyimage(file.read()))
def _plot_recruitment(df, df_type, channel_of_interest, target, columns=[], figuresize=10):
"""
Plot recruitment data for different conditions and pathogens.
Args:
df (DataFrame): The input DataFrame containing the recruitment data.
df_type (str): The type of DataFrame (e.g., 'train', 'test').
channel_of_interest (str): The channel of interest for plotting.
target (str): The target variable for plotting.
columns (list, optional): Additional columns to plot. Defaults to an empty list.
figuresize (int, optional): The size of the figure. Defaults to 50.
Returns:
None
"""
color_list = [(55/255, 155/255, 155/255),
(155/255, 55/255, 155/255),
(55/255, 155/255, 255/255),
(255/255, 55/255, 155/255)]
sns.set_palette(sns.color_palette(color_list))
font = figuresize/2
width=figuresize
height=figuresize/4
fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(width, height))
sns.barplot(ax=axes[0], data=df, x='condition', y=f'cell_channel_{channel_of_interest}_mean_intensity', hue='pathogen', capsize=.1, ci='sd', dodge=False)
axes[0].set_xlabel(f'pathogen {df_type}', fontsize=font)
axes[0].set_ylabel(f'cell_channel_{channel_of_interest}_mean_intensity', fontsize=font)
sns.barplot(ax=axes[1], data=df, x='condition', y=f'nucleus_channel_{channel_of_interest}_mean_intensity', hue='pathogen', capsize=.1, ci='sd', dodge=False)
axes[1].set_xlabel(f'pathogen {df_type}', fontsize=font)
axes[1].set_ylabel(f'nucleus_channel_{channel_of_interest}_mean_intensity', fontsize=font)
sns.barplot(ax=axes[2], data=df, x='condition', y=f'cytoplasm_channel_{channel_of_interest}_mean_intensity', hue='pathogen', capsize=.1, ci='sd', dodge=False)
axes[2].set_xlabel(f'pathogen {df_type}', fontsize=font)
axes[2].set_ylabel(f'cytoplasm_channel_{channel_of_interest}_mean_intensity', fontsize=font)
sns.barplot(ax=axes[3], data=df, x='condition', y=f'pathogen_channel_{channel_of_interest}_mean_intensity', hue='pathogen', capsize=.1, ci='sd', dodge=False)
axes[3].set_xlabel(f'pathogen {df_type}', fontsize=font)
axes[3].set_ylabel(f'pathogen_channel_{channel_of_interest}_mean_intensity', fontsize=font)
axes[0].legend_.remove()
axes[1].legend_.remove()
axes[2].legend_.remove()
axes[3].legend_.remove()
handles, labels = axes[3].get_legend_handles_labels()
axes[3].legend(handles, labels, bbox_to_anchor=(1.05, 0.5), loc='center left')
for i in [0,1,2,3]:
axes[i].tick_params(axis='both', which='major', labelsize=font)
axes[i].set_xticklabels(axes[i].get_xticklabels(), rotation=45)
plt.tight_layout()
plt.show()
columns = columns + ['pathogen_cytoplasm_mean_mean', 'pathogen_cytoplasm_q75_mean', 'pathogen_periphery_cytoplasm_mean_mean', 'pathogen_outside_cytoplasm_mean_mean', 'pathogen_outside_cytoplasm_q75_mean']
columns = columns + [f'pathogen_slope_channel_{channel_of_interest}', f'pathogen_cell_distance_channel_{channel_of_interest}', f'nucleus_cell_distance_channel_{channel_of_interest}']
width = figuresize*2
columns_per_row = math.ceil(len(columns) / 2)
height = (figuresize*2)/columns_per_row
fig, axes = plt.subplots(nrows=2, ncols=columns_per_row, figsize=(width, height * 2))
axes = axes.flatten()
print(f'{columns}')
for i, col in enumerate(columns):
ax = axes[i]
sns.barplot(ax=ax, data=df, x='condition', y=f'{col}', hue='pathogen', capsize=.1, ci='sd', dodge=False)
ax.set_xlabel(f'pathogen {df_type}', fontsize=font)
ax.set_ylabel(f'{col}', fontsize=int(font*2))
ax.legend_.remove()
ax.tick_params(axis='both', which='major', labelsize=font)
ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
if i <= 5:
ax.set_ylim(1, None)
for i in range(len(columns), len(axes)):
axes[i].axis('off')
plt.tight_layout()
plt.show()
def _plot_controls(df, mask_chans, channel_of_interest, figuresize=5):
"""
Plot controls for different channels and conditions.
Args:
df (pandas.DataFrame): The DataFrame containing the data.
mask_chans (list): The list of channels to include in the plot.
channel_of_interest (int): The channel of interest.
figuresize (int, optional): The size of the figure. Defaults to 5.
Returns:
None
"""
mask_chans.append(channel_of_interest)
if len(mask_chans) == 4:
mask_chans = [0,1,2,3]
if len(mask_chans) == 3:
mask_chans = [0,1,2]
if len(mask_chans) == 2:
mask_chans = [0,1]
if len(mask_chans) == 1:
mask_chans = [0]
controls_cols = []
for chan in mask_chans:
controls_cols_c = []
controls_cols_c.append(f'cell_channel_{chan}_mean_intensity')
controls_cols_c.append(f'nucleus_channel_{chan}_mean_intensity')
controls_cols_c.append(f'pathogen_channel_{chan}_mean_intensity')
controls_cols_c.append(f'cytoplasm_channel_{chan}_mean_intensity')
controls_cols.append(controls_cols_c)
unique_conditions = df['condition'].unique().tolist()
if len(unique_conditions) ==1:
unique_conditions=unique_conditions+unique_conditions
fig, axes = plt.subplots(len(unique_conditions), len(mask_chans)+1, figsize=(figuresize*len(mask_chans), figuresize*len(unique_conditions)))
# Define RGB color tuples (scaled to 0-1 range)
color_list = [(55/255, 155/255, 155/255),
(155/255, 55/255, 155/255),
(55/255, 155/255, 255/255),
(255/255, 55/255, 155/255)]
for idx_condition, condition in enumerate(unique_conditions):
df_temp = df[df['condition'] == condition]
for idx_channel, control_cols_c in enumerate(controls_cols):
data = []
std_dev = []
for control_col in control_cols_c:
if control_col in df_temp.columns:
mean_intensity = df_temp[control_col].mean()
mean_intensity = 0 if np.isnan(mean_intensity) else mean_intensity
data.append(mean_intensity)
std_dev.append(df_temp[control_col].std())
current_axis = axes[idx_condition][idx_channel]
current_axis.bar(["cell", "nucleus", "pathogen", "cytoplasm"], data, yerr=std_dev,
capsize=4, color=color_list)
current_axis.set_xlabel('Component')
current_axis.set_ylabel('Mean Intensity')
current_axis.set_title(f'Condition: {condition} - Channel {idx_channel}')
plt.tight_layout()
plt.show()
###################################################
# Classify
###################################################
def _imshow(img, labels, nrow=20, color='white', fontsize=12):
"""
Display multiple images in a grid with corresponding labels.
Args:
img (list): List of images to display.
labels (list): List of labels corresponding to each image.
nrow (int, optional): Number of images per row in the grid. Defaults to 20.
color (str, optional): Color of the label text. Defaults to 'white'.
fontsize (int, optional): Font size of the label text. Defaults to 12.
"""
n_images = len(labels)
n_col = nrow
n_row = int(np.ceil(n_images / n_col))
img_height = img[0].shape[1]
img_width = img[0].shape[2]
canvas = np.zeros((img_height * n_row, img_width * n_col, 3))
for i in range(n_row):
for j in range(n_col):
idx = i * n_col + j
if idx < n_images:
canvas[i * img_height:(i + 1) * img_height, j * img_width:(j + 1) * img_width] = np.transpose(img[idx], (1, 2, 0))
fig = plt.figure(figsize=(50, 50))
plt.imshow(canvas)
plt.axis("off")
for i, label in enumerate(labels):
row = i // n_col
col = i % n_col
x = col * img_width + 2
y = row * img_height + 15
plt.text(x, y, label, color=color, fontsize=fontsize, fontweight='bold')
return fig
def _imshow_gpu(img, labels, nrow=20, color='white', fontsize=12):
"""
Display multiple images in a grid with corresponding labels.
Args:
img (torch.Tensor): A batch of images as a tensor.
labels (list): List of labels corresponding to each image.
nrow (int, optional): Number of images per row in the grid. Defaults to 20.
color (str, optional): Color of the label text. Defaults to 'white'.
fontsize (int, optional): Font size of the label text. Defaults to 12.
"""
if img.is_cuda:
img = img.cpu() # Move to CPU if the tensor is on GPU
n_images = len(labels)
n_col = nrow
n_row = int(np.ceil(n_images / n_col))
img_height = img.shape[2] # Height of the image
img_width = img.shape[3] # Width of the image
# Prepare the canvas on CPU
canvas = torch.zeros((img_height * n_row, img_width * n_col, 3))
for i in range(n_row):
for j in range(n_col):
idx = i * n_col + j
if idx < n_images:
# Place the image on the canvas
canvas[i * img_height:(i + 1) * img_height, j * img_width:(j + 1) * img_width] = img[idx].permute(1, 2, 0)
canvas = canvas.numpy() # Convert to NumPy for plotting
fig = plt.figure(figsize=(50, 50))
plt.imshow(canvas)
plt.axis("off")
for i, label in enumerate(labels):
row = i // n_col
col = i % n_col
x = col * img_width + 2
y = row * img_height + 15
plt.text(x, y, label, color=color, fontsize=fontsize, fontweight='bold')
return fig
def _plot_histograms_and_stats(df):
conditions = df['condition'].unique()
for condition in conditions:
subset = df[df['condition'] == condition]
# Calculate the statistics
mean_pred = subset['pred'].mean()
over_0_5 = sum(subset['pred'] > 0.5)
under_0_5 = sum(subset['pred'] <= 0.5)
# Print the statistics
print(f"Condition: {condition}")
print(f"Number of rows: {len(subset)}")
print(f"Mean of pred: {mean_pred}")
print(f"Count of pred values over 0.5: {over_0_5}")
print(f"Count of pred values under 0.5: {under_0_5}")
print(f"Percent positive: {(over_0_5/(over_0_5+under_0_5))*100}")
print(f"Percent negative: {(under_0_5/(over_0_5+under_0_5))*100}")
print('-'*40)
# Plot the histogram
plt.figure(figsize=(10,6))
plt.hist(subset['pred'], bins=30, edgecolor='black')
plt.axvline(mean_pred, color='red', linestyle='dashed', linewidth=1, label=f"Mean = {mean_pred:.2f}")
plt.title(f'Histogram for pred - Condition: {condition}')
plt.xlabel('Pred Value')
plt.ylabel('Count')
plt.legend()
plt.show()
def _show_residules(model):
# Get the residuals
residuals = model.resid
# Histogram of residuals
plt.hist(residuals, bins=30)
plt.title('Histogram of Residuals')
plt.xlabel('Residual Value')
plt.ylabel('Frequency')
plt.show()
# QQ plot
sm.qqplot(residuals, fit=True, line='45')
plt.title('QQ Plot')
plt.show()
# Residuals vs. Fitted values
plt.scatter(model.fittedvalues, residuals)
plt.xlabel('Fitted values')
plt.ylabel('Residuals')
plt.title('Residuals vs. Fitted Values')
plt.axhline(y=0, color='red')
plt.show()
# Shapiro-Wilk test for normality
W, p_value = stats.shapiro(residuals)
print(f'Shapiro-Wilk Test W-statistic: {W}, p-value: {p_value}')
def _reg_v_plot(df, grouping, variable, plate_number):
df['-log10(p)'] = -np.log10(df['p'])
# Create the volcano plot
plt.figure(figsize=(40, 30))
sc = plt.scatter(df['effect'], df['-log10(p)'], c=np.sign(df['effect']), cmap='coolwarm')
plt.title('Volcano Plot', fontsize=12)
plt.xlabel('Coefficient', fontsize=12)
plt.ylabel('-log10(P-value)', fontsize=12)
# Add text for specified points
for idx, row in df.iterrows():
if row['p'] < 0.05:# and abs(row['effect']) > 0.1:
plt.text(row['effect'], -np.log10(row['p']), idx, fontsize=12, ha='center', va='bottom', color='black')
plt.axhline(y=-np.log10(0.05), color='gray', linestyle='--') # line for p=0.05
plt.show()
[docs]
def generate_plate_heatmap(df, plate_number, variable, grouping, min_max, min_count):
df = df.copy() # Work on a copy to avoid SettingWithCopyWarning
df['plate'], df['row'], df['col'] = zip(*df['prc'].str.split('_'))
# Filtering the dataframe based on the plate_number
df = df[df['plate'] == plate_number].copy() # Create another copy after filtering
# Ensure proper ordering
row_order = [f'r{i}' for i in range(1, 17)]
col_order = [f'c{i}' for i in range(1, 28)] # Exclude c15 as per your earlier code
df['row'] = pd.Categorical(df['row'], categories=row_order, ordered=True)
df['col'] = pd.Categorical(df['col'], categories=col_order, ordered=True)
df['count'] = df.groupby(['row', 'col'])['row'].transform('count')
if min_count > 0:
df = df[df['count'] >= min_count]
# Explicitly set observed=True to avoid FutureWarning
grouped = df.groupby(['row', 'col'], observed=True) # Group by row and column
if grouping == 'mean':
plate = grouped[variable].mean().reset_index()
elif grouping == 'sum':
plate = grouped[variable].sum().reset_index()
elif grouping == 'count':
variable = 'count'
plate = grouped[variable].count().reset_index()
else:
raise ValueError(f"Unsupported grouping: {grouping}")
plate_map = pd.pivot_table(plate, values=variable, index='row', columns='col').fillna(0)
if min_max == 'all':
min_max = [plate_map.min().min(), plate_map.max().max()]
elif min_max == 'allq':
min_max = np.quantile(plate_map.values, [0.02, 0.98])
elif isinstance(min_max, (list, tuple)) and len(min_max) == 2:
if isinstance(min_max[0], (float)) and isinstance(min_max[1], (float)):
min_max = np.quantile(plate_map.values, [min_max[0], min_max[1]])
if isinstance(min_max[0], (int)) and isinstance(min_max[1], (int)):
min_max = [min_max[0], min_max[1]]
return plate_map, min_max
[docs]
def plot_plates(df, variable, grouping, min_max, cmap, min_count=0, verbose=True):
plates = df['prc'].str.split('_', expand=True)[0].unique()
n_rows, n_cols = (len(plates) + 3) // 4, 4
fig, ax = plt.subplots(n_rows, n_cols, figsize=(40, 5 * n_rows))
ax = ax.flatten()
for index, plate in enumerate(plates):
plate_map, min_max_values = generate_plate_heatmap(df, plate, variable, grouping, min_max, min_count)
sns.heatmap(plate_map, cmap=cmap, vmin=min_max_values[0], vmax=min_max_values[1], ax=ax[index])
ax[index].set_title(plate)
for i in range(len(plates), n_rows * n_cols):
fig.delaxes(ax[i])
plt.subplots_adjust(wspace=0.1, hspace=0.4)
if verbose:
plt.show()
return fig
[docs]
def print_mask_and_flows(stack, mask, flows, overlay=False):
fig, axs = plt.subplots(1, 3, figsize=(30, 10)) # Adjust subplot layout
if stack.shape[-1] == 1:
stack = np.squeeze(stack)
# Display original image or its first channel
if stack.ndim == 2:
axs[0].imshow(stack, cmap='gray')
elif stack.ndim == 3:
axs[0].imshow(stack)
else:
raise ValueError("Unexpected stack dimensionality.")
axs[0].set_title('Original Image')
axs[0].axis('off')
# Overlay mask on original image if overlay is True
if overlay:
mask_cmap = generate_mask_random_cmap(mask) # Generate random colormap for mask
mask_overlay = np.ma.masked_where(mask == 0, mask) # Mask background
outlines = find_boundaries(mask, mode='thick') # Find mask outlines
if stack.ndim == 2 or stack.ndim == 3:
axs[1].imshow(stack, cmap='gray' if stack.ndim == 2 else None)
axs[1].imshow(mask_overlay, cmap=mask_cmap, alpha=0.5) # Overlay mask
axs[1].contour(outlines, colors='r', linewidths=2) # Add red outlines with thickness 2
else:
axs[1].imshow(mask, cmap='gray')
axs[1].set_title('Mask with Overlay' if overlay else 'Mask')
axs[1].axis('off')
# Display flow image or its first channel
if flows and isinstance(flows, list) and flows[0].ndim in [2, 3]:
flow_image = flows[0]
if flow_image.ndim == 3:
flow_image = flow_image[:, :, 0] # Use first channel for 3D
axs[2].imshow(flow_image, cmap='jet')
else:
raise ValueError("Unexpected flow dimensionality or structure.")
axs[2].set_title('Flows')
axs[2].axis('off')
fig.tight_layout()
plt.show()
[docs]
def plot_resize(images, resized_images, labels, resized_labels):
# Display an example image and label before and after resizing
fig, ax = plt.subplots(2, 2, figsize=(20, 20))
# Check if the image is grayscale; if so, add a colormap and keep dimensions correct
if images[0].ndim == 2: # Grayscale image
ax[0, 0].imshow(images[0], cmap='gray')
else: # RGB or RGBA image
ax[0, 0].imshow(images[0])
ax[0, 0].set_title('Original Image')
if resized_images[0].ndim == 2: # Grayscale image
ax[0, 1].imshow(resized_images[0], cmap='gray')
else: # RGB or RGBA image
ax[0, 1].imshow(resized_images[0])
ax[0, 1].set_title('Resized Image')
# Assuming labels are always grayscale (most common scenario)
ax[1, 0].imshow(labels[0], cmap='gray')
ax[1, 0].set_title('Original Label')
ax[1, 1].imshow(resized_labels[0], cmap='gray')
ax[1, 1].set_title('Resized Label')
plt.show()
[docs]
def normalize_and_visualize(image, normalized_image, title=""):
"""Utility function for visualization"""
fig, ax = plt.subplots(1, 2, figsize=(12, 6))
if image.ndim == 3: # Multi-channel image
ax[0].imshow(np.mean(image, axis=-1), cmap='gray') # Display the average over channels for visualization
else: # Grayscale image
ax[0].imshow(image, cmap='gray')
ax[0].set_title("Original " + title)
ax[0].axis('off')
if normalized_image.ndim == 3:
ax[1].imshow(np.mean(normalized_image, axis=-1), cmap='gray') # Similarly, display the average over channels
else:
ax[1].imshow(normalized_image, cmap='gray')
ax[1].set_title("Normalized " + title)
ax[1].axis('off')
plt.show()
[docs]
def visualize_masks(mask1, mask2, mask3, title="Masks Comparison"):
fig, axs = plt.subplots(1, 3, figsize=(30, 10))
for ax, mask, title in zip(axs, [mask1, mask2, mask3], ['Mask 1', 'Mask 2', 'Mask 3']):
cmap = generate_mask_random_cmap(mask)
# If the mask is binary, we can skip normalization
if np.isin(mask, [0, 1]).all():
ax.imshow(mask, cmap=cmap)
else:
# Normalize the image for displaying purposes
norm = plt.Normalize(vmin=0, vmax=mask.max())
ax.imshow(mask, cmap=cmap, norm=norm)
ax.set_title(title)
ax.axis('off')
plt.suptitle(title)
plt.show()
[docs]
def visualize_cellpose_masks(masks, titles=None, filename=None, save=False, src=None):
"""
Visualize multiple masks with optional titles.
Parameters:
masks (list of np.ndarray): A list of masks to visualize.
titles (list of str, optional): A list of titles for the masks. If None, default titles will be used.
comparison_title (str): Title for the entire figure.
"""
comparison_title=f"Masks Comparison for {filename}"
if titles is None:
titles = [f'Mask {i+1}' for i in range(len(masks))]
# Ensure the length of titles matches the number of masks
assert len(titles) == len(masks), "Number of titles and masks must match"
num_masks = len(masks)
fig, axs = plt.subplots(1, num_masks, figsize=(10 * num_masks, 10)) # Adjusting figure size dynamically
for ax, mask, title in zip(axs, masks, titles):
cmap = generate_mask_random_cmap(mask)
# Normalize and display the mask
norm = plt.Normalize(vmin=0, vmax=mask.max())
ax.imshow(mask, cmap=cmap, norm=norm)
ax.set_title(title)
ax.axis('off')
plt.suptitle(comparison_title)
plt.show()
if save:
if src is None:
src = os.getcwd()
results_dir = os.path.join(src, 'results')
os.makedirs(results_dir, exist_ok=True)
fig_path = os.path.join(results_dir, f'{filename}.pdf')
fig.savefig(fig_path, format='pdf')
print(f'Saved figure to {fig_path}')
return
[docs]
def plot_comparison_results(comparison_results):
df = pd.DataFrame(comparison_results)
df_melted = pd.melt(df, id_vars=['filename'], var_name='metric', value_name='value')
df_jaccard = df_melted[df_melted['metric'].str.contains('jaccard')]
df_dice = df_melted[df_melted['metric'].str.contains('dice')]
df_boundary_f1 = df_melted[df_melted['metric'].str.contains('boundary_f1')]
df_ap = df_melted[df_melted['metric'].str.contains('average_precision')]
fig, axs = plt.subplots(1, 4, figsize=(40, 10))
# Jaccard Index Plot
sns.boxplot(data=df_jaccard, x='metric', y='value', ax=axs[0], color='lightgrey')
sns.stripplot(data=df_jaccard, x='metric', y='value', ax=axs[0], jitter=True, alpha=0.6)
axs[0].set_title('Jaccard Index by Comparison')
axs[0].set_xticklabels(axs[0].get_xticklabels(), rotation=45, horizontalalignment='right')
axs[0].set_xlabel('Comparison')
axs[0].set_ylabel('Jaccard Index')
# Dice Coefficient Plot
sns.boxplot(data=df_dice, x='metric', y='value', ax=axs[1], color='lightgrey')
sns.stripplot(data=df_dice, x='metric', y='value', ax=axs[1], jitter=True, alpha=0.6)
axs[1].set_title('Dice Coefficient by Comparison')
axs[1].set_xticklabels(axs[1].get_xticklabels(), rotation=45, horizontalalignment='right')
axs[1].set_xlabel('Comparison')
axs[1].set_ylabel('Dice Coefficient')
# Border F1 scores
sns.boxplot(data=df_boundary_f1, x='metric', y='value', ax=axs[2], color='lightgrey')
sns.stripplot(data=df_boundary_f1, x='metric', y='value', ax=axs[2], jitter=True, alpha=0.6)
axs[2].set_title('Boundary F1 Score by Comparison')
axs[2].set_xticklabels(axs[2].get_xticklabels(), rotation=45, horizontalalignment='right')
axs[2].set_xlabel('Comparison')
axs[2].set_ylabel('Boundary F1 Score')
# AP scores plot
sns.boxplot(data=df_ap, x='metric', y='value', ax=axs[3], color='lightgrey')
sns.stripplot(data=df_ap, x='metric', y='value', ax=axs[3], jitter=True, alpha=0.6)
axs[3].set_title('Average Precision by Comparison')
axs[3].set_xticklabels(axs[3].get_xticklabels(), rotation=45, horizontalalignment='right')
axs[3].set_xlabel('Comparison')
axs[3].set_ylabel('Average Precision')
plt.tight_layout()
plt.show()
return fig
[docs]
def plot_object_outlines(src, objects=['nucleus','cell','pathogen'], channels=[0,1,2], max_nr=10):
for object_, channel in zip(objects, channels):
folders = [os.path.join(src, 'norm_channel_stack', f'{object_}_mask_stack'),
os.path.join(src,f'{channel+1}')]
print(folders)
plot_images_and_arrays(folders,
lower_percentile=2,
upper_percentile=99.5,
threshold=1000,
extensions=['.npy', '.tif', '.tiff', '.png'],
overlay=True,
max_nr=10,
randomize=True)
[docs]
def volcano_plot(coef_df, filename='volcano_plot.pdf'):
# Create the volcano plot
plt.figure(figsize=(10, 6))
sns.scatterplot(
data=coef_df,
x='coefficient',
y='-log10(p_value)',
hue='highlight',
palette={True: 'red', False: 'blue'}
)
plt.title('Volcano Plot of Coefficients')
plt.xlabel('Coefficient')
plt.ylabel('-log10(p-value)')
plt.axhline(y=-np.log10(0.05), color='red', linestyle='--')
plt.legend().remove()
plt.savefig(filename, format='pdf')
print(f'Saved Volcano plot: {filename}')
plt.show()
[docs]
def plot_histogram(df, dependent_variable):
# Plot histogram of the dependent variable
plt.figure(figsize=(10, 6))
sns.histplot(df[dependent_variable], kde=True)
plt.title(f'Histogram of {dependent_variable}')
plt.xlabel(dependent_variable)
plt.ylabel('Frequency')
plt.show()
[docs]
def plot_lorenz_curves(csv_files, remove_keys=['TGGT1_220950_1', 'TGGT1_233460_4']):
def lorenz_curve(data):
"""Calculate Lorenz curve."""
sorted_data = np.sort(data)
cumulative_data = np.cumsum(sorted_data)
lorenz_curve = cumulative_data / cumulative_data[-1]
lorenz_curve = np.insert(lorenz_curve, 0, 0)
return lorenz_curve
combined_data = []
plt.figure(figsize=(10, 6))
for idx, csv_file in enumerate(csv_files):
if idx == 1:
save_fldr = os.path.dirname(csv_file)
save_path = os.path.join(save_fldr, 'lorenz_curve.pdf')
df = pd.read_csv(csv_file)
for remove in remove_keys:
df = df[df['key'] != remove]
values = df['value'].values
combined_data.extend(values)
lorenz = lorenz_curve(values)
name = os.path.basename(csv_file)[:3]
plt.plot(np.linspace(0, 1, len(lorenz)), lorenz, label=name)
# Plot combined Lorenz curve
combined_lorenz = lorenz_curve(np.array(combined_data))
plt.plot(np.linspace(0, 1, len(combined_lorenz)), combined_lorenz, label="Combined Lorenz Curve", linestyle='--', color='black')
plt.title('Lorenz Curves')
plt.xlabel('Cumulative Share of Individuals')
plt.ylabel('Cumulative Share of Value')
plt.legend()
plt.grid(False)
plt.savefig(save_path)
plt.show()
[docs]
def plot_permutation(permutation_df):
num_features = len(permutation_df)
fig_height = max(8, num_features * 0.3) # Set a minimum height of 8 and adjust height based on number of features
fig_width = 10 # Width can be fixed or adjusted similarly
font_size = max(10, 12 - num_features * 0.2) # Adjust font size dynamically
fig, ax = plt.subplots(figsize=(fig_width, fig_height))
ax.barh(permutation_df['feature'], permutation_df['importance_mean'], xerr=permutation_df['importance_std'], color="teal", align="center", alpha=0.6)
ax.set_xlabel('Permutation Importance', fontsize=font_size)
ax.tick_params(axis='both', which='major', labelsize=font_size)
plt.tight_layout()
return fig
[docs]
def plot_feature_importance(feature_importance_df):
num_features = len(feature_importance_df)
fig_height = max(8, num_features * 0.3) # Set a minimum height of 8 and adjust height based on number of features
fig_width = 10 # Width can be fixed or adjusted similarly
font_size = max(10, 12 - num_features * 0.2) # Adjust font size dynamically
fig, ax = plt.subplots(figsize=(fig_width, fig_height))
ax.barh(feature_importance_df['feature'], feature_importance_df['importance'], color="blue", align="center", alpha=0.6)
ax.set_xlabel('Feature Importance', fontsize=font_size)
ax.tick_params(axis='both', which='major', labelsize=font_size)
plt.tight_layout()
return fig
[docs]
def read_and_plot__vision_results(base_dir, y_axis='accuracy', name_split='_time', y_lim=[0.8, 0.9]):
# List to store data from all CSV files
data_frames = []
dst = os.path.join(base_dir, 'result')
os.mkdir(dst,exists=True)
# Walk through the directory
for root, dirs, files in os.walk(base_dir):
for file in files:
if file.endswith("_test_result.csv"):
file_path = os.path.join(root, file)
# Extract model information from the file name
file_name = os.path.basename(file_path)
model = file_name.split(f'{name_split}')[0]
# Extract epoch information from the file name
epoch_info = file_name.split('_time')[1]
base_folder = os.path.dirname(file_path)
epoch = os.path.basename(base_folder)
# Read the CSV file
df = pd.read_csv(file_path)
df['model'] = model
df['epoch'] = epoch
# Append the data frame to the list
data_frames.append(df)
# Concatenate all data frames
if data_frames:
result_df = pd.concat(data_frames, ignore_index=True)
# Calculate average y_axis per model
avg_metric = result_df.groupby('model')[y_axis].mean().reset_index()
avg_metric = avg_metric.sort_values(by=y_axis)
print(avg_metric)
# Plotting the results
plt.figure(figsize=(10, 6))
plt.bar(avg_metric['model'], avg_metric[y_axis])
plt.xlabel('Model')
plt.ylabel(f'{y_axis}')
plt.title(f'Average {y_axis.capitalize()} per Model')
plt.xticks(rotation=45)
plt.tight_layout()
if y_lim is not None:
plt.ylim(y_lim)
plt.show()
else:
print("No CSV files found in the specified directory.")