#!python

import math
import os, sys
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import re


def parse_log_file(loss_file):
    data_map = {}
    all_metric_keys = []

    if not os.path.exists(loss_file):
        print(f'{loss_file} not found')
        sys.exit(0)

    with open(loss_file, 'r') as f:
        for line in f:
            if '====' in line:
                continue

            try:
                # 解析格式: ... epoch: 0, file: 1/4, batch: 3/7571 -> loss: ...
                if '->' not in line:
                    continue

                meta_part, values_part = line.split(' -> ')

                epoch_match = re.search(r'epoch:\s*(\d+)', meta_part)
                file_match = re.search(r'file:\s*(\d+)', meta_part)
                batch_match = re.search(r'batch:\s*(\d+)', meta_part)

                if not (epoch_match and file_match and batch_match):
                    continue

                epoch = int(epoch_match.group(1))
                file_idx = int(file_match.group(1))
                batch_idx = int(batch_match.group(1))

                sort_key = (epoch, file_idx, batch_idx)

                current_metrics = {}
                values_kvs = values_part.split(', ')
                for values_kv in values_kvs:
                    k, v = values_kv.split(': ')
                    val = float(v.strip())
                    current_metrics[k] = val

                    if k not in all_metric_keys:
                        all_metric_keys.append(k)

                data_map[sort_key] = current_metrics

            except:
                continue

    return data_map, all_metric_keys


def main():
    arguments = sys.argv[1:]
    if not arguments:
        print("Usage: python3 plot_log.py <log_file>")
        sys.exit(1)

    loss_file = arguments[0]
    data_map, all_metric_keys = parse_log_file(loss_file)

    if not data_map:
        print("No valid data found.")
        sys.exit(0)

    sorted_keys = sorted(data_map.keys())
    results = {k: [] for k in all_metric_keys}
    separator_indices = []
    prev_key = sorted_keys[0]

    for i, key in enumerate(sorted_keys):
        metrics = data_map[key]
        for k in all_metric_keys:
            if k in metrics:
                results[k].append(metrics[k])
            else:
                pass

        curr_epoch, curr_file, _ = key
        prev_epoch, prev_file, _ = prev_key

        if curr_epoch != prev_epoch:
            separator_indices.append((i, 'epoch', f"Ep {curr_epoch}"))
        elif curr_file != prev_file:
            separator_indices.append((i, 'file', f"F {curr_file}"))

        prev_key = key

    results_size = len(results.keys())
    cols = 4 if results_size >= 4 else results_size
    rows = math.ceil(results_size / cols)

    fig, axes = plt.subplots(nrows=rows, ncols=cols, figsize=(5 * cols, 4 * rows))

    if results_size == 1:
        axes = [axes]
    elif isinstance(axes, (list, tuple)):
        pass
    else:
        axes = axes.flatten()

    for idx, metric_name in enumerate(results.keys()):
        if idx >= len(axes): break

        ax = axes[idx]
        y = results[metric_name]
        x = list(range(len(y)))

        ax.plot(x, y, linewidth=1.0, label=metric_name)

        for sep_idx, sep_type, sep_label in separator_indices:
            if sep_type == 'epoch':
                ax.axvline(x=sep_idx, color='red', linestyle='--', linewidth=1.5, alpha=0.8)
                if idx == 0:
                    ax.text(sep_idx, ax.get_ylim()[1], sep_label, rotation=90, verticalalignment='top', color='red',
                            fontsize=8)
            elif sep_type == 'file':
                ax.axvline(x=sep_idx, color='green', linestyle=':', linewidth=1.0, alpha=0.6)
                if idx == 0:
                    ax.text(sep_idx, ax.get_ylim()[1], sep_label, rotation=90, verticalalignment='top', color='green',
                            fontsize=8)

        ax.set_title(metric_name)
        ax.xaxis.set_major_locator(MaxNLocator(nbins=10))
        ax.tick_params(axis='x', rotation=30)
        ax.set_xlabel("Steps")
        ax.grid(True, linestyle='--', alpha=0.3)

    for i in range(results_size, len(axes)):
        axes[i].set_visible(False)

    plt.tight_layout()
    plt.show()


if __name__ == '__main__':
    main()