from __future__ import annotations
from collections.abc import Sequence
import pytomography
from pytomography.metadata import ObjectMeta
import torch
import numpy as np
import numpy.linalg as npl
import uproot
from scipy.ndimage import affine_transform
from ..shared import get_header_value, get_attenuation_map_interfile
[docs]def get_aligned_attenuation_map(
headerfile: str,
object_meta: ObjectMeta
) -> torch.tensor:
"""Returns an aligned attenuation map in units of inverse mm for reconstruction. This assumes that the attenuation map shares the same center point with the reconstruction space.
Args:
headerfile (str): Filepath to the header file of the attenuation map
object_meta (ObjectMeta): Object metadata providing spatial information about the reconstructed dimensions.
Returns:
torch.Tensor: Aligned attenuation map
"""
amap = get_attenuation_map_interfile(headerfile)[0].cpu().numpy()
# Load metadata
with open(headerfile) as f:
headerdata = f.readlines()
headerdata = np.array(headerdata)
dx = get_header_value(headerdata, 'scaling factor (mm/pixel) [1]', np.float32)
dy = get_header_value(headerdata, 'scaling factor (mm/pixel) [2]', np.float32)
dz = get_header_value(headerdata, 'scaling factor (mm/pixel) [3]', np.float32)
dr_amap = (dx, dy, dz)
shape_amap = amap.shape
object_origin_amap = (- np.array(shape_amap) / 2 + 0.5) * (np.array(dr_amap))
dr = object_meta.dr
shape = object_meta.shape
object_origin = object_origin = (- np.array(shape) / 2 + 0.5) * (np.array(dr))
M_PET = np.array([
[dr[0],0,0,object_origin[0]],
[0,dr[1],0,object_origin[1]],
[0,0,dr[2],object_origin[2]],
[0,0,0,1]
])
M_CT = np.array([
[dr_amap[0],0,0,object_origin_amap[0]],
[0,dr_amap[1],0,object_origin_amap[1]],
[0,0,dr_amap[2],object_origin_amap[2]],
[0,0,0,1]
])
amap = affine_transform(amap, npl.inv(M_CT)@M_PET, output_shape = shape, order=1)
amap = torch.tensor(amap, device=pytomography.device).unsqueeze(0) / 10 # to mm^-1
return amap
[docs]def get_scanner_LUT(
path: str,
init_volume_name: str = 'crystal',
final_volume_name: str = 'world',
mean_interaction_depth: float = 0,
return_info: bool = False
) -> np.array:
"""Returns the scanner lookup table. The three values at a particular index in the lookup table correspond to the x, y, and z positions of the detector id correpsonding to that index.
Args:
path (str): Path to .mac file where the scanner geometry is defined in FATE
init_volume_name (str, optional): Volume name corresponding the lowest level element in the GATE geometry. Defaults to 'crystal'.
final_volume_name (str, optional): Volume name corresponding the highest level element in the GATE geometry. Defaults to 'world'.
mean_interaction_depth (float, optional): Average interaction depth of photons in the crystals in mm. Defaults to 0.
return_info (bool, optional): Returns information about the scanner geometry. Defaults to False.
Returns:
np.array: Scanner lookup table.
"""
with open(path) as f:
headerdata = f.readlines()
headerdata = np.array(headerdata)
# Recursively get names of all volumes
volume = init_volume_name
parents = [volume]
while volume!=final_volume_name:
volume = get_header_value(headerdata, f'/daughters/name.*{volume}', split_substr='/', split_idx=2, dtype=str)
if not(volume):
return None
parents.append(volume)
# Get initial positions of crystal
positions = []
for parent in parents:
try:
x = get_header_value(headerdata, f'/gate/{parent}/placement/setTranslation', split_substr=None, split_idx=1)
y = get_header_value(headerdata, f'/gate/{parent}/placement/setTranslation', split_substr=None, split_idx=2)
z = get_header_value(headerdata, f'/gate/{parent}/placement/setTranslation', split_substr=None, split_idx=3)
except:
x = y = z = 0
positions.append([x,y,z])
positions = np.array(positions)
x_crystal, y_crystal, z_crystal = positions.sum(axis=0)
x_crystal = np.array([x_crystal])
y_crystal = np.array([y_crystal])
z_crystal = np.array([z_crystal])
# Get edges of crystal (assume original in +X) TODO: fix
x_crystal = x_crystal - get_header_value(headerdata, f'/gate/{init_volume_name}/geometry/setXLength', split_substr=None, split_idx=1) / 2 + mean_interaction_depth
# Generate positions of all crystals using repeaters
#parents.reverse()
repeats = []
for parent in parents:
repeaters = get_header_value(headerdata, f'/gate/{parent}/repeaters/insert', split_substr=None, split_idx=1, dtype=str, return_all=True)
if not(repeaters):
continue
for repeater in repeaters:
if repeater=='cubicArray':
repeat_numbers = [get_header_value(headerdata, f'/gate/{parent}/{repeater}/setRepeatNumber{coord}', split_substr=None, split_idx=1) for coord in ['X', 'Y', 'Z']]
repeat_vector = [get_header_value(headerdata, f'/gate/{parent}/{repeater}/setRepeatVector', split_substr=None, split_idx=i) for i in range(1,4)]
# GATE convention is that z changes first, so meshgrid is done like this
xv, zv, yv = np.meshgrid(
repeat_vector[0] * (np.arange(0,repeat_numbers[0]) - (repeat_numbers[0]-1)/2),
repeat_vector[2] * (np.arange(0,repeat_numbers[2]) - (repeat_numbers[2]-1)/2),
repeat_vector[1] * (np.arange(0,repeat_numbers[1]) - (repeat_numbers[1]-1)/2),
)
len_repeat = xv.ravel().shape[0]
x_crystal_repeated = np.repeat(x_crystal[:,np.newaxis], len_repeat, axis=1)
y_crystal_repeated = np.repeat(y_crystal[:,np.newaxis], len_repeat, axis=1)
z_crystal_repeated = np.repeat(z_crystal[:,np.newaxis], len_repeat, axis=1)
x_crystal = (x_crystal_repeated+xv.ravel()).ravel()
y_crystal = (y_crystal_repeated+yv.ravel()).ravel()
z_crystal = (z_crystal_repeated+zv.ravel()).ravel()
repeats.append(int(np.prod(repeat_numbers)))
elif repeater=='linear':
repeat_number = get_header_value(headerdata, f'/gate/{parent}/{repeater}/setRepeatNumber', split_substr=None, split_idx=1)
repeat_vector = [get_header_value(headerdata, f'/gate/{parent}/{repeater}/setRepeatVector', split_substr=None, split_idx=i) for i in range(1,4)]
xr = repeat_vector[0] * (np.arange(0,repeat_number) - (repeat_number-1)/2)
yr = repeat_vector[1] * (np.arange(0,repeat_number) - (repeat_number-1)/2)
zr = repeat_vector[2] * (np.arange(0,repeat_number) - (repeat_number-1)/2)
len_repeat = xr.shape[0]
x_crystal_repeated = np.repeat(x_crystal[:,np.newaxis], len_repeat, axis=1)
y_crystal_repeated = np.repeat(y_crystal[:,np.newaxis], len_repeat, axis=1)
z_crystal_repeated = np.repeat(z_crystal[:,np.newaxis], len_repeat, axis=1)
x_crystal = (x_crystal_repeated+xr).ravel()
y_crystal = (y_crystal_repeated+yr).ravel()
z_crystal = (z_crystal_repeated+zr).ravel()
repeats.append(int(repeat_number))
elif repeater=='ring':
repeat_number = get_header_value(headerdata, f'/gate/{parent}/{repeater}/setRepeatNumber', split_substr=None, split_idx=1)
first_angle = get_header_value(headerdata, f'/gate/{parent}/{repeater}/setFirstAngle', split_substr=None, split_idx=1) * np.pi / 180
if not first_angle:
first_angle = 0
phi = np.linspace(first_angle, first_angle+2*np.pi, int(repeat_number), endpoint=False)
x_crystal_repeated = np.repeat(x_crystal[:,np.newaxis], len(phi), axis=1)
y_crystal_repeated = np.repeat(y_crystal[:,np.newaxis], len(phi), axis=1)
z_crystal_repeated = np.repeat(z_crystal[:,np.newaxis], len(phi), axis=1)
x_crystal = (np.cos(phi)*x_crystal_repeated - np.sin(phi)*y_crystal_repeated).ravel()
y_crystal = (np.sin(phi)*x_crystal_repeated + np.cos(phi)*y_crystal_repeated).ravel()
z_crystal = (z_crystal_repeated).ravel()
repeats.append(int(repeat_number))
info = dict(zip(parents, repeats))
if return_info:
return torch.tensor(-np.vstack((x_crystal,y_crystal,z_crystal)).T), info
else:
return torch.tensor(-np.vstack((x_crystal,y_crystal,z_crystal)).T)
[docs]def get_N_components(mac_file: str) -> tuple:
"""Obtains the number of gantrys, rsectors, modules, submodules, and crystals per level from a GATE macro file.
Args:
mac_file (str): Path to the gate macro file
Returns:
tuple: number of gantrys, rsectors, modules, submodules, and crystals
"""
geom_info = get_scanner_LUT(mac_file, return_info=True)[1]
N_gantry = 1
N_module = geom_info['module']
try:
N_submodule = geom_info['submodule']
except:
N_submodule = 1
N_rsector = geom_info['rsector']
N_crystal = geom_info['crystal']
return N_gantry, N_rsector, N_module, N_submodule, N_crystal
[docs]def get_detector_ids(
paths: Sequence[str],
mac_file: str,
TOF: bool = False,
TOF_bin_edges: np.array = None,
substr: str = 'Coincidences',
same_source_pos: bool = False
) -> np.array:
"""Obtains the detector IDs from a sequence of ROOT files
Args:
paths (Sequence[str]): sequence of root file paths
mac_file (str): GATE geometry macro file
TOF (bool, optional): Whether or not to get TOF binning information. Defaults to False.
TOF_bin_edges (np.array, optional): TOF bin edges; required if TOF is True. Defaults to None.
substr (str, optional): Substring to index for in ROOT files. Defaults to 'Coincidences'.
same_source_pos (bool, optional): Only include coincidences that correspond to the same source position. This can be used to filter randoms. Defaults to False.
Returns:
np.array: Array of all detector ID pairs corresponding to all detected LORs.
"""
if TOF:
if TOF_bin_edges is None:
Exception('If using TOF, must provide TOF bin edges for binning')
N_gantry, N_rsector, N_module, N_submodule, N_crystal = get_N_components(mac_file)
detector_ids = [[],[],[]]
for i,path in enumerate(paths):
with uproot.open(path) as f:
if same_source_pos:
xs1 = f[substr]['sourcePosX1'].array(library='np')
xs2 = f[substr]['sourcePosX2'].array(library='np')
ys1 = f[substr]['sourcePosY1'].array(library='np')
ys2 = f[substr]['sourcePosY2'].array(library='np')
zs1 = f[substr]['sourcePosZ1'].array(library='np')
zs2 = f[substr]['sourcePosZ2'].array(library='np')
same_location_idxs = (xs1==xs2)*(ys1==ys2)*(zs1==zs2)
for j in range(2):
gantry_id = f[substr][f'gantryID{j+1}'].array(library="np")
rsector_id = f[substr][f'rsectorID{j+1}'].array(library="np")
module_id = f[substr][f'moduleID{j+1}'].array(library="np")
submodule_id = f[substr][f'submoduleID{j+1}'].array(library="np")
crystal_id = f[substr][f'crystalID{j+1}'].array(library="np")
detector_id = crystal_id * N_submodule * N_module * N_rsector * N_gantry \
+ submodule_id * N_module * N_rsector * N_gantry \
+ module_id * N_rsector * N_gantry \
+ rsector_id * N_gantry \
+ gantry_id
if same_source_pos:
detector_id = detector_id[same_location_idxs]
detector_ids[j].append(detector_id.astype(np.int16))
if TOF:
t1 = f[substr]['time1'].array(library='np')
t2 = f[substr]['time2'].array(library='np')
tof_pos = 1e12*(t2 - t1) * 0.15 # ps to mm
detector_id = np.digitize(-tof_pos, TOF_bin_edges) - 1
if same_source_pos:
detector_id = detector_id[same_location_idxs]
detector_ids[2].append(detector_id)
if TOF:
return np.array([
np.concatenate(detector_ids[0]),
np.concatenate(detector_ids[1]),
np.concatenate(detector_ids[2])]).T
else:
return np.array([
np.concatenate(detector_ids[0]),
np.concatenate(detector_ids[1])]).T
[docs]def get_radius(detector_ids: torch.tensor, scanner_LUT: torch.tensor) -> torch.tensor:
"""Gets the radial position of all LORs
Args:
detector_ids (torch.tensor): Detector ID pairs corresponding to LORs
scanner_LUT (torch.tensor): scanner look up table
Returns:
torch.tensor: radii of all detector ID pairs provided
"""
x1, y1, z1 = scanner_LUT[detector_ids[:,0]].T
x2, y2, z2 = scanner_LUT[detector_ids[:,1]].T
return torch.where(
(x1==x2)*(y1==y2),
torch.sqrt(x1**2+y1**2),
torch.abs(x1*y2-y1*x2)/torch.sqrt((x1-x2)**2+(y1-y2)**2)
)
[docs]def get_table(det_ids: torch.tensor, mac_file: str) -> torch.tensor:
r"""Obtains a table of crystal1ID, crystal2ID, submoduleID, :math:`\Delta`moduleID, :math:`\Delta`rsectorID corresponding to each of the detector id pairs provided. Useful fo symmetries when computing normalization :math:`\eta`.
Args:
det_ids (torch.tensor): :math:`N \times 2` (non-TOF) or :math:`N \times 3` (TOF) tensor that provides detector ID pairs (and TOF bin) for coincidence events.
mac_file (str): GATE macro file that defines detector geometry
Returns:
torch.tensor: A 2D tensor that lists crystal1ID, crystal2ID, submoduleID, :math:`\Delta`moduleID, :math:`\Delta`rsectorID for each LOR.
"""
N_gantry, N_rsector, N_module, N_submodule, N_crystal = get_N_components(mac_file)
# Larger ID comes second
det_ids = torch.vstack([det_ids.min(axis=1)[0], det_ids.max(axis=1)[0]]).T
cry_ids = det_ids // (N_submodule * N_module * N_rsector * N_gantry)
subM_ids = det_ids % (N_submodule * N_module * N_rsector * N_gantry) // (N_module * N_rsector * N_gantry)
M_ids = det_ids % (N_module * N_rsector * N_gantry) // (N_rsector * N_gantry)
R_ids = det_ids % (N_rsector * N_gantry) // N_gantry
deltaM_ids = torch.abs(torch.diff(M_ids, axis=1))
deltaR_ids = torch.abs(torch.diff(R_ids, axis=1))
return torch.tensor(torch.concatenate([cry_ids, subM_ids, deltaM_ids, deltaR_ids], axis=1))
[docs]def get_eta_cylinder_calibration(
paths: Sequence[str],
mac_file: str,
cylinder_radius: float,
same_source_pos: bool = False,
mean_interaction_depth: float = 0
) -> torch.tensor:
"""Obtain normalization :math:`\eta` from a calibration scan consisting of a cylindrical shell
Args:
paths (Sequence[str]): paths of all ROOT files containing data
mac_file (str): GATE macro file that defines scanner geometry
cylinder_radius (float): The radius of the cylindrical shell used in calibration
same_source_pos (bool, optional): Only include coincidence events with same source position; can be used to filter out randoms. Defaults to False.
mean_interaction_depth (float, optional): Mean interaction depth of photons in detector crystals. Defaults to 0.
Returns:
torch.tensor: Tensor corresponding to :math:`eta`.
"""
N_gantry, N_rsector, N_module, N_submodule, N_crystal = get_N_components(mac_file)
N_detectors = N_gantry* N_rsector * N_module * N_submodule * N_crystal
# Geometry correction factor for non-unform exposure from cylindrical shell
scanner_LUT = torch.tensor(get_scanner_LUT(mac_file, mean_interaction_depth=mean_interaction_depth))
all_LOR_ids = torch.combinations(torch.arange(N_detectors).to(torch.int32), 2)
geometric_correction_factor = 1/(torch.sqrt(1-(get_radius(all_LOR_ids, scanner_LUT) / cylinder_radius )**2) + pytomography.delta)
# Detector correction factor (exploits symmetries)
H = torch.zeros(N_crystal, N_crystal, N_submodule, N_submodule, N_module, N_rsector)
for path in paths:
det_ids = torch.tensor(get_detector_ids([path], mac_file, same_source_pos=same_source_pos))
vals = get_table(det_ids, mac_file)
bins = [torch.arange(x).to(torch.float32)-0.5 for x in [N_crystal+1, N_crystal+1, N_submodule+1, N_submodule+1, N_module+1, N_rsector+1]]
H += torch.histogramdd(vals.to(torch.float32), bins)[0]
vals_all_pairs = get_table(torch.combinations(torch.arange(N_detectors).to(torch.int32), 2), mac_file)
N_bins = torch.histogramdd(vals_all_pairs.to(torch.float32), bins)[0]
# If you want to test this later, also return H and N_bins seperately
return (H/N_bins)[vals_all_pairs[:,0], vals_all_pairs[:,1], vals_all_pairs[:,2], vals_all_pairs[:,3], vals_all_pairs[:,4], vals_all_pairs[:,5]] * geometric_correction_factor
# Removes all LORs not intersecting with reconstruction cube
[docs]def remove_events_out_of_bounds(
detector_ids: torch.tensor,
scanner_LUT: torch.tensor,
object_meta: ObjectMeta
) -> torch.tensor:
r"""Removes all detected LORs outside of the reconstruced volume given by ``object_meta``.
Args:
detector_ids (torch.tensor): :math:`N \times 2` (non-TOF) or :math:`N \times 3` (TOF) tensor that provides detector ID pairs (and TOF bin) for coincidence events.
scanner_LUT (torch.tensor): scanner lookup table that provides spatial coordinates for all detector ID pairs
object_meta (ObjectMeta): object metadata providing the region of reconstruction
Returns:
torch.tensor: all detector ID pairs corresponding to coincidence events
"""
bmin = -torch.tensor(object_meta.shape) * torch.tensor(object_meta.dr) / 2
bmax = torch.tensor(object_meta.shape) * torch.tensor(object_meta.dr) / 2
bmin = bmin.to(detector_ids.device); bmax=bmax.to(detector_ids.device)
origin = scanner_LUT[detector_ids[:,0]]
direction = scanner_LUT[detector_ids[:,1]] - origin
t1 = torch.where(
direction>=0,
(bmin - origin) / direction,
(bmax - origin) / direction
)
t2 = torch.where(
direction>=0,
(bmax - origin) / direction,
(bmin - origin) / direction
)
intersect = (t1[:,0]>t2[:,1])+(t1[:,1]>t2[:,0])+((t1[:,0]>t2[:,2]))+(t1[:,2]>t2[:,0])
return detector_ids[~intersect]