Source code for reflectorch.inference.torch_fitter

from tqdm import trange

import torch
from torch import nn, Tensor

from reflectorch.data_generation import LogLikelihood, reflectivity, PriorSampler


[docs] class ReflGradientFit(object): """Directly optimizes the thin film parameters using a Pytorch optimizer Args: q (Tensor): the q positions exp_curve (Tensor): the experimental reflectivity curve prior_sampler (PriorSampler): the prior sampler params (Tensor): the initial thin film parameters fit_indices (Tensor): the indices of the thin film parameters which are to be fitted sigmas (Tensor, optional): error bars of the reflectivity curve, if not provided they are derived from ``rel_err`` and ``abs_err``. Defaults to None. optim_cls (Type[torch.optim.Optimizer], optional): the Pytorch optimizer class. Defaults to None. lr (float, optional): the learning rate. Defaults to 1e-2. rel_err (float, optional): the relative error in the reflectivity curve. Defaults to 0.1. abs_err (float, optional): the absolute error in the reflectivity curve. Defaults to 1e-7. """ def __init__(self, q: Tensor, exp_curve: Tensor, prior_sampler: PriorSampler, params: Tensor, fit_indices: Tensor, sigmas: Tensor = None, optim_cls=None, lr: float = 1e-2, rel_err: float = 0.1, abs_err: float = 1e-7, ): self.q = q if sigmas is None: sigmas = exp_curve * rel_err + abs_err self.likelihood = LogLikelihood(q, exp_curve, prior_sampler, sigmas) self.num_layers = params.shape[-1] // 3 self.fit_indices = fit_indices self.init_params = params.clone() self.params_to_fit = nn.Parameter(self.init_params[fit_indices].clone()) optim_cls = optim_cls or torch.optim.Adam self.optim = optim_cls([self.params_to_fit], lr) self.losses = [] @property def params(self): params = self.init_params.clone() params[self.fit_indices] = self.params_to_fit return params def calc_log_likelihood(self): return self.likelihood.calc_log_likelihood(self.refl()) def calc_log_prob_loss(self): return - self.calc_log_likelihood().mean() def refl(self): d, sigma, rho = torch.split(self.params, [self.num_layers, self.num_layers + 1, self.num_layers + 1], -1) return reflectivity(self.q, d, sigma, rho)
[docs] def run(self, num_iterations: int = 500, disable_tqdm: bool = False): """Runs the optimization process Args: num_iterations (int, optional): number of iterations the optimization is run for. Defaults to 500. disable_tqdm (bool, optional): whether to disable the prograss bar. Defaults to False. """ pbar = trange(num_iterations, disable=disable_tqdm) for _ in pbar: self.optim.zero_grad() loss = self.calc_log_prob_loss() loss.backward() self.optim.step() self.losses.append(loss.item()) pbar.set_description(f'Loss = {loss.item():.2e}')
def clear(self): self.losses.clear()