#!python

import math
import os
import sys
import re

# Try to import plotly
try:
    import plotly.graph_objects as go
    from plotly.subplots import make_subplots

    HAS_PLOTLY = True
except ImportError:
    HAS_PLOTLY = False

import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator


def parse_log_file(loss_file):
    data_map = {}
    all_metric_keys = []

    if not os.path.exists(loss_file):
        print(f'{loss_file} not found')
        sys.exit(0)

    with open(loss_file, 'r') as f:
        for line in f:
            if '====' in line:
                continue

            try:
                if '->' not in line:
                    continue

                meta_part, values_part = line.split(' -> ')

                epoch_match = re.search(r'epoch:\s*(\d+)', meta_part)
                file_match = re.search(r'file:\s*(\d+)', meta_part)
                batch_match = re.search(r'batch:\s*(\d+)', meta_part)

                if not (epoch_match and file_match and batch_match):
                    continue

                epoch = int(epoch_match.group(1))
                file_idx = int(file_match.group(1))
                batch_idx = int(batch_match.group(1))

                sort_key = (epoch, file_idx, batch_idx)

                current_metrics = {}
                values_kvs = values_part.split(', ')
                for values_kv in values_kvs:
                    k, v = values_kv.split(': ')
                    val = float(v.strip())
                    current_metrics[k] = val

                    if k not in all_metric_keys:
                        all_metric_keys.append(k)

                data_map[sort_key] = current_metrics

            except:
                continue

    return data_map, all_metric_keys


def plot_with_plotly(results, separator_indices, loss_file):
    print("✨[Plotly is installed] Generating interactive web page using Plotly...")
    metric_names = list(results.keys())
    results_size = len(metric_names)

    # Create subplots, vertically stacked
    fig = make_subplots(rows=results_size, cols=1, subplot_titles=metric_names, vertical_spacing=0.08)

    for idx, metric_name in enumerate(metric_names):
        y = results[metric_name]
        x = list(range(len(y)))

        fig.add_trace(go.Scatter(x=x, y=y, mode='lines', name=metric_name), row=idx + 1, col=1)

        # Add vertical separator lines
        for sep_idx, sep_type, sep_label in separator_indices:
            if sep_type == 'epoch':
                fig.add_vline(x=sep_idx, line_dash="dash", line_color="red", opacity=0.5, row=idx + 1, col=1)
                # Only add text labels to the first row to prevent clutter
                if idx == 0:
                    fig.add_annotation(x=sep_idx, y=1, yref=f"y domain", text=sep_label, showarrow=False,
                                       textangle=-90, xanchor='right', font=dict(color="red", size=10), row=idx + 1,
                                       col=1)
            elif sep_type == 'file':
                fig.add_vline(x=sep_idx, line_dash="dot", line_color="green", opacity=0.3, row=idx + 1, col=1)

    # Dynamically adjust web page height based on the number of metrics
    fig.update_layout(height=250 * results_size, title_text=f"Training Metrics - {loss_file}", showlegend=False)

    fig.update_yaxes(exponentformat='e')

    out_html = f"{loss_file}.html"
    fig.write_html(out_html, auto_open=True)
    print(f"✅ Generation complete! HTML opened in browser automatically. File saved to: {out_html}")


def plot_with_matplotlib(results, separator_indices):
    metric_names = list(results.keys())
    results_size = len(metric_names)

    # Set maximum charts per window to 6 (2 cols x 3 rows) to prevent cropping
    MAX_PER_FIG = 6
    num_figs = math.ceil(results_size / MAX_PER_FIG)

    for fig_idx in range(num_figs):
        start_idx = fig_idx * MAX_PER_FIG
        end_idx = min(start_idx + MAX_PER_FIG, results_size)
        subset_metrics = metric_names[start_idx:end_idx]
        subset_size = len(subset_metrics)

        cols = 2 if subset_size > 1 else 1
        rows = math.ceil(subset_size / cols)

        fig, _ = plt.subplots(nrows=rows, ncols=cols, figsize=(6 * cols, 4 * rows))
        fig.canvas.manager.set_window_title(f'Metrics Part {fig_idx + 1}')

        axes = fig.axes

        for idx, metric_name in enumerate(subset_metrics):
            ax = axes[idx]
            y = results[metric_name]
            x = list(range(len(y)))

            ax.plot(x, y, linewidth=1.0, label=metric_name)

            for sep_idx, sep_type, sep_label in separator_indices:
                if sep_type == 'epoch':
                    ax.axvline(x=sep_idx, color='red', linestyle='--', linewidth=1.5, alpha=0.8)
                    if idx == 0:
                        ax.text(sep_idx, ax.get_ylim()[1], sep_label, rotation=90, verticalalignment='top', color='red',
                                fontsize=8)
                elif sep_type == 'file':
                    ax.axvline(x=sep_idx, color='green', linestyle=':', linewidth=1.0, alpha=0.6)
                    if idx == 0:
                        ax.text(sep_idx, ax.get_ylim()[1], sep_label, rotation=90, verticalalignment='top',
                                color='green', fontsize=8)

            ax.set_title(metric_name)
            ax.xaxis.set_major_locator(MaxNLocator(nbins=10))
            ax.tick_params(axis='x', rotation=30)
            ax.set_xlabel("Steps")
            ax.grid(True, linestyle='--', alpha=0.3)

        for i in range(subset_size, len(axes)):
            axes[i].set_visible(False)

        plt.tight_layout()

    plt.show()


def main():
    arguments = sys.argv[1:]
    if not arguments:
        print("Usage: vis_log <log_file>")
        sys.exit(1)

    loss_file = arguments[0]
    data_map, all_metric_keys = parse_log_file(loss_file)

    if not data_map:
        print("No valid data found in the log file.")
        sys.exit(0)

    # Data formatting preparation
    sorted_keys = sorted(data_map.keys())
    results = {k: [] for k in all_metric_keys}
    separator_indices = []

    if sorted_keys:
        prev_key = sorted_keys[0]
        for i, key in enumerate(sorted_keys):
            metrics = data_map[key]
            for k in all_metric_keys:
                if k in metrics:
                    results[k].append(metrics[k])

            curr_epoch, curr_file, _ = key
            prev_epoch, prev_file, _ = prev_key

            if curr_epoch != prev_epoch:
                separator_indices.append((i, 'epoch', f"Ep {curr_epoch}"))
            elif curr_file != prev_file:
                separator_indices.append((i, 'file', f"F {curr_file}"))

            prev_key = key

    # Decide display method based on dependencies
    if HAS_PLOTLY:
        plot_with_plotly(results, separator_indices, loss_file)
    else:
        print("⚠️ [Warning] Plotly is not installed (recommend using: `pip install plotly`). Downgrading to Matplotlib multi-window display...")
        plot_with_matplotlib(results, separator_indices)


if __name__ == '__main__':
    main()