#!python

import os, sys
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator

if __name__ == '__main__':
    arguments = sys.argv[1:]
    if not arguments:
        print("Usage: python3 script.py <lr_log_file>")
        exit(1)

    lr_file = arguments[0]

    if not os.path.exists(lr_file):
        print(f'{lr_file} not found')
        exit(0)

    lrs = {}
    with open(lr_file, 'r') as f:
        for line in f:
            if not line:
                continue
            try:
                data = line.split('step: ')[-1]
                data = data.split(', lr:')
                step = int(data[0].strip())
                lr = float(data[1].strip())

                lrs[step] = lr
            except (IndexError, ValueError):
                continue

    sorted_data = sorted(lrs.items(), key=lambda item: item[0])

    if not sorted_data:
        print("No valid data found.")
        exit(0)

    x = [item[0] for item in sorted_data]
    y = [item[1] for item in sorted_data]

    plt.figure(figsize=(10, 6))
    plt.title('Learning Rate')
    plt.xlabel("Steps")
    plt.ylabel("Learning Rate")

    ax = plt.gca()
    plt.plot(x, y, linewidth=1.5)

    ax.xaxis.set_major_locator(MaxNLocator(nbins=10))
    plt.xticks(rotation=30)
    plt.grid(True, linestyle='--', alpha=0.5)

    plt.tight_layout()
    plt.show()