Source code for scitex_ml.feature_extraction.vit

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Time-stamp: "2024-11-27 21:36:51 (ywatanabe)"
# File: ./scitex_repo/src/scitex/ai/feature_extraction/vit.py

THIS_FILE = "/home/ywatanabe/proj/scitex_repo/src/scitex/ai/feature_extraction/vit.py"

"""
Functionality:
    Extracts features from images using Vision Transformer (ViT) models
Input:
    Image arrays of arbitrary dimensions
Output:
    Feature vectors (1000-dimensional embeddings)
Prerequisites:
    torch, PIL, torchvision
"""

import os as _os
from typing import Tuple, Union

import numpy as np
import torch
import torch as _torch
from pytorch_pretrained_vit import ViT
from torchvision import transforms as _transforms

# from scitex.decorators import batch_torch_fn


def _setup_device(device: Union[str, None]) -> str:
    if device is None:
        device = "cuda" if _torch.cuda.is_available() else "cpu"
    return device


[docs] class VitFeatureExtractor:
[docs] def __init__( self, model_name="B_16", torch_home="./models", device=None, ): self.valid_models = [ "B_16", "B_32", "L_16", "L_32", "B_16_imagenet1k", "B_32_imagenet1k", "L_16_imagenet1k", "L_32_imagenet1k", ] self.model_name = model_name self.torch_home = torch_home self.device = _setup_device(device) _os.environ["TORCH_HOME"] = torch_home self._validate_inputs() self.model = ViT(model_name, pretrained=True).to(self.device).eval() self.transform = _transforms.Compose( [ _transforms.ToPILImage(), _transforms.Resize(self.model.image_size), _transforms.ToTensor(), _transforms.Normalize(0.5, 0.5), ] )
def _validate_inputs(self): if self.model_name not in self.valid_models: raise ValueError(f"Invalid model name. Choose from: {self.valid_models}") if not _os.path.exists(self.torch_home): raise FileNotFoundError(f"Model directory not found: {self.torch_home}") def _preprocess_array( self, arr: _torch.Tensor, dim: Tuple[int, int], channel_dim: Union[int, None], ) -> _torch.Tensor: # print(f"Input array shape: {arr.shape}") orig_shape = arr.shape dim = tuple(d if d >= 0 else len(orig_shape) + d for d in dim) perm = list(range(len(orig_shape))) for d in sorted(dim): perm.remove(d) perm.append(d) arr = arr.permute(perm) # Flatten all dimensions except the last two (spatial dimensions) batch_shape = arr.shape[:-2] spatial_shape = arr.shape[-2:] arr = arr.reshape(-1, *spatial_shape) # Process each image transformed = [] for img in arr: img = img.unsqueeze(0) img = img.repeat(3, 1, 1) transformed.append(self.transform(img)) result = _torch.stack(transformed) return result, batch_shape # @batch_method # @torch_method # @batch_torch_fn
[docs] def extract_features( self, arr, axis=(-2, -1), dim=None, channel_dim=None, batch_size=None, device="cuda", ): processed_arr, batch_shape = self._preprocess_array( arr, axis, channel_dim, ) # print(f"Processed shape: {processed_arr.shape}") processed_arr = processed_arr.to(self.device) with _torch.no_grad(): features = self.model(processed_arr).cpu() return features.reshape(*batch_shape, -1)
if __name__ == "__main__": import scitex extractor = scitex_ml.feature_extraction.VitFeatureExtractor( model_name="B_16_imagenet1k" ) tensor = torch.randn(3, 2, 4, 5, 32, 32) processed = extractor.extract_features(tensor, (-2, -1), None) print(processed.shape) arr = np.random.rand(3, 2, 4, 5, 32, 32) processed = extractor.extract_features(arr, (-2, -1), None) print(processed.shape) # torch.Size([3, 2, 4, 5, 32, 32]) # EOF