Source code for pytomography.transforms.shared.kem

from __future__ import annotations
import numpy as np
from pytomography.utils import get_object_nearest_neighbour
import torch
from pytomography.transforms import Transform


[docs]class KEMTransform(Transform): r"""Object to object transform used to take in a coefficient image :math:`\alpha` and return an image estimate :math:`f = K\alpha`. This transform implements the matrix :math:`K`. Args: support_objects (Sequence[torch.tensor]): Objects used for support when building each basis function. These may correspond to PET/CT/MRI images, for example. support_kernels (Sequence[Callable], optional): A list of functions corresponding to the support kernel of each support object. If none, defaults to :math:`k(v_i, v_j; \sigma) = \exp\left(-\frac{(v_i-v_j)^2}{2\sigma^2} \right)` for each support object. Defaults to None. support_kernels_params (Sequence[Sequence[float]], optional): A list of lists, where each sublist contains the additional parameters corresponding to each support kernel (parameters that follow the semi-colon in the expression above). As an example, if using the default configuration for ``support_kernels`` for two different support objects (say CT and PET), one could given ``support_kernel_params=[[40],[5]]`` If none then defaults to a list of `N*[[1]]` where `N` is the number of support objects. Defaults to None. distance_kernel (Callable, optional): Kernel used to weight based on voxel-voxel distance. If none, defaults to :math:`k(x_i, x_j; \sigma) = \exp\left(-\frac{(x_i-x_j)^2}{2\sigma^2} \right) Defaults to None. distance_kernel_params (_type_, optional): A list of parameters corresponding to additional parameters for the ``distance_kernel`` (i.e. the parameters that follow the semi-colon in the expression above). If none, then defaults to :math:`\sigma=1`. Defaults to None. size (int, optional): The size of each kernel. Defaults to 5. """ def __init__( self, support_objects, support_kernels = None, support_kernels_params = None, distance_kernel = None, distance_kernel_params = None, size: int = 5 ) -> None: super(KEMTransform, self).__init__() self.support_objects = support_objects if support_kernels is None: # If not given, all default to Gaussian functions self.support_kernels = [lambda obj_f, obj_j, sigma: torch.exp(-(obj_f - obj_j)**2 / (2*sigma**2)) for _ in range(len(support_objects))] else: self.support_kernels = support_kernels if support_kernels_params is None: # If not given, parameters default to sigma=1 for each kernel self.support_kernel_params = [[1] for _ in range(len(support_objects))] else: self.support_kernel_params = support_kernels_params if distance_kernel is None: # If not given, defaults to Gaussian function self.distance_kernel = lambda d, sigma: np.exp(-d**2 / (2*sigma**2)) else: self.distance_kernel = distance_kernel if distance_kernel_params is None: # If not given, defaults to sigma = 1cm self.distance_kernel_params = [1] else: self.distance_kernel_params = distance_kernel_params idx_max = int((size - 1) / 2) self.idxs = np.arange(-idx_max, idx_max+1) print(self.distance_kernel_params) @torch.no_grad()
[docs] def forward( self, object: torch.Tensor, ) -> torch.tensor: r"""Forward transform corresponding to :math:`K\alpha` Args: object (torch.Tensor): Coefficient image :math:`\alpha` Returns: torch.tensor: Image :math:`K\alpha` """ object_return = torch.zeros(object.shape).to(self.device) total = 0 for i in self.idxs: for j in self.idxs: for k in self.idxs: kernel_component = 1 # Distance Component d = np.sqrt((self.object_meta.dx*i)**2 + (self.object_meta.dy*j)**2 + (self.object_meta.dz*k)**2) kernel_component*=self.distance_kernel(d, *self.distance_kernel_params) # All support objects: for l in range(len(self.support_objects)): neighbour_support_object = get_object_nearest_neighbour(self.support_objects[l], (i,j,k)) kernel_component *= self.support_kernels[l](self.support_objects[l], neighbour_support_object, *self.support_kernel_params[l]) neighbour = get_object_nearest_neighbour(object, (i,j,k)) object_return += kernel_component * neighbour total += kernel_component # for normalization object_return /= total return object_return
@torch.no_grad()
[docs] def backward( self, object: torch.Tensor, norm_constant: torch.Tensor | None = None, ) -> torch.tensor: r"""Backward transform corresponding to :math:`K^T\alpha`. Since the matrix is symmetric, the implementation is the same as forward. Args: object (torch.Tensor): Coefficient image :math:`\alpha` Returns: torch.tensor: Image :math:`K^T\alpha` """ object = self.forward(object) if norm_constant is not None: norm_constant = self.forward(norm_constant) return object, norm_constant else: return object