Source code for pytomography.transforms.shared.motion
from __future__ import annotations
import torch
import pytomography
from pytomography.transforms import Transform
from scipy.ndimage import map_coordinates
from torch.nn.functional import grid_sample
[docs]class DVFMotionTransform(Transform):
def __init__(
self,
dvf_forward: torch.Tensor | None = None,
dvf_backward: torch.Tensor | None = None,
)-> None:
"""Object to object transform that uses a deformation vector field to deform an object.
Args:
dvf_forward (torch.Tensor[Lx,Ly,Lz,3] | None, optional): Vector field correspond to forward transformation. If None, then no transformation is used. Defaults to None.
dvf_backward (torch.Tensor[Lx,Ly,Lz,3] | None, optional): Vector field correspond to backward transformation. If None, then no transformation is used. Defaults to None. Defaults to None.
"""
self.dvf_forward = dvf_forward.to(pytomography.device).to(pytomography.dtype)
self.dvf_backward = dvf_backward.to(pytomography.device).to(pytomography.dtype)
#self.dvf_forward_vol_ratio = self._get_vol_ratio(self.dvf_forward)
#self.dvf_backward_vol_ratio = self._get_vol_ratio(self.dvf_backward)
self.dvf_forward_vol_ratio = 1
self.dvf_backward_vol_ratio = 1
super(DVFMotionTransform, self).__init__() ## go to the _init_ in Class Transform
[docs] def _get_vol_ratio(self, DVF):
xhat = torch.zeros((3,1,1,1)).to(pytomography.device)
xhat[0] = 1
yhat = torch.zeros((3,1,1,1)).to(pytomography.device)
yhat[1] = 1
zhat = torch.zeros((3,1,1,1)).to(pytomography.device)
zhat[2] = 1
v = DVF.permute((3,0,1,2))
delv = torch.stack(torch.gradient(v, axis=(1,2,3)), axis=0)
vol_ratio = torch.abs((torch.cross(delv[0]+xhat,delv[1]+yhat)*(delv[2]+zhat)).sum(axis=0)).unsqueeze(0)
return vol_ratio
[docs] def _get_old_coordinates(self):
"""Obtain meshgrid of coordinates corresponding to the object
Returns:
torch.Tensor: Tensor of coordinates corresponding to input object
"""
dim_x, dim_y, dim_z = self.object_meta.shape
coordinates=torch.stack(torch.meshgrid(torch.arange(dim_x),torch.arange(dim_y), torch.arange(dim_z), indexing='ij')).permute((1,2,3,0)).to(pytomography.device).to(pytomography.dtype)
return coordinates
[docs] def _get_new_coordinates(self, old_coordinates: torch.Tensor, DVF: torch.Tensor):
"""Obtain the new coordinates of each voxel based on the DVF.
Args:
old_coordinates (torch.Tensor): Old coordinates of each voxel
DVF (torch.Tensor): Deformation vector field.
Returns:
_type_: _description_
"""
dimensions = torch.tensor(self.object_meta.shape).to(pytomography.device)
new_coordinates = old_coordinates + DVF
new_coordinates = 2/(dimensions-1)*new_coordinates - 1
return new_coordinates
[docs] def _apply_dvf(self, DVF: torch.Tensor, vol_ratio, object_i: torch.Tensor):
"""Applies the deformation vector field to the object
Args:
DVF (torch.Tensor): Deformation vector field
object_i (torch.Tensor): Old object.
Returns:
torch.Tensor: Deformed object.
"""
old_coordinates = self._get_old_coordinates()
new_coordinates = self._get_new_coordinates(old_coordinates, DVF)
# Adjust for strecthcing of object
return torch.nn.functional.grid_sample(object_i.unsqueeze(0), new_coordinates.unsqueeze(0).flip(dims=[-1]), align_corners=True)[0] * vol_ratio
[docs] def forward(
self,
object_i: torch.Tensor,
)-> torch.Tensor:
"""Forward transform of deformation vector field
Args:
object_i (torch.Tensor): Original object.
Returns:
torch.Tensor: Deformed object corresponding to forward transform.
"""
if self.dvf_forward is None:
return object_i
else:
return self._apply_dvf(self.dvf_forward, self.dvf_forward_vol_ratio, object_i)
[docs] def backward(
self,
object_i: torch.Tensor,
)-> torch.Tensor:
"""Backward transform of deformation vector field
Args:
object_i (torch.Tensor): Original object.
Returns:
torch.Tensor: Deformed object corresponding to backward transform.
"""
if self.dvf_backward is None:
return object_i
else:
return self._apply_dvf(self.dvf_backward, self.dvf_backward_vol_ratio, object_i)