Coverage for src / autoencodix / data / _sc_filter.py: 14%

162 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union 

2 

3import mudata as md # type: ignore 

4import numpy as np 

5import pandas as pd 

6import scanpy as sc # type: ignore 

7from anndata import AnnData # type: ignore 

8 

9from autoencodix.data._filter import DataFilter 

10from autoencodix.configs.default_config import DataInfo, DefaultConfig 

11 

12if TYPE_CHECKING: 

13 import mudata as md # type: ignore 

14 

15 MuData = md.MuData.MuData 

16else: 

17 MuData = Any 

18 

19 

20class SingleCellFilter: 

21 """Filter and scale single-cell data, returning a MuData object with synchronized metadata.AnnData 

22 

23 Attributes: 

24 data_info: Configuration for filtering and scaling (can be a single DataInfo or a dict of DataInfo per modality). 

25 total_features: Total number of features to keep across all modalities. 

26 config: Configuration object containing settings for data processing. 

27 _is_data_info_dict: Internal flag indicating if data_info is a dictionary. 

28 """ 

29 

30 def __init__( 

31 self, data_info: Union[Dict[str, DataInfo], DataInfo], config: DefaultConfig 

32 ): 

33 """ 

34 Initialize single-cell filter. 

35 Args: 

36 data_info: Either a single data_info object for all modalities or a dictionary of data_info objects for each modality. 

37 config: Configuration object containing settings for data processing. 

38 """ 

39 self.data_info = data_info 

40 self.total_features = config.k_filter 

41 self._is_data_info_dict = isinstance(data_info, dict) 

42 self.config = config 

43 

44 def _get_data_info_for_modality(self, mod_key: str) -> DataInfo: 

45 """ 

46 Get the data_info configuration for a specific modality. 

47 Args: 

48 mod_key: The modality key (e.g., "RNA", "METH") 

49 Returns 

50 The data_info configuration for the modality 

51 """ 

52 if self._is_data_info_dict: 

53 info = self.data_info.get(mod_key) # type: ignore 

54 if info is None: 

55 raise ValueError(f"No data info found for modality {mod_key}") 

56 return info 

57 return self.data_info # type: ignore 

58 

59 def _get_layers_for_modality(self, mod_key: str, mod_data) -> List[str]: 

60 """ 

61 Get the layers to process for a specific modality. 

62 Args 

63 mod_key: The modality key (e.g., "RNA", "METH") 

64 mod_data: The AnnData object for the modality 

65 Returns 

66 List of layer names to process. If None or empty, returns ['X'] for default layer. 

67 """ 

68 data_info = self._get_data_info_for_modality(mod_key) 

69 selected_layers = data_info.selected_layers 

70 

71 # Validate that the specified layers exist 

72 available_layers = list(mod_data.layers.keys()) 

73 valid_layers = [] 

74 

75 for layer in selected_layers: 

76 if layer == "X": 

77 valid_layers.append("X") 

78 elif layer in available_layers: 

79 valid_layers.append(layer) 

80 else: 

81 print( 

82 f"Warning: Layer '{layer}' not found in modality '{mod_key}'. Skipping." 

83 ) 

84 if not valid_layers: 

85 valid_layers = ["X"] 

86 

87 return valid_layers 

88 

89 def _presplit_processing( 

90 self, 

91 mudata: MuData, # type: ignore[invalid-type-form] 

92 ) -> MuData: # type: ignore[invalid-type-form] 

93 """ 

94 Preprocess the data using modality-specific configurations. 

95 Returns: 

96 Preprocessed data 

97 """ 

98 print(f"mudata: {mudata}") 

99 for mod_key, mod_data in mudata.mod.items(): 

100 data_info = self._get_data_info_for_modality(mod_key) 

101 if data_info is not None: 

102 sc.pp.filter_cells(mod_data, min_genes=data_info.min_genes) 

103 layers_to_process = self._get_layers_for_modality(mod_key, mod_data) 

104 

105 for layer in layers_to_process: 

106 if layer == "X": 

107 if data_info.log_transform: 

108 sc.pp.log1p(mod_data) 

109 else: 

110 temp_view = mod_data.copy() 

111 temp_view.X = mod_data.layers[layer].copy() 

112 if data_info.log_transform: 

113 sc.pp.log1p(temp_view) 

114 mod_data.layers[layer] = temp_view.X.copy() 

115 

116 mudata.mod[mod_key] = mod_data 

117 

118 return mudata 

119 

120 def presplit_processing( 

121 self, 

122 multi_sc: Union[MuData, Dict[str, MuData]], # ty: ignore[invalid-type-form] 

123 ) -> Dict[str, MuData]: # ty: ignore[invalid-type-form] 

124 """ 

125 Process each modality independently to filter cells based on min_genes. 

126 

127 Args: 

128 multi_sc: Either a single MuData object or a dictionary of MuData objects. 

129 Returns: 

130 A dictionary mapping modality keys to processed MuData objects. 

131 """ 

132 from mudata import MuData 

133 

134 if isinstance(multi_sc, MuData): 

135 return self._presplit_processing(mudata=multi_sc) 

136 res = {k: None for k in multi_sc.keys()} 

137 for k, v in multi_sc.items(): 

138 

139 processed = self._presplit_processing(mudata=v) 

140 res[k] = processed 

141 return res 

142 

143 def _to_dataframe(self, mod_data, layer=None) -> pd.DataFrame: 

144 """ 

145 Transform a modality's AnnData object to a pandas DataFrame. 

146 Args: 

147 mod_data: Modality data to be transformed 

148 layer: Layer to convert to DataFrame. If None, uses X. 

149 Returns: 

150 Transformed DataFrame 

151 """ 

152 if layer is None or layer == "X": 

153 data = mod_data.X 

154 else: 

155 data = mod_data.layers[layer] 

156 

157 # Convert to dense array if sparse 

158 if isinstance(data, np.ndarray): 

159 matrix = data 

160 else: # Assuming it's a sparse matrix 

161 matrix = data.toarray() 

162 

163 return pd.DataFrame( 

164 matrix, columns=mod_data.var_names, index=mod_data.obs_names 

165 ) 

166 

167 def _from_dataframe(self, df: pd.DataFrame, mod_data, layer=None): 

168 """ 

169 Update a modality's AnnData object with the values from a DataFrame. 

170 This also synchronizes the `obs` and `var` metadata to match the filtered data. 

171 Args: 

172 df: DataFrame containing the updated values 

173 mod_data: Modality data to be updated 

174 layer: Layer to update with DataFrame values. If None, updates X. 

175 Returns: 

176 Updated AnnData object 

177 """ 

178 # Filter the AnnData object to match the rows and columns of the DataFrame 

179 filtered_mod_data = mod_data[df.index, df.columns].copy() 

180 

181 # Update the data matrix with the filtered and scaled values 

182 if layer is None or layer == "X": 

183 filtered_mod_data.X = df.values 

184 else: 

185 if layer not in filtered_mod_data.layers: 

186 filtered_mod_data.layers[layer] = df.values 

187 else: 

188 filtered_mod_data.layers[layer] = df.values 

189 

190 return filtered_mod_data 

191 

192 def sc_postsplit_processing( 

193 self, 

194 mudata: MuData, # ty: ignore[invalid-type-form] 

195 gene_map: Optional[ 

196 Dict[str, List[str]] 

197 ] = None, # ty: ignore[invalid-type-form] 

198 ) -> Tuple[MuData, Dict[str, List[str]]]: # ty: ignore[invalid-type-form] 

199 """ 

200 Process each modality independently to filter genes based on X layer, then 

201 consistently apply the same filtering to all layers. 

202 

203 Args: 

204 mudata : Input multi-modal data container 

205 gene_map : Optional override of genes to keep per modality 

206 

207 Returns: 

208 - Processed MuData with filtered modalities 

209 - Mapping of modality to kept gene names 

210 """ 

211 kept_genes = {} 

212 processed_mods = {} 

213 

214 for mod_key, adata in mudata.mod.items(): 

215 # Get configuration for this modality 

216 info = self._get_data_info_for_modality(mod_key) 

217 if info is None: 

218 raise ValueError(f"No data info for modality '{mod_key}'") 

219 

220 # Determine which genes to keep 

221 if gene_map and mod_key in gene_map: 

222 # Use provided gene list if available 

223 genes_to_keep = gene_map[mod_key] 

224 var_mask = adata.var_names.isin(genes_to_keep) 

225 else: 

226 # Filter genes based on minimum cells expressing each gene 

227 var_mask = sc.pp.filter_genes( 

228 adata.copy(), min_cells=info.min_cells, inplace=False 

229 )[0] 

230 genes_to_keep = adata.var_names[var_mask].tolist() 

231 

232 kept_genes[mod_key] = genes_to_keep 

233 

234 # Create new AnnData with filtered X layer 

235 filtered_adata = AnnData( 

236 X=adata.X[:, var_mask], 

237 obs=adata.obs.copy(), 

238 var=adata.var[var_mask].copy(), 

239 uns=adata.uns.copy(), 

240 obsm=adata.obsm.copy(), 

241 ) 

242 

243 # Normalize if configured 

244 if info.normalize_counts: 

245 sc.pp.normalize_total(filtered_adata) 

246 

247 # Copy filtered layers 

248 for layer in self._get_layers_for_modality(mod_key, adata): 

249 if layer == "X": 

250 continue 

251 

252 if layer not in adata.layers: 

253 raise ValueError( 

254 f"Layer '{layer}' not found in modality '{mod_key}'" 

255 ) 

256 

257 filtered_adata.layers[layer] = adata.layers[layer][:, var_mask].copy() 

258 

259 processed_mods[mod_key] = filtered_adata 

260 

261 # Construct new MuData from filtered modalities 

262 return md.MuData(processed_mods), kept_genes 

263 

264 def _apply_general_filtering( 

265 self, df: pd.DataFrame, data_info: DataInfo, gene_list: Optional[List] 

266 ) -> Tuple[Union[pd.Series, pd.DataFrame], List]: 

267 data_processor = DataFilter(data_info=data_info, config=self.config) 

268 return data_processor.filter(df=df, genes_to_keep=gene_list) 

269 

270 def _apply_scaling( 

271 self, df: pd.DataFrame, data_info: DataInfo, scaler: Any 

272 ) -> Tuple[Union[pd.Series, pd.DataFrame], Any]: 

273 data_processor = DataFilter(data_info=data_info, config=self.config) 

274 if scaler is None: 

275 scaler = data_processor.fit_scaler(df=df) 

276 scaled_df = data_processor.scale(df=df, scaler=scaler) 

277 return scaled_df, scaler 

278 

279 def general_postsplit_processing( 

280 self, 

281 mudata: MuData, # ty: ignore[invalid-type-form] 

282 gene_map: Optional[Dict[str, List]], 

283 scaler_map: Optional[Dict[str, Dict[str, Any]]] = None, 

284 ) -> Tuple[ 

285 MuData, # ty: ignore[invalid-type-form] 

286 Dict[str, List], 

287 Dict[str, Dict[str, Any]], # ty: ignore[invalid-type-form] 

288 ]: # ty: ignore[invalid-type-form] 

289 """Process single-cell data with proper MuData handling 

290 Args: 

291 mudata: Input multi-modal data container 

292 gene_map: Optional override of genes to keep per modality 

293 scaler_map: Optional pre-fitted scalers per modality and layer 

294 Returns: 

295 Processed MuData with filtered and scaled modalities, 

296 """ 

297 feature_distribution = self.distribute_features_across_modalities( 

298 mudata, self.total_features 

299 ) 

300 out_gene_map = {} 

301 out_scaler_map = {mod_key: {} for mod_key in mudata.mod.keys()} 

302 

303 # Dictionary to store processed modalities 

304 processed_modalities = {} 

305 

306 for mod_key, original_mod in mudata.mod.items(): 

307 data_info = self._get_data_info_for_modality(mod_key) 

308 data_info.k_filter = feature_distribution[mod_key] 

309 

310 if data_info is None: 

311 raise ValueError(f"No data info found for modality {mod_key}") 

312 

313 # Create working copy of the modality data 

314 mod_data = original_mod.copy() 

315 

316 # Process X matrix 

317 x_df = self._to_dataframe(mod_data, layer=None) 

318 filtered_x, gene_list = self._apply_general_filtering( 

319 df=x_df, 

320 gene_list=gene_map.get(mod_key) if gene_map else None, 

321 data_info=data_info, 

322 ) 

323 out_gene_map[mod_key] = gene_list 

324 

325 # Apply scaling to X 

326 scaled_x, x_scaler = self._apply_scaling( 

327 df=filtered_x, 

328 data_info=data_info, 

329 scaler=scaler_map[mod_key].get("X") if scaler_map else None, 

330 ) 

331 out_scaler_map[mod_key]["X"] = x_scaler 

332 

333 # Create new AnnData for this modality 

334 processed_adata = self._create_new_adata( 

335 scaled_x, 

336 original_adata=mod_data, 

337 obs_names=mod_data.obs_names.tolist(), 

338 var_names=filtered_x.columns.tolist(), 

339 ) 

340 

341 # Process layers 

342 layers_to_process = self._get_layers_for_modality(mod_key, mod_data) 

343 for layer in layers_to_process: 

344 if layer == "X": 

345 continue 

346 

347 # Process layer data 

348 layer_df = self._to_dataframe(mod_data, layer=layer) 

349 filtered_layer = layer_df[filtered_x.columns] # Match X's columns 

350 

351 # Apply scaling with same genes as X 

352 scaled_layer, layer_scaler = self._apply_scaling( 

353 df=filtered_layer, 

354 data_info=data_info, 

355 scaler=scaler_map[mod_key].get(layer) if scaler_map else None, 

356 ) 

357 out_scaler_map[mod_key][layer] = layer_scaler 

358 

359 # Store in new AnnData 

360 processed_adata.layers[layer] = scaled_layer.values 

361 

362 # Store processed modality 

363 processed_modalities[mod_key] = processed_adata 

364 

365 # Create new MuData from processed modalities 

366 new_mudata = md.MuData(processed_modalities) 

367 

368 return new_mudata, out_gene_map, out_scaler_map 

369 

370 def _create_new_adata(self, df, original_adata, obs_names, var_names): 

371 """Helper to create properly structured AnnData""" 

372 return AnnData( 

373 X=df.values, 

374 obs=original_adata.obs.loc[obs_names], 

375 var=pd.DataFrame(index=var_names), 

376 layers={}, 

377 uns=original_adata.uns.copy(), 

378 obsm=original_adata.obsm.copy(), 

379 varm=original_adata.varm.copy(), 

380 ) 

381 

382 def distribute_features_across_modalities( 

383 self, 

384 mudata: MuData, # ty: ignore[invalid-type-form] 

385 total_features: Optional[int], # ty: ignore[invalid-type-form] 

386 ) -> Dict[str, int]: 

387 """ 

388 Distributes a total number of features across modalities evenly. 

389 

390 Args: 

391 mudata: Multi-modal data object 

392 total_features: Total number of features to distribute across all modalities 

393 

394 Returns: 

395 Dictionary mapping modality keys to number of features to keep 

396 """ 

397 

398 valid_modalities = [key for key in mudata.mod.keys()] 

399 if total_features is None: 

400 return {k: None for k in valid_modalities} 

401 n_modalities = len(valid_modalities) 

402 

403 if n_modalities == 0: 

404 return {} 

405 

406 base_features = total_features // n_modalities 

407 remainder = total_features % n_modalities 

408 

409 # Distribute features 

410 feature_distribution = {} 

411 for i, mod_key in enumerate(valid_modalities): 

412 # Add one extra feature to early modalities if there's remainder 

413 extra = 1 if i < remainder else 0 

414 feature_distribution[mod_key] = base_features + extra 

415 

416 # Set k_filter in data_info if available 

417 data_info = self._get_data_info_for_modality(mod_key) 

418 if data_info is not None: 

419 if not hasattr(data_info, "k_filter"): 

420 setattr(data_info, "k_filter", feature_distribution[mod_key]) 

421 else: 

422 data_info.k_filter = feature_distribution[mod_key] 

423 

424 return feature_distribution