Coverage for partipy/paretoti_funcs.py: 12%
226 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-16 12:01 +0100
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-16 12:01 +0100
1import inspect
3import numpy as np
4import pandas as pd
5import plotly.express as px
6import plotly.graph_objects as go
7import plotnine as pn
8import scanpy as sc
9from joblib import Parallel, delayed
10from scipy.optimize import linear_sum_assignment
11from scipy.spatial import ConvexHull
12from scipy.spatial.distance import cdist
13from tqdm import tqdm
15from .arch import AA
16from .const import DEFAULT_INIT, DEFAULT_OPTIM
19def set_dimension(adata: sc.AnnData, n_pcs: int) -> None:
20 """
21 Sets the number of PCs used for subsetting the PCA in `adata.obsm["X_pca"]`.
22 If `adata.obsm["X_pca"]` does not exist, PCA is computed and stored in `adata.obsm["X_pca"]`.
23 The number of PCs are stored in `adata.uns["PCs"]`
25 Parameters
26 ----------
27 adata : sc.AnnData
28 AnnData object containing single-cell data.
29 n_pcs : int
30 The number of principal components (PCs) to retain. Must be less than or equal to the
31 number of available PCs in `adata.obsm["X_pca"]`.
33 Returns
34 -------
35 None
36 The number of PCs are stored in `adata.uns["PCs"]`
37 """
38 # Validation input
39 if "X_pca" not in adata.obsm:
40 print("X_pca not found in adata.obsm. Computing PCA on highly variable genes...")
41 sc.pp.pca(adata, mask_var="highly_variable")
43 if n_pcs > adata.obsm["X_pca"].shape[1]:
44 raise ValueError(f"Requested {n_pcs} PCs, but only {adata.obsm['X_pca'].shape[1]} PCs are available.")
46 adata.uns["PCs"] = n_pcs
49def var_explained_aa(
50 adata: sc.AnnData,
51 min_a: int = 2,
52 max_a: int = 10,
53 optim: str = DEFAULT_OPTIM,
54 init: str = DEFAULT_INIT,
55 n_jobs: int = -1,
56) -> None:
57 """
58 Compute the variance explained by Archetypal Analysis (AA) for a range of archetypes.
60 This function performs Archetypal Analysis (AA) for a range of archetypes (from `min_a` to `max_a`)
61 on the PCA data stored in `adata.obsm["X_pca"]`. The results are
62 stored in `adata.uns["AA_var"]`.
64 Parameters
65 ----------
66 adata: sc.AnnData
67 AnnData object containing adata.obsm["X_pca"].
68 min_a : int, optional (default=2)
69 Minimum number of archetypes to test.
70 max_a : int, optional (default=10)
71 Maximum number of archetypes to test.
72 optim : str, optional (default=DEFAULT_OPTIM)
73 The optimization function to use for Archetypal Analysis.
74 init : str, optional (default=DEFAULT_INIT)
75 The initialization function to use for Archetypal Analysis.
76 n_jobs : int, optional (default=-1)
77 Number of jobs for parallel computation. `-1` uses all available cores.
79 Returns
80 -------
81 None
82 The results are stored in `adata.uns["AA_var"]` as a DataFrame with the following columns:
83 - `k`: The number of archetypes.
84 - `varexpl`: The variance explained by the model.
85 - `varexpl_ontop`: The additional variance explained compared to the model with `k-1` archetypes.
86 - `dist_to_projected`: The distance between the variance explained and its projection on the line
87 connecting the variance explained of first and last k.
88 """
89 # Validation input
90 if min_a < 2:
91 raise ValueError("`min_a` must be at least 2.")
92 if max_a < min_a:
93 raise ValueError("`max_a` must be greater than or equal to `min_a`.")
95 if "PCs" not in adata.uns:
96 raise ValueError(
97 "PCs not found in adata.uns. Please set the dimension for archetypal analysis with set_dimension()"
98 )
100 X = adata.obsm["X_pca"][:, : adata.uns["PCs"]]
102 k_arr = np.arange(min_a, max_a + 1)
104 # Parallel computation of AA
105 def compute_aa(k):
106 A, B, Z, RSS, varexpl = AA(n_archetypes=k, optim=optim, init=init).fit(X).return_all()
107 return k, {"Z": Z, "A": A, "B": B, "RSS": RSS, "varexpl": varexpl}
109 results_list = Parallel(n_jobs=n_jobs)(delayed(compute_aa)(k) for k in k_arr)
111 # results = {k: result for k, result in results_list}
112 results = dict(results_list) # faster, and see https://docs.astral.sh/ruff/rules/unnecessary-comprehension/
114 varexpl_values = np.array([results[k]["varexpl"] for k in k_arr])
116 plot_df = pd.DataFrame(
117 {
118 "k": k_arr,
119 "varexpl": varexpl_values,
120 "varexpl_ontop": np.insert(np.diff(varexpl_values), 0, varexpl_values[0]),
121 }
122 )
124 # Compute the distance of the explained variance to its projection
125 offset_vec = plot_df[["k", "varexpl"]].iloc[0].values
126 proj_vec = (plot_df[["k", "varexpl"]].values - offset_vec)[-1, :][:, None]
127 proj_mtx = proj_vec @ np.linalg.inv(proj_vec.T @ proj_vec) @ proj_vec.T
128 proj_val = (proj_mtx @ (plot_df[["k", "varexpl"]].values - offset_vec).T).T + offset_vec
129 proj_df = pd.DataFrame(proj_val, columns=["k", "varexpl"])
130 plot_df["dist_to_projected"] = np.linalg.norm(
131 plot_df[["k", "varexpl"]].values - proj_df[["k", "varexpl"]].values, axis=1
132 )
134 adata.uns["AA_var"] = plot_df
137def plot_var_explained_aa(
138 adata: sc.AnnData,
139) -> pn.ggplot:
140 """
141 Generate an elbow plot of the variance explained by Archetypal Analysis (AA) for a range of archetypes.
143 This function creates a plot showing the variance explained by AA models with different numbers of archetypes.
144 The data is retrieved from `adata.uns["AA_var"]`. If `AA_var` is not found, `var_explained_aa` is called.
146 Parameters
147 ----------
148 adata : sc.AnnData
149 AnnData object containing the variance explained data in `adata.uns["AA_var"]`.
151 Returns
152 -------
153 pn.ggplot
154 A ggplot object showing the variance explained plot.
155 """
156 # Validation input
157 if "AA_var" not in adata.uns:
158 print("AA_var not found in adata.uns. Computing variance explained by archetypal analysis...")
159 var_explained_aa(adata=adata)
161 plot_df = adata.uns["AA_var"]
163 # Create data for the diagonal line
164 diag_data = pd.DataFrame(
165 {
166 "k": [plot_df["k"].min(), plot_df["k"].max()],
167 "varexpl": [plot_df["varexpl"].min(), plot_df["varexpl"].max()],
168 }
169 )
171 p = (
172 pn.ggplot(plot_df)
173 + pn.geom_line(mapping=pn.aes(x="k", y="varexpl"), color="black")
174 + pn.geom_point(mapping=pn.aes(x="k", y="varexpl"), color="black")
175 + pn.geom_line(data=diag_data, mapping=pn.aes(x="k", y="varexpl"), color="gray")
176 + pn.labs(x="Number of Archetypes (k)", y="Variance Explained")
177 + pn.lims(y=[0, 1])
178 + pn.scale_x_continuous(breaks=list(np.arange(plot_df["k"].min(), plot_df["k"].max() + 1)))
179 + pn.theme_matplotlib()
180 )
181 return p
184def plot_projected_dist(
185 adata: sc.AnnData,
186) -> pn.ggplot:
187 """
188 Create a plot showing the projected distance for a range of archetypes.
189 The data is retrieved from `adata.uns["AA_var"]`. If `AA_var` is not found, `var_explained_aa` is called.
191 Parameters
192 ----------
193 adata : sc.AnnData
194 AnnData object containing the variance explained data in `adata.uns["AA_var"]`.
196 Returns
197 -------
198 pn.ggplot
199 A ggplot object showing the projected distance plot.
200 """
201 # Validation input
202 if "AA_var" not in adata.uns:
203 print("AA_var not found in adata.uns. Computing variance explained by archetypal analysis...")
204 var_explained_aa(adata=adata)
206 plot_df = adata.uns["AA_var"]
208 p = (
209 pn.ggplot(plot_df)
210 + pn.geom_col(mapping=pn.aes(x="k", y="dist_to_projected"))
211 + pn.scale_x_continuous(breaks=list(np.arange(plot_df["k"].min(), plot_df["k"].max() + 1)))
212 + pn.labs(x="Number of Archetypes (k)", y="Distance to Projected Point")
213 + pn.theme_matplotlib()
214 )
216 return p
219def plot_var_on_top(
220 adata: sc.AnnData,
221) -> pn.ggplot:
222 """
223 Generate a plot showing the additional variance explained by AA models when increasing the number
224 of archetypes from `k-1` to `k` The data is retrieved from `adata.uns["AA_var"]`. If `AA_var` is not found, `var_explained_aa` is called.
226 Parameters
227 ----------
228 adata : sc.AnnData
229 AnnData objectt containing the variance explained data in `adata.uns["AA_var"]`.
231 Returns
232 -------
233 pn.ggplot
234 A ggplot object showing the variance explained on top of (k-1) model plot.
235 """
236 # Validation input
237 if "AA_var" not in adata.uns:
238 print("AA_var not found in adata.uns. Computing variance explained by archetypal analysis...")
239 var_explained_aa(adata=adata)
241 plot_df = adata.uns["AA_var"]
243 p = (
244 pn.ggplot(plot_df)
245 + pn.geom_point(pn.aes(x="k", y="varexpl_ontop"), color="black")
246 + pn.geom_line(pn.aes(x="k", y="varexpl_ontop"), color="black")
247 + pn.labs(x="Number of Archetypes (k)", y="Variance Explained on Top of (k-1) Model")
248 + pn.scale_x_continuous(breaks=list(np.arange(plot_df["k"].min(), plot_df["k"].max() + 1)))
249 + pn.lims(y=(0, None))
250 + pn.theme_matplotlib()
251 )
253 return p
256def bootstrap_aa(
257 adata: sc.AnnData,
258 n_bootstrap: int,
259 n_archetypes: int,
260 optim: str = DEFAULT_OPTIM,
261 init: str = DEFAULT_INIT,
262 seed: int = 42,
263) -> None:
264 """
265 Perform bootstrap sampling to compute archetypes and assess their stability.
267 This function generates bootstrap samples from the data, computes archetypes for each sample,
268 aligns them with the reference archetypes, and stores the results in `adata.uns["AA_bootstrap"]`.
270 Parameters
271 ----------
272 adata : sc.AnnData
273 AnnData object. The PCA data should be stored in `adata.obsm["X_pca"]`.
274 n_bootstrap : int
275 The number of bootstrap samples to generate.
276 n_archetypes : int
277 The number of archetypes to compute for each bootstrap sample.
278 optim : str, optional (default=DEFAULT_OPTIM)
279 The optimization function to use for Archetypal Analysis.
280 init : str, optional (default=DEFAULT_INIT)
281 The initialization function to use for Archetypal Analysis.
282 seed : int, optional (default=42)
283 The random seed for reproducibility.
285 Returns
286 -------
287 None
288 The results are stored in `adata.uns["AA_bootstrap"]` as a DataFrame with the following columns:
289 - `pc_i`: The coordinates of the archetypes in the i-th principal component.
290 - `archetype`: The archetype index.
291 - `iter`: The bootstrap iteration index (0 for the reference archetypes).
292 - `reference`: A boolean indicating whether the archetype is from the reference model.
293 - `mean_variance`: The mean variance of archetype coordinates across bootstrap samples.
294 """
295 # Validation input
296 if "PCs" not in adata.uns:
297 raise ValueError(
298 "PCs not found in adata.uns. Please set the dimension for archetypal analysis with set_dimension()"
299 )
301 X = adata.obsm["X_pca"][:, : adata.uns["PCs"]]
303 n_samples, n_features = X.shape
304 rng = np.random.default_rng(seed)
306 # Reference archetypes
307 ref_Z = AA(n_archetypes=n_archetypes, optim=optim, init=init).fit(X).Z
309 # Generate bootstrap samples
310 idx_bootstrap = rng.choice(n_samples, size=(n_bootstrap, n_samples), replace=True)
311 Z_list = [AA(n_archetypes=n_archetypes, optim=optim, init=init).fit(X[idx, :]).Z for idx in idx_bootstrap]
313 # Align archetypes
314 Z_list = [align_archetypes(ref_arch=ref_Z.copy(), query_arch=query_Z.copy()) for query_Z in Z_list]
316 # Compute variance
317 Z_stack = np.stack(Z_list)
318 var_per_archetype = Z_stack.var(axis=0).mean(axis=1)
319 mean_variance = var_per_archetype.mean()
321 # Create result dataframe
322 bootstrap_data = [
323 pd.DataFrame(Z, columns=[f"pc_{i}" for i in range(n_features)]).assign(
324 archetype=np.arange(n_archetypes), iter=i + 1
325 )
326 for i, Z in enumerate(Z_list)
327 ]
328 bootstrap_df = pd.concat(bootstrap_data)
330 df = pd.DataFrame(ref_Z, columns=[f"pc_{i}" for i in range(n_features)])
331 df["archetype"] = np.arange(n_archetypes)
332 df["iter"] = 0
334 bootstrap_df = pd.concat((bootstrap_df, df), axis=0)
335 bootstrap_df["reference"] = bootstrap_df["iter"] == 0
336 bootstrap_df["archetype"] = pd.Categorical(bootstrap_df["archetype"])
338 bootstrap_df["mean_variance"] = mean_variance
340 adata.uns["AA_bootstrap"] = bootstrap_df
343def plot_bootstrap_aa(adata: sc.AnnData) -> go.Figure:
344 """
345 Create an interactive 3D scatter plot showing the positions of archetypes
346 computed from bootstrap samples, stored in `adata.uns["AA_bootstrap"]`.
348 Parameters
349 ----------
350 adata : sc.AnnData
351 Annotated data object containing the archetype bootstrap data in `adata.uns["AA_bootstrap"]`.
353 Returns
354 -------
355 go.Figure:
356 3D plot of bootstrap results for the archetypes.
357 """
358 # Validation input
359 if "AA_bootstrap" not in adata.uns:
360 raise ValueError("AA_bootstrap not found in adata.uns. Please run bootstrap_aa() to compute")
362 # Generate the 3D scatter plot
363 bootstrap_df = adata.uns["AA_bootstrap"]
364 fig = px.scatter_3d(
365 bootstrap_df,
366 x="pc_0",
367 y="pc_1",
368 z="pc_2",
369 color="archetype",
370 symbol="reference",
371 labels={
372 "pc_0": "PC 1",
373 "pc_1": "PC 2",
374 "pc_2": "PC 3",
375 },
376 title="Archetypes on bootstrapepd data",
377 size_max=10,
378 hover_data=["iter", "archetype", "reference"],
379 opacity=0.5,
380 )
381 fig.update_layout(template="none")
383 return fig
386def project_on_affine_subspace(X, Z) -> np.ndarray:
387 """
388 Projects a set of points X onto the affine subspace spanned by the vertices Z.
390 Parameters
391 ----------
392 X : numpy.ndarray
393 A (D x n) array of n points in D-dimensional space to be projected.
394 Z : numpy.ndarray
395 A (D x k) array of k vertices (archetypes) defining the affine subspace in D-dimensional space.
397 Returns
398 -------
399 proj_coord : numpy.ndarray
400 The coordinates of the projected points in the subspace defined by Z.
401 """
402 D, k = Z.shape
404 # Compute the projection vectors (basis for the affine subspace)
405 if k == 2:
406 # For a line (k=2), the projection vector is simply the difference between the two vertices
407 proj_vec = (Z[:, 1] - Z[:, 0])[:, None]
408 else:
409 # For higher dimensions, compute the projection vectors relative to the first vertex
410 proj_vec = Z[:, 1:] - Z[:, 0][:, None]
412 # Compute the coordinates of the projected points in the subspace
413 proj_coord = np.linalg.inv(proj_vec.T @ proj_vec) @ proj_vec.T @ (X - Z[:, 0][:, None])
415 return proj_coord
418def compute_t_ratio(
419 X: sc.AnnData | np.ndarray,
420 Z: np.ndarray | None = None,
421) -> float | None:
422 """
423 Compute the t-ratio, which is the ratio of the volume of the polytope defined by the archetypes (Z)
424 to the volume of the convex hull of the data points (X).
426 Parameters
427 ----------
428 X : Union[sc.AnnData, np.ndarray]
429 The input data, which can be either:
430 - An AnnData object containing the following attributes:
431 - `adata.obsm["X_pca"]`: A 2D array of shape (n_samples, n_features) representing the PCA coordinates of the data.
432 - `adata.uns["PCs"]`: The number of principal components used for AA.
433 - `adata.uns["archetypal_analysis"]["Z"]`: A 2D array of shape (n_archetypes, n_features) representing the archetypes.
434 - A 2D numpy array of shape (n_samples, n_features) representing the data matrix. In this case, `Z` must be provided.
435 Z : np.ndarray, optional
436 A 2D array of shape (n_archetypes, n_features) representing the archetypes. Required if `X` is a numpy array.
438 Returns
439 -------
440 Optional[float]
441 - If `X` is an AnnData object, the t-ratio is stored in `X.uns["t_ratio"]` and nothing is returned.
442 - If `X` is a numpy array, the t-ratio is returned as a float.
443 """
444 adata = None
445 if isinstance(X, np.ndarray):
446 if Z is None:
447 raise ValueError("Z must be provided when input_data is a numpy.ndarray.")
448 else:
449 adata = X
450 X = adata.obsm["X_pca"][:, : adata.uns["PCs"]]
451 Z = adata.uns["archetypal_analysis"]["Z"]
453 # Extract dimensions D (PCs), and number of archetypes
454 D, k = X.shape[1], Z.shape[0] # type: ignore[union-attr]
456 # Input validation
457 if k < 2:
458 raise ValueError("k must satisfy 2 <= k, meaning you need at least 2 archetypes.")
460 if k < D + 1:
461 # project onto affine subspace spanned by Z
462 proj_X = project_on_affine_subspace(X.T, Z.T).T # type: ignore[union-attr]
463 proj_Z = project_on_affine_subspace(Z.T, Z.T).T # type: ignore[union-attr]
465 # Compute the convex hull volumes
466 convhull_volume = ConvexHull(proj_X).volume
467 polytope_volume = ConvexHull(proj_Z).volume
468 else:
469 # Compute the convex hull volumes directly
470 convhull_volume = ConvexHull(X).volume
471 polytope_volume = ConvexHull(Z).volume
473 t_ratio = polytope_volume / convhull_volume
475 if isinstance(adata, sc.AnnData):
476 adata.uns["t_ratio"] = t_ratio
477 return None
478 else:
479 return t_ratio
482def t_ratio_significance(adata, iter=1000, seed=42, n_jobs=-1):
483 """
484 Assesses the significance of the polytope spanned by the archetypes by comparing the t-ratio of the original data to t-ratios computed from randomized datasets.
486 Parameters
487 ----------
488 adata : sc.AnnData
489 An AnnData object containing `adata.obsm["X_pca"]` and `adata.uns["PCs"], optionally `adata.uns["t_ratio"]`. If `adata.uns["t_ratio"]` doesnt exist it is called and computed.
490 rep : int, optional (default=1000)
491 Number of randomized datasets to generate.
492 seed : int, optional (default=42)
493 The random seed for reproducibility.
494 n_jobs : int, optional
495 Number of jobs for parallelization (default: 1). Use -1 to use all available cores.
497 Returns
498 -------
499 float
500 The proportion of randomized datasets with a t-ratio greater than the original t-ratio (p-value).
501 """
502 # Input validation
503 if "X_pca" not in adata.obsm:
504 raise ValueError("adata.obsm['X_pca'] not found.")
505 if "t_ratio" not in adata.uns:
506 print("Computing t-ratio...")
507 compute_t_ratio(adata)
509 X = adata.obsm["X_pca"][:, : adata.uns["PCs"]]
510 t_ratio = adata.uns["t_ratio"]
511 n_samples, n_features = X.shape
512 n_archetypes = adata.uns["archetypal_analysis"]["Z"].shape[0]
514 rng = np.random.default_rng(seed)
516 def compute_randomized_t_ratio():
517 # Shuffle each feature independently
518 SimplexRand1 = np.array([rng.permutation(X[:, i]) for i in range(n_features)]).T
519 # Compute archetypes and t-ratio for randomized data
520 Z_mix = AA(n_archetypes=n_archetypes).fit(SimplexRand1).Z
521 return compute_t_ratio(SimplexRand1, Z_mix)
523 # Parallelize the computation of randomized t-ratios
524 RandRatio = Parallel(n_jobs=n_jobs)(
525 delayed(compute_randomized_t_ratio)() for _ in tqdm(range(iter), desc="Randomizing")
526 )
528 # Calculate the p-value
529 p_value = np.sum(np.array(RandRatio) > t_ratio) / iter
530 return p_value
533def plot_2D(
534 X: np.ndarray | sc.AnnData,
535 Z: np.ndarray | None = None,
536 color_vec: np.ndarray | None = None,
537) -> pn.ggplot:
538 """
539 2D plot of the datapoints in X and the 2D polytope enclosed by the archetypes in Z.
541 Parameters
542 ----------
543 X : Union[np.ndarray, sc.AnnData]
544 The input data, which can be either:
545 - A 2D array of shape (n_samples, n_features) representing the data points.
546 - An AnnData object containing the PCA data in `X.obsm["X_pca"]` and archetypes in `X.uns["archetypal_analysis"]["Z"]`.
547 Z : np.ndarray, optional
548 A 2D array of shape (n_archetypes, n_features) representing the archetype coordinates.
549 Required if `X` is not an AnnData object.
550 color_vec : np.ndarray, optional
551 A 1D array of shape (n_samples,) containing values for coloring the data points in `X`.
553 Returns
554 -------
555 pn.ggplot
556 2D plot of X and polytope enclosed by Z
557 """
558 # Validation input
559 if isinstance(X, sc.AnnData):
560 if "archetypal_analysis" not in X.uns:
561 raise ValueError("Result from Archetypal Analysis not found in adata.uns. Please run AA()")
562 Z = X.uns["archetypal_analysis"]["Z"]
563 X = X.obsm["X_pca"][:, : X.uns["PCs"]]
565 if Z is None:
566 raise ValueError("Please add the archetypes coordinates as input Z")
568 if X.shape[1] < 2 or Z.shape[1] < 2:
569 raise ValueError("Both X and Z must have at least 2 columns (PCs).")
571 X_plot, Z_plot = X[:, :2], Z[:, :2]
573 # Order archetypes for plotting the polytope
574 plot_df = pd.DataFrame(X_plot, columns=["x0", "x1"])
575 order = np.argsort(np.arctan2(Z_plot[:, 1] - np.mean(Z_plot[:, 1]), Z_plot[:, 0] - np.mean(Z_plot[:, 0])))
577 arch_df = pd.DataFrame(Z_plot, columns=["x0", "x1"])
578 arch_df = arch_df.iloc[order].reset_index(drop=True)
579 arch_df = pd.concat([arch_df, arch_df.iloc[:1]], ignore_index=True)
581 # Generate plot
582 p1 = pn.ggplot()
584 if color_vec is not None:
585 if len(color_vec) != len(plot_df):
586 raise ValueError("color_vec must have the same length as X.")
587 plot_df["color_vec"] = np.array(color_vec)
588 p1 += pn.geom_point(data=plot_df, mapping=pn.aes(x="x0", y="x1", color="color_vec"), alpha=0.5)
589 else:
590 p1 += pn.geom_point(data=plot_df, mapping=pn.aes(x="x0", y="x1"), color="black", alpha=0.5)
592 p1 += pn.geom_point(data=arch_df, mapping=pn.aes(x="x0", y="x1"), color="red", size=1)
593 p1 += pn.geom_path(data=arch_df, mapping=pn.aes(x="x0", y="x1"), color="red", size=1)
595 p1 += pn.labs(x="PC 1", y="PC 2")
596 p1 += pn.theme_matplotlib()
598 return p1
601def plot_3D(
602 X: np.ndarray | sc.AnnData,
603 Z: np.ndarray | None = None,
604 color_vec: np.ndarray | None = None,
605 marker_size: int = 4,
606 color_polyhedron: str = "green",
607) -> go.Figure:
608 """
609 3D plot of the datapoints in X and the 3D polytope enclosed by the archetypes in Z.
611 Parameters
612 ----------
613 X : Union[np.ndarray, sc.AnnData]
614 The input data, which can be either:
615 - A 2D array of shape (n_samples, n_features) representing the data points.
616 - An AnnData object containing the PCA data in `X.obsm["X_pca"]` and archetypes in `X.uns["archetypal_analysis"]["Z"]`.
617 Z : np.ndarray, optional
618 A 2D array of shape (n_archetypes, n_features) representing the archetype coordinates.
619 Required if `X` is not an AnnData object.
620 color_vec : np.ndarray, optional
621 A 1D array of shape (n_samples,) containing values for coloring the data points in `X`.
622 marker_size : int, optional (default=4)
623 The size of the markers for the data points in `X`.
624 color_polyhedron : str, optional (default="green")
625 The color of the polytope (convex hull) defined by the archetypes.
627 Returns
628 -------
629 go.Figuret
630 3D plot of X and polytope enclosed by Z
631 """
632 # Validation input
633 if isinstance(X, sc.AnnData):
634 if "archetypal_analysis" not in X.uns:
635 raise ValueError("Result from Archetypal Analysis not found in adata.uns. Please run AA()")
637 Z = X.uns["archetypal_analysis"]["Z"]
638 X = X.obsm["X_pca"][:, : X.uns["PCs"]]
640 if Z is None:
641 raise ValueError("Please add the archetypes coordinates as input Z")
643 if X.shape[1] < 3 or Z.shape[1] < 3:
644 raise ValueError("Both X and Z must have at least 3 columns (PCs).")
646 X_plot, Z_plot = X[:, :3], Z[:, :3]
648 plot_df = pd.DataFrame(X_plot, columns=["x0", "x1", "x2"])
649 plot_df["marker_size"] = np.repeat(marker_size, X_plot.shape[0])
651 # Create the 3D scatter plot
652 if color_vec is not None:
653 if len(color_vec) != len(plot_df):
654 raise ValueError("color_vec must have the same length as X.")
655 plot_df["color_vec"] = np.array(color_vec)
656 fig = px.scatter_3d(
657 plot_df,
658 x="x0",
659 y="x1",
660 z="x2",
661 labels={"x0": "PC 1", "x1": "PC 2", "x2": "PC 3"},
662 title="3D polytope",
663 color="color_vec",
664 size="marker_size",
665 size_max=10,
666 opacity=0.5,
667 )
668 else:
669 fig = px.scatter_3d(
670 plot_df,
671 x="x0",
672 y="x1",
673 z="x2",
674 labels={"x0": "PC 1", "x1": "PC 2", "x2": "PC 3"},
675 title="3D polytope",
676 size="marker_size",
677 size_max=10,
678 opacity=0.5,
679 )
681 # Compute the convex hull of the archetypes
682 hull = ConvexHull(Z_plot)
684 # Add archetypes to the plot
685 archetype_labels = [f"Archetype {i}" for i in range(Z_plot.shape[0])]
686 fig.add_trace(
687 go.Scatter3d(
688 x=Z_plot[:, 0],
689 y=Z_plot[:, 1],
690 z=Z_plot[:, 2],
691 mode="markers",
692 text=archetype_labels,
693 marker=dict(size=4, color=color_polyhedron, symbol="circle"), # noqa: C408
694 hoverinfo="text",
695 name="Archetypes",
696 )
697 )
699 # Add the polytope (convex hull) to the plot
700 fig.add_trace(
701 go.Mesh3d(
702 x=Z_plot[:, 0],
703 y=Z_plot[:, 1],
704 z=Z_plot[:, 2],
705 i=hull.simplices[:, 0],
706 j=hull.simplices[:, 1],
707 k=hull.simplices[:, 2],
708 color=color_polyhedron,
709 opacity=0.1,
710 )
711 )
713 # Add edges of the polytope to the plot
714 for simplex in hull.simplices:
715 simplex = np.append(simplex, simplex[0])
716 fig.add_trace(
717 go.Scatter3d(
718 x=Z_plot[simplex, 0],
719 y=Z_plot[simplex, 1],
720 z=Z_plot[simplex, 2],
721 mode="lines",
722 line={"color": color_polyhedron, "width": 4},
723 showlegend=False,
724 )
725 )
727 fig.update_layout(template="none")
728 return fig
731def align_archetypes(ref_arch: np.ndarray, query_arch: np.ndarray) -> np.ndarray:
732 """
733 Align the archetypes of the query to match the order of archetypes in the reference.
735 This function uses the Euclidean distance between archetypes in the reference and query sets
736 to determine the optimal alignment. The Hungarian algorithm (linear sum assignment) is used
737 to find the best matching pairs, and the query archetypes are reordered accordingly.
739 Parameters
740 ----------
741 ref_arch : np.ndarray
742 A 2D array of shape (n_archetypes, n_features) representing the reference archetypes.
743 query_arch : np.ndarray
744 A 2D array of shape (n_archetypes, n_features) representing the query archetypes.
746 Returns
747 -------
748 np.ndarray
749 A 2D array of shape (n_archetypes, n_features) containing the reordered query archetypes.
750 """
751 # Compute pairwise Euclidean distances
752 euclidean_d = cdist(ref_arch, query_arch.copy(), metric="euclidean")
754 # Find the optimal assignment using the Hungarian algorithm
755 ref_idx, query_idx = linear_sum_assignment(euclidean_d)
757 return query_arch[query_idx, :]
760def compute_AA(
761 adata: sc.AnnData | np.ndarray,
762 n_archetypes: int,
763 init: str | None = None,
764 optim: str | None = None,
765 weight: None | str = None,
766 max_iter: int | None = None,
767 derivative_max_iter: int | None = None,
768 tol: float | None = None,
769 verbose: bool | None = None,
770 save_to_anndata: bool = True,
771 archetypes_only: bool = True,
772) -> np.ndarray | tuple[np.ndarray, np.ndarray, np.ndarray, float, float] | None:
773 """
775 Perform Archetypal Analysis (AA) on the input data.
777 This function is a wrapper for the AA class, providing a simplified interface for fitting the model,
778 and returning the desired outputs or saving them to the AnnData object.
780 Parameters
781 ----------
782 adata : Union[sc.AnnData, np.ndarray]
783 The input data, which can be either:
784 - An AnnData object containing data in `adata.obsm["X_pca"]`.
785 - A 2D numpy array of shape (n_samples, n_features) representing the data matrix.
786 n_archetypes : int
787 The number of archetypes to compute.
788 init : str, optional
789 The initialization method for the archetypes. If not provided, the default from the AA class is used.
790 Options include:
791 - "random": Random initialization.
792 - "furthest_sum": Furthest sum initialization.
793 optim : str, optional
794 The optimization method for fitting the model. If not provided, the default from the AA class is used.
795 Options include:
796 - "projected_gradients": Projected gradients optimization.
797 - "frank_wolfe": Frank-Wolfe optimization.
798 - "regularized_nnls": Regularized non-negative least squares optimization.
799 weight : str, optional
800 The weighting method for the data. If not provided, the default from the AA class is used.
801 Options include:
802 - "bisquare": Bisquare weighting.
803 max_iter : int, optional
804 The maximum number of iterations for the optimization. If not provided, the default from the AA class is used.
805 derivative_max_iter : int, optional
806 The maximum number of iterations for derivative computation. If not provided, the default from the AA class is used.
807 tol : float, optional
808 The tolerance for convergence. If not provided, the default from the AA class is used.
809 verbose : bool, optional
810 Whether to print verbose output during fitting. If not provided, the default from the AA class is used.
811 save_to_anndata : bool, optional (default=True)
812 Whether to save the results to the AnnData object. If `adata` is not an AnnData object, this is ignored.
813 archetypes_only : bool, optional (default=True)
814 Whether to return only the archetypes matrix. If `save_to_anndata` is True, this parameter determines
815 whether only the archetypes are saved to the AnnData object.
817 Returns
818 -------
819 Optional[Union[np.ndarray, Tuple[np.ndarray, np.ndarray, np.ndarray, float, float]]]
820 The output depends on the values of `save_to_anndata` and `archetypes_only`:
821 - If `archetypes_only` is True:
822 - Only the archetype matrix (Z) is returned/ saved
823 - If `archetypes_only` is True:
824 - returns/ saves a tuple containing:
825 - A: The matrix of weights for the data points (n_samples, n_archetypes).
826 - B: The matrix of weights for the archetypes (n_archetypes, n_samples).
827 - Z: The archetypes matrix (n_archetypes, n_features).
828 - RSS: The residual sum of squares.
829 - varexpl: The variance explained by the model.
830 - If `save_to_anndata` is True:
831 - Returns `None`. Results are saved to `adata.uns["archetypal_analysis"]`.
832 - If `save_to_anndata` is False:
833 - Returns the results.
834 """
835 # Get the signature of AA.__init__
836 signature = inspect.signature(AA.__init__)
838 # Create a dictionary of parameter names and their default values
839 defaults = {
840 param: signature.parameters[param].default
841 for param in signature.parameters
842 if param != "self" and param != "n_archetypes"
843 }
845 # Use the provided values or fall back to the defaults
846 init = init if init is not None else defaults["init"]
847 optim = optim if optim is not None else defaults["optim"]
848 weight = weight if weight is not None else defaults["weight"]
849 max_iter = max_iter if max_iter is not None else defaults["max_iter"]
850 derivative_max_iter = derivative_max_iter if derivative_max_iter is not None else defaults["derivative_max_iter"]
851 tol = tol if tol is not None else defaults["tol"]
852 verbose = verbose if verbose is not None else defaults["verbose"]
854 # Create the AA model with the specified parameters
855 model = AA(
856 n_archetypes=n_archetypes,
857 init=init,
858 optim=optim,
859 weight=weight,
860 max_iter=max_iter,
861 derivative_max_iter=derivative_max_iter,
862 tol=tol,
863 verbose=verbose,
864 )
866 # Fit the model to the data
867 model.fit(adata)
869 # Save the results to the AnnData object if specified
870 if save_to_anndata:
871 if not isinstance(adata, sc.AnnData):
872 print("No AnnData object found. Returning results")
873 save_to_anndata = False
874 else:
875 model.save_to_anndata(archetypes_only=archetypes_only)
877 # Return based on the flags
878 if save_to_anndata:
879 return None # Results are saved to AnnData, so return nothing
880 elif archetypes_only:
881 return model.archetypes() # Return only the archetypes matrix
882 else:
883 return model.return_all() # Return the full fitted model
886# def bootstrap_variance_k_arr(X, n_bootstrap, k_arr, delta=0, seed=42, **kwargs):
887# assert k_arr.min() > 1
888# bootstrap_var = np.array(
889# [
890# bootstrap_variance_single_k(
891# X, n_bootstrap=n_bootstrap, k=k, delta=delta, seed=seed, **kwargs
892# )
893# for k in k_arr
894# ]
895# )
896# plot_df = pd.DataFrame({"k": k_arr, "var": bootstrap_var})
897# p = (
898# pn.ggplot(plot_df, pn.aes(x="k", y="var"))
899# + pn.geom_point(color="blue")
900# + pn.geom_line(color="blue")
901# + pn.labs(x="Number of Archetypes", y="Mean Variance in Archetype Position")
902# )
903# return p
906# Appendix
908# not sure if this ratio of archeytpe over data variance is useful in any way
909# def compute_var_ratio_vitali(X, Z):
910# # adapted from: https://github.com/vitkl/ParetoTI/blob/510990630da589101c6a8313571c96f7544879da/R/fit_pch.R#L1178
911# data_var = X.var(axis=1)
912# arch_var = Z.var(axis=1)
913# return arch_var / data_var
915# see https://github.com/vitkl/ParetoTI/blob/510990630da589101c6a8313571c96f7544879da/R/fit_pch.R#L257
916# archetypes = Z.A.T
917# print(archetypes.shape)
918#
919## create random matrix where each row sums to 1
920# n_additional = 100
921# np.random.seed(42)
922# rand_mtx = np.random.rand(n_additional, archetypes.shape[0])
923# rand_mtx /= rand_mtx.sum(axis=1)[:, None]
924#
925# additional_archetypes = (rand_mtx @ archetypes)
926# print(additional_archetypes.shape)
927#
928# stacked_archetypes = np.row_stack([archetypes, additional_archetypes])
929# print(stacked_archetypes.shape)
930#
931## https://github.com/vitkl/ParetoTI/blob/510990630da589101c6a8313571c96f7544879da/R/fit_pch.R#L879C39-L879C46
932## Vitali used the "FA" argument, but this is turned on by default I think in the scipy version (FA - report total area and volume),
933## see http://www.qhull.org/html/qhull.htm
934# convhull_archetypes = ConvexHull(stacked_archetypes, qhull_options="QJ")
935# print(convhull_archetypes.volume)