Source code for ASTROMER.core.astromer

import tensorflow as tf
import logging
import os, sys

from ASTROMER.core.output    import RegLayer
from ASTROMER.core.tboard    import save_scalar, draw_graph
from ASTROMER.core.losses    import custom_rmse, custom_bce
from ASTROMER.core.metrics   import custom_acc
from ASTROMER.core.encoder   import Encoder

from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import Model
from tqdm import tqdm

logging.getLogger('tensorflow').setLevel(logging.ERROR)  # suppress warnings
os.system('clear')

[docs]def get_ASTROMER(num_layers=2, d_model=200, num_heads=2, dff=256, base=10000, dropout=0.1, use_leak=False, no_train=True, maxlen=100, batch_size=None): serie = Input(shape=(maxlen, 1), batch_size=None, name='input') times = Input(shape=(maxlen, 1), batch_size=None, name='times') mask = Input(shape=(maxlen, 1), batch_size=None, name='mask') length = Input(shape=(maxlen,), batch_size=None, dtype=tf.int32, name='length') placeholder = {'input':serie, 'mask_in':mask, 'times':times, 'length':length} encoder = Encoder(num_layers, d_model, num_heads, dff, base=base, rate=dropout, use_leak=use_leak, name='encoder') if no_train: encoder.trainable = False x = encoder(placeholder) x = RegLayer(name='regression')(x) return Model(inputs=placeholder, outputs=x, name="ASTROMER")
[docs]@tf.function def train_step(model, batch, opt): with tf.GradientTape() as tape: x_pred = model(batch) mse = custom_rmse(y_true=batch['output'], y_pred=x_pred, mask=batch['mask_out']) grads = tape.gradient(mse, model.trainable_weights) opt.apply_gradients(zip(grads, model.trainable_weights)) return mse
[docs]@tf.function def valid_step(model, batch, return_pred=False, normed=False): with tf.GradientTape() as tape: x_pred = model(batch) x_true = batch['output'] mse = custom_rmse(y_true=x_true, y_pred=x_pred, mask=batch['mask_out']) if return_pred: return mse, x_pred, x_true return mse
[docs]def train(model, train_dataset, valid_dataset, patience=20, exp_path='./experiments/test', epochs=1, finetuning=False, use_random=True, num_cls=2, lr=1e-3, verbose=1): os.makedirs(exp_path, exist_ok=True) # Tensorboard train_writter = tf.summary.create_file_writer( os.path.join(exp_path, 'logs', 'train')) valid_writter = tf.summary.create_file_writer( os.path.join(exp_path, 'logs', 'valid')) batch = [t for t in train_dataset.take(1)][0] draw_graph(model, batch, train_writter, exp_path) # Optimizer optimizer = tf.keras.optimizers.Adam(lr, beta_1=0.9, beta_2=0.98, epsilon=1e-9) # To save metrics train_mse = tf.keras.metrics.Mean(name='train_mse') valid_mse = tf.keras.metrics.Mean(name='valid_mse') # Training Loop best_loss = 999999. es_count = 0 pbar = tqdm(range(epochs), desc='epoch') for epoch in pbar: for train_batch in train_dataset: mse = train_step(model, train_batch, optimizer) train_mse.update_state(mse) for valid_batch in valid_dataset: mse = valid_step(model, valid_batch) valid_mse.update_state(mse) msg = 'EPOCH {} - ES COUNT: {}/{} train mse: {:.4f} - val mse: {:.4f}'.format(epoch, es_count, patience, train_mse.result(), valid_mse.result()) pbar.set_description(msg) save_scalar(train_writter, train_mse, epoch, name='mse') save_scalar(valid_writter, valid_mse, epoch, name='mse') if valid_mse.result() < best_loss: best_loss = valid_mse.result() es_count = 0. model.save_weights(os.path.join(exp_path, 'weights')) else: es_count+=1. if es_count == patience: print('[INFO] Early Stopping Triggered') break train_mse.reset_states() valid_mse.reset_states()
[docs]def predict(model, dataset, conf, predic_proba=False): total_mse, inputs, reconstructions = [], [], [] masks, times = [], [] for step, batch in tqdm(enumerate(dataset), desc='prediction'): mse, x_pred, x_true = valid_step(model, batch, return_pred=True, normed=True) total_mse.append(mse) times.append(batch['times']) inputs.append(x_true) reconstructions.append(x_pred) masks.append(batch['mask_out']) res = {'mse':tf.reduce_mean(total_mse).numpy(), 'x_pred': tf.concat(reconstructions, 0), 'x_true': tf.concat(inputs, 0), 'mask': tf.concat(masks, 0), 'time': tf.concat(times, 0)} return res