Metadata-Version: 2.4
Name: jaxkd-cuda
Version: 0.0.0
Summary: Extension package for jaxkd.
Author-Email: Benjamin Dodge <bendodge@stanford.edu>
License-Expression: MIT
Project-URL: Source, https://github.com/dodgebc/jaxkd-cuda
Requires-Python: >=3.10
Requires-Dist: jax[cuda12]
Provides-Extra: dev
Requires-Dist: jaxkd; extra == "dev"
Requires-Dist: scikit-build-core; extra == "dev"
Requires-Dist: ipywidgets>=8.1.7; extra == "dev"
Requires-Dist: jax[cuda12]==0.6.0; extra == "dev"
Requires-Dist: jupyterlab>=4.4.3; extra == "dev"
Requires-Dist: matplotlib>=3.10.3; extra == "dev"
Requires-Dist: pytest>=8.3.5; extra == "dev"
Requires-Dist: ruff>=0.11.11; extra == "dev"
Requires-Dist: tqdm>=4.67.1; extra == "dev"
Requires-Dist: pip>=25.1.1; extra == "dev"
Description-Content-Type: text/markdown

# jaxkd-cuda

This package contains CUDA extensions for [JAX *k*-D](https://github.com/dodgebc/jaxkd). It requires JAX, CMake, and a CUDA compiler (nvcc) to build. It is intended to be installed as an optional dependency to JAX *k*-D and used as an add-on like so:

`python -m pip install jaxkd[cuda]`

Note that the [cudaKDTree](https://github.com/ingowald/cudaKDTree) library is more powerful and flexible, and can be bound to JAX using the foreign function interface. See the sample bindings in `jaxkd_cuda/cukd` for a rough example of how to do this. [JaxKDTree](https://github.com/EiffL/JaxKDTree) also has an example, though it is no longer working with the current JAX API.

This extension uses a slightly different tree-building method to exactly match the behavior of the pure-JAX version. It only permutes an index array and chooses the dimension with the widest spread of points (not largest bounding box) to split. Currently the performance bottleneck is actually the reduce operations needed to compute this. There is also a substantial memory overhead (a few times the number of points), which can probably be reduced in the future. The neighbor query algorithm follows [[2](https://arxiv.org/abs/2210.12859)] and the neighbor counting is a trivial modification.