# -*- coding: utf-8 -*-
#
#
# This source code is licensed under the GPL license found in the
# LICENSE file in the root directory of this source tree.
from typing import Tuple, Dict, Type, List
import torch
from torch import Tensor
from reflectorch.data_generation.priors.base import PriorSampler
from reflectorch.data_generation.priors.params import AbstractParams
from reflectorch.data_generation.priors.no_constraints import (
DEFAULT_DEVICE,
DEFAULT_DTYPE,
)
from reflectorch.data_generation.priors.parametric_models import (
MULTILAYER_MODELS,
ParametricModel,
)
from reflectorch.data_generation.priors.scaler_mixin import ScalerMixin
[docs]
class BasicParams(AbstractParams):
"""Parameter class compatible with different parameterizations of the SLD profile. It stores the parameters as well as their minimum and maximum subprior bounds.
Args:
parameters (Tensor): the values of the thin film parameters
min_bounds (Tensor): the minimum subprior bounds of the parameters
max_bounds (Tensor): the maximum subprior bounds of the parameters
max_num_layers (int, optional): the maximum number of layers (for box model parameterizations it is the number of layers). Defaults to None.
param_model (ParametricModel, optional): the parametric model. Defaults to the box model parameterization with number of layers given by max_num_layers.
"""
__slots__ = (
'parameters',
'min_bounds',
'max_bounds',
'max_num_layers',
'param_model',
)
PARAM_NAMES = __slots__
PARAM_MODEL_CLS: Type[ParametricModel]
MAX_NUM_LAYERS: int = 30
def __init__(self,
parameters: Tensor,
min_bounds: Tensor,
max_bounds: Tensor,
max_num_layers: int = None,
param_model: ParametricModel = None,
):
max_num_layers = max_num_layers or self.MAX_NUM_LAYERS
self.param_model = param_model or self.PARAM_MODEL_CLS(max_num_layers)
self.max_num_layers = max_num_layers
self.parameters = parameters
self.min_bounds = min_bounds
self.max_bounds = max_bounds
[docs]
def get_param_labels(self) -> List[str]:
"""gets the parameter labels"""
return self.param_model.get_param_labels()
[docs]
def reflectivity(self, q: Tensor, log: bool = False, **kwargs):
r"""computes the reflectivity curves directly from the parameters
Args:
q (Tensor): the q values
log (bool, optional): whether to apply logarithm to the curves. Defaults to False.
Returns:
Tensor: the simulated reflectivity curves
"""
return self.param_model.reflectivity(q, self.parameters, log=log, **kwargs)
@property
def max_layer_num(self) -> int: # keep for back compatibility but TODO: unify api among different params
"""gets the maximum number of layers"""
return self.max_num_layers
@property
def num_params(self) -> int:
"""get the number of parameters (parameter dimensionality)"""
return self.param_model.param_dim
@property
def thicknesses(self):
"""gets the thicknesses"""
params = self.param_model.to_standard_params(self.parameters)
return params['thickness']
@property
def roughnesses(self):
"""gets the roughnesses"""
params = self.param_model.to_standard_params(self.parameters)
return params['roughness']
@property
def slds(self):
"""gets the slds"""
params = self.param_model.to_standard_params(self.parameters)
return params['sld']
@staticmethod
def rearrange_context_from_params(
scaled_params: Tensor,
context: Tensor,
inference: bool = False,
from_params: bool = False,
):
if inference:
if from_params:
num_params = scaled_params.shape[-1] // 3
scaled_params = scaled_params[:, num_params:]
context = torch.cat([context, scaled_params], dim=-1)
return context
num_params = scaled_params.shape[-1] // 3
assert num_params * 3 == scaled_params.shape[-1]
scaled_params, bound_context = torch.split(scaled_params, [num_params, 2 * num_params], dim=-1)
context = torch.cat([context, bound_context], dim=-1)
return scaled_params, context
@staticmethod
def restore_params_from_context(scaled_params: Tensor, context: Tensor):
num_params = scaled_params.shape[-1]
scaled_bounds = context[:, -2 * num_params:]
scaled_params = torch.cat([scaled_params, scaled_bounds], dim=-1)
return scaled_params
[docs]
def as_tensor(self, add_bounds: bool = True, **kwargs) -> Tensor:
"""converts the instance of the class to a Pytorch tensor
Args:
add_bounds (bool, optional): whether to add the subprior bounds to the tensor. Defaults to True.
Returns:
Tensor: the Pytorch tensor obtained from the instance of the class
"""
if not add_bounds:
return self.parameters
return torch.cat([self.parameters, self.min_bounds, self.max_bounds], -1)
[docs]
@classmethod
def from_tensor(cls, params: Tensor, **kwargs):
"""initializes an instance of the class from a Pytorch tensor
Args:
params (Tensor): Pytorch tensor containing the parameter values, min subprior bounds and max subprior bounds
Returns:
BasicParams: the instance of the class
"""
num_params = params.shape[-1] // 3
params, min_bounds, max_bounds = torch.split(
params, [num_params, num_params, num_params], dim=-1
)
return cls(
params,
min_bounds,
max_bounds,
**kwargs
)
[docs]
def scale_with_q(self, q_ratio: float):
"""scales the parameters based on the q ratio
Args:
q_ratio (float): the scaling ratio
"""
self.parameters = self.param_model.scale_with_q(self.parameters, q_ratio)
self.min_bounds = self.param_model.scale_with_q(self.min_bounds, q_ratio)
self.max_bounds = self.param_model.scale_with_q(self.max_bounds, q_ratio)
[docs]
class SubpriorParametricSampler(PriorSampler, ScalerMixin):
PARAM_CLS = BasicParams
def __init__(self,
param_ranges: Dict[str, Tuple[float, float]],
bound_width_ranges: Dict[str, Tuple[float, float]],
model_name: str,
device: torch.device = DEFAULT_DEVICE,
dtype: torch.dtype = DEFAULT_DTYPE,
max_num_layers: int = 50,
logdist: bool = False,
scale_params_by_ranges = False,
scaled_range: Tuple[float, float] = (-1., 1.),
**kwargs
):
"""Prior sampler for the parameters of a parametric model and their subprior bounds
Args:
param_ranges (Dict[str, Tuple[float, float]]): dictionary containing the name of each type of parameter together with its range
bound_width_ranges (Dict[str, Tuple[float, float]]): dictionary containing the name of each type of parameter together with the range for sampling the widths of the subprior interval
model_name (str): the name of the parametric model
device (torch.device, optional): the Pytorch device. Defaults to DEFAULT_DEVICE.
dtype (torch.dtype, optional): the Pytorch data type. Defaults to DEFAULT_DTYPE.
max_num_layers (int, optional): the maximum number of layers (for box model parameterizations it is the number of layers). Defaults to 50.
logdist (bool, optional): if True the relative widths of the subprior intervals are sampled uniformly on a logarithmic scale instead of uniformly. Defaults to False.
scale_params_by_ranges (bool, optional): if True the parameters are scaled with respect to their ranges instead of being scaled with respect to their prior bounds. Defaults to False.
scaled_range (Tuple[float, float], optional): the range for scaling the parameters. Defaults to (-1., 1.)
"""
self.scaled_range = scaled_range
self.param_model: ParametricModel = MULTILAYER_MODELS[model_name](
max_num_layers,
logdist=logdist,
**kwargs
)
self.device = device
self.dtype = dtype
self.num_layers = max_num_layers
self.PARAM_CLS.PARAM_MODEL_CLS = MULTILAYER_MODELS[model_name]
self.PARAM_CLS.MAX_NUM_LAYERS = max_num_layers
self._param_dim = self.param_model.param_dim
self.min_bounds, self.max_bounds, self.min_delta, self.max_delta = self.param_model.init_bounds(
param_ranges, bound_width_ranges, device=device, dtype=dtype
)
self.param_ranges = param_ranges
self.bound_width_ranges = bound_width_ranges
self.model_name = model_name
self.logdist = logdist
self.scale_params_by_ranges = scale_params_by_ranges
@property
def max_num_layers(self) -> int:
"""gets the maximum number of layers"""
return self.num_layers
@property
def param_dim(self) -> int:
"""get the number of parameters (parameter dimensionality)"""
return self._param_dim
[docs]
def sample(self, batch_size: int) -> BasicParams:
"""sample a batch of parameters
Args:
batch_size (int): the batch size
Returns:
BasicParams: sampled parameters
"""
params, min_bounds, max_bounds = self.param_model.sample(
batch_size, self.min_bounds, self.max_bounds, self.min_delta, self.max_delta
)
params = BasicParams(
parameters=params,
min_bounds=min_bounds,
max_bounds=max_bounds,
max_num_layers=self.max_num_layers,
param_model=self.param_model,
)
return params
[docs]
def scale_params(self, params: BasicParams) -> Tensor:
"""scale the parameters to a ML-friendly range
Args:
params (BasicParams): the parameters to be scaled
Returns:
Tensor: the scaled parameters
"""
if self.scale_params_by_ranges:
scaled_params = torch.cat([
self._scale(params.parameters, self.min_bounds, self.max_bounds), #parameters and subprior bounds are scaled with respect to the parameter ranges
self._scale(params.min_bounds, self.min_bounds, self.max_bounds),
self._scale(params.max_bounds, self.min_bounds, self.max_bounds),
], -1)
return scaled_params
else:
scaled_params = torch.cat([
self._scale(params.parameters, params.min_bounds, params.max_bounds), #each parameter scaled with respect to its subprior bounds
self._scale(params.min_bounds, self.min_bounds, self.max_bounds), #the subprior bounds are scaled with respect to the parameter ranges
self._scale(params.max_bounds, self.min_bounds, self.max_bounds),
], -1)
return scaled_params
[docs]
def restore_params(self, scaled_params: Tensor) -> BasicParams:
"""restore the parameters to their original range
Args:
scaled_params (Tensor): the scaled parameters
Returns:
BasicParams: the parameters restored to their original range
"""
num_params = scaled_params.shape[-1] // 3
scaled_params, scaled_min_bounds, scaled_max_bounds = torch.split(
scaled_params, num_params, -1
)
if self.scale_params_by_ranges:
min_bounds = self._restore(scaled_min_bounds, self.min_bounds, self.max_bounds)
max_bounds = self._restore(scaled_max_bounds, self.min_bounds, self.max_bounds)
params = self._restore(scaled_params, self.min_bounds, self.max_bounds)
else:
min_bounds = self._restore(scaled_min_bounds, self.min_bounds, self.max_bounds)
max_bounds = self._restore(scaled_max_bounds, self.min_bounds, self.max_bounds)
params = self._restore(scaled_params, min_bounds, max_bounds)
return BasicParams(
parameters=params,
min_bounds=min_bounds,
max_bounds=max_bounds,
max_num_layers=self.max_num_layers,
param_model=self.param_model,
)
def scale_bounds(self, bounds: Tensor) -> Tensor:
return self._scale(bounds, self.min_bounds, self.max_bounds)
def log_prob(self, params: BasicParams) -> Tensor:
log_prob = torch.zeros(params.batch_size, device=self.device, dtype=self.dtype)
log_prob[~self.get_indices_within_bounds(params)] = -float('inf')
return log_prob
def get_indices_within_domain(self, params: BasicParams) -> Tensor:
return self.get_indices_within_bounds(params)
def get_indices_within_bounds(self, params: BasicParams) -> Tensor:
return (
torch.all(params.parameters >= params.min_bounds, -1) &
torch.all(params.parameters <= params.max_bounds, -1)
)
def filter_params(self, params: BasicParams) -> BasicParams:
indices = self.get_indices_within_domain(params)
return params[indices]
def clamp_params(
self, params: BasicParams, inplace: bool = False
) -> BasicParams:
if inplace:
params.parameters = torch.clamp_(params.parameters, params.min_bounds, params.max_bounds)
return params
return BasicParams(
parameters=torch.clamp(params.parameters, params.min_bounds, params.max_bounds),
min_bounds=params.min_bounds.clone(),
max_bounds=params.max_bounds.clone(),
max_num_layers=self.max_num_layers,
param_model=self.param_model,
)