Coverage for src / autoencodix / data / _datapackage_splitter.py: 28%

57 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1import copy 

2from typing import Any, Dict, Optional 

3 

4import numpy as np 

5import pandas as pd 

6from anndata import AnnData # type: ignore 

7from mudata import MuData # type: ignore 

8 

9from autoencodix.data.datapackage import DataPackage 

10from autoencodix.configs.default_config import DefaultConfig 

11 

12 

13class DataPackageSplitter: 

14 """Splits DataPackage objects into training, validation, and testing sets. 

15 

16 Supports paired and unpaired (translation) splitting. 

17 

18 Attributes: 

19 data_package: The original DataPackage to split. 

20 config: The configuration settings for the splitting process. 

21 indices: The indices for each split (train/val/test). 

22 """ 

23 

24 def __init__( 

25 self, 

26 data_package: DataPackage, 

27 config: DefaultConfig, 

28 indices: Dict[str, Dict[str, Dict[str, np.ndarray]]], 

29 ) -> None: 

30 self._data_package = data_package 

31 self.indices = indices 

32 self.config = config 

33 

34 if not isinstance(self._data_package, DataPackage): 

35 raise TypeError( 

36 f"Expected data_package to be of type DataPackage, got {type(self._data_package)}" 

37 ) 

38 

39 def _shallow_copy(self, value: Any) -> Any: 

40 try: 

41 return copy.copy(value) 

42 except AttributeError: 

43 return value 

44 

45 def _indexing(self, obj: Any, indices: np.ndarray) -> Any: 

46 """Indexes pd.DataFrame, list, AnnData, or MuData objects using the provided indices. 

47 

48 Args: 

49 obj: The object to index (can be pd.DataFrame, list, AnnData, MuData, or None). 

50 indices: The indices to use for indexing. 

51 Returns: 

52 The indexed object, or None if the input object is None. 

53 Raises: 

54 TypeError: If an unsupported type is encountered. 

55 """ 

56 

57 if obj is None: 

58 return None 

59 if isinstance(obj, pd.DataFrame): 

60 return obj.iloc[indices] 

61 elif isinstance(obj, list): 

62 return [obj[i] for i in indices] 

63 elif isinstance(obj, (AnnData, MuData)): 

64 # print(f"shape of obj: {obj.shape}") 

65 # print(f"obj: {obj}") 

66 # print(f"len(ind): {len(indices)}") 

67 # print(f"max of index{np.max(indices)}") 

68 # print(f"ind: {indices}") 

69 return obj[indices] 

70 else: 

71 raise TypeError( 

72 f"Unsupported type for indexing: {type(obj)}. " 

73 "Supported types are pd.DataFrame, list, AnnData, and MuData." 

74 ) 

75 

76 def _split_data_package(self, split: str) -> Optional[DataPackage]: 

77 """Creates a new DataPackage where each attribute is indexed (if applicable) 

78 by the given indices. Returns None if indices are empty. 

79 

80 Args: 

81 indices: The indices to use for splitting the DataPackage. 

82 Returns: 

83 A new DataPackage with attributes indexed by the provided indices, 

84 or None if indices are empty. 

85 """ 

86 if len(self.indices) == 0: 

87 return None 

88 

89 split_data = {} 

90 for key, value in self._data_package.__dict__.items(): 

91 if value is None: 

92 continue 

93 split_data[key] = { 

94 modality: self._indexing(data, self.indices[key][modality][split]) 

95 for modality, data in value.items() 

96 } 

97 return DataPackage(**split_data) 

98 

99 def _split_mudata( 

100 self, 

101 mudata: MuData, # ty: ignore[invalid-type-form] 

102 indices_map: Dict[str, Dict[str, np.ndarray]], 

103 split: str, 

104 ) -> MuData: # ty: ignore[invalid-type-form] 

105 """Splits a MuData object based on the provided indices map. 

106 

107 Args: 

108 mudata: The MuData object to split. 

109 indices_map: A dictionary mapping modalities to their respective indices. 

110 split: The split type ("train", "valid", or "test"). 

111 Returns: 

112 A new MuData object with the specified splits applied. 

113 """ 

114 for modality, data in mudata.mod.items(): 

115 indices = indices_map[modality][split] 

116 mudata.mod[modality] = self._indexing(data, indices) 

117 return mudata 

118 

119 def _requires_paired(self) -> bool: 

120 return self.config.requires_paired is None or self.config.requires_paired 

121 

122 def split(self) -> Dict[str, Optional[Dict[str, Any]]]: 

123 """Splits the underlying DataPackage into train, valid, and test subsets. 

124 Returns: 

125 A dictionary containing the split data packages for "train", "valid", and "test". 

126 Each entry contains a "data" key with the DataPackage and an "indices" key with 

127 the corresponding indices. 

128 Raises: 

129 ValueError: If no data package is available for splitting. 

130 TypeError: If indices are not provided for unpaired translation case. 

131 """ 

132 if self._data_package is None: 

133 raise ValueError("No data package available for splitting") 

134 

135 splits = ["train", "valid", "test"] 

136 result: Dict[str, Optional[Dict[str, Any]]] = { 

137 "train": {}, 

138 "valid": {}, 

139 "test": {}, 

140 } 

141 

142 for split in splits: 

143 if self.indices is None: # or split not in self.indices: 

144 result[split] = None 

145 continue 

146 result[split] = { 

147 "data": self._split_data_package(split=split), 

148 "indices": self.indices, 

149 } 

150 

151 return result