[1]:

from __future__ import annotations import os import sys import numpy as np sys.path.append('/home/gpuvmadm/PyTomography/src') import pytomography from pytomography.metadata import PSFMeta from pytomography.mappings import MapNet from pytomography.utils import pad_object, pad_image, unpad_object, unpad_image, rotate_detector_z from pytomography.metadata import ObjectMeta, ImageMeta from pytomography.projections import ForwardProjectionNet, BackProjectionNet from pytomography.algorithms import OSEMOSL from pytomography.priors import RelativeDifferencePrior from pytomography.mappings import PETAttenuationNet, PETPSFNet from scipy.ndimage import gaussian_filter import matplotlib.pyplot as plt plt.style.use('seaborn-white') from matplotlib import colors import torch from scipy.ndimage import map_coordinates from scipy.interpolate import interpn from scipy.ndimage import zoom device = torch.device("cuda" if torch.cuda.is_available() else "cpu") pytomography.device = device
/tmp/ipykernel_3745/3448874680.py:17: MatplotlibDeprecationWarning: The seaborn styles shipped by Matplotlib are deprecated since 3.6, as they no longer correspond to the styles shipped by seaborn. However, they will remain available as 'seaborn-v0_8-<style>'. Alternatively, directly use the seaborn API instead.
  plt.style.use('seaborn-white')
[2]:
path = '/home/gpuvmadm/PyTomography/test_files'
[3]:
GT = np.fromfile(os.path.join(path,'body2_F18.bin'), dtype=np.float32)
CT = np.fromfile(os.path.join(path,'body_atn_511keV.bin'), dtype=np.float32)
[4]:
GT = GT.reshape((751,512,512))
CT = CT.reshape((751,512,512))
[5]:
GT = zoom(GT, (384/751, 256/512, 256/512), order=1)
CT = zoom(CT, (384/751, 256/512, 256/512), order=1)
[6]:
GT = np.transpose(GT, (2,1,0))
CT = np.transpose(CT, (2,1,0))
[7]:
GT = torch.tensor(GT).unsqueeze(dim=0)
CT = torch.tensor(CT).unsqueeze(dim=0)
[8]:
object_meta = ObjectMeta(dr=(0.15,0.15,0.767), shape=[256,256,384])
image_meta = ImageMeta(object_meta, angles=np.linspace(0,180,574, endpoint=False), radii=44.3*np.ones(574))
[9]:
dx, dy, dz = object_meta.dr
scale_factor = dx*dy*dz * 1.76 #1.76 cps / Bq
GT = GT * scale_factor
[10]:
def kernel_noncol(x,r,R, delta=1e-8):
    if r**2<R**2:
        sigma = torch.sqrt(R**2 - r**2)/4 * np.pi / 180
    else:
        sigma = torch.zeros(r.shape) + delta
    result = torch.exp(-x**2/sigma**2 / 2)
    return result / (torch.sum(result)+delta)

def kernel_penetration(x,r,R,mu=0.87, delta=1e-8):
    result = torch.exp(-torch.abs(mu*x / ((r/R)*torch.sqrt(1-(r/R)**2) + delta)))
    if r>=0:
        result*= x <= 0
    else:
        result*= x >= 0
    return result / (torch.sum(result)+delta)

def kernel_scattering(x,r,R,scatter_fact=0.327, delta=1e-8):
    sigma = scatter_fact * torch.sqrt(1-(r/R)**2) / (2 * np.sqrt(2*np.log(2))) # fwhm -> sigma
    result = torch.exp(-x**2/sigma**2 / 2)
    return result / (torch.sum(result)+delta)
[11]:
kernels = [kernel_noncol, kernel_penetration, lambda x, r, R: kernel_penetration(2*x, r, R), kernel_scattering]
[12]:
ctnet = PETAttenuationNet(CT = CT, device=device)
psfnet = PETPSFNet(kernels, device=device)
fp_net = ForwardProjectionNet(obj2obj_nets=[],
                              im2im_nets=[ctnet,psfnet],
                              object_meta=object_meta,
                              image_meta=image_meta)
bp_net = BackProjectionNet(obj2obj_nets=[],
                           im2im_nets=[ctnet,psfnet],
                           object_meta=object_meta,
                           image_meta=image_meta)
/data/anaconda/envs/torch/lib/python3.9/site-packages/torch/functional.py:478: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at  ../aten/src/ATen/native/TensorShape.cpp:2894.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
[13]:
plt.figure(figsize=(10,2))
plt.plot(psfnet.PSF_matrix[0][0][0][300].cpu())
plt.show()
../_images/notebooks_t_2DPET_12_0.png
[14]:
projections_analytic = fp_net(torch.poisson(GT))
[15]:
plt.figure(figsize=(3,6))
plt.pcolormesh(projections_analytic[0,0].cpu().T, cmap='nipy_spectral')
plt.colorbar()
[15]:
<matplotlib.colorbar.Colorbar at 0x7f2140597fd0>
../_images/notebooks_t_2DPET_14_1.png

Make a reconstruction network

[16]:
rdp = RelativeDifferencePrior(beta=20, gamma=2)
[17]:
osem_net = OSEMOSL(image = projections_analytic,
                   forward_projection_net=fp_net,
                   back_projection_net=bp_net)
reconstructed_object = osem_net(n_iters=2, n_subsets=7)
[18]:
osem_net = OSEMOSL(image = projections_analytic,
                   forward_projection_net=fp_net,
                   back_projection_net=bp_net,
                   prior=rdp)
reconstructed_object_rdp = osem_net(n_iters=10, n_subsets=7)
[19]:
plt.subplots(1,2,figsize=(10,10))
plt.subplot(121)
plt.title('2it_7sub_noprior')
plt.pcolormesh(reconstructed_object[0][:,128].cpu().T, cmap='nipy_spectral', vmax=4.5)
plt.colorbar()
plt.subplot(122)
plt.title('10it_7sub_rdpbeta20gamma2')
plt.pcolormesh(reconstructed_object_rdp[0][:,128].cpu().T, cmap='nipy_spectral', vmax=4.5)
plt.colorbar()
[19]:
<matplotlib.colorbar.Colorbar at 0x7f214045d790>
../_images/notebooks_t_2DPET_19_1.png
[23]:
from torch.nn.functional import pad
[42]:
x = torch.arange(125).reshape(1,5,5,5)
x = pad(x, [1,1,1,1,1,1])
x = torch.roll(x, shifts=(-1,1,1), dims=(1,2,3))
x = x[:,1:-1,1:-1,1:-1]
[43]:
x
[43]:
tensor([[[[  0,   0,   0,   0,   0],
          [  0,  25,  26,  27,  28],
          [  0,  30,  31,  32,  33],
          [  0,  35,  36,  37,  38],
          [  0,  40,  41,  42,  43]],

         [[  0,   0,   0,   0,   0],
          [  0,  50,  51,  52,  53],
          [  0,  55,  56,  57,  58],
          [  0,  60,  61,  62,  63],
          [  0,  65,  66,  67,  68]],

         [[  0,   0,   0,   0,   0],
          [  0,  75,  76,  77,  78],
          [  0,  80,  81,  82,  83],
          [  0,  85,  86,  87,  88],
          [  0,  90,  91,  92,  93]],

         [[  0,   0,   0,   0,   0],
          [  0, 100, 101, 102, 103],
          [  0, 105, 106, 107, 108],
          [  0, 110, 111, 112, 113],
          [  0, 115, 116, 117, 118]],

         [[  0,   0,   0,   0,   0],
          [  0,   0,   0,   0,   0],
          [  0,   0,   0,   0,   0],
          [  0,   0,   0,   0,   0],
          [  0,   0,   0,   0,   0]]]])
[ ]: