#!python

"""
Compute "nonlinear dark" model, e.g.

desi_compute_dark_nonlinear --days 20200729 20200730 --camera b0 \
        --darkfile dark-20200729-b0.fits.gz \
        --biasfile bias-20200729-b0.fits.gz
"""

import astropy.io.fits as pyfits

import argparse


def previous_night_or_day(night_or_day) :
    year = int(str(night_or_day)[0:4])
    month = int(str(night_or_day)[4:6])
    day = int(str(night_or_day)[6:8])
    t = datetime.datetime(year, month, day) - datetime.timedelta(days=1)
    return int(t.strftime('%Y%m%d'))



parser = argparse.ArgumentParser(
    formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    description="Compute a non-linear dark model",
    epilog='''
    Combines a set of ZEROs and DARKs at different exposure times to build

    model(x,y,t) = bias(x,y) + dark(x,y)*t + nonlinear(y,t)

    i.e. the non-linear term is only a function of row (y)

    Data can be grouped by calendar --days (rollover at midnight) or
    --nights (rollover at noon)
''')

parser.add_argument('--days', type=int, nargs="*",
                    help='YEARMMDD days to use for ZEROs and DARKs')
parser.add_argument('--nights', type=int, nargs="*",
                    help='YEARMMDD nights to use for ZEROs and DARKs')

parser.add_argument('--bias-per-day',action="store_true", help="compute a bias per day; default is per night")

parser.add_argument('--first-expid', type=int, required=False,
                    help='First EXPID to include')
parser.add_argument('--last-expid', type=int, required=False,
                    help='Last EXPID to include')
parser.add_argument('--darkfile', type=str, required=True,
                    help='output dark model file')
parser.add_argument('--biasfile', type=str, required=True,
                    help='output bias model file')
parser.add_argument('--camera', type=str, required=True,
                    help = 'Camera to process (e.g. b0, r1, z9')
parser.add_argument('-t','--tempdir', type=str, required=False,
                    help='directory for intermediate files')
parser.add_argument('--linexptime', type=float, default=300.0, required=False,
                    help='Model dark current as linear above this exptime')
parser.add_argument('--nskip-zeros', type=int, default=5, required=False,
                    help='Skip N ZEROs per day while flushing charge')
parser.add_argument('--mindarks', type=int, default=5, required=False,
                    help='Minimum number of DARKs per EXPTIME to use that EXPTIME')
parser.add_argument('--min-vccdsec', type=float, default=21600., required=False,
                    help='Minimum VCCDSEC (seconds since VCCD on) to use dark, (default: 21600=6h)')
parser.add_argument('--temp-tolerance', type=float, default=1, required=False,
                    help='Do not create model if maximum CCD temperature difference is above this threshold, (default: 1K)')

args = parser.parse_args()

#- Import after parsing for faster --help
import os
import sys
import glob
import datetime

from astropy.table import Table
from astropy.time import Time
import numpy as np
import fitsio

from desiutil.log import get_logger
import desispec.io.util
from desispec.ccdcalib import compute_dark_file, compute_bias_file
from desispec.ccdcalib import fit_const_plus_dark,fit_dark
from desispec.ccdcalib import model_y1d
from desispec.io import findfile

log = get_logger()

if args.days is None and args.nights is None:
    log.critical('Must specify --days or --nights')
    sys.exit(1)

#- tempdir caches files that could be re-used when rerunning for debugging
#- i.e. it can be cleaned up when done, but isn't completely transient
if args.tempdir is None:
    outdir = os.path.dirname(os.path.abspath(args.darkfile))
    tempdir = os.path.join(outdir, 'temp')
else:
    tempdir = args.tempdir

log.debug(f'Writing temporary files to {tempdir}')
if not os.path.isdir(tempdir):
    os.makedirs(tempdir)

speclog_file = os.path.join(tempdir, 'speclog.csv')
if os.path.exists(speclog_file):
    log.info(f'Reading speclog from {speclog_file}')
    speclog = Table.read(speclog_file, format='ascii',fill_values=None)   #the None is needed to avoid masking
else:

    nights_for_speclog = list()
    if args.days is not None and len(args.days)>0 :
        nights_for_speclog += args.days
        nights_for_speclog += [previous_night_or_day(day) for day in args.days]
    if args.nights is not None and len(args.nights)>0 :
        nights_for_speclog += args.nights
    nights_for_speclog = sorted(set(nights_for_speclog))

    log.info(f'Generating speclog for nights {nights_for_speclog}')
    speclog = desispec.io.util.get_speclog(nights_for_speclog)

    #- Add "DAY" column = rolls over at midnight instead of MST noon
    t = Time(speclog['MJD']-7/24, format='mjd')
    speclog['DAY'] = t.strftime('%Y%m%d').astype(int)

    #- Trim to just the requested days/nights
    keep = np.zeros(len(speclog), dtype=bool)
    if args.nights is not None and len(args.nights)>0 :
        keep |= np.isin(speclog["NIGHT"],args.nights)
    if args.days is not None and len(args.days)>0 :
        keep |= np.isin(speclog["DAY"],args.days)

    speclog = speclog[keep]
    tmpfile = speclog_file + '.tmp-' + str(os.getpid())
    speclog.write(tmpfile, format='ascii.csv')
    os.rename(tmpfile, speclog_file)
    log.info(f'Wrote speclog to {speclog_file}')

#- filter expids if requested
if args.first_expid is not None:
    keep = speclog['EXPID'] >= args.first_expid
    speclog = speclog[keep]
    if len(speclog) == 0:
        log.critical('No exposures! exiting...')
        sys.exit(1)

if args.last_expid is not None:
    keep = speclog['EXPID'] <= args.last_expid
    speclog = speclog[keep]
    if len(speclog) == 0:
        log.critical('No exposures! exiting...')
        sys.exit(1)

colnames = speclog.colnames
if 'CAMWORD' in colnames and 'BADCAMWORD' in colnames and 'BADAMPS' in colnames:
    goodcamwords=[desispec.io.util.difference_camwords(s['CAMWORD'],s['BADCAMWORD']) for s in speclog]
    goodcamlists=[desispec.io.util.decode_camword(camword) for camword in goodcamwords]
    keep = np.array([(args.camera in goodcam and args.camera not in badamp) for goodcam,badamp in zip(goodcamlists,speclog['BADAMPS'])])
    keep &= speclog['LASTSTEP']=='all'
    speclog = speclog[keep]
    if len(speclog) == 0:
        log.critical('No exposures! exiting...')
        sys.exit(1)

#- keep only ZERO or DARK
keep = (speclog['OBSTYPE']=='ZERO')|(speclog['OBSTYPE']=='DARK')| (speclog['OBSTYPE'] == 'zero')| (speclog['OBSTYPE'] == 'dark')
speclog = speclog[keep]
if len(speclog) == 0:
    log.critical('No exposures! exiting...')
    sys.exit(1)

#- group EXPTIMEs by integer
speclog['EXPTIME_INT'] = speclog['EXPTIME'].astype(int)


#- Remove any EXPTIME_INTs with too few exposures to make a good dark
keep = np.zeros(len(speclog), dtype=bool)
for exptime in np.unique(speclog['EXPTIME_INT']):
    ii = (speclog['EXPTIME_INT'] == exptime)
    ndarks = np.count_nonzero(ii)
    if ndarks >= args.mindarks:
        log.info(f'Using {ndarks} exposures with EXPTIME {exptime}')
        keep[ii] = True
    else:
        log.warning(f'Only {ndarks}<{args.mindarks} DARKs for EXPTIME {exptime}; dropping')

speclog = speclog[keep]

#- Print some summary stats before continuing
isZero = (speclog['OBSTYPE'] == 'ZERO') | (speclog['OBSTYPE'] == 'zero')
isDark = (speclog['OBSTYPE'] == 'DARK') | (speclog['OBSTYPE'] == 'dark')

#- group per night or per day if option --bias-per-day
if args.bias_per_day :
    daynight_key = "DAY"
else :
    daynight_key = "NIGHT"

selected_daynights = []
for daynight_val in np.unique(speclog[daynight_key]) :
    ii = speclog[daynight_key] == daynight_val # either a day or a night
    nzeros = np.count_nonzero(ii & isZero)
    ndarks = np.count_nonzero(ii & isDark)
    darktimes = sorted(set(speclog['EXPTIME_INT'][ii & isDark]))
    log.info(f'{daynight_key}={daynight_val} has {nzeros} ZEROs and {ndarks} DARKs with exptimes {darktimes}')
    if nzeros==0 or ndarks==0:
        log.warning(f"there are no good observations left for {daynight_key}={daynight_val}, skipping it")
    else :
        selected_daynights.append(daynight_val)

if len(selected_daynights)==0:
    log.critical('Nothing left to process')
    sys.exit(1)

#- Combine the ZEROs into per-day bias files
all_zerofiles = list()
min_temp=max_temp=None
for daynight_val in selected_daynights :
    zerofiles = list()
    ii = isZero & (speclog[daynight_key] == daynight_val)
    nzeros = np.count_nonzero(ii)
    nzeros_good = nzeros - args.nskip_zeros
    if nzeros_good < 5:
        log.critical(f'{nzeros} ZEROS on {daynight_key}={daynight_val} is insufficient when skipping {args.nskip_zeros}')
        continue

    elif nzeros_good < 20:
        log.warning(f'Only {nzeros_good} good ZEROs on {daynight_key}={daynight_val}')
    else:
        log.info(f'Using {nzeros_good} ZEROs on {daynight_key}={daynight_val}')

    for row in speclog[ii][args.nskip_zeros:]:
        rawfile = findfile('raw', row['NIGHT'], row['EXPID'])
        # check cam exists for this file
        with pyfits.open(rawfile) as hdulist :
            if args.camera in hdulist :
                if 'VCCDSEC' in hdulist[args.camera].header:
                    if hdulist[args.camera].header['VCCDSEC']<args.min_vccdsec:
                        log.info(f"{args.camera} in {rawfile} had low VCCDSEC, skipping")
                        continue
                else:
                    log.warning(f"no VCCDSEC header on {rawfile}, {args.camera}")
                if 'CCDTEMP' in hdulist[args.camera].header:
                    if min_temp is None:
                        min_temp=max_temp=hdulist[args.camera].header['CCDTEMP']
                    else:
                        min_temp=min([min_temp,hdulist[args.camera].header['CCDTEMP']])
                        max_temp=max([max_temp,hdulist[args.camera].header['CCDTEMP']])
                    if max_temp-min_temp > args.temp_tolerance:
                        log.critical(f"{args.camera} in {rawfile} caused high temperature shift of {max_temp-min_temp:.3f}K")
                        sys.exit(3)

                else:
                    log.warning(f"no CCDTEMP header on {rawfile}, {args.camera}")
                zerofiles.append(rawfile)
                all_zerofiles.append(rawfile)
    if len(zerofiles)<5:
        log.warning(f'{len(zerofiles)} ZEROS on {night} is insufficient, reason are VCCD checks with minimum time {args.min_vccdsec}')
        continue
    biasfile = f'{tempdir}/bias-{daynight_key}-{daynight_val}-{args.camera}.fits'
    if os.path.exists(biasfile):
        log.info(f'{biasfile} already exists')
    else:
        log.info(f'Generating {biasfile}')
        compute_bias_file(zerofiles, biasfile, args.camera)

#- Combine all ZEROs into a default BIAS file
if os.path.exists(args.biasfile):
    log.info(f'{args.biasfile} already exists')
else:
    log.info(f'Generating {args.biasfile}')
    compute_bias_file(all_zerofiles, args.biasfile, args.camera)

#- Combine the DARKs into master darks per exptime
darktimes = np.array(sorted(set(speclog['EXPTIME_INT'][isDark])))

for exptime in darktimes:
    darkfile = f'{tempdir}/dark-{args.camera}-{exptime}.fits'
    if os.path.exists(darkfile):
        log.info(f'{darkfile} already exists')
        continue
    else:
        log.info(f'Generating {darkfile}')

    rawfiles = list()
    biasfiles = list()
    ii = (speclog['EXPTIME_INT'] == exptime)
    for row in speclog[isDark & ii]:
        daynight_val, night, expid = row[daynight_key], row['NIGHT'], row['EXPID']
        filename = findfile('raw', night, expid, args.camera)
        with pyfits.open(filename) as hdulist :
            if args.camera in hdulist :
                if 'VCCDSEC' in hdulist[args.camera].header:
                    if hdulist[args.camera].header['VCCDSEC']<args.min_vccdsec:
                        log.info(f"{args.camera} in {filename} had low VCCDSEC, skipping")
                        continue
                else:
                    log.warning(f"no VCCDSEC header on {filename}, {args.camera}")
                if 'CCDTEMP' in hdulist[args.camera].header:
                    if min_temp is None:
                        min_temp=max_temp=hdulist[args.camera].header['CCDTEMP']
                    else:
                        min_temp=min([min_temp,hdulist[args.camera].header['CCDTEMP']])
                        max_temp=max([max_temp,hdulist[args.camera].header['CCDTEMP']])
                    if max_temp-min_temp > args.temp_tolerance:
                        log.critical(f"{args.camera} in {filename} caused high temperature shift of {max_temp-min_temp:.3f}K")
                        sys.exit(3)
                else:
                    log.warning(f"no CCDTEMP header on {filename}, {args.camera}")
                biasfile=f'{tempdir}/bias-{daynight_key}-{daynight_val}-{args.camera}.fits'
                if os.path.isfile(biasfile):
                    rawfiles.append(filename)
                    biasfiles.append(biasfile)
    if len(rawfiles)==0:
        log.warning(f"no DARKs with {exptime}s left, skipping that EXPTIME")
        continue

    compute_dark_file(rawfiles, darkfile, args.camera, bias=biasfiles,
        exptime=exptime)



#- Read the individual combined dark images
log.info('Reading darks for individual EXPTIMEs')
darkimages = list()
darkheaders = list()
for exptime in darktimes:
    darkfile = f'{tempdir}/dark-{args.camera}-{exptime}.fits'
    if os.path.isfile(darkfile) :
        img, hdr = fitsio.read(darkfile, 'DARK', header=True)
        darkimages.append(img*exptime)
        darkheaders.append(hdr)

darkimages = np.array(darkimages)

if np.max(darktimes) < args.linexptime:
    log.critical(f'No DARKs with exptime >= args.linexptime={args.linexptime}')
    sys.exit(2)

ii = darktimes >= args.linexptime
log.info('Calculating const+dark using exptimes {}'.format(darktimes[ii]))
const, dark = fit_const_plus_dark(darktimes[ii], darkimages[ii])
ny, nx = dark.shape

for iteration in range(3) :
    log.info("iter={} Assemble 1D models for left & right amps vs. exposure time".format(iteration))
    nonlinear1d = list()
    darkimages_bis = list()
    for exptime, image in zip(darktimes, darkimages):
        assert image.shape == (ny,nx)
        tmp = image - dark*exptime  #- 1D images model dark-subtracted residuals
        left = model_y1d(tmp[:, 0:nx//2], smooth=0)
        right = model_y1d(tmp[:, nx//2:], smooth=0)
        nonlinear1d.append( np.array([left, right]) )
        twod=np.zeros((ny,nx))
        twod[:,0:nx//2] = left[:,None]
        twod[:,nx//2:]  = right[:,None]
        darkimages_bis.append(image-twod)
        log.debug("iter={} exptime={} <left>={} <right>={}".format(iteration,exptime,np.mean(left),np.mean(right)))
    log.info('iter={} Calculating dark using exptimes {}'.format(iteration,darktimes[ii]))
    darkimages_bis=np.array(darkimages_bis)
    previous_dark = dark
    dark = fit_dark(darktimes[ii], darkimages_bis[ii])
    maxdiff=np.max(np.abs(dark-previous_dark))*1000
    log.info('iter={} max(|delta dark|)*(1000 sec)={:4.3f} elec'.format(iteration,maxdiff))
    if maxdiff<0.01 : break # typically ok after 1 iteration

#- Write final output
log.info(f'Writing {args.darkfile}')
with fitsio.FITS(args.darkfile, 'rw', clobber=True) as fx:
    header = fitsio.FITSHDR()
    header['BUNIT'] = 'electron/s'
    header.add_record(dict(name='DARKFMT', value='v2',
            comment='bias(x,y) + dark(x,y)*t + nonlinear(y,t)'))

    #- Add header keywords from first DARK
    hdr = darkheaders[0]
    for key in hdr.keys():
        if (key != 'EXPTIME') and \
           (not key.startswith('INPUT')) and \
           (key not in header):
                header.add_record(
                    dict(name=key, value=hdr[key], comment=hdr.get_comment(key))
                    )

    #- Add record of all input files used
    i = 0
    for hdr in darkheaders:
        for k in range(1000):
            key = f'INPUT{k:03d}'
            if key in hdr:
                header[f'INPUT{i:03d}'] = hdr[key]
                i += 1
            else:
                break

    #- 2D dark model in electron/s
    fx.write(dark.astype(np.float32), extname='DARK', header=header)

    #- 1D profiles at individual times, in electron [not electron/s]
    for exptime, model1d, hdr in zip(darktimes, nonlinear1d, darkheaders):
        hdr['BUNIT'] = 'electron'
        hdr.add_record(dict(name='BUNIT', value='electron',
            comment='Note: 1D profiles are electron, not electron/s'))
        hdr.delete('EXTNAME')
        extname = 'T{}'.format(int(exptime))
        fx.write(model1d.astype(np.float32), extname=extname, header=hdr)
