Coverage for src / autoencodix / data / _sc_filter.py: 14%
162 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
3import mudata as md # type: ignore
4import numpy as np
5import pandas as pd
6import scanpy as sc # type: ignore
7from anndata import AnnData # type: ignore
9from autoencodix.data._filter import DataFilter
10from autoencodix.configs.default_config import DataInfo, DefaultConfig
12if TYPE_CHECKING:
13 import mudata as md # type: ignore
15 MuData = md.MuData.MuData
16else:
17 MuData = Any
20class SingleCellFilter:
21 """Filter and scale single-cell data, returning a MuData object with synchronized metadata.AnnData
23 Attributes:
24 data_info: Configuration for filtering and scaling (can be a single DataInfo or a dict of DataInfo per modality).
25 total_features: Total number of features to keep across all modalities.
26 config: Configuration object containing settings for data processing.
27 _is_data_info_dict: Internal flag indicating if data_info is a dictionary.
28 """
30 def __init__(
31 self, data_info: Union[Dict[str, DataInfo], DataInfo], config: DefaultConfig
32 ):
33 """
34 Initialize single-cell filter.
35 Args:
36 data_info: Either a single data_info object for all modalities or a dictionary of data_info objects for each modality.
37 config: Configuration object containing settings for data processing.
38 """
39 self.data_info = data_info
40 self.total_features = config.k_filter
41 self._is_data_info_dict = isinstance(data_info, dict)
42 self.config = config
44 def _get_data_info_for_modality(self, mod_key: str) -> DataInfo:
45 """
46 Get the data_info configuration for a specific modality.
47 Args:
48 mod_key: The modality key (e.g., "RNA", "METH")
49 Returns
50 The data_info configuration for the modality
51 """
52 if self._is_data_info_dict:
53 info = self.data_info.get(mod_key) # type: ignore
54 if info is None:
55 raise ValueError(f"No data info found for modality {mod_key}")
56 return info
57 return self.data_info # type: ignore
59 def _get_layers_for_modality(self, mod_key: str, mod_data) -> List[str]:
60 """
61 Get the layers to process for a specific modality.
62 Args
63 mod_key: The modality key (e.g., "RNA", "METH")
64 mod_data: The AnnData object for the modality
65 Returns
66 List of layer names to process. If None or empty, returns ['X'] for default layer.
67 """
68 data_info = self._get_data_info_for_modality(mod_key)
69 selected_layers = data_info.selected_layers
71 # Validate that the specified layers exist
72 available_layers = list(mod_data.layers.keys())
73 valid_layers = []
75 for layer in selected_layers:
76 if layer == "X":
77 valid_layers.append("X")
78 elif layer in available_layers:
79 valid_layers.append(layer)
80 else:
81 print(
82 f"Warning: Layer '{layer}' not found in modality '{mod_key}'. Skipping."
83 )
84 if not valid_layers:
85 valid_layers = ["X"]
87 return valid_layers
89 def _presplit_processing(
90 self,
91 mudata: MuData, # type: ignore[invalid-type-form]
92 ) -> MuData: # type: ignore[invalid-type-form]
93 """
94 Preprocess the data using modality-specific configurations.
95 Returns:
96 Preprocessed data
97 """
98 print(f"mudata: {mudata}")
99 for mod_key, mod_data in mudata.mod.items():
100 data_info = self._get_data_info_for_modality(mod_key)
101 if data_info is not None:
102 sc.pp.filter_cells(mod_data, min_genes=data_info.min_genes)
103 layers_to_process = self._get_layers_for_modality(mod_key, mod_data)
105 for layer in layers_to_process:
106 if layer == "X":
107 if data_info.log_transform:
108 sc.pp.log1p(mod_data)
109 else:
110 temp_view = mod_data.copy()
111 temp_view.X = mod_data.layers[layer].copy()
112 if data_info.log_transform:
113 sc.pp.log1p(temp_view)
114 mod_data.layers[layer] = temp_view.X.copy()
116 mudata.mod[mod_key] = mod_data
118 return mudata
120 def presplit_processing(
121 self,
122 multi_sc: Union[MuData, Dict[str, MuData]], # ty: ignore[invalid-type-form]
123 ) -> Dict[str, MuData]: # ty: ignore[invalid-type-form]
124 """
125 Process each modality independently to filter cells based on min_genes.
127 Args:
128 multi_sc: Either a single MuData object or a dictionary of MuData objects.
129 Returns:
130 A dictionary mapping modality keys to processed MuData objects.
131 """
132 from mudata import MuData
134 if isinstance(multi_sc, MuData):
135 return self._presplit_processing(mudata=multi_sc)
136 res = {k: None for k in multi_sc.keys()}
137 for k, v in multi_sc.items():
139 processed = self._presplit_processing(mudata=v)
140 res[k] = processed
141 return res
143 def _to_dataframe(self, mod_data, layer=None) -> pd.DataFrame:
144 """
145 Transform a modality's AnnData object to a pandas DataFrame.
146 Args:
147 mod_data: Modality data to be transformed
148 layer: Layer to convert to DataFrame. If None, uses X.
149 Returns:
150 Transformed DataFrame
151 """
152 if layer is None or layer == "X":
153 data = mod_data.X
154 else:
155 data = mod_data.layers[layer]
157 # Convert to dense array if sparse
158 if isinstance(data, np.ndarray):
159 matrix = data
160 else: # Assuming it's a sparse matrix
161 matrix = data.toarray()
163 return pd.DataFrame(
164 matrix, columns=mod_data.var_names, index=mod_data.obs_names
165 )
167 def _from_dataframe(self, df: pd.DataFrame, mod_data, layer=None):
168 """
169 Update a modality's AnnData object with the values from a DataFrame.
170 This also synchronizes the `obs` and `var` metadata to match the filtered data.
171 Args:
172 df: DataFrame containing the updated values
173 mod_data: Modality data to be updated
174 layer: Layer to update with DataFrame values. If None, updates X.
175 Returns:
176 Updated AnnData object
177 """
178 # Filter the AnnData object to match the rows and columns of the DataFrame
179 filtered_mod_data = mod_data[df.index, df.columns].copy()
181 # Update the data matrix with the filtered and scaled values
182 if layer is None or layer == "X":
183 filtered_mod_data.X = df.values
184 else:
185 if layer not in filtered_mod_data.layers:
186 filtered_mod_data.layers[layer] = df.values
187 else:
188 filtered_mod_data.layers[layer] = df.values
190 return filtered_mod_data
192 def sc_postsplit_processing(
193 self,
194 mudata: MuData, # ty: ignore[invalid-type-form]
195 gene_map: Optional[
196 Dict[str, List[str]]
197 ] = None, # ty: ignore[invalid-type-form]
198 ) -> Tuple[MuData, Dict[str, List[str]]]: # ty: ignore[invalid-type-form]
199 """
200 Process each modality independently to filter genes based on X layer, then
201 consistently apply the same filtering to all layers.
203 Args:
204 mudata : Input multi-modal data container
205 gene_map : Optional override of genes to keep per modality
207 Returns:
208 - Processed MuData with filtered modalities
209 - Mapping of modality to kept gene names
210 """
211 kept_genes = {}
212 processed_mods = {}
214 for mod_key, adata in mudata.mod.items():
215 # Get configuration for this modality
216 info = self._get_data_info_for_modality(mod_key)
217 if info is None:
218 raise ValueError(f"No data info for modality '{mod_key}'")
220 # Determine which genes to keep
221 if gene_map and mod_key in gene_map:
222 # Use provided gene list if available
223 genes_to_keep = gene_map[mod_key]
224 var_mask = adata.var_names.isin(genes_to_keep)
225 else:
226 # Filter genes based on minimum cells expressing each gene
227 var_mask = sc.pp.filter_genes(
228 adata.copy(), min_cells=info.min_cells, inplace=False
229 )[0]
230 genes_to_keep = adata.var_names[var_mask].tolist()
232 kept_genes[mod_key] = genes_to_keep
234 # Create new AnnData with filtered X layer
235 filtered_adata = AnnData(
236 X=adata.X[:, var_mask],
237 obs=adata.obs.copy(),
238 var=adata.var[var_mask].copy(),
239 uns=adata.uns.copy(),
240 obsm=adata.obsm.copy(),
241 )
243 # Normalize if configured
244 if info.normalize_counts:
245 sc.pp.normalize_total(filtered_adata)
247 # Copy filtered layers
248 for layer in self._get_layers_for_modality(mod_key, adata):
249 if layer == "X":
250 continue
252 if layer not in adata.layers:
253 raise ValueError(
254 f"Layer '{layer}' not found in modality '{mod_key}'"
255 )
257 filtered_adata.layers[layer] = adata.layers[layer][:, var_mask].copy()
259 processed_mods[mod_key] = filtered_adata
261 # Construct new MuData from filtered modalities
262 return md.MuData(processed_mods), kept_genes
264 def _apply_general_filtering(
265 self, df: pd.DataFrame, data_info: DataInfo, gene_list: Optional[List]
266 ) -> Tuple[Union[pd.Series, pd.DataFrame], List]:
267 data_processor = DataFilter(data_info=data_info, config=self.config)
268 return data_processor.filter(df=df, genes_to_keep=gene_list)
270 def _apply_scaling(
271 self, df: pd.DataFrame, data_info: DataInfo, scaler: Any
272 ) -> Tuple[Union[pd.Series, pd.DataFrame], Any]:
273 data_processor = DataFilter(data_info=data_info, config=self.config)
274 if scaler is None:
275 scaler = data_processor.fit_scaler(df=df)
276 scaled_df = data_processor.scale(df=df, scaler=scaler)
277 return scaled_df, scaler
279 def general_postsplit_processing(
280 self,
281 mudata: MuData, # ty: ignore[invalid-type-form]
282 gene_map: Optional[Dict[str, List]],
283 scaler_map: Optional[Dict[str, Dict[str, Any]]] = None,
284 ) -> Tuple[
285 MuData, # ty: ignore[invalid-type-form]
286 Dict[str, List],
287 Dict[str, Dict[str, Any]], # ty: ignore[invalid-type-form]
288 ]: # ty: ignore[invalid-type-form]
289 """Process single-cell data with proper MuData handling
290 Args:
291 mudata: Input multi-modal data container
292 gene_map: Optional override of genes to keep per modality
293 scaler_map: Optional pre-fitted scalers per modality and layer
294 Returns:
295 Processed MuData with filtered and scaled modalities,
296 """
297 feature_distribution = self.distribute_features_across_modalities(
298 mudata, self.total_features
299 )
300 out_gene_map = {}
301 out_scaler_map = {mod_key: {} for mod_key in mudata.mod.keys()}
303 # Dictionary to store processed modalities
304 processed_modalities = {}
306 for mod_key, original_mod in mudata.mod.items():
307 data_info = self._get_data_info_for_modality(mod_key)
308 data_info.k_filter = feature_distribution[mod_key]
310 if data_info is None:
311 raise ValueError(f"No data info found for modality {mod_key}")
313 # Create working copy of the modality data
314 mod_data = original_mod.copy()
316 # Process X matrix
317 x_df = self._to_dataframe(mod_data, layer=None)
318 filtered_x, gene_list = self._apply_general_filtering(
319 df=x_df,
320 gene_list=gene_map.get(mod_key) if gene_map else None,
321 data_info=data_info,
322 )
323 out_gene_map[mod_key] = gene_list
325 # Apply scaling to X
326 scaled_x, x_scaler = self._apply_scaling(
327 df=filtered_x,
328 data_info=data_info,
329 scaler=scaler_map[mod_key].get("X") if scaler_map else None,
330 )
331 out_scaler_map[mod_key]["X"] = x_scaler
333 # Create new AnnData for this modality
334 processed_adata = self._create_new_adata(
335 scaled_x,
336 original_adata=mod_data,
337 obs_names=mod_data.obs_names.tolist(),
338 var_names=filtered_x.columns.tolist(),
339 )
341 # Process layers
342 layers_to_process = self._get_layers_for_modality(mod_key, mod_data)
343 for layer in layers_to_process:
344 if layer == "X":
345 continue
347 # Process layer data
348 layer_df = self._to_dataframe(mod_data, layer=layer)
349 filtered_layer = layer_df[filtered_x.columns] # Match X's columns
351 # Apply scaling with same genes as X
352 scaled_layer, layer_scaler = self._apply_scaling(
353 df=filtered_layer,
354 data_info=data_info,
355 scaler=scaler_map[mod_key].get(layer) if scaler_map else None,
356 )
357 out_scaler_map[mod_key][layer] = layer_scaler
359 # Store in new AnnData
360 processed_adata.layers[layer] = scaled_layer.values
362 # Store processed modality
363 processed_modalities[mod_key] = processed_adata
365 # Create new MuData from processed modalities
366 new_mudata = md.MuData(processed_modalities)
368 return new_mudata, out_gene_map, out_scaler_map
370 def _create_new_adata(self, df, original_adata, obs_names, var_names):
371 """Helper to create properly structured AnnData"""
372 return AnnData(
373 X=df.values,
374 obs=original_adata.obs.loc[obs_names],
375 var=pd.DataFrame(index=var_names),
376 layers={},
377 uns=original_adata.uns.copy(),
378 obsm=original_adata.obsm.copy(),
379 varm=original_adata.varm.copy(),
380 )
382 def distribute_features_across_modalities(
383 self,
384 mudata: MuData, # ty: ignore[invalid-type-form]
385 total_features: Optional[int], # ty: ignore[invalid-type-form]
386 ) -> Dict[str, int]:
387 """
388 Distributes a total number of features across modalities evenly.
390 Args:
391 mudata: Multi-modal data object
392 total_features: Total number of features to distribute across all modalities
394 Returns:
395 Dictionary mapping modality keys to number of features to keep
396 """
398 valid_modalities = [key for key in mudata.mod.keys()]
399 if total_features is None:
400 return {k: None for k in valid_modalities}
401 n_modalities = len(valid_modalities)
403 if n_modalities == 0:
404 return {}
406 base_features = total_features // n_modalities
407 remainder = total_features % n_modalities
409 # Distribute features
410 feature_distribution = {}
411 for i, mod_key in enumerate(valid_modalities):
412 # Add one extra feature to early modalities if there's remainder
413 extra = 1 if i < remainder else 0
414 feature_distribution[mod_key] = base_features + extra
416 # Set k_filter in data_info if available
417 data_info = self._get_data_info_for_modality(mod_key)
418 if data_info is not None:
419 if not hasattr(data_info, "k_filter"):
420 setattr(data_info, "k_filter", feature_distribution[mod_key])
421 else:
422 data_info.k_filter = feature_distribution[mod_key]
424 return feature_distribution