Coverage for contextualized/dags/trainers.py: 100%

6 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-21 13:49 -0400

1""" 

2PyTorch-Lightning trainers used for Contextualized graphs. 

3""" 

4 

5import torch 

6import pytorch_lightning as pl 

7 

8 

9class GraphTrainer(pl.Trainer): 

10 """ 

11 Trains the contextualized.graphs lightning_modules 

12 """ 

13 

14 def predict_params(self, model, dataloader, **kwargs): 

15 """ 

16 Predict graph parameters with model-specific kwargs 

17 """ 

18 preds = torch.cat(super().predict(model, dataloader), dim=0).numpy() 

19 return model._format_params(preds, **kwargs)