Source code for networks.non_linear_net

import math
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F


[docs]class NonLinearNet(nn.Module): """ NonLinear Classifier """ def __init__(self, input_dim, out_dim): super(NonLinearNet, self).__init__() self.fc_all = nn.Sequential( nn.Linear(input_dim, 50), nn.ReLU(), nn.Dropout(p=0.2), nn.Linear(50, 10), nn.ReLU(), nn.Dropout(p=0.2), nn.Linear(10, out_dim), ) # add dropout
[docs] def forward(self, x): out = self.fc_all(x) return out