Bayesian Priors#

[1]:
import os
from pytomography.io.SPECT import simind
from pytomography.priors import RelativeDifferencePrior
from pytomography.projectors import SPECTSystemMatrix
from pytomography.transforms import SPECTAttenuationTransform, SPECTPSFTransform
from pytomography.algorithms import BSREM, OSEM
from pytomography.priors import QuadraticPrior, RelativeDifferencePrior
from pytomography.priors import TopNAnatomyNeighbourWeight
from torch import poisson
import torch
import matplotlib.pyplot as plt
from pytomography.callbacks import CallBack
import numpy as np

Modify the following path to the directory where you saved the tutorial data:

[2]:
path = '/disk1/pytomography_tutorial_data/simind_tutorial/'

The first cell of code is borrowed from the SPECT: Reconstructing SIMIND Data tutorial in the multiple regions case. For a more comprehensive description about the code below, please see that tutorial.

[3]:
organs = ['bkg', 'liver', 'l_lung', 'r_lung', 'l_kidney', 'r_kidney','salivary', 'bladder']
activities = [2500, 450, 7, 7, 100, 100, 20, 90] # MBq
headerfiles = [os.path.join(path, 'multi_projections', organ, 'photopeak.h00') for organ in organs]
headerfiles_lower = [os.path.join(path, 'multi_projections', organ, 'lowerscatter.h00') for organ in organs]
headerfiles_upper = [os.path.join(path, 'multi_projections', organ, 'upperscatter.h00') for organ in organs]
object_meta, proj_meta = simind.get_metadata(headerfiles[0]) #assumes the same for all
photopeak = simind.combine_projection_data(headerfiles, activities)
scatter = simind.combine_scatter_data_TEW(headerfiles, headerfiles_lower, headerfiles_upper, activities)
# Convert from CPS to counts
dT = 15 #s
photopeak *= dT
scatter *= dT
photopeak_poisson = poisson(photopeak)
scatter_poisson = poisson(scatter)
# Obtain required transforms to build system matrix
attenuation_map = simind.get_attenuation_map(os.path.join(path, 'multi_projections', 'mu208.hct'))
att_transform = SPECTAttenuationTransform(attenuation_map)
psf_meta = simind.get_psfmeta_from_header(headerfiles[0])
psf_transform = SPECTPSFTransform(psf_meta)
system_matrix = SPECTSystemMatrix(
    obj2obj_transforms = [att_transform,psf_transform],
    proj2proj_transforms = [],
    object_meta = object_meta,
    proj_meta = proj_meta,
    n_parallel=8)

Once the system matrix is setup, its time to define the priors for the reconstruction algorithm. PyTomography has many useful prior functions for SPECT reconstruction. Most of them are derived from NearestNeighbourPrior; meaning prior information for each voxel only depends on the neighbouring 26 voxels. Priors can be defined as follows:

[4]:
prior_quad = QuadraticPrior(beta=0.3)
prior_rdp = RelativeDifferencePrior(beta=0.3, gamma=2)

By default, the “contribution” (or weight) from the 26 neighbouring voxels is scaled by the Euclidean distance from the central voxel. Voxels on the corners (8 total) therefore only contribute \(1/\sqrt{3}\) compared to the adjacent 6 (up/down, left/right, front/back). The weight in this case is an instance of EuclideanNeighbourWeight.

PyTomography also has options to use other weights for neighbouring voxels. For example, it can use anatomical images (such as attenuation maps and CT images) to obtain a different weighting scheme. For example, the TopNAnatomyNeighbourWeight is an extension of EuclideanNeighbourWeight that only uses non-zero weights for the top-N closest neighbours based on an external anatomical image:

[5]:
weight_top8anatomy = TopNAnatomyNeighbourWeight(attenuation_map, N_neighbours=8)

This can be used to create custom priors with the given weighing. These are typically referred to as anatomical priors (AP) and would be designated as Quadratic-AP or RDP-AP

[6]:
prior_quad_weighttop8 = QuadraticPrior(beta=0.3, weight=weight_top8anatomy)
prior_rdp_weighttop8 = RelativeDifferencePrior(beta=0.3, gamma=2, weight=weight_top8anatomy)

We can now reconstruct this object using no prior, and the four prior functions we just defined

[7]:
class ComputeLogLiklihood(CallBack):
    def __init__(self, projections, system_matrix):
        self.projections = projections
        self.system_matrix = system_matrix
        self.liklihoods = []
    def run(self, object, n_iter):
        projection_estimate = system_matrix.forward(object)
        liklihood = self.projections*torch.log(projection_estimate) - projection_estimate
        liklihood[self.projections<=0] = -projection_estimate[self.projections<=0]
        self.liklihoods.append(liklihood.sum().item())
[8]:
def reconstruct(
    prior = None,
    n_iters=50,
    n_subsets=8,
    relaxation_function=lambda n: 1,
    scaling_matrix_type='subind_norm'
    ):
    cb = ComputeLogLiklihood(photopeak_poisson, system_matrix)
    reconstruction_algorithm = BSREM(
    projections = photopeak_poisson,
    system_matrix = system_matrix,
    scatter = scatter_poisson,
    prior=prior,
    relaxation_function=relaxation_function,
    scaling_matrix_type=scaling_matrix_type)
    return reconstruction_algorithm(n_iters, n_subsets, callback=cb), cb

First we can reconstructed using a relaxation sequence that corresponds to “unrelaxed” (\(\alpha_n=1\)) and “relaxed” (\(\alpha_n=1/(n/15+1)\))

[22]:
recon_unrelaxed, cb_unrelaxed = reconstruct(relaxation_function=lambda n: 1, n_subsets=16, prior=prior_rdp)
recon_relaxed, cb_relaxed = reconstruct(relaxation_function=lambda n: 1/(n/50+1), n_subsets=16, prior=prior_rdp)
[23]:
plt.subplots(1,2,figsize=(4,5))
plt.subplot(121)
plt.pcolormesh(recon_unrelaxed[0].cpu()[:,70].T, cmap='nipy_spectral')
plt.colorbar()
plt.axis('off')
plt.title('Unrelaxed')
plt.subplot(122)
plt.pcolormesh(recon_relaxed[0].cpu()[:,70].T, cmap='nipy_spectral')
plt.colorbar()
plt.axis('off')
plt.title('Relaxed')
[23]:
Text(0.5, 1.0, 'Relaxed')
../_images/notebooks_t_spectpriors_17_1.png

We can also plot the log-liklihoods:

[24]:
plt.figure(figsize=(4,2))
plt.plot(cb_unrelaxed.liklihoods, label='Unrelaxed')
plt.plot(cb_relaxed.liklihoods, ls='--', label='Relaxed')
plt.ylim(bottom=0.9995*max(cb_unrelaxed.liklihoods), top=1.00001*max(cb_unrelaxed.liklihoods))
plt.xlabel('Iteration')
plt.ylabel('$\log L$')
plt.legend()
plt.show()
../_images/notebooks_t_spectpriors_19_0.png
[12]:
recon_noprior = reconstruct(None)
recon_quad = reconstruct(prior_quad)
recon_rdp = reconstruct(prior_rdp)
recon_quad_weighttop8 = reconstruct(prior_quad_weighttop8)
recon_rdp_weighttop8 = reconstruct(prior_rdp_weighttop8)
[13]:
plt.subplots(1,5,figsize=(12,5))
plt.subplot(151)
plt.pcolormesh(recon_noprior[0].cpu()[:,70].T, cmap='nipy_spectral')
plt.colorbar()
plt.axis('off')
plt.title('No Prior')
plt.subplot(152)
plt.pcolormesh(recon_quad[0].cpu()[:,70].T, cmap='nipy_spectral')
plt.colorbar()
plt.axis('off')
plt.title('Quadratic')
plt.subplot(153)
plt.pcolormesh(recon_rdp[0].cpu()[:,70].T, cmap='nipy_spectral')
plt.colorbar()
plt.axis('off')
plt.title('RDP')
plt.subplot(154)
plt.pcolormesh(recon_quad_weighttop8[0].cpu()[:,70].T, cmap='nipy_spectral')
plt.colorbar()
plt.axis('off')
plt.title('Quadratic-AP')
plt.subplot(155)
plt.pcolormesh(recon_rdp_weighttop8[0].cpu()[:,70].T, cmap='nipy_spectral')
plt.colorbar()
plt.axis('off')
plt.title('RDP-AP')
plt.show()
/tmp/ipykernel_14436/447662742.py:3: UserWarning: The use of `x.T` on tensors of dimension other than 2 to reverse their shape is deprecated and it will throw an error in a future release. Consider `x.mT` to transpose batches of matricesor `x.permute(*torch.arange(x.ndim - 1, -1, -1))` to reverse the dimensions of a tensor. (Triggered internally at  ../aten/src/ATen/native/TensorShape.cpp:2981.)
  plt.pcolormesh(recon_noprior[0].cpu()[:,70].T, cmap='nipy_spectral')
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/home/gpuvmadm/PyTomography/docs/source/notebooks/t_spectpriors.ipynb Cell 22 line <cell line: 3>()
      <a href='vscode-notebook-cell://ssh-remote%2Bgpuvm00004jhubvm01.canadacentral.cloudapp.azure.com/home/gpuvmadm/PyTomography/docs/source/notebooks/t_spectpriors.ipynb#Y114sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0'>1</a> plt.subplots(1,5,figsize=(12,5))
      <a href='vscode-notebook-cell://ssh-remote%2Bgpuvm00004jhubvm01.canadacentral.cloudapp.azure.com/home/gpuvmadm/PyTomography/docs/source/notebooks/t_spectpriors.ipynb#Y114sdnNjb2RlLXJlbW90ZQ%3D%3D?line=1'>2</a> plt.subplot(151)
----> <a href='vscode-notebook-cell://ssh-remote%2Bgpuvm00004jhubvm01.canadacentral.cloudapp.azure.com/home/gpuvmadm/PyTomography/docs/source/notebooks/t_spectpriors.ipynb#Y114sdnNjb2RlLXJlbW90ZQ%3D%3D?line=2'>3</a> plt.pcolormesh(recon_noprior[0].cpu()[:,70].T, cmap='nipy_spectral')
      <a href='vscode-notebook-cell://ssh-remote%2Bgpuvm00004jhubvm01.canadacentral.cloudapp.azure.com/home/gpuvmadm/PyTomography/docs/source/notebooks/t_spectpriors.ipynb#Y114sdnNjb2RlLXJlbW90ZQ%3D%3D?line=3'>4</a> plt.colorbar()
      <a href='vscode-notebook-cell://ssh-remote%2Bgpuvm00004jhubvm01.canadacentral.cloudapp.azure.com/home/gpuvmadm/PyTomography/docs/source/notebooks/t_spectpriors.ipynb#Y114sdnNjb2RlLXJlbW90ZQ%3D%3D?line=4'>5</a> plt.axis('off')

File /data/anaconda/envs/pytomographytest/lib/python3.9/site-packages/matplotlib/pyplot.py:2773, in pcolormesh(alpha, norm, cmap, vmin, vmax, shading, antialiased, data, *args, **kwargs)
   2768 @_copy_docstring_and_deprecators(Axes.pcolormesh)
   2769 def pcolormesh(
   2770         *args, alpha=None, norm=None, cmap=None, vmin=None,
   2771         vmax=None, shading=None, antialiased=False, data=None,
   2772         **kwargs):
-> 2773     __ret = gca().pcolormesh(
   2774         *args, alpha=alpha, norm=norm, cmap=cmap, vmin=vmin,
   2775         vmax=vmax, shading=shading, antialiased=antialiased,
   2776         **({"data": data} if data is not None else {}), **kwargs)
   2777     sci(__ret)
   2778     return __ret

File /data/anaconda/envs/pytomographytest/lib/python3.9/site-packages/matplotlib/__init__.py:1442, in _preprocess_data.<locals>.inner(ax, data, *args, **kwargs)
   1439 @functools.wraps(func)
   1440 def inner(ax, *args, data=None, **kwargs):
   1441     if data is None:
-> 1442         return func(ax, *map(sanitize_sequence, args), **kwargs)
   1444     bound = new_sig.bind(ax, *args, **kwargs)
   1445     auto_label = (bound.arguments.get(label_namer)
   1446                   or bound.kwargs.get(label_namer))

File /data/anaconda/envs/pytomographytest/lib/python3.9/site-packages/matplotlib/axes/_axes.py:6229, in Axes.pcolormesh(self, alpha, norm, cmap, vmin, vmax, shading, antialiased, *args, **kwargs)
   6225     C = C.ravel()
   6227 kwargs.setdefault('snap', mpl.rcParams['pcolormesh.snap'])
-> 6229 collection = mcoll.QuadMesh(
   6230     coords, antialiased=antialiased, shading=shading,
   6231     array=C, cmap=cmap, norm=norm, alpha=alpha, **kwargs)
   6232 collection._scale_norm(norm, vmin, vmax)
   6234 coords = coords.reshape(-1, 2)  # flatten the grid structure; keep x, y

File /data/anaconda/envs/pytomographytest/lib/python3.9/site-packages/matplotlib/collections.py:1939, in QuadMesh.__init__(self, coordinates, antialiased, shading, **kwargs)
   1936 self._bbox.update_from_data_xy(self._coordinates.reshape(-1, 2))
   1937 # super init delayed after own init because array kwarg requires
   1938 # self._coordinates and self._shading
-> 1939 super().__init__(**kwargs)
   1940 self.set_mouseover(False)

File /data/anaconda/envs/pytomographytest/lib/python3.9/site-packages/matplotlib/_api/deprecation.py:454, in make_keyword_only.<locals>.wrapper(*args, **kwargs)
    448 if len(args) > name_idx:
    449     warn_deprecated(
    450         since, message="Passing the %(name)s %(obj_type)s "
    451         "positionally is deprecated since Matplotlib %(since)s; the "
    452         "parameter will become keyword-only %(removal)s.",
    453         name=name, obj_type=f"parameter of {func.__name__}()")
--> 454 return func(*args, **kwargs)

File /data/anaconda/envs/pytomographytest/lib/python3.9/site-packages/matplotlib/collections.py:201, in Collection.__init__(self, edgecolors, facecolors, linewidths, linestyles, capstyle, joinstyle, antialiaseds, offsets, offset_transform, norm, cmap, pickradius, hatch, urls, zorder, **kwargs)
    198 self._offset_transform = offset_transform
    200 self._path_effects = None
--> 201 self._internal_update(kwargs)
    202 self._paths = None

File /data/anaconda/envs/pytomographytest/lib/python3.9/site-packages/matplotlib/artist.py:1223, in Artist._internal_update(self, kwargs)
   1216 def _internal_update(self, kwargs):
   1217     """
   1218     Update artist properties without prenormalizing them, but generating
   1219     errors as if calling `set`.
   1220
   1221     The lack of prenormalization is to maintain backcompatibility.
   1222     """
-> 1223     return self._update_props(
   1224         kwargs, "{cls.__name__}.set() got an unexpected keyword argument "
   1225         "{prop_name!r}")

File /data/anaconda/envs/pytomographytest/lib/python3.9/site-packages/matplotlib/artist.py:1199, in Artist._update_props(self, props, errfmt)
   1196             if not callable(func):
   1197                 raise AttributeError(
   1198                     errfmt.format(cls=type(self), prop_name=k))
-> 1199             ret.append(func(v))
   1200 if ret:
   1201     self.pchanged()

File /data/anaconda/envs/pytomographytest/lib/python3.9/site-packages/matplotlib/collections.py:1982, in QuadMesh.set_array(self, A)
   1980     shape = np.shape(A)
   1981     if shape not in ok_shapes:
-> 1982         raise ValueError(
   1983             f"For X ({width}) and Y ({height}) with {self._shading} "
   1984             f"shading, A should have shape "
   1985             f"{' or '.join(map(str, ok_shapes))}, not {A.shape}")
   1986 return super().set_array(A)

ValueError: For X (129) and Y (385) with flat shading, A should have shape (384, 128, 3) or (384, 128, 4) or (384, 128) or (49152,), not (384, 128, 1)
../_images/notebooks_t_spectpriors_21_2.png