Metadata-Version: 2.4
Name: mpdistil
Version: 0.1.0
Summary: Meta-Policy Knowledge Distillation for compact student models
Home-page: https://github.com/parmanu-lcs2/mpdistil/
Author: Parmanu, LCS2, IIT Delhi
Author-email: 
License: MIT
Project-URL: Homepage, https://github.com/parmanu-lcs2/mpdistil
Project-URL: Repository, https://github.com/parmanu-lcs2/mpdistil
Keywords: knowledge-distillation,meta-learning,deep-learning,nlp,transformers
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Science/Research
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3
Requires-Python: >=3.8
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch>=1.13.0
Requires-Dist: transformers>=4.28.0
Requires-Dist: datasets>=2.12.0
Requires-Dist: numpy>=1.21.0
Requires-Dist: pandas>=1.3.0
Requires-Dist: scikit-learn>=1.0.0
Requires-Dist: scipy>=1.7.0
Requires-Dist: tqdm>=4.64.0
Requires-Dist: wandb>=0.15.0
Requires-Dist: matplotlib>=3.5.0
Dynamic: home-page
Dynamic: license-file
Dynamic: requires-python

# MPDistil 🎓

**Meta-Policy Knowledge Distillation for Training Compact Student Models**

[![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/)
[![PyTorch](https://img.shields.io/badge/PyTorch-1.13+-red.svg)](https://pytorch.org/)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)

MPDistil is a teacher-student collaborative knowledge distillation framework that enables compact student models to outperform larger teacher models through meta-learning and curriculum learning.

Based on the ICLR 2024 paper: [**A Good Learner can Teach Better: Teacher-Student Collaborative Knowledge Distillation**](https://openreview.net/forum?id=Ixi4j6LtdX)

## 🌟 Key Features

- **📊 Superior Performance**: 6-layer BERT student outperforms 12-layer BERT teacher on 5/6 SuperGLUE tasks
- **🎯 4-Phase Training**: Teacher fine-tuning → PKD → Meta-teacher → Curriculum learning
- **🚀 Simple API**: Easy-to-use `.train()` method with full control over all phases
- **📏 Flexible Metrics**: Built-in support for accuracy, F1, MCC, correlation via HuggingFace `evaluate`
- **🔧 Customizable**: Works with any HuggingFace model and custom datasets
- **💻 Colab-Ready**: GPU-optimized for Google Colab environments
- **📦 Easy Installation**: Single pip command to get started

## 📈 Methodology

<p align="center">
  <img width="720" alt="methodology" src="https://github.com/notmyname16/MPDistil/assets/88495622/e9444194-08fd-43ad-880b-94232302f449">
</p>

MPDistil consists of 4 training phases:

1. **Teacher Fine-tuning**: Fine-tune teacher model on the target task
2. **Student PKD**: Knowledge distillation with Patient Knowledge Distillation
3. **Meta-Teacher Learning**: Collaborative or competitive loss for meta-learning
4. **Curriculum Learning**: Reinforcement learning-based task selection

## 🚀 Installation

### From GitHub (Recommended)

```bash
pip install git+https://github.com/parmanu-lcs2/mpdistil.git
```

### From Source

```bash
git clone https://github.com/parmanu-lcs2/mpdistil.git
cd mpdistil
pip install -e .
```

## 💡 Quick Start

### Basic Usage

```python
from mpdistil import MPDistil, load_superglue_dataset

# Load data
loaders, num_labels = load_superglue_dataset('CB')

# Initialize MPDistil
model = MPDistil(
    task_name='CB',
    num_labels=num_labels,
    metric='f1',  # Options: 'accuracy', 'f1', 'mcc', 'correlation', 'auto'
    teacher_model='bert-base-uncased',
    student_model='bert-base-uncased',
    student_layers=6
)

# Train with all 4 phases
history = model.train(
    train_loader=loaders['train'],
    val_loader=loaders['val'],
    teacher_epochs=10,   # Phase 1
    student_epochs=10,   # Phase 2
    meta_epochs=1        # Phase 3 (NEW!)
)

# Save trained student
model.save_student('./my_student_model')

# Make predictions
predictions = model.predict(loaders['test'])
```

### With Custom Data

```python
from mpdistil import MPDistil, create_simple_dataloader

# Prepare your data
texts = [("This is text A", "This is text B"), ...]
labels = [0, 1, 0, ...]

# Create DataLoader
train_loader = create_simple_dataloader(
    texts=texts,
    labels=labels,
    tokenizer_name='bert-base-uncased',
    max_length=128,
    batch_size=8
)

# Train model
model = MPDistil(task_name='MyTask', num_labels=2, metric='accuracy')
history = model.train(train_loader, val_loader)
```

### With Meta-Learning (Curriculum)

```python
# Load multiple tasks for curriculum learning
cb_loaders, _ = load_superglue_dataset('CB')
rte_loaders, _ = load_superglue_dataset('RTE')
boolq_loaders, _ = load_superglue_dataset('BoolQ')

# Train with curriculum learning
history = model.train(
    train_loader=cb_loaders['train'],
    val_loader=cb_loaders['val'],
    meta_loaders={
        'RTE': rte_loaders['val'],
        'BoolQ': boolq_loaders['val']
    },
    teacher_epochs=10,   # Phase 1
    student_epochs=10,   # Phase 2  
    meta_epochs=3,       # Phase 3 - can train for multiple epochs!
    num_episodes=200     # Phase 4 - curriculum learning episodes
)
```

## 📖 API Reference

### MPDistil Class

#### Constructor

```python
MPDistil(
    task_name: str,              # Name of the main task
    num_labels: int,             # Number of output classes
    metric: str = 'accuracy',    # Metric: 'accuracy', 'f1', 'mcc', 'correlation', 'auto'
    teacher_model: str = 'bert-base-uncased',  # HuggingFace model name
    student_model: str = 'bert-base-uncased',  # HuggingFace model name
    student_layers: int = 6,     # Number of layers for student
    device: str = 'auto',        # 'auto', 'cuda', or 'cpu'
    output_dir: str = './mpdistil_outputs'
)
```

#### Methods

| Method | Description |
|--------|-------------|
| `train(train_loader, val_loader, **kwargs)` | Train the model (all 4 phases) |
| `predict(test_loader)` | Generate predictions |
| `save_student(path)` | Save student model in HuggingFace format |
| `load_student(path)` | Load a saved student model |
| `save_predictions(predictions, path, label_mapping)` | Save predictions to TSV |

### TrainingConfig

Configure training hyperparameters:

```python
from mpdistil import TrainingConfig

config = TrainingConfig(
    # Phase 1: Teacher
    teacher_epochs=10,
    teacher_lr=2e-5,
    
    # Phase 2: Student PKD
    student_epochs=10,
    student_lr=3e-5,
    alpha=0.5,          # Soft loss weight
    beta=100.0,         # PKD loss weight
    temperature=5.0,    # Distillation temperature
    
    # Phase 3: Meta-Teacher
    meta_epochs=3,      # NEW! Meta-teacher can train for multiple epochs
    meta_lr=1e-3,
    use_competitive_loss=False,  # Use collaborative loss
    
    # Phase 4: Curriculum
    num_episodes=200,
    reward_type='binary',  # or 'real'
    
    # General
    batch_size=8,
    seed=42,
    report_to=None      # Options: 'wandb', 'tensorboard', None
)

history = model.train(train_loader, val_loader, config=config)
```

### Training Parameters

| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `teacher_epochs` | int | 10 | Phase 1: Teacher training epochs |
| `student_epochs` | int | 10 | Phase 2: Student training epochs |
| `meta_epochs` | int | 1 | Phase 3: Meta-teacher training epochs (NEW!) |
| `num_episodes` | int | 200 | Phase 4: Curriculum learning episodes |
| `teacher_lr` | float | 2e-5 | Teacher learning rate |
| `student_lr` | float | 3e-5 | Student learning rate |
| `meta_lr` | float | 1e-3 | Meta-learning rate |
| `alpha` | float | 0.5 | Soft vs hard loss weight |
| `beta` | float | 100.0 | PKD loss weight |
| `temperature` | float | 5.0 | Distillation temperature |
| `use_competitive_loss` | bool | False | Competitive vs collaborative |
| `reward_type` | str | 'binary' | 'binary' or 'real' |
| `batch_size` | int | 8 | Batch size |
| `seed` | int | 42 | Random seed |

## 📊 Results

Performance on SuperGLUE tasks (BERT-base teacher → BERT-6L student):

| Model | BoolQ | CB | COPA | RTE | WiC | WSC |
|-------|-------|-------|------|-----|-----|-----|
| Teacher | 75.3 | 83.9 | 63.0 | 67.1 | 57.1 | 64.4 |
| Student (Undistilled) | 71.6 | 75.0 | 53.0 | 64.6 | 56.0 | 63.5 |
| **MPDistil (Ours)** | **73.4** | **83.9** | **70.0** | **67.5** | **59.6** | **65.4** |

✨ **Student outperforms teacher on 5/6 tasks!**

## 📝 Examples

See the `examples/` directory for Jupyter notebooks:

- **[GLUE.ipynb](examples/GLUE.ipynb)**: Usage with SuperGLUE

## 📊 Evaluation Metrics

MPDistil supports multiple evaluation metrics via HuggingFace `evaluate` library:

### Available Metrics

| Metric | Use Case | Returns |
|--------|----------|---------|
| `'accuracy'` | Standard classification (default) | `{'acc': float}` |
| `'f1'` | Imbalanced datasets, multi-class | `{'acc': float, 'f1': float, 'acc_and_f1': float}` |
| `'mcc'` | Binary classification, imbalanced | `{'mcc': float}` |
| `'correlation'` | Regression tasks (STS-B) | `{'pearson': float, 'spearmanr': float}` |
| `'auto'` | Auto-detect based on task | Task-specific metric |

### Example: Using Different Metrics

```python
# Accuracy (default)
model = MPDistil(task_name='BoolQ', num_labels=2, metric='accuracy')

# F1 score (recommended for CB, MultiRC)
model = MPDistil(task_name='CB', num_labels=3, metric='f1')

# Matthews Correlation (for binary with imbalance)
model = MPDistil(task_name='CoLA', num_labels=2, metric='mcc')

# Correlation (for regression)
model = MPDistil(task_name='STS-B', num_labels=1, metric='correlation', is_regression=True)

# Auto-detect
model = MPDistil(task_name='CB', num_labels=3, metric='auto')  # Uses F1 for CB
```

## 🔬 How It Works

### Phase 1: Teacher Fine-tuning
Fine-tune a large teacher model (e.g., BERT-base) on your target task.

### Phase 2: Student PKD
Train a smaller student model using:
- **Soft targets** from teacher (KL divergence)
- **Hard labels** (cross-entropy)
- **Patient KD** (intermediate feature matching)

Loss: `α * soft_loss + (1-α) * hard_loss + β * pkd_loss`

### Phase 3: Meta-Teacher
Create a meta-teacher that learns from both teacher and student representations:

**Collaborative loss** (default):
```
L = 0.5 * CE(T'(h_teacher), y) + 0.5 * CE(T'(h_student), y)
```

**Competitive loss**:
```
L = -mean(P_teacher) + mean(P_student) + CE_loss
```

### Phase 4: Curriculum Learning
Use reinforcement learning to select which auxiliary tasks help the student learn:
- **Action model** selects next task
- **Reward** based on student improvement over teacher
- **REINFORCE** algorithm updates policy

## 🛠️ Advanced Usage

### Custom Model Architectures

```python
model = MPDistil(
    task_name='MyTask',
    num_labels=3,
    teacher_model='roberta-large',
    student_model='distilbert-base-uncased',
    student_layers=6
)
```

### Weights & Biases Logging

```python
history = model.train(
    train_loader=train_loader,
    val_loader=val_loader,
    report_to='wandb',      # Options: 'wandb', 'tensorboard', None
    wandb_project='my-project',
    wandb_run_name='experiment-1'
)
```

### Access Trained Models

```python
# Access student model
student = model.student

# Access teacher model
teacher = model.teacher

# Use with HuggingFace
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained('./my_student_model')
model = AutoModel.from_pretrained('./my_student_model')
```

## 📚 Citation

If you use MPDistil in your research, please cite:

```bibtex
@inproceedings{sengupta2024mpdistil,
  title={A Good Learner can Teach Better: Teacher-Student Collaborative Knowledge Distillation},
  author={Sengupta, Ayan and Dixit, Shantanu and Akhtar, Md Shad and Chakraborty, Tanmoy},
  booktitle={The Twelfth International Conference on Learning Representations},
  year={2024},
  url={https://openreview.net/forum?id=Ixi4j6LtdX}
}
```

## 📄 License

This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.

## 🤝 Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

## 🙏 Acknowledgments

- Original paper: [ICLR 2024](https://openreview.net/forum?id=Ixi4j6LtdX)
- Built with [PyTorch](https://pytorch.org/) and [HuggingFace Transformers](https://huggingface.co/transformers/)

## 📧 Contact

For questions or issues, please open an issue on GitHub or contact the authors.

---

Made with ❤️ for the research community
