from __future__ import annotations
import warnings
import os
import collections.abc
from pathlib import Path
from typing import Sequence
import numpy as np
import numpy.linalg as npl
from scipy.ndimage import affine_transform
import torch
import pydicom
from pydicom.dataset import Dataset
from pydicom.uid import generate_uid
import pytomography
from pytomography.metadata import SPECTObjectMeta, SPECTProjMeta, SPECTPSFMeta
from pytomography.utils import get_blank_below_above, compute_TEW, get_mu_from_spectrum_interp
from ..CT import get_HU2mu_conversion, open_CT_file, compute_max_slice_loc_CT
from ..shared import create_ds
[docs]def parse_projection_dataset(ds: Dataset) -> Sequence[torch.Tensor, np.array, np.array, dict]:
"""Gets projections with corresponding radii and angles corresponding to projection data from a DICOM file.
Args:
ds (Dataset): pydicom dataset object.
Returns:
(torch.tensor[EWindows, TimeWindows, Ltheta, Lr, Lz], np.array, np.array): Returns (i) projection data (ii) angles (iii) radii and (iv) flags for whether or not multiple energy windows/time slots were detected.
"""
flags = {'multi_energy_window': False, 'multi_time_slot': False}
pixel_array = ds.pixel_array
# Energy Window Vector
energy_window_vector = np.array(ds.EnergyWindowVector)
detector_vector = np.array(ds.DetectorVector)
# Time slot vector
try:
time_slot_vector = np.array(ds.TimeSlotVector)
except:
time_slot_vector = np.ones(len(detector_vector)).astype(int)
# Update flags
if len(np.unique(energy_window_vector))>1:
flags['multi_energy_window'] = True
if len(np.unique(time_slot_vector))>1:
flags['multi_time_slot'] = True
# Get radii and angles
detectors = np.array(ds.DetectorVector)
radii = np.array([])
angles = np.array([])
for detector in np.unique(detectors):
n_angles = ds.RotationInformationSequence[0].NumberOfFramesInRotation
delta_angle = ds.RotationInformationSequence[0].AngularStep
try:
start_angle = ds.DetectorInformationSequence[detector-1].StartAngle
except:
start_angle = ds.RotationInformationSequence[0].StartAngle
rotation_direction = ds.RotationInformationSequence[0].RotationDirection
if rotation_direction=='CC' or rotation_direction=='CCW':
angles = np.concatenate([angles, start_angle + delta_angle*np.arange(n_angles)])
else:
angles = np.concatenate([angles, start_angle - delta_angle*np.arange(n_angles)])
radial_positions_detector = ds.DetectorInformationSequence[detector-1].RadialPosition
if not isinstance(radial_positions_detector, collections.abc.Sequence):
radial_positions_detector = n_angles * [radial_positions_detector]
radii = np.concatenate([radii, radial_positions_detector])
projections = []
for energy_window in np.unique(energy_window_vector):
t_slot_projections = []
for time_slot in np.unique(time_slot_vector):
pixel_array_i = pixel_array[(time_slot_vector==time_slot)*(energy_window_vector==energy_window)]
t_slot_projections.append(pixel_array_i)
projections.append(t_slot_projections)
projections = np.array(projections)
angles = (angles + 180)%360 # to detector angle convention
sorted_idxs = np.argsort(angles)
projections = np.transpose(projections[:,:,sorted_idxs,::-1], (0,1,2,4,3)).astype(np.float32)
projections= torch.tensor(projections.copy()).to(pytomography.dtype).to(pytomography.device)
return (projections,
angles[sorted_idxs],
radii[sorted_idxs]/10,
flags)
[docs]def get_projections(
file: str,
index_peak: None | int = None,
index_time: None | int = None,
print_shape: bool = True
) -> Sequence[SPECTObjectMeta, SPECTProjMeta, torch.Tensor]:
"""Gets projections from a .dcm file.
Args:
file (str): Path to the .dcm file of SPECT projection data.
index_peak (int): If not none, then the returned projections correspond to the index of this energy window. Otherwise returns all energy windows. Defaults to None.
index_time (int): If not none, then the returned projections correspond to the index of the time slot in gated SPECT. Otherwise returns all time slots. Defaults to None
print_shape (bool): If true, then prints the shape of the projections returned. Defaults to true.
Returns:
(SPECTObjectMeta, SPECTProjMeta, torch.Tensor[..., Ltheta, Lr, Lz]) where ... depends on if time slots are considered.
"""
ds = pydicom.read_file(file, force=True)
projections, _, _, flags = parse_projection_dataset(ds)
if index_peak is not None:
projections = projections[index_peak].unsqueeze(dim=0)
flags['multi_energy_window'] = False
if index_time is not None:
projections = projections[:,index_time].unsqueeze(dim=1)
flags['multi_time_slot'] = False
projections = projections.squeeze()
dimension_list = ['Ltheta', 'Lr', 'Lz']
if flags['multi_time_slot']:
dimension_list = ['N_timeslots'] + dimension_list
if print_shape: print('Multiple time slots found')
if flags['multi_energy_window']:
dimension_list = ['N_energywindows'] + dimension_list
if print_shape: print('Multiple energy windows found')
if len(dimension_list)==3:
dimension_list = ['1'] + dimension_list
projections = projections.unsqueeze(dim=0)
if print_shape: print(f'Returned projections have dimensions ({" ".join(dimension_list)})')
return projections
[docs]def get_window_width(ds: Dataset, index: int) -> float:
"""Computes the width of an energy window corresponding to a particular index in the DetectorInformationSequence DICOM attribute.
Args:
ds (Dataset): DICOM dataset.
index (int): Energy window index corresponding to the DICOM dataset.
Returns:
float: Range of the energy window in keV
"""
energy_window = ds.EnergyWindowInformationSequence[index]
window_range1 = energy_window.EnergyWindowRangeSequence[0].EnergyWindowLowerLimit
window_range2 = energy_window.EnergyWindowRangeSequence[0].EnergyWindowUpperLimit
return window_range2 - window_range1
[docs]def get_scatter_from_TEW(
file: str,
index_peak: int,
index_lower: int,
index_upper: int
) -> torch.Tensor:
"""Gets an estimate of scatter projection data from a DICOM file using the triple energy window method.
Args:
file (str): Filepath of the DICOM file
index_peak (int): Index of the ``EnergyWindowInformationSequence`` DICOM attribute corresponding to the photopeak.
index_lower (int): Index of the ``EnergyWindowInformationSequence`` DICOM attribute corresponding to lower scatter window.
index_upper (int): Index of the ``EnergyWindowInformationSequence`` DICOM attribute corresponding to upper scatter window.
Returns:
torch.Tensor[1,Ltheta,Lr,Lz]: Tensor corresponding to the scatter estimate.
"""
ds = pydicom.read_file(file, force=True)
ww_peak = get_window_width(ds, index_peak)
ww_lower = get_window_width(ds, index_lower)
ww_upper = get_window_width(ds, index_upper)
projections_all = get_projections(file, print_shape=False)
scatter = compute_TEW(projections_all[index_lower],projections_all[index_upper], ww_lower, ww_upper, ww_peak)
return scatter.to(pytomography.device)
[docs]def get_attenuation_map_from_file(file_AM: str) -> torch.Tensor:
"""Gets an attenuation map from a DICOM file. This data is usually provided by the manufacturer of the SPECT scanner.
Args:
file_AM (str): File name of attenuation map
Returns:
torch.Tensor: Tensor of shape [batch_size, Lx, Ly, Lz] corresponding to the atteunation map in units of cm:math:`^{-1}`
"""
ds = pydicom.read_file(file_AM, force=True)
# DICOM header for scale factor that shows up sometimes
if (0x033,0x1038) in ds:
scale_factor = 1/ds[0x033,0x1038].value
else:
scale_factor = 1
attenuation_map = ds.pixel_array * scale_factor
return torch.tensor(np.transpose(attenuation_map, (2,1,0))).unsqueeze(dim=0).to(pytomography.dtype).to(pytomography.device)
[docs]def CT_to_mumap(
CT: torch.tensor,
files_CT: Sequence[str],
file_NM: str,
index_peak=0
) -> torch.tensor:
"""Converts a CT image to a mu-map given SPECT projection data. The CT data must be aligned with the projection data already; this is a helper function for ``get_attenuation_map_from_CT_slices``.
Args:
CT (torch.tensor): CT object in units of HU
files_CT (Sequence[str]): Filepaths of all CT slices
file_NM (str): Filepath of SPECT projectio ndata
index_peak (int, optional): Index of EnergyInformationSequence corresponding to the photopeak. Defaults to 0.
Returns:
torch.tensor: Attenuation map in units of 1/cm
"""
ds_NM = pydicom.read_file(file_NM)
window_upper = ds_NM.EnergyWindowInformationSequence[index_peak].EnergyWindowRangeSequence[0].EnergyWindowUpperLimit
window_lower = ds_NM.EnergyWindowInformationSequence[index_peak].EnergyWindowRangeSequence[0].EnergyWindowLowerLimit
E_SPECT = (window_lower + window_upper)/2
KVP = pydicom.read_file(files_CT[0]).KVP
HU2mu_conversion = get_HU2mu_conversion(files_CT, KVP, E_SPECT)
return HU2mu_conversion(CT)
[docs]def get_attenuation_map_from_CT_slices(
files_CT: Sequence[str],
file_NM: str | None = None,
index_peak: int = 0,
keep_as_HU: bool = False,
mode: str = 'nearest'
) -> torch.Tensor:
"""Converts a sequence of DICOM CT files (corresponding to a single scan) into a torch.Tensor object usable as an attenuation map in PyTomography.
Args:
files_CT (Sequence[str]): List of all files corresponding to an individual CT scan
file_NM (str): File corresponding to raw PET/SPECT data (required to align CT with projections). If None, then no alignment is done. Defaults to None.
index_peak (int, optional): Index corresponding to photopeak in projection data. Defaults to 0.
keep_as_HU (bool): If True, then don't convert to linear attenuation coefficient and keep as Hounsfield units. Defaults to False
Returns:
torch.Tensor: Tensor of shape [Lx, Ly, Lz] corresponding to attenuation map.
"""
CT_HU = open_CT_file(files_CT)
if file_NM is None:
return torch.tensor(CT_HU[:,:,::-1].copy()).unsqueeze(dim=0).to(pytomography.dtype).to(pytomography.device)
ds_NM = pydicom.read_file(file_NM)
# Align with SPECT:
M_CT = _get_affine_CT(files_CT)
M_NM = _get_affine_spect_projections(file_NM)
# Resample CT and convert to mu at 208keV and save
M = npl.inv(M_CT) @ M_NM
# When doing affine transform, fill outside with point below -1000HU so it automatically gets converted to mu=0 after bilinear transform
CT_HU = affine_transform(CT_HU, M, output_shape=(ds_NM.Rows, ds_NM.Rows, ds_NM.Columns), mode=mode, cval=-1500)
if keep_as_HU:
CT = CT_HU
else:
CT= CT_to_mumap(CT_HU, files_CT, file_NM, index_peak)
CT = torch.tensor(CT[:,:,::-1].copy()).unsqueeze(dim=0).to(pytomography.dtype).to(pytomography.device)
return CT
[docs]def _get_affine_spect_projections(filename:str) -> np.array:
"""Computes an affine matrix corresponding the coordinate system of a SPECT DICOM file of projections.
Args:
ds (Dataset): DICOM dataset of projection data
Returns:
np.array: Affine matrix
"""
# Note: per DICOM convention z actually decreases as the z-index increases (initial z slices start with the head)
ds = pydicom.read_file(filename)
Sx, Sy, Sz = ds.DetectorInformationSequence[0].ImagePositionPatient
dx = dy = ds.PixelSpacing[0]
dz = ds.PixelSpacing[1]
Sx -= ds.Rows / 2 * dx
Sy -= ds.Rows / 2 * dy
Sy -= ds.RotationInformationSequence[0].TableHeight
M = np.zeros((4,4))
M[0] = np.array([dx, 0, 0, Sx])
M[1] = np.array([0, dy, 0, Sy])
M[2] = np.array([0, 0, -dz, Sz])
M[3] = np.array([0, 0, 0, 1])
return M
[docs]def _get_affine_CT(filenames: Sequence[str]):
"""Computes an affine matrix corresponding the coordinate system of a CT DICOM file. Note that since CT scans consist of many independent DICOM files, ds corresponds to an individual one of these files. This is why the maximum z value is also required (across all seperate independent DICOM files).
Args:
ds (Dataset): DICOM dataset of CT data
max_z (float): Maximum value of z across all axial slices that make up the CT scan
Returns:
np.array: Affine matrix corresponding to CT scan.
"""
# Note: per DICOM convention z actually decreases as the z-index increases (initial z slices start with the head)
ds = pydicom.read_file(filenames[0])
max_z = compute_max_slice_loc_CT(filenames)
M = np.zeros((4,4))
M[0:3, 0] = np.array(ds.ImageOrientationPatient[0:3])*ds.PixelSpacing[0]
M[0:3, 1] = np.array(ds.ImageOrientationPatient[3:])*ds.PixelSpacing[1]
M[0:3, 2] = - np.array([0,0,1]) * ds.SliceThickness
M[0:2, 3] = np.array(ds.ImagePositionPatient)[0:2]
M[2, 3] = max_z
M[3, 3] = 1
return M
[docs]def stitch_multibed(
recons: torch.Tensor,
files_NM: Sequence[str],
method: str ='midslice'
) -> torch.Tensor:
"""Stitches together multiple reconstructed objects corresponding to different bed positions.
Args:
recons (torch.Tensor[n_beds, Lx, Ly, Lz]): Reconstructed objects. The first index of the tensor corresponds to different bed positions
files_NM (list): List of length ``n_beds`` corresponding to the DICOM file of each reconstruction
method (str, optional): Method to perform stitching (see https://doi.org/10.1117/12.2254096 for all methods described). Available methods include ``'midslice'``, ``'average'``, ``'crossfade'``, and ``'TEM;`` (transition error minimization).
Returns:
torch.Tensor[1, Lx, Ly, Lz']: Stitched together DICOM file. Note the new z-dimension size :math:`L_z'`.
"""
dss = np.array([pydicom.read_file(file_NM) for file_NM in files_NM])
zs = np.array([ds.DetectorInformationSequence[0].ImagePositionPatient[-1] for ds in dss])
# Sort by increasing z-position
order = np.argsort(zs)
dss = dss[order]
zs = zs[order]
recons = recons[order]
#convert to voxel height
zs = np.round((zs - zs[0])/dss[0].PixelSpacing[1]).astype(int)
new_z_height = zs[-1] + recons.shape[-1]
recon_aligned = torch.zeros((1, dss[0].Rows, dss[0].Rows, new_z_height)).to(pytomography.device)
blank_below, blank_above = get_blank_below_above(get_projections(files_NM[0])[2])
for i in range(len(zs)):
recon_aligned[:,:,:,zs[i]+blank_below:zs[i]+blank_above] = recons[i,:,:,blank_below:blank_above]
# Apply stitching method
for i in range(1,len(zs)):
zmin = zs[i] + blank_below
zmax = zs[i-1] + blank_above
dL = zmax - zmin
half = round((zmax - zmin)/2)
if zmax>zmin+1: #at least two voxels apart
zmin_upper = blank_below
zmax_lower = blank_above
delta = -(zs[i] - zs[i-1]) - blank_below + blank_above
r1 = recons[i-1][:,:,zmax_lower-delta:zmax_lower]
r2 = recons[i][:,:,zmin_upper:zmin_upper+delta]
if method=='midslice':
recon_aligned[:,:,:,zmin:zmin+half] = r1[:,:,:half]
recon_aligned[:,:,:,zmin+half:zmax] = r2[:,:,half:]
elif method=='average':
recon_aligned[:,:,:,zmin:zmax] = 0.5 * (r1 + r2)
elif method=='crossfade':
idx = torch.arange(dL).to(pytomography.device) + 0.5
recon_aligned[:,:,:,zmin:zmax] = ((dL-idx)*r1 + idx*r2) / dL
elif method=='TEM':
stitch_index = torch.min(torch.abs(r1-r2), axis=2)[1]
range_tensor = torch.arange(dL).unsqueeze(0).unsqueeze(0).to(pytomography.device)
mask_tensor = range_tensor < stitch_index.unsqueeze(-1)
expanded_mask = mask_tensor.expand(*stitch_index.shape, dL)
recon_aligned[:,:,:,zmin:zmax][expanded_mask.unsqueeze(0)] = r1[expanded_mask]
recon_aligned[:,:,:,zmin:zmax][~expanded_mask.unsqueeze(0)] = r2[~expanded_mask]
return recon_aligned
[docs]def save_dcm(
save_path: str,
object: torch.Tensor,
file_NM: str,
recon_name: str = '',
scale_factor: float = 1024,
) -> None:
"""Saves the reconstructed object `object` to a series of DICOM files in the folder given by `save_path`. Requires the filepath of the projection data `file_NM` to get Study information.
Args:
object (torch.Tensor): Reconstructed object of shape [1,Lx,Ly,Lz].
save_path (str): Location of folder where to save the DICOM output files.
file_NM (str): File path of the projection data corresponding to the reconstruction.
recon_name (str): Type of reconstruction performed. Obtained from the `recon_method_str` attribute of a reconstruction algorithm class.
scale_factor (float, optional): Amount by which to scale output data so that it can be converted into a 16 bit integer. Defaults to 1024.
"""
try:
Path(save_path).resolve().mkdir(parents=True, exist_ok=False)
except:
raise Exception(f'Folder {save_path} already exists; new folder name is required.')
# Convert tensor image to numpy array
pixel_data = torch.permute(object.squeeze(),(2,1,0)) * scale_factor
pixel_data = pixel_data.cpu().numpy().astype(np.uint16)
# Get affine information
ds_NM = pydicom.dcmread(file_NM)
Sx, Sy, Sz = ds_NM.DetectorInformationSequence[0].ImagePositionPatient
dx = dy = ds_NM.PixelSpacing[0]
dz = ds_NM.PixelSpacing[1]
Sx -= ds_NM.Rows / 2 * dx
Sy -= ds_NM.Rows / 2 * dy
# Y-Origin point at tableheight=0
Sy -= ds_NM.RotationInformationSequence[0].TableHeight
# Sz now refers to location of lowest slice
Sz -= (pixel_data.shape[0]-1) * dz
SOP_instance_UID = generate_uid()
SOP_class_UID = '1.2.840.10008.5.1.4.1.1.128' #SPECT storage
modality = 'PT' # SPECT storage
ds = create_ds(ds_NM, SOP_instance_UID, SOP_class_UID, modality)
ds.Rows, ds.Columns = pixel_data.shape[1:]
ds.SeriesNumber = 1
ds.NumberOfSlices = pixel_data.shape[0]
ds.PixelSpacing = [dx,dy]
ds.SliceThickness = dz
ds.ImageOrientationPatient = [1,0,0,0,1,0]
ds.RescaleSlope = 1/scale_factor
# Set other things
ds.BitsAllocated = 16
ds.BitsStored = 16
ds.SamplesPerPixel = 1
ds.PhotometricInterpretation = 'MONOCHROME2'
ds.PixelRepresentation = 0
ds.ReconstructionMethod = recon_name
# Create all slices
for i in range(pixel_data.shape[0]):
# Load existing DICOM file
ds_i = ds.copy()
ds_i.InstanceNumber = i+1
ds_i.ImagePositionPatient = [Sx,Sy,Sz+i*dz]
# Create SOP Instance UID unique to slice
SOP_instance_UID_slice = f'{ds_i.SOPInstanceUID[:-3]}{i+1:03d}'
ds_i.SOPInstanceUID = SOP_instance_UID_slice
ds_i.file_meta.MediaStorageSOPInstanceUID = (SOP_instance_UID_slice)
# Set the pixel data
ds_i.PixelData = pixel_data[i].tobytes()
ds_i.save_as(os.path.join(save_path, f'{ds.SOPInstanceUID}.dcm'))