# -*- coding: utf-8 -*-
#
#
# This source code is licensed under the GPL license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from reflectorch.data_generation import BATCH_DATA_TYPE
from reflectorch.ml.basic_trainer import Trainer
from reflectorch.ml.dataloaders import ReflectivityDataLoader
__all__ = [
'RealTimeSimTrainer',
'DenoisingAETrainer',
'VAETrainer',
'PointEstimatorTrainer',
]
[docs]
class RealTimeSimTrainer(Trainer):
"""Trainer with functionality to customize the sampled batch of data"""
loader: ReflectivityDataLoader
[docs]
def get_batch_by_idx(self, batch_num: int):
"""Gets a batch of data with the default batch size"""
batch_data = self.loader.get_batch(self.batch_size)
return self._get_batch(batch_data)
[docs]
def get_batch_by_size(self, batch_size: int):
"""Gets a batch of data with a custom batch size"""
batch_data = self.loader.get_batch(batch_size)
return self._get_batch(batch_data)
def _get_batch(self, batch_data: BATCH_DATA_TYPE):
"""Modify the batch of data sampled from the data loader"""
raise NotImplementedError
[docs]
class PointEstimatorTrainer(RealTimeSimTrainer):
"""Trainer for the regression inverse problem with incorporation of prior bounds"""
add_sigmas_to_context: bool = False
def _get_batch(self, batch_data: BATCH_DATA_TYPE):
scaled_params = batch_data['scaled_params'].to(torch.float32)
scaled_curves = batch_data['scaled_noisy_curves'].to(torch.float32)
if self.train_with_q_input:
q_values = batch_data['q_values'].to(torch.float32)
scaled_q_values = self.loader.q_generator.scale_q(q_values)
else:
scaled_q_values = None
num_params = scaled_params.shape[-1] // 3
assert num_params * 3 == scaled_params.shape[-1]
scaled_params, scaled_bounds = torch.split(scaled_params, [num_params, 2 * num_params], dim=-1)
return scaled_params, scaled_bounds, scaled_curves, scaled_q_values
[docs]
def get_loss_dict(self, batch_data):
"""computes the loss dictionary"""
scaled_params, scaled_bounds, scaled_curves, scaled_q_values = batch_data
if self.train_with_q_input:
predicted_params = self.model(scaled_curves, scaled_bounds, scaled_q_values)
else:
predicted_params = self.model(scaled_curves, scaled_bounds)
loss = self.mse(predicted_params, scaled_params)
return {'loss': loss}
def init(self):
self.mse = nn.MSELoss()
class DenoisingAETrainer(RealTimeSimTrainer):
"""Trainer which can be used for training a denoising autoencoder model. Overrides _get_batch and get_loss_dict methods """
def init(self):
self.criterion = nn.MSELoss()
self.loader.calc_denoised_curves = True
def _get_batch(self, batch_data: BATCH_DATA_TYPE):
"""returns scaled curves with and without noise"""
scaled_noisy_curves, curves = batch_data['scaled_noisy_curves'], batch_data['curves']
scaled_curves = self.loader.curves_scaler.scale(curves)
scaled_noisy_curves, scaled_curves = scaled_noisy_curves.to(torch.float32), scaled_curves.to(torch.float32)
return scaled_noisy_curves, scaled_curves
def get_loss_dict(self, batch_data):
"""returns the reconstruction loss of the autoencoder"""
scaled_noisy_curves, scaled_curves = batch_data
restored_curves = self.model(scaled_noisy_curves)
loss = self.criterion(scaled_curves, restored_curves)
return {'loss': loss}
class VAETrainer(DenoisingAETrainer):
"""Trainer which can be used for training a denoising autoencoder model. Overrides _get_batch and get_loss_dict methods """
def init(self):
self.loader.calc_denoised_curves = True
self.freebits = 0.05
def calc_kl(self, z_mu, z_logvar):
return 0.5*(z_mu**2 + torch.exp(z_logvar) - 1 - z_logvar)
def gaussian_log_prob(self, z, mu, logvar):
return -0.5*(np.log(2*np.pi) + logvar + (z-mu)**2/torch.exp(logvar))
def get_loss_dict(self, batch_data):
"""returns the reconstruction loss of the autoencoder"""
scaled_noisy_curves, scaled_curves = batch_data
_, (z_mu, z_logvar, restored_curves_mu, restored_curves_logvar) = self.model(scaled_noisy_curves)
l_rec = -torch.mean(self.gaussian_log_prob(scaled_curves, restored_curves_mu, restored_curves_logvar), dim=-1)
l_kl = torch.mean(F.relu(self.calc_kl(z_mu, z_logvar) - self.freebits*np.log(2)) + self.freebits*np.log(2), dim=-1)
loss = torch.mean(l_rec + l_kl)/np.log(2)
l_rec = torch.mean(l_rec)
l_kl = torch.mean(l_kl)
return {'loss': loss}