Source code for reflectorch.data_generation.q_generator

# -*- coding: utf-8 -*-
#
#
# This source code is licensed under the GPL license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple, Union

import numpy as np

import torch
from torch import Tensor

from reflectorch.data_generation.utils import uniform_sampler
from reflectorch.data_generation.priors import BasicParams
from reflectorch.utils import angle_to_q
from reflectorch.data_generation.priors.no_constraints import DEFAULT_DEVICE, DEFAULT_DTYPE

__all__ = [
    "QGenerator",
    "ConstantQ",
    "VariableQ",
    "EquidistantQ",
    "ConstantAngle",
]


[docs] class QGenerator(object): """Base class for momentum transfer (q) generators""" def get_batch(self, batch_size: int, context: dict = None) -> Tensor: pass
[docs] class ConstantQ(QGenerator): """Q generator for reflectivity curves with fixed discretization Args: q (Union[Tensor, Tuple[float, float, int]], optional): tuple (q_min, q_max, num_q) defining the minimum q value, maximum q value and the number of q points. Defaults to (0., 0.2, 128). device (optional): the Pytorch device. Defaults to DEFAULT_DEVICE. dtype (optional): the Pytorch data type. Defaults to DEFAULT_DTYPE. remove_zero (bool, optional): do not include the upper end of the interval. Defaults to False. fixed_zero (bool, optional): do not include the lower end of the interval. Defaults to False. """ def __init__(self, q: Union[Tensor, Tuple[float, float, int]] = (0., 0.2, 128), device=DEFAULT_DEVICE, dtype=DEFAULT_DTYPE, remove_zero: bool = False, fixed_zero: bool = False, ): if isinstance(q, (tuple, list)): q = torch.linspace(*q, device=device, dtype=dtype) if remove_zero: if fixed_zero: q = q[1:] else: q = q[:-1] self.q = q
[docs] def get_batch(self, batch_size: int, context: dict = None) -> Tensor: """generate a batch of q values Args: batch_size (int): the batch size Returns: Tensor: generated batch of q values """ return self.q.clone()[None].expand(batch_size, self.q.shape[0])
[docs] class VariableQ(QGenerator): """Q generator for reflectivity curves with variable discretization Args: q_min_range (list, optional): the range for sampling the minimum q value of the curves, *q_min*. Defaults to [0.01, 0.03]. q_max_range (list, optional): the range for sampling the maximum q value of the curves, *q_max*. Defaults to [0.1, 0.5]. n_q_range (list, optional): the range for the number of points in the curves (equidistantly sampled between *q_min* and *q_max*, the number of points varies between batches but is constant within a batch). Defaults to [64, 256]. device (optional): the Pytorch device. Defaults to DEFAULT_DEVICE. dtype (optional): the Pytorch data type. Defaults to DEFAULT_DTYPE. """ def __init__(self, q_min_range: Tuple[float, float] = (0.01, 0.03), q_max_range: Tuple[float, float] = (0.1, 0.5), n_q_range: Tuple[int, int] = (64, 256), device=DEFAULT_DEVICE, dtype=DEFAULT_DTYPE, ): self.q_min_range = q_min_range self.q_max_range = q_max_range self.n_q_range = n_q_range self.device = device self.dtype = dtype
[docs] def get_batch(self, batch_size: int, context: dict = None) -> Tensor: """generate a batch of q values (the number of points varies between batches but is constant within a batch) Args: batch_size (int): the batch size Returns: Tensor: generated batch of q values """ q_min = np.random.uniform(*self.q_min_range, batch_size) q_max = np.random.uniform(*self.q_max_range, batch_size) if self.n_q_range[0] == self.n_q_range[1]: n_q = self.n_q_range[0] else: n_q = np.random.randint(self.n_q_range[0], self.n_q_range[1] + 1) q = torch.from_numpy(np.linspace(q_min, q_max, n_q).T).to(self.device).to(self.dtype) return q
[docs] def scale_q(self, q): """scales the q values to the range [-1, 1] Args: q (Tensor): unscaled q values Returns: Tensor: scaled q values """ scaled_q_01 = (q - self.q_min_range[0]) / (self.q_max_range[1] - self.q_min_range[0]) return 2.0 * (scaled_q_01 - 0.5)
[docs] class ConstantAngle(QGenerator): """Q generator for reflectivity curves measured at equidistant angles Args: angle_range (Tuple[float, float, int], optional): the range of the incident angles. Defaults to (0., 0.2, 257). wavelength (float, optional): the beam wavelength in units of angstroms. Defaults to 1. device (optional): the Pytorch device. Defaults to DEFAULT_DEVICE. dtype (optional): the Pytorch data type. Defaults to DEFAULT_DTYPE. """ def __init__(self, angle_range: Tuple[float, float, int] = (0., 0.2, 257), wavelength: float = 1., device=DEFAULT_DEVICE, dtype=DEFAULT_DTYPE, ): self.q = torch.from_numpy(angle_to_q(np.linspace(*angle_range), wavelength)).to(device).to(dtype)
[docs] def get_batch(self, batch_size: int, context: dict = None) -> Tensor: """generate a batch of q values Args: batch_size (int): the batch size Returns: Tensor: generated batch of q values """ return self.q.clone()[None].expand(batch_size, self.q.shape[0])
class EquidistantQ(QGenerator): def __init__(self, max_range: Tuple[float, float], num_values: Union[int, Tuple[int, int]], device=None, dtype=torch.float64 ): self.max_range = max_range self._num_values = num_values self.device = device self.dtype = dtype @property def num_values(self) -> int: if isinstance(self._num_values, int): return self._num_values return np.random.randint(*self._num_values) def get_batch(self, batch_size: int, context: dict = None) -> Tensor: num_values = self.num_values q_max = uniform_sampler(*self.max_range, batch_size, 1, device=self.device, dtype=self.dtype) norm_qs = torch.linspace(0, 1, num_values + 1, device=self.device, dtype=self.dtype)[1:][None] qs = norm_qs * q_max return qs class TransformerQ(QGenerator): def __init__(self, q_max: float = 0.2, num_values: Union[int, Tuple[int, int]] = (30, 512), min_dq_ratio: float = 5., device=None, dtype=torch.float64, ): self.min_dq_ratio = min_dq_ratio self.q_max = q_max self._dq_range = q_max / num_values[1], q_max / num_values[0] self._num_values = num_values self.device = device self.dtype = dtype def get_batch(self, batch_size: int, context: dict = None) -> Tensor: assert context is not None params: BasicParams = context['params'] total_thickness = params.thicknesses.sum(-1) assert total_thickness.shape[0] == batch_size min_dqs = torch.clamp( 2 * np.pi / total_thickness / self.min_dq_ratio, self._dq_range[0], self._dq_range[1] * 0.9 ) dqs = torch.rand_like(min_dqs) * (self._dq_range[1] - min_dqs) + min_dqs num_q_values = torch.clamp(self.q_max // dqs, *self._num_values).to(torch.int) q_values, mask = generate_q_padding_mask(num_q_values, self.q_max) context['tgt_key_padding_mask'] = mask context['num_q_values'] = num_q_values return q_values def generate_q_padding_mask(num_q_values: Tensor, q_max: float): batch_size = num_q_values.shape[0] dqs = (q_max / num_q_values)[:, None] q_values = torch.arange(1, num_q_values.max().item() + 1)[None].repeat(batch_size, 1) * dqs mask = (q_values > q_max + dqs / 2) q_values[mask] = 0. return q_values, mask