Coverage for ParTIpy/enrichment.py: 12%

74 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-16 10:20 +0100

1"""Functions to calculate which features (e.g. genes or covariates) are enriched at each archetype""" 

2 

3import numpy as np 

4import pandas as pd 

5import scanpy as sc 

6from scipy.spatial.distance import cdist 

7 

8 

9def calculate_weights( 

10 X: np.ndarray | sc.AnnData, 

11 Z: np.ndarray | None = None, 

12 mode: str = "automatic", 

13 length_scale: None | float = None, 

14) -> None | tuple[np.ndarray, float | None]: 

15 """ 

16 Calculate weights for cells based on their distance to archetypes using a squared exponential kernel. 

17 

18 Parameters 

19 ---------- 

20 X : Union[np.ndarray, sc.AnnData] 

21 The input data, which can be either: 

22 - A 2D array of shape (n_samples, n_features) representing the PCA coordinates of the cells. 

23 - An AnnData object containing the PCA coordinates in `.obsm["X_pca_reduced"]` and archetypes in `.uns["archetypal_analysis"]["Z"]`. 

24 Z : np.ndarray, optional 

25 A 2D array of shape (n_archetypes, n_features) representing the PCA coordinates of the archetypes. 

26 Required if `X` is not an AnnData object. 

27 mode : str, optional (default="automatic") 

28 The mode for determining the length scale of the kernel: 

29 - "automatic": The length scale is calculated as half the median distance from the data centroid to the archetypes. 

30 - "manual": The length scale is provided by the user via the `length_scale` parameter. 

31 length_scale : float, optional 

32 The length scale of the kernel. Required if `mode="manual"`. 

33 

34 Returns 

35 ------- 

36 np.ndarray 

37 - If `X` is an AnnData object, the weights are added to `X.obsm["cell_weights"]` and nothing is returned. 

38 - If `X` is a numpy array, a 2D array of shape (n_samples, n_archetypes) representing the weights for each cell-archetype pair. 

39 """ 

40 # Handle and validate input data 

41 adata = None 

42 if isinstance(X, sc.AnnData): 

43 adata = X 

44 if "archetypal_analysis" not in X.uns: 

45 raise ValueError("Result from Archetypal Analysis not found in adata.uns. Please run AA()") 

46 Z = X.uns["archetypal_analysis"]["Z"] 

47 X = X.obsm["X_pca_reduced"] 

48 

49 if Z is None: 

50 raise ValueError("Please add the archetypes coordinates as input Z") 

51 

52 # Calculate or validate length_scale based on mode 

53 if mode == "automatic": 

54 centroid = np.mean(X, axis=0).reshape(1, -1) 

55 length_scale = np.median(cdist(centroid, Z)) / 2 

56 elif mode == "manual": 

57 if length_scale is None: 

58 raise ValueError("For 'manual' mode, 'length_scale' must be provided.") 

59 else: 

60 raise ValueError("Mode must be either 'automatic' or 'manual'.") 

61 print(f"Applied length scale is {length_scale}.") 

62 

63 # Weight calculation 

64 euclidean_dist = cdist(X, Z) 

65 weights = np.exp(-(euclidean_dist**2) / (2 * length_scale**2)) # type: ignore[operator] 

66 

67 if isinstance(adata, sc.AnnData): 

68 adata.obsm["cell_weights"] = weights 

69 return None 

70 else: 

71 return weights 

72 

73 

74def weighted_expr(adata: sc.AnnData, layer: str | None = None) -> np.ndarray: 

75 """ 

76 Calculate a weighted pseudobulk expression profile for each archetype. 

77 

78 This function computes the weighted average of gene expression across cells for each archetype. 

79 The weights should be based on the distance of cells to the archetypes, as computed by `calculate_weights`. 

80 

81 Parameters 

82 ---------- 

83 adata : sc.AnnData 

84 An AnnData object containing the gene expression data and weights. The weights should be stored in 

85 `adata.obsm["cell_weights"]` as a 2D array of shape (n_samples, n_archetypes). 

86 layer : str, optional (default=None) 

87 The layer of the AnnData object to use for gene expression. If `None`, `adata.X` is used. For Pareto analysis of AA data, 

88 z-scaled data is recommended. 

89 

90 Returns 

91 ------- 

92 np.ndarray 

93 A 2D array of shape (n_archetypes, n_genes) representing the weighted pseudobulk expression profiles. 

94 """ 

95 weights = adata.obsm["cell_weights"].T 

96 if layer is None: 

97 expr = adata.X 

98 else: 

99 expr = adata.layers[layer] 

100 pseudobulk = np.einsum("ij,jk->ik", weights, expr) 

101 pseudobulk /= weights.sum(axis=1, keepdims=True) 

102 

103 pseudobulk_df = pd.DataFrame(pseudobulk, columns=adata.var_names) 

104 

105 return pseudobulk_df 

106 

107 

108def extract_top_processes( 

109 est: pd.DataFrame, 

110 pval: pd.DataFrame, 

111 order: str = "desc", 

112 n: int = 20, 

113 p_threshold: float = 0.05, 

114) -> dict[str, pd.DataFrame]: 

115 """ 

116 Extract the top enriched biological processes based on statistical significance. 

117 

118 This function filters and ranks the most enriched biological processes from the decoupler output 

119 based on estimated enrichment scores (`est`) and corresponding p-values (`pval`) below the 

120 specified threshold (`p_treshold`). 

121 

122 Parameters 

123 ---------- 

124 est : pd.DataFrame 

125 A DataFrame of shape (n_archetypes, n_processes) containing the estimated enrichment scores 

126 for each process and archetype. 

127 pval : pd.DataFrame 

128 A DataFrame of shape (n_archetypes, n_processes) containing the p-values corresponding to 

129 the enrichment scores in `est`. 

130 order : str, optional (default="desc") 

131 The sorting order for selecting the top processes: 

132 - "desc": Selects the top `n` processes with the highest enrichment scores. 

133 - "asc": Selects the top `n` processes with the lowest enrichment scores. 

134 n : int, optional (default=20) 

135 The number of top processes to extract per archetype. 

136 p_threshold : float, optional (default=0.05) 

137 The p-value threshold for filtering processes. Only processes with p-values below this 

138 threshold are considered. 

139 

140 Returns 

141 ------- 

142 Dict[str, pd.DataFrame] 

143 A dictionary where keys are of the form "archetype_X" and values are 

144 DataFrames containing the top `n` enriched processes for each archetype. Each DataFrame 

145 has two columns: 

146 - "Process": The name of the biological process. 

147 - "Score": The enrichment score for the process. 

148 """ 

149 # Validate input 

150 if est.shape != pval.shape: 

151 raise ValueError("`est` and `pval` must have the same shape.") 

152 

153 if order not in ["desc", "asc"]: 

154 raise ValueError("`order` must be either 'desc' or 'asc'.") 

155 

156 results = {} 

157 for archetype in range(est.shape[0]): 

158 # Filter processes based on p-value threshold 

159 significant_processes = pval.iloc[archetype] < p_threshold 

160 filtered_scores = est.iloc[archetype, list(significant_processes)] 

161 

162 # Sort and select top processes 

163 if order == "desc": 

164 top_processes = filtered_scores.nlargest(n).reset_index() 

165 else: 

166 top_processes = filtered_scores.nsmallest(n).reset_index() 

167 

168 top_processes.columns = ["Process", "Score"] 

169 results[f"archetype_{archetype}"] = top_processes 

170 

171 return results 

172 

173 

174def extract_top_specific_processes( 

175 est: pd.DataFrame, 

176 pval: pd.DataFrame, 

177 drop_threshold: int = 0, 

178 n: int = 20, 

179 p_threshold: float = 0.05, 

180): 

181 """ 

182 Extract the top enriched biological processes that are specific to each archetype. 

183 

184 This function identifies the most enriched biological processes for each archetype based on 

185 estimated enrichment scores (`est`) and corresponding p-values (`pval`) from the decoupler output below the 

186 specified threshold (`p_treshold`). It ensures that the selected processes are specific to the archetype by 

187 enforcing that their enrichment scores are below a specified threshold (`drop_threshold`) in all other archetypes. 

188 

189 Parameters 

190 ---------- 

191 est : pd.DataFrame 

192 A DataFrame of shape (n_archetypes, n_processes) containing the estimated enrichment scores 

193 for each process and archetype. 

194 pval : pd.DataFrame 

195 A DataFrame of shape (n_archetypes, n_processes) containing the p-values corresponding to 

196 the enrichment scores in `est`. 

197 drop_threshold : int, optional (default=20) 

198 The enrichment threshold below which processes are dropped. 

199 n : int, optional (default=20) 

200 The number of top processes to extract per archetype. 

201 p_threshold : float, optional (default=0.05) 

202 The p-value threshold for filtering processes. Only processes with p-values below this 

203 threshold are considered. 

204 

205 Returns 

206 ------- 

207 Dict[str, pd.DataFrame] 

208 A dictionary where keys are of the form "archetype_X" and values are 

209 DataFrames containing the top `n` enriched processes for each archetype that are below a score of 

210 `drop_threshold` for all other archetypes. 

211 """ 

212 if est.shape != pval.shape: 

213 raise ValueError("`est` and `pval` must have the same shape.") 

214 

215 results = {} 

216 for archetype in range(est.shape[0]): 

217 # Filter processes based on p-value threshold 

218 significant_processes = pval.iloc[archetype] < p_threshold 

219 top_processes = est.iloc[archetype, list(significant_processes)].nlargest(n).index 

220 

221 # Filter processes based on drop threshold 

222 subset = est.loc[:, top_processes] 

223 subset.index = subset.index.astype(int) 

224 filtered_processes = top_processes[(subset.drop(index=archetype) < drop_threshold).all(axis=0)] 

225 

226 results[f"archetype_{archetype}"] = est.loc[:, filtered_processes].copy() 

227 

228 return results 

229 

230 

231def meta_enrichment(adata: sc.AnnData, meta: str) -> pd.DataFrame: 

232 """ 

233 Compute the weighted enrichment of metadata categories across archetypes. 

234 

235 This function performs the following steps: 

236 1. One-hot encodes the categorical metadata. 

237 2. Normalizes the one-hot encoded metadata to sum to 1 for each category. 

238 3. Computes the weighted enrichment of each metadata category for each archetype using the weights stored in `adata.obsm["cell_weights"]`. 

239 

240 Parameters 

241 ---------- 

242 adata : sc.AnnData 

243 An AnnData object containing the metadata in `adata.obs[meta]` and weights in `adata.obsm["cell_weights"]`. 

244 meta : str 

245 The name of the categorical metadata column in `adata.obs` to use for enrichment analysis. 

246 

247 Returns 

248 ------- 

249 pd.DataFrame 

250 A DataFrame of shape (n_archetypes, n_categories) containing the normalized enrichment of a metadata category for a given archetypes. 

251 """ 

252 metadata = adata.obs[meta] 

253 weights = adata.obsm["cell_weights"].T 

254 

255 # One-hot encoding of metadata 

256 df_encoded = pd.get_dummies(metadata).astype(float) 

257 # Normalization 

258 df_encoded = df_encoded / df_encoded.values.sum(axis=0, keepdims=True) 

259 

260 # Compute weighted enrichment 

261 weighted_meta = np.einsum("ij,jk->ik", weights, df_encoded) 

262 weighted_meta /= weights.sum(axis=1, keepdims=True) 

263 

264 # Normalization 

265 weighted_meta = weighted_meta / np.sum(weighted_meta, axis=1, keepdims=True) 

266 weighted_meta_df = pd.DataFrame(weighted_meta, columns=df_encoded.columns) 

267 

268 return weighted_meta_df