Source code for lensmodels.potential

import numpy as np
from wolensing.lensmodels.lens import *
from jax import jit

[docs] @jit def geometrical(x1, x2, y): ''' :param x1: x-coordinates of position on lens plane with respect to the window center. :param x2: y-coordinates of position on lens plane with respect to the window center. :param y: numpy array, source positions. :return: geometrical part of the time delay. ''' x = jnp.array([x1, x2], dtype=jnp.float64) geo = (1/2) * jnp.linalg.norm(x-y[:, jnp.newaxis, jnp.newaxis], axis=0)**2 return geo
[docs] def potential(lens_model_list, x1, x2, y, kwargs): ''' :param lens_model_list: list of lens models. :param x1: x-coordinates of position on lens plane with respect to the window center. :param x2: y-coordinates of position on lens plane with respect to the window center. :param y: numpy array, source positions. :kwargs: arguemnts for the lens models. :return: time delay function. ''' potential = jnp.float64(0.0) for lens_type, lens_kwargs in zip(lens_model_list, kwargs): thetaE = lens_kwargs['theta_E'] x_center = lens_kwargs['center_x'] y_center = lens_kwargs['center_y'] if lens_type == 'SIS': potential += Psi_SIS(x1, x2, x_center, y_center, thetaE) # Make sure Psi_SIS is JAX-compatible elif lens_type == 'POINT_MASS': potential += Psi_PM(x1, x2, x_center, y_center, thetaE) # Make sure Psi_PM is JAX-compatible elif lens_type == 'NFW': potential += Psi_NFW(x1, x2, x_center, y_center, thetaE, kappa=3) elif lens_type == 'SIE': e1 = lens_kwargs['e1'] e2 = lens_kwargs['e2'] potential += Psi_SIE(x1, x2, x_center, y_center, thetaE, e1, e2) geo = geometrical(x1, x2, y) fermat_potential = geo - potential return fermat_potential