Coverage for partipy/coreset.py: 100%
23 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-09 10:38 +0200
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-09 10:38 +0200
1import numpy as np
2from scipy.spatial.distance import cdist
4REPLACE = True
7def construct_coreset(X: np.ndarray, coreset_size: int, seed: int):
8 """Construct coreset"""
9 n_samples = X.shape[0]
11 sq_dists = np.square(cdist(XA=X, XB=X.mean(axis=0, keepdims=True)).flatten())
12 probs = sq_dists / sq_dists.sum()
14 rng = np.random.default_rng(seed=seed)
15 # NOTE: In the original implementation they sample WITH replacement
16 # https://github.com/smair/archetypalanalysis-coreset/blob/6b34fce70ec1c47c9938d1f7887c506a131c94f6/code/coresets.py#L79
17 coreset_indices = rng.choice(a=n_samples, size=coreset_size, p=probs, replace=REPLACE)
19 weights = 1 / (probs[coreset_indices] * coreset_size)
20 weights_root = np.sqrt(weights)
21 weights_root = weights_root.astype(np.float32)
23 return coreset_indices, weights_root
26def construct_lightweight_coreset(X: np.ndarray, coreset_size: int, seed: int):
27 """Construct k-means clustering via lightweight coresets (Bachem et al. (2018))"""
28 n_samples = X.shape[0]
30 sq_dists = np.square(cdist(XA=X, XB=X.mean(axis=0, keepdims=True)).flatten())
31 probs = 0.5 * (1 / n_samples) + 0.5 * (sq_dists / sq_dists.sum())
33 rng = np.random.default_rng(seed=seed)
34 coreset_indices = rng.choice(a=n_samples, size=coreset_size, p=probs, replace=REPLACE)
36 weights = 1 / (probs[coreset_indices] * coreset_size)
37 weights_root = np.sqrt(weights)
38 weights_root = weights_root.astype(np.float32)
40 return coreset_indices, weights_root
43# NOTE: This is not really a coreset, but I rather use this for testing purposes
44def construct_uniform_coreset(X: np.ndarray, coreset_size: int, seed: int): # pragma: no cover
45 """Construct mock coreset by uniform sampling"""
46 n_samples = X.shape[0]
47 rng = np.random.default_rng(seed=seed)
48 coreset_indices = rng.choice(a=n_samples, size=coreset_size, replace=REPLACE)
49 weights_root = np.ones(shape=coreset_size, dtype=np.float32)
50 return coreset_indices, weights_root