Coverage for partipy/paretoti.py: 86%
143 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-09 10:39 +0200
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-09 10:39 +0200
1import inspect
3import numpy as np
4import pandas as pd
5import scanpy as sc
6from joblib import Parallel, delayed
7from scipy.optimize import linear_sum_assignment
8from scipy.spatial import ConvexHull
9from scipy.spatial.distance import cdist
10from tqdm import tqdm
12from .arch import AA
13from .const import DEFAULT_INIT, DEFAULT_OPTIM
14from .selection import compute_IC
17def set_obsm(adata: sc.AnnData, obsm_key: str, n_dimension: int) -> None:
18 """
19 Sets the `obsm` key and dimensionality to be used as input for archetypal analysis (AA).
21 This function verifies that the specified `obsm_key` exists in `adata.obsm` and that the
22 requested number of dimensions does not exceed the available dimensions in that matrix.
23 The configuration is stored in `adata.uns["aa_config"]`.
25 Parameters
26 ----------
27 adata : sc.AnnData
28 AnnData object containing single-cell data. The specified `obsm_key` should refer to
29 a matrix in `adata.obsm` to be used as input for AA.
31 obsm_key : str
32 Key in `adata.obsm` pointing to the matrix to be used for AA.
34 n_dimension : int
35 Number of dimensions to retain from `adata.obsm[obsm_key]`. Must be less than or equal
36 to the number of columns in that matrix.
38 Returns
39 -------
40 None
41 The AA configuration is stored in `adata.uns["aa_config"]`.
42 """
43 if obsm_key not in adata.obsm:
44 raise ValueError(f"'{obsm_key}' not found in adata.obsm. Available keys are: {list(adata.obsm.keys())}")
46 available_dim = adata.obsm[obsm_key].shape[1]
47 if n_dimension > available_dim:
48 raise ValueError(
49 f"Requested {n_dimension} dimensions from '{obsm_key}', but only {available_dim} are available."
50 )
52 if "aa_config" in adata.uns:
53 print("Warning: 'aa_config' already exists in adata.uns and will be overwritten.")
55 adata.uns["aa_config"] = {
56 "obsm_key": obsm_key,
57 "n_dimension": n_dimension,
58 }
61def _validate_aa_config(adata: sc.AnnData) -> None:
62 """
63 Validates that the AnnData object is properly configured for archetypal analysis (AA).
65 This function checks that:
66 - `adata.uns["aa_config"]` exists,
67 - it contains the keys "obsm_key" and "n_dimension",
68 - the specified `obsm_key` exists in `adata.obsm`,
69 - and that the requested number of dimensions does not exceed the available dimensions.
71 Parameters
72 ----------
73 adata : sc.AnnData
74 AnnData object expected to contain AA configuration in `adata.uns["aa_config"]`.
76 Returns
77 -------
78 None
80 Raises
81 ------
82 ValueError
83 If the configuration is missing, incomplete, or inconsistent with the contents of `adata.obsm`.
84 """
85 if "aa_config" not in adata.uns:
86 raise ValueError("AA configuration not found in `adata.uns['aa_config']`.")
88 config = adata.uns["aa_config"]
90 if not isinstance(config, dict):
91 raise ValueError("`adata.uns['aa_config']` must be a dictionary.")
93 required_keys = {"obsm_key", "n_dimension"}
94 missing = required_keys - config.keys()
95 if missing:
96 raise ValueError(f"Missing keys in `aa_config`: {missing}")
98 obsm_key = config["obsm_key"]
99 n_dimension = config["n_dimension"]
101 if obsm_key not in adata.obsm:
102 raise ValueError(f"'{obsm_key}' not found in `adata.obsm`. Available keys: {list(adata.obsm.keys())}")
104 available_dim = adata.obsm[obsm_key].shape[1]
105 if n_dimension > available_dim:
106 raise ValueError(
107 f"Configured number of dimensions ({n_dimension}) exceeds available dimensions ({available_dim}) in `adata.obsm['{obsm_key}']`."
108 )
111def _validate_aa_results(adata: sc.AnnData) -> None:
112 """
113 Validates that the result from Archetypal Analysis is present in the AnnData object.
115 Parameters
116 ----------
117 adata : sc.AnnData
118 Annotated data matrix.
120 Raises
121 ------
122 ValueError
123 If the archetypal analysis result is not found in `adata.uns["AA_results"]`.
124 """
125 if "AA_results" not in adata.uns:
126 raise ValueError(
127 "Result from Archetypal Analysis not found in `adata.uns['AA_results']`. "
128 "Please run the AA() function first."
129 )
132def var_explained_aa(
133 adata: sc.AnnData,
134 min_a: int = 2,
135 max_a: int = 10,
136 optim: str = DEFAULT_OPTIM,
137 init: str = DEFAULT_INIT,
138 n_jobs: int = -1,
139 **kwargs,
140) -> None:
141 """
142 Compute the variance explained by Archetypal Analysis (AA) for a range of archetypes.
144 This function performs Archetypal Analysis (AA) across a range of archetype counts (`min_a` to `max_a`)
145 on the PCA representation stored in `adata.obsm[obsm_key]`. It stores the explained variance and other
146 diagnostics in `adata.uns["AA_var"]`.
148 Parameters
149 ----------
150 adata: sc.AnnData
151 AnnData object containing adata.obsm["obsm_key"].
152 min_a : int, optional (default=2)
153 Minimum number of archetypes to test.
154 max_a : int, optional (default=10)
155 Maximum number of archetypes to test.
156 optim : str, optional (default=DEFAULT_OPTIM)
157 The optimization function to use for Archetypal Analysis.
158 init : str, optional (default=DEFAULT_INIT)
159 The initialization function to use for Archetypal Analysis.
160 n_jobs : int, optional (default=-1)
161 Number of jobs for parallel computation. `-1` uses all available cores.
162 **kwargs:
163 Additional keyword arguments passed to `AA` class.
165 Returns
166 -------
167 None
168 The results are stored in `adata.uns["AA_var"]` as a DataFrame with the following columns:
169 - `k`: The number of archetypes.
170 - `varexpl`: Variance explained by the AA model with `k` archetypes.
171 - `varexpl_ontop`: Incremental variance explained compared to `k-1` archetypes.
172 - `dist_to_projected`: Distance from each point to its projection on the line connecting the first and last points
173 in the variance curve, used to identify "elbow points".
174 """
175 # input validation
176 _validate_aa_config(adata=adata)
177 if min_a < 2:
178 raise ValueError("`min_a` must be at least 2.")
179 if max_a < min_a:
180 raise ValueError("`max_a` must be greater than or equal to `min_a`.")
182 obsm_key = adata.uns["aa_config"]["obsm_key"]
183 n_dimensions = adata.uns["aa_config"]["n_dimension"]
184 X = adata.obsm[obsm_key][:, :n_dimensions]
186 k_arr = np.arange(min_a, max_a + 1)
188 # Parallel computation of AA
189 def _compute_archeptyes(k):
190 A, B, Z, RSS, varexpl = AA(n_archetypes=k, optim=optim, init=init, **kwargs).fit(X).return_all()
191 return k, {"Z": Z, "A": A, "B": B, "RSS": RSS, "varexpl": varexpl}
193 if n_jobs == 1:
194 results_list = [_compute_archeptyes(k) for k in k_arr]
195 else:
196 results_list = Parallel(n_jobs=n_jobs)(delayed(_compute_archeptyes)(k) for k in k_arr)
198 # results = {k: result for k, result in results_list}
199 results = dict(results_list) # faster, and see https://docs.astral.sh/ruff/rules/unnecessary-comprehension/
201 IC_values = []
202 for n_archetypes in k_arr:
203 X_tilde = results[n_archetypes]["A"] @ results[n_archetypes]["Z"]
204 IC_values.append(compute_IC(X=X, X_tilde=X_tilde, n_archetypes=n_archetypes))
206 varexpl_values = np.array([results[k]["varexpl"] for k in k_arr])
208 result_df = pd.DataFrame(
209 {
210 "k": k_arr,
211 "IC": IC_values,
212 "varexpl": varexpl_values,
213 "varexpl_ontop": np.insert(np.diff(varexpl_values), 0, varexpl_values[0]),
214 }
215 )
217 # Compute the distance of the explained variance to its projection
218 offset_vec = result_df[["k", "varexpl"]].iloc[0].values
219 proj_vec = (result_df[["k", "varexpl"]].values - offset_vec)[-1, :][:, None]
220 proj_mtx = proj_vec @ np.linalg.inv(proj_vec.T @ proj_vec) @ proj_vec.T
221 proj_val = (proj_mtx @ (result_df[["k", "varexpl"]].values - offset_vec).T).T + offset_vec
222 proj_df = pd.DataFrame(proj_val, columns=["k", "varexpl"])
223 result_df["dist_to_projected"] = np.linalg.norm(
224 result_df[["k", "varexpl"]].values - proj_df[["k", "varexpl"]].values, axis=1
225 )
227 adata.uns["AA_var"] = result_df
230def bootstrap_aa(
231 adata: sc.AnnData,
232 n_bootstrap: int,
233 n_archetypes: int,
234 optim: str = DEFAULT_OPTIM,
235 init: str = DEFAULT_INIT,
236 seed: int = 42,
237 save_to_anndata: bool = True,
238 n_jobs: int = -1,
239 **kwargs,
240) -> None | pd.DataFrame:
241 """
242 Perform bootstrap sampling to compute archetypes and assess their stability.
244 This function generates bootstrap samples from the data, computes archetypes for each sample,
245 aligns them with the reference archetypes, and stores the results in `adata.uns["AA_bootstrap"]`.
246 It allows assessing the stability of the archetypes across multiple bootstrap iterations.
248 Parameters
249 ----------
250 adata : sc.AnnData
251 The AnnData object containing the data to fit the archetypes. The data should be available in
252 `adata.obsm[obsm_key]`.
253 n_bootstrap : int
254 The number of bootstrap samples to generate.
255 n_archetypes : int
256 The number of archetypes to compute for each bootstrap sample.
257 optim : str, optional (default=DEFAULT_OPTIM)
258 The optimization function to use for Archetypal Analysis.
259 init : str, optional (default=DEFAULT_INIT)
260 The initialization function to use for Archetypal Analysis.
261 seed : int, optional (default=42)
262 The random seed for reproducibility.
263 n_jobs : int, optional (default=-1)
264 The number of jobs to run in parallel. `-1` uses all available cores.
265 **kwargs:
266 Additional keyword arguments passed to `AA` class.
268 Returns
269 -------
270 None
271 The results are stored in `adata.uns["AA_bootstrap"]` as a DataFrame with the following columns:
272 - `pc_i`: The coordinates of the archetypes in the i-th principal component.
273 - `archetype`: The archetype index.
274 - `iter`: The bootstrap iteration index (0 for the reference archetypes).
275 - `reference`: A boolean indicating whether the archetype is from the reference model.
276 - `mean_variance`: The mean variance of all archetype coordinates across bootstrap samples.
277 - `variance_per_archetype`: The mean variance of each archetype coordinates across bootstrap samples.
278 """
279 # input validation
280 _validate_aa_config(adata=adata)
282 obsm_key = adata.uns["aa_config"]["obsm_key"]
283 n_dimensions = adata.uns["aa_config"]["n_dimension"]
284 X = adata.obsm[obsm_key][:, :n_dimensions]
286 n_samples, n_features = X.shape
287 rng = np.random.default_rng(seed)
289 # Reference archetypes
290 ref_Z = AA(n_archetypes=n_archetypes, optim=optim, init=init, **kwargs).fit(X).Z
292 # Generate bootstrap samples
293 idx_bootstrap = rng.choice(n_samples, size=(n_bootstrap, n_samples), replace=True)
295 # Define function for parallel computation
296 def compute_bootstrap_z(idx):
297 return AA(n_archetypes=n_archetypes, optim=optim, init=init, **kwargs).fit(X[idx, :]).Z
299 # Parallel computation of AA on bootstrap samples
300 Z_list = Parallel(n_jobs=n_jobs)(delayed(compute_bootstrap_z)(idx) for idx in idx_bootstrap)
302 # Align archetypes
303 Z_list = [_align_archetypes(ref_arch=ref_Z.copy(), query_arch=query_Z.copy()) for query_Z in Z_list]
305 # Compute variance per archetype
306 Z_stack = np.stack(Z_list)
307 var_per_archetype = Z_stack.var(axis=0).mean(axis=1)
308 mean_variance = var_per_archetype.mean()
310 # Create result dataframe
311 bootstrap_data = [
312 pd.DataFrame(Z, columns=[f"x{i}" for i in range(n_features)]).assign(
313 archetype=np.arange(n_archetypes), iter=i + 1
314 )
315 for i, Z in enumerate(Z_list)
316 ]
317 bootstrap_df = pd.concat(bootstrap_data)
319 df = pd.DataFrame(ref_Z, columns=[f"x{i}" for i in range(n_features)])
320 df["archetype"] = np.arange(n_archetypes)
321 df["iter"] = 0
323 bootstrap_df = pd.concat((bootstrap_df, df), axis=0)
324 bootstrap_df["reference"] = bootstrap_df["iter"] == 0
325 bootstrap_df["archetype"] = pd.Categorical(bootstrap_df["archetype"])
327 bootstrap_df["mean_variance"] = mean_variance
329 archetype_variance_map = dict(zip(np.arange(n_archetypes), var_per_archetype, strict=False))
330 bootstrap_df["variance_per_archetype"] = bootstrap_df["archetype"].astype(int).map(archetype_variance_map)
332 if save_to_anndata:
333 adata.uns["AA_bootstrap"] = bootstrap_df
334 return None
335 else:
336 return bootstrap_df
339def bootstrap_aa_multiple_k(
340 adata: sc.AnnData,
341 n_bootstrap: int = 30,
342 n_archetypes_list=None,
343 save_to_anndata: bool = True,
344 n_jobs: int = -1,
345 **kwargs,
346):
347 """
348 Perform bootstrap sampling across multiple numbers of archetypes to assess stability.
350 This function repeatedly applies bootstrap sampling and Archetypal Analysis (AA) for different
351 numbers of archetypes, aggregates the archetype stability metrics, and allows for evaluating
352 how stability varies with model complexity.
354 Parameters
355 ----------
356 adata : sc.AnnData
357 AnnData object containing the data. The PCA data should be stored in `adata.obsm["X_pca"]`.
358 n_bootstrap : int, optional (default=30)
359 The number of bootstrap samples to generate for each number of archetypes.
360 n_archetypes_list : list of int, optional (default=range(2, 8))
361 A list specifying the numbers of archetypes to evaluate.
362 save_to_anndata : bool, optional (default=True)
363 Whether to save the results to `adata.uns["AA_boostrap_multiple_k"]`. If `False`, the
364 result is returned.
365 **kwargs:
366 Additional keyword arguments passed to `AA` class.
368 Returns
369 -------
370 None or pd.DataFrame
371 If `save_to_anndata=True`, results are stored in `adata.uns["AA_boostrap_multiple_k"]` as a
372 DataFrame with the following columns:
373 - `archetype`: The archetype index.
374 - `variance_per_archetype`: The mean variance of each archetype's coordinates across bootstrap samples.
375 - `n_archetypes`: The number of archetypes used for the corresponding bootstrap analysis.
377 If `save_to_anndata=False`, the DataFrame is returned.
378 """
379 if n_archetypes_list is None:
380 n_archetypes_list = list(range(2, 8))
382 df_list = []
383 for k in n_archetypes_list:
384 boostrap_df = bootstrap_aa(
385 adata=adata, n_bootstrap=n_bootstrap, n_archetypes=k, save_to_anndata=False, n_jobs=n_jobs, **kwargs
386 )
387 boostrap_df["n_archetypes"] = k # type: ignore[index]
388 df_list.append(boostrap_df)
389 df = pd.concat(df_list, axis=0)
390 df = df[["archetype", "variance_per_archetype", "n_archetypes"]].drop_duplicates()
391 if save_to_anndata:
392 adata.uns["AA_boostrap_multiple_k"] = df
393 else:
394 return df
397def _project_on_affine_subspace(X, Z) -> np.ndarray: # pragma: no cover
398 """
399 Projects a set of points X onto the affine subspace spanned by the vertices Z.
401 Parameters
402 ----------
403 X : numpy.ndarray
404 A (D x n) array of n points in D-dimensional space to be projected.
405 Z : numpy.ndarray
406 A (D x k) array of k vertices (archetypes) defining the affine subspace in D-dimensional space.
408 Returns
409 -------
410 proj_coord : numpy.ndarray
411 The coordinates of the projected points in the subspace defined by Z.
412 """
413 D, k = Z.shape
415 # Compute the projection vectors (basis for the affine subspace)
416 if k == 2:
417 # For a line (k=2), the projection vector is simply the difference between the two vertices
418 proj_vec = (Z[:, 1] - Z[:, 0])[:, None]
419 else:
420 # For higher dimensions, compute the projection vectors relative to the first vertex
421 proj_vec = Z[:, 1:] - Z[:, 0][:, None]
423 # Compute the coordinates of the projected points in the subspace
424 proj_coord = np.linalg.inv(proj_vec.T @ proj_vec) @ proj_vec.T @ (X - Z[:, 0][:, None])
426 return proj_coord
429def _compute_t_ratio(X: np.ndarray, Z: np.ndarray) -> float: # pragma: no cover
430 """
431 Compute the t-ratio: volume(polytope defined by Z) / volume(convex hull of X)
433 Parameters
434 ----------
435 X : np.ndarray, shape (n_samples, n_features)
436 Data matrix.
437 Z : np.ndarray, shape (n_archetypes, n_features)
438 Archetypes matrix.
440 Returns
441 -------
442 float
443 The t-ratio.
444 """
445 D, k = X.shape[1], Z.shape[0]
447 if k < 2:
448 raise ValueError("At least 2 archetypes are required (k >= 2).")
450 if k < D + 1:
451 proj_X = _project_on_affine_subspace(X.T, Z.T).T
452 proj_Z = _project_on_affine_subspace(Z.T, Z.T).T
453 convhull_volume = ConvexHull(proj_X).volume
454 polytope_volume = ConvexHull(proj_Z).volume
455 else:
456 convhull_volume = ConvexHull(X).volume
457 polytope_volume = ConvexHull(Z).volume
459 return polytope_volume / convhull_volume
462def compute_t_ratio(adata) -> float | None: # pragma: no cover
463 """
464 Compute the t-ratio from either an AnnData object or raw matrices.
466 Parameters
467 ----------
468 adata : sc.AnnData
469 If AnnData: must contain `obsm[obsm_key]` and `uns["AA_results"]["Z"]`.
471 Returns
472 -------
473 Optional[float]
474 - If input is AnnData, result is stored in `X.uns["t_ratio"]`.
475 - If input is ndarray, result is returned as float.
476 """
477 # input validation
478 _validate_aa_config(adata=adata)
479 if "AA_results" not in adata.uns or "Z" not in adata.uns["AA_results"]:
480 raise ValueError("Missing archetypes in `adata.uns['AA_results']['Z']`.")
482 obsm_key = adata.uns["aa_config"]["obsm_key"]
483 n_dimensions = adata.uns["aa_config"]["n_dimension"]
484 X = adata.obsm[obsm_key][:, :n_dimensions]
485 Z = adata.uns["AA_results"]["Z"]
486 t_ratio = _compute_t_ratio(X, Z)
487 adata.uns["t_ratio"] = t_ratio
488 return None
491def t_ratio_significance(adata, iter=1000, seed=42, n_jobs=-1): # pragma: no cover
492 """
493 Assesses the significance of the polytope spanned by the archetypes by comparing the t-ratio of the original data to t-ratios computed from randomized datasets.
495 Parameters
496 ----------
497 adata : sc.AnnData
498 An AnnData object containing `adata.obsm["X_pca"]` and `adata.uns["aa_config"]["n_dimension"], optionally `adata.uns["t_ratio"]`. If `adata.uns["t_ratio"]` doesnt exist it is called and computed.
499 iter : int, optional (default=1000)
500 Number of randomized datasets to generate.
501 seed : int, optional (default=42)
502 The random seed for reproducibility.
503 n_jobs : int, optional
504 Number of jobs for parallelization (default: 1). Use -1 to use all available cores.
506 Returns
507 -------
508 float
509 The proportion of randomized datasets with a t-ratio greater than the original t-ratio (p-value).
510 """
511 # input validation
512 _validate_aa_config(adata=adata)
514 if "t_ratio" not in adata.uns:
515 print("Computing t-ratio...")
516 compute_t_ratio(adata)
518 obsm_key = adata.uns["aa_config"]["obsm_key"]
519 n_dimensions = adata.uns["aa_config"]["n_dimension"]
520 X = adata.obsm[obsm_key][:, :n_dimensions]
522 t_ratio = adata.uns["t_ratio"]
523 n_samples, n_features = X.shape
524 n_archetypes = adata.uns["AA_results"]["Z"].shape[0]
526 rng = np.random.default_rng(seed)
528 def compute_randomized_t_ratio():
529 # Shuffle each feature independently
530 SimplexRand1 = np.array([rng.permutation(X[:, i]) for i in range(n_features)]).T
531 # Compute archetypes and t-ratio for randomized data
532 Z_mix = AA(n_archetypes=n_archetypes).fit(SimplexRand1).Z
533 return _compute_t_ratio(SimplexRand1, Z_mix)
535 # Parallelize the computation of randomized t-ratios
536 RandRatio = Parallel(n_jobs=n_jobs)(
537 delayed(compute_randomized_t_ratio)() for _ in tqdm(range(iter), desc="Randomizing")
538 )
540 # Calculate the p-value
541 p_value = np.sum(np.array(RandRatio) > t_ratio) / iter
542 return p_value
545def t_ratio_significance_shuffled(adata, iter=1000, seed=42, n_jobs=-1): # pragma: no cover
546 """
547 Assesses the significance of the polytope spanned by the archetypes by comparing the t-ratio of the original data to t-ratios computed from randomized datasets.
549 Parameters
550 ----------
551 adata : sc.AnnData
552 An AnnData object containing `adata.obsm["X_pca"]` and `adata.uns["aa_config"]["n_dimension"], optionally `adata.uns["t_ratio"]`. If `adata.uns["t_ratio"]` doesnt exist it is called and computed.
553 iter : int, optional (default=1000)
554 Number of randomized datasets to generate.
555 seed : int, optional (default=42)
556 The random seed for reproducibility.
557 n_jobs : int, optional
558 Number of jobs for parallelization (default: 1). Use -1 to use all available cores.
560 Returns
561 -------
562 float
563 The proportion of randomized datasets with a t-ratio greater than the original t-ratio (p-value).
564 """
565 # input validation
566 _validate_aa_config(adata=adata)
568 if "t_ratio" not in adata.uns:
569 print("Computing t-ratio...")
570 compute_t_ratio(adata)
572 obsm_key = adata.uns["aa_config"]["obsm_key"]
573 n_dimensions = adata.uns["aa_config"]["n_dimension"]
574 X = adata.obsm[obsm_key][:, :n_dimensions]
576 t_ratio = adata.uns["t_ratio"]
577 n_samples, n_features = X.shape
578 n_archetypes = adata.uns["AA_results"]["Z"].shape[0]
580 rng = np.random.default_rng(seed)
582 def compute_randomized_t_ratio():
583 # Shuffle each feature independently
584 SimplexRand1 = np.array([rng.permutation(X[:, i]) for i in range(n_features)]).T
585 SimplexRand1_pca = sc.pp.pca(SimplexRand1, n_comps=adata.uns["aa_config"]["n_dimension"])
586 # Compute archetypes and t-ratio for randomized data
587 Z_mix = AA(n_archetypes=n_archetypes).fit(SimplexRand1_pca).Z
588 return _compute_t_ratio(SimplexRand1_pca, Z_mix)
590 # Parallelize the computation of randomized t-ratios
591 RandRatio = Parallel(n_jobs=n_jobs)(
592 delayed(compute_randomized_t_ratio)() for _ in tqdm(range(iter), desc="Randomizing")
593 )
595 # Calculate the p-value
596 p_value = np.sum(np.array(RandRatio) > t_ratio) / iter
597 return p_value
600def _align_archetypes(ref_arch: np.ndarray, query_arch: np.ndarray) -> np.ndarray:
601 """
602 Align the archetypes of the query to match the order of archetypes in the reference.
604 This function uses the Euclidean distance between archetypes in the reference and query sets
605 to determine the optimal alignment. The Hungarian algorithm (linear sum assignment) is used
606 to find the best matching pairs, and the query archetypes are reordered accordingly.
608 Parameters
609 ----------
610 ref_arch : np.ndarray
611 A 2D array of shape (n_archetypes, n_features) representing the reference archetypes.
612 query_arch : np.ndarray
613 A 2D array of shape (n_archetypes, n_features) representing the query archetypes.
615 Returns
616 -------
617 np.ndarray
618 A 2D array of shape (n_archetypes, n_features) containing the reordered query archetypes.
619 """
620 # Compute pairwise Euclidean distances
621 euclidean_d = cdist(ref_arch, query_arch.copy(), metric="euclidean")
623 # Find the optimal assignment using the Hungarian algorithm
624 ref_idx, query_idx = linear_sum_assignment(euclidean_d)
626 return query_arch[query_idx, :]
629def compute_archetypes(
630 adata: sc.AnnData,
631 n_archetypes: int,
632 init: str | None = None,
633 optim: str | None = None,
634 weight: None | str = None,
635 max_iter: int | None = None,
636 rel_tol: float | None = None,
637 verbose: bool | None = None,
638 seed: int = 42,
639 save_to_anndata: bool = True,
640 archetypes_only: bool = True,
641 **optim_kwargs,
642) -> np.ndarray | tuple[np.ndarray, np.ndarray, np.ndarray, list[float] | np.ndarray, float] | None:
643 """
645 Perform Archetypal Analysis (AA) on the input data.
647 This function is a wrapper around the AA class, offering a simplified interface for fitting the model
648 and returning the results, or saving them to the AnnData object. It allows users to customize the
649 archetype computation with various parameters for initialization, optimization, convergence, and output.
651 Parameters
652 ----------
653 adata : sc.AnnData
654 The AnnData object containing the data to fit the archetypes. The data should be available in
655 `adata.obsm[obsm_key]`.
656 n_archetypes : int
657 The number of archetypes to compute.
658 init : str, optional
659 The initialization method for the archetypes. If not provided, the default from the AA class is used.
660 Options include:
661 - "uniform": Uniform initialization.
662 - "furthest_sum": Furthest sum initialization (default).
663 - "plus_plus": Archetype++ initialization.
664 optim : str, optional
665 The optimization method for fitting the model. If not provided, the default from the AA class is used.
666 Options include:
667 - "projected_gradients": Projected gradients optimization.
668 - "frank_wolfe": Frank-Wolfe optimization.
669 - "regularized_nnls": Regularized non-negative least squares optimization.
670 weight : str, optional
671 The weighting method for the data. If not provided, the default from the AA class is used.
672 Options include:
673 - None : default
674 - "bisquare": Bisquare weighting.
675 - "huber": Hunber weighting.
676 max_iter : int, optional
677 The maximum number of iterations for the optimization. If not provided, the default from the AA class is used.
678 rel_tol : float, optional
679 The rel_tol tolerance for convergence. If not provided, the default from the AA class is used.
680 verbose : bool, optional
681 Whether to print verbose output during fitting. If not provided, the default from the AA class is used.
682 seed : int, optional
683 The random seed for reproducibility.
684 save_to_anndata : bool, optional (default=True)
685 Whether to save the results to the AnnData object. If False, the results are returned as a tuple. If
686 `adata` is not an AnnData object, this is ignored.
687 archetypes_only : bool, optional (default=True)
688 Whether to save/return only the archetypes matrix `Z` (if det to True) or also the full outputs, including
689 the matrices `A`, `B`, `RSS`, and variance explained `varexpl`.
690 optim_kwargs: TODO
692 Returns
693 -------
694 Optional[Union[np.ndarray, Tuple[np.ndarray, np.ndarray, np.ndarray, float, float]]]
695 The output depends on the values of `save_to_anndata` and `archetypes_only`:
696 - If `archetypes_only` is True:
697 - Only the archetype matrix (Z) is returned/ saved
698 - If `archetypes_only` is False:
699 - returns/ saves a tuple containing:
700 - A: The matrix of weights for the data points (n_samples, n_archetypes).
701 - B: The matrix of weights for the archetypes (n_archetypes, n_samples).
702 - Z: The archetypes matrix (n_archetypes, n_features).
703 - RSS: The residual sum of squares.
704 - varexpl: The variance explained by the model.
705 - If `save_to_anndata` is True:
706 - Returns `None`. Results are saved to `adata.uns["AA_results"]`.
707 - If `save_to_anndata` is False:
708 - Returns the results.
709 """
710 # input validation
711 _validate_aa_config(adata=adata)
713 # Get the signature of AA.__init__
714 signature = inspect.signature(AA.__init__)
716 # Create a dictionary of parameter names and their default values
717 defaults = {
718 param: signature.parameters[param].default
719 for param in signature.parameters
720 if param != "self" and param != "n_archetypes"
721 }
723 # Use the provided values or fall back to the defaults
724 init = init if init is not None else defaults["init"]
725 optim = optim if optim is not None else defaults["optim"]
726 weight = weight if weight is not None else defaults["weight"]
727 max_iter = max_iter if max_iter is not None else defaults["max_iter"]
728 rel_tol = rel_tol if rel_tol is not None else defaults["rel_tol"]
729 verbose = verbose if verbose is not None else defaults["verbose"]
731 # Create the AA model with the specified parameters
732 model = AA(
733 n_archetypes=n_archetypes,
734 init=init,
735 optim=optim,
736 weight=weight,
737 max_iter=max_iter,
738 rel_tol=rel_tol,
739 verbose=verbose,
740 seed=seed,
741 **optim_kwargs,
742 )
744 # Extract the data matrix used to fit the archetypes
745 obsm_key = adata.uns["aa_config"]["obsm_key"]
746 n_dimensions = adata.uns["aa_config"]["n_dimension"]
747 X = adata.obsm[obsm_key][:, :n_dimensions]
748 X = X.astype(np.float32)
750 # Fit the model to the data
751 model.fit(X)
753 # Save the results to the AnnData object if specified
754 if save_to_anndata:
755 if archetypes_only:
756 adata.uns["AA_results"] = {
757 "Z": model.Z,
758 }
759 else:
760 adata.uns["AA_results"] = {
761 "A": model.A,
762 "B": model.B,
763 "Z": model.Z,
764 "RSS": model.RSS_trace,
765 "varexpl": model.varexpl,
766 }
767 return None
768 else:
769 if archetypes_only:
770 return model.Z
771 else:
772 return model.A, model.B, model.Z, model.RSS_trace, model.varexpl