Source code for reflectorch.data_generation.dataset

# -*- coding: utf-8 -*-
#
#
# This source code is licensed under the GPL license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict, Union
import warnings

from torch import Tensor
import torch

from reflectorch.data_generation.priors import PriorSampler, BasicParams
from reflectorch.data_generation.noise import QNoiseGenerator, IntensityNoiseGenerator
from reflectorch.data_generation.q_generator import QGenerator
from reflectorch.data_generation.scale_curves import CurvesScaler, LogAffineCurvesScaler
from reflectorch.data_generation.smearing import Smearing

BATCH_DATA_TYPE = Dict[str, Union[Tensor, BasicParams]]


[docs] class BasicDataset(object): """Reflectometry dataset. It generates the q positions, samples the thin film parameters from the prior, simulates the reflectivity curves and applies noise to the curves. Args: q_generator (QGenerator): the momentum transfer (q) generator prior_sampler (PriorSampler): the prior sampler intensity_noise (IntensityNoiseGenerator, optional): the intensity noise generator. Defaults to None. q_noise (QNoiseGenerator, optional): the q noise generator. Defaults to None. curves_scaler (CurvesScaler, optional): the reflectivity curve scaler. Defaults to an instance of LogAffineCurvesScaler, which scales the curves to the range [-1, 1], the minimum considered intensity being 1e-10. calc_denoised_curves (bool, optional): whether to add the curves without noise to the dictionary. Defaults to False. smearing (Smearing, optional): curve smearing generator. Defaults to None. """ def __init__(self, q_generator: QGenerator, prior_sampler: PriorSampler, intensity_noise: IntensityNoiseGenerator = None, q_noise: QNoiseGenerator = None, curves_scaler: CurvesScaler = None, calc_denoised_curves: bool = False, smearing: Smearing = None, ): self.q_generator = q_generator self.intensity_noise = intensity_noise self.q_noise = q_noise self.curves_scaler = curves_scaler or LogAffineCurvesScaler() self.prior_sampler = prior_sampler self.smearing = smearing self.calc_denoised_curves = calc_denoised_curves
[docs] def update_batch_data(self, batch_data: BATCH_DATA_TYPE) -> None: """implement in a subclass to edit batch_data dict inplace""" pass
def _sample_from_prior(self, batch_size: int): params: BasicParams = self.prior_sampler.sample(batch_size) scaled_params: Tensor = self.prior_sampler.scale_params(params) return params, scaled_params
[docs] def get_batch(self, batch_size: int) -> BATCH_DATA_TYPE: """get a batch of data as a dictionary with keys ``params``, ``scaled_params``, ``q_values``, ``curves``, ``scaled_noisy_curves`` Args: batch_size (int): the batch size """ batch_data = {} params, scaled_params = self._sample_from_prior(batch_size) batch_data['params'] = params batch_data['scaled_params'] = scaled_params q_values: Tensor = self.q_generator.get_batch(batch_size, batch_data) if self.q_noise: batch_data['original_q_values'] = q_values q_values = self.q_noise.apply(q_values, batch_data) batch_data['q_values'] = q_values curves = self._calc_curves(q_values, params) if self.calc_denoised_curves: batch_data['curves'] = curves noisy_curves = curves if self.intensity_noise: noisy_curves = self.intensity_noise(noisy_curves, batch_data) scaled_noisy_curves = self.curves_scaler.scale(noisy_curves) batch_data['scaled_noisy_curves'] = scaled_noisy_curves is_finite = torch.all(torch.isfinite(scaled_noisy_curves), -1) if not torch.all(is_finite).item(): infinite_indices = ~is_finite warnings.warn(f'Batch with {infinite_indices.sum().item()} curves with infinities skipped.') return self.get_batch(batch_size = batch_size) is_finite = torch.all(torch.isfinite(batch_data['scaled_noisy_curves']), -1) assert torch.all(is_finite).item() self.update_batch_data(batch_data) return batch_data
def _calc_curves(self, q_values: Tensor, params: BasicParams): if self.smearing: curves = self.smearing.get_curves(q_values, params) else: curves = params.reflectivity(q_values) curves = curves.to(q_values) return curves
def _insert_batch_data(tgt_batch_data, add_batch_data, indices): for key in tuple(tgt_batch_data.keys()): value = tgt_batch_data[key] if isinstance(value, BasicParams) or len(value.shape) == 2: value[indices] = add_batch_data[key] else: warnings.warn(f'Ignore {key} while merging batch_data.') if __name__ == '__main__': from reflectorch.data_generation.q_generator import ConstantQ from reflectorch.data_generation.priors import BasicPriorSampler, UniformSubPriorSampler from reflectorch.data_generation.noise import BasicExpIntensityNoise from reflectorch.data_generation.noise import BasicQNoiseGenerator from reflectorch.utils import to_np from time import perf_counter q_generator = ConstantQ((0, 0.2, 65), device='cpu') noise_gen = BasicExpIntensityNoise( relative_errors=(0.05, 0.2), # scale_range=(-1e-2, 1e-2), logdist=True, apply_shift=True, ) q_noise_gen = BasicQNoiseGenerator( shift_std=5e-4, noise_std=(0, 1e-3), ) prior_sampler = UniformSubPriorSampler( thickness_range=(0, 250), roughness_range=(0, 40), sld_range=(0, 60), num_layers=2, device=torch.device('cpu'), dtype=torch.float64, smaller_roughnesses=True, logdist=True, relative_min_bound_width=5e-4, ) smearing = Smearing( sigma_range=(0.8e-3, 5e-3), gauss_num=31, share_smeared=0.5, ) dataset = BasicDataset( q_generator, prior_sampler, noise_gen, q_noise=q_noise_gen, smearing=smearing ) start = perf_counter() batch_data = dataset.get_batch(32) print(f'Total time = {(perf_counter() - start):.3f} sec ') print(batch_data['params'].roughnesses[:10]) print(batch_data['scaled_noisy_curves'].min().item()) scaled_noisy_curves = batch_data['scaled_noisy_curves'] scaled_curves = dataset.curves_scaler.scale( batch_data['params'].reflectivity(q_generator.q) ) try: import matplotlib.pyplot as plt for i in range(16): plt.plot( to_np(q_generator.q.squeeze().cpu().numpy()), to_np(scaled_curves[i]) ) plt.plot( to_np(q_generator.q.squeeze().cpu().numpy()), to_np(scaled_noisy_curves[i]) ) plt.show() except ImportError: pass