LICENSE
README.md
pyproject.toml
tokamax/__init__.py
tokamax/autotuning.py
tokamax/benchmarking.py
tokamax/config.py
tokamax/tokamax_test.py
tokamax.egg-info/PKG-INFO
tokamax.egg-info/SOURCES.txt
tokamax.egg-info/dependency_links.txt
tokamax.egg-info/requires.txt
tokamax.egg-info/top_level.txt
tokamax/_src/__init__.py
tokamax/_src/ad.py
tokamax/_src/ad_test.py
tokamax/_src/batching.py
tokamax/_src/batching_test.py
tokamax/_src/benchmarking.py
tokamax/_src/benchmarking_test.py
tokamax/_src/config.py
tokamax/_src/config_test.py
tokamax/_src/conftest.py
tokamax/_src/gpu_utils.py
tokamax/_src/gpu_utils_test.py
tokamax/_src/hlo_utils.py
tokamax/_src/hlo_utils_test.py
tokamax/_src/jaxtyping.py
tokamax/_src/jaxtyping_test.py
tokamax/_src/mosaic_tpu.py
tokamax/_src/numerics.py
tokamax/_src/numerics_test.py
tokamax/_src/precision.py
tokamax/_src/precision_test.py
tokamax/_src/pydantic.py
tokamax/_src/pydantic_test.py
tokamax/_src/quantization.py
tokamax/_src/shape.py
tokamax/_src/shape_test.py
tokamax/_src/test_utils.py
tokamax/_src/test_utils_test.py
tokamax/_src/utils.py
tokamax/_src/autotuning/__init__.py
tokamax/_src/autotuning/api.py
tokamax/_src/autotuning/api_test.py
tokamax/_src/autotuning/arg_spec.py
tokamax/_src/autotuning/autotuner.py
tokamax/_src/autotuning/cache.py
tokamax/_src/autotuning/cache_test.py
tokamax/_src/ops/__init__.py
tokamax/_src/ops/op.py
tokamax/_src/ops/op_test.py
tokamax/_src/ops/attention/__init__.py
tokamax/_src/ops/attention/api.py
tokamax/_src/ops/attention/api_sharding_test.py
tokamax/_src/ops/attention/api_test.py
tokamax/_src/ops/attention/arg_specs.py
tokamax/_src/ops/attention/base.py
tokamax/_src/ops/attention/base_test.py
tokamax/_src/ops/attention/bench.py
tokamax/_src/ops/attention/jax_nn.py
tokamax/_src/ops/attention/jax_nn_test.py
tokamax/_src/ops/attention/pallas_mosaic_gpu.py
tokamax/_src/ops/attention/pallas_mosaic_gpu_common.py
tokamax/_src/ops/attention/pallas_mosaic_gpu_kernel_sm90.py
tokamax/_src/ops/attention/pallas_mosaic_gpu_test.py
tokamax/_src/ops/attention/pallas_mosaic_gpu_vjp.py
tokamax/_src/ops/attention/pallas_mosaic_gpu_vjp_common.py
tokamax/_src/ops/attention/pallas_mosaic_gpu_vjp_kernel_sm90.py
tokamax/_src/ops/attention/pallas_mosaic_tpu.py
tokamax/_src/ops/attention/pallas_mosaic_tpu_test.py
tokamax/_src/ops/attention/pallas_triton.py
tokamax/_src/ops/attention/pallas_triton_test.py
tokamax/_src/ops/attention/pallas_triton_vjp.py
tokamax/_src/ops/attention/test_base.py
tokamax/_src/ops/attention/xla_chunked.py
tokamax/_src/ops/attention/xla_chunked_test.py
tokamax/_src/ops/experimental/tpu/splash_attention/base.py
tokamax/_src/ops/experimental/tpu/splash_attention/ring_attention_kernel.py
tokamax/_src/ops/experimental/tpu/splash_attention/ring_attention_kernel_test.py
tokamax/_src/ops/experimental/tpu/splash_attention/splash_attention_kernel.py
tokamax/_src/ops/experimental/tpu/splash_attention/splash_attention_kernel_sharded_test.py
tokamax/_src/ops/experimental/tpu/splash_attention/splash_attention_kernel_test.py
tokamax/_src/ops/experimental/tpu/splash_attention/splash_attention_mask.py
tokamax/_src/ops/experimental/tpu/splash_attention/splash_attention_mask_info.py
tokamax/_src/ops/experimental/tpu/splash_attention/splash_attention_mask_test.py
tokamax/_src/ops/experimental/tpu/splash_attention/splash_attention_test_utils.py
tokamax/_src/ops/flex_attention/__init__.py
tokamax/_src/ops/flex_attention/api.py
tokamax/_src/ops/flex_attention/arg_specs.py
tokamax/_src/ops/flex_attention/base.py
tokamax/_src/ops/flex_attention/base_test.py
tokamax/_src/ops/flex_attention/pallas_triton.py
tokamax/_src/ops/flex_attention/pallas_triton_test.py
tokamax/_src/ops/flex_attention/test_base.py
tokamax/_src/ops/flex_attention/wrapper.py
tokamax/_src/ops/flex_attention/wrapper_test_base.py
tokamax/_src/ops/gated_linear_unit/__init__.py
tokamax/_src/ops/gated_linear_unit/api.py
tokamax/_src/ops/gated_linear_unit/api_test.py
tokamax/_src/ops/gated_linear_unit/arg_specs.py
tokamax/_src/ops/gated_linear_unit/base.py
tokamax/_src/ops/gated_linear_unit/base_test.py
tokamax/_src/ops/gated_linear_unit/pallas_triton.py
tokamax/_src/ops/gated_linear_unit/pallas_triton_test.py
tokamax/_src/ops/gated_linear_unit/test_base.py
tokamax/_src/ops/linear_softmax_cross_entropy_loss/__init__.py
tokamax/_src/ops/linear_softmax_cross_entropy_loss/api.py
tokamax/_src/ops/linear_softmax_cross_entropy_loss/api_test.py
tokamax/_src/ops/linear_softmax_cross_entropy_loss/arg_specs.py
tokamax/_src/ops/linear_softmax_cross_entropy_loss/base.py
tokamax/_src/ops/linear_softmax_cross_entropy_loss/base_test.py
tokamax/_src/ops/linear_softmax_cross_entropy_loss/pallas_mosaic_tpu.py
tokamax/_src/ops/linear_softmax_cross_entropy_loss/pallas_mosaic_tpu_test.py
tokamax/_src/ops/linear_softmax_cross_entropy_loss/reference.py
tokamax/_src/ops/linear_softmax_cross_entropy_loss/reference_test.py
tokamax/_src/ops/linear_softmax_cross_entropy_loss/test_utils.py
tokamax/_src/ops/normalization/__init__.py
tokamax/_src/ops/normalization/api.py
tokamax/_src/ops/normalization/api_test.py
tokamax/_src/ops/normalization/arg_specs.py
tokamax/_src/ops/normalization/base.py
tokamax/_src/ops/normalization/base_test.py
tokamax/_src/ops/normalization/bench.py
tokamax/_src/ops/normalization/pallas_triton.py
tokamax/_src/ops/normalization/pallas_triton_config.py
tokamax/_src/ops/normalization/pallas_triton_test.py
tokamax/_src/ops/normalization/pallas_triton_vjp.py
tokamax/_src/ops/normalization/pallas_triton_vjp_config.py
tokamax/_src/ops/normalization/test_base.py
tokamax/_src/ops/ragged_dot/__init__.py
tokamax/_src/ops/ragged_dot/api.py
tokamax/_src/ops/ragged_dot/api_test.py
tokamax/_src/ops/ragged_dot/arg_specs.py
tokamax/_src/ops/ragged_dot/base.py
tokamax/_src/ops/ragged_dot/base_test.py
tokamax/_src/ops/ragged_dot/bench.py
tokamax/_src/ops/ragged_dot/pallas_mosaic_gpu.py
tokamax/_src/ops/ragged_dot/pallas_mosaic_gpu_common.py
tokamax/_src/ops/ragged_dot/pallas_mosaic_gpu_kernel_sm100.py
tokamax/_src/ops/ragged_dot/pallas_mosaic_gpu_kernel_sm100_quant.py
tokamax/_src/ops/ragged_dot/pallas_mosaic_gpu_kernel_sm100_quant_post_scale.py
tokamax/_src/ops/ragged_dot/pallas_mosaic_gpu_kernel_sm90.py
tokamax/_src/ops/ragged_dot/pallas_mosaic_gpu_kernel_sm90_quant.py
tokamax/_src/ops/ragged_dot/pallas_mosaic_gpu_kernel_sm90_quant_ws.py
tokamax/_src/ops/ragged_dot/pallas_mosaic_gpu_kernel_sm90_quant_ws_async_store.py
tokamax/_src/ops/ragged_dot/pallas_mosaic_gpu_test.py
tokamax/_src/ops/ragged_dot/pallas_mosaic_tpu.py
tokamax/_src/ops/ragged_dot/pallas_mosaic_tpu_kernel.py
tokamax/_src/ops/ragged_dot/pallas_mosaic_tpu_test.py
tokamax/_src/ops/ragged_dot/pallas_triton.py
tokamax/_src/ops/ragged_dot/pallas_triton_test.py
tokamax/_src/ops/ragged_dot/test_base.py
tokamax/_src/ops/triangle_multiplication/__init__.py
tokamax/_src/ops/triangle_multiplication/api.py
tokamax/_src/ops/triangle_multiplication/api_test.py
tokamax/_src/ops/triangle_multiplication/arg_specs.py
tokamax/_src/ops/triangle_multiplication/base.py
tokamax/_src/ops/triangle_multiplication/base_test.py
tokamax/_src/pallas/__init__.py
tokamax/_src/pallas/block.py
tokamax/_src/pallas/block_test.py
tokamax/_src/pallas/grid.py
tokamax/benchmarks/attention.py
tokamax/benchmarks/triangle_multiplication.py
tokamax/data/autotuning/nvidia_h100_80gb_hbm3/dot_product_attention.json
tokamax/data/autotuning/nvidia_h100_80gb_hbm3/dot_product_attention_vjp.json
tokamax/data/autotuning/nvidia_h100_80gb_hbm3/flex_attention.json
tokamax/data/autotuning/nvidia_h100_80gb_hbm3/gated_linear_unit.json
tokamax/data/autotuning/nvidia_h100_80gb_hbm3/jax_nn_dot_product_attention.json
tokamax/data/autotuning/nvidia_h100_80gb_hbm3/linear_softmax_cross_entropy_loss.json
tokamax/data/autotuning/nvidia_h100_80gb_hbm3/linear_softmax_cross_entropy_loss_vjp.json
tokamax/data/autotuning/nvidia_h100_80gb_hbm3/normalization.json
tokamax/data/autotuning/nvidia_h100_80gb_hbm3/normalization_vjp.json
tokamax/data/autotuning/nvidia_h100_80gb_hbm3/pallas_mosaic_gpu_flash_attention.json
tokamax/data/autotuning/nvidia_h100_80gb_hbm3/pallas_mosaic_gpu_flash_attention_vjp.json
tokamax/data/autotuning/nvidia_h100_80gb_hbm3/pallas_mosaic_gpu_ragged_dot.json
tokamax/data/autotuning/nvidia_h100_80gb_hbm3/pallas_triton_flash_attention.json
tokamax/data/autotuning/nvidia_h100_80gb_hbm3/pallas_triton_flash_attention_vjp.json
tokamax/data/autotuning/nvidia_h100_80gb_hbm3/pallas_triton_gated_linear_unit.json
tokamax/data/autotuning/nvidia_h100_80gb_hbm3/pallas_triton_normalization.json
tokamax/data/autotuning/nvidia_h100_80gb_hbm3/pallas_triton_normalization_vjp.json
tokamax/data/autotuning/nvidia_h100_80gb_hbm3/pallas_triton_ragged_dot.json
tokamax/data/autotuning/nvidia_h100_80gb_hbm3/ragged_dot.json
tokamax/data/autotuning/nvidia_h100_80gb_hbm3/triangle_multiplication.json
tokamax/data/autotuning/nvidia_h100_80gb_hbm3/xla_chunked_dot_product_attention.json
tokamax/data/autotuning/tpu7x/dot_product_attention.json
tokamax/data/autotuning/tpu7x/dot_product_attention_vjp.json
tokamax/data/autotuning/tpu7x/flex_attention.json
tokamax/data/autotuning/tpu7x/gated_linear_unit.json
tokamax/data/autotuning/tpu7x/jax_nn_dot_product_attention.json
tokamax/data/autotuning/tpu7x/linear_softmax_cross_entropy_loss.json
tokamax/data/autotuning/tpu7x/linear_softmax_cross_entropy_loss_vjp.json
tokamax/data/autotuning/tpu7x/normalization.json
tokamax/data/autotuning/tpu7x/normalization_vjp.json
tokamax/data/autotuning/tpu7x/pallas_mosaic_tpu_linear_softmax_cross_entropy_loss.json
tokamax/data/autotuning/tpu7x/pallas_mosaic_tpu_linear_softmax_cross_entropy_loss_vjp.json
tokamax/data/autotuning/tpu7x/pallas_mosaic_tpu_ragged_dot.json
tokamax/data/autotuning/tpu7x/ragged_dot.json
tokamax/data/autotuning/tpu7x/triangle_multiplication.json
tokamax/data/autotuning/tpu7x/xla_chunked_dot_product_attention.json
tokamax/experimental/utils/tuning/tpu/ragged_dot_benchmarking.py
tokamax/experimental/utils/tuning/tpu/splash_attention_benchmarking.py