Coverage for src / autoencodix / data / _datasplitter.py: 81%
154 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1import itertools
2import pandas as pd
3from typing import Dict, Optional, Any, Set, List, Sequence, Tuple
5import numpy as np
6from sklearn.model_selection import train_test_split # type: ignore
8from autoencodix.configs.default_config import DefaultConfig
9from autoencodix.data.datapackage import DataPackage
10from collections import defaultdict
13# internal check done
14# write tests: done
15class DataSplitter:
16 """
17 Splits data into train, validation, and test sets. And validates the splits.
19 Also allows for custom splits to be provided.
20 Here we allow empty splits (e.g. test_ratio=0), this might raise an error later
21 in the pipeline, when this split is expected to be non-empty. However, this allows
22 are more flexible usage of the pipeline (e.g. when the user only wants to run the fit step).
24 Constraints:
25 1. Split ratios must sum to 1
26 2. Each non-empty split must have at least min_samples_per_split samples
27 3. Any split ratio must be <= 1.0
28 4. Custom splits must contain 'train', 'valid', and 'test' keys and non-overlapping indices
30 Attributes:
31 _config: Configuration object containing split ratios
33 _custom_splits: Optional pre-defined split indices
34 _test_ratio
35 _valid_ratio
37 """
39 def __init__(
40 self,
41 config: DefaultConfig,
42 custom_splits: Optional[Dict[str, np.ndarray]] = None,
43 ):
44 """
45 Initialize DataSplitter with configuration and optional custom splits.
47 Args:
48 config (DefaultConfig): Configuration object containing split ratios
49 custom_splits (Optional[Dict[str, np.ndarray]]): Pre-defined split indices
50 """
51 self._config = config
52 self._test_ratio = self._config.test_ratio
53 self._valid_ratio = self._config.valid_ratio
54 self._train_ratio = self._config.train_ratio
55 self._min_samples = self._config.min_samples_per_split
56 self._custom_splits = custom_splits
58 self._validate_ratios()
59 if self._custom_splits:
60 self._validate_custom_splits(self._custom_splits)
62 def _validate_ratios(self) -> None:
63 """
64 Validate that the splitting ratios meet required constraints.
65 Returns:
66 None
67 Raises:
68 ValueError: If ratios violate constraints
70 """
71 if not 0 <= self._test_ratio <= 1:
72 raise ValueError(
73 f"Test ratio must be between 0 and 1, got {self._test_ratio}"
74 )
75 if not 0 <= self._valid_ratio <= 1:
76 raise ValueError(
77 f"Validation ratio must be between 0 and 1, got {self._valid_ratio}"
78 )
79 if not 0 <= self._train_ratio <= 1:
80 raise ValueError(
81 f"Train ratio must be between 0 and 1, got {self._train_ratio}"
82 )
84 if np.sum([self._test_ratio, self._valid_ratio, self._train_ratio]) != 1:
85 raise ValueError("Split ratios must sum to 1")
87 def _validate_split_sizes(self, n_samples: int) -> None:
88 """
89 Validate that each non-empty split will have sufficient samples.
91 Args:
92 n_samples: Total number of samples in dataset
93 Returns:
94 None
95 Raises:
96 ValueError: If any non-empty split would have too few samples
98 """
99 # Calculate expected sizes
100 n_train = int(n_samples * (1 - self._test_ratio - self._valid_ratio))
101 n_valid = int(n_samples * self._valid_ratio) if self._valid_ratio > 0 else 0
102 n_test = int(n_samples * self._test_ratio) if self._test_ratio > 0 else 0
104 if self._train_ratio > 0 and n_train < self._min_samples:
105 raise ValueError(
106 f"Training set would have {n_train} samples, "
107 f"which is less than minimum required ({self._min_samples})"
108 )
110 if self._valid_ratio > 0 and n_valid < self._min_samples:
111 raise ValueError(
112 f"Validation set would have {n_valid} samples, "
113 f"which is less than minimum required ({self._min_samples})"
114 )
116 if self._test_ratio > 0 and n_test < self._min_samples:
117 raise ValueError(
118 f"Test set would have {n_test} samples, "
119 f"which is less than minimum required ({self._min_samples})"
120 )
122 def _validate_custom_splits(self, splits: Dict[str, np.ndarray]) -> None:
123 """
124 Validate custom splits for correctness.
126 Args:
127 splits: Custom split indices
128 Returns:
129 None
130 Raises:
131 ValueError: If custom splits violate constraints
133 """
134 required_keys = {"train", "valid", "test"}
135 if not all(key in splits for key in required_keys):
136 raise ValueError(
137 f"Custom splits must contain all of: {required_keys} \ Got: {splits.keys()} \ if you want to pass empty splits, pass an empty array"
138 )
140 # check for index out of bounds
141 if len(splits["train"]) < self._min_samples:
142 raise ValueError(
143 f"Custom training split has {len(splits['train'])} samples, "
144 f"which is less than minimum required ({self._min_samples})"
145 )
147 # For non-empty validation and test splits, check minimum size
148 if len(splits["valid"]) > 0 and len(splits["valid"]) < self._min_samples:
149 raise ValueError(
150 f"Custom validation split has {len(splits['valid'])} samples, "
151 f"which is less than minimum required ({self._min_samples})"
152 )
154 if len(splits["test"]) > 0 and len(splits["test"]) < self._min_samples:
155 raise ValueError(
156 f"Custom test split has {len(splits['test'])} samples, "
157 f"which is less than minimum required ({self._min_samples})"
158 )
160 # Check for overlap between splits
161 for k1, k2 in itertools.combinations(required_keys, 2):
162 intersection = set(splits[k1]) & set(splits[k2])
163 if intersection:
164 raise ValueError(
165 f"Overlapping indices found between splits '{k1}' and '{k2}': {intersection}"
166 )
168 def split(
169 self,
170 n_samples: int,
171 ) -> Dict[str, np.ndarray]:
172 """
173 Split data into train, validation, and test sets.
175 Args:
176 n_samples: Total number of samples in the dataset
178 Returns:
179 Dictionary containing indices for each split, with empty arrays for splits with ratio=0
181 Raises:
182 ValueError: If resulting splits would violate size constraints
183 """
184 self._validate_split_sizes(n_samples)
185 indices = np.arange(n_samples)
187 if self._custom_splits:
188 max_index = n_samples - 1
189 for split in self._custom_splits.values():
190 if len(split) > 0:
191 if np.max(split) > max_index:
192 raise AssertionError(
193 f"Custom split indices must be within range [0, {max_index}]"
194 )
195 elif np.min(split) < 0:
196 raise AssertionError(
197 f"Custom split indices must be within range [0, {max_index}]"
198 )
199 return self._custom_splits
201 # all three 0 case already handled in _validate_ratios (sum to 1)
202 if self._test_ratio == 0 and self._valid_ratio == 0:
203 return {
204 "train": indices,
205 "valid": np.array([], dtype=int),
206 "test": np.array([], dtype=int),
207 }
208 if self._train_ratio == 0 and self._valid_ratio == 0:
209 return {
210 "train": np.array([], dtype=int),
211 "valid": np.array([], dtype=int),
212 "test": indices,
213 }
214 if self._train_ratio == 0 and self._test_ratio == 0:
215 return {
216 "train": np.array([], dtype=int),
217 "valid": indices,
218 "test": np.array([], dtype=int),
219 }
221 if self._train_ratio == 0:
222 valid_indices, test_indices = train_test_split(
223 indices,
224 test_size=self._test_ratio,
225 random_state=self._config.global_seed,
226 )
227 return {
228 "train": np.array([], dtype=int),
229 "valid": valid_indices,
230 "test": test_indices,
231 }
233 if self._test_ratio == 0:
234 train_indices, valid_indices = train_test_split(
235 indices,
236 test_size=self._valid_ratio,
237 random_state=self._config.global_seed,
238 )
239 return {
240 "train": train_indices,
241 "valid": valid_indices,
242 "test": np.array([], dtype=int),
243 }
245 if self._valid_ratio == 0:
246 train_indices, test_indices = train_test_split(
247 indices,
248 test_size=self._test_ratio,
249 random_state=self._config.global_seed,
250 )
251 return {
252 "train": train_indices,
253 "valid": np.array([], dtype=int),
254 "test": test_indices,
255 }
257 # Normal case: split into all three sets
258 train_valid_indices, test_indices = train_test_split(
259 indices, test_size=self._test_ratio, random_state=self._config.global_seed
260 )
262 train_indices, valid_indices = train_test_split(
263 train_valid_indices,
264 test_size=self._valid_ratio / (1 - self._test_ratio),
265 random_state=self._config.global_seed,
266 )
268 return {"train": train_indices, "valid": valid_indices, "test": test_indices}
271class PairedUnpairedSplitter:
272 """Performs pairing-aware data splitting across multiple modalities.
274 Handles any number of data modalities and automatically identifies
275 fully paired and partially paired samples. Each sample is assigned
276 to exactly one split (train, valid, or test). If a sample appears in
277 multiple modalities, it is guaranteed to appear in the same split
278 across all of them.
280 Each modality can have a corresponding annotation file. Samples in modalities
281 that are not in their corresponding annotation file are dropped.
283 Attributes:
284 data_package: The input data package containing modalities and annotations.
285 config: Configuration object with split ratios and random seed.
286 annotation_ids_per_modality: Mapping from modality keys to their valid sample IDs.
287 modalities: Mapping of modality keys to their data objects.
288 membership_groups: Mapping of modality combinations to the set of
289 sample IDs belonging to each combination (e.g., RNA+Protein pairs).
290 """
292 def __init__(self, data_package, config):
293 """Initializes the splitter and computes modality membership groups.
295 Args:
296 data_package: The full data package to split. Must implement
297 `_get_sample_ids` and iterable access yielding (key, object) pairs.
298 config: Split configuration object defining ratios and random seed.
300 Raises:
301 TypeError: If `data_package` is not a valid DataPackage instance.
302 """
303 if not hasattr(data_package, "_get_sample_ids"):
304 raise TypeError("data_package must be an instance of DataPackage.")
306 self.datapackage = data_package
307 self.config = config
308 self.membership_groups: Dict[Tuple[str, ...], Set[str]] = (
309 self._compute_membership_groups()
310 )
312 def _compute_membership_groups(self) -> Dict[Tuple[str, ...], Set[str]]:
313 """Groups samples by the set of modalities in which they appear.
315 Returns:
316 A mapping from modality combinations (tuples of modality keys)
317 to the set of sample IDs belonging to each combination.
319 Examples:
320 This mapping could look like:
321 {("multi_bulk.rna", "multi_bulk.cna"): {"id1", id3", ...},
322 ()"multi_bulk.rna, "img.img"): {"id2", "id4 ,...}
323 }
324 """
325 sample_to_modalities: Dict[str, Set[str]] = defaultdict(set)
326 for modality_key, obj in self.datapackage:
327 if "annotation" in modality_key:
328 continue
329 ids: List[str] = self.datapackage._get_sample_ids(obj)
331 for sid in ids:
332 sample_to_modalities[sid].add(modality_key)
334 # Group samples by identical modality membership
335 groups: Dict[Tuple[str, ...], Set[str]] = defaultdict(set)
336 for sid, mods in sample_to_modalities.items():
337 groups[tuple(sorted(mods))].add(sid)
338 return groups
340 def _split_group(self, ids: Sequence[str]) -> Dict[str, np.ndarray]:
341 """Splits a homogeneous group of sample IDs into train, valid, and test subsets.
343 Args:
344 ids: Collection of sample IDs belonging to the same modality group.
346 Returns:
347 A mapping with keys ``train``, ``valid``, and ``test``, where each value
348 is an array of sample IDs assigned to that split.
349 """
350 ids = list(ids)
351 rng = np.random.default_rng(self.config.global_seed)
352 ids = rng.permutation(ids)
353 n = len(ids)
354 n_train = int(n * self.config.train_ratio)
355 n_valid = int(n * self.config.valid_ratio)
356 return {
357 "train": np.array(ids[:n_train], dtype=object),
358 "valid": np.array(ids[n_train : n_train + n_valid], dtype=object),
359 "test": np.array(ids[n_train + n_valid :], dtype=object),
360 }
362 def split(self) -> Dict[str, Dict[str, Dict[str, np.ndarray]]]:
363 """Performs the complete pairing-aware split across modalities and annotations.
365 Ensures that:
366 - Each sample appears in exactly one split across all modalities it belongs to.
367 - Fully paired samples are synchronized across modalities.
368 - Each annotation table is split consistently with its corresponding modality.
370 Returns:
371 Nested mapping of ``{parent_key -> {child_key -> {split_name -> np.ndarray(indices)}}}``,
372 suitable for use with ``DataPackageSplitter``. Includes splits for both
373 modalities and their corresponding annotation files.
374 """
375 # Sort groups by descending number of modalities (most-paired first)
376 # This ensures fully-paired samples are assigned before partially-paired ones
377 sorted_groups = sorted(
378 self.membership_groups.items(), key=lambda kv: -len(kv[0])
379 )
380 assigned_ids: Set[str] = set()
382 # Initialize split storage for each modality
383 per_modality_splits: Dict[str, Dict[str, Set[str]]] = {
384 mod: {"train": set(), "valid": set(), "test": set()}
385 for mod, _ in self.datapackage
386 }
388 # Assign each membership group to splits
389 for mods_tuple, sids in sorted_groups:
390 sids_to_assign = [sid for sid in sids if sid not in assigned_ids]
391 if not sids_to_assign:
392 continue
394 group_splits = self._split_group(sids_to_assign)
395 for split_name, split_ids in group_splits.items():
396 for mod in mods_tuple:
397 per_modality_splits[mod][split_name].update(split_ids)
398 assigned_ids.update(split_ids)
400 final_indices: Dict[str, Dict[str, Dict[str, np.ndarray]]] = {}
401 for full_key, data_obj in self.datapackage:
402 parent_key, child_key = full_key.split(".")
403 if parent_key == "annotation":
404 continue # Handle annotations separately below
405 original_ids = self.datapackage._get_sample_ids(data_obj)
407 id_to_pos = {sid: i for i, sid in enumerate(original_ids)}
408 final_indices.setdefault(parent_key, {})[child_key] = {}
410 for split_name in ["train", "valid", "test"]:
411 split_ids_for_mod = per_modality_splits.get(full_key, {}).get(
412 split_name, set()
413 )
414 final_indices[parent_key][child_key][split_name] = np.array(
415 sorted([id_to_pos[sid] for sid in split_ids_for_mod]), dtype=int
416 )
418 if self.datapackage.annotation:
419 for anno_key, anno_df in self.datapackage.annotation.items():
420 print(f"anno key: {anno_key}")
421 anno_id_to_pos = {sid: i for i, sid in enumerate(anno_df.index)}
422 final_indices.setdefault("annotation", {})[anno_key] = {}
423 if (
424 len(self.datapackage.annotation) == 1
425 and self.config.requires_paired
426 ) or (anno_key == "paired" and self.config.requires_paired):
427 # For each split, take the union of all sample IDs across modalities
428 for split_name in ["train", "valid", "test"]:
429 split_ids_union = set().union(
430 *[
431 splits[split_name]
432 for splits in per_modality_splits.values()
433 ]
434 )
435 anno_indices = sorted(
436 [
437 anno_id_to_pos[sid]
438 for sid in split_ids_union
439 if sid in anno_id_to_pos
440 ]
441 )
442 final_indices["annotation"][anno_key][split_name] = np.array(
443 anno_indices, dtype=int
444 )
446 continue # Skip normal annotation handling for "paired"
448 for split_name in ["train", "valid", "test"]:
449 for mod_name in per_modality_splits:
450 parent_key, child_key = mod_name.split(".")
451 # in per_modality_splits are no annotation ids
452 # thats what we do here, so we need to skip annotation
453 # to not get empty arrays in final_indices
454 if parent_key == "annotation":
455 continue
456 if child_key == anno_key:
457 # Take split IDs from the corresponding modality only
458 split_ids = per_modality_splits.get(mod_name, {}).get(
459 split_name, set()
460 )
461 anno_indices = sorted(
462 [
463 anno_id_to_pos[sid]
464 for sid in split_ids
465 if sid in anno_id_to_pos
466 ]
467 )
468 final_indices["annotation"][anno_key][split_name] = (
469 np.array(anno_indices, dtype=int)
470 )
472 return final_indices