#!/usr/bin/env python
#
# See top-level LICENSE.rst file for Copyright information
#
# -*- coding: utf-8 -*-


"""
This script computes the reference n(z) and per-fiber TSNR2, used by the per-cumulative-tile QA.
"""


import os, sys
from glob import glob
import fitsio
import argparse
from astropy.table import Table, vstack
import fitsio
import numpy as np
import yaml
from importlib import resources
from desispec.tile_qa_plot import (
    get_qa_config,
    get_tracer_nz_program,
    get_zbins,
    get_zhists,
    get_qa_badmsks,
)
from desitarget.geomask import match_to
import matplotlib.pyplot as plt
from desiutil.log import get_logger


def parse(options=None):
    parser = argparse.ArgumentParser(
        description="Generate the reference n(z) and TSNR2 2d-map for tile QA."
    )
    parser.add_argument(
        "--outdir",
        type=str,
        default=None,
        required=False,
        help="Path to ouput directory, default is desispec/data/qa",
    )
    parser.add_argument("--plot", action="store_true", help="produce control plots?")
    args = parser.parse_args()
    return args


def read_ref_data(program):
    """
    Read reference data to get the reference n(z).
    """

    config = get_qa_config()
    all_programs = []
    for tracer in config["tile_qa_plot"]["tracers"]:
        all_programs += config["tile_qa_plot"]["tracers"][tracer]["program"].split(",")
    assert program in all_programs

    # AR DARK1B / LGE: pick three daily tiles
    if program == "DARK1B":
        specprod = "daily"
        tileids = [100803, 100808, 100826]
        lastnights = [20250506, 20250506, 20250506]
    # AR other: pick the same tiles as originally in everest
    # AR but now in loa
    else:
        specprod = "loa"
        fn = os.path.join(
            os.getenv("DESI_ROOT"),
            "spectro", "redux", specprod,
            "tiles-{}.csv".format(specprod)
        )
        t = Table.read(fn)
        sel = t["LASTNIGHT"] <= 20210709
        sel &= (t["SURVEY"] == "main") & (t["PROGRAM"] == program.lower())
        fn = os.path.join(
            os.getenv("DESI_SURVEYOPS"),
            "ops",
            "tiles-main.ecsv"
        )
        t2 = Table.read(fn)
        sel2 = (t2["PROGRAM"] == program) & (t2["PASS"] == 0)
        t2 = t2[sel2]
        sel &= np.in1d(t["TILEID"], t2["TILEID"])
        tileids, lastnights = t["TILEID"][sel], t["LASTNIGHT"][sel]

    fns = [
        os.path.join(
            os.getenv("DESI_ROOT"),
            "spectro", "redux", specprod, "tiles", "cumulative",
            str(tileid), str(lastnight),
            "tile-qa-{}-thru{}.fits".format(tileid, lastnight)
        ) for tileid, lastnight in zip(tileids, lastnights)
    ]
    ds = []
    for fn in fns:

        tileid = fitsio.read_header(fn, "FIBERQA")["TILEID"]
        tileidpad = "{:06d}".format(tileid)

        d = Table(fitsio.read(fn, "FIBERQA"))
        d["TILEID"] = tileid

        # AR add some fiberassign keys
        fa_fn = os.path.join(
            os.getenv("DESI_TARGET"),
            "fiberassign", "tiles", "trunk", tileidpad[:3],
            "fiberassign-{}.fits.gz".format(tileidpad)
        )
        fa_d = Table(fitsio.read(fa_fn, "FIBERASSIGN"))

        # AR there s been some fiber swapping at some point (3402, 3429)...
        # AR and it can happen that a tile-qa file only has 9 or less petals..
        # AR - match by targetid, to propagate per-target info
        # AR - match by fiber, to propagate per-target info
        ii = match_to(fa_d["TARGETID"], d["TARGETID"])
        assert np.all(fa_d["TARGETID"][ii] == d["TARGETID"])
        for key in ["PRIORITY"]:
            d[key] = fa_d[key][ii]
        ii = match_to(fa_d["FIBER"], d["FIBER"])
        assert np.all(fa_d["FIBER"][ii] == d["FIBER"])
        for key in ["FIBERASSIGN_X", "FIBERASSIGN_Y"]:
            d[key] = fa_d[key][ii]

        ds.append(d)
    d = vstack(ds)

    return d


def main():

    args = parse()
    log = get_logger()

    # AR outdir
    if args.outdir is None:
        args.outdir = resources.files("desispec").joinpath("data/qa")

    # AR output files
    outroot_nz = os.path.join(args.outdir, "qa-reference-nz")
    outroot_tsnr2 = os.path.join(args.outdir, "qa-reference-tsnr2")

    # AR config params
    config = get_qa_config()
    dchi2_min = config["tile_qa_plot"]["dchi2_min"]
    tsnr2_key = config["tile_qa_plot"]["tsnr2_key"]

    # AR n(z): tracers
    tracers = list(config["tile_qa_plot"]["tracers"].keys())

    # AR reference data (from loa + 3 daily tiles for dark1b)
    # AR pick the first program for each tracer
    programs = np.unique([get_tracer_nz_program(tracer) for tracer in tracers])
    ref_d = {program: read_ref_data(program) for program in programs}

    # AR n(z): redshift grid
    bins = get_zbins()
    nbin = len(bins) - 1

    # AR n(z): looping over tracers
    all_nz = []
    for tracer in tracers:

        # AR pick the first listed program here
        program = get_tracer_nz_program(tracer)
        d = ref_d[program].copy()

        # AR making hist for each TILEID
        tileids = np.unique(d["TILEID"])
        _, zhists = get_zhists(tileids, tracer, dchi2_min, d, fstatus_key="QAFIBERSTATUS", tileid_key="TILEID", nolya=True)

        # AR compressing stats for all TILEIDs
        nz = Table()
        nz["TRACER"] = [tracer for _ in range(nbin)]
        nz["ZMIN"], nz["ZMAX"] = bins[:-1], bins[1:]
        nz["N_MEAN"] = zhists.mean(axis=1).round(4)
        nz["N_MEAN_STD"] = zhists.std(axis=1).round(
            4
        )  # AR I think it s correct to *not* divide by sqrt(ntile)
        all_nz.append(nz)

    nz = vstack(all_nz)

    # AR n(z): writing to ecsv
    hdr = fitsio.FITSHDR()
    hdr["DCHI2MIN"] = dchi2_min
    nz.meta = dict(hdr)
    nz.write("{}.ecsv".format(outroot_nz), format="ascii.ecsv", overwrite=True)

    # AR n(z): control plot
    if args.plot:
        fig, ax = plt.subplots()
        for tracer in tracers:
            sel = nz["TRACER"] == tracer
            x = 0.5 * (nz["ZMIN"][sel] + nz["ZMAX"][sel])
            y = nz["N_MEAN"][sel]
            ye = nz["N_MEAN_STD"][sel]
            p = ax.plot(x, y, label=tracer)
            ax.fill_between(x, y - ye, y + ye, color=p[0].get_color(), alpha=0.3)
        ax.legend()
        ax.grid(True)
        ax.set_xlabel("Z")
        ax.set_ylabel("Per tile fractional count")
        ax.set_xlim(-0.1, 6)
        ax.set_ylim(0, 0.25)
        plt.savefig("{}.png".format(outroot_nz), bbox_inches="tight")
        plt.close()

    # AR TSNR2 = f(FIBERASSIGN_X, FIBERASSIGN_Y, FIBER)
    keys = ["PROGRAM", "FIBER", "FIBERASSIGN_X", "FIBERASSIGN_Y", tsnr2_key]
    mydict = {key: [] for key in keys}
    badqa_val, _ = get_qa_badmsks()
    #
    d = ref_d["DARK"].copy()

    # AR TSNR2: reading (cutting on COADD_FIBERSTATUS and PASS=0)
    sel = (d["QAFIBERSTATUS"] & badqa_val) == 0
    d = d[sel]

    fiber_d = Table()
    fiber_d["FIBER"] = np.arange(5000, dtype=int)
    keys = ["FIBERASSIGN_X", "FIBERASSIGN_Y", "TSNR2_LRG"]
    for key in keys:
        fiber_d[key] = np.nan

    # AR TSNR2: median value per fiber
    for i in range(5000):
        sel = d["FIBER"] == i
        if sel.sum() > 0:
            for key in keys:
                fiber_d[key][i] = np.nanmedian(d[key][sel]).round(2)
    fiber_d.write("{}.ecsv".format(outroot_tsnr2), overwrite=True)

    # AR TSNR2: control plot
    if args.plot:
        fig, ax = plt.subplots()
        sc = ax.scatter(
            fiber_d["FIBERASSIGN_X"],
            fiber_d["FIBERASSIGN_Y"],
            c=fiber_d["TSNR2_LRG"],
            s=5,
        )
        ax.set_xlabel("FIBERASSIGN_X [mm]")
        ax.set_ylabel("FIBERASSIGN_Y [mm]")
        ax.grid()
        cbar = plt.colorbar(sc)
        cbar.set_label("TSNR2_LRG")
        plt.savefig("{}.png".format(outroot_tsnr2), bbox_inches="tight")
        plt.close()


if __name__ == "__main__":
    sys.exit(main())
