Metadata-Version: 2.4
Name: aitraining
Version: 0.0.51
Summary: Advanced Machine Learning Training Platform - IN DEVELOPMENT
Author-email: Andrew Correa <andrew@monostate.ai>
License: Apache 2.0
Keywords: automl,autonlp,autotrain,huggingface,machine learning,deep learning,fine-tuning,llm
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Education
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: Apache Software License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Python: >=3.8
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch>=1.10.0
Requires-Dist: transformers>=5.3.0
Requires-Dist: accelerate>=1.13.0
Requires-Dist: peft>=0.18.1
Requires-Dist: trl>=0.29.0
Requires-Dist: datasets[vision]~=3.2.0
Requires-Dist: tensorboard==2.18.0
Requires-Dist: scikit-learn==1.6.0
Requires-Dist: evaluate==0.4.3
Requires-Dist: sentence-transformers>=5.2.3
Requires-Dist: albumentations==1.4.23
Requires-Dist: Pillow==11.0.0
Requires-Dist: pandas==2.2.3
Requires-Dist: numpy
Requires-Dist: sentencepiece>=0.1.99
Requires-Dist: tiktoken==0.8.0
Requires-Dist: nltk==3.9.1
Requires-Dist: sacremoses==0.1.1
Requires-Dist: seqeval==1.2.2
Requires-Dist: ipadic==1.0.0
Requires-Dist: timm==1.0.22
Requires-Dist: torchmetrics==1.6.0
Requires-Dist: faster-coco-eval==1.7.0
Requires-Dist: pycocotools==2.0.8
Requires-Dist: optuna==4.1.0
Requires-Dist: xgboost==2.1.3
Requires-Dist: huggingface_hub>=1.6.0
Requires-Dist: hf-transfer
Requires-Dist: fastapi>=0.129.0
Requires-Dist: starlette
Requires-Dist: uvicorn==0.34.0
Requires-Dist: python-multipart==0.0.20
Requires-Dist: werkzeug==3.1.3
Requires-Dist: pyngrok==7.2.1
Requires-Dist: authlib==1.4.0
Requires-Dist: itsdangerous==2.2.0
Requires-Dist: httpx==0.28.1
Requires-Dist: loguru==0.7.3
Requires-Dist: requests==2.32.3
Requires-Dist: pydantic>=2.12.5
Requires-Dist: pyyaml==6.0.2
Requires-Dist: tqdm==4.67.1
Requires-Dist: joblib==1.4.2
Requires-Dist: einops==0.8.0
Requires-Dist: packaging==24.2
Requires-Dist: cryptography==44.0.0
Requires-Dist: nvitop==1.3.2
Requires-Dist: py7zr==0.22.0
Requires-Dist: safetensors
Requires-Dist: psutil
Requires-Dist: diffusers
Requires-Dist: ipywidgets
Requires-Dist: wandb>=0.23.0
Requires-Dist: textual==0.86.1
Requires-Dist: rich==13.9.4
Requires-Dist: rouge_score==0.1.2
Requires-Dist: jiwer==3.0.5
Requires-Dist: bitsandbytes>=0.45.0; sys_platform == "linux"
Provides-Extra: dev
Requires-Dist: black; extra == "dev"
Requires-Dist: isort; extra == "dev"
Requires-Dist: flake8; extra == "dev"
Requires-Dist: pytest; extra == "dev"
Requires-Dist: pytest-asyncio; extra == "dev"
Provides-Extra: test
Requires-Dist: pytest>=7.0.0; extra == "test"
Requires-Dist: pytest-asyncio; extra == "test"
Requires-Dist: playwright>=1.40.0; extra == "test"
Requires-Dist: httpx>=0.25.0; extra == "test"
Provides-Extra: vllm
Requires-Dist: vllm>=0.14.0; extra == "vllm"
Provides-Extra: quality
Requires-Dist: black; extra == "quality"
Requires-Dist: isort; extra == "quality"
Requires-Dist: flake8; extra == "quality"
Provides-Extra: docs
Requires-Dist: recommonmark; extra == "docs"
Requires-Dist: sphinx; extra == "docs"
Requires-Dist: sphinx-markdown-tables; extra == "docs"
Requires-Dist: sphinx-rtd-theme; extra == "docs"
Requires-Dist: sphinx-copybutton; extra == "docs"
Dynamic: license-file

<p align="center">
  <img src="https://raw.githubusercontent.com/monostate/aitraining/main/docs/images/terminal-wizard.png" alt="AITraining Interactive Wizard" width="700">
</p>

<p align="center">
  <a href="https://pypi.org/project/aitraining/"><img src="https://img.shields.io/pypi/v/aitraining.svg" alt="PyPI version"></a>
  <a href="https://pypi.org/project/aitraining/"><img src="https://img.shields.io/pypi/pyversions/aitraining.svg" alt="Python versions"></a>
  <a href="https://github.com/monostate/aitraining/blob/main/LICENSE"><img src="https://img.shields.io/badge/License-Apache%202.0-blue.svg" alt="License"></a>
  <a href="https://docs.monostate.com"><img src="https://img.shields.io/badge/docs-monostate.com-FF6B35.svg" alt="Documentation"></a>
  <a href="https://deepwiki.com/monostate/aitraining"><img src="https://deepwiki.com/badge.svg" alt="Ask DeepWiki"></a>
</p>

<p align="center">
  <b>Train state-of-the-art ML models with minimal code</b>
</p>

<p align="center">
  English | <a href="README_PTBR.md">Portugues</a>
</p>

<p align="center">
  📚 <b><a href="https://docs.monostate.com">Full Documentation →</a></b>
</p>

---

> **📖 Comprehensive Documentation Available**
> 
> Visit **[docs.monostate.com](https://docs.monostate.com)** for detailed guides, tutorials, API reference, and examples covering all features including LLM fine-tuning, PEFT/LoRA, DPO/ORPO training, hyperparameter sweeps, and more.

---

AITraining is an advanced machine learning training platform built on top of [AutoTrain Advanced](https://github.com/huggingface/autotrain-advanced). It provides a streamlined interface for fine-tuning LLMs, vision models, and more.

## Highlights

### Automatic Dataset Conversion

Feed any dataset format and AITraining automatically detects and converts it. Supports 6 input formats with automatic detection:

| Format | Detection | Example Columns |
|--------|-----------|-----------------|
| **Alpaca** | instruction/input/output | `{"instruction": "...", "output": "..."}` |
| **ShareGPT** | from/value pairs | `{"conversations": [{"from": "human", ...}]}` |
| **Messages** | role/content | `{"messages": [{"role": "user", ...}]}` |
| **Q&A** | question/answer variants | `{"question": "...", "answer": "..."}` |
| **DPO** | prompt/chosen/rejected | For preference training |
| **Plain Text** | Single text column | Raw text for pretraining |

```bash
aitraining llm --train --auto-convert-dataset --chat-template gemma3 \
  --data-path tatsu-lab/alpaca --model google/gemma-3-270m-it
```

### 32 Chat Templates

Comprehensive template library with token-level weight control:

- **Llama family**: llama, llama-3, llama-3.1
- **Gemma family**: gemma, gemma-2, gemma-3, gemma-3n
- **Others**: mistral, qwen-2.5, phi-3, phi-4, chatml, alpaca, vicuna, zephyr

```python
from autotrain.rendering import get_renderer, ChatFormat, RenderConfig

config = RenderConfig(format=ChatFormat.CHATML, only_assistant=True)
renderer = get_renderer('chatml', tokenizer, config)
encoded = renderer.build_supervised_example(conversation)
# Returns: {'input_ids', 'labels', 'token_weights', 'attention_mask'}
```

### GRPO Training with Custom Environments

Train with Group Relative Policy Optimization using your own reward environment. The env provides prompts and scores multi-turn episodes — GRPO generates completions, scores them, and optimizes:

```bash
aitraining llm --train --trainer grpo \
  --model deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \
  --rl-env-module my_envs.hotel_env \
  --rl-env-class HotelEnv \
  --rl-num-generations 4 \
  --rl-max-new-tokens 256
```

Your environment implements three methods:

```python
class HotelEnv:
    def build_dataset(self, tokenizer) -> Dataset:
        """Return HF Dataset with 'prompt' column."""

    def score_episode(self, model, tokenizer, completion, case_idx) -> float:
        """Run multi-turn episode from completion, return 0.0-1.0 score."""

    def get_tools(self) -> list[dict]:
        """Return tool schemas for generation (optional)."""
```

### Custom RL Environments (PPO)

Build custom reward functions for PPO training with three environment types:

```bash
# Text generation with custom reward
aitraining llm --train --trainer ppo \
  --rl-env-type text_generation \
  --rl-env-config '{"stop_sequences": ["</s>"]}' \
  --rl-reward-model-path ./reward_model

# Multi-objective rewards (correctness + formatting)
aitraining llm --train --trainer ppo \
  --rl-env-type multi_objective \
  --rl-env-config '{"reward_components": {"correctness": {"type": "keyword"}, "formatting": {"type": "length"}}}' \
  --rl-reward-weights '{"correctness": 1.0, "formatting": 0.1}'
```

### Hyperparameter Sweeps

Automated optimization with Optuna, random search, or grid search:

```python
from autotrain.utils import HyperparameterSweep, SweepConfig, ParameterRange

config = SweepConfig(
    backend="optuna",
    optimization_metric="eval_loss",
    optimization_mode="minimize",
    num_trials=20,
)

sweep = HyperparameterSweep(
    objective_function=train_model,
    config=config,
    parameters=[
        ParameterRange("learning_rate", "log_uniform", low=1e-5, high=1e-3),
        ParameterRange("batch_size", "categorical", choices=[4, 8, 16]),
    ]
)
result = sweep.run()
# Returns best_params, best_value, trial history
```

### Enhanced Evaluation Metrics

8 metrics beyond loss, with callbacks for periodic evaluation:

| Metric | Type | Use Case |
|--------|------|----------|
| **Perplexity** | Auto-computed | Language model quality |
| **BLEU** | Generation | Translation, summarization |
| **ROUGE** (1/2/L) | Generation | Summarization |
| **BERTScore** | Generation | Semantic similarity |
| **METEOR** | Generation | Translation |
| **F1/Accuracy** | Classification | Standard metrics |
| **Exact Match** | QA | Question answering |

```python
from autotrain.evaluation import Evaluator, EvaluationConfig, MetricType

config = EvaluationConfig(
    metrics=[MetricType.PERPLEXITY, MetricType.BLEU, MetricType.ROUGE, MetricType.BERTSCORE],
    save_predictions=True,
)
evaluator = Evaluator(model, tokenizer, config)
result = evaluator.evaluate(dataset)
```

### Auto LoRA Merge

After PEFT training, automatically merge adapters and save deployment-ready models:

```bash
# Default: merges adapters into full model
aitraining llm --train --peft --model meta-llama/Llama-3.2-1B

# Keep adapters separate (smaller files)
aitraining llm --train --peft --no-merge-adapter --model meta-llama/Llama-3.2-1B
```

---

## Screenshots

<p align="center">
  <img src="https://raw.githubusercontent.com/monostate/aitraining/main/docs/images/chat-screenshot.png" alt="Chat interface for testing trained models" width="700">
  <br>
  <em>Built-in chat interface for testing trained models with conversation history</em>
</p>

<p align="center">
  <img src="https://raw.githubusercontent.com/monostate/aitraining/main/docs/images/tui-wandb.png" alt="Terminal UI with W&B LEET integration" width="700">
  <br>
  <em>Terminal UI with real-time W&B LEET metrics visualization</em>
</p>

---

## Installation

```bash
pip install aitraining
```

Requirements: Python >= 3.10, PyTorch

## Quick Start

### Interactive Wizard

```bash
aitraining
```

The wizard guides you through:
1. Trainer type selection (LLM, vision, NLP, tabular)
2. Model selection with curated catalogs from HuggingFace
3. Dataset configuration with auto-format detection
4. Advanced parameters (PEFT, quantization, sweeps)

### Config File

```bash
aitraining --config config.yaml
```

### Python API

```python
from autotrain.trainers.clm import train
from autotrain.trainers.clm.params import LLMTrainingParams

config = LLMTrainingParams(
    model="meta-llama/Llama-3.2-1B",
    data_path="your-dataset",
    trainer="sft",
    epochs=3,
    batch_size=4,
    lr=2e-5,
    peft=True,
    auto_convert_dataset=True,
    chat_template="llama3",
)

train(config)
```

---

## Comparison

### AITraining vs AutoTrain vs Tinker

| Feature | AutoTrain | AITraining | Tinker |
|---------|-----------|------------|--------|
| **Trainers** |
| SFT/DPO/ORPO | Yes | Yes | Yes |
| PPO (RLHF) | Basic | Enhanced (TRL) | Advanced |
| GRPO | No | Yes (TRL 0.28) | Custom |
| Reward Modeling | Yes | Yes | No |
| Knowledge Distillation | No | Yes (KL + CE loss) | Yes (text-only) |
| **Data** |
| Auto Format Detection | No | Yes (6 formats) | No |
| Chat Template Library | Basic | 32 templates | 5 templates |
| Runtime Column Mapping | No | Yes | No |
| Conversation Extension | No | Yes | No |
| **Training** |
| Hyperparameter Sweeps | No | Yes (Optuna) | Manual |
| Custom RL Environments | No | Yes (3 types) | Yes |
| Multi-objective Rewards | No | Yes | Yes |
| Forward-Backward Pipeline | No | Yes | Yes |
| Async Off-Policy RL | No | No | Yes |
| Stream Minibatch | No | No | Yes |
| **Evaluation** |
| Metrics Beyond Loss | No | 8 metrics | Manual |
| Periodic Eval Callbacks | No | Yes | Yes |
| Custom Metric Registration | No | Yes | No |
| **Interface** |
| Interactive CLI Wizard | No | Yes | No |
| TUI (Experimental) | No | Yes | No |
| W&B LEET Visualizer | No | Yes | Yes |
| **Hardware** |
| Apple Silicon (MPS) | Limited | Full | No |
| Quantization (int4/int8) | Yes | Yes | Unknown |
| Multi-GPU | Yes | Yes | Yes |
| **Task Coverage** |
| Vision Tasks | Yes | Yes | No |
| NLP Tasks | Yes | Yes | No |
| Tabular Tasks | Yes | Yes | No |
| Tool Use Environments | No | Yes (GRPO) | Yes |
| Multiplayer RL | No | No | Yes |

---

## Supported Tasks

| Task | Trainers | Status |
|------|----------|--------|
| LLM Fine-tuning | SFT, DPO, ORPO, PPO, GRPO, Reward, Distillation | Stable |
| Text Classification | Single/Multi-label | Stable |
| Token Classification | NER, POS tagging | Stable |
| Sequence-to-Sequence | Translation, Summarization | Stable |
| Image Classification | Single/Multi-label | Stable |
| Object Detection | YOLO, DETR | Stable |
| VLM Training | Vision-Language Models | Beta |
| Tabular | XGBoost, sklearn | Stable |
| Sentence Transformers | Semantic similarity | Stable |
| Extractive QA | SQuAD format | Stable |

---

## Configuration Example

```yaml
task: llm-sft
base_model: meta-llama/Llama-3.2-1B
project_name: my-finetune

data:
  path: your-dataset
  train_split: train
  auto_convert_dataset: true
  chat_template: llama3

params:
  epochs: 3
  batch_size: 4
  lr: 2e-5
  peft: true
  lora_r: 16
  lora_alpha: 32
  quantization: int4
  mixed_precision: bf16

# Optional: hyperparameter sweep
sweep:
  enabled: true
  backend: optuna
  n_trials: 10
  metric: eval_loss
```

---

## Documentation

**📚 [docs.monostate.com](https://docs.monostate.com)** — Complete documentation with tutorials, API reference, and examples.

### Quick Links

- [Getting Started](https://docs.monostate.com/foundations/quickstart)
- [LLM Fine-tuning Guide](https://docs.monostate.com/cli/llm-training)
- [YAML Configuration](https://docs.monostate.com/cli/yaml-configs)
- [Python SDK Reference](https://docs.monostate.com/api/python-sdk)
- [Advanced Training (DPO/ORPO/PPO)](https://docs.monostate.com/advanced/dpo-training)
- [Changelog](https://docs.monostate.com/changelog)

### Local Docs

- [Interactive Wizard Guide](docs/interactive_wizard.md)
- [Dataset Formats & Conversion](docs/dataset_formats.md)
- [Trainer Reference](docs/trainers/README.md)
- [Python API](docs/api/PYTHON_API.md)
- [RL API Reference](docs/reference/RL_API_REFERENCE.md)

---

## License

Apache 2.0 - See [LICENSE](LICENSE) for details.

Based on [AutoTrain Advanced](https://github.com/huggingface/autotrain-advanced) by Hugging Face.

---

<p align="center">
  <a href="https://monostate.ai">Monostate AI</a>
</p>
