Coverage for ParTIpy/arch.py: 12%
107 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-16 10:22 +0100
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-16 10:22 +0100
1"""
2Class for archetypal analysis
4Note: notation used X ≈ A · B · X = A · Z
6Code adapted from https://github.com/atmguille/archetypal-analysis (by Guillermo García Cobo)
7"""
9import numpy as np
10import scanpy as sc
12from .const import (
13 DEFAULT_INIT,
14 DEFAULT_OPTIM,
15 DEFAULT_WEIGHT,
16 INIT_ALGS,
17 OPTIM_ALGS,
18 WEIGHT_ALGS,
19)
20from .initialize import _furthest_sum_init, _random_init
21from .optim import (
22 _compute_A_frank_wolfe,
23 _compute_A_projected_gradients,
24 _compute_A_regularized_nnls,
25 _compute_B_frank_wolfe,
26 _compute_B_projected_gradients,
27 _compute_B_regularized_nnls,
28)
29from .weights import compute_bisquare_weights
32class AA:
33 """
34 TODO: Write docstring here
35 ...
36 """
38 def __init__(
39 self,
40 n_archetypes: int,
41 init: str = DEFAULT_INIT,
42 optim: str = DEFAULT_OPTIM,
43 weight: None | str = DEFAULT_WEIGHT,
44 max_iter: int = 100,
45 derivative_max_iter: int = 100,
46 tol: float = 1e-6,
47 verbose: bool = False,
48 ):
49 self.n_archetypes = n_archetypes
50 self.init = init
51 self.optim = optim
52 self.weight = weight
53 self.max_iter = max_iter
54 self.deriv_max_iter = derivative_max_iter
55 self.tol = tol
56 self.verbose = verbose
57 self.A = None
58 self.B = None
59 self.Z = None # Archetypes
60 self.muA, self.muB = None, None
61 self.n_samples, self.n_features = None, None
62 self.RSS = None
63 self.RSS_trace: list[float] = []
64 self.varexpl = None
65 self.adata = None
67 # checks
68 assert self.init in INIT_ALGS
69 assert self.optim in OPTIM_ALGS
70 assert self.weight in WEIGHT_ALGS
72 def fit(self, X: np.ndarray):
73 """
74 Computes the archetypes and the RSS from the data X, which are stored
75 in the corresponding attributes
76 :param X: data matrix, with shape (n_samples, n_features)
77 :return: self
78 """
79 if isinstance(X, sc.AnnData):
80 if "X_pca_reduced" not in X.obsm:
81 raise ValueError(
82 "X_pca_reduced not in AnnData object. Please use reduce_pca() to add it to the AnnData object."
83 )
84 self.adata = X
85 X = X.obsm["X_pca_reduced"]
87 self.n_samples, self.n_features = X.shape
89 # ensure C-contiguous format for numba
90 X = np.ascontiguousarray(X)
92 # set the initalization function
93 if self.init == "random":
94 initialize_B = _random_init
95 elif self.init == "furthest_sum":
96 initialize_B = _furthest_sum_init
97 else:
98 raise NotImplementedError()
100 # set the optimization functions
101 if self.optim == "regularized_nnls":
102 compute_A = _compute_A_regularized_nnls
103 compute_B = _compute_B_regularized_nnls
104 elif self.optim == "projected_gradients":
105 compute_A = _compute_A_projected_gradients # type: ignore[assignment]
106 compute_B = _compute_B_projected_gradients # type: ignore[assignment]
107 elif self.optim == "frank_wolfe":
108 compute_A = _compute_A_frank_wolfe # type: ignore[assignment]
109 compute_B = _compute_B_frank_wolfe # type: ignore[assignment]
110 else:
111 raise NotImplementedError()
113 # set the weight function
114 if self.weight:
115 if self.weight == "bisquare":
116 compute_weights = compute_bisquare_weights
117 else:
118 raise NotImplementedError()
120 # initialize B and the archetypes Z
121 B = initialize_B(X=X, n_archetypes=self.n_archetypes)
122 Z = B @ X
124 # randomly initialize A
125 A = -np.log(np.random.random((self.n_samples, self.n_archetypes)))
126 A /= np.sum(A, axis=1, keepdims=True)
128 TSS = np.sum(X * X)
129 prev_RSS = None
131 W = np.ones(X.shape[0]) if self.weight else None
133 for _ in range(self.max_iter):
134 X_w = np.diag(W) @ X if self.weight else X
135 A = compute_A(X_w, Z, A, self.deriv_max_iter)
136 B = compute_B(X_w, A, B, self.deriv_max_iter)
137 Z = B @ X_w
139 # compute residuals using the original data
140 A_0 = compute_A(X, Z, A, self.deriv_max_iter) if self.weight else A
141 R = X - A_0 @ Z
142 W = compute_weights(R) if self.weight else None
144 RSS = np.linalg.norm(R) ** 2
145 if (prev_RSS is not None) and ((abs(prev_RSS - RSS) / prev_RSS) < self.tol):
146 break
147 prev_RSS = RSS
148 self.RSS_trace.append(float(RSS))
150 # Recalculate A and B using the unweighted data
151 if self.weight:
152 A = compute_A(X, Z, A, self.deriv_max_iter)
153 B = compute_B(X, A, B, self.deriv_max_iter)
154 Z = B @ X
155 RSS = np.linalg.norm(X - A @ Z) ** 2
157 self.Z = Z
158 self.A = A
159 self.B = B
160 self.RSS = RSS
161 self.RSS_trace = np.array(self.RSS_trace)
162 self.varexpl = (TSS - RSS) / TSS
163 return self
165 def archetypes(self) -> np.ndarray:
166 """
167 Returns the archetypes' matrix
168 :return: archetypes matrix, with shape (n_archetypes, n_features)
169 """
170 return self.Z
172 def transform(self, X: np.ndarray) -> np.ndarray:
173 """
174 Computes the best convex approximation A of X by the archetypes Z
175 :param X: data matrix, with shape (n_samples, n_features)
176 :return: A matrix, with shape (n_samples, n_archetypes)
177 """
178 if self.optim == "regularized_nnls":
179 return _compute_A_regularized_nnls(X, self.Z)
180 elif self.optim == "projected_gradients":
181 A_random = -np.log(np.random.random((self.n_samples, self.n_archetypes)))
182 A_random /= np.sum(A_random, axis=1, keepdims=True)
183 return _compute_A_projected_gradients(X=X, Z=self.Z, A=A_random)
184 elif self.optim == "frank_wolfe":
185 A_random = -np.log(np.random.random((self.n_samples, self.n_archetypes)))
186 A_random /= np.sum(A_random, axis=1, keepdims=True)
187 return _compute_A_frank_wolfe(X, self.Z, A=A_random)
188 else:
189 raise NotImplementedError()
191 def return_all(self) -> tuple:
192 """Return optimized matrices: A, B, Z, and fitting stats: RSS, varexpl."""
193 return self.A, self.B, self.Z, self.RSS, self.varexpl
195 def save_to_anndata(self):
196 """Saves the results (A, B, Z, RSS, varexpl) to the AnnData object provided in fit()."""
197 if self.adata is None:
198 raise ValueError("No AnnData object found. Please provide an AnnData object to fit().")
200 self.adata.uns["archetypal_analysis"] = {
201 "A": self.A,
202 "B": self.B,
203 "Z": self.Z,
204 "RSS": self.RSS,
205 "varexpl": self.varexpl,
206 }