Coverage for ParTIpy/arch.py: 12%

107 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-16 10:22 +0100

1""" 

2Class for archetypal analysis 

3 

4Note: notation used X ≈ A · B · X = A · Z 

5 

6Code adapted from https://github.com/atmguille/archetypal-analysis (by Guillermo García Cobo) 

7""" 

8 

9import numpy as np 

10import scanpy as sc 

11 

12from .const import ( 

13 DEFAULT_INIT, 

14 DEFAULT_OPTIM, 

15 DEFAULT_WEIGHT, 

16 INIT_ALGS, 

17 OPTIM_ALGS, 

18 WEIGHT_ALGS, 

19) 

20from .initialize import _furthest_sum_init, _random_init 

21from .optim import ( 

22 _compute_A_frank_wolfe, 

23 _compute_A_projected_gradients, 

24 _compute_A_regularized_nnls, 

25 _compute_B_frank_wolfe, 

26 _compute_B_projected_gradients, 

27 _compute_B_regularized_nnls, 

28) 

29from .weights import compute_bisquare_weights 

30 

31 

32class AA: 

33 """ 

34 TODO: Write docstring here 

35 ... 

36 """ 

37 

38 def __init__( 

39 self, 

40 n_archetypes: int, 

41 init: str = DEFAULT_INIT, 

42 optim: str = DEFAULT_OPTIM, 

43 weight: None | str = DEFAULT_WEIGHT, 

44 max_iter: int = 100, 

45 derivative_max_iter: int = 100, 

46 tol: float = 1e-6, 

47 verbose: bool = False, 

48 ): 

49 self.n_archetypes = n_archetypes 

50 self.init = init 

51 self.optim = optim 

52 self.weight = weight 

53 self.max_iter = max_iter 

54 self.deriv_max_iter = derivative_max_iter 

55 self.tol = tol 

56 self.verbose = verbose 

57 self.A = None 

58 self.B = None 

59 self.Z = None # Archetypes 

60 self.muA, self.muB = None, None 

61 self.n_samples, self.n_features = None, None 

62 self.RSS = None 

63 self.RSS_trace: list[float] = [] 

64 self.varexpl = None 

65 self.adata = None 

66 

67 # checks 

68 assert self.init in INIT_ALGS 

69 assert self.optim in OPTIM_ALGS 

70 assert self.weight in WEIGHT_ALGS 

71 

72 def fit(self, X: np.ndarray): 

73 """ 

74 Computes the archetypes and the RSS from the data X, which are stored 

75 in the corresponding attributes 

76 :param X: data matrix, with shape (n_samples, n_features) 

77 :return: self 

78 """ 

79 if isinstance(X, sc.AnnData): 

80 if "X_pca_reduced" not in X.obsm: 

81 raise ValueError( 

82 "X_pca_reduced not in AnnData object. Please use reduce_pca() to add it to the AnnData object." 

83 ) 

84 self.adata = X 

85 X = X.obsm["X_pca_reduced"] 

86 

87 self.n_samples, self.n_features = X.shape 

88 

89 # ensure C-contiguous format for numba 

90 X = np.ascontiguousarray(X) 

91 

92 # set the initalization function 

93 if self.init == "random": 

94 initialize_B = _random_init 

95 elif self.init == "furthest_sum": 

96 initialize_B = _furthest_sum_init 

97 else: 

98 raise NotImplementedError() 

99 

100 # set the optimization functions 

101 if self.optim == "regularized_nnls": 

102 compute_A = _compute_A_regularized_nnls 

103 compute_B = _compute_B_regularized_nnls 

104 elif self.optim == "projected_gradients": 

105 compute_A = _compute_A_projected_gradients # type: ignore[assignment] 

106 compute_B = _compute_B_projected_gradients # type: ignore[assignment] 

107 elif self.optim == "frank_wolfe": 

108 compute_A = _compute_A_frank_wolfe # type: ignore[assignment] 

109 compute_B = _compute_B_frank_wolfe # type: ignore[assignment] 

110 else: 

111 raise NotImplementedError() 

112 

113 # set the weight function 

114 if self.weight: 

115 if self.weight == "bisquare": 

116 compute_weights = compute_bisquare_weights 

117 else: 

118 raise NotImplementedError() 

119 

120 # initialize B and the archetypes Z 

121 B = initialize_B(X=X, n_archetypes=self.n_archetypes) 

122 Z = B @ X 

123 

124 # randomly initialize A 

125 A = -np.log(np.random.random((self.n_samples, self.n_archetypes))) 

126 A /= np.sum(A, axis=1, keepdims=True) 

127 

128 TSS = np.sum(X * X) 

129 prev_RSS = None 

130 

131 W = np.ones(X.shape[0]) if self.weight else None 

132 

133 for _ in range(self.max_iter): 

134 X_w = np.diag(W) @ X if self.weight else X 

135 A = compute_A(X_w, Z, A, self.deriv_max_iter) 

136 B = compute_B(X_w, A, B, self.deriv_max_iter) 

137 Z = B @ X_w 

138 

139 # compute residuals using the original data 

140 A_0 = compute_A(X, Z, A, self.deriv_max_iter) if self.weight else A 

141 R = X - A_0 @ Z 

142 W = compute_weights(R) if self.weight else None 

143 

144 RSS = np.linalg.norm(R) ** 2 

145 if (prev_RSS is not None) and ((abs(prev_RSS - RSS) / prev_RSS) < self.tol): 

146 break 

147 prev_RSS = RSS 

148 self.RSS_trace.append(float(RSS)) 

149 

150 # Recalculate A and B using the unweighted data 

151 if self.weight: 

152 A = compute_A(X, Z, A, self.deriv_max_iter) 

153 B = compute_B(X, A, B, self.deriv_max_iter) 

154 Z = B @ X 

155 RSS = np.linalg.norm(X - A @ Z) ** 2 

156 

157 self.Z = Z 

158 self.A = A 

159 self.B = B 

160 self.RSS = RSS 

161 self.RSS_trace = np.array(self.RSS_trace) 

162 self.varexpl = (TSS - RSS) / TSS 

163 return self 

164 

165 def archetypes(self) -> np.ndarray: 

166 """ 

167 Returns the archetypes' matrix 

168 :return: archetypes matrix, with shape (n_archetypes, n_features) 

169 """ 

170 return self.Z 

171 

172 def transform(self, X: np.ndarray) -> np.ndarray: 

173 """ 

174 Computes the best convex approximation A of X by the archetypes Z 

175 :param X: data matrix, with shape (n_samples, n_features) 

176 :return: A matrix, with shape (n_samples, n_archetypes) 

177 """ 

178 if self.optim == "regularized_nnls": 

179 return _compute_A_regularized_nnls(X, self.Z) 

180 elif self.optim == "projected_gradients": 

181 A_random = -np.log(np.random.random((self.n_samples, self.n_archetypes))) 

182 A_random /= np.sum(A_random, axis=1, keepdims=True) 

183 return _compute_A_projected_gradients(X=X, Z=self.Z, A=A_random) 

184 elif self.optim == "frank_wolfe": 

185 A_random = -np.log(np.random.random((self.n_samples, self.n_archetypes))) 

186 A_random /= np.sum(A_random, axis=1, keepdims=True) 

187 return _compute_A_frank_wolfe(X, self.Z, A=A_random) 

188 else: 

189 raise NotImplementedError() 

190 

191 def return_all(self) -> tuple: 

192 """Return optimized matrices: A, B, Z, and fitting stats: RSS, varexpl.""" 

193 return self.A, self.B, self.Z, self.RSS, self.varexpl 

194 

195 def save_to_anndata(self): 

196 """Saves the results (A, B, Z, RSS, varexpl) to the AnnData object provided in fit().""" 

197 if self.adata is None: 

198 raise ValueError("No AnnData object found. Please provide an AnnData object to fit().") 

199 

200 self.adata.uns["archetypal_analysis"] = { 

201 "A": self.A, 

202 "B": self.B, 

203 "Z": self.Z, 

204 "RSS": self.RSS, 

205 "varexpl": self.varexpl, 

206 }