#!/usr/bin/env python
import os
import numpy as np
from pathlib import Path
import hdf5maker as h5m
from hdf5maker.tools import find_suffix
import hdf5plugin
import h5py
import multiprocessing as mp
import time
from argparse import ArgumentParser
import sys

from importlib.metadata import version #to read version

# Allow for running without X sever available
import matplotlib
matplotlib.use("Agg")


def convert_data_file(fname_in, fname_out, start, stop, fastquad):
    """
    Wrapper to read one data file and convert to hdf5
    """
    t0 = time.time()
    raw = h5m.RawFile(fname_in, fastquad=fastquad)
    raw.seek(start)
    data = raw.read(stop - start)
    t1 = time.time()
    h5m.write_data_file(fname_out, data, image_nr_low=start + 1)
    t2 = time.time()
    return {"reading": t1 - t0, "writing": t2 - t1}


if __name__ == "__main__":
    t_start = time.time()

    parser = h5m.cmd_line.get_argument_parser()
    args = parser.parse_args()

    fname_in = args.input
    args.path_out.mkdir(exist_ok=True)
    fname_out = h5m.get_output_fname(fname_in, args.path_out, args.custom_name)


    try:
        ver = version('hdf5maker')
    except:
        ver = 'developer'

    print(h5m.color.info(f"hdfmaker-{ver} raw2hdf5 converter"))

    print(f"\n=== Converting Eiger Data ===")
    # Open the raw file to find out how much data we have
    if not fname_in.exists():
        print(h5m.color.red(f"Could not find: {fname_in}, exiting."))
        sys.exit(1)

    print(h5m.color.ok(f"Opening: {fname_in}"))
    print(f"{args.fastquad=}")


    # Replace the Total Frames value in the raw file
    # Needed for pre 6.3 since we don't have Frames in File
    h5m.replace_total_frames(fname_in)
    try:
        raw = h5m.RawFile(fname_in, fastquad=args.fastquad)
    except IOError as e:
        print(h5m.color.error(e))
        sys.exit(1)
    num_data_files = len(raw.frames_per_file)
    start = np.zeros(num_data_files + 1, dtype=np.int64)
    start[1:] = np.cumsum(raw.frames_per_file)
    frames_to_read = list(zip(start, start[1:]))

    print(f"Image size: {raw.image_size}")
    print(f"Dynamic range: {raw.dr}")
    print(f"Total frames: {raw.total_frames}")

    # Mask for defective pixels, as a default contains only module gaps
    pixel_mask = raw.module_gaps

    # Look for additional pixels to mask
    if args.bad_pixels is None:
        fname_bp = Path("bad_pixels.txt")
    else:
        fname_bp = args.bad_pixels
    if fname_bp.exists():
        print(h5m.color.ok(f"Loading pixels to mask from: {fname_bp}"))
        try:
            pixels = h5m.read_bad_pixels(fname_bp)
            print(h5m.color.ok(f"Setting pixel_mask true for: {pixels}"))
            for p in pixels:
                pixel_mask[p] = True
        except:
            print(
                h5m.color.warning(
                    f"Could not parse bad pixel file: {fname_bp}, no pixels masked"
                )
            )
    else:
        print(h5m.color.warning(f"No bad pixels file found"))

    if args.graphs:
        print("Image output")
        image = raw.read(1)[0]
        image[pixel_mask] = 0
        h5m.im_plot(f"{fname_out.parent/fname_out.stem}_img", image)
    raw.close()

    # Generate filenames for data out files
    first, _, last = fname_out.stem.rpartition("_master")
    stem = "".join((first, last))
    fname_out_data = [
        fname_out.parent / f"{stem}_{i:06d}.h5" for i in range(num_data_files)
    ]
    if args.sample:
        fname_out_data = fname_out_data[0:1]
        n_frames = (100, raw.total_frames)[raw.total_frames < 100]
        frames_to_read = [(0, n_frames)]
        print(f"Option -s used, converting {frames_to_read}")

    # Run conversion in parallel
    conversion_args = [
        (fname_in, fname, start, stop, args.fastquad)
        for (start, stop), fname in zip(frames_to_read, fname_out_data)
    ]
    pool = mp.Pool(args.num_threads)
    result = pool.starmap(convert_data_file, conversion_args)

    # Try to read Mythen3 IO (could be old .raw or new .json)
    if args.with_i0:
        fname_i0 = Path(f"{raw.master.base}_I0_master_{raw.master.run_id}")
        fname_i0 = find_suffix(fname_i0)
        print(f"\n=== Mythen3 I0 Data ===")
        if fname_i0.exists():
            print(h5m.color.ok(f"Reading I0 from: {fname_i0}"))
            try:
                r = h5m.RawFile(fname_i0)
                data = r.read()
            except IOError as e:
                print(h5m.color.red(f"{e}"))

            bc = Path("bad_channels.txt")
            if bc.exists():
                bad_channels = h5m.read_bad_channels("bad_channels.txt")
                print(h5m.color.ok(f"Masking channels: {bad_channels}"))
                mask_i0 = np.zeros(data.shape[1], bool)
                for c in bad_channels:
                    mask_i0[c] = True
                cm = np.broadcast_to(mask_i0[np.newaxis, :], data.shape)
                data[cm] = 0
            else:
                print(h5m.color.warning("No bad channels file found"))

            i0 = data.sum(axis=1)
            if args.graphs:
                h5m.i0_graph(f"{fname_out.parent/fname_out.stem}_i0", data, i0)

        else:
            print(
                h5m.color.warning(
                    f"Mythen3 I0 file: {fname_i0} not found!\nNo I0 data written to file"
                )
            )
            i0 = None
    else:
        i0 = None

    # Write hdf5 master file
    print(f"\n=== Master File ===")
    h5m.write_master_file(
        fname_out,
        fname_out_data,
        pixel_mask=pixel_mask,
        i0=i0,
        raw_master_file=raw.master,
    )
    t_stop = time.time()

    print(f"Processing took: {t_stop-t_start:.2f}s for {frames_to_read[-1][1]} frames")
