Source code for pytomography.transforms.SPECT.atteunation

from __future__ import annotations
from typing import Sequence
import numpy.linalg as npl
import torch
import torch.nn as nn
from pytomography.utils import rotate_detector_z, rev_cumsum, pad_object
from pytomography.transforms import Transform
from pytomography.io.SPECT import open_CT_file
from pytomography.metadata import SPECTObjectMeta, SPECTImageMeta
from pytomography.io.SPECT import get_attenuation_map_from_CT_slices

[docs]def get_prob_of_detection_matrix(attenuation_map: torch.Tensor, dx: float) -> torch.tensor: r"""Converts an attenuation map of :math:`\text{cm}^{-1}` to a probability of photon detection matrix (scanner at +x). Note that this requires the attenuation map to be at the energy of photons being emitted. Args: attenuation_map (torch.tensor): Tensor of size [batch_size, Lx, Ly, Lz] corresponding to the attenuation coefficient in :math:`{\text{cm}^{-1}} dx (float): Axial plane pixel spacing. Returns: torch.tensor: Tensor of size [batch_size, Lx, Ly, Lz] corresponding to probability of photon being detected at detector at +x axis. """ return torch.exp(-rev_cumsum(attenuation_map * dx))
[docs]class SPECTAttenuationTransform(Transform): r"""obj2obj transform used to model the effects of attenuation in SPECT. Args: attenuation_map (torch.tensor): Tensor of size [batch_size, Lx, Ly, Lz] corresponding to the attenuation coefficient in :math:`{\text{cm}^{-1}}` at the photon energy corresponding to the particular scan filepath (Sequence[str]): FILL IN """ def __init__( self, attenuation_map: torch.Tensor | None = None, filepath: Sequence[str] | None = None )-> None: super(SPECTAttenuationTransform, self).__init__() self.filepath = filepath if attenuation_map is None and filepath is None: raise Exception("Please supply only one of `attenuation_map` or `filepath` as arguments") elif filepath is None: # Assumes CT is aligned with SPECT projections self.attenuation_map = attenuation_map.to(self.device) else: # TODO: offer support for all input types self.CT_unaligned_numpy = open_CT_file(filepath) # Will then get aligned with projections when configured
[docs] def configure( self, object_meta: SPECTObjectMeta, image_meta: SPECTImageMeta ) -> None: """Function used to initalize the transform using corresponding object and image metadata Args: object_meta (SPECTObjectMeta): Object metadata. image_meta (SPECTImageMeta): Image metadata. """ super(SPECTAttenuationTransform, self).configure(object_meta, image_meta) # Align CT with SPECT and rescale units TODO: If CT extends beyond boundaries if self.filepath is not None: self.attenuation_map = get_attenuation_map_from_CT_slices(self.filepath, image_meta.filepath, image_meta.index_peak)
@torch.no_grad()
[docs] def forward( self, object_i: torch.Tensor, ang_idx: torch.Tensor, )-> torch.Tensor: r"""Forward projection :math:`A:\mathbb{U} \to \mathbb{U}` of attenuation correction. Args: object_i (torch.tensor): Tensor of size [batch_size, Lx, Ly, Lz] being projected along ``axis=1``. ang_idx (torch.Tensor): The projection indices: used to find the corresponding angle in image space corresponding to each projection angle in ``object_i``. Returns: torch.tensor: Tensor of size [batch_size, Lx, Ly, Lz] such that projection of this tensor along the first axis corresponds to an attenuation corrected projection. """ attenuation_map = pad_object(self.attenuation_map) norm_factor = get_prob_of_detection_matrix(rotate_detector_z(attenuation_map.repeat(object_i.shape[0],1,1,1), self.image_meta.angles[ang_idx]), self.object_meta.dx) object_i*=norm_factor return object_i
@torch.no_grad()
[docs] def backward( self, object_i: torch.Tensor, ang_idx: torch.Tensor, norm_constant: torch.Tensor | None = None, ) -> torch.Tensor: r"""Back projection :math:`A^T:\mathbb{U} \to \mathbb{U}` of attenuation correction. Since the matrix is diagonal, the implementation is the same as forward projection. The only difference is the optional normalization parameter. Args: object_i (torch.tensor): Tensor of size [batch_size, Lx, Ly, Lz] being projected along ``axis=1``. ang_idx (torch.Tensor): The projection indices: used to find the corresponding angle in image space corresponding to each projection angle in ``object_i``. norm_constant (torch.tensor, optional): A tensor used to normalize the output during back projection. Defaults to None. Returns: torch.tensor: Tensor of size [batch_size, Lx, Ly, Lz] such that projection of this tensor along the first axis corresponds to an attenuation corrected projection. """ attenuation_map = pad_object(self.attenuation_map) norm_factor = get_prob_of_detection_matrix(rotate_detector_z(attenuation_map.repeat(object_i.shape[0],1,1,1), self.image_meta.angles[ang_idx]), self.object_meta.dx) object_i*=norm_factor if norm_constant is not None: norm_constant*=norm_factor return object_i, norm_constant else: return object_i