GPU Neural Networks via Vulkan

Convolutional Neural Networks

Build image classifiers using Conv2d, pooling, batch normalization, and GPU-accelerated training.

Input Format: NCHW

Grilly uses the same NCHW memory layout as PyTorch:

Tensor Shape
(N, C, H, W)
N = batch  ·  C = channels  ·  H = height  ·  W = width
python
import numpy as np

# 8 RGB images of size 32x32
x = np.random.randn(8, 3, 32, 32).astype(np.float32)
print(f"Input: {x.shape}")  # (8, 3, 32, 32)

Conv2d

A 2D convolution layer slides a learnable kernel across the spatial dimensions. The key parameters are:

python
import grilly.nn as nn

conv = nn.Conv2d(
    in_channels=3,       # RGB input
    out_channels=32,     # number of filters
    kernel_size=3,       # 3x3 kernel
    stride=1,
    padding=1,            # 'same' padding: output size = input size
)

out = conv(x)
print(f"Conv output: {out.shape}")
Output
Conv output: (8, 32, 32, 32)
Tip The SPIR-V backend uses an im2col + GEMM dispatch path for standard convolutions, and falls back to a direct dispatch for grouped or dilated convolutions. Both paths are GPU-accelerated.

Pooling and BatchNorm

Pooling reduces spatial dimensions. Batch normalization stabilizes training by normalizing per-channel statistics:

python
# MaxPool2d: halves spatial size
pool = nn.MaxPool2d(kernel_size=2, stride=2)
pooled = pool(out)
print(f"After pooling: {pooled.shape}")

# BatchNorm2d: normalize across batch for each channel
bn = nn.BatchNorm2d(num_features=32)
normed = bn(pooled)
print(f"After batchnorm: {normed.shape}")

# AvgPool2d: average pooling alternative
avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)

# AdaptiveAvgPool2d: reduces any spatial size to target size
global_pool = nn.AdaptiveAvgPool2d((1, 1))
print(f"Global pool: {global_pool(normed).shape}")
Output
After pooling: (8, 32, 16, 16)
After batchnorm: (8, 32, 16, 16)
Global pool: (8, 32, 1, 1)

Building a CNN Classifier

A typical CNN alternates convolution + activation + pooling blocks, then flattens and feeds into fully connected layers:

python
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        # Block 1: 3 -> 32 channels, 32x32 -> 16x16
        self.conv1   = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1     = nn.BatchNorm2d(32)
        # Block 2: 32 -> 64 channels, 16x16 -> 8x8
        self.conv2   = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2     = nn.BatchNorm2d(64)
        # Shared pooling (stateless)
        self.pool    = nn.MaxPool2d(2, 2)
        # Classifier head
        self.flatten = nn.Flatten()
        self.fc1     = nn.Linear(64 * 8 * 8, 256)
        self.fc2     = nn.Linear(256, num_classes)

        # Register all submodules
        for name in ['conv1','bn1','conv2','bn2','pool',
                     'flatten','fc1','fc2']:
            self._modules[name] = getattr(self, name)

    def forward(self, x):
        # Block 1
        x = self.pool(nn.ReLU()(self.bn1(self.conv1(x))))
        # Block 2
        x = self.pool(nn.ReLU()(self.bn2(self.conv2(x))))
        # Classify
        x = self.flatten(x)
        x = nn.ReLU()(self.fc1(x))
        return self.fc2(x)

model = SimpleCNN(num_classes=10)
print(f"Parameters: {sum(p.size for p in model.parameters())}")
Output
Parameters: 1083178
vs PyTorch The architecture is identical to a PyTorch CNN. The only difference is the _modules registration loop and the nn.ReLU()(x) syntax (creating a fresh activation instance each call). Both are harmless — ReLU is stateless and lightweight.

Training the CNN

python
import grilly.optim as optim

# Synthetic CIFAR-like data: (N, 3, 32, 32)
X = np.random.randn(256, 3, 32, 32).astype(np.float32)
y = np.random.randint(0, 10, 256).astype(np.int64)

model = SimpleCNN()
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(10):
    model.train()
    output = model(X)
    loss = loss_fn(output, y)

    grad = loss_fn.backward(np.ones_like(loss), output, y)
    model.zero_grad()
    model.backward(grad)
    optimizer.step()

    print(f"Epoch {epoch+1:2d}: loss={float(np.mean(loss)):.4f}")
Output
Epoch  1: loss=2.3104
Epoch  2: loss=2.2847
Epoch  3: loss=2.2521
Epoch  4: loss=2.2108
Epoch  5: loss=2.1594
Epoch  6: loss=2.0978
Epoch  7: loss=2.0253
Epoch  8: loss=1.9418
Epoch  9: loss=1.8472
Epoch 10: loss=1.7421

Conv1d

For 1D signals (audio, time series), Conv1d is available with the same API. Input shape is (N, C, L):

python
# 1D convolution: 16 channels, length 100
conv1d = nn.Conv1d(in_channels=16, out_channels=32, kernel_size=5, padding=2)
signal = np.random.randn(4, 16, 100).astype(np.float32)
out = conv1d(signal)
print(f"Conv1d output: {out.shape}")  # (4, 32, 100)
Tip Under the hood, Conv1d reshapes (N, C, L) to (N, C, 1, L) and delegates to the Conv2d SPIR-V kernel. The performance is identical.