Metadata-Version: 2.4
Name: jax-mcp
Version: 0.1.0
Summary: MCP server for JAX documentation
Requires-Python: >=3.10
Requires-Dist: httpx>=0.27.0
Requires-Dist: mcp>=1.0.0
Requires-Dist: pyyaml>=6.0
Provides-Extra: dev
Requires-Dist: pytest-asyncio>=0.23.0; extra == 'dev'
Requires-Dist: pytest>=7.0; extra == 'dev'
Description-Content-Type: text/markdown

# jax-mcp

MCP (Model Context Protocol) server for [JAX](https://github.com/google/jax) documentation.

Enables LLMs to access up-to-date JAX documentation and validate generated code.

## Installation

```bash
pip install -e .
```

Or run directly:

```bash
python -m jax_mcp
```

## Usage with Claude Code

```bash
# Add as MCP server
claude mcp add -t stdio -s user jax -- python -m jax_mcp

# Or if installed globally
claude mcp add -t stdio -s user jax -- jax-mcp
```

## Configuration

| Environment Variable | Default | Description |
|---------------------|---------|-------------|
| `JAX_DOCS_PATH` | (none) | Path to local JAX docs directory (offline mode) |
| `JAX_MCP_CACHE_DIR` | `~/.cache/jax-mcp` | Cache directory for online mode |
| `JAX_MCP_CACHE_TTL` | `24` | Cache TTL in hours |
| `JAX_MCP_NO_CACHE` | `0` | Set to `1` to disable caching |

### Offline Mode

Point to a local JAX clone for offline access:

```bash
export JAX_DOCS_PATH=/path/to/jax/docs
python -m jax_mcp
```

### Online Mode (Default)

Fetches docs from GitHub with local caching:

```bash
python -m jax_mcp
# Fetches from: raw.githubusercontent.com/google/jax/main/docs/
```

## Available Tools

| Tool | Description |
|------|-------------|
| `list-sections` | List all documentation sections by category |
| `get-documentation` | Fetch specific documentation content |
| `jax-checker` | Validate JAX code for common gotchas |

### Documentation Categories

- **concepts**: Core JAX concepts (pytrees, transformations, tracing)
- **gotchas**: Common mistakes and how to avoid them
- **transforms**: jit, vmap, grad, scan patterns
- **advanced**: Distributed computing, custom pytrees
- **performance**: GPU tips, profiling, benchmarking
- **api**: Module overviews (jax.numpy, jax.lax, jax.random)
- **examples**: Practical code examples

### JAX Checker

The `jax-checker` tool catches common JAX mistakes:

- In-place array mutations (`array[idx] = value`)
- Side effects in jitted functions (print, globals)
- PRNG key reuse without splitting
- Python control flow in traced code
- Missing `block_until_ready()` for benchmarks
- Float64 usage without config

## Development

```bash
# Install dev dependencies
pip install -e ".[dev]"

# Run tests
pytest
```

## License

MIT
