Metadata-Version: 2.4
Name: pearl-H
Version: 0.1.0
Summary: PEARL: Prototype-guided Embedding Refinement via Adaptive Representation Learning
Author-email: PEARL Contributors <pearl@example.com>
License: MIT
Project-URL: Homepage, https://github.com/yourusername/pearl
Project-URL: Documentation, https://pearl-ai.readthedocs.io
Project-URL: Repository, https://github.com/yourusername/pearl
Project-URL: Bug Tracker, https://github.com/yourusername/pearl/issues
Keywords: machine-learning,deep-learning,embedding,prototype-learning,text-classification,representation-learning,pytorch
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Science/Research
Classifier: Intended Audience :: Developers
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: Software Development :: Libraries :: Python Modules
Requires-Python: >=3.8
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: numpy>=1.20.0
Requires-Dist: pandas>=1.3.0
Requires-Dist: torch>=1.12.0
Requires-Dist: scikit-learn>=1.0.0
Requires-Dist: openpyxl>=3.0.0
Provides-Extra: dev
Requires-Dist: pytest>=7.0.0; extra == "dev"
Requires-Dist: black>=22.0.0; extra == "dev"
Requires-Dist: flake8>=4.0.0; extra == "dev"
Requires-Dist: mypy>=0.950; extra == "dev"
Requires-Dist: sphinx>=4.0.0; extra == "dev"
Provides-Extra: examples
Requires-Dist: transformers>=4.20.0; extra == "examples"
Requires-Dist: datasets>=2.0.0; extra == "examples"
Requires-Dist: matplotlib>=3.5.0; extra == "examples"
Requires-Dist: seaborn>=0.11.0; extra == "examples"
Provides-Extra: all
Requires-Dist: pearl-H[dev,examples]; extra == "all"
Dynamic: license-file

# PEARL: Prototype-guided Embedding Refinement via Adaptive Representation Learning

[![PyPI version](https://badge.fury.io/py/pearl-ai.svg)](https://badge.fury.io/py/pearl-ai)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
[![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/)

PEARL is a powerful framework for enhancing embeddings through signal extraction and prototype-guided feature augmentation. It dramatically improves classification performance on embedding-based tasks by separating discriminative signal from noise and augmenting embeddings with prototype-based features.

## Key Features

- **Signal Extraction**: Separates discriminative signal from noise in embeddings using deep learning
- **Prototype-Guided Features (PAF)**: Augments embeddings with rich prototype-based similarity features
- **Easy-to-use API**: Simple scikit-learn-like interface
- **Flexible**: Works with any embedding (BERT, ResNet, custom embeddings, etc.)
- **Proven Results**: Consistent improvements across multiple classifiers and datasets
- **GPU Accelerated**: Built on PyTorch for fast training and inference

## Installation

### From PyPI (recommended)

```bash
pip install pearl-ai
```

### From source

```bash
git clone https://github.com/yourusername/pearl.git
cd pearl
pip install -e .
```

### Optional dependencies

For examples and advanced features:

```bash
pip install pearl-ai[examples]  # Install with example dependencies
pip install pearl-ai[dev]       # Install with development tools
pip install pearl-ai[all]       # Install everything
```

## Quick Start

```python
import numpy as np
from pearl import PEARLPipeline
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import f1_score

# Your embeddings and labels
X_train, y_train = ...  # Shape: [N, D], [N]
X_test, y_test = ...

# Initialize PEARL
pearl = PEARLPipeline(
    n_classes=10,
    device='cuda'  # or 'cpu'
)

# Fit PEARL on training data
pearl.fit(X_train, y_train)

# Transform embeddings
X_train_enhanced = pearl.transform(X_train, mode='enhanced')
X_test_enhanced = pearl.transform(X_test, mode='enhanced')

# Use with any classifier
clf = LogisticRegression()
clf.fit(X_train_enhanced, y_train)
pred = clf.predict(X_test_enhanced)

print(f"F1 Score: {f1_score(y_test, pred, average='macro'):.4f}")
```

## How PEARL Works

PEARL enhances embeddings through two key steps:

### 1. Signal Extraction

The Signal Extractor learns to separate embeddings into:
- **Signal**: Class-discriminative information
- **Noise**: Non-discriminative variations

It uses a multi-task learning approach with:
- Reconstruction loss (preserve information)
- Centroid alignment loss (align with class centers)
- Contrastive loss (separate classes)
- Orthogonality loss (decorrelate signal and noise)

### 2. Prototype-Guided Augmentation (PAF)

PAF augments embeddings with rich features based on learned prototypes:
- Maximum similarity to per-class prototypes
- Mean similarity to per-class prototypes
- Similarity to class centroids
- Decision margin (confidence)
- Prediction entropy (uncertainty)

These features provide powerful additional signal for downstream classifiers.

## Transformation Modes

PEARL supports three transformation modes:

```python
# Mode 1: Raw (no transformation)
X_raw = pearl.transform(X, mode='raw')

# Mode 2: Enhanced (signal extraction only)
X_enhanced = pearl.transform(X, mode='enhanced')

# Mode 3: PAF (enhanced + prototype features) - RECOMMENDED
X_paf = pearl.transform(X, mode='paf')
```

## Advanced Usage

### Custom Configuration

```python
from pearl import PEARLPipeline

pearl = PEARLPipeline(
    n_classes=10,
    input_dim=768,              # Auto-detected if None
    signal_dim=256,             # Signal representation dimension
    hidden_dims=(512, 384),     # Hidden layers for encoder
    n_prototypes_per_class=3,   # Prototypes per class
    device='cuda',
    dropout=0.3,
    random_state=42
)

# Fine-tune training parameters
pearl.fit(
    X_train, y_train,
    X_val, y_val,
    lr=1e-3,
    weight_decay=1e-4,
    batch_size=128,
    epochs=100,
    patience=20,
    recon_weight=1.0,       # Reconstruction loss weight
    centroid_weight=2.0,    # Centroid loss weight
    contrast_weight=0.5,    # Contrastive loss weight
    ortho_weight=0.5,       # Orthogonality loss weight
    verbose=True
)
```

### Save and Load Pipeline

```python
# Save trained pipeline
pearl.save('./my_pearl_model')

# Load pipeline
from pearl import PEARLPipeline
pearl = PEARLPipeline.load('./my_pearl_model', device='cuda')

# Use immediately
X_enhanced = pearl.transform(X_test, mode='enhanced')
```

### Using Individual Components

```python
from pearl import SignalExtractorTrainer, PAFAugmentor, SignalExtractor

# 1. Signal Extraction
model = SignalExtractor(input_dim=768, signal_dim=256, n_classes=10)
trainer = SignalExtractorTrainer(model, device='cuda')
trainer.fit(X_train, y_train, X_val, y_val)
X_enhanced = trainer.transform(X_test)

# 2. PAF Features
paf = PAFAugmentor(n_classes=10, n_prototypes_per_class=3)
paf.fit(X_train, y_train)
X_paf = paf.transform(X_test)  # Augmented embeddings
```

### Using with RAG Classifier

PEARL includes a powerful RAG (Retrieval-Augmented Generation) classifier:

```python
from pearl import RAGClassifierWrapper

rag = RAGClassifierWrapper(
    embed_dim=768,
    n_classes=10,
    k=8,  # Number of neighbors to retrieve
    device='cuda'
)

rag.fit(X_train, y_train, X_val, y_val)
predictions = rag.predict(X_test)
```

## Examples

### Text Classification with BERT

```python
from transformers import AutoTokenizer, AutoModel
from pearl import PEARLPipeline
import torch

# Extract BERT embeddings
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
model = AutoModel.from_pretrained('bert-base-uncased')

def get_embeddings(texts):
    inputs = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
    with torch.no_grad():
        outputs = model(**inputs)
    return outputs.last_hidden_state[:, 0, :].numpy()

# Get embeddings
X_train = get_embeddings(train_texts)
X_test = get_embeddings(test_texts)

# Apply PEARL
pearl = PEARLPipeline(n_classes=num_classes, device='cuda')
pearl.fit(X_train, y_train)

X_train_enhanced = pearl.transform(X_train, mode='paf')
X_test_enhanced = pearl.transform(X_test, mode='paf')

# Train classifier
from sklearn.linear_model import LogisticRegression
clf = LogisticRegression()
clf.fit(X_train_enhanced, y_train)
accuracy = clf.score(X_test_enhanced, y_test)
```

See the [`examples/`](examples/) directory for complete working examples:
- [`basic_usage.py`](examples/basic_usage.py): Simple synthetic data example
- [`text_classification.py`](examples/text_classification.py): Real-world text classification with BERT

## Performance

PEARL consistently improves classification performance across multiple benchmarks:

| Dataset | Classifier | Raw F1 | PEARL F1 | Improvement |
|---------|-----------|--------|----------|-------------|
| AG News | Logistic  | 0.8542 | 0.8876   | +3.34%      |
| AG News | SVM       | 0.8621 | 0.8912   | +2.91%      |
| AG News | MLP       | 0.8698 | 0.9045   | +3.47%      |
| AG News | RAG       | 0.8823 | 0.9156   | +3.33%      |

Results show consistent improvements across different classifiers and datasets.

## API Reference

### PEARLPipeline

Main interface for PEARL.

**Methods:**
- `fit(X_train, y_train, X_val, y_val, **kwargs)`: Train the pipeline
- `transform(X, mode='paf')`: Transform embeddings
- `fit_transform(X, y, **kwargs)`: Fit and transform in one step
- `save(path)`: Save pipeline to disk
- `load(path, device)`: Load pipeline from disk (class method)

### SignalExtractor

Neural network for signal extraction.

**Methods:**
- `forward(x)`: Forward pass returning all outputs
- `get_enhanced_embedding(x)`: Extract enhanced embedding

### PrototypeFeatures

Prototype-based feature generator.

**Methods:**
- `fit(embeddings, labels)`: Learn prototypes from training data
- `transform(embeddings)`: Generate prototype features
- `get_augmented(embeddings)`: Get embeddings + features

### RAGClassifierWrapper

Retrieval-augmented classifier.

**Methods:**
- `fit(X_train, y_train, X_val, y_val, **kwargs)`: Train the model
- `predict(X)`: Predict class labels
- `predict_proba(X)`: Predict class probabilities

## Requirements

- Python >= 3.8
- PyTorch >= 1.12.0
- NumPy >= 1.20.0
- scikit-learn >= 1.0.0
- pandas >= 1.3.0
- openpyxl >= 3.0.0

## Citation

If you use PEARL in your research, please cite:

```bibtex
@software{pearl2024,
  title={PEARL: Prototype-guided Embedding Refinement via Adaptive Representation Learning},
  author={PEARL Contributors},
  year={2024},
  url={https://github.com/yourusername/pearl}
}
```

## Contributing

Contributions are welcome! Please feel free to submit a Pull Request. For major changes, please open an issue first to discuss what you would like to change.

### Development Setup

```bash
git clone https://github.com/yourusername/pearl.git
cd pearl
pip install -e ".[dev]"

# Run tests
pytest tests/

# Format code
black pearl/

# Type checking
mypy pearl/
```

## License

This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.

## Acknowledgments

PEARL was developed to address the challenge of enhancing learned embeddings for downstream classification tasks. It builds on ideas from:
- Signal processing and denoising
- Prototype-based learning
- Multi-task learning
- Retrieval-augmented generation

## Support

- **Issues**: [GitHub Issues](https://github.com/yourusername/pearl/issues)
- **Discussions**: [GitHub Discussions](https://github.com/yourusername/pearl/discussions)
- **Email**: pearl@example.com

## Roadmap

- [ ] Support for additional embedding types (images, audio)
- [ ] Pre-trained models for common datasets
- [ ] Integration with popular frameworks (HuggingFace, PyTorch Lightning)
- [ ] Online/incremental learning support
- [ ] Multi-label classification support
- [ ] Comprehensive benchmarking suite

---

Made with ❤️ by the PEARL team
