Metadata-Version: 2.4
Name: pilot-optimizer
Version: 0.1.0
Summary: PILOT: Policy-Informed Learned Optimization for Adaptive Deep Network Training
Author-email: Sattam Altuuaim <sattam.tuuaim@kaust.edu.sa>, Lama Ayash <lama.ayash@kaust.edu.sa>, Muhammad Mubashar <muhammad.mubashar@strath.ac.uk>, Naeemullah Khan <naeemullah.khan@kaust.edu.sa>
License: MIT
Project-URL: Homepage, https://sattamaltwaim.github.io/PILOT/
Project-URL: Repository, https://github.com/SattamAltwaim/PILOT
Project-URL: Paper, https://arxiv.org/abs/submit/7629402
Keywords: optimizer,deep-learning,pytorch,meta-learning,adaptive-optimization
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Python: >=3.9
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch>=2.0.0
Provides-Extra: dev
Requires-Dist: pytest; extra == "dev"
Requires-Dist: numpy; extra == "dev"
Dynamic: license-file

# PILOT: Policy-Informed Learned Optimization for Adaptive Deep Network Training

> **PILOT** is an online adaptive optimizer that adjusts its update behavior during training using gradient-direction agreement as a signal of local optimization stability.

---

## Overview

Most optimizers use a fixed update structure throughout training — a static balance between momentum, normalization, and sign-based updates that cannot respond to how the loss landscape evolves.

**PILOT** introduces a learnable policy that continuously modulates three core update primitives:
- **Momentum reliance** — how much to rely on accumulated gradient history vs. the current gradient
- **Variance-normalization strength** — how aggressively to apply adaptive scaling
- **Sign-based behavior** — how much to compress gradient magnitudes toward ±1

The policy is conditioned on a smoothed gradient-direction agreement signal, which serves as a compact online descriptor of local update consistency. It is updated online during training using a one-step meta-gradient estimate — no offline search, no meta-training phase, no second-order estimation.

![Loss Landscape — CIFAR-10 / SmallCNN](fig2_landscape.png)
*PILOT follows a distinct trajectory through the loss surface and converges to a lower-loss region compared to Adam, AdamW, Lion, and Sophia.*

---

## Key Results

### CNN Architecture

| Dataset | Optimizer | Accuracy (%) ↑ | Val Loss ↓ | Loss Var. ↓ |
|---|---|---|---|---|
| FashionMNIST | Adam | 93.28 | 0.1957 | **0.0033** |
| FashionMNIST | AdamW | 93.22 | 0.1944 | 0.0034 |
| FashionMNIST | Lion | 92.91 | 0.2091 | 0.0041 |
| FashionMNIST | AdaBelief | 93.66 | 0.1822 | 0.0046 |
| FashionMNIST | **PILOT (Ours)** | **94.13** | **0.1719** | 0.0045 |
| CIFAR-10 | Adam | 79.91 | 0.5794 | 0.0103 |
| CIFAR-10 | Lion | 80.87 | 0.5487 | 0.0105 |
| CIFAR-10 | **PILOT (Ours)** | **81.94** | **0.5302** | **0.0073** |

### ResNet-18 Architecture

| Dataset | Optimizer | Accuracy (%) ↑ | Val Loss ↓ | Loss Var. ↓ |
|---|---|---|---|---|
| FashionMNIST | AdaBelief | 95.33 | 0.1711 | 0.0056 |
| FashionMNIST | **PILOT (Ours)** | **95.71** | 0.2690 | 0.0030 |
| CIFAR-10 | Adam | 93.18 | **0.2140** | 0.0073 |
| CIFAR-10 | AdamW | 92.90 | 0.2514 | 0.0066 |
| CIFAR-10 | **PILOT (Ours)** | **93.42** | 0.2496 | **0.0001** |

---

## Method

### Gradient-Direction Agreement

At each step, PILOT computes the cosine similarity between successive gradients:

$$r_t = \frac{g_t^\top g_{t-1}}{\|g_t\|_2 \, \|g_{t-1}\|_2 + \epsilon}$$

This is smoothed via an exponential moving average:

$$\rho_t = \gamma \rho_{t-1} + (1 - \gamma) r_t$$

Positive values indicate stable, aligned gradients. Values near zero indicate noise. Negative values indicate directional disagreement.

### Learnable Policy

The smoothed signal $\rho_t$ is fed through polynomial functions followed by sigmoid activations to produce three scalar control variables:

$$p_{m,t} = \sigma(f(\rho_t; \phi_m)), \quad p_{v,t} = \tfrac{1}{2}\sigma(f(\rho_t; \phi_v)), \quad p_{s,t} = \sigma(f(\rho_t; \phi_s))$$

The total number of learnable policy parameters is $3(d+1)$, where $d$ is the polynomial degree.

### Update Rule

$$\theta_{t+1} = \theta_t - \eta \frac{(|n_t| + \epsilon_n)^{1 - p_{s,t}} \odot \text{sign}(n_t)}{\hat{v}_t^{\,p_{v,t}} + \epsilon}$$

where $n_t = p_{m,t} \hat{m}_t + (1 - p_{m,t}) g_t$ is the policy-controlled blend of momentum and current gradient.

This formulation recovers Adam ($p_m=1, p_v=0.5, p_s=0$) and sign-based updates ($p_s=1, p_v=0$) as special cases.

---

## Installation

```bash
pip install pilot-optimizer
```

Or install from source:

```bash
git clone https://github.com/SattamAltwaim/PILOT.git
cd PILOT
pip install -e .
```

---

## Usage

```python
from pilot import PILOT

optimizer = PILOT(
    model.parameters(),
    lr=1e-3,
    betas=(0.9, 0.999),
    weight_decay=1e-4,
    gamma=0.95,        # smoothing coefficient for agreement signal
    lr_phi=0.01,       # policy learning rate
    degree=2           # polynomial degree
)

for batch in dataloader:
    loss = criterion(model(x), y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
```

---

## Hyperparameters

| Parameter | Description | Typical Range |
|---|---|---|
| `lr` | Model learning rate | `1e-4` – `1e-3` |
| `betas` | Moment coefficients | `(0.9, 0.999)` |
| `gamma` | Agreement signal smoothing | `0.85` – `0.99` |
| `lr_phi` | Policy learning rate | `5e-4` – `5e-2` |
| `degree` | Polynomial degree | `1` – `4` |

### Configuration-Specific Selections

| Dataset | Architecture | γ | η_φ | Degree |
|---|---|---|---|---|
| CIFAR-10 | CNN | 0.882 | 0.00312 | 1 |
| CIFAR-10 | ResNet-18 | 0.950 | 0.00500 | 2 |
| FashionMNIST | CNN | 0.950 | 0.01000 | 2 |
| FashionMNIST | ResNet-18 | 0.957 | 0.00273 | 3 |

---

## Experiments

Experiments use 30 epochs, cross-entropy loss, cosine annealing LR schedule, batch size 128, and AMP. ResNet-18 configurations include a 3-epoch linear warmup.

```bash
# CNN on CIFAR-10
python train.py --dataset cifar10 --arch cnn --optimizer pilot

# ResNet-18 on FashionMNIST
python train.py --dataset fashionmnist --arch resnet18 --optimizer pilot
```

