Source code for pytomography.transforms.shared.kem
from __future__ import annotations
import numpy as np
from pytomography.utils import get_object_nearest_neighbour
import torch
import pytomography
from pytomography.transforms import Transform
from pytomography.metadata import ObjectMeta, ProjMeta
[docs]class KEMTransform(Transform):
r"""Object to object transform used to take in a coefficient image :math:`\alpha` and return an image estimate :math:`f = K\alpha`. This transform implements the matrix :math:`K`.
Args:
support_objects (Sequence[torch.tensor]): Objects used for support when building each basis function. These may correspond to PET/CT/MRI images, for example.
support_kernels (Sequence[Callable], optional): A list of functions corresponding to the support kernel of each support object. If none, defaults to :math:`k(v_i, v_j; \sigma) = \exp\left(-\frac{(v_i-v_j)^2}{2\sigma^2} \right)` for each support object. Defaults to None.
support_kernels_params (Sequence[Sequence[float]], optional): A list of lists, where each sublist contains the additional parameters corresponding to each support kernel (parameters that follow the semi-colon in the expression above). As an example, if using the default configuration for ``support_kernels`` for two different support objects (say CT and PET), one could given ``support_kernel_params=[[40],[5]]`` If none then defaults to a list of `N*[[1]]` where `N` is the number of support objects. Defaults to None.
distance_kernel (Callable, optional): Kernel used to weight based on voxel-voxel distance. If none, defaults to :math:`k(x_i, x_j; \sigma) = \exp\left(-\frac{(x_i-x_j)^2}{2\sigma^2} \right) Defaults to None.
distance_kernel_params (_type_, optional): A list of parameters corresponding to additional parameters for the ``distance_kernel`` (i.e. the parameters that follow the semi-colon in the expression above). If none, then defaults to :math:`\sigma=1`. Defaults to None.
size (int, optional): The size of each kernel. Defaults to 5.
"""
def __init__(
self,
support_objects,
support_kernels = None,
support_kernels_params = None,
distance_kernel = None,
distance_kernel_params = None,
size: int = 5,
top_N: int | None = None,
kernel_on_gpu: bool = False
) -> None:
super(KEMTransform, self).__init__()
self.support_objects = support_objects
if support_kernels is None:
# If not given, all default to Gaussian functions
self.support_kernels = [lambda obj_f, obj_j, sigma: torch.exp(-(obj_f - obj_j)**2 / (2*sigma**2)) for _ in range(len(support_objects))]
else:
self.support_kernels = support_kernels
if support_kernels_params is None:
# If not given, parameters default to sigma=1 for each kernel
self.support_kernel_params = [[1] for _ in range(len(support_objects))]
else:
self.support_kernel_params = support_kernels_params
if distance_kernel is None:
# If not given, defaults to Gaussian function
self.distance_kernel = lambda d, sigma: np.exp(-d**2 / (2*sigma**2))
else:
self.distance_kernel = distance_kernel
if distance_kernel_params is None:
# If not given, defaults to sigma = 1cm
self.distance_kernel_params = [1]
else:
self.distance_kernel_params = distance_kernel_params
self.size = size
self.idx_max = int((size - 1) / 2)
self.idxs = torch.arange(-self.idx_max, self.idx_max+1)
self.top_N = top_N
self.kernel_on_gpu = kernel_on_gpu
[docs] def compute_kernel(self):
shape = self.support_objects[0].shape[1:]
# Keep kernel on CPU until its used (its too big for GPU)
self.kernel = torch.ones((self.size, self.size, self.size, *shape)).to(pytomography.dtype)
for i in self.idxs:
for j in self.idxs:
for k in self.idxs:
kernel_component = 1
# All support objects:
for l in range(len(self.support_objects)):
neighbour_support_object = get_object_nearest_neighbour(self.support_objects[l].cpu(), (i,j,k))
kernel_component *= self.support_kernels[l](self.support_objects[l].cpu(), neighbour_support_object, *self.support_kernel_params[l])
self.kernel[i+self.idx_max,j+self.idx_max,k+self.idx_max] = kernel_component
# Get only top N entries (based on anatomical support)
if self.top_N is not None:
self.kernel *= (torch.argsort(torch.argsort(self.kernel.reshape((self.size**3,*shape)), dim=0), dim=0)>=self.size**3 - self.top_N).reshape((self.size, self.size, self.size,*shape))
# Scale by distance
xv, yv, zv = torch.meshgrid([self.object_meta.dx*self.idxs, self.object_meta.dy*self.idxs, self.object_meta.dz*self.idxs])
d = torch.sqrt(xv**2+yv**2+zv**2)
self.kernel *= self.distance_kernel(d, *self.distance_kernel_params).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
# Normalize
self.kernel /= self.kernel.sum(dim=(0,1,2))
# Put on GPU (defaults to false since very large)
if self.kernel_on_gpu:
self.kernel = self.kernel.to(pytomography.device)
[docs] def configure(
self,
object_meta: ObjectMeta,
proj_meta: ProjMeta
) -> None:
"""Function used to initalize the transform using corresponding object and projection metadata
Args:
object_meta (SPECTObjectMeta): Object metadata.
proj_meta (SPECTProjMeta): Projections metadata.
"""
super(KEMTransform, self).configure(object_meta, proj_meta)
self.compute_kernel()
@torch.no_grad()
[docs] def forward(
self,
object: torch.Tensor,
) -> torch.tensor:
r"""Forward transform corresponding to :math:`K\alpha`
Args:
object (torch.Tensor): Coefficient image :math:`\alpha`
Returns:
torch.tensor: Image :math:`K\alpha`
"""
object_return = torch.zeros(object.shape).to(self.device)
for i in self.idxs:
for j in self.idxs:
for k in self.idxs:
neighbour = get_object_nearest_neighbour(object, (i,j,k))
object_return += self.kernel[i+self.idx_max,j+self.idx_max,k+self.idx_max].to(pytomography.device) * neighbour
return object_return
@torch.no_grad()
[docs] def backward(
self,
object: torch.Tensor,
norm_constant: torch.Tensor | None = None,
) -> torch.tensor:
r"""Backward transform corresponding to :math:`K^T\alpha`. Since the matrix is symmetric, the implementation is the same as forward.
Args:
object (torch.Tensor): Coefficient image :math:`\alpha`
Returns:
torch.tensor: Image :math:`K^T\alpha`
"""
object = self.forward(object)
if norm_constant is not None:
norm_constant = self.forward(norm_constant)
return object, norm_constant
else:
return object