Coverage for src / autoencodix / data / datapackage.py: 21%
191 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1from dataclasses import dataclass, field
2from typing import Dict, List, Optional, Union, Any, Iterator, Tuple, TypeVar
4import pandas as pd
5from anndata import AnnData # type: ignore
6from mudata import MuData # type: ignore
8from autoencodix.data._imgdataclass import ImgData
11T = TypeVar("T") # For generic type hints
14@dataclass
15class DataPackage:
16 """Represents a data package containing multiple types of data."""
18 multi_sc: Optional[Dict[str, MuData]] = None # # ty: ignore[invalid-type-form]
19 multi_bulk: Optional[Dict[str, pd.DataFrame]] = None
20 annotation: Optional[Dict[str, Union[pd.DataFrame, None]]] = None
21 img: Optional[Dict[str, List[ImgData]]] = None
23 from_modality: Optional[
24 Dict[
25 str,
26 Union[
27 pd.DataFrame,
28 List[ImgData],
29 MuData, # ty: ignore[invalid-type-form]
30 AnnData, # ty: ignore[invalid-type-form]
31 ],
32 ] # ty: ignore[invalid-type-form]
33 ] = field(default_factory=dict, repr=False)
34 to_modality: Optional[
35 Dict[
36 str,
37 Union[
38 pd.DataFrame,
39 List[ImgData],
40 MuData, # ty: ignore[invalid-type-form]
41 AnnData, # ty: ignore[invalid-type-form]
42 ], # ty: ignore[invalid-type-form]
43 ] # ty: ignore[invalid-type-form]
44 ] = field(default_factory=dict, repr=False)
46 def __getitem__(self, key: str) -> Any:
47 """Allow dictionary-like access to top-level attributes."""
48 if hasattr(self, key):
49 return getattr(self, key)
50 raise KeyError(f"{key} not found in DataPackage.")
52 def __setitem__(self, key: str, value: Any) -> None:
53 """Allow dictionary-like item assignment to top-level attributes."""
54 if hasattr(self, key):
55 setattr(self, key, value)
56 else:
57 raise KeyError(f"{key} not found in DataPackage.")
59 def __iter__(self) -> Iterator[Tuple[str, Any]]:
60 """Make DataPackage iterable, yielding (key, value) pairs.
62 For dictionary attributes, yields nested items as (parent_key.child_key, value).
63 """
64 for attr_name in self.__annotations__.keys():
65 attr_value = getattr(self, attr_name)
67 if attr_value is None:
68 continue
69 if isinstance(attr_value, dict):
70 for sub_key, sub_value in attr_value.items():
71 yield f"{attr_name}.{sub_key}", sub_value
72 else:
73 yield attr_name, attr_value
75 def format_shapes(self) -> str:
76 """Format the shape dictionary in a clean, readable way."""
77 shapes = self.shape()
78 lines = []
80 for data_type, data_info in shapes.items():
81 # Skip empty entries
82 if not data_info:
83 continue
85 sub_items = []
86 for subtype, shape in data_info.items():
87 if shape is not None:
88 if isinstance(shape, tuple):
89 sub_items.append(
90 f"{subtype}: {shape[0]} samples × {shape[1]} features"
91 )
92 else:
93 sub_items.append(f"{subtype}: {shape} items")
95 if sub_items:
96 lines.append(f"{data_type}:")
97 lines.extend(f" {item}" for item in sub_items)
99 if not lines:
100 return "Empty DataPackage"
102 return "\n".join(lines)
104 def __str__(self) -> str:
105 return self.format_shapes()
107 def __repr__(self) -> str:
108 return self.__str__()
110 def is_empty(self) -> bool:
111 """Check if the data package is empty."""
112 return all(
113 [
114 self.multi_sc is None,
115 self.multi_bulk is None or len(self.multi_bulk) == 0,
116 self.annotation is None,
117 self.img is None,
118 not self.from_modality,
119 not self.to_modality,
120 ]
121 )
123 def get_n_samples(self) -> Dict[str, Dict[str, int]]:
124 """Get the number of samples for each data type in nested dictionary format.
126 Returns:
127 Dictionary with nested structure: {modality_type: {sub_key: count}}
128 """
129 n_samples: Dict[str, Dict[str, int]] = {}
131 # Process each main attribute
132 for attr_name in self.__annotations__.keys():
133 attr_value = getattr(self, attr_name)
135 if isinstance(attr_value, dict):
136 # Handle dictionary attributes (multi_sc, multi_bulk, etc.)
137 sub_counts = {}
138 for sub_key, sub_value in attr_value.items():
139 if sub_value is None or len(sub_value) == 0:
140 continue
141 sub_counts[sub_key] = self._get_n_samples(sub_value)
142 n_samples[attr_name] = sub_counts if sub_counts else {}
143 else:
144 # Handle non-dictionary attributes
145 count = self._get_n_samples(attr_value)
146 n_samples[attr_name] = {attr_name: count}
148 paired_count = self._calculate_paired_count()
149 n_samples["paired_count"] = {"paired_count": paired_count}
151 return n_samples
153 def _calculate_paired_count(self) -> int:
154 """
155 Calculate the number of samples that are common across modalities that have data.
157 Returns:
158 Number of common samples across modalities with data
159 """
160 all_counts = []
162 # Collect all sample counts from modalities that have data
163 for attr_name in self.__annotations__.keys():
164 attr_value = getattr(self, attr_name)
165 if attr_value is None:
166 continue
168 if isinstance(attr_value, dict):
169 if attr_value: # Non-empty dictionary
170 for sub_value in attr_value.values():
171 if sub_value is not None:
172 count = self._get_n_samples(sub_value)
173 if count > 0:
174 all_counts.append(count)
175 else:
176 count = self._get_n_samples(attr_value)
177 if count > 0:
178 all_counts.append(count)
180 # Return minimum count (intersection) or 0 if no data
181 return min(all_counts) if all_counts else 0
183 def get_common_ids(self) -> List[str]:
184 """Get the common sample IDs across modalities that have data.
186 Returns:
187 List of sample IDs that are present in all modalities with data
188 """
189 all_ids = []
191 # Collect sample IDs from each modality that has data
192 for attr_name in self.__annotations__.keys():
193 attr_value = getattr(self, attr_name)
194 if attr_value is None:
195 continue
197 if isinstance(attr_value, dict):
198 if attr_value: # Non-empty dictionary
199 for sub_value in attr_value.values():
200 if sub_value is not None:
201 ids = self._get_sample_ids(sub_value)
202 if ids:
203 all_ids.append(set(ids))
204 else:
205 ids = self._get_sample_ids(attr_value)
206 if ids:
207 all_ids.append(set(ids))
209 # Find intersection of all ID sets
210 if not all_ids:
211 return []
213 common = all_ids[0]
214 for id_set in all_ids[1:]:
215 common = common.intersection(id_set)
217 return sorted(list(common))
219 def _get_sample_ids(
220 self,
221 dataobj: Union[
222 MuData, # ty: ignore[invalid-type-form]
223 pd.DataFrame,
224 List[ImgData],
225 AnnData,
226 ], # ty: ignore[invalid-type-form]
227 ) -> List[str]:
228 """
229 Extract sample IDs from a data object.
231 Args:
232 dataobj: Data object to extract IDs from
234 Returns:
235 List of sample IDs
236 """
237 if dataobj is None:
238 return []
240 if isinstance(dataobj, pd.DataFrame):
241 return dataobj.index.astype(str).tolist()
242 elif isinstance(dataobj, list):
243 # For lists of ImgData, extract sample_id from each object
244 return [
245 img_data.sample_id
246 for img_data in dataobj
247 if hasattr(img_data, "sample_id")
248 ]
249 elif isinstance(dataobj, AnnData):
250 return dataobj.obs.index.astype(str).tolist()
251 elif isinstance(dataobj, MuData):
252 # For MuData, we can use the obs.index directly
253 return dataobj.obs.index.astype(str).tolist() # ty: ignore
254 else:
255 return []
257 def _get_n_samples(
258 self,
259 dataobj: Union[
260 MuData, # ty: ignore[invalid-type-form]
261 pd.DataFrame,
262 List[ImgData],
263 AnnData,
264 Dict,
265 ],
266 ) -> int:
267 """Get the number of samples for a specific attribute."""
268 if dataobj is None:
269 return 0
271 if isinstance(dataobj, pd.DataFrame):
272 return dataobj.shape[0]
273 elif isinstance(dataobj, dict):
274 if not dataobj: # Empty dict
275 return 0
276 first_value = next(iter(dataobj.values()))
277 return self._get_n_samples(first_value)
278 elif isinstance(dataobj, list):
279 return len(dataobj)
280 elif isinstance(dataobj, AnnData):
281 return dataobj.obs.shape[0]
282 elif isinstance(dataobj, MuData):
283 if not dataobj.mod: # ty: ignore
284 return 0
285 return dataobj.n_obs # ty: ignore
286 else:
287 raise ValueError(
288 f"Unknown data type {type(dataobj)} for dataobj. Probably you've implemented a new attribute in the DataPackage class or changed the data type of an existing attribute."
289 )
291 def shape(self) -> Dict[str, Dict[str, Any]]:
292 """
293 Get the shape of the data for each data type in nested dictionary format.
295 Returns:
296 Dictionary with nested structure: {modality_type: {sub_key: shape}}
297 """
298 shapes: Dict[str, Dict[str, Any]] = {}
300 for attr_name in self.__annotations__.keys():
301 attr_value = getattr(self, attr_name)
303 if isinstance(attr_value, dict):
304 # Handle dictionary attributes
305 if attr_value is None or len(attr_value) == 0:
306 # Empty or None dictionary
307 shapes[attr_name] = {}
308 else:
309 sub_dict = self._get_shape_from_dict(attr_value)
310 shapes[attr_name] = sub_dict
311 else:
312 # Handle non-dictionary attributes
313 shape = self._get_single_shape(attr_value)
314 shapes[attr_name] = {attr_name: shape}
316 return shapes
318 def _get_single_shape(self, dataobj: Any) -> Optional[Union[Tuple, int]]:
319 """Get shape for a single data object."""
320 if dataobj is None:
321 return None
322 elif isinstance(dataobj, list):
323 return len(dataobj)
324 elif isinstance(dataobj, pd.DataFrame):
325 return dataobj.shape
326 elif isinstance(dataobj, AnnData):
327 return dataobj.shape
328 elif isinstance(dataobj, MuData):
329 return (dataobj.n_obs, dataobj.n_vars)
330 else:
331 return None
333 def _get_shape_from_dict(self, data_dict: Dict) -> Dict[str, Any]:
334 """Recursively process dictionary to extract shapes of contained data objects.
336 Args::
337 data_dict: Dictionary containing data objects
339 Returns:
340 Dictionary with shapes information
341 """
342 result: Dict[str, Any] = {}
343 for key, value in data_dict.items():
344 if isinstance(value, pd.DataFrame):
345 result[key] = value.shape
346 elif isinstance(value, list):
347 # For lists of objects, just store the length
348 result[key] = len(value)
349 elif isinstance(value, AnnData):
350 result[key] = value.shape
351 elif isinstance(value, MuData):
352 result[key] = (value.n_obs, value.n_vars)
353 elif isinstance(value, dict):
354 # Recursively process nested dictionaries
355 nested_result = self._get_shape_from_dict(value)
356 result[key] = nested_result
357 elif value is None:
358 result[key] = None
359 else:
360 # For unknown types, store a descriptive string instead of raising an error
361 # This is more robust as it won't crash the entire method
362 result[key] = f"<{type(value).__name__}>"
364 return result
366 def get_modality_key(self, direction: str) -> Optional[str]:
367 """Get the first key for a specific direction's modality.
369 Args:
370 direction: Either 'from' or 'to'
372 Returns:
373 First key of the modality dictionary or None if empty
374 """
375 if direction not in ["from", "to"]:
376 raise ValueError(f"Direction must be 'from' or 'to', got {direction}")
378 modality_dict = self.from_modality if direction == "from" else self.to_modality
379 if not modality_dict:
380 return None
382 return next(iter(modality_dict.keys()), None)