Coverage for partipy/plotting.py: 85%

190 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-09 10:41 +0200

1import matplotlib.pyplot as plt 

2import numpy as np 

3import pandas as pd 

4import plotly.express as px 

5import plotly.graph_objects as go 

6import plotnine as pn 

7import scanpy as sc 

8from mizani.palettes import hue_pal 

9from scipy.spatial import ConvexHull 

10 

11from .paretoti import _validate_aa_config, _validate_aa_results, var_explained_aa 

12 

13 

14def plot_var_explained(adata: sc.AnnData) -> pn.ggplot: 

15 """ 

16 Generate an elbow plot of the variance explained by Archetypal Analysis (AA) for a range of archetypes. 

17 

18 This function creates a plot showing the variance explained by AA models with different numbers of archetypes. 

19 The data is retrieved from `adata.uns["AA_var"]`. If `adata.uns["AA_var"]` is not found, `var_explained_aa` is called. 

20 

21 Parameters 

22 ---------- 

23 adata : sc.AnnData 

24 AnnData object containing the variance explained data in `adata.uns["AA_var"]`. 

25 

26 Returns 

27 ------- 

28 pn.ggplot 

29 A ggplot object showing the variance explained plot. 

30 """ 

31 # Validation input 

32 if "AA_var" not in adata.uns: 

33 print("AA_var not found in adata.uns. Computing variance explained by archetypal analysis...") 

34 var_explained_aa(adata=adata) 

35 

36 plot_df = adata.uns["AA_var"] 

37 

38 # Create data for the diagonal line 

39 diag_data = pd.DataFrame( 

40 { 

41 "k": [plot_df["k"].min(), plot_df["k"].max()], 

42 "varexpl": [plot_df["varexpl"].min(), plot_df["varexpl"].max()], 

43 } 

44 ) 

45 

46 p = ( 

47 pn.ggplot(plot_df) 

48 + pn.geom_line(mapping=pn.aes(x="k", y="varexpl"), color="black") 

49 + pn.geom_point(mapping=pn.aes(x="k", y="varexpl"), color="black") 

50 + pn.geom_line(data=diag_data, mapping=pn.aes(x="k", y="varexpl"), color="gray") 

51 + pn.labs(x="Number of Archetypes (k)", y="Variance Explained") 

52 + pn.lims(y=[0, 1]) 

53 + pn.scale_x_continuous(breaks=list(np.arange(plot_df["k"].min(), plot_df["k"].max() + 1))) 

54 + pn.theme_matplotlib() 

55 + pn.theme(panel_grid_major=pn.element_line(color="gray", size=0.5, alpha=0.5), figure_size=(6, 3)) 

56 ) 

57 return p 

58 

59 

60def plot_IC(adata: sc.AnnData) -> pn.ggplot: 

61 """ 

62 Generate a plot showing an information criteria for a range of archetypes. 

63 

64 This function creates a plot showing the variance explained by AA models with different numbers of archetypes. 

65 The data is retrieved from `adata.uns["AA_var"]`. If `adata.uns["AA_var"]` is not found, `var_explained_aa` is called. 

66 

67 Parameters 

68 ---------- 

69 adata : sc.AnnData 

70 AnnData object containing the variance explained data in `adata.uns["AA_var"]`. 

71 

72 Returns 

73 ------- 

74 pn.ggplot 

75 A ggplot object showing the variance explained plot. 

76 """ 

77 # Validation input 

78 if "AA_var" not in adata.uns: 

79 print("AA_var not found in adata.uns. Computing variance explained by archetypal analysis...") 

80 var_explained_aa(adata=adata) 

81 

82 plot_df = adata.uns["AA_var"] 

83 

84 p = ( 

85 pn.ggplot(plot_df) 

86 + pn.geom_line(mapping=pn.aes(x="k", y="IC"), color="black") 

87 + pn.geom_point(mapping=pn.aes(x="k", y="IC"), color="black") 

88 + pn.labs(x="Number of Archetypes (k)", y="Information Criteria") 

89 + pn.scale_x_continuous(breaks=list(np.arange(plot_df["k"].min(), plot_df["k"].max() + 1))) 

90 + pn.theme_matplotlib() 

91 + pn.theme(panel_grid_major=pn.element_line(color="gray", size=0.5, alpha=0.5), figure_size=(6, 3)) 

92 ) 

93 return p 

94 

95 

96def plot_bootstrap_2D(adata: sc.AnnData, show_two_panels: bool = True) -> pn.ggplot: 

97 """ 

98 Visualize the distribution and stability of archetypes across bootstrap samples in 2D PCA space. 

99 

100 Creates a static 2D scatter plot showing the positions of archetypes 

101 computed from bootstrap samples, stored in `adata.uns["AA_bootstrap"]`. 

102 

103 Parameters 

104 ---------- 

105 adata : sc.AnnData 

106 Annotated data object containing the archetype bootstrap data in `adata.uns["AA_bootstrap"]`. 

107 

108 Returns 

109 ------- 

110 pn.ggplot 

111 A 2D scatter plot visualizing the bootstrap results for the archetypes. 

112 """ 

113 # Validation input 

114 if "AA_bootstrap" not in adata.uns: 

115 raise ValueError("AA_bootstrap not found in adata.uns. Please run bootstrap_aa() to compute") 

116 

117 # Generate the 2D scatter plot 

118 plot_df = adata.uns["AA_bootstrap"].copy() 

119 

120 if ("x2" in plot_df.columns.to_list()) and show_two_panels: 

121 plot_df = plot_df.melt( 

122 id_vars=["x0", "archetype", "reference"], value_vars=["x1", "x2"], var_name="variable", value_name="value" 

123 ) 

124 p = ( 

125 pn.ggplot(plot_df) 

126 + pn.geom_point(pn.aes(x="x0", y="value", color="archetype", shape="reference")) 

127 + pn.facet_wrap(facets="variable", scales="fixed") 

128 + pn.labs(x="First Axis", y="Second / Third Axis") 

129 + pn.coord_equal() 

130 ) 

131 else: 

132 p = ( 

133 pn.ggplot(plot_df) 

134 + pn.geom_point(pn.aes(x="x0", y="x1", color="archetype", shape="reference")) 

135 + pn.coord_equal() 

136 ) 

137 return p 

138 

139 

140def plot_bootstrap_3D(adata: sc.AnnData) -> go.Figure: 

141 """ 

142 Interactive 3D visualization of archetypes from bootstrap samples to assess their variability. 

143 

144 Create an interactive 3D scatter plot showing the positions of archetypes 

145 computed from bootstrap samples, stored in `adata.uns["AA_bootstrap"]`. 

146 

147 Parameters 

148 ---------- 

149 adata : sc.AnnData 

150 Annotated data object containing the archetype bootstrap data in `adata.uns["AA_bootstrap"]`. 

151 

152 Returns 

153 ------- 

154 go.Figure 

155 A 3D scatter plot visualizing the bootstrap results for the archetypes. 

156 """ 

157 # Validation input 

158 if "AA_bootstrap" not in adata.uns: 

159 raise ValueError("AA_bootstrap not found in adata.uns. Please run bootstrap_aa() to compute") 

160 

161 # Generate the 3D scatter plot 

162 bootstrap_df = adata.uns["AA_bootstrap"] 

163 fig = px.scatter_3d( 

164 bootstrap_df, 

165 x="x0", 

166 y="x1", 

167 z="x2", 

168 color="archetype", 

169 symbol="reference", 

170 title="Archetypes on bootstrapepd data", 

171 size_max=10, 

172 hover_data=["iter", "archetype", "reference"], 

173 opacity=0.5, 

174 ) 

175 fig.update_layout(template=None) 

176 

177 return fig 

178 

179 

180def plot_bootstrap_multiple_k(adata: sc.AnnData) -> pn.ggplot: 

181 """ 

182 Visualize archetype stability as a function of the number of archetypes. 

183 

184 This function generates a plot summarizing the stability of archetypes across different 

185 numbers of archetypes (`k`), based on bootstrap variance metrics. It displays individual 

186 archetype variances as points, along with summary statistics (median and maximum variance) 

187 as lines. 

188 

189 Parameters 

190 ---------- 

191 adata : sc.AnnData 

192 Annotated data object containing the results from `bootstrap_aa_multiple_k` in 

193 `adata.uns["AA_boostrap_multiple_k"]`. 

194 

195 Returns 

196 ------- 

197 pn.ggplot 

198 A ggplot object displaying: 

199 - Scatter points for individual archetype variances (`variance_per_archetype`) as a function of `n_archetypes`. 

200 - Lines and points for the median and maximum variance across archetypes at each `n_archetypes`. 

201 """ 

202 if "AA_boostrap_multiple_k" not in adata.uns: 

203 raise ValueError( 

204 "bootstrap_aa_multiple_k not found in adata.uns. Please run bootstrap_aa_multiple_k() to compute" 

205 ) 

206 df = adata.uns["AA_boostrap_multiple_k"] 

207 df_summary = df.groupby("n_archetypes")["variance_per_archetype"].agg(["median", "max"]).reset_index() 

208 df_summary = df_summary.melt(id_vars="n_archetypes", value_vars=["median", "max"]) 

209 p = ( 

210 pn.ggplot() 

211 + pn.geom_point(data=df, mapping=pn.aes(x="n_archetypes", y="variance_per_archetype"), alpha=0.5, size=3) 

212 + pn.geom_line(data=df_summary, mapping=pn.aes(x="n_archetypes", y="value", color="variable")) 

213 + pn.geom_point(data=df_summary, mapping=pn.aes(x="n_archetypes", y="value", color="variable")) 

214 + pn.labs(x="Number of Archetypes", y="Value", color="Variance\nSummary") 

215 ) 

216 return p 

217 

218 

219def plot_archetypes_2D( 

220 adata: sc.AnnData, color: str | None = None, alpha: float = 1.0, show_two_panels: bool = True 

221) -> pn.ggplot: 

222 """ 

223 Generate a static 2D scatter plot showing data points, archetypes and the polytope they span. 

224 

225 This function visualizes the archetypes computed via Archetypal Analysis (AA) 

226 in PCA space, along with the data points. An optional color vector can be used 

227 to annotate the data points. 

228 

229 Parameters 

230 ---------- 

231 adata : sc.AnnData 

232 Annotated data object containing the archetypes in `adata.uns["AA_results"]["Z"]` 

233 and PCA-reduced data in `adata.obsm["X_pca"]`. 

234 color : str or None, optional 

235 Column name in `adata.obs` to use for coloring the data points. If None, no coloring is applied. 

236 

237 Returns 

238 ------- 

239 pn.ggplot 

240 A static 2D scatter plot showing the data and archetypes. 

241 """ 

242 _validate_aa_config(adata) 

243 _validate_aa_results(adata) 

244 obsm_key = adata.uns["aa_config"]["obsm_key"] 

245 n_dimensions = adata.uns["aa_config"]["n_dimension"] 

246 X = adata.obsm[obsm_key][:, :n_dimensions] 

247 Z = adata.uns["AA_results"]["Z"] 

248 color_vec = sc.get.obs_df(adata, color).values.flatten() if color else None 

249 plot = plot_2D(X=X, Z=Z, color_vec=color_vec, alpha=alpha, show_two_panels=show_two_panels) 

250 return plot 

251 

252 

253def plot_2D( 

254 X: np.ndarray, Z: np.ndarray, color_vec: np.ndarray | None = None, alpha: float = 1.0, show_two_panels: bool = True 

255) -> pn.ggplot: 

256 """ 

257 2D plot of the datapoints in X and the 2D polytope enclosed by the archetypes in Z. 

258 

259 Parameters 

260 ---------- 

261 X : np.ndarray 

262 A 2D array of shape (n_samples, n_features) representing the data points. 

263 Z : np.ndarray 

264 A 2D array of shape (n_archetypes, n_features) representing the archetype coordinates. 

265 color_vec : np.ndarray, optional 

266 A 1D array of shape (n_samples,) containing values for coloring the data points in `X`. 

267 

268 Returns 

269 ------- 

270 pn.ggplot 

271 2D plot of X and polytope enclosed by Z. 

272 """ 

273 if X.shape[1] < 2 or Z.shape[1] < 2: 

274 raise ValueError("Both X and Z must have at least 2 columns (PCs).") 

275 if color_vec is not None: 

276 if len(color_vec) != len(X): 

277 raise ValueError("color_vec must have the same length as X.") 

278 

279 if (X.shape[1] > 2) and show_two_panels: 

280 data_df = pd.DataFrame(X[:, :3], columns=["x0", "x1", "x2"]) 

281 if color_vec is not None: 

282 data_df["color_vec"] = np.array(color_vec) 

283 data_df = data_df.melt( 

284 id_vars=["x0", "color_vec"], value_vars=["x1", "x2"], var_name="variable", value_name="value" 

285 ) 

286 else: 

287 data_df = data_df.melt(id_vars=["x0"], value_vars=["x1", "x2"], var_name="variable", value_name="value") 

288 arch_df_list = [] 

289 for dim in range(1, 3): 

290 order = np.argsort(np.arctan2(Z[:, dim] - np.mean(Z[:, dim]), Z[:, 0] - np.mean(Z[:, 0]))) 

291 arch_df = pd.DataFrame(Z[:, [0, dim]], columns=["x0", "value"]) 

292 arch_df["variable"] = f"x{dim}" 

293 arch_df["archetype_label"] = np.arange(arch_df.shape[0]) 

294 arch_df = arch_df.iloc[order].reset_index(drop=True) 

295 arch_df = pd.concat([arch_df, arch_df.iloc[:1]], ignore_index=True) 

296 arch_df_list.append(arch_df) 

297 arch_df = pd.concat(arch_df_list) 

298 else: 

299 data_df = pd.DataFrame(X[:, :2], columns=["x0", "value"]) 

300 if color_vec is not None: 

301 data_df["color_vec"] = np.array(color_vec) 

302 data_df["variable"] = "x1" 

303 order = np.argsort(np.arctan2(Z[:, 1] - np.mean(Z[:, 1]), Z[:, 0] - np.mean(Z[:, 0]))) 

304 arch_df = pd.DataFrame(Z[:, [0, 1]], columns=["x0", "value"]) 

305 arch_df["variable"] = "x1" 

306 arch_df["archetype_label"] = np.arange(arch_df.shape[0]) 

307 arch_df = arch_df.iloc[order].reset_index(drop=True) 

308 arch_df = pd.concat([arch_df, arch_df.iloc[:1]], ignore_index=True) 

309 

310 # Generate plot 

311 plot = pn.ggplot() 

312 

313 if color_vec is not None: 

314 plot += pn.geom_point(data=data_df, mapping=pn.aes(x="x0", y="value", color="color_vec"), alpha=alpha) 

315 else: 

316 plot += pn.geom_point(data=data_df, mapping=pn.aes(x="x0", y="value"), color="black", alpha=alpha) 

317 

318 plot += pn.geom_point(data=arch_df, mapping=pn.aes(x="x0", y="value"), color="red", size=1) 

319 plot += pn.geom_path(data=arch_df, mapping=pn.aes(x="x0", y="value"), color="red", size=1) 

320 plot += pn.geom_label( 

321 data=arch_df, mapping=pn.aes(x="x0", y="value", label="archetype_label"), color="black", size=12 

322 ) 

323 plot += pn.facet_wrap(facets="variable", scales="fixed") 

324 plot += pn.labs(x="First Axis", y="Second / Third Axis") 

325 plot += pn.coord_equal() 

326 

327 return plot 

328 

329 

330def plot_archetypes_3D(adata: sc.AnnData, color: str | None = None) -> pn.ggplot: 

331 """ 

332 Create an interactive 3D scatter plot showing data points, archetypes and the polytope they span. 

333 

334 This function uses the first three principal components from `adata.obsm["X_pca"]` 

335 and visualizes the archetypes stored in `adata.uns["AA_results"]["Z"]`. 

336 If a color key is provided, it colors data points by the corresponding values from `adata.obs`. 

337 

338 Parameters 

339 ---------- 

340 adata : sc.AnnData 

341 Annotated data object containing the PCA-reduced data in `obsm["X_pca"]` and 

342 archetypes in `uns["AA_results"]["Z"]`. 

343 color : str, optional 

344 Name of a column in `adata.obs` to color the data points by. 

345 

346 Returns 

347 ------- 

348 go.Figure 

349 A Plotly figure object showing a 3D scatter plot of the data and archetypes. 

350 """ 

351 _validate_aa_config(adata) 

352 _validate_aa_results(adata) 

353 obsm_key = adata.uns["aa_config"]["obsm_key"] 

354 n_dimensions = adata.uns["aa_config"]["n_dimension"] 

355 X = adata.obsm[obsm_key][:, :n_dimensions] 

356 Z = adata.uns["AA_results"]["Z"] 

357 color_vec = sc.get.obs_df(adata, color).values.flatten() if color else None 

358 plot = plot_3D(X=X, Z=Z, color_vec=color_vec) 

359 return plot 

360 

361 

362def plot_3D( 

363 X: np.ndarray, 

364 Z: np.ndarray, 

365 color_vec: np.ndarray | None = None, 

366 marker_size: int = 4, 

367 color_polyhedron: str = "green", 

368) -> go.Figure: 

369 """ 

370 Generates a 3D plot of data points and the polytope formed by archetypes. 

371 

372 Parameters 

373 ---------- 

374 X : np.ndarray 

375 A 2D array of shape (n_samples, n_features) representing the data points. 

376 Z : np.ndarray 

377 A 2D array of shape (n_archetypes, n_features) representing the archetype coordinates. 

378 color_vec : np.ndarray, optional 

379 A 1D array of shape (n_samples,) containing values for coloring the data points in `X`. 

380 marker_size : int, optional (default=4) 

381 The size of the markers for the data points in `X`. 

382 color_polyhedron : str, optional (default="green") 

383 The color of the polytope defined by the archetypes. 

384 

385 Returns 

386 ------- 

387 go.Figure 

388 3D plot of X and polytope enclosed by Z. 

389 """ 

390 # Validation input 

391 if Z is None: 

392 raise ValueError("Please add the archetypes coordinates as input Z") 

393 

394 if X.shape[1] < 3 or Z.shape[1] < 3: 

395 raise ValueError("Both X and Z must have at least 3 columns (PCs).") 

396 

397 X_plot, Z_plot = X[:, :3], Z[:, :3] 

398 

399 plot_df = pd.DataFrame(X_plot, columns=["x0", "x1", "x2"]) 

400 plot_df["marker_size"] = np.repeat(marker_size, X_plot.shape[0]) 

401 

402 # Create the 3D scatter plot 

403 if color_vec is not None: 

404 if len(color_vec) != len(plot_df): 

405 raise ValueError("color_vec must have the same length as X.") 

406 plot_df["color_vec"] = np.array(color_vec) 

407 fig = px.scatter_3d( 

408 plot_df, 

409 x="x0", 

410 y="x1", 

411 z="x2", 

412 labels={"x0": "PC 1", "x1": "PC 2", "x2": "PC 3"}, 

413 title="3D polytope", 

414 color="color_vec", 

415 size="marker_size", 

416 size_max=10, 

417 opacity=0.5, 

418 ) 

419 else: 

420 fig = px.scatter_3d( 

421 plot_df, 

422 x="x0", 

423 y="x1", 

424 z="x2", 

425 labels={"x0": "PC 1", "x1": "PC 2", "x2": "PC 3"}, 

426 title="3D polytope", 

427 size="marker_size", 

428 size_max=10, 

429 opacity=0.5, 

430 ) 

431 

432 # Compute the convex hull of the archetypes 

433 hull = ConvexHull(Z_plot) 

434 

435 # Add archetypes to the plot 

436 archetype_labels = [f"Archetype {i}" for i in range(Z_plot.shape[0])] 

437 fig.add_trace( 

438 go.Scatter3d( 

439 x=Z_plot[:, 0], 

440 y=Z_plot[:, 1], 

441 z=Z_plot[:, 2], 

442 mode="markers", 

443 text=archetype_labels, 

444 marker=dict(size=4, color=color_polyhedron, symbol="circle"), # noqa: C408 

445 hoverinfo="text", 

446 name="Archetypes", 

447 ) 

448 ) 

449 

450 # Add the polytope (convex hull) to the plot 

451 fig.add_trace( 

452 go.Mesh3d( 

453 x=Z_plot[:, 0], 

454 y=Z_plot[:, 1], 

455 z=Z_plot[:, 2], 

456 i=hull.simplices[:, 0], 

457 j=hull.simplices[:, 1], 

458 k=hull.simplices[:, 2], 

459 color=color_polyhedron, 

460 opacity=0.1, 

461 ) 

462 ) 

463 

464 # Add edges of the polytope to the plot 

465 for simplex in hull.simplices: 

466 simplex = np.append(simplex, simplex[0]) 

467 fig.add_trace( 

468 go.Scatter3d( 

469 x=Z_plot[simplex, 0], 

470 y=Z_plot[simplex, 1], 

471 z=Z_plot[simplex, 2], 

472 mode="lines", 

473 line={"color": color_polyhedron, "width": 4}, 

474 showlegend=False, 

475 ) 

476 ) 

477 

478 fig.update_layout(template=None) 

479 return fig 

480 

481 

482def barplot_meta_enrichment(meta_enrich: pd.DataFrame, meta: str = "Meta"): 

483 """ 

484 Generate a stacked bar plot showing metadata enrichment across archetypes. 

485 

486 Parameters 

487 ---------- 

488 meta_enrich: pd.DataFrame 

489 Output of `meta_enrichment()`, a DataFrame where rows are archetypes and columns are metadata categories, 

490 with values representing normalized enrichment scores. 

491 meta : str, optional 

492 Label to use for the metadata category legend in the plot. Default is "Meta". 

493 

494 Returns 

495 ------- 

496 pn.ggplot.ggplot 

497 A stacked bar plot of metadata enrichment per archetype. 

498 """ 

499 # prepare data 

500 meta_enrich = meta_enrich.reset_index().rename(columns={"index": "archetype"}) 

501 meta_enrich_long = meta_enrich.melt(id_vars=["archetype"], var_name="Meta", value_name="Normalized_Enrichment") 

502 

503 # get unique categories and assign colors 

504 categories = meta_enrich_long["Meta"].unique() 

505 color_palette = hue_pal()(len(categories)) 

506 

507 # Create plot 

508 plot = ( 

509 pn.ggplot( 

510 meta_enrich_long, 

511 pn.aes(x="factor(archetype)", y="Normalized_Enrichment", fill="Meta"), 

512 ) 

513 + pn.geom_bar(stat="identity", position="stack") 

514 + pn.theme_matplotlib() 

515 # + pn.scale_fill_brewer(type="qual", palette="Dark2") 

516 + pn.scale_fill_manual(values=color_palette) 

517 + pn.labs( 

518 title="Meta Enrichment Across Archetypes", 

519 x="Archetype", 

520 y="Normalized Enrichment", 

521 fill=meta, 

522 ) 

523 ) 

524 return plot 

525 

526 

527def heatmap_meta_enrichment(meta_enrich: pd.DataFrame, meta: str | None = "Meta"): 

528 """ 

529 Generate a heatmap showing metadata enrichment across archetypes. 

530 

531 Parameters 

532 ---------- 

533 meta_enrich: pd.DataFrame 

534 Output of `meta_enrichment()`, a DataFrame where rows are archetypes and columns are metadata categories, 

535 with values representing normalized enrichment scores. 

536 meta : str, optional 

537 Label to use for the metadata category legend in the plot. Default is "Meta". 

538 

539 Returns 

540 ------- 

541 pn.ggplot.ggplot 

542 A heatmap of normalized enrichment scores per archetype and metadata category. 

543 """ 

544 # Prepare data 

545 meta_enrich = meta_enrich.reset_index().rename(columns={"index": "archetype"}) 

546 meta_enrich_long = meta_enrich.melt(id_vars=["archetype"], var_name="Meta", value_name="Normalized_Enrichment") 

547 

548 # Create plot 

549 plot = ( 

550 pn.ggplot(meta_enrich_long, pn.aes("archetype", "Meta", fill="Normalized_Enrichment")) 

551 + pn.geom_tile() 

552 + pn.scale_fill_continuous(cmap_name="Blues") 

553 + pn.theme_matplotlib() 

554 + pn.labs(title="Heatmap", x="Archetype", y=meta, fill=" Normalized \nEnrichment") 

555 ) 

556 return plot 

557 

558 

559def barplot_functional_enrichment(top_features: dict, show: bool = True): 

560 """ 

561 Generate bar plots showing functional enrichment scores for each archetype. 

562 

563 Each plot displays the top enriched features (e.g., biological processes) for one archetype. 

564 

565 Parameters 

566 ---------- 

567 top_features : dict 

568 A dictionary where keys are archetype indices (0, 1,...) and values are pd.DataFrames 

569 containing the data to plot. Each DataFrame should have a column for the feature ('Process') and a column 

570 for the archetype (0, 1, ...) 

571 show: bool, optional 

572 If the plots should be printed. 

573 

574 Returns 

575 ------- 

576 list 

577 A list of `plotnine.ggplot` objects, one for each archetype. 

578 """ 

579 plots = [] 

580 # Loop through archetypes 

581 for key in range(len(top_features)): 

582 data = top_features[key] 

583 

584 # Order column 

585 data["Process"] = pd.Categorical(data["Process"], categories=data["Process"].tolist(), ordered=True) 

586 

587 # Create plot 

588 plot = ( 

589 pn.ggplot(data, pn.aes(x="Process", y=str(key), fill=str(key))) 

590 + pn.geom_bar(stat="identity") 

591 + pn.labs( 

592 title=f"Enrichment at archetype {key}", 

593 x="Feature", 

594 y="Enrichment score", 

595 fill="Enrichment score", 

596 ) 

597 + pn.theme_matplotlib() 

598 + pn.theme(figure_size=(15, 5)) 

599 + pn.coord_flip() 

600 + pn.scale_fill_gradient2( 

601 low="blue", 

602 mid="lightgrey", 

603 high="red", 

604 midpoint=0, 

605 ) 

606 ) 

607 if show: 

608 plot.show() 

609 plots.append(plot) 

610 

611 # Return the list of plots 

612 return plots 

613 

614 

615def barplot_enrichment_comparison(specific_processes_arch: pd.DataFrame): 

616 """ 

617 Plots a grouped bar plot comparing enrichment scores across archetypes for a given set of features. 

618 

619 Parameters 

620 ---------- 

621 specific_processes_arch : pd.DataFrame 

622 Output from `extract_specific_processes`. Must contain a 'Process' column, a 'specificity' score, 

623 and one column per archetype with enrichment values. 

624 

625 Returns 

626 ------- 

627 plotnine.ggplot.ggplot 

628 A grouped bar plot visualizing the enrichment scores for the specified features across archetypes." 

629 """ 

630 # Subset the DataFrame to include only the specified features 

631 process_order = specific_processes_arch.sort_values("specificity", ascending=False)["Process"].to_list() 

632 arch_columns = specific_processes_arch.drop(columns=["Process", "specificity"]).columns.to_list() 

633 plot_df = specific_processes_arch.drop(columns="specificity").melt( 

634 id_vars=["Process"], value_vars=arch_columns, var_name="Archetype", value_name="Enrichment" 

635 ) 

636 plot_df["Process"] = pd.Categorical(plot_df["Process"], categories=process_order) 

637 

638 plot = ( 

639 pn.ggplot(plot_df, pn.aes(x="Process", y="Enrichment", fill="factor(Archetype)")) 

640 + pn.geom_bar(stat="identity", position=pn.position_dodge()) 

641 + pn.theme_matplotlib() 

642 + pn.scale_fill_brewer(type="qual", palette="Dark2") 

643 + pn.labs( 

644 x="Process", 

645 y="Enrichment score", 

646 fill="Archetype", 

647 title="Enrichment Comparison", 

648 ) 

649 + pn.theme(figure_size=(10, 5)) 

650 + pn.coord_flip() 

651 ) 

652 return plot 

653 

654 

655def radarplot_meta_enrichment(meta_enrich: pd.DataFrame): 

656 """ 

657 Parameters 

658 ---------- 

659 meta_enrich: pd.DataFrame 

660 Output of meta_enrichment(), a pd.DataFrame containing the enrichment of meta categories (columns) for all archetypes (rows). 

661 

662 Returns 

663 ------- 

664 plt.pyplot.Figure 

665 Radar plots for all archetypes. 

666 """ 

667 # Prepare data 

668 meta_enrich = meta_enrich.T.reset_index().rename(columns={"index": "Meta_feature"}) 

669 

670 # Function to create a radar plot for a given row 

671 def make_radar(row, title, color): 

672 # Set number of meta categories 

673 categories = list(meta_enrich)[1:] 

674 N = len(categories) 

675 

676 # Calculate angles for the radar plot 

677 angles = [n / float(N) * 2 * np.pi for n in range(N)] 

678 angles += angles[:1] 

679 

680 # Initialise the radar plot 

681 ax = plt.subplot(int(np.ceil(len(meta_enrich) / 2)), 2, row + 1, polar=True) 

682 

683 # Put first axis on top: 

684 ax.set_theta_offset(np.pi / 2) 

685 ax.set_theta_direction(-1) 

686 

687 # One axe per variable and add labels 

688 archetype_label = [f"A{i}" for i in range(len(list(meta_enrich)[1:]))] 

689 plt.xticks(angles[:-1], archetype_label, color="grey", size=8) 

690 

691 # Draw ylabels 

692 ax.set_rlabel_position(0) 

693 plt.yticks( 

694 [0, 0.25, 0.5, 0.75, 1], 

695 ["0", "0.25", "0.50", "0.75", "1.0"], 

696 color="grey", 

697 size=7, 

698 ) 

699 plt.ylim(0, 1) 

700 

701 # Draw plot 

702 values = meta_enrich.loc[row].drop("Meta_feature").values.flatten().tolist() 

703 values += values[:1] 

704 ax.plot(angles, values, color=color, linewidth=2, linestyle="solid") 

705 ax.fill(angles, values, color=color, alpha=0.4) 

706 

707 # Add a title 

708 plt.title(title, size=11, color=color, y=1.065) 

709 

710 # Initialize the figure 

711 my_dpi = 96 

712 plt.figure(figsize=(1000 / my_dpi, 1000 / my_dpi), dpi=my_dpi) 

713 

714 # Create a color palette: 

715 my_palette = plt.colormaps.get_cmap("Dark2") 

716 

717 # Loop to plot 

718 for row in range(0, len(meta_enrich.index)): 

719 make_radar( 

720 row=row, 

721 title=f"Feature: {meta_enrich['Meta_feature'][row]}", 

722 color=my_palette(row), 

723 ) 

724 

725 return plt