Coverage for src / autoencodix / data / _datasplitter.py: 81%

154 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1import itertools 

2import pandas as pd 

3from typing import Dict, Optional, Any, Set, List, Sequence, Tuple 

4 

5import numpy as np 

6from sklearn.model_selection import train_test_split # type: ignore 

7 

8from autoencodix.configs.default_config import DefaultConfig 

9from autoencodix.data.datapackage import DataPackage 

10from collections import defaultdict 

11 

12 

13# internal check done 

14# write tests: done 

15class DataSplitter: 

16 """ 

17 Splits data into train, validation, and test sets. And validates the splits. 

18 

19 Also allows for custom splits to be provided. 

20 Here we allow empty splits (e.g. test_ratio=0), this might raise an error later 

21 in the pipeline, when this split is expected to be non-empty. However, this allows 

22 are more flexible usage of the pipeline (e.g. when the user only wants to run the fit step). 

23 

24 Constraints: 

25 1. Split ratios must sum to 1 

26 2. Each non-empty split must have at least min_samples_per_split samples 

27 3. Any split ratio must be <= 1.0 

28 4. Custom splits must contain 'train', 'valid', and 'test' keys and non-overlapping indices 

29 

30 Attributes: 

31 _config: Configuration object containing split ratios 

32 

33 _custom_splits: Optional pre-defined split indices 

34 _test_ratio 

35 _valid_ratio 

36 

37 """ 

38 

39 def __init__( 

40 self, 

41 config: DefaultConfig, 

42 custom_splits: Optional[Dict[str, np.ndarray]] = None, 

43 ): 

44 """ 

45 Initialize DataSplitter with configuration and optional custom splits. 

46 

47 Args: 

48 config (DefaultConfig): Configuration object containing split ratios 

49 custom_splits (Optional[Dict[str, np.ndarray]]): Pre-defined split indices 

50 """ 

51 self._config = config 

52 self._test_ratio = self._config.test_ratio 

53 self._valid_ratio = self._config.valid_ratio 

54 self._train_ratio = self._config.train_ratio 

55 self._min_samples = self._config.min_samples_per_split 

56 self._custom_splits = custom_splits 

57 

58 self._validate_ratios() 

59 if self._custom_splits: 

60 self._validate_custom_splits(self._custom_splits) 

61 

62 def _validate_ratios(self) -> None: 

63 """ 

64 Validate that the splitting ratios meet required constraints. 

65 Returns: 

66 None 

67 Raises: 

68 ValueError: If ratios violate constraints 

69 

70 """ 

71 if not 0 <= self._test_ratio <= 1: 

72 raise ValueError( 

73 f"Test ratio must be between 0 and 1, got {self._test_ratio}" 

74 ) 

75 if not 0 <= self._valid_ratio <= 1: 

76 raise ValueError( 

77 f"Validation ratio must be between 0 and 1, got {self._valid_ratio}" 

78 ) 

79 if not 0 <= self._train_ratio <= 1: 

80 raise ValueError( 

81 f"Train ratio must be between 0 and 1, got {self._train_ratio}" 

82 ) 

83 

84 if np.sum([self._test_ratio, self._valid_ratio, self._train_ratio]) != 1: 

85 raise ValueError("Split ratios must sum to 1") 

86 

87 def _validate_split_sizes(self, n_samples: int) -> None: 

88 """ 

89 Validate that each non-empty split will have sufficient samples. 

90 

91 Args: 

92 n_samples: Total number of samples in dataset 

93 Returns: 

94 None 

95 Raises: 

96 ValueError: If any non-empty split would have too few samples 

97 

98 """ 

99 # Calculate expected sizes 

100 n_train = int(n_samples * (1 - self._test_ratio - self._valid_ratio)) 

101 n_valid = int(n_samples * self._valid_ratio) if self._valid_ratio > 0 else 0 

102 n_test = int(n_samples * self._test_ratio) if self._test_ratio > 0 else 0 

103 

104 if self._train_ratio > 0 and n_train < self._min_samples: 

105 raise ValueError( 

106 f"Training set would have {n_train} samples, " 

107 f"which is less than minimum required ({self._min_samples})" 

108 ) 

109 

110 if self._valid_ratio > 0 and n_valid < self._min_samples: 

111 raise ValueError( 

112 f"Validation set would have {n_valid} samples, " 

113 f"which is less than minimum required ({self._min_samples})" 

114 ) 

115 

116 if self._test_ratio > 0 and n_test < self._min_samples: 

117 raise ValueError( 

118 f"Test set would have {n_test} samples, " 

119 f"which is less than minimum required ({self._min_samples})" 

120 ) 

121 

122 def _validate_custom_splits(self, splits: Dict[str, np.ndarray]) -> None: 

123 """ 

124 Validate custom splits for correctness. 

125 

126 Args: 

127 splits: Custom split indices 

128 Returns: 

129 None 

130 Raises: 

131 ValueError: If custom splits violate constraints 

132 

133 """ 

134 required_keys = {"train", "valid", "test"} 

135 if not all(key in splits for key in required_keys): 

136 raise ValueError( 

137 f"Custom splits must contain all of: {required_keys} \ Got: {splits.keys()} \ if you want to pass empty splits, pass an empty array" 

138 ) 

139 

140 # check for index out of bounds 

141 if len(splits["train"]) < self._min_samples: 

142 raise ValueError( 

143 f"Custom training split has {len(splits['train'])} samples, " 

144 f"which is less than minimum required ({self._min_samples})" 

145 ) 

146 

147 # For non-empty validation and test splits, check minimum size 

148 if len(splits["valid"]) > 0 and len(splits["valid"]) < self._min_samples: 

149 raise ValueError( 

150 f"Custom validation split has {len(splits['valid'])} samples, " 

151 f"which is less than minimum required ({self._min_samples})" 

152 ) 

153 

154 if len(splits["test"]) > 0 and len(splits["test"]) < self._min_samples: 

155 raise ValueError( 

156 f"Custom test split has {len(splits['test'])} samples, " 

157 f"which is less than minimum required ({self._min_samples})" 

158 ) 

159 

160 # Check for overlap between splits 

161 for k1, k2 in itertools.combinations(required_keys, 2): 

162 intersection = set(splits[k1]) & set(splits[k2]) 

163 if intersection: 

164 raise ValueError( 

165 f"Overlapping indices found between splits '{k1}' and '{k2}': {intersection}" 

166 ) 

167 

168 def split( 

169 self, 

170 n_samples: int, 

171 ) -> Dict[str, np.ndarray]: 

172 """ 

173 Split data into train, validation, and test sets. 

174 

175 Args: 

176 n_samples: Total number of samples in the dataset 

177 

178 Returns: 

179 Dictionary containing indices for each split, with empty arrays for splits with ratio=0 

180 

181 Raises: 

182 ValueError: If resulting splits would violate size constraints 

183 """ 

184 self._validate_split_sizes(n_samples) 

185 indices = np.arange(n_samples) 

186 

187 if self._custom_splits: 

188 max_index = n_samples - 1 

189 for split in self._custom_splits.values(): 

190 if len(split) > 0: 

191 if np.max(split) > max_index: 

192 raise AssertionError( 

193 f"Custom split indices must be within range [0, {max_index}]" 

194 ) 

195 elif np.min(split) < 0: 

196 raise AssertionError( 

197 f"Custom split indices must be within range [0, {max_index}]" 

198 ) 

199 return self._custom_splits 

200 

201 # all three 0 case already handled in _validate_ratios (sum to 1) 

202 if self._test_ratio == 0 and self._valid_ratio == 0: 

203 return { 

204 "train": indices, 

205 "valid": np.array([], dtype=int), 

206 "test": np.array([], dtype=int), 

207 } 

208 if self._train_ratio == 0 and self._valid_ratio == 0: 

209 return { 

210 "train": np.array([], dtype=int), 

211 "valid": np.array([], dtype=int), 

212 "test": indices, 

213 } 

214 if self._train_ratio == 0 and self._test_ratio == 0: 

215 return { 

216 "train": np.array([], dtype=int), 

217 "valid": indices, 

218 "test": np.array([], dtype=int), 

219 } 

220 

221 if self._train_ratio == 0: 

222 valid_indices, test_indices = train_test_split( 

223 indices, 

224 test_size=self._test_ratio, 

225 random_state=self._config.global_seed, 

226 ) 

227 return { 

228 "train": np.array([], dtype=int), 

229 "valid": valid_indices, 

230 "test": test_indices, 

231 } 

232 

233 if self._test_ratio == 0: 

234 train_indices, valid_indices = train_test_split( 

235 indices, 

236 test_size=self._valid_ratio, 

237 random_state=self._config.global_seed, 

238 ) 

239 return { 

240 "train": train_indices, 

241 "valid": valid_indices, 

242 "test": np.array([], dtype=int), 

243 } 

244 

245 if self._valid_ratio == 0: 

246 train_indices, test_indices = train_test_split( 

247 indices, 

248 test_size=self._test_ratio, 

249 random_state=self._config.global_seed, 

250 ) 

251 return { 

252 "train": train_indices, 

253 "valid": np.array([], dtype=int), 

254 "test": test_indices, 

255 } 

256 

257 # Normal case: split into all three sets 

258 train_valid_indices, test_indices = train_test_split( 

259 indices, test_size=self._test_ratio, random_state=self._config.global_seed 

260 ) 

261 

262 train_indices, valid_indices = train_test_split( 

263 train_valid_indices, 

264 test_size=self._valid_ratio / (1 - self._test_ratio), 

265 random_state=self._config.global_seed, 

266 ) 

267 

268 return {"train": train_indices, "valid": valid_indices, "test": test_indices} 

269 

270 

271class PairedUnpairedSplitter: 

272 """Performs pairing-aware data splitting across multiple modalities. 

273 

274 Handles any number of data modalities and automatically identifies 

275 fully paired and partially paired samples. Each sample is assigned 

276 to exactly one split (train, valid, or test). If a sample appears in 

277 multiple modalities, it is guaranteed to appear in the same split 

278 across all of them. 

279 

280 Each modality can have a corresponding annotation file. Samples in modalities 

281 that are not in their corresponding annotation file are dropped. 

282 

283 Attributes: 

284 data_package: The input data package containing modalities and annotations. 

285 config: Configuration object with split ratios and random seed. 

286 annotation_ids_per_modality: Mapping from modality keys to their valid sample IDs. 

287 modalities: Mapping of modality keys to their data objects. 

288 membership_groups: Mapping of modality combinations to the set of 

289 sample IDs belonging to each combination (e.g., RNA+Protein pairs). 

290 """ 

291 

292 def __init__(self, data_package, config): 

293 """Initializes the splitter and computes modality membership groups. 

294 

295 Args: 

296 data_package: The full data package to split. Must implement 

297 `_get_sample_ids` and iterable access yielding (key, object) pairs. 

298 config: Split configuration object defining ratios and random seed. 

299 

300 Raises: 

301 TypeError: If `data_package` is not a valid DataPackage instance. 

302 """ 

303 if not hasattr(data_package, "_get_sample_ids"): 

304 raise TypeError("data_package must be an instance of DataPackage.") 

305 

306 self.datapackage = data_package 

307 self.config = config 

308 self.membership_groups: Dict[Tuple[str, ...], Set[str]] = ( 

309 self._compute_membership_groups() 

310 ) 

311 

312 def _compute_membership_groups(self) -> Dict[Tuple[str, ...], Set[str]]: 

313 """Groups samples by the set of modalities in which they appear. 

314 

315 Returns: 

316 A mapping from modality combinations (tuples of modality keys) 

317 to the set of sample IDs belonging to each combination. 

318 

319 Examples: 

320 This mapping could look like: 

321 {("multi_bulk.rna", "multi_bulk.cna"): {"id1", id3", ...}, 

322 ()"multi_bulk.rna, "img.img"): {"id2", "id4 ,...} 

323 } 

324 """ 

325 sample_to_modalities: Dict[str, Set[str]] = defaultdict(set) 

326 for modality_key, obj in self.datapackage: 

327 if "annotation" in modality_key: 

328 continue 

329 ids: List[str] = self.datapackage._get_sample_ids(obj) 

330 

331 for sid in ids: 

332 sample_to_modalities[sid].add(modality_key) 

333 

334 # Group samples by identical modality membership 

335 groups: Dict[Tuple[str, ...], Set[str]] = defaultdict(set) 

336 for sid, mods in sample_to_modalities.items(): 

337 groups[tuple(sorted(mods))].add(sid) 

338 return groups 

339 

340 def _split_group(self, ids: Sequence[str]) -> Dict[str, np.ndarray]: 

341 """Splits a homogeneous group of sample IDs into train, valid, and test subsets. 

342 

343 Args: 

344 ids: Collection of sample IDs belonging to the same modality group. 

345 

346 Returns: 

347 A mapping with keys ``train``, ``valid``, and ``test``, where each value 

348 is an array of sample IDs assigned to that split. 

349 """ 

350 ids = list(ids) 

351 rng = np.random.default_rng(self.config.global_seed) 

352 ids = rng.permutation(ids) 

353 n = len(ids) 

354 n_train = int(n * self.config.train_ratio) 

355 n_valid = int(n * self.config.valid_ratio) 

356 return { 

357 "train": np.array(ids[:n_train], dtype=object), 

358 "valid": np.array(ids[n_train : n_train + n_valid], dtype=object), 

359 "test": np.array(ids[n_train + n_valid :], dtype=object), 

360 } 

361 

362 def split(self) -> Dict[str, Dict[str, Dict[str, np.ndarray]]]: 

363 """Performs the complete pairing-aware split across modalities and annotations. 

364 

365 Ensures that: 

366 - Each sample appears in exactly one split across all modalities it belongs to. 

367 - Fully paired samples are synchronized across modalities. 

368 - Each annotation table is split consistently with its corresponding modality. 

369 

370 Returns: 

371 Nested mapping of ``{parent_key -> {child_key -> {split_name -> np.ndarray(indices)}}}``, 

372 suitable for use with ``DataPackageSplitter``. Includes splits for both 

373 modalities and their corresponding annotation files. 

374 """ 

375 # Sort groups by descending number of modalities (most-paired first) 

376 # This ensures fully-paired samples are assigned before partially-paired ones 

377 sorted_groups = sorted( 

378 self.membership_groups.items(), key=lambda kv: -len(kv[0]) 

379 ) 

380 assigned_ids: Set[str] = set() 

381 

382 # Initialize split storage for each modality 

383 per_modality_splits: Dict[str, Dict[str, Set[str]]] = { 

384 mod: {"train": set(), "valid": set(), "test": set()} 

385 for mod, _ in self.datapackage 

386 } 

387 

388 # Assign each membership group to splits 

389 for mods_tuple, sids in sorted_groups: 

390 sids_to_assign = [sid for sid in sids if sid not in assigned_ids] 

391 if not sids_to_assign: 

392 continue 

393 

394 group_splits = self._split_group(sids_to_assign) 

395 for split_name, split_ids in group_splits.items(): 

396 for mod in mods_tuple: 

397 per_modality_splits[mod][split_name].update(split_ids) 

398 assigned_ids.update(split_ids) 

399 

400 final_indices: Dict[str, Dict[str, Dict[str, np.ndarray]]] = {} 

401 for full_key, data_obj in self.datapackage: 

402 parent_key, child_key = full_key.split(".") 

403 if parent_key == "annotation": 

404 continue # Handle annotations separately below 

405 original_ids = self.datapackage._get_sample_ids(data_obj) 

406 

407 id_to_pos = {sid: i for i, sid in enumerate(original_ids)} 

408 final_indices.setdefault(parent_key, {})[child_key] = {} 

409 

410 for split_name in ["train", "valid", "test"]: 

411 split_ids_for_mod = per_modality_splits.get(full_key, {}).get( 

412 split_name, set() 

413 ) 

414 final_indices[parent_key][child_key][split_name] = np.array( 

415 sorted([id_to_pos[sid] for sid in split_ids_for_mod]), dtype=int 

416 ) 

417 

418 if self.datapackage.annotation: 

419 for anno_key, anno_df in self.datapackage.annotation.items(): 

420 print(f"anno key: {anno_key}") 

421 anno_id_to_pos = {sid: i for i, sid in enumerate(anno_df.index)} 

422 final_indices.setdefault("annotation", {})[anno_key] = {} 

423 if ( 

424 len(self.datapackage.annotation) == 1 

425 and self.config.requires_paired 

426 ) or (anno_key == "paired" and self.config.requires_paired): 

427 # For each split, take the union of all sample IDs across modalities 

428 for split_name in ["train", "valid", "test"]: 

429 split_ids_union = set().union( 

430 *[ 

431 splits[split_name] 

432 for splits in per_modality_splits.values() 

433 ] 

434 ) 

435 anno_indices = sorted( 

436 [ 

437 anno_id_to_pos[sid] 

438 for sid in split_ids_union 

439 if sid in anno_id_to_pos 

440 ] 

441 ) 

442 final_indices["annotation"][anno_key][split_name] = np.array( 

443 anno_indices, dtype=int 

444 ) 

445 

446 continue # Skip normal annotation handling for "paired" 

447 

448 for split_name in ["train", "valid", "test"]: 

449 for mod_name in per_modality_splits: 

450 parent_key, child_key = mod_name.split(".") 

451 # in per_modality_splits are no annotation ids 

452 # thats what we do here, so we need to skip annotation 

453 # to not get empty arrays in final_indices 

454 if parent_key == "annotation": 

455 continue 

456 if child_key == anno_key: 

457 # Take split IDs from the corresponding modality only 

458 split_ids = per_modality_splits.get(mod_name, {}).get( 

459 split_name, set() 

460 ) 

461 anno_indices = sorted( 

462 [ 

463 anno_id_to_pos[sid] 

464 for sid in split_ids 

465 if sid in anno_id_to_pos 

466 ] 

467 ) 

468 final_indices["annotation"][anno_key][split_name] = ( 

469 np.array(anno_indices, dtype=int) 

470 ) 

471 

472 return final_indices