Source code for pytomography.projectors.shared.kem_system_matrix

from ..system_matrix import SystemMatrix
import torch

[docs]class KEMSystemMatrix(SystemMatrix): """Given a KEM transform :math:`K` and a system matrix :math:`H`, implements the transform :math:`HK` (and backward transform :math:`K^T H^T`) Args: system_matrix (SystemMatrix): System matrix corresponding to a particular imaging system kem_transform (KEMTransform): Transform used to go from coefficient image to real image of predicted counts. """ def __init__(self, system_matrix, kem_transform): self.object_meta = system_matrix.object_meta self.proj_meta = system_matrix.proj_meta self.system_matrix = system_matrix self.kem_transform = kem_transform # Inherit required functions from system matrix self.set_n_subsets = self.system_matrix.set_n_subsets self.get_projection_subset = self.system_matrix.get_projection_subset self.get_weighting_subset = self.system_matrix.get_projection_subset
[docs] def compute_normalization_factor(self, subset_idx : int | None = None): """Function used to get normalization factor :math:`K^T H^T_m 1` corresponding to projection subset :math:`m`. Args: subset_idx (int | None, optional): Index of subset. If none, then considers all projections. Defaults to None. Returns: torch.Tensor: normalization factor :math:`K^T H^T_m 1` """ object = self.system_matrix.compute_normalization_factor(subset_idx) return self.kem_transform.backward(object)
[docs] def forward(self, object, subset_idx=None): r"""Forward transform :math:`HK` Args: object (torch.tensor): Object to be forward projected subset_idx (int, optional): Only uses a subset of angles :math:`g_m` corresponding to the provided subset index :math:`m`. If None, then defaults to the full projections :math:`g`. Returns: torch.tensor: Corresponding projections generated from forward projection """ object = self.kem_transform.forward(object) return self.system_matrix.forward(object, subset_idx)
[docs] def backward(self, proj, subset_idx=None, return_norm_constant = False): r"""Backward transform :math:`K^T H^T` Args: proj (torch.tensor): Projection data to be back projected subset_idx (int, optional): Only uses a subset of angles :math:`g_m` corresponding to the provided subset index :math:`m`. If None, then defaults to the full projections :math:`g`. return_norm_constant (bool, optional): Additionally returns :math:`K^T H^T 1` if true; defaults to False. Returns: torch.tensor: Corresponding object generated from back projection. """ if return_norm_constant: object, norm_constant = self.system_matrix.backward(proj, subset_idx, return_norm_constant) object, norm_constant = self.kem_transform.backward(object, norm_constant) return object, norm_constant else: object = self.system_matrix.backward(proj, subset_idx, return_norm_constant) return self.kem_transform.backward(object)