Source code for reflectorch.data_generation.likelihoods

# -*- coding: utf-8 -*-
#
#
# This source code is licensed under the GPL license found in the
# LICENSE file in the root directory of this source tree.

from typing import Union, Tuple

import torch
from torch import Tensor

from reflectorch.data_generation import (
    PriorSampler,
    Params,
)


[docs] class LogLikelihood(object): """Computes the gaussian log likelihood of the thin film parameters Args: q (Tensor): the q values exp_curve (Tensor): the experimental reflectivity curve priors (PriorSampler): the prior sampler sigmas (Union[float, Tensor]): the sigmas (i.e. intensity error bars) """ def __init__(self, q: Tensor, exp_curve: Tensor, priors: PriorSampler, sigmas: Union[float, Tensor]): self.exp_curve = torch.atleast_2d(exp_curve) self.priors: PriorSampler = priors self.q = q self.sigmas = sigmas self.sigmas2 = self.sigmas ** 2
[docs] def calc_log_likelihood(self, curves: Tensor): "computes the gaussian log likelihood" log_probs = - (self.exp_curve - curves) ** 2 / self.sigmas2 / 2 return log_probs.sum(-1)
def __call__(self, params: Union[Params, Tensor], curves: Tensor = None): if not isinstance(params, Params): params: Params = self.priors.PARAM_CLS.from_tensor(params) log_priors: Tensor = self.priors.log_prob(params) indices: Tensor = torch.isfinite(log_priors) if not indices.sum().item(): return log_priors finite_params: Params = params[indices] if curves is None: curves: Tensor = finite_params.reflectivity(self.q) else: curves = curves[indices] log_priors[indices] += self.calc_log_likelihood(curves) return log_priors calc_log_posterior = __call__ def get_importance_sampling_weights( self, sampled_params: Params, nf_log_probs: Tensor, curves: Tensor = None ) -> Tuple[Tensor, Tensor, Tensor]: log_probs = self.calc_log_posterior(sampled_params, curves=curves) log_weights = log_probs - nf_log_probs log_weights = log_weights - log_weights.max() weights = torch.exp(log_weights.to(torch.float64)).to(log_weights) weights = weights / weights.sum() return weights, log_weights, log_probs
[docs] class PoissonLogLikelihood(LogLikelihood): """Computes the Poisson log likelihood of the thin film parameters Args: q (Tensor): the q values exp_curve (Tensor): the experimental reflectivity curve priors (PriorSampler): the prior sampler sigmas (Union[float, Tensor]): the sigmas (i.e. intensity error bars) """
[docs] def calc_log_likelihood(self, curves: Tensor): """computes the Poisson log likelihood""" log_probs = self.exp_curve / self.sigmas2 * (self.exp_curve * torch.log(curves) - curves) return log_probs.sum(-1)