#!python

if __name__ == '__main__':
    import os, sys
    import matplotlib.pyplot as plt

    arguments = sys.argv[1:]
    loss_file = arguments[0]

    if not os.path.exists(loss_file):
        print(f'{loss_file} not found')
        exit(0)

    steps = []
    losses = []
    with open(loss_file, 'r') as f:
        step = 0
        for line in f:
            if not line or 'loss:' not in line:
                if 'start train' not in line:
                    steps.clear()
                    losses.clear()
                    step = 0
                continue

            # (2025-03-19 20:13:44) epoch: 0, file: 1/1, batch: 623/1099, loss: 0.12186837196350098
            loss = float(line.split('loss:')[-1].strip())

            steps.append(step)
            losses.append(loss)
            step += 1

    plt.xlabel('steps')
    plt.ylabel('loss')

    plt.plot(steps, losses)
    plt.show()
