Source code for reflectorch.data_generation.scale_curves

# -*- coding: utf-8 -*-
#
#
# This source code is licensed under the GPL license found in the
# LICENSE file in the root directory of this source tree.

from pathlib import Path

import torch
from torch import Tensor

from reflectorch.data_generation.priors import PriorSampler
from reflectorch.paths import SAVED_MODELS_DIR


[docs] class CurvesScaler(object): """Base class for curve scalers""" def scale(self, curves: Tensor): raise NotImplementedError def restore(self, curves: Tensor): raise NotImplementedError
[docs] class LogAffineCurvesScaler(CurvesScaler): """ Curve scaler which scales the reflectivity curves according to the logarithmic affine transformation: :math:`\log_{10}(R + eps) \cdot weight + bias`. Args: weight (float): multiplication factor in the transformation bias (float): addition term in the transformation eps (float): sets the minimum intensity value of the reflectivity curves which is considered """ def __init__(self, weight: float = 0.1, bias: float = 0.5, eps: float = 1e-10): self.weight = weight self.bias = bias self.eps = eps
[docs] def scale(self, curves: Tensor): """scales the reflectivity curves to a ML-friendly range Args: curves (Tensor): original reflectivity curves Returns: Tensor: reflectivity curves scaled to a ML-friendly range """ return torch.log10(curves + self.eps) * self.weight + self.bias
[docs] def restore(self, curves: Tensor): """restores the physical reflectivity curves Args: curves (Tensor): scaled reflectivity curves Returns: Tensor: reflectivity curves restored to the physical range """ return 10 ** ((curves - self.bias) / self.weight) - self.eps
[docs] class MeanNormalizationCurvesScaler(CurvesScaler): """Curve scaler which scales the reflectivity curves by the precomputed mean of a batch of curves Args: path (str, optional): path to the precomputed mean of the curves, only used if ``curves_mean`` is None. Defaults to None. curves_mean (Tensor, optional): the precomputed mean of the curves. Defaults to None. device (torch.device, optional): the Pytorch device. Defaults to 'cuda'. """ def __init__(self, path: str = None, curves_mean: Tensor = None, device: torch.device = 'cuda'): if curves_mean is None: curves_mean = torch.load(self.get_path(path)) self.curves_mean = curves_mean.to(device)
[docs] def scale(self, curves: Tensor): """scales the reflectivity curves to a ML-friendly range Args: curves (Tensor): original reflectivity curves Returns: Tensor: reflectivity curves scaled to a ML-friendly range """ self.curves_mean = self.curves_mean.to(curves) return curves / self.curves_mean - 1
[docs] def restore(self, curves: Tensor): """restores the physical reflectivity curves Args: curves (Tensor): scaled reflectivity curves Returns: Tensor: reflectivity curves restored to the physical range """ self.curves_mean = self.curves_mean.to(curves) return (curves + 1) * self.curves_mean
[docs] @staticmethod def save(prior_sampler: PriorSampler, q: Tensor, path: str, num: int = 16384): """computes the mean of a batch of reflectivity curves and saves it Args: prior_sampler (PriorSampler): the prior sampler q (Tensor): the q values path (str): the path for saving the mean of the curves num (int, optional): the number of curves used to compute the mean. Defaults to 16384. """ params = prior_sampler.sample(num) curves_mean = params.reflectivity(q, log=False).mean(0).cpu() torch.save(curves_mean, MeanNormalizationCurvesScaler.get_path(path))
@staticmethod def get_path(path: str) -> Path: if not path.endswith('.pt'): path = path + '.pt' return SAVED_MODELS_DIR / path