Source code for reflectorch.data_generation.priors.base
# -*- coding: utf-8 -*-
#
#
# This source code is licensed under the GPL license found in the
# LICENSE file in the root directory of this source tree.
from torch import Tensor
from reflectorch.data_generation.priors.params import Params
__all__ = [
"PriorSampler",
]
[docs]
class PriorSampler(object):
"""Base class for prior samplers"""
PARAM_CLS = Params
@property
def param_dim(self) -> int:
"""gets the number of parameters (i.e. the parameter dimensionality)"""
return self.PARAM_CLS.layers_num2size(self.max_num_layers)
@property
def max_num_layers(self) -> int:
"""gets the number of layers"""
raise NotImplementedError
[docs]
def sample(self, batch_size: int) -> Params:
"""sample a batch of parameters"""
raise NotImplementedError
[docs]
def scale_params(self, params: Params) -> Tensor:
"""scale the parameters to a ML-friendly range"""
raise NotImplementedError
[docs]
def restore_params(self, scaled_params: Tensor) -> Params:
"""restore the parameters to their original range"""
raise NotImplementedError
def log_prob(self, params: Params) -> Tensor:
raise NotImplementedError
def get_indices_within_domain(self, params: Params) -> Tensor:
raise NotImplementedError
def get_indices_within_bounds(self, params: Params) -> Tensor:
raise NotImplementedError
def filter_params(self, params: Params) -> Params:
indices = self.get_indices_within_domain(params)
return params[indices]
def clamp_params(self, params: Params) -> Params:
raise NotImplementedError
def __repr__(self):
args = ', '.join(f'{k}={str(v)[:10]}' for k, v in vars(self).items())
return f'{self.__class__.__name__}({args})'