import torch
from torch import Tensor
from reflectorch.data_generation.utils import (
uniform_sampler,
logdist_sampler,
)
from reflectorch.data_generation.priors.utils import get_max_allowed_roughness
[docs]
class SamplerStrategy(object):
"""Base class for sampler strategies"""
def sample(self, batch_size: int,
total_min_bounds: Tensor,
total_max_bounds: Tensor,
total_min_delta: Tensor,
total_max_delta: Tensor,
):
raise NotImplementedError
[docs]
class BasicSamplerStrategy(SamplerStrategy):
"""Sampler strategy with no constraints on the values of the parameters
Args:
logdist (bool, optional): if True the relative widths of the subprior intervals are sampled uniformly on a logarithmic scale instead of uniformly. Defaults to False.
"""
def __init__(self, logdist: bool = False):
if logdist:
self.widths_sampler_func = logdist_sampler
else:
self.widths_sampler_func = uniform_sampler
[docs]
def sample(self, batch_size: int,
total_min_bounds: Tensor,
total_max_bounds: Tensor,
total_min_delta: Tensor,
total_max_delta: Tensor,
):
"""
Args:
batch_size (int): the batch size
total_min_bounds (Tensor): mimimum values of the parameters
total_max_bounds (Tensor): maximum values of the parameters
total_min_delta (Tensor): minimum widths of the subprior intervals
total_max_delta (Tensor): maximum widths of the subprior intervals
Returns:
tuple(Tensor): samples the values of the parameters and their prior bounds (params, min_bounds, max_bounds). The widths W of the subprior interval are sampled first, then the centers C of the subprior interval, such that the prior bounds are C-W/2 and C+W/2, then the parameters are sampled from [C-W/2, C+W/2] )
"""
return basic_sampler(
batch_size,
total_min_bounds,
total_max_bounds,
total_min_delta,
total_max_delta,
self.widths_sampler_func,
)
[docs]
class ConstrainedRoughnessSamplerStrategy(BasicSamplerStrategy):
"""Sampler strategy where the roughnesses are constrained not to exceed a fraction of the two neighboring thicknesses
Args:
thickness_mask (Tensor): indices in the tensors which correspond to thicknesses
roughness_mask (Tensor): indices in the tensors which correspond to roughnesses
logdist (bool, optional): if ``True`` the relative widths of the subprior intervals are sampled uniformly on a logarithmic scale instead of uniformly. Defaults to False.
max_thickness_share (float, optional): fraction of the layer thickness that the roughness should not exceed. Defaults to 0.5.
"""
def __init__(self,
thickness_mask: Tensor,
roughness_mask: Tensor,
logdist: bool = False,
max_thickness_share: float = 0.5,
):
super().__init__(logdist=logdist)
self.thickness_mask = thickness_mask
self.roughness_mask = roughness_mask
self.max_thickness_share = max_thickness_share
[docs]
def sample(self, batch_size: int,
total_min_bounds: Tensor,
total_max_bounds: Tensor,
total_min_delta: Tensor,
total_max_delta: Tensor,
):
"""
Args:
batch_size (int): the batch size
total_min_bounds (Tensor): mimimum values of the parameters
total_max_bounds (Tensor): maximum values of the parameters
total_min_delta (Tensor): minimum widths of the subprior intervals
total_max_delta (Tensor): maximum widths of the subprior intervals
Returns:
tuple(Tensor): samples the values of the parameters and their prior bounds *(params, min_bounds, max_bounds)*, the roughnesses being constrained. The widths **W** of the subprior interval are sampled first, then the centers **C** of the subprior interval, such that the prior bounds are **C** - **W** / 2 and **C** + **W** / 2, then the parameters are sampled from [**C** - **W** / 2, **C** + **W** / 2] )
"""
device = total_min_bounds.device
return constrained_roughness_sampler(
batch_size,
total_min_bounds,
total_max_bounds,
total_min_delta,
total_max_delta,
thickness_mask=self.thickness_mask.to(device),
roughness_mask=self.roughness_mask.to(device),
widths_sampler_func=self.widths_sampler_func,
coef=self.max_thickness_share,
)
[docs]
class ConstrainedRoughnessAndImgSldSamplerStrategy(BasicSamplerStrategy):
"""Sampler strategy where the roughnesses are constrained not to exceed a fraction of the two neighboring thicknesses, and the imaginary slds are constrained not to exceed a fraction of the real slds
Args:
thickness_mask (Tensor): indices in the tensors which correspond to thicknesses
roughness_mask (Tensor): indices in the tensors which correspond to roughnesses
sld_mask (Tensor): indices in the tensors which correspond to real slds
isld_mask (Tensor): indices in the tensors which correspond to imaginary slds
logdist (bool, optional): if ``True`` the relative widths of the subprior intervals are sampled uniformly on a logarithmic scale instead of uniformly. Defaults to False.
max_thickness_share (float, optional): fraction of the layer thickness that the roughness should not exceed. Defaults to 0.5
max_sld_share (float, optional): fraction of the real sld that the imaginary sld should not exceed. Defaults to 0.2.
"""
def __init__(self,
thickness_mask: Tensor,
roughness_mask: Tensor,
sld_mask: Tensor,
isld_mask: Tensor,
logdist: bool = False,
max_thickness_share: float = 0.5,
max_sld_share: float = 0.2,
):
super().__init__(logdist=logdist)
self.thickness_mask = thickness_mask
self.roughness_mask = roughness_mask
self.sld_mask = sld_mask
self.isld_mask = isld_mask
self.max_thickness_share = max_thickness_share
self.max_sld_share = max_sld_share
[docs]
def sample(self, batch_size: int,
total_min_bounds: Tensor,
total_max_bounds: Tensor,
total_min_delta: Tensor,
total_max_delta: Tensor,
):
"""
Args:
batch_size (int): the batch size
total_min_bounds (Tensor): mimimum values of the parameters
total_max_bounds (Tensor): maximum values of the parameters
total_min_delta (Tensor): minimum widths of the subprior intervals
total_max_delta (Tensor): maximum widths of the subprior intervals
Returns:
tuple(Tensor): samples the values of the parameters and their prior bounds *(params, min_bounds, max_bounds)*, the roughnesses and imaginary slds being constrained. The widths **W** of the subprior interval are sampled first, then the centers **C** of the subprior interval, such that the prior bounds are **C** - **W** /2 and **C** + **W** / 2, then the parameters are sampled from [**C** - **W** / 2, **C** + **W** / 2] )
"""
device = total_min_bounds.device
return constrained_roughness_and_isld_sampler(
batch_size,
total_min_bounds,
total_max_bounds,
total_min_delta,
total_max_delta,
thickness_mask=self.thickness_mask.to(device),
roughness_mask=self.roughness_mask.to(device),
sld_mask=self.sld_mask.to(device),
isld_mask=self.isld_mask.to(device),
widths_sampler_func=self.widths_sampler_func,
coef_roughness=self.max_thickness_share,
coef_isld=self.max_sld_share,
)
def basic_sampler(
batch_size: int,
total_min_bounds: Tensor,
total_max_bounds: Tensor,
total_min_delta: Tensor,
total_max_delta: Tensor,
widths_sampler_func,
):
delta_vector = total_max_bounds - total_min_bounds
prior_widths = widths_sampler_func(
total_min_delta, total_max_delta,
batch_size, delta_vector.shape[1],
device=total_min_bounds.device, dtype=total_min_bounds.dtype
)
prior_centers = uniform_sampler(
total_min_bounds + prior_widths / 2, total_max_bounds - prior_widths / 2,
*prior_widths.shape,
device=total_min_bounds.device, dtype=total_min_bounds.dtype
)
min_bounds, max_bounds = prior_centers - prior_widths / 2, prior_centers + prior_widths / 2
params = torch.rand(
*min_bounds.shape,
device=min_bounds.device,
dtype=min_bounds.dtype
) * (max_bounds - min_bounds) + min_bounds
return params, min_bounds, max_bounds
def constrained_roughness_sampler(
batch_size: int,
total_min_bounds: Tensor,
total_max_bounds: Tensor,
total_min_delta: Tensor,
total_max_delta: Tensor,
thickness_mask: Tensor,
roughness_mask: Tensor,
widths_sampler_func,
coef: float = 0.5,
):
params, min_bounds, max_bounds = basic_sampler(
batch_size, total_min_bounds, total_max_bounds, total_min_delta, total_max_delta,
widths_sampler_func=widths_sampler_func,
)
max_roughness = torch.minimum(
get_max_allowed_roughness(thicknesses=params[..., thickness_mask], coef=coef),
total_max_bounds[..., roughness_mask]
)
min_roughness = total_min_bounds[..., roughness_mask]
assert torch.all(min_roughness <= max_roughness)
min_roughness_delta = total_min_delta[..., roughness_mask]
max_roughness_delta = torch.minimum(total_max_delta[..., roughness_mask], max_roughness - min_roughness)
roughnesses, min_r_bounds, max_r_bounds = basic_sampler(
batch_size, min_roughness, max_roughness,
min_roughness_delta, max_roughness_delta,
widths_sampler_func=widths_sampler_func
)
min_bounds[..., roughness_mask], max_bounds[..., roughness_mask] = min_r_bounds, max_r_bounds
params[..., roughness_mask] = roughnesses
return params, min_bounds, max_bounds
def constrained_roughness_and_isld_sampler(
batch_size: int,
total_min_bounds: Tensor,
total_max_bounds: Tensor,
total_min_delta: Tensor,
total_max_delta: Tensor,
thickness_mask: Tensor,
roughness_mask: Tensor,
sld_mask: Tensor,
isld_mask: Tensor,
widths_sampler_func,
coef_roughness: float = 0.5,
coef_isld: float = 0.2,
):
params, min_bounds, max_bounds = basic_sampler(
batch_size, total_min_bounds, total_max_bounds, total_min_delta, total_max_delta,
widths_sampler_func=widths_sampler_func,
)
max_roughness = torch.minimum(
get_max_allowed_roughness(thicknesses=params[..., thickness_mask], coef=coef_roughness),
total_max_bounds[..., roughness_mask]
)
min_roughness = total_min_bounds[..., roughness_mask]
assert torch.all(min_roughness <= max_roughness)
min_roughness_delta = total_min_delta[..., roughness_mask]
max_roughness_delta = torch.minimum(total_max_delta[..., roughness_mask], max_roughness - min_roughness)
roughnesses, min_r_bounds, max_r_bounds = basic_sampler(
batch_size, min_roughness, max_roughness,
min_roughness_delta, max_roughness_delta,
widths_sampler_func=widths_sampler_func
)
min_bounds[..., roughness_mask], max_bounds[..., roughness_mask] = min_r_bounds, max_r_bounds
params[..., roughness_mask] = roughnesses
max_isld = torch.minimum(
params[..., sld_mask] * coef_isld,
total_max_bounds[..., isld_mask]
)
min_isld = total_min_bounds[..., isld_mask]
assert torch.all(min_isld <= max_isld)
min_isld_delta = total_min_delta[..., isld_mask]
max_isld_delta = torch.minimum(total_max_delta[..., isld_mask], max_isld - min_isld)
islds, min_isld_bounds, max_isld_bounds = basic_sampler(
batch_size, min_isld, max_isld,
min_isld_delta, max_isld_delta,
widths_sampler_func=widths_sampler_func
)
min_bounds[..., isld_mask], max_bounds[..., isld_mask] = min_isld_bounds, max_isld_bounds
params[..., isld_mask] = islds
return params, min_bounds, max_bounds