Implementing Filtered Back Projection#

We’ll use the classes of PyTomography to implement filtered back projection in SPECT.

[1]:
import os
from pytomography.projections import SPECTSystemMatrix
from pytomography.metadata import ObjectMeta, ImageMeta
import numpy as np
import matplotlib.pyplot as plt
import torch

The two foundational tools of image reconstruction are

  1. Forward projection \(\sum_{i} c_{ij} a_i\)

  2. Back projection \(\sum_{j} c_{ij} b_j\)

Let’s discuss what these operators actually mean. First, let’s define our quantities. \(c_{ij}\) is known as the system matrix, and may include information involving attenuation and PSF correction. \(a_i\) is an arbtriary object and \(b_j\) is an arbitrary image.

It’s worth now discussing what the indices \(i\) and \(j\) actually mean. You might think: objects are three dimensional, shouldn’t there be at least 3 indices when we’re doing linear operations? Consider the following: because we are in a discrete space, any 3 dimensional object can be converted to a single (albeit very long) one dimensional object: a 128x128x128 3D matrix can be converted into a single 1D vector of length 2097152. That’s how many voxels there are in object space: you can think of index \(i\) as indexing a single voxel.

The same can be said for an image. If we have 64 projections of matrix size 128x128, then that can be thought of as a single vector of length 1048576. That’s also how many individual detector elements there are.

So in forward projection \(\sum_{i} c_{ij} a_i\), the system matrix \(c_{ij}\) maps the contribution from voxel \(i\) to a detector element \(j\). In back projection \(\sum_{j} c_{ij} b_j\), the system matrix \(c_{ij}\) maps the intensity in detector element back to every possible voxel \(i\) that could have contributed to it. In reality, however, not every voxel that could have contributed to detector element \(j\) does so with equal intensity; it is for this reason that forward projection followed by back projection does not yield the original image.

Let’s experiment with these operators. First we’ll make a 3D rectangle in object space:

[2]:
x = torch.linspace(-1,1,128)
y = torch.linspace(-1,1,128)
z = torch.linspace(-1,1,132)
xv, yv, zv = torch.meshgrid([x,y,z], indexing='ij')
object_truth = (xv>-0.2)*(xv<0.2)*(yv>-0.15)*(yv<0.15)*(zv>-0.1)*(zv<0.1)
object_truth = object_truth.to(torch.float).unsqueeze(dim=0) # add batch dimension
object_truth.shape
[2]:
torch.Size([1, 128, 128, 132])
[3]:
plt.figure(figsize=(5,4))
plt.pcolormesh(object_truth[0][:,:,64].T, cmap='Greys_r')
plt.axis('off')
plt.colorbar()
[3]:
<matplotlib.colorbar.Colorbar at 0x7f5955a0d070>
../_images/notebooks_t_fbp_6_1.png

Before we do any projections, we need to get corresponding metadata for our object. In this case, we’ll assume the voxel sizes are 1cm \(^3\). For our image space, we’ll assume 60 projections are taken at angular spacing of 6 degrees.

[4]:
angles = np.arange(0,360.,6.)
object_meta = ObjectMeta(dr=(1,1,1), shape=object_truth[0].shape)
image_meta = ImageMeta(object_meta, angles=angles)

With this metadata, we can create our forward and back projection networks. We’ll model no phenomenon for now.

[5]:
system_matrix = SPECTSystemMatrix(
    obj2obj_transforms=[],
    im2im_transforms=[],
    object_meta=object_meta,
    image_meta=image_meta)

We can now use the forward method of system_matrix to model \(g=Hf\) (convert the object \(f\) into an image \(g\))

[6]:
image = system_matrix.forward(object_truth)
image.shape
[6]:
torch.Size([1, 60, 128, 132])

We can look at a projection at 60 degrees for example:

[7]:
fig, axes = plt.subplots(1,5,figsize=(15,4))
for i, proj in enumerate([0,5,10,15,20]):
    axes[i].pcolormesh(image[0][proj].T, cmap='Greys_r')
    axes[i].set_title(f'Angle={image_meta.angles[proj]}')
    axes[i].axis('off')
../_images/notebooks_t_fbp_14_0.png

At angles like 60 degrees, the cube is darkest in the center and lighter on the outside; this is like looking through a semi-transparent cube in real life, it’s going to be darkest near the center if you’re looking at it from an off angle.

We can also back project using the backward method of the system_matrix instance. Here we compute \(H^T g\) and the normalization constant \(H^T 1\)

[8]:
object_bp, norm_bp = system_matrix.backward(image, return_norm_constant=True)
object_bp
[8]:
tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         ...,

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]]])

But if we look at the new object:

[9]:
axial_slice = object_bp[0][:,:,64].T
[10]:
plt.figure(figsize=(5,4))
plt.pcolormesh(axial_slice, cmap='Greys_r')
plt.axis('off')
plt.colorbar()
[10]:
<matplotlib.colorbar.Colorbar at 0x7f5921114fa0>
../_images/notebooks_t_fbp_20_1.png

We can see that it has been blurred, and that the values are much too large. How can we fix this?

Example: Filtered Back Projection. In this case the image estimate is given by

\[\hat{f}_i = \frac{\pi}{N_{\text{proj}}} \sum_i H_{ij} \left( \mathcal{F}^{-1}(|\omega|\mathcal{F}(g)) \right)_i\]

where the term in brackets involves applying a 1D convolution (in this case, multiplication in Fourier space with the Ramp filter) to the image along the \(r\) axis.

[11]:
freq_fft = torch.fft.fftfreq(image.shape[-2])
filter = torch.abs(freq_fft).reshape((-1,1))
image_fft = torch.fft.fft(image, axis=-2)
image_fft = image_fft* filter
image_filtered = torch.fft.ifft(image_fft, axis=-2).real

Now we can back project and normalize

[12]:
object_fbp = system_matrix.backward(image_filtered) *np.pi / len(image_meta.angles)
[13]:
plt.figure(figsize=(5,4))
plt.pcolormesh(object_fbp[0][:,:,64].T, cmap='Greys_r')
plt.axis('off')
plt.colorbar()
[13]:
<matplotlib.colorbar.Colorbar at 0x7f592105af40>
../_images/notebooks_t_fbp_26_1.png

The cube is no longer blurred, but artifacts are present. Such artifacts are not present when using algorithms like OSEM for reconstruction.