Source code for reflectorch.ml.dataloaders
from torch import Tensor
from reflectorch.data_generation import BasicDataset
from reflectorch.data_generation.reflectivity import kinematical_approximation
from reflectorch.data_generation.priors import BasicParams
from reflectorch.ml.basic_trainer import DataLoader
__all__ = [
"ReflectivityDataLoader",
"MultilayerDataLoader",
]
[docs]
class ReflectivityDataLoader(BasicDataset, DataLoader):
"""Dataloader for reflectivity data, combining functionality from the ``BasicDataset`` (basic dataset class for reflectivity) and the ``DataLoader`` (which inherits from ``TrainerCallback``) classes"""
pass
class MultilayerDataLoader(ReflectivityDataLoader):
"""Dataloader for reflectivity curves simulated using the kinematical approximation"""
def _sample_from_prior(self, batch_size: int):
return self.prior_sampler.optimized_sample(batch_size)
def _calc_curves(self, q_values: Tensor, params: BasicParams):
return kinematical_approximation(q_values, params.thicknesses, params.roughnesses, params.slds)