Source code for pytomography.transforms.shared.spatial

from __future__ import annotations
import torch
from torch.nn.functional import pad
from pytomography.transforms import Transform
from kornia.geometry.transform import rotate
import numpy as np

[docs]class RotationTransform(Transform): r"""obj2obj transform used to rotate an object to angle :math:`\beta` in the DICOM reference frame. (Note that an angle of ) Args: mode (str): Interpolation mode used in the rotation. """ def __init__( self, mode: str = 'bilinear' )-> None: super(RotationTransform, self).__init__() self.mode = mode @torch.no_grad()
[docs] def forward( self, object: torch.Tensor, angles: torch.Tensor, )-> torch.Tensor: r"""Rotates an object to angle :math:`\beta` in the DICOM reference frame. Note that the scanner angle :math:`\beta` is related to :math:`\phi` (azimuthal angle) by :math:`\phi = 3\pi/2 - \beta`. Args: object (torch.tensor): Tensor of size [batch_size, Lx, Ly, Lz] being rotated. angles (torch.Tensor): Tensor of size [batch_size] corresponding to the rotation angles. Returns: torch.tensor: Tensor of size [batch_size, Lx, Ly, Lz] where each element in the batch dimension is rotated by the corresponding angle. """ return rotate(object.permute(0,3,1,2), angles, mode=self.mode).permute(0,2,3,1)
@torch.no_grad()
[docs] def backward( self, object: torch.Tensor, angles: torch.Tensor ) -> torch.Tensor: r"""Forward projection :math:`A:\mathbb{U} \to \mathbb{U}` of attenuation correction. Args: object (torch.tensor): Tensor of size [batch_size, Lx, Ly, Lz] being rotated. angles (torch.Tensor): Tensor of size [batch_size] corresponding to the rotation angles. Returns: torch.tensor: Tensor of size [batch_size, Lx, Ly, Lz] where each element in the batch dimension is rotated by the corresponding angle. """ return rotate(object.permute(0,3,1,2), -angles, mode=self.mode).permute(0,2,3,1)