Source code for ASTROMER.core.metrics
import tensorflow as tf
[docs]@tf.function
def custom_acc(y_true, y_pred):
if len(tf.shape(y_pred)) > 2:
y_pred = tf.nn.softmax(y_pred)[:,-1,:]
else:
y_pred = tf.nn.softmax(y_pred)
y_true = tf.reshape(y_true, [-1, 1])
y_pred = tf.argmax(y_pred, 1, output_type=tf.int32)
y_pred = tf.expand_dims(y_pred, 1)
correct = tf.math.equal(y_true, y_pred)
correct = tf.cast(correct, tf.float32)
return tf.reduce_mean(correct)