#!python


import sys,string, os
import astropy.io.fits as pyfits
import argparse
import datetime
import numpy as np
import astropy.table as t

from scipy.stats import iqr
from scipy.ndimage import binary_closing

from desispec import io
from desiutil.log import get_logger

from desispec.preproc import parse_sec_keyword, get_readout_mode, get_amp_ids
from desispec.maskbits import ccdmask
from desispec.workflow.tableio import load_table, load_tables
from desispec.calibfinder import CalibFinder
from desispec.io.raw import read_raw_primary_header

#Parser to take arguments from command line
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description="Compute a mask using dark images",
 epilog='''Input is a list of raw dark images (Possibly with different exposure times). Raw images are fully preprocessed then for sets of N darks the min absolute value is computed as a first guess of the mask. From M such guesses an or operation is performed, finally binary_closing fills in gaps and columns with more than a certain fraction of masking will be filtered''')

parser.add_argument('-f','--firstnight', type = str, default = None, required = False,
                    help = 'first night to use for automatic discovery')

parser.add_argument('-l','--lastnight', type = str, default = None, required = False,
                    help = 'last night to use for automatic discovery')

parser.add_argument('-i','--images', type = str, default = None, required = False, nargs="*",
                    help = 'paths to raw dark image fits files (either this or --firstnight/--lastnight needs to be set)')

parser.add_argument('--input-specprod', type = str, default = None, required = False,
                    help = 'use this specific specprod to read the exposure tables (default is $SPECPROD)')

parser.add_argument('--specprod-dir', type = str, default = None, required = True,
                    help = 'use this specific specprod dir for the preprocessed files (has to be different from main processing)')

parser.add_argument('--save-preproc', action = 'store_true',
                    help = 'save preproc file (default is false)')

parser.add_argument('--min-exptime', type = float , default = 1000. , required = False,
                    help = 'minimum exposure time of dark exposures')

parser.add_argument('-o','--outfile', type = str, default = None, required = True,
                    help = 'output mask filename')

parser.add_argument('--camera',type = str, required = True,
                    help = 'header HDU (int or string)')

parser.add_argument('--colfrac', type = float, default=0.4, required=False,
                    help='If more than this fraction of pixels are blocked in a column, the whole column in that amplifier is blocked.')
parser.add_argument('--closeiter', type = int, default=15, required=False,
                    help='Number of iterations for the binary closing operation')

parser.add_argument('--old-mask', type = str, default = None, required = False,
                    help = 'Path to a previous mask file. If given, the output is the bitwise OR of the input mask and the mask created by this script.')

parser.add_argument('--frames-per-coadd', type = int, default = 4, required = False,
                    help = 'number of dark frames to compute the min(abs(preprocessed dark)) from')
parser.add_argument('--coadds-per-mask', type = int, default = 2, required = False,
                    help = 'number of masks to combine via or')

parser.add_argument('--threshold', type = float, default = 0.007, required = False,
                    help = 'threshold for mask (in counts per s, absolute value)')

parser.add_argument('--use-only-morning-darks', action = 'store_true',
                    help = 'only use morning darks')

parser.add_argument('--ignore-pixflat', action = 'store_true',
                    help = 'dont set mask bit for anomalous pixflat')

args = parser.parse_args()
log  = get_logger()

if "SPECPROD" not in os.environ and args.input_specprod is None :
    log.error("Need either the env. variable SPECPROD set or the optional argument --input-specprod SPECPROD to get the exposure tables")
    sys.exit(12)


if args.images is None : # search for files

    if (args.firstnight is None) or (args.lastnight is None) :
        log.critical("Need to supply either --firstnight and --lastnight or a list of images in --images")
        raise ValueError("Need to supply either --firstnight and --lastnight or a list of images in --images")

    dt_firstnight=datetime.datetime.strptime(f"{args.firstnight}","%Y%m%d")
    dt_lastnight=datetime.datetime.strptime(f"{args.lastnight}","%Y%m%d")
    datearr=np.arange(dt_firstnight,dt_lastnight+datetime.timedelta(days=1),datetime.timedelta(days=1),dtype='datetime64[D]')
    datearr=np.datetime_as_string(datearr)
    datearr=np.char.replace(datearr,'-','')
    exptab_names=[io.findfile('exptable', d, specprod=args.input_specprod) for d in datearr]

    if len(exptab_names)==0 :
        log.error("Empty list of exposure tables")
        sys.exit(12)

    exptab=t.vstack(load_tables(exptab_names))

    select=exptab['OBSTYPE']=='dark'
    if args.use_only_morning_darks :
        select&=exptab['PROGRAM']=='morning darks'
    exptab=exptab[select]
    nights=exptab['NIGHT']
    expids=exptab['EXPID']

    args.images  = [io.findfile('raw', night=n, expid=e) for n,e in zip(nights,expids)]



elif  (args.firstnight is not None) | (args.lastnight is not None):
    log.warning("Will ignore--firstnight/--lastnight because input images provided with --images")

number_of_images_needed = args.frames_per_coadd*args.coadds_per_mask
if len( args.images ) < number_of_images_needed :
    log.error(f"Found {len(args.images)} images but need at least {args.frames_per_coadd}x{args.coadds_per_mask}={number_of_images_needed}")
    sys.exit(12)
#args.images = args.images[:number_of_images_needed]

raw_filenames=args.images
raw_filenames.sort()

print("Images to read:")
for filename in raw_filenames :
    print(filename)


log.info("read images ...")

shape=None

#Read all the dark images
readout_mode=None
cam_header=None
images=[]
nights=[]
expids=[]
filenames=[]
pixflat_filenames=[]

hardware_and_readout=dict()

for raw_filename in raw_filenames :

    log.info(f"Reading file {raw_filename}")

    # collect header data
    if not os.path.exists(raw_filename):
        log.warning(f"{raw_filename} not found, continuing")
        continue

    primary_header = read_raw_primary_header(raw_filename)
    exptime = primary_header["EXPTIME"]

    if exptime<args.min_exptime:
        expid = primary_header["EXPID"]
        log.warning(f"Discard dark exposure {expid} because EXPTIME={exptime} <{args.min_exptime}s")
        continue

    extname=f'{(args.camera)}'.upper()

    with pyfits.open(raw_filename) as fitsfile:
        try :
            cam_header=fitsfile[extname].header
        except KeyError as e :
            log.warning(f"No {args.camera} in {raw_filename}")
            continue

    #this is a check on 2amp/4amp detector modes
    if readout_mode is None:
        readout_mode = get_readout_mode(cam_header)
    else:
        readout_mode2=get_readout_mode(cam_header)
        if readout_mode2 != readout_mode:
            log.critical("A set of spectra with different readout modes was submitted")
            raise TypeError("A set of spectra with different readout modes was submitted")

    night = primary_header["NIGHT"]
    expid = primary_header["EXPID"]

    # get pixflat
    cfinder = CalibFinder([primary_header,cam_header])
    if cfinder.haskey("PIXFLAT") :
        pixflat_filenames.append(cfinder.findfile("PIXFLAT"))


    for k in ["DETECTOR","CCDCFG","CCDTMING","CCDTEMP"] :
        if k not in hardware_and_readout.keys() :
            hardware_and_readout[k]=cam_header[k]
        else :
            if k == "CCDTEMP" :
                assert(np.abs(hardware_and_readout[k]-cam_header[k])<5.)
            else :
                assert(hardware_and_readout[k]==cam_header[k])


    preproc_filename = io.findfile('preproc', night=night, expid=expid, camera=args.camera, specprod_dir=args.specprod_dir)

    if os.path.isfile(preproc_filename) :
        log.info(f"Read existing file {preproc_filename}")
        with pyfits.open(preproc_filename) as fitsfile:
            img=fitsfile['IMAGE'].data
    else :
        log.info(f"Preprocess {preproc_filename}")
        image = io.read_raw(raw_filename, args.camera,
                            bias=True,
                            nogain=False,
                            nocosmic=True,
                            mask=False,
                            dark=True,
                            pixflat=False,
                            nocrosstalk=True,
                            ccd_calibration_filename=None)
        img=image.pix
        if args.save_preproc :
            io.write_image(preproc_filename,image)
            log.info(f"wrote {preproc_filename}")

    shape=img.shape
    log.info(f"adding dark {filename} divided by exposure time {exptime} s")
    expids.append(expid)
    nights.append(night)
    filenames.append(preproc_filename)
    images.append(img/exptime)

ii=np.argsort(expids) # make sure it's sorted
expids=np.array(expids)[ii]
nights=np.array(nights)[ii]
filenames=np.array(filenames)[ii]
images=np.array(images)[ii]

# convert to list of lists, where each of the sublist images will be used to make a first mask
images_list_of_lists=[]

# tmp list
images_tmp_list=[]

previous_night=None
for expid,night,filename,image in zip(expids,nights,filenames,images) :
    images_tmp_list.append(image)
    # save list if exceeds requested number of frames per coadd or switched to a different night
    if len(images_tmp_list) >= args.frames_per_coadd or ((previous_night is not None) and night != previous_night):
        images_list_of_lists.append(np.array(images_tmp_list))
        images_tmp_list=[]

# save the remainder
if len(images_tmp_list) > 0 :
    images_list_of_lists.append(np.array(images_tmp_list))

log.info("compute mask ...")

mask = None
for images in images_list_of_lists :
    tmp_image = np.min(np.abs(images), axis=0)
    if mask is None :
        mask = np.zeros(tmp_image.shape, dtype=np.int32)
    above_threshold = tmp_image>args.threshold
    mask[tmp_image>args.threshold] |= ccdmask.BAD

mask[0,:] |= mask[1,:] # first row inherit the mask from the second
mask[-1,:] |= mask[-2,:] # last row inherit the mask from the next to last

log.info("close gaps via binary closing ...")

# need a margin for this to work with edges
margin=50
larger_mask = np.ones((mask.shape[0]+2*margin,mask.shape[1]+2*margin),dtype=np.int32)
for loop in range(1):
    larger_mask[margin:-margin,margin:-margin] = mask
    closed_mask = binary_closing(larger_mask, iterations=args.closeiter, structure=np.ones([2, 2]).astype(np.int32))
    assert(closed_mask.shape==larger_mask.shape)
    mask|=closed_mask[margin:-margin,margin:-margin]

#the loop starting here goes over individual mesks, everything before should work on all masks to be processed
log.info("check which columns have a large fraction masked")
print(readout_mode)
amp_ids=get_amp_ids(cam_header)
bad_pix = (mask>0)
if readout_mode=="4Amp" or readout_mode=="2AmpUpDown":
    #Block entire columns above a certain threshold per amplifier
    bad_pix_upper = bad_pix[0:bad_pix.shape[0]//2,:]
    bad_pix_lower = bad_pix[bad_pix.shape[0]//2:bad_pix.shape[0],:]
    bad_frac_upper = np.sum(bad_pix_upper, axis=0)/(bad_pix.shape[0]//2)
    bad_frac_lower = np.sum(bad_pix_lower, axis=0)/(bad_pix.shape[0]//2)
    bad_cols_upper = np.where(bad_frac_upper>=args.colfrac)
    bad_cols_lower = np.where(bad_frac_lower>=args.colfrac)
    mask[0:bad_pix.shape[0]//2,bad_cols_upper] |= ccdmask.BAD
    mask[bad_pix.shape[0]//2:bad_pix.shape[0],bad_cols_lower] |= ccdmask.BAD
elif readout_mode=="2AmpLeftRight":
    #Block entire columns above a certain threshold per amplifier
    bad_pix_all = bad_pix[0:bad_pix.shape[0],:]
    bad_frac_all = np.sum(bad_pix_all, axis=0)/bad_pix.shape[0]
    bad_cols_all = np.where(bad_frac_all>=args.colfrac)
    mask[0:bad_pix.shape[0],bad_cols_all] |= ccdmask.BAD
else:
    msg = f'Unknown {readout_mode=}'
    log.critical(msg)
    raise ValueError(msg)

if not args.ignore_pixflat :
    log.info("flat field threshold")
    if len(pixflat_filenames)>0 :
        for pixflat_filename in np.unique(pixflat_filenames) :
            print(f"Reading {pixflat_filename}")
            with pyfits.open(pixflat_filename) as fitsfile :
                flat = fitsfile[0].data
                badlow=(flat<0.5)
                badhigh=(flat>1.15) # means it's a bad pixels
                mask[badlow]  |= ccdmask.PIXFLATLOW
                mask[badhigh] |= ccdmask.BAD

log.info("incorporate previous mask")
#Incorporate a previously created mask using a bitwise OR
if args.old_mask!=None:
    with pyfits.open(args.old_mask) as f:
        mask_old = f[0].data
        mask |= mask_old
        log.info("Taken bitwise OR of input mask and the generated mask")


log.info(f"writing {args.outfile}")
mask_percent = np.sum(mask>0)*100/np.product(mask.shape)

#Write fits file with header info
hdu = pyfits.PrimaryHDU(mask.astype("int16"))
for k in hardware_and_readout.keys() :
    hdu.header[k]=hardware_and_readout[k]
hdu.header["NFRAMES"] = (args.frames_per_coadd, 'number of frames used for minimum absolute')
hdu.header["NMASKS"] = (args.coadds_per_mask, 'number of masks combined with or')
hdu.header["COLFRAC"] = (args.colfrac, 'Bad pixel fraction for blocked column in amp')
hdu.header["CAMERA"] = (args.camera, 'header HDU (int or string)')
hdu.header["CLOSEITER"] = (args.closeiter, 'binary cloing iterations')
hdu.header["MASKFRAC"] = (mask_percent, 'Percent of pixels masked')
if args.old_mask is not None :
    hdu.header["OLDMASK"] = (args.old_mask, 'Path to input mask')
hdu.header["NIGHTS"] = (" ".join(np.unique(nights.astype(str)).flatten()), 'NIGHTS included')
hdu.header["EXPIDS"] = (" ".join(expids.astype(str).flatten()), 'EXPIDs included')
for i, filename in enumerate(filenames):
    hdu.header[f"INPUT{i:03d}"] = os.path.basename(filename)
hdu.header["EXTNAME"] = "MASK"
hdu.header["BUNIT"] = ""

hdu.writeto(args.outfile, overwrite=True)
log.info(f"Saved mask")
