Coverage for src / autoencodix / data / datapackage.py: 21%

191 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1from dataclasses import dataclass, field 

2from typing import Dict, List, Optional, Union, Any, Iterator, Tuple, TypeVar 

3 

4import pandas as pd 

5from anndata import AnnData # type: ignore 

6from mudata import MuData # type: ignore 

7 

8from autoencodix.data._imgdataclass import ImgData 

9 

10 

11T = TypeVar("T") # For generic type hints 

12 

13 

14@dataclass 

15class DataPackage: 

16 """Represents a data package containing multiple types of data.""" 

17 

18 multi_sc: Optional[Dict[str, MuData]] = None # # ty: ignore[invalid-type-form] 

19 multi_bulk: Optional[Dict[str, pd.DataFrame]] = None 

20 annotation: Optional[Dict[str, Union[pd.DataFrame, None]]] = None 

21 img: Optional[Dict[str, List[ImgData]]] = None 

22 

23 from_modality: Optional[ 

24 Dict[ 

25 str, 

26 Union[ 

27 pd.DataFrame, 

28 List[ImgData], 

29 MuData, # ty: ignore[invalid-type-form] 

30 AnnData, # ty: ignore[invalid-type-form] 

31 ], 

32 ] # ty: ignore[invalid-type-form] 

33 ] = field(default_factory=dict, repr=False) 

34 to_modality: Optional[ 

35 Dict[ 

36 str, 

37 Union[ 

38 pd.DataFrame, 

39 List[ImgData], 

40 MuData, # ty: ignore[invalid-type-form] 

41 AnnData, # ty: ignore[invalid-type-form] 

42 ], # ty: ignore[invalid-type-form] 

43 ] # ty: ignore[invalid-type-form] 

44 ] = field(default_factory=dict, repr=False) 

45 

46 def __getitem__(self, key: str) -> Any: 

47 """Allow dictionary-like access to top-level attributes.""" 

48 if hasattr(self, key): 

49 return getattr(self, key) 

50 raise KeyError(f"{key} not found in DataPackage.") 

51 

52 def __setitem__(self, key: str, value: Any) -> None: 

53 """Allow dictionary-like item assignment to top-level attributes.""" 

54 if hasattr(self, key): 

55 setattr(self, key, value) 

56 else: 

57 raise KeyError(f"{key} not found in DataPackage.") 

58 

59 def __iter__(self) -> Iterator[Tuple[str, Any]]: 

60 """Make DataPackage iterable, yielding (key, value) pairs. 

61 

62 For dictionary attributes, yields nested items as (parent_key.child_key, value). 

63 """ 

64 for attr_name in self.__annotations__.keys(): 

65 attr_value = getattr(self, attr_name) 

66 

67 if attr_value is None: 

68 continue 

69 if isinstance(attr_value, dict): 

70 for sub_key, sub_value in attr_value.items(): 

71 yield f"{attr_name}.{sub_key}", sub_value 

72 else: 

73 yield attr_name, attr_value 

74 

75 def format_shapes(self) -> str: 

76 """Format the shape dictionary in a clean, readable way.""" 

77 shapes = self.shape() 

78 lines = [] 

79 

80 for data_type, data_info in shapes.items(): 

81 # Skip empty entries 

82 if not data_info: 

83 continue 

84 

85 sub_items = [] 

86 for subtype, shape in data_info.items(): 

87 if shape is not None: 

88 if isinstance(shape, tuple): 

89 sub_items.append( 

90 f"{subtype}: {shape[0]} samples × {shape[1]} features" 

91 ) 

92 else: 

93 sub_items.append(f"{subtype}: {shape} items") 

94 

95 if sub_items: 

96 lines.append(f"{data_type}:") 

97 lines.extend(f" {item}" for item in sub_items) 

98 

99 if not lines: 

100 return "Empty DataPackage" 

101 

102 return "\n".join(lines) 

103 

104 def __str__(self) -> str: 

105 return self.format_shapes() 

106 

107 def __repr__(self) -> str: 

108 return self.__str__() 

109 

110 def is_empty(self) -> bool: 

111 """Check if the data package is empty.""" 

112 return all( 

113 [ 

114 self.multi_sc is None, 

115 self.multi_bulk is None or len(self.multi_bulk) == 0, 

116 self.annotation is None, 

117 self.img is None, 

118 not self.from_modality, 

119 not self.to_modality, 

120 ] 

121 ) 

122 

123 def get_n_samples(self) -> Dict[str, Dict[str, int]]: 

124 """Get the number of samples for each data type in nested dictionary format. 

125 

126 Returns: 

127 Dictionary with nested structure: {modality_type: {sub_key: count}} 

128 """ 

129 n_samples: Dict[str, Dict[str, int]] = {} 

130 

131 # Process each main attribute 

132 for attr_name in self.__annotations__.keys(): 

133 attr_value = getattr(self, attr_name) 

134 

135 if isinstance(attr_value, dict): 

136 # Handle dictionary attributes (multi_sc, multi_bulk, etc.) 

137 sub_counts = {} 

138 for sub_key, sub_value in attr_value.items(): 

139 if sub_value is None or len(sub_value) == 0: 

140 continue 

141 sub_counts[sub_key] = self._get_n_samples(sub_value) 

142 n_samples[attr_name] = sub_counts if sub_counts else {} 

143 else: 

144 # Handle non-dictionary attributes 

145 count = self._get_n_samples(attr_value) 

146 n_samples[attr_name] = {attr_name: count} 

147 

148 paired_count = self._calculate_paired_count() 

149 n_samples["paired_count"] = {"paired_count": paired_count} 

150 

151 return n_samples 

152 

153 def _calculate_paired_count(self) -> int: 

154 """ 

155 Calculate the number of samples that are common across modalities that have data. 

156 

157 Returns: 

158 Number of common samples across modalities with data 

159 """ 

160 all_counts = [] 

161 

162 # Collect all sample counts from modalities that have data 

163 for attr_name in self.__annotations__.keys(): 

164 attr_value = getattr(self, attr_name) 

165 if attr_value is None: 

166 continue 

167 

168 if isinstance(attr_value, dict): 

169 if attr_value: # Non-empty dictionary 

170 for sub_value in attr_value.values(): 

171 if sub_value is not None: 

172 count = self._get_n_samples(sub_value) 

173 if count > 0: 

174 all_counts.append(count) 

175 else: 

176 count = self._get_n_samples(attr_value) 

177 if count > 0: 

178 all_counts.append(count) 

179 

180 # Return minimum count (intersection) or 0 if no data 

181 return min(all_counts) if all_counts else 0 

182 

183 def get_common_ids(self) -> List[str]: 

184 """Get the common sample IDs across modalities that have data. 

185 

186 Returns: 

187 List of sample IDs that are present in all modalities with data 

188 """ 

189 all_ids = [] 

190 

191 # Collect sample IDs from each modality that has data 

192 for attr_name in self.__annotations__.keys(): 

193 attr_value = getattr(self, attr_name) 

194 if attr_value is None: 

195 continue 

196 

197 if isinstance(attr_value, dict): 

198 if attr_value: # Non-empty dictionary 

199 for sub_value in attr_value.values(): 

200 if sub_value is not None: 

201 ids = self._get_sample_ids(sub_value) 

202 if ids: 

203 all_ids.append(set(ids)) 

204 else: 

205 ids = self._get_sample_ids(attr_value) 

206 if ids: 

207 all_ids.append(set(ids)) 

208 

209 # Find intersection of all ID sets 

210 if not all_ids: 

211 return [] 

212 

213 common = all_ids[0] 

214 for id_set in all_ids[1:]: 

215 common = common.intersection(id_set) 

216 

217 return sorted(list(common)) 

218 

219 def _get_sample_ids( 

220 self, 

221 dataobj: Union[ 

222 MuData, # ty: ignore[invalid-type-form] 

223 pd.DataFrame, 

224 List[ImgData], 

225 AnnData, 

226 ], # ty: ignore[invalid-type-form] 

227 ) -> List[str]: 

228 """ 

229 Extract sample IDs from a data object. 

230 

231 Args: 

232 dataobj: Data object to extract IDs from 

233 

234 Returns: 

235 List of sample IDs 

236 """ 

237 if dataobj is None: 

238 return [] 

239 

240 if isinstance(dataobj, pd.DataFrame): 

241 return dataobj.index.astype(str).tolist() 

242 elif isinstance(dataobj, list): 

243 # For lists of ImgData, extract sample_id from each object 

244 return [ 

245 img_data.sample_id 

246 for img_data in dataobj 

247 if hasattr(img_data, "sample_id") 

248 ] 

249 elif isinstance(dataobj, AnnData): 

250 return dataobj.obs.index.astype(str).tolist() 

251 elif isinstance(dataobj, MuData): 

252 # For MuData, we can use the obs.index directly 

253 return dataobj.obs.index.astype(str).tolist() # ty: ignore 

254 else: 

255 return [] 

256 

257 def _get_n_samples( 

258 self, 

259 dataobj: Union[ 

260 MuData, # ty: ignore[invalid-type-form] 

261 pd.DataFrame, 

262 List[ImgData], 

263 AnnData, 

264 Dict, 

265 ], 

266 ) -> int: 

267 """Get the number of samples for a specific attribute.""" 

268 if dataobj is None: 

269 return 0 

270 

271 if isinstance(dataobj, pd.DataFrame): 

272 return dataobj.shape[0] 

273 elif isinstance(dataobj, dict): 

274 if not dataobj: # Empty dict 

275 return 0 

276 first_value = next(iter(dataobj.values())) 

277 return self._get_n_samples(first_value) 

278 elif isinstance(dataobj, list): 

279 return len(dataobj) 

280 elif isinstance(dataobj, AnnData): 

281 return dataobj.obs.shape[0] 

282 elif isinstance(dataobj, MuData): 

283 if not dataobj.mod: # ty: ignore 

284 return 0 

285 return dataobj.n_obs # ty: ignore 

286 else: 

287 raise ValueError( 

288 f"Unknown data type {type(dataobj)} for dataobj. Probably you've implemented a new attribute in the DataPackage class or changed the data type of an existing attribute." 

289 ) 

290 

291 def shape(self) -> Dict[str, Dict[str, Any]]: 

292 """ 

293 Get the shape of the data for each data type in nested dictionary format. 

294 

295 Returns: 

296 Dictionary with nested structure: {modality_type: {sub_key: shape}} 

297 """ 

298 shapes: Dict[str, Dict[str, Any]] = {} 

299 

300 for attr_name in self.__annotations__.keys(): 

301 attr_value = getattr(self, attr_name) 

302 

303 if isinstance(attr_value, dict): 

304 # Handle dictionary attributes 

305 if attr_value is None or len(attr_value) == 0: 

306 # Empty or None dictionary 

307 shapes[attr_name] = {} 

308 else: 

309 sub_dict = self._get_shape_from_dict(attr_value) 

310 shapes[attr_name] = sub_dict 

311 else: 

312 # Handle non-dictionary attributes 

313 shape = self._get_single_shape(attr_value) 

314 shapes[attr_name] = {attr_name: shape} 

315 

316 return shapes 

317 

318 def _get_single_shape(self, dataobj: Any) -> Optional[Union[Tuple, int]]: 

319 """Get shape for a single data object.""" 

320 if dataobj is None: 

321 return None 

322 elif isinstance(dataobj, list): 

323 return len(dataobj) 

324 elif isinstance(dataobj, pd.DataFrame): 

325 return dataobj.shape 

326 elif isinstance(dataobj, AnnData): 

327 return dataobj.shape 

328 elif isinstance(dataobj, MuData): 

329 return (dataobj.n_obs, dataobj.n_vars) 

330 else: 

331 return None 

332 

333 def _get_shape_from_dict(self, data_dict: Dict) -> Dict[str, Any]: 

334 """Recursively process dictionary to extract shapes of contained data objects. 

335 

336 Args:: 

337 data_dict: Dictionary containing data objects 

338 

339 Returns: 

340 Dictionary with shapes information 

341 """ 

342 result: Dict[str, Any] = {} 

343 for key, value in data_dict.items(): 

344 if isinstance(value, pd.DataFrame): 

345 result[key] = value.shape 

346 elif isinstance(value, list): 

347 # For lists of objects, just store the length 

348 result[key] = len(value) 

349 elif isinstance(value, AnnData): 

350 result[key] = value.shape 

351 elif isinstance(value, MuData): 

352 result[key] = (value.n_obs, value.n_vars) 

353 elif isinstance(value, dict): 

354 # Recursively process nested dictionaries 

355 nested_result = self._get_shape_from_dict(value) 

356 result[key] = nested_result 

357 elif value is None: 

358 result[key] = None 

359 else: 

360 # For unknown types, store a descriptive string instead of raising an error 

361 # This is more robust as it won't crash the entire method 

362 result[key] = f"<{type(value).__name__}>" 

363 

364 return result 

365 

366 def get_modality_key(self, direction: str) -> Optional[str]: 

367 """Get the first key for a specific direction's modality. 

368 

369 Args: 

370 direction: Either 'from' or 'to' 

371 

372 Returns: 

373 First key of the modality dictionary or None if empty 

374 """ 

375 if direction not in ["from", "to"]: 

376 raise ValueError(f"Direction must be 'from' or 'to', got {direction}") 

377 

378 modality_dict = self.from_modality if direction == "from" else self.to_modality 

379 if not modality_dict: 

380 return None 

381 

382 return next(iter(modality_dict.keys()), None)