Metadata-Version: 2.1
Name: wassersteinwormhole
Version: 0.3.7
Summary: Transformer based embeddings for Wasserstein Distances
License: MIT
Author: Doron Haviv
Requires-Python: >=3.9,<4.0
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Requires-Dist: clu (>=0.0.12,<0.0.13)
Requires-Dist: flax (>=0.10.6,<0.11.0)
Requires-Dist: ott-jax (>=0.4.9,<0.5.0)
Requires-Dist: scanpy (>=1.11.2,<2.0.0)
Requires-Dist: tqdm (>=4.67.1,<5.0.0)
Description-Content-Type: text/markdown

WassersteinWormhole
======================

Embedding point-clouds by preserving Wasserstein distances with the Wormhole.

This implementation is written in Python3 and relies on FLAX, JAX, & JAX-OTT.


To install JAX, simply run the command:

    pip install --upgrade pip install -U "jax[cuda12]” 

And to install WassersteinWormhole along with the rest of the requirements: 

    pip install wassersteinwormhole

And running the Wormhole on your own set of point-clouds is as simple as:
    
    from wassersteinwormhole import Wormhole 
    WormholeModel = Wormhole(point_clouds = point_clouds)
    WormholeModel.train()
    Embeddings = WormholeModel.encode(WormholeModel.point_clouds, WormholeModel.masks)
 
For more details, follow tutorial at [https://wasserstienwormhole.readthedocs.io.](https://wassersteinwormhole.readthedocs.io/en/latest/)

