Bayesian Priors#
[1]:
import os
from pytomography.io.SPECT import simind
from pytomography.priors import RelativeDifferencePrior
from pytomography.projectors import SPECTSystemMatrix
from pytomography.transforms import SPECTAttenuationTransform, SPECTPSFTransform
from pytomography.algorithms import BSREM, OSEM
from pytomography.priors import QuadraticPrior, RelativeDifferencePrior
from pytomography.priors import TopNAnatomyNeighbourWeight
from torch import poisson
import torch
import matplotlib.pyplot as plt
from pytomography.callbacks import CallBack
import numpy as np
Modify the following path to the directory where you saved the tutorial data:
[2]:
path = '/disk1/pytomography_tutorial_data/simind_tutorial/'
The first cell of code is borrowed from the SPECT: Reconstructing SIMIND Data
tutorial in the multiple regions case. For a more comprehensive description about the code below, please see that tutorial.
[3]:
organs = ['bkg', 'liver', 'l_lung', 'r_lung', 'l_kidney', 'r_kidney','salivary', 'bladder']
activities = [2500, 450, 7, 7, 100, 100, 20, 90] # MBq
headerfiles = [os.path.join(path, 'multi_projections', organ, 'photopeak.h00') for organ in organs]
headerfiles_lower = [os.path.join(path, 'multi_projections', organ, 'lowerscatter.h00') for organ in organs]
headerfiles_upper = [os.path.join(path, 'multi_projections', organ, 'upperscatter.h00') for organ in organs]
object_meta, proj_meta = simind.get_metadata(headerfiles[0]) #assumes the same for all
photopeak = simind.combine_projection_data(headerfiles, activities)
scatter = simind.combine_scatter_data_TEW(headerfiles, headerfiles_lower, headerfiles_upper, activities)
# Convert from CPS to counts
dT = 15 #s
photopeak *= dT
scatter *= dT
photopeak_poisson = poisson(photopeak)
scatter_poisson = poisson(scatter)
# Obtain required transforms to build system matrix
attenuation_map = simind.get_attenuation_map(os.path.join(path, 'multi_projections', 'mu208.hct'))
att_transform = SPECTAttenuationTransform(attenuation_map)
psf_meta = simind.get_psfmeta_from_header(headerfiles[0])
psf_transform = SPECTPSFTransform(psf_meta)
system_matrix = SPECTSystemMatrix(
obj2obj_transforms = [att_transform,psf_transform],
proj2proj_transforms = [],
object_meta = object_meta,
proj_meta = proj_meta,
n_parallel=8)
Once the system matrix is setup, its time to define the priors for the reconstruction algorithm. PyTomography has many useful prior functions for SPECT reconstruction. Most of them are derived from NearestNeighbourPrior
; meaning prior information for each voxel only depends on the neighbouring 26 voxels. Priors can be defined as follows:
[4]:
prior_quad = QuadraticPrior(beta=0.3)
prior_rdp = RelativeDifferencePrior(beta=0.3, gamma=2)
By default, the “contribution” (or weight) from the 26 neighbouring voxels is scaled by the Euclidean distance from the central voxel. Voxels on the corners (8 total) therefore only contribute \(1/\sqrt{3}\) compared to the adjacent 6 (up/down, left/right, front/back). The weight in this case is an instance of EuclideanNeighbourWeight
.
PyTomography also has options to use other weights for neighbouring voxels. For example, it can use anatomical images (such as attenuation maps and CT images) to obtain a different weighting scheme. For example, the TopNAnatomyNeighbourWeight
is an extension of EuclideanNeighbourWeight
that only uses non-zero weights for the top-N closest neighbours based on an external anatomical image:
[5]:
weight_top8anatomy = TopNAnatomyNeighbourWeight(attenuation_map, N_neighbours=8)
This can be used to create custom priors with the given weighing. These are typically referred to as anatomical priors (AP) and would be designated as Quadratic-AP or RDP-AP
[6]:
prior_quad_weighttop8 = QuadraticPrior(beta=0.3, weight=weight_top8anatomy)
prior_rdp_weighttop8 = RelativeDifferencePrior(beta=0.3, gamma=2, weight=weight_top8anatomy)
We can now reconstruct this object using no prior, and the four prior functions we just defined
[7]:
class ComputeLogLiklihood(CallBack):
def __init__(self, projections, system_matrix):
self.projections = projections
self.system_matrix = system_matrix
self.liklihoods = []
def run(self, object, n_iter):
projection_estimate = system_matrix.forward(object)
liklihood = self.projections*torch.log(projection_estimate) - projection_estimate
liklihood[self.projections<=0] = -projection_estimate[self.projections<=0]
self.liklihoods.append(liklihood.sum().item())
[8]:
def reconstruct(
prior = None,
n_iters=50,
n_subsets=8,
relaxation_function=lambda n: 1,
scaling_matrix_type='subind_norm'
):
cb = ComputeLogLiklihood(photopeak_poisson, system_matrix)
reconstruction_algorithm = BSREM(
projections = photopeak_poisson,
system_matrix = system_matrix,
scatter = scatter_poisson,
prior=prior,
relaxation_function=relaxation_function,
scaling_matrix_type=scaling_matrix_type)
return reconstruction_algorithm(n_iters, n_subsets, callback=cb), cb
First we can reconstructed using a relaxation sequence that corresponds to “unrelaxed” (\(\alpha_n=1\)) and “relaxed” (\(\alpha_n=1/(n/15+1)\))
[22]:
recon_unrelaxed, cb_unrelaxed = reconstruct(relaxation_function=lambda n: 1, n_subsets=16, prior=prior_rdp)
recon_relaxed, cb_relaxed = reconstruct(relaxation_function=lambda n: 1/(n/50+1), n_subsets=16, prior=prior_rdp)
[23]:
plt.subplots(1,2,figsize=(4,5))
plt.subplot(121)
plt.pcolormesh(recon_unrelaxed[0].cpu()[:,70].T, cmap='nipy_spectral')
plt.colorbar()
plt.axis('off')
plt.title('Unrelaxed')
plt.subplot(122)
plt.pcolormesh(recon_relaxed[0].cpu()[:,70].T, cmap='nipy_spectral')
plt.colorbar()
plt.axis('off')
plt.title('Relaxed')
[23]:
Text(0.5, 1.0, 'Relaxed')

We can also plot the log-liklihoods:
[24]:
plt.figure(figsize=(4,2))
plt.plot(cb_unrelaxed.liklihoods, label='Unrelaxed')
plt.plot(cb_relaxed.liklihoods, ls='--', label='Relaxed')
plt.ylim(bottom=0.9995*max(cb_unrelaxed.liklihoods), top=1.00001*max(cb_unrelaxed.liklihoods))
plt.xlabel('Iteration')
plt.ylabel('$\log L$')
plt.legend()
plt.show()

[12]:
recon_noprior = reconstruct(None)
recon_quad = reconstruct(prior_quad)
recon_rdp = reconstruct(prior_rdp)
recon_quad_weighttop8 = reconstruct(prior_quad_weighttop8)
recon_rdp_weighttop8 = reconstruct(prior_rdp_weighttop8)
[13]:
plt.subplots(1,5,figsize=(12,5))
plt.subplot(151)
plt.pcolormesh(recon_noprior[0].cpu()[:,70].T, cmap='nipy_spectral')
plt.colorbar()
plt.axis('off')
plt.title('No Prior')
plt.subplot(152)
plt.pcolormesh(recon_quad[0].cpu()[:,70].T, cmap='nipy_spectral')
plt.colorbar()
plt.axis('off')
plt.title('Quadratic')
plt.subplot(153)
plt.pcolormesh(recon_rdp[0].cpu()[:,70].T, cmap='nipy_spectral')
plt.colorbar()
plt.axis('off')
plt.title('RDP')
plt.subplot(154)
plt.pcolormesh(recon_quad_weighttop8[0].cpu()[:,70].T, cmap='nipy_spectral')
plt.colorbar()
plt.axis('off')
plt.title('Quadratic-AP')
plt.subplot(155)
plt.pcolormesh(recon_rdp_weighttop8[0].cpu()[:,70].T, cmap='nipy_spectral')
plt.colorbar()
plt.axis('off')
plt.title('RDP-AP')
plt.show()
/tmp/ipykernel_14436/447662742.py:3: UserWarning: The use of `x.T` on tensors of dimension other than 2 to reverse their shape is deprecated and it will throw an error in a future release. Consider `x.mT` to transpose batches of matricesor `x.permute(*torch.arange(x.ndim - 1, -1, -1))` to reverse the dimensions of a tensor. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:2981.)
plt.pcolormesh(recon_noprior[0].cpu()[:,70].T, cmap='nipy_spectral')
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/home/gpuvmadm/PyTomography/docs/source/notebooks/t_spectpriors.ipynb Cell 22 line <cell line: 3>()
<a href='vscode-notebook-cell://ssh-remote%2Bgpuvm00004jhubvm01.canadacentral.cloudapp.azure.com/home/gpuvmadm/PyTomography/docs/source/notebooks/t_spectpriors.ipynb#Y114sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0'>1</a> plt.subplots(1,5,figsize=(12,5))
<a href='vscode-notebook-cell://ssh-remote%2Bgpuvm00004jhubvm01.canadacentral.cloudapp.azure.com/home/gpuvmadm/PyTomography/docs/source/notebooks/t_spectpriors.ipynb#Y114sdnNjb2RlLXJlbW90ZQ%3D%3D?line=1'>2</a> plt.subplot(151)
----> <a href='vscode-notebook-cell://ssh-remote%2Bgpuvm00004jhubvm01.canadacentral.cloudapp.azure.com/home/gpuvmadm/PyTomography/docs/source/notebooks/t_spectpriors.ipynb#Y114sdnNjb2RlLXJlbW90ZQ%3D%3D?line=2'>3</a> plt.pcolormesh(recon_noprior[0].cpu()[:,70].T, cmap='nipy_spectral')
<a href='vscode-notebook-cell://ssh-remote%2Bgpuvm00004jhubvm01.canadacentral.cloudapp.azure.com/home/gpuvmadm/PyTomography/docs/source/notebooks/t_spectpriors.ipynb#Y114sdnNjb2RlLXJlbW90ZQ%3D%3D?line=3'>4</a> plt.colorbar()
<a href='vscode-notebook-cell://ssh-remote%2Bgpuvm00004jhubvm01.canadacentral.cloudapp.azure.com/home/gpuvmadm/PyTomography/docs/source/notebooks/t_spectpriors.ipynb#Y114sdnNjb2RlLXJlbW90ZQ%3D%3D?line=4'>5</a> plt.axis('off')
File /data/anaconda/envs/pytomographytest/lib/python3.9/site-packages/matplotlib/pyplot.py:2773, in pcolormesh(alpha, norm, cmap, vmin, vmax, shading, antialiased, data, *args, **kwargs)
2768 @_copy_docstring_and_deprecators(Axes.pcolormesh)
2769 def pcolormesh(
2770 *args, alpha=None, norm=None, cmap=None, vmin=None,
2771 vmax=None, shading=None, antialiased=False, data=None,
2772 **kwargs):
-> 2773 __ret = gca().pcolormesh(
2774 *args, alpha=alpha, norm=norm, cmap=cmap, vmin=vmin,
2775 vmax=vmax, shading=shading, antialiased=antialiased,
2776 **({"data": data} if data is not None else {}), **kwargs)
2777 sci(__ret)
2778 return __ret
File /data/anaconda/envs/pytomographytest/lib/python3.9/site-packages/matplotlib/__init__.py:1442, in _preprocess_data.<locals>.inner(ax, data, *args, **kwargs)
1439 @functools.wraps(func)
1440 def inner(ax, *args, data=None, **kwargs):
1441 if data is None:
-> 1442 return func(ax, *map(sanitize_sequence, args), **kwargs)
1444 bound = new_sig.bind(ax, *args, **kwargs)
1445 auto_label = (bound.arguments.get(label_namer)
1446 or bound.kwargs.get(label_namer))
File /data/anaconda/envs/pytomographytest/lib/python3.9/site-packages/matplotlib/axes/_axes.py:6229, in Axes.pcolormesh(self, alpha, norm, cmap, vmin, vmax, shading, antialiased, *args, **kwargs)
6225 C = C.ravel()
6227 kwargs.setdefault('snap', mpl.rcParams['pcolormesh.snap'])
-> 6229 collection = mcoll.QuadMesh(
6230 coords, antialiased=antialiased, shading=shading,
6231 array=C, cmap=cmap, norm=norm, alpha=alpha, **kwargs)
6232 collection._scale_norm(norm, vmin, vmax)
6234 coords = coords.reshape(-1, 2) # flatten the grid structure; keep x, y
File /data/anaconda/envs/pytomographytest/lib/python3.9/site-packages/matplotlib/collections.py:1939, in QuadMesh.__init__(self, coordinates, antialiased, shading, **kwargs)
1936 self._bbox.update_from_data_xy(self._coordinates.reshape(-1, 2))
1937 # super init delayed after own init because array kwarg requires
1938 # self._coordinates and self._shading
-> 1939 super().__init__(**kwargs)
1940 self.set_mouseover(False)
File /data/anaconda/envs/pytomographytest/lib/python3.9/site-packages/matplotlib/_api/deprecation.py:454, in make_keyword_only.<locals>.wrapper(*args, **kwargs)
448 if len(args) > name_idx:
449 warn_deprecated(
450 since, message="Passing the %(name)s %(obj_type)s "
451 "positionally is deprecated since Matplotlib %(since)s; the "
452 "parameter will become keyword-only %(removal)s.",
453 name=name, obj_type=f"parameter of {func.__name__}()")
--> 454 return func(*args, **kwargs)
File /data/anaconda/envs/pytomographytest/lib/python3.9/site-packages/matplotlib/collections.py:201, in Collection.__init__(self, edgecolors, facecolors, linewidths, linestyles, capstyle, joinstyle, antialiaseds, offsets, offset_transform, norm, cmap, pickradius, hatch, urls, zorder, **kwargs)
198 self._offset_transform = offset_transform
200 self._path_effects = None
--> 201 self._internal_update(kwargs)
202 self._paths = None
File /data/anaconda/envs/pytomographytest/lib/python3.9/site-packages/matplotlib/artist.py:1223, in Artist._internal_update(self, kwargs)
1216 def _internal_update(self, kwargs):
1217 """
1218 Update artist properties without prenormalizing them, but generating
1219 errors as if calling `set`.
1220
1221 The lack of prenormalization is to maintain backcompatibility.
1222 """
-> 1223 return self._update_props(
1224 kwargs, "{cls.__name__}.set() got an unexpected keyword argument "
1225 "{prop_name!r}")
File /data/anaconda/envs/pytomographytest/lib/python3.9/site-packages/matplotlib/artist.py:1199, in Artist._update_props(self, props, errfmt)
1196 if not callable(func):
1197 raise AttributeError(
1198 errfmt.format(cls=type(self), prop_name=k))
-> 1199 ret.append(func(v))
1200 if ret:
1201 self.pchanged()
File /data/anaconda/envs/pytomographytest/lib/python3.9/site-packages/matplotlib/collections.py:1982, in QuadMesh.set_array(self, A)
1980 shape = np.shape(A)
1981 if shape not in ok_shapes:
-> 1982 raise ValueError(
1983 f"For X ({width}) and Y ({height}) with {self._shading} "
1984 f"shading, A should have shape "
1985 f"{' or '.join(map(str, ok_shapes))}, not {A.shape}")
1986 return super().set_array(A)
ValueError: For X (129) and Y (385) with flat shading, A should have shape (384, 128, 3) or (384, 128, 4) or (384, 128) or (49152,), not (384, 128, 1)
