#!python
import math
import os, sys
import matplotlib.pyplot as plt
from numpy import ndarray
from matplotlib.ticker import MaxNLocator

if __name__ == '__main__':
    arguments = sys.argv[1:]
    loss_file = arguments[0]

    if not os.path.exists(loss_file):
        print(f'{loss_file} not found')
        exit(0)

    results = {}

    # ====epoch: {epoch}, start train {file_name}====
    # [time] keys_key1: keys_value1, keys_key2: keys_value2 -> values_key1: values_value1, values_key2: values_value2
    with open(loss_file, 'r') as f:
        for line in f:
            if '====' in line:
                continue

            # values_key1: values_value1, values_key2: values_value2
            values_kvs = line.split(' -> ')[1].split(', ')
            for values_kv in values_kvs:
                k, v = values_kv.split(': ')
                if k not in results:
                    results[k] = [float(v.strip())]
                else:
                    results[k].append(float(v.strip()))

    results_size = len(results.keys())
    if results_size <= 4:
        rows = 1
        cols = results_size
    else:
        rows = math.ceil(results_size / 4)
        cols = 4

    fig, axes = plt.subplots(nrows=rows, ncols=cols, figsize=(4 * cols, 4 * rows))

    if isinstance(axes, ndarray):
        axes = axes.flatten()
    else:
        axes = [axes]

    for idx, title in enumerate(results.keys()):
        ax = axes[idx]
        y = results[title]
        x = list(range(len(y)))

        ax.plot(x, y)
        ax.set_title(title)

        ax.xaxis.set_major_locator(MaxNLocator(nbins=10))

        ax.tick_params(axis='x', rotation=30)

        ax.set_xlabel("Step")
        ax.set_ylabel(title)

    total_plots = len(results.keys())
    for i in range(total_plots, len(axes)):
        axes[i].set_visible(False)

    plt.tight_layout()
    plt.show()