Source code for datasetsdefer.cifar_synth

import math
import torch
from .basedataset import BaseDataset
import numpy as np
import os
import random
import shutil
import time
import sys
from sklearn.preprocessing import StandardScaler
import pickle
from sklearn.gaussian_process.kernels import RBF
import torch
from scipy.stats import multivariate_normal
import scipy.stats as st
import torch.optim as optim
from tqdm import tqdm
from sklearn.metrics.pairwise import rbf_kernel
from scipy.special import expit
import torch.distributions as D
import logging
import sys

sys.path.append("../")
import torchvision.datasets as datasets
from torch.autograd import Variable
from torch.utils.data import Dataset
from torch.utils.data.dataset import random_split
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.nn.functional as F
from datasetsdefer.generic_dataset import GenericImageExpertDataset


[docs]class CifarSynthExpert: """simple class to describe our synthetic expert on CIFAR-10 k: number of classes expert can predict, n_classes: number of classes (10 for CIFAR-10)""" def __init__(self, k, n_classes): self.k = k self.n_classes = n_classes
[docs] def predict(self, labels): batch_size = len(labels) outs = [0] * batch_size for i in range(0, batch_size): if labels[i] <= self.k - 1: outs[i] = labels[i] else: prediction_rand = random.randint(0, self.n_classes - 1) outs[i] = prediction_rand return outs
[docs]class CifarSynthDataset(BaseDataset): """This is the CifarK synthetic expert on top of Cifar-10 from Consistent Estimators for Learning to Defer (https://arxiv.org/abs/2006.01862)""" def __init__( self, expert_k, use_data_aug, test_split=0.2, val_split=0.1, batch_size=1000, n_dataset=10, transforms=None, ): """ expert_k: number of classes expert can predict use_data_aug: whether to use data augmentation (bool) test_split: NOT USED FOR CIFAR, since we have a fixed test set val_split: percentage of data to be used for validation (from training set) batch_size: batch size for training transforms: data transforms """ self.expert_k = expert_k self.use_data_aug = use_data_aug self.n_dataset = n_dataset self.expert_fn = CifarSynthExpert(expert_k, self.n_dataset).predict self.test_split = test_split self.val_split = val_split self.batch_size = batch_size self.train_split = 1 - test_split - val_split self.transforms = transforms self.generate_data()
[docs] def generate_data(self): """ generate data for training, validation and test sets """ normalize = transforms.Normalize( mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], std=[x / 255.0 for x in [63.0, 62.1, 66.7]], ) if self.use_data_aug: transform_train = transforms.Compose( [ transforms.ToTensor(), transforms.Lambda( lambda x: F.pad( x.unsqueeze(0), (4, 4, 4, 4), mode="reflect" ).squeeze() ), transforms.ToPILImage(), transforms.RandomCrop(32), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ] ) else: transform_train = transforms.Compose( [ transforms.ToTensor(), normalize, ] ) transform_test = transforms.Compose([transforms.ToTensor(), normalize]) if self.n_dataset == 10: dataset = "cifar10" elif self.n_dataset == 100: dataset = "cifar100" kwargs = {"num_workers": 0, "pin_memory": True} train_dataset_all = datasets.__dict__[dataset.upper()]( "../data", train=True, download=True, transform=transform_train ) train_size = int((1 - self.val_split) * len(train_dataset_all)) val_size = len(train_dataset_all) - train_size train_dataset, val_dataset = torch.utils.data.random_split( train_dataset_all, [train_size, val_size] ) test_dataset = datasets.__dict__["cifar10".upper()]( "../data", train=False, transform=transform_test, download=True ) dataset_train = GenericImageExpertDataset( np.array(train_dataset.dataset.data)[train_dataset.indices], np.array(train_dataset.dataset.targets)[train_dataset.indices], self.expert_fn( np.array(train_dataset.dataset.targets)[train_dataset.indices] ), transform_train, ) dataset_val = GenericImageExpertDataset( np.array(val_dataset.dataset.data)[val_dataset.indices], np.array(val_dataset.dataset.targets)[val_dataset.indices], self.expert_fn(np.array(val_dataset.dataset.targets)[val_dataset.indices]), transform_test, ) dataset_test = GenericImageExpertDataset( test_dataset.data, test_dataset.targets, self.expert_fn(test_dataset.targets), transform_test, ) self.data_train_loader = DataLoader( dataset=dataset_train, batch_size=self.batch_size, shuffle=True, num_workers=0, pin_memory=True, ) self.data_val_loader = DataLoader( dataset=dataset_val, batch_size=self.batch_size, shuffle=False, num_workers=0, pin_memory=True, ) self.data_test_loader = DataLoader( dataset=dataset_test, batch_size=self.batch_size, shuffle=False, num_workers=0, pin_memory=True, )