Coverage for partipy/enrichment.py: 97%

105 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-09 10:33 +0200

1"""Functions to calculate which features (e.g. genes or covariates) are enriched at each archetype.""" 

2 

3import numpy as np 

4import pandas as pd 

5import scanpy as sc 

6from scipy.spatial.distance import cdist 

7 

8from .paretoti import _validate_aa_config, _validate_aa_results 

9 

10 

11def compute_archetype_weights( 

12 adata: sc.AnnData, 

13 mode: str = "automatic", 

14 length_scale: None | float = None, 

15 save_to_anndata: bool = True, 

16) -> None | np.ndarray: 

17 """ 

18 Calculate weights for the data points based on their distance to archetypes using a squared exponential kernel. 

19 

20 Parameters 

21 ---------- 

22 X : Union[np.ndarray, sc.AnnData] 

23 The input data, which can be either: 

24 - A 2D array of shape (n_samples, n_features) representing the PCA coordinates of the cells. 

25 - An AnnData object containing the PCA coordinates in `.obsm["X_pca"]` and archetypes in `.uns["AA_results"]["Z"]`. 

26 Z : np.ndarray, optional 

27 A 2D array of shape (n_archetypes, n_features) representing the PCA coordinates of the archetypes. 

28 Required if `X` is not an AnnData object. 

29 mode : str, optional (default="automatic") 

30 The mode for determining the length scale of the kernel: 

31 - "automatic": The length scale is calculated as half the median distance from the data centroid to the archetypes. 

32 - "manual": The length scale is provided by the user via the `length_scale` parameter. 

33 length_scale : float, optional 

34 The length scale of the kernel. Required if `mode="manual"`. 

35 

36 Returns 

37 ------- 

38 np.ndarray 

39 - If `X` is an AnnData object, the weights are added to `X.obsm["cell_weights"]` and nothing is returned. 

40 - If `X` is a numpy array, a 2D array of shape (n_samples, n_archetypes) representing the weights for each cell-archetype pair. 

41 """ 

42 # input validation 

43 _validate_aa_config(adata=adata) 

44 _validate_aa_results(adata=adata) 

45 

46 obsm_key = adata.uns["aa_config"]["obsm_key"] 

47 n_dimensions = adata.uns["aa_config"]["n_dimension"] 

48 X = adata.obsm[obsm_key][:, :n_dimensions] 

49 Z = adata.uns["AA_results"]["Z"] 

50 

51 # Calculate or validate length_scale based on mode 

52 if mode == "automatic": 

53 centroid = np.mean(X, axis=0).reshape(1, -1) 

54 length_scale = np.median(cdist(centroid, Z)) / 2 

55 elif mode == "manual": 

56 if length_scale is None: 

57 raise ValueError("For 'manual' mode, 'length_scale' must be provided.") 

58 else: 

59 raise ValueError("Mode must be either 'automatic' or 'manual'.") 

60 print(f"Applied length scale is {length_scale}.") 

61 

62 # Weight calculation 

63 euclidean_dist = cdist(X, Z) 

64 weights = np.exp(-(euclidean_dist**2) / (2 * length_scale**2)) # type: ignore[operator] 

65 weights = weights.astype(np.float32) 

66 

67 if save_to_anndata: 

68 adata.obsm["cell_weights"] = weights 

69 return None 

70 else: 

71 return weights 

72 

73 

74# compute_characteristic_gene_expression_per_archetype 

75def compute_archetype_expression(adata: sc.AnnData, layer: str | None = None) -> pd.DataFrame: 

76 """ 

77 Calculate a weighted average gene expression profile for each archetype. 

78 

79 This function computes the weighted average of gene expression across cells for each archetype. 

80 The weights should be based on the distance of cells to the archetypes, as computed by `calculate_weights`. 

81 

82 Parameters 

83 ---------- 

84 adata : sc.AnnData 

85 An AnnData object containing the gene expression data and weights. The weights should be stored in 

86 `adata.obsm["cell_weights"]` as a 2D array of shape (n_samples, n_archetypes). 

87 layer : str, optional (default=None) 

88 The layer of the AnnData object to use for gene expression. If `None`, `adata.X` is used. For Pareto analysis of AA data, 

89 z-scaled data is recommended. 

90 

91 Returns 

92 ------- 

93 pd.DataFrame 

94 A DataFrame of shape (n_archetypes, n_genes) with weighted pseudobulk expression profiles. 

95 """ 

96 if "cell_weights" not in adata.obsm: 

97 raise ValueError("No weights available. Please run compute_archetype_weights()") 

98 weights = adata.obsm["cell_weights"].T 

99 

100 if layer is None: 

101 expr = adata.X 

102 elif layer not in adata.layers: 

103 raise ValueError("Invalid layer") 

104 else: 

105 expr = adata.layers[layer] 

106 

107 pseudobulk = np.einsum("ij,jk->ik", weights, expr) 

108 pseudobulk /= weights.sum(axis=1, keepdims=True) 

109 

110 pseudobulk_df = pd.DataFrame(pseudobulk, columns=adata.var_names) 

111 pseudobulk_df.columns.name = None 

112 

113 return pseudobulk_df 

114 

115 

116def extract_enriched_processes( 

117 est: pd.DataFrame, 

118 pval: pd.DataFrame, 

119 order: str = "desc", 

120 n: int = 20, 

121 p_threshold: float = 0.05, 

122) -> dict[int, pd.DataFrame]: 

123 """ 

124 Extract top enriched biological processes for each archetype based on significance and enrichment score. 

125 

126 This function filters and ranks biological processes using enrichment estimates (`est`) and p-values (`pval`) 

127 from decoupler output. For each archetype, it selects the top `n` processes with p-values below `p_threshold`, 

128 optionally sorting by the highest or lowest enrichment scores. It also computes a "specificity" score indicating 

129 how uniquely enriched a process is for a given archetype compared to others. 

130 

131 Parameters 

132 ---------- 

133 est : pd.DataFrame 

134 A DataFrame of shape (n_archetypes, n_processes) containing the estimated enrichment scores 

135 for each process and archetype. 

136 pval : pd.DataFrame 

137 A DataFrame of shape (n_archetypes, n_processes) containing the p-values corresponding to 

138 the enrichment scores in `est`. 

139 order : str, optional (default="desc") 

140 The sorting order for selecting the top processes: 

141 - "desc": Selects the top `n` processes with the highest enrichment scores. 

142 - "asc": Selects the top `n` processes with the lowest enrichment scores. 

143 n : int, optional (default=20) 

144 The number of top processes to extract per archetype. 

145 p_threshold : float, optional (default=0.05) 

146 The p-value threshold for filtering processes. Only processes with p-values below this 

147 threshold are considered. 

148 

149 Returns 

150 ------- 

151 Dict[int, pd.DataFrame] 

152 A dictionary mapping each archetype index to a DataFrame of the top `n` enriched processes. 

153 Each DataFrame has the following columns: 

154 - "Process": Name of the biological process. 

155 - "{archetype indices}": Enrichment score for that process. 

156 - "specificity": A score indicating how uniquely enriched the process is in the given archetype. 

157 """ 

158 # Validate input 

159 if not ((p_threshold > 0.0) and (p_threshold <= 1.0)): 

160 raise ValueError("`p_threshold` must be a valid p value") 

161 if est.shape != pval.shape: 

162 raise ValueError("`est` and `pval` must have the same shape.") 

163 

164 if order not in ["desc", "asc"]: 

165 raise ValueError("`order` must be either 'desc' or 'asc'.") 

166 

167 results = {} 

168 for arch_idx in range(est.shape[0]): 

169 # Filter processes based on p-value threshold 

170 significant_processes = pval.columns[pval.iloc[arch_idx] < p_threshold] 

171 

172 # compute specificity score 

173 top_processes = est[significant_processes].T 

174 arch_z_score = top_processes[[str(arch_idx)]].values 

175 other_z_scores = top_processes[[c for c in top_processes.columns if c != str(arch_idx)]].values 

176 top_processes["specificity"] = (arch_z_score - other_z_scores).min(axis=1) 

177 

178 # filter 

179 if order == "desc": 

180 top_processes = top_processes.nlargest(n=n, columns=f"{arch_idx}").reset_index(names="Process") 

181 else: 

182 top_processes = top_processes.nsmallest(n=n, columns=f"{arch_idx}").reset_index(names="Process") 

183 

184 results[arch_idx] = top_processes 

185 

186 return results 

187 

188 

189def extract_specific_processes( 

190 est: pd.DataFrame, 

191 pval: pd.DataFrame, 

192 n: int = 20, 

193 p_threshold: float = 0.05, 

194) -> dict[int, pd.DataFrame]: 

195 """ 

196 Extract the top biological processes that are uniquely enriched in each archetype. 

197 

198 This function identifies the top `n` biological processes for each archetype based on their 

199 enrichment scores (`est`) and associated p-values (`pval`). Only processes with p-values below 

200 `p_threshold` in a given archetype are considered. A "specificity" score is computed for each 

201 process, reflecting how much more enriched it is in the target archetype compared to others. 

202 

203 Parameters 

204 ---------- 

205 est : pd.DataFrame 

206 A DataFrame of shape (n_archetypes, n_processes) containing the estimated enrichment scores 

207 for each process and archetype. 

208 pval : pd.DataFrame 

209 A DataFrame of shape (n_archetypes, n_processes) containing the p-values corresponding to 

210 the enrichment scores in `est`. 

211 n : int, optional (default=20) 

212 The number of top processes to extract per archetype. 

213 p_threshold : float, optional (default=0.05) 

214 The p-value threshold for filtering processes. Only processes with p-values below this 

215 threshold are considered. 

216 

217 Returns 

218 ------- 

219 dict[int, pd.DataFrame] 

220 A dictionary mapping each archetype index to a DataFrame containing the top `n` processes 

221 specific to that archetype. Each DataFrame includes: 

222 - "Process": Name of the biological process. 

223 - "{archetype indices}": Enrichment score in the given archetype. 

224 - "specificity": Score indicating how uniquely enriched the process is compared to other archetypes. 

225 """ 

226 # Validate input 

227 if not ((p_threshold > 0.0) and (p_threshold <= 1.0)): 

228 raise ValueError("`p_threshold` must be a valid p value") 

229 if est.shape != pval.shape: 

230 raise ValueError("`est` and `pval` must have the same shape.") 

231 

232 results = {} 

233 for arch_idx in range(est.shape[0]): 

234 # Filter processes based on p-value threshold 

235 significant_processes = pval.columns[pval.iloc[arch_idx] < p_threshold] 

236 

237 # compute specificity score 

238 top_processes = est[significant_processes].T 

239 arch_z_score = top_processes[[str(arch_idx)]].values 

240 other_z_scores = top_processes[[c for c in top_processes.columns if c != str(arch_idx)]].values 

241 top_processes["specificity"] = (arch_z_score - other_z_scores).min(axis=1) 

242 top_processes = top_processes.nlargest(n=n, columns="specificity").reset_index(names="Process") 

243 

244 results[arch_idx] = top_processes.copy() 

245 

246 return results 

247 

248 

249def compute_meta_enrichment(adata: sc.AnnData, meta_col: str, datatype: str = "automatic") -> pd.DataFrame: 

250 """ 

251 Compute the enrichment of metadata categories across archetypes. 

252 

253 This function estimates how enriched each metadata category is within each archetype using 

254 a weighted average approach. Weights are based on each cell’s contribution to each archetype 

255 (`adata.obsm["cell_weights"]`).It supports both categorical and continuous metadata. 

256 

257 Steps for categorical data: 

258 1. One-hot encode the metadata column from `adata.obs[meta_col]`. 

259 2. Normalize the metadata so that the sum for each category equals 1 (column-wise). 

260 3. Compute weighted enrichment using cell weights. 

261 4. Normalize the resulting enrichment scores across metadata categories for each archetype (row-wise). 

262 

263 Steps for continuous data: 

264 1. Compute the weighted average of the metadata per archetype. 

265 

266 Parameters 

267 ---------- 

268 adata : sc.AnnData 

269 AnnData object with categorical metadata in `adata.obs[meta_col]` and archetype weights 

270 in `adata.obsm["cell_weights"]` 

271 meta_col : str 

272 The name of the categorical metadata column in `adata.obs` to use for enrichment analysis. 

273 datatype : str, optional (default="automatic") 

274 Specifies how to interpret the metadata column: 

275 - "automatic": infers type based on column dtype. 

276 - "categorical": treats the column as categorical and one-hot encodes it. 

277 - "continuous": treats the column as numeric and computes weighted averages. 

278 

279 Returns 

280 ------- 

281 pd.DataFrame 

282 A DataFrame of shape (n_archetypes, n_categories) for categorical data or 

283 (n_archetypes, 1) for continuous data, containing normalized enrichment scores 

284 or weighted averages respectively. 

285 """ 

286 if meta_col not in adata.obs: 

287 raise ValueError("Metadata column does not exist") 

288 if "cell_weights" not in adata.obsm: 

289 raise ValueError("No weights available. Please run compute_archetype_weights()") 

290 

291 metadata = adata.obs[meta_col] 

292 weights = adata.obsm["cell_weights"].T 

293 

294 if datatype == "automatic": 

295 if pd.api.types.is_numeric_dtype(metadata): 

296 mode = "continuous" 

297 metadata = metadata.to_numpy(dtype="float") 

298 elif pd.api.types.is_string_dtype(metadata): 

299 mode = "categorical" 

300 else: 

301 raise ValueError("Not a valid data type detected") 

302 elif datatype == "continuous" or datatype == "categorical": 

303 mode = datatype 

304 else: 

305 raise ValueError("Not a valid data type") 

306 

307 if mode == "categorical": 

308 # One-hot encoding of metadata 

309 df_encoded = pd.get_dummies(metadata).astype(float) 

310 # Normalization 

311 df_encoded = df_encoded / df_encoded.values.sum(axis=0, keepdims=True) 

312 

313 # Compute weighted enrichment 

314 weighted_meta = np.einsum("ij,jk->ik", weights, df_encoded) 

315 weighted_meta /= weights.sum(axis=1, keepdims=True) 

316 

317 # Normalization 

318 weighted_meta = weighted_meta / np.sum(weighted_meta, axis=1, keepdims=True) 

319 weighted_meta_df = pd.DataFrame(weighted_meta, columns=df_encoded.columns) 

320 

321 elif mode == "continuous": 

322 metadata = np.asarray(metadata, dtype=float).reshape(-1, 1) 

323 

324 # Compute weighted enrichment 

325 weighted_meta = np.einsum("ij,jk->ik", weights, metadata) 

326 weighted_meta /= weights.sum(axis=1, keepdims=True) 

327 

328 weighted_meta_df = pd.DataFrame(weighted_meta, columns=[meta_col]) 

329 

330 return weighted_meta_df