GPU Neural Networks via Vulkan

Building and Training Neural Networks

Define models with a PyTorch-like Module API, pick a loss function and optimizer, and train end-to-end.

Step 1: Define a Module

Subclass nn.Module and implement forward(). Child modules must be registered in self._modules so that parameters() can find them:

python
import grilly.nn as nn
import numpy as np

class SimpleNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(2, 4)
        self.fc2 = nn.Linear(4, 1)
        # Register submodules for parameters() discovery
        self._modules['fc1'] = self.fc1
        self._modules['fc2'] = self.fc2

    def forward(self, x):
        x = nn.ReLU()(self.fc1(x))
        return self.fc2(x)
Note: _modules registration Unlike PyTorch, grilly does not auto-register Module attributes set in __init__. You must add each child module to self._modules explicitly. This ensures model.parameters() and model.state_dict() see all trainable weights.

Sequential Shorthand

nn.Sequential handles registration automatically and supports kernel fusion for back-to-back Linear + activation layers:

python
model = nn.Sequential(
    nn.Linear(2, 4),
    nn.ReLU(),
    nn.Linear(4, 1),
)

x = np.array([[0.0, 1.0]], dtype=np.float32)
output = model(x)
print(f"Output shape: {output.shape}")
Output
Output shape: (1, 1)
Tip Sequential auto-fuses Linear followed by ReLU, GELU, or SiLU into a single GPU dispatch when the fused SPIR-V shader is available. This halves the number of GPU kernel launches for those patterns.

Step 2: Prepare Data

Data is plain numpy. Here we set up the XOR problem — four binary inputs with their expected outputs:

python
# XOR dataset
X_train = np.array([
    [0.0, 0.0],
    [0.0, 1.0],
    [1.0, 0.0],
    [1.0, 1.0],
], dtype=np.float32)

y_train = np.array([
    [0.0],
    [1.0],
    [1.0],
    [0.0],
], dtype=np.float32)

Step 3: Loss Function and Optimizer

Choose a loss function and optimizer. The optimizer reads gradients from each parameter's .grad attribute:

python
import grilly.optim as optim

model = SimpleNN()
loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)

Available Loss Functions

LossUse Case
nn.MSELoss()Regression, continuous targets
nn.CrossEntropyLoss()Multi-class classification (expects class indices)
nn.BCELoss()Binary classification (expects probabilities)

Available Optimizers

OptimizerConstructor
optim.SGDSGD(params, lr, momentum=0, weight_decay=0)
optim.AdamAdam(params, lr=1e-3, betas=(0.9, 0.999))
optim.AdamWAdamW(params, lr=1e-3, weight_decay=0.01)

Step 4: Training Loop

The training loop follows: forward → loss → backward → zero_grad → backward → step.

python
for epoch in range(100):
    model.train()

    # Forward pass
    outputs = model(X_train)
    loss = loss_fn(outputs, y_train)

    # Compute gradient of loss w.r.t. outputs
    grad = loss_fn.backward(np.ones_like(loss), outputs, y_train)

    # Backpropagate through the model
    model.zero_grad()
    model.backward(grad)

    # Update weights
    optimizer.step()

    if (epoch + 1) % 20 == 0:
        print(f"Epoch {epoch+1:3d}, Loss: {float(np.mean(loss)):.4f}")
Output
Epoch  20, Loss: 0.2481
Epoch  40, Loss: 0.2377
Epoch  60, Loss: 0.2201
Epoch  80, Loss: 0.1842
Epoch 100, Loss: 0.1294
vs PyTorch In PyTorch: loss.backward() computes all gradients automatically. In grilly: you call loss_fn.backward(...) to get the output gradient, then model.backward(grad) to propagate it through layers. The explicit flow gives you full control over gradient computation.

Step 5: Testing the Model

Switch to eval mode and run inference with no gradient overhead:

python
model.eval()

test_data = np.array([
    [0.0, 0.0],
    [0.0, 1.0],
    [1.0, 0.0],
    [1.0, 1.0],
], dtype=np.float32)

predictions = model(test_data)
print(f"Predictions:\n{predictions}")
Output
Predictions:
[[0.0832]
 [0.7541]
 [0.8129]
 [0.2013]]

Saving and Loading

Use state_dict() to serialize all trainable parameters, and load_state_dict() to restore them:

python
from grilly.utils import save_checkpoint, load_checkpoint

# Save model + optimizer state
save_checkpoint(
    model=model,
    optimizer=optimizer,
    epoch=100,
    loss=0.1294,
    filepath='xor_model.npz',
)

# Load later
state = load_checkpoint('xor_model.npz')
model.load_state_dict(state['model_state'])
print(f"Resumed from epoch {state['epoch']}")

Available Layers

LayerConstructorNotes
nn.LinearLinear(in, out, bias=True)Fully connected, Xavier init
nn.Conv2dConv2d(in_ch, out_ch, kernel, ...)See CNNs howto
nn.LSTMLSTM(input, hidden, num_layers)Long short-term memory
nn.GRUGRU(input, hidden, num_layers)Gated recurrent unit
nn.EmbeddingEmbedding(vocab, dim)Lookup table
nn.LayerNormLayerNorm(features, eps=1e-5)Layer normalization
nn.BatchNorm2dBatchNorm2d(channels)Batch normalization
nn.DropoutDropout(p=0.5)Inverted dropout
nn.ReLUReLU()Rectified linear
nn.GELUGELU()Gaussian error linear
nn.SiLUSiLU()Sigmoid linear (Swish)
nn.SoftmaxSoftmax(dim=-1)Probability normalization
nn.SequentialSequential(*modules)Auto-fuses Linear+activation