Coverage for ParTIpy/enrichment.py: 12%
74 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-16 10:20 +0100
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-16 10:20 +0100
1"""Functions to calculate which features (e.g. genes or covariates) are enriched at each archetype"""
3import numpy as np
4import pandas as pd
5import scanpy as sc
6from scipy.spatial.distance import cdist
9def calculate_weights(
10 X: np.ndarray | sc.AnnData,
11 Z: np.ndarray | None = None,
12 mode: str = "automatic",
13 length_scale: None | float = None,
14) -> None | tuple[np.ndarray, float | None]:
15 """
16 Calculate weights for cells based on their distance to archetypes using a squared exponential kernel.
18 Parameters
19 ----------
20 X : Union[np.ndarray, sc.AnnData]
21 The input data, which can be either:
22 - A 2D array of shape (n_samples, n_features) representing the PCA coordinates of the cells.
23 - An AnnData object containing the PCA coordinates in `.obsm["X_pca_reduced"]` and archetypes in `.uns["archetypal_analysis"]["Z"]`.
24 Z : np.ndarray, optional
25 A 2D array of shape (n_archetypes, n_features) representing the PCA coordinates of the archetypes.
26 Required if `X` is not an AnnData object.
27 mode : str, optional (default="automatic")
28 The mode for determining the length scale of the kernel:
29 - "automatic": The length scale is calculated as half the median distance from the data centroid to the archetypes.
30 - "manual": The length scale is provided by the user via the `length_scale` parameter.
31 length_scale : float, optional
32 The length scale of the kernel. Required if `mode="manual"`.
34 Returns
35 -------
36 np.ndarray
37 - If `X` is an AnnData object, the weights are added to `X.obsm["cell_weights"]` and nothing is returned.
38 - If `X` is a numpy array, a 2D array of shape (n_samples, n_archetypes) representing the weights for each cell-archetype pair.
39 """
40 # Handle and validate input data
41 adata = None
42 if isinstance(X, sc.AnnData):
43 adata = X
44 if "archetypal_analysis" not in X.uns:
45 raise ValueError("Result from Archetypal Analysis not found in adata.uns. Please run AA()")
46 Z = X.uns["archetypal_analysis"]["Z"]
47 X = X.obsm["X_pca_reduced"]
49 if Z is None:
50 raise ValueError("Please add the archetypes coordinates as input Z")
52 # Calculate or validate length_scale based on mode
53 if mode == "automatic":
54 centroid = np.mean(X, axis=0).reshape(1, -1)
55 length_scale = np.median(cdist(centroid, Z)) / 2
56 elif mode == "manual":
57 if length_scale is None:
58 raise ValueError("For 'manual' mode, 'length_scale' must be provided.")
59 else:
60 raise ValueError("Mode must be either 'automatic' or 'manual'.")
61 print(f"Applied length scale is {length_scale}.")
63 # Weight calculation
64 euclidean_dist = cdist(X, Z)
65 weights = np.exp(-(euclidean_dist**2) / (2 * length_scale**2)) # type: ignore[operator]
67 if isinstance(adata, sc.AnnData):
68 adata.obsm["cell_weights"] = weights
69 return None
70 else:
71 return weights
74def weighted_expr(adata: sc.AnnData, layer: str | None = None) -> np.ndarray:
75 """
76 Calculate a weighted pseudobulk expression profile for each archetype.
78 This function computes the weighted average of gene expression across cells for each archetype.
79 The weights should be based on the distance of cells to the archetypes, as computed by `calculate_weights`.
81 Parameters
82 ----------
83 adata : sc.AnnData
84 An AnnData object containing the gene expression data and weights. The weights should be stored in
85 `adata.obsm["cell_weights"]` as a 2D array of shape (n_samples, n_archetypes).
86 layer : str, optional (default=None)
87 The layer of the AnnData object to use for gene expression. If `None`, `adata.X` is used. For Pareto analysis of AA data,
88 z-scaled data is recommended.
90 Returns
91 -------
92 np.ndarray
93 A 2D array of shape (n_archetypes, n_genes) representing the weighted pseudobulk expression profiles.
94 """
95 weights = adata.obsm["cell_weights"].T
96 if layer is None:
97 expr = adata.X
98 else:
99 expr = adata.layers[layer]
100 pseudobulk = np.einsum("ij,jk->ik", weights, expr)
101 pseudobulk /= weights.sum(axis=1, keepdims=True)
103 pseudobulk_df = pd.DataFrame(pseudobulk, columns=adata.var_names)
105 return pseudobulk_df
108def extract_top_processes(
109 est: pd.DataFrame,
110 pval: pd.DataFrame,
111 order: str = "desc",
112 n: int = 20,
113 p_threshold: float = 0.05,
114) -> dict[str, pd.DataFrame]:
115 """
116 Extract the top enriched biological processes based on statistical significance.
118 This function filters and ranks the most enriched biological processes from the decoupler output
119 based on estimated enrichment scores (`est`) and corresponding p-values (`pval`) below the
120 specified threshold (`p_treshold`).
122 Parameters
123 ----------
124 est : pd.DataFrame
125 A DataFrame of shape (n_archetypes, n_processes) containing the estimated enrichment scores
126 for each process and archetype.
127 pval : pd.DataFrame
128 A DataFrame of shape (n_archetypes, n_processes) containing the p-values corresponding to
129 the enrichment scores in `est`.
130 order : str, optional (default="desc")
131 The sorting order for selecting the top processes:
132 - "desc": Selects the top `n` processes with the highest enrichment scores.
133 - "asc": Selects the top `n` processes with the lowest enrichment scores.
134 n : int, optional (default=20)
135 The number of top processes to extract per archetype.
136 p_threshold : float, optional (default=0.05)
137 The p-value threshold for filtering processes. Only processes with p-values below this
138 threshold are considered.
140 Returns
141 -------
142 Dict[str, pd.DataFrame]
143 A dictionary where keys are of the form "archetype_X" and values are
144 DataFrames containing the top `n` enriched processes for each archetype. Each DataFrame
145 has two columns:
146 - "Process": The name of the biological process.
147 - "Score": The enrichment score for the process.
148 """
149 # Validate input
150 if est.shape != pval.shape:
151 raise ValueError("`est` and `pval` must have the same shape.")
153 if order not in ["desc", "asc"]:
154 raise ValueError("`order` must be either 'desc' or 'asc'.")
156 results = {}
157 for archetype in range(est.shape[0]):
158 # Filter processes based on p-value threshold
159 significant_processes = pval.iloc[archetype] < p_threshold
160 filtered_scores = est.iloc[archetype, list(significant_processes)]
162 # Sort and select top processes
163 if order == "desc":
164 top_processes = filtered_scores.nlargest(n).reset_index()
165 else:
166 top_processes = filtered_scores.nsmallest(n).reset_index()
168 top_processes.columns = ["Process", "Score"]
169 results[f"archetype_{archetype}"] = top_processes
171 return results
174def extract_top_specific_processes(
175 est: pd.DataFrame,
176 pval: pd.DataFrame,
177 drop_threshold: int = 0,
178 n: int = 20,
179 p_threshold: float = 0.05,
180):
181 """
182 Extract the top enriched biological processes that are specific to each archetype.
184 This function identifies the most enriched biological processes for each archetype based on
185 estimated enrichment scores (`est`) and corresponding p-values (`pval`) from the decoupler output below the
186 specified threshold (`p_treshold`). It ensures that the selected processes are specific to the archetype by
187 enforcing that their enrichment scores are below a specified threshold (`drop_threshold`) in all other archetypes.
189 Parameters
190 ----------
191 est : pd.DataFrame
192 A DataFrame of shape (n_archetypes, n_processes) containing the estimated enrichment scores
193 for each process and archetype.
194 pval : pd.DataFrame
195 A DataFrame of shape (n_archetypes, n_processes) containing the p-values corresponding to
196 the enrichment scores in `est`.
197 drop_threshold : int, optional (default=20)
198 The enrichment threshold below which processes are dropped.
199 n : int, optional (default=20)
200 The number of top processes to extract per archetype.
201 p_threshold : float, optional (default=0.05)
202 The p-value threshold for filtering processes. Only processes with p-values below this
203 threshold are considered.
205 Returns
206 -------
207 Dict[str, pd.DataFrame]
208 A dictionary where keys are of the form "archetype_X" and values are
209 DataFrames containing the top `n` enriched processes for each archetype that are below a score of
210 `drop_threshold` for all other archetypes.
211 """
212 if est.shape != pval.shape:
213 raise ValueError("`est` and `pval` must have the same shape.")
215 results = {}
216 for archetype in range(est.shape[0]):
217 # Filter processes based on p-value threshold
218 significant_processes = pval.iloc[archetype] < p_threshold
219 top_processes = est.iloc[archetype, list(significant_processes)].nlargest(n).index
221 # Filter processes based on drop threshold
222 subset = est.loc[:, top_processes]
223 subset.index = subset.index.astype(int)
224 filtered_processes = top_processes[(subset.drop(index=archetype) < drop_threshold).all(axis=0)]
226 results[f"archetype_{archetype}"] = est.loc[:, filtered_processes].copy()
228 return results
231def meta_enrichment(adata: sc.AnnData, meta: str) -> pd.DataFrame:
232 """
233 Compute the weighted enrichment of metadata categories across archetypes.
235 This function performs the following steps:
236 1. One-hot encodes the categorical metadata.
237 2. Normalizes the one-hot encoded metadata to sum to 1 for each category.
238 3. Computes the weighted enrichment of each metadata category for each archetype using the weights stored in `adata.obsm["cell_weights"]`.
240 Parameters
241 ----------
242 adata : sc.AnnData
243 An AnnData object containing the metadata in `adata.obs[meta]` and weights in `adata.obsm["cell_weights"]`.
244 meta : str
245 The name of the categorical metadata column in `adata.obs` to use for enrichment analysis.
247 Returns
248 -------
249 pd.DataFrame
250 A DataFrame of shape (n_archetypes, n_categories) containing the normalized enrichment of a metadata category for a given archetypes.
251 """
252 metadata = adata.obs[meta]
253 weights = adata.obsm["cell_weights"].T
255 # One-hot encoding of metadata
256 df_encoded = pd.get_dummies(metadata).astype(float)
257 # Normalization
258 df_encoded = df_encoded / df_encoded.values.sum(axis=0, keepdims=True)
260 # Compute weighted enrichment
261 weighted_meta = np.einsum("ij,jk->ik", weights, df_encoded)
262 weighted_meta /= weights.sum(axis=1, keepdims=True)
264 # Normalization
265 weighted_meta = weighted_meta / np.sum(weighted_meta, axis=1, keepdims=True)
266 weighted_meta_df = pd.DataFrame(weighted_meta, columns=df_encoded.columns)
268 return weighted_meta_df