Metadata-Version: 2.4
Name: torch-surgeon
Version: 0.1.0
Summary: Real-time gradient pathology detection for PyTorch
Author-email: Ashmit Singh <connect.ashmit.singh@proton.me>
License: MIT
Project-URL: Homepage, https://github.com/Ashmit-Singh/torch-surgeon
Project-URL: Repository, https://github.com/Ashmit-Singh/torch-surgeon
Keywords: pytorch,deep-learning,debugging,gradients,training
Classifier: Programming Language :: Python :: 3
Classifier: License :: OSI Approved :: MIT License
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Python: >=3.8
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch>=1.9.0
Requires-Dist: matplotlib>=3.3.0
Provides-Extra: dev
Requires-Dist: pytest; extra == "dev"
Requires-Dist: pytest-cov; extra == "dev"
Dynamic: license-file

# torch-surgeon

[![PyPI version](https://badge.fury.io/py/torch-surgeon.svg)](https://badge.fury.io/py/torch-surgeon)
[![Python](https://img.shields.io/pypi/pyversions/torch-surgeon)](https://pypi.org/project/torch-surgeon)
[![CI](https://github.com/Ashmit-Singh/torch-surgeon/actions/workflows/ci.yml/badge.svg)](https://github.com/Ashmit-Singh/torch-surgeon/actions)
[![Coverage](https://img.shields.io/badge/coverage-94%25-brightgreen)]()

**Real-time gradient pathology detection for PyTorch — in 2 lines.**

A loss curve is a lagging indicator. By the time it shows a problem, vanishing or
exploding gradients have been compounding for hundreds of steps. torch-surgeon attaches
diagnostic hooks to your model and surfaces per-layer pathologies in real time,
before they compound into an unrecoverable run.

## Install

```bash
pip install torch-surgeon
```

## Usage

```python
from torch_surgeon import Surgeon

surgeon = Surgeon(model, rules="default")
surgeon.attach()

# ... your existing training loop, unchanged ...
for epoch in range(epochs):
    loss = criterion(model(x), y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

report = surgeon.report()   # per-layer stats dict
surgeon.detach()            # clean removal of all hooks
```

## What it detects

| Pathology | Detection method |
|---|---|
| **Vanishing gradients** | Per-layer norm ratio drops below threshold vs EMA baseline |
| **Exploding gradients** | Per-layer norm ratio exceeds threshold vs EMA baseline |
| **Stagnant layers** | Norm near-zero for N consecutive steps — layer stopped learning |

## Custom rules

```python
surgeon = Surgeon(model, rules={
    "vanishing_threshold": 0.01,   # default
    "exploding_threshold": 100.0,  # default
    "stagnant_steps": 50,          # default
    "log_every": 10,               # print summary every N steps
    "plot": True,                  # live matplotlib plot
    "verbose": True,
})
```

## How it works

torch-surgeon uses PyTorch's `register_full_backward_hook` API to intercept gradients
at every leaf layer during the backward pass. Statistics (mean, std, norm) are computed
**inside the hook** and the raw gradient tensor is discarded immediately — keeping overhead
under 1% on typical training loops.

Pathology detection uses an exponential moving average (EMA) baseline per layer rather
than fixed thresholds — so it generalises across architectures without manual tuning.

## Performance

Sub-1% training overhead on standard loops. Validated against 100-step timing benchmarks
on Linear/ReLU networks. Stats computed in-hook; no tensors stored between steps.

## License

MIT
