Source code for reflectorch.data_generation.smearing
import torch
from torch import Tensor
from reflectorch.data_generation.priors.parametric_subpriors import BasicParams
[docs]
class Smearing(object):
"""Class which applies resolution smearing to the reflectivity curves.
The intensity at a q point will be the average of the intensities of neighbouring q points, weighted by a gaussian profile.
Args:
sigma_range (tuple, optional): the range for sampling the standard deviation of the gaussians. Defaults to (1e-4, 5e-3).
constant_dq (bool, optional): whether the smearing is constant for each q point. Defaults to True.
gauss_num (int, optional): the number of interpolating gaussian profiles. Defaults to 31.
share_smeared (float, optional): the share of curves in the batch for which the resolution smearing is applied. Defaults to 0.2.
"""
def __init__(self,
sigma_range: tuple = (1e-4, 5e-3),
constant_dq: bool = True,
gauss_num: int = 31,
share_smeared: float = 0.2,
):
self.sigma_min, self.sigma_max = sigma_range
self.sigma_delta = self.sigma_max - self.sigma_min
self.constant_dq = constant_dq
self.gauss_num = gauss_num
self.share_smeared = share_smeared
def __repr__(self):
return f'Smearing(({self.sigma_min}, {self.sigma_max})'
def generate_resolutions(self, batch_size: int, device=None, dtype=None):
num_smeared = int(batch_size * self.share_smeared)
if not num_smeared:
return None, None
dq = torch.rand(num_smeared, 1, device=device, dtype=dtype) * self.sigma_delta + self.sigma_min
indices = torch.zeros(batch_size, device=device, dtype=torch.bool)
indices[torch.randperm(batch_size, device=device)[:num_smeared]] = True
return dq, indices
def get_curves(self, q_values: Tensor, params: BasicParams):
dq, indices = self.generate_resolutions(params.batch_size, device=params.device, dtype=params.dtype)
if dq is None:
return params.reflectivity(q_values, log=False)
curves = torch.empty(params.batch_size, q_values.shape[-1], device=params.device, dtype=params.dtype)
if (~indices).sum().item():
if q_values.dim() == 2 and q_values.shape[0] > 1:
q = q_values[~indices]
else:
q = q_values
curves[~indices] = params[~indices].reflectivity(q, log=False)
if indices.sum().item():
if q_values.dim() == 2 and q_values.shape[0] > 1:
q = q_values[indices]
else:
q = q_values
curves[indices] = params[indices].reflectivity(
q, dq=dq, constant_dq=self.constant_dq, log=False, gauss_num=self.gauss_num
)
return curves