Source code for networks.non_linear_net
import math
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
[docs]class NonLinearNet(nn.Module):
"""
NonLinear Classifier
"""
def __init__(self, input_dim, out_dim):
super(NonLinearNet, self).__init__()
self.fc_all = nn.Sequential(
nn.Linear(input_dim, 50),
nn.ReLU(),
nn.Dropout(p=0.2),
nn.Linear(50, 10),
nn.ReLU(),
nn.Dropout(p=0.2),
nn.Linear(10, out_dim),
)
# add dropout