Coverage for ParTIpy/paretoti_funcs.py: 13%
206 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-16 10:20 +0100
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-16 10:20 +0100
1import numpy as np
2import pandas as pd
3import plotly.express as px
4import plotly.graph_objects as go
5import plotnine as pn
6import scanpy as sc
7from joblib import Parallel, delayed
8from scipy.optimize import linear_sum_assignment
9from scipy.spatial import ConvexHull
10from scipy.spatial.distance import cdist
11from tqdm import tqdm
13from .arch import AA
14from .const import DEFAULT_INIT, DEFAULT_OPTIM
16####TODO####
17# add/fix t-ratio function
18# Function mean archetype variance for different n_archetypes
19############
22def reduce_pca(adata: sc.AnnData, n_pcs: int) -> None:
23 """
24 Reduces the PCA representation in `adata.obsm["X_pca"]` to the first `n_pcs` components.
25 If `adata.obsm["X_pca"]` does not exist, PCA is computed and stored in `adata.obsm["X_pca"]`.
26 The reduced PCA representation is stored in `adata.obsm["X_pca_reduced"]`.
28 Parameters
29 ----------
30 adata : sc.AnnData
31 AnnData object containing single-cell data.
32 n_pcs : int
33 The number of principal components (PCs) to retain. Must be less than or equal to the
34 number of available PCs in `adata.obsm["X_pca"]`.
36 Returns
37 -------
38 None
39 The results are stored in `adata.obsm["X_pca_reduced"]`
40 """
41 # Validation input
42 if "X_pca" not in adata.obsm:
43 print("X_pca not found in adata.obsm. Computing PCA...")
44 sc.pp.pca(adata, mask_var="highly_variable")
46 if n_pcs > adata.obsm["X_pca"].shape[1]:
47 raise ValueError(f"Requested {n_pcs} PCs, but only {adata.obsm['X_pca'].shape[1]} PCs are available.")
49 adata.obsm["X_pca_reduced"] = adata.obsm["X_pca"][:, :n_pcs]
52def var_explained_aa(
53 adata: sc.AnnData,
54 min_a: int = 2,
55 max_a: int = 10,
56 optim: str = DEFAULT_OPTIM,
57 init: str = DEFAULT_INIT,
58 n_jobs: int = -1,
59) -> None:
60 """
61 Compute the variance explained by Archetypal Analysis (AA) for a range of archetypes.
63 This function performs Archetypal Analysis (AA) for a range of archetypes (from `min_a` to `max_a`)
64 on the PCA-reduced data stored in `adata.obsm["X_pca_reduced"]`. If the reduced PCA representation
65 is not available, it uses the full PCA representation (`adata.obsm["X_pca"]`). The results are
66 stored in `adata.uns["AA_var"]`.
68 Parameters
69 ----------
70 adata: sc.AnnData
71 AnnData object containing adata.obsm["X_pca_reduced"] or
72 adata.obsm["X_pca"].
73 min_a : int, optional (default=2)
74 Minimum number of archetypes to test.
75 max_a : int, optional (default=10)
76 Maximum number of archetypes to test.
77 optim : str, optional (default=DEFAULT_OPTIM)
78 The optimization function to use for Archetypal Analysis.
79 init : str, optional (default=DEFAULT_INIT)
80 The initialization function to use for Archetypal Analysis.
81 n_jobs : int, optional (default=-1)
82 Number of jobs for parallel computation. `-1` uses all available cores.
84 Returns
85 -------
86 None
87 The results are stored in `adata.uns["AA_var"]` as a DataFrame with the following columns:
88 - `k`: The number of archetypes.
89 - `varexpl`: The variance explained by the model.
90 - `varexpl_ontop`: The additional variance explained compared to the model with `k-1` archetypes.
91 - `dist_to_projected`: The distance between the variance explained and its projection on the line
92 connecting the variance explained of first and last k.
93 """
94 # Validation input
95 if min_a < 2:
96 raise ValueError("`min_a` must be at least 2.")
97 if max_a < min_a:
98 raise ValueError("`max_a` must be greater than or equal to `min_a`.")
100 if "X_pca_reduced" not in adata.obsm:
101 print("No reduced PCA found. Calculating with all available PCs from X_pca")
102 X = adata.obsm["X_pca"]
103 else:
104 X = adata.obsm["X_pca_reduced"]
106 k_arr = np.arange(min_a, max_a + 1)
108 # Parallel computation of AA
109 def compute_aa(k):
110 A, B, Z, RSS, varexpl = AA(n_archetypes=k, optim=optim, init=init).fit(X).return_all()
111 return k, {"Z": Z, "A": A, "B": B, "RSS": RSS, "varexpl": varexpl}
113 results_list = Parallel(n_jobs=n_jobs)(delayed(compute_aa)(k) for k in k_arr)
115 # results = {k: result for k, result in results_list}
116 results = dict(results_list) # faster, and see https://docs.astral.sh/ruff/rules/unnecessary-comprehension/
118 varexpl_values = np.array([results[k]["varexpl"] for k in k_arr])
120 plot_df = pd.DataFrame(
121 {
122 "k": k_arr,
123 "varexpl": varexpl_values,
124 "varexpl_ontop": np.insert(np.diff(varexpl_values), 0, varexpl_values[0]),
125 }
126 )
128 # Compute the distance of the explained variance to its projection
129 offset_vec = plot_df[["k", "varexpl"]].iloc[0].values
130 proj_vec = (plot_df[["k", "varexpl"]].values - offset_vec)[-1, :][:, None]
131 proj_mtx = proj_vec @ np.linalg.inv(proj_vec.T @ proj_vec) @ proj_vec.T
132 proj_val = (proj_mtx @ (plot_df[["k", "varexpl"]].values - offset_vec).T).T + offset_vec
133 proj_df = pd.DataFrame(proj_val, columns=["k", "varexpl"])
134 plot_df["dist_to_projected"] = np.linalg.norm(
135 plot_df[["k", "varexpl"]].values - proj_df[["k", "varexpl"]].values, axis=1
136 )
138 adata.uns["AA_var"] = plot_df
141def plot_var_explained_aa(
142 adata: sc.AnnData,
143) -> pn.ggplot:
144 """
145 Generate an elbow plot of the variance explained by Archetypal Analysis (AA) for a range of archetypes.
147 This function creates a plot showing the variance explained by AA models with different numbers of archetypes.
148 The data is retrieved from `adata.uns["AA_var"]`. If `AA_var` is not found, `var_explained_aa` is called.
150 Parameters
151 ----------
152 adata : sc.AnnData
153 AnnData object containing the variance explained data in `adata.uns["AA_var"]`.
155 Returns
156 -------
157 pn.ggplot
158 A ggplot object showing the variance explained plot.
159 """
160 # Validation input
161 if "AA_var" not in adata.uns:
162 print("AA_var not found in adata.uns. Computing variance explained by archetypal analysis...")
163 var_explained_aa(adata=adata)
165 plot_df = adata.uns["AA_var"]
167 # Create data for the diagonal line
168 diag_data = pd.DataFrame(
169 {
170 "k": [plot_df["k"].min(), plot_df["k"].max()],
171 "varexpl": [plot_df["varexpl"].min(), plot_df["varexpl"].max()],
172 }
173 )
175 p = (
176 pn.ggplot(plot_df)
177 + pn.geom_line(mapping=pn.aes(x="k", y="varexpl"), color="black")
178 + pn.geom_point(mapping=pn.aes(x="k", y="varexpl"), color="black")
179 + pn.geom_line(data=diag_data, mapping=pn.aes(x="k", y="varexpl"), color="gray")
180 + pn.labs(x="Number of Archetypes (k)", y="Variance Explained")
181 + pn.lims(y=[0, 1])
182 + pn.scale_x_continuous(breaks=np.arange(plot_df["k"].min(), plot_df["k"].max() + 1))
183 + pn.theme_matplotlib()
184 )
185 return p
188def plot_projected_dist(
189 adata: sc.AnnData,
190) -> pn.ggplot:
191 """
192 Create a plot showing the projected distance for a range of archetypes.
193 The data is retrieved from `adata.uns["AA_var"]`. If `AA_var` is not found, `var_explained_aa` is called.
195 Parameters
196 ----------
197 adata : sc.AnnData
198 AnnData object containing the variance explained data in `adata.uns["AA_var"]`.
200 Returns
201 -------
202 pn.ggplot
203 A ggplot object showing the projected distance plot.
204 """
205 # Validation input
206 if "AA_var" not in adata.uns:
207 print("AA_var not found in adata.uns. Computing variance explained by archetypal analysis...")
208 var_explained_aa(adata=adata)
210 plot_df = adata.uns["AA_var"]
212 p = (
213 pn.ggplot(plot_df)
214 + pn.geom_col(mapping=pn.aes(x="k", y="dist_to_projected"))
215 + pn.scale_x_continuous(breaks=np.arange(plot_df["k"].min(), plot_df["k"].max() + 1))
216 + pn.labs(x="Number of Archetypes (k)", y="Distance to Projected Point")
217 + pn.theme_matplotlib()
218 )
220 return p
223def plot_var_on_top(
224 adata: sc.AnnData,
225) -> pn.ggplot:
226 """
227 Generate a plot showing the additional variance explained by AA models when increasing the number
228 of archetypes from `k-1` to `k` The data is retrieved from `adata.uns["AA_var"]`. If `AA_var` is not found, `var_explained_aa` is called.
230 Parameters
231 ----------
232 adata : sc.AnnData
233 AnnData objectt containing the variance explained data in `adata.uns["AA_var"]`.
235 Returns
236 -------
237 pn.ggplot
238 A ggplot object showing the variance explained on top of (k-1) model plot.
239 """
240 # Validation input
241 if "AA_var" not in adata.uns:
242 print("AA_var not found in adata.uns. Computing variance explained by archetypal analysis...")
243 var_explained_aa(adata=adata)
245 plot_df = adata.uns["AA_var"]
247 p = (
248 pn.ggplot(plot_df)
249 + pn.geom_point(pn.aes(x="k", y="varexpl_ontop"), color="black")
250 + pn.geom_line(pn.aes(x="k", y="varexpl_ontop"), color="black")
251 + pn.labs(x="Number of Archetypes (k)", y="Variance Explained on Top of (k-1) Model")
252 + pn.scale_x_continuous(breaks=np.arange(plot_df["k"].min(), plot_df["k"].max() + 1))
253 + pn.lims(y=(0, None))
254 + pn.theme_matplotlib()
255 )
257 return p
260def bootstrap_aa(
261 adata: sc.AnnData,
262 n_bootstrap: int,
263 n_archetypes: int,
264 optim: str = DEFAULT_OPTIM,
265 init: str = DEFAULT_INIT,
266 seed: int = 42,
267) -> None:
268 """
269 Perform bootstrap sampling to compute archetypes and assess their stability.
271 This function generates bootstrap samples from the data, computes archetypes for each sample,
272 aligns them with the reference archetypes, and stores the results in `adata.uns["AA_bootstrap"]`.
274 Parameters
275 ----------
276 adata : sc.AnnData
277 AnnData object. The PCA-reduced data should be stored in `adata.obsm["X_pca_reduced"]`. If not
278 found, the full PCA representation (`adata.obsm["X_pca"]`) is used.
279 n_bootstrap : int
280 The number of bootstrap samples to generate.
281 n_archetypes : int
282 The number of archetypes to compute for each bootstrap sample.
283 optim : str, optional (default=DEFAULT_OPTIM)
284 The optimization function to use for Archetypal Analysis.
285 init : str, optional (default=DEFAULT_INIT)
286 The initialization function to use for Archetypal Analysis.
287 seed : int, optional (default=42)
288 The random seed for reproducibility.
290 Returns
291 -------
292 None
293 The results are stored in `adata.uns["AA_bootstrap"]` as a DataFrame with the following columns:
294 - `pc_i`: The coordinates of the archetypes in the i-th principal component.
295 - `archetype`: The archetype index.
296 - `iter`: The bootstrap iteration index (0 for the reference archetypes).
297 - `reference`: A boolean indicating whether the archetype is from the reference model.
298 - `mean_variance`: The mean variance of archetype coordinates across bootstrap samples.
299 """
300 # Validation input
301 if "X_pca_reduced" not in adata.obsm:
302 if "X_pca" not in adata.obsm:
303 raise ValueError("Neither `X_pca_reduced` nor `X_pca` found in `adata.obsm`. Please compute PCA first.")
304 print("No reduced PCA found. Calculating with all available PCs from `X_pca`.")
305 X = adata.obsm["X_pca"]
306 else:
307 X = adata.obsm["X_pca_reduced"]
309 n_samples, n_features = X.shape
310 rng = np.random.default_rng(seed)
312 # Reference archetypes
313 ref_Z = AA(n_archetypes=n_archetypes, optim=optim, init=init).fit(X).Z
315 # Generate bootstrap samples
316 idx_bootstrap = rng.choice(n_samples, size=(n_bootstrap, n_samples), replace=True)
317 Z_list = [AA(n_archetypes=n_archetypes, optim=optim, init=init).fit(X[idx, :]).Z for idx in idx_bootstrap]
319 # Align archetypes
320 Z_list = [align_archetypes(ref_arch=ref_Z.copy(), query_arch=query_Z.copy()) for query_Z in Z_list]
322 # Compute variance
323 Z_stack = np.stack(Z_list)
324 var_per_archetype = Z_stack.var(axis=0).mean(axis=1)
325 mean_variance = var_per_archetype.mean()
327 # Create result dataframe
328 bootstrap_data = [
329 pd.DataFrame(Z, columns=[f"pc_{i}" for i in range(n_features)]).assign(
330 archetype=np.arange(n_archetypes), iter=i + 1
331 )
332 for i, Z in enumerate(Z_list)
333 ]
334 bootstrap_df = pd.concat(bootstrap_data)
336 df = pd.DataFrame(ref_Z, columns=[f"pc_{i}" for i in range(n_features)])
337 df["archetype"] = np.arange(n_archetypes)
338 df["iter"] = 0
340 bootstrap_df = pd.concat((bootstrap_df, df), axis=0)
341 bootstrap_df["reference"] = bootstrap_df["iter"] == 0
342 bootstrap_df["archetype"] = pd.Categorical(bootstrap_df["archetype"])
344 bootstrap_df["mean_variance"] = mean_variance
346 adata.uns["AA_bootstrap"] = bootstrap_df
349def plot_bootstrap_aa(adata: sc.AnnData) -> go.Figure:
350 """
351 Create an interactive 3D scatter plot showing the positions of archetypes
352 computed from bootstrap samples, stored in `adata.uns["AA_bootstrap"]`.
354 Parameters
355 ----------
356 adata : sc.AnnData
357 Annotated data object containing the archetype bootstrap data in `adata.uns["AA_bootstrap"]`.
359 Returns
360 -------
361 go.Figure:
362 3D plot of bootstrap results for the archetypes.
363 """
364 # Validation input
365 if "AA_bootstrap" not in adata.uns:
366 raise ValueError("AA_bootstrap not found in adata.uns. Please run bootstrap_aa() to compute")
368 # Generate the 3D scatter plot
369 bootstrap_df = adata.uns["AA_bootstrap"]
370 fig = px.scatter_3d(
371 bootstrap_df,
372 x="pc_0",
373 y="pc_1",
374 z="pc_2",
375 color="archetype",
376 symbol="reference",
377 labels={
378 "pc_0": "PC 1",
379 "pc_1": "PC 2",
380 "pc_2": "PC 3",
381 },
382 title="Archetypes on bootstrapepd data",
383 size_max=10,
384 hover_data=["iter", "archetype", "reference"],
385 opacity=0.5,
386 )
387 fig.update_layout(template="none")
389 return fig
392def project_on_affine_subspace(X, Z):
393 """
394 Projects a set of points X onto the affine subspace spanned by the vertices Z.
396 Parameters
397 ----------
398 X : numpy.ndarray
399 A (D x n) array of n points in D-dimensional space to be projected.
400 Z : numpy.ndarray
401 A (D x k) array of k vertices (archetypes) defining the affine subspace in D-dimensional space.
403 Returns
404 -------
405 proj_coord : numpy.ndarray
406 The coordinates of the projected points in the subspace defined by Z.
407 """
408 D, k = Z.shape
410 # Compute the projection vectors (basis for the affine subspace)
411 if k == 2:
412 # For a line (k=2), the projection vector is simply the difference between the two vertices
413 proj_vec = (Z[:, 1] - Z[:, 0])[:, None]
414 else:
415 # For higher dimensions, compute the projection vectors relative to the first vertex
416 proj_vec = Z[:, 1:] - Z[:, 0][:, None]
418 # Compute the coordinates of the projected points in the subspace
419 proj_coord = np.linalg.inv(proj_vec.T @ proj_vec) @ proj_vec.T @ (X - Z[:, 0][:, None])
421 return proj_coord
424def compute_t_ratio(X, Z=None):
425 """
426 Computes the ratio of the volume of the polytope defined by Z to the volume of the convex hull of X.
428 Parameters
429 ----------
430 adata : sc.AnnData
431 An AnnData object containing the following attributes:
432 - `adata.obsm["X_pca_reduced"]`: A (n x D) array of n data points in D-dimensional space.
433 - `adata.uns["archetypal_analysis"]["Z"]`: A (k x D) array of k archetypes defining the polytope in D-dimensional space.
435 Returns
436 -------
437 None
438 The function stores the computed t-ratio in `adata.uns["t_ratio"]`.
439 """
440 adata = None
441 if isinstance(X, np.ndarray):
442 if Z is None:
443 raise ValueError("Z must be provided when input_data is a numpy.ndarray.")
444 else:
445 adata = X
446 X = adata.obsm["X_pca_reduced"]
447 Z = adata.uns["archetypal_analysis"]["Z"]
449 # Extract dimensions D (PCs), and number of archetypes
450 D, k = X.shape[1], Z.shape[0]
452 # Input validation
453 if k < 2:
454 raise ValueError("k must satisfy 2 <= k, meaning you need at least 2 archetypes.")
456 if k < D + 1:
457 # project onto affine subspace spanned by Z
458 proj_X = project_on_affine_subspace(X.T, Z.T).T
459 proj_Z = project_on_affine_subspace(Z.T, Z.T).T
461 # Compute the convex hull volumes
462 convhull_volume = ConvexHull(proj_X).volume
463 polytope_volume = ConvexHull(proj_Z).volume
464 else:
465 # Compute the convex hull volumes directly
466 convhull_volume = ConvexHull(X).volume
467 polytope_volume = ConvexHull(Z).volume
469 t_ratio = polytope_volume / convhull_volume
471 if isinstance(adata, sc.AnnData):
472 adata.uns["t_ratio"] = t_ratio
473 else:
474 return t_ratio
477def t_ratio_significance(adata, iter=1000, seed=42, n_jobs=-1):
478 """
479 Assesses the significance of the polytope spanned by the archetypes by comparing the t-ratio of the original data to t-ratios computed from randomized datasets.
481 Parameters
482 ----------
483 adata : sc.AnnData
484 An AnnData object containing `adata.obsm["X_pca_reduced"]` and optionally `adata.uns["t_ratio"]`. If it doesnt exist it is called and computed.
485 rep : int, optional (default=1000)
486 Number of randomized datasets to generate.
487 seed : int, optional (default=42)
488 The random seed for reproducibility.
489 n_jobs : int, optional
490 Number of jobs for parallelization (default: 1). Use -1 to use all available cores.
492 Returns
493 -------
494 float
495 The proportion of randomized datasets with a t-ratio greater than the original t-ratio (p-value).
496 """
497 # Input validation
498 if "X_pca_reduced" not in adata.obsm:
499 raise ValueError("adata.obsm['X_pca_reduced'] not found.")
500 if "t_ratio" not in adata.uns:
501 print("Computing t-ratio...")
502 compute_t_ratio(adata)
504 X = adata.obsm["X_pca_reduced"]
505 t_ratio = adata.uns["t_ratio"]
506 n_samples, n_features = X.shape
507 n_archetypes = adata.uns["archetypal_analysis"]["Z"].shape[0]
509 rng = np.random.default_rng(seed)
511 def compute_randomized_t_ratio():
512 # Shuffle each feature independently
513 SimplexRand1 = np.array([rng.permutation(X[:, i]) for i in range(n_features)]).T
514 # Compute archetypes and t-ratio for randomized data
515 Z_mix = AA(n_archetypes=n_archetypes).fit(SimplexRand1).Z
516 return compute_t_ratio(SimplexRand1, Z_mix)
518 # Parallelize the computation of randomized t-ratios
519 RandRatio = Parallel(n_jobs=n_jobs)(
520 delayed(compute_randomized_t_ratio)() for _ in tqdm(range(iter), desc="Randomizing")
521 )
523 # Calculate the p-value
524 p_value = np.sum(np.array(RandRatio) > t_ratio) / iter
525 return p_value
528def plot_2D(
529 X: np.ndarray | sc.AnnData,
530 Z: np.ndarray | None = None,
531 color_vec: np.ndarray | None = None,
532) -> pn.ggplot:
533 """
534 2D plot of the datapoints in X and the 2D polytope enclosed by the archetypes in Z.
536 Parameters
537 ----------
538 X : Union[np.ndarray, sc.AnnData]
539 The input data, which can be either:
540 - A 2D array of shape (n_samples, n_features) representing the data points.
541 - An AnnData object containing the PCA-reduced data in `.obsm["X_pca_reduced"]` and archetypes in `.uns["archetypal_analysis"]["Z"]`.
542 Z : np.ndarray, optional
543 A 2D array of shape (n_archetypes, n_features) representing the archetype coordinates.
544 Required if `X` is not an AnnData object.
545 color_vec : np.ndarray, optional
546 A 1D array of shape (n_samples,) containing values for coloring the data points in `X`.
548 Returns
549 -------
550 pn.ggplot
551 2D plot of X and polytope enclosed by Z
552 """
553 # Validation input
554 if isinstance(X, sc.AnnData):
555 if "archetypal_analysis" not in X.uns:
556 raise ValueError("Result from Archetypal Analysis not found in adata.uns. Please run AA()")
557 Z = X.uns["archetypal_analysis"]["Z"]
558 X = X.obsm["X_pca_reduced"]
560 if Z is None:
561 raise ValueError("Please add the archetypes coordinates as input Z")
563 if X.shape[1] < 2 or Z.shape[1] < 2:
564 raise ValueError("Both X and Z must have at least 2 columns (PCs).")
566 X_plot, Z_plot = X[:, :2], Z[:, :2]
568 # Order archetypes for plotting the polytope
569 plot_df = pd.DataFrame(X_plot, columns=["x0", "x1"])
570 order = np.argsort(np.arctan2(Z_plot[:, 1] - np.mean(Z_plot[:, 1]), Z_plot[:, 0] - np.mean(Z_plot[:, 0])))
572 arch_df = pd.DataFrame(Z_plot, columns=["x0", "x1"])
573 arch_df = arch_df.iloc[order].reset_index(drop=True)
574 arch_df = pd.concat([arch_df, arch_df.iloc[:1]], ignore_index=True)
576 # Generate plot
577 p1 = pn.ggplot()
579 if color_vec is not None:
580 if len(color_vec) != len(plot_df):
581 raise ValueError("color_vec must have the same length as X.")
582 plot_df["color_vec"] = np.array(color_vec)
583 p1 += pn.geom_point(data=plot_df, mapping=pn.aes(x="x0", y="x1", color="color_vec"), alpha=0.5)
584 else:
585 p1 += pn.geom_point(data=plot_df, mapping=pn.aes(x="x0", y="x1"), color="black", alpha=0.5)
587 p1 += pn.geom_point(data=arch_df, mapping=pn.aes(x="x0", y="x1"), color="red", size=1)
588 p1 += pn.geom_path(data=arch_df, mapping=pn.aes(x="x0", y="x1"), color="red", size=1)
590 p1 += pn.labs(x="PC 1", y="PC 2")
591 p1 += pn.theme_matplotlib()
593 return p1
596def plot_3D(
597 X: np.ndarray | sc.AnnData,
598 Z: np.ndarray | None = None,
599 color_vec: np.ndarray | None = None,
600 marker_size: int = 4,
601 color_polyhedron: str = "green",
602) -> go.Figure:
603 """
604 3D plot of the datapoints in X and the 3D polytope enclosed by the archetypes in Z.
606 Parameters
607 ----------
608 X : Union[np.ndarray, sc.AnnData]
609 The input data, which can be either:
610 - A 2D array of shape (n_samples, n_features) representing the data points.
611 - An AnnData object containing the PCA-reduced data in `.obsm["X_pca_reduced"]` and archetypes in `.uns["archetypal_analysis"]["Z"]`.
612 Z : np.ndarray, optional
613 A 2D array of shape (n_archetypes, n_features) representing the archetype coordinates.
614 Required if `X` is not an AnnData object.
615 color_vec : np.ndarray, optional
616 A 1D array of shape (n_samples,) containing values for coloring the data points in `X`.
617 marker_size : int, optional (default=4)
618 The size of the markers for the data points in `X`.
619 color_polyhedron : str, optional (default="green")
620 The color of the polytope (convex hull) defined by the archetypes.
622 Returns
623 -------
624 go.Figuret
625 3D plot of X and polytope enclosed by Z
626 """
627 # Validation input
628 if isinstance(X, sc.AnnData):
629 if "archetypal_analysis" not in X.uns:
630 raise ValueError("Result from Archetypal Analysis not found in adata.uns. Please run AA()")
632 Z = X.uns["archetypal_analysis"]["Z"]
633 X = X.obsm["X_pca_reduced"]
635 if Z is None:
636 raise ValueError("Please add the archetypes coordinates as input Z")
638 if X.shape[1] < 3 or Z.shape[1] < 3:
639 raise ValueError("Both X and Z must have at least 3 columns (PCs).")
641 X_plot, Z_plot = X[:, :3], Z[:, :3]
643 plot_df = pd.DataFrame(X_plot, columns=["x0", "x1", "x2"])
644 plot_df["marker_size"] = np.repeat(marker_size, X_plot.shape[0])
646 # Create the 3D scatter plot
647 if color_vec is not None:
648 if len(color_vec) != len(plot_df):
649 raise ValueError("color_vec must have the same length as X.")
650 plot_df["color_vec"] = np.array(color_vec)
651 fig = px.scatter_3d(
652 plot_df,
653 x="x0",
654 y="x1",
655 z="x2",
656 labels={"x0": "PC 1", "x1": "PC 2", "x2": "PC 3"},
657 title="3D polytope",
658 color="color_vec",
659 size="marker_size",
660 size_max=10,
661 opacity=0.5,
662 )
663 else:
664 fig = px.scatter_3d(
665 plot_df,
666 x="x0",
667 y="x1",
668 z="x2",
669 labels={"x0": "PC 1", "x1": "PC 2", "x2": "PC 3"},
670 title="3D polytope",
671 size="marker_size",
672 size_max=10,
673 opacity=0.5,
674 )
676 # Compute the convex hull of the archetypes
677 hull = ConvexHull(Z_plot)
679 # Add archetypes to the plot
680 archetype_labels = [f"Archetype {i}" for i in range(Z_plot.shape[0])]
681 fig.add_trace(
682 go.Scatter3d(
683 x=Z_plot[:, 0],
684 y=Z_plot[:, 1],
685 z=Z_plot[:, 2],
686 mode="markers",
687 marker={"size": 4, "color": color_polyhedron, "symbol": "circle"},
688 text=archetype_labels,
689 hoverinfo="text",
690 name="Archetypes",
691 )
692 )
694 # Add the polytope (convex hull) to the plot
695 fig.add_trace(
696 go.Mesh3d(
697 x=Z_plot[:, 0],
698 y=Z_plot[:, 1],
699 z=Z_plot[:, 2],
700 i=hull.simplices[:, 0],
701 j=hull.simplices[:, 1],
702 k=hull.simplices[:, 2],
703 color=color_polyhedron,
704 opacity=0.1,
705 )
706 )
708 # Add edges of the polytope to the plot
709 for simplex in hull.simplices:
710 simplex = np.append(simplex, simplex[0])
711 fig.add_trace(
712 go.Scatter3d(
713 x=Z_plot[simplex, 0],
714 y=Z_plot[simplex, 1],
715 z=Z_plot[simplex, 2],
716 mode="lines",
717 line={"color": color_polyhedron, "width": 4},
718 showlegend=False,
719 )
720 )
722 fig.update_layout(template="none")
723 return fig
726def align_archetypes(ref_arch: np.ndarray, query_arch: np.ndarray) -> np.ndarray:
727 """
728 Align the archetypes of the query to match the order of archetypes in the reference.
730 This function uses the Euclidean distance between archetypes in the reference and query sets
731 to determine the optimal alignment. The Hungarian algorithm (linear sum assignment) is used
732 to find the best matching pairs, and the query archetypes are reordered accordingly.
734 Parameters
735 ----------
736 ref_arch : np.ndarray
737 A 2D array of shape (n_archetypes, n_features) representing the reference archetypes.
738 query_arch : np.ndarray
739 A 2D array of shape (n_archetypes, n_features) representing the query archetypes.
741 Returns
742 -------
743 np.ndarray
744 A 2D array of shape (n_archetypes, n_features) containing the reordered query archetypes.
745 """
746 # Compute pairwise Euclidean distances
747 euclidean_d = cdist(ref_arch, query_arch.copy(), metric="euclidean")
749 # Find the optimal assignment using the Hungarian algorithm
750 ref_idx, query_idx = linear_sum_assignment(euclidean_d)
752 return query_arch[query_idx, :]
755# def bootstrap_variance_k_arr(X, n_bootstrap, k_arr, delta=0, seed=42, **kwargs):
756# assert k_arr.min() > 1
757# bootstrap_var = np.array(
758# [
759# bootstrap_variance_single_k(
760# X, n_bootstrap=n_bootstrap, k=k, delta=delta, seed=seed, **kwargs
761# )
762# for k in k_arr
763# ]
764# )
765# plot_df = pd.DataFrame({"k": k_arr, "var": bootstrap_var})
766# p = (
767# pn.ggplot(plot_df, pn.aes(x="k", y="var"))
768# + pn.geom_point(color="blue")
769# + pn.geom_line(color="blue")
770# + pn.labs(x="Number of Archetypes", y="Mean Variance in Archetype Position")
771# )
772# return p
775# Appendix
777# not sure if this ratio of archeytpe over data variance is useful in any way
778# def compute_var_ratio_vitali(X, Z):
779# # adapted from: https://github.com/vitkl/ParetoTI/blob/510990630da589101c6a8313571c96f7544879da/R/fit_pch.R#L1178
780# data_var = X.var(axis=1)
781# arch_var = Z.var(axis=1)
782# return arch_var / data_var
784# see https://github.com/vitkl/ParetoTI/blob/510990630da589101c6a8313571c96f7544879da/R/fit_pch.R#L257
785# archetypes = Z.A.T
786# print(archetypes.shape)
787#
788## create random matrix where each row sums to 1
789# n_additional = 100
790# np.random.seed(42)
791# rand_mtx = np.random.rand(n_additional, archetypes.shape[0])
792# rand_mtx /= rand_mtx.sum(axis=1)[:, None]
793#
794# additional_archetypes = (rand_mtx @ archetypes)
795# print(additional_archetypes.shape)
796#
797# stacked_archetypes = np.row_stack([archetypes, additional_archetypes])
798# print(stacked_archetypes.shape)
799#
800## https://github.com/vitkl/ParetoTI/blob/510990630da589101c6a8313571c96f7544879da/R/fit_pch.R#L879C39-L879C46
801## Vitali used the "FA" argument, but this is turned on by default I think in the scipy version (FA - report total area and volume),
802## see http://www.qhull.org/html/qhull.htm
803# convhull_archetypes = ConvexHull(stacked_archetypes, qhull_options="QJ")
804# print(convhull_archetypes.volume)