spacr.plot
==========

.. py:module:: spacr.plot






Module Contents
---------------

.. py:function:: plot_image_mask_overlay(file, channels, cell_channel, nucleus_channel, pathogen_channel, figuresize=10, percentiles=(2, 98), thickness=3, save_pdf=True, mode='outlines', export_tiffs=False, all_on_all=False, all_outlines=False, filter_dict=None)

   Plot image and mask overlays.


.. py:function:: plot_cellpose4_output(batch, masks, flows, cmap='inferno', figuresize=10, nr=1, print_object_number=True)

   Plot the masks and flows for a given batch of images.

   :param batch: The batch of images.
   :type batch: numpy.ndarray
   :param masks: The masks corresponding to the images.
   :type masks: list or numpy.ndarray
   :param flows: The flows corresponding to the images.
   :type flows: list or numpy.ndarray
   :param cmap: The colormap to use for displaying the images. Defaults to 'inferno'.
   :type cmap: str, optional
   :param figuresize: The size of the figure. Defaults to 20.
   :type figuresize: int, optional
   :param nr: The maximum number of images to plot. Defaults to 1.
   :type nr: int, optional
   :param file_type: The file type of the flows. Defaults to '.npz'.
   :type file_type: str, optional
   :param print_object_number: Whether to print the object number on the mask. Defaults to True.
   :type print_object_number: bool, optional

   :returns: None


.. py:function:: plot_organelle_output(img_batch, masks, settings, cmap='inferno', figuresize=10, nr=1, print_object_number=True)

   Plot organelle segmentation results for a batch of images.

   Shows the original image, the label mask with object IDs, and a
   morphology-specific diagnostic panel (e.g. thresholded binary,
   ridge-enhanced image, DoG response, or ring edges).

   :param img_batch: Shape (N, H, W) single-channel images.
   :type img_batch: numpy.ndarray
   :param masks: Label masks, one per image.
   :type masks: list or numpy.ndarray
   :param settings: Organelle settings (used to determine morphology/method
                    and generate the diagnostic panel).
   :type settings: dict
   :param cmap: Colormap for the raw image. Defaults to 'inferno'.
   :type cmap: str, optional
   :param figuresize: Base figure size. Defaults to 10.
   :type figuresize: int, optional
   :param nr: Maximum number of images to plot. Defaults to 1.
   :type nr: int, optional
   :param print_object_number: Print object labels on the mask.
                               Defaults to True.
   :type print_object_number: bool, optional

   :returns: None


.. py:function:: plot_masks(batch, masks, flows, cmap='inferno', figuresize=10, nr=1, file_type='.npz', print_object_number=True)

   Plot the masks and flows for a given batch of images.

   :param batch: The batch of images.
   :type batch: numpy.ndarray
   :param masks: The masks corresponding to the images.
   :type masks: list or numpy.ndarray
   :param flows: The flows corresponding to the images.
   :type flows: list or numpy.ndarray
   :param cmap: The colormap to use for displaying the images. Defaults to 'inferno'.
   :type cmap: str, optional
   :param figuresize: The size of the figure. Defaults to 20.
   :type figuresize: int, optional
   :param nr: The maximum number of images to plot. Defaults to 1.
   :type nr: int, optional
   :param file_type: The file type of the flows. Defaults to '.npz'.
   :type file_type: str, optional
   :param print_object_number: Whether to print the object number on the mask. Defaults to True.
   :type print_object_number: bool, optional

   :returns: None


.. py:function:: generate_mask_random_cmap(mask)

   Generate a random colormap based on the unique labels in the given mask.

   Parameters:
   mask (numpy.ndarray): The input mask array.

   Returns:
   matplotlib.colors.ListedColormap: The random colormap.


.. py:function:: random_cmap(num_objects=100)

   Generate a random colormap.

   Parameters:
   num_objects (int): The number of objects to generate colors for. Default is 100.

   Returns:
   random_cmap (matplotlib.colors.ListedColormap): A random colormap.


.. py:function:: plot_images_and_arrays(folders, lower_percentile=1, upper_percentile=99, threshold=1000, extensions=['.npy', '.tif', '.tiff', '.png'], overlay=False, max_nr=None, randomize=True)

   Plot images and arrays from the given folders.

   :param folders: A list of folder paths containing the images and arrays.
   :type folders: list
   :param lower_percentile: The lower percentile for image normalization. Defaults to 1.
   :type lower_percentile: int, optional
   :param upper_percentile: The upper percentile for image normalization. Defaults to 99.
   :type upper_percentile: int, optional
   :param threshold: The threshold for determining whether to display an image as a mask or normalize it. Defaults to 1000.
   :type threshold: int, optional
   :param extensions: A list of file extensions to consider. Defaults to ['.npy', '.tif', '.tiff', '.png'].
   :type extensions: list, optional
   :param overlay: If True, overlay the outlines of the objects on the image. Defaults to False.
   :type overlay: bool, optional


.. py:function:: plot_arrays(src, figuresize=10, cmap='inferno', nr=1, normalize=True, q1=1, q2=99)

   Plot randomly selected arrays from a given directory or a single .npz/.npy file.

   Parameters:
   - src (str): The directory path or file path containing the arrays.
   - figuresize (int): The size of the figure (default: 10).
   - cmap (str): The colormap to use for displaying the arrays (default: 'inferno').
   - nr (int): The number of arrays to plot (default: 1).
   - normalize (bool): Whether to normalize the arrays (default: True).
   - q1 (int): The lower percentile for normalization (default: 1).
   - q2 (int): The upper percentile for normalization (default: 99).


.. py:function:: plot_merged(src, settings)

   Plot the merged images after applying various filters and modifications.

   :param src: Path to folder with images.
   :type src: path
   :param settings: The settings for the plot.
   :type settings: dict

   :returns: None


.. py:function:: generate_plate_heatmap(df, plate_number, variable, grouping, min_max, min_count)

.. py:function:: plot_plates(df, variable, grouping, min_max, cmap, min_count=0, verbose=True, dst=None)

.. py:function:: print_mask_and_flows(stack, mask, flows, overlay=True, max_size=1000, thickness=2)

   Display the original image, mask with outlines, and flow images.

   :param stack: Original image or stack.
   :type stack: np.array
   :param mask: Mask image.
   :type mask: np.array
   :param flows: List of flow images.
   :type flows: list
   :param overlay: Whether to overlay the mask outlines on the original image.
   :type overlay: bool
   :param max_size: Maximum allowed size for any dimension of the images.
   :type max_size: int
   :param thickness: Thickness of the contour outlines.
   :type thickness: int


.. py:function:: plot_resize(images, resized_images, labels, resized_labels)

.. py:function:: normalize_and_visualize(image, normalized_image, title='')

   Utility function for visualization


.. py:function:: visualize_masks(mask1, mask2, mask3, title='Masks Comparison')

.. py:function:: visualize_cellpose_masks(masks, titles=None, filename=None, save=False, src=None)

   Visualize multiple masks with optional titles.

   :param masks: A list of masks to visualize.
   :type masks: list of np.ndarray
   :param titles: A list of titles for the masks. If None, default titles will be used.
   :type titles: list of str, optional
   :param comparison_title: Title for the entire figure.
   :type comparison_title: str


.. py:function:: plot_comparison_results(comparison_results)

.. py:function:: plot_object_outlines(src, objects=['nucleus', 'cell', 'pathogen'], channels=[0, 1, 2], max_nr=10)

.. py:function:: volcano_plot(coef_df, filename='volcano_plot.pdf')

.. py:function:: plot_histogram(df, column, dst=None)

.. py:function:: plot_lorenz_curves(csv_files, name_column='grna_name', value_column='count', remove_keys=None, x_lim=[0.0, 1], y_lim=[0, 1], remove_outliers=False, save=True)

.. py:function:: plot_permutation(permutation_df)

.. py:function:: plot_feature_importance(feature_importance_df)

.. py:function:: read_and_plot__vision_results(base_dir, y_axis='accuracy', name_split='_time', y_lim=[0.8, 0.9])

.. py:function:: jitterplot_by_annotation(src, x_column, y_column, plot_title='Jitter Plot', output_path=None, filter_column=None, filter_values=None)

   Reads a CSV file and creates a jitter plot of one column grouped by another column.

   Args:
   src (str): Path to the source data.
   x_column (str): Name of the column to be used for the x-axis.
   y_column (str): Name of the column to be used for the y-axis.
   plot_title (str): Title of the plot. Default is 'Jitter Plot'.
   output_path (str): Path to save the plot image. If None, the plot will be displayed. Default is None.

   Returns:
   pd.DataFrame: The filtered and balanced DataFrame.


.. py:function:: create_grouped_plot(df, grouping_column, data_column, graph_type='bar', summary_func='mean', order=None, colors=None, output_dir='./output', save=False, y_lim=None, error_bar_type='std')

   Create a grouped plot, perform statistical tests, and optionally export the results along with the plot.

   Parameters:
   - df: DataFrame containing the data.
   - grouping_column: Column name for the categorical grouping.
   - data_column: Column name for the data to be grouped and plotted.
   - graph_type: Type of plot ('bar', 'violin', 'jitter', 'box', 'jitter_box').
   - summary_func: Summary function to apply to each group ('mean', 'median', etc.).
   - order: List specifying the order of the groups. If None, groups will be ordered alphabetically.
   - colors: List of colors for each group.
   - output_dir: Directory where the figure and test results will be saved if `save=True`.
   - save: Boolean flag indicating whether to save the plot and results to files.
   - y_lim: Optional y-axis min and max.
   - error_bar_type: Type of error bars to plot, either 'std' for standard deviation or 'sem' for standard error of the mean.

   Outputs:
   - Figure of the plot.
   - DataFrame with full statistical test results, including normality tests.


.. py:class:: spacrGraph(df, grouping_column, data_column, graph_type='bar', summary_func='mean', order=None, colors=None, output_dir='./output', save=False, y_lim=None, log_y=False, log_x=False, error_bar_type='std', remove_outliers=False, theme='pastel', representation='object', paired=False, all_to_all=True, compare_group=None, graph_name=None)

   .. py:attribute:: df


   .. py:attribute:: grouping_column


   .. py:attribute:: order


   .. py:attribute:: data_column


   .. py:attribute:: graph_type
      :value: 'bar'



   .. py:attribute:: summary_func
      :value: 'mean'



   .. py:attribute:: colors
      :value: None



   .. py:attribute:: output_dir
      :value: './output'



   .. py:attribute:: save
      :value: False



   .. py:attribute:: error_bar_type
      :value: 'std'



   .. py:attribute:: remove_outliers
      :value: False



   .. py:attribute:: theme
      :value: 'pastel'



   .. py:attribute:: representation
      :value: 'object'



   .. py:attribute:: paired
      :value: False



   .. py:attribute:: all_to_all
      :value: True



   .. py:attribute:: compare_group
      :value: None



   .. py:attribute:: y_lim
      :value: None



   .. py:attribute:: graph_name
      :value: None



   .. py:attribute:: log_x
      :value: False



   .. py:attribute:: log_y
      :value: False



   .. py:attribute:: results_df


   .. py:attribute:: sns_palette
      :value: None



   .. py:attribute:: fig
      :value: None



   .. py:attribute:: results_name
      :value: '___'



   .. py:attribute:: raw_df


   .. py:method:: preprocess_data()

      Preprocess the data: remove NaNs, optionally ensure 'plateID' column is created,
      then group by either 'prc', 'plateID', or do no grouping at all if representation == 'object'.



   .. py:method:: remove_outliers_from_plot()

      Remove outliers from the plot but keep them in the data.



   .. py:method:: perform_normality_tests()

      Perform normality tests for each group and data column.



   .. py:method:: perform_levene_test(unique_groups)


   .. py:method:: perform_statistical_tests(unique_groups, is_normal)

      Perform statistical tests separately for each data column.



   .. py:method:: perform_posthoc_tests(is_normal, unique_groups)

      Perform post-hoc tests for multiple groups based on all_to_all flag.



   .. py:method:: create_plot(ax=None)

      Create and display the plot based on the chosen graph type.



   .. py:method:: get_results()

      Return the results dataframe.



   .. py:method:: get_figure()

      Return the generated figure.



.. py:function:: plot_data_from_db(settings)

.. py:function:: plot_data_from_csv(settings)

.. py:function:: plot_region(settings)

.. py:function:: plot_image_grid(image_paths, percentiles)

   Plots a square grid of images from a list of image paths.
   Unused subplots are filled with black, and padding is minimized.

   Parameters:
   - image_paths: List of paths to images to be displayed.

   Returns:
   - fig: The generated matplotlib figure.


.. py:function:: overlay_masks_on_images(img_folder, normalize=True, resize=True, save=False, plot=False, thickness=2)

   Load images and masks from folders, overlay mask contours on images, and optionally normalize, resize, and save.

   :param img_folder: Path to the folder containing images.
   :type img_folder: str
   :param mask_folder: Path to the folder containing masks.
   :type mask_folder: str
   :param normalize: If True, normalize images to the 1st and 99th percentiles.
   :type normalize: bool
   :param resize: If True, resize the final overlay to 500x500.
   :type resize: bool
   :param save: If True, save the final overlay in an 'overlay' folder within the image folder.
   :type save: bool
   :param thickness: Thickness of the contour lines.
   :type thickness: int


.. py:function:: graph_importance(settings)

.. py:function:: plot_proportion_stacked_bars(settings, df, group_column, bin_column, prc_column='prc', level='object', cmap='viridis')

   Generate a stacked bar plot for proportions and perform chi-squared and pairwise tests.

   Parameters:
   - settings (dict): Analysis settings.
   - df (DataFrame): Input data.
   - group_column (str): Column indicating the groups.
   - bin_column (str): Column indicating the categories.
   - prc_column (str): Optional; column for additional stratification.
   - level (str): Level of aggregation ('well' or 'object').

   Returns:
   - chi2 (float): Chi-squared statistic for the overall test.
   - p (float): p-value for the overall chi-squared test.
   - dof (int): Degrees of freedom for the overall chi-squared test.
   - expected (ndarray): Expected frequencies for the overall chi-squared test.
   - raw_counts (DataFrame): Contingency table of observed counts.
   - fig (Figure): The generated plot.
   - pairwise_results (list): Pairwise test results from `chi_pairwise`.


.. py:function:: create_venn_diagram(file1, file2, gene_column='gene', filter_coeff=0.1, save=True, save_path=None)

   Reads two CSV files, extracts the `gene` column, and creates a Venn diagram
   to show overlapping and non-overlapping genes.

   :param file1: Path to the first CSV file.
   :type file1: str
   :param file2: Path to the second CSV file.
   :type file2: str
   :param gene_column: Name of the column containing gene data (default: "gene").
   :type gene_column: str
   :param filter_coeff: Coefficient threshold for filtering genes.
   :type filter_coeff: float
   :param save: Whether to save the plot.
   :type save: bool
   :param save_path: Path to save the Venn diagram figure.
   :type save_path: str

   :returns: Overlapping and non-overlapping genes.
   :rtype: dict


.. py:function:: volcano_plot(data: Union[str, pandas.DataFrame], *, fold_change_col: str, p_value_col: str, name_col: Optional[str] = None, x_transform: str = 'none', y_transform: str = '-log10', fold_change_threshold: Optional[float] = None, p_value_threshold: Optional[float] = None, annotate: bool = True, annotate_max: Optional[int] = None, point_size: float = 20.0, alpha: float = 0.7, figsize: Tuple[float, float] = (8.0, 6.0), title: Optional[str] = None, xlim: Optional[Tuple[float, float]] = None, ylim: Optional[Tuple[float, float]] = None, threshold_line_kwargs: Optional[dict] = None, scatter_kwargs: Optional[dict] = None, text_kwargs: Optional[dict] = None, save_path: Optional[str] = None, show: bool = True, ax: Optional[matplotlib.pyplot.Axes] = None, sheet_name: Union[int, str] = 0) -> Tuple[matplotlib.pyplot.Figure, matplotlib.pyplot.Axes, list]

   Read a table (CSV/TSV/XLS/XLSX) and generate a volcano plot.

   Auto-detects file type from extension:
     - .csv -> read_csv
     - .tsv / .tab -> read_csv(sep="\t")
     - .xls / .xlsx -> read_excel (sheet_name)

   :param data: Path to table file or a pandas DataFrame.
   :param fold_change_col: Column name for fold change (or log fold change if x_transform="none").
   :param p_value_col: Column name for p-values.
   :param name_col: Optional column for labels.
   :param x_transform: "none", "log2", "log10", "ln".
                       Use "none" if fold_change_col already contains logFC values (can be negative).
   :param y_transform: "none", "-log10", "-ln", "log10", "ln".
   :param fold_change_threshold:
                                 - If x_transform="none": interpreted in x units (e.g. abs(log2FC) >= 1).
                                 - If x_transform is a log transform: interpreted in *raw FC* units (e.g. 2),
                                   then converted into plotted units for the vertical dashed lines.
   :param p_value_threshold: P-value cutoff (e.g. 0.05); drawn as a dashed horizontal line.
   :param annotate: If True and name_col is provided, annotate points passing thresholds.
                    If no thresholds are provided, nothing is annotated unless annotate_max is set.
   :param annotate_max: If set, annotate at most N eligible points (highest y first).
   :param sheet_name: Excel sheet index/name (used for .xls/.xlsx).

   :returns: (fig, ax, hits) where hits is the list of annotated labels.


