Source code for spacr.toxo

import os
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from adjustText import adjust_text
import pandas as pd
from scipy.stats import fisher_exact
from sklearn.metrics import mean_absolute_error
from matplotlib.gridspec import GridSpec

[docs] def custom_volcano_plot(data_path, metadata_path, metadata_column='tagm_location',point_size=50, figsize=20, threshold=0,save_path=None, x_lim=[-0.5, 0.5], y_lims=[[0, 6], [9, 20]]): # Dictionary mapping compartment to color colors = {'micronemes':'black', 'rhoptries 1':'darkviolet', 'rhoptries 2':'darkviolet', 'nucleus - chromatin':'blue', 'nucleus - non-chromatin':'blue', 'dense granules':'teal', 'ER 1':'pink', 'ER 2':'pink', 'unknown':'black', 'tubulin cytoskeleton':'slategray', 'IMC':'slategray', 'PM - peripheral 1':'slategray', 'PM - peripheral 2':'slategray', 'cytosol':'turquoise', 'mitochondrion - soluble':'red', 'mitochondrion - membranes':'red', 'apicoplast':'slategray', 'Golgi':'green', 'PM - integral':'slategray', 'apical 1':'orange', 'apical 2':'orange', '19S proteasome':'slategray', '20S proteasome':'slategray', '60S ribosome':'slategray', '40S ribosome':'slategray', } # Increase font size for better readability fontsize = 18 plt.rcParams.update({'font.size': fontsize}) # --- Load data --- if isinstance(data_path, pd.DataFrame): data = data_path else: data = pd.read_csv(data_path) # Extract ‘variable’ and ‘gene_nr’ from your feature notation data['variable'] = data['feature'].str.extract(r'\[(.*?)\]') data['variable'].fillna(data['feature'], inplace=True) data['gene_nr'] = data['variable'].str.split('_').str[0] data = data[data['variable'] != 'Intercept'] # --- Load metadata --- if isinstance(metadata_path, pd.DataFrame): metadata = metadata_path else: metadata = pd.read_csv(metadata_path) metadata['gene_nr'] = metadata['gene_nr'].astype(str) data['gene_nr'] = data['gene_nr'].astype(str) # Merge data and metadata merged_data = pd.merge(data, metadata[['gene_nr', metadata_column]], on='gene_nr', how='left') merged_data[metadata_column].fillna('unknown', inplace=True) # --- Create figure with "upper" and "lower" subplots sharing the x-axis --- fig = plt.figure(figsize=(figsize, figsize)) gs = GridSpec(2, 1, height_ratios=[1, 3], hspace=0.05) ax_upper = fig.add_subplot(gs[0]) ax_lower = fig.add_subplot(gs[1], sharex=ax_upper) # Hide x-axis labels on the upper plot ax_upper.tick_params(axis='x', which='both', bottom=False, labelbottom=False) # List to collect the variables (hits) that meet threshold criteria hit_list = [] # --- Scatter plot on both axes --- for _, row in merged_data.iterrows(): y_val = -np.log10(row['p_value']) # Decide which axis to draw on based on the p-value ax = ax_upper if y_val > y_lims[1][0] else ax_lower # Here is the main change: color by the colors dict ax.scatter( row['coefficient'], y_val, color=colors.get(row[metadata_column], 'gray'), # <-- Use your color dict marker='o', # You can fix a single marker if desired s=point_size, edgecolor='black', alpha=0.6 ) # Check significance thresholds if (row['p_value'] <= 0.05) and (abs(row['coefficient']) >= abs(threshold)): hit_list.append(row['variable']) # --- Adjust axis limits --- ax_upper.set_ylim(y_lims[1]) ax_lower.set_ylim(y_lims[0]) ax_lower.set_xlim(x_lim) # Hide top spines ax_lower.spines['top'].set_visible(False) ax_upper.spines['top'].set_visible(False) ax_upper.spines['bottom'].set_visible(False) # Set x-axis and y-axis labels ax_lower.set_xlabel('Coefficient') ax_lower.set_ylabel('-log10(p-value)') ax_upper.set_ylabel('-log10(p-value)') for ax in [ax_upper, ax_lower]: ax.spines['right'].set_visible(False) # --- Add threshold lines to both axes --- for ax in [ax_upper, ax_lower]: ax.axvline(x=-abs(threshold), linestyle='--', color='black') ax.axvline(x=abs(threshold), linestyle='--', color='black') ax_lower.axhline(y=-np.log10(0.05), linestyle='--', color='black') # --- Annotate significant points --- texts_upper, texts_lower = [], [] for _, row in merged_data.iterrows(): y_val = -np.log10(row['p_value']) if row['p_value'] > 0.05 or abs(row['coefficient']) < abs(threshold): continue ax = ax_upper if y_val > y_lims[1][0] else ax_lower text = ax.text( row['coefficient'], y_val, row['variable'], fontsize=fontsize, ha='center', va='bottom' ) if ax == ax_upper: texts_upper.append(text) else: texts_lower.append(text) # Attempt to keep text labels from overlapping adjust_text(texts_upper, ax=ax_upper, arrowprops=dict(arrowstyle='-', color='black')) adjust_text(texts_lower, ax=ax_lower, arrowprops=dict(arrowstyle='-', color='black')) # --- Add a legend keyed by color (optional) --- # If you'd like a legend that shows what each compartment color represents: legend_handles = [] for comp, comp_color in colors.items(): # Create a “dummy” scatter for legend legend_handles.append( plt.Line2D([0], [0], marker='o', color=comp_color, label=comp, linewidth=0, markersize=8) ) # You can adjust the location and styling of the legend to taste: ax_lower.legend( handles=legend_handles, bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.25, labelspacing=2, handletextpad=0.25, markerscale=1.5, prop={'size': fontsize} ) # --- Save and show --- if save_path: plt.savefig(save_path, format='pdf', bbox_inches='tight') plt.show() return hit_list
[docs] def go_term_enrichment_by_column(significant_df, metadata_path, go_term_columns=['Computed GO Processes', 'Curated GO Components', 'Curated GO Functions', 'Curated GO Processes']): """ Perform GO term enrichment analysis for each GO term column and generate plots. Parameters: - significant_df: DataFrame containing the significant genes from the screen. - metadata_path: Path to the metadata file containing GO terms. - go_term_columns: List of columns in the metadata corresponding to GO terms. For each GO term column, this function will: - Split the GO terms by semicolons. - Count the occurrences of GO terms in the hits and in the background. - Perform Fisher's exact test for enrichment. - Plot the enrichment score vs -log10(p-value). """ #significant_df['variable'].fillna(significant_df['feature'], inplace=True) #split_columns = significant_df['variable'].str.split('_', expand=True) #significant_df['gene_nr'] = split_columns[0] #gene_list = significant_df['gene_nr'].to_list() significant_df = significant_df.dropna(subset=['n_gene']) significant_df = significant_df[significant_df['n_gene'] != None] gene_list = significant_df['n_gene'].to_list() # Load metadata metadata = pd.read_csv(metadata_path) split_columns = metadata['Gene ID'].str.split('_', expand=True) metadata['gene_nr'] = split_columns[1] # Create a subset of metadata with only the rows that contain genes in gene_list (hits) hits_metadata = metadata[metadata['gene_nr'].isin(gene_list)] # Create a list to hold results from all columns combined_results = [] for go_term_column in go_term_columns: # Initialize lists to store results go_terms = [] enrichment_scores = [] p_values = [] # Split the GO terms in the entire metadata and hits metadata[go_term_column] = metadata[go_term_column].fillna('') hits_metadata[go_term_column] = hits_metadata[go_term_column].fillna('') all_go_terms = metadata[go_term_column].str.split(';').explode() hit_go_terms = hits_metadata[go_term_column].str.split(';').explode() # Count occurrences of each GO term in hits and total metadata all_go_term_counts = all_go_terms.value_counts() hit_go_term_counts = hit_go_terms.value_counts() # Perform enrichment analysis for each GO term for go_term in all_go_term_counts.index: total_with_go_term = all_go_term_counts.get(go_term, 0) hits_with_go_term = hit_go_term_counts.get(go_term, 0) # Calculate the total number of genes and hits total_genes = len(metadata) total_hits = len(hits_metadata) # Perform Fisher's exact test contingency_table = [[hits_with_go_term, total_hits - hits_with_go_term], [total_with_go_term - hits_with_go_term, total_genes - total_hits - (total_with_go_term - hits_with_go_term)]] _, p_value = fisher_exact(contingency_table) # Calculate enrichment score (hits with GO term / total hits with GO term) if total_with_go_term > 0 and total_hits > 0: enrichment_score = (hits_with_go_term / total_hits) / (total_with_go_term / total_genes) else: enrichment_score = 0.0 # Store the results only if enrichment score is non-zero if enrichment_score > 0.0: go_terms.append(go_term) enrichment_scores.append(enrichment_score) p_values.append(p_value) # Create a results DataFrame for this GO term column results_df = pd.DataFrame({ 'GO Term': go_terms, 'Enrichment Score': enrichment_scores, 'P-value': p_values, 'GO Column': go_term_column # Track the GO term column for final combined plot }) # Sort by enrichment score results_df = results_df.sort_values(by='Enrichment Score', ascending=False) # Append this DataFrame to the combined list combined_results.append(results_df) # Plot the enrichment results for each individual column plt.figure(figsize=(10, 6)) # Create a scatter plot of Enrichment Score vs -log10(p-value) sns.scatterplot(data=results_df, x='Enrichment Score', y=-np.log10(results_df['P-value']), hue='GO Term', size='Enrichment Score', sizes=(50, 200)) # Set plot labels and title plt.title(f'GO Term Enrichment Analysis for {go_term_column}') plt.xlabel('Enrichment Score') plt.ylabel('-log10(P-value)') # Move the legend to the right of the plot plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.) # Show the plot plt.tight_layout() # Ensure everything fits in the figure area plt.show() # Optionally return or save the results for each column print(f'Results for {go_term_column}') # Combine results from all columns into a single DataFrame combined_df = pd.concat(combined_results) # Plot the combined results with text labels plt.figure(figsize=(12, 8)) sns.scatterplot(data=combined_df, x='Enrichment Score', y=-np.log10(combined_df['P-value']), style='GO Column', size='Enrichment Score', sizes=(50, 200)) # Set plot labels and title for the combined graph plt.title('Combined GO Term Enrichment Analysis') plt.xlabel('Enrichment Score') plt.ylabel('-log10(P-value)') # Annotate the points with labels and connecting lines texts = [] for i, row in combined_df.iterrows(): texts.append(plt.text(row['Enrichment Score'], -np.log10(row['P-value']), row['GO Term'], fontsize=9)) # Adjust text to avoid overlap adjust_text(texts, arrowprops=dict(arrowstyle='-', color='black')) # Show the combined plot plt.tight_layout() plt.show()
[docs] def plot_gene_phenotypes(data, gene_list, x_column='Gene ID', data_column='T.gondii GT1 CRISPR Phenotype - Mean Phenotype',error_column='T.gondii GT1 CRISPR Phenotype - Standard Error', save_path=None): """ Plot a line graph for the mean phenotype with standard error shading and highlighted genes. Args: data (pd.DataFrame): The input DataFrame containing gene data. gene_list (list): A list of gene names to highlight on the plot. """ # Ensure x_column is properly processed def extract_gene_id(gene): if isinstance(gene, str) and '_' in gene: return gene.split('_')[1] return str(gene) data.loc[:, data_column] = pd.to_numeric(data[data_column], errors='coerce') data = data.dropna(subset=[data_column]) data.loc[:, error_column] = pd.to_numeric(data[error_column], errors='coerce') data = data.dropna(subset=[error_column]) data['x'] = data[x_column].apply(extract_gene_id) # Sort by the data_column and assign ranks data = data.sort_values(by=data_column).reset_index(drop=True) data['rank'] = range(1, len(data) + 1) # Prepare the x, y, and error values for plotting x = data['rank'] y = data[data_column] yerr = data[error_column] # Create the plot plt.figure(figsize=(10, 10)) # Plot the mean phenotype with standard error shading plt.plot(x, y, label='Mean Phenotype', color=(0/255, 155/255, 155/255), linewidth=2) plt.fill_between( x, y - yerr, y + yerr, color=(0/255, 155/255, 155/255), alpha=0.1, label='Standard Error' ) # Prepare for adjustText texts = [] # Store text objects for adjustment # Highlight the genes in the gene_list for gene in gene_list: gene_id = extract_gene_id(gene) gene_data = data[data['x'] == gene_id] if not gene_data.empty: # Scatter the highlighted points in purple and add labels for adjustment plt.scatter( gene_data['rank'], gene_data[data_column], color=(155/255, 55/255, 155/255), s=200, alpha=0.6, label=f'Highlighted Gene: {gene}', zorder=3 # Ensure the points are on top ) # Add the text label next to the highlighted gene texts.append( plt.text( gene_data['rank'].values[0], gene_data[data_column].values[0], gene, fontsize=18, ha='right' ) ) # Adjust text to avoid overlap with lines drawn from points to text adjust_text(texts, arrowprops=dict(arrowstyle='-', color='gray')) # Label the plot plt.xlabel('Rank') plt.ylabel('Mean Phenotype') #plt.xticks(rotation=90) # Rotate x-axis labels for readability plt.legend().remove() # Remove the legend if not needed plt.tight_layout() # Save the plot if a path is provided if save_path: plt.savefig(save_path, format='pdf', dpi=600, bbox_inches='tight') print(f"Figure saved to {save_path}") plt.show()
[docs] def plot_gene_heatmaps(data, gene_list, columns, x_column='Gene ID', normalize=False, save_path=None): """ Generate a teal-to-white heatmap with the specified columns and genes. Args: data (pd.DataFrame): The input DataFrame containing gene data. gene_list (list): A list of genes to include in the heatmap. columns (list): A list of column names to visualize as heatmaps. normalize (bool): If True, normalize the values for each gene between 0 and 1. save_path (str): Optional. If provided, the plot will be saved to this path. """ # Ensure x_column is properly processed def extract_gene_id(gene): if isinstance(gene, str) and '_' in gene: return gene.split('_')[1] return str(gene) data['x'] = data[x_column].apply(extract_gene_id) # Filter the data to only include the specified genes filtered_data = data[data['x'].isin(gene_list)].set_index('x')[columns] # Normalize each gene's values between 0 and 1 if normalize=True if normalize: filtered_data = filtered_data.apply(lambda x: (x - x.min()) / (x.max() - x.min()), axis=1) # Define the figure size dynamically based on the number of genes and columns width = len(columns) * 4 height = len(gene_list) * 1 # Create the heatmap plt.figure(figsize=(width, height)) cmap = sns.color_palette("viridis", as_cmap=True) # Plot the heatmap with genes on the y-axis and columns on the x-axis sns.heatmap( filtered_data, cmap=cmap, cbar=True, annot=False, linewidths=0.5, square=True ) # Set the labels plt.xticks(rotation=90, ha='center') # Rotate x-axis labels for better readability plt.yticks(rotation=0) # Keep y-axis labels horizontal plt.xlabel('') plt.ylabel('') # Adjust layout to ensure the plot fits well plt.tight_layout() # Save the plot if a path is provided if save_path: plt.savefig(save_path, format='pdf', dpi=600, bbox_inches='tight') print(f"Figure saved to {save_path}") plt.show()
[docs] def generate_score_heatmap(settings): def group_cv_score(csv, plate=1, column='c3', data_column='pred'): df = pd.read_csv(csv) if 'column_name' in df.columns: df = df[df['column_name']==column] elif 'column' in df.columns: df['columnID'] = df['column'] df = df[df['column_name']==column] if not plate is None: df['plateID'] = f"plate{plate}" grouped_df = df.groupby(['plateID', 'rowID', 'column_name'])[data_column].mean().reset_index() grouped_df['prc'] = grouped_df['plateID'].astype(str) + '_' + grouped_df['rowID'].astype(str) + '_' + grouped_df['column_name'].astype(str) return grouped_df def calculate_fraction_mixed_condition(csv, plate=1, column='c3', control_sgrnas = ['TGGT1_220950_1', 'TGGT1_233460_4']): df = pd.read_csv(csv) df = df[df['column_name']==column] if plate not in df.columns: df['plateID'] = f"plate{plate}" df = df[df['grna_name'].str.match(f'^{control_sgrnas[0]}$|^{control_sgrnas[1]}$')] grouped_df = df.groupby(['plateID', 'rowID', 'column_name'])['count'].sum().reset_index() grouped_df = grouped_df.rename(columns={'count': 'total_count'}) merged_df = pd.merge(df, grouped_df, on=['plateID', 'rowID', 'column_name']) merged_df['fraction'] = merged_df['count'] / merged_df['total_count'] merged_df['prc'] = merged_df['plateID'].astype(str) + '_' + merged_df['rowID'].astype(str) + '_' + merged_df['column_name'].astype(str) return merged_df def plot_multi_channel_heatmap(df, column='c3'): """ Plot a heatmap with multiple channels as columns. Parameters: - df: DataFrame with scores for different channels. - column: Column to filter by (default is 'c3'). """ # Extract row number and convert to integer for sorting df['row_num'] = df['rowID'].str.extract(r'(\d+)').astype(int) # Filter and sort by plate, row, and column df = df[df['column_name'] == column] df = df.sort_values(by=['plateID', 'row_num', 'column_name']) # Drop temporary 'row_num' column after sorting df = df.drop('row_num', axis=1) # Create a new column combining plate, row, and column for the index df['plate_row_col'] = df['plateID'] + '-' + df['rowID'] + '-' + df['column_name'] # Set 'plate_row_col' as the index df.set_index('plate_row_col', inplace=True) # Extract only numeric data for the heatmap heatmap_data = df.select_dtypes(include=[float, int]) # Plot heatmap with square boxes, no annotations, and 'viridis' colormap plt.figure(figsize=(12, 8)) sns.heatmap( heatmap_data, cmap="viridis", cbar=True, square=True, annot=False ) plt.title("Heatmap of Prediction Scores for All Channels") plt.xlabel("Channels") plt.ylabel("Plate-Row-Column") plt.tight_layout() # Save the figure object and return it fig = plt.gcf() plt.show() return fig def combine_classification_scores(folders, csv_name, data_column, plate=1, column='c3'): # Ensure `folders` is a list if isinstance(folders, str): folders = [folders] ls = [] # Initialize ls to store found CSV file paths # Iterate over the provided folders for folder in folders: sub_folders = os.listdir(folder) # Get sub-folder list for sub_folder in sub_folders: # Iterate through sub-folders path = os.path.join(folder, sub_folder) # Join the full path if os.path.isdir(path): # Check if it’s a directory csv = os.path.join(path, csv_name) # Join path to the CSV file if os.path.exists(csv): # If CSV exists, add to list ls.append(csv) else: print(f'No such file: {csv}') # Initialize combined DataFrame combined_df = None print(f'Found {len(ls)} CSV files') # Loop through all collected CSV files and process them for csv_file in ls: df = pd.read_csv(csv_file) # Read CSV into DataFrame df = df[df['column_name']==column] if not plate is None: df['plateID'] = f"plate{plate}" # Group the data by 'plateID', 'rowID', and 'column_name' grouped_df = df.groupby(['plateID', 'rowID', 'column_name'])[data_column].mean().reset_index() # Use the CSV filename to create a new column name folder_name = os.path.dirname(csv_file).replace(".csv", "") new_column_name = os.path.basename(f"{folder_name}_{data_column}") print(new_column_name) grouped_df = grouped_df.rename(columns={data_column: new_column_name}) # Merge into the combined DataFrame if combined_df is None: combined_df = grouped_df else: combined_df = pd.merge(combined_df, grouped_df, on=['plateID', 'rowID', 'column_name'], how='outer') combined_df['prc'] = combined_df['plateID'].astype(str) + '_' + combined_df['rowID'].astype(str) + '_' + combined_df['column_name'].astype(str) return combined_df def calculate_mae(df): """ Calculate the MAE between each channel's predictions and the fraction column for all rows. """ # Extract numeric columns excluding 'fraction' and 'prc' channels = df.drop(columns=['fraction', 'prc']).select_dtypes(include=[float, int]) mae_data = [] # Compute MAE for each channel with 'fraction' for all rows for column in channels.columns: for index, row in df.iterrows(): mae = mean_absolute_error([row['fraction']], [row[column]]) mae_data.append({'Channel': column, 'MAE': mae, 'Row': row['prc']}) # Convert the list of dictionaries to a DataFrame mae_df = pd.DataFrame(mae_data) return mae_df result_df = combine_classification_scores(settings['folders'], settings['csv_name'], settings['data_column'], settings['plateID'], settings['columnID'], ) df = calculate_fraction_mixed_condition(settings['csv'], settings['plateID'], settings['columnID'], settings['control_sgrnas']) df = df[df['grna_name']==settings['fraction_grna']] fraction_df = df[['fraction', 'prc']] merged_df = pd.merge(fraction_df, result_df, on=['prc']) cv_df = group_cv_score(settings['cv_csv'], settings['plateID'], settings['columnID'], settings['data_column_cv']) cv_df = cv_df[[settings['data_column_cv'], 'prc']] merged_df = pd.merge(merged_df, cv_df, on=['prc']) fig = plot_multi_channel_heatmap(merged_df, settings['columnID']) if 'row_number' in merged_df.columns: merged_df = merged_df.drop('row_num', axis=1) mae_df = calculate_mae(merged_df) if 'row_number' in mae_df.columns: mae_df = mae_df.drop('row_num', axis=1) if not settings['dst'] is None: mae_dst = os.path.join(settings['dst'], f"mae_scores_comparison_plate_{settings['plateID']}.csv") merged_dst = os.path.join(settings['dst'], f"scores_comparison_plate_{settings['plateID']}_data.csv") heatmap_save = os.path.join(settings['dst'], f"scores_comparison_plate_{settings['plateID']}.pdf") mae_df.to_csv(mae_dst, index=False) merged_df.to_csv(merged_dst, index=False) fig.savefig(heatmap_save, format='pdf', dpi=600, bbox_inches='tight') return merged_df