Metadata-Version: 2.4
Name: flash_ansr
Version: 0.4.5
Summary: Flash-ANSR: Fast Amortized Neural Symbolic Regression - Discover symbolic expressions from tabular data using SetTransformer and Transformer architectures
Author: Paul Saegert
License-Expression: MIT
Project-URL: Homepage, https://github.com/psaegert/flash-ansr
Project-URL: Documentation, https://flash-ansr.readthedocs.io/en/latest/
Project-URL: Repository, https://github.com/psaegert/flash-ansr
Project-URL: Issues, https://github.com/psaegert/flash-ansr/issues
Project-URL: PyPI, https://pypi.org/project/flash-ansr/
Project-URL: Bug Reports, https://github.com/psaegert/flash-ansr/issues
Project-URL: Demo Notebook, https://github.com/psaegert/flash-ansr/blob/main/experimental/demo.ipynb
Keywords: symbolic-regression,neural-symbolic-regression,simulation-based-inference,transformer,set-transformer,machine-learning,pytorch,expression-discovery,equation-discovery,tabular-data,mathematical-modeling
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Science/Research
Classifier: Intended Audience :: Developers
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: Scientific/Engineering :: Mathematics
Classifier: Topic :: Software Development :: Libraries :: Python Modules
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.12
Classifier: Operating System :: OS Independent
Requires-Python: >=3.12
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: absl-py
Requires-Dist: datasets
Requires-Dist: drawdata
Requires-Dist: editdistance
Requires-Dist: einops
Requires-Dist: matplotlib
Requires-Dist: numpy
Requires-Dist: pyyaml
Requires-Dist: schedulefree
Requires-Dist: scikit-learn
Requires-Dist: scipy
Requires-Dist: simplipy
Requires-Dist: torch
Requires-Dist: torch_optimizer
Requires-Dist: tqdm
Requires-Dist: wandb
Requires-Dist: zss
Provides-Extra: dev
Requires-Dist: pre-commit; extra == "dev"
Requires-Dist: pytest; extra == "dev"
Requires-Dist: pytest-cov; extra == "dev"
Requires-Dist: mypy; extra == "dev"
Requires-Dist: flake8; extra == "dev"
Requires-Dist: pygount; extra == "dev"
Requires-Dist: pylint; extra == "dev"
Requires-Dist: types-setuptools; extra == "dev"
Requires-Dist: types-tqdm; extra == "dev"
Requires-Dist: types-toml; extra == "dev"
Requires-Dist: types-PyYAML; extra == "dev"
Requires-Dist: radon; extra == "dev"
Requires-Dist: mkdocs; extra == "dev"
Requires-Dist: mkdocs-material; extra == "dev"
Requires-Dist: mkdocs-autorefs; extra == "dev"
Requires-Dist: mkdocstrings; extra == "dev"
Requires-Dist: mkdocs-get-deps; extra == "dev"
Requires-Dist: mkdocs-material-extensions; extra == "dev"
Requires-Dist: mkdocstrings-python; extra == "dev"
Dynamic: license-file

<h1 align="center" style="margin-top: 0px;">⚡Flash-ANSR:<br>Fast Amortized Neural Symbolic Regression</h1>

<div align="center">

[![PyPI version](https://img.shields.io/pypi/v/flash-ansr.svg)](https://pypi.org/project/flash-ansr/)
[![PyPI license](https://img.shields.io/pypi/l/flash-ansr.svg)](https://pypi.org/project/flash-ansr/)
[![Documentation Status](https://readthedocs.org/projects/flash-ansr/badge/?version=latest)](https://flash-ansr.readthedocs.io/en/latest/?badge=latest)

</div>

<div align="center">

[![pytest](https://github.com/psaegert/flash-ansr/actions/workflows/pytest.yml/badge.svg)](https://github.com/psaegert/flash-ansr/actions/workflows/pytest.yml)
[![quality checks](https://github.com/psaegert/flash-ansr/actions/workflows/pre-commit.yml/badge.svg)](https://github.com/psaegert/flash-ansr/actions/workflows/pre-commit.yml)
[![CodeQL Advanced](https://github.com/psaegert/flash-ansr/actions/workflows/codeql.yaml/badge.svg)](https://github.com/psaegert/flash-ansr/actions/workflows/codeql.yaml)

</div>

# Papers
- WIP


# Usage

```sh
pip install flash-ansr
```

```python
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Import flash_ansr
from flash_ansr import (
  FlashANSR,
  SoftmaxSamplingConfig,
  install_model,
  get_path,
)

# Select a model from Hugging Face
# https://huggingface.co/models?search=flash-ansr-v23.0
MODEL = "psaegert/flash-ansr-v23.0-120M"

# Download the latest snapshot of the model
# By default, the model is downloaded to the directory `./models/` in the package root
install_model(MODEL)

# Load the model
model = FlashANSR.load(
  directory=get_path('models', MODEL),
  generation_config=SoftmaxSamplingConfig(choices=32),  # or BeamSearchConfig / MCTSGenerationConfig
  n_restarts=8,
).to(device)

# Define data
X = ...
y = ...

# Fit the model to the data
model.fit(X, y, verbose=True)

# Show the best expression
print(model.get_expression())

# Predict with the best expression
y_pred = model.predict(X)
```

Explore more in the [Demo Notebook](https://github.com/psaegert/flash-ansr/blob/main/experimental/demo.ipynb).

# Overview

### Training

<img src="https://raw.githubusercontent.com/psaegert/flash-ansr/refs/heads/main/assets/images/flash-ansr-training.png" width="500">

> **⚡ANSR Training on Fully Procedurally Generated Data** Inspired by NeSymReS ([Biggio et al. 2021](https://arxiv.org/abs/2106.06427))

### Architecture

<img src="https://raw.githubusercontent.com/psaegert/flash-ansr/refs/heads/main/assets/images/flash-ansr.png">

> **FlashANSR Architecture.** The model consists of an upgraded version of the Set Transformer ([Lee et al. 2019](https://arxiv.org/abs/1810.00825)) encoder, and a Pre-Norm Transformer decoder ([Vaswani et al. 2017](https://arxiv.org/abs/1706.03762), [Xiong et al. 2020](https://arxiv.org/abs/2002.04745)) as a generative model over symbolic expressions.

### Results
Coming soon
<!-- <img src="https://raw.githubusercontent.com/psaegert/flash-ansr/refs/heads/main/assets/images/test_time_compute_fastsrb.svg">

> **Test Time Compute scaling.** ⚡ANSR, NeSymReS ([Biggio et al. 2021](https://arxiv.org/abs/2106.06427)), PySR ([Cranmer 2023](https://arxiv.org/abs/2305.01582)), and E2E ([Kamienny et al. 2022](https://arxiv.org/abs/2204.10532)) are evaluated on the FastSRB benchmark with 10 datasets per equation, $n_{support}=512$, noise level 0.0.\
> AMD 9950X (16C32T), RTX 4090 (24GB). -->



# Citation
```bibtex
@mastersthesis{flash-ansr2024-thesis,
  author  = {Paul Saegert},
  title   = {Flash Amortized Neural Symbolic Regression},
  school  = {Heidelberg University},
  year    = {2025},
  url     = {https://github.com/psaegert/flash-ansr-thesis}
}
@software{flash-ansr2024,
  author  = {Paul Saegert},
  title   = {Flash Amortized Neural Symbolic Regression},
  year    = {2024},
  publisher   = {GitHub},
  version = {0.4.5},
  url     = {https://github.com/psaegert/flash-ansr}
}
```
