#!/usr/bin/env python

import sys
import astropy.io.fits as pyfits
import argparse
import numpy as np
from desispec.pixflat import convolve2d
import scipy.ndimage

def flatten(img,sigma=20.) :
    """Flatten an image by dividing by its median per row, column, and a Gaussian convolution of itself
    Args:
      img : 2D np.array image
    Returns:
      modified image (2D np.array)
    Options:
        sigma : sigma of 2D Gaussian convolution in pixels
    """

    for i in range(img.shape[0]) :
        img[i] /= np.median(img[i])
    for i in range(img.shape[1]) :
        img[:,i] /= np.median(img[:,i])

    hw=int(3*sigma)
    u=np.linspace(-hw,hw,2*hw+1)
    x=np.tile(u,(2*hw+1,1))
    y=x.T
    k=np.exp(-x**2/2/sigma**2-y**2/2/sigma**2)
    k /= np.sum(k)
    smooth=convolve2d(img,k,weight=None)
    img /= smooth

    return img


parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
                                 description="""Computes a pixel level flat field image from a flat image obtained during the CCD qualification tests, like /global/cfs/cdirs/desi/spectro/teststand/rawdata/v0/M1-28_report/Flats/flat_median_700.fits. This code is designed for LBL CCDs in the RED cameras for now.
"""
)

parser.add_argument('-i','--infile', type = str, default = None, required = True,
                    help = 'path to input flat image fits file')
parser.add_argument('-o','--outfile', type = str, default = None, required = True,
                    help = 'output flatfield image filename')
parser.add_argument('-c','--camera', type = str, default = None, required = True,
                    help = 'camera: only RED implemented for now')
parser.add_argument('--sigma', type = float, default = 20., required = False,
                    help = 'sigma of 2D Gaussian convolution in pixels')

args = parser.parse_args()

img=pyfits.open(args.infile)[0].data.astype(float)


if args.camera.upper() != "RED" :
       print("error, can only flatfield red cameras for now")
       sys.exit(12)

if (img.shape==(4114,4200)) : # /global/cfs/cdirs/desi/spectro/teststand/rawdata/v0/M1-22-1_report/Flats/flat_median_700.fits TOO SMALL ?? evidence a central band is missing (see discontinuity of shape in flat at 2055,2055 )
    img2=np.zeros((4130,4200))
    img2[:4114//2]=img[:4114//2]
    img2[4130-4114//2:]=img[4114//2:]

    meanval=(img2[4114//2-1]+img2[4130-4114//2])/2.
    for j in range(4114//2,4130//2) :
        img2[j] = img2[4114//2-1]
    for j in range(4130//2,4130-4114//2) :
        img2[j] = img2[4130-4114//2]
    #pyfits.writeto("tmp.fits",img2,overwrite=True)
    #print("wrote tmp.fits")
    img=img2



flat=np.ones((4128,4114)) # for
n0=flat.shape[0]
n1=flat.shape[1]


i0=2065
i1=2136

print("flatten quadrant 1")
tmp=flatten(img[i0:i0+n0//2,i1:i1+n1//2],sigma=args.sigma)
flat[:n0//2,:n1//2] = tmp[::-1,::-1]

print("flatten quadrant 2")
tmp=flatten(img[1:n0//2+1,i1:i1+n1//2],sigma=args.sigma)
flat[n0//2:,:n1//2] = tmp[::-1,::-1]

print("flatten quadrant 3")
tmp=flatten(img[1:n0//2+1,7:7+n1//2],sigma=args.sigma)
flat[n0//2:,n1//2:n1] = tmp[::-1,::-1]

print("flatten quadrant 4")
tmp=flatten(img[i0:i0+n0//2,7:7+n1//2],sigma=args.sigma)
flat[:n0//2:,n1//2:n1] = tmp[::-1,::-1]

print("edges and central rows")
for i in range(n0//2-2,n0//2+3) :
    width=20
    tmp=np.ones(n1+2*width)
    tmp[width:-width]=flat[i]
    tmp[-width:]=tmp[-width-1]
    flat[i] /= scipy.ndimage.median_filter(tmp,width,mode='constant')[width:-width]

flat[:,:40]=1
flat[:,-20:]=1
flat[:20]=1
flat[-20:]=1

h=pyfits.HDUList([pyfits.PrimaryHDU(flat.astype('float32'))])
h[0].header["BUNIT"]=("","adimensional quantify to divide to flat field a CCD frame")
h[0].header["INPUT"]=(args.infile,"input file")
h[0].header["EXTNAME"]="PIXFLAT"
h.writeto(args.outfile,overwrite=True)
print("wrote",args.outfile)
