Metadata-Version: 2.4
Name: aquin
Version: 0.0.1
Summary: Record training runs locally and push to Aquin for post-hoc inspection — loss curves, SAE diffs, model diffs, and more.
License: MIT
Project-URL: Homepage, https://aquin.app
Requires-Python: >=3.9
Description-Content-Type: text/markdown
Requires-Dist: requests>=2.28

# Aquin SDK

Record your training runs locally and push them to [Aquin](https://aquin.app) for post-hoc inspection — loss curves, learning rate, grad norm, epoch summaries, SAE feature diffs, model behaviour diffs, and more.

## Install

```bash
pip install aquin
```

## Quickstart

```python
import aquin

run = aquin.init(
    base_model="meta-llama/Llama-3.2-1B-Instruct",
    run_name="my-lora-run",
    config={
        "lr": 2e-4, "epochs": 3, "rank": 16, "lora_alpha": 32,
        "method": "qlora", "per_device_train_batch_size": 2,
        "gradient_accumulation_steps": 8, "dataset": "data.jsonl",
    },
)

for epoch in range(3):
    for step, batch in enumerate(dataloader):
        loss = train_step(batch)
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0).item()
        run.log(
            step,
            loss=loss.item(),
            learning_rate=scheduler.get_last_lr()[0],
            grad_norm=grad_norm,
            epoch=epoch,
        )

run.checkpoint(model, step=step)
run.finish()
```

Then push:

```bash
aquin package
aquin push
```

Your run appears in the Aquin dashboard under **CLI runs** with the full inspection suite.

## API

### `aquin.init(base_model, run_name, config)`

Starts a new run. Creates `aquin_run/` in the current directory.

| Param | Description |
|---|---|
| `base_model` | HuggingFace model ID, e.g. `"meta-llama/Llama-3.2-1B-Instruct"` |
| `run_name` | Display name for the run |
| `config` | Dict of training hyperparameters (optional, can also pass to `finish()`) |

### `run.log(step, *, loss, ...)`

Record metrics for one training step. Call every step inside your loop.

| Param | Description |
|---|---|
| `step` | Global training step (required) |
| `loss` | Scalar training loss (required) |
| `learning_rate` | Current LR — enables LR chart |
| `grad_norm` | Gradient norm — enables grad norm chart |
| `epoch` | Current epoch — enables epoch summary table |
| `momentum_norm` | Optimizer momentum norm — enables momentum chart |
| `step_ms` | Wall-clock time for this step in ms |

### `run.checkpoint(model, step)`

Saves the model checkpoint locally. One checkpoint per run — always replaces the previous save. Call once at the end of training. The checkpoint is included in the push and used for SAE diff and model diff analysis.

### `run.finish(config)`

Flushes all metrics to disk. Pass `config` here if you didn't pass it to `aquin.init()`.

## CLI

```bash
aquin login       # save your API key
aquin package     # bundle aquin_run/ into aquin_run.tar.gz
aquin push        # push to Aquin
aquin whoami      # check which account you're logged in as
```

## Using with HuggingFace Trainer / TRL

Use a `TrainerCallback` to hook into the training loop:

```python
import time
from transformers import TrainerCallback

class AquinCallback(TrainerCallback):
    def __init__(self, run):
        self.run = run
        self._step_start = 0.0

    def on_step_begin(self, args, state, control, **kwargs):
        self._step_start = time.time()

    def on_log(self, args, state, control, logs=None, **kwargs):
        if not logs or "loss" not in logs:
            return
        self.run.log(
            step=state.global_step,
            loss=float(logs["loss"]),
            learning_rate=float(logs["learning_rate"]) if "learning_rate" in logs else None,
            grad_norm=float(logs["grad_norm"]) if "grad_norm" in logs else None,
            epoch=int(state.epoch) if state.epoch is not None else None,
            step_ms=round((time.time() - self._step_start) * 1000),
        )

    def on_train_end(self, args, state, control, **kwargs):
        model = kwargs.get("model")
        if model:
            self.run.checkpoint(model, step=state.global_step)
```
