Coverage for src / autoencodix / modeling / _captum_forward.py: 47%
15 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1import torch
2import torch.nn as nn
3from autoencodix.base._base_autoencoder import BaseAutoencoder
4from autoencodix.utils._model_output import ModelOutput
7class CaptumForward(nn.Module):
8 def __init__(self, model: BaseAutoencoder, dim: int):
9 super(CaptumForward, self).__init__() # <-- REQUIRED
10 self.model = model # (Registered as a submodule)
11 self.dim = dim
12 self.device = next(
13 model.parameters()
14 ).device # Get the device of the model parameters
16 def forward(self, x: torch.Tensor):
17 mp: ModelOutput = self.model(x=x.to(self.device))
18 latent = mp.latentspace
19 output = latent[:, self.dim]
20 return output.unsqueeze(1).to("cpu")