Coverage for partipy/arch.py: 82%
176 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-09 10:41 +0200
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-09 10:41 +0200
1"""
2Class for archetypal analysis
4Note: notation used X ≈ A B X = A Z
6Code adapted from https://github.com/atmguille/archetypal-analysis (by Guillermo García Cobo)
7"""
9import numpy as np
11from .const import (
12 DEFAULT_INIT,
13 DEFAULT_OPTIM,
14 DEFAULT_WEIGHT,
15 INIT_ALGS,
16 MIN_ITERATIONS,
17 OPTIM_ALGS,
18 WEIGHT_ALGS,
19)
20from .coreset import construct_coreset, construct_lightweight_coreset, construct_uniform_coreset
21from .initialize import _init_A, _init_furthest_sum, _init_plus_plus, _init_uniform
22from .optim import (
23 _compute_A_frank_wolfe,
24 _compute_A_projected_gradients,
25 _compute_A_regularized_nnls,
26 _compute_B_frank_wolfe,
27 _compute_B_projected_gradients,
28 _compute_B_regularized_nnls,
29 _compute_RSS_AZ,
30)
31from .weights import compute_bisquare_weights, compute_huber_weights
34class AA:
35 """
36 Archetypal Analysis approximates data points as a convex combination of a set of archetypes, which are themselves convex combinations of the data points.
37 The goal is to find the best approximation for a given number of archetypes, representing the structure of the data in a lower-dimensional space.
39 The model is defined as follows:
40 X ≈ A B X = A Z
42 where:
43 - X is the data point matrix.
44 - A is the coefficient matrix mapping each data point to a convex combination of archetypes.
45 - B is the coefficient matrix mapping each archetype to a convex combination of data points.
46 - Z = B X is the matrix containing the archetypes coordinates.
48 The optimization problem minimalizes the residual sum of squares (RSS)
49 RSS = ||X - A Z||^2
50 subject to the constraints that A and B are non-negative and their rows sum to 1, ensuring convex combinations.
52 Parameters
53 ----------
54 n_archetypes : int
55 Number of archetypes to compute.
56 init : str, optional (default="furthest_sum)
57 Initialization method for the archetypes. Options are:
58 - "random": Random initialization.
59 - "furthest_sum": Utilizes the furthest sum algorithm (recommended).
60 optim: str, optional (default="projected_gradients")
61 Optimization algorithm to use. Options are:
62 - "regularized_nnls": Regularized non-negative least squares.
63 - "projected_gradients": Projected gradient descent (PCHA).
64 - "frank_wolfe": Frank-Wolfe algorithm.
65 weight : str or None, optional (default: None)
66 Weighting scheme for robust archetypal analysis. Options:
67 - None: No weighting.
68 - "bisquare": Bisquare weighting.
69 max_iter : int, optional (default: 500)
70 Maximum number of iterations for the optimization.
71 tol : float, optional (default: 1e-6)
72 Tolerance for convergence. The optimization stops if the relative change in RSS
73 falls below this threshold.
74 verbose : bool, optional (default: False)
75 If True, print progress during optimization.
76 seed : int, optional (default: 42)
77 Random seed
78 optim_kwargs : arguments that are passed to compute_A and compute_B
79 """
81 def __init__(
82 self,
83 n_archetypes: int,
84 init: str = DEFAULT_INIT,
85 optim: str = DEFAULT_OPTIM,
86 weight: None | str = DEFAULT_WEIGHT,
87 max_iter: int = 500,
88 rel_tol: float = 1e-4,
89 early_stopping: bool = True,
90 use_coreset: bool = False,
91 coreset_flavor: str = "default",
92 coreset_fraction: float = 0.1,
93 coreset_size: None | int = None,
94 centering: bool = True,
95 scaling: bool = True,
96 verbose: bool = False,
97 seed: int = 42,
98 **optim_kwargs,
99 ):
100 self.n_archetypes = n_archetypes
101 self.init = init
102 self.optim = optim
103 self.weight = weight
104 self.max_iter = max_iter
105 self.rel_tol = rel_tol
106 self.early_stopping = early_stopping
107 self.use_coreset = use_coreset
108 self.coreset_flavor = coreset_flavor
109 self.coreset_fraction = coreset_fraction
110 self.coreset_size = coreset_size
111 self.centering = centering
112 self.scaling = scaling
113 self.verbose = verbose
114 self.seed = seed
115 self.optim_kwargs = optim_kwargs
116 # NOTE: I don't want to use here type annotation np.ndarray: None | np.ndarray
117 # because it makes little sense for downstream type checking
118 self.A: np.ndarray = None # type: ignore[assignment]
119 self.B: np.ndarray = None # type: ignore[assignment]
120 self.Z: np.ndarray = None # type: ignore[assignment]
121 self.n_samples: int = None # type: ignore[assignment]
122 self.n_features: int = None # type: ignore[assignment]
123 self.RSS: float | None = None
124 self.RSS_trace: np.ndarray = np.zeros(max_iter, dtype=np.float32)
125 self.varexpl: float = None # type: ignore[assignment]
126 self.fitting_info: dict
128 # checks
129 if self.init not in INIT_ALGS:
130 raise ValueError(f"Initialization method '{self.init}' is not supported. Must be one of {INIT_ALGS}.")
132 if self.optim not in OPTIM_ALGS:
133 raise ValueError(f"Optimization algorithm '{self.optim}' is not supported. Must be one of {OPTIM_ALGS}.")
135 if self.weight not in WEIGHT_ALGS:
136 raise ValueError(f"Weighting method '{self.weight}' is not supported. Must be one of {WEIGHT_ALGS}.")
138 if self.max_iter < 0:
139 raise ValueError(f"max_iter must be non-negative, got {self.max_iter}.")
141 if self.weight is not None and early_stopping is not False:
142 raise ValueError(
143 "Early stopping must be disabled (early_stopping=False) when using weighted/robust"
144 "archetypal analysis. This is because optimization with weights does not lead to RSS reduction"
145 )
147 if self.use_coreset and self.weight:
148 raise ValueError(
149 "It is not yet implemented to use robust archetypal analysis and coresets at the same time"
150 )
152 def fit(self, X: np.ndarray):
153 """
154 Computes the archetypes and the RSS from the data X, which are stored
155 in the corresponding attributes.
157 Parameters
158 ----------
159 X : np.ndarray
160 Data matrix with shape (n_samples, n_features).
162 Returns
163 -------
164 self : AA
165 The instance of the AA class, with computed archetypes and RSS stored as attributes.
166 """
167 self.n_samples, self.n_features = X.shape
169 # set the initalization function
170 if self.init == "uniform":
171 initialize_B = _init_uniform
172 elif self.init == "furthest_sum":
173 initialize_B = _init_furthest_sum
174 elif self.init == "plus_plus":
175 initialize_B = _init_plus_plus
176 else:
177 raise NotImplementedError()
179 # set the optimization functions
180 if self.optim == "regularized_nnls":
181 compute_A = _compute_A_regularized_nnls
182 compute_B = _compute_B_regularized_nnls
183 elif self.optim == "projected_gradients":
184 compute_A = _compute_A_projected_gradients # type: ignore[assignment]
185 compute_B = _compute_B_projected_gradients # type: ignore[assignment]
186 elif self.optim == "frank_wolfe":
187 compute_A = _compute_A_frank_wolfe # type: ignore[assignment]
188 compute_B = _compute_B_frank_wolfe # type: ignore[assignment]
189 else:
190 raise NotImplementedError()
192 # set the weight function
193 if self.weight:
194 if self.weight == "bisquare":
195 compute_weights = compute_bisquare_weights
196 elif self.weight == "huber":
197 compute_weights = compute_huber_weights
198 else:
199 raise NotImplementedError()
201 # ensure C-contiguous format for numba (plus using np.float32 datatype)
202 X = np.ascontiguousarray(X, dtype=np.float32)
204 # keep the raw X
205 X_raw = X
207 # center X by substracting the feature means
208 if self.centering:
209 feature_means = X.mean(axis=0, keepdims=True)
210 X -= feature_means
212 # scale X globally (needs to happen before we compute weights, otherwise the weights are off)
213 # TODO: Test whether we can also just apply the same scaling to the weights
214 if self.scaling:
215 global_scale = np.linalg.norm(X) / np.sqrt(np.prod(X.shape))
216 X /= global_scale
218 # construct the coreset and initialize A
219 if self.use_coreset:
220 if self.coreset_size is None:
221 self.coreset_size = int(self.n_samples * self.coreset_fraction)
223 if self.coreset_flavor == "default":
224 coreset_indices, W = construct_coreset(X=X, coreset_size=self.coreset_size, seed=self.seed)
225 elif self.coreset_flavor == "lightweight_kmeans":
226 coreset_indices, W = construct_lightweight_coreset(X=X, coreset_size=self.coreset_size, seed=self.seed)
227 elif self.coreset_flavor == "uniform":
228 coreset_indices, W = construct_uniform_coreset(X=X, coreset_size=self.coreset_size, seed=self.seed)
229 else:
230 raise NotImplementedError()
232 if self.verbose:
233 print(f"coreset size = {self.coreset_size} | coreset flavor = {self.coreset_flavor}")
235 X = X[coreset_indices, :].copy() # TODO: probably no copy needed here!
236 A = _init_A(n_samples=self.coreset_size, n_archetypes=self.n_archetypes, seed=self.seed)
238 else:
239 A = _init_A(n_samples=self.n_samples, n_archetypes=self.n_archetypes, seed=self.seed)
241 # initialize B and the archetypes Z
242 B, inital_indices = initialize_B(X=X, n_archetypes=self.n_archetypes, seed=self.seed, return_indices=True)
243 Z = B @ X
245 # initialize weights
246 if self.weight:
247 W = np.ones(X.shape[0], dtype=np.float32)
248 elif self.use_coreset:
249 # if we use coreset we only have to weight X a single time
250 WX = W[:, None] * X # same as np.diag(W) @ X
252 TSS = RSS = np.sum(X * X)
254 convergence_flag = False
255 for n_iter in range(self.max_iter):
256 if self.weight:
257 WX = W[:, None] * X
258 A = compute_A(WX, Z, A, **self.optim_kwargs)
259 B = compute_B(WX, A, B, **self.optim_kwargs)
260 Z = B @ WX
262 # recompute weights based on the original, which are computed using the original data
263 A_0 = compute_A(X, Z, A, **self.optim_kwargs)
264 R = X - A_0 @ Z
265 W = compute_weights(R)
267 elif self.use_coreset:
268 # compute A using the unweighted data X
269 A = compute_A(X=X, Z=Z, A=A, **self.optim_kwargs)
270 WA = W[:, None] * A
271 B = compute_B(X=X, A=WA, B=B, WX=WX, **self.optim_kwargs)
272 Z = B @ X
274 else:
275 A = compute_A(X, Z, A, **self.optim_kwargs)
276 B = compute_B(X, A, B, **self.optim_kwargs)
277 Z = B @ X
279 # compute RSS and check for convergence
280 RSS = _compute_RSS_AZ(X=X, A=A, Z=Z)
281 self.RSS_trace[n_iter] = float(RSS)
282 max_window = min(n_iter, 20)
283 rel_delta_RSS_mean_last_n = (
284 np.mean(
285 (
286 self.RSS_trace[(n_iter - max_window + 1) : (n_iter + 1)]
287 - self.RSS_trace[(n_iter - max_window) : (n_iter)]
288 )
289 / self.RSS_trace[(n_iter - max_window) : (n_iter)]
290 )
291 if n_iter > 0
292 else np.nan
293 )
294 if self.verbose:
295 print(
296 f"\riter: {n_iter} | RSS: {RSS:.3f} | rel_delta_RSS: {rel_delta_RSS_mean_last_n:.6f}",
297 end="",
298 flush=True,
299 )
300 if np.isnan(RSS) or np.isinf(RSS):
301 print("\nWarning: RSS is NaN or Inf. Stopping optimization.")
302 break
304 if (n_iter >= MIN_ITERATIONS) and self.early_stopping:
305 if (rel_delta_RSS_mean_last_n >= 0.0) or (np.abs(rel_delta_RSS_mean_last_n) < self.rel_tol):
306 convergence_flag = True
307 break
308 if self.verbose:
309 message = (
310 f"\nAlgorithm converged after {n_iter} iterations."
311 if convergence_flag
312 else f"\nAlgorithm did not converge after {n_iter} iterations."
313 )
314 print(message)
316 if self.use_coreset:
317 B_full = np.zeros((self.n_archetypes, self.n_samples))
318 for B_col_idx, coreset_idx in enumerate(coreset_indices):
319 B_full[:, coreset_idx] += B[:, B_col_idx]
320 # B_full[:, coreset_indices] = B # this only works in resample is set to false
321 B = B_full
322 Z = B @ X_raw
323 # TODO: change to projected gradients or frank-wolfe here!
324 A = _compute_A_regularized_nnls(X=X_raw, Z=Z, A=None)
326 # If using weights, we need to recalculate A and B using the unweighted data
327 if self.weight:
328 A = compute_A(X, Z, A, **self.optim_kwargs)
329 B = compute_B(X, A, B, **self.optim_kwargs)
330 Z = B @ X
331 RSS = np.linalg.norm(X - A @ Z) ** 2
333 if self.scaling:
334 X *= global_scale
335 Z *= global_scale
337 if self.centering:
338 X += feature_means
339 Z += feature_means
341 self.Z = Z
342 self.A = A
343 self.B = B
344 self.RSS = float(RSS)
345 self.RSS_trace = self.RSS_trace[self.RSS_trace > 0.0]
346 self.varexpl = (TSS - RSS) / TSS
347 self.fitting_info = {
348 "conv": convergence_flag if self.max_iter > 0 else None,
349 "n_iter": n_iter if self.max_iter > 0 else None,
350 "coreset_indices": coreset_indices if self.use_coreset else None,
351 "weights": W if (self.use_coreset or self.weight) else None,
352 "inital_indices": inital_indices,
353 }
354 return self
356 def archetypes(self) -> None | np.ndarray:
357 """
358 Returns the archetypes' matrix.
360 Returns
361 -------
362 np.ndarray or None
363 The archetypes matrix with shape (n_archetypes, n_features),
364 or None if the archetypes have not been computed yet.
365 """
366 return self.Z
368 def transform(self, X: np.ndarray) -> np.ndarray:
369 """
370 Computes the best convex approximation A of X by the archetypes Z.
372 Parameters
373 ----------
374 X : np.ndarray
375 Data matrix with shape (n_samples, n_features).
377 Returns
378 -------
379 np.ndarray
380 The matrix A with shape (n_samples, n_archetypes).
381 """
382 if self.optim == "regularized_nnls":
383 return _compute_A_regularized_nnls(X, self.Z)
384 elif self.optim == "projected_gradients":
385 A_random = -np.log(np.random.random((self.n_samples, self.n_archetypes)))
386 A_random /= np.sum(A_random, axis=1, keepdims=True)
387 return _compute_A_projected_gradients(X=X, Z=self.Z, A=A_random)
388 elif self.optim == "frank_wolfe":
389 A_random = -np.log(np.random.random((self.n_samples, self.n_archetypes)))
390 A_random /= np.sum(A_random, axis=1, keepdims=True)
391 return _compute_A_frank_wolfe(X, self.Z, A=A_random)
392 else:
393 raise NotImplementedError()
395 def return_all(self) -> tuple:
396 """
397 Returns the optimized matrices and fitting statistics.
399 Returns
400 -------
401 tuple
402 A tuple containing:
403 - A : np.ndarray
404 Coefficient matrix with shape (n_samples, n_archetypes).
405 - B : np.ndarray
406 Coefficient matrix with shape (n_archetypes, n_samples).
407 - Z : np.ndarray
408 Archetype matrix with shape (n_archetypes, n_features).
409 - RSS_trace : list[float]
410 Residual sum of squares per iteration.
411 - varexpl : float
412 Variance explained by the model.
413 """
414 return self.A, self.B, self.Z, self.RSS_trace, self.varexpl