Coverage for partipy/coreset.py: 100%

23 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-09 10:38 +0200

1import numpy as np 

2from scipy.spatial.distance import cdist 

3 

4REPLACE = True 

5 

6 

7def construct_coreset(X: np.ndarray, coreset_size: int, seed: int): 

8 """Construct coreset""" 

9 n_samples = X.shape[0] 

10 

11 sq_dists = np.square(cdist(XA=X, XB=X.mean(axis=0, keepdims=True)).flatten()) 

12 probs = sq_dists / sq_dists.sum() 

13 

14 rng = np.random.default_rng(seed=seed) 

15 # NOTE: In the original implementation they sample WITH replacement 

16 # https://github.com/smair/archetypalanalysis-coreset/blob/6b34fce70ec1c47c9938d1f7887c506a131c94f6/code/coresets.py#L79 

17 coreset_indices = rng.choice(a=n_samples, size=coreset_size, p=probs, replace=REPLACE) 

18 

19 weights = 1 / (probs[coreset_indices] * coreset_size) 

20 weights_root = np.sqrt(weights) 

21 weights_root = weights_root.astype(np.float32) 

22 

23 return coreset_indices, weights_root 

24 

25 

26def construct_lightweight_coreset(X: np.ndarray, coreset_size: int, seed: int): 

27 """Construct k-means clustering via lightweight coresets (Bachem et al. (2018))""" 

28 n_samples = X.shape[0] 

29 

30 sq_dists = np.square(cdist(XA=X, XB=X.mean(axis=0, keepdims=True)).flatten()) 

31 probs = 0.5 * (1 / n_samples) + 0.5 * (sq_dists / sq_dists.sum()) 

32 

33 rng = np.random.default_rng(seed=seed) 

34 coreset_indices = rng.choice(a=n_samples, size=coreset_size, p=probs, replace=REPLACE) 

35 

36 weights = 1 / (probs[coreset_indices] * coreset_size) 

37 weights_root = np.sqrt(weights) 

38 weights_root = weights_root.astype(np.float32) 

39 

40 return coreset_indices, weights_root 

41 

42 

43# NOTE: This is not really a coreset, but I rather use this for testing purposes 

44def construct_uniform_coreset(X: np.ndarray, coreset_size: int, seed: int): # pragma: no cover 

45 """Construct mock coreset by uniform sampling""" 

46 n_samples = X.shape[0] 

47 rng = np.random.default_rng(seed=seed) 

48 coreset_indices = rng.choice(a=n_samples, size=coreset_size, replace=REPLACE) 

49 weights_root = np.ones(shape=coreset_size, dtype=np.float32) 

50 return coreset_indices, weights_root