Metadata-Version: 2.4
Name: embedata
Version: 0.1.0
Summary: Unified API for loading 28+ vision datasets with async embeddings from any PyTorch model
License: MIT
License-File: LICENSE
Keywords: datasets,computer-vision,embeddings,pytorch,torchvision
Author: Ayoub G.
Author-email: dev@ayghri.me
Requires-Python: >=3.10
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: Scientific/Engineering :: Image Recognition
Provides-Extra: hf
Requires-Dist: datasets (>=2.14) ; extra == "hf"
Requires-Dist: numpy (>=1.24)
Requires-Dist: torch (>=2.0)
Requires-Dist: torchvision (>=0.15)
Requires-Dist: tqdm (>=4.60)
Project-URL: Homepage, https://github.com/ayghri/embedata
Project-URL: Issues, https://github.com/ayghri/embedata/issues
Project-URL: Repository, https://github.com/ayghri/embedata
Description-Content-Type: text/markdown

# embedata

Unified API for loading 29 vision datasets and extracting embeddings from any PyTorch model.
Handles torchvision, HuggingFace, and custom ImageFolder datasets through a single interface.
Ships with 12 built-in embedding models (CLIP, DINOv2, DINOv3) and supports async extraction on secondary GPUs.

## Install

```bash
pip install embedata              # core (torchvision datasets)
pip install embedata[hf]          # + HuggingFace datasets (ImageNet, CUB, RESISC45)
```

From source:

```bash
git clone https://github.com/ayghri/embedata.git && cd embedata
pip install -e ".[hf]"
```

## Quick start

```python
from embedata import list_datasets, get_dataset, get_datasets, get_dataloaders

print(list_datasets())  # all 29 datasets

train_ds = get_dataset("cifar10", split="train", root_dir="./data")
train_ds, val_ds = get_datasets("cifar10", root_dir="./data")
train_loader, val_loader = get_dataloaders("cifar10", batch_size=128, root_dir="./data")
```

Most datasets auto-download on first use. Datasets that require manual setup (Kaggle downloads, frame extraction, etc.) document their steps via `get_spec("dataset_name").notes`.

## Extracting embeddings

Pass any PyTorch model that maps images to feature vectors. Embeddings are saved as `.npy` files:

```python
import torch
from embedata import get_dataloaders, extract

model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitg14")
model.eval().to("cuda:0")

train_loader, val_loader = get_dataloaders("cifar10", batch_size=256, root_dir="./data")

extract(train_loader, model, device=torch.device("cuda:0"),
        save_dir="./representations/cifar10/dinov2", suffix="train")
extract(val_loader, model, device=torch.device("cuda:0"),
        save_dir="./representations/cifar10/dinov2", suffix="val")
```

Load them back as a PyTorch Dataset:

```python
from embedata import load_embeddings

ds = load_embeddings("cifar10", "dinov2", repr_dir="./representations", split="train")
feat, label = ds[0]
```

For on-the-fly extraction without disk I/O, `EmbeddingDataLoader` runs a model on a secondary device in a background thread:

```python
from embedata import EmbeddingDataLoader, get_dataset

dataset = get_dataset("cifar10", split="train", root_dir="./data")

loader = EmbeddingDataLoader(dataset, model, device="cuda:1", batch_size=256, prefetch=2)

for embeddings, labels in loader:
    ...
```

## Built-in models

The package includes 12 pre-registered models accessible via `load_model`.
Eight CLIP variants (`clipRN50`, `clipRN101`, `clipRN50x4`, `clipRN50x16`, `clipRN50x64`, `clipvitB32`, `clipvitB16`, `clipvitL14`),
DINOv2 ViT-g/14 (`dinov2`), and three DINOv3 variants (`dinov3s`, `dinov3b`, `dinov3l`).

```python
from embedata import list_models, load_model

model, preprocess = load_model("clipvitL14", device="cuda:0", models_dir="./models")
```

The returned `preprocess` is the model's image transform -- pass it to `get_dataloaders` or `get_dataset`. For DINOv2, `preprocess` is `None`; use `get_default_transforms()` instead.

CLIP models require `pip install git+https://github.com/openai/CLIP.git`. DINOv3 models require `pip install transformers`.

## Custom datasets

Register your own dataset loader with the `@register` decorator. The function receives a torchvision transform and a data path, and returns `(train_dataset, val_dataset)`:

```python
from embedata import register

@register("my_dataset", notes="Setup instructions shown by get_spec()")
def _my_dataset(transform, data_path):
    train_ds = ...
    val_ds = ...
    return train_ds, val_ds
```

## HuggingFace streaming

Three datasets (ImageNet, CUB, RESISC45) load via HuggingFace and support streaming mode (requires `pip install embedata[hf]`). ImageNet-1k is gated and requires `huggingface-cli login`.

```python
train_ds, val_ds = get_datasets("imagenet", streaming=True)
```

## Dataset preparation

Some datasets require manual download and preparation before use. Python prepare scripts are bundled with the package:

```bash
python -m embedata.prepare.birdsnap     --root_dir ROOT_DIR
python -m embedata.prepare.fer2013      --root_dir ROOT_DIR
python -m embedata.prepare.ucf101       --root_dir ROOT_DIR [--download]
python -m embedata.prepare.hatefulmemes --root_dir ROOT_DIR
```

Shell scripts for `cars`, `eurosat`, and `sun397` are included under `embedata/prepare/`. All scripts expect raw data under `ROOT_DIR/datasets/{dataset_name}/` and write prepared splits to the same location.

## Available datasets

| Dataset | Train | Val/Test | Classes | Size | Source | Notes |
|---------|------:|----------:|--------:|------|--------|-------|
| aircraft | 6,667 | 3,333 | 100 | variable | FGVCAircraft | trainval / test |
| birdsnap | ~25,000 | ~24,829 | 500 | variable | ImageFolder | manual download + prepare |
| caltech101 | ~3,060 | ~5,587 | 101 | variable | Caltech101 | 30/class for train |
| cars | 8,144 | 8,041 | 196 | variable | StanfordCars | Kaggle download |
| cifar10 | 50,000 | 10,000 | 10 | 32x32 | CIFAR10 | auto-download |
| cifar100 | 50,000 | 10,000 | 100 | 32x32 | CIFAR100 | auto-download |
| clevr | 70,000 | 15,000 | 11 | 320x240 | CLEVRClassification | count 0-10 |
| country211 | 42,200 | 21,100 | 211 | variable | Country211 | train+valid / test |
| cub | 5,994 | 5,794 | 200 | variable | HuggingFace | CUB-200-2011 |
| dtd | 3,760 | 1,880 | 47 | variable | DTD | train+val / test |
| eurosat | 10,000 | 5,000 | 10 | 64x64 | EuroSAT | 1k+500 per class |
| fashionmnist | 60,000 | 10,000 | 10 | 28x28 | FashionMNIST | auto-download |
| fer2013 | 28,709 | 3,589 | 7 | 48x48 | FER2013 | Kaggle + prepare |
| flowers | 2,040 | 6,149 | 102 | variable | Flowers102 | train+val / test |
| food101 | 75,750 | 25,250 | 101 | variable | Food101 | auto-download |
| gtsrb | 26,640 | 12,630 | 43 | variable | GTSRB | auto-download |
| hatefulmemes | ~8,500 | ~500 | 2 | variable | ImageFolder | Kaggle + prepare |
| imagenet | 1,281,167 | 50,000 | 1,000 | variable | HuggingFace | gated, HF login |
| imagenette | 9,469 | 3,925 | 10 | variable | Imagenette | ImageNet subset |
| kinetics700 | varies | varies | 700 | variable | ImageFolder | frame extraction |
| kitti | varies | varies | 4 | variable | ImageFolder | manual prep |
| mnist | 60,000 | 10,000 | 10 | 28x28 | MNIST | auto-download |
| pcam | 294,912 | 32,768 | 2 | 96x96 | PCAM | train+val / test |
| pets | 3,680 | 3,669 | 37 | variable | OxfordIIITPet | auto-download |
| resisc45 | 25,200 | 6,300 | 45 | 256x256 | HuggingFace | remote sensing |
| sst | 7,792 | 1,821 | 2 | variable | RenderedSST2 | train+val / test |
| stl10 | 5,000 | 8,000 | 10 | 96x96 | STL10 | auto-download |
| sun397 | ~19,850 | ~19,850 | 397 | variable | SUN397 | manual + prepare |
| ucf101 | varies | varies | 101 | variable | ImageFolder | frame extraction + prepare |

Counts reflect splits as loaded by embedata (some merge train+val for training).
"variable" means images have different native resolutions -- all are resized by the transform (default 224x224). Datasets marked "auto-download" are fetched on first use.

## License

MIT

