Metadata-Version: 2.4
Name: paramax
Version: 0.0.5
Summary: Parameterizations and parameter constraints for JAX PyTrees.
Project-URL: repository, https://github.com/danielward27/paramax
Project-URL: documentation, https://danielward27.github.io/paramax/index.html
Author-email: Daniel Ward <danielward27@outlook.com>
License: The MIT License (MIT)
        
        Copyright (c) 2022 Daniel Ward
        
        Permission is hereby granted, free of charge, to any person obtaining a copy
        of this software and associated documentation files (the "Software"), to deal
        in the Software without restriction, including without limitation the rights
        to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
        copies of the Software, and to permit persons to whom the Software is
        furnished to do so, subject to the following conditions:
        
        The above copyright notice and this permission notice shall be included in all
        copies or substantial portions of the Software.
        
        THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
        IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
        FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
        AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
        LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
        OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
        SOFTWARE.
License-File: LICENSE
Keywords: equinox,jax,neural-networks
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: MIT License
Classifier: Natural Language :: English
Classifier: Programming Language :: Python :: 3
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: Scientific/Engineering :: Information Analysis
Classifier: Topic :: Scientific/Engineering :: Mathematics
Classifier: Typing :: Typed
Requires-Python: >=3.10
Requires-Dist: equinox
Requires-Dist: jax
Requires-Dist: jaxtyping
Provides-Extra: dev
Requires-Dist: beartype; extra == 'dev'
Requires-Dist: pytest; extra == 'dev'
Requires-Dist: ruff; extra == 'dev'
Requires-Dist: sphinx; extra == 'dev'
Requires-Dist: sphinx-autodoc-typehints; extra == 'dev'
Requires-Dist: sphinx-book-theme; extra == 'dev'
Requires-Dist: sphinx-copybutton; extra == 'dev'
Description-Content-Type: text/markdown


Paramax
============
Parameterizations and constraints for JAX PyTrees
-----------------------------------------------------------------------

Paramax allows applying custom constraints or behaviors to PyTree components,
using unwrappable placeholders. This can be used for
- Enforcing positivity (e.g., scale parameters)
- Structured matrices (triangular, symmetric, etc.)
- Applying tricks like weight normalization
- Marking components as non-trainable

Some benefits of the unwrappable pattern:
- It allows parameterizations to be computed once for a model (e.g. at the top of the
  loss function).
- It is flexible, e.g. allowing custom parameterizations to be applied to PyTrees
  from external libraries
- It is concise

If you found the package useful, please consider giving it a star on github, and if you
create ``AbstractUnwrappable``s that may be of interest to others, a pull request would
be much appreciated!

## Documentation

Documentation available [here](https://danielward27.github.io/paramax/).

## Installation
```bash
pip install paramax
```

## Example
```python
>>> import paramax
>>> import jax.numpy as jnp
>>> scale = paramax.Parameterize(jnp.exp, jnp.log(jnp.ones(3)))  # Enforce positivity
>>> paramax.unwrap(("abc", 1, scale))
('abc', 1, Array([1., 1., 1.], dtype=float32))
```

## Alternative parameterization patterns
Using properties to access parameterized model components is common but has drawbacks:
- Parameterizations are tied to class definition, limiting flexibility e.g. this
  cannot be used on PyTrees from external libraries
- It can become verbose with many parameters
- It often leads to repeatedly computing the parameterization

## Related
- We make use of the [Equinox](https://arxiv.org/abs/2111.00254) package, to register
the PyTrees used in the package
- This package spawned out of a need for a simple method to apply parameter constraints
    in the distributions package [flowjax](https://github.com/danielward27/flowjax)
