Coverage for src / autoencodix / modeling / _captum_forward.py: 47%

15 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1import torch 

2import torch.nn as nn 

3from autoencodix.base._base_autoencoder import BaseAutoencoder 

4from autoencodix.utils._model_output import ModelOutput 

5 

6 

7class CaptumForward(nn.Module): 

8 def __init__(self, model: BaseAutoencoder, dim: int): 

9 super(CaptumForward, self).__init__() # <-- REQUIRED 

10 self.model = model # (Registered as a submodule) 

11 self.dim = dim 

12 self.device = next( 

13 model.parameters() 

14 ).device # Get the device of the model parameters 

15 

16 def forward(self, x: torch.Tensor): 

17 mp: ModelOutput = self.model(x=x.to(self.device)) 

18 latent = mp.latentspace 

19 output = latent[:, self.dim] 

20 return output.unsqueeze(1).to("cpu")