Source code for pytomography.projections.projection

from __future__ import annotations
import torch.nn as nn
import abc
import pytomography
from pytomography.mappings import MapNet
from pytomography.metadata import ObjectMeta, ImageMeta

[docs]class ProjectionNet(nn.Module): r"""Abstract parent class for projection networks. Any subclass of this network must implement the ``forward`` method. Args: obj2obj_nets (list): Sequence of object mappings that occur before projection. im2im_nets (list): Sequence of image mappings that occur after projection. object_meta (ObjectMeta): Object metadata. image_meta (ImageMeta): Image metadata. device (str, optional): Pytorch device used for computation. If None, uses the default device `pytomography.device` Defaults to None.""" def __init__( self, obj2obj_nets: list[MapNet], im2im_nets: list[MapNet], object_meta: ObjectMeta, image_meta: ImageMeta, device: str = None ) -> None: super(ProjectionNet, self).__init__() self.device = pytomography.device if device is None else device self.obj2obj_nets = obj2obj_nets self.im2im_nets = im2im_nets self.object_meta = object_meta self.image_meta = image_meta self.initialize_correction_nets()
[docs] def initialize_correction_nets(self): """Initializes all mapping networks with the required object and image metadata corresponding to the projection network. """ for net in self.obj2obj_nets: net.initialize_network(self.object_meta, self.image_meta) for net in self.im2im_nets: net.initialize_network(self.object_meta, self.image_meta)
@abc.abstractmethod
[docs] def foward(self): """Abstract method that must be implemented by any subclass of this class. """ ...