Coverage for partipy/plotting.py: 85%
190 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-09 10:41 +0200
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-09 10:41 +0200
1import matplotlib.pyplot as plt
2import numpy as np
3import pandas as pd
4import plotly.express as px
5import plotly.graph_objects as go
6import plotnine as pn
7import scanpy as sc
8from mizani.palettes import hue_pal
9from scipy.spatial import ConvexHull
11from .paretoti import _validate_aa_config, _validate_aa_results, var_explained_aa
14def plot_var_explained(adata: sc.AnnData) -> pn.ggplot:
15 """
16 Generate an elbow plot of the variance explained by Archetypal Analysis (AA) for a range of archetypes.
18 This function creates a plot showing the variance explained by AA models with different numbers of archetypes.
19 The data is retrieved from `adata.uns["AA_var"]`. If `adata.uns["AA_var"]` is not found, `var_explained_aa` is called.
21 Parameters
22 ----------
23 adata : sc.AnnData
24 AnnData object containing the variance explained data in `adata.uns["AA_var"]`.
26 Returns
27 -------
28 pn.ggplot
29 A ggplot object showing the variance explained plot.
30 """
31 # Validation input
32 if "AA_var" not in adata.uns:
33 print("AA_var not found in adata.uns. Computing variance explained by archetypal analysis...")
34 var_explained_aa(adata=adata)
36 plot_df = adata.uns["AA_var"]
38 # Create data for the diagonal line
39 diag_data = pd.DataFrame(
40 {
41 "k": [plot_df["k"].min(), plot_df["k"].max()],
42 "varexpl": [plot_df["varexpl"].min(), plot_df["varexpl"].max()],
43 }
44 )
46 p = (
47 pn.ggplot(plot_df)
48 + pn.geom_line(mapping=pn.aes(x="k", y="varexpl"), color="black")
49 + pn.geom_point(mapping=pn.aes(x="k", y="varexpl"), color="black")
50 + pn.geom_line(data=diag_data, mapping=pn.aes(x="k", y="varexpl"), color="gray")
51 + pn.labs(x="Number of Archetypes (k)", y="Variance Explained")
52 + pn.lims(y=[0, 1])
53 + pn.scale_x_continuous(breaks=list(np.arange(plot_df["k"].min(), plot_df["k"].max() + 1)))
54 + pn.theme_matplotlib()
55 + pn.theme(panel_grid_major=pn.element_line(color="gray", size=0.5, alpha=0.5), figure_size=(6, 3))
56 )
57 return p
60def plot_IC(adata: sc.AnnData) -> pn.ggplot:
61 """
62 Generate a plot showing an information criteria for a range of archetypes.
64 This function creates a plot showing the variance explained by AA models with different numbers of archetypes.
65 The data is retrieved from `adata.uns["AA_var"]`. If `adata.uns["AA_var"]` is not found, `var_explained_aa` is called.
67 Parameters
68 ----------
69 adata : sc.AnnData
70 AnnData object containing the variance explained data in `adata.uns["AA_var"]`.
72 Returns
73 -------
74 pn.ggplot
75 A ggplot object showing the variance explained plot.
76 """
77 # Validation input
78 if "AA_var" not in adata.uns:
79 print("AA_var not found in adata.uns. Computing variance explained by archetypal analysis...")
80 var_explained_aa(adata=adata)
82 plot_df = adata.uns["AA_var"]
84 p = (
85 pn.ggplot(plot_df)
86 + pn.geom_line(mapping=pn.aes(x="k", y="IC"), color="black")
87 + pn.geom_point(mapping=pn.aes(x="k", y="IC"), color="black")
88 + pn.labs(x="Number of Archetypes (k)", y="Information Criteria")
89 + pn.scale_x_continuous(breaks=list(np.arange(plot_df["k"].min(), plot_df["k"].max() + 1)))
90 + pn.theme_matplotlib()
91 + pn.theme(panel_grid_major=pn.element_line(color="gray", size=0.5, alpha=0.5), figure_size=(6, 3))
92 )
93 return p
96def plot_bootstrap_2D(adata: sc.AnnData, show_two_panels: bool = True) -> pn.ggplot:
97 """
98 Visualize the distribution and stability of archetypes across bootstrap samples in 2D PCA space.
100 Creates a static 2D scatter plot showing the positions of archetypes
101 computed from bootstrap samples, stored in `adata.uns["AA_bootstrap"]`.
103 Parameters
104 ----------
105 adata : sc.AnnData
106 Annotated data object containing the archetype bootstrap data in `adata.uns["AA_bootstrap"]`.
108 Returns
109 -------
110 pn.ggplot
111 A 2D scatter plot visualizing the bootstrap results for the archetypes.
112 """
113 # Validation input
114 if "AA_bootstrap" not in adata.uns:
115 raise ValueError("AA_bootstrap not found in adata.uns. Please run bootstrap_aa() to compute")
117 # Generate the 2D scatter plot
118 plot_df = adata.uns["AA_bootstrap"].copy()
120 if ("x2" in plot_df.columns.to_list()) and show_two_panels:
121 plot_df = plot_df.melt(
122 id_vars=["x0", "archetype", "reference"], value_vars=["x1", "x2"], var_name="variable", value_name="value"
123 )
124 p = (
125 pn.ggplot(plot_df)
126 + pn.geom_point(pn.aes(x="x0", y="value", color="archetype", shape="reference"))
127 + pn.facet_wrap(facets="variable", scales="fixed")
128 + pn.labs(x="First Axis", y="Second / Third Axis")
129 + pn.coord_equal()
130 )
131 else:
132 p = (
133 pn.ggplot(plot_df)
134 + pn.geom_point(pn.aes(x="x0", y="x1", color="archetype", shape="reference"))
135 + pn.coord_equal()
136 )
137 return p
140def plot_bootstrap_3D(adata: sc.AnnData) -> go.Figure:
141 """
142 Interactive 3D visualization of archetypes from bootstrap samples to assess their variability.
144 Create an interactive 3D scatter plot showing the positions of archetypes
145 computed from bootstrap samples, stored in `adata.uns["AA_bootstrap"]`.
147 Parameters
148 ----------
149 adata : sc.AnnData
150 Annotated data object containing the archetype bootstrap data in `adata.uns["AA_bootstrap"]`.
152 Returns
153 -------
154 go.Figure
155 A 3D scatter plot visualizing the bootstrap results for the archetypes.
156 """
157 # Validation input
158 if "AA_bootstrap" not in adata.uns:
159 raise ValueError("AA_bootstrap not found in adata.uns. Please run bootstrap_aa() to compute")
161 # Generate the 3D scatter plot
162 bootstrap_df = adata.uns["AA_bootstrap"]
163 fig = px.scatter_3d(
164 bootstrap_df,
165 x="x0",
166 y="x1",
167 z="x2",
168 color="archetype",
169 symbol="reference",
170 title="Archetypes on bootstrapepd data",
171 size_max=10,
172 hover_data=["iter", "archetype", "reference"],
173 opacity=0.5,
174 )
175 fig.update_layout(template=None)
177 return fig
180def plot_bootstrap_multiple_k(adata: sc.AnnData) -> pn.ggplot:
181 """
182 Visualize archetype stability as a function of the number of archetypes.
184 This function generates a plot summarizing the stability of archetypes across different
185 numbers of archetypes (`k`), based on bootstrap variance metrics. It displays individual
186 archetype variances as points, along with summary statistics (median and maximum variance)
187 as lines.
189 Parameters
190 ----------
191 adata : sc.AnnData
192 Annotated data object containing the results from `bootstrap_aa_multiple_k` in
193 `adata.uns["AA_boostrap_multiple_k"]`.
195 Returns
196 -------
197 pn.ggplot
198 A ggplot object displaying:
199 - Scatter points for individual archetype variances (`variance_per_archetype`) as a function of `n_archetypes`.
200 - Lines and points for the median and maximum variance across archetypes at each `n_archetypes`.
201 """
202 if "AA_boostrap_multiple_k" not in adata.uns:
203 raise ValueError(
204 "bootstrap_aa_multiple_k not found in adata.uns. Please run bootstrap_aa_multiple_k() to compute"
205 )
206 df = adata.uns["AA_boostrap_multiple_k"]
207 df_summary = df.groupby("n_archetypes")["variance_per_archetype"].agg(["median", "max"]).reset_index()
208 df_summary = df_summary.melt(id_vars="n_archetypes", value_vars=["median", "max"])
209 p = (
210 pn.ggplot()
211 + pn.geom_point(data=df, mapping=pn.aes(x="n_archetypes", y="variance_per_archetype"), alpha=0.5, size=3)
212 + pn.geom_line(data=df_summary, mapping=pn.aes(x="n_archetypes", y="value", color="variable"))
213 + pn.geom_point(data=df_summary, mapping=pn.aes(x="n_archetypes", y="value", color="variable"))
214 + pn.labs(x="Number of Archetypes", y="Value", color="Variance\nSummary")
215 )
216 return p
219def plot_archetypes_2D(
220 adata: sc.AnnData, color: str | None = None, alpha: float = 1.0, show_two_panels: bool = True
221) -> pn.ggplot:
222 """
223 Generate a static 2D scatter plot showing data points, archetypes and the polytope they span.
225 This function visualizes the archetypes computed via Archetypal Analysis (AA)
226 in PCA space, along with the data points. An optional color vector can be used
227 to annotate the data points.
229 Parameters
230 ----------
231 adata : sc.AnnData
232 Annotated data object containing the archetypes in `adata.uns["AA_results"]["Z"]`
233 and PCA-reduced data in `adata.obsm["X_pca"]`.
234 color : str or None, optional
235 Column name in `adata.obs` to use for coloring the data points. If None, no coloring is applied.
237 Returns
238 -------
239 pn.ggplot
240 A static 2D scatter plot showing the data and archetypes.
241 """
242 _validate_aa_config(adata)
243 _validate_aa_results(adata)
244 obsm_key = adata.uns["aa_config"]["obsm_key"]
245 n_dimensions = adata.uns["aa_config"]["n_dimension"]
246 X = adata.obsm[obsm_key][:, :n_dimensions]
247 Z = adata.uns["AA_results"]["Z"]
248 color_vec = sc.get.obs_df(adata, color).values.flatten() if color else None
249 plot = plot_2D(X=X, Z=Z, color_vec=color_vec, alpha=alpha, show_two_panels=show_two_panels)
250 return plot
253def plot_2D(
254 X: np.ndarray, Z: np.ndarray, color_vec: np.ndarray | None = None, alpha: float = 1.0, show_two_panels: bool = True
255) -> pn.ggplot:
256 """
257 2D plot of the datapoints in X and the 2D polytope enclosed by the archetypes in Z.
259 Parameters
260 ----------
261 X : np.ndarray
262 A 2D array of shape (n_samples, n_features) representing the data points.
263 Z : np.ndarray
264 A 2D array of shape (n_archetypes, n_features) representing the archetype coordinates.
265 color_vec : np.ndarray, optional
266 A 1D array of shape (n_samples,) containing values for coloring the data points in `X`.
268 Returns
269 -------
270 pn.ggplot
271 2D plot of X and polytope enclosed by Z.
272 """
273 if X.shape[1] < 2 or Z.shape[1] < 2:
274 raise ValueError("Both X and Z must have at least 2 columns (PCs).")
275 if color_vec is not None:
276 if len(color_vec) != len(X):
277 raise ValueError("color_vec must have the same length as X.")
279 if (X.shape[1] > 2) and show_two_panels:
280 data_df = pd.DataFrame(X[:, :3], columns=["x0", "x1", "x2"])
281 if color_vec is not None:
282 data_df["color_vec"] = np.array(color_vec)
283 data_df = data_df.melt(
284 id_vars=["x0", "color_vec"], value_vars=["x1", "x2"], var_name="variable", value_name="value"
285 )
286 else:
287 data_df = data_df.melt(id_vars=["x0"], value_vars=["x1", "x2"], var_name="variable", value_name="value")
288 arch_df_list = []
289 for dim in range(1, 3):
290 order = np.argsort(np.arctan2(Z[:, dim] - np.mean(Z[:, dim]), Z[:, 0] - np.mean(Z[:, 0])))
291 arch_df = pd.DataFrame(Z[:, [0, dim]], columns=["x0", "value"])
292 arch_df["variable"] = f"x{dim}"
293 arch_df["archetype_label"] = np.arange(arch_df.shape[0])
294 arch_df = arch_df.iloc[order].reset_index(drop=True)
295 arch_df = pd.concat([arch_df, arch_df.iloc[:1]], ignore_index=True)
296 arch_df_list.append(arch_df)
297 arch_df = pd.concat(arch_df_list)
298 else:
299 data_df = pd.DataFrame(X[:, :2], columns=["x0", "value"])
300 if color_vec is not None:
301 data_df["color_vec"] = np.array(color_vec)
302 data_df["variable"] = "x1"
303 order = np.argsort(np.arctan2(Z[:, 1] - np.mean(Z[:, 1]), Z[:, 0] - np.mean(Z[:, 0])))
304 arch_df = pd.DataFrame(Z[:, [0, 1]], columns=["x0", "value"])
305 arch_df["variable"] = "x1"
306 arch_df["archetype_label"] = np.arange(arch_df.shape[0])
307 arch_df = arch_df.iloc[order].reset_index(drop=True)
308 arch_df = pd.concat([arch_df, arch_df.iloc[:1]], ignore_index=True)
310 # Generate plot
311 plot = pn.ggplot()
313 if color_vec is not None:
314 plot += pn.geom_point(data=data_df, mapping=pn.aes(x="x0", y="value", color="color_vec"), alpha=alpha)
315 else:
316 plot += pn.geom_point(data=data_df, mapping=pn.aes(x="x0", y="value"), color="black", alpha=alpha)
318 plot += pn.geom_point(data=arch_df, mapping=pn.aes(x="x0", y="value"), color="red", size=1)
319 plot += pn.geom_path(data=arch_df, mapping=pn.aes(x="x0", y="value"), color="red", size=1)
320 plot += pn.geom_label(
321 data=arch_df, mapping=pn.aes(x="x0", y="value", label="archetype_label"), color="black", size=12
322 )
323 plot += pn.facet_wrap(facets="variable", scales="fixed")
324 plot += pn.labs(x="First Axis", y="Second / Third Axis")
325 plot += pn.coord_equal()
327 return plot
330def plot_archetypes_3D(adata: sc.AnnData, color: str | None = None) -> pn.ggplot:
331 """
332 Create an interactive 3D scatter plot showing data points, archetypes and the polytope they span.
334 This function uses the first three principal components from `adata.obsm["X_pca"]`
335 and visualizes the archetypes stored in `adata.uns["AA_results"]["Z"]`.
336 If a color key is provided, it colors data points by the corresponding values from `adata.obs`.
338 Parameters
339 ----------
340 adata : sc.AnnData
341 Annotated data object containing the PCA-reduced data in `obsm["X_pca"]` and
342 archetypes in `uns["AA_results"]["Z"]`.
343 color : str, optional
344 Name of a column in `adata.obs` to color the data points by.
346 Returns
347 -------
348 go.Figure
349 A Plotly figure object showing a 3D scatter plot of the data and archetypes.
350 """
351 _validate_aa_config(adata)
352 _validate_aa_results(adata)
353 obsm_key = adata.uns["aa_config"]["obsm_key"]
354 n_dimensions = adata.uns["aa_config"]["n_dimension"]
355 X = adata.obsm[obsm_key][:, :n_dimensions]
356 Z = adata.uns["AA_results"]["Z"]
357 color_vec = sc.get.obs_df(adata, color).values.flatten() if color else None
358 plot = plot_3D(X=X, Z=Z, color_vec=color_vec)
359 return plot
362def plot_3D(
363 X: np.ndarray,
364 Z: np.ndarray,
365 color_vec: np.ndarray | None = None,
366 marker_size: int = 4,
367 color_polyhedron: str = "green",
368) -> go.Figure:
369 """
370 Generates a 3D plot of data points and the polytope formed by archetypes.
372 Parameters
373 ----------
374 X : np.ndarray
375 A 2D array of shape (n_samples, n_features) representing the data points.
376 Z : np.ndarray
377 A 2D array of shape (n_archetypes, n_features) representing the archetype coordinates.
378 color_vec : np.ndarray, optional
379 A 1D array of shape (n_samples,) containing values for coloring the data points in `X`.
380 marker_size : int, optional (default=4)
381 The size of the markers for the data points in `X`.
382 color_polyhedron : str, optional (default="green")
383 The color of the polytope defined by the archetypes.
385 Returns
386 -------
387 go.Figure
388 3D plot of X and polytope enclosed by Z.
389 """
390 # Validation input
391 if Z is None:
392 raise ValueError("Please add the archetypes coordinates as input Z")
394 if X.shape[1] < 3 or Z.shape[1] < 3:
395 raise ValueError("Both X and Z must have at least 3 columns (PCs).")
397 X_plot, Z_plot = X[:, :3], Z[:, :3]
399 plot_df = pd.DataFrame(X_plot, columns=["x0", "x1", "x2"])
400 plot_df["marker_size"] = np.repeat(marker_size, X_plot.shape[0])
402 # Create the 3D scatter plot
403 if color_vec is not None:
404 if len(color_vec) != len(plot_df):
405 raise ValueError("color_vec must have the same length as X.")
406 plot_df["color_vec"] = np.array(color_vec)
407 fig = px.scatter_3d(
408 plot_df,
409 x="x0",
410 y="x1",
411 z="x2",
412 labels={"x0": "PC 1", "x1": "PC 2", "x2": "PC 3"},
413 title="3D polytope",
414 color="color_vec",
415 size="marker_size",
416 size_max=10,
417 opacity=0.5,
418 )
419 else:
420 fig = px.scatter_3d(
421 plot_df,
422 x="x0",
423 y="x1",
424 z="x2",
425 labels={"x0": "PC 1", "x1": "PC 2", "x2": "PC 3"},
426 title="3D polytope",
427 size="marker_size",
428 size_max=10,
429 opacity=0.5,
430 )
432 # Compute the convex hull of the archetypes
433 hull = ConvexHull(Z_plot)
435 # Add archetypes to the plot
436 archetype_labels = [f"Archetype {i}" for i in range(Z_plot.shape[0])]
437 fig.add_trace(
438 go.Scatter3d(
439 x=Z_plot[:, 0],
440 y=Z_plot[:, 1],
441 z=Z_plot[:, 2],
442 mode="markers",
443 text=archetype_labels,
444 marker=dict(size=4, color=color_polyhedron, symbol="circle"), # noqa: C408
445 hoverinfo="text",
446 name="Archetypes",
447 )
448 )
450 # Add the polytope (convex hull) to the plot
451 fig.add_trace(
452 go.Mesh3d(
453 x=Z_plot[:, 0],
454 y=Z_plot[:, 1],
455 z=Z_plot[:, 2],
456 i=hull.simplices[:, 0],
457 j=hull.simplices[:, 1],
458 k=hull.simplices[:, 2],
459 color=color_polyhedron,
460 opacity=0.1,
461 )
462 )
464 # Add edges of the polytope to the plot
465 for simplex in hull.simplices:
466 simplex = np.append(simplex, simplex[0])
467 fig.add_trace(
468 go.Scatter3d(
469 x=Z_plot[simplex, 0],
470 y=Z_plot[simplex, 1],
471 z=Z_plot[simplex, 2],
472 mode="lines",
473 line={"color": color_polyhedron, "width": 4},
474 showlegend=False,
475 )
476 )
478 fig.update_layout(template=None)
479 return fig
482def barplot_meta_enrichment(meta_enrich: pd.DataFrame, meta: str = "Meta"):
483 """
484 Generate a stacked bar plot showing metadata enrichment across archetypes.
486 Parameters
487 ----------
488 meta_enrich: pd.DataFrame
489 Output of `meta_enrichment()`, a DataFrame where rows are archetypes and columns are metadata categories,
490 with values representing normalized enrichment scores.
491 meta : str, optional
492 Label to use for the metadata category legend in the plot. Default is "Meta".
494 Returns
495 -------
496 pn.ggplot.ggplot
497 A stacked bar plot of metadata enrichment per archetype.
498 """
499 # prepare data
500 meta_enrich = meta_enrich.reset_index().rename(columns={"index": "archetype"})
501 meta_enrich_long = meta_enrich.melt(id_vars=["archetype"], var_name="Meta", value_name="Normalized_Enrichment")
503 # get unique categories and assign colors
504 categories = meta_enrich_long["Meta"].unique()
505 color_palette = hue_pal()(len(categories))
507 # Create plot
508 plot = (
509 pn.ggplot(
510 meta_enrich_long,
511 pn.aes(x="factor(archetype)", y="Normalized_Enrichment", fill="Meta"),
512 )
513 + pn.geom_bar(stat="identity", position="stack")
514 + pn.theme_matplotlib()
515 # + pn.scale_fill_brewer(type="qual", palette="Dark2")
516 + pn.scale_fill_manual(values=color_palette)
517 + pn.labs(
518 title="Meta Enrichment Across Archetypes",
519 x="Archetype",
520 y="Normalized Enrichment",
521 fill=meta,
522 )
523 )
524 return plot
527def heatmap_meta_enrichment(meta_enrich: pd.DataFrame, meta: str | None = "Meta"):
528 """
529 Generate a heatmap showing metadata enrichment across archetypes.
531 Parameters
532 ----------
533 meta_enrich: pd.DataFrame
534 Output of `meta_enrichment()`, a DataFrame where rows are archetypes and columns are metadata categories,
535 with values representing normalized enrichment scores.
536 meta : str, optional
537 Label to use for the metadata category legend in the plot. Default is "Meta".
539 Returns
540 -------
541 pn.ggplot.ggplot
542 A heatmap of normalized enrichment scores per archetype and metadata category.
543 """
544 # Prepare data
545 meta_enrich = meta_enrich.reset_index().rename(columns={"index": "archetype"})
546 meta_enrich_long = meta_enrich.melt(id_vars=["archetype"], var_name="Meta", value_name="Normalized_Enrichment")
548 # Create plot
549 plot = (
550 pn.ggplot(meta_enrich_long, pn.aes("archetype", "Meta", fill="Normalized_Enrichment"))
551 + pn.geom_tile()
552 + pn.scale_fill_continuous(cmap_name="Blues")
553 + pn.theme_matplotlib()
554 + pn.labs(title="Heatmap", x="Archetype", y=meta, fill=" Normalized \nEnrichment")
555 )
556 return plot
559def barplot_functional_enrichment(top_features: dict, show: bool = True):
560 """
561 Generate bar plots showing functional enrichment scores for each archetype.
563 Each plot displays the top enriched features (e.g., biological processes) for one archetype.
565 Parameters
566 ----------
567 top_features : dict
568 A dictionary where keys are archetype indices (0, 1,...) and values are pd.DataFrames
569 containing the data to plot. Each DataFrame should have a column for the feature ('Process') and a column
570 for the archetype (0, 1, ...)
571 show: bool, optional
572 If the plots should be printed.
574 Returns
575 -------
576 list
577 A list of `plotnine.ggplot` objects, one for each archetype.
578 """
579 plots = []
580 # Loop through archetypes
581 for key in range(len(top_features)):
582 data = top_features[key]
584 # Order column
585 data["Process"] = pd.Categorical(data["Process"], categories=data["Process"].tolist(), ordered=True)
587 # Create plot
588 plot = (
589 pn.ggplot(data, pn.aes(x="Process", y=str(key), fill=str(key)))
590 + pn.geom_bar(stat="identity")
591 + pn.labs(
592 title=f"Enrichment at archetype {key}",
593 x="Feature",
594 y="Enrichment score",
595 fill="Enrichment score",
596 )
597 + pn.theme_matplotlib()
598 + pn.theme(figure_size=(15, 5))
599 + pn.coord_flip()
600 + pn.scale_fill_gradient2(
601 low="blue",
602 mid="lightgrey",
603 high="red",
604 midpoint=0,
605 )
606 )
607 if show:
608 plot.show()
609 plots.append(plot)
611 # Return the list of plots
612 return plots
615def barplot_enrichment_comparison(specific_processes_arch: pd.DataFrame):
616 """
617 Plots a grouped bar plot comparing enrichment scores across archetypes for a given set of features.
619 Parameters
620 ----------
621 specific_processes_arch : pd.DataFrame
622 Output from `extract_specific_processes`. Must contain a 'Process' column, a 'specificity' score,
623 and one column per archetype with enrichment values.
625 Returns
626 -------
627 plotnine.ggplot.ggplot
628 A grouped bar plot visualizing the enrichment scores for the specified features across archetypes."
629 """
630 # Subset the DataFrame to include only the specified features
631 process_order = specific_processes_arch.sort_values("specificity", ascending=False)["Process"].to_list()
632 arch_columns = specific_processes_arch.drop(columns=["Process", "specificity"]).columns.to_list()
633 plot_df = specific_processes_arch.drop(columns="specificity").melt(
634 id_vars=["Process"], value_vars=arch_columns, var_name="Archetype", value_name="Enrichment"
635 )
636 plot_df["Process"] = pd.Categorical(plot_df["Process"], categories=process_order)
638 plot = (
639 pn.ggplot(plot_df, pn.aes(x="Process", y="Enrichment", fill="factor(Archetype)"))
640 + pn.geom_bar(stat="identity", position=pn.position_dodge())
641 + pn.theme_matplotlib()
642 + pn.scale_fill_brewer(type="qual", palette="Dark2")
643 + pn.labs(
644 x="Process",
645 y="Enrichment score",
646 fill="Archetype",
647 title="Enrichment Comparison",
648 )
649 + pn.theme(figure_size=(10, 5))
650 + pn.coord_flip()
651 )
652 return plot
655def radarplot_meta_enrichment(meta_enrich: pd.DataFrame):
656 """
657 Parameters
658 ----------
659 meta_enrich: pd.DataFrame
660 Output of meta_enrichment(), a pd.DataFrame containing the enrichment of meta categories (columns) for all archetypes (rows).
662 Returns
663 -------
664 plt.pyplot.Figure
665 Radar plots for all archetypes.
666 """
667 # Prepare data
668 meta_enrich = meta_enrich.T.reset_index().rename(columns={"index": "Meta_feature"})
670 # Function to create a radar plot for a given row
671 def make_radar(row, title, color):
672 # Set number of meta categories
673 categories = list(meta_enrich)[1:]
674 N = len(categories)
676 # Calculate angles for the radar plot
677 angles = [n / float(N) * 2 * np.pi for n in range(N)]
678 angles += angles[:1]
680 # Initialise the radar plot
681 ax = plt.subplot(int(np.ceil(len(meta_enrich) / 2)), 2, row + 1, polar=True)
683 # Put first axis on top:
684 ax.set_theta_offset(np.pi / 2)
685 ax.set_theta_direction(-1)
687 # One axe per variable and add labels
688 archetype_label = [f"A{i}" for i in range(len(list(meta_enrich)[1:]))]
689 plt.xticks(angles[:-1], archetype_label, color="grey", size=8)
691 # Draw ylabels
692 ax.set_rlabel_position(0)
693 plt.yticks(
694 [0, 0.25, 0.5, 0.75, 1],
695 ["0", "0.25", "0.50", "0.75", "1.0"],
696 color="grey",
697 size=7,
698 )
699 plt.ylim(0, 1)
701 # Draw plot
702 values = meta_enrich.loc[row].drop("Meta_feature").values.flatten().tolist()
703 values += values[:1]
704 ax.plot(angles, values, color=color, linewidth=2, linestyle="solid")
705 ax.fill(angles, values, color=color, alpha=0.4)
707 # Add a title
708 plt.title(title, size=11, color=color, y=1.065)
710 # Initialize the figure
711 my_dpi = 96
712 plt.figure(figsize=(1000 / my_dpi, 1000 / my_dpi), dpi=my_dpi)
714 # Create a color palette:
715 my_palette = plt.colormaps.get_cmap("Dark2")
717 # Loop to plot
718 for row in range(0, len(meta_enrich.index)):
719 make_radar(
720 row=row,
721 title=f"Feature: {meta_enrich['Meta_feature'][row]}",
722 color=my_palette(row),
723 )
725 return plt