The Problem with PyTorch Inference
spaCy is the de-facto standard for production NLP pipelines. Models like en_core_web_trf use transformer architectures (RoBERTa, BERT) as feature extractors: on every batch, the transformer does a forward pass and downstream components — NER, POS tagger, lemmatizer — read the resulting hidden states from Doc._.trf_data.
Here's the thing: PyTorch was built for training. Its dynamic computation graph, per-operator dispatch, Python-level overhead, and autograd bookkeeping are exactly what you need to iterate on a research model. For inference, they're dead weight. At production scale, this matters.
The real bottleneck isn't the matrix multiplications — it's everything around them. CUDA kernel launch overhead, unnecessary memory allocations, gradient tracking that nobody asked for. Specialized inference engines like TensorRT and ONNX Runtime eliminate this by compiling the static graph once and running it as a fused kernel series.
How spaCy Works Under the Hood
Pipeline Architecture
When you call nlp(text), the text flows through a sequence of components. The transformer sits near the front and is the most compute-intensive step:
The thinc Object Hierarchy
spaCy uses thinc as an abstraction layer over ML frameworks. This creates a nested object hierarchy — and understanding it is the key to the whole approach. Here's what lives inside a transformer component:
PyTorchShim is the bridge that takes thinc's ragged arrays and calls nn.Module.forward(). The actual model — RoBERTa, BERT, XLM-R — lives in shim._model. That's the single object we replace.
Everything above it in the hierarchy — the thinc model, the spaCy component, all the downstream NER/POS logic — stays completely untouched.
The Approach
The idea is simple in principle: export the model to a static ONNX graph once, run it through TensorRT or ONNX Runtime, and give PyTorchShim a proxy object that looks like an nn.Module but runs on the optimized engine.
From the user's side, it's a one-liner:
import spacy
import spacy_accelerate
nlp = spacy.load("en_core_web_trf")
nlp = spacy_accelerate.optimize(nlp, provider="tensorrt", precision="fp16")
# Works exactly the same as before
doc = nlp("Apple Inc. was founded by Steve Jobs.")
print([(ent.text, ent.label_) for ent in doc.ents])
# [('Apple Inc.', 'ORG'), ('Steve Jobs', 'PERSON')]
The optimize() Pipeline: 9 Steps
-
1Validate pipelineCheck that the spaCy pipeline has a transformer component. Raise early if not.
-
2Find transformer shimRecursively walk the thinc model hierarchy to locate the
PyTorchShim. -
3Detect architectureRead
state_dictkey patterns to identify architecture (RoBERTa, BERT, XLM-R, DistilBERT) and extractnum_layers,hidden_size. -
4Build OptimizeConfigAssemble the configuration object from user parameters and detected architecture info.
-
5Check provider availabilityVerify that TensorrtExecutionProvider (or CUDAExecutionProvider) is registered in the installed onnxruntime-gpu build.
-
6Get or export ONNX modelCache hit → load the file. Cache miss → convert weights, export to ONNX, optionally convert to FP16.
-
7Create proxyInstantiate
IOBindingProxy,ORTProxy, orCPUProxydepending on provider and config. -
8Patch pipelineSave the original model (
shim._original_model = shim._model) and swap in the proxy. One assignment, reversible. -
9Warmup inferenceRun a dummy sentence through the full pipeline to trigger TRT kernel compilation and JIT warm-up. All subsequent calls are at full speed.
Weight Conversion: the Tricky Part
spaCy ships with curated-transformers as its native transformer library, which uses a different weight format than HuggingFace. To export to ONNX, we need a HuggingFace model — so we have to remap the weights.
The Q/K/V Split
The trickiest piece is multi-head attention. In curated-transformers, the Q, K, and V projection matrices are fused into a single weight tensor. HuggingFace keeps them separate. We use torch.chunk to split them:
# Split fused QKV weight into three separate matrices
mha_in_weight = state_dict["layers.{i}.mha.input.weight"] # (3H, H)
q_w, k_w, v_w = torch.chunk(mha_in_weight, 3, dim=0)
mha_in_bias = state_dict["layers.{i}.mha.input.bias"] # (3H,)
q_b, k_b, v_b = torch.chunk(mha_in_bias, 3, dim=0)
Full Renaming Map
| curated-transformers key | HuggingFace RoBERTa key |
|---|---|
| Embeddings | |
| embeddings.inner.word_embeddings.weight | embeddings.word_embeddings.weight |
| embeddings.inner.position_embeddings.weight | embeddings.position_embeddings.weight |
| embeddings.inner.layer_norm.{weight,bias} | embeddings.LayerNorm.{weight,bias} |
| Per-layer attention (layer i) | |
| layers.{i}.mha.input.weight [chunk 0] | layer.{i}.attention.self.query.weight |
| layers.{i}.mha.input.weight [chunk 1] | layer.{i}.attention.self.key.weight |
| layers.{i}.mha.input.weight [chunk 2] | layer.{i}.attention.self.value.weight |
| layers.{i}.mha.output.weight | layer.{i}.attention.output.dense.weight |
| layers.{i}.attn_output_layernorm.weight | layer.{i}.attention.output.LayerNorm.weight |
| Per-layer FFN | |
| layers.{i}.ffn.intermediate.{weight,bias} | layer.{i}.intermediate.dense.{weight,bias} |
| layers.{i}.ffn.output.{weight,bias} | layer.{i}.output.dense.{weight,bias} |
| layers.{i}.ffn_output_layernorm.weight | layer.{i}.output.LayerNorm.weight |
After remapping, we create a HuggingFace model (RobertaModel, BertModel, etc.), load the state dict with strict=False, and verify parity: run a dummy input through both models and assert max_abs_diff < 1e-4. If that check fails, something went wrong in the mapping.
ONNX Export and FP16 Conversion
Export
torch.onnx.export(
ONNXWrapper(hf_model), # Wrapper: returns only last_hidden_state
(dummy_input_ids, dummy_mask),
onnx_path,
input_names=["input_ids", "attention_mask"],
output_names=["last_hidden_state"],
dynamic_axes={
"input_ids": {0: "batch", 1: "seq"},
"attention_mask": {0: "batch", 1: "seq"},
"last_hidden_state": {0: "batch", 1: "seq"},
},
opset_version=17,
)
batch_size=2, not 1. If you pass a batch of 1, the ONNX optimizer may constant-fold the batch dimension to a static value of 1 during export. With batch_size=2, it can't make that inference — the axis stays dynamic.
FP16 Conversion
We use the ONNX Runtime transformer optimizer, which applies graph-level fusions before converting to FP16:
model = optimizer.optimize_model(
onnx_path,
model_type="bert",
optimization_options=FusionOptions(...),
)
model.convert_float_to_float16(keep_io_types=True)
Two flags matter here:
keep_io_types=True— inputs (int64token ids) and the final output (float32) stay in their original types. Only the internal operations run in FP16.enable_bias_gelu=False— TensorRT doesn't support the custom BiasGelu op that ORT introduces. Disabling it keeps the graph TRT-compatible. LayerNorm, MHA, and standard GELU fusions are still applied.
Caching
ONNX export takes 20–60 seconds. TRT engine compilation takes 2–5 minutes on first run. Repeating this on every startup isn't an option. Everything gets cached keyed by a content hash:
cache_key = SHA256({
"name": nlp.meta["name"], # "en_core_web_trf"
"version": nlp.meta["version"], # "3.8.0"
"lang": nlp.meta["lang"], # "en"
"precision": "fp16",
"structure_hash": SHA256(sorted_first_20_keys)[:8],
"first_weight_shape": str(state_dict[first_key].shape),
})[:16]
The structure_hash and first_weight_shape fields guard against in-place weight modifications. If the model weights change at runtime, the key changes and the cache misses.
~/.cache/spacy-accelerate/
├── a3f1b2c4d5e6f7a8/
│ ├── model.onnx
│ └── model_fp16.onnx
├── 9d8e7f6a5b4c3d2e/
│ └── model_fp16.onnx
└── trt_engines/
├── TensorrtExecutionProvider_cache_...
└── ...
TRT caches compiled GPU kernels separately in trt_engines/. On subsequent runs, if the input shape matches a cached engine, TRT loads it directly — no recompilation, warmup takes around 10 seconds instead of minutes.
Runtime: IO Binding and Zero-Copy
The default ONNX Runtime CUDA session accepts NumPy arrays. That means every forward pass has two round-trips through the CPU:
IO Binding passes GPU memory pointers to ONNX Runtime instead of data. The engine reads inputs directly from the CUDA buffer and writes outputs to a pre-allocated output buffer — both without leaving the GPU. This eliminates the biggest overhead source on short-to-medium sequence lengths.
# IOBindingProxy forward pass (simplified)
binding = session.io_binding()
binding.bind_input(
name="input_ids",
device_type="cuda", device_id=0,
element_type=np.int64,
shape=input_ids.shape,
buffer_ptr=input_ids.data_ptr(), # GPU pointer, no copy
)
binding.bind_output(
name="last_hidden_state",
device_type="cuda", device_id=0,
element_type=np.float32,
shape=output_shape,
buffer_ptr=output_buffer.data_ptr(),
)
session.run_with_iobinding(binding)
Batch Bucketing for TensorRT
TensorRT compiles a separate GPU kernel for each input shape. Without bucketing, a new batch size triggers a 100–500ms compilation pause in production — unacceptable latency spikes.
The solution: pre-define a set of bucket sizes at startup, and pad every incoming batch to the nearest bucket. All kernels compile during warmup, never during inference.
# TensorRT FP16 with batch bucketing
nlp = spacy_accelerate.optimize(
nlp,
provider="tensorrt",
precision="fp16",
trt_max_workspace_size=4 * 1024**3, # 4 GB GPU workspace
trt_builder_optimization_level=3, # 0–5; higher = longer compile
trt_timing_cache=True, # reuse timing across engine builds
batch_buckets=[8, 16, 64, 128], # pre-compile these shapes
fixed_seq_length=144, # 128-token window + 16 overhead
)
The fixed_seq_length=144 setting tells the proxy to always pad sequences to 144 tokens before sending to TRT. For en_core_web_trf, spaCy's thinc processing always produces sequences of exactly this length, so we're not wasting compute — and TRT can compile a kernel that's even more aggressively optimized for the fixed shape.
UniversalTransformerOutput
ONNX exports only last_hidden_state. The problem is that spaCy and different versions of spacy-transformers access transformer outputs by different attribute names depending on the exact version and component:
output.embedding_output
output.last_hidden_layer_state
output.all_hidden_layer_states[i]
output.layer_hidden_states
output.all_outputs
Rather than tracking which version uses which name, we create a wrapper that responds to all of them with the same value:
class UniversalTransformerOutput:
def __init__(self, hidden_state, num_layers):
self.embedding_output = hidden_state
self.last_hidden_layer_state = hidden_state
self.all_hidden_layer_states = [hidden_state] * num_layers
self.layer_hidden_states = self.all_hidden_layer_states
self.last_hidden_layer_states = self.all_hidden_layer_states
self.all_outputs = self.all_hidden_layer_states
self.num_layers = num_layers
This is a deliberate simplification: all "layers" return the same final hidden state. For NER and POS tasks — the primary use case — this is correct, since they only use the final layer. If you need probing access to intermediate layers, you'd need a different export strategy.
Benchmarks
en_core_web_trf (RoBERTa-base, spaCy 3.8). NVIDIA RTX 4000 SFF Ada Generation, CUDA 12, CoNLL-2003 test set, batch size 128. Numbers are words per second averaged over 3 measured passes after 1 discarded warm-up pass. Two modes measured: full pipeline and NER-only (tagger, parser, attribute_ruler, lemmatizer disabled).
Full Pipeline
All spaCy components active: tokenizer → transformer → tagger → parser → attribute_ruler → lemmatizer → ner. The non-transformer components are not accelerated, which caps the overall speedup.
| Configuration | WPS | Speedup | Accuracy |
|---|---|---|---|
| PyTorch FP32 (baseline) | 6,241 | 1.00× | 100.0% |
| PyTorch FP16 (autocast) | 6,166 | 0.99× | 100.0% |
| ONNX Runtime CUDA FP32 | 9,910 | 1.59× | 99.9% |
| ONNX Runtime CUDA FP16 | 15,763 | 2.53× | 99.75% |
| TensorRT FP32 | 10,552 | 1.69× | 99.95% |
| TensorRT FP16 | 16,935 | 2.71× | 99.5% |
NER-only
Tagger, parser, attribute_ruler, and lemmatizer disabled — only transformer + NER runs. This isolates the transformer bottleneck and shows the maximum acceleration the engine can deliver on the model itself.
| Configuration | WPS | Speedup | Accuracy |
|---|---|---|---|
| PyTorch FP32 (baseline) | 7,066 | 1.00× | 100.0% |
| PyTorch FP16 (autocast) | 6,859 | 0.97× | 100.0% |
| ONNX Runtime CUDA FP32 | 11,972 | 1.69× | 99.9% |
| ONNX Runtime CUDA FP16 | 22,394 | 3.17× | 99.75% |
| TensorRT FP32 | 13,138 | 1.86× | 99.95% |
| TensorRT FP16 | 24,823 | 3.51× | 99.65% |
NER-only speedup (3.51×) is higher than full-pipeline (2.71×) because Amdahl's law: in full-pipeline mode, tagger, parser, and lemmatizer are not accelerated and consume a fixed fraction of pipeline time, capping the overall gain.
Bottom line
TensorRT FP16 processes 24,823 WPS (NER-only) vs 7,066 for PyTorch FP32 — 99.65% accuracy retained on CoNLL-2003. Full pipeline: 16,935 WPS, 2.71×. Tested on en_core_web_trf, NVIDIA RTX 4000 SFF Ada Generation. One line of code.
Usage
Installation
# Step 1: install the package
pip install spacy-accelerate
# Step 2: install the NVIDIA build of onnxruntime-gpu (includes TensorRT EP)
# The standard PyPI build does not include TensorRT EP
pip install --force-reinstall \
--extra-index-url https://pypi.nvidia.com \
onnxruntime-gpu==1.23.2
# Verify providers are available
python -m spacy_accelerate
# Expected output:
# TensorRT EP : OK
# CUDA EP : OK
TensorRT FP16 — Maximum Performance
nlp = spacy_accelerate.optimize(
nlp,
provider="tensorrt",
precision="fp16",
trt_max_workspace_size=4 * 1024**3,
trt_builder_optimization_level=3,
trt_timing_cache=True,
batch_buckets=[8, 16, 64, 128],
fixed_seq_length=144,
)
First run after a clean cache: 2–5 minutes for TRT kernel compilation. Subsequent runs: ~10 seconds warmup, then full speed.
ONNX Runtime CUDA FP16 — Fast Start
nlp = spacy_accelerate.optimize(
nlp,
provider="cuda",
precision="fp16",
)
No TRT compilation overhead. Good throughput (3.02×). The right choice when you can't afford the startup latency of TRT compilation, or when TRT isn't available.
CPU
nlp = spacy_accelerate.optimize(
nlp,
provider="cpu",
precision="fp32",
)
No GPU required. ONNX Runtime applies graph-level optimizations that give a modest improvement over raw PyTorch on CPU.
Supported Models
Only en_core_web_trf has been tested end-to-end. Other spaCy English models (en_core_web_sm/md/lg) are not transformer-based and are not compatible. Non-English transformer models have not been validated.
| spaCy model | Architecture | Status | Notes |
|---|---|---|---|
en_core_web_trf |
RoBERTa (curated) | Tested | Only model with benchmark results |
en_core_web_sm/md/lg |
CNN / tok2vec | Not applicable | No transformer component |
| Other language models | Various | Not tested | Weight mapping may not cover all architectures |
Limitations
Inference only. The graph is frozen at export time. Fine-tuning through spacy-accelerate is not possible — there are no gradients.
Final hidden state only. ONNX exports a single output. If a downstream component accesses intermediate layer outputs, it will receive the final layer repeated. Probing intermediate layers requires a different export strategy.
Batch buckets required for TRT. TensorRT needs shapes known at compile time. Batch sizes outside your configured batch_buckets will trigger a compilation pause on first occurrence.
CUDA 12 only. The full stack (PyTorch, TensorRT, cupy, onnxruntime-gpu) is built against CUDA 12. CUDA 11 is not supported.
Single GPU. Multi-GPU is not implemented. device_id selects a specific device, but cross-GPU parallelism isn't supported.