Coverage for src / autoencodix / configs / default_config.py: 76%
241 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1import warnings
2from enum import Enum
3from typing import Any, Dict, List, Literal, Optional, Union
5from pydantic import (
6 BaseModel,
7 ConfigDict,
8 Field,
9 ValidationInfo,
10 field_validator,
11 model_validator,
12)
15class SchemaPrinterMixin:
16 """Mixin class that adds schema printing functionality to Pydantic models."""
18 @classmethod
19 def get_params(cls) -> Dict[str, Dict[str, Any]]:
20 """
21 Get detailed information about all config fields including types and default values.
23 Returns
24 -------
25 Dict[str, Dict[str, Any]]
26 Dictionary containing field name, type, default value, and description if available
27 """
28 fields_info = {}
29 for name, field in cls.model_fields.items(): # type: ignore
30 fields_info[name] = {
31 "type": str(field.annotation),
32 "default": field.default,
33 "description": field.description or "No description available",
34 }
35 return fields_info
37 @classmethod
38 def print_schema(cls, filter_params: Optional[List[str]] = None) -> None:
39 """
40 Print a human-readable schema of all config parameters.
42 Args:
43 filter_params: If provided, only print information for these parameters
44 """
45 if filter_params:
46 print("Valid Keyword Arguments:")
47 print("-" * 50)
48 else:
49 print(f"\n{cls.__name__} Configuration Parameters:")
50 print("-" * 50)
52 for name, info in cls.get_params().items():
53 if filter_params and name not in filter_params:
54 continue
55 print(f"\n{name}:")
56 print(f" Type: {info['type']}")
57 print(f" Default: {info['default']}")
58 print(f" Description: {info['description']}")
61class DataCase(str, Enum):
62 MULTI_SINGLE_CELL = "Multi Single Cell"
63 MULTI_BULK = "Multi Bulk"
64 BULK_TO_BULK = "Bulk<->Bulk"
65 IMG_TO_BULK = "IMG<->Bulk"
66 SINGLE_CELL_TO_SINGLE_CELL = "Single Cell<->Single Cell"
67 SINGLE_CELL_TO_IMG = "Single Cell<->IMG"
68 IMG_TO_IMG = "IMG<->IMG"
71class ConfigValidationError(Exception):
72 pass
75class DataInfo(BaseModel, SchemaPrinterMixin):
76 # general -------------------------------------
77 file_path: str = Field(default="", description="Path to raw data file")
78 data_type: Literal["NUMERIC", "CATEGORICAL", "IMG", "ANNOTATION"] = Field(
79 default="NUMERIC"
80 )
81 scaling: Literal[
82 "STANDARD", "MINMAX", "ROBUST", "MAXABS", "NONE", "NOTSET", "LOG1P"
83 ] = Field(
84 default="NOTSET",
85 description="Setting the scaling here in DataInfo overrides the globally set scaling method for the specific data modality",
86 ) # can also be set globally, for all data modalities.
88 filtering: Literal["VAR", "MAD", "CORR", "VARCORR", "NOFILT", "NONZEROVAR"] = Field(
89 default="VAR"
90 )
91 sep: Union[str, None] = Field(default=None) # for pandas read_csv
92 extra_anno_file: Union[str, None] = Field(default=None)
94 # single cell specific -------------------------
95 is_single_cell: bool = Field(default=False)
97 min_cells: float = Field(
98 default=0.05,
99 ge=0,
100 le=1,
101 description="Minimum fraction of cells a gene must be expressed in to be kept. Genes expressed in fewer cells will be filtered out.",
102 ) # Controls gene filtering based on expression prevalence
104 min_genes: float = Field(
105 default=0.02,
106 ge=0,
107 le=1,
108 description="Minimum fraction of genes a cell must express to be kept. Cells expressing fewer genes will be filtered out.",
109 ) # Controls cell quality filtering
110 selected_layers: List[str] = Field(default=["X"])
112 is_X: bool = Field(default=False) # only for single cell data
113 normalize_counts: bool = Field(
114 default=True, description="Whether to normalize by total counts"
115 )
116 log_transform: bool = Field(
117 default=False, description="Whether to apply log1p transformation"
118 )
119 k_filter: Optional[int] = Field(
120 default=None,
121 description="Don't set this gets calculated dynamically, based on k_filter in general config ",
122 )
123 # image specific ------------------------------
124 img_width_resize: Union[int, None] = Field(default=64)
125 img_height_resize: Union[int, None] = Field(default=64)
126 # annotation specific -------------------------
127 # xmodalix specific -------------------------
128 translate_direction: Union[Literal["from", "to"], None] = Field(default=None)
129 pretrain_epochs: Optional[int] = Field(
130 default=None,
131 description="Number of pretraining epochs. This overwrites the global 'pretraining_epochs' in DefaultConfig class to have different number of pretraining epochs for each data modality",
132 )
134 @field_validator("selected_layers")
135 @classmethod
136 def validate_selected_layers(cls, v):
137 if "X" not in v:
138 raise ValueError('"X" must always be a part of the selected_layers list')
139 return v
141 @field_validator("k_filter", mode="before")
142 @classmethod
143 def _forbid_user_k_filter(cls, v: Any, info: ValidationInfo) -> Any:
144 """
145 'before' -> runs only when the value comes from user input.
146 After instantiation we can still do data_info.k_filter = xx
147 """
148 if v is not None:
149 raise ValueError(
150 "k_filter is computed automatically for each data modality, based on global k_filter – remove it from your DataInfo configuration."
151 )
152 return v
154 # # add validation to only allow quadratic image resizing
155 # @field_validator("img_width_resize", "img_height_resize")
156 # @classmethod
157 # def validate_image_resize(cls, v, values):
158 # if v is not None and v <= 0:
159 # raise ValueError("Image resize dimensions must be positive integers")
160 # if "img_width_resize" in values and "img_height_resize" in values:
161 # if values["img_width_resize"] != values["img_height_resize"]:
162 # raise ValueError("Image width and height must be the same for resizing")
163 # return v
165 @field_validator("img_width_resize", "img_height_resize")
166 @classmethod
167 def validate_image_resize(cls, v, info: ValidationInfo):
168 if v is not None and v <= 0:
169 raise ValueError("Image resize dimensions must be positive integers")
171 # Access other field values through info.data
172 data = info.data
173 if "img_width_resize" in data and "img_height_resize" in data:
174 if data["img_width_resize"] != data["img_height_resize"]:
175 raise ValueError("Image width and height must be the same for resizing")
176 return v
179class DataConfig(BaseModel, SchemaPrinterMixin):
180 data_info: Dict[str, DataInfo]
181 require_common_cells: Optional[bool] = Field(default=False)
182 annotation_columns: Optional[List[str]] = Field(default=None)
185# write tests: done
186class DefaultConfig(BaseModel, SchemaPrinterMixin):
187 """Complete configuration for model, training, hardware, and data handling."""
189 # Input validation
190 model_config = ConfigDict(extra="forbid")
191 # Datasets configuration --------------------------------------------------
192 data_config: DataConfig = DataConfig(data_info={})
193 annotation_columns: Optional[List[str]] = Field(default=None)
194 img_path_col: str = Field(
195 default="img_paths",
196 description="When working with images, we except a column in your annotation file that specifies the path of the image for a particular sample. Here you can define the name of this column",
197 )
198 requires_paired: Union[bool, None] = Field(
199 default_factory=lambda: True,
200 description="Indicator if the samples for the xmodalix are paired, based on some sample id",
201 )
203 data_case: Union[DataCase, None] = Field(
204 default_factory=lambda: None,
205 description="Data case for the model, will be determined automatically",
206 )
207 k_filter: Union[int, None] = Field(
208 default=None, description="Number of features to keep"
209 )
210 scaling: Literal["STANDARD", "MINMAX", "ROBUST", "MAXABS", "NONE", "LOG1P"] = Field(
211 default="STANDARD",
212 description="Setting the scaling here for all data modalities, can per overruled by setting scaling at data modality level per data modality",
213 )
215 skip_preprocessing: bool = Field(
216 default=False, description="If set don't scale, filter or clean the input data."
217 )
219 class_param: Optional[str] = Field(default=None)
221 # Model configuration -----------------------------------------------------
222 latent_dim: int = Field(
223 default=16, ge=1, description="Dimension of the latent space"
224 )
225 hidden_dim: int = Field(
226 default=16,
227 ge=1,
228 description="Hidden dimension of image_vae, applies only to image_vae",
229 )
230 n_layers: int = Field(
231 default=3,
232 ge=0,
233 description="Number of layers in encoder/decoder, without latent layer. If 0, is only the latent layer.",
234 )
235 enc_factor: float = Field(
236 default=4, gt=0, description="Scaling factor for encoder dimensions"
237 )
238 maskix_hidden_dim: int = Field(
239 default=256,
240 ge=8,
241 description="The Maskix implementation follows https://doi.org/10.1093/bioinformatics/btae020. The authors use a hidden dimension 0f 256 for their neural network, so we set this as default",
242 )
243 maskix_swap_prob: float = Field(
244 default=0.4,
245 ge=0,
246 description="For the Maskix input_data masinkg, we sample a probablity if samples within one gene should be swapt. This is done with a Bernoulli distribution, maskix_swap_prob is the probablity passed to the bernoulli distribution ",
247 )
248 drop_p: float = Field(
249 default=0.1, ge=0.0, le=1.0, description="Dropout probability"
250 )
252 # Training configuration --------------------------------------------------
253 save_memory: bool = Field(
254 default=False, description="If set to True we don't store TrainingDynamics"
255 )
256 save_vram: bool = Field(
257 default=False,
258 description="If set to True we move intermediate results to CPU to save GPU VRAM, but this will be slower",
259 )
260 learning_rate: float = Field(
261 default=0.001, gt=0, description="Learning rate for optimization"
262 )
263 compile_model: bool = Field(
264 default=False,
265 description="If set to True we compile the model with torch.compile",
266 )
267 pin_memory: bool = Field(
268 default=False, description="Pin memory for faster data transfer"
269 )
270 batch_size: int = Field(
271 default=32,
272 ge=2,
273 description="Number of samples per batch, has to be > 1, because we use BatchNorm() Layer",
274 )
275 epochs: int = Field(default=3, ge=1, description="Number of training epochs")
276 weight_decay: float = Field(
277 default=0.01, ge=0, description="L2 regularization factor"
278 )
279 reconstruction_loss: Literal["mse", "bce"] = Field(
280 default="mse", description="Type of reconstruction loss"
281 )
282 default_vae_loss: Literal["kl", "mmd"] = Field(
283 default="kl", description="Type of VAE loss"
284 )
285 loss_reduction: Literal["sum", "mean"] = Field(
286 default="sum",
287 description="Loss reduction in PyTorch i.e in torch.nn.functional.binary_cross_entropy_with_logits(reduction=loss_reduction)",
288 )
289 beta: float = Field(
290 default=1, ge=0, description="Beta weighting factor for VAE loss"
291 )
292 beta_mi: float = Field(
293 default=1,
294 ge=0,
295 description="Beta weighting factor for mutual information term in disentangled VAE loss",
296 )
297 beta_tc: float = Field(
298 default=1,
299 ge=0,
300 description="Beta weighting factor for total correlation term in disentangled VAE loss",
301 )
302 beta_dimKL: float = Field(
303 default=1,
304 ge=0,
305 description="Beta weighting factor for dimension-wise KL in disentangled VAE loss",
306 )
307 use_mss: bool = Field(
308 default=True,
309 description="Using minibatch stratified sampling for disentangled VAE loss calculation (faster estimation)",
310 )
311 gamma: float = Field(
312 default=10.0,
313 ge=0,
314 description="Gamma weighting factor for Adversial Loss Term i.e. for XModalix Classfier training",
315 )
316 delta_pair: float = Field(
317 default=5.0,
318 ge=0,
319 description="Delta weighting factor for paired loss term in XModalix Training",
320 )
321 delta_class: float = Field(
322 default=5.0,
323 ge=0,
324 description="Delta weighting factor for class loss term in XModalix Training",
325 )
326 delta_mask_predictor: float = Field(
327 default=0.7,
328 ge=0.0,
329 description="Delt weighting factor of the mask predictin loss term for the Maskix",
330 )
331 delta_mask_corrupted: float = Field(
332 default=0.75,
333 ge=0.0,
334 description="For the Maskix: if >0.5 this gives more weight for the correct reconstruction of corrupted input",
335 )
336 maskix_architecture: Literal["scMAE", "custom"] = Field(
337 default="scMAE",
338 description="If you want to customize your maskix architecture \
339 via 'n_layers' or 'enc_factor, you need to set this to 'custom'. \
340 Otherwise, the architecture for the scMAE from https://doi.org/10.1093/bioinformatics/btae020 is used",
341 )
342 min_samples_per_split: int = Field(
343 default=1, ge=1, description="Minimum number of samples per split"
344 )
345 anneal_function: Literal[
346 "5phase-constant",
347 "3phase-linear",
348 "3phase-log",
349 "logistic-mid",
350 "logistic-early",
351 "logistic-late",
352 "no-annealing",
353 ] = Field(
354 default="logistic-mid",
355 description="Annealing function strategy for VAE loss scheduling",
356 )
357 pretrain_epochs: int = Field(
358 default=0,
359 ge=0,
360 description="Number of pretraining epochs, can be overwritten in DataInfo to have different number of pretraining epochs for each data modality",
361 )
363 # Hardware configuration --------------------------------------------------
364 device: Literal["cpu", "cuda", "gpu", "tpu", "mps", "auto"] = Field(
365 default="auto", description="Device to use"
366 )
367 # 0 uses cpu and not gpu
368 n_gpus: int = Field(default=1, ge=1, description="Number of GPUs to use")
369 checkpoint_interval: int = Field(
370 default=10, ge=1, description="Interval for saving checkpoints"
371 )
372 float_precision: Literal[
373 "transformer-engine",
374 "transformer-engine-float16",
375 "16-true",
376 "16-mixed",
377 "bf16-true",
378 "bf16-mixed",
379 "32-true",
380 "64-true",
381 "64",
382 "32",
383 "16",
384 "bf16",
385 ] = Field(default="32", description="Floating point precision")
386 gpu_strategy: Literal[
387 "auto",
388 "dp",
389 "ddp",
390 "ddp_spawn",
391 "ddp_find_unused_parameters_true",
392 "xla",
393 "deepspeed",
394 "fsdp",
395 ] = Field(default="auto", description="GPU parallelization strategy")
397 # Data handling configuration ---------------------------------------------
398 train_ratio: float = Field(
399 default=0.7, ge=0, lt=1, description="Ratio of data for training"
400 )
401 test_ratio: float = Field(
402 default=0.2, ge=0, lt=1, description="Ratio of data for testing"
403 )
404 valid_ratio: float = Field(
405 default=0.1, ge=0, lt=1, description="Ratio of data for validation"
406 )
408 # General configuration ---------------------------------------------------
409 reproducible: bool = Field(
410 default=False, description="Whether to ensure reproducibility"
411 )
412 global_seed: int = Field(default=1, ge=0, description="Global random seed")
413 profiling: bool = Field(
414 default=False,
415 description="Internal Only: if set to true runs torch.profiler on xmodalix trainer",
416 )
417 profile_logs: str = Field(default="profile")
419 ##### VALIDATION ##### -----------------------------------------------------
420 ##### ----------------- -----------------------------------------------------
422 @model_validator(mode="after")
423 def handle_backward_compatibility(self) -> "DefaultConfig":
424 """Handle migration of annotation_columns from DataConfig to DefaultConfig."""
425 if self.data_config.annotation_columns is not None:
426 if self.annotation_columns is not None:
427 warnings.warn(
428 "annotation_columns is set in both DefaultConfig and DataConfig. "
429 "Using the value from DefaultConfig."
430 )
431 self.data_config.annotation_columns = self.annotation_columns
432 else:
433 warnings.warn(
434 "annotation_columns in DataConfig is deprecated. "
435 "Please set it directly in DefaultConfig instead."
436 )
437 self.annotation_columns = self.data_config.annotation_columns
438 else:
439 self.data_config.annotation_columns = self.annotation_columns
440 return self
442 @field_validator("data_config")
443 @classmethod
444 def validate_data_config(cls, data_config: DataConfig):
445 """Main validation logic for dataset consistency and translation."""
446 data_info = data_config.data_info
448 numeric_count = sum(
449 1 for info in data_info.values() if info.data_type == "NUMERIC"
450 )
451 img_count = sum(1 for info in data_info.values() if info.data_type == "IMG")
453 if numeric_count == 0 and img_count == 0:
454 raise ConfigValidationError("At least one NUMERIC dataset is required.")
456 numeric_datasets = [
457 info for info in data_info.values() if info.data_type == "NUMERIC"
458 ]
459 if numeric_datasets:
460 is_single_cell = numeric_datasets[0].is_single_cell
461 if any(info.is_single_cell != is_single_cell for info in numeric_datasets):
462 raise ConfigValidationError(
463 "All numeric datasets must be either single cell or bulk."
464 )
466 from_dataset = next(
467 (
468 (name, info)
469 for name, info in data_info.items()
470 if info.translate_direction == "from"
471 ),
472 None,
473 )
474 to_dataset = next(
475 (
476 (name, info)
477 for name, info in data_info.items()
478 if info.translate_direction == "to"
479 ),
480 None,
481 )
483 if bool(from_dataset) != bool(to_dataset):
484 raise ConfigValidationError(
485 "Translation requires exactly one 'from' and one 'to' dataset."
486 )
488 if from_dataset and to_dataset:
489 from_info, to_info = from_dataset[1], to_dataset[1]
490 if from_info.data_type == "NUMERIC" and to_info.data_type == "NUMERIC":
491 if from_info.is_single_cell != to_info.is_single_cell:
492 raise ConfigValidationError(
493 "Cannot translate between single cell and bulk data."
494 )
496 return data_config
498 @model_validator(mode="after")
499 def determine_case(self) -> "DefaultConfig":
500 """Assign the correct DataCase after model validation."""
501 data_info = self.data_config.data_info
503 # Handle empty data_info case
504 if not data_info:
505 return self
507 # Find 'from' and 'to' datasets
508 from_dataset = next(
509 (
510 (name, info)
511 for name, info in data_info.items()
512 if info.translate_direction == "from"
513 ),
514 None,
515 )
516 to_dataset = next(
517 (
518 (name, info)
519 for name, info in data_info.items()
520 if info.translate_direction == "to"
521 ),
522 None,
523 )
525 if from_dataset and to_dataset:
526 from_info, to_info = from_dataset[1], to_dataset[1]
527 if from_info.data_type == "NUMERIC" and to_info.data_type == "NUMERIC":
528 self.data_case = (
529 DataCase.SINGLE_CELL_TO_SINGLE_CELL
530 if from_info.is_single_cell
531 else DataCase.BULK_TO_BULK
532 )
533 elif "IMG" in {from_info.data_type, to_info.data_type}:
534 numeric_dataset = (
535 from_info if from_info.data_type == "NUMERIC" else to_info
536 )
537 # check for IMG_IMG
538 if from_info.data_type == "IMG" and to_info.data_type == "IMG":
539 self.data_case = DataCase.IMG_TO_IMG
540 else:
541 self.data_case = (
542 DataCase.SINGLE_CELL_TO_IMG
543 if numeric_dataset.is_single_cell
544 else DataCase.IMG_TO_BULK
545 )
546 else:
547 img_ds = [info for info in data_info.values() if info.data_type == "IMG"]
548 if img_ds:
549 self.data_case = DataCase.IMG_TO_IMG
551 numeric_datasets = [
552 info for info in data_info.values() if info.data_type == "NUMERIC"
553 ]
555 if numeric_datasets:
556 numeric_dataset = numeric_datasets[0]
557 self.data_case = (
558 DataCase.MULTI_SINGLE_CELL
559 if numeric_dataset.is_single_cell
560 else DataCase.MULTI_BULK
561 )
562 if self.data_case is None:
563 import warnings
565 warnings.warn(message="Could not determine data_case")
567 return self
569 @field_validator("test_ratio", "valid_ratio")
570 def validate_ratios(cls, v, values):
571 total = (
572 sum(
573 values.data.get(key, 0)
574 for key in ["train_ratio", "test_ratio", "valid_ratio"]
575 )
576 + v
577 )
578 if total > 1.0:
579 raise ValueError(f"Data split ratios must sum to 1.0 or less (got {total})")
580 return v
582 # TODO test if other float precisions work with MPS
583 @field_validator("float_precision")
584 def validate_float_precision(cls, v, values):
585 """Validate float precision based on device type."""
586 device = values.data["device"]
587 if device == "mps" and v != "32":
588 raise ValueError("MPS backend only supports float precision '32'")
589 return v
591 # gpu strategy needs to be auto for mps # TODO test if other strategies work
592 @field_validator("gpu_strategy")
593 def validate_gpu_strategy(cls, v, values):
594 device = values.data.get("device")
595 if device == "mps" and v != "auto":
596 raise ValueError("MPS backend only supports GPU strategy 'auto'")
598 @model_validator(mode="after")
599 def validate_k_filter_with_nonzero_var(self):
600 k_filter = self.k_filter
602 data_info = self.data_config.data_info
604 for info in data_info.values():
605 if info.filtering == "NONZEROVAR" and k_filter is not None:
606 raise ValueError(
607 "k_filter cannot be combined with DataInfo that has scaling set to 'NONZEROVAR'"
608 )
610 return self
612 #### END VALIDATION #### --------------------------------------------------
614 #### READIBILITY #### ------------------------------------------------------
615 #### ------------ #### ------------------------------------------------------
616 @classmethod
617 def get_params(cls) -> Dict[str, Dict[str, Any]]:
618 """
619 Get detailed information about all config fields including types and default values.
621 Returns:
622 Dictionary containing field name, type, default value, and description if available
623 """
624 fields_info = {}
625 for name, field in cls.model_fields.items():
626 fields_info[name] = {
627 "type": str(field.annotation),
628 "default": field.default,
629 "description": field.description or "No description available",
630 }
631 return fields_info
633 @classmethod
634 def print_schema(cls, filter_params: Optional[None] = None) -> None: # type: ignore
635 """
636 Print a human-readable schema of all config parameters.
637 """
638 if filter_params:
639 filter_params = list(filter_params)
640 print("Valid Keyword Arguments:")
641 print("-" * 50)
642 else:
643 print(f"\n{cls.__name__} Configuration Parameters:")
644 print("-" * 50)
646 for name, info in cls.get_params().items():
647 if filter_params and name not in filter_params:
648 continue
649 print(f"\n{name}:")
650 print(f" Type: {info['type']}")
651 print(f" Default: {info['default']}") # type: ignore
652 print(f" Description: {info['description']}")