Metadata-Version: 2.1
Name: faster-qwen3-tts-batch
Version: 0.2.0
Summary: Dynamic batching for Qwen3-TTS with CUDA graph acceleration
Author: jbang2004
Project-URL: Homepage, https://github.com/jbang2004/faster-qwen3-tts-batch
Project-URL: Repository, https://github.com/jbang2004/faster-qwen3-tts-batch
Keywords: tts,qwen3,cuda-graph,batching
Classifier: Development Status :: 4 - Beta
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Python: >=3.10
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch>=2.4
Requires-Dist: transformers>=4.47
Requires-Dist: numpy
Requires-Dist: soundfile
Requires-Dist: faster-qwen3-tts>=0.2.4
Provides-Extra: dev
Requires-Dist: pytest; extra == "dev"

# faster-qwen3-tts-batch

Dynamic batching for Qwen3-TTS with CUDA graph acceleration. Extends [faster-qwen3-tts](https://github.com/jbang2004/faster-qwen3-tts) to support batched inference, achieving 2-4x throughput improvement via shared weight computation and batched CUDA graphs.

## Features

- **Batched CUDA Graphs**: Predictor and talker decode steps captured as batched CUDA graphs for minimal kernel launch overhead
- **Batched Prefill**: Single forward pass for all B samples with left-padding alignment (vs B separate prefills)
- **Continuous Batching**: Mid-generation slot replacement — when a sample finishes (EOS), its slot is immediately filled with the next pending request
- **Async Scheduler**: `BatchScheduler` with dedicated GPU thread, smart batching with configurable `max_wait_ms`, and async `infer()` API
- **Two Execution Modes**:
  - `continuous=False` (default): Batch-and-wait — collect batch, run to completion, optimal for consumer GPUs
  - `continuous=True`: Slot replacement — fill freed slots mid-generation, optimal for high-end GPUs (H100/A100)

## Installation

```bash
pip install -e .
```

Requires `faster-qwen3-tts >= 0.2.4` and a CUDA-capable GPU.

## Quick Start

### Direct Batched Generation

```python
from faster_qwen3_tts_batch import BatchedFasterQwen3TTS

model = BatchedFasterQwen3TTS.from_pretrained(
    "path/to/Qwen3-TTS-12Hz-0.6B-Base",
    max_batch_size=4,
    max_seq_len=512,
)
model.warmup(prefill_len=100)

requests = [
    {"text": "你好世界", "language": "Auto", "ref_audio": "ref.wav", "ref_text": ""},
    {"text": "今天天气不错", "language": "Auto", "ref_audio": "ref.wav", "ref_text": ""},
]
outputs = model.generate_voice_clone_batch(requests, max_new_tokens=500)

for audio_arrays, sr in outputs:
    # audio_arrays[0] is numpy array, sr is sample rate (24000)
    ...
```

### Async Scheduler (for serving)

```python
import asyncio
from faster_qwen3_tts_batch import BatchedFasterQwen3TTS, BatchScheduler

model = BatchedFasterQwen3TTS.from_pretrained("path/to/model", max_batch_size=4)
model.warmup()

scheduler = BatchScheduler(
    model,
    max_batch_size=4,
    max_wait_ms=50,         # wait up to 50ms to collect a batch
    continuous=False,        # batch-and-wait mode (default)
    gen_kwargs={"max_new_tokens": 500, "repetition_penalty": 1.1},
)

async def serve():
    await scheduler.start()

    # Concurrent requests are automatically batched
    audio, sr = await scheduler.infer(
        text="你好世界",
        language="Auto",
        ref_audio="ref.wav",
    )

    await scheduler.stop()

asyncio.run(serve())
```

## Architecture

```
                     ┌──────────────┐
                     │  infer()     │  ← async API (multiple callers)
                     │  (event loop)│
                     └──────┬───────┘
                            │ queue.Queue (thread-safe)
                     ┌──────▼───────┐
                     │  GPU Thread  │  ← dedicated engine loop
                     │              │
                     │  collect     │  wait up to max_wait_ms
                     │  batch       │  or until max_batch_size
                     │              │
              ┌──────┴──────────────┴──────┐
              │                            │
     ┌────────▼─────────┐     ┌────────────▼───────────┐
     │ batch-and-wait   │     │ continuous batching     │
     │ (continuous=False)│     │ (continuous=True)       │
     │                  │     │                        │
     │ generate_batch() │     │ continuous_generate()  │
     │ → decode audio   │     │ → slot replacement     │
     │ → deliver all    │     │ → decode after done    │
     └──────────────────┘     └────────────────────────┘
```

### Key Components

| File | Description |
|------|-------------|
| `model.py` | `BatchedFasterQwen3TTS` — main API, wraps base model with batched graphs |
| `scheduler.py` | `BatchScheduler` — async scheduler with dedicated GPU thread |
| `continuous_generate.py` | Continuous batched generation with per-slot state tracking |
| `batched_generate.py` | Batch-and-wait generation (all samples run to completion) |
| `batched_talker_graph.py` | Batched CUDA graph for talker decode with slot replacement support |
| `batched_predictor_graph.py` | Batched CUDA graph for code predictor |
| `batched_sampling.py` | Batched top-k/top-p sampling with per-slot EOS suppression |

## Benchmark Results (RTX 3060, 0.6B model, 8 requests)

```
Sequential (8x B=1):        ~9.0s  throughput=2.9x realtime
Batch-and-Wait (2x B=4):    ~5.4s  throughput=5.2x realtime
Continuous Batching (B=4):  ~10.6s  throughput=3.8x realtime  (4 replacements)
```

**Key finding**: On consumer GPUs (RTX 3060/4090), single-sample eager prefill for slot replacement (~500ms) costs more than it saves. **Batch-and-wait is the optimal strategy for consumer GPUs**. Continuous batching benefits high-end GPUs (H100/A100) where prefill latency is ~50ms.

## Development History

### Phase 1: Batched CUDA Graphs + Scheduler (v0.1.0)

The initial version established the core infrastructure:

- **`BatchedPredictorGraph`**: CUDA-graphed batched code predictor — takes `[B, 2, H]` input (past_hidden + last_token_embed), produces `[B, 15]` codebook tokens in a single graph replay
- **`BatchedTalkerGraph`**: CUDA-graphed batched talker decode — left-padding alignment so all samples share a single `cache_position` scalar, with per-sample differences handled by `attention_mask` and `rope_deltas`
- **`batched_fast_generate()`**: Batched autoregressive loop with shared prefill, CUDA-graphed decode, and batched sampling
- **`BatchScheduler`**: Async scheduler that collects requests with smart waiting (max_batch_size or max_wait_ms timeout), then dispatches to GPU

Key design decisions:
- **Left-padding alignment**: All samples padded to `max_seq_len` in the batch, enabling a single shared `cache_position` for CUDA graph capture
- **StaticCache**: Required for CUDA graphs (no dynamic memory allocation during replay)
- **Per-batch rope_deltas**: Each sample carries its own rope_delta to compensate for padding differences

### Phase 2: Batched Prefill (PR #2)

Replaced B separate talker prefill forward passes with a single batched forward:

- **Before**: Loop `B` times calling `talker.forward()` individually, each producing a separate DynamicCache
- **After**: Left-pad all inputs to `max_seq_len`, single `talker.forward()` call, one DynamicCache for the whole batch

Added `prefill_kv_batched()` to `BatchedTalkerGraph` for direct batched-cache-to-StaticCache copy (vs the original `prefill_kv()` which left-pads individual caches).

Also introduced `set_generation_state()` with `padding_in_rope_deltas` flag — when rope_deltas come from a batched prefill (with attention_mask), they already account for padding, so no additional offset is needed.

### Phase 3: Continuous Batching (v0.2.0, current)

The most architecturally significant change — enabling mid-generation slot replacement:

**On-the-fly attention mask computation** (`batched_talker_graph.py`):
- Replaced the precomputed `attn_mask_table[position]` lookup with per-step `_compute_and_set_mask()` using HuggingFace's `create_causal_mask` / `create_sliding_window_causal_mask`
- This eliminated the need to rebuild the entire mask table when a slot is replaced — just update `_attention_mask_2d[slot_idx]` and recompute
- Trade-off: Slightly more compute per decode step, but enables slot replacement without O(max_seq_len) mask rebuilds

**Slot replacement primitives** (`batched_talker_graph.py`):
- `replace_slot_kv()`: Zero out a slot's KV cache in StaticCache, then write the new sample's left-padded KV data at `[current_pos - prefill_len, current_pos)`
- `update_slot_state()`: Update `_attention_mask_2d` and `rope_deltas` for the replaced slot
- Key discovery: HuggingFace `StaticCache` layers use `layer.keys` / `layer.values` (not `layer.key_cache` / `layer.value_cache`)

**Per-slot generation state** (`continuous_generate.py`):
- `gen_steps` (per-slot): Controls trailing_text_hidden injection timing — each slot independently indexes into its text hidden sequence
- `slot_steps` (per-slot): Controls min_new_tokens EOS suppression — replacement samples start from 0
- Dynamic `trailing_text_padded` buffer with auto-expansion when replacement samples have longer text
- `id()`-based completed tag tracking to avoid unhashable `_PendingRequest` dataclass in sets

**Deferred audio decode** (`scheduler.py`):
- `on_slot_done` callback collects raw `(tag, codec_ids.clone(), timing)` tuples during generation
- Audio decode (`speech_tokenizer.decode()`) runs *after* `continuous_batched_generate()` returns
- This avoids CUDA state conflicts between the speech tokenizer and CUDA-graphed generation loop

**Dual-mode scheduler**:
- `continuous=False` (default): Delegates to `generate_voice_clone_batch()` — proven faster on consumer GPUs
- `continuous=True`: Uses `continuous_batched_generate()` with `try_get_replacement` callback
- Dedicated GPU thread with `queue.Queue` for thread-safe async→sync bridging

**Performance insight**: Continuous batching is not universally beneficial. On RTX 3060, single-sample eager prefill for slot replacement (~500ms) pauses all active slots, costing more time than it saves. The `continuous` flag defaults to `False` to reflect this finding.

## License

MIT
