Metadata-Version: 2.4
Name: bonsai-nn-library
Version: 0.5.3
Summary: Common PyTorch utilities for research projects on neural networks.
Author-email: Richard Lange <opensource@bonsai-neuro-ai.com>
License: MIT
Project-URL: Repository, https://github.com/bonsai-neuro-ai/nn-library
Classifier: Programming Language :: Python :: 3
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Classifier: Development Status :: 1 - Planning
Classifier: Intended Audience :: Science/Research
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Python: >=3.10
Description-Content-Type: text/markdown
License-File: LICENSE.txt
Requires-Dist: jsonargparse
Requires-Dist: matplotlib
Requires-Dist: mlflow
Requires-Dist: numpy<2
Requires-Dist: pandas
Requires-Dist: parse
Requires-Dist: pycocotools
Requires-Dist: pydot
Requires-Dist: pyyaml
Requires-Dist: scikit-learn
Requires-Dist: scipy
Requires-Dist: seaborn
Requires-Dist: tabulate
Requires-Dist: torch[opt-einsum]
Requires-Dist: torchmetrics
Requires-Dist: torchvision
Requires-Dist: tqdm
Dynamic: license-file

# NN-Library

We in the [BONSAI Lab](https://bonsai-neuro-ai.com) do research on neural networks, among other 
things, that requires loading/training/reconfiguring neural network models. This library is a 
work-in-progress suite of in-house tools to address some pain-points we've encountered in our 
research workflow.

We make no guarantees about the stability or usability of this library, but we hope that it can be
useful to others in the research community. If you have any questions or suggestions, please feel
free to reach out to us or open an issue on the GitHub repository.

## Installation

Using `pip`:

    pip install bonsai-nn-library

Using [`uv`](https://docs.astral.sh/uv/) (in a project):

    uv add bonsai-nn-library

Or, as a drop-in replacement for `pip`:

    uv pip install bonsai-nn-library

## Usage (Python >= 3.10)

Say you want to use one of our "fancy layers" like a low-rank convolution. You can do so like this:

```python
from torch import nn
from nn_lib.models.fancy_layers import LowRankConv2d

model = nn.Sequential(
    LowRankConv2d(in_channels=3, out_channels=64, kernel_size=3, rank=8),
    nn.ReLU(),
    nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3),
    nn.ReLU(),
    nn.Flatten(),
    nn.LazyLinear(10)
)
```

## Useful thing \#1: improved `GraphModule`s.

PyTorch was not originally designed to handle explicit computation graphs, but it was added somewhat
later in the `torch.fx` module. Others might use  tensorflow or jax for this, but we like PyTorch. 
The `torch.fx.GraphModule` class is the built-in way to handle computation graphs in PyTorch, but it
lacks some features that we find useful. We have extended the `GraphModule` class in our
`GraphModulePlus` class, which inherits from `GraphModule` and adds some further functionality.

A motivating use-case is that we want to be able to "stitch" models together or extract out hidden
layer activity. This is a little tricky to get right using `GraphModule` alone, but we've added some
utilities like

* `GraphModulePlus.set_output(layer_name)`: use this to chop off the head of a model and make it
  output from a specific layer.
* `GraphModulePlus.new_from_merge(...)`: use this to merge or "stitch" existing models together. See
   `demos/demo_stitching.py` for a worked out example.

We've also done some metaprogramming trickery so that if you import `GraphModulePlus` anywhere in
your code, it will automatically inject itself into the `torch.fx` module. The surprising but
convenient behavior is:

```python
from torch import nn
from torch.fx import symbolic_trace
from nn_lib.models import GraphModulePlus

my_regular_torch_model = nn.Sequential(
    nn.Conv2d(3, 64, 3),
    nn.ReLU(),
    nn.Conv2d(64, 64, 3),
    nn.ReLU(),
    nn.Flatten(),
    nn.LazyLinear(10)
)

# Natively, symbolic_trace is expected to return a GraphModule, but we've injected GraphModulePlus
graphified_model = symbolic_trace(my_regular_torch_model)
assert isinstance(graphified_model, GraphModulePlus)
```

## Useful thing \#2: Fancy layers.

We have implemented a few "fancy" layers, available via `nn_lib.models` or 
`nn_lib.models.fancy_layers` that we find useful in our research. These include:

* `Regressable` linear layers: a `Protocol` that allows linear layers to be initialized by least
   squares regression. This is useful for initializing a linear layer to approximate a function
   learned by a different model.
* `RegressableLinear`: a regressable version of `nn.Linear`
* `LowRankLinear`: a regressable linear layer with a low-rank factorization.
* `ProcrustesLinear`: a regressable linear layer constrained to rotation, with optional shift (bias)
   and optional scaling.
* A conv2d version of each of the above.

## Useful thing \#3: MLFLOW utilities.

We use MLFlow to track our experiments. We have a few utilities in `nn_lib.utils.mlfow` that remove
a bit of boilerplate from our code.

## Useful thing \#4: Dataset utilities and torchvision wrappers.

We have implemented some `LightningDataModule` classes for a few standard vision datasets. See
`nn_lib.datasets`. We've also implemented a few simple helpers for downloading pretrained models
from torch hub using the `torchvision` API. See `nn_lib/models/__init__.py`

## Useful thing \#5: NTK utilities.

See `nn_lib.analysis.ntk` for some neural tangent kernel utilities.

## Forthcoming/Planned features

* More fancy layers
* Vector Quantization utilities (but see `nn_lib.models.sparse_auto_encoder` which has some already)
* Further analysis utilities especially focused on calculating neural similarity measures.

## Obsolete/deprecated features

* lightning training and overly-complex CLI utilities. Some straggler files might still need to be
  cleaned up.
