Metadata-Version: 2.4
Name: phasecurvefit
Version: 0.1.0
Summary: Walk through phase-space observations
Project-URL: Bug Tracker, https://github.com/GalacticDynamics/phasecurvefit/issues
Project-URL: Changelog, https://github.com/GalacticDynamics/phasecurvefit/releases
Project-URL: Homepage, https://github.com/GalacticDynamics/phasecurvefit
Author-email: GalacticDynamics <nstarman@users.noreply.github.com>, Nathaniel Starkman <nstarman@users.noreply.github.com>
License: MIT License
        
        Copyright (c) 2025, Nathaniel Starkman.
        All rights reserved.
        
        Permission is hereby granted, free of charge, to any person obtaining a copy
        of this software and associated documentation files (the "Software"), to deal
        in the Software without restriction, including without limitation the rights
        to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
        copies of the Software, and to permit persons to whom the Software is
        furnished to do so, subject to the following conditions:
        
        The above copyright notice and this permission notice shall be included in all
        copies or substantial portions of the Software.
        
        THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
        IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
        FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
        AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
        LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
        OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
        SOFTWARE.
License-File: LICENSE
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3 :: Only
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Topic :: Scientific/Engineering
Classifier: Typing :: Typed
Requires-Python: >=3.12
Requires-Dist: equinox>=0.11.0
Requires-Dist: galax
Requires-Dist: jax-bounded-while>=0.1
Requires-Dist: jax-tqdm>=0.4.0
Requires-Dist: jax>=0.7.2
Requires-Dist: jaxtyping>=0.3.5
Requires-Dist: optax>=0.2.4
Requires-Dist: optional-dependencies>=0.4.0
Requires-Dist: orbax-checkpoint>=0.11.32
Requires-Dist: plum-dispatch>=2.6.1
Requires-Dist: quax>=0.2.1
Requires-Dist: unxt
Requires-Dist: zeroth>=1.0.1
Provides-Extra: interop
Requires-Dist: unxt>=1.10.0; extra == 'interop'
Provides-Extra: kdtree
Requires-Dist: jaxkd>=0.0.5; extra == 'kdtree'
Description-Content-Type: text/markdown

# phasecurvefit: Construct Paths through Phase-Space Points

[![PyPI version](https://img.shields.io/pypi/v/phasecurvefit.svg)](https://pypi.org/project/phasecurvefit/)
[![Python versions](https://img.shields.io/pypi/pyversions/phasecurvefit.svg)](https://pypi.org/project/phasecurvefit/)

Construct paths through phase-Space points, supporting many different
algorithms.

## Installation

Install the core package:

```bash
pip install phasecurvefit
```

Or with uv:

```bash
uv add phasecurvefit
```

<details>
  <summary>from source, using uv</summary>

```bash
uv add git+https://github.com/GalacticDynamics/phasecurvefit.git@main
```

You can customize the branch by replacing `main` with any other branch name.

</details>
<details>
  <summary>building from source</summary>

```bash
cd /path/to/parent
git clone https://github.com/GalacticDynamics/phasecurvefit.git
cd phasecurvefit
uv pip install -e .  # editable mode
```

</details>

### Optional Dependencies

phasecurvefit has optional dependencies for extended functionality:

- **unxt**: Physical units support for phase-space calculations
- **tree (jaxkd)**: Spatial KD-tree queries for large datasets

Install with optional dependencies:

```bash
pip install phasecurvefit[interop]  # Install with unxt for unit support
pip install phasecurvefit[kdtree]  # Install with jaxkd for KD-tree strategy
```

Or with uv:

```bash
uv add phasecurvefit --extra interop
uv add phasecurvefit --extra kdtree
```

## Quick Start

```python
import jax.numpy as jnp
import phasecurvefit as pcf

# Create phase-space observations as dictionaries (Cartesian coordinates)
pos = {"x": jnp.array([0.0, 1.0, 2.0]), "y": jnp.array([0.0, 0.5, 1.0])}
vel = {"x": jnp.array([1.0, 1.0, 1.0]), "y": jnp.array([0.5, 0.5, 0.5])}

# Order the observations (use KD-tree for spatial neighbor prefiltering)
config = pcf.WalkConfig(strategy=pcf.strats.KDTree(k=2))
result = pcf.walk_local_flow(pos, vel, config=config, start_idx=0, metric_scale=1.0)
print(result.indices)  # Array([0, 1, 2], dtype=int32)
```

### With Physical Units

When `unxt` is installed, you can use physical units:

```python
import unxt as u

# Create phase-space observations with units
pos = {"x": u.Q([0.0, 1.0, 2.0], "kpc"), "y": u.Q([0.0, 0.5, 1.0], "kpc")}
vel = {"x": u.Q([1.0, 1.0, 1.0], "km/s"), "y": u.Q([0.5, 0.5, 0.5], "km/s")}

# Units are preserved throughout the calculation
metric_scale = u.Q(1.0, "kpc")
result = pcf.walk_local_flow(
    pos, vel, start_idx=0, metric_scale=metric_scale, usys=u.unitsystems.galactic
)
```

## Features

- **JAX-powered**: Fully compatible with JAX transformations (`jit`, `vmap`,
  `grad`)
- **Cartesian coordinates**: Works in Cartesian coordinate space (x, y, z)
- **Physical units**: Optional support via `unxt` for unit-aware calculations
- **Pluggable metrics**: Customizable distance metrics for different physical
  interpretations
- **Type-safe**: Comprehensive type hints with `jaxtyping`
- **GPU-ready**: Runs on CPU, GPU, or TPU via JAX
- **Spatial KD-tree option**: Use [jaxkd](https://github.com/dodgebc/jaxkd) to
  prefilter neighbors

# Distance Metrics

The algorithm supports pluggable distance metrics to control how points are
ordered. The default `FullPhaseSpaceDistanceMetric` uses true 6D Euclidean
distance across positions and velocities:

```python
from phasecurvefit.metrics import FullPhaseSpaceDistanceMetric
import jax.numpy as jnp

# Define simple Cartesian arrays (not quantities)
pos = {"x": jnp.array([0.0, 1.0, 2.0]), "y": jnp.array([0.0, 0.5, 1.0])}
vel = {"x": jnp.array([1.0, 1.0, 1.0]), "y": jnp.array([0.5, 0.5, 0.5])}

# Configure with full phase-space metric (the default)
config = pcf.WalkConfig(metric=FullPhaseSpaceDistanceMetric())
result = pcf.walk_local_flow(pos, vel, config=config, start_idx=0, metric_scale=1.0)
```

### Using Different Metrics

`phasecurvefit` provides three built-in metrics:

1. **FullPhaseSpaceDistanceMetric** (default): True 6D Euclidean distance in
   phase space
2. **AlignedMomentumDistanceMetric**: Combines spatial distance with velocity
   alignment (NN+p metric)
3. **SpatialDistanceMetric**: Pure spatial distance, ignoring velocity

```python
from phasecurvefit.metrics import SpatialDistanceMetric, FullPhaseSpaceDistanceMetric
import jax.numpy as jnp

# Define simple Cartesian arrays (not quantities)
pos = {"x": jnp.array([0.0, 1.0, 2.0]), "y": jnp.array([0.0, 0.5, 1.0])}
vel = {"x": jnp.array([1.0, 1.0, 1.0]), "y": jnp.array([0.5, 0.5, 0.5])}

# Pure spatial ordering (ignores velocity)
config_spatial = pcf.WalkConfig(metric=SpatialDistanceMetric())
result = pcf.walk_local_flow(
    pos, vel, config=config_spatial, start_idx=0, metric_scale=0.0
)

# Full 6D phase-space distance
config_phase = pcf.WalkConfig(metric=FullPhaseSpaceDistanceMetric())
result = pcf.walk_local_flow(
    pos, vel, config=config_phase, start_idx=0, metric_scale=1.0
)
```

### Custom Metrics

You can define custom metrics by subclassing `AbstractDistanceMetric`:

```python
import equinox as eqx
import jax
import jax.numpy as jnp
from phasecurvefit.metrics import AbstractDistanceMetric


class WeightedPhaseSpaceMetric(AbstractDistanceMetric):
    """Custom weighted phase-space metric."""

    def __call__(self, current_pos, current_vel, positions, velocities, metric_scale):
        # Compute position distance
        pos_diff = jax.tree.map(jnp.subtract, positions, current_pos)
        pos_dist_sq = sum(jax.tree.leaves(jax.tree.map(jnp.square, pos_diff)))

        # Compute velocity distance
        vel_diff = jax.tree.map(jnp.subtract, velocities, current_vel)
        vel_dist_sq = sum(jax.tree.leaves(jax.tree.map(jnp.square, vel_diff)))

        # Custom weighting scheme
        return jnp.sqrt(pos_dist_sq + (metric_scale**2) * vel_dist_sq)


# Use custom metric via WalkConfig
config = pcf.WalkConfig(metric=WeightedPhaseSpaceMetric())
result = pcf.walk_local_flow(pos, vel, config=config, start_idx=0, metric_scale=1.0)
```

See the
[Metrics Guide](https://phasecurvefit.readthedocs.io/en/latest/guides/metrics.html)
for more details and examples.

## KD-tree Strategy

For large datasets, you can enable spatial KD-tree prefiltering to accelerate
neighbor selection:

```python
# Install optional dependency first:
# pip install phasecurvefit[kdtree]

import jax.numpy as jnp

# Define simple Cartesian arrays (not quantities)
pos = {"x": jnp.array([0.0, 1.0, 2.0]), "y": jnp.array([0.0, 0.5, 1.0])}
vel = {"x": jnp.array([1.0, 1.0, 1.0]), "y": jnp.array([0.5, 0.5, 0.5])}

# Use KD-tree strategy and query 2 spatial neighbors per step
config = pcf.WalkConfig(strategy=pcf.strats.KDTree(k=2))
result = pcf.walk_local_flow(pos, vel, config=config, start_idx=0, metric_scale=1.0)
```
