from __future__ import annotations
import torch
import torch.nn as nn
import numpy as np
import pytomography
from pytomography.utils import get_distance, compute_pad_size, pad_object, pad_object_z, unpad_object_z, rotate_detector_z, unpad_object
from pytomography.mappings import MapNet
from pytomography.metadata import ObjectMeta, ImageMeta, PSFMeta
from torchvision.transforms import InterpolationMode
[docs]class SPECTPSFNet(MapNet):
"""obj2obj network used to model the effects of PSF blurring in SPECT. The smoothing kernel used to apply PSF modeling uses a Gaussian kernel with width :math:`\sigma` dependent on the distance of the point to the detector; that information is specified in the ``PSFMeta`` parameter.
Args:
psf_meta (PSFMeta): Metadata corresponding to the parameters of PSF blurring
device (str, optional): Pytorch device used for computation. If None, uses the default device `pytomography.device` Defaults to None.
"""
def __init__(
self,
psf_meta: PSFMeta,
device: str = None
) -> None:
"""Initializer that sets corresponding psf parameters"""
super(SPECTPSFNet, self).__init__(device)
self.psf_meta = psf_meta
[docs] def initialize_network(
self,
object_meta: ObjectMeta,
image_meta: ImageMeta
) -> None:
"""Function used to initalize the mapping network using corresponding object and image metadata
Args:
object_meta (ObjectMeta): Object metadata.
image_meta (ImageMeta): Image metadata.
"""
super(SPECTPSFNet, self).initialize_network(object_meta, image_meta)
self.kernel_size = self.compute_kernel_size()
self.layers = {}
for radius in np.unique(image_meta.radii):
sigma = self.get_sigma(radius, object_meta.dx, object_meta.shape, self.psf_meta.collimator_slope, self.psf_meta.collimator_intercept)
self.layers[radius] = get_PSF_transform(sigma/object_meta.dx, self.kernel_size, kernel_dimensions=self.psf_meta.kernel_dimensions, device=self.device)
[docs] def compute_kernel_size(self) -> int:
"""Function used to compute the kernel size used for PSF blurring. In particular, uses the ``max_sigmas`` attribute of ``PSFMeta`` to determine what the kernel size should be such that the kernel encompasses at least ``max_sigmas`` at all points in the object.
Returns:
int: The corresponding kernel size used for PSF blurring.
"""
s = self.object_meta.padded_shape[0]
dx = self.object_meta.dr[0]
largest_sigma = self.psf_meta.collimator_slope*(s/2 * dx) + self.psf_meta.collimator_intercept
return int(np.round(largest_sigma/dx * self.psf_meta.max_sigmas)*2 + 1)
[docs] def get_sigma(
self,
radius: float,
dx: float,
shape: tuple,
collimator_slope: float,
collimator_intercept: float
) -> np.array:
"""Uses PSF Meta data information to get blurring :math:`\sigma` as a function of distance from detector. It is assumed that ``sigma=collimator_slope*d + collimator_intercept`` where :math:`d` is the distance from the detector.
Args:
radius (float): The distance from the detector
dx (float): Transaxial plane pixel spacing
shape (tuple): Tuple containing (Lx, Ly, Lz): dimensions of object space
collimator_slope (float): See collimator intercept
collimator_intercept (float): Collimator slope and collimator intercept are defined such that sigma(d) = collimator_slope*d + collimator_intercept
where sigma corresponds to sigma of a Gaussian function that characterizes blurring as a function of distance from the detector.
Returns:
array: An array of length Lx corresponding to blurring at each point along the 1st axis in object space
"""
dim = shape[0] + 2*compute_pad_size(shape[0])
distances = get_distance(dim, radius, dx)
sigma = collimator_slope * distances + collimator_intercept
return sigma
@torch.no_grad()
[docs] def forward(
self,
object_i: torch.Tensor,
i: int,
norm_constant: torch.Tensor | None = None,
) -> torch.tensor:
"""Applies PSF modeling for the situation where an object is being detector by a detector at the :math:`+x` axis.
Args:
object_i (torch.tensor): Tensor of size [batch_size, Lx, Ly, Lz] being projected along its first axis
i (int): The projection index: used to find the corresponding angle in image space corresponding to object i
norm_constant (torch.tensor, optional): A tensor used to normalize the output during back projection. Defaults to None.
Returns:
torch.tensor: Tensor of size [batch_size, Lx, Ly, Lz] such that projection of this tensor along the first axis corresponds to
an PSF corrected projection.
"""
z_pad_size = int((self.kernel_size-1)/2)
object_i = pad_object_z(object_i, z_pad_size)
object_i = self.layers[self.image_meta.radii[i]](object_i)
object_i = unpad_object_z(object_i, pad_size=z_pad_size)
# Adjust normalization constant
if norm_constant is not None:
norm_constant = pad_object_z(norm_constant, z_pad_size)
norm_constant = self.layers[self.image_meta.radii[i]](norm_constant)
norm_constant = unpad_object_z(norm_constant, pad_size=z_pad_size)
return object_i, norm_constant
else:
return object_i