Coverage for partipy/paretoti.py: 86%

143 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-09 10:39 +0200

1import inspect 

2 

3import numpy as np 

4import pandas as pd 

5import scanpy as sc 

6from joblib import Parallel, delayed 

7from scipy.optimize import linear_sum_assignment 

8from scipy.spatial import ConvexHull 

9from scipy.spatial.distance import cdist 

10from tqdm import tqdm 

11 

12from .arch import AA 

13from .const import DEFAULT_INIT, DEFAULT_OPTIM 

14from .selection import compute_IC 

15 

16 

17def set_obsm(adata: sc.AnnData, obsm_key: str, n_dimension: int) -> None: 

18 """ 

19 Sets the `obsm` key and dimensionality to be used as input for archetypal analysis (AA). 

20 

21 This function verifies that the specified `obsm_key` exists in `adata.obsm` and that the 

22 requested number of dimensions does not exceed the available dimensions in that matrix. 

23 The configuration is stored in `adata.uns["aa_config"]`. 

24 

25 Parameters 

26 ---------- 

27 adata : sc.AnnData 

28 AnnData object containing single-cell data. The specified `obsm_key` should refer to 

29 a matrix in `adata.obsm` to be used as input for AA. 

30 

31 obsm_key : str 

32 Key in `adata.obsm` pointing to the matrix to be used for AA. 

33 

34 n_dimension : int 

35 Number of dimensions to retain from `adata.obsm[obsm_key]`. Must be less than or equal 

36 to the number of columns in that matrix. 

37 

38 Returns 

39 ------- 

40 None 

41 The AA configuration is stored in `adata.uns["aa_config"]`. 

42 """ 

43 if obsm_key not in adata.obsm: 

44 raise ValueError(f"'{obsm_key}' not found in adata.obsm. Available keys are: {list(adata.obsm.keys())}") 

45 

46 available_dim = adata.obsm[obsm_key].shape[1] 

47 if n_dimension > available_dim: 

48 raise ValueError( 

49 f"Requested {n_dimension} dimensions from '{obsm_key}', but only {available_dim} are available." 

50 ) 

51 

52 if "aa_config" in adata.uns: 

53 print("Warning: 'aa_config' already exists in adata.uns and will be overwritten.") 

54 

55 adata.uns["aa_config"] = { 

56 "obsm_key": obsm_key, 

57 "n_dimension": n_dimension, 

58 } 

59 

60 

61def _validate_aa_config(adata: sc.AnnData) -> None: 

62 """ 

63 Validates that the AnnData object is properly configured for archetypal analysis (AA). 

64 

65 This function checks that: 

66 - `adata.uns["aa_config"]` exists, 

67 - it contains the keys "obsm_key" and "n_dimension", 

68 - the specified `obsm_key` exists in `adata.obsm`, 

69 - and that the requested number of dimensions does not exceed the available dimensions. 

70 

71 Parameters 

72 ---------- 

73 adata : sc.AnnData 

74 AnnData object expected to contain AA configuration in `adata.uns["aa_config"]`. 

75 

76 Returns 

77 ------- 

78 None 

79 

80 Raises 

81 ------ 

82 ValueError 

83 If the configuration is missing, incomplete, or inconsistent with the contents of `adata.obsm`. 

84 """ 

85 if "aa_config" not in adata.uns: 

86 raise ValueError("AA configuration not found in `adata.uns['aa_config']`.") 

87 

88 config = adata.uns["aa_config"] 

89 

90 if not isinstance(config, dict): 

91 raise ValueError("`adata.uns['aa_config']` must be a dictionary.") 

92 

93 required_keys = {"obsm_key", "n_dimension"} 

94 missing = required_keys - config.keys() 

95 if missing: 

96 raise ValueError(f"Missing keys in `aa_config`: {missing}") 

97 

98 obsm_key = config["obsm_key"] 

99 n_dimension = config["n_dimension"] 

100 

101 if obsm_key not in adata.obsm: 

102 raise ValueError(f"'{obsm_key}' not found in `adata.obsm`. Available keys: {list(adata.obsm.keys())}") 

103 

104 available_dim = adata.obsm[obsm_key].shape[1] 

105 if n_dimension > available_dim: 

106 raise ValueError( 

107 f"Configured number of dimensions ({n_dimension}) exceeds available dimensions ({available_dim}) in `adata.obsm['{obsm_key}']`." 

108 ) 

109 

110 

111def _validate_aa_results(adata: sc.AnnData) -> None: 

112 """ 

113 Validates that the result from Archetypal Analysis is present in the AnnData object. 

114 

115 Parameters 

116 ---------- 

117 adata : sc.AnnData 

118 Annotated data matrix. 

119 

120 Raises 

121 ------ 

122 ValueError 

123 If the archetypal analysis result is not found in `adata.uns["AA_results"]`. 

124 """ 

125 if "AA_results" not in adata.uns: 

126 raise ValueError( 

127 "Result from Archetypal Analysis not found in `adata.uns['AA_results']`. " 

128 "Please run the AA() function first." 

129 ) 

130 

131 

132def var_explained_aa( 

133 adata: sc.AnnData, 

134 min_a: int = 2, 

135 max_a: int = 10, 

136 optim: str = DEFAULT_OPTIM, 

137 init: str = DEFAULT_INIT, 

138 n_jobs: int = -1, 

139 **kwargs, 

140) -> None: 

141 """ 

142 Compute the variance explained by Archetypal Analysis (AA) for a range of archetypes. 

143 

144 This function performs Archetypal Analysis (AA) across a range of archetype counts (`min_a` to `max_a`) 

145 on the PCA representation stored in `adata.obsm[obsm_key]`. It stores the explained variance and other 

146 diagnostics in `adata.uns["AA_var"]`. 

147 

148 Parameters 

149 ---------- 

150 adata: sc.AnnData 

151 AnnData object containing adata.obsm["obsm_key"]. 

152 min_a : int, optional (default=2) 

153 Minimum number of archetypes to test. 

154 max_a : int, optional (default=10) 

155 Maximum number of archetypes to test. 

156 optim : str, optional (default=DEFAULT_OPTIM) 

157 The optimization function to use for Archetypal Analysis. 

158 init : str, optional (default=DEFAULT_INIT) 

159 The initialization function to use for Archetypal Analysis. 

160 n_jobs : int, optional (default=-1) 

161 Number of jobs for parallel computation. `-1` uses all available cores. 

162 **kwargs: 

163 Additional keyword arguments passed to `AA` class. 

164 

165 Returns 

166 ------- 

167 None 

168 The results are stored in `adata.uns["AA_var"]` as a DataFrame with the following columns: 

169 - `k`: The number of archetypes. 

170 - `varexpl`: Variance explained by the AA model with `k` archetypes. 

171 - `varexpl_ontop`: Incremental variance explained compared to `k-1` archetypes. 

172 - `dist_to_projected`: Distance from each point to its projection on the line connecting the first and last points 

173 in the variance curve, used to identify "elbow points". 

174 """ 

175 # input validation 

176 _validate_aa_config(adata=adata) 

177 if min_a < 2: 

178 raise ValueError("`min_a` must be at least 2.") 

179 if max_a < min_a: 

180 raise ValueError("`max_a` must be greater than or equal to `min_a`.") 

181 

182 obsm_key = adata.uns["aa_config"]["obsm_key"] 

183 n_dimensions = adata.uns["aa_config"]["n_dimension"] 

184 X = adata.obsm[obsm_key][:, :n_dimensions] 

185 

186 k_arr = np.arange(min_a, max_a + 1) 

187 

188 # Parallel computation of AA 

189 def _compute_archeptyes(k): 

190 A, B, Z, RSS, varexpl = AA(n_archetypes=k, optim=optim, init=init, **kwargs).fit(X).return_all() 

191 return k, {"Z": Z, "A": A, "B": B, "RSS": RSS, "varexpl": varexpl} 

192 

193 if n_jobs == 1: 

194 results_list = [_compute_archeptyes(k) for k in k_arr] 

195 else: 

196 results_list = Parallel(n_jobs=n_jobs)(delayed(_compute_archeptyes)(k) for k in k_arr) 

197 

198 # results = {k: result for k, result in results_list} 

199 results = dict(results_list) # faster, and see https://docs.astral.sh/ruff/rules/unnecessary-comprehension/ 

200 

201 IC_values = [] 

202 for n_archetypes in k_arr: 

203 X_tilde = results[n_archetypes]["A"] @ results[n_archetypes]["Z"] 

204 IC_values.append(compute_IC(X=X, X_tilde=X_tilde, n_archetypes=n_archetypes)) 

205 

206 varexpl_values = np.array([results[k]["varexpl"] for k in k_arr]) 

207 

208 result_df = pd.DataFrame( 

209 { 

210 "k": k_arr, 

211 "IC": IC_values, 

212 "varexpl": varexpl_values, 

213 "varexpl_ontop": np.insert(np.diff(varexpl_values), 0, varexpl_values[0]), 

214 } 

215 ) 

216 

217 # Compute the distance of the explained variance to its projection 

218 offset_vec = result_df[["k", "varexpl"]].iloc[0].values 

219 proj_vec = (result_df[["k", "varexpl"]].values - offset_vec)[-1, :][:, None] 

220 proj_mtx = proj_vec @ np.linalg.inv(proj_vec.T @ proj_vec) @ proj_vec.T 

221 proj_val = (proj_mtx @ (result_df[["k", "varexpl"]].values - offset_vec).T).T + offset_vec 

222 proj_df = pd.DataFrame(proj_val, columns=["k", "varexpl"]) 

223 result_df["dist_to_projected"] = np.linalg.norm( 

224 result_df[["k", "varexpl"]].values - proj_df[["k", "varexpl"]].values, axis=1 

225 ) 

226 

227 adata.uns["AA_var"] = result_df 

228 

229 

230def bootstrap_aa( 

231 adata: sc.AnnData, 

232 n_bootstrap: int, 

233 n_archetypes: int, 

234 optim: str = DEFAULT_OPTIM, 

235 init: str = DEFAULT_INIT, 

236 seed: int = 42, 

237 save_to_anndata: bool = True, 

238 n_jobs: int = -1, 

239 **kwargs, 

240) -> None | pd.DataFrame: 

241 """ 

242 Perform bootstrap sampling to compute archetypes and assess their stability. 

243 

244 This function generates bootstrap samples from the data, computes archetypes for each sample, 

245 aligns them with the reference archetypes, and stores the results in `adata.uns["AA_bootstrap"]`. 

246 It allows assessing the stability of the archetypes across multiple bootstrap iterations. 

247 

248 Parameters 

249 ---------- 

250 adata : sc.AnnData 

251 The AnnData object containing the data to fit the archetypes. The data should be available in 

252 `adata.obsm[obsm_key]`. 

253 n_bootstrap : int 

254 The number of bootstrap samples to generate. 

255 n_archetypes : int 

256 The number of archetypes to compute for each bootstrap sample. 

257 optim : str, optional (default=DEFAULT_OPTIM) 

258 The optimization function to use for Archetypal Analysis. 

259 init : str, optional (default=DEFAULT_INIT) 

260 The initialization function to use for Archetypal Analysis. 

261 seed : int, optional (default=42) 

262 The random seed for reproducibility. 

263 n_jobs : int, optional (default=-1) 

264 The number of jobs to run in parallel. `-1` uses all available cores. 

265 **kwargs: 

266 Additional keyword arguments passed to `AA` class. 

267 

268 Returns 

269 ------- 

270 None 

271 The results are stored in `adata.uns["AA_bootstrap"]` as a DataFrame with the following columns: 

272 - `pc_i`: The coordinates of the archetypes in the i-th principal component. 

273 - `archetype`: The archetype index. 

274 - `iter`: The bootstrap iteration index (0 for the reference archetypes). 

275 - `reference`: A boolean indicating whether the archetype is from the reference model. 

276 - `mean_variance`: The mean variance of all archetype coordinates across bootstrap samples. 

277 - `variance_per_archetype`: The mean variance of each archetype coordinates across bootstrap samples. 

278 """ 

279 # input validation 

280 _validate_aa_config(adata=adata) 

281 

282 obsm_key = adata.uns["aa_config"]["obsm_key"] 

283 n_dimensions = adata.uns["aa_config"]["n_dimension"] 

284 X = adata.obsm[obsm_key][:, :n_dimensions] 

285 

286 n_samples, n_features = X.shape 

287 rng = np.random.default_rng(seed) 

288 

289 # Reference archetypes 

290 ref_Z = AA(n_archetypes=n_archetypes, optim=optim, init=init, **kwargs).fit(X).Z 

291 

292 # Generate bootstrap samples 

293 idx_bootstrap = rng.choice(n_samples, size=(n_bootstrap, n_samples), replace=True) 

294 

295 # Define function for parallel computation 

296 def compute_bootstrap_z(idx): 

297 return AA(n_archetypes=n_archetypes, optim=optim, init=init, **kwargs).fit(X[idx, :]).Z 

298 

299 # Parallel computation of AA on bootstrap samples 

300 Z_list = Parallel(n_jobs=n_jobs)(delayed(compute_bootstrap_z)(idx) for idx in idx_bootstrap) 

301 

302 # Align archetypes 

303 Z_list = [_align_archetypes(ref_arch=ref_Z.copy(), query_arch=query_Z.copy()) for query_Z in Z_list] 

304 

305 # Compute variance per archetype 

306 Z_stack = np.stack(Z_list) 

307 var_per_archetype = Z_stack.var(axis=0).mean(axis=1) 

308 mean_variance = var_per_archetype.mean() 

309 

310 # Create result dataframe 

311 bootstrap_data = [ 

312 pd.DataFrame(Z, columns=[f"x{i}" for i in range(n_features)]).assign( 

313 archetype=np.arange(n_archetypes), iter=i + 1 

314 ) 

315 for i, Z in enumerate(Z_list) 

316 ] 

317 bootstrap_df = pd.concat(bootstrap_data) 

318 

319 df = pd.DataFrame(ref_Z, columns=[f"x{i}" for i in range(n_features)]) 

320 df["archetype"] = np.arange(n_archetypes) 

321 df["iter"] = 0 

322 

323 bootstrap_df = pd.concat((bootstrap_df, df), axis=0) 

324 bootstrap_df["reference"] = bootstrap_df["iter"] == 0 

325 bootstrap_df["archetype"] = pd.Categorical(bootstrap_df["archetype"]) 

326 

327 bootstrap_df["mean_variance"] = mean_variance 

328 

329 archetype_variance_map = dict(zip(np.arange(n_archetypes), var_per_archetype, strict=False)) 

330 bootstrap_df["variance_per_archetype"] = bootstrap_df["archetype"].astype(int).map(archetype_variance_map) 

331 

332 if save_to_anndata: 

333 adata.uns["AA_bootstrap"] = bootstrap_df 

334 return None 

335 else: 

336 return bootstrap_df 

337 

338 

339def bootstrap_aa_multiple_k( 

340 adata: sc.AnnData, 

341 n_bootstrap: int = 30, 

342 n_archetypes_list=None, 

343 save_to_anndata: bool = True, 

344 n_jobs: int = -1, 

345 **kwargs, 

346): 

347 """ 

348 Perform bootstrap sampling across multiple numbers of archetypes to assess stability. 

349 

350 This function repeatedly applies bootstrap sampling and Archetypal Analysis (AA) for different 

351 numbers of archetypes, aggregates the archetype stability metrics, and allows for evaluating 

352 how stability varies with model complexity. 

353 

354 Parameters 

355 ---------- 

356 adata : sc.AnnData 

357 AnnData object containing the data. The PCA data should be stored in `adata.obsm["X_pca"]`. 

358 n_bootstrap : int, optional (default=30) 

359 The number of bootstrap samples to generate for each number of archetypes. 

360 n_archetypes_list : list of int, optional (default=range(2, 8)) 

361 A list specifying the numbers of archetypes to evaluate. 

362 save_to_anndata : bool, optional (default=True) 

363 Whether to save the results to `adata.uns["AA_boostrap_multiple_k"]`. If `False`, the 

364 result is returned. 

365 **kwargs: 

366 Additional keyword arguments passed to `AA` class. 

367 

368 Returns 

369 ------- 

370 None or pd.DataFrame 

371 If `save_to_anndata=True`, results are stored in `adata.uns["AA_boostrap_multiple_k"]` as a 

372 DataFrame with the following columns: 

373 - `archetype`: The archetype index. 

374 - `variance_per_archetype`: The mean variance of each archetype's coordinates across bootstrap samples. 

375 - `n_archetypes`: The number of archetypes used for the corresponding bootstrap analysis. 

376 

377 If `save_to_anndata=False`, the DataFrame is returned. 

378 """ 

379 if n_archetypes_list is None: 

380 n_archetypes_list = list(range(2, 8)) 

381 

382 df_list = [] 

383 for k in n_archetypes_list: 

384 boostrap_df = bootstrap_aa( 

385 adata=adata, n_bootstrap=n_bootstrap, n_archetypes=k, save_to_anndata=False, n_jobs=n_jobs, **kwargs 

386 ) 

387 boostrap_df["n_archetypes"] = k # type: ignore[index] 

388 df_list.append(boostrap_df) 

389 df = pd.concat(df_list, axis=0) 

390 df = df[["archetype", "variance_per_archetype", "n_archetypes"]].drop_duplicates() 

391 if save_to_anndata: 

392 adata.uns["AA_boostrap_multiple_k"] = df 

393 else: 

394 return df 

395 

396 

397def _project_on_affine_subspace(X, Z) -> np.ndarray: # pragma: no cover 

398 """ 

399 Projects a set of points X onto the affine subspace spanned by the vertices Z. 

400 

401 Parameters 

402 ---------- 

403 X : numpy.ndarray 

404 A (D x n) array of n points in D-dimensional space to be projected. 

405 Z : numpy.ndarray 

406 A (D x k) array of k vertices (archetypes) defining the affine subspace in D-dimensional space. 

407 

408 Returns 

409 ------- 

410 proj_coord : numpy.ndarray 

411 The coordinates of the projected points in the subspace defined by Z. 

412 """ 

413 D, k = Z.shape 

414 

415 # Compute the projection vectors (basis for the affine subspace) 

416 if k == 2: 

417 # For a line (k=2), the projection vector is simply the difference between the two vertices 

418 proj_vec = (Z[:, 1] - Z[:, 0])[:, None] 

419 else: 

420 # For higher dimensions, compute the projection vectors relative to the first vertex 

421 proj_vec = Z[:, 1:] - Z[:, 0][:, None] 

422 

423 # Compute the coordinates of the projected points in the subspace 

424 proj_coord = np.linalg.inv(proj_vec.T @ proj_vec) @ proj_vec.T @ (X - Z[:, 0][:, None]) 

425 

426 return proj_coord 

427 

428 

429def _compute_t_ratio(X: np.ndarray, Z: np.ndarray) -> float: # pragma: no cover 

430 """ 

431 Compute the t-ratio: volume(polytope defined by Z) / volume(convex hull of X) 

432 

433 Parameters 

434 ---------- 

435 X : np.ndarray, shape (n_samples, n_features) 

436 Data matrix. 

437 Z : np.ndarray, shape (n_archetypes, n_features) 

438 Archetypes matrix. 

439 

440 Returns 

441 ------- 

442 float 

443 The t-ratio. 

444 """ 

445 D, k = X.shape[1], Z.shape[0] 

446 

447 if k < 2: 

448 raise ValueError("At least 2 archetypes are required (k >= 2).") 

449 

450 if k < D + 1: 

451 proj_X = _project_on_affine_subspace(X.T, Z.T).T 

452 proj_Z = _project_on_affine_subspace(Z.T, Z.T).T 

453 convhull_volume = ConvexHull(proj_X).volume 

454 polytope_volume = ConvexHull(proj_Z).volume 

455 else: 

456 convhull_volume = ConvexHull(X).volume 

457 polytope_volume = ConvexHull(Z).volume 

458 

459 return polytope_volume / convhull_volume 

460 

461 

462def compute_t_ratio(adata) -> float | None: # pragma: no cover 

463 """ 

464 Compute the t-ratio from either an AnnData object or raw matrices. 

465 

466 Parameters 

467 ---------- 

468 adata : sc.AnnData 

469 If AnnData: must contain `obsm[obsm_key]` and `uns["AA_results"]["Z"]`. 

470 

471 Returns 

472 ------- 

473 Optional[float] 

474 - If input is AnnData, result is stored in `X.uns["t_ratio"]`. 

475 - If input is ndarray, result is returned as float. 

476 """ 

477 # input validation 

478 _validate_aa_config(adata=adata) 

479 if "AA_results" not in adata.uns or "Z" not in adata.uns["AA_results"]: 

480 raise ValueError("Missing archetypes in `adata.uns['AA_results']['Z']`.") 

481 

482 obsm_key = adata.uns["aa_config"]["obsm_key"] 

483 n_dimensions = adata.uns["aa_config"]["n_dimension"] 

484 X = adata.obsm[obsm_key][:, :n_dimensions] 

485 Z = adata.uns["AA_results"]["Z"] 

486 t_ratio = _compute_t_ratio(X, Z) 

487 adata.uns["t_ratio"] = t_ratio 

488 return None 

489 

490 

491def t_ratio_significance(adata, iter=1000, seed=42, n_jobs=-1): # pragma: no cover 

492 """ 

493 Assesses the significance of the polytope spanned by the archetypes by comparing the t-ratio of the original data to t-ratios computed from randomized datasets. 

494 

495 Parameters 

496 ---------- 

497 adata : sc.AnnData 

498 An AnnData object containing `adata.obsm["X_pca"]` and `adata.uns["aa_config"]["n_dimension"], optionally `adata.uns["t_ratio"]`. If `adata.uns["t_ratio"]` doesnt exist it is called and computed. 

499 iter : int, optional (default=1000) 

500 Number of randomized datasets to generate. 

501 seed : int, optional (default=42) 

502 The random seed for reproducibility. 

503 n_jobs : int, optional 

504 Number of jobs for parallelization (default: 1). Use -1 to use all available cores. 

505 

506 Returns 

507 ------- 

508 float 

509 The proportion of randomized datasets with a t-ratio greater than the original t-ratio (p-value). 

510 """ 

511 # input validation 

512 _validate_aa_config(adata=adata) 

513 

514 if "t_ratio" not in adata.uns: 

515 print("Computing t-ratio...") 

516 compute_t_ratio(adata) 

517 

518 obsm_key = adata.uns["aa_config"]["obsm_key"] 

519 n_dimensions = adata.uns["aa_config"]["n_dimension"] 

520 X = adata.obsm[obsm_key][:, :n_dimensions] 

521 

522 t_ratio = adata.uns["t_ratio"] 

523 n_samples, n_features = X.shape 

524 n_archetypes = adata.uns["AA_results"]["Z"].shape[0] 

525 

526 rng = np.random.default_rng(seed) 

527 

528 def compute_randomized_t_ratio(): 

529 # Shuffle each feature independently 

530 SimplexRand1 = np.array([rng.permutation(X[:, i]) for i in range(n_features)]).T 

531 # Compute archetypes and t-ratio for randomized data 

532 Z_mix = AA(n_archetypes=n_archetypes).fit(SimplexRand1).Z 

533 return _compute_t_ratio(SimplexRand1, Z_mix) 

534 

535 # Parallelize the computation of randomized t-ratios 

536 RandRatio = Parallel(n_jobs=n_jobs)( 

537 delayed(compute_randomized_t_ratio)() for _ in tqdm(range(iter), desc="Randomizing") 

538 ) 

539 

540 # Calculate the p-value 

541 p_value = np.sum(np.array(RandRatio) > t_ratio) / iter 

542 return p_value 

543 

544 

545def t_ratio_significance_shuffled(adata, iter=1000, seed=42, n_jobs=-1): # pragma: no cover 

546 """ 

547 Assesses the significance of the polytope spanned by the archetypes by comparing the t-ratio of the original data to t-ratios computed from randomized datasets. 

548 

549 Parameters 

550 ---------- 

551 adata : sc.AnnData 

552 An AnnData object containing `adata.obsm["X_pca"]` and `adata.uns["aa_config"]["n_dimension"], optionally `adata.uns["t_ratio"]`. If `adata.uns["t_ratio"]` doesnt exist it is called and computed. 

553 iter : int, optional (default=1000) 

554 Number of randomized datasets to generate. 

555 seed : int, optional (default=42) 

556 The random seed for reproducibility. 

557 n_jobs : int, optional 

558 Number of jobs for parallelization (default: 1). Use -1 to use all available cores. 

559 

560 Returns 

561 ------- 

562 float 

563 The proportion of randomized datasets with a t-ratio greater than the original t-ratio (p-value). 

564 """ 

565 # input validation 

566 _validate_aa_config(adata=adata) 

567 

568 if "t_ratio" not in adata.uns: 

569 print("Computing t-ratio...") 

570 compute_t_ratio(adata) 

571 

572 obsm_key = adata.uns["aa_config"]["obsm_key"] 

573 n_dimensions = adata.uns["aa_config"]["n_dimension"] 

574 X = adata.obsm[obsm_key][:, :n_dimensions] 

575 

576 t_ratio = adata.uns["t_ratio"] 

577 n_samples, n_features = X.shape 

578 n_archetypes = adata.uns["AA_results"]["Z"].shape[0] 

579 

580 rng = np.random.default_rng(seed) 

581 

582 def compute_randomized_t_ratio(): 

583 # Shuffle each feature independently 

584 SimplexRand1 = np.array([rng.permutation(X[:, i]) for i in range(n_features)]).T 

585 SimplexRand1_pca = sc.pp.pca(SimplexRand1, n_comps=adata.uns["aa_config"]["n_dimension"]) 

586 # Compute archetypes and t-ratio for randomized data 

587 Z_mix = AA(n_archetypes=n_archetypes).fit(SimplexRand1_pca).Z 

588 return _compute_t_ratio(SimplexRand1_pca, Z_mix) 

589 

590 # Parallelize the computation of randomized t-ratios 

591 RandRatio = Parallel(n_jobs=n_jobs)( 

592 delayed(compute_randomized_t_ratio)() for _ in tqdm(range(iter), desc="Randomizing") 

593 ) 

594 

595 # Calculate the p-value 

596 p_value = np.sum(np.array(RandRatio) > t_ratio) / iter 

597 return p_value 

598 

599 

600def _align_archetypes(ref_arch: np.ndarray, query_arch: np.ndarray) -> np.ndarray: 

601 """ 

602 Align the archetypes of the query to match the order of archetypes in the reference. 

603 

604 This function uses the Euclidean distance between archetypes in the reference and query sets 

605 to determine the optimal alignment. The Hungarian algorithm (linear sum assignment) is used 

606 to find the best matching pairs, and the query archetypes are reordered accordingly. 

607 

608 Parameters 

609 ---------- 

610 ref_arch : np.ndarray 

611 A 2D array of shape (n_archetypes, n_features) representing the reference archetypes. 

612 query_arch : np.ndarray 

613 A 2D array of shape (n_archetypes, n_features) representing the query archetypes. 

614 

615 Returns 

616 ------- 

617 np.ndarray 

618 A 2D array of shape (n_archetypes, n_features) containing the reordered query archetypes. 

619 """ 

620 # Compute pairwise Euclidean distances 

621 euclidean_d = cdist(ref_arch, query_arch.copy(), metric="euclidean") 

622 

623 # Find the optimal assignment using the Hungarian algorithm 

624 ref_idx, query_idx = linear_sum_assignment(euclidean_d) 

625 

626 return query_arch[query_idx, :] 

627 

628 

629def compute_archetypes( 

630 adata: sc.AnnData, 

631 n_archetypes: int, 

632 init: str | None = None, 

633 optim: str | None = None, 

634 weight: None | str = None, 

635 max_iter: int | None = None, 

636 rel_tol: float | None = None, 

637 verbose: bool | None = None, 

638 seed: int = 42, 

639 save_to_anndata: bool = True, 

640 archetypes_only: bool = True, 

641 **optim_kwargs, 

642) -> np.ndarray | tuple[np.ndarray, np.ndarray, np.ndarray, list[float] | np.ndarray, float] | None: 

643 """ 

644 

645 Perform Archetypal Analysis (AA) on the input data. 

646 

647 This function is a wrapper around the AA class, offering a simplified interface for fitting the model 

648 and returning the results, or saving them to the AnnData object. It allows users to customize the 

649 archetype computation with various parameters for initialization, optimization, convergence, and output. 

650 

651 Parameters 

652 ---------- 

653 adata : sc.AnnData 

654 The AnnData object containing the data to fit the archetypes. The data should be available in 

655 `adata.obsm[obsm_key]`. 

656 n_archetypes : int 

657 The number of archetypes to compute. 

658 init : str, optional 

659 The initialization method for the archetypes. If not provided, the default from the AA class is used. 

660 Options include: 

661 - "uniform": Uniform initialization. 

662 - "furthest_sum": Furthest sum initialization (default). 

663 - "plus_plus": Archetype++ initialization. 

664 optim : str, optional 

665 The optimization method for fitting the model. If not provided, the default from the AA class is used. 

666 Options include: 

667 - "projected_gradients": Projected gradients optimization. 

668 - "frank_wolfe": Frank-Wolfe optimization. 

669 - "regularized_nnls": Regularized non-negative least squares optimization. 

670 weight : str, optional 

671 The weighting method for the data. If not provided, the default from the AA class is used. 

672 Options include: 

673 - None : default 

674 - "bisquare": Bisquare weighting. 

675 - "huber": Hunber weighting. 

676 max_iter : int, optional 

677 The maximum number of iterations for the optimization. If not provided, the default from the AA class is used. 

678 rel_tol : float, optional 

679 The rel_tol tolerance for convergence. If not provided, the default from the AA class is used. 

680 verbose : bool, optional 

681 Whether to print verbose output during fitting. If not provided, the default from the AA class is used. 

682 seed : int, optional 

683 The random seed for reproducibility. 

684 save_to_anndata : bool, optional (default=True) 

685 Whether to save the results to the AnnData object. If False, the results are returned as a tuple. If 

686 `adata` is not an AnnData object, this is ignored. 

687 archetypes_only : bool, optional (default=True) 

688 Whether to save/return only the archetypes matrix `Z` (if det to True) or also the full outputs, including 

689 the matrices `A`, `B`, `RSS`, and variance explained `varexpl`. 

690 optim_kwargs: TODO 

691 

692 Returns 

693 ------- 

694 Optional[Union[np.ndarray, Tuple[np.ndarray, np.ndarray, np.ndarray, float, float]]] 

695 The output depends on the values of `save_to_anndata` and `archetypes_only`: 

696 - If `archetypes_only` is True: 

697 - Only the archetype matrix (Z) is returned/ saved 

698 - If `archetypes_only` is False: 

699 - returns/ saves a tuple containing: 

700 - A: The matrix of weights for the data points (n_samples, n_archetypes). 

701 - B: The matrix of weights for the archetypes (n_archetypes, n_samples). 

702 - Z: The archetypes matrix (n_archetypes, n_features). 

703 - RSS: The residual sum of squares. 

704 - varexpl: The variance explained by the model. 

705 - If `save_to_anndata` is True: 

706 - Returns `None`. Results are saved to `adata.uns["AA_results"]`. 

707 - If `save_to_anndata` is False: 

708 - Returns the results. 

709 """ 

710 # input validation 

711 _validate_aa_config(adata=adata) 

712 

713 # Get the signature of AA.__init__ 

714 signature = inspect.signature(AA.__init__) 

715 

716 # Create a dictionary of parameter names and their default values 

717 defaults = { 

718 param: signature.parameters[param].default 

719 for param in signature.parameters 

720 if param != "self" and param != "n_archetypes" 

721 } 

722 

723 # Use the provided values or fall back to the defaults 

724 init = init if init is not None else defaults["init"] 

725 optim = optim if optim is not None else defaults["optim"] 

726 weight = weight if weight is not None else defaults["weight"] 

727 max_iter = max_iter if max_iter is not None else defaults["max_iter"] 

728 rel_tol = rel_tol if rel_tol is not None else defaults["rel_tol"] 

729 verbose = verbose if verbose is not None else defaults["verbose"] 

730 

731 # Create the AA model with the specified parameters 

732 model = AA( 

733 n_archetypes=n_archetypes, 

734 init=init, 

735 optim=optim, 

736 weight=weight, 

737 max_iter=max_iter, 

738 rel_tol=rel_tol, 

739 verbose=verbose, 

740 seed=seed, 

741 **optim_kwargs, 

742 ) 

743 

744 # Extract the data matrix used to fit the archetypes 

745 obsm_key = adata.uns["aa_config"]["obsm_key"] 

746 n_dimensions = adata.uns["aa_config"]["n_dimension"] 

747 X = adata.obsm[obsm_key][:, :n_dimensions] 

748 X = X.astype(np.float32) 

749 

750 # Fit the model to the data 

751 model.fit(X) 

752 

753 # Save the results to the AnnData object if specified 

754 if save_to_anndata: 

755 if archetypes_only: 

756 adata.uns["AA_results"] = { 

757 "Z": model.Z, 

758 } 

759 else: 

760 adata.uns["AA_results"] = { 

761 "A": model.A, 

762 "B": model.B, 

763 "Z": model.Z, 

764 "RSS": model.RSS_trace, 

765 "varexpl": model.varexpl, 

766 } 

767 return None 

768 else: 

769 if archetypes_only: 

770 return model.Z 

771 else: 

772 return model.A, model.B, model.Z, model.RSS_trace, model.varexpl