#!/usr/bin/env python
"""
Parse fastspecfit log files and report timing statistics.

"""
import re
import sys
import argparse
import numpy as np
from pathlib import Path


def main():
    parser = argparse.ArgumentParser(
        description='Parse fastspecfit log files and report timing statistics.')
    parser.add_argument('logdir',
                        help='Top-level directory to search recursively for log files.')
    parser.add_argument('--pattern', default='*.log', metavar='GLOB',
                        help='Glob pattern for log files (default: %(default)s).')
    args = parser.parse_args()

    logdir = Path(args.logdir)
    if not logdir.is_dir():
        print(f'Error: {logdir} is not a directory.', file=sys.stderr)
        sys.exit(1)

    logfiles = sorted(logdir.rglob(args.pattern))
    if not logfiles:
        print(f'No files matching "{args.pattern}" found under {logdir}.', file=sys.stderr)
        sys.exit(1)

    from fastspecfit.util import parse_fsftime

    _nobj_re = re.compile(r'nobj=(\d+)')
    _per_obj_re = re.compile(r'([\d.]+)s/obj/core')

    records = []
    n_skipped = 0

    for logfile in logfiles:
        rec = {}
        try:
            with open(logfile) as f:
                for line in f:
                    parsed = parse_fsftime(line)
                    if parsed is None:
                        continue
                    op = parsed['operation']
                    dur = parsed['duration_sec']
                    ctx = parsed['context'] or ''

                    if op == 'fit_all':
                        rec['fit_all_sec'] = dur
                        m = _nobj_re.search(ctx)
                        if m:
                            rec['nobj'] = int(m.group(1))
                        m = _per_obj_re.search(ctx)
                        if m:
                            rec['per_obj_sec'] = float(m.group(1))
                    elif op == 'read_spectra':
                        rec['read_sec'] = dur
                    elif op == 'write_fastspecfit':
                        rec['write_sec'] = dur

        except Exception as e:
            print(f'Warning: could not read {logfile}: {e}', file=sys.stderr)
            n_skipped += 1
            continue

        if 'fit_all_sec' in rec:
            records.append(rec)
        else:
            n_skipped += 1

    if not records:
        print('No fsftime fit_all entries found.', file=sys.stderr)
        sys.exit(1)

    n = len(records)
    total_obj = sum(r.get('nobj', 0) for r in records)

    def _stats(vals):
        a = np.array([v for v in vals if v is not None], dtype=float)
        if len(a) == 0:
            return None
        return dict(n=len(a), mean=np.mean(a), median=np.median(a),
                    p10=np.percentile(a, 10), p90=np.percentile(a, 90),
                    mn=np.min(a), mx=np.max(a))

    def _fmt_time(v):
        return f'{v/60:.2f} min' if v >= 60 else f'{v:.2f} s'

    def _fmt_rate(v):
        return f'{v:.2f} s'

    def _fmt_count(v):
        return f'{v:,.0f}'

    rows = [
        ('nobj (per file)',    _stats([r.get('nobj')        for r in records]), _fmt_count),
        ('fit_all',           _stats([r.get('fit_all_sec') for r in records]),  _fmt_time),
        ('s/obj/core',        _stats([r.get('per_obj_sec') for r in records]),  _fmt_rate),
        ('read_spectra',      _stats([r.get('read_sec')    for r in records]),  _fmt_time),
        ('write_fastspecfit', _stats([r.get('write_sec')   for r in records]),  _fmt_time),
    ]

    skip_msg = f', {n_skipped} skipped' if n_skipped else ''
    print(f'\nParsed {n} log files{skip_msg} under {logdir}')
    print(f'Total objects: {total_obj:,}')
    print()

    hdr = f'{"Operation":<22}  {"N":>5}  {"Mean":>10}  {"Median":>10}  {"P10":>10}  {"P90":>10}  {"Min":>10}  {"Max":>10}'
    print(hdr)
    print('-' * len(hdr))

    for label, stats, fmt_fn in rows:
        if stats is None:
            continue
        cols = [fmt_fn(stats[k]) for k in ('mean', 'median', 'p10', 'p90', 'mn', 'mx')]
        print(f'{label:<22}  {stats["n"]:>5}  {cols[0]:>10}  {cols[1]:>10}  {cols[2]:>10}  {cols[3]:>10}  {cols[4]:>10}  {cols[5]:>10}')

    print()


if __name__ == '__main__':
    main()
