Metadata-Version: 2.4
Name: rwkv-emb
Version: 0.1.2
Summary: The EmbeddingRWKV Model
Author: Haowen HOU
Project-URL: Homepage, https://github.com/howard-hou/EmbeddingRWKV
Project-URL: Bug Tracker, https://github.com/howard-hou/EmbeddingRWKV/issues
Classifier: Programming Language :: Python :: 3
Classifier: License :: OSI Approved :: Apache Software License
Classifier: Operating System :: OS Independent
Requires-Python: >=3.7
Description-Content-Type: text/markdown
License-File: LICENSE
Dynamic: license-file

# EmbeddingRWKV

A high-efficiency text embedding and reranking model based on RWKV architecture.

## 🚀 Quick Start (End-to-End)

Get text embeddings in just a few lines. The tokenizer and model are designed to work seamlessly together.

> **Note**: Always set `add_eos=True` during tokenization. The model relies on the EOS token (`65535`) to mark the end of a sentence for correct embedding generation.

```python
import os
# Set environment for JIT compilation (Optional, set to '1' for CUDA acceleration)
os.environ["RWKV_CUDA_ON"] = '1'

from rwkv_emb.tokenizer import RWKVTokenizer
from rwkv_emb.model import EmbeddingRWKV

# 1. Initialize Tokenizer & Model
model = EmbeddingRWKV(model_path='/path/to/model.pth')
tokenizer = RWKVTokenizer()

# 2. Tokenize (Text -> Tokens)
text = "Hello world! This is RWKV embedding."
# Important: Enable add_eos=True to append the required EOS token (65535)
tokens = tokenizer.encode(text, add_eos=True)
print(f"Tokens: {tokens}") 

# 3. Inference (Tokens -> Embedding)
embedding, state = model.forward(tokens, None)

print(f"Embedding shape: {embedding.shape}")
print(embedding.shape)
```

## ⚡ Batch Inference & Performance Guide

For production use cases, running inference in batches is significantly faster.

### ⚠️ Critical Performance Tip: Pad to Same Length

While the model supports batches with variable sequence lengths, **we strongly recommend padding all sequences to the same length** for maximum GPU throughput.

  - **Pad Token**: `0`
  - **Performance**: Fixed-length batches allow the CUDA kernel to parallelize computation efficiently. Variable-length batches will trigger a slower execution path.

### Batch Example (Recommended)

```python
# Example: Batching two sentences of different lengths
sentences = [
    "Short sentence.",
    "This is a slightly longer sentence for demonstration."
]

# 1. Tokenize all
batch_tokens = [tokenizer.encode(s, add_eos=True) for s in sentences]

# 2. Left pad to the longest sequence in the batch using 0
max_len = max(len(t) for t in batch_tokens)

for i in range(len(batch_tokens)):
    pad_len = max_len - len(batch_tokens[i])
    # insert 0s to the beginning
    batch_tokens[i] = [0] * pad_len + batch_tokens[i]

# batch_tokens is now a rectangular matrix (List of Lists with same length)
print(f"Padded Batch: {batch_tokens}")

# 3. Fast Inference
embeddings, states = model.forward(batch_tokens, None)

# embeddings shape: [Batch_Size, Embedding_Dim]
print("Batch Embeddings:", embeddings.shape)
print("Batch States:", states[0].shape, states[1].shape)
```

## 🎯 RWKVReRanker (State-based Reranker)

The `RWKVReRanker` utilizes the final hidden state produced by the main `EmbeddingRWKV` model to score the relevance between a query and a document.

### Online Mode

#### Workflow
1.  **Format** Query and Document based on Online template.
2.  Run the **Embedding Model** to generate the final State.
3.  Feed the **Attention State** (`state[1]`) into the **ReRanker** to get a relevance score.

#### 📝 Online Mode Usage Example

```python
import torch
from rwkv_emb.tokenizer import RWKVTokenizer
from rwkv_emb.model import EmbeddingRWKV, RWKVReRanker

# 1. Load Models
# The ReRanker weights are stored in the differernt checkpoint
emb_model = EmbeddingRWKV(model_path='/path/to/EmbeddingRWKV.pth')
reranker = RWKVReRanker(model_path='/path/to/RWKVReRanker.pth')

tokenizer = RWKVTokenizer()

# 2. Prepare Data (Query + Candidate Documents)
query = "What represents the end of a sequence?"
documents = [
    "The EOS token is used to mark the end of a sentence.",
    "Apples are red and delicious fruits.",
    "Machine learning requires large datasets."
]

# 3. Construct Input Pairs
# We treat the Query and Document as a single sequence.
pairs = []
online_template = "Instruct: Given a query, retrieve documents that answer the query\nDocument: {document}\nQuery: {query}"
for doc in documents:
    # Format: Instruct + Document + Query
    text = online_template.format(document=doc, query=query)
    pairs.append(text)

# 4. Tokenize & Pad (Critical for Batch Performance)
batch_tokens = [tokenizer.encode(p, add_eos=True) for p in pairs]

# Left pad to same length for efficiency
max_len = max(len(t) for t in batch_tokens)
for i in range(len(batch_tokens)):
    batch_tokens[i] = [0] * (max_len - len(batch_tokens[i])) + batch_tokens[i]

# 5. Get States from Embedding Model
# We don't need the embedding output here, we only need the final 'state'
_, state = emb_model.forward(batch_tokens, None)

# 6. Score with ReRanker
# The ReRanker expects the Attention State: state[1]
# state[1] shape: [Layers, Batch, Heads, HeadSize, HeadSize]
logits = reranker.forward(state[1])
scores = torch.sigmoid(logits) # Convert logits to probabilities (0-1)

# 7. Print Results
print("\nRWKVReRanker Online Scores:")
for doc, score in zip(documents, scores):
    print(f"[{score:.4f}] {doc}")
```

### Offline Mode (Cached Doc State)
For scenarios where documents are static but queries change (e.g., Search Engines, RAG), you can **pre-compute and cache the document states**. This reduces query-time latency from $O(L_{doc} + L_{query})$ to just $O(L_{query})$.

#### Workflow

1.  **Indexing**: Process `Instruct + Document` -\> Save State.
2.  **Querying**: Load State -\> Process `Query` -\> Score.

#### 📝 Offline Mode Usage Example

```python
# --- Phase 1: Indexing (Pre-computation) ---
# Note: Do NOT add EOS here, because the sequence continues with the query later.
doc_template = "Instruct: Given a query, retrieve documents that answer the query\nDocument: {document}\n"
cached_states = []

print("Indexing documents...")
for doc in documents:
    text = doc_template.format(document=doc)
    # add_eos=False is CRITICAL here
    tokens = tokenizer.encode(text, add_eos=False) 
    
    # Forward pass
    _, state = emb_model.forward(tokens, None)
    
    # Move state to CPU to save GPU memory during storage
    # State structure: [Tensor(Tokenshift), Tensor(TimeMix)]
    cpu_state = [s.cpu() for s in state]
    cached_states.append(cpu_state)
# Save cached states to disk (optional)
torch.save(cached_states, 'cached_doc_states.pth')

# --- Phase 2: Querying (Fast Retrieval) ---
query_template = "Query: {query}"
query_text = query_template.format(query=query)
# Now we add EOS to mark the end of the full sequence
query_tokens = tokenizer.encode(query_text, add_eos=True)

print(f"Processing query: '{query}' against {len(cached_states)} cached docs...")

# We can batch the query processing against multiple document states
# 1. Prepare a batch of states (Move back to GPU)
#    Note: We must CLONE/DEEPCOPY because RWKV modifies state in-place!
batch_states = [[], []]
for cpu_s in cached_states:
    batch_states[0].append(cpu_s[0].clone().cuda()) # Tokenshift State
    batch_states[1].append(cpu_s[1].clone().cuda()) # TimeMix State

# Stack into batch tensors
# State[0]: [Layers, 2, 1, Hidden] -> Stack dim 2 -> [Layers, 2, Batch, Hidden]
# State[1]: [Layers, 1, Heads, HeadSize, HeadSize] -> Stack dim 1 -> [Layers, Batch, Heads, ...]
state_input = [
    torch.stack(batch_states[0], dim=2).squeeze(3), 
    torch.stack(batch_states[1], dim=1).squeeze(2)
]

# 2. Prepare query tokens (Broadcast query to batch size)
batch_size = len(documents)
batch_query_tokens = [query_tokens] * batch_size

# 3. Fast Forward (Only processing query tokens!)
_, final_state = emb_model.forward(batch_query_tokens, state_input)
logits = reranker.forward(final_state[1])
scores = torch.sigmoid(logits)

print("\nRWKVReRanker Offline Scores:")
for doc, score in zip(documents, scores):
    print(f"[{score:.4f}] {doc}")
```
