#!python

import math
import os, sys
import matplotlib.pyplot as plt
from numpy import ndarray
from matplotlib.ticker import MaxNLocator
import re

if __name__ == '__main__':
    arguments = sys.argv[1:]
    loss_file = arguments[0]

    if not os.path.exists(loss_file):
        print(f'{loss_file} not found')
        exit(0)

    data_map = {}
    all_metric_keys = []

    with open(loss_file, 'r') as f:
        for line in f:
            if '====' in line:
                continue

            try:
                meta_part, values_part = line.split(' -> ')

                epoch = int(re.search(r'epoch:\s*(\d+)', meta_part).group(1))
                file_str = re.search(r'file:\s*(\d+)', meta_part).group(1)
                file_idx = int(file_str)
                batch_str = re.search(r'batch:\s*(\d+)', meta_part).group(1)
                batch_idx = int(batch_str)

                sort_key = (epoch, file_idx, batch_idx)

                current_metrics = {}
                values_kvs = values_part.split(', ')
                for values_kv in values_kvs:
                    k, v = values_kv.split(': ')
                    val = float(v.strip())
                    current_metrics[k] = val

                    if k not in all_metric_keys:
                        all_metric_keys.append(k)

                data_map[sort_key] = current_metrics

            except Exception as e:
                continue

    sorted_keys = sorted(data_map.keys())
    results = {k: [] for k in all_metric_keys}

    for key in sorted_keys:
        metrics = data_map[key]
        for k in all_metric_keys:
            if k in metrics:
                results[k].append(metrics[k])

    if not results:
        print("No valid data found.")
        exit(0)

    results_size = len(results.keys())
    if results_size <= 4:
        rows = 1
        cols = results_size
    else:
        rows = math.ceil(results_size / 4)
        cols = 4

    fig, axes = plt.subplots(nrows=rows, ncols=cols, figsize=(4 * cols, 4 * rows))

    if isinstance(axes, ndarray):
        axes = axes.flatten()
    else:
        axes = [axes]

    for idx, title in enumerate(results.keys()):
        ax = axes[idx]
        y = results[title]
        x = list(range(len(y)))

        ax.plot(x, y)
        ax.set_title(title)

        ax.xaxis.set_major_locator(MaxNLocator(nbins=10))
        ax.tick_params(axis='x', rotation=30)
        ax.set_xlabel("Step")
        ax.set_ylabel(title)

    total_plots = len(results.keys())
    for i in range(total_plots, len(axes)):
        axes[i].set_visible(False)

    plt.tight_layout()
    plt.show()