Coverage for contextualized/dags/trainers.py: 100%
6 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:49 -0400
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:49 -0400
1"""
2PyTorch-Lightning trainers used for Contextualized graphs.
3"""
5import torch
6import pytorch_lightning as pl
9class GraphTrainer(pl.Trainer):
10 """
11 Trains the contextualized.graphs lightning_modules
12 """
14 def predict_params(self, model, dataloader, **kwargs):
15 """
16 Predict graph parameters with model-specific kwargs
17 """
18 preds = torch.cat(super().predict(model, dataloader), dim=0).numpy()
19 return model._format_params(preds, **kwargs)