Metadata-Version: 2.4
Name: torch_export_python
Version: 0.1.0
Summary: Export PyTorch models as readable, editable Python code
Author: Meta Platforms, Inc.
License: BSD 3-Clause License
        
        (c) Meta Platforms, Inc. and affiliates.
        
        Redistribution and use in source and binary forms, with or without modification,
        are permitted provided that the following conditions are met:
        
        1. Redistributions of source code must retain the above copyright notice, this list
        of conditions and the following disclaimer.
        
        2. Redistributions in binary form must reproduce the above copyright notice, this
        list of conditions and the following disclaimer in the documentation
        and/or other materials provided with the distribution.
        
        3. Neither the name of the copyright holder nor the names of its contributors may
        be used to endorse or promote products derived from this software without specific
        prior written permission.
        
        THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
        EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
        OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT
        SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
        INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
        TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
        BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
        CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
        ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
        DAMAGE.
        
Project-URL: Homepage, https://github.com/meta-pytorch/torch_export_python
Project-URL: Repository, https://github.com/meta-pytorch/torch_export_python
Project-URL: Issues, https://github.com/meta-pytorch/torch_export_python/issues
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: BSD License
Classifier: Programming Language :: Python :: 3 :: Only
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Python: >=3.10
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch>=2.10.0
Dynamic: license-file

# `torch.export_python`

This is an experimental mode of use of the PyTorch compiler stack, where the
output artifacts of the compiler are entirely readable Python code.  This
output can then be checked into a source repository and edited by hand; or
we also try to make it simple to regenerate the code if you upgraded
the compiler or modified the original Python plain source code.

Why might you want something like this?  Lots of reasons:

- **Precompilation.** You don't want to have to run the compiler every
  time you run your model.  The exported Python code here has no runtime
  compiler dependency.

- **Transparency.** The code generated by the compiler is similar to
  the code you would have written if you optimized by hand.  So you don't have
  to trust the compiler; you can audit the output only and trust that only.

- **Portability.** You don't have to regenerate the exported Python code if
  you don't want to; you can upgrade Python separately from upgrading your
  kernels.  The generated code uses only stable PyTorch APIs and is portable
  across versions.

As of right now, we intend for export python to cover these layers of the
compiler stack:

- **AOTAutograd.** We can take forward-only PyTorch code and create a custom
  autograd function that binds together a forward and backward implementation.
  This is blocked on AOTAutograd codegen of runtime wrappers, see:

    https://github.com/pytorch/pytorch/pull/176741
    https://github.com/pytorch/pytorch/pull/179599
    https://github.com/pytorch/pytorch/pull/179061
    https://github.com/pytorch/pytorch/pull/178927
    https://github.com/pytorch/pytorch/pull/178675

- **Inductor.** We can take PyTorch code and turn it into fused Triton kernels.

In future work, we may also extend to support Dynamo for handling complicated
input/output conventions / Python side effect mutation.

## Usage

### Basic export (static shapes)

```python
import torch
from torch_export_python import torch_export_python

def rms_norm(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
    return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-6) * weight

args = (torch.randn(2, 64, 256, device="cuda"),
        torch.randn(256, device="cuda"))

kernel = torch_export_python(rms_norm, args)
```

### Dynamic shapes

Use `torch.export.Dim` to mark dimensions that can vary at runtime:

```python
from torch.export import Dim

B = Dim("B")
T = Dim("T")
kernel = torch_export_python(rms_norm, args, dynamic_shapes={
    "x": (B, T, None),
    "weight": (None,),
})
```

### Saving and loading

The generated source depends only on `torch`, `triton`, and
`triton.language` at runtime — no `torch._inductor` imports needed.

```python
# save.py — write the generated code to a module
from pathlib import Path
Path("my_package/my_kernel.py").write_text(kernel.source)
```

```python
# run.py — import and run like any other Python module
from my_package.my_kernel import call
result = call(x, weight)
```

### Running in-memory

For quick iteration without saving to disk:

```python
result = kernel.run(*args)             # convenient: keeps your references live
result = kernel.boxed_run(list(args))  # explicit: input list is consumed
```

## API Reference

- `torch_export_python(fn, args, *, dynamic_shapes=None) -> ExportedKernel` — end-to-end pipeline: trace via `torch.export`, compile through Inductor, clean up the output.
- `export_and_codegen(fn, args, *, dynamic_shapes=None) -> str` — stage 1 only: returns raw Inductor source before cleanup.
- `postprocess(raw_src, *, dynamic_shape_names=None, tensor_arg_count=0) -> ExportedKernel` — stage 2 only: clean raw Inductor output and wrap in an `ExportedKernel`.
- `ExportedKernel.source` — the generated Python source code string.
- `ExportedKernel.run(*args)` — execute the kernel with the original function's arguments.
- `ExportedKernel.boxed_run(args_list)` — execute with boxed calling convention (input list is consumed).
- `ExportedKernel.dynamic_shape_names` — list of symbolic dimension names (e.g. `["B", "T"]`).
- `ExportedKernel.tensor_arg_count` — number of tensor arguments expected.

## License

BSD 3-Clause License. See [LICENSE](LICENSE) for details.
