Coverage for src / autoencodix / data / _datapackage_splitter.py: 28%
57 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1import copy
2from typing import Any, Dict, Optional
4import numpy as np
5import pandas as pd
6from anndata import AnnData # type: ignore
7from mudata import MuData # type: ignore
9from autoencodix.data.datapackage import DataPackage
10from autoencodix.configs.default_config import DefaultConfig
13class DataPackageSplitter:
14 """Splits DataPackage objects into training, validation, and testing sets.
16 Supports paired and unpaired (translation) splitting.
18 Attributes:
19 data_package: The original DataPackage to split.
20 config: The configuration settings for the splitting process.
21 indices: The indices for each split (train/val/test).
22 """
24 def __init__(
25 self,
26 data_package: DataPackage,
27 config: DefaultConfig,
28 indices: Dict[str, Dict[str, Dict[str, np.ndarray]]],
29 ) -> None:
30 self._data_package = data_package
31 self.indices = indices
32 self.config = config
34 if not isinstance(self._data_package, DataPackage):
35 raise TypeError(
36 f"Expected data_package to be of type DataPackage, got {type(self._data_package)}"
37 )
39 def _shallow_copy(self, value: Any) -> Any:
40 try:
41 return copy.copy(value)
42 except AttributeError:
43 return value
45 def _indexing(self, obj: Any, indices: np.ndarray) -> Any:
46 """Indexes pd.DataFrame, list, AnnData, or MuData objects using the provided indices.
48 Args:
49 obj: The object to index (can be pd.DataFrame, list, AnnData, MuData, or None).
50 indices: The indices to use for indexing.
51 Returns:
52 The indexed object, or None if the input object is None.
53 Raises:
54 TypeError: If an unsupported type is encountered.
55 """
57 if obj is None:
58 return None
59 if isinstance(obj, pd.DataFrame):
60 return obj.iloc[indices]
61 elif isinstance(obj, list):
62 return [obj[i] for i in indices]
63 elif isinstance(obj, (AnnData, MuData)):
64 # print(f"shape of obj: {obj.shape}")
65 # print(f"obj: {obj}")
66 # print(f"len(ind): {len(indices)}")
67 # print(f"max of index{np.max(indices)}")
68 # print(f"ind: {indices}")
69 return obj[indices]
70 else:
71 raise TypeError(
72 f"Unsupported type for indexing: {type(obj)}. "
73 "Supported types are pd.DataFrame, list, AnnData, and MuData."
74 )
76 def _split_data_package(self, split: str) -> Optional[DataPackage]:
77 """Creates a new DataPackage where each attribute is indexed (if applicable)
78 by the given indices. Returns None if indices are empty.
80 Args:
81 indices: The indices to use for splitting the DataPackage.
82 Returns:
83 A new DataPackage with attributes indexed by the provided indices,
84 or None if indices are empty.
85 """
86 if len(self.indices) == 0:
87 return None
89 split_data = {}
90 for key, value in self._data_package.__dict__.items():
91 if value is None:
92 continue
93 split_data[key] = {
94 modality: self._indexing(data, self.indices[key][modality][split])
95 for modality, data in value.items()
96 }
97 return DataPackage(**split_data)
99 def _split_mudata(
100 self,
101 mudata: MuData, # ty: ignore[invalid-type-form]
102 indices_map: Dict[str, Dict[str, np.ndarray]],
103 split: str,
104 ) -> MuData: # ty: ignore[invalid-type-form]
105 """Splits a MuData object based on the provided indices map.
107 Args:
108 mudata: The MuData object to split.
109 indices_map: A dictionary mapping modalities to their respective indices.
110 split: The split type ("train", "valid", or "test").
111 Returns:
112 A new MuData object with the specified splits applied.
113 """
114 for modality, data in mudata.mod.items():
115 indices = indices_map[modality][split]
116 mudata.mod[modality] = self._indexing(data, indices)
117 return mudata
119 def _requires_paired(self) -> bool:
120 return self.config.requires_paired is None or self.config.requires_paired
122 def split(self) -> Dict[str, Optional[Dict[str, Any]]]:
123 """Splits the underlying DataPackage into train, valid, and test subsets.
124 Returns:
125 A dictionary containing the split data packages for "train", "valid", and "test".
126 Each entry contains a "data" key with the DataPackage and an "indices" key with
127 the corresponding indices.
128 Raises:
129 ValueError: If no data package is available for splitting.
130 TypeError: If indices are not provided for unpaired translation case.
131 """
132 if self._data_package is None:
133 raise ValueError("No data package available for splitting")
135 splits = ["train", "valid", "test"]
136 result: Dict[str, Optional[Dict[str, Any]]] = {
137 "train": {},
138 "valid": {},
139 "test": {},
140 }
142 for split in splits:
143 if self.indices is None: # or split not in self.indices:
144 result[split] = None
145 continue
146 result[split] = {
147 "data": self._split_data_package(split=split),
148 "indices": self.indices,
149 }
151 return result