#!/usr/bin/env python

"""
Plot preproc thumbnail near specified fiber and wavelength
"""

import os, sys
import argparse

import numpy as np
import matplotlib.pyplot as plt
import fitsio

from astropy.table import Table
from astropy.visualization import ZScaleInterval

from desiutil.log import get_logger
from desispec.io import findfile, read_xytraceset

def get_cameras(fiber, wavelength):
    """
    Get list of cameras that cover this fiber and wavelength

    Args:
        fiber (int): fiber number 0-4999
        wavelength (float): wavelength in Angstroms

    Returns (list of str): cameras that cover that fiber,wavelength
    """
    bmin, bmax = 3600, 5800
    rmin, rmax = 5760, 7620
    zmin, zmax = 7520, 9824

    spectrograph = fiber//500
    cameras = list()
    if bmin <= wavelength <= bmax:
        cameras.append(f'b{spectrograph}')
    if rmin <= wavelength <= rmax:
        cameras.append(f'r{spectrograph}')
    if zmin <= wavelength <= zmax:
        cameras.append(f'z{spectrograph}')

    return cameras

def read_thumbnail(night, expid, camera, fiber, wavelength, size=51):
    """
    Read thumbnail image for night,expid,camera centered on fiber,wavelength

    Args:
        night (int): YEARMMDD night
        expid (int): exposure ID
        camera (str): spectrograph camera, e.g. b0, r1, z9
        fiber (int): fiber number 0-4999
        wavelength (float): wavelength in Angstrom

    Optional:
        size (int): size of thumbnail in pixels

    Returns: (thumbnail_image, thumbnail_mask)
    """
    psffile = findfile('psf', night, expid, camera, readonly=True)

    # fall back to psfnight if per-exposure psf isn't available
    if not os.path.exists(psffile):
        psffile = findfile('psfnight', night=night, camera=camera, readonly=True)

    xy = read_xytraceset(psffile)

    x = int(xy.x_vs_wave(fiber%500, wavelength))
    y = int(xy.y_vs_wave(fiber%500, wavelength))

    preprocfile = findfile('preproc', night, expid, camera, readonly=True)
    with fitsio.FITS(preprocfile) as fx:
        img = fx['IMAGE'].read()
        mask = fx['MASK'].read()

    xyslice = (slice(y-size//2, y+size//2+1), slice(x-size//2, x+size//2+1))
    return img[xyslice], mask[xyslice]

def plot_thumbnail(img, mask):
    """
    Plot thumbnail imgage with overlaid transparent red mask
    """
    zscale = ZScaleInterval()
    vmin, vmax = zscale.get_limits(img)
    plt.imshow(img, vmin=vmin, vmax=vmax)

    # RGBA overlay for mask in red, unmasked is transparent
    overlay = np.zeros((*mask.shape, 4), dtype=float)
    overlay[mask!=0, 0] = 1.0  # red
    overlay[mask!=0, 3] = 0.5  # transparency
    plt.imshow(overlay)

def plot_crosshair(x,y):
    """
    Plot magenta crosshair at x,y
    """
    fmt = 'm-'
    a,b = 2,5
    plt.plot([x-b,x-a], [y,y], fmt)
    plt.plot([x+a,x+b], [y,y], fmt)
    plt.plot([x,x], [y-b,y-a], fmt)
    plt.plot([x,x], [y+a,y+b], fmt)

def main():
    p = argparse.ArgumentParser()
    p.add_argument('-f', '--fiber', type=int, required=True, help="fiber number")
    p.add_argument('-w', '--wavelength', type=float, required=True, help="wavelength in Angstrom")

    group = p.add_mutually_exclusive_group(required=True)
    group.add_argument('-t', '--tileid', type=int, help="Tile ID")
    group.add_argument('-e', '--expid', type=int, help="Exposure ID")

    p.add_argument('-n', '--night', type=int, help="YEARMMDD night; only required with --expid for non-science exposures")
    p.add_argument('--size', type=int, default=51, help="Size of thumbnail to plot")
    p.add_argument('-s', '--specprod', help="Override $SPECPROD")
    p.add_argument('-o', '--outfile', help="Save figure to OUTFILE")
    p.add_argument('--debug', action="store_true", help="...")
    args = p.parse_args()

    log = get_logger()

    #- to avoid having to pass specprod everywhere
    if args.specprod is not None:
        os.environ['SPECPROD'] = args.specprod
    
    #- find which exposures observed this tile
    expfile = findfile('exposures', readonly=True)
    exps = Table.read(expfile)
    if args.tileid is not None:
        keep = (exps['TILEID'] == args.tileid)
    else:
        keep = (exps['EXPID'] == args.expid)

    exps = exps[keep]

    if len(exps) == 0:
        if args.tileid is not None:
            log.critical(f'no exposures found for tile {args.tileid}')
            return 1
        elif args.night is not None:
            exps = Table(dict(TILEID=[-99,], NIGHT=[args.night,], EXPID=[args.expid,]))
        else:
            log.critical(f'expid {args.expid} not found')
            return 1

    if args.tileid is not None:
        tileid = args.tileid
    else:
        tileid = exps['TILEID'][0]  # found via --expid instead of --tileid

    #- Read the individual thumbnails for each NIGHT,EXPID,CAMERA
    results = list()
    cameras = get_cameras(args.fiber, args.wavelength)
    for night, expid in exps['NIGHT', 'EXPID']:
        for camera in cameras:
            img, mask = read_thumbnail(night, expid, camera, args.fiber, args.wavelength, size=args.size)
            results.append( (img, mask, night, expid, camera) )

    #- Did we find any images?
    n = len(results)
    if n == 0:
        log.critical('no images found')
        return 1

    #- Plot thumbnail(s)
    plt.figure(figsize=(3*n, 3))

    for i, (img, mask, night, expid, camera) in enumerate(results):
        plt.subplot(1, n, i+1)
        plot_thumbnail(img, mask)
        plot_crosshair(args.size//2, args.size//2)
        plt.title(f'{night}/{expid:08d}/{camera}')
        plt.axis('off')

    #- Longer super-title if room to fit
    if n>1:
        plt.suptitle(f'Tile {tileid} Fiber {args.fiber} @ {args.wavelength:.1f}A')
    else:
        plt.suptitle(f'{tileid}/{args.fiber} @ {args.wavelength:.1f}A')

    plt.tight_layout()
    if args.outfile:
        plt.savefig(args.outfile)
        log.info(f'Wrote {args.outfile}')
    else:
        plt.show()

    if args.debug:
        import IPython; IPython.embed()


if __name__ == '__main__':
    sys.exit(main())
