Coverage for partipy/paretoti_funcs.py: 12%

226 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-16 12:01 +0100

1import inspect 

2 

3import numpy as np 

4import pandas as pd 

5import plotly.express as px 

6import plotly.graph_objects as go 

7import plotnine as pn 

8import scanpy as sc 

9from joblib import Parallel, delayed 

10from scipy.optimize import linear_sum_assignment 

11from scipy.spatial import ConvexHull 

12from scipy.spatial.distance import cdist 

13from tqdm import tqdm 

14 

15from .arch import AA 

16from .const import DEFAULT_INIT, DEFAULT_OPTIM 

17 

18 

19def set_dimension(adata: sc.AnnData, n_pcs: int) -> None: 

20 """ 

21 Sets the number of PCs used for subsetting the PCA in `adata.obsm["X_pca"]`. 

22 If `adata.obsm["X_pca"]` does not exist, PCA is computed and stored in `adata.obsm["X_pca"]`. 

23 The number of PCs are stored in `adata.uns["PCs"]` 

24 

25 Parameters 

26 ---------- 

27 adata : sc.AnnData 

28 AnnData object containing single-cell data. 

29 n_pcs : int 

30 The number of principal components (PCs) to retain. Must be less than or equal to the 

31 number of available PCs in `adata.obsm["X_pca"]`. 

32 

33 Returns 

34 ------- 

35 None 

36 The number of PCs are stored in `adata.uns["PCs"]` 

37 """ 

38 # Validation input 

39 if "X_pca" not in adata.obsm: 

40 print("X_pca not found in adata.obsm. Computing PCA on highly variable genes...") 

41 sc.pp.pca(adata, mask_var="highly_variable") 

42 

43 if n_pcs > adata.obsm["X_pca"].shape[1]: 

44 raise ValueError(f"Requested {n_pcs} PCs, but only {adata.obsm['X_pca'].shape[1]} PCs are available.") 

45 

46 adata.uns["PCs"] = n_pcs 

47 

48 

49def var_explained_aa( 

50 adata: sc.AnnData, 

51 min_a: int = 2, 

52 max_a: int = 10, 

53 optim: str = DEFAULT_OPTIM, 

54 init: str = DEFAULT_INIT, 

55 n_jobs: int = -1, 

56) -> None: 

57 """ 

58 Compute the variance explained by Archetypal Analysis (AA) for a range of archetypes. 

59 

60 This function performs Archetypal Analysis (AA) for a range of archetypes (from `min_a` to `max_a`) 

61 on the PCA data stored in `adata.obsm["X_pca"]`. The results are 

62 stored in `adata.uns["AA_var"]`. 

63 

64 Parameters 

65 ---------- 

66 adata: sc.AnnData 

67 AnnData object containing adata.obsm["X_pca"]. 

68 min_a : int, optional (default=2) 

69 Minimum number of archetypes to test. 

70 max_a : int, optional (default=10) 

71 Maximum number of archetypes to test. 

72 optim : str, optional (default=DEFAULT_OPTIM) 

73 The optimization function to use for Archetypal Analysis. 

74 init : str, optional (default=DEFAULT_INIT) 

75 The initialization function to use for Archetypal Analysis. 

76 n_jobs : int, optional (default=-1) 

77 Number of jobs for parallel computation. `-1` uses all available cores. 

78 

79 Returns 

80 ------- 

81 None 

82 The results are stored in `adata.uns["AA_var"]` as a DataFrame with the following columns: 

83 - `k`: The number of archetypes. 

84 - `varexpl`: The variance explained by the model. 

85 - `varexpl_ontop`: The additional variance explained compared to the model with `k-1` archetypes. 

86 - `dist_to_projected`: The distance between the variance explained and its projection on the line 

87 connecting the variance explained of first and last k. 

88 """ 

89 # Validation input 

90 if min_a < 2: 

91 raise ValueError("`min_a` must be at least 2.") 

92 if max_a < min_a: 

93 raise ValueError("`max_a` must be greater than or equal to `min_a`.") 

94 

95 if "PCs" not in adata.uns: 

96 raise ValueError( 

97 "PCs not found in adata.uns. Please set the dimension for archetypal analysis with set_dimension()" 

98 ) 

99 

100 X = adata.obsm["X_pca"][:, : adata.uns["PCs"]] 

101 

102 k_arr = np.arange(min_a, max_a + 1) 

103 

104 # Parallel computation of AA 

105 def compute_aa(k): 

106 A, B, Z, RSS, varexpl = AA(n_archetypes=k, optim=optim, init=init).fit(X).return_all() 

107 return k, {"Z": Z, "A": A, "B": B, "RSS": RSS, "varexpl": varexpl} 

108 

109 results_list = Parallel(n_jobs=n_jobs)(delayed(compute_aa)(k) for k in k_arr) 

110 

111 # results = {k: result for k, result in results_list} 

112 results = dict(results_list) # faster, and see https://docs.astral.sh/ruff/rules/unnecessary-comprehension/ 

113 

114 varexpl_values = np.array([results[k]["varexpl"] for k in k_arr]) 

115 

116 plot_df = pd.DataFrame( 

117 { 

118 "k": k_arr, 

119 "varexpl": varexpl_values, 

120 "varexpl_ontop": np.insert(np.diff(varexpl_values), 0, varexpl_values[0]), 

121 } 

122 ) 

123 

124 # Compute the distance of the explained variance to its projection 

125 offset_vec = plot_df[["k", "varexpl"]].iloc[0].values 

126 proj_vec = (plot_df[["k", "varexpl"]].values - offset_vec)[-1, :][:, None] 

127 proj_mtx = proj_vec @ np.linalg.inv(proj_vec.T @ proj_vec) @ proj_vec.T 

128 proj_val = (proj_mtx @ (plot_df[["k", "varexpl"]].values - offset_vec).T).T + offset_vec 

129 proj_df = pd.DataFrame(proj_val, columns=["k", "varexpl"]) 

130 plot_df["dist_to_projected"] = np.linalg.norm( 

131 plot_df[["k", "varexpl"]].values - proj_df[["k", "varexpl"]].values, axis=1 

132 ) 

133 

134 adata.uns["AA_var"] = plot_df 

135 

136 

137def plot_var_explained_aa( 

138 adata: sc.AnnData, 

139) -> pn.ggplot: 

140 """ 

141 Generate an elbow plot of the variance explained by Archetypal Analysis (AA) for a range of archetypes. 

142 

143 This function creates a plot showing the variance explained by AA models with different numbers of archetypes. 

144 The data is retrieved from `adata.uns["AA_var"]`. If `AA_var` is not found, `var_explained_aa` is called. 

145 

146 Parameters 

147 ---------- 

148 adata : sc.AnnData 

149 AnnData object containing the variance explained data in `adata.uns["AA_var"]`. 

150 

151 Returns 

152 ------- 

153 pn.ggplot 

154 A ggplot object showing the variance explained plot. 

155 """ 

156 # Validation input 

157 if "AA_var" not in adata.uns: 

158 print("AA_var not found in adata.uns. Computing variance explained by archetypal analysis...") 

159 var_explained_aa(adata=adata) 

160 

161 plot_df = adata.uns["AA_var"] 

162 

163 # Create data for the diagonal line 

164 diag_data = pd.DataFrame( 

165 { 

166 "k": [plot_df["k"].min(), plot_df["k"].max()], 

167 "varexpl": [plot_df["varexpl"].min(), plot_df["varexpl"].max()], 

168 } 

169 ) 

170 

171 p = ( 

172 pn.ggplot(plot_df) 

173 + pn.geom_line(mapping=pn.aes(x="k", y="varexpl"), color="black") 

174 + pn.geom_point(mapping=pn.aes(x="k", y="varexpl"), color="black") 

175 + pn.geom_line(data=diag_data, mapping=pn.aes(x="k", y="varexpl"), color="gray") 

176 + pn.labs(x="Number of Archetypes (k)", y="Variance Explained") 

177 + pn.lims(y=[0, 1]) 

178 + pn.scale_x_continuous(breaks=list(np.arange(plot_df["k"].min(), plot_df["k"].max() + 1))) 

179 + pn.theme_matplotlib() 

180 ) 

181 return p 

182 

183 

184def plot_projected_dist( 

185 adata: sc.AnnData, 

186) -> pn.ggplot: 

187 """ 

188 Create a plot showing the projected distance for a range of archetypes. 

189 The data is retrieved from `adata.uns["AA_var"]`. If `AA_var` is not found, `var_explained_aa` is called. 

190 

191 Parameters 

192 ---------- 

193 adata : sc.AnnData 

194 AnnData object containing the variance explained data in `adata.uns["AA_var"]`. 

195 

196 Returns 

197 ------- 

198 pn.ggplot 

199 A ggplot object showing the projected distance plot. 

200 """ 

201 # Validation input 

202 if "AA_var" not in adata.uns: 

203 print("AA_var not found in adata.uns. Computing variance explained by archetypal analysis...") 

204 var_explained_aa(adata=adata) 

205 

206 plot_df = adata.uns["AA_var"] 

207 

208 p = ( 

209 pn.ggplot(plot_df) 

210 + pn.geom_col(mapping=pn.aes(x="k", y="dist_to_projected")) 

211 + pn.scale_x_continuous(breaks=list(np.arange(plot_df["k"].min(), plot_df["k"].max() + 1))) 

212 + pn.labs(x="Number of Archetypes (k)", y="Distance to Projected Point") 

213 + pn.theme_matplotlib() 

214 ) 

215 

216 return p 

217 

218 

219def plot_var_on_top( 

220 adata: sc.AnnData, 

221) -> pn.ggplot: 

222 """ 

223 Generate a plot showing the additional variance explained by AA models when increasing the number 

224 of archetypes from `k-1` to `k` The data is retrieved from `adata.uns["AA_var"]`. If `AA_var` is not found, `var_explained_aa` is called. 

225 

226 Parameters 

227 ---------- 

228 adata : sc.AnnData 

229 AnnData objectt containing the variance explained data in `adata.uns["AA_var"]`. 

230 

231 Returns 

232 ------- 

233 pn.ggplot 

234 A ggplot object showing the variance explained on top of (k-1) model plot. 

235 """ 

236 # Validation input 

237 if "AA_var" not in adata.uns: 

238 print("AA_var not found in adata.uns. Computing variance explained by archetypal analysis...") 

239 var_explained_aa(adata=adata) 

240 

241 plot_df = adata.uns["AA_var"] 

242 

243 p = ( 

244 pn.ggplot(plot_df) 

245 + pn.geom_point(pn.aes(x="k", y="varexpl_ontop"), color="black") 

246 + pn.geom_line(pn.aes(x="k", y="varexpl_ontop"), color="black") 

247 + pn.labs(x="Number of Archetypes (k)", y="Variance Explained on Top of (k-1) Model") 

248 + pn.scale_x_continuous(breaks=list(np.arange(plot_df["k"].min(), plot_df["k"].max() + 1))) 

249 + pn.lims(y=(0, None)) 

250 + pn.theme_matplotlib() 

251 ) 

252 

253 return p 

254 

255 

256def bootstrap_aa( 

257 adata: sc.AnnData, 

258 n_bootstrap: int, 

259 n_archetypes: int, 

260 optim: str = DEFAULT_OPTIM, 

261 init: str = DEFAULT_INIT, 

262 seed: int = 42, 

263) -> None: 

264 """ 

265 Perform bootstrap sampling to compute archetypes and assess their stability. 

266 

267 This function generates bootstrap samples from the data, computes archetypes for each sample, 

268 aligns them with the reference archetypes, and stores the results in `adata.uns["AA_bootstrap"]`. 

269 

270 Parameters 

271 ---------- 

272 adata : sc.AnnData 

273 AnnData object. The PCA data should be stored in `adata.obsm["X_pca"]`. 

274 n_bootstrap : int 

275 The number of bootstrap samples to generate. 

276 n_archetypes : int 

277 The number of archetypes to compute for each bootstrap sample. 

278 optim : str, optional (default=DEFAULT_OPTIM) 

279 The optimization function to use for Archetypal Analysis. 

280 init : str, optional (default=DEFAULT_INIT) 

281 The initialization function to use for Archetypal Analysis. 

282 seed : int, optional (default=42) 

283 The random seed for reproducibility. 

284 

285 Returns 

286 ------- 

287 None 

288 The results are stored in `adata.uns["AA_bootstrap"]` as a DataFrame with the following columns: 

289 - `pc_i`: The coordinates of the archetypes in the i-th principal component. 

290 - `archetype`: The archetype index. 

291 - `iter`: The bootstrap iteration index (0 for the reference archetypes). 

292 - `reference`: A boolean indicating whether the archetype is from the reference model. 

293 - `mean_variance`: The mean variance of archetype coordinates across bootstrap samples. 

294 """ 

295 # Validation input 

296 if "PCs" not in adata.uns: 

297 raise ValueError( 

298 "PCs not found in adata.uns. Please set the dimension for archetypal analysis with set_dimension()" 

299 ) 

300 

301 X = adata.obsm["X_pca"][:, : adata.uns["PCs"]] 

302 

303 n_samples, n_features = X.shape 

304 rng = np.random.default_rng(seed) 

305 

306 # Reference archetypes 

307 ref_Z = AA(n_archetypes=n_archetypes, optim=optim, init=init).fit(X).Z 

308 

309 # Generate bootstrap samples 

310 idx_bootstrap = rng.choice(n_samples, size=(n_bootstrap, n_samples), replace=True) 

311 Z_list = [AA(n_archetypes=n_archetypes, optim=optim, init=init).fit(X[idx, :]).Z for idx in idx_bootstrap] 

312 

313 # Align archetypes 

314 Z_list = [align_archetypes(ref_arch=ref_Z.copy(), query_arch=query_Z.copy()) for query_Z in Z_list] 

315 

316 # Compute variance 

317 Z_stack = np.stack(Z_list) 

318 var_per_archetype = Z_stack.var(axis=0).mean(axis=1) 

319 mean_variance = var_per_archetype.mean() 

320 

321 # Create result dataframe 

322 bootstrap_data = [ 

323 pd.DataFrame(Z, columns=[f"pc_{i}" for i in range(n_features)]).assign( 

324 archetype=np.arange(n_archetypes), iter=i + 1 

325 ) 

326 for i, Z in enumerate(Z_list) 

327 ] 

328 bootstrap_df = pd.concat(bootstrap_data) 

329 

330 df = pd.DataFrame(ref_Z, columns=[f"pc_{i}" for i in range(n_features)]) 

331 df["archetype"] = np.arange(n_archetypes) 

332 df["iter"] = 0 

333 

334 bootstrap_df = pd.concat((bootstrap_df, df), axis=0) 

335 bootstrap_df["reference"] = bootstrap_df["iter"] == 0 

336 bootstrap_df["archetype"] = pd.Categorical(bootstrap_df["archetype"]) 

337 

338 bootstrap_df["mean_variance"] = mean_variance 

339 

340 adata.uns["AA_bootstrap"] = bootstrap_df 

341 

342 

343def plot_bootstrap_aa(adata: sc.AnnData) -> go.Figure: 

344 """ 

345 Create an interactive 3D scatter plot showing the positions of archetypes 

346 computed from bootstrap samples, stored in `adata.uns["AA_bootstrap"]`. 

347 

348 Parameters 

349 ---------- 

350 adata : sc.AnnData 

351 Annotated data object containing the archetype bootstrap data in `adata.uns["AA_bootstrap"]`. 

352 

353 Returns 

354 ------- 

355 go.Figure: 

356 3D plot of bootstrap results for the archetypes. 

357 """ 

358 # Validation input 

359 if "AA_bootstrap" not in adata.uns: 

360 raise ValueError("AA_bootstrap not found in adata.uns. Please run bootstrap_aa() to compute") 

361 

362 # Generate the 3D scatter plot 

363 bootstrap_df = adata.uns["AA_bootstrap"] 

364 fig = px.scatter_3d( 

365 bootstrap_df, 

366 x="pc_0", 

367 y="pc_1", 

368 z="pc_2", 

369 color="archetype", 

370 symbol="reference", 

371 labels={ 

372 "pc_0": "PC 1", 

373 "pc_1": "PC 2", 

374 "pc_2": "PC 3", 

375 }, 

376 title="Archetypes on bootstrapepd data", 

377 size_max=10, 

378 hover_data=["iter", "archetype", "reference"], 

379 opacity=0.5, 

380 ) 

381 fig.update_layout(template="none") 

382 

383 return fig 

384 

385 

386def project_on_affine_subspace(X, Z) -> np.ndarray: 

387 """ 

388 Projects a set of points X onto the affine subspace spanned by the vertices Z. 

389 

390 Parameters 

391 ---------- 

392 X : numpy.ndarray 

393 A (D x n) array of n points in D-dimensional space to be projected. 

394 Z : numpy.ndarray 

395 A (D x k) array of k vertices (archetypes) defining the affine subspace in D-dimensional space. 

396 

397 Returns 

398 ------- 

399 proj_coord : numpy.ndarray 

400 The coordinates of the projected points in the subspace defined by Z. 

401 """ 

402 D, k = Z.shape 

403 

404 # Compute the projection vectors (basis for the affine subspace) 

405 if k == 2: 

406 # For a line (k=2), the projection vector is simply the difference between the two vertices 

407 proj_vec = (Z[:, 1] - Z[:, 0])[:, None] 

408 else: 

409 # For higher dimensions, compute the projection vectors relative to the first vertex 

410 proj_vec = Z[:, 1:] - Z[:, 0][:, None] 

411 

412 # Compute the coordinates of the projected points in the subspace 

413 proj_coord = np.linalg.inv(proj_vec.T @ proj_vec) @ proj_vec.T @ (X - Z[:, 0][:, None]) 

414 

415 return proj_coord 

416 

417 

418def compute_t_ratio( 

419 X: sc.AnnData | np.ndarray, 

420 Z: np.ndarray | None = None, 

421) -> float | None: 

422 """ 

423 Compute the t-ratio, which is the ratio of the volume of the polytope defined by the archetypes (Z) 

424 to the volume of the convex hull of the data points (X). 

425 

426 Parameters 

427 ---------- 

428 X : Union[sc.AnnData, np.ndarray] 

429 The input data, which can be either: 

430 - An AnnData object containing the following attributes: 

431 - `adata.obsm["X_pca"]`: A 2D array of shape (n_samples, n_features) representing the PCA coordinates of the data. 

432 - `adata.uns["PCs"]`: The number of principal components used for AA. 

433 - `adata.uns["archetypal_analysis"]["Z"]`: A 2D array of shape (n_archetypes, n_features) representing the archetypes. 

434 - A 2D numpy array of shape (n_samples, n_features) representing the data matrix. In this case, `Z` must be provided. 

435 Z : np.ndarray, optional 

436 A 2D array of shape (n_archetypes, n_features) representing the archetypes. Required if `X` is a numpy array. 

437 

438 Returns 

439 ------- 

440 Optional[float] 

441 - If `X` is an AnnData object, the t-ratio is stored in `X.uns["t_ratio"]` and nothing is returned. 

442 - If `X` is a numpy array, the t-ratio is returned as a float. 

443 """ 

444 adata = None 

445 if isinstance(X, np.ndarray): 

446 if Z is None: 

447 raise ValueError("Z must be provided when input_data is a numpy.ndarray.") 

448 else: 

449 adata = X 

450 X = adata.obsm["X_pca"][:, : adata.uns["PCs"]] 

451 Z = adata.uns["archetypal_analysis"]["Z"] 

452 

453 # Extract dimensions D (PCs), and number of archetypes 

454 D, k = X.shape[1], Z.shape[0] # type: ignore[union-attr] 

455 

456 # Input validation 

457 if k < 2: 

458 raise ValueError("k must satisfy 2 <= k, meaning you need at least 2 archetypes.") 

459 

460 if k < D + 1: 

461 # project onto affine subspace spanned by Z 

462 proj_X = project_on_affine_subspace(X.T, Z.T).T # type: ignore[union-attr] 

463 proj_Z = project_on_affine_subspace(Z.T, Z.T).T # type: ignore[union-attr] 

464 

465 # Compute the convex hull volumes 

466 convhull_volume = ConvexHull(proj_X).volume 

467 polytope_volume = ConvexHull(proj_Z).volume 

468 else: 

469 # Compute the convex hull volumes directly 

470 convhull_volume = ConvexHull(X).volume 

471 polytope_volume = ConvexHull(Z).volume 

472 

473 t_ratio = polytope_volume / convhull_volume 

474 

475 if isinstance(adata, sc.AnnData): 

476 adata.uns["t_ratio"] = t_ratio 

477 return None 

478 else: 

479 return t_ratio 

480 

481 

482def t_ratio_significance(adata, iter=1000, seed=42, n_jobs=-1): 

483 """ 

484 Assesses the significance of the polytope spanned by the archetypes by comparing the t-ratio of the original data to t-ratios computed from randomized datasets. 

485 

486 Parameters 

487 ---------- 

488 adata : sc.AnnData 

489 An AnnData object containing `adata.obsm["X_pca"]` and `adata.uns["PCs"], optionally `adata.uns["t_ratio"]`. If `adata.uns["t_ratio"]` doesnt exist it is called and computed. 

490 rep : int, optional (default=1000) 

491 Number of randomized datasets to generate. 

492 seed : int, optional (default=42) 

493 The random seed for reproducibility. 

494 n_jobs : int, optional 

495 Number of jobs for parallelization (default: 1). Use -1 to use all available cores. 

496 

497 Returns 

498 ------- 

499 float 

500 The proportion of randomized datasets with a t-ratio greater than the original t-ratio (p-value). 

501 """ 

502 # Input validation 

503 if "X_pca" not in adata.obsm: 

504 raise ValueError("adata.obsm['X_pca'] not found.") 

505 if "t_ratio" not in adata.uns: 

506 print("Computing t-ratio...") 

507 compute_t_ratio(adata) 

508 

509 X = adata.obsm["X_pca"][:, : adata.uns["PCs"]] 

510 t_ratio = adata.uns["t_ratio"] 

511 n_samples, n_features = X.shape 

512 n_archetypes = adata.uns["archetypal_analysis"]["Z"].shape[0] 

513 

514 rng = np.random.default_rng(seed) 

515 

516 def compute_randomized_t_ratio(): 

517 # Shuffle each feature independently 

518 SimplexRand1 = np.array([rng.permutation(X[:, i]) for i in range(n_features)]).T 

519 # Compute archetypes and t-ratio for randomized data 

520 Z_mix = AA(n_archetypes=n_archetypes).fit(SimplexRand1).Z 

521 return compute_t_ratio(SimplexRand1, Z_mix) 

522 

523 # Parallelize the computation of randomized t-ratios 

524 RandRatio = Parallel(n_jobs=n_jobs)( 

525 delayed(compute_randomized_t_ratio)() for _ in tqdm(range(iter), desc="Randomizing") 

526 ) 

527 

528 # Calculate the p-value 

529 p_value = np.sum(np.array(RandRatio) > t_ratio) / iter 

530 return p_value 

531 

532 

533def plot_2D( 

534 X: np.ndarray | sc.AnnData, 

535 Z: np.ndarray | None = None, 

536 color_vec: np.ndarray | None = None, 

537) -> pn.ggplot: 

538 """ 

539 2D plot of the datapoints in X and the 2D polytope enclosed by the archetypes in Z. 

540 

541 Parameters 

542 ---------- 

543 X : Union[np.ndarray, sc.AnnData] 

544 The input data, which can be either: 

545 - A 2D array of shape (n_samples, n_features) representing the data points. 

546 - An AnnData object containing the PCA data in `X.obsm["X_pca"]` and archetypes in `X.uns["archetypal_analysis"]["Z"]`. 

547 Z : np.ndarray, optional 

548 A 2D array of shape (n_archetypes, n_features) representing the archetype coordinates. 

549 Required if `X` is not an AnnData object. 

550 color_vec : np.ndarray, optional 

551 A 1D array of shape (n_samples,) containing values for coloring the data points in `X`. 

552 

553 Returns 

554 ------- 

555 pn.ggplot 

556 2D plot of X and polytope enclosed by Z 

557 """ 

558 # Validation input 

559 if isinstance(X, sc.AnnData): 

560 if "archetypal_analysis" not in X.uns: 

561 raise ValueError("Result from Archetypal Analysis not found in adata.uns. Please run AA()") 

562 Z = X.uns["archetypal_analysis"]["Z"] 

563 X = X.obsm["X_pca"][:, : X.uns["PCs"]] 

564 

565 if Z is None: 

566 raise ValueError("Please add the archetypes coordinates as input Z") 

567 

568 if X.shape[1] < 2 or Z.shape[1] < 2: 

569 raise ValueError("Both X and Z must have at least 2 columns (PCs).") 

570 

571 X_plot, Z_plot = X[:, :2], Z[:, :2] 

572 

573 # Order archetypes for plotting the polytope 

574 plot_df = pd.DataFrame(X_plot, columns=["x0", "x1"]) 

575 order = np.argsort(np.arctan2(Z_plot[:, 1] - np.mean(Z_plot[:, 1]), Z_plot[:, 0] - np.mean(Z_plot[:, 0]))) 

576 

577 arch_df = pd.DataFrame(Z_plot, columns=["x0", "x1"]) 

578 arch_df = arch_df.iloc[order].reset_index(drop=True) 

579 arch_df = pd.concat([arch_df, arch_df.iloc[:1]], ignore_index=True) 

580 

581 # Generate plot 

582 p1 = pn.ggplot() 

583 

584 if color_vec is not None: 

585 if len(color_vec) != len(plot_df): 

586 raise ValueError("color_vec must have the same length as X.") 

587 plot_df["color_vec"] = np.array(color_vec) 

588 p1 += pn.geom_point(data=plot_df, mapping=pn.aes(x="x0", y="x1", color="color_vec"), alpha=0.5) 

589 else: 

590 p1 += pn.geom_point(data=plot_df, mapping=pn.aes(x="x0", y="x1"), color="black", alpha=0.5) 

591 

592 p1 += pn.geom_point(data=arch_df, mapping=pn.aes(x="x0", y="x1"), color="red", size=1) 

593 p1 += pn.geom_path(data=arch_df, mapping=pn.aes(x="x0", y="x1"), color="red", size=1) 

594 

595 p1 += pn.labs(x="PC 1", y="PC 2") 

596 p1 += pn.theme_matplotlib() 

597 

598 return p1 

599 

600 

601def plot_3D( 

602 X: np.ndarray | sc.AnnData, 

603 Z: np.ndarray | None = None, 

604 color_vec: np.ndarray | None = None, 

605 marker_size: int = 4, 

606 color_polyhedron: str = "green", 

607) -> go.Figure: 

608 """ 

609 3D plot of the datapoints in X and the 3D polytope enclosed by the archetypes in Z. 

610 

611 Parameters 

612 ---------- 

613 X : Union[np.ndarray, sc.AnnData] 

614 The input data, which can be either: 

615 - A 2D array of shape (n_samples, n_features) representing the data points. 

616 - An AnnData object containing the PCA data in `X.obsm["X_pca"]` and archetypes in `X.uns["archetypal_analysis"]["Z"]`. 

617 Z : np.ndarray, optional 

618 A 2D array of shape (n_archetypes, n_features) representing the archetype coordinates. 

619 Required if `X` is not an AnnData object. 

620 color_vec : np.ndarray, optional 

621 A 1D array of shape (n_samples,) containing values for coloring the data points in `X`. 

622 marker_size : int, optional (default=4) 

623 The size of the markers for the data points in `X`. 

624 color_polyhedron : str, optional (default="green") 

625 The color of the polytope (convex hull) defined by the archetypes. 

626 

627 Returns 

628 ------- 

629 go.Figuret 

630 3D plot of X and polytope enclosed by Z 

631 """ 

632 # Validation input 

633 if isinstance(X, sc.AnnData): 

634 if "archetypal_analysis" not in X.uns: 

635 raise ValueError("Result from Archetypal Analysis not found in adata.uns. Please run AA()") 

636 

637 Z = X.uns["archetypal_analysis"]["Z"] 

638 X = X.obsm["X_pca"][:, : X.uns["PCs"]] 

639 

640 if Z is None: 

641 raise ValueError("Please add the archetypes coordinates as input Z") 

642 

643 if X.shape[1] < 3 or Z.shape[1] < 3: 

644 raise ValueError("Both X and Z must have at least 3 columns (PCs).") 

645 

646 X_plot, Z_plot = X[:, :3], Z[:, :3] 

647 

648 plot_df = pd.DataFrame(X_plot, columns=["x0", "x1", "x2"]) 

649 plot_df["marker_size"] = np.repeat(marker_size, X_plot.shape[0]) 

650 

651 # Create the 3D scatter plot 

652 if color_vec is not None: 

653 if len(color_vec) != len(plot_df): 

654 raise ValueError("color_vec must have the same length as X.") 

655 plot_df["color_vec"] = np.array(color_vec) 

656 fig = px.scatter_3d( 

657 plot_df, 

658 x="x0", 

659 y="x1", 

660 z="x2", 

661 labels={"x0": "PC 1", "x1": "PC 2", "x2": "PC 3"}, 

662 title="3D polytope", 

663 color="color_vec", 

664 size="marker_size", 

665 size_max=10, 

666 opacity=0.5, 

667 ) 

668 else: 

669 fig = px.scatter_3d( 

670 plot_df, 

671 x="x0", 

672 y="x1", 

673 z="x2", 

674 labels={"x0": "PC 1", "x1": "PC 2", "x2": "PC 3"}, 

675 title="3D polytope", 

676 size="marker_size", 

677 size_max=10, 

678 opacity=0.5, 

679 ) 

680 

681 # Compute the convex hull of the archetypes 

682 hull = ConvexHull(Z_plot) 

683 

684 # Add archetypes to the plot 

685 archetype_labels = [f"Archetype {i}" for i in range(Z_plot.shape[0])] 

686 fig.add_trace( 

687 go.Scatter3d( 

688 x=Z_plot[:, 0], 

689 y=Z_plot[:, 1], 

690 z=Z_plot[:, 2], 

691 mode="markers", 

692 text=archetype_labels, 

693 marker=dict(size=4, color=color_polyhedron, symbol="circle"), # noqa: C408 

694 hoverinfo="text", 

695 name="Archetypes", 

696 ) 

697 ) 

698 

699 # Add the polytope (convex hull) to the plot 

700 fig.add_trace( 

701 go.Mesh3d( 

702 x=Z_plot[:, 0], 

703 y=Z_plot[:, 1], 

704 z=Z_plot[:, 2], 

705 i=hull.simplices[:, 0], 

706 j=hull.simplices[:, 1], 

707 k=hull.simplices[:, 2], 

708 color=color_polyhedron, 

709 opacity=0.1, 

710 ) 

711 ) 

712 

713 # Add edges of the polytope to the plot 

714 for simplex in hull.simplices: 

715 simplex = np.append(simplex, simplex[0]) 

716 fig.add_trace( 

717 go.Scatter3d( 

718 x=Z_plot[simplex, 0], 

719 y=Z_plot[simplex, 1], 

720 z=Z_plot[simplex, 2], 

721 mode="lines", 

722 line={"color": color_polyhedron, "width": 4}, 

723 showlegend=False, 

724 ) 

725 ) 

726 

727 fig.update_layout(template="none") 

728 return fig 

729 

730 

731def align_archetypes(ref_arch: np.ndarray, query_arch: np.ndarray) -> np.ndarray: 

732 """ 

733 Align the archetypes of the query to match the order of archetypes in the reference. 

734 

735 This function uses the Euclidean distance between archetypes in the reference and query sets 

736 to determine the optimal alignment. The Hungarian algorithm (linear sum assignment) is used 

737 to find the best matching pairs, and the query archetypes are reordered accordingly. 

738 

739 Parameters 

740 ---------- 

741 ref_arch : np.ndarray 

742 A 2D array of shape (n_archetypes, n_features) representing the reference archetypes. 

743 query_arch : np.ndarray 

744 A 2D array of shape (n_archetypes, n_features) representing the query archetypes. 

745 

746 Returns 

747 ------- 

748 np.ndarray 

749 A 2D array of shape (n_archetypes, n_features) containing the reordered query archetypes. 

750 """ 

751 # Compute pairwise Euclidean distances 

752 euclidean_d = cdist(ref_arch, query_arch.copy(), metric="euclidean") 

753 

754 # Find the optimal assignment using the Hungarian algorithm 

755 ref_idx, query_idx = linear_sum_assignment(euclidean_d) 

756 

757 return query_arch[query_idx, :] 

758 

759 

760def compute_AA( 

761 adata: sc.AnnData | np.ndarray, 

762 n_archetypes: int, 

763 init: str | None = None, 

764 optim: str | None = None, 

765 weight: None | str = None, 

766 max_iter: int | None = None, 

767 derivative_max_iter: int | None = None, 

768 tol: float | None = None, 

769 verbose: bool | None = None, 

770 save_to_anndata: bool = True, 

771 archetypes_only: bool = True, 

772) -> np.ndarray | tuple[np.ndarray, np.ndarray, np.ndarray, float, float] | None: 

773 """ 

774 

775 Perform Archetypal Analysis (AA) on the input data. 

776 

777 This function is a wrapper for the AA class, providing a simplified interface for fitting the model, 

778 and returning the desired outputs or saving them to the AnnData object. 

779 

780 Parameters 

781 ---------- 

782 adata : Union[sc.AnnData, np.ndarray] 

783 The input data, which can be either: 

784 - An AnnData object containing data in `adata.obsm["X_pca"]`. 

785 - A 2D numpy array of shape (n_samples, n_features) representing the data matrix. 

786 n_archetypes : int 

787 The number of archetypes to compute. 

788 init : str, optional 

789 The initialization method for the archetypes. If not provided, the default from the AA class is used. 

790 Options include: 

791 - "random": Random initialization. 

792 - "furthest_sum": Furthest sum initialization. 

793 optim : str, optional 

794 The optimization method for fitting the model. If not provided, the default from the AA class is used. 

795 Options include: 

796 - "projected_gradients": Projected gradients optimization. 

797 - "frank_wolfe": Frank-Wolfe optimization. 

798 - "regularized_nnls": Regularized non-negative least squares optimization. 

799 weight : str, optional 

800 The weighting method for the data. If not provided, the default from the AA class is used. 

801 Options include: 

802 - "bisquare": Bisquare weighting. 

803 max_iter : int, optional 

804 The maximum number of iterations for the optimization. If not provided, the default from the AA class is used. 

805 derivative_max_iter : int, optional 

806 The maximum number of iterations for derivative computation. If not provided, the default from the AA class is used. 

807 tol : float, optional 

808 The tolerance for convergence. If not provided, the default from the AA class is used. 

809 verbose : bool, optional 

810 Whether to print verbose output during fitting. If not provided, the default from the AA class is used. 

811 save_to_anndata : bool, optional (default=True) 

812 Whether to save the results to the AnnData object. If `adata` is not an AnnData object, this is ignored. 

813 archetypes_only : bool, optional (default=True) 

814 Whether to return only the archetypes matrix. If `save_to_anndata` is True, this parameter determines 

815 whether only the archetypes are saved to the AnnData object. 

816 

817 Returns 

818 ------- 

819 Optional[Union[np.ndarray, Tuple[np.ndarray, np.ndarray, np.ndarray, float, float]]] 

820 The output depends on the values of `save_to_anndata` and `archetypes_only`: 

821 - If `archetypes_only` is True: 

822 - Only the archetype matrix (Z) is returned/ saved 

823 - If `archetypes_only` is True: 

824 - returns/ saves a tuple containing: 

825 - A: The matrix of weights for the data points (n_samples, n_archetypes). 

826 - B: The matrix of weights for the archetypes (n_archetypes, n_samples). 

827 - Z: The archetypes matrix (n_archetypes, n_features). 

828 - RSS: The residual sum of squares. 

829 - varexpl: The variance explained by the model. 

830 - If `save_to_anndata` is True: 

831 - Returns `None`. Results are saved to `adata.uns["archetypal_analysis"]`. 

832 - If `save_to_anndata` is False: 

833 - Returns the results. 

834 """ 

835 # Get the signature of AA.__init__ 

836 signature = inspect.signature(AA.__init__) 

837 

838 # Create a dictionary of parameter names and their default values 

839 defaults = { 

840 param: signature.parameters[param].default 

841 for param in signature.parameters 

842 if param != "self" and param != "n_archetypes" 

843 } 

844 

845 # Use the provided values or fall back to the defaults 

846 init = init if init is not None else defaults["init"] 

847 optim = optim if optim is not None else defaults["optim"] 

848 weight = weight if weight is not None else defaults["weight"] 

849 max_iter = max_iter if max_iter is not None else defaults["max_iter"] 

850 derivative_max_iter = derivative_max_iter if derivative_max_iter is not None else defaults["derivative_max_iter"] 

851 tol = tol if tol is not None else defaults["tol"] 

852 verbose = verbose if verbose is not None else defaults["verbose"] 

853 

854 # Create the AA model with the specified parameters 

855 model = AA( 

856 n_archetypes=n_archetypes, 

857 init=init, 

858 optim=optim, 

859 weight=weight, 

860 max_iter=max_iter, 

861 derivative_max_iter=derivative_max_iter, 

862 tol=tol, 

863 verbose=verbose, 

864 ) 

865 

866 # Fit the model to the data 

867 model.fit(adata) 

868 

869 # Save the results to the AnnData object if specified 

870 if save_to_anndata: 

871 if not isinstance(adata, sc.AnnData): 

872 print("No AnnData object found. Returning results") 

873 save_to_anndata = False 

874 else: 

875 model.save_to_anndata(archetypes_only=archetypes_only) 

876 

877 # Return based on the flags 

878 if save_to_anndata: 

879 return None # Results are saved to AnnData, so return nothing 

880 elif archetypes_only: 

881 return model.archetypes() # Return only the archetypes matrix 

882 else: 

883 return model.return_all() # Return the full fitted model 

884 

885 

886# def bootstrap_variance_k_arr(X, n_bootstrap, k_arr, delta=0, seed=42, **kwargs): 

887# assert k_arr.min() > 1 

888# bootstrap_var = np.array( 

889# [ 

890# bootstrap_variance_single_k( 

891# X, n_bootstrap=n_bootstrap, k=k, delta=delta, seed=seed, **kwargs 

892# ) 

893# for k in k_arr 

894# ] 

895# ) 

896# plot_df = pd.DataFrame({"k": k_arr, "var": bootstrap_var}) 

897# p = ( 

898# pn.ggplot(plot_df, pn.aes(x="k", y="var")) 

899# + pn.geom_point(color="blue") 

900# + pn.geom_line(color="blue") 

901# + pn.labs(x="Number of Archetypes", y="Mean Variance in Archetype Position") 

902# ) 

903# return p 

904 

905 

906# Appendix 

907 

908# not sure if this ratio of archeytpe over data variance is useful in any way 

909# def compute_var_ratio_vitali(X, Z): 

910# # adapted from: https://github.com/vitkl/ParetoTI/blob/510990630da589101c6a8313571c96f7544879da/R/fit_pch.R#L1178 

911# data_var = X.var(axis=1) 

912# arch_var = Z.var(axis=1) 

913# return arch_var / data_var 

914 

915# see https://github.com/vitkl/ParetoTI/blob/510990630da589101c6a8313571c96f7544879da/R/fit_pch.R#L257 

916# archetypes = Z.A.T 

917# print(archetypes.shape) 

918# 

919## create random matrix where each row sums to 1 

920# n_additional = 100 

921# np.random.seed(42) 

922# rand_mtx = np.random.rand(n_additional, archetypes.shape[0]) 

923# rand_mtx /= rand_mtx.sum(axis=1)[:, None] 

924# 

925# additional_archetypes = (rand_mtx @ archetypes) 

926# print(additional_archetypes.shape) 

927# 

928# stacked_archetypes = np.row_stack([archetypes, additional_archetypes]) 

929# print(stacked_archetypes.shape) 

930# 

931## https://github.com/vitkl/ParetoTI/blob/510990630da589101c6a8313571c96f7544879da/R/fit_pch.R#L879C39-L879C46 

932## Vitali used the "FA" argument, but this is turned on by default I think in the scipy version (FA - report total area and volume), 

933## see http://www.qhull.org/html/qhull.htm 

934# convhull_archetypes = ConvexHull(stacked_archetypes, qhull_options="QJ") 

935# print(convhull_archetypes.volume)