Coverage for ParTIpy/paretoti_funcs.py: 13%

206 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-16 10:20 +0100

1import numpy as np 

2import pandas as pd 

3import plotly.express as px 

4import plotly.graph_objects as go 

5import plotnine as pn 

6import scanpy as sc 

7from joblib import Parallel, delayed 

8from scipy.optimize import linear_sum_assignment 

9from scipy.spatial import ConvexHull 

10from scipy.spatial.distance import cdist 

11from tqdm import tqdm 

12 

13from .arch import AA 

14from .const import DEFAULT_INIT, DEFAULT_OPTIM 

15 

16####TODO#### 

17# add/fix t-ratio function 

18# Function mean archetype variance for different n_archetypes 

19############ 

20 

21 

22def reduce_pca(adata: sc.AnnData, n_pcs: int) -> None: 

23 """ 

24 Reduces the PCA representation in `adata.obsm["X_pca"]` to the first `n_pcs` components. 

25 If `adata.obsm["X_pca"]` does not exist, PCA is computed and stored in `adata.obsm["X_pca"]`. 

26 The reduced PCA representation is stored in `adata.obsm["X_pca_reduced"]`. 

27 

28 Parameters 

29 ---------- 

30 adata : sc.AnnData 

31 AnnData object containing single-cell data. 

32 n_pcs : int 

33 The number of principal components (PCs) to retain. Must be less than or equal to the 

34 number of available PCs in `adata.obsm["X_pca"]`. 

35 

36 Returns 

37 ------- 

38 None 

39 The results are stored in `adata.obsm["X_pca_reduced"]` 

40 """ 

41 # Validation input 

42 if "X_pca" not in adata.obsm: 

43 print("X_pca not found in adata.obsm. Computing PCA...") 

44 sc.pp.pca(adata, mask_var="highly_variable") 

45 

46 if n_pcs > adata.obsm["X_pca"].shape[1]: 

47 raise ValueError(f"Requested {n_pcs} PCs, but only {adata.obsm['X_pca'].shape[1]} PCs are available.") 

48 

49 adata.obsm["X_pca_reduced"] = adata.obsm["X_pca"][:, :n_pcs] 

50 

51 

52def var_explained_aa( 

53 adata: sc.AnnData, 

54 min_a: int = 2, 

55 max_a: int = 10, 

56 optim: str = DEFAULT_OPTIM, 

57 init: str = DEFAULT_INIT, 

58 n_jobs: int = -1, 

59) -> None: 

60 """ 

61 Compute the variance explained by Archetypal Analysis (AA) for a range of archetypes. 

62 

63 This function performs Archetypal Analysis (AA) for a range of archetypes (from `min_a` to `max_a`) 

64 on the PCA-reduced data stored in `adata.obsm["X_pca_reduced"]`. If the reduced PCA representation 

65 is not available, it uses the full PCA representation (`adata.obsm["X_pca"]`). The results are 

66 stored in `adata.uns["AA_var"]`. 

67 

68 Parameters 

69 ---------- 

70 adata: sc.AnnData 

71 AnnData object containing adata.obsm["X_pca_reduced"] or 

72 adata.obsm["X_pca"]. 

73 min_a : int, optional (default=2) 

74 Minimum number of archetypes to test. 

75 max_a : int, optional (default=10) 

76 Maximum number of archetypes to test. 

77 optim : str, optional (default=DEFAULT_OPTIM) 

78 The optimization function to use for Archetypal Analysis. 

79 init : str, optional (default=DEFAULT_INIT) 

80 The initialization function to use for Archetypal Analysis. 

81 n_jobs : int, optional (default=-1) 

82 Number of jobs for parallel computation. `-1` uses all available cores. 

83 

84 Returns 

85 ------- 

86 None 

87 The results are stored in `adata.uns["AA_var"]` as a DataFrame with the following columns: 

88 - `k`: The number of archetypes. 

89 - `varexpl`: The variance explained by the model. 

90 - `varexpl_ontop`: The additional variance explained compared to the model with `k-1` archetypes. 

91 - `dist_to_projected`: The distance between the variance explained and its projection on the line 

92 connecting the variance explained of first and last k. 

93 """ 

94 # Validation input 

95 if min_a < 2: 

96 raise ValueError("`min_a` must be at least 2.") 

97 if max_a < min_a: 

98 raise ValueError("`max_a` must be greater than or equal to `min_a`.") 

99 

100 if "X_pca_reduced" not in adata.obsm: 

101 print("No reduced PCA found. Calculating with all available PCs from X_pca") 

102 X = adata.obsm["X_pca"] 

103 else: 

104 X = adata.obsm["X_pca_reduced"] 

105 

106 k_arr = np.arange(min_a, max_a + 1) 

107 

108 # Parallel computation of AA 

109 def compute_aa(k): 

110 A, B, Z, RSS, varexpl = AA(n_archetypes=k, optim=optim, init=init).fit(X).return_all() 

111 return k, {"Z": Z, "A": A, "B": B, "RSS": RSS, "varexpl": varexpl} 

112 

113 results_list = Parallel(n_jobs=n_jobs)(delayed(compute_aa)(k) for k in k_arr) 

114 

115 # results = {k: result for k, result in results_list} 

116 results = dict(results_list) # faster, and see https://docs.astral.sh/ruff/rules/unnecessary-comprehension/ 

117 

118 varexpl_values = np.array([results[k]["varexpl"] for k in k_arr]) 

119 

120 plot_df = pd.DataFrame( 

121 { 

122 "k": k_arr, 

123 "varexpl": varexpl_values, 

124 "varexpl_ontop": np.insert(np.diff(varexpl_values), 0, varexpl_values[0]), 

125 } 

126 ) 

127 

128 # Compute the distance of the explained variance to its projection 

129 offset_vec = plot_df[["k", "varexpl"]].iloc[0].values 

130 proj_vec = (plot_df[["k", "varexpl"]].values - offset_vec)[-1, :][:, None] 

131 proj_mtx = proj_vec @ np.linalg.inv(proj_vec.T @ proj_vec) @ proj_vec.T 

132 proj_val = (proj_mtx @ (plot_df[["k", "varexpl"]].values - offset_vec).T).T + offset_vec 

133 proj_df = pd.DataFrame(proj_val, columns=["k", "varexpl"]) 

134 plot_df["dist_to_projected"] = np.linalg.norm( 

135 plot_df[["k", "varexpl"]].values - proj_df[["k", "varexpl"]].values, axis=1 

136 ) 

137 

138 adata.uns["AA_var"] = plot_df 

139 

140 

141def plot_var_explained_aa( 

142 adata: sc.AnnData, 

143) -> pn.ggplot: 

144 """ 

145 Generate an elbow plot of the variance explained by Archetypal Analysis (AA) for a range of archetypes. 

146 

147 This function creates a plot showing the variance explained by AA models with different numbers of archetypes. 

148 The data is retrieved from `adata.uns["AA_var"]`. If `AA_var` is not found, `var_explained_aa` is called. 

149 

150 Parameters 

151 ---------- 

152 adata : sc.AnnData 

153 AnnData object containing the variance explained data in `adata.uns["AA_var"]`. 

154 

155 Returns 

156 ------- 

157 pn.ggplot 

158 A ggplot object showing the variance explained plot. 

159 """ 

160 # Validation input 

161 if "AA_var" not in adata.uns: 

162 print("AA_var not found in adata.uns. Computing variance explained by archetypal analysis...") 

163 var_explained_aa(adata=adata) 

164 

165 plot_df = adata.uns["AA_var"] 

166 

167 # Create data for the diagonal line 

168 diag_data = pd.DataFrame( 

169 { 

170 "k": [plot_df["k"].min(), plot_df["k"].max()], 

171 "varexpl": [plot_df["varexpl"].min(), plot_df["varexpl"].max()], 

172 } 

173 ) 

174 

175 p = ( 

176 pn.ggplot(plot_df) 

177 + pn.geom_line(mapping=pn.aes(x="k", y="varexpl"), color="black") 

178 + pn.geom_point(mapping=pn.aes(x="k", y="varexpl"), color="black") 

179 + pn.geom_line(data=diag_data, mapping=pn.aes(x="k", y="varexpl"), color="gray") 

180 + pn.labs(x="Number of Archetypes (k)", y="Variance Explained") 

181 + pn.lims(y=[0, 1]) 

182 + pn.scale_x_continuous(breaks=np.arange(plot_df["k"].min(), plot_df["k"].max() + 1)) 

183 + pn.theme_matplotlib() 

184 ) 

185 return p 

186 

187 

188def plot_projected_dist( 

189 adata: sc.AnnData, 

190) -> pn.ggplot: 

191 """ 

192 Create a plot showing the projected distance for a range of archetypes. 

193 The data is retrieved from `adata.uns["AA_var"]`. If `AA_var` is not found, `var_explained_aa` is called. 

194 

195 Parameters 

196 ---------- 

197 adata : sc.AnnData 

198 AnnData object containing the variance explained data in `adata.uns["AA_var"]`. 

199 

200 Returns 

201 ------- 

202 pn.ggplot 

203 A ggplot object showing the projected distance plot. 

204 """ 

205 # Validation input 

206 if "AA_var" not in adata.uns: 

207 print("AA_var not found in adata.uns. Computing variance explained by archetypal analysis...") 

208 var_explained_aa(adata=adata) 

209 

210 plot_df = adata.uns["AA_var"] 

211 

212 p = ( 

213 pn.ggplot(plot_df) 

214 + pn.geom_col(mapping=pn.aes(x="k", y="dist_to_projected")) 

215 + pn.scale_x_continuous(breaks=np.arange(plot_df["k"].min(), plot_df["k"].max() + 1)) 

216 + pn.labs(x="Number of Archetypes (k)", y="Distance to Projected Point") 

217 + pn.theme_matplotlib() 

218 ) 

219 

220 return p 

221 

222 

223def plot_var_on_top( 

224 adata: sc.AnnData, 

225) -> pn.ggplot: 

226 """ 

227 Generate a plot showing the additional variance explained by AA models when increasing the number 

228 of archetypes from `k-1` to `k` The data is retrieved from `adata.uns["AA_var"]`. If `AA_var` is not found, `var_explained_aa` is called. 

229 

230 Parameters 

231 ---------- 

232 adata : sc.AnnData 

233 AnnData objectt containing the variance explained data in `adata.uns["AA_var"]`. 

234 

235 Returns 

236 ------- 

237 pn.ggplot 

238 A ggplot object showing the variance explained on top of (k-1) model plot. 

239 """ 

240 # Validation input 

241 if "AA_var" not in adata.uns: 

242 print("AA_var not found in adata.uns. Computing variance explained by archetypal analysis...") 

243 var_explained_aa(adata=adata) 

244 

245 plot_df = adata.uns["AA_var"] 

246 

247 p = ( 

248 pn.ggplot(plot_df) 

249 + pn.geom_point(pn.aes(x="k", y="varexpl_ontop"), color="black") 

250 + pn.geom_line(pn.aes(x="k", y="varexpl_ontop"), color="black") 

251 + pn.labs(x="Number of Archetypes (k)", y="Variance Explained on Top of (k-1) Model") 

252 + pn.scale_x_continuous(breaks=np.arange(plot_df["k"].min(), plot_df["k"].max() + 1)) 

253 + pn.lims(y=(0, None)) 

254 + pn.theme_matplotlib() 

255 ) 

256 

257 return p 

258 

259 

260def bootstrap_aa( 

261 adata: sc.AnnData, 

262 n_bootstrap: int, 

263 n_archetypes: int, 

264 optim: str = DEFAULT_OPTIM, 

265 init: str = DEFAULT_INIT, 

266 seed: int = 42, 

267) -> None: 

268 """ 

269 Perform bootstrap sampling to compute archetypes and assess their stability. 

270 

271 This function generates bootstrap samples from the data, computes archetypes for each sample, 

272 aligns them with the reference archetypes, and stores the results in `adata.uns["AA_bootstrap"]`. 

273 

274 Parameters 

275 ---------- 

276 adata : sc.AnnData 

277 AnnData object. The PCA-reduced data should be stored in `adata.obsm["X_pca_reduced"]`. If not 

278 found, the full PCA representation (`adata.obsm["X_pca"]`) is used. 

279 n_bootstrap : int 

280 The number of bootstrap samples to generate. 

281 n_archetypes : int 

282 The number of archetypes to compute for each bootstrap sample. 

283 optim : str, optional (default=DEFAULT_OPTIM) 

284 The optimization function to use for Archetypal Analysis. 

285 init : str, optional (default=DEFAULT_INIT) 

286 The initialization function to use for Archetypal Analysis. 

287 seed : int, optional (default=42) 

288 The random seed for reproducibility. 

289 

290 Returns 

291 ------- 

292 None 

293 The results are stored in `adata.uns["AA_bootstrap"]` as a DataFrame with the following columns: 

294 - `pc_i`: The coordinates of the archetypes in the i-th principal component. 

295 - `archetype`: The archetype index. 

296 - `iter`: The bootstrap iteration index (0 for the reference archetypes). 

297 - `reference`: A boolean indicating whether the archetype is from the reference model. 

298 - `mean_variance`: The mean variance of archetype coordinates across bootstrap samples. 

299 """ 

300 # Validation input 

301 if "X_pca_reduced" not in adata.obsm: 

302 if "X_pca" not in adata.obsm: 

303 raise ValueError("Neither `X_pca_reduced` nor `X_pca` found in `adata.obsm`. Please compute PCA first.") 

304 print("No reduced PCA found. Calculating with all available PCs from `X_pca`.") 

305 X = adata.obsm["X_pca"] 

306 else: 

307 X = adata.obsm["X_pca_reduced"] 

308 

309 n_samples, n_features = X.shape 

310 rng = np.random.default_rng(seed) 

311 

312 # Reference archetypes 

313 ref_Z = AA(n_archetypes=n_archetypes, optim=optim, init=init).fit(X).Z 

314 

315 # Generate bootstrap samples 

316 idx_bootstrap = rng.choice(n_samples, size=(n_bootstrap, n_samples), replace=True) 

317 Z_list = [AA(n_archetypes=n_archetypes, optim=optim, init=init).fit(X[idx, :]).Z for idx in idx_bootstrap] 

318 

319 # Align archetypes 

320 Z_list = [align_archetypes(ref_arch=ref_Z.copy(), query_arch=query_Z.copy()) for query_Z in Z_list] 

321 

322 # Compute variance 

323 Z_stack = np.stack(Z_list) 

324 var_per_archetype = Z_stack.var(axis=0).mean(axis=1) 

325 mean_variance = var_per_archetype.mean() 

326 

327 # Create result dataframe 

328 bootstrap_data = [ 

329 pd.DataFrame(Z, columns=[f"pc_{i}" for i in range(n_features)]).assign( 

330 archetype=np.arange(n_archetypes), iter=i + 1 

331 ) 

332 for i, Z in enumerate(Z_list) 

333 ] 

334 bootstrap_df = pd.concat(bootstrap_data) 

335 

336 df = pd.DataFrame(ref_Z, columns=[f"pc_{i}" for i in range(n_features)]) 

337 df["archetype"] = np.arange(n_archetypes) 

338 df["iter"] = 0 

339 

340 bootstrap_df = pd.concat((bootstrap_df, df), axis=0) 

341 bootstrap_df["reference"] = bootstrap_df["iter"] == 0 

342 bootstrap_df["archetype"] = pd.Categorical(bootstrap_df["archetype"]) 

343 

344 bootstrap_df["mean_variance"] = mean_variance 

345 

346 adata.uns["AA_bootstrap"] = bootstrap_df 

347 

348 

349def plot_bootstrap_aa(adata: sc.AnnData) -> go.Figure: 

350 """ 

351 Create an interactive 3D scatter plot showing the positions of archetypes 

352 computed from bootstrap samples, stored in `adata.uns["AA_bootstrap"]`. 

353 

354 Parameters 

355 ---------- 

356 adata : sc.AnnData 

357 Annotated data object containing the archetype bootstrap data in `adata.uns["AA_bootstrap"]`. 

358 

359 Returns 

360 ------- 

361 go.Figure: 

362 3D plot of bootstrap results for the archetypes. 

363 """ 

364 # Validation input 

365 if "AA_bootstrap" not in adata.uns: 

366 raise ValueError("AA_bootstrap not found in adata.uns. Please run bootstrap_aa() to compute") 

367 

368 # Generate the 3D scatter plot 

369 bootstrap_df = adata.uns["AA_bootstrap"] 

370 fig = px.scatter_3d( 

371 bootstrap_df, 

372 x="pc_0", 

373 y="pc_1", 

374 z="pc_2", 

375 color="archetype", 

376 symbol="reference", 

377 labels={ 

378 "pc_0": "PC 1", 

379 "pc_1": "PC 2", 

380 "pc_2": "PC 3", 

381 }, 

382 title="Archetypes on bootstrapepd data", 

383 size_max=10, 

384 hover_data=["iter", "archetype", "reference"], 

385 opacity=0.5, 

386 ) 

387 fig.update_layout(template="none") 

388 

389 return fig 

390 

391 

392def project_on_affine_subspace(X, Z): 

393 """ 

394 Projects a set of points X onto the affine subspace spanned by the vertices Z. 

395 

396 Parameters 

397 ---------- 

398 X : numpy.ndarray 

399 A (D x n) array of n points in D-dimensional space to be projected. 

400 Z : numpy.ndarray 

401 A (D x k) array of k vertices (archetypes) defining the affine subspace in D-dimensional space. 

402 

403 Returns 

404 ------- 

405 proj_coord : numpy.ndarray 

406 The coordinates of the projected points in the subspace defined by Z. 

407 """ 

408 D, k = Z.shape 

409 

410 # Compute the projection vectors (basis for the affine subspace) 

411 if k == 2: 

412 # For a line (k=2), the projection vector is simply the difference between the two vertices 

413 proj_vec = (Z[:, 1] - Z[:, 0])[:, None] 

414 else: 

415 # For higher dimensions, compute the projection vectors relative to the first vertex 

416 proj_vec = Z[:, 1:] - Z[:, 0][:, None] 

417 

418 # Compute the coordinates of the projected points in the subspace 

419 proj_coord = np.linalg.inv(proj_vec.T @ proj_vec) @ proj_vec.T @ (X - Z[:, 0][:, None]) 

420 

421 return proj_coord 

422 

423 

424def compute_t_ratio(X, Z=None): 

425 """ 

426 Computes the ratio of the volume of the polytope defined by Z to the volume of the convex hull of X. 

427 

428 Parameters 

429 ---------- 

430 adata : sc.AnnData 

431 An AnnData object containing the following attributes: 

432 - `adata.obsm["X_pca_reduced"]`: A (n x D) array of n data points in D-dimensional space. 

433 - `adata.uns["archetypal_analysis"]["Z"]`: A (k x D) array of k archetypes defining the polytope in D-dimensional space. 

434 

435 Returns 

436 ------- 

437 None 

438 The function stores the computed t-ratio in `adata.uns["t_ratio"]`. 

439 """ 

440 adata = None 

441 if isinstance(X, np.ndarray): 

442 if Z is None: 

443 raise ValueError("Z must be provided when input_data is a numpy.ndarray.") 

444 else: 

445 adata = X 

446 X = adata.obsm["X_pca_reduced"] 

447 Z = adata.uns["archetypal_analysis"]["Z"] 

448 

449 # Extract dimensions D (PCs), and number of archetypes 

450 D, k = X.shape[1], Z.shape[0] 

451 

452 # Input validation 

453 if k < 2: 

454 raise ValueError("k must satisfy 2 <= k, meaning you need at least 2 archetypes.") 

455 

456 if k < D + 1: 

457 # project onto affine subspace spanned by Z 

458 proj_X = project_on_affine_subspace(X.T, Z.T).T 

459 proj_Z = project_on_affine_subspace(Z.T, Z.T).T 

460 

461 # Compute the convex hull volumes 

462 convhull_volume = ConvexHull(proj_X).volume 

463 polytope_volume = ConvexHull(proj_Z).volume 

464 else: 

465 # Compute the convex hull volumes directly 

466 convhull_volume = ConvexHull(X).volume 

467 polytope_volume = ConvexHull(Z).volume 

468 

469 t_ratio = polytope_volume / convhull_volume 

470 

471 if isinstance(adata, sc.AnnData): 

472 adata.uns["t_ratio"] = t_ratio 

473 else: 

474 return t_ratio 

475 

476 

477def t_ratio_significance(adata, iter=1000, seed=42, n_jobs=-1): 

478 """ 

479 Assesses the significance of the polytope spanned by the archetypes by comparing the t-ratio of the original data to t-ratios computed from randomized datasets. 

480 

481 Parameters 

482 ---------- 

483 adata : sc.AnnData 

484 An AnnData object containing `adata.obsm["X_pca_reduced"]` and optionally `adata.uns["t_ratio"]`. If it doesnt exist it is called and computed. 

485 rep : int, optional (default=1000) 

486 Number of randomized datasets to generate. 

487 seed : int, optional (default=42) 

488 The random seed for reproducibility. 

489 n_jobs : int, optional 

490 Number of jobs for parallelization (default: 1). Use -1 to use all available cores. 

491 

492 Returns 

493 ------- 

494 float 

495 The proportion of randomized datasets with a t-ratio greater than the original t-ratio (p-value). 

496 """ 

497 # Input validation 

498 if "X_pca_reduced" not in adata.obsm: 

499 raise ValueError("adata.obsm['X_pca_reduced'] not found.") 

500 if "t_ratio" not in adata.uns: 

501 print("Computing t-ratio...") 

502 compute_t_ratio(adata) 

503 

504 X = adata.obsm["X_pca_reduced"] 

505 t_ratio = adata.uns["t_ratio"] 

506 n_samples, n_features = X.shape 

507 n_archetypes = adata.uns["archetypal_analysis"]["Z"].shape[0] 

508 

509 rng = np.random.default_rng(seed) 

510 

511 def compute_randomized_t_ratio(): 

512 # Shuffle each feature independently 

513 SimplexRand1 = np.array([rng.permutation(X[:, i]) for i in range(n_features)]).T 

514 # Compute archetypes and t-ratio for randomized data 

515 Z_mix = AA(n_archetypes=n_archetypes).fit(SimplexRand1).Z 

516 return compute_t_ratio(SimplexRand1, Z_mix) 

517 

518 # Parallelize the computation of randomized t-ratios 

519 RandRatio = Parallel(n_jobs=n_jobs)( 

520 delayed(compute_randomized_t_ratio)() for _ in tqdm(range(iter), desc="Randomizing") 

521 ) 

522 

523 # Calculate the p-value 

524 p_value = np.sum(np.array(RandRatio) > t_ratio) / iter 

525 return p_value 

526 

527 

528def plot_2D( 

529 X: np.ndarray | sc.AnnData, 

530 Z: np.ndarray | None = None, 

531 color_vec: np.ndarray | None = None, 

532) -> pn.ggplot: 

533 """ 

534 2D plot of the datapoints in X and the 2D polytope enclosed by the archetypes in Z. 

535 

536 Parameters 

537 ---------- 

538 X : Union[np.ndarray, sc.AnnData] 

539 The input data, which can be either: 

540 - A 2D array of shape (n_samples, n_features) representing the data points. 

541 - An AnnData object containing the PCA-reduced data in `.obsm["X_pca_reduced"]` and archetypes in `.uns["archetypal_analysis"]["Z"]`. 

542 Z : np.ndarray, optional 

543 A 2D array of shape (n_archetypes, n_features) representing the archetype coordinates. 

544 Required if `X` is not an AnnData object. 

545 color_vec : np.ndarray, optional 

546 A 1D array of shape (n_samples,) containing values for coloring the data points in `X`. 

547 

548 Returns 

549 ------- 

550 pn.ggplot 

551 2D plot of X and polytope enclosed by Z 

552 """ 

553 # Validation input 

554 if isinstance(X, sc.AnnData): 

555 if "archetypal_analysis" not in X.uns: 

556 raise ValueError("Result from Archetypal Analysis not found in adata.uns. Please run AA()") 

557 Z = X.uns["archetypal_analysis"]["Z"] 

558 X = X.obsm["X_pca_reduced"] 

559 

560 if Z is None: 

561 raise ValueError("Please add the archetypes coordinates as input Z") 

562 

563 if X.shape[1] < 2 or Z.shape[1] < 2: 

564 raise ValueError("Both X and Z must have at least 2 columns (PCs).") 

565 

566 X_plot, Z_plot = X[:, :2], Z[:, :2] 

567 

568 # Order archetypes for plotting the polytope 

569 plot_df = pd.DataFrame(X_plot, columns=["x0", "x1"]) 

570 order = np.argsort(np.arctan2(Z_plot[:, 1] - np.mean(Z_plot[:, 1]), Z_plot[:, 0] - np.mean(Z_plot[:, 0]))) 

571 

572 arch_df = pd.DataFrame(Z_plot, columns=["x0", "x1"]) 

573 arch_df = arch_df.iloc[order].reset_index(drop=True) 

574 arch_df = pd.concat([arch_df, arch_df.iloc[:1]], ignore_index=True) 

575 

576 # Generate plot 

577 p1 = pn.ggplot() 

578 

579 if color_vec is not None: 

580 if len(color_vec) != len(plot_df): 

581 raise ValueError("color_vec must have the same length as X.") 

582 plot_df["color_vec"] = np.array(color_vec) 

583 p1 += pn.geom_point(data=plot_df, mapping=pn.aes(x="x0", y="x1", color="color_vec"), alpha=0.5) 

584 else: 

585 p1 += pn.geom_point(data=plot_df, mapping=pn.aes(x="x0", y="x1"), color="black", alpha=0.5) 

586 

587 p1 += pn.geom_point(data=arch_df, mapping=pn.aes(x="x0", y="x1"), color="red", size=1) 

588 p1 += pn.geom_path(data=arch_df, mapping=pn.aes(x="x0", y="x1"), color="red", size=1) 

589 

590 p1 += pn.labs(x="PC 1", y="PC 2") 

591 p1 += pn.theme_matplotlib() 

592 

593 return p1 

594 

595 

596def plot_3D( 

597 X: np.ndarray | sc.AnnData, 

598 Z: np.ndarray | None = None, 

599 color_vec: np.ndarray | None = None, 

600 marker_size: int = 4, 

601 color_polyhedron: str = "green", 

602) -> go.Figure: 

603 """ 

604 3D plot of the datapoints in X and the 3D polytope enclosed by the archetypes in Z. 

605 

606 Parameters 

607 ---------- 

608 X : Union[np.ndarray, sc.AnnData] 

609 The input data, which can be either: 

610 - A 2D array of shape (n_samples, n_features) representing the data points. 

611 - An AnnData object containing the PCA-reduced data in `.obsm["X_pca_reduced"]` and archetypes in `.uns["archetypal_analysis"]["Z"]`. 

612 Z : np.ndarray, optional 

613 A 2D array of shape (n_archetypes, n_features) representing the archetype coordinates. 

614 Required if `X` is not an AnnData object. 

615 color_vec : np.ndarray, optional 

616 A 1D array of shape (n_samples,) containing values for coloring the data points in `X`. 

617 marker_size : int, optional (default=4) 

618 The size of the markers for the data points in `X`. 

619 color_polyhedron : str, optional (default="green") 

620 The color of the polytope (convex hull) defined by the archetypes. 

621 

622 Returns 

623 ------- 

624 go.Figuret 

625 3D plot of X and polytope enclosed by Z 

626 """ 

627 # Validation input 

628 if isinstance(X, sc.AnnData): 

629 if "archetypal_analysis" not in X.uns: 

630 raise ValueError("Result from Archetypal Analysis not found in adata.uns. Please run AA()") 

631 

632 Z = X.uns["archetypal_analysis"]["Z"] 

633 X = X.obsm["X_pca_reduced"] 

634 

635 if Z is None: 

636 raise ValueError("Please add the archetypes coordinates as input Z") 

637 

638 if X.shape[1] < 3 or Z.shape[1] < 3: 

639 raise ValueError("Both X and Z must have at least 3 columns (PCs).") 

640 

641 X_plot, Z_plot = X[:, :3], Z[:, :3] 

642 

643 plot_df = pd.DataFrame(X_plot, columns=["x0", "x1", "x2"]) 

644 plot_df["marker_size"] = np.repeat(marker_size, X_plot.shape[0]) 

645 

646 # Create the 3D scatter plot 

647 if color_vec is not None: 

648 if len(color_vec) != len(plot_df): 

649 raise ValueError("color_vec must have the same length as X.") 

650 plot_df["color_vec"] = np.array(color_vec) 

651 fig = px.scatter_3d( 

652 plot_df, 

653 x="x0", 

654 y="x1", 

655 z="x2", 

656 labels={"x0": "PC 1", "x1": "PC 2", "x2": "PC 3"}, 

657 title="3D polytope", 

658 color="color_vec", 

659 size="marker_size", 

660 size_max=10, 

661 opacity=0.5, 

662 ) 

663 else: 

664 fig = px.scatter_3d( 

665 plot_df, 

666 x="x0", 

667 y="x1", 

668 z="x2", 

669 labels={"x0": "PC 1", "x1": "PC 2", "x2": "PC 3"}, 

670 title="3D polytope", 

671 size="marker_size", 

672 size_max=10, 

673 opacity=0.5, 

674 ) 

675 

676 # Compute the convex hull of the archetypes 

677 hull = ConvexHull(Z_plot) 

678 

679 # Add archetypes to the plot 

680 archetype_labels = [f"Archetype {i}" for i in range(Z_plot.shape[0])] 

681 fig.add_trace( 

682 go.Scatter3d( 

683 x=Z_plot[:, 0], 

684 y=Z_plot[:, 1], 

685 z=Z_plot[:, 2], 

686 mode="markers", 

687 marker={"size": 4, "color": color_polyhedron, "symbol": "circle"}, 

688 text=archetype_labels, 

689 hoverinfo="text", 

690 name="Archetypes", 

691 ) 

692 ) 

693 

694 # Add the polytope (convex hull) to the plot 

695 fig.add_trace( 

696 go.Mesh3d( 

697 x=Z_plot[:, 0], 

698 y=Z_plot[:, 1], 

699 z=Z_plot[:, 2], 

700 i=hull.simplices[:, 0], 

701 j=hull.simplices[:, 1], 

702 k=hull.simplices[:, 2], 

703 color=color_polyhedron, 

704 opacity=0.1, 

705 ) 

706 ) 

707 

708 # Add edges of the polytope to the plot 

709 for simplex in hull.simplices: 

710 simplex = np.append(simplex, simplex[0]) 

711 fig.add_trace( 

712 go.Scatter3d( 

713 x=Z_plot[simplex, 0], 

714 y=Z_plot[simplex, 1], 

715 z=Z_plot[simplex, 2], 

716 mode="lines", 

717 line={"color": color_polyhedron, "width": 4}, 

718 showlegend=False, 

719 ) 

720 ) 

721 

722 fig.update_layout(template="none") 

723 return fig 

724 

725 

726def align_archetypes(ref_arch: np.ndarray, query_arch: np.ndarray) -> np.ndarray: 

727 """ 

728 Align the archetypes of the query to match the order of archetypes in the reference. 

729 

730 This function uses the Euclidean distance between archetypes in the reference and query sets 

731 to determine the optimal alignment. The Hungarian algorithm (linear sum assignment) is used 

732 to find the best matching pairs, and the query archetypes are reordered accordingly. 

733 

734 Parameters 

735 ---------- 

736 ref_arch : np.ndarray 

737 A 2D array of shape (n_archetypes, n_features) representing the reference archetypes. 

738 query_arch : np.ndarray 

739 A 2D array of shape (n_archetypes, n_features) representing the query archetypes. 

740 

741 Returns 

742 ------- 

743 np.ndarray 

744 A 2D array of shape (n_archetypes, n_features) containing the reordered query archetypes. 

745 """ 

746 # Compute pairwise Euclidean distances 

747 euclidean_d = cdist(ref_arch, query_arch.copy(), metric="euclidean") 

748 

749 # Find the optimal assignment using the Hungarian algorithm 

750 ref_idx, query_idx = linear_sum_assignment(euclidean_d) 

751 

752 return query_arch[query_idx, :] 

753 

754 

755# def bootstrap_variance_k_arr(X, n_bootstrap, k_arr, delta=0, seed=42, **kwargs): 

756# assert k_arr.min() > 1 

757# bootstrap_var = np.array( 

758# [ 

759# bootstrap_variance_single_k( 

760# X, n_bootstrap=n_bootstrap, k=k, delta=delta, seed=seed, **kwargs 

761# ) 

762# for k in k_arr 

763# ] 

764# ) 

765# plot_df = pd.DataFrame({"k": k_arr, "var": bootstrap_var}) 

766# p = ( 

767# pn.ggplot(plot_df, pn.aes(x="k", y="var")) 

768# + pn.geom_point(color="blue") 

769# + pn.geom_line(color="blue") 

770# + pn.labs(x="Number of Archetypes", y="Mean Variance in Archetype Position") 

771# ) 

772# return p 

773 

774 

775# Appendix 

776 

777# not sure if this ratio of archeytpe over data variance is useful in any way 

778# def compute_var_ratio_vitali(X, Z): 

779# # adapted from: https://github.com/vitkl/ParetoTI/blob/510990630da589101c6a8313571c96f7544879da/R/fit_pch.R#L1178 

780# data_var = X.var(axis=1) 

781# arch_var = Z.var(axis=1) 

782# return arch_var / data_var 

783 

784# see https://github.com/vitkl/ParetoTI/blob/510990630da589101c6a8313571c96f7544879da/R/fit_pch.R#L257 

785# archetypes = Z.A.T 

786# print(archetypes.shape) 

787# 

788## create random matrix where each row sums to 1 

789# n_additional = 100 

790# np.random.seed(42) 

791# rand_mtx = np.random.rand(n_additional, archetypes.shape[0]) 

792# rand_mtx /= rand_mtx.sum(axis=1)[:, None] 

793# 

794# additional_archetypes = (rand_mtx @ archetypes) 

795# print(additional_archetypes.shape) 

796# 

797# stacked_archetypes = np.row_stack([archetypes, additional_archetypes]) 

798# print(stacked_archetypes.shape) 

799# 

800## https://github.com/vitkl/ParetoTI/blob/510990630da589101c6a8313571c96f7544879da/R/fit_pch.R#L879C39-L879C46 

801## Vitali used the "FA" argument, but this is turned on by default I think in the scipy version (FA - report total area and volume), 

802## see http://www.qhull.org/html/qhull.htm 

803# convhull_archetypes = ConvexHull(stacked_archetypes, qhull_options="QJ") 

804# print(convhull_archetypes.volume)