from .basedataset import BaseDataset
import torch
import numpy as np
import os
import random
import shutil
import sys
import torch
import logging
import pandas as pd
import sys
sys.path.append("../")
import torchvision.transforms as transforms
from datasetsdefer.generic_dataset import GenericImageExpertDataset
import requests
import urllib.request
import tarfile
[docs]class ChestXrayDataset(BaseDataset):
"""Chest X-ray dataset from NIH with multiple radiologist annotations per point from Google Research"""
def __init__(
self,
non_deferral_dataset,
use_data_aug,
data_dir,
label_chosen,
test_split=0.2,
val_split=0.1,
batch_size=1000,
transforms=None,
):
"""
See https://nihcc.app.box.com/v/ChestXray-NIHCC and
non_deferral_dataset (bool): if True, the dataset is the non-deferral dataset, meaning it is the full NIH dataset without the val-test of the human labeled, otherwise it is the deferral dataset that is only 4k in size total
data_dir: where to save files for model
label_chosen (int in 0,1,2,3): if non_deferral_dataset = False: which label to use between 0,1,2,3 which correspond to Fracture, Pneumotheras, Airspace Opacity, and Nodule/Mass; if true: then it's NoFinding or not, Pneumotheras, Effusion, Nodule/Mass
use_data_aug: whether to use data augmentation (bool)
test_split: percentage of test data
val_split: percentage of data to be used for validation (from training set)
batch_size: batch size for training
transforms: data transforms
"""
self.non_deferral_dataset = non_deferral_dataset
self.data_dir = data_dir
self.use_data_aug = use_data_aug
self.label_chosen = label_chosen
self.test_split = test_split
self.val_split = val_split
self.batch_size = batch_size
self.n_dataset = 2
self.train_split = 1 - test_split - val_split
self.transforms = transforms
self.generate_data()
[docs] def generate_data(self):
"""
generate data for training, validation and test sets
"""
links = [
"https://nihcc.box.com/shared/static/vfk49d74nhbxq3nqjg0900w5nvkorp5c.gz",
"https://nihcc.box.com/shared/static/i28rlmbvmfjbl8p2n3ril0pptcmcu9d1.gz",
"https://nihcc.box.com/shared/static/f1t00wrtdk94satdfb9olcolqx20z2jp.gz",
"https://nihcc.box.com/shared/static/0aowwzs5lhjrceb3qp67ahp0rd1l1etg.gz",
"https://nihcc.box.com/shared/static/v5e3goj22zr6h8tzualxfsqlqaygfbsn.gz",
"https://nihcc.box.com/shared/static/asi7ikud9jwnkrnkj99jnpfkjdes7l6l.gz",
"https://nihcc.box.com/shared/static/jn1b4mw4n6lnh74ovmcjb8y48h8xj07n.gz",
"https://nihcc.box.com/shared/static/tvpxmn7qyrgl0w8wfh9kqfjskv6nmm1j.gz",
"https://nihcc.box.com/shared/static/upyy3ml7qdumlgk2rfcvlb9k6gvqq2pj.gz",
"https://nihcc.box.com/shared/static/l6nilvfa9cg3s28tqv1qc1olm3gnz54p.gz",
"https://nihcc.box.com/shared/static/hhq8fkdgvcari67vfhs7ppg2w6ni4jze.gz",
"https://nihcc.box.com/shared/static/ioqwiy20ihqwyr8pf4c24eazhh281pbu.gz",
]
max_links = 12 # 12 is the limit
links = links[:max_links]
if not os.path.exists(self.data_dir + "/images_nih"):
logging.info("Downloading NIH dataset")
for idx, link in enumerate(links):
if not os.path.exists(
self.data_dir + "/images_%02d.tar.gz" % (idx + 1)
):
fn = self.data_dir + "/images_%02d.tar.gz" % (idx + 1)
logging.info("downloading " + fn + "...")
urllib.request.urlretrieve(link, fn) # download the zip file
logging.info("Download complete. Please check the checksums")
# make directory
if not os.path.exists(self.data_dir + "/images_nih"):
os.makedirs(self.data_dir + "/images_nih")
# extract files
for idx in range(max_links):
fn = self.data_dir + "/images_%02d.tar.gz" % (idx + 1)
logging.info("Extracting " + fn + "...")
# os.system('tar -zxvf '+fn+' -C '+self.data_dir+'/images_nih')
file = tarfile.open(fn)
file.extractall(self.data_dir + "/images_nih")
file.close()
fn = self.data_dir + "/images_%02d.tar.gz" % (idx + 1)
os.remove(fn)
logging.info("Done")
else:
# double check that all files are there and extracted
# get number of files in directory
# if not equal to 102120, then download again
num_files = len(
[
name
for name in os.listdir(self.data_dir + "/images_nih")
if os.path.isfile(os.path.join(self.data_dir + "/images_nih", name))
]
)
if num_files != 102120: # acutal is 112120
logging.info("Files missing. Re-downloading...")
shutil.rmtree(self.data_dir + "/images_nih")
for idx, link in enumerate(links):
# check if file exists
fn = self.data_dir + "/images_%02d.tar.gz" % (idx + 1)
if not os.path.exists(
self.data_dir + "/images_%02d.tar.gz" % (idx + 1)
):
logging.info("downloading " + fn + "...")
urllib.request.urlretrieve(link, fn)
logging.info("Download complete. Please check the checksums")
# make directory
if not os.path.exists(self.data_dir + "/images_nih"):
os.makedirs(self.data_dir + "/images_nih")
# extract files
for idx in range(max_links):
fn = self.data_dir + "/images_%02d.tar.gz" % (idx + 1)
logging.info("Extracting " + fn + "...")
# os.system('tar -zxvf '+fn+' -C '+self.data_dir+'/images_nih')
file = tarfile.open(fn)
file.extractall(self.data_dir + "/images_nih")
file.close()
fn = self.data_dir + "/images_%02d.tar.gz" % (idx + 1)
os.remove(fn)
logging.info("Done")
# DOWNLOAD CSV DATA FOR LABELS
if (
not os.path.exists(
self.data_dir + "/four_findings_expert_labels_individual_readers.csv"
)
or not os.path.exists(
self.data_dir + "/four_findings_expert_labels_test_labels.csv"
)
or not os.path.exists(
self.data_dir + "/four_findings_expert_labels_validation_labels.csv"
)
or not os.path.exists(self.data_dir + "/Data_Entry_2017_v2020.csv")
):
logging.info("Downloading readers NIH data")
r = requests.get(
"https://storage.googleapis.com/gcs-public-data--healthcare-nih-chest-xray-labels/four_findings_expert_labels/individual_readers.csv",
allow_redirects=True,
)
with open(
self.data_dir + "/four_findings_expert_labels_individual_readers.csv",
"wb",
) as f:
f.write(r.content)
r = requests.get(
"https://storage.googleapis.com/gcs-public-data--healthcare-nih-chest-xray-labels/four_findings_expert_labels/test_labels.csv",
allow_redirects=True,
)
with open(
self.data_dir + "/four_findings_expert_labels_test_labels.csv", "wb"
) as f:
f.write(r.content)
r = requests.get(
"https://storage.googleapis.com/gcs-public-data--healthcare-nih-chest-xray-labels/four_findings_expert_labels/validation_labels.csv",
allow_redirects=True,
)
with open(
self.data_dir + "/four_findings_expert_labels_validation_labels.csv",
"wb",
) as f:
f.write(r.content)
logging.info("Finished Downloading readers NIH data")
r = requests.get(
"https://dl2.boxcloud.com/d/1/b1!54XsklVtK4wWu3E-qHtsc9JEcNkALHT_hWAUX_n2jJE1XNe5w7dXSoPbm8e1OqaA8xcmquQ2M-qbMKV_StRSatX2KMHE__vr4Z16j1mcnDlFWCXF71jpbZum1lRwn_i5iom8wJ7J0bx7Px2MJgefbb8QKUxvuMGEmnA3-e69TegyuT8wLqB0YPx19Bp8Iue4TKEJ457zFnPjtfC2p5le1yQjfIoxzKXi2oSFxZcSTAv8se3Eynm6ssbAKhs7CC9NbJruD1wuJmQYjcy3YQvEAIgdTTaIDItLiX-GWVIqgZkhPwbnnE0XGT2dm9TlS0WKq7saLydgx4ji88SmVTtr4t82V2h1AnL-6KXl2LPam8DWjSIeja-ehu5vSPf2Uyn4ShIwBHrJM7yFZWa9VrePrd7ANGMVKU879rD01gBrXTvoPdDSKe9KRHfPbCsBCkW0B2NN5T_GvKQAtN4-OFnmGF2QQDFLq17X32XKbPyfXHPAU3eICueFNOPo9MPOESoJ48gpSAIkIjB5VpCYDSK_2Bpm4U3EVfFGCx2FOo48B6_jTo1Xw2c04_RWVEDHfm30IBFMn8Qd5vFea1rFLCXQkKUCVipyLO4z2ezkIk8TaeL0u_6UnaN8bfsIhjyhxM1JwfJT0Z150njCPHHd3CmJe-Cg0Qq-TI9-P1yPFAfEBmkYVyeLzQ9H6HjdmAvMPUvXyXck-_EOBCSV9eIEkH_ZOhN1DHdd3kB-feB-2p_d65cXI1o-f-C5Ep15-AW8mnLn9UU6rG_EgsEyMev0kC75Zo0hO2rOXG_fbZ10PHt4fplmV02pYGDfvGvXtPOHF3nG_qqLF9hNgf1D6IqekWTxeS3SGEL-M3OEl6rhI_1TFg0OQLOy9UtibVPerHYmAwTyTJWFxbjp0pHXGdLZixpLWvZZ0H56aYlILyDP6PJcWMtvdXTRvBE3GdMRv0A1wVKK5tVcYeIEbAgR668AN2FIlbEDsrWjdE6DgvfiXtpaCWY8FmxRWu0UkA-GvAgVuMFhS4FFb5xzTnoUS9IBjkPB_hu8jCsC1qJXBavmWnOEVuTRMQIkhXeIdAw_OR2847VNxMpLOvpb2CIVHrRiDYXI25CO5VBR0jFk4YlX-1mfHApalBLgOaVVdQK4YCd_7J2EGij_41HipGcRSqgRUcZfecfgslc3vTP_vDAo4bASKeeHnKa-6WbP7FGrXyhur-Zv-olIQZPWX5iXJetVv4fHm0QoMLq2q1zk_ARioJSXCIaf9pM6a0rkCN5I64-DaugzU2XZ/download"
)
with open(self.data_dir + "/Data_Entry_2017_v2020.csv", "wb") as f:
f.write(r.content)
try:
readers_data = pd.read_csv(
self.data_dir
+ "/four_findings_expert_labels_individual_readers.csv"
)
test_data = pd.read_csv(
self.data_dir + "/four_findings_expert_labels_test_labels.csv"
)
validation_data = pd.read_csv(
self.data_dir + "/four_findings_expert_labels_validation_labels.csv"
)
all_dataset_data = pd.read_csv(
self.data_dir + "/Data_Entry_2017_v2020.csv"
)
except:
logging.error("Failed to load readers NIH data")
raise
else:
logging.info("Loading readers NIH data")
try:
readers_data = pd.read_csv(
self.data_dir
+ "/four_findings_expert_labels_individual_readers.csv"
)
test_data = pd.read_csv(
self.data_dir + "/four_findings_expert_labels_test_labels.csv"
)
validation_data = pd.read_csv(
self.data_dir + "/four_findings_expert_labels_validation_labels.csv"
)
all_dataset_data = pd.read_csv(
self.data_dir + "/Data_Entry_2017_v2020.csv"
)
except:
logging.error("Failed to load readers NIH data")
raise
data_labels = {}
for i in range(len(validation_data)):
labels = [
validation_data.iloc[i]["Fracture"],
validation_data.iloc[i]["Pneumothorax"],
validation_data.iloc[i]["Airspace opacity"],
validation_data.iloc[i]["Nodule or mass"],
]
# covert YES to 1 and otherwise to 0
labels = [1 if x == "YES" else 0 for x in labels]
data_labels[validation_data.iloc[i]["Image Index"]] = labels
for i in range(len(test_data)):
labels = [
test_data.iloc[i]["Fracture"],
test_data.iloc[i]["Pneumothorax"],
test_data.iloc[i]["Airspace opacity"],
test_data.iloc[i]["Nodule or mass"],
]
# covert YES to 1 and otherwise to 0
labels = [1 if x == "YES" else 0 for x in labels]
data_labels[test_data.iloc[i]["Image Index"]] = labels
data_human_labels = {}
for i in range(len(readers_data)):
labels = [
readers_data.iloc[i]["Fracture"],
readers_data.iloc[i]["Pneumothorax"],
readers_data.iloc[i]["Airspace opacity"],
readers_data.iloc[i]["Nodule/mass"],
]
# covert YES to 1 and otherwise to 0
labels = [1 if x == "YES" else 0 for x in labels]
if readers_data.iloc[i]["Image ID"] in data_human_labels:
data_human_labels[readers_data.iloc[i]["Image ID"]].append(labels)
else:
data_human_labels[readers_data.iloc[i]["Image ID"]] = [labels]
# for each key in data_human_labels, we have a list of lists, sample only one list from each key
data_human_labels = {
k: random.sample(v, 1)[0] for k, v in data_human_labels.items()
}
labels_categories = [
"Fracture",
"Pneumothorax",
"Airspace opacity",
"Nodule/mass",
]
self.label_to_idx = {
labels_categories[i]: i for i in range(len(labels_categories))
}
image_to_patient_id = {}
for i in range(len(readers_data)):
image_to_patient_id[readers_data.iloc[i]["Image ID"]] = readers_data.iloc[
i
]["Patient ID"]
patient_ids = list(set(image_to_patient_id.values()))
data_all_nih_label = {}
# the original dataset has the following labels ['Atelectasis' 'Cardiomegaly' 'Consolidation' 'Edema' 'Effusion' 'Emphysema' 'Fibrosis' 'Hernia' 'Infiltration' 'Mass' 'No Finding' 'Nodule' 'Pleural_Thickening' 'Pneumonia' 'Pneumothorax']
for i in range(len(all_dataset_data)):
if not all_dataset_data["Patient ID"][i] in patient_ids:
labels = [0, 0, 0, 0]
if "Pneumothorax" in all_dataset_data["Finding Labels"][i]:
labels[1] = 1
if "Effusion" in all_dataset_data["Finding Labels"][i]:
labels[2] = 1
if (
"Mass" in all_dataset_data["Finding Labels"][i]
or "Nodule" in all_dataset_data["Finding Labels"][i]
):
labels[3] = 1
if "No Finding" in all_dataset_data["Finding Labels"][i]:
labels[0] = 0
else:
labels[0] = 1
data_all_nih_label[all_dataset_data["Image Index"][i]] = labels
# depending on non_deferral_dataset
transform_train = transforms.Compose(
[
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
transform_test = transforms.Compose(
[
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
if self.non_deferral_dataset == True:
# iterate over key, value in data_all_nih_label
data_y = []
data_expert = []
image_paths = []
for key, value in list(data_all_nih_label.items()):
image_path = self.data_dir + "/images_nih/" + key
# check if the file exists
if os.path.isfile(image_path):
data_y.append(value[self.label_chosen])
image_paths.append(self.data_dir + "/images_nih/" + key)
data_expert.append(value[self.label_chosen]) # nonsense expert
data_y = np.array(data_y)
data_expert = np.array(data_expert)
image_paths = np.array(image_paths)
random_seed = random.randrange(10000)
test_size = int(self.test_split * len(image_paths))
val_size = int(self.val_split * len(image_paths))
train_size = len(image_paths) - test_size - val_size
train_x, val_x, test_x = torch.utils.data.random_split(
image_paths,
[train_size, val_size, test_size],
generator=torch.Generator().manual_seed(random_seed),
)
train_y, val_y, test_y = torch.utils.data.random_split(
data_y,
[train_size, val_size, test_size],
generator=torch.Generator().manual_seed(random_seed),
)
train_h, val_h, test_h = torch.utils.data.random_split(
data_expert,
[train_size, val_size, test_size],
generator=torch.Generator().manual_seed(random_seed),
)
data_train = GenericImageExpertDataset(
train_x.dataset[train_x.indices],
train_y.dataset[train_y.indices],
train_h.dataset[train_h.indices],
transform_train,
to_open=True,
)
data_val = GenericImageExpertDataset(
val_x.dataset[val_x.indices],
val_y.dataset[val_y.indices],
val_h.dataset[val_h.indices],
transform_test,
to_open=True,
)
data_test = GenericImageExpertDataset(
test_x.dataset[test_x.indices],
test_y.dataset[test_y.indices],
test_h.dataset[test_h.indices],
transform_test,
to_open=True,
)
self.data_train_loader = torch.utils.data.DataLoader(
data_train,
batch_size=self.batch_size,
shuffle=True,
num_workers=0,
pin_memory=True,
)
self.data_val_loader = torch.utils.data.DataLoader(
data_val,
batch_size=self.batch_size,
shuffle=False,
num_workers=0,
pin_memory=True,
)
self.data_test_loader = torch.utils.data.DataLoader(
data_test,
batch_size=self.batch_size,
shuffle=False,
num_workers=0,
pin_memory=True,
)
else:
# split patient_ids into train and test, val
random.shuffle(patient_ids, random.random)
# split using 80% for trarain, 10% for test and 10% for validation
train_patient_ids = patient_ids[: int(len(patient_ids) * self.train_split)]
test_patient_ids = patient_ids[
int(len(patient_ids) * self.train_split) : int(
len(patient_ids) * (self.train_split + self.test_split)
)
]
val_patient_ids = patient_ids[
int(len(patient_ids) * (self.train_split + self.test_split)) :
]
# go from patient ids to image ids
train_image_ids = np.array(
[k for k, v in image_to_patient_id.items() if v in train_patient_ids]
)
val_image_ids = np.array(
[k for k, v in image_to_patient_id.items() if v in val_patient_ids]
)
test_image_ids = np.array(
[k for k, v in image_to_patient_id.items() if v in test_patient_ids]
)
# remove images that are not in the directory
train_image_ids = np.array(
[
k
for k in train_image_ids
if os.path.isfile(self.data_dir + "/images_nih/" + k)
]
)
val_image_ids = np.array(
[
k
for k in val_image_ids
if os.path.isfile(self.data_dir + "/images_nih/" + k)
]
)
test_image_ids = np.array(
[
k
for k in test_image_ids
if os.path.isfile(self.data_dir + "/images_nih/" + k)
]
)
logging.info("Finished splitting data into train, test and validation")
# print sizes
logging.info("Train size: {}".format(len(train_image_ids)))
logging.info("Test size: {}".format(len(test_image_ids)))
logging.info("Validation size: {}".format(len(val_image_ids)))
train_y = np.array(
[data_labels[k][self.label_chosen] for k in train_image_ids]
)
val_y = np.array([data_labels[k][self.label_chosen] for k in val_image_ids])
test_y = np.array(
[data_labels[k][self.label_chosen] for k in test_image_ids]
)
train_h = np.array(
[data_human_labels[k][self.label_chosen] for k in train_image_ids]
)
val_h = np.array(
[data_human_labels[k][self.label_chosen] for k in val_image_ids]
)
test_h = np.array(
[data_human_labels[k][self.label_chosen] for k in test_image_ids]
)
train_image_ids = np.array(
[self.data_dir + "/images_nih/" + k for k in train_image_ids]
)
val_image_ids = np.array(
[self.data_dir + "/images_nih/" + k for k in val_image_ids]
)
test_image_ids = np.array(
[self.data_dir + "/images_nih/" + k for k in test_image_ids]
)
data_train = GenericImageExpertDataset(
train_image_ids, train_y, train_h, transform_train, to_open=True
)
data_val = GenericImageExpertDataset(
val_image_ids, val_y, val_h, transform_test, to_open=True
)
data_test = GenericImageExpertDataset(
test_image_ids, test_y, test_h, transform_test, to_open=True
)
self.data_train_loader = torch.utils.data.DataLoader(
data_train,
batch_size=self.batch_size,
shuffle=True,
num_workers=0,
pin_memory=True,
)
self.data_val_loader = torch.utils.data.DataLoader(
data_val,
batch_size=self.batch_size,
shuffle=False,
num_workers=0,
pin_memory=True,
)
self.data_test_loader = torch.utils.data.DataLoader(
data_test,
batch_size=self.batch_size,
shuffle=False,
num_workers=0,
pin_memory=True,
)