Coverage for src / autoencodix / utils / example_data.py: 0%

136 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1import numpy as np 

2import pandas as pd 

3from typing import Tuple 

4import torch 

5from sklearn.datasets import make_blobs # type: ignore 

6from sklearn.model_selection import train_test_split # type: ignore 

7from autoencodix.data._datasetcontainer import DatasetContainer 

8from autoencodix.data._numeric_dataset import NumericDataset 

9from autoencodix.configs.default_config import DefaultConfig 

10from autoencodix.data.datapackage import DataPackage 

11from autoencodix.configs.default_config import DataCase 

12import mudata 

13import anndata 

14 

15config = DefaultConfig() 

16 

17 

18def generate_example_data( 

19 n_samples=1000, n_features=30, n_clusters=5, random_seed=config.global_seed 

20): 

21 """ 

22 Generate synthetic data for autoencoder testing and examples. 

23 

24 Args: 

25 n_samples: Number of samples to generate, default=1000 

26 n_features: Number of features for each sample, default=30 

27 n_clusters: Number of clusters to generate, default=5 

28 random_seed:: Random seed for reproducibility 

29 

30 Returns: 

31 DatasetContainer: Container with train, validation and test datasets 

32 """ 

33 np.random.seed(random_seed) 

34 torch.manual_seed(random_seed) 

35 

36 X, cluster_labels = make_blobs( 

37 n_samples=n_samples, 

38 n_features=n_features, 

39 centers=n_clusters, 

40 cluster_std=1.8, 

41 random_state=random_seed, 

42 ) 

43 

44 # Convert integer cluster labels to string labels like "Cluster_1" 

45 str_cluster_labels = np.array([f"Cluster_{i + 1}" for i in cluster_labels]) 

46 # Create additional metadata features that correlate with the clusters 

47 # Convert integer cluster labels to string labels like "Cluster_1" 

48 str_cluster_labels = np.array([f"Cluster_{i + 1}" for i in cluster_labels]) 

49 

50 metadata_df = pd.DataFrame( 

51 { 

52 "cluster": str_cluster_labels, 

53 "age": np.random.normal(30, 10, n_samples) 

54 + cluster_labels * 5, # Age correlates with cluster 

55 "size": np.random.uniform(0, 10, n_samples) 

56 + cluster_labels * 2, # Size correlates with cluster 

57 "density": np.random.exponential(1, n_samples) 

58 * (cluster_labels + 1) 

59 / 3, # Density correlates with cluster 

60 "category": np.random.choice( 

61 ["A", "B", "C", "D", "E"], n_samples, p=[0.2, 0.2, 0.2, 0.2, 0.2] 

62 ), # Random category 

63 } 

64 ) 

65 

66 # Add some noise features to metadata that don't correlate with clusters 

67 metadata_df["random_feature"] = np.random.normal(0, 1, n_samples) 

68 

69 ids = [f"sample_{i}" for i in range(n_samples)] 

70 metadata_df["sample_id"] = ids 

71 metadata_df.index = ids # Set sample_id as index 

72 data_tensor = torch.tensor(X, dtype=torch.float32) 

73 

74 # Get split ratios from DefaultConfig 

75 config = DefaultConfig() 

76 valid_ratio = config.valid_ratio 

77 test_ratio = config.test_ratio 

78 

79 # Calculate test_size for first split (everything except train) 

80 first_test_size = valid_ratio + test_ratio 

81 

82 # Calculate the proportion of valid in the remaining data 

83 # If valid_ratio + test_ratio = 0.3 and valid_ratio = 0.1, then valid should be 1/3 of the remaining data 

84 valid_proportion = valid_ratio / first_test_size if first_test_size > 0 else 0 

85 

86 train_idx, temp_idx = train_test_split( 

87 np.arange(n_samples), 

88 test_size=first_test_size, 

89 random_state=random_seed, 

90 stratify=cluster_labels, # Ensure balanced classes in each split 

91 ) 

92 val_idx, test_idx = train_test_split( 

93 temp_idx, 

94 test_size=(1 - valid_proportion), # Remaining proportion goes to test 

95 random_state=random_seed, 

96 stratify=cluster_labels[temp_idx], # Ensure balanced classes in each split 

97 ) 

98 

99 # Create split indicators as numpy arrays 

100 train_split = np.zeros(n_samples, dtype=bool) 

101 train_split[train_idx] = True 

102 

103 val_split = np.zeros(n_samples, dtype=bool) 

104 val_split[val_idx] = True 

105 

106 test_split = np.zeros(n_samples, dtype=bool) 

107 test_split[test_idx] = True 

108 

109 train_dataset = NumericDataset( 

110 data=data_tensor[train_idx], 

111 config=DefaultConfig(), 

112 sample_ids=[ids[i] for i in train_idx], 

113 metadata=metadata_df.iloc[train_idx], 

114 split_indices=train_split, 

115 feature_ids=[f"feature_{i}" for i in range(n_features)], 

116 ) 

117 

118 val_dataset = NumericDataset( 

119 data=data_tensor[val_idx], 

120 config=DefaultConfig(), 

121 sample_ids=[ids[i] for i in val_idx], 

122 metadata=metadata_df.iloc[val_idx], 

123 split_indices=val_split, 

124 feature_ids=[f"feature_{i}" for i in range(n_features)], 

125 ) 

126 

127 test_dataset = NumericDataset( 

128 data=data_tensor[test_idx], 

129 config=DefaultConfig(), 

130 sample_ids=[ids[i] for i in test_idx], 

131 metadata=metadata_df.iloc[test_idx], 

132 split_indices=test_split, 

133 feature_ids=[f"feature_{i}" for i in range(n_features)], 

134 ) 

135 

136 return DatasetContainer(train=train_dataset, valid=val_dataset, test=test_dataset) 

137 

138 

139def generate_multi_bulk_example( 

140 n_samples=500, 

141 n_features_modality1=100, 

142 n_features_modality2=80, 

143 random_seed=config.global_seed, 

144) -> Tuple[DataPackage, pd.DataFrame, pd.DataFrame, pd.DataFrame]: 

145 """Generate example data for MULTI_BULK case. 

146 

147 Args: 

148 n_samples: Number of samples to generate 

149 n_features_modality1: Number of features for first modality 

150 n_features_modality2: Number of features for second modality 

151 random_seed: Random seed for reproducibility 

152 

153 Returns: 

154 DataPackage with multi_bulk data and raw_data 

155 """ 

156 np.random.seed(random_seed) 

157 

158 latent_dim = 10 

159 latent_representation = np.random.normal(0, 1, (n_samples, latent_dim)) 

160 

161 weights_mod1 = np.random.normal(0, 1, (latent_dim, n_features_modality1)) 

162 weights_mod2 = np.random.normal(0, 1, (latent_dim, n_features_modality2)) 

163 

164 # Generate observed data with noise 

165 data_mod1 = np.dot(latent_representation, weights_mod1) + np.random.normal( 

166 0, 0.5, (n_samples, n_features_modality1) 

167 ) 

168 data_mod2 = np.dot(latent_representation, weights_mod2) + np.random.normal( 

169 0, 0.5, (n_samples, n_features_modality2) 

170 ) 

171 

172 mod1_features = [f"gene_{i}" for i in range(n_features_modality1)] 

173 mod2_features = [f"protein_{i}" for i in range(n_features_modality2)] 

174 

175 sample_ids = [f"sample_{i}" for i in range(n_samples)] 

176 

177 # Create DataFrames 

178 df_mod1 = pd.DataFrame(data_mod1, columns=mod1_features, index=sample_ids) 

179 df_mod2 = pd.DataFrame(data_mod2, columns=mod2_features, index=sample_ids) 

180 

181 # Create annotation DataFrame with sample information 

182 # Adding some metadata that correlates with the latent structure 

183 group_assignments = np.argmax(latent_representation[:, :3], axis=1) 

184 conditions = ["condition_" + str(g) for g in group_assignments] 

185 batch_effects = np.random.choice(["batch_1", "batch_2", "batch_3"], size=n_samples) 

186 

187 annotation_df = pd.DataFrame( 

188 { 

189 "condition": conditions, 

190 "batch": batch_effects, 

191 "quality_score": np.random.uniform(0.6, 1.0, n_samples), 

192 }, 

193 index=sample_ids, 

194 ) 

195 

196 data_package = DataPackage() 

197 data_package.multi_bulk = {"transcriptomics": df_mod1, "proteomics": df_mod2} 

198 data_package.annotation = { 

199 "transcriptomics": annotation_df.copy(), 

200 "proteomics": annotation_df.copy(), 

201 } 

202 

203 return data_package, df_mod1, df_mod2, annotation_df 

204 

205 

206def generate_default_bulk_bulk_example( 

207 n_samples=500, n_features=200, random_seed=config.global_seed 

208): 

209 """NOT IMPLEMENTED YET""" 

210 np.random.seed(random_seed) 

211 raise NotImplementedError("BULK_TO_BULK example generation not yet implemented") 

212 

213 

214def generate_default_sc_sc_example( 

215 n_cells=1000, n_features=500, random_seed=config.global_seed 

216): 

217 """NOT IMPLEMENTED YET""" 

218 np.random.seed(random_seed) 

219 

220 raise NotImplementedError( 

221 "SINGLE_CELL_TO_SINGLE_CELL example generation not yet implemented" 

222 ) 

223 

224 

225def generate_default_img_bulk_example(n_samples=100, random_seed=config.global_seed): 

226 """NOT IMPLEMENTED YET""" 

227 np.random.seed(random_seed) 

228 

229 # Not implemented yet 

230 raise NotImplementedError("IMG_TO_BULK example generation not yet implemented") 

231 

232 

233def generate_default_sc_img_example(n_samples=100, random_seed=config.global_seed): 

234 """NOT IMPLEMENTED YET""" 

235 np.random.seed(random_seed) 

236 

237 raise NotImplementedError( 

238 "SINGLE_CELL_TO_IMG example generation not yet implemented" 

239 ) 

240 

241 

242def generate_default_img_img_example(n_samples=100, random_seed=config.global_seed): 

243 """NOT IMPLEMENTED YET""" 

244 np.random.seed(random_seed) 

245 

246 raise NotImplementedError("IMG_TO_IMG example generation not yet implemented") 

247 

248 

249def generate_raw_datapackage(data_case: DataCase, **kwargs): 

250 """Generate a raw DataPackage for the specified data case. 

251 

252 Args: 

253 data_case: The type of data to generate 

254 **kwargs: Additional arguments to pass to the specific generator function 

255 

256 Returns: 

257 Raw data package ready for preprocessing 

258 """ 

259 generator_map = { 

260 DataCase.MULTI_BULK: generate_multi_bulk_example, 

261 DataCase.MULTI_SINGLE_CELL: generate_multi_sc_example, 

262 DataCase.BULK_TO_BULK: generate_default_bulk_bulk_example, 

263 DataCase.SINGLE_CELL_TO_SINGLE_CELL: generate_default_sc_sc_example, 

264 DataCase.IMG_TO_BULK: generate_default_img_bulk_example, 

265 DataCase.SINGLE_CELL_TO_IMG: generate_default_sc_img_example, 

266 DataCase.IMG_TO_IMG: generate_default_img_img_example, 

267 } 

268 

269 if data_case not in generator_map: 

270 raise ValueError(f"Unknown data case: {data_case}") 

271 

272 generator_function = generator_map[data_case] 

273 return generator_function(**kwargs) 

274 

275 

276def generate_multi_sc_example( 

277 n_cells=1000, n_genes=500, n_proteins=200, random_seed=config.global_seed 

278): 

279 """Generate example data for MULTI_SINGLE_CELL case. 

280 

281 Args: 

282 n_cells: Number of cells to generate 

283 n_genes: Number of genes for RNA modality 

284 n_proteins: Number of proteins for protein modality 

285 random_seed: Random seed for reproducibility 

286 

287 Returns: 

288 DataPackage with multi_sc data 

289 """ 

290 np.random.seed(random_seed) 

291 

292 cell_ids = [f"cell_{i}" for i in range(n_cells)] 

293 

294 gene_names = [f"gene_{i}" for i in range(n_genes)] 

295 protein_names = [f"protein_{i}" for i in range(n_proteins)] 

296 

297 n_cell_types = 5 

298 cell_type_probabilities = np.random.dirichlet(np.ones(n_cell_types), size=1)[0] 

299 cell_types = np.random.choice( 

300 np.arange(n_cell_types), size=n_cells, p=cell_type_probabilities 

301 ) 

302 

303 cell_metadata = pd.DataFrame( 

304 { 

305 "cell_type": [f"type_{t}" for t in cell_types], 

306 "batch": np.random.choice(["batch1", "batch2", "batch3"], size=n_cells), 

307 "donor": np.random.choice( 

308 ["donor1", "donor2", "donor3", "donor4"], size=n_cells 

309 ), 

310 "cell_cycle": np.random.choice(["G1", "S", "G2M"], size=n_cells), 

311 }, 

312 index=cell_ids, 

313 ) 

314 

315 # Generate RNA data (sparse count matrix with negative binomial distribution) 

316 gene_programs = np.zeros((n_cell_types, n_genes)) 

317 for i in range(n_cell_types): 

318 high_expr_genes = np.random.choice( 

319 n_genes, size=int(n_genes * 0.2), replace=False 

320 ) 

321 gene_programs[i, high_expr_genes] = np.random.gamma( 

322 5, 2, size=len(high_expr_genes) 

323 ) 

324 

325 lambda_values = np.zeros((n_cells, n_genes)) 

326 for i in range(n_cells): 

327 lambda_values[i] = gene_programs[cell_types[i]] + np.random.gamma( 

328 0.5, 0.5, n_genes 

329 ) 

330 

331 seq_depth = np.random.lognormal(mean=8, sigma=0.5, size=n_cells) 

332 lambda_values = lambda_values * seq_depth[:, np.newaxis] 

333 

334 rna_counts = np.random.poisson(lambda_values) 

335 dropout_prob = 0.3 

336 dropout_mask = np.random.binomial(1, 1 - dropout_prob, size=rna_counts.shape) 

337 rna_counts = rna_counts * dropout_mask 

338 

339 # Generate protein data 

340 protein_programs = np.zeros((n_cell_types, n_proteins)) 

341 for i in range(n_cell_types): 

342 high_expr_proteins = np.random.choice( 

343 n_proteins, size=int(n_proteins * 0.3), replace=False 

344 ) 

345 protein_programs[i, high_expr_proteins] = np.random.gamma( 

346 3, 1, size=len(high_expr_proteins) 

347 ) 

348 

349 protein_levels = np.zeros((n_cells, n_proteins)) 

350 for i in range(n_cells): 

351 protein_levels[i] = protein_programs[cell_types[i]] + np.random.gamma( 

352 0.2, 0.3, n_proteins 

353 ) 

354 

355 batch_effect = np.zeros(n_cells) 

356 batch_map = {"batch1": 0.8, "batch2": 1.0, "batch3": 1.2} 

357 for i in range(n_cells): 

358 batch_effect[i] = batch_map[cell_metadata.iloc[i]["batch"]] 

359 

360 protein_levels = protein_levels * batch_effect[:, np.newaxis] 

361 protein_levels = protein_levels + np.random.normal( 

362 0, 0.1, size=protein_levels.shape 

363 ) 

364 protein_levels = np.maximum(0, protein_levels) 

365 

366 var_rna = pd.DataFrame(index=gene_names) 

367 var_protein = pd.DataFrame(index=protein_names) 

368 

369 from scipy import sparse 

370 

371 # Convert to scipy sparse matrices to avoid copy issues 

372 rna_counts_sparse = sparse.csr_matrix(rna_counts) 

373 protein_levels_sparse = sparse.csr_matrix(protein_levels) 

374 

375 rna_adata = anndata.AnnData(X=rna_counts_sparse, obs=cell_metadata, var=var_rna) 

376 

377 protein_adata = anndata.AnnData( 

378 X=protein_levels_sparse, obs=cell_metadata, var=var_protein 

379 ) 

380 

381 mdata = mudata.MuData({"rna": rna_adata, "protein": protein_adata}) 

382 

383 data_package = DataPackage() 

384 data_package.multi_sc = {"multi_sc": mdata} 

385 # data_package.annotation = {"multi_sc": mdata.obs} 

386 

387 return data_package, rna_adata, mdata 

388 

389 

390# Pre-generated example data for direct import 

391config = DefaultConfig() 

392EXAMPLE_PROCESSED_DATA = generate_example_data(random_seed=config.global_seed) 

393EXAMPLE_MULTI_BULK, raw_rna, raw_protein, annotation = generate_raw_datapackage( 

394 data_case=DataCase.MULTI_BULK 

395) 

396EXAMPLE_MULTI_SC, sample_adata, sample_mudata = generate_raw_datapackage( 

397 data_case=DataCase.MULTI_SINGLE_CELL 

398)