#!python

import os
import sys
import re

# Try to import plotly
try:
    import plotly.graph_objects as go

    HAS_PLOTLY = True
except ImportError:
    HAS_PLOTLY = False

import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator


def parse_lr_file(lr_file):
    if not os.path.exists(lr_file):
        print(f"Error: '{lr_file}' not found.")
        sys.exit(0)

    lrs = {}
    with open(lr_file, 'r') as f:
        for line in f:
            if not line.strip():
                continue
            match = re.search(r'step:\s*(\d+).*?lr:\s*([0-9.\-eE]+)', line)
            if match:
                step = int(match.group(1))
                lr = float(match.group(2))
                lrs[step] = lr

    sorted_data = sorted(lrs.items(), key=lambda item: item[0])
    return sorted_data


def plot_with_plotly(sorted_data, lr_file):
    print("✨ [Plotly is installed] Generating interactive web page using Plotly...")

    x = [item[0] for item in sorted_data]
    y = [item[1] for item in sorted_data]

    fig = go.Figure()
    fig.add_trace(go.Scatter(
        x=x,
        y=y,
        mode='lines',
        name='Learning Rate',
        line=dict(width=2, color='blue')
    ))

    fig.update_layout(
        title=f"Learning Rate Schedule - {os.path.basename(lr_file)}",
        xaxis_title="Steps",
        yaxis_title="Learning Rate",
        hovermode="x unified",
        template="plotly_white",
        yaxis=dict(exponentformat='e')
    )
    # fig.update_yaxes(type="log")

    out_html = f"{lr_file}.html"
    fig.write_html(out_html, auto_open=True)
    print(f"✅ Generation complete! HTML opened in browser automatically. File saved to: {out_html}")


def plot_with_matplotlib(sorted_data, lr_file):  # 增加了 lr_file 参数
    x = [item[0] for item in sorted_data]
    y = [item[1] for item in sorted_data]

    plt.figure(figsize=(10, 6))
    # 标题加入文件名，与 Plotly 保持一致
    plt.title(f'Learning Rate Schedule - {os.path.basename(lr_file)}')
    plt.xlabel("Steps")
    plt.ylabel("Learning Rate")

    ax = plt.gca()
    plt.plot(x, y, linewidth=1.5, color='blue')

    # 【可选】启用 Y 轴对数坐标
    # ax.set_yscale('log')

    ax.xaxis.set_major_locator(MaxNLocator(nbins=10))
    plt.xticks(rotation=30)
    plt.grid(True, linestyle='--', alpha=0.5)

    plt.tight_layout()
    plt.show()


def main():
    arguments = sys.argv[1:]
    if not arguments:
        print("Usage: vis_lr <lr_log_file>")
        sys.exit(1)

    lr_file = arguments[0]
    sorted_data = parse_lr_file(lr_file)

    if not sorted_data:
        print("No valid data found in the log file.")
        sys.exit(0)

    # Decide display method based on dependencies
    if HAS_PLOTLY:
        plot_with_plotly(sorted_data, lr_file)
    else:
        print("⚠️ [Warning] Plotly is not installed. Using Matplotlib for display...")
        # 这里补上 lr_file 传参
        plot_with_matplotlib(sorted_data, lr_file)


def main():
    arguments = sys.argv[1:]
    if not arguments:
        print("Usage: vis_lr <lr_log_file>")
        sys.exit(1)

    lr_file = arguments[0]
    sorted_data = parse_lr_file(lr_file)

    if not sorted_data:
        print("No valid data found in the log file.")
        sys.exit(0)

    # Decide display method based on dependencies
    if HAS_PLOTLY:
        plot_with_plotly(sorted_data, lr_file)
    else:
        print("⚠️ [Warning] Plotly is not installed (recommend using: `pip install plotly`). Using Matplotlib for display...")
        plot_with_matplotlib(sorted_data)


if __name__ == '__main__':
    main()