networkvi.model.NETWORKVI#
- class networkvi.model.NETWORKVI(adata: ~anndata._core.anndata.AnnData, n_genes: int = 0, ensembl_ids_genes: ~numpy.ndarray | None = None, n_regions: int = 0, ensembl_ids_regions: ~numpy.ndarray | None = None, ensembl_ids_proteins: ~numpy.ndarray | None = None, n_patient_covariates: int = 0, modality_weights: ~typing.Literal['equal', 'cell', 'universal', 'moe'] = 'moe', modality_penalty: ~typing.Literal['Jeffreys', 'MMD', 'None'] = 'Jeffreys', n_hidden: int | None = None, n_latent: int | None = None, n_layers_encoder: int = 2, layers_encoder_type: ~typing.Literal['linear', 'go'] = 'go', n_layers_decoder: int = 2, layers_decoder_type: ~typing.Literal['linear'] = 'linear', sparsities: list = [0.9, 0.9], dynamic: bool = True, dynamic_update_rate: int | None = None, dynamic_end_update_rate: str | None = None, gene_interaction_layer_dynamic: bool = False, gene_interaction_layer_pruning_frac: float | None = None, gene_interaction_layer_dynamic_update_rate: int | None = None, gene_interaction_layer_dynamic_end_update_rate: int | None = None, gene_interaction_layer_dynamic_save_path: str | None = None, keep_activations: bool = False, expression_gene_layer_type: ~typing.Literal['none', 'standard', 'interaction'] = 'interaction', accessibility_gene_layer_type: ~typing.Literal['none', 'standard', 'interaction'] = 'interaction', protein_gene_layer_type: ~typing.Literal['none', 'standard', 'interaction'] = 'interaction', gene_layer_interaction_source: str | None = None, standard_gene_size: int = 4, standard_go_size: int = 6, obo_file: str | None = None, map_ensembl_go: list | ~numpy.ndarray | None = None, dropout_rate: float = 0.1, region_factors: bool = True, gene_likelihood: ~typing.Literal['zinb', 'nb', 'poisson'] = 'zinb', dispersion: ~typing.Literal['gene', 'gene-batch', 'gene-label', 'gene-cell'] = 'gene', use_batch_norm: ~typing.Literal['encoder', 'decoder', 'none', 'both'] = 'none', use_layer_norm: ~typing.Literal['encoder', 'decoder', 'none', 'both'] = 'both', latent_distribution: ~typing.Literal['normal', 'ln'] = 'normal', deeply_inject_covariates: bool = False, decoder_deeply_inject_covariates: bool = False, first_layer_inject_covariates: bool = False, last_layer_inject_covariates: bool = False, encode_covariates: bool = False, fully_paired: bool = False, protein_dispersion: ~typing.Literal['protein', 'protein-batch', 'protein-label'] = 'protein', activation_fn: ~torch.nn.modules.module.Module = <class 'torch.nn.modules.activation.ReLU'>, use_mean_mixing: bool = False, use_product_of_experts: bool = False, use_mixture_of_experts: bool = True, **model_kwargs)#
Integration of multi-moda data employing domain knowledge-driven neural networks [Arnoldt et al., 2025].
NetworkVI performs paired and mosaic integration of multiomic datasets using sparse encoders.
- Parameters:
adata – AnnData object that has been registered via
setup_anndata().n_regions – The number of accessibility features (genomic regions).
ensembl_ids_regions – ENSEMBL-IDs of accessibility features (genomic regions).
n_genes – The number of gene expression features (genes).
ensembl_ids_genes – ENSEMBL-IDs of gene expression features (genes).
ensembl_ids_proteins – ENSEMBL-IDs of surface proteins (proteins).
n_patient_covariates – The number of patients.
modality_weights – Weighting scheme across modalities. One of the following: *
"equal": Equal weight in each modality *"universal": Learn weights across modalities w_m. *"cell": Learn weights across modalities and cells. w_{m,c} *"moe": Learn weights with gating network for MoE.modality_penalty – Training Penalty across modalities. One of the following: *
"Jeffreys": Jeffreys penalty to align modalities *"MMD": MMD penalty to align modalities *"None": No penaltyn_hidden – Number of nodes per hidden layer. If None, defaults to square root of number of regions.
n_latent – Dimensionality of the latent space. If None, defaults to square root of n_hidden.
n_layers_encoder – Number of hidden layers used for encoder NNs.
layers_encoder_type – Type of hidden layers used for encoder NNs. Type of layer. One of the following *
'linear'- Linear Layers *'go'- GO Layersn_layers_decoder – Number of hidden layers used for decoder NNs.
layers_decoder_type – Type of hidden layers used for decoder NNs. *
'linear'- Linear Layers *'go'- GO Layersexpression_gene_layer_type – Type of expression gene layer. One of the following *
'none'- No gene layer *'standard'- Standard Gene Layer *'interaction'- Interaction Gene Layeraccessibility_gene_layer_type – Type of accessibility gene layer. One of the following *
'none'- No gene layer *'standard'- Standard Gene Layer *'interaction'- Interaction Gene Layerprotein_gene_layer_type – Type of protein gene layer. One of the following *
'none'- No gene layer *'standard'- Standard Gene Layer *'interaction'- Interaction Gene Layergene_layer_interaction_source – Gene layer interaction source. One of the following *
'pp'- Protein-Protein *'tf'- Transcription Factor *'tad'- Topologically Associated Domainsstandard_gene_size – Standard size of gene nodes in Gene Layers.
standard_go_size – Standard size of GO nodes in GO Layers.
obo_file – Path .obo file of GO.
map_ensembl_go – List of .gaf files with mappings of Ensembl IDs to GO.
keep_activations – Bool, whether keep activations in fully-connected encoder layers.
use_mean_mixing – Bool, whether perform mean modality mixing.
use_product_of_experts – Bool, whether perform modality mixing with PoE.
use_mixture_of_experts – Bool, whether perform modality mixing with MoE.
dropout_rate – Dropout rate for neural networks.
model_depth – Model sequencing depth / library size.
region_factors – Include region-specific factors in the model.
gene_dispersion – One of the following *
'gene'- genes_dispersion parameter of NB is constant per gene across cells *'gene-batch'- genes_dispersion can differ between different batches *'gene-label'- genes_dispersion can differ between different labelsprotein_dispersion – One of the following *
'protein'- protein_dispersion parameter is constant per protein across cells *'protein-batch'- protein_dispersion can differ between different batches NOT TESTED *'protein-label'- protein_dispersion can differ between different labels NOT TESTEDlatent_distribution – One of *
'normal'- Normal distribution *'ln'- Logistic normal distribution (Normal(0, I) transformed by softmax)deeply_inject_covariates – Whether to deeply inject covariates into all layers of the endecoder. If False, covariates will only be included in the input layer.
first_layer_inject_covariates – Whether to deeply inject covariates into all layers of the decoder. If False, covariates will only be included in the input layer.
last_layer_inject_covariates – Whether to inject covariates into all layers of the decoder. If False, covariates will only be included in the input layer.
fully_paired – allows the simplification of the model if the data is fully paired. Currently ignored.
**model_kwargs – Keyword args for
NETWORKVAE
Examples
>>> adata_rna = anndata.read_h5ad(path_to_rna_anndata) >>> adata_atac = networkvi.data.read_10x_atac(path_to_atac_anndata) >>> adata_multi = networkvi.data.read_10x_multiome(path_to_multiomic_anndata) >>> adata_svi = networkvi.data.organize_multiome_anndatas(adata_multi, adata_rna, adata_atac) >>> networkvi.model.networkvi.setup_anndata(adata_svi, ensembl_ids_genes=adata_rna.var["ensembl_ids"], ensembl_ids_regions=adata_atac.var["ensembl_ids"], batch_key="modality") >>> vae = networkvi.model.NetworkVI(adata_svi) >>> vae.train()
Notes
- The model assumes that the features are organized so that all expression features are
consecutive, followed by all accessibility features. For example, if the data has 100 genes and 250 genomic regions, the model assumes that the first 100 features are genes, and the next 250 are the regions.
- The main batch annotation, specified in
setup_anndata, should correspond to the modality each cell originated from. This allows the model to focus mixing efforts, using an adversarial component, on mixing the modalities. Other covariates can be specified using the categorical_covariate_keys argument.
- The main batch annotation, specified in
- __init__(adata: ~anndata._core.anndata.AnnData, n_genes: int = 0, ensembl_ids_genes: ~numpy.ndarray | None = None, n_regions: int = 0, ensembl_ids_regions: ~numpy.ndarray | None = None, ensembl_ids_proteins: ~numpy.ndarray | None = None, n_patient_covariates: int = 0, modality_weights: ~typing.Literal['equal', 'cell', 'universal', 'moe'] = 'moe', modality_penalty: ~typing.Literal['Jeffreys', 'MMD', 'None'] = 'Jeffreys', n_hidden: int | None = None, n_latent: int | None = None, n_layers_encoder: int = 2, layers_encoder_type: ~typing.Literal['linear', 'go'] = 'go', n_layers_decoder: int = 2, layers_decoder_type: ~typing.Literal['linear'] = 'linear', sparsities: list = [0.9, 0.9], dynamic: bool = True, dynamic_update_rate: int | None = None, dynamic_end_update_rate: str | None = None, gene_interaction_layer_dynamic: bool = False, gene_interaction_layer_pruning_frac: float | None = None, gene_interaction_layer_dynamic_update_rate: int | None = None, gene_interaction_layer_dynamic_end_update_rate: int | None = None, gene_interaction_layer_dynamic_save_path: str | None = None, keep_activations: bool = False, expression_gene_layer_type: ~typing.Literal['none', 'standard', 'interaction'] = 'interaction', accessibility_gene_layer_type: ~typing.Literal['none', 'standard', 'interaction'] = 'interaction', protein_gene_layer_type: ~typing.Literal['none', 'standard', 'interaction'] = 'interaction', gene_layer_interaction_source: str | None = None, standard_gene_size: int = 4, standard_go_size: int = 6, obo_file: str | None = None, map_ensembl_go: list | ~numpy.ndarray | None = None, dropout_rate: float = 0.1, region_factors: bool = True, gene_likelihood: ~typing.Literal['zinb', 'nb', 'poisson'] = 'zinb', dispersion: ~typing.Literal['gene', 'gene-batch', 'gene-label', 'gene-cell'] = 'gene', use_batch_norm: ~typing.Literal['encoder', 'decoder', 'none', 'both'] = 'none', use_layer_norm: ~typing.Literal['encoder', 'decoder', 'none', 'both'] = 'both', latent_distribution: ~typing.Literal['normal', 'ln'] = 'normal', deeply_inject_covariates: bool = False, decoder_deeply_inject_covariates: bool = False, first_layer_inject_covariates: bool = False, last_layer_inject_covariates: bool = False, encode_covariates: bool = False, fully_paired: bool = False, protein_dispersion: ~typing.Literal['protein', 'protein-batch', 'protein-label'] = 'protein', activation_fn: ~torch.nn.modules.module.Module = <class 'torch.nn.modules.activation.ReLU'>, use_mean_mixing: bool = False, use_product_of_experts: bool = False, use_mixture_of_experts: bool = True, **model_kwargs)#
Methods
__init__(adata[, n_genes, ...])calculate_covariate_attention([...])Return attention rollout.
convert_legacy_save(dir_path, output_dir_path)Converts a legacy saved model (<v0.15.0) to the updated save format.
deregister_manager([adata])Deregisters the
AnnDataManagerinstance associated with adata.differential_accessibility([adata, groupby, ...])A unified method for differential accessibility analysis.
differential_expression([adata, groupby, ...])A unified method for differential expression analysis.
get_accessibility_estimates([adata, ...])Impute the full accessibility matrix.
get_anndata_manager(adata[, required])Retrieves the
AnnDataManagerfor a given AnnData object.get_elbo([adata, indices, batch_size, ...])Compute the evidence lower bound (ELBO) on the data.
get_from_registry(adata, registry_key)Returns the object in AnnData associated with the key in the data registry.
get_go_gene_importances([labels_column, ...])Return gene and GO importances.
get_latent_representation([adata, modality, ...])Return the latent representation for each cell.
get_library_size_factors([adata, indices, ...])Return library size factors.
get_marginal_ll([adata, indices, ...])Compute the marginal log-likehood of the data.
get_normalized_expression([adata, indices, ...])Returns the normalized (decoded) gene expression.
get_perturbation_go_gene_importances(...[, ...])get_protein_foreground_probability([adata, ...])Returns the foreground probability for proteins.
get_reconstruction_error([adata, indices, ...])Compute the reconstruction error on the data.
get_region_factors()Return region-specific factors.
load(dir_path[, adata, accelerator, device, ...])Instantiate a model from the saved output.
load_query_data(adata, reference_model[, ...])Online update of a reference model with scArches algorithm [Lotfollahi et al., 2022].
load_registry(dir_path[, prefix])Return the full registry saved with the model.
prepare_query_anndata(adata, reference_model)Prepare data for query integration.
register_manager(adata_manager)Registers an
AnnDataManagerinstance with this model class.save(dir_path[, prefix, overwrite, ...])Save the state of the model.
setup_anndata(adata[, layer, batch_key, ...])Sets up the
AnnDataobject for this model.to_device(device)Move model to device.
train([max_epochs, lr, accelerator, ...])Trains the model using amortized variational inference.
view_anndata_setup([adata, ...])Print summary of the setup for the initial AnnData or a given AnnData object.
view_setup_args(dir_path[, prefix])Print args used to setup a saved model.
Attributes
adataData attached to model instance.
adata_managerManager instance associated with self.adata.
deviceThe current device that the module's params are on.
historyReturns computed metrics during training.
is_trainedWhether the model has been trained.
summary_stringSummary string of the model.
test_indicesObservations that are in test set.
train_indicesObservations that are in train set.
validation_indicesObservations that are in validation set.