Source code for reflectorch.models.encoders.fno

import torch
import torch.nn as nn
import torch.nn.functional as F

from reflectorch.models.activations import activation_by_name

class SpectralConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, modes):
        super().__init__()

        """
        1D Fourier layer. It does FFT, linear transform, and Inverse FFT.    
        """

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes = modes  #Number of Fourier modes to multiply, at most floor(N/2) + 1

        self.scale = (1 / (in_channels*out_channels))
        self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, modes, dtype=torch.cfloat))

    # Complex multiplication
    def compl_mul1d(self, input, weights):
        # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
        return torch.einsum("bix,iox->box", input, weights)

    def forward(self, x):
        batchsize = x.shape[0]
        #Compute Fourier coeffcients up to factor of e^(- something constant)
        x_ft = torch.fft.rfft(x)

        # Multiply relevant Fourier modes
        out_ft = torch.zeros(batchsize, self.out_channels, x.size(-1)//2 + 1,  device=x.device, dtype=torch.cfloat)
        out_ft[:, :, :self.modes] = self.compl_mul1d(x_ft[:, :, :self.modes], self.weights1)

        #Return to physical space
        x = torch.fft.irfft(out_ft, n=x.size(-1))
        return x
    

[docs] class FnoEncoder(nn.Module): """An embedding network based on the Fourier Neural Operator (FNO) architecture .. image:: ../docs/fig_reflectometry_embedding_networks.png :width: 400px :align: center Args: ch_in (int): number of input channels dim_embedding (int): dimension of the output embedding modes (int): number of Fourier modes width_fno (int): number of channels of the intermediate representations n_fno_blocks (int): number of FNO blocks activation (str): the activation function fusion_self_attention (bool): if ``True`` a fusion layer is used after the FNO blocks to produce the final embedding """ def __init__( self, ch_in: int = 2, dim_embedding: int = 128, modes: int = 32, width_fno: int = 64, n_fno_blocks: int = 6, activation: str = 'gelu', fusion_self_attention: bool = False, ): super().__init__() self.ch_in = ch_in self.dim_embedding = dim_embedding self.modes = modes self.width_fno = width_fno self.n_fno_blocks = n_fno_blocks self.activation = activation_by_name(activation)() self.fusion_self_attention = fusion_self_attention self.fc0 = nn.Linear(ch_in, width_fno) #(r(q), q) self.spectral_convs = nn.ModuleList([SpectralConv1d(in_channels=width_fno, out_channels=width_fno, modes=modes) for _ in range(n_fno_blocks)]) self.w_convs = nn.ModuleList([nn.Conv1d(in_channels=width_fno, out_channels=width_fno, kernel_size=1) for _ in range(n_fno_blocks)]) self.fc_out = nn.Linear(width_fno, dim_embedding) if fusion_self_attention: self.fusion = FusionSelfAttention(width_fno, 2*width_fno) def forward(self, x): """""" x = x.permute(0, 2, 1) #(B, D, S) -> (B, S, D) x = self.fc0(x) x = x.permute(0, 2, 1) #(B, S, D) -> (B, D, S) for i in range(self.n_fno_blocks): x1 = self.spectral_convs[i](x) x2 = self.w_convs[i](x) x = x1 + x2 x = self.activation(x) if self.fusion_self_attention: x = x.permute(0, 2, 1) x = self.fusion(x) else: x = x.mean(dim=-1) x = self.fc_out(x) return x
class FusionSelfAttention(nn.Module): def __init__(self, embed_dim: int = 64, hidden_dim: int = 64, activation=nn.Tanh, ): super().__init__() self.fuser = nn.Sequential(nn.Linear(embed_dim, hidden_dim), activation(), nn.Linear(hidden_dim, 1, bias=False)) def forward(self, c): # (batch_size x seq_len x embed_dim) a = self.fuser(c) alpha = torch.exp(a) alpha = alpha/alpha.sum(dim=1, keepdim=True) return (alpha*c).sum(dim=1) # (batch_size x embed_dim)