Metadata-Version: 2.4
Name: flax_orbax
Version: 0.1.4
Summary: A package for saving JAX-compatible dataclasses with Orbax
Author-email: Dibya Ghosh <dibya@berkeley.edu>
Requires-Python: >=3.10
Requires-Dist: pyyaml
Provides-Extra: dev
Requires-Dist: ipykernel; extra == 'dev'
Provides-Extra: jax
Requires-Dist: flax; extra == 'jax'
Requires-Dist: orbax-checkpoint; extra == 'jax'
Description-Content-Type: text/markdown

## Orbax 🤝 Dataclasses

A convenient way to serialize dataclasses (and for orbax) in an easier to read way (avoid Pickle!)

Usage:

It's pretty common to store parameters and model information in the same object in Jax scripts, e.g.

```python
@jax.tree_util.register_dataclass
@dataclasses.dataclass
class Model:
    module: nn.Module = dataclasses.field(metadata=dict(static=True))
    params: dict

# or e.g. a flax.training.train_state.TrainState
```

If we try saving it using standard Orbax:
```python
ckptr = ocp.StandardCheckpointer()
ckptr.save('/tmp/checkpoint', model)
ckptr.restore('/tmp/checkpoint') # -> we'll get a dict: {'params': ...} back, the object is lost
```

The idea of this library is to be able to seamlessly store the full object
```python
ckptr = flax_orbax.Checkpointer()
ckptr.save('/tmp/checkpoint', model)
ckptr.restore('/tmp/checkpoint') # -> we will get a Model() back
```

## A few sharp edges

Some things in a typical JAX workflow are not serializable -- e.g. the optax transformation. `flax_orbax.wrap` handles this:

```python
# instead of optax.adam(1e-3), do:
tx = flax_orbax.wrap(optax.adam)(1e-3)
```

If you want to restore on a different device (with a different sharding):

```python
ckptr = flax_orbax.Checkpointer()
ckptr.restore(args=flax_orbax.RestoreArgs(sharding=sharding))
```
