Source code for pytomography.algorithms.fbp
"""This module contains classes that implement filtered back projection reconstruction algorithms.
"""
from __future__ import annotations
import pytomography
import torch
from pytomography.metadata import SPECTObjectMeta, SPECTProjMeta
from pytomography.projectors.SPECT import SPECTSystemMatrix
from pytomography.utils import RampFilter
from collections.abc import Sequence
[docs]class FilteredBackProjection:
r"""Implementation of filtered back projection reconstruction :math:`\hat{f} = \frac{\pi}{N_{\text{proj}}} \mathcal{R}^{-1}\mathcal{F}^{-1}\Pi\mathcal{F} g` where :math:`N_{\text{proj}}` is the number of projections, :math:`\mathcal{R}` is the 3D radon transform, :math:`\mathcal{F}` is the 2D Fourier transform (applied to each projection seperately), and :math:`\Pi` is the filter applied in Fourier space, which is by default the ramp filter.
Args:
projections (torch.Tensor): projection data :math:`g` to be reconstructed
angles (Sequence): Angles corresponding to each projection
filter (Callable, optional): Additional Fourier space filter (applied after Ramp Filter) used during reconstruction.
"""
def __init__(
self,
projections: torch.tensor,
angles: Sequence[float],
filter=None
) -> None:
self.proj = projections
self.object_meta = SPECTObjectMeta(dr=(1,1,1),shape=(self.proj.shape[2], self.proj.shape[2], self.proj.shape[3]))
self.proj_meta = SPECTProjMeta(projection_shape=self.proj.shape[2:],angles=angles)
self.filter = filter
# Random transform equivalent to SPECT System matrix
self.system_matrix = SPECTSystemMatrix(
obj2obj_transforms=[],
proj2proj_transforms=[],
object_meta=self.object_meta,
proj_meta=self.proj_meta)
[docs] def __call__(self):
"""Applies reconstruction
Returns:
torch.tensor: Reconstructed object prediction
"""
freq_fft = torch.fft.fftfreq(self.proj.shape[-2]).reshape((-1,1)).to(pytomography.device)
filter_total = RampFilter()(freq_fft)
if self.filter is not None:
filter_total *= self.filter(freq_fft)
proj_fft = torch.fft.fft(self.proj, axis=-2)
proj_fft = proj_fft* filter_total
proj_filtered = torch.fft.ifft(proj_fft, axis=-2).real
object_prediction = self.system_matrix.backward(proj_filtered) * torch.pi / len(self.proj_meta.angles)
return object_prediction