Coverage for partipy/enrichment.py: 97%
105 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-09 10:33 +0200
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-09 10:33 +0200
1"""Functions to calculate which features (e.g. genes or covariates) are enriched at each archetype."""
3import numpy as np
4import pandas as pd
5import scanpy as sc
6from scipy.spatial.distance import cdist
8from .paretoti import _validate_aa_config, _validate_aa_results
11def compute_archetype_weights(
12 adata: sc.AnnData,
13 mode: str = "automatic",
14 length_scale: None | float = None,
15 save_to_anndata: bool = True,
16) -> None | np.ndarray:
17 """
18 Calculate weights for the data points based on their distance to archetypes using a squared exponential kernel.
20 Parameters
21 ----------
22 X : Union[np.ndarray, sc.AnnData]
23 The input data, which can be either:
24 - A 2D array of shape (n_samples, n_features) representing the PCA coordinates of the cells.
25 - An AnnData object containing the PCA coordinates in `.obsm["X_pca"]` and archetypes in `.uns["AA_results"]["Z"]`.
26 Z : np.ndarray, optional
27 A 2D array of shape (n_archetypes, n_features) representing the PCA coordinates of the archetypes.
28 Required if `X` is not an AnnData object.
29 mode : str, optional (default="automatic")
30 The mode for determining the length scale of the kernel:
31 - "automatic": The length scale is calculated as half the median distance from the data centroid to the archetypes.
32 - "manual": The length scale is provided by the user via the `length_scale` parameter.
33 length_scale : float, optional
34 The length scale of the kernel. Required if `mode="manual"`.
36 Returns
37 -------
38 np.ndarray
39 - If `X` is an AnnData object, the weights are added to `X.obsm["cell_weights"]` and nothing is returned.
40 - If `X` is a numpy array, a 2D array of shape (n_samples, n_archetypes) representing the weights for each cell-archetype pair.
41 """
42 # input validation
43 _validate_aa_config(adata=adata)
44 _validate_aa_results(adata=adata)
46 obsm_key = adata.uns["aa_config"]["obsm_key"]
47 n_dimensions = adata.uns["aa_config"]["n_dimension"]
48 X = adata.obsm[obsm_key][:, :n_dimensions]
49 Z = adata.uns["AA_results"]["Z"]
51 # Calculate or validate length_scale based on mode
52 if mode == "automatic":
53 centroid = np.mean(X, axis=0).reshape(1, -1)
54 length_scale = np.median(cdist(centroid, Z)) / 2
55 elif mode == "manual":
56 if length_scale is None:
57 raise ValueError("For 'manual' mode, 'length_scale' must be provided.")
58 else:
59 raise ValueError("Mode must be either 'automatic' or 'manual'.")
60 print(f"Applied length scale is {length_scale}.")
62 # Weight calculation
63 euclidean_dist = cdist(X, Z)
64 weights = np.exp(-(euclidean_dist**2) / (2 * length_scale**2)) # type: ignore[operator]
65 weights = weights.astype(np.float32)
67 if save_to_anndata:
68 adata.obsm["cell_weights"] = weights
69 return None
70 else:
71 return weights
74# compute_characteristic_gene_expression_per_archetype
75def compute_archetype_expression(adata: sc.AnnData, layer: str | None = None) -> pd.DataFrame:
76 """
77 Calculate a weighted average gene expression profile for each archetype.
79 This function computes the weighted average of gene expression across cells for each archetype.
80 The weights should be based on the distance of cells to the archetypes, as computed by `calculate_weights`.
82 Parameters
83 ----------
84 adata : sc.AnnData
85 An AnnData object containing the gene expression data and weights. The weights should be stored in
86 `adata.obsm["cell_weights"]` as a 2D array of shape (n_samples, n_archetypes).
87 layer : str, optional (default=None)
88 The layer of the AnnData object to use for gene expression. If `None`, `adata.X` is used. For Pareto analysis of AA data,
89 z-scaled data is recommended.
91 Returns
92 -------
93 pd.DataFrame
94 A DataFrame of shape (n_archetypes, n_genes) with weighted pseudobulk expression profiles.
95 """
96 if "cell_weights" not in adata.obsm:
97 raise ValueError("No weights available. Please run compute_archetype_weights()")
98 weights = adata.obsm["cell_weights"].T
100 if layer is None:
101 expr = adata.X
102 elif layer not in adata.layers:
103 raise ValueError("Invalid layer")
104 else:
105 expr = adata.layers[layer]
107 pseudobulk = np.einsum("ij,jk->ik", weights, expr)
108 pseudobulk /= weights.sum(axis=1, keepdims=True)
110 pseudobulk_df = pd.DataFrame(pseudobulk, columns=adata.var_names)
111 pseudobulk_df.columns.name = None
113 return pseudobulk_df
116def extract_enriched_processes(
117 est: pd.DataFrame,
118 pval: pd.DataFrame,
119 order: str = "desc",
120 n: int = 20,
121 p_threshold: float = 0.05,
122) -> dict[int, pd.DataFrame]:
123 """
124 Extract top enriched biological processes for each archetype based on significance and enrichment score.
126 This function filters and ranks biological processes using enrichment estimates (`est`) and p-values (`pval`)
127 from decoupler output. For each archetype, it selects the top `n` processes with p-values below `p_threshold`,
128 optionally sorting by the highest or lowest enrichment scores. It also computes a "specificity" score indicating
129 how uniquely enriched a process is for a given archetype compared to others.
131 Parameters
132 ----------
133 est : pd.DataFrame
134 A DataFrame of shape (n_archetypes, n_processes) containing the estimated enrichment scores
135 for each process and archetype.
136 pval : pd.DataFrame
137 A DataFrame of shape (n_archetypes, n_processes) containing the p-values corresponding to
138 the enrichment scores in `est`.
139 order : str, optional (default="desc")
140 The sorting order for selecting the top processes:
141 - "desc": Selects the top `n` processes with the highest enrichment scores.
142 - "asc": Selects the top `n` processes with the lowest enrichment scores.
143 n : int, optional (default=20)
144 The number of top processes to extract per archetype.
145 p_threshold : float, optional (default=0.05)
146 The p-value threshold for filtering processes. Only processes with p-values below this
147 threshold are considered.
149 Returns
150 -------
151 Dict[int, pd.DataFrame]
152 A dictionary mapping each archetype index to a DataFrame of the top `n` enriched processes.
153 Each DataFrame has the following columns:
154 - "Process": Name of the biological process.
155 - "{archetype indices}": Enrichment score for that process.
156 - "specificity": A score indicating how uniquely enriched the process is in the given archetype.
157 """
158 # Validate input
159 if not ((p_threshold > 0.0) and (p_threshold <= 1.0)):
160 raise ValueError("`p_threshold` must be a valid p value")
161 if est.shape != pval.shape:
162 raise ValueError("`est` and `pval` must have the same shape.")
164 if order not in ["desc", "asc"]:
165 raise ValueError("`order` must be either 'desc' or 'asc'.")
167 results = {}
168 for arch_idx in range(est.shape[0]):
169 # Filter processes based on p-value threshold
170 significant_processes = pval.columns[pval.iloc[arch_idx] < p_threshold]
172 # compute specificity score
173 top_processes = est[significant_processes].T
174 arch_z_score = top_processes[[str(arch_idx)]].values
175 other_z_scores = top_processes[[c for c in top_processes.columns if c != str(arch_idx)]].values
176 top_processes["specificity"] = (arch_z_score - other_z_scores).min(axis=1)
178 # filter
179 if order == "desc":
180 top_processes = top_processes.nlargest(n=n, columns=f"{arch_idx}").reset_index(names="Process")
181 else:
182 top_processes = top_processes.nsmallest(n=n, columns=f"{arch_idx}").reset_index(names="Process")
184 results[arch_idx] = top_processes
186 return results
189def extract_specific_processes(
190 est: pd.DataFrame,
191 pval: pd.DataFrame,
192 n: int = 20,
193 p_threshold: float = 0.05,
194) -> dict[int, pd.DataFrame]:
195 """
196 Extract the top biological processes that are uniquely enriched in each archetype.
198 This function identifies the top `n` biological processes for each archetype based on their
199 enrichment scores (`est`) and associated p-values (`pval`). Only processes with p-values below
200 `p_threshold` in a given archetype are considered. A "specificity" score is computed for each
201 process, reflecting how much more enriched it is in the target archetype compared to others.
203 Parameters
204 ----------
205 est : pd.DataFrame
206 A DataFrame of shape (n_archetypes, n_processes) containing the estimated enrichment scores
207 for each process and archetype.
208 pval : pd.DataFrame
209 A DataFrame of shape (n_archetypes, n_processes) containing the p-values corresponding to
210 the enrichment scores in `est`.
211 n : int, optional (default=20)
212 The number of top processes to extract per archetype.
213 p_threshold : float, optional (default=0.05)
214 The p-value threshold for filtering processes. Only processes with p-values below this
215 threshold are considered.
217 Returns
218 -------
219 dict[int, pd.DataFrame]
220 A dictionary mapping each archetype index to a DataFrame containing the top `n` processes
221 specific to that archetype. Each DataFrame includes:
222 - "Process": Name of the biological process.
223 - "{archetype indices}": Enrichment score in the given archetype.
224 - "specificity": Score indicating how uniquely enriched the process is compared to other archetypes.
225 """
226 # Validate input
227 if not ((p_threshold > 0.0) and (p_threshold <= 1.0)):
228 raise ValueError("`p_threshold` must be a valid p value")
229 if est.shape != pval.shape:
230 raise ValueError("`est` and `pval` must have the same shape.")
232 results = {}
233 for arch_idx in range(est.shape[0]):
234 # Filter processes based on p-value threshold
235 significant_processes = pval.columns[pval.iloc[arch_idx] < p_threshold]
237 # compute specificity score
238 top_processes = est[significant_processes].T
239 arch_z_score = top_processes[[str(arch_idx)]].values
240 other_z_scores = top_processes[[c for c in top_processes.columns if c != str(arch_idx)]].values
241 top_processes["specificity"] = (arch_z_score - other_z_scores).min(axis=1)
242 top_processes = top_processes.nlargest(n=n, columns="specificity").reset_index(names="Process")
244 results[arch_idx] = top_processes.copy()
246 return results
249def compute_meta_enrichment(adata: sc.AnnData, meta_col: str, datatype: str = "automatic") -> pd.DataFrame:
250 """
251 Compute the enrichment of metadata categories across archetypes.
253 This function estimates how enriched each metadata category is within each archetype using
254 a weighted average approach. Weights are based on each cell’s contribution to each archetype
255 (`adata.obsm["cell_weights"]`).It supports both categorical and continuous metadata.
257 Steps for categorical data:
258 1. One-hot encode the metadata column from `adata.obs[meta_col]`.
259 2. Normalize the metadata so that the sum for each category equals 1 (column-wise).
260 3. Compute weighted enrichment using cell weights.
261 4. Normalize the resulting enrichment scores across metadata categories for each archetype (row-wise).
263 Steps for continuous data:
264 1. Compute the weighted average of the metadata per archetype.
266 Parameters
267 ----------
268 adata : sc.AnnData
269 AnnData object with categorical metadata in `adata.obs[meta_col]` and archetype weights
270 in `adata.obsm["cell_weights"]`
271 meta_col : str
272 The name of the categorical metadata column in `adata.obs` to use for enrichment analysis.
273 datatype : str, optional (default="automatic")
274 Specifies how to interpret the metadata column:
275 - "automatic": infers type based on column dtype.
276 - "categorical": treats the column as categorical and one-hot encodes it.
277 - "continuous": treats the column as numeric and computes weighted averages.
279 Returns
280 -------
281 pd.DataFrame
282 A DataFrame of shape (n_archetypes, n_categories) for categorical data or
283 (n_archetypes, 1) for continuous data, containing normalized enrichment scores
284 or weighted averages respectively.
285 """
286 if meta_col not in adata.obs:
287 raise ValueError("Metadata column does not exist")
288 if "cell_weights" not in adata.obsm:
289 raise ValueError("No weights available. Please run compute_archetype_weights()")
291 metadata = adata.obs[meta_col]
292 weights = adata.obsm["cell_weights"].T
294 if datatype == "automatic":
295 if pd.api.types.is_numeric_dtype(metadata):
296 mode = "continuous"
297 metadata = metadata.to_numpy(dtype="float")
298 elif pd.api.types.is_string_dtype(metadata):
299 mode = "categorical"
300 else:
301 raise ValueError("Not a valid data type detected")
302 elif datatype == "continuous" or datatype == "categorical":
303 mode = datatype
304 else:
305 raise ValueError("Not a valid data type")
307 if mode == "categorical":
308 # One-hot encoding of metadata
309 df_encoded = pd.get_dummies(metadata).astype(float)
310 # Normalization
311 df_encoded = df_encoded / df_encoded.values.sum(axis=0, keepdims=True)
313 # Compute weighted enrichment
314 weighted_meta = np.einsum("ij,jk->ik", weights, df_encoded)
315 weighted_meta /= weights.sum(axis=1, keepdims=True)
317 # Normalization
318 weighted_meta = weighted_meta / np.sum(weighted_meta, axis=1, keepdims=True)
319 weighted_meta_df = pd.DataFrame(weighted_meta, columns=df_encoded.columns)
321 elif mode == "continuous":
322 metadata = np.asarray(metadata, dtype=float).reshape(-1, 1)
324 # Compute weighted enrichment
325 weighted_meta = np.einsum("ij,jk->ik", weights, metadata)
326 weighted_meta /= weights.sum(axis=1, keepdims=True)
328 weighted_meta_df = pd.DataFrame(weighted_meta, columns=[meta_col])
330 return weighted_meta_df