Building and Training Neural Networks
Define models with a PyTorch-like Module API, pick a loss function and optimizer, and train end-to-end.
Step 1: Define a Module
Subclass nn.Module and implement forward(). Child modules must be registered in self._modules so that parameters() can find them:
import grilly.nn as nn
import numpy as np
class SimpleNN(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(2, 4)
self.fc2 = nn.Linear(4, 1)
# Register submodules for parameters() discovery
self._modules['fc1'] = self.fc1
self._modules['fc2'] = self.fc2
def forward(self, x):
x = nn.ReLU()(self.fc1(x))
return self.fc2(x)
Module attributes set in __init__. You must add each child module to self._modules explicitly. This ensures model.parameters() and model.state_dict() see all trainable weights.
Sequential Shorthand
nn.Sequential handles registration automatically and supports kernel fusion for back-to-back Linear + activation layers:
model = nn.Sequential(
nn.Linear(2, 4),
nn.ReLU(),
nn.Linear(4, 1),
)
x = np.array([[0.0, 1.0]], dtype=np.float32)
output = model(x)
print(f"Output shape: {output.shape}")
Output shape: (1, 1)
Linear followed by ReLU, GELU, or SiLU into a single GPU dispatch when the fused SPIR-V shader is available. This halves the number of GPU kernel launches for those patterns.
Step 2: Prepare Data
Data is plain numpy. Here we set up the XOR problem — four binary inputs with their expected outputs:
# XOR dataset
X_train = np.array([
[0.0, 0.0],
[0.0, 1.0],
[1.0, 0.0],
[1.0, 1.0],
], dtype=np.float32)
y_train = np.array([
[0.0],
[1.0],
[1.0],
[0.0],
], dtype=np.float32)
Step 3: Loss Function and Optimizer
Choose a loss function and optimizer. The optimizer reads gradients from each parameter's .grad attribute:
import grilly.optim as optim
model = SimpleNN()
loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)
Available Loss Functions
| Loss | Use Case |
|---|---|
nn.MSELoss() | Regression, continuous targets |
nn.CrossEntropyLoss() | Multi-class classification (expects class indices) |
nn.BCELoss() | Binary classification (expects probabilities) |
Available Optimizers
| Optimizer | Constructor |
|---|---|
optim.SGD | SGD(params, lr, momentum=0, weight_decay=0) |
optim.Adam | Adam(params, lr=1e-3, betas=(0.9, 0.999)) |
optim.AdamW | AdamW(params, lr=1e-3, weight_decay=0.01) |
Step 4: Training Loop
The training loop follows: forward → loss → backward → zero_grad → backward → step.
for epoch in range(100):
model.train()
# Forward pass
outputs = model(X_train)
loss = loss_fn(outputs, y_train)
# Compute gradient of loss w.r.t. outputs
grad = loss_fn.backward(np.ones_like(loss), outputs, y_train)
# Backpropagate through the model
model.zero_grad()
model.backward(grad)
# Update weights
optimizer.step()
if (epoch + 1) % 20 == 0:
print(f"Epoch {epoch+1:3d}, Loss: {float(np.mean(loss)):.4f}")
Epoch 20, Loss: 0.2481 Epoch 40, Loss: 0.2377 Epoch 60, Loss: 0.2201 Epoch 80, Loss: 0.1842 Epoch 100, Loss: 0.1294
loss.backward() computes all gradients automatically. In grilly: you call loss_fn.backward(...) to get the output gradient, then model.backward(grad) to propagate it through layers. The explicit flow gives you full control over gradient computation.
Step 5: Testing the Model
Switch to eval mode and run inference with no gradient overhead:
model.eval()
test_data = np.array([
[0.0, 0.0],
[0.0, 1.0],
[1.0, 0.0],
[1.0, 1.0],
], dtype=np.float32)
predictions = model(test_data)
print(f"Predictions:\n{predictions}")
Predictions: [[0.0832] [0.7541] [0.8129] [0.2013]]
Saving and Loading
Use state_dict() to serialize all trainable parameters, and load_state_dict() to restore them:
from grilly.utils import save_checkpoint, load_checkpoint
# Save model + optimizer state
save_checkpoint(
model=model,
optimizer=optimizer,
epoch=100,
loss=0.1294,
filepath='xor_model.npz',
)
# Load later
state = load_checkpoint('xor_model.npz')
model.load_state_dict(state['model_state'])
print(f"Resumed from epoch {state['epoch']}")
Available Layers
| Layer | Constructor | Notes |
|---|---|---|
nn.Linear | Linear(in, out, bias=True) | Fully connected, Xavier init |
nn.Conv2d | Conv2d(in_ch, out_ch, kernel, ...) | See CNNs howto |
nn.LSTM | LSTM(input, hidden, num_layers) | Long short-term memory |
nn.GRU | GRU(input, hidden, num_layers) | Gated recurrent unit |
nn.Embedding | Embedding(vocab, dim) | Lookup table |
nn.LayerNorm | LayerNorm(features, eps=1e-5) | Layer normalization |
nn.BatchNorm2d | BatchNorm2d(channels) | Batch normalization |
nn.Dropout | Dropout(p=0.5) | Inverted dropout |
nn.ReLU | ReLU() | Rectified linear |
nn.GELU | GELU() | Gaussian error linear |
nn.SiLU | SiLU() | Sigmoid linear (Swish) |
nn.Softmax | Softmax(dim=-1) | Probability normalization |
nn.Sequential | Sequential(*modules) | Auto-fuses Linear+activation |