Metadata-Version: 2.4
Name: flax-weightwatcher
Version: 0.1.1
Summary: A smaller, faster, cleaner WeightWatcher for FLAX/JAX with XLA
Home-page: https://github.com/jaisidhsingh/flax-weightwatcher
Author: Jaisidh Singh
Author-email: jaisidh.singh@student.uni-tuebingen.com
Requires-Python: >=3.12
Description-Content-Type: text/markdown
Requires-Dist: flax>=0.12.0
Requires-Dist: jax>=0.7.2
Requires-Dist: jaxtyping>=0.3.3
Requires-Dist: numpy>=2.3.3
Requires-Dist: pandas>=2.3.3
Requires-Dist: powerlaw>=1.5
Dynamic: author
Dynamic: author-email
Dynamic: home-page
Dynamic: requires-python

# Flax-WeightWatcher

Unfortunately, <a href="https://github.com/CalculatedContent/WeightWatcher">CalculatedContent's WeightWatcher</a> does not support FLAX models, and could use accelerated linear algebra (XLA) frameworks for greater speed.

Since I found the process of making a PR to the original WeightWatcher repository too tedious, I just wrote my own one for FLAX models, because that's what I'm working with at the moment. It also helps that JAX uses XLA.

Flax-WeightWatcher is not meant to be a one-one match with the original, yet it is designed to be extensible. I welcome any interest in contributing to it to extend its functionality to perhaps match or exceed the original.

## Installation

A simple `pip install flax-weightwatcher` will install this tool as a Python library. It will also install the following dependancies:

```
jax
flax
numpy
pandas
powerlaw
jaxtyping
```

## Usage

The usage is intended to match that of the original WeightWatcher but has some minor changes.

```python
from flax_weightwatcher import FlaxWeightWatcher

model = nnx.Sequential(*[nnx.Linear(28*28, 128, rngs=nnx.Rngs(0)), nnx.Linear(128, 10, rngs=nnx.Rngs(0))])
watcher = FlaxWeightWatcher(model=model, details_format="df") # can also be "dict" to return the details in a dictionary instead of a pandas DataFrame
details = watcher.analyze()
details.head()
```

This should print something like this:

```
   layer_index layer_name weight_shape     alpha  num_eigenvals_fit  num_eigenvals  stable_ranks  effective_ranks  ranks
0           13   layers.0      784,128  6.143005                 38            128      8.157706       117.944939    128
1           25   layers.1       128,10  5.302257                  8             10      2.552803         9.536791     10
```

## Features to be added

- metrics computed in WeightWatcher
- ESD plotting utilities
