Project Structure:
├── .git
│   ├── hooks
│   │   ├── applypatch-msg.sample
│   │   ├── commit-msg.sample
│   │   ├── fsmonitor-watchman.sample
│   │   ├── post-update.sample
│   │   ├── pre-applypatch.sample
│   │   ├── pre-commit.sample
│   │   ├── pre-merge-commit.sample
│   │   ├── pre-push.sample
│   │   ├── pre-rebase.sample
│   │   ├── pre-receive.sample
│   │   ├── prepare-commit-msg.sample
│   │   ├── push-to-checkout.sample
│   │   └── update.sample
│   ├── info
│   │   └── exclude
│   ├── logs
│   │   ├── refs
│   │   │   ├── heads
│   │   │   │   ├── dev
│   │   │   │   └── main
│   │   │   └── remotes
│   │   │       └── origin
│   │   │           ├── HEAD
│   │   │           ├── dev
│   │   │           └── main
│   │   └── HEAD
│   ├── objects
│   │   ├── 04
│   │   │   └── 2276a9d3a20ce3434c191c0e0e73dbabe4919d
│   │   ├── 0a
│   │   │   └── 65f544feea4e36f70f59ce7f41c4a084d66967
│   │   ├── 0c
│   │   │   └── 53e80e8fdbc66d0e6a366aeda956a0f4d3c07c
│   │   ├── 0d
│   │   │   └── 3127e0a301c23ea7af9aa1870ed4ba69613773
│   │   ├── 18
│   │   │   └── f6b7513b0bdc8c779099d0af3ce77bd82c585e
│   │   ├── 21
│   │   │   └── dd4f68a229c6f36d6c0a2181b2ec70b792d102
│   │   ├── 26
│   │   │   └── b0f937394c27c401a0e99f99079958018180a0
│   │   ├── 28
│   │   │   └── 8a79048cbe4e8edc72914ca8b6de7c166b77a0
│   │   ├── 2b
│   │   │   └── d684b2fe37832bb6cb64b87b146d0eac403d5a
│   │   ├── 2d
│   │   │   └── f30214d09c9632ef2ed1121aafa02d09c7432c
│   │   ├── 2e
│   │   │   └── adc2356b56eeace0f290d591c71e213f17a143
│   │   ├── 3d
│   │   │   └── a17d137f72addae88e7a83dd77c6ecbf96f069
│   │   ├── 44
│   │   │   └── 9f99cced041eb5f6691d1857082f8513c634e0
│   │   ├── 52
│   │   │   └── eb4c52ec155f086669f1491269d42fcaaab053
│   │   ├── 53
│   │   │   └── 6262e1cf413a24cc510e454f3b0a9d56c44225
│   │   ├── 55
│   │   │   └── c69d3dbaad8414d9943852ea5aad9930694536
│   │   ├── 58
│   │   │   └── 3c75eb57ed6d3a342a5f2b633a79a768f7c80f
│   │   ├── 62
│   │   │   └── e2444019fdefef5c82d0a766afbf4f2be4bce3
│   │   ├── 69
│   │   │   ├── 4181ee3f025baf24f6b45c099791dee712ee87
│   │   │   └── 854aa1513ebdea45a278bbe53be4c74e977bc0
│   │   ├── 6e
│   │   │   └── 39867fcbfd425d98673a6814cfa6a735446bc9
│   │   ├── 79
│   │   │   └── ff4b2af317940b18cb6510415c28cefb419377
│   │   ├── 7a
│   │   │   └── 31e92ad4b0f43d70d269235a87e15f9d5e550c
│   │   ├── 7c
│   │   │   └── 7b194f6070cffab36898c8211261cef40e9708
│   │   ├── 82
│   │   │   └── 5de1c8b318cc13cce0cbff2220943776d9a578
│   │   ├── 8e
│   │   │   └── 2ca4bfd9adc2fec71c124d1fd43d8525e0dcee
│   │   ├── 9d
│   │   │   └── 826bfca2005642470debf527bee34cc53cf130
│   │   ├── a4
│   │   │   └── cef528cad21cfc8d1edbc46593635c3f9c7792
│   │   ├── b8
│   │   │   └── 80363cdf5b8f17bacfebe93d6172d6aff9883c
│   │   ├── c8
│   │   │   └── e63a445e3517bddb9a38116a5b4e0d279015dd
│   │   ├── d6
│   │   │   └── 0a1975eb6dcaa256e2e3c1fcbc8411eb9cb818
│   │   ├── dc
│   │   │   ├── 07181c4fe70cac8a79988b2d4b56eb2b662af4
│   │   │   └── cc9dc1f292338e0d35d52c35e24d59131a56c1
│   │   ├── dd
│   │   │   └── c36fd22260329d4886f8af2f70313bd6059e12
│   │   ├── e2
│   │   │   └── 72de3326f4344f03efd6b7fad2eadc16575abf
│   │   ├── e3
│   │   │   └── 28a5751433a8c35662ea4d312529868dcf400b
│   │   ├── f0
│   │   │   └── b854fd1d2bd55e3d4e1f430de89abe81048012
│   │   ├── fd
│   │   │   └── 71b7ad8e314b4dda40df1f58743fef91c99ca5
│   │   ├── fe
│   │   │   └── 4ba8caf2b0591709d40d83ed2d41402d8d91dc
│   │   ├── info
│   │   └── pack
│   │       ├── pack-cffd40a48bf78d35883453d6161e5b6b1b1fec99.idx
│   │       └── pack-cffd40a48bf78d35883453d6161e5b6b1b1fec99.pack
│   ├── refs
│   │   ├── heads
│   │   │   ├── dev
│   │   │   └── main
│   │   ├── remotes
│   │   │   └── origin
│   │   │       ├── HEAD
│   │   │       ├── dev
│   │   │       └── main
│   │   └── tags
│   ├── COMMIT_EDITMSG
│   ├── FETCH_HEAD
│   ├── HEAD
│   ├── ORIG_HEAD
│   ├── config
│   ├── description
│   ├── index
│   └── packed-refs
├── .github
│   └── workflows
│       └── python-package.yml
├── ssrlib
│   ├── __pycache__
│   │   ├── __init__.cpython-310.pyc
│   │   └── __init__.cpython-311.pyc
│   ├── core
│   │   ├── __pycache__
│   │   │   ├── __init__.cpython-310.pyc
│   │   │   ├── __init__.cpython-311.pyc
│   │   │   ├── config.cpython-310.pyc
│   │   │   ├── config.cpython-311.pyc
│   │   │   ├── pipeline.cpython-310.pyc
│   │   │   ├── pipeline.cpython-311.pyc
│   │   │   ├── registry.cpython-310.pyc
│   │   │   └── registry.cpython-311.pyc
│   │   ├── __init__.py
│   │   ├── config.py
│   │   ├── pipeline.py
│   │   └── registry.py
│   ├── datasets
│   │   ├── __pycache__
│   │   │   ├── __init__.cpython-310.pyc
│   │   │   ├── __init__.cpython-311.pyc
│   │   │   ├── base.cpython-310.pyc
│   │   │   ├── base.cpython-311.pyc
│   │   │   ├── celeba.cpython-310.pyc
│   │   │   ├── celeba.cpython-311.pyc
│   │   │   ├── cifar10.cpython-310.pyc
│   │   │   ├── cifar10.cpython-311.pyc
│   │   │   ├── food101.cpython-310.pyc
│   │   │   ├── food101.cpython-311.pyc
│   │   │   ├── hf_mixin.cpython-310.pyc
│   │   │   ├── hf_mixin.cpython-311.pyc
│   │   │   ├── hf_registry.cpython-310.pyc
│   │   │   ├── hf_registry.cpython-311.pyc
│   │   │   ├── hf_vision.cpython-310.pyc
│   │   │   ├── hf_vision.cpython-311.pyc
│   │   │   ├── imagenet100.cpython-310.pyc
│   │   │   ├── imagenet100.cpython-311.pyc
│   │   │   ├── kaggle_mixin.cpython-310.pyc
│   │   │   ├── kaggle_mixin.cpython-311.pyc
│   │   │   ├── synthtest_dataset.cpython-310.pyc
│   │   │   └── synthtest_dataset.cpython-311.pyc
│   │   ├── __init__.py
│   │   ├── base.py
│   │   ├── celeba.py
│   │   ├── cifar10.py
│   │   ├── food101.py
│   │   ├── hf_mixin.py
│   │   ├── hf_registry.py
│   │   ├── hf_vision.py
│   │   ├── imagenet100.py
│   │   ├── kaggle_mixin.py
│   │   └── synthtest_dataset.py
│   ├── embedders
│   │   ├── __pycache__
│   │   │   ├── __init__.cpython-310.pyc
│   │   │   ├── __init__.cpython-311.pyc
│   │   │   ├── base.cpython-310.pyc
│   │   │   └── base.cpython-311.pyc
│   │   ├── cv
│   │   │   ├── __pycache__
│   │   │   │   ├── __init__.cpython-310.pyc
│   │   │   │   ├── __init__.cpython-311.pyc
│   │   │   │   ├── clip.cpython-310.pyc
│   │   │   │   ├── clip.cpython-311.pyc
│   │   │   │   ├── dino.cpython-310.pyc
│   │   │   │   ├── dino.cpython-311.pyc
│   │   │   │   ├── dinov2.cpython-310.pyc
│   │   │   │   ├── dinov2.cpython-311.pyc
│   │   │   │   ├── vicreg.cpython-310.pyc
│   │   │   │   └── vicreg.cpython-311.pyc
│   │   │   ├── __init__.py
│   │   │   ├── clip.py
│   │   │   ├── dino.py
│   │   │   ├── dinov2.py
│   │   │   └── vicreg.py
│   │   ├── nlp
│   │   │   ├── __pycache__
│   │   │   │   ├── __init__.cpython-310.pyc
│   │   │   │   ├── __init__.cpython-311.pyc
│   │   │   │   ├── bert.cpython-310.pyc
│   │   │   │   ├── bert.cpython-311.pyc
│   │   │   │   ├── bert_base.cpython-310.pyc
│   │   │   │   ├── bert_base.cpython-311.pyc
│   │   │   │   ├── e5.cpython-310.pyc
│   │   │   │   ├── e5.cpython-311.pyc
│   │   │   │   ├── modernbert.cpython-310.pyc
│   │   │   │   └── modernbert.cpython-311.pyc
│   │   │   ├── __init__.py
│   │   │   ├── bert.py
│   │   │   ├── bert_base.py
│   │   │   ├── e5.py
│   │   │   └── modernbert.py
│   │   ├── __init__.py
│   │   └── base.py
│   ├── losses
│   │   ├── __init__.py
│   │   ├── base.py
│   │   ├── contrastive_loss.py
│   │   ├── deepinfomax_loss.py
│   │   ├── infonce_loss.py
│   │   └── triplet_loss.py
│   ├── processing
│   │   ├── __pycache__
│   │   │   ├── __init__.cpython-310.pyc
│   │   │   ├── __init__.cpython-311.pyc
│   │   │   ├── base.cpython-310.pyc
│   │   │   ├── base.cpython-311.pyc
│   │   │   ├── covariance.cpython-310.pyc
│   │   │   ├── covariance.cpython-311.pyc
│   │   │   ├── effective_rank.cpython-310.pyc
│   │   │   ├── effective_rank.cpython-311.pyc
│   │   │   ├── leverage_scores.cpython-310.pyc
│   │   │   ├── leverage_scores.cpython-311.pyc
│   │   │   ├── pairwise_stats.cpython-310.pyc
│   │   │   ├── pairwise_stats.cpython-311.pyc
│   │   │   ├── spectrum.cpython-310.pyc
│   │   │   ├── spectrum.cpython-311.pyc
│   │   │   ├── stable_rank.cpython-310.pyc
│   │   │   ├── stable_rank.cpython-311.pyc
│   │   │   ├── zca.cpython-310.pyc
│   │   │   └── zca.cpython-311.pyc
│   │   ├── __init__.py
│   │   ├── base.py
│   │   ├── covariance.py
│   │   ├── effective_rank.py
│   │   ├── leverage_scores.py
│   │   ├── pairwise_stats.py
│   │   ├── spectrum.py
│   │   ├── stable_rank.py
│   │   └── zca.py
│   ├── storage
│   │   ├── __pycache__
│   │   │   ├── __init__.cpython-310.pyc
│   │   │   ├── __init__.cpython-311.pyc
│   │   │   ├── tensor_storage.cpython-310.pyc
│   │   │   └── tensor_storage.cpython-311.pyc
│   │   ├── __init__.py
│   │   └── tensor_storage.py
│   └── __init__.py
├── .gitignore
├── README.md
├── basic_pipeline.py
├── example_with_storage.py
├── pyproject.toml
├── requirements-dev.txt
├── requirements.txt
├── setup.py
└── uv.lock


File: example_with_storage.py
from ssrlib.datasets import SynthTestDataset
from ssrlib.embedders.cv import DINOv2Embedder, CLIPEmbedder
from ssrlib.processing import CovarianceProcessor, ZCAProcessor


from ssrlib import Pipeline


def example_with_storage():
    """Example showing pipeline with storage caching."""

    # Create pipeline
    pipeline = Pipeline(
        [
            (
                "datasets",
                [
                    SynthTestDataset(tensors_num=50, tensor_shape=(3, 224, 224), seed=1),
                    SynthTestDataset(tensors_num=30, tensor_shape=(3, 224, 224), seed=2),
                ],
            ),
            (
                "embedders",
                [
                    DINOv2Embedder("dinov2_vitb14"),
                    CLIPEmbedder("clip-vit-large-patch14"),
                ],
            ),
            ("processors", [CovarianceProcessor(), ZCAProcessor(epsilon=1e-6)]),
        ]
    )

    print("=== First execution (cache miss) ===")
    # First execution - will compute and cache all embeddings
    results1 = pipeline.execute(
        use_storage=True,
        storage_dir="./cache/pipeline_test",
        storage_description="Test pipeline with two synthetic datasets",
    )

    print(f"Cache hit rate: {results1.metadata.get('cache_hit_rate', 0):.2%}")
    print(f"Storage info: {results1.storage_info}")

    print("\n=== Second execution (cache hit) ===")
    # Second execution - should load from cache
    results2 = pipeline.execute(
        use_storage=True,
        storage_dir="./cache/pipeline_test",
        force_recompute=False,  # Use cache
    )

    print(f"Cache hit rate: {results2.metadata.get('cache_hit_rate', 0):.2%}")
    print(f"Timing comparison:")
    print(f"  First run: {results1.timing['total_time']:.2f}s")
    print(f"  Second run: {results2.timing['total_time']:.2f}s")
    print(f"  Speedup: {results1.timing['total_time']/results2.timing['total_time']:.1f}x")

    print("\n=== Force recompute ===")
    # Third execution - force recompute
    results3 = pipeline.execute(
        use_storage=True,
        storage_dir="./cache/pipeline_test",
        force_recompute=True,  # Ignore cache
    )

    print(f"Cache hit rate: {results3.metadata.get('cache_hit_rate', 0):.2%}")


if __name__ == "__main__":
    example_with_storage()


File: setup.py
from setuptools import setup, find_packages

with open("README.md", "r", encoding="utf-8") as fh:
    long_description = fh.read()

setup(
    name="ssrlib",
    version="0.1.0",
    author="Mikhail Kuznetov",
    author_email="mmkuznecov2002@gmail.com",
    description="A modular Python framework for Self-Supervised Learning with automatic component discovery and intelligent caching",
    long_description=long_description,
    long_description_content_type="text/markdown",
    url="https://github.com/mmkuznecov/ssrlib",
    packages=find_packages(),
    classifiers=[
        "Development Status :: 3 - Alpha",
        "Intended Audience :: Developers",
        "Intended Audience :: Science/Research",
        "License :: OSI Approved :: MIT License",
        "Operating System :: OS Independent",
        "Programming Language :: Python :: 3",
        "Programming Language :: Python :: 3.8",
        "Programming Language :: Python :: 3.9",
        "Programming Language :: Python :: 3.10",
        "Programming Language :: Python :: 3.11",
        "Topic :: Scientific/Engineering :: Artificial Intelligence",
        "Topic :: Software Development :: Libraries :: Python Modules",
    ],
    python_requires=">=3.8",
    install_requires=[
        # Core dependencies
        "torch>=2.0.0",
        "torchvision>=0.15.0",
        "numpy>=1.21.0",
        "pandas>=1.3.0",
        "scipy>=1.7.0",
        # Image processing
        "Pillow>=8.3.0",
        # Model loading
        "transformers>=4.20.0",
        "huggingface-hub>=0.16.0",
        # Utilities
        "tqdm>=4.62.0",
        "pyyaml>=5.4.0",
        "requests>=2.28.0",
        # Data handling
        "safetensors>=0.3.0",
    ],
    extras_require={
        "dev": [
            "pytest>=7.0",
            "pytest-cov>=3.0",
            "black>=22.0",
            "isort>=5.10",
            "pylint>=2.15",
        ],
        "examples": [
            "matplotlib>=3.5.0",
            "seaborn>=0.11.0",
            "jupyter>=1.0.0",
            "ipywidgets>=8.0.0",
        ],
        "all": [
            # Optional advanced features
            "scikit-learn>=1.0.0",
            "sentencepiece>=0.1.96",  # For some NLP embedders
        ],
    },
)


File: basic_pipeline.py
import numpy as np
import torch
from ssrlib import Pipeline, Config
from ssrlib.datasets import SynthTestDataset
from ssrlib.embedders.cv import DINOv2Embedder, CLIPEmbedder
from ssrlib.processing import CovarianceProcessor, ZCAProcessor


def basic_single_pipeline():
    """Basic pipeline with single dataset and embedder."""

    print("=== Basic Single Pipeline ===")

    # Create a simple pipeline
    pipeline = Pipeline(
        [
            ("dataset", SynthTestDataset(tensors_num=50, seed=42)),
            ("embedder", DINOv2Embedder("dinov2_vitb14")),
            ("processor", CovarianceProcessor()),
        ]
    )

    # Execute with custom configuration
    config = Config({"batch_size": 32, "device": "cpu"})
    results = pipeline.execute(config_override={"batch_size": 16})

    # Access results
    dataset_name = "SynthTest"
    embedder_name = "DINOv2_dinov2_vitb14"
    processor_name = "Covariance"

    embeddings = results.get_embeddings(dataset_name, embedder_name)
    covariance = results.get_processed(dataset_name, embedder_name, processor_name)

    print(f"Embeddings shape: {embeddings.shape}")
    print(f"Covariance matrix shape: {covariance.shape}")
    print(f"Pipeline timing: {results.timing}")


def multi_component_pipeline():
    """Pipeline with multiple datasets, embedders, and processors."""

    print("\n=== Multi-Component Pipeline ===")

    # Create pipeline with multiple components
    pipeline = Pipeline(
        [
            (
                "datasets",
                [
                    SynthTestDataset(tensors_num=20, tensor_shape=(3, 224, 224), seed=1),
                    SynthTestDataset(tensors_num=20, tensor_shape=(3, 224, 224), seed=2),
                ],
            ),
            (
                "embedders",
                [
                    DINOv2Embedder("dinov2_vitb14"),
                    CLIPEmbedder("clip-vit-large-patch14"),
                ],
            ),
            ("processors", [CovarianceProcessor(), ZCAProcessor(epsilon=1e-6)]),
        ]
    )

    # This creates 2×2×2 = 8 different combinations
    results = pipeline.execute()

    print(f"Total embeddings computed: {len(results.embeddings)}")
    print(f"Total processed outputs: {len(results.processed)}")

    # Show all combinations
    for (dataset, embedder), emb in results.embeddings.items():
        print(f"{dataset} + {embedder}: {emb.shape}")

    for (dataset, embedder, processor), proc in results.processed.items():
        print(f"{dataset} + {embedder} + {processor}: {proc.shape}")


def configuration_driven_pipeline():
    """Example using configuration files and overrides."""

    print("\n=== Configuration-Driven Pipeline ===")

    # Create configuration
    config = Config(
        {
            "device": "cpu",
            "batch_size": 64,
            "output_dir": "./results",
            "model": {
                "dinov2_variant": "dinov2_vitb14",
                "clip_variant": "clip-vit-large-patch14",
            },
            "processing": {"zca_epsilon": 1e-9, "compute_covariance": True},
        }
    )

    # Create pipeline using config values
    embedders = []
    if config.get("model.dinov2_variant"):
        embedders.append(DINOv2Embedder(config.get("model.dinov2_variant")))
    if config.get("model.clip_variant"):
        embedders.append(CLIPEmbedder(config.get("model.clip_variant")))

    processors = []
    if config.get("processing.compute_covariance"):
        processors.append(CovarianceProcessor())
    if config.get("processing.zca_epsilon"):
        processors.append(ZCAProcessor(epsilon=config.get("processing.zca_epsilon")))

    pipeline = Pipeline(
        [
            (
                "dataset",
                SynthTestDataset(tensors_num=30, tensor_shape=(3, 224, 224), seed=123),
            ),
            ("embedders", embedders),
            ("processors", processors),
        ],
        config=config,
    )

    # Execute with runtime overrides
    results = pipeline.execute(config_override={"batch_size": 32})

    print(f"Used batch size: {config.get('batch_size')}")
    print(f"Results metadata keys: {list(results.metadata.keys())}")


def analyze_results():
    """Example of analyzing pipeline results."""

    print("\n=== Analyzing Results ===")

    # Simple pipeline for analysis
    pipeline = Pipeline(
        [
            (
                "dataset",
                SynthTestDataset(tensors_num=30, tensor_shape=(3, 224, 224), seed=123),
            ),
            ("embedder", DINOv2Embedder("dinov2_vits14")),  # Smaller model
            ("processors", [CovarianceProcessor(), ZCAProcessor()]),
        ]
    )

    results = pipeline.execute()

    # Analyze embeddings
    dataset_name = "SynthTest"
    embedder_name = "DINOv2_dinov2_vits14"

    embeddings = results.get_embeddings(dataset_name, embedder_name)
    covariance = results.get_processed(dataset_name, embedder_name, "Covariance")
    whitened = results.get_processed(dataset_name, embedder_name, "ZCA")

    print(f"\nEmbedding Analysis:")
    print(f"  Shape: {embeddings.shape}")
    print(f"  Mean: {np.mean(embeddings):.4f}")
    print(f"  Std: {np.std(embeddings):.4f}")
    print(f"  Min: {np.min(embeddings):.4f}")
    print(f"  Max: {np.max(embeddings):.4f}")

    print(f"\nCovariance Matrix Analysis:")
    print(f"  Shape: {covariance.shape}")
    print(f"  Trace: {np.trace(covariance):.4f}")
    print(f"  Determinant: {np.linalg.det(covariance):.4e}")
    print(f"  Max eigenvalue: {np.max(np.linalg.eigvals(covariance)):.4f}")
    print(f"  Min eigenvalue: {np.min(np.linalg.eigvals(covariance)):.4f}")

    print(f"\nWhitened Embeddings Analysis:")
    print(f"  Shape: {whitened.shape}")
    print(f"  Mean: {np.mean(whitened):.4f}")
    print(f"  Std: {np.std(whitened):.4f}")

    # Check if whitening worked (covariance should be close to identity)
    whitened_cov = np.cov(whitened.T)
    identity_error = np.mean((whitened_cov - np.eye(whitened_cov.shape[0])) ** 2)
    print(f"  Whitening quality (MSE from identity): {identity_error:.6f}")


def main():
    """Run all examples."""

    print("ssrlib Framework - Usage Examples")
    print("================================")

    try:
        basic_single_pipeline()
        multi_component_pipeline()
        configuration_driven_pipeline()
        analyze_results()

        print("\n=== All Examples Completed Successfully! ===")

    except Exception as e:
        print(f"\nExample failed with error: {str(e)}")
        print("Note: Make sure you have the required datasets downloaded.")
        print("For testing without real data, see test_examples.py")


if __name__ == "__main__":
    main()


File: ssrlib/__init__.py
from .core.pipeline import Pipeline, PipelineResults
from .core.config import Config

__version__ = "0.1.0"
__all__ = ["Pipeline", "PipelineResults", "Config"]


File: ssrlib/losses/triplet_loss.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Union

from .base import ContrastiveLossBase, DistanceMetric


class TripletLoss(ContrastiveLossBase):
    """
    Triplet Loss implementation with various distance metrics.

    The triplet loss encourages embeddings where the distance between anchor
    and positive is smaller than the distance between anchor and negative by
    at least a margin.

    Loss = max(0, d(anchor, positive) - d(anchor, negative) + margin)

    Args:
        margin: Minimum margin between positive and negative pairs
        distance_metric: Distance metric ('euclidean', 'cosine', 'squared_euclidean', 'manhattan')
        reduction: Loss reduction ('mean', 'sum', 'none')
        normalize: Whether to L2 normalize embeddings

    References:
        Schroff et al. "FaceNet: A Unified Embedding for Face Recognition and Clustering"
    """

    def __init__(
        self,
        margin: float = 1.0,
        distance_metric: DistanceMetric = DistanceMetric.EUCLIDEAN,
        reduction: str = "mean",
        normalize: bool = False,
    ):
        """Initialize triplet loss.

        Args:
            margin: Minimum margin between positive and negative pairs
            distance_metric: Distance metric to use
            reduction: Loss reduction method
            normalize: Whether to L2 normalize embeddings
        """
        super().__init__(
            margin=margin,
            distance_metric=distance_metric,
            reduction=reduction,
            normalize=normalize,
        )

    def forward(
        self, anchor: torch.Tensor, positive: torch.Tensor, negative: torch.Tensor
    ) -> torch.Tensor:
        """
        Forward pass for triplet loss.

        Args:
            anchor: Anchor embeddings (N, D)
            positive: Positive embeddings (N, D)
            negative: Negative embeddings (N, D)

        Returns:
            Triplet loss
        """
        # Apply normalization if enabled
        anchor, positive, negative = self.apply_normalization(anchor, positive, negative)

        # Compute distances
        pos_dist = self.compute_distance(anchor, positive, self.distance_metric)
        neg_dist = self.compute_distance(anchor, negative, self.distance_metric)

        # Triplet loss: max(0, d(a,p) - d(a,n) + margin)
        loss = torch.clamp(pos_dist - neg_dist + self.margin, min=0.0)

        return self.apply_reduction(loss)


class TripletLossWithMining(TripletLoss):
    """
    Triplet Loss with hard negative mining.

    This extends the basic triplet loss by automatically selecting
    hard negative examples within each batch.

    Args:
        margin: Minimum margin between positive and negative pairs
        distance_metric: Distance metric to use
        reduction: Loss reduction method
        normalize: Whether to L2 normalize embeddings
        mining_strategy: Strategy for selecting hard negatives ('hardest', 'semi_hard', 'all')
    """

    def __init__(
        self,
        margin: float = 1.0,
        distance_metric: DistanceMetric = DistanceMetric.EUCLIDEAN,
        reduction: str = "mean",
        normalize: bool = False,
        mining_strategy: str = "hardest",
    ):
        """Initialize triplet loss with mining.

        Args:
            mining_strategy: Mining strategy ('hardest', 'semi_hard', 'all')
        """
        super().__init__(margin, distance_metric, reduction, normalize)

        assert mining_strategy in [
            "hardest",
            "semi_hard",
            "all",
        ], f"Invalid mining strategy: {mining_strategy}"
        self.mining_strategy = mining_strategy

    def mine_triplets(self, embeddings: torch.Tensor, labels: torch.Tensor) -> tuple:
        """
        Mine triplets from a batch of embeddings and labels.

        Args:
            embeddings: Batch of embeddings (N, D)
            labels: Corresponding labels (N,)

        Returns:
            Tuple of (anchor_idx, positive_idx, negative_idx)
        """
        # Compute pairwise distances
        pairwise_dist = self.compute_pairwise_distance(embeddings, self.distance_metric)

        # Create masks for positive and negative pairs
        labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
        labels_not_equal = ~labels_equal

        # Remove diagonal (self-comparisons)
        labels_equal.fill_diagonal_(False)

        anchors, positives, negatives = [], [], []

        for i in range(len(embeddings)):
            # Find positive examples (same label, not self)
            positive_mask = labels_equal[i]
            if not positive_mask.any():
                continue  # Skip if no positives available

            # Find negative examples (different label)
            negative_mask = labels_not_equal[i]
            if not negative_mask.any():
                continue  # Skip if no negatives available

            # Get distances for this anchor
            pos_dists = pairwise_dist[i][positive_mask]
            neg_dists = pairwise_dist[i][negative_mask]

            pos_indices = torch.where(positive_mask)[0]
            neg_indices = torch.where(negative_mask)[0]

            if self.mining_strategy == "hardest":
                # Hardest positive (farthest positive)
                hardest_pos_idx = pos_indices[torch.argmax(pos_dists)]
                # Hardest negative (closest negative)
                hardest_neg_idx = neg_indices[torch.argmin(neg_dists)]

                anchors.append(i)
                positives.append(hardest_pos_idx)
                negatives.append(hardest_neg_idx)

            elif self.mining_strategy == "semi_hard":
                # Semi-hard negatives: d(a,p) < d(a,n) < d(a,p) + margin
                hardest_pos_dist = torch.max(pos_dists)
                semi_hard_mask = (neg_dists > hardest_pos_dist) & (
                    neg_dists < hardest_pos_dist + self.margin
                )

                if semi_hard_mask.any():
                    # Choose random semi-hard negative
                    semi_hard_negs = neg_indices[semi_hard_mask]
                    chosen_neg = semi_hard_negs[torch.randint(len(semi_hard_negs), (1,))]

                    anchors.append(i)
                    positives.append(pos_indices[torch.argmax(pos_dists)])
                    negatives.append(chosen_neg)

            elif self.mining_strategy == "all":
                # All valid combinations
                for pos_idx in pos_indices:
                    for neg_idx in neg_indices:
                        anchors.append(i)
                        positives.append(pos_idx)
                        negatives.append(neg_idx)

        if not anchors:
            # Return empty tensors if no valid triplets found
            return (
                torch.tensor([], dtype=torch.long),
                torch.tensor([], dtype=torch.long),
                torch.tensor([], dtype=torch.long),
            )

        return torch.tensor(anchors), torch.tensor(positives), torch.tensor(negatives)

    def forward(self, embeddings: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        """
        Forward pass with automatic triplet mining.

        Args:
            embeddings: Batch of embeddings (N, D)
            labels: Corresponding labels (N,)

        Returns:
            Triplet loss
        """
        # Apply normalization if enabled
        embeddings = self.apply_normalization(embeddings)[0]

        # Mine triplets
        anchor_idx, pos_idx, neg_idx = self.mine_triplets(embeddings, labels)

        if len(anchor_idx) == 0:
            # Return zero loss if no valid triplets found
            return torch.tensor(0.0, device=embeddings.device, requires_grad=True)

        # Extract triplets
        anchors = embeddings[anchor_idx]
        positives = embeddings[pos_idx]
        negatives = embeddings[neg_idx]

        # Compute triplet loss
        return super().forward(anchors, positives, negatives)

    def get_config(self) -> dict:
        """Get configuration including mining strategy."""
        config = super().get_config()
        config["mining_strategy"] = self.mining_strategy
        return config


File: ssrlib/losses/__init__.py
"""Loss functions for ssrlib with automatic discovery."""

import logging
from pathlib import Path
from typing import Dict, List, Type, Any
import warnings

logger = logging.getLogger(__name__)

# Import base classes and registry system
from .base import BaseLoss, ContrastiveLossBase, DistanceMetric
from ..core.registry import BaseRegistry, discover_components

# Type alias
LossRegistry = BaseRegistry[BaseLoss]


def categorize_loss(cls: Type[BaseLoss]) -> str:
    """Determine category for a loss function."""
    # Check inheritance hierarchy
    if issubclass(cls, ContrastiveLossBase):
        return "contrastive"

    # Check name patterns
    class_name = cls.__name__.lower()
    if "contrastive" in class_name or "triplet" in class_name:
        return "contrastive"
    elif "info" in class_name or "nce" in class_name:
        return "information_theory"
    elif "deepinfomax" in class_name:
        return "mutual_information"
    else:
        return "general"


def discover_loss_classes() -> LossRegistry:
    """Discover all loss classes in the losses module."""
    registry = LossRegistry("loss")

    return discover_components(
        package_path=Path(__file__).parent,
        package_name=__name__,
        base_class=BaseLoss,
        registry=registry,
        get_category_func=categorize_loss,
    )


# Perform discovery at import time
logger.debug("Starting loss function discovery...")
_loss_registry = discover_loss_classes()


# Convenience functions
def get_available_losses() -> Dict[str, Type[BaseLoss]]:
    """Get dictionary of all available loss functions."""
    return _loss_registry._items.copy()


def get_loss_descriptions() -> Dict[str, str]:
    """Get dictionary of loss descriptions."""
    return _loss_registry._descriptions.copy()


def list_losses(category: str = None) -> List[str]:
    """List available loss functions."""
    if category:
        return _loss_registry.list_by_category(category).get(category, [])
    return _loss_registry.list_all()


def get_loss_info(name: str) -> Dict[str, Any]:
    """Get detailed information about a loss function."""
    return _loss_registry.get_info(name)


def print_available_losses() -> None:
    """Print all available loss functions with descriptions."""
    _loss_registry.print_registry()


def create_loss(name: str, **kwargs) -> BaseLoss:
    """Create a loss function by name."""
    loss_class = _loss_registry.get(name)
    if loss_class is None:
        available = ", ".join(_loss_registry.list_all())
        raise ValueError(f"Unknown loss function '{name}'. Available: {available}")
    return loss_class(**kwargs)


# Create dynamic exports
_exported_classes = {}
for name, loss_class in _loss_registry._items.items():
    _exported_classes[name] = loss_class

# Update module globals
globals().update(_exported_classes)

# Create __all__ dynamically
__all__ = [
    "BaseLoss",
    "ContrastiveLossBase",
    "DistanceMetric",
    "get_available_losses",
    "get_loss_descriptions",
    "list_losses",
    "get_loss_info",
    "print_available_losses",
    "create_loss",
    *_loss_registry.list_all(),
]

# Log results
if logger.isEnabledFor(logging.INFO):
    logger.info(f"Loss discovery complete: {len(_loss_registry.list_all())} losses found")
    for category, losses in _loss_registry.list_by_category().items():
        logger.info(f"  {category}: {', '.join(losses)}")

# Warn about errors
if _loss_registry._discovery_errors:
    warnings.warn(
        f"Some loss modules failed to import: {len(_loss_registry._discovery_errors)} errors. "
        f"Run logging.getLogger('{__name__}').setLevel(logging.DEBUG) for details.",
        ImportWarning,
    )


File: ssrlib/losses/infonce_loss.py
import torch
import torch.nn.functional as F
from torch import nn
from typing import Optional

from .base import BaseLoss


class InfoNCE(BaseLoss):
    """
    InfoNCE (Information Noise Contrastive Estimation) loss for self-supervised learning.

    This contrastive loss enforces the embeddings of similar (positive) samples to be close
    and those of different (negative) samples to be distant. A query embedding is compared
    with one positive key and with one or more negative keys.

    References:
        https://arxiv.org/abs/1807.03748v2
        https://arxiv.org/abs/2010.05113

    Args:
        temperature: Logits are divided by temperature before calculating the cross entropy.
        reduction: Reduction method applied to the output.
            Value must be one of ['none', 'sum', 'mean'].
        negative_mode: Determines how the (optional) negative_keys are handled.
            Value must be one of ['paired', 'unpaired'].
            If 'paired', then each query sample is paired with a number of negative keys.
            If 'unpaired', then the set of negative keys are all unrelated to any positive key.
        normalize: Whether to normalize embeddings before computing similarities.

    Input shape:
        query: (N, D) Tensor with query samples (e.g. embeddings of the input).
        positive_key: (N, D) Tensor with positive samples (e.g. embeddings of augmented input).
        negative_keys (optional): Tensor with negative samples (e.g. embeddings of other inputs)
            If negative_mode = 'paired', then negative_keys is a (N, M, D) Tensor.
            If negative_mode = 'unpaired', then negative_keys is a (M, D) Tensor.
            If None, then the negative keys for a sample are the positive keys for the other samples.

    Returns:
         Value of the InfoNCE Loss.

    Examples:
        >>> loss = InfoNCE()
        >>> batch_size, num_negative, embedding_size = 32, 48, 128
        >>> query = torch.randn(batch_size, embedding_size)
        >>> positive_key = torch.randn(batch_size, embedding_size)
        >>> negative_keys = torch.randn(num_negative, embedding_size)
        >>> output = loss(query, positive_key, negative_keys)
    """

    def __init__(
        self,
        temperature: float = 0.1,
        reduction: str = "mean",
        negative_mode: str = "unpaired",
        normalize: bool = True,
    ):
        """Initialize InfoNCE loss.

        Args:
            temperature: Temperature scaling parameter
            reduction: Loss reduction method
            negative_mode: How to handle negative keys ('paired' or 'unpaired')
            normalize: Whether to normalize embeddings
        """
        super().__init__(
            reduction=reduction,
            normalize=normalize,
            temperature=temperature,
            negative_mode=negative_mode,
        )

        assert negative_mode in [
            "paired",
            "unpaired",
        ], f"Invalid negative_mode: {negative_mode}"

        self.negative_mode = negative_mode

    def forward(
        self,
        query: torch.Tensor,
        positive_key: torch.Tensor,
        negative_keys: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Forward pass for InfoNCE loss.

        Args:
            query: Query embeddings (N, D)
            positive_key: Positive key embeddings (N, D)
            negative_keys: Optional negative key embeddings

        Returns:
            InfoNCE loss
        """
        return info_nce(
            query=query,
            positive_key=positive_key,
            negative_keys=negative_keys,
            temperature=self.temperature,
            reduction=self.reduction,
            negative_mode=self.negative_mode,
            normalize=self.normalize,
        )


def info_nce(
    query: torch.Tensor,
    positive_key: torch.Tensor,
    negative_keys: Optional[torch.Tensor] = None,
    temperature: float = 0.1,
    reduction: str = "mean",
    negative_mode: str = "unpaired",
    normalize: bool = True,
) -> torch.Tensor:
    """
    Functional interface for InfoNCE loss.

    Args:
        query: Query embeddings (N, D)
        positive_key: Positive key embeddings (N, D)
        negative_keys: Optional negative key embeddings
        temperature: Temperature scaling parameter
        reduction: Loss reduction method
        negative_mode: How to handle negative keys ('paired' or 'unpaired')
        normalize: Whether to normalize embeddings

    Returns:
        InfoNCE loss
    """
    # Input validation
    if query.dim() != 2:
        raise ValueError("<query> must have 2 dimensions.")
    if positive_key.dim() != 2:
        raise ValueError("<positive_key> must have 2 dimensions.")
    if negative_keys is not None:
        if negative_mode == "unpaired" and negative_keys.dim() != 2:
            raise ValueError(
                "<negative_keys> must have 2 dimensions if <negative_mode> == 'unpaired'."
            )
        if negative_mode == "paired" and negative_keys.dim() != 3:
            raise ValueError(
                "<negative_keys> must have 3 dimensions if <negative_mode> == 'paired'."
            )

    # Check matching number of samples
    if len(query) != len(positive_key):
        raise ValueError("<query> and <positive_key> must have the same number of samples.")
    if negative_keys is not None:
        if negative_mode == "paired" and len(query) != len(negative_keys):
            raise ValueError(
                "If negative_mode == 'paired', then <negative_keys> must have the same number of samples as <query>."
            )

    # Check embedding dimensions match
    if query.shape[-1] != positive_key.shape[-1]:
        raise ValueError(
            "Vectors of <query> and <positive_key> should have the same number of components."
        )
    if negative_keys is not None:
        if query.shape[-1] != negative_keys.shape[-1]:
            raise ValueError(
                "Vectors of <query> and <negative_keys> should have the same number of components."
            )

    # Normalize to unit vectors if requested
    if normalize:
        query, positive_key, negative_keys = _normalize(query, positive_key, negative_keys)

    if negative_keys is not None:
        # Explicit negative keys provided

        # Cosine similarity between positive pairs
        positive_logit = torch.sum(query * positive_key, dim=1, keepdim=True)

        if negative_mode == "unpaired":
            # Cosine similarity between all query-negative combinations
            negative_logits = query @ _transpose(negative_keys)

        elif negative_mode == "paired":
            # Each query paired with its corresponding negative keys
            query_expanded = query.unsqueeze(1)  # (N, 1, D)
            negative_logits = query_expanded @ _transpose(
                negative_keys
            )  # (N, 1, D) @ (N, D, M) -> (N, 1, M)
            negative_logits = negative_logits.squeeze(1)  # (N, M)

        # Concatenate positive and negative logits
        # First column contains positive logits, rest are negative
        logits = torch.cat([positive_logit, negative_logits], dim=1)
        labels = torch.zeros(len(logits), dtype=torch.long, device=query.device)
    else:
        # Negative keys are implicitly other positive keys in the batch

        # Cosine similarity between all query-positive_key combinations
        logits = query @ _transpose(positive_key)

        # Positive keys are on the diagonal
        labels = torch.arange(len(query), device=query.device)

    # Apply temperature scaling and compute cross-entropy loss
    return F.cross_entropy(logits / temperature, labels, reduction=reduction)


def _transpose(x: torch.Tensor) -> torch.Tensor:
    """Transpose last two dimensions."""
    return x.transpose(-2, -1)


def _normalize(*xs: torch.Tensor) -> tuple:
    """Normalize tensors to unit vectors."""
    return tuple(None if x is None else F.normalize(x, dim=-1) for x in xs)


File: ssrlib/losses/contrastive_loss.py
import torch
import torch.nn.functional as F
from typing import Optional

from .base import ContrastiveLossBase, DistanceMetric


class ContrastiveLoss(ContrastiveLossBase):
    """
    Contrastive Loss for learning embeddings.

    The contrastive loss pulls together embeddings of similar samples (positive pairs)
    and pushes apart embeddings of dissimilar samples (negative pairs).

    For positive pairs (label=0): minimize distance
    For negative pairs (label=1): maximize distance up to margin

    Loss = (1-label) * d^2 + label * max(0, margin - d)^2

    Args:
        margin: Minimum margin for negative pairs
        distance_metric: Distance metric to use
        reduction: Loss reduction method
        normalize: Whether to L2 normalize embeddings

    References:
        Hadsell et al. "Dimensionality Reduction by Learning an Invariant Mapping"
    """

    def __init__(
        self,
        margin: float = 2.0,
        distance_metric: DistanceMetric = DistanceMetric.EUCLIDEAN,
        reduction: str = "mean",
        normalize: bool = False,
    ):
        """Initialize contrastive loss.

        Args:
            margin: Margin for negative pairs (default: 2.0)
            distance_metric: Distance metric to use
            reduction: Reduction method ('mean', 'sum', 'none')
            normalize: Whether to L2 normalize embeddings
        """
        super().__init__(
            margin=margin,
            distance_metric=distance_metric,
            reduction=reduction,
            normalize=normalize,
        )

    def forward(
        self, output1: torch.Tensor, output2: torch.Tensor, label: torch.Tensor
    ) -> torch.Tensor:
        """
        Forward pass for contrastive loss.

        Args:
            output1: First set of embeddings (N, D)
            output2: Second set of embeddings (N, D)
            label: Binary labels (N,) where 0=similar, 1=dissimilar

        Returns:
            Contrastive loss
        """
        # Apply normalization if enabled
        output1, output2 = self.apply_normalization(output1, output2)

        # Compute distance
        distance = self.compute_distance(output1, output2, self.distance_metric)

        # Contrastive loss computation
        # For similar pairs (label=0): penalize large distances
        pos_loss = (1 - label) * torch.pow(distance, 2)

        # For dissimilar pairs (label=1): penalize small distances (below margin)
        neg_loss = label * torch.pow(torch.clamp(self.margin - distance, min=0.0), 2)

        # Combined loss
        loss = pos_loss + neg_loss

        return self.apply_reduction(loss)


# Alternative implementation with squared euclidean distance (matches user's original)
class ContrastiveLossOriginal(torch.nn.Module):
    """
    Original contrastive loss implementation (for compatibility).

    This matches the user's original implementation exactly.
    """

    def __init__(self, margin=2.0):
        super(ContrastiveLossOriginal, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
        pos = (1 - label) * torch.pow(euclidean_distance, 2)
        neg = (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)
        loss_contrastive = torch.mean(pos + neg)
        return loss_contrastive


File: ssrlib/losses/deepinfomax_loss.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional

from .base import BaseLoss


# Mock discriminator classes for testing and default usage
class MockGlobalDiscriminator(nn.Module):
    """Mock global discriminator for testing purposes."""

    def __init__(self, output_dim: int = 1):
        super().__init__()
        self.linear = nn.Linear(1, output_dim)  # Minimal implementation

    def forward(self, y: torch.Tensor, M: torch.Tensor) -> torch.Tensor:
        """Forward pass for global discriminator.

        Args:
            y: Encoded representations [batch_size, encoding_dim]
            M: Feature maps [batch_size, channels, height, width]

        Returns:
            Discriminator output [batch_size, output_dim]
        """
        batch_size = y.shape[0]
        return torch.randn(batch_size, 1, device=y.device)


class MockLocalDiscriminator(nn.Module):
    """Mock local discriminator for testing purposes."""

    def __init__(self, output_channels: int = 1):
        super().__init__()
        self.conv = nn.Conv2d(1, output_channels, kernel_size=1)  # Minimal implementation

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass for local discriminator.

        Args:
            x: Combined feature maps [batch_size, channels, height, width]

        Returns:
            Discriminator output [batch_size, output_channels, height, width]
        """
        batch_size, channels, height, width = x.shape
        return torch.randn(batch_size, 1, height, width, device=x.device)


class MockPriorDiscriminator(nn.Module):
    """Mock prior discriminator for testing purposes."""

    def __init__(self, input_dim: int = 64, output_dim: int = 1):
        super().__init__()
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass for prior discriminator.

        Args:
            x: Input representations [batch_size, input_dim]

        Returns:
            Discriminator output [batch_size, output_dim] with sigmoid activation
        """
        batch_size = x.shape[0]
        return torch.sigmoid(torch.randn(batch_size, 1, device=x.device))


class DeepInfoMaxLoss(BaseLoss):
    """
    DeepInfoMax Loss for self-supervised representation learning.

    The DeepInfoMaxLoss maximizes mutual information between input and learned
    representations using adversarial training with three discriminators:
    - Global: Captures global mutual information
    - Local: Captures local mutual information
    - Prior: Matches representations to a prior distribution

    References:
        Hjelm et al. "Learning deep representations by mutual information estimation and maximization"
        https://arxiv.org/pdf/1808.06670.pdf

    Args:
        global_discriminator: Global discriminator instance
        local_discriminator: Local discriminator instance
        prior_discriminator: Prior discriminator instance
        alpha: Weight for global loss term
        beta: Weight for local loss term
        gamma: Weight for prior loss term
        reduction: Loss reduction method
    """

    def __init__(
        self,
        global_discriminator: Optional[nn.Module] = None,
        local_discriminator: Optional[nn.Module] = None,
        prior_discriminator: Optional[nn.Module] = None,
        alpha: float = 0.5,
        beta: float = 1.0,
        gamma: float = 0.1,
        reduction: str = "mean",
    ):
        """Initialize DeepInfoMax loss.

        Args:
            global_discriminator: Global discriminator (uses mock if None)
            local_discriminator: Local discriminator (uses mock if None)
            prior_discriminator: Prior discriminator (uses mock if None)
            alpha: Weight for global loss term
            beta: Weight for local loss term
            gamma: Weight for prior loss term
            reduction: Loss reduction method
        """
        super().__init__(reduction=reduction, alpha=alpha, beta=beta, gamma=gamma)

        # Use provided discriminators or create default ones
        if global_discriminator is not None:
            self.global_d = global_discriminator
        else:
            try:
                # Try to import actual discriminators
                from models import GlobalDiscriminator

                self.global_d = GlobalDiscriminator()
            except ImportError:
                # Fall back to mock discriminator
                self.global_d = MockGlobalDiscriminator()

        if local_discriminator is not None:
            self.local_d = local_discriminator
        else:
            try:
                from models import LocalDiscriminator

                self.local_d = LocalDiscriminator()
            except ImportError:
                self.local_d = MockLocalDiscriminator()

        if prior_discriminator is not None:
            self.prior_d = prior_discriminator
        else:
            try:
                from models import PriorDiscriminator

                self.prior_d = PriorDiscriminator()
            except ImportError:
                self.prior_d = MockPriorDiscriminator()

        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma

    def forward(self, y: torch.Tensor, M: torch.Tensor, M_prime: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of DeepInfoMaxLoss.

        Args:
            y: Encoded representations [batch_size, encoding_dim]
            M: Feature maps [batch_size, channels, height, width]
            M_prime: Shuffled/rotated feature maps [batch_size, channels, height, width]

        Returns:
            Combined loss (LOCAL + GLOBAL + PRIOR)
        """
        # Expand y to match spatial dimensions of feature maps
        # See appendix 1A of https://arxiv.org/pdf/1808.06670.pdf
        y_exp = y.unsqueeze(-1).unsqueeze(-1)  # [batch_size, encoding_dim, 1, 1]

        # Get spatial dimensions from feature maps
        _, _, height, width = M.shape
        y_exp = y_exp.expand(-1, -1, height, width)  # [batch_size, encoding_dim, height, width]

        # Concatenate feature maps with expanded representations
        y_M = torch.cat((M, y_exp), dim=1)  # [batch_size, channels + encoding_dim, height, width]
        y_M_prime = torch.cat((M_prime, y_exp), dim=1)

        # Local discriminator loss
        # E_j = E[log(T(y, x_j))] where (y, x_j) are joint samples
        # E_m = E[log(T(y, x_m))] where (y, x_m) are marginal samples
        Ej_local = -F.softplus(-self.local_d(y_M)).mean()
        Em_local = F.softplus(self.local_d(y_M_prime)).mean()
        LOCAL = (Em_local - Ej_local) * self.beta

        # Global discriminator loss
        Ej_global = -F.softplus(-self.global_d(y, M)).mean()
        Em_global = F.softplus(self.global_d(y, M_prime)).mean()
        GLOBAL = (Em_global - Ej_global) * self.alpha

        # Prior discriminator loss
        # Encourages representations to match a prior distribution
        prior = torch.rand_like(y)
        term_a = torch.log(self.prior_d(prior)).mean()
        term_b = torch.log(1.0 - self.prior_d(y)).mean()
        PRIOR = -(term_a + term_b) * self.gamma

        # Combined loss
        total_loss = LOCAL + GLOBAL + PRIOR

        return self.apply_reduction(total_loss.unsqueeze(0))

    def get_config(self) -> dict:
        """Get configuration for the loss function."""
        config = super().get_config()
        config.update(
            {
                "alpha": self.alpha,
                "beta": self.beta,
                "gamma": self.gamma,
                "global_discriminator": type(self.global_d).__name__,
                "local_discriminator": type(self.local_d).__name__,
                "prior_discriminator": type(self.prior_d).__name__,
            }
        )
        return config


# Convenience function for creating shuffled/rotated versions of feature maps
def create_negative_samples(M: torch.Tensor, method: str = "rotate") -> torch.Tensor:
    """
    Create negative samples from feature maps for DeepInfoMax.

    Args:
        M: Feature maps [batch_size, channels, height, width]
        method: Method for creating negatives ('rotate', 'shuffle')

    Returns:
        Negative feature maps with same shape as M
    """
    if method == "rotate":
        # Rotate batch dimension (each sample paired with different sample's features)
        return torch.cat((M[1:], M[0].unsqueeze(0)), dim=0)
    elif method == "shuffle":
        # Shuffle along batch dimension
        indices = torch.randperm(M.shape[0])
        return M[indices]
    else:
        raise ValueError(f"Unknown method: {method}")


# Example usage function
def example_usage():
    """Example of how to use DeepInfoMaxLoss."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Create loss with mock discriminators
    loss_fn = DeepInfoMaxLoss(
        global_discriminator=MockGlobalDiscriminator(),
        local_discriminator=MockLocalDiscriminator(),
        prior_discriminator=MockPriorDiscriminator(),
        alpha=0.5,
        beta=1.0,
        gamma=0.1,
    )

    # Example data
    batch_size = 4
    encoding_dim = 64
    feature_channels = 128
    feature_size = 26

    y = torch.randn(batch_size, encoding_dim, device=device)
    M = torch.randn(batch_size, feature_channels, feature_size, feature_size, device=device)
    M_prime = create_negative_samples(M, method="rotate")

    # Compute loss
    try:
        loss = loss_fn(y, M, M_prime)
        print(f"DeepInfoMax loss: {loss.item():.4f}")
        return loss
    except Exception as e:
        print(f"Error computing loss: {e}")
        return None


if __name__ == "__main__":
    example_usage()


File: ssrlib/losses/base.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional
from enum import Enum


class DistanceMetric(Enum):
    """Distance metrics for loss functions."""

    EUCLIDEAN = "euclidean"
    SQUARED_EUCLIDEAN = "squared_euclidean"
    COSINE = "cosine"
    MANHATTAN = "manhattan"


class BaseLoss(nn.Module, ABC):
    """Base class for ssrlib loss functions with common functionality."""

    def __init__(
        self,
        reduction: str = "mean",
        normalize: bool = False,
        temperature: Optional[float] = None,
        **kwargs,
    ):
        """Initialize base loss.

        Args:
            reduction: Reduction method ('mean', 'sum', 'none')
            normalize: Whether to L2 normalize embeddings
            temperature: Temperature scaling factor
            **kwargs: Additional loss-specific parameters
        """
        super().__init__()

        assert reduction in ["mean", "sum", "none"], f"Invalid reduction: {reduction}"

        self.reduction = reduction
        self.normalize = normalize
        self.temperature = temperature
        self.loss_params = kwargs

    def apply_normalization(self, *tensors: torch.Tensor) -> tuple:
        """Apply L2 normalization to tensors if enabled.

        Args:
            *tensors: Input tensors to normalize

        Returns:
            Tuple of normalized tensors
        """
        if self.normalize:
            return tuple(F.normalize(t, p=2, dim=-1) for t in tensors)
        return tensors

    def apply_temperature(self, logits: torch.Tensor) -> torch.Tensor:
        """Apply temperature scaling to logits.

        Args:
            logits: Input logits

        Returns:
            Temperature-scaled logits
        """
        if self.temperature is not None:
            return logits / self.temperature
        return logits

    def apply_reduction(self, loss: torch.Tensor) -> torch.Tensor:
        """Apply reduction to loss tensor.

        Args:
            loss: Loss tensor

        Returns:
            Reduced loss
        """
        if self.reduction == "mean":
            return loss.mean()
        elif self.reduction == "sum":
            return loss.sum()
        else:  # 'none'
            return loss

    def compute_distance(
        self,
        x1: torch.Tensor,
        x2: torch.Tensor,
        metric: DistanceMetric = DistanceMetric.EUCLIDEAN,
    ) -> torch.Tensor:
        """Compute distance between tensors using specified metric.

        Args:
            x1, x2: Input tensors of shape (..., D)
            metric: Distance metric to use

        Returns:
            Distance tensor
        """
        if metric == DistanceMetric.EUCLIDEAN:
            return torch.norm(x1 - x2, p=2, dim=-1)
        elif metric == DistanceMetric.SQUARED_EUCLIDEAN:
            return torch.sum((x1 - x2) ** 2, dim=-1)
        elif metric == DistanceMetric.MANHATTAN:
            return torch.norm(x1 - x2, p=1, dim=-1)
        elif metric == DistanceMetric.COSINE:
            # Ensure normalized for cosine distance
            x1_norm = F.normalize(x1, p=2, dim=-1)
            x2_norm = F.normalize(x2, p=2, dim=-1)
            return 1 - torch.sum(x1_norm * x2_norm, dim=-1)
        else:
            raise ValueError(f"Unknown distance metric: {metric}")

    def compute_pairwise_distance(
        self, x: torch.Tensor, metric: DistanceMetric = DistanceMetric.EUCLIDEAN
    ) -> torch.Tensor:
        """Compute pairwise distances within a batch.

        Args:
            x: Input tensor of shape (N, D)
            metric: Distance metric to use

        Returns:
            Pairwise distance matrix of shape (N, N)
        """
        if metric == DistanceMetric.EUCLIDEAN:
            return torch.cdist(x, x, p=2)
        elif metric == DistanceMetric.SQUARED_EUCLIDEAN:
            return torch.cdist(x, x, p=2) ** 2
        elif metric == DistanceMetric.MANHATTAN:
            return torch.cdist(x, x, p=1)
        elif metric == DistanceMetric.COSINE:
            x_norm = F.normalize(x, p=2, dim=1)
            cosine_sim = torch.mm(x_norm, x_norm.t())
            return 1 - cosine_sim
        else:
            raise ValueError(f"Unknown distance metric: {metric}")

    @abstractmethod
    def forward(self, *args, **kwargs) -> torch.Tensor:
        """Forward pass of the loss function."""
        pass

    def get_config(self) -> Dict[str, Any]:
        """Get loss function configuration.

        Returns:
            Configuration dictionary
        """
        return {
            "reduction": self.reduction,
            "normalize": self.normalize,
            "temperature": self.temperature,
            **self.loss_params,
        }

    def __repr__(self) -> str:
        """String representation of the loss function."""
        config_str = ", ".join(f"{k}={v}" for k, v in self.get_config().items() if v is not None)
        return f"{self.__class__.__name__}({config_str})"


class ContrastiveLossBase(BaseLoss):
    """Base class for contrastive-style losses."""

    def __init__(
        self,
        margin: float = 1.0,
        distance_metric: DistanceMetric = DistanceMetric.EUCLIDEAN,
        **kwargs,
    ):
        """Initialize contrastive loss base.

        Args:
            margin: Margin for contrastive learning
            distance_metric: Distance metric to use
            **kwargs: Additional parameters for BaseLoss
        """
        super().__init__(**kwargs)
        self.margin = margin
        self.distance_metric = distance_metric

    def get_config(self) -> Dict[str, Any]:
        """Get configuration including contrastive-specific parameters."""
        config = super().get_config()
        config.update({"margin": self.margin, "distance_metric": self.distance_metric.value})
        return config


File: ssrlib/core/config.py
from typing import Dict, Any
import yaml
import json


class Config:
    """Configuration management class for ssrlib."""

    def __init__(self, config_dict: Dict = None):
        """Initialize configuration.

        Args:
            config_dict: Dictionary containing configuration
        """
        self._config = config_dict or {}

    def get(self, key: str, default: Any = None) -> Any:
        """Get configuration value.

        Args:
            key: Configuration key (supports dot notation like 'model.batch_size')
            default: Default value if key not found

        Returns:
            Configuration value
        """
        keys = key.split(".")
        value = self._config

        for k in keys:
            if isinstance(value, dict) and k in value:
                value = value[k]
            else:
                return default

        return value

    def set(self, key: str, value: Any) -> None:
        """Set configuration value.

        Args:
            key: Configuration key (supports dot notation)
            value: Value to set
        """
        keys = key.split(".")
        config = self._config

        for k in keys[:-1]:
            if k not in config:
                config[k] = {}
            config = config[k]

        config[keys[-1]] = value

    @classmethod
    def from_file(cls, config_path: str) -> "Config":
        """Load configuration from file.

        Args:
            config_path: Path to configuration file (.yaml or .json)

        Returns:
            Config instance
        """
        with open(config_path, "r") as f:
            if config_path.endswith(".yaml") or config_path.endswith(".yml"):
                config_dict = yaml.safe_load(f)
            elif config_path.endswith(".json"):
                config_dict = json.load(f)
            else:
                raise ValueError(f"Unsupported config file format: {config_path}")

        return cls(config_dict)

    def to_dict(self) -> Dict[str, Any]:
        """Convert configuration to dictionary."""
        return self._config.copy()


File: ssrlib/core/registry.py
"""Generic registry system for ssrlib components."""

import importlib
import inspect
import pkgutil
import logging
from pathlib import Path
from typing import Dict, List, Type, Any, Optional, TypeVar, Generic, Callable

logger = logging.getLogger(__name__)

T = TypeVar("T")  # Generic type for the component class


class BaseRegistry(Generic[T]):
    """Generic registry for dynamically discovered components."""

    def __init__(self, component_type_name: str):
        """
        Args:
            component_type_name: Name of component type (e.g., 'dataset', 'embedder')
        """
        self.component_type_name = component_type_name
        self._items: Dict[str, Type[T]] = {}
        self._descriptions: Dict[str, str] = {}
        self._categories: Dict[str, List[str]] = {}
        self._properties: Dict[str, Dict[str, Any]] = {}
        self._discovery_errors: List[str] = []

        # Optional features
        self._modalities: Optional[Dict[str, str]] = None
        self._output_types: Optional[Dict[str, str]] = None

    def enable_modalities(self) -> "BaseRegistry[T]":
        """Enable modality tracking (for datasets, embedders)."""
        self._modalities = {}
        return self

    def enable_output_types(self) -> "BaseRegistry[T]":
        """Enable output type tracking (for processors)."""
        self._output_types = {}
        return self

    def register(
        self,
        name: str,
        item_class: Type[T],
        description: str = "",
        category: str = "general",
        modality: Optional[str] = None,
        output_type: Optional[str] = None,
        properties: Optional[Dict[str, Any]] = None,
    ) -> None:
        """Register a component."""
        self._items[name] = item_class
        self._descriptions[name] = description
        self._properties[name] = properties or {}

        if category not in self._categories:
            self._categories[category] = []
        self._categories[category].append(name)

        if self._modalities is not None and modality:
            self._modalities[name] = modality

        if self._output_types is not None and output_type:
            self._output_types[name] = output_type

    def get(self, name: str) -> Optional[Type[T]]:
        """Get component class by name."""
        return self._items.get(name)

    def get_description(self, name: str) -> str:
        """Get description for a component."""
        return self._descriptions.get(name, "No description available.")

    def get_properties(self, name: str) -> Dict[str, Any]:
        """Get properties for a component."""
        return self._properties.get(name, {})

    def list_all(self) -> List[str]:
        """List all available component names."""
        return list(self._items.keys())

    def list_by_category(self, category: str = None) -> Dict[str, List[str]]:
        """List components by category."""
        if category:
            return {category: self._categories.get(category, [])}
        return self._categories.copy()

    def list_by_modality(self, modality: str = None) -> Dict[str, List[str]]:
        """List components by modality (if enabled)."""
        if self._modalities is None:
            return {}

        modality_groups = {}
        for name, mod in self._modalities.items():
            if mod not in modality_groups:
                modality_groups[mod] = []
            modality_groups[mod].append(name)

        if modality:
            return {modality: modality_groups.get(modality, [])}
        return modality_groups

    def list_by_output_type(self, output_type: str = None) -> Dict[str, List[str]]:
        """List components by output type (if enabled)."""
        if self._output_types is None:
            return {}

        output_groups = {}
        for name, out_type in self._output_types.items():
            if out_type not in output_groups:
                output_groups[out_type] = []
            output_groups[out_type].append(name)

        if output_type:
            return {output_type: output_groups.get(output_type, [])}
        return output_groups

    def get_info(self, name: str) -> Dict[str, Any]:
        """Get comprehensive information about a component."""
        item_class = self.get(name)
        if not item_class:
            return {}

        info = {
            "name": name,
            "class": item_class.__name__,
            "module": item_class.__module__,
            "description": self.get_description(name),
            "docstring": item_class.__doc__ or "",
            "base_classes": [cls.__name__ for cls in item_class.__mro__[1:]],
            "is_abstract": inspect.isabstract(item_class),
            "properties": self.get_properties(name),
        }

        # Add modality if available
        if self._modalities is not None and name in self._modalities:
            info["modality"] = self._modalities[name]

        # Add output type if available
        if self._output_types is not None and name in self._output_types:
            info["output_type"] = self._output_types[name]

        # Try to get initialization signature
        try:
            sig = inspect.signature(item_class.__init__)
            info["parameters"] = {
                param_name: {
                    "default": param.default if param.default != param.empty else None,
                    "annotation": (
                        str(param.annotation) if param.annotation != param.empty else None
                    ),
                }
                for param_name, param in sig.parameters.items()
                if param_name not in ["self", "args", "kwargs"]
            }
        except Exception as e:
            info["parameters"] = f"Error extracting parameters: {e}"

        # Check for AVAILABLE_MODELS
        if hasattr(item_class, "AVAILABLE_MODELS"):
            info["available_models"] = list(item_class.AVAILABLE_MODELS.keys())

        return info

    def print_registry(self) -> None:
        """Print formatted registry information."""
        print(f"Available {self.component_type_name.title()}s:")
        print("=" * 50)

        # Print by category
        for category, items in self._categories.items():
            print(f"\n{category.upper()}:")
            for item_name in sorted(items):
                description = self.get_description(item_name)

                # Add modality or output type if available
                suffix = ""
                if self._modalities is not None and item_name in self._modalities:
                    suffix = f" ({self._modalities[item_name]})"
                elif self._output_types is not None and item_name in self._output_types:
                    suffix = f" [{self._output_types[item_name]}]"

                print(f"  {item_name}{suffix}: {description}")

        # Print by modality if enabled
        if self._modalities is not None:
            print(f"\nBY MODALITY:")
            modality_groups = self.list_by_modality()
            for modality, items in modality_groups.items():
                print(f"  {modality}: {', '.join(sorted(items))}")

        # Print by output type if enabled
        if self._output_types is not None:
            print(f"\nBY OUTPUT TYPE:")
            output_groups = self.list_by_output_type()
            for output_type, items in output_groups.items():
                print(f"  {output_type}: {', '.join(sorted(items))}")

        if self._discovery_errors:
            print(f"\nDiscovery Errors ({len(self._discovery_errors)}):")
            for error in self._discovery_errors:
                print(f"  - {error}")


def extract_description(cls: Type) -> str:
    """Extract description from class docstring."""
    if not cls.__doc__:
        return "No description available."

    docstring = inspect.cleandoc(cls.__doc__)
    lines = docstring.split("\n")

    if lines:
        first_line = lines[0].strip()
        if first_line:
            return first_line

    for line in lines:
        line = line.strip()
        if line and not line.startswith("Args:") and not line.startswith("Returns:"):
            return line

    return "No description available."


def discover_components(
    package_path: Path,
    package_name: str,
    base_class: Type[T],
    registry: BaseRegistry[T],
    get_category_func: Optional[Callable[[Type[T]], str]] = None,
    get_modality_func: Optional[Callable[[Type[T]], str]] = None,
    get_output_type_func: Optional[Callable[[Type[T]], str]] = None,
    get_properties_func: Optional[Callable[[Type[T]], Dict[str, Any]]] = None,
    skip_modules: Optional[List[str]] = None,
) -> BaseRegistry[T]:
    """
    Generic discovery function for any component type.

    Args:
        package_path: Path to the package directory
        package_name: Full package name
        base_class: Base class to filter for
        registry: Registry instance to populate
        get_category_func: Optional function to determine category from class
        get_modality_func: Optional function to determine modality from class
        get_output_type_func: Optional function to determine output type from class
        get_properties_func: Optional function to extract properties from class
        skip_modules: List of module names to skip (default: ['__init__', 'base'])
    """
    if skip_modules is None:
        skip_modules = ["__init__", "base"]

    logger.debug(f"Discovering components in {package_path}")

    # Iterate through all modules in the package
    for module_info in pkgutil.walk_packages([str(package_path)], prefix=f"{package_name}."):
        module_name = module_info.name

        # Skip specified modules
        if any(module_name.endswith(f".{skip}") for skip in skip_modules):
            continue

        try:
            module = importlib.import_module(module_name)
            logger.debug(f"Scanning module: {module_name}")

            # Find all classes in the module
            for name, obj in inspect.getmembers(module, inspect.isclass):
                if (
                    inspect.isclass(obj)
                    and issubclass(obj, base_class)
                    and not inspect.isabstract(obj)
                    and obj != base_class
                    and obj.__module__ == module_name
                ):
                    # Extract information
                    description = extract_description(obj)

                    # Get category
                    category = "general"
                    if get_category_func:
                        category = get_category_func(obj)
                    elif hasattr(obj, f"get_{registry.component_type_name}_category"):
                        category = getattr(obj, f"get_{registry.component_type_name}_category")()

                    # Get modality
                    modality = None
                    if get_modality_func:
                        modality = get_modality_func(obj)
                    elif hasattr(obj, f"get_{registry.component_type_name}_modality"):
                        modality = getattr(obj, f"get_{registry.component_type_name}_modality")()

                    # Get output type
                    output_type = None
                    if get_output_type_func:
                        output_type = get_output_type_func(obj)

                    # Get properties
                    properties = {}
                    if get_properties_func:
                        properties = get_properties_func(obj)
                    elif hasattr(obj, f"get_{registry.component_type_name}_properties"):
                        properties = getattr(
                            obj, f"get_{registry.component_type_name}_properties"
                        )()

                    # Register the component
                    registry.register(
                        name,
                        obj,
                        description,
                        category,
                        modality,
                        output_type,
                        properties,
                    )

                    logger.debug(f"Registered: {name} ({category})")

        except Exception as e:
            error_msg = f"Failed to import {module_name}: {e}"
            logger.warning(error_msg)
            registry._discovery_errors.append(error_msg)
            continue

    logger.info(f"Discovered {len(registry.list_all())} components")
    return registry


File: ssrlib/core/__init__.py
from .pipeline import Pipeline, PipelineResults
from .config import Config
from .registry import BaseRegistry, discover_components, extract_description

__all__ = [
    "Pipeline",
    "PipelineResults",
    "Config",
    "BaseRegistry",
    "discover_components",
    "extract_description",
]


File: ssrlib/core/pipeline.py
from typing import List, Tuple, Any, Dict, Union, Optional, Iterator
import time
import numpy as np
import hashlib
import os
from pathlib import Path
import json
import logging

from ..datasets.base import BaseDataset
from ..embedders.base import BaseEmbedder
from ..processing.base import BaseProcessor
from .config import Config
from ..storage.tensor_storage import TensorStorage

logger = logging.getLogger(__name__)


class PipelineResults:
    """Container for pipeline execution results."""

    def __init__(self):
        # (dataset_key, embedder_name) -> embeddings
        self.embeddings: Dict[Tuple[str, str], np.ndarray] = {}
        # (dataset_key, embedder_name, processor_name) -> processed_data
        self.processed: Dict[Tuple[str, str, str], np.ndarray] = {}
        # General metadata
        self.metadata: Dict[str, Any] = {}
        # Timing information
        self.timing: Dict[str, float] = {}
        # Mapping from dataset_key to original dataset name
        self.dataset_key_mapping: Dict[str, str] = {}
        # Storage information
        self.storage_info: Optional[Dict[str, Any]] = None

    def get_embeddings(self, dataset_key: str, embedder_name: str) -> np.ndarray:
        """Get embeddings for specific dataset-embedder combination."""
        return self.embeddings.get((dataset_key, embedder_name))

    def get_processed(
        self, dataset_key: str, embedder_name: str, processor_name: str
    ) -> np.ndarray:
        """Get processed data for specific dataset-embedder-processor combination."""
        return self.processed.get((dataset_key, embedder_name, processor_name))

    def list_dataset_keys(self) -> List[str]:
        """List all unique dataset keys."""
        return list(self.dataset_key_mapping.keys())

    def get_original_dataset_name(self, dataset_key: str) -> str:
        """Get original dataset name from dataset key."""
        return self.dataset_key_mapping.get(dataset_key, dataset_key)


class Pipeline:
    """Main pipeline class for orchestrating ssrlib components with storage support."""

    def __init__(self, components: List[Tuple[str, Any]], config: Config = None):
        """Initialize pipeline.

        Args:
            components: List of (component_type, component) tuples
            config: Configuration object
        """
        self.components = components
        self.config = config or Config()

        # Organize components by type
        self.datasets = []
        self.embedders = []
        self.processors = []

        self._organize_components()

    def _organize_components(self) -> None:
        """Organize components by type from the input list."""
        for comp_type, comp in self.components:
            if comp_type in ["dataset", "datasets"]:
                if isinstance(comp, list):
                    self.datasets.extend(comp)
                else:
                    self.datasets.append(comp)
            elif comp_type in ["embedder", "embedders"]:
                if isinstance(comp, list):
                    self.embedders.extend(comp)
                else:
                    self.embedders.append(comp)
            elif comp_type in ["processor", "processors"]:
                if isinstance(comp, list):
                    self.processors.extend(comp)
                else:
                    self.processors.append(comp)

    def add_dataset(self, dataset: BaseDataset) -> "Pipeline":
        """Add dataset to pipeline."""
        self.datasets.append(dataset)
        return self

    def add_embedder(self, embedder: BaseEmbedder) -> "Pipeline":
        """Add embedder to pipeline."""
        self.embedders.append(embedder)
        return self

    def add_processor(self, processor: BaseProcessor) -> "Pipeline":
        """Add processor to pipeline."""
        self.processors.append(processor)
        return self

    def execute(
        self,
        config_override: Dict = None,
        use_storage: bool = False,
        storage_dir: Optional[str] = None,
        force_recompute: bool = False,
        storage_description: str = "",
    ) -> PipelineResults:
        """
        Execute the pipeline with optional storage caching.

        Refactored version with reduced complexity (target: ~8).

        Args:
            config_override: Configuration overrides
            use_storage: Whether to use storage for caching embeddings
            storage_dir: Directory for storage (auto-generated if None)
            force_recompute: Whether to force recomputation even if cached
            storage_description: Description for storage

        Returns:
            PipelineResults containing all computed embeddings and processed data
        """
        start_time = time.time()

        try:
            # Stage 1: Validation and preparation
            self._validate_configuration()
            results = self._initialize_results()

            # Stage 2: Apply configuration
            dataset_keys = self._prepare_dataset_keys(results)
            self._apply_config_overrides(config_override)
            batch_size = self.config.get("batch_size", 32)

            # Stage 3: Setup storage
            storage = self._setup_storage_system(
                use_storage, storage_dir, storage_description, results
            )

            # Stage 4: Load data and models
            self._download_datasets()
            self._load_embedders()

            # Stage 5: Extract embeddings (with caching)
            storage_keys_map = self._prepare_storage_keys(dataset_keys)
            cached_embeddings = self._load_cached_embeddings(
                storage, storage_keys_map, use_storage, force_recompute
            )

            embedding_time = time.time()
            embeddings_to_save = self._extract_embeddings(
                dataset_keys,
                storage_keys_map,
                cached_embeddings,
                batch_size,
                use_storage,
                results,
            )
            results.timing["embedding_time"] = time.time() - embedding_time

            # Stage 6: Save new embeddings
            self._save_embeddings_to_cache(
                embeddings_to_save,
                storage,
                storage_dir,
                storage_description,
                use_storage,
                results,
            )

            # Stage 7: Log cache statistics
            self._log_cache_statistics(cached_embeddings, embeddings_to_save, use_storage, results)

            # Stage 8: Process embeddings
            if self.processors:
                self._process_embeddings(dataset_keys, results)

            # Stage 9: Collect final results
            self._collect_metadata(dataset_keys, storage, results)

            results.timing["total_time"] = time.time() - start_time
            logger.info(f"Pipeline execution completed in {results.timing['total_time']:.2f}s")
            print(f"Pipeline execution completed in {results.timing['total_time']:.2f}s")

            return results

        except Exception as e:
            logger.error(f"Pipeline execution failed: {str(e)}", exc_info=True)
            raise RuntimeError(f"Pipeline execution failed: {str(e)}") from e

    def _validate_configuration(self) -> None:
        """Validate pipeline configuration. Complexity: 2"""
        if not self.datasets:
            raise ValueError("Pipeline requires at least one dataset")

        if not self.embedders:
            raise ValueError("Pipeline requires at least one embedder")

    def _initialize_results(self) -> PipelineResults:
        """Initialize results container. Complexity: 1"""
        return PipelineResults()

    def _prepare_dataset_keys(self, results: PipelineResults) -> Dict[Any, str]:
        """
        Create unique keys for dataset instances.
        Complexity: 3
        """
        dataset_keys = self._create_unique_dataset_keys()

        # Store mapping in results
        for dataset, unique_key in dataset_keys.items():
            results.dataset_key_mapping[unique_key] = dataset.name

        return dataset_keys

    def _create_unique_dataset_keys(self) -> Dict[Any, str]:
        """Create unique keys for each dataset instance. Complexity: 4"""
        dataset_counts = {}
        dataset_keys = {}

        for dataset in self.datasets:
            base_name = dataset.name

            if base_name not in dataset_counts:
                dataset_counts[base_name] = 0
            else:
                dataset_counts[base_name] += 1

            # Create unique key
            if dataset_counts[base_name] == 0:
                unique_key = base_name
            else:
                unique_key = f"{base_name}[{dataset_counts[base_name]}]"

            dataset_keys[dataset] = unique_key

        return dataset_keys

    def _apply_config_overrides(self, config_override: Optional[Dict]) -> None:
        """Apply configuration overrides. Complexity: 2"""
        if not config_override:
            return

        for key, value in config_override.items():
            self.config.set(key, value)

    def _setup_storage_system(
        self,
        use_storage: bool,
        storage_dir: Optional[str],
        storage_description: str,
        results: PipelineResults,
    ) -> Optional[TensorStorage]:
        """
        Setup storage system if requested.
        Complexity: 4
        """
        if not use_storage:
            results.storage_info = {"enabled": False}
            return None

        # Auto-generate storage directory if needed
        if storage_dir is None:
            timestamp = int(time.time())
            storage_dir = f"./storage/pipeline_cache_{timestamp}"

        os.makedirs(storage_dir, exist_ok=True)
        storage = self._setup_storage(storage_dir, storage_description)

        results.storage_info = {
            "enabled": True,
            "directory": storage_dir,
            "description": storage_description,
        }

        return storage

    def _setup_storage(self, storage_dir: str, description: str = "") -> TensorStorage:
        """Setup or load existing storage. Complexity: 3"""
        metadata_path = os.path.join(storage_dir, "metadata", "metadata.json")

        if os.path.exists(storage_dir) and os.path.exists(metadata_path):
            logger.info(f"Loading existing storage from {storage_dir}")
            return TensorStorage(storage_dir)

        logger.info(f"Will create new storage in {storage_dir}")
        return None

    def _download_datasets(self) -> None:
        """Download all datasets. Complexity: 2"""
        logger.info("Downloading datasets...")
        print("Downloading datasets...")

        for dataset in self.datasets:
            try:
                dataset.download()
            except Exception as e:
                logger.error(f"Failed to download dataset {dataset.name}: {str(e)}")
                raise

    def _load_embedders(self) -> None:
        """Load all embedder models. Complexity: 2"""
        logger.info("Loading embedders...")
        print("Loading embedders...")

        for embedder in self.embedders:
            try:
                embedder.load_model()
            except Exception as e:
                logger.error(f"Failed to load embedder {embedder.name}: {str(e)}")
                raise

    def _prepare_storage_keys(self, dataset_keys: Dict[Any, str]) -> Dict[Tuple[Any, Any], str]:
        """
        Create storage keys for all dataset-embedder combinations.
        Complexity: 3
        """
        storage_keys_map = {}

        for dataset in self.datasets:
            dataset_key = dataset_keys[dataset]
            for embedder in self.embedders:
                storage_key = self._create_storage_key(dataset_key, embedder.name, dataset)
                storage_keys_map[(dataset, embedder)] = storage_key

        return storage_keys_map

    def _create_storage_key(
        self, dataset_key: str, embedder_name: str, dataset: BaseDataset
    ) -> str:
        """
        Create unique storage key for dataset-embedder combination.
        Complexity: 2
        """
        dataset_config = {
            "name": dataset.name,
            "size": len(dataset),
            "metadata": dataset.get_metadata(),
        }

        config_str = json.dumps(dataset_config, sort_keys=True)
        config_hash = hashlib.md5(config_str.encode()).hexdigest()[:8]

        return f"{dataset_key}_{embedder_name}_{config_hash}"

    def _load_cached_embeddings(
        self,
        storage: Optional[TensorStorage],
        storage_keys_map: Dict[Tuple[Any, Any], str],
        use_storage: bool,
        force_recompute: bool,
    ) -> Dict[str, np.ndarray]:
        """
        Load cached embeddings from storage if available.
        Complexity: 3
        """
        if not use_storage or not storage or force_recompute:
            return {}

        storage_keys_needed = list(storage_keys_map.values())
        return self._load_embeddings_from_storage(storage, storage_keys_needed)

    def _load_embeddings_from_storage(
        self, storage: TensorStorage, storage_keys_needed: List[str]
    ) -> Dict[str, np.ndarray]:
        """Load embeddings from storage if they exist. Complexity: 4"""
        loaded_embeddings = {}

        if storage is None or storage.metadata_df is None or storage.metadata_df.empty:
            return loaded_embeddings

        for storage_key in storage_keys_needed:
            matches = storage.metadata_df[storage.metadata_df["storage_key"] == storage_key]
            if len(matches) > 0:
                tensor_idx = matches.iloc[0]["tensor_idx"]
                embeddings = storage[tensor_idx]
                loaded_embeddings[storage_key] = embeddings
                logger.info(f"Loaded embeddings for {storage_key} from cache")
                print(f"Loaded embeddings for {storage_key} from cache")

        return loaded_embeddings

    def _extract_embeddings(
        self,
        dataset_keys: Dict[Any, str],
        storage_keys_map: Dict[Tuple[Any, Any], str],
        cached_embeddings: Dict[str, np.ndarray],
        batch_size: int,
        use_storage: bool,
        results: PipelineResults,
    ) -> Dict[str, Tuple[np.ndarray, Dict]]:
        """
        Extract embeddings for all dataset-embedder combinations.
        Complexity: 5
        """
        logger.info("Extracting embeddings...")
        print("Extracting embeddings...")

        embeddings_to_save = {}

        for dataset in self.datasets:
            dataset_key = dataset_keys[dataset]

            for embedder in self.embedders:
                storage_key = storage_keys_map[(dataset, embedder)]

                # Try to use cache or compute
                if storage_key in cached_embeddings:
                    embeddings = cached_embeddings[storage_key]
                    logger.info(f"Using cached embeddings for {dataset_key} + {embedder.name}")
                    print(f"Using cached embeddings for {dataset_key} + {embedder.name}")
                else:
                    embeddings = self._compute_embeddings(
                        dataset, embedder, dataset_key, batch_size
                    )

                    # Prepare for storage
                    if use_storage:
                        metadata = self._create_embedding_metadata(
                            dataset_key, dataset, embedder, embeddings, batch_size
                        )
                        embeddings_to_save[storage_key] = (embeddings, metadata)

                # Store in results
                self._store_embedding_results(results, dataset_key, embedder.name, embeddings)

        return embeddings_to_save

    def _compute_embeddings(
        self,
        dataset: BaseDataset,
        embedder: BaseEmbedder,
        dataset_key: str,
        batch_size: int,
    ) -> np.ndarray:
        """Compute embeddings for a dataset. Complexity: 2"""
        logger.info(f"Computing embeddings for {dataset_key} + {embedder.name}")
        print(f"Computing embeddings for {dataset_key} + {embedder.name}")

        try:
            return embedder.embed_dataset(dataset, batch_size)
        except Exception as e:
            logger.error(f"Failed to compute embeddings: {str(e)}")
            raise

    def _create_embedding_metadata(
        self,
        dataset_key: str,
        dataset: BaseDataset,
        embedder: BaseEmbedder,
        embeddings: np.ndarray,
        batch_size: int,
    ) -> Dict[str, Any]:
        """Create metadata for embedding storage. Complexity: 1"""
        return {
            "dataset_key": dataset_key,
            "dataset_name": dataset.name,
            "embedder_name": embedder.name,
            "embeddings_shape": embeddings.shape,
            "dataset_size": len(dataset),
            "batch_size": batch_size,
            "timestamp": time.time(),
        }

    def _store_embedding_results(
        self,
        results: PipelineResults,
        dataset_key: str,
        embedder_name: str,
        embeddings: np.ndarray,
    ) -> None:
        """Store embeddings in results. Complexity: 1"""
        results.embeddings[(dataset_key, embedder_name)] = embeddings
        results.metadata[f"{dataset_key}_{embedder_name}_shape"] = embeddings.shape
        results.metadata[f"{dataset_key}_{embedder_name}_dtype"] = str(embeddings.dtype)

    def _save_embeddings_to_cache(
        self,
        embeddings_to_save: Dict[str, Tuple[np.ndarray, Dict]],
        storage: Optional[TensorStorage],
        storage_dir: str,
        storage_description: str,
        use_storage: bool,
        results: PipelineResults,
    ) -> None:
        """
        Save new embeddings to storage.
        Complexity: 5
        """
        if not use_storage or not embeddings_to_save:
            return

        save_time = time.time()

        if storage is None:
            storage = self._save_embeddings_to_storage(
                embeddings_to_save, storage_dir, storage_description
            )
        else:
            # Adding to existing storage not yet implemented
            logger.warning("Adding to existing storage not implemented, creating new storage")
            print("Warning: Adding to existing storage not implemented, creating new storage")
            storage = self._save_embeddings_to_storage(
                embeddings_to_save, storage_dir + "_new", storage_description
            )

        results.timing["storage_save_time"] = time.time() - save_time

        if storage:
            results.storage_info.update(storage.get_storage_info())

    def _save_embeddings_to_storage(
        self,
        embeddings_data: Dict[str, Tuple[np.ndarray, Dict]],
        storage_dir: str,
        description: str = "",
    ) -> TensorStorage:
        """Save embeddings to storage. Complexity: 3"""
        if not embeddings_data:
            return None

        def tensor_iterator() -> Iterator[np.ndarray]:
            for embeddings, _ in embeddings_data.values():
                yield embeddings

        def metadata_iterator() -> Iterator[Dict[str, Any]]:
            for storage_key, (_, metadata) in embeddings_data.items():
                metadata_dict = metadata.copy()
                metadata_dict["storage_key"] = storage_key
                yield metadata_dict

        logger.info(f"Saving {len(embeddings_data)} embeddings to storage...")
        print(f"Saving {len(embeddings_data)} embeddings to storage...")

        storage = TensorStorage.create_storage(
            storage_dir=storage_dir,
            data_iterator=tensor_iterator(),
            metadata_iterator=metadata_iterator(),
            description=description,
        )

        return storage

    def _log_cache_statistics(
        self,
        cached_embeddings: Dict[str, np.ndarray],
        embeddings_to_save: Dict[str, Tuple[np.ndarray, Dict]],
        use_storage: bool,
        results: PipelineResults,
    ) -> None:
        """Log cache hit/miss statistics. Complexity: 3"""
        if not use_storage:
            return

        cache_hits = len(cached_embeddings)
        cache_misses = len(embeddings_to_save)
        total_combinations = len(self.datasets) * len(self.embedders)

        logger.info(
            f"Cache statistics: {cache_hits} hits, {cache_misses} misses out of {total_combinations} total"
        )
        print(
            f"Cache statistics: {cache_hits} hits, {cache_misses} misses out of {total_combinations} total"
        )

        results.metadata["cache_hits"] = cache_hits
        results.metadata["cache_misses"] = cache_misses
        results.metadata["cache_hit_rate"] = (
            cache_hits / total_combinations if total_combinations > 0 else 0
        )

    def _process_embeddings(self, dataset_keys: Dict[Any, str], results: PipelineResults) -> None:
        """
        Apply all processors to embeddings.
        Complexity: 4
        """
        logger.info("Processing embeddings...")
        print("Processing embeddings...")

        processing_time = time.time()

        for dataset in self.datasets:
            dataset_key = dataset_keys[dataset]

            for embedder in self.embedders:
                embeddings = results.embeddings[(dataset_key, embedder.name)]

                for processor in self.processors:
                    self._apply_single_processor(
                        processor, embeddings, dataset_key, embedder.name, results
                    )

        results.timing["processing_time"] = time.time() - processing_time

    def _apply_single_processor(
        self,
        processor: BaseProcessor,
        embeddings: np.ndarray,
        dataset_key: str,
        embedder_name: str,
        results: PipelineResults,
    ) -> None:
        """Apply a single processor. Complexity: 2"""
        logger.info(f"Processing {dataset_key}-{embedder_name} with {processor.name}")
        print(f"Processing {dataset_key}-{embedder_name} with {processor.name}")

        try:
            processed = processor.process(embeddings)
            results.processed[(dataset_key, embedder_name, processor.name)] = processed

            # Store metadata
            key_prefix = f"{dataset_key}_{embedder_name}_{processor.name}"
            results.metadata[f"{key_prefix}_shape"] = processed.shape
            results.metadata[f"{key_prefix}_dtype"] = str(processed.dtype)
        except Exception as e:
            logger.error(f"Failed to process with {processor.name}: {str(e)}")
            raise

    def _collect_metadata(
        self,
        dataset_keys: Dict[Any, str],
        storage: Optional[TensorStorage],
        results: PipelineResults,
    ) -> None:
        """Collect final metadata. Complexity: 3"""
        # Dataset metadata
        results.metadata["datasets"] = []
        for dataset in self.datasets:
            dataset_meta = dataset.get_metadata().copy()
            dataset_meta["pipeline_key"] = dataset_keys[dataset]
            results.metadata["datasets"].append(dataset_meta)

        # Embedder metadata
        results.metadata["embedders"] = [e.get_metadata() for e in self.embedders]

        # Processor metadata
        results.metadata["processors"] = [p.get_metadata() for p in self.processors]

        # Config
        results.metadata["config"] = self.config.to_dict()

        # Storage metadata
        if storage:
            results.metadata["storage"] = storage.get_storage_info()


File: ssrlib/processing/leverage_scores.py
import numpy as np
from typing import Dict, Any, Optional
from .base import BaseProcessor


class LeverageScoresProcessor(BaseProcessor):
    """
    Computes row leverage scores diag(U_k U_k^T) from a rank-k SVD of centered X.
    Sum of scores equals k. Useful for landmark selection, importance sampling,
    and spotting outliers/anomalies.

    If rank is None, chooses the smallest k giving the desired energy threshold.
    """

    def __init__(
        self,
        rank: Optional[int] = None,
        energy: float = 0.9,
        center: bool = True,
        **kwargs,
    ):
        """
        Args:
            rank: target rank k (1..min(n,d)). If None, choose via 'energy'.
            energy: fraction of spectral energy (sum s_i^2) to retain when rank=None.
            center: mean-center before SVD (recommended).
        """
        super().__init__("LeverageScores", **kwargs)
        if rank is not None and rank <= 0:
            raise ValueError("rank must be positive when provided.")
        if not (0.0 < energy <= 1.0):
            raise ValueError("energy must be in (0, 1].")

        self.rank = rank
        self.energy = float(energy)
        self.center = bool(center)

        self._metadata.update(
            {
                "processor_type": "leverage_scores",
                "rank": self.rank,
                "energy": self.energy,
                "center": self.center,
                "output_type": "row_scores",
            }
        )

    def process(self, embeddings: np.ndarray) -> np.ndarray:
        if embeddings.ndim != 2:
            raise ValueError(f"Expected 2D embeddings, got shape {embeddings.shape}")

        X = embeddings.astype(np.float64, copy=False)
        if self.center:
            X = X - X.mean(axis=0, keepdims=True)

        # Full SVD is simple & robust for moderate sizes.
        # Swap in a randomized SVD if you need scalability later.
        U, s, _ = np.linalg.svd(X, full_matrices=False)

        if U.size == 0:
            scores = np.zeros((X.shape[0],), dtype=np.float64)
            chosen_k = 0
        else:
            if self.rank is None:
                # Choose k for desired energy in s^2
                energy_spectrum = s**2
                cum = np.cumsum(energy_spectrum)
                total = cum[-1] if cum.size > 0 else 0.0
                if total <= 0:
                    chosen_k = 0
                else:
                    k = int(np.searchsorted(cum, self.energy * total) + 1)
                    chosen_k = max(1, min(k, U.shape[1]))
            else:
                chosen_k = max(1, min(self.rank, U.shape[1]))

            Uk = U[:, :chosen_k] if chosen_k > 0 else np.zeros((X.shape[0], 0), dtype=np.float64)
            scores = np.sum(Uk * Uk, axis=1)  # diag(Uk Uk^T)

        self._metadata.update(
            {
                "input_shape": embeddings.shape,
                "n_vectors": int(embeddings.shape[0]),
                "n_features": int(embeddings.shape[1]),
                "chosen_rank": int(self._metadata.get("rank") or chosen_k),
                "scores_sum": float(scores.sum()),
                "scores_min": float(scores.min() if scores.size else 0.0),
                "scores_max": float(scores.max() if scores.size else 0.0),
            }
        )

        return scores


File: ssrlib/processing/stable_rank.py
import numpy as np
from typing import Dict, Any
from .base import BaseProcessor


class StableRankProcessor(BaseProcessor):
    """
    Stable rank of the (optionally centered) data matrix X:
        srank = ||X||_F^2 / ||X||_2^2
    where ||X||_2 is the top singular value.
    """

    def __init__(self, center: bool = True, epsilon: float = 1e-12, **kwargs):
        """
        Args:
            center: mean-center rows before computing norms (often desirable).
            epsilon: small floor for top singular value squared.
        """
        super().__init__("StableRank", **kwargs)
        self.center = bool(center)
        self.epsilon = float(epsilon)

        self._metadata.update(
            {
                "processor_type": "stable_rank",
                "center": self.center,
                "epsilon": self.epsilon,
                "output_type": "scalar_statistic",
            }
        )

    def process(self, embeddings: np.ndarray) -> np.ndarray:
        if embeddings.ndim != 2:
            raise ValueError(f"Expected 2D embeddings, got shape {embeddings.shape}")

        X = embeddings.astype(np.float64, copy=False)
        if self.center:
            X = X - X.mean(axis=0, keepdims=True)

        fro2 = float(np.sum(X * X))
        # Top singular value via SVD
        # (full SVD is fine for moderate sizes; replace with randomized SVD if needed)
        s = np.linalg.svd(X, full_matrices=False, compute_uv=False)
        s1_sq = float(s[0] ** 2) if s.size > 0 else 0.0

        denom = max(s1_sq, self.epsilon)
        srank = fro2 / denom if denom > 0 else 0.0

        self._metadata.update(
            {
                "input_shape": embeddings.shape,
                "n_vectors": int(embeddings.shape[0]),
                "n_features": int(embeddings.shape[1]),
                "frobenius_sq": fro2,
                "top_singular_sq": s1_sq,
                "stable_rank": srank,
            }
        )
        return np.array([srank], dtype=np.float64)


File: ssrlib/processing/covariance.py
import numpy as np
from typing import Dict, Any

from .base import BaseProcessor


class CovarianceProcessor(BaseProcessor):
    """Processor for computing covariance matrix of embeddings."""

    def __init__(self, **kwargs):
        """Initialize covariance processor."""
        super().__init__("Covariance", **kwargs)

        self._metadata.update({"processor_type": "covariance", "output_type": "covariance_matrix"})

    def process(self, embeddings: np.ndarray) -> np.ndarray:
        """Compute covariance matrix of embeddings.

        Args:
            embeddings: Input embeddings of shape (n_vectors, n_features)

        Returns:
            Covariance matrix of shape (n_features, n_features)
        """
        if embeddings.ndim != 2:
            raise ValueError(f"Expected 2D embeddings, got shape {embeddings.shape}")

        # Compute feature covariance matrix
        # np.cov expects features in rows, so we transpose
        covariance_matrix = np.cov(embeddings.T)

        # Update metadata with shape info
        self._metadata.update(
            {
                "input_shape": embeddings.shape,
                "output_shape": covariance_matrix.shape,
                "n_vectors": embeddings.shape[0],
                "n_features": embeddings.shape[1],
            }
        )

        return covariance_matrix


File: ssrlib/processing/__init__.py
"""Processing for ssrlib with automatic discovery."""

import logging
from pathlib import Path
from typing import Dict, List, Type, Any
import warnings

logger = logging.getLogger(__name__)

# Import base class and registry system
from .base import BaseProcessor
from ..core.registry import BaseRegistry, discover_components

# Type alias
ProcessorRegistry = BaseRegistry[BaseProcessor]


def discover_processor_classes() -> ProcessorRegistry:
    """Discover all processor classes in the processing module."""
    registry = ProcessorRegistry("processor")

    return discover_components(
        package_path=Path(__file__).parent,
        package_name=__name__,
        base_class=BaseProcessor,
        registry=registry,
    )


# Perform discovery at import time
logger.debug("Starting processor discovery...")
_processor_registry = discover_processor_classes()


# Public API functions
def get_available_processors() -> Dict[str, Type[BaseProcessor]]:
    """Get dictionary of all available processors.

    Returns:
        Dictionary mapping processor names to their classes
    """
    return _processor_registry._items.copy()


def get_processor_descriptions() -> Dict[str, str]:
    """Get dictionary of processor descriptions.

    Returns:
        Dictionary mapping processor names to their descriptions
    """
    return _processor_registry._descriptions.copy()


def list_processors() -> List[str]:
    """List all available processor names.

    Returns:
        List of processor names
    """
    return _processor_registry.list_all()


def get_processor_info(name: str) -> Dict[str, Any]:
    """Get detailed information about a processor.

    Args:
        name: Name of the processor

    Returns:
        Dictionary containing processor information

    Raises:
        ValueError: If processor not found
    """
    return _processor_registry.get_info(name)


def print_available_processors() -> None:
    """Print all available processors with descriptions."""
    _processor_registry.print_registry()


def create_processor(name: str, **kwargs) -> BaseProcessor:
    """Create a processor instance by name.

    Args:
        name: Name of the processor
        **kwargs: Processor-specific initialization arguments

    Returns:
        Initialized processor instance

    Raises:
        ValueError: If processor not found
    """
    processor_class = _processor_registry.get(name)
    if processor_class is None:
        available = ", ".join(_processor_registry.list_all())
        raise ValueError(f"Unknown processor '{name}'. Available: {available}")
    return processor_class(**kwargs)


# Create dynamic exports
_exported_classes = {}
for name, processor_class in _processor_registry._items.items():
    _exported_classes[name] = processor_class

# Update module globals for direct imports
globals().update(_exported_classes)

# Create __all__ dynamically
__all__ = [
    "BaseProcessor",
    "get_available_processors",
    "get_processor_descriptions",
    "list_processors",
    "get_processor_info",
    "print_available_processors",
    "create_processor",
    *_processor_registry.list_all(),
]

# Log results
if logger.isEnabledFor(logging.INFO):
    processors = _processor_registry.list_all()
    logger.info(f"Processor discovery complete: {len(processors)} processors found")
    logger.info(f"  Available: {', '.join(sorted(processors))}")

# Warn about errors
if _processor_registry._discovery_errors:
    warnings.warn(
        f"Some processor modules failed to import: {len(_processor_registry._discovery_errors)} errors. "
        f"Run logging.getLogger('{__name__}').setLevel(logging.DEBUG) for details.",
        ImportWarning,
    )


File: ssrlib/processing/pairwise_stats.py
import numpy as np
from typing import Optional
from .base import BaseProcessor


class PairwiseDistanceStatsProcessor(BaseProcessor):
    """
    Computes summary statistics of pairwise distances between embeddings
    on a (possibly subsampled) set of points.

    Returns a 1D array [mean, std, min, max] for the chosen distance metric.
    """

    def __init__(
        self,
        metric: str = "cosine",
        max_samples: int = 4096,
        center: bool = False,
        seed: Optional[int] = 0,
        **kwargs,
    ):
        """
        Args:
            metric: 'cosine' or 'euclidean'.
            max_samples: maximum number of points to use for pairwise stats.
                         If n_vectors > max_samples, a random subset is taken.
            center: whether to mean-center before computing distances.
            seed: random seed for subsampling (None = use np.random default).
        """
        super().__init__("PairwiseDistanceStats", **kwargs)
        metric = metric.lower()
        if metric not in ("cosine", "euclidean"):
            raise ValueError("metric must be 'cosine' or 'euclidean'")

        if max_samples <= 1:
            raise ValueError("max_samples must be > 1")

        self.metric = metric
        self.max_samples = int(max_samples)
        self.center = bool(center)
        self.seed = seed

        self._metadata.update(
            {
                "processor_type": "pairwise_distance_stats",
                "metric": self.metric,
                "max_samples": self.max_samples,
                "center": self.center,
                "seed": self.seed,
                "output_type": "distance_summary",
                "stats_order": ["mean", "std", "min", "max"],
            }
        )

    def _subsample(self, X: np.ndarray) -> np.ndarray:
        n = X.shape[0]
        if n <= self.max_samples:
            return X

        rng = np.random.default_rng(self.seed)
        idx = rng.choice(n, size=self.max_samples, replace=False)
        return X[idx]

    def process(self, embeddings: np.ndarray) -> np.ndarray:
        if embeddings.ndim != 2:
            raise ValueError(f"Expected 2D embeddings, got shape {embeddings.shape}")

        X = embeddings.astype(np.float64, copy=False)
        if self.center:
            X = X - X.mean(axis=0, keepdims=True)

        X = self._subsample(X)
        m = X.shape[0]

        if m < 2:
            # Нечего считать — вернём нули
            stats = np.zeros(4, dtype=np.float64)
            self._metadata.update(
                {
                    "input_shape": embeddings.shape,
                    "used_samples": int(m),
                    "pairwise_count": 0,
                }
            )
            return stats

        # Вычисляем попарные расстояния/сходства
        if self.metric == "cosine":
            # Нормируем строки
            norms = np.linalg.norm(X, axis=1, keepdims=True)
            norms = np.where(norms == 0.0, 1.0, norms)
            Y = X / norms
            sim = Y @ Y.T  # (m, m), cosine similarity
            # Превращаем в расстояние
            dist = 1.0 - sim
        else:  # 'euclidean'
            # ||x_i - x_j||^2 = ||x_i||^2 + ||x_j||^2 - 2 x_i x_j^T
            sq_norms = np.sum(X * X, axis=1, keepdims=True)  # (m, 1)
            sq_dists = sq_norms + sq_norms.T - 2.0 * (X @ X.T)
            # Численно из-за округления sq_dists могут чуть уходить в минус
            sq_dists = np.maximum(sq_dists, 0.0)
            dist = np.sqrt(sq_dists)

        # Берём только верхний треугольник (без диагонали)
        iu = np.triu_indices(m, k=1)
        dist_vec = dist[iu]

        mean = float(dist_vec.mean())
        std = float(dist_vec.std())
        dmin = float(dist_vec.min())
        dmax = float(dist_vec.max())

        self._metadata.update(
            {
                "input_shape": embeddings.shape,
                "used_samples": int(m),
                "pairwise_count": int(dist_vec.size),
                "mean": mean,
                "std": std,
                "min": dmin,
                "max": dmax,
            }
        )

        return np.array([mean, std, dmin, dmax], dtype=np.float64)


File: ssrlib/processing/spectrum.py
import numpy as np
from typing import Dict, Any
from .base import BaseProcessor


class SpectrumProcessor(BaseProcessor):
    """
    Computes the eigenvalue spectrum of the (optionally centered) covariance
    matrix of embeddings.

    By default returns raw eigenvalues λ_i sorted in descending order.
    Optionally can also return normalized spectrum (explained variance ratios).
    """

    def __init__(
        self,
        center: bool = True,
        epsilon: float = 1e-12,
        normalize: bool = False,
        **kwargs,
    ):
        """
        Args:
            center: mean-center before covariance.
            epsilon: small floor for eigenvalues to avoid numerical issues.
            normalize: if True, return explained-variance ratios λ_i / sum_j λ_j
                       instead of raw eigenvalues.
        """
        super().__init__("Spectrum", **kwargs)
        self.center = bool(center)
        self.epsilon = float(epsilon)
        self.normalize = bool(normalize)

        self._metadata.update(
            {
                "processor_type": "spectrum",
                "center": self.center,
                "epsilon": self.epsilon,
                "normalize": self.normalize,
                "output_type": "eigenvalue_spectrum",
            }
        )

    def process(self, embeddings: np.ndarray) -> np.ndarray:
        if embeddings.ndim != 2:
            raise ValueError(f"Expected 2D embeddings, got shape {embeddings.shape}")

        X = embeddings.astype(np.float64, copy=False)
        if self.center:
            X = X - X.mean(axis=0, keepdims=True)

        # Covariance and eigenvalues
        C = np.cov(X.T)
        evals = np.linalg.eigvalsh(C)  # ascending
        evals = np.maximum(evals, 0.0)

        # Sort descending for convenience
        evals = evals[::-1]

        trace = float(evals.sum())
        top = float(evals[0]) if evals.size > 0 else 0.0

        if self.normalize and trace > self.epsilon:
            spectrum = evals / (trace + self.epsilon)
        else:
            spectrum = evals

        self._metadata.update(
            {
                "input_shape": embeddings.shape,
                "n_vectors": int(embeddings.shape[0]),
                "n_features": int(embeddings.shape[1]),
                "trace": trace,
                "spectral_norm": top,
                "n_eigenvalues": int(evals.size),
            }
        )

        return spectrum.astype(np.float64, copy=False)


File: ssrlib/processing/effective_rank.py
import numpy as np
from typing import Dict, Any
from .base import BaseProcessor


class EffectiveRankProcessor(BaseProcessor):
    """
    Returns the 'effective rank' of the (centered) covariance:
        erank = exp( - sum_i p_i log p_i ), where p_i = λ_i / sum_j λ_j
    A soft, scale-invariant dimensionality proxy.
    """

    def __init__(self, epsilon: float = 1e-12, center: bool = True, **kwargs):
        """
        Args:
            epsilon: small floor for eigenvalues and probs.
            center: whether to mean-center before covariance.
        """
        super().__init__("EffectiveRank", **kwargs)
        self.epsilon = float(epsilon)
        self.center = bool(center)

        self._metadata.update(
            {
                "processor_type": "effective_rank",
                "epsilon": self.epsilon,
                "center": self.center,
                "output_type": "scalar_statistic",
            }
        )

    def process(self, embeddings: np.ndarray) -> np.ndarray:
        if embeddings.ndim != 2:
            raise ValueError(f"Expected 2D embeddings, got shape {embeddings.shape}")

        X = embeddings.astype(np.float64, copy=False)
        if self.center:
            X = X - X.mean(axis=0, keepdims=True)

        C = np.cov(X.T)
        evals = np.linalg.eigvalsh(C)  # sorted ascending
        evals = np.maximum(evals, 0.0)
        total = evals.sum()

        if not np.isfinite(total) or total <= self.epsilon:
            erank = 0.0
        else:
            p = evals / (total + self.epsilon)
            p = np.clip(p, self.epsilon, 1.0)  # avoid log(0)
            H = -np.sum(p * np.log(p))
            erank = float(np.exp(H))

        self._metadata.update(
            {
                "input_shape": embeddings.shape,
                "n_vectors": int(embeddings.shape[0]),
                "n_features": int(embeddings.shape[1]),
                "effective_rank": erank,
            }
        )
        return np.array([erank], dtype=np.float64)


File: ssrlib/processing/base.py
from abc import ABC, abstractmethod
from typing import Dict, Any
import numpy as np


class BaseProcessor(ABC):
    """Base class for all processors in ssrlib."""

    def __init__(self, name: str, **kwargs):
        """Initialize processor.

        Args:
            name: Name of the processor
            **kwargs: Additional processor-specific parameters
        """
        self.name = name
        self._metadata = {}

    @abstractmethod
    def process(self, embeddings: np.ndarray) -> np.ndarray:
        """Process embeddings.

        Args:
            embeddings: Input embeddings of shape (n_vectors, n_features)

        Returns:
            Processed embeddings or computed features
        """
        pass

    def get_metadata(self) -> Dict[str, Any]:
        """Get processor metadata."""
        return {"name": self.name, **self._metadata}


File: ssrlib/processing/zca.py
import numpy as np
from typing import Dict, Any
from scipy.linalg import eigh

from .base import BaseProcessor


class ZCAProcessor(BaseProcessor):
    """Processor for ZCA whitening of embeddings."""

    def __init__(self, epsilon: float = 1e-9, **kwargs):
        """Initialize ZCA processor.

        Args:
            epsilon: Regularization parameter for numerical stability
        """
        super().__init__("ZCA", **kwargs)
        self.epsilon = epsilon

        self._metadata.update(
            {
                "processor_type": "zca_whitening",
                "epsilon": epsilon,
                "output_type": "whitened_embeddings",
            }
        )

    def process(self, embeddings: np.ndarray) -> np.ndarray:
        """Apply ZCA whitening to embeddings.

        Args:
            embeddings: Input embeddings of shape (n_vectors, n_features)

        Returns:
            ZCA whitened embeddings of shape (n_vectors, n_features)
        """
        if embeddings.ndim != 2:
            raise ValueError(f"Expected 2D embeddings, got shape {embeddings.shape}")

        # Center the data
        mean = np.mean(embeddings, axis=0, keepdims=True)
        centered = embeddings - mean

        # Compute covariance matrix
        cov = np.cov(centered.T)

        # Eigendecomposition
        eigenvalues, eigenvectors = eigh(cov)

        # Sort eigenvalues and eigenvectors in descending order
        idx = np.argsort(eigenvalues)[::-1]
        eigenvalues = eigenvalues[idx]
        eigenvectors = eigenvectors[:, idx]

        # ZCA whitening matrix
        # W = U * D^(-1/2) * U^T where U are eigenvectors and D are eigenvalues
        d_inv_sqrt = np.diag(1.0 / np.sqrt(eigenvalues + self.epsilon))
        zca_matrix = eigenvectors @ d_inv_sqrt @ eigenvectors.T

        # Apply whitening
        whitened = centered @ zca_matrix.T

        # Update metadata
        self._metadata.update(
            {
                "input_shape": embeddings.shape,
                "output_shape": whitened.shape,
                "n_vectors": embeddings.shape[0],
                "n_features": embeddings.shape[1],
                "mean_eigenvalue": float(np.mean(eigenvalues)),
                "condition_number": float(eigenvalues[0] / eigenvalues[-1]),
            }
        )

        return whitened


File: ssrlib/datasets/food101.py
"""Food101 dataset from HuggingFace."""

from torchvision import transforms
from typing import Dict, Any, ClassVar

from .hf_vision import HFVisionDataset


class Food101Dataset(HFVisionDataset):
    """Food-101 Dataset from HuggingFace Hub."""

    # Class-level metadata
    _dataset_category: ClassVar[str] = "vision"
    _dataset_modality: ClassVar[str] = "vision"
    _dataset_properties: ClassVar[Dict[str, Any]] = {
        "num_classes": 101,
        "image_format": "jpg",
        "processed_image_size": (224, 224),
        "task_type": "multi_class_classification",
        "source": "huggingface",
        "hf_id": "ethz/food101",
    }

    DEFAULT_TRANSFORM = transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )

    def __init__(self, split: str = "train", **kwargs):
        """Initialize Food101 dataset."""
        super().__init__(name="Food101", registry_key="food101", split=split, **kwargs)


File: ssrlib/datasets/__init__.py
"""Datasets for ssrlib with automatic discovery."""

import logging
from pathlib import Path
from typing import Dict, List, Type, Any
import warnings

logger = logging.getLogger(__name__)

# Import base class and registry system
from .base import BaseDataset
from ..core.registry import BaseRegistry, discover_components

# Import HF registry utilities
from .hf_registry import list_hf_datasets, get_hf_dataset_info

# Type alias for clarity
DatasetRegistry = BaseRegistry[BaseDataset]


def discover_dataset_classes() -> DatasetRegistry:
    """Discover all dataset classes in the datasets module."""
    registry = DatasetRegistry("dataset").enable_modalities()

    return discover_components(
        package_path=Path(__file__).parent,
        package_name=__name__,
        base_class=BaseDataset,
        registry=registry,
    )


# Perform discovery at import time
logger.debug("Starting dataset discovery...")
_dataset_registry = discover_dataset_classes()


# Convenience functions
def get_available_datasets() -> Dict[str, Type[BaseDataset]]:
    """Get dictionary of all available datasets."""
    return _dataset_registry._items.copy()


def get_dataset_descriptions() -> Dict[str, str]:
    """Get dictionary of dataset descriptions."""
    return _dataset_registry._descriptions.copy()


def list_datasets(category: str = None, modality: str = None) -> List[str]:
    """List available datasets with optional filtering."""
    if category:
        return _dataset_registry.list_by_category(category).get(category, [])
    elif modality:
        return _dataset_registry.list_by_modality(modality).get(modality, [])
    return _dataset_registry.list_all()


def get_dataset_info(name: str) -> Dict[str, Any]:
    """Get detailed information about a dataset."""
    return _dataset_registry.get_info(name)


def print_available_datasets() -> None:
    """Print all available datasets with descriptions."""
    _dataset_registry.print_registry()


def create_dataset(name: str, **kwargs) -> BaseDataset:
    """Create a dataset by name."""
    dataset_class = _dataset_registry.get(name)
    if dataset_class is None:
        available = ", ".join(_dataset_registry.list_all())
        raise ValueError(f"Unknown dataset '{name}'. Available: {available}")
    return dataset_class(**kwargs)


def get_vision_datasets() -> List[str]:
    """Get list of vision datasets."""
    return list_datasets(modality="vision")


def get_text_datasets() -> List[str]:
    """Get list of text datasets."""
    return list_datasets(modality="text")


def get_audio_datasets() -> List[str]:
    """Get list of audio datasets."""
    return list_datasets(modality="audio")


def get_synthetic_datasets() -> List[str]:
    """Get list of synthetic datasets."""
    return list_datasets(modality="synthetic")


def get_datasets_by_category(category: str) -> List[str]:
    """Get datasets by category."""
    return list_datasets(category=category)


def get_dataset_categories() -> List[str]:
    """Get list of all available categories."""
    return list(_dataset_registry._categories.keys())


def get_dataset_modalities() -> List[str]:
    """Get list of all available modalities."""
    if _dataset_registry._modalities:
        return list(set(_dataset_registry._modalities.values()))
    return []


# HuggingFace dataset utilities
def get_hf_datasets() -> List[str]:
    """Get list of available HuggingFace datasets."""
    return list_hf_datasets()


# Create dynamic exports
_exported_classes = {}
for name, dataset_class in _dataset_registry._items.items():
    _exported_classes[name] = dataset_class

# Update module globals
globals().update(_exported_classes)

# Create __all__ dynamically
__all__ = [
    "BaseDataset",
    "get_available_datasets",
    "get_dataset_descriptions",
    "list_datasets",
    "get_dataset_info",
    "print_available_datasets",
    "create_dataset",
    "get_vision_datasets",
    "get_text_datasets",
    "get_audio_datasets",
    "get_synthetic_datasets",
    "get_datasets_by_category",
    "get_dataset_categories",
    "get_dataset_modalities",
    "get_hf_datasets",
    "get_hf_dataset_info",
    "list_hf_datasets",
    *_dataset_registry.list_all(),
]

# Log results
if logger.isEnabledFor(logging.INFO):
    logger.info(f"Dataset discovery complete: {len(_dataset_registry.list_all())} datasets found")
    for category, datasets in _dataset_registry.list_by_category().items():
        logger.info(f"  {category}: {', '.join(datasets)}")

# Warn about errors
if _dataset_registry._discovery_errors:
    warnings.warn(
        f"Some dataset modules failed to import: {len(_dataset_registry._discovery_errors)} errors. "
        f"Run logging.getLogger('{__name__}').setLevel(logging.DEBUG) for details.",
        ImportWarning,
    )


File: ssrlib/datasets/hf_registry.py
"""Registry for Hugging Face datasets."""

from dataclasses import dataclass
from typing import Dict


@dataclass
class HFDatasetInfo:
    """Information about a HuggingFace dataset."""

    hf_id: str  # HuggingFace dataset identifier
    num_classes: int
    train_split: str = "train"
    test_split: str = "test"
    image_key: str = "image"
    label_key: str = "label"
    description: str = ""


# Registry of supported HuggingFace datasets
HF_DATASET_REGISTRY: Dict[str, HFDatasetInfo] = {
    "food101": HFDatasetInfo(
        hf_id="ethz/food101",
        num_classes=101,
        train_split="train",
        test_split="validation",  # Food101 uses 'validation' instead of 'test'
        image_key="image",
        label_key="label",
        description="Food-101 dataset with 101 food categories",
    ),
    "cifar10": HFDatasetInfo(
        hf_id="uoft-cs/cifar10",
        num_classes=10,
        train_split="train",
        test_split="test",
        image_key="img",
        label_key="label",
        description="CIFAR-10 dataset with 10 classes",
    ),
    "cifar100": HFDatasetInfo(
        hf_id="uoft-cs/cifar100",
        num_classes=100,
        train_split="train",
        test_split="test",
        image_key="img",
        label_key="fine_label",  # CIFAR-100 uses 'fine_label'
        description="CIFAR-100 dataset with 100 fine-grained classes",
    ),
    "sun397": HFDatasetInfo(
        hf_id="tanganke/sun397",
        num_classes=397,
        train_split="train",
        test_split="test",
        image_key="image",
        label_key="label",
        description="SUN397 scene recognition dataset with 397 categories",
    ),
    "stanford_cars": HFDatasetInfo(
        hf_id="tanganke/stanford_cars",
        num_classes=196,
        train_split="train",
        test_split="test",
        image_key="image",
        label_key="label",
        description="Stanford Cars dataset with 196 car models",
    ),
    "dtd": HFDatasetInfo(
        hf_id="tanganke/dtd",
        num_classes=47,
        train_split="train",
        test_split="test",
        image_key="image",
        label_key="label",
        description="Describable Textures Dataset with 47 texture classes",
    ),
    "oxford_pets": HFDatasetInfo(
        hf_id="timm/oxford-iiit-pet",
        num_classes=37,
        train_split="train",
        test_split="test",
        image_key="image",
        label_key="label",
        description="Oxford-IIIT Pet dataset with 37 pet breeds",
    ),
    "caltech101": HFDatasetInfo(
        hf_id="flwrlabs/caltech101",
        num_classes=101,
        train_split="train",
        test_split="train",  # Only has train split
        image_key="image",
        label_key="label",
        description="Caltech-101 dataset with 101 object categories",
    ),
    "flowers102": HFDatasetInfo(
        hf_id="Donghyun99/Oxford-Flower-102",
        num_classes=102,
        train_split="train",
        test_split="test",
        image_key="image",
        label_key="label",
        description="Oxford Flowers-102 dataset with 102 flower species",
    ),
}


def get_hf_dataset_info(dataset_name: str) -> HFDatasetInfo:
    """
    Get HuggingFace dataset information.

    Args:
        dataset_name: Dataset name

    Returns:
        HFDatasetInfo object

    Raises:
        ValueError: If dataset not in registry
    """
    if dataset_name not in HF_DATASET_REGISTRY:
        available = ", ".join(HF_DATASET_REGISTRY.keys())
        raise ValueError(f"Unknown HuggingFace dataset: {dataset_name}. " f"Available: {available}")

    return HF_DATASET_REGISTRY[dataset_name]


def list_hf_datasets() -> list:
    """Get list of available HuggingFace datasets."""
    return list(HF_DATASET_REGISTRY.keys())


File: ssrlib/datasets/celeba.py
import os
import pandas as pd
import torch
from PIL import Image
from torchvision import transforms
from typing import Iterator, Dict, Any, Optional, Tuple, List, Union, ClassVar
from pathlib import Path
import logging

from .base import BaseDataset
from .kaggle_mixin import KaggleDatasetMixin

logger = logging.getLogger(__name__)


class CelebADataset(KaggleDatasetMixin, BaseDataset):
    """CelebA Dataset for ssrlib framework."""

    # Class-level metadata
    _dataset_category: ClassVar[str] = "vision"
    _dataset_modality: ClassVar[str] = "vision"
    _dataset_properties: ClassVar[Dict[str, Any]] = {
        "num_attributes": 40,
        "image_format": "jpg",
        "default_image_size": (178, 218),
        "processed_image_size": (224, 224),
        "num_identities": 10177,
        "total_images": 202599,
        "supports_multi_label": True,
        "task_type": "binary_classification",
    }

    # Expected files and directories
    REQUIRED_FILES = ["list_eval_partition.csv", "list_attr_celeba.csv"]
    REQUIRED_DIRS = ["img_align_celeba/img_align_celeba"]

    def __init__(
        self,
        root: str = "data",
        split: str = "train",
        task_name: str = "Attractive",
        transform: Optional[transforms.Compose] = None,
        **kwargs,
    ):
        """Initialize CelebA dataset."""
        super().__init__("CelebA", **kwargs)

        self.root = Path(root) / "CelebA"
        self.split = split
        self.task_name = task_name

        # Set default transform
        if transform is None:
            self.transform = transforms.Compose(
                [
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                ]
            )
        else:
            self.transform = transform

        # File paths (set after verification)
        self.split_csv = None
        self.attr_csv = None
        self.images_dir = None

        self.data = None
        self.attr_names = None

        self._metadata.update({"split": split, "task_name": task_name, "root": str(self.root)})

        # Download if needed
        if not self._check_exists():
            logger.info(f"CelebA not found, downloading to {self.root}")
            print(f"CelebA not found, downloading to {self.root}")
            self._download()

        self._load_data()
        self._downloaded = True

    def _get_kaggle_dataset_id(self) -> str:
        """Get Kaggle dataset ID."""
        return "jessicali9530/celeba-dataset"

    def _check_exists(self) -> bool:
        """Check if dataset exists and is properly structured."""
        if not self.root.exists():
            return False

        return self._verify_structure()

    def _verify_structure(self) -> bool:
        """
        Verify CelebA structure is correct.

        Expected structure:
        CelebA/
        ├── list_eval_partition.csv
        ├── list_attr_celeba.csv
        └── img_align_celeba/
            └── img_align_celeba/
                ├── 000001.jpg
                ├── 000002.jpg
                └── ...
        """
        logger.info("Verifying CelebA structure...")

        # Find required files
        self.split_csv = self._find_required_file("list_eval_partition.csv")
        self.attr_csv = self._find_required_file("list_attr_celeba.csv")
        self.images_dir = self._find_required_directory("img_align_celeba")

        # Check all found
        if not (self.split_csv and self.attr_csv and self.images_dir):
            return False

        # Verify images exist
        image_files = list(self.images_dir.glob("*.jpg"))
        if len(image_files) == 0:
            logger.error("No images found in img_align_celeba directory")
            return False

        logger.info(f"✓ Structure verified: {len(image_files)} images found")
        print(f"✓ CelebA structure verified: {len(image_files)} images found")

        return True

    def _find_required_file(self, filename: str) -> Optional[Path]:
        """Find required file and move to root if needed."""
        # Check if already in root
        target_path = self.root / filename
        if target_path.exists():
            logger.info(f"✓ Found {filename}")
            return target_path

        # Search for file
        found_path = self._find_file(filename)
        if found_path:
            logger.info(f"Found {filename} at {found_path}, moving to root")
            return self._move_to_root(found_path, filename)

        logger.error(f"✗ {filename} not found")
        return None

    def _find_required_directory(self, dirname: str) -> Optional[Path]:
        """
        Find required image directory and organize if needed.

        Handles various extraction patterns:
        - img_align_celeba/img_align_celeba/  (correct)
        - img_align_celeba/                  (needs nesting)
        - Images directly in root             (needs organization)
        """
        target_path = self.root / "img_align_celeba" / "img_align_celeba"

        # Already correct structure
        if target_path.exists() and target_path.is_dir():
            logger.info(f"✓ Found {dirname}")
            return target_path

        # Find img_align_celeba directory
        parent_dir = self.root / "img_align_celeba"

        if parent_dir.exists():
            # Check if images are directly in parent
            image_files = list(parent_dir.glob("*.jpg"))

            if image_files:
                # Need to create nested structure
                logger.info("Creating nested img_align_celeba directory...")
                target_path.mkdir(parents=True, exist_ok=True)

                # Move images to nested directory
                for img_file in image_files:
                    img_file.rename(target_path / img_file.name)

                logger.info(f"✓ Organized {len(image_files)} images")
                return target_path

            elif target_path.exists():
                # Nested directory exists
                return target_path

        # Try to find directory anywhere
        found_dir = self._find_directory("img_align_celeba")
        if found_dir:
            logger.info(f"Found img_align_celeba at {found_dir}")

            # Check if it needs nesting
            if found_dir.parent != self.root:
                found_dir = self._move_to_root(found_dir, "img_align_celeba")

            # Ensure nested structure
            if found_dir == parent_dir:
                image_files = list(found_dir.glob("*.jpg"))
                if image_files:
                    target_path.mkdir(parents=True, exist_ok=True)
                    for img_file in image_files:
                        img_file.rename(target_path / img_file.name)

            return target_path

        logger.error(f"✗ {dirname} directory not found")
        return None

    def download(self) -> None:
        """Download CelebA dataset if not present."""
        if self._downloaded:
            return

        if not self._check_exists():
            self._download()
            self._load_data()

        self._downloaded = True

    def _download(self) -> None:
        """Download dataset from Kaggle."""
        self._download_from_kaggle(zip_filename="celeba_dataset.zip")

    def _load_data(self) -> None:
        """Load dataset metadata."""
        if not self.split_csv or not self.attr_csv:
            raise RuntimeError("Dataset files not found. Run download() first.")

        # Load split information
        split_df = pd.read_csv(self.split_csv)
        split_map = {"train": 0, "valid": 1, "test": 2}
        split_df = split_df[split_df["partition"] == split_map[self.split]]

        # Load attributes
        attr_df = pd.read_csv(self.attr_csv)
        self.attr_names = list(attr_df.columns[1:])

        # Validate task name
        if self.task_name not in self.attr_names:
            raise ValueError(f"Unknown task '{self.task_name}'. Available: {self.attr_names}")

        # Merge data
        self.data = pd.merge(
            split_df,
            attr_df[["image_id", self.task_name]],
            on="image_id",
            how="left",
        )

        logger.info(f"Loaded {len(self.data)} samples for split '{self.split}'")

    def __getitem__(
        self, idx: Union[int, slice]
    ) -> Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]:
        """Get item(s) by index."""
        if self.data is None:
            self.download()

        if isinstance(idx, slice):
            indices = range(*idx.indices(len(self.data)))
            return [self._get_single_item(i) for i in indices]
        else:
            return self._get_single_item(idx)

    def _get_single_item(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get a single item by index."""
        if idx >= len(self.data) or idx < -len(self.data):
            raise IndexError(f"Index {idx} out of range for dataset of size {len(self.data)}")

        if idx < 0:
            idx = len(self.data) + idx

        row = self.data.iloc[idx]
        img_path = self.images_dir / row["image_id"]

        if not img_path.exists():
            raise FileNotFoundError(f"Image {img_path} not found")

        # Load and transform image
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        # Get target (convert from -1/1 to 0/1)
        target = torch.tensor(1 if row[self.task_name] == 1 else 0, dtype=torch.long)

        return image, target

    def __iter__(self) -> Iterator[torch.Tensor]:
        """Iterate over dataset returning image tensors only."""
        if self.data is None:
            self.download()

        for idx in range(len(self.data)):
            try:
                image, _ = self._get_single_item(idx)
                yield image
            except Exception as e:
                logger.warning(f"Skipping sample {idx}: {str(e)}")
                continue

    def __len__(self) -> int:
        """Return dataset size."""
        if self.data is None:
            self.download()
        return len(self.data)

    def __repr__(self) -> str:
        """String representation."""
        return (
            f"CelebADataset(split='{self.split}', task='{self.task_name}', "
            f"size={len(self) if self.data is not None else 'Unknown'}, "
            f"root='{self.root}')"
        )

    def get_classes(self) -> Dict[str, Any]:
        """Get class information."""
        return {
            "task_name": self.task_name,
            "num_classes": 2,
            "class_names": ["No", "Yes"],
            "class_to_idx": {"No": 0, "Yes": 1},
        }

    def get_all_attributes(self) -> List[str]:
        """Get list of all available attributes."""
        if self.attr_names is None:
            self.download()
        return self.attr_names.copy()

    def get_sample_info(self, idx: int) -> Dict[str, Any]:
        """Get detailed information about a specific sample."""
        if self.data is None:
            self.download()

        if idx >= len(self.data) or idx < -len(self.data):
            raise IndexError(f"Index {idx} out of range")

        if idx < 0:
            idx = len(self.data) + idx

        row = self.data.iloc[idx]
        img_path = self.images_dir / row["image_id"]

        return {
            "index": idx,
            "image_id": row["image_id"],
            "image_path": str(img_path),
            "target_value": row[self.task_name],
            "target_class": "Yes" if row[self.task_name] == 1 else "No",
            "split": self.split,
            "exists": img_path.exists(),
        }

    def get_metadata(self) -> Dict[str, Any]:
        """Get dataset metadata."""
        metadata = super().get_metadata()
        if self.data is not None:
            metadata.update(
                {
                    "num_samples": len(self.data),
                    "image_shape": "(3, 224, 224)",
                    "split": self.split,
                    "task_name": self.task_name,
                    "num_attributes": len(self.attr_names) if self.attr_names else 0,
                }
            )
        return metadata


File: ssrlib/datasets/hf_vision.py
"""Base class for HuggingFace vision datasets."""

import torch
from PIL import Image
from torchvision import transforms
from typing import Iterator, Dict, Any, Optional, Tuple, List, Union, ClassVar
import logging

from .base import BaseDataset
from .hf_mixin import HFDatasetMixin
from .hf_registry import get_hf_dataset_info

logger = logging.getLogger(__name__)


class HFVisionDataset(HFDatasetMixin, BaseDataset):
    """Base class for HuggingFace vision datasets."""

    # Subclasses must override
    HF_REGISTRY_KEY: ClassVar[str] = None
    DEFAULT_TRANSFORM: ClassVar[transforms.Compose] = None

    def __init__(
        self,
        name: str,
        registry_key: str,
        split: str = "train",
        transform: Optional[transforms.Compose] = None,
        cache_dir: Optional[str] = None,
        **kwargs,
    ):
        """Initialize HF vision dataset."""
        super().__init__(name, **kwargs)

        self.split = split
        self.cache_dir = cache_dir

        # Get dataset info from registry
        self.hf_info = get_hf_dataset_info(registry_key)

        # Set transform
        if transform is None:
            self.transform = self.DEFAULT_TRANSFORM or self._default_transform()
        else:
            self.transform = transform

        # Will be set after loading
        self.hf_dataset = None
        self.image_key = None
        self.label_key = None
        self.hf_keys = None

        self._metadata.update(
            {
                "split": split,
                "hf_id": self.hf_info.hf_id,
                "num_classes": self.hf_info.num_classes,
            }
        )

        # Load dataset
        self.download()
        self._downloaded = True

    def _default_transform(self) -> transforms.Compose:
        """Default transform for vision datasets."""
        return transforms.Compose(
            [
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ]
        )

    def _get_hf_dataset_id(self) -> str:
        """Get HuggingFace dataset ID."""
        return self.hf_info.hf_id

    def _get_hf_split_name(self, split: str) -> str:
        """Map split name to HF split."""
        if split == "train":
            return self.hf_info.train_split
        elif split in ["test", "val", "validation"]:
            return self.hf_info.test_split
        else:
            return split

    def _get_hf_keys(self) -> Dict[str, str]:
        """Get HF dataset column keys."""
        return {
            "image": self.hf_info.image_key,
            "label": self.hf_info.label_key,
        }

    def download(self) -> None:
        """Load dataset from HuggingFace."""
        if self._downloaded:
            return

        logger.info(f"Loading {self.name} dataset (split: {self.split})")
        print(f"Loading {self.name} dataset (split: {self.split})")

        self._load_from_huggingface(self.split, self.cache_dir)
        self._downloaded = True

    def __getitem__(
        self, idx: Union[int, slice]
    ) -> Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]:
        """Get item(s) by index."""
        if self.hf_dataset is None:
            self.download()

        if isinstance(idx, slice):
            indices = range(*idx.indices(len(self.hf_dataset)))
            return [self._get_single_item(i) for i in indices]
        else:
            return self._get_single_item(idx)

    def _get_single_item(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get a single item by index."""
        if idx >= len(self.hf_dataset) or idx < -len(self.hf_dataset):
            raise IndexError(f"Index {idx} out of range for dataset of size {len(self.hf_dataset)}")

        if idx < 0:
            idx = len(self.hf_dataset) + idx

        example = self.hf_dataset[idx]

        # Get image
        image = example[self.image_key]

        # Convert to PIL if needed
        if not isinstance(image, Image.Image):
            if hasattr(image, "convert"):
                image = image.convert("RGB")
            else:
                image = Image.fromarray(image).convert("RGB")
        else:
            image = image.convert("RGB")

        # Apply transform
        if self.transform:
            image = self.transform(image)

        # Get label
        label = example[self.label_key]
        label = self._convert_label(label)

        # Convert to tensor
        if not isinstance(label, torch.Tensor):
            label = torch.tensor(label, dtype=torch.long)

        return image, label

    def __iter__(self) -> Iterator[torch.Tensor]:
        """Iterate over dataset returning image tensors only."""
        if self.hf_dataset is None:
            self.download()

        for idx in range(len(self.hf_dataset)):
            try:
                image, _ = self._get_single_item(idx)
                yield image
            except Exception as e:
                logger.warning(f"Skipping sample {idx}: {str(e)}")
                continue

    def __len__(self) -> int:
        """Return dataset size."""
        if self.hf_dataset is None:
            self.download()
        return len(self.hf_dataset)

    def __repr__(self) -> str:
        """String representation."""
        return (
            f"{self.name}Dataset(split='{self.split}', "
            f"size={len(self) if self.hf_dataset is not None else 'Unknown'})"
        )

    def get_classes(self) -> Dict[str, Any]:
        """Get class information."""
        class_names = self._get_class_names()

        if class_names:
            class_to_idx = {name: idx for idx, name in enumerate(class_names)}
            idx_to_class = {idx: name for name, idx in class_to_idx.items()}
        else:
            class_to_idx = {}
            idx_to_class = {}

        return {
            "num_classes": self._get_num_classes(),
            "class_names": class_names,
            "class_to_idx": class_to_idx,
            "idx_to_class": idx_to_class,
        }

    def get_sample_info(self, idx: int) -> Dict[str, Any]:
        """Get detailed information about a specific sample."""
        if self.hf_dataset is None:
            self.download()

        if idx >= len(self.hf_dataset) or idx < -len(self.hf_dataset):
            raise IndexError(f"Index {idx} out of range")

        if idx < 0:
            idx = len(self.hf_dataset) + idx

        example = self.hf_dataset[idx]

        return {
            "index": idx,
            "label": example[self.label_key],
            "label_idx": self._convert_label(example[self.label_key]),
            "split": self.split,
            "has_image": self.image_key in example,
        }

    def get_metadata(self) -> Dict[str, Any]:
        """Get dataset metadata."""
        metadata = super().get_metadata()
        if self.hf_dataset is not None:
            metadata.update(
                {
                    "num_samples": len(self.hf_dataset),
                    "num_classes": self._get_num_classes(),
                    "split": self.split,
                    "hf_id": self.hf_info.hf_id,
                }
            )
        return metadata


File: ssrlib/datasets/kaggle_mixin.py
"""Simplified mixin for downloading datasets from Kaggle."""

import zipfile
import shutil
import requests
from pathlib import Path
from typing import Optional
from abc import abstractmethod
import logging

logger = logging.getLogger(__name__)


class KaggleDatasetMixin:
    """
    Simplified mixin for Kaggle dataset downloads.

    Subclasses must implement:
    - _get_kaggle_dataset_id() -> str
    - _verify_structure() -> bool
    """

    @abstractmethod
    def _get_kaggle_dataset_id(self) -> str:
        """Return Kaggle dataset ID (e.g., 'username/dataset-name')."""
        pass

    @abstractmethod
    def _verify_structure(self) -> bool:
        """Verify dataset structure is correct after extraction."""
        pass

    def _download_from_kaggle(
        self,
        dataset_id: Optional[str] = None,
        zip_filename: Optional[str] = None,
    ) -> None:
        """
        Download and extract dataset from Kaggle.

        Simplified flow:
        1. Download zip file
        2. Extract to root
        3. Verify structure
        4. Clean up
        """
        dataset_id = dataset_id or self._get_kaggle_dataset_id()

        # Generate zip filename
        if zip_filename is None:
            zip_filename = f"{dataset_id.split('/')[-1]}.zip"

        # Ensure root directory exists
        self.root.mkdir(parents=True, exist_ok=True)
        zip_path = self.root / zip_filename

        try:
            # Step 1: Download
            logger.info(f"Downloading {dataset_id} from Kaggle...")
            print(f"Downloading {dataset_id} from Kaggle...")
            self._download_file(dataset_id, zip_path)

            # Step 2: Extract
            logger.info("Extracting dataset...")
            print("Extracting dataset...")
            self._extract_zip(zip_path)

            # Step 3: Verify
            if not self._verify_structure():
                raise RuntimeError("Dataset structure verification failed")

            # Step 4: Clean up
            logger.info("Cleaning up...")
            print("Cleaning up...")
            zip_path.unlink()

            logger.info("Dataset download completed successfully")
            print("Dataset download completed successfully")

        except requests.exceptions.RequestException as e:
            self._handle_download_error(e, dataset_id)
            raise

        except zipfile.BadZipFile as e:
            self._handle_zip_error(e, zip_path)
            raise

        except Exception as e:
            self._cleanup_on_error(zip_path)
            logger.error(f"Unexpected error: {str(e)}")
            raise RuntimeError(f"Dataset download failed: {str(e)}") from e

    def _download_file(self, dataset_id: str, output_path: Path) -> None:
        """Download file from Kaggle API."""
        kaggle_url = f"https://www.kaggle.com/api/v1/datasets/download/{dataset_id}"

        response = requests.get(kaggle_url, stream=True)
        response.raise_for_status()

        total_size = int(response.headers.get("content-length", 0))

        with open(output_path, "wb") as f:
            if total_size > 0:
                downloaded = 0
                for chunk in response.iter_content(chunk_size=8192):
                    if chunk:
                        f.write(chunk)
                        downloaded += len(chunk)
                        progress = (downloaded / total_size) * 100
                        print(f"\rProgress: {progress:.1f}%", end="", flush=True)
                print()  # New line after progress
            else:
                for chunk in response.iter_content(chunk_size=8192):
                    if chunk:
                        f.write(chunk)

        logger.info(f"Download completed: {output_path}")

    def _extract_zip(self, zip_path: Path) -> None:
        """Extract zip file to root directory."""
        with zipfile.ZipFile(zip_path, "r") as zip_ref:
            zip_ref.extractall(self.root)

    def _handle_download_error(self, error: Exception, dataset_id: str) -> None:
        """Handle download errors with helpful messages."""
        logger.error(f"Download failed: {str(error)}")
        print(f"\n❌ Error downloading dataset: {str(error)}")
        print("\n📝 Please check:")
        print("   1. Internet connection")
        print("   2. Kaggle API credentials (~/.kaggle/kaggle.json)")
        print(f"   3. Dataset access: https://www.kaggle.com/datasets/{dataset_id}")
        print("\n💡 Manual download:")
        print(f"   1. Visit: https://www.kaggle.com/datasets/{dataset_id}")
        print(f"   2. Download to: {self.root}")
        print("   3. Extract the files")

    def _handle_zip_error(self, error: Exception, zip_path: Path) -> None:
        """Handle zip extraction errors."""
        logger.error(f"Zip extraction failed: {str(error)}")
        print(f"\n❌ Error extracting zip file: {str(error)}")

        if zip_path.exists():
            zip_path.unlink()
            print("Removed corrupted zip file")

    def _cleanup_on_error(self, zip_path: Path) -> None:
        """Clean up files after error."""
        if zip_path.exists():
            try:
                zip_path.unlink()
                logger.info("Cleaned up partial download")
            except Exception as e:
                logger.warning(f"Failed to clean up: {str(e)}")

    def _find_file(self, filename: str) -> Optional[Path]:
        """
        Find a file anywhere in root directory.
        Returns first match or None.
        """
        for path in self.root.rglob(filename):
            if path.is_file():
                return path
        return None

    def _find_directory(self, dirname: str) -> Optional[Path]:
        """
        Find a directory anywhere in root directory.
        Returns first match or None.
        """
        for path in self.root.rglob(dirname):
            if path.is_dir():
                return path
        return None

    def _move_to_root(self, source: Path, target_name: str) -> Path:
        """Move file or directory to root with new name."""
        target = self.root / target_name

        if target.exists():
            if target.is_dir():
                shutil.rmtree(target)
            else:
                target.unlink()

        shutil.move(str(source), str(target))
        logger.info(f"Moved {source.name} to {target_name}")

        return target


File: ssrlib/datasets/synthtest_dataset.py
import torch
import numpy as np
from typing import Iterator, Dict, Any, Optional, Union, List, ClassVar

from .base import BaseDataset


class SynthTestDataset(BaseDataset):
    """Synthetic test dataset that generates random image-like tensors."""

    # Class-level metadata
    _dataset_category: ClassVar[str] = "synthetic"
    _dataset_modality: ClassVar[str] = "synthetic"
    _dataset_properties: ClassVar[Dict[str, Any]] = {
        "default_tensor_shape": (3, 224, 224),
        "default_num_tensors": 100,
        "deterministic": True,
        "task_type": "testing",
        "supports_custom_shapes": True,
        "value_range": (-2.0, 2.0),
    }

    def __init__(
        self,
        tensors_num: int = 100,
        seed: Optional[int] = None,
        tensor_shape: tuple = (3, 224, 224),
        **kwargs,
    ):
        """Initialize synthetic test dataset.

        Args:
            tensors_num: Number of tensors to generate
            seed: Random seed for reproducibility (optional)
            tensor_shape: Shape of generated tensors (default: (3, 224, 224))
            **kwargs: Additional arguments passed to BaseDataset
        """
        super().__init__("SynthTest", **kwargs)

        self.tensors_num = tensors_num
        self.seed = seed
        self.tensor_shape = tensor_shape

        # Validate inputs
        if tensors_num <= 0:
            raise ValueError(f"tensors_num must be positive, got {tensors_num}")

        if len(tensor_shape) != 3:
            raise ValueError(f"tensor_shape must have 3 dimensions, got {len(tensor_shape)}")

        # Update metadata
        self._metadata.update(
            {
                "tensors_num": tensors_num,
                "tensor_shape": tensor_shape,
                "seed": seed,
                "synthetic": True,
                "dataset_type": "synthetic_test",
            }
        )

        # Mark as already "downloaded" since no actual download is needed
        self._downloaded = True

    def download(self) -> None:
        """No-op download method for synthetic data."""
        if not self._downloaded:
            print(f"Synthetic dataset {self.name} ready (no download needed)")
            self._downloaded = True

    def __getitem__(self, idx: Union[int, slice]) -> Union[torch.Tensor, List[torch.Tensor]]:
        """Get item(s) by index."""
        if isinstance(idx, slice):
            indices = range(*idx.indices(self.tensors_num))
            return [self._get_single_item(i) for i in indices]
        else:
            return self._get_single_item(idx)

    def _get_single_item(self, idx: int) -> torch.Tensor:
        """Get a single tensor by index."""
        if idx >= self.tensors_num or idx < -self.tensors_num:
            raise IndexError(f"Index {idx} out of range for dataset of size {self.tensors_num}")

        if idx < 0:
            idx = self.tensors_num + idx

        # Generate deterministic tensor based on index and seed
        if self.seed is not None:
            generator = torch.Generator()
            generator.manual_seed(self.seed + idx)
            tensor = torch.randn(*self.tensor_shape, generator=generator)
        else:
            tensor = torch.randn(*self.tensor_shape)

        # Clamp to reasonable range
        tensor = torch.clamp(tensor, -2.0, 2.0)
        return tensor

    def __iter__(self) -> Iterator[torch.Tensor]:
        """Generate random tensors."""
        if self.seed is not None:
            torch.manual_seed(self.seed)
            np.random.seed(self.seed)

        for i in range(self.tensors_num):
            tensor = torch.randn(*self.tensor_shape)
            tensor = torch.clamp(tensor, -2.0, 2.0)
            yield tensor

    def __len__(self) -> int:
        """Return number of tensors in dataset."""
        return self.tensors_num

    def __repr__(self) -> str:
        """String representation of dataset."""
        return (
            f"SynthTestDataset(size={self.tensors_num}, "
            f"shape={self.tensor_shape}, seed={self.seed})"
        )

    def get_sample_info(self, idx: int) -> Dict[str, Any]:
        """Get information about a specific sample."""
        if idx >= self.tensors_num or idx < -self.tensors_num:
            raise IndexError(f"Index {idx} out of range")

        if idx < 0:
            idx = self.tensors_num + idx

        return {
            "index": idx,
            "tensor_shape": self.tensor_shape,
            "seed_used": self.seed,
            "deterministic": self.seed is not None,
            "synthetic": True,
        }

    def regenerate(self, new_seed: Optional[int] = None) -> None:
        """Force regeneration with new seed."""
        if new_seed is not None:
            self.seed = new_seed
            self._metadata["seed"] = new_seed
        else:
            self.seed = None
            self._metadata["seed"] = None

    def set_seed(self, seed: int) -> None:
        """Set random seed for reproducible generation."""
        self.seed = seed
        self._metadata["seed"] = seed

    def get_metadata(self) -> Dict[str, Any]:
        """Get dataset metadata."""
        metadata = super().get_metadata()
        metadata.update(
            {
                "num_samples": self.tensors_num,
                "image_shape": f"{self.tensor_shape}",
                "synthetic": True,
                "deterministic": self.seed is not None,
                "seed_used": self.seed,
            }
        )
        return metadata


File: ssrlib/datasets/cifar10.py
"""CIFAR-10 dataset from HuggingFace."""

from torchvision import transforms
from typing import Dict, Any, ClassVar

from .hf_vision import HFVisionDataset


class CIFAR10Dataset(HFVisionDataset):
    """CIFAR-10 Dataset from HuggingFace Hub."""

    # Class-level metadata
    _dataset_category: ClassVar[str] = "vision"
    _dataset_modality: ClassVar[str] = "vision"
    _dataset_properties: ClassVar[Dict[str, Any]] = {
        "num_classes": 10,
        "image_format": "png",
        "original_image_size": (32, 32),
        "processed_image_size": (224, 224),
        "task_type": "multi_class_classification",
        "source": "huggingface",
        "hf_id": "uoft-cs/cifar10",
    }

    # Upscale from 32x32 to 224x224
    DEFAULT_TRANSFORM = transforms.Compose(
        [
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )

    def __init__(self, split: str = "train", **kwargs):
        """Initialize CIFAR-10 dataset."""
        super().__init__(name="CIFAR10", registry_key="cifar10", split=split, **kwargs)

    def get_metadata(self) -> Dict[str, Any]:
        """Get dataset metadata with CIFAR-10 specific info."""
        metadata = super().get_metadata()
        metadata["original_size"] = "(3, 32, 32)"
        metadata["image_shape"] = "(3, 224, 224)"
        return metadata


File: ssrlib/datasets/imagenet100.py
import os
import json
import glob
from pathlib import Path
from typing import Iterator, Dict, Any, Optional, List, Tuple, Union, ClassVar
import torch
from PIL import Image
from torchvision import transforms
import logging

from .base import BaseDataset
from .kaggle_mixin import KaggleDatasetMixin

logger = logging.getLogger(__name__)


class ImageNet100Dataset(KaggleDatasetMixin, BaseDataset):
    """ImageNet100 Dataset for ssrlib framework."""

    # Class-level metadata
    _dataset_category: ClassVar[str] = "vision"
    _dataset_modality: ClassVar[str] = "vision"
    _dataset_properties: ClassVar[Dict[str, Any]] = {
        "num_classes": 100,
        "image_format": "JPEG",
        "processed_image_size": (224, 224),
        "task_type": "multi_class_classification",
        "supports_train_val_split": True,
        "hierarchical_labels": True,
    }

    def __init__(
        self,
        root: str = "data",
        split: str = "train",
        labels_path: Optional[str] = None,
        combine_train_splits: bool = True,
        transform: Optional[transforms.Compose] = None,
        **kwargs,
    ):
        """Initialize ImageNet100 dataset."""
        super().__init__("ImageNet100", **kwargs)

        self.root = Path(root) / "ImageNet100"
        self.split = split
        self.combine_train_splits = combine_train_splits
        self.labels_path = labels_path

        # Set default transform
        if transform is None:
            self.transform = transforms.Compose(
                [
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                ]
            )
        else:
            self.transform = transform

        # Will be loaded after verification
        self.samples = []
        self.synset_to_class = {}
        self.class_names = []
        self.class_to_idx = {}
        self.idx_to_class = {}

        self._metadata.update(
            {
                "split": split,
                "combine_train_splits": combine_train_splits,
                "root": str(self.root),
            }
        )

        # Download if needed
        if not self._check_exists():
            logger.info(f"ImageNet100 not found, downloading to {self.root}")
            print(f"ImageNet100 not found, downloading to {self.root}")
            self._download()

        self._load_data()
        self._downloaded = True

    def _get_kaggle_dataset_id(self) -> str:
        """Get Kaggle dataset ID."""
        return "ambityga/imagenet100"

    def _check_exists(self) -> bool:
        """Check if dataset exists and is properly structured."""
        if not self.root.exists():
            return False

        return self._verify_structure()

    def _verify_structure(self) -> bool:
        """
        Verify ImageNet100 structure is correct.

        Expected structure:
        ImageNet100/
        ├── train.X1/
        │   ├── n01440764/
        │   ├── n01443537/
        │   └── ...
        ├── train.X2/ (optional)
        ├── val.X/
        │   ├── n01440764/
        │   ├── n01443537/
        │   └── ...
        └── Labels.json (optional)
        """
        logger.info("Verifying ImageNet100 structure...")

        # Find training directories
        train_dirs = self._find_train_directories()
        if not train_dirs:
            logger.error("No training directories found")
            return False

        # Find validation directory
        val_dir = self._find_val_directory()
        if not val_dir:
            logger.error("No validation directory found")
            return False

        # Check for images
        has_images = self._check_has_images(train_dirs + [val_dir])
        if not has_images:
            logger.error("No images found in directories")
            return False

        # Find labels file (optional)
        if self.labels_path is None:
            self.labels_path = self._find_labels_file()

        logger.info(f"✓ Structure verified")
        print(f"✓ ImageNet100 structure verified")
        print(f"  - Training dirs: {len(train_dirs)}")
        print(f"  - Validation dir: {val_dir.name}")
        print(f"  - Labels file: {'Found' if self.labels_path else 'Not found'}")

        return True

    def _find_train_directories(self) -> List[Path]:
        """Find training directories (train.X1, train.X2, etc.)."""
        train_dirs = []

        # Look for train.X* pattern
        for pattern in ["train.X*", "train*", "Train*"]:
            found = list(self.root.glob(pattern))
            for path in found:
                if path.is_dir() and self._has_synset_subdirs(path):
                    train_dirs.append(path)

        # Sort by name
        train_dirs.sort(key=lambda x: x.name)

        if train_dirs:
            logger.info(f"Found {len(train_dirs)} training directories")

        return train_dirs

    def _find_val_directory(self) -> Optional[Path]:
        """Find validation directory (val.X or similar)."""
        # Try exact name first
        val_dir = self.root / "val.X"
        if val_dir.exists() and self._has_synset_subdirs(val_dir):
            logger.info("Found validation directory: val.X")
            return val_dir

        # Try various patterns
        for pattern in ["val*", "Val*", "validation*", "valid*"]:
            found = list(self.root.glob(pattern))
            for path in found:
                if path.is_dir() and self._has_synset_subdirs(path):
                    logger.info(f"Found validation directory: {path.name}")
                    return path

        return None

    def _has_synset_subdirs(self, directory: Path) -> bool:
        """Check if directory contains synset subdirectories (n01...)."""
        subdirs = [d for d in directory.iterdir() if d.is_dir()]
        if not subdirs:
            return False

        # Check if any subdirectory starts with 'n' (synset pattern)
        return any(d.name.startswith("n") for d in subdirs)

    def _check_has_images(self, directories: List[Path]) -> bool:
        """Check if directories contain image files."""
        for directory in directories:
            for synset_dir in directory.iterdir():
                if synset_dir.is_dir():
                    # Check for image files
                    image_files = (
                        list(synset_dir.glob("*.JPEG"))
                        + list(synset_dir.glob("*.jpg"))
                        + list(synset_dir.glob("*.png"))
                    )
                    if image_files:
                        return True
        return False

    def _find_labels_file(self) -> Optional[str]:
        """Find Labels.json file if it exists."""
        for pattern in ["Labels.json", "labels.json", "LABELS.json"]:
            labels_path = self.root / pattern
            if labels_path.exists():
                logger.info(f"Found labels file: {pattern}")
                return str(labels_path)

        return None

    def download(self) -> None:
        """Download ImageNet100 dataset if not present."""
        if self._downloaded:
            return

        if not self._check_exists():
            self._download()
            self._load_data()

        self._downloaded = True

    def _download(self) -> None:
        """Download dataset from Kaggle."""
        self._download_from_kaggle(zip_filename="imagenet100.zip")

    def _load_data(self) -> None:
        """Load dataset structure and samples."""
        # Load labels if available
        if self.labels_path and os.path.exists(self.labels_path):
            self._load_labels()

        # Load samples based on split
        if self.split in ["train", "training"]:
            self._load_train_samples()
        elif self.split in ["val", "valid", "validation"]:
            self._load_val_samples()
        else:
            raise ValueError(f"Unknown split: {self.split}")

        # Create class mappings
        self._create_class_mappings()

        logger.info(f"Loaded {len(self.samples)} samples for split '{self.split}'")

    def _load_labels(self) -> None:
        """Load synset to class name mapping."""
        with open(self.labels_path, "r") as f:
            labels = json.load(f)

        self.synset_to_class = {
            synset: desc.split(",")[0].strip() for synset, desc in labels.items()
        }

        logger.info(f"Loaded {len(self.synset_to_class)} class labels")

    def _load_train_samples(self) -> None:
        """Load training samples."""
        train_dirs = self._find_train_directories()

        if not train_dirs:
            raise ValueError(f"No training directories found in {self.root}")

        # Use all or just first directory based on combine_train_splits
        if self.combine_train_splits:
            dirs_to_load = train_dirs
        else:
            dirs_to_load = train_dirs[:1]

        logger.info(f"Loading from {len(dirs_to_load)} training directories")

        for train_dir in dirs_to_load:
            self._load_samples_from_directory(train_dir)

    def _load_val_samples(self) -> None:
        """Load validation samples."""
        val_dir = self._find_val_directory()

        if not val_dir:
            raise ValueError(f"Validation directory not found in {self.root}")

        logger.info(f"Loading from validation directory: {val_dir.name}")
        self._load_samples_from_directory(val_dir)

    def _load_samples_from_directory(self, directory: Path) -> None:
        """Load all samples from a directory with synset structure."""
        for synset_dir in directory.iterdir():
            if not synset_dir.is_dir():
                continue

            synset_id = synset_dir.name
            class_name = self.synset_to_class.get(synset_id, synset_id)

            # Find all image files
            image_extensions = ["*.JPEG", "*.jpg", "*.jpeg", "*.png"]
            for ext in image_extensions:
                for img_path in synset_dir.glob(ext):
                    self.samples.append((str(img_path), class_name, synset_id))

    def _create_class_mappings(self) -> None:
        """Create class to index mappings."""
        if self.synset_to_class:
            unique_classes = sorted(set(self.synset_to_class.values()))
        else:
            unique_classes = sorted(set(sample[1] for sample in self.samples))

        self.class_names = unique_classes
        self.class_to_idx = {name: idx for idx, name in enumerate(unique_classes)}
        self.idx_to_class = {idx: name for name, idx in self.class_to_idx.items()}

        logger.info(f"Created mappings for {len(self.class_names)} classes")

    def __getitem__(
        self, idx: Union[int, slice]
    ) -> Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]:
        """Get item(s) by index."""
        if not self._downloaded:
            self.download()

        if isinstance(idx, slice):
            indices = range(*idx.indices(len(self.samples)))
            return [self._get_single_item(i) for i in indices]
        else:
            return self._get_single_item(idx)

    def _get_single_item(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get a single item by index."""
        if idx >= len(self.samples) or idx < -len(self.samples):
            raise IndexError(f"Index {idx} out of range for dataset of size {len(self.samples)}")

        if idx < 0:
            idx = len(self.samples) + idx

        img_path, class_name, synset_id = self.samples[idx]

        if not os.path.exists(img_path):
            raise FileNotFoundError(f"Image {img_path} not found")

        # Load and transform image
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        # Get class index
        class_idx = self.class_to_idx[class_name]
        target = torch.tensor(class_idx, dtype=torch.long)

        return image, target

    def __iter__(self) -> Iterator[torch.Tensor]:
        """Iterate over dataset returning image tensors only."""
        if not self._downloaded:
            self.download()

        for idx in range(len(self.samples)):
            try:
                image, _ = self._get_single_item(idx)
                yield image
            except Exception as e:
                logger.warning(f"Skipping sample {idx}: {str(e)}")
                continue

    def __len__(self) -> int:
        """Return dataset size."""
        if not self._downloaded:
            self.download()
        return len(self.samples)

    def __repr__(self) -> str:
        """String representation."""
        return (
            f"ImageNet100Dataset(split='{self.split}', "
            f"size={len(self.samples) if self.samples else 'Unknown'}, "
            f"num_classes={len(self.class_names)}, root='{self.root}')"
        )

    def get_classes(self) -> Dict[str, Any]:
        """Get class information."""
        return {
            "num_classes": len(self.class_names),
            "class_names": self.class_names.copy(),
            "class_to_idx": self.class_to_idx.copy(),
            "idx_to_class": self.idx_to_class.copy(),
            "synset_to_class": (self.synset_to_class.copy() if self.synset_to_class else {}),
        }

    def get_class_name(self, idx: int) -> str:
        """Get class name from index."""
        if idx not in self.idx_to_class:
            raise ValueError(f"Class index {idx} not found")
        return self.idx_to_class[idx]

    def get_class_index(self, class_name: str) -> int:
        """Get index from class name."""
        if class_name not in self.class_to_idx:
            raise ValueError(f"Class name '{class_name}' not found")
        return self.class_to_idx[class_name]

    def get_sample_info(self, idx: int) -> Dict[str, Any]:
        """Get detailed information about a specific sample."""
        if not self._downloaded:
            self.download()

        if idx >= len(self.samples) or idx < -len(self.samples):
            raise IndexError(f"Index {idx} out of range")

        if idx < 0:
            idx = len(self.samples) + idx

        img_path, class_name, synset_id = self.samples[idx]

        return {
            "index": idx,
            "image_path": img_path,
            "class_name": class_name,
            "class_index": self.class_to_idx[class_name],
            "synset_id": synset_id,
            "split": self.split,
            "exists": os.path.exists(img_path),
        }

    def get_metadata(self) -> Dict[str, Any]:
        """Get dataset metadata."""
        metadata = super().get_metadata()
        metadata.update(
            {
                "num_samples": len(self.samples),
                "num_classes": len(self.class_names),
                "class_names": self.class_names[:10],  # First 10 for brevity
                "image_shape": "(3, 224, 224)",
                "split": self.split,
                "combine_train_splits": self.combine_train_splits,
            }
        )
        return metadata


File: ssrlib/datasets/base.py
from abc import ABC, abstractmethod
from typing import Iterator, Dict, Any, ClassVar
import torch


class BaseDataset(ABC):
    """Base class for all datasets in ssrlib with self-describing metadata."""

    # Class-level metadata - subclasses should override these
    _dataset_category: ClassVar[str] = "general"
    _dataset_modality: ClassVar[str] = "unknown"
    _dataset_properties: ClassVar[Dict[str, Any]] = {}

    def __init__(self, name: str, **kwargs):
        """Initialize dataset.

        Args:
            name: Name of the dataset
            **kwargs: Additional dataset-specific parameters
        """
        self.name = name
        self._metadata = {}
        self._downloaded = False

    @abstractmethod
    def download(self) -> None:
        """Download dataset if not already present."""
        pass

    @abstractmethod
    def __iter__(self) -> Iterator[torch.Tensor]:
        """Iterate over dataset returning tensors."""
        pass

    @abstractmethod
    def __len__(self) -> int:
        """Return dataset size."""
        pass

    def get_metadata(self) -> Dict[str, Any]:
        """Get dataset metadata."""
        return {
            "name": self.name,
            "size": len(self),
            "downloaded": self._downloaded,
            **self._metadata,
        }

    @classmethod
    def get_dataset_category(cls) -> str:
        """Get dataset category."""
        return cls._dataset_category

    @classmethod
    def get_dataset_modality(cls) -> str:
        """Get dataset modality."""
        return cls._dataset_modality

    @classmethod
    def get_dataset_properties(cls) -> Dict[str, Any]:
        """Get dataset properties."""
        return cls._dataset_properties.copy()


File: ssrlib/datasets/hf_mixin.py
"""Mixin for loading datasets from Hugging Face Hub."""

from pathlib import Path
from typing import Optional, Dict, Any
from abc import abstractmethod
import logging

logger = logging.getLogger(__name__)


class HFDatasetMixin:
    """
    Mixin for datasets loaded from Hugging Face Hub.

    Subclasses must implement:
    - _get_hf_dataset_id() -> str
    - _get_hf_split_name(split: str) -> str
    - _get_hf_keys() -> Dict[str, str]
    """

    @abstractmethod
    def _get_hf_dataset_id(self) -> str:
        """Return HuggingFace dataset ID (e.g., 'ethz/food101')."""
        pass

    @abstractmethod
    def _get_hf_split_name(self, split: str) -> str:
        """
        Map split name to HF dataset split.

        Args:
            split: Requested split ('train', 'test', 'val')

        Returns:
            Actual HF split name
        """
        pass

    @abstractmethod
    def _get_hf_keys(self) -> Dict[str, str]:
        """
        Get HF dataset column keys.

        Returns:
            Dict with 'image' and 'label' keys
        """
        pass

    def _load_from_huggingface(self, split: str, cache_dir: Optional[str] = None) -> None:
        """
        Load dataset from Hugging Face Hub.

        Args:
            split: Dataset split to load
            cache_dir: Optional cache directory
        """
        try:
            from datasets import load_dataset
        except ImportError:
            raise ImportError(
                "The 'datasets' package is required for HuggingFace datasets. "
                "Install it with: pip install datasets"
            )

        dataset_id = self._get_hf_dataset_id()
        hf_split = self._get_hf_split_name(split)

        logger.info(f"Loading HuggingFace dataset: {dataset_id}")
        print(f"Loading HuggingFace dataset: {dataset_id}")
        print(f"Split: {split} -> {hf_split}")

        try:
            self.hf_dataset = load_dataset(
                dataset_id,
                split=hf_split,
                cache_dir=cache_dir,
            )

            logger.info(f"✓ Loaded {len(self.hf_dataset)} examples")
            print(f"✓ Loaded {len(self.hf_dataset)} examples")

            # Get column keys
            self.hf_keys = self._get_hf_keys()
            self.image_key = self.hf_keys["image"]
            self.label_key = self.hf_keys["label"]

            # Setup label mapping if needed
            self._setup_label_mapping()

        except ValueError as e:
            if "trust_remote_code" in str(e).lower():
                logger.error(
                    f"Dataset {dataset_id} requires a loading script (deprecated). "
                    f"The dataset may need to be updated on HuggingFace Hub."
                )
                raise RuntimeError(
                    f"Cannot load dataset '{dataset_id}'. "
                    f"It uses a deprecated loading script. "
                    f"Please check if the dataset has been updated to Parquet format."
                ) from e
            else:
                raise

        except Exception as e:
            logger.error(f"Failed to load dataset {dataset_id}: {str(e)}")
            raise RuntimeError(f"Failed to load HuggingFace dataset: {str(e)}") from e

    def _setup_label_mapping(self) -> None:
        """
        Setup label mapping for string labels.

        Some datasets use string labels instead of integers.
        Creates bidirectional mapping: label <-> index.
        """
        self.label_to_idx = None
        self.idx_to_label = None

        try:
            # Check first label type
            if len(self.hf_dataset) == 0:
                logger.warning("Empty dataset, skipping label mapping")
                return

            first_label = self.hf_dataset[0][self.label_key]

            # If already numeric, no mapping needed
            if isinstance(first_label, (int, float)):
                logger.info("Labels are numeric, no mapping needed")
                return

            # Create mapping for string labels
            if isinstance(first_label, str):
                logger.info("Detected string labels, creating mapping...")

                # Try to get from features first (faster)
                all_labels = self._get_labels_from_features()

                # Fallback: collect from dataset
                if all_labels is None:
                    all_labels = self._collect_unique_labels()

                # Create bidirectional mapping
                sorted_labels = sorted(list(all_labels))
                self.label_to_idx = {label: idx for idx, label in enumerate(sorted_labels)}
                self.idx_to_label = {idx: label for label, idx in self.label_to_idx.items()}

                logger.info(f"✓ Created label mapping for {len(self.label_to_idx)} classes")
                print(f"✓ Label mapping: {len(self.label_to_idx)} classes")

        except Exception as e:
            logger.warning(f"Could not setup label mapping: {str(e)}")

    def _get_labels_from_features(self) -> Optional[set]:
        """Try to get labels from dataset features."""
        try:
            if hasattr(self.hf_dataset, "features"):
                label_feature = self.hf_dataset.features.get(self.label_key)
                if hasattr(label_feature, "names"):
                    labels = set(label_feature.names)
                    logger.info(f"Got {len(labels)} labels from features")
                    return labels
        except Exception as e:
            logger.debug(f"Could not get labels from features: {e}")

        return None

    def _collect_unique_labels(self) -> set:
        """Collect unique labels from dataset."""
        logger.info("Collecting unique labels from dataset...")

        try:
            # Try to get entire column (fast)
            label_column = self.hf_dataset[self.label_key]
            return set(label_column)
        except Exception:
            # Fallback: sample from dataset
            logger.info("Sampling labels from dataset...")
            all_labels = set()
            sample_size = min(1000, len(self.hf_dataset))
            step = max(1, len(self.hf_dataset) // sample_size)

            for i in range(0, len(self.hf_dataset), step):
                try:
                    label = self.hf_dataset[i][self.label_key]
                    all_labels.add(label)
                except Exception:
                    continue

            logger.info(f"Collected {len(all_labels)} unique labels from sampling")
            return all_labels

    def _convert_label(self, label: Any) -> int:
        """
        Convert label to integer index.

        Args:
            label: Label from dataset (string or int)

        Returns:
            Integer label index
        """
        # If string label and we have mapping
        if isinstance(label, str) and self.label_to_idx is not None:
            return self.label_to_idx[label]

        # If already int
        if isinstance(label, (int, float)):
            return int(label)

        # Fallback
        return label

    def _get_num_classes(self) -> int:
        """
        Get number of classes in dataset.

        Returns:
            Number of unique classes
        """
        if self.label_to_idx is not None:
            return len(self.label_to_idx)

        # Try to get from features
        try:
            if hasattr(self.hf_dataset, "features"):
                label_feature = self.hf_dataset.features.get(self.label_key)
                if hasattr(label_feature, "num_classes"):
                    return label_feature.num_classes
                elif hasattr(label_feature, "names"):
                    return len(label_feature.names)
        except Exception:
            pass

        # Fallback: count unique labels
        logger.warning("Could not determine number of classes, counting unique labels...")
        try:
            unique_labels = set(self.hf_dataset[self.label_key])
            return len(unique_labels)
        except Exception:
            return 0

    def _get_class_names(self) -> Optional[list]:
        """
        Get class names if available.

        Returns:
            List of class names or None
        """
        # From label mapping
        if self.idx_to_label is not None:
            return [self.idx_to_label[i] for i in range(len(self.idx_to_label))]

        # From features
        try:
            if hasattr(self.hf_dataset, "features"):
                label_feature = self.hf_dataset.features.get(self.label_key)
                if hasattr(label_feature, "names"):
                    return label_feature.names
        except Exception:
            pass

        return None


File: ssrlib/storage/tensor_storage.py
import os
import numpy as np
from typing import List, Iterator, Dict, Any, Optional, Union, Tuple
import json
import logging
from tqdm import tqdm
import pandas as pd
import shutil

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")


class TensorStorage:
    """
    TensorStorage that stores tensor data in chunks and metadata in parquet.
    Provides mapping between tensor indices and additional metadata parameters.
    """

    def __init__(
        self,
        storage_dir: str,
        description: str = "",
        chunk_size: Optional[int] = None,
        return_metadata: bool = False,
    ):
        """
        Initialize the TensorStorage.

        Args:
            storage_dir (str): Directory where the storage will be created or loaded from.
            description (str): Optional description of the storage (used only for new storage)
            chunk_size (Optional[int]): Size of each chunk in bytes. If None, will be loaded from metadata
                                      or set to default value.
            return_metadata (bool): If True, __getitem__ will return (tensor, metadata) pairs
        """
        self.storage_dir = storage_dir
        self.return_metadata = return_metadata
        self.chunks_dir = os.path.join(storage_dir, "chunks")
        self.metadata_dir = os.path.join(storage_dir, "metadata")
        self.metadata_file = os.path.join(self.metadata_dir, "metadata.json")
        self.parquet_file = os.path.join(self.metadata_dir, "tensor_metadata.parquet")

        # Create directories if they don't exist
        os.makedirs(self.chunks_dir, exist_ok=True)
        os.makedirs(self.metadata_dir, exist_ok=True)

        self.metadata = self._load_metadata()

        # Set description: use saved description if available, otherwise use parameter
        self.description = self.metadata.get("description", description)

        # Set chunk size with priority: provided > metadata > default
        default_chunk_size = 3 * 2**20 * np.dtype(np.float32).itemsize
        self.chunk_size = chunk_size or self.metadata.get("chunk_size") or default_chunk_size

        self.loaded_chunks = {}
        self.current_window = []

        # Load parquet metadata
        self.metadata_df = self.load_metadata_table()

    def _load_parquet_metadata(self):
        """Load the parquet metadata if it exists."""
        if os.path.exists(self.parquet_file):
            self.metadata_df = pd.read_parquet(self.parquet_file)
        else:
            self.metadata_df = pd.DataFrame()

    def _load_metadata(self):
        """Load metadata from the JSON file if it exists, otherwise return an empty dict."""
        if os.path.exists(self.metadata_file):
            with open(self.metadata_file, "r") as f:
                return json.load(f)
        return {}

    def _save_metadata(self):
        """Save the current metadata to the JSON file."""
        with open(self.metadata_file, "w") as f:
            json.dump(self.metadata, f)

    def _get_chunk_filename(self, chunk_id: int) -> str:
        """Generate the filename for a given chunk ID."""
        return os.path.join(self.chunks_dir, f"ch{chunk_id}")

    def _load_chunk(self, chunk_id: int) -> np.ndarray:
        """
        Load a chunk into memory if it's not already loaded.

        Args:
            chunk_id (int): ID of the chunk to load.

        Returns:
            np.ndarray: The loaded chunk data.
        """
        if chunk_id not in self.loaded_chunks:
            chunk_file = self._get_chunk_filename(chunk_id)
            self.loaded_chunks[chunk_id] = np.fromfile(chunk_file, dtype=np.float32)
        return self.loaded_chunks[chunk_id]

    def _unload_chunk(self, chunk_id: int):
        """Remove a chunk from memory."""
        if chunk_id in self.loaded_chunks:
            del self.loaded_chunks[chunk_id]

    def _update_window(self, needed_chunks: List[int]):
        """
        Update the window of loaded chunks based on what's needed.
        Unload unnecessary chunks and load new ones.

        Args:
            needed_chunks (List[int]): List of chunk IDs that are needed.
        """
        new_window = set(needed_chunks)
        for chunk_id in self.current_window:
            if chunk_id not in new_window:
                self._unload_chunk(chunk_id)
        for chunk_id in new_window:
            if chunk_id not in self.current_window:
                self._load_chunk(chunk_id)
        self.current_window = list(new_window)

    def __getitem__(self, idx: int) -> Union[np.ndarray, Tuple[np.ndarray, Dict[str, Any]]]:
        """
        Retrieve a tensor from storage by its index.

        Args:
            idx (int): Index of the tensor to retrieve.

        Returns:
            Union[np.ndarray, Tuple[np.ndarray, Dict[str, Any]]]:
                If return_metadata is False: just the tensor
                If return_metadata is True: tuple of (tensor, metadata_dict)

        Raises:
            IndexError: If the index is not found in storage.
        """
        if str(idx) not in self.metadata["elements"]:
            raise IndexError(f"Index {idx} not found in storage")

        item_meta = self.metadata["elements"][str(idx)]
        chunks_info = item_meta["chunks"]
        shape = item_meta["shape"]

        needed_chunks = [chunk_info[0] for chunk_info in chunks_info]
        self._update_window(needed_chunks)

        data = []
        for chunk_id, start_idx, end_idx in chunks_info:
            chunk_data = self._load_chunk(chunk_id)[start_idx:end_idx]
            data.append(chunk_data)

        tensor = np.concatenate(data).reshape(shape)

        if not self.return_metadata:
            return tensor

        # Get metadata if requested
        if self.metadata_df is not None and not self.metadata_df.empty:
            metadata = self.metadata_df[self.metadata_df["tensor_idx"] == idx].iloc[0].to_dict()
        else:
            metadata = {}

        return tensor, metadata

    def __len__(self):
        """Return the number of tensors in the storage."""
        return len(self.metadata["elements"])

    def __repr__(self) -> str:
        """Return string representation of the storage."""
        total_size = 0
        if os.path.exists(self.chunks_dir):
            for chunk_file in os.listdir(self.chunks_dir):
                total_size += os.path.getsize(os.path.join(self.chunks_dir, chunk_file))

        info = [
            f"TensorStorage at '{self.storage_dir}'",
            f"Description: {self.description}",
            f"Number of tensors: {len(self)}",
            f"Chunk size: {self.chunk_size / (1024 * 1024):.2f} MB",
            f"Total storage size: {total_size / (1024 * 1024):.2f} MB",
        ]

        # Add shape information if available
        if len(self) > 0:
            first_tensor_meta = self.metadata["elements"]["0"]
            info.append(f"Tensor shape: {first_tensor_meta['shape']}")

        return "\n".join(info)

    def load_metadata_table(self) -> Optional[pd.DataFrame]:
        """
        Load the metadata table from parquet file.

        Returns:
            Optional[pd.DataFrame]: DataFrame with metadata or None if file doesn't exist
        """
        if os.path.exists(self.parquet_file):
            return pd.read_parquet(self.parquet_file)
        return None

    def get_storage_info(self) -> Dict[str, Any]:
        """Get detailed information about the storage."""
        total_size = 0
        if os.path.exists(self.chunks_dir):
            for chunk_file in os.listdir(self.chunks_dir):
                total_size += os.path.getsize(os.path.join(self.chunks_dir, chunk_file))

        # Ensure chunk_size is never None for calculations
        chunk_size = self.chunk_size or (3 * 2**20 * np.dtype(np.float32).itemsize)

        info = {
            "storage_dir": self.storage_dir,
            "description": self.description,
            "num_tensors": len(self),
            "chunk_size_mb": chunk_size / (1024 * 1024),
            "total_size_mb": total_size / (1024 * 1024),
        }

        if len(self) > 0:
            first_tensor_meta = self.metadata["elements"]["0"]
            info["tensor_shape"] = first_tensor_meta["shape"]

        return info

    def get_tensor_by_param(self, param_name: str, param_value: Any) -> Optional[np.ndarray]:
        """
        Retrieve a tensor by querying a parameter in the metadata.

        Args:
            param_name (str): Name of the parameter to query
            param_value (Any): Value to search for

        Returns:
            Optional[np.ndarray]: The tensor if found, None otherwise
        """
        # Check if the parameter exists in the metadata DataFrame
        if self.metadata_df is None or self.metadata_df.empty:
            return None

        if param_name not in self.metadata_df.columns:
            return None  # Parameter doesn't exist

        matches = self.metadata_df[self.metadata_df[param_name] == param_value]
        if len(matches) == 0:
            return None

        tensor_idx = matches.iloc[0]["tensor_idx"]
        return self[tensor_idx]

    def get_params_for_tensor(self, tensor_idx: int) -> Dict[str, Any]:
        """
        Get all parameters associated with a tensor.

        Args:
            tensor_idx (int): Index of the tensor

        Returns:
            Dict[str, Any]: Dictionary of parameters
        """
        if self.metadata_df is None or self.metadata_df.empty:
            return {}

        matches = self.metadata_df[self.metadata_df["tensor_idx"] == tensor_idx]
        if len(matches) == 0:
            return {}

        return matches.iloc[0].to_dict()

    def get_tensors_by_batch(self, batch_id: int) -> List[np.ndarray]:
        """
        Retrieve all tensors associated with a specific batch ID.

        Args:
            batch_id (int): The batch ID to query

        Returns:
            List[np.ndarray]: List of tensors in the batch
        """
        if self.metadata_df is None or self.metadata_df.empty:
            return []

        if "batch_id" not in self.metadata_df.columns:
            return []

        matches = self.metadata_df[self.metadata_df["batch_id"] == batch_id]
        return [self[idx] for idx in matches["tensor_idx"]]

    def filter_tensors(self, **kwargs) -> List[int]:
        """
        Filter tensors based on multiple metadata parameters.

        Args:
            **kwargs: Key-value pairs of metadata parameters to filter by

        Returns:
            List[int]: List of tensor indices matching all criteria
        """
        if self.metadata_df is None or self.metadata_df.empty:
            return []

        filtered_df = self.metadata_df
        for key, value in kwargs.items():
            if key not in filtered_df.columns:
                # If any key doesn't exist, no matches possible
                return []
            filtered_df = filtered_df[filtered_df[key] == value]
        return filtered_df["tensor_idx"].tolist()

    @staticmethod
    def create_storage(
        storage_dir: str,
        data_iterator: Iterator[np.ndarray],
        metadata_iterator: Iterator[Dict[str, Any]],
        chunk_size: Optional[int] = None,
        description: str = "",
    ) -> "TensorStorage":
        """
        Create a new TensorStorage from iterators of numpy arrays and metadata.

        Args:
            storage_dir (str): Directory where the storage will be created.
            data_iterator (Iterator[np.ndarray]): Iterator yielding numpy arrays to store.
            metadata_iterator (Iterator[Dict[str, Any]]): Iterator yielding metadata dicts.
            chunk_size (int): Size of each chunk in bytes.
            description (str): Optional description of the storage.

        Returns:
            TensorStorage: The created storage instance.
        """
        os.makedirs(os.path.join(storage_dir, "chunks"), exist_ok=True)
        os.makedirs(os.path.join(storage_dir, "metadata"), exist_ok=True)

        storage = TensorStorage(storage_dir, description, chunk_size)

        # Ensure chunk_size is set to actual value, not None
        if chunk_size is None:
            chunk_size = storage.chunk_size

        logging.info(f"Creating storage in directory: {storage_dir}")
        logging.info(f"Chunk size: {storage.chunk_size / (1024 * 1024):.2f} MB")

        current_chunk = []
        current_chunk_size = 0
        chunk_id = 0
        elements_metadata = {}
        metadata_records = []
        total_elements = 0

        progress_bar = tqdm(desc="Processing arrays", unit="array")

        for idx, (arr, metadata_dict) in enumerate(zip(data_iterator, metadata_iterator)):
            total_elements += 1
            progress_bar.update(1)

            flat_arr = arr.flatten()
            arr_size = flat_arr.nbytes

            if arr_size > storage.chunk_size:
                raise ValueError(f"Array at index {idx} is larger than the maximum chunk size")

            # Handle chunk storage
            if current_chunk_size + arr_size > storage.chunk_size:
                chunk_filename = storage._get_chunk_filename(chunk_id)
                np.concatenate(current_chunk).astype(np.float32).tofile(chunk_filename)
                chunk_id += 1
                current_chunk = []
                current_chunk_size = 0

            start_idx = current_chunk_size // np.dtype(np.float32).itemsize
            end_idx = start_idx + flat_arr.size
            current_chunk.append(flat_arr)
            current_chunk_size += arr_size

            # Store tensor metadata
            elements_metadata[str(idx)] = {
                "shape": arr.shape,
                "chunks": [(chunk_id, start_idx, end_idx)],
            }

            # Store additional metadata
            metadata_dict["tensor_idx"] = idx
            metadata_records.append(metadata_dict)

        # Save the last chunk if there's any data left
        if current_chunk:
            chunk_filename = storage._get_chunk_filename(chunk_id)
            np.concatenate(current_chunk).astype(np.float32).tofile(chunk_filename)

        progress_bar.close()

        # Save the tensor metadata
        storage.metadata = {
            "chunk_size": chunk_size,
            "total_elements": total_elements,
            "total_chunks": chunk_id + 1,
            "elements": elements_metadata,
            "description": description,
        }
        storage._save_metadata()

        # Save the parquet metadata
        metadata_df = pd.DataFrame(metadata_records)
        metadata_df.to_parquet(storage.parquet_file, index=False)
        storage.metadata_df = metadata_df

        logging.info(f"Storage creation complete. Total elements: {total_elements}")
        logging.info(f"Total chunks created: {chunk_id + 1}")
        return storage

    def close(self):
        """Clean up resources and ensure all metadata is saved."""
        self._save_metadata()
        self.loaded_chunks.clear()
        self.current_window.clear()

    def __enter__(self):
        """Context manager entry."""
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """Context manager exit."""
        self.close()

    def rebuild_storage(
        self,
        new_storage_dir: Optional[str] = None,
        new_chunk_size: Optional[int] = None,
        description: Optional[str] = None,
        inplace: bool = False,
    ) -> "TensorStorage":
        """
        Rebuild the storage with new chunk size.

        Args:
            new_storage_dir (Optional[str]): Directory for the new storage.
                                           If None and inplace=True, will rebuild in place.
            new_chunk_size (Optional[int]): New chunk size in bytes. If None, uses current chunk size
            description (Optional[str]): New description. If None, uses current description
            inplace (bool): If True, will rebuild the storage in place, replacing current storage

        Returns:
            TensorStorage: New storage instance with updated chunk size

        Raises:
            ValueError: If inplace is True and new_storage_dir is provided
        """
        if inplace and new_storage_dir is not None:
            raise ValueError("Cannot specify new_storage_dir when inplace=True")

        if inplace:
            # Create temporary directory for rebuilding
            temp_dir = os.path.join(
                os.path.dirname(self.storage_dir),
                f"{os.path.basename(self.storage_dir)}_temp",
            )
        else:
            if new_storage_dir is None:
                raise ValueError("Must specify new_storage_dir when inplace=False")
            temp_dir = new_storage_dir

        # Use current values if new ones not provided
        new_chunk_size = new_chunk_size or self.chunk_size
        new_description = description or self.description

        try:
            # Create iterators for current data
            def tensor_iterator() -> Iterator[np.ndarray]:
                for i in range(len(self)):
                    yield self[i]

            def metadata_iterator() -> Iterator[Dict[str, Any]]:
                for i in range(len(self)):
                    yield self.get_params_for_tensor(i)

            # Create new storage with different chunk size
            new_storage = TensorStorage.create_storage(
                storage_dir=temp_dir,
                data_iterator=tensor_iterator(),
                metadata_iterator=metadata_iterator(),
                chunk_size=new_chunk_size,
                description=new_description,
            )

            if inplace:
                # Verify the new storage before replacing
                self._verify_rebuilt_storage(new_storage)

                # Close both storages to ensure all files are written
                self.close()
                new_storage.close()

                # Replace old storage with new one
                backup_dir = f"{self.storage_dir}_backup"
                os.rename(self.storage_dir, backup_dir)
                try:
                    os.rename(temp_dir, self.storage_dir)
                except Exception as e:
                    # If something goes wrong, restore from backup
                    os.rename(backup_dir, self.storage_dir)
                    raise e

                # Remove backup after successful replacement
                shutil.rmtree(backup_dir)

                # Reinitialize self with new storage
                self.__init__(
                    storage_dir=self.storage_dir,
                    description=new_description,
                    chunk_size=new_chunk_size,
                    return_metadata=self.return_metadata,
                )
                return self
            else:
                return new_storage

        except Exception as e:
            # Clean up temporary directory if something goes wrong
            if inplace and os.path.exists(temp_dir):
                shutil.rmtree(temp_dir)
            raise e

    def _verify_rebuilt_storage(self, new_storage: "TensorStorage") -> bool:
        """
        Verify that the rebuilt storage contains the same data.

        Args:
            new_storage: The newly built storage to verify

        Returns:
            bool: True if verification passes

        Raises:
            ValueError: If verification fails
        """
        # Verify basic properties
        if len(self) != len(new_storage):
            raise ValueError("New storage has different number of elements")

        # Verify sample of tensors
        sample_size = min(100, len(self))  # Check up to 100 random tensors
        indices = np.random.choice(len(self), sample_size, replace=False)

        for idx in indices:
            original_tensor = self[idx]
            new_tensor = new_storage[idx]
            if not np.allclose(original_tensor, new_tensor):
                raise ValueError(f"Data mismatch at index {idx}")

            original_meta = self.get_params_for_tensor(idx)
            new_meta = new_storage.get_params_for_tensor(idx)
            if original_meta != new_meta:
                raise ValueError(f"Metadata mismatch at index {idx}")

        return True


File: ssrlib/storage/__init__.py
from .tensor_storage import TensorStorage

__all__ = ["TensorStorage"]


File: ssrlib/embedders/__init__.py
"""Embedder implementations for ssrlib with automatic discovery."""

import logging
from pathlib import Path
from typing import Dict, List, Type, Any
import warnings

logger = logging.getLogger(__name__)

# Import base class and registry system
from .base import BaseEmbedder
from ..core.registry import BaseRegistry, discover_components

# Type alias
EmbedderRegistry = BaseRegistry[BaseEmbedder]


def discover_embedder_classes() -> EmbedderRegistry:
    """Discover all embedder classes in the embedders module."""
    registry = EmbedderRegistry("embedder").enable_modalities()

    return discover_components(
        package_path=Path(__file__).parent,
        package_name=__name__,
        base_class=BaseEmbedder,
        registry=registry,
    )


# Perform discovery at import time
logger.debug("Starting embedder discovery...")
_embedder_registry = discover_embedder_classes()


# Convenience functions
def get_available_embedders() -> Dict[str, Type[BaseEmbedder]]:
    """Get dictionary of all available embedders."""
    return _embedder_registry._items.copy()


def get_embedder_descriptions() -> Dict[str, str]:
    """Get dictionary of embedder descriptions."""
    return _embedder_registry._descriptions.copy()


def list_embedders(category: str = None, modality: str = None) -> List[str]:
    """List available embedders with optional filtering."""
    if category:
        return _embedder_registry.list_by_category(category).get(category, [])
    elif modality:
        return _embedder_registry.list_by_modality(modality).get(modality, [])
    return _embedder_registry.list_all()


def get_embedder_info(name: str) -> Dict[str, Any]:
    """Get detailed information about an embedder."""
    return _embedder_registry.get_info(name)


def print_available_embedders() -> None:
    """Print all available embedders with descriptions."""
    _embedder_registry.print_registry()


def create_embedder(name: str, **kwargs) -> BaseEmbedder:
    """Create an embedder by name."""
    embedder_class = _embedder_registry.get(name)
    if embedder_class is None:
        available = ", ".join(_embedder_registry.list_all())
        raise ValueError(f"Unknown embedder '{name}'. Available: {available}")
    return embedder_class(**kwargs)


def get_vision_embedders() -> List[str]:
    """Get list of vision embedders."""
    return list_embedders(modality="vision")


def get_text_embedders() -> List[str]:
    """Get list of text embedders."""
    return list_embedders(modality="text")


def get_audio_embedders() -> List[str]:
    """Get list of audio embedders."""
    return list_embedders(modality="audio")


def get_multimodal_embedders() -> List[str]:
    """Get list of multimodal embedders."""
    return list_embedders(modality="multimodal")


def get_embedders_by_category(category: str) -> List[str]:
    """Get embedders by category."""
    return list_embedders(category=category)


def get_embedder_categories() -> List[str]:
    """Get list of all available categories."""
    return list(_embedder_registry._categories.keys())


def get_embedder_modalities() -> List[str]:
    """Get list of all available modalities."""
    if _embedder_registry._modalities:
        return list(set(_embedder_registry._modalities.values()))
    return []


# Create dynamic exports
_exported_classes = {}
for name, embedder_class in _embedder_registry._items.items():
    _exported_classes[name] = embedder_class

# Update module globals
globals().update(_exported_classes)

# Create __all__ dynamically
__all__ = [
    "BaseEmbedder",
    "get_available_embedders",
    "get_embedder_descriptions",
    "list_embedders",
    "get_embedder_info",
    "print_available_embedders",
    "create_embedder",
    "get_vision_embedders",
    "get_text_embedders",
    "get_audio_embedders",
    "get_multimodal_embedders",
    "get_embedders_by_category",
    "get_embedder_categories",
    "get_embedder_modalities",
    *_embedder_registry.list_all(),
]

# Log results
if logger.isEnabledFor(logging.INFO):
    logger.info(
        f"Embedder discovery complete: {len(_embedder_registry.list_all())} embedders found"
    )
    for category, embedders in _embedder_registry.list_by_category().items():
        logger.info(f"  {category}: {', '.join(embedders)}")

# Warn about errors
if _embedder_registry._discovery_errors:
    warnings.warn(
        f"Some embedder modules failed to import: {len(_embedder_registry._discovery_errors)} errors. "
        f"Run logging.getLogger('{__name__}').setLevel(logging.DEBUG) for details.",
        ImportWarning,
    )


File: ssrlib/embedders/base.py
"""Base embedder implementation for ssrlib."""

import torch
import torch.nn as nn
from abc import ABC, abstractmethod
from typing import Optional, Dict, Any, ClassVar
import numpy as np
import logging

logger = logging.getLogger(__name__)


class BaseEmbedder(ABC):
    """Base class for all embedders in ssrlib with self-describing metadata."""

    # Class-level metadata - subclasses should override these
    _embedder_category: ClassVar[str] = "general"
    _embedder_modality: ClassVar[str] = "unknown"
    _embedder_properties: ClassVar[Dict[str, Any]] = {}

    def __init__(self, name: str, device: str = "cpu", batch_size: int = 32, **kwargs):
        """Initialize embedder.

        Args:
            name: Model name
            device: Device to use ('cpu' or 'cuda')
            batch_size: Default batch size for processing
            **kwargs: Additional configuration
        """
        self.name = name
        self.device = torch.device(device)
        self.batch_size = batch_size
        self.model = None
        self._loaded = False
        self._metadata = {
            "category": self.get_embedder_category(),
            "modality": self.get_embedder_modality(),
        }
        self._metadata.update(kwargs)

    @abstractmethod
    def load_model(self) -> None:
        """Load the pretrained model."""
        pass

    @abstractmethod
    def forward(self, batch: torch.Tensor) -> torch.Tensor:
        """Extract embeddings from a batch.

        Args:
            batch: Input tensor

        Returns:
            Embeddings tensor
        """
        pass

    @abstractmethod
    def get_embedding_dim(self) -> int:
        """Get the dimension of embeddings produced by this model.

        Returns:
            Embedding dimension
        """
        pass

    def embed_dataset(self, dataset, batch_size: Optional[int] = None) -> np.ndarray:
        """Extract embeddings for entire dataset with batching.

        Args:
            dataset: Dataset to embed (iterable yielding tensors)
            batch_size: Batch size for processing (uses default if None)

        Returns:
            Embeddings array of shape (n_samples, embedding_dim)
        """
        if not self._loaded:
            self.load_model()

        batch_size = batch_size or self.batch_size
        embeddings = []
        current_batch = []

        logger.info(f"Extracting embeddings with {self.name} (batch_size={batch_size})")

        for i, sample in enumerate(dataset):
            current_batch.append(sample)

            if len(current_batch) == batch_size:
                batch_tensor = torch.stack(current_batch).to(self.device)
                batch_embeddings = self.forward(batch_tensor)
                embeddings.append(batch_embeddings.cpu().numpy())
                current_batch = []

                if (i + 1) % (batch_size * 10) == 0:
                    logger.info(f"  Processed {i + 1} samples")

        # Process remaining samples
        if current_batch:
            batch_tensor = torch.stack(current_batch).to(self.device)
            batch_embeddings = self.forward(batch_tensor)
            embeddings.append(batch_embeddings.cpu().numpy())

        result = np.concatenate(embeddings, axis=0)
        logger.info(f"  Extracted embeddings shape: {result.shape}")
        return result

    def unload_model(self) -> None:
        """Unload model to free memory."""
        if self.model is not None:
            del self.model
            self.model = None

        self._loaded = False

        # Clear GPU cache if using CUDA
        if self.device.type == "cuda":
            torch.cuda.empty_cache()

        logger.info(f"Unloaded model: {self.name}")

    def get_metadata(self) -> Dict[str, Any]:
        """Get embedder metadata.

        Returns:
            Dictionary with metadata
        """
        return {
            "name": self.name,
            "device": str(self.device),
            "loaded": self._loaded,
            "embedding_dim": self.get_embedding_dim(),
            **self._metadata,
        }

    def get_model_info(self) -> Dict[str, Any]:
        """Get information about the loaded model.

        Returns:
            Dictionary with model information
        """
        info = {
            "name": self.name,
            "category": self.get_embedder_category(),
            "modality": self.get_embedder_modality(),
            "device": str(self.device),
            "is_loaded": self._loaded,
            "embedding_dim": self.get_embedding_dim(),
        }

        if self._loaded and self.model is not None:
            # Count parameters
            total_params = sum(p.numel() for p in self.model.parameters())
            info["total_parameters"] = total_params
            info["trainable_parameters"] = sum(
                p.numel() for p in self.model.parameters() if p.requires_grad
            )

        return info

    @classmethod
    def get_embedder_category(cls) -> str:
        """Get embedder category."""
        return cls._embedder_category

    @classmethod
    def get_embedder_modality(cls) -> str:
        """Get embedder modality."""
        return cls._embedder_modality

    @classmethod
    def get_embedder_properties(cls) -> Dict[str, Any]:
        """Get embedder properties."""
        return cls._embedder_properties.copy()

    def __del__(self):
        """Cleanup on deletion."""
        try:
            self.unload_model()
        except:
            pass

    def __repr__(self) -> str:
        """String representation."""
        return (
            f"{self.__class__.__name__}(name='{self.name}', "
            f"device='{self.device}', loaded={self._loaded})"
        )


File: ssrlib/embedders/nlp/e5.py
"""E5 Multilingual embedder implementation."""

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
from typing import Dict, Any, ClassVar, Optional

from ..base import BaseEmbedder


class E5Embedder(BaseEmbedder):
    """E5 Multilingual embedder for natural language processing."""

    # Class-level metadata
    _embedder_category: ClassVar[str] = "nlp"
    _embedder_modality: ClassVar[str] = "text"
    _embedder_properties: ClassVar[Dict[str, Any]] = {
        "model_family": "E5",
        "source": "Microsoft",
        "multilingual": True,
        "instruction_following": True,
        "contrastive_learning": True,
        "max_sequence_length": 512,
        "pooling_strategy": "mean",
        "normalize_embeddings": True,
        "supports_94_languages": True,
    }

    AVAILABLE_MODELS = {
        "e5-small": {
            "embedding_dim": 384,
            "hf_name": "intfloat/e5-small-v2",
            "multilingual": False,
        },
        "e5-base": {
            "embedding_dim": 768,
            "hf_name": "intfloat/e5-base-v2",
            "multilingual": False,
        },
        "e5-large": {
            "embedding_dim": 1024,
            "hf_name": "intfloat/e5-large-v2",
            "multilingual": False,
        },
        "multilingual-e5-small": {
            "embedding_dim": 384,
            "hf_name": "intfloat/multilingual-e5-small",
            "multilingual": True,
        },
        "multilingual-e5-base": {
            "embedding_dim": 768,
            "hf_name": "intfloat/multilingual-e5-base",
            "multilingual": True,
        },
        "multilingual-e5-large": {
            "embedding_dim": 1024,
            "hf_name": "intfloat/multilingual-e5-large",
            "multilingual": True,
        },
        "multilingual-e5-large-instruct": {
            "embedding_dim": 1024,
            "hf_name": "intfloat/multilingual-e5-large-instruct",
            "multilingual": True,
            "instruction_following": True,
        },
    }

    def __init__(
        self,
        model_name: str = "multilingual-e5-base",
        device: str = "cpu",
        normalize: bool = True,
        **kwargs,
    ):
        """Initialize E5 embedder.

        Args:
            model_name: Name of the E5 model to use
            device: Device to run on ('cpu' or 'cuda')
            normalize: Whether to normalize embeddings
            **kwargs: Additional arguments
        """
        super().__init__(f"E5_{model_name}", device, **kwargs)

        if model_name not in self.AVAILABLE_MODELS:
            raise ValueError(
                f"Unknown model {model_name}. " f"Available: {list(self.AVAILABLE_MODELS.keys())}"
            )

        self.model_name = model_name
        self.hf_name = self.AVAILABLE_MODELS[model_name]["hf_name"]
        self.embedding_dim = self.AVAILABLE_MODELS[model_name]["embedding_dim"]
        self.normalize = normalize
        self.tokenizer = None
        self.is_instruct = self.AVAILABLE_MODELS[model_name].get("instruction_following", False)

        # Update metadata
        self._metadata.update(
            {
                "model_name": model_name,
                "hf_name": self.hf_name,
                "embedding_dim": self.embedding_dim,
                "normalize": normalize,
                "model_family": "E5",
                "multilingual": self.AVAILABLE_MODELS[model_name]["multilingual"],
                "instruction_following": self.is_instruct,
            }
        )

    def get_embedding_dim(self) -> int:
        """Get embedding dimension."""
        return self.embedding_dim

    def load_model(self) -> None:
        """Load E5 model from Hugging Face."""
        if self._loaded:
            return

        print(f"Loading E5 model: {self.hf_name}")
        try:
            self.model = AutoModel.from_pretrained(self.hf_name)
            self.tokenizer = AutoTokenizer.from_pretrained(self.hf_name)
            self.model = self.model.to(self.device)
            self.model.eval()
            self._loaded = True
            print(f"Successfully loaded {self.model_name}")
        except Exception as e:
            raise RuntimeError(f"Failed to load {self.model_name}: {str(e)}")

    def _average_pool(
        self, last_hidden_states: torch.Tensor, attention_mask: torch.Tensor
    ) -> torch.Tensor:
        """Perform average pooling on token embeddings.

        This is the official pooling method for E5 models.

        Args:
            last_hidden_states: Token embeddings from model
            attention_mask: Attention mask

        Returns:
            Pooled embeddings
        """
        last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
        return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

    def forward(self, batch: torch.Tensor) -> torch.Tensor:
        """Forward pass through E5 model.

        Note: This expects pre-tokenized input_ids as tensors.
        For text input, use embed_texts() method instead.

        Args:
            batch: Input batch of token IDs of shape (batch_size, seq_len)

        Returns:
            Embeddings of shape (batch_size, embedding_dim)
        """
        if not self._loaded:
            self.load_model()

        self.model.eval()
        with torch.no_grad():
            # Create attention mask (assuming padding token is 0)
            attention_mask = (batch != 0).long()

            outputs = self.model(input_ids=batch, attention_mask=attention_mask)

            # Use average pooling (E5's official method)
            embeddings = self._average_pool(outputs.last_hidden_state, attention_mask)

            # Normalize if requested
            if self.normalize:
                embeddings = F.normalize(embeddings, p=2, dim=1)

        return embeddings

    def embed_texts(
        self, texts: list, max_length: int = 512, task_instruction: Optional[str] = None
    ) -> torch.Tensor:
        """Embed a list of texts.

        Args:
            texts: List of text strings
            max_length: Maximum sequence length
            task_instruction: Optional task instruction (for instruct models)
                Example: "Given a web search query, retrieve relevant passages"

        Returns:
            Embeddings tensor of shape (len(texts), embedding_dim)
        """
        if not self._loaded:
            self.load_model()

        # Add instruction prefix if using instruct model
        if self.is_instruct and task_instruction:
            texts = [f"Instruct: {task_instruction}\nQuery: {text}" for text in texts]

        # Tokenize texts
        encoded = self.tokenizer(
            texts,
            padding=True,
            truncation=True,
            max_length=max_length,
            return_tensors="pt",
        )

        input_ids = encoded["input_ids"].to(self.device)
        attention_mask = encoded["attention_mask"].to(self.device)

        self.model.eval()
        with torch.no_grad():
            outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)

            # Use average pooling (E5's official method)
            embeddings = self._average_pool(outputs.last_hidden_state, attention_mask)

            # Normalize if requested
            if self.normalize:
                embeddings = F.normalize(embeddings, p=2, dim=1)

        return embeddings

    def embed_queries_and_documents(
        self, queries: list, documents: list, task_instruction: Optional[str] = None
    ) -> tuple:
        """Embed queries and documents separately (for retrieval tasks).

        For instruct models, queries get the instruction prefix but documents don't.

        Args:
            queries: List of query strings
            documents: List of document strings
            task_instruction: Task instruction for queries (instruct models only)

        Returns:
            Tuple of (query_embeddings, document_embeddings)
        """
        # Embed queries with instruction
        query_embeddings = self.embed_texts(queries, task_instruction=task_instruction)

        # Embed documents without instruction
        document_embeddings = self.embed_texts(documents, task_instruction=None)

        return query_embeddings, document_embeddings


File: ssrlib/embedders/nlp/modernbert.py
"""ModernBERT embedder implementation."""

from typing import Dict, Any, ClassVar

from .bert_base import TransformerEmbedderBase


class ModernBERTEmbedder(TransformerEmbedderBase):
    """ModernBERT embedder for natural language processing."""

    # Class-level metadata
    _embedder_category: ClassVar[str] = "nlp"
    _embedder_modality: ClassVar[str] = "text"
    _embedder_properties: ClassVar[Dict[str, Any]] = {
        "model_family": "ModernBERT",
        "source": "Answer.AI",
        "modernized_architecture": True,
        "rotary_embeddings": True,
        "max_sequence_length": 8192,
        "pooling_strategy": "cls",
    }

    MODEL_FAMILY_NAME: ClassVar[str] = "ModernBERT"
    DEFAULT_MODEL: ClassVar[str] = "modernbert-base"

    AVAILABLE_MODELS = {
        "modernbert-base": {
            "embedding_dim": 768,
            "hf_name": "answerdotai/ModernBERT-base",
            "num_layers": 22,
        },
        "modernbert-large": {
            "embedding_dim": 1024,
            "hf_name": "answerdotai/ModernBERT-large",
            "num_layers": 28,
        },
    }

    def __init__(
        self,
        model_name: str = "modernbert-base",
        pooling: str = "cls",
        device: str = "cpu",
        **kwargs,
    ):
        """Initialize ModernBERT embedder.

        Args:
            model_name: Name of the ModernBERT model to use
            pooling: Pooling strategy ('cls' or 'mean')
            device: Device to run on ('cpu' or 'cuda')
            **kwargs: Additional arguments
        """
        super().__init__(model_name, pooling, device, **kwargs)

    def _get_model_metadata(self) -> Dict[str, Any]:
        """Get ModernBERT-specific metadata."""
        return {
            "model_name": self.model_name,
            "hf_name": self.hf_name,
            "embedding_dim": self.embedding_dim,
            "pooling": self.pooling,
            "model_family": "ModernBERT",
            "num_layers": self.AVAILABLE_MODELS[self.model_name]["num_layers"],
        }

    def _get_default_max_length(self) -> int:
        """ModernBERT default max length is 512 (same as BERT)."""
        return 512

    def _clamp_max_length(self, max_length: int) -> int:
        """ModernBERT supports up to 8192 tokens."""
        return min(max_length, 8192)


File: ssrlib/embedders/nlp/__init__.py
"""Natural language processing embedders."""

from .bert import BERTEmbedder
from .modernbert import ModernBERTEmbedder
from .e5 import E5Embedder

__all__ = ["BERTEmbedder", "ModernBERTEmbedder", "E5Embedder"]


File: ssrlib/embedders/nlp/bert.py
"""BERT embedder implementation."""

from typing import Dict, Any, ClassVar

from .bert_base import TransformerEmbedderBase


class BERTEmbedder(TransformerEmbedderBase):
    """BERT embedder for natural language processing."""

    # Class-level metadata
    _embedder_category: ClassVar[str] = "nlp"
    _embedder_modality: ClassVar[str] = "text"
    _embedder_properties: ClassVar[Dict[str, Any]] = {
        "model_family": "BERT",
        "source": "Google",
        "masked_language_model": True,
        "bidirectional": True,
        "architecture": "Transformer",
        "max_sequence_length": 512,
        "pooling_strategy": "cls",
    }

    MODEL_FAMILY_NAME: ClassVar[str] = "BERT"
    DEFAULT_MODEL: ClassVar[str] = "bert-base-uncased"

    AVAILABLE_MODELS = {
        "bert-base-uncased": {
            "embedding_dim": 768,
            "hf_name": "bert-base-uncased",
            "languages": ["en"],
        },
        "bert-base-cased": {
            "embedding_dim": 768,
            "hf_name": "bert-base-cased",
            "languages": ["en"],
        },
        "bert-large-uncased": {
            "embedding_dim": 1024,
            "hf_name": "bert-large-uncased",
            "languages": ["en"],
        },
        "bert-large-cased": {
            "embedding_dim": 1024,
            "hf_name": "bert-large-cased",
            "languages": ["en"],
        },
        "bert-base-multilingual-cased": {
            "embedding_dim": 768,
            "hf_name": "bert-base-multilingual-cased",
            "languages": ["multilingual"],
        },
    }

    def __init__(
        self,
        model_name: str = "bert-base-uncased",
        pooling: str = "cls",
        device: str = "cpu",
        **kwargs,
    ):
        """Initialize BERT embedder.

        Args:
            model_name: Name of the BERT model to use
            pooling: Pooling strategy ('cls' or 'mean')
            device: Device to run on ('cpu' or 'cuda')
            **kwargs: Additional arguments
        """
        super().__init__(model_name, pooling, device, **kwargs)

    def _get_model_metadata(self) -> Dict[str, Any]:
        """Get BERT-specific metadata."""
        return {
            "model_name": self.model_name,
            "hf_name": self.hf_name,
            "embedding_dim": self.embedding_dim,
            "pooling": self.pooling,
            "model_family": "BERT",
            "languages": self.AVAILABLE_MODELS[self.model_name]["languages"],
        }


File: ssrlib/embedders/nlp/bert_base.py
"""Base class for BERT-style transformer embedders."""

import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
from typing import Dict, Any, ClassVar, Optional
from abc import ABC, abstractmethod

from ..base import BaseEmbedder


class TransformerEmbedderBase(BaseEmbedder, ABC):
    """Base class for BERT-style transformer embedders (BERT, ModernBERT, etc.)."""

    # Subclasses must define these
    AVAILABLE_MODELS: ClassVar[Dict[str, Dict[str, Any]]] = {}
    DEFAULT_MODEL: ClassVar[str] = ""
    MODEL_FAMILY_NAME: ClassVar[str] = ""

    def __init__(
        self,
        model_name: str,
        pooling: str = "cls",
        device: str = "cpu",
        **kwargs,
    ):
        """Initialize transformer embedder.

        Args:
            model_name: Name of the model to use
            pooling: Pooling strategy ('cls' or 'mean')
            device: Device to run on ('cpu' or 'cuda')
            **kwargs: Additional arguments
        """
        # Validate model name
        if model_name not in self.AVAILABLE_MODELS:
            raise ValueError(
                f"Unknown model {model_name}. " f"Available: {list(self.AVAILABLE_MODELS.keys())}"
            )

        # Validate pooling
        if pooling not in ["cls", "mean"]:
            raise ValueError(f"Pooling must be 'cls' or 'mean', got {pooling}")

        # Initialize base with family-prefixed name
        super().__init__(f"{self.MODEL_FAMILY_NAME}_{model_name}", device, **kwargs)

        self.model_name = model_name
        self.hf_name = self.AVAILABLE_MODELS[model_name]["hf_name"]
        self.embedding_dim = self.AVAILABLE_MODELS[model_name]["embedding_dim"]
        self.pooling = pooling
        self.tokenizer = None

        # Update metadata with model-specific info
        self._metadata.update(self._get_model_metadata())

    @abstractmethod
    def _get_model_metadata(self) -> Dict[str, Any]:
        """Get model-specific metadata to update.

        Subclasses should override this to add model-specific metadata fields.

        Returns:
            Dictionary with metadata to update
        """
        pass

    def get_embedding_dim(self) -> int:
        """Get embedding dimension."""
        return self.embedding_dim

    def load_model(self) -> None:
        """Load model from Hugging Face."""
        if self._loaded:
            return

        print(f"Loading {self.MODEL_FAMILY_NAME} model: {self.hf_name}")
        try:
            self.model = AutoModel.from_pretrained(self.hf_name)
            self.tokenizer = AutoTokenizer.from_pretrained(self.hf_name)
            self.model = self.model.to(self.device)
            self.model.eval()
            self._loaded = True
            print(f"Successfully loaded {self.model_name}")
        except Exception as e:
            raise RuntimeError(f"Failed to load {self.model_name}: {str(e)}")

    def _mean_pooling(
        self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
    ) -> torch.Tensor:
        """Perform mean pooling on token embeddings.

        Args:
            token_embeddings: Token embeddings from model
            attention_mask: Attention mask

        Returns:
            Pooled embeddings
        """
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return sum_embeddings / sum_mask

    def _apply_pooling(
        self, last_hidden_state: torch.Tensor, attention_mask: torch.Tensor
    ) -> torch.Tensor:
        """Apply the configured pooling strategy.

        Args:
            last_hidden_state: Model output hidden states
            attention_mask: Attention mask

        Returns:
            Pooled embeddings
        """
        if self.pooling == "cls":
            # Use [CLS] token embedding (first token)
            return last_hidden_state[:, 0, :]
        elif self.pooling == "mean":
            # Use mean pooling over all tokens
            return self._mean_pooling(last_hidden_state, attention_mask)
        else:
            raise ValueError(f"Unknown pooling strategy: {self.pooling}")

    def forward(self, batch: torch.Tensor) -> torch.Tensor:
        """Forward pass through model.

        Note: This expects pre-tokenized input_ids as tensors.
        For text input, use embed_texts() method instead.

        Args:
            batch: Input batch of token IDs of shape (batch_size, seq_len)

        Returns:
            Embeddings of shape (batch_size, embedding_dim)
        """
        if not self._loaded:
            self.load_model()

        self.model.eval()
        with torch.no_grad():
            # Create attention mask (assuming padding token is 0)
            attention_mask = (batch != 0).long()

            outputs = self.model(input_ids=batch, attention_mask=attention_mask)
            embeddings = self._apply_pooling(outputs.last_hidden_state, attention_mask)

        return embeddings

    def embed_texts(self, texts: list, max_length: Optional[int] = None) -> torch.Tensor:
        """Embed a list of texts.

        Args:
            texts: List of text strings
            max_length: Maximum sequence length (uses model default if None)

        Returns:
            Embeddings tensor of shape (len(texts), embedding_dim)
        """
        if not self._loaded:
            self.load_model()

        # Get max length for this model
        if max_length is None:
            max_length = self._get_default_max_length()
        else:
            max_length = self._clamp_max_length(max_length)

        # Tokenize texts
        encoded = self.tokenizer(
            texts,
            padding=True,
            truncation=True,
            max_length=max_length,
            return_tensors="pt",
        )

        input_ids = encoded["input_ids"].to(self.device)
        attention_mask = encoded["attention_mask"].to(self.device)

        self.model.eval()
        with torch.no_grad():
            outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
            embeddings = self._apply_pooling(outputs.last_hidden_state, attention_mask)

        return embeddings

    def _get_default_max_length(self) -> int:
        """Get default max length for this model. Override if needed."""
        return 512

    def _clamp_max_length(self, max_length: int) -> int:
        """Clamp max length to model's maximum. Override if needed."""
        return max_length


File: ssrlib/embedders/cv/vicreg.py
"""VICReg embedder implementation."""

import torch
import torch.nn as nn
from typing import Dict, Any, ClassVar

from ..base import BaseEmbedder


class VICRegEmbedder(BaseEmbedder):
    """VICReg embedder for computer vision."""

    # Class-level metadata
    _embedder_category: ClassVar[str] = "vision"
    _embedder_modality: ClassVar[str] = "vision"
    _embedder_properties: ClassVar[Dict[str, Any]] = {
        "model_family": "VICReg",
        "source": "facebookresearch",
        "self_supervised": True,
        "variance_invariance_covariance": True,
        "architecture": "ResNet",
        "input_size": (224, 224),
    }

    AVAILABLE_MODELS = {
        "resnet50": {"embedding_dim": 2048},
        "resnet50x2": {"embedding_dim": 2048},
        "resnet50x4": {"embedding_dim": 2048},
    }

    def __init__(self, model_name: str = "resnet50", device: str = "cpu", **kwargs):
        """Initialize VICReg embedder.

        Args:
            model_name: Name of the VICReg model to use
            device: Device to run on ('cpu' or 'cuda')
            **kwargs: Additional arguments
        """
        super().__init__(f"VICReg_{model_name}", device, **kwargs)

        if model_name not in self.AVAILABLE_MODELS:
            raise ValueError(
                f"Unknown model {model_name}. " f"Available: {list(self.AVAILABLE_MODELS.keys())}"
            )

        self.model_name = model_name
        self.embedding_dim = self.AVAILABLE_MODELS[model_name]["embedding_dim"]

        # Update metadata
        self._metadata.update(
            {
                "model_name": model_name,
                "embedding_dim": self.embedding_dim,
                "model_family": "VICReg",
            }
        )

    def get_embedding_dim(self) -> int:
        """Get embedding dimension."""
        return self.embedding_dim

    def load_model(self) -> None:
        """Load VICReg model from torch hub."""
        if self._loaded:
            return

        print(f"Loading VICReg model: {self.model_name}")
        try:
            self.model = torch.hub.load("facebookresearch/vicreg:main", self.model_name)
            self.model = self.model.to(self.device)
            self.model.eval()
            self._loaded = True
            print(f"Successfully loaded {self.model_name}")
        except Exception as e:
            raise RuntimeError(f"Failed to load {self.model_name}: {str(e)}")

    def forward(self, batch: torch.Tensor) -> torch.Tensor:
        """Forward pass through VICReg model.

        Args:
            batch: Input batch of shape (batch_size, 3, H, W)

        Returns:
            Embeddings of shape (batch_size, embedding_dim)
        """
        if not self._loaded:
            self.load_model()

        self.model.eval()
        with torch.no_grad():
            embeddings = self.model(batch)

        return embeddings


File: ssrlib/embedders/cv/dinov2.py
"""DINOv2 embedder implementation."""

import torch
import torch.nn as nn
from typing import Dict, Any, ClassVar

from ..base import BaseEmbedder


class DINOv2Embedder(BaseEmbedder):
    """DINOv2 embedder for computer vision."""

    # Class-level metadata
    _embedder_category: ClassVar[str] = "vision"
    _embedder_modality: ClassVar[str] = "vision"
    _embedder_properties: ClassVar[Dict[str, Any]] = {
        "model_family": "DINOv2",
        "source": "facebookresearch",
        "self_supervised": True,
        "supports_registration": True,
        "input_size": (224, 224),
        "architecture": "ViT",
    }

    AVAILABLE_MODELS = {
        "dinov2_vits14": {"embedding_dim": 384},
        "dinov2_vitb14": {"embedding_dim": 768},
        "dinov2_vitl14": {"embedding_dim": 1024},
        "dinov2_vitg14": {"embedding_dim": 1536},
        "dinov2_vits14_reg": {"embedding_dim": 384},
        "dinov2_vitb14_reg": {"embedding_dim": 768},
        "dinov2_vitl14_reg": {"embedding_dim": 1024},
        "dinov2_vitg14_reg": {"embedding_dim": 1536},
    }

    def __init__(self, model_name: str = "dinov2_vitb14", device: str = "cpu", **kwargs):
        """Initialize DINOv2 embedder.

        Args:
            model_name: Name of the DINOv2 model to use
            device: Device to run on ('cpu' or 'cuda')
            **kwargs: Additional arguments
        """
        super().__init__(f"DINOv2_{model_name}", device, **kwargs)

        if model_name not in self.AVAILABLE_MODELS:
            raise ValueError(
                f"Unknown model {model_name}. " f"Available: {list(self.AVAILABLE_MODELS.keys())}"
            )

        self.model_name = model_name
        self.embedding_dim = self.AVAILABLE_MODELS[model_name]["embedding_dim"]

        # Update metadata
        self._metadata.update(
            {
                "model_name": model_name,
                "embedding_dim": self.embedding_dim,
                "model_family": "DINOv2",
            }
        )

    def get_embedding_dim(self) -> int:
        """Get embedding dimension."""
        return self.embedding_dim

    def load_model(self) -> None:
        """Load DINOv2 model from torch hub."""
        if self._loaded:
            return

        print(f"Loading DINOv2 model: {self.model_name}")
        try:
            self.model = torch.hub.load("facebookresearch/dinov2", self.model_name)
            self.model = self.model.to(self.device)
            self.model.eval()
            self._loaded = True
            print(f"Successfully loaded {self.model_name}")
        except Exception as e:
            raise RuntimeError(f"Failed to load {self.model_name}: {str(e)}")

    def forward(self, batch: torch.Tensor) -> torch.Tensor:
        """Forward pass through DINOv2 model.

        Args:
            batch: Input batch of shape (batch_size, 3, H, W)

        Returns:
            Embeddings of shape (batch_size, embedding_dim)
        """
        if not self._loaded:
            self.load_model()

        self.model.eval()
        with torch.no_grad():
            embeddings = self.model(batch)

        return embeddings


File: ssrlib/embedders/cv/__init__.py
from .dinov2 import DINOv2Embedder
from .clip import CLIPEmbedder
from .vicreg import VICRegEmbedder
from .dino import DINOEmbedder

__all__ = ["DINOv2Embedder", "CLIPEmbedder", "VICRegEmbedder", "DINOEmbedder"]


File: ssrlib/embedders/cv/clip.py
"""CLIP embedder implementation."""

import torch
import torch.nn as nn
from transformers import CLIPProcessor, CLIPModel
from typing import Dict, Any, ClassVar

from ..base import BaseEmbedder


class CLIPEmbedder(BaseEmbedder):
    """CLIP embedder for computer vision."""

    # Class-level metadata
    _embedder_category: ClassVar[str] = "vision"
    _embedder_modality: ClassVar[str] = "multimodal"
    _embedder_properties: ClassVar[Dict[str, Any]] = {
        "model_family": "CLIP",
        "source": "OpenAI",
        "multimodal": True,
        "supports_text": True,
        "supports_images": True,
        "contrastive_learning": True,
        "architecture": "ViT",
    }

    AVAILABLE_MODELS = {
        "clip-vit-large-patch14": {
            "embedding_dim": 768,
            "hf_name": "openai/clip-vit-large-patch14",
        },
        "clip-vit-base-patch32": {
            "embedding_dim": 512,
            "hf_name": "openai/clip-vit-base-patch32",
        },
        "clip-vit-base-patch16": {
            "embedding_dim": 512,
            "hf_name": "openai/clip-vit-base-patch16",
        },
    }

    def __init__(self, model_name: str = "clip-vit-large-patch14", device: str = "cpu", **kwargs):
        """Initialize CLIP embedder.

        Args:
            model_name: Name of the CLIP model to use
            device: Device to run on ('cpu' or 'cuda')
            **kwargs: Additional arguments
        """
        super().__init__(f"CLIP_{model_name}", device, **kwargs)

        if model_name not in self.AVAILABLE_MODELS:
            raise ValueError(
                f"Unknown model {model_name}. " f"Available: {list(self.AVAILABLE_MODELS.keys())}"
            )

        self.model_name = model_name
        self.hf_name = self.AVAILABLE_MODELS[model_name]["hf_name"]
        self.embedding_dim = self.AVAILABLE_MODELS[model_name]["embedding_dim"]
        self.processor = None

        # Update metadata
        self._metadata.update(
            {
                "model_name": model_name,
                "hf_name": self.hf_name,
                "embedding_dim": self.embedding_dim,
                "model_family": "CLIP",
            }
        )

    def get_embedding_dim(self) -> int:
        """Get embedding dimension."""
        return self.embedding_dim

    def load_model(self) -> None:
        """Load CLIP model from Hugging Face."""
        if self._loaded:
            return

        print(f"Loading CLIP model: {self.hf_name}")
        try:
            self.model = CLIPModel.from_pretrained(self.hf_name)
            self.processor = CLIPProcessor.from_pretrained(self.hf_name)
            self.model = self.model.to(self.device)
            self.model.eval()
            self._loaded = True
            print(f"Successfully loaded {self.model_name}")
        except Exception as e:
            raise RuntimeError(f"Failed to load {self.model_name}: {str(e)}")

    def forward(self, batch: torch.Tensor) -> torch.Tensor:
        """Forward pass through CLIP model.

        Args:
            batch: Input batch of shape (batch_size, 3, H, W)
                  Expected to be normalized with ImageNet stats

        Returns:
            Embeddings of shape (batch_size, embedding_dim)
        """
        if not self._loaded:
            self.load_model()

        self.model.eval()
        with torch.no_grad():
            # CLIP expects pixel values in [0, 1] range
            # Denormalize from ImageNet normalization
            mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(batch.device)
            std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(batch.device)

            denormalized = batch * std + mean
            denormalized = torch.clamp(denormalized, 0, 1)

            # Get image features
            embeddings = self.model.get_image_features(pixel_values=denormalized)

        return embeddings


File: ssrlib/embedders/cv/dino.py
"""DINO (original) embedder implementation."""

import torch
import torch.nn as nn
from transformers import ResNetModel, ViTModel
from typing import Dict, Any, ClassVar

from ..base import BaseEmbedder


class DINOEmbedder(BaseEmbedder):
    """DINO (original) embedder for computer vision."""

    # Class-level metadata
    _embedder_category: ClassVar[str] = "vision"
    _embedder_modality: ClassVar[str] = "vision"
    _embedder_properties: ClassVar[Dict[str, Any]] = {
        "model_family": "DINO",
        "source": "Facebook AI",
        "self_supervised": True,
        "distillation": True,
        "supports_resnet": True,
        "supports_vit": True,
    }

    AVAILABLE_MODELS = {
        "dino_resnet50": {
            "embedding_dim": 2048,
            "hf_name": "Ramos-Ramos/dino-resnet-50",
            "architecture": "resnet",
        },
        "dino_vitb8": {
            "embedding_dim": 768,
            "hf_name": "facebook/dino-vitb8",
            "architecture": "vit",
        },
        "dino_vitb16": {
            "embedding_dim": 768,
            "hf_name": "facebook/dino-vitb16",
            "architecture": "vit",
        },
        "dino_vits8": {
            "embedding_dim": 384,
            "hf_name": "facebook/dino-vits8",
            "architecture": "vit",
        },
        "dino_vits16": {
            "embedding_dim": 384,
            "hf_name": "facebook/dino-vits16",
            "architecture": "vit",
        },
    }

    def __init__(self, model_name: str = "dino_vitb16", device: str = "cpu", **kwargs):
        """Initialize DINO embedder.

        Args:
            model_name: Name of the DINO model to use
            device: Device to run on ('cpu' or 'cuda')
            **kwargs: Additional arguments
        """
        super().__init__(f"DINO_{model_name}", device, **kwargs)

        if model_name not in self.AVAILABLE_MODELS:
            raise ValueError(
                f"Unknown model {model_name}. " f"Available: {list(self.AVAILABLE_MODELS.keys())}"
            )

        self.model_name = model_name
        self.hf_name = self.AVAILABLE_MODELS[model_name]["hf_name"]
        self.architecture = self.AVAILABLE_MODELS[model_name]["architecture"]
        self.embedding_dim = self.AVAILABLE_MODELS[model_name]["embedding_dim"]

        # Update metadata
        self._metadata.update(
            {
                "model_name": model_name,
                "hf_name": self.hf_name,
                "architecture": self.architecture,
                "embedding_dim": self.embedding_dim,
                "model_family": "DINO",
            }
        )

    def get_embedding_dim(self) -> int:
        """Get embedding dimension."""
        return self.embedding_dim

    def load_model(self) -> None:
        """Load DINO model from Hugging Face."""
        if self._loaded:
            return

        print(f"Loading DINO model: {self.hf_name}")
        try:
            if self.architecture == "resnet":
                self.model = ResNetModel.from_pretrained(self.hf_name)
            elif self.architecture == "vit":
                self.model = ViTModel.from_pretrained(self.hf_name)
            else:
                raise ValueError(f"Unknown architecture: {self.architecture}")

            self.model = self.model.to(self.device)
            self.model.eval()
            self._loaded = True
            print(f"Successfully loaded {self.model_name}")
        except Exception as e:
            raise RuntimeError(f"Failed to load {self.model_name}: {str(e)}")

    def forward(self, batch: torch.Tensor) -> torch.Tensor:
        """Forward pass through DINO model.

        Args:
            batch: Input batch of shape (batch_size, 3, H, W)

        Returns:
            Embeddings of shape (batch_size, embedding_dim)
        """
        if not self._loaded:
            self.load_model()

        self.model.eval()
        with torch.no_grad():
            if self.architecture == "resnet":
                outputs = self.model(pixel_values=batch)
                embeddings = outputs.pooler_output
            elif self.architecture == "vit":
                outputs = self.model(pixel_values=batch)
                # Use mean pooling over sequence dimension for ViT
                embeddings = outputs.last_hidden_state.mean(dim=1)
            else:
                raise ValueError(f"Unknown architecture: {self.architecture}")

        return embeddings

