Source code for pytomography.priors.prior
from __future__ import annotations
import torch
import torch.nn as nn
import numpy as np
import abc
import pytomography
from pytomography.metadata import ObjectMeta, ImageMeta
[docs]class Prior():
r"""Abstract class for implementation of prior :math:`V(f)` where :math:`V` is from the log-posterior probability :math:`\ln L(\tilde{f}, f) - \beta V(f)`. Any function inheriting from this class should implement a ``foward`` method that computes the tensor :math:`\frac{\partial V}{\partial f_r}` where :math:`f` is an object tensor.
Args:
beta (float): Used to scale the weight of the prior
"""
@abc.abstractmethod
def __init__(self, beta: float):
self.beta = beta
self.device = pytomography.device
[docs] def set_beta_scale(self, factor: float) -> None:
r"""Sets a scale factor for :math:`\beta` required for OSEM when finite subsets are used per iteration.
Args:
factor (float): Value by which to scale :math:`\beta`
"""
self.beta_scale_factor = factor
[docs] def set_object(self, object: torch.Tensor) -> None:
r"""Sets the object :math:`f_r` used to compute :math:`\frac{\partial V}{\partial f_r}`
Args:
object (torch.tensor): Tensor of size [batch_size, Lx, Ly, Lz] representing :math:`f_r`.
"""
self.object = object
@abc.abstractmethod
[docs] def compute_gradient(self):
"""Abstract method to compute the gradient of the prior based on the ``self.object`` attribute.
"""
...