#!/usr/bin/env python

import sys
import astropy.io.fits as pyfits
import argparse
import numpy as np
import scipy.signal
import numba

from desiutil.log import get_logger

import scipy.ndimage as ndi
from numba import cfunc, carray
from numba.types import intc, intp, float64, voidptr
from numba.types import CPointer
from scipy import LowLevelCallable
from desispec.pixflat import convolve2d


@cfunc(intc(CPointer(float64), intp,
            CPointer(float64), voidptr))
def maskedmed(values_ptr, len_values, result, data):
    values = carray(values_ptr, (len_values,), dtype=float64)
    ii=(values!=0)
    if np.sum(ii)>0 :
        result[0] = np.median(values[ii])
    else :
        result[0] = 0.
    return 1

def maskedmedian(image,shape) :
    """ Returns a masked sliding median filtering of the input image.
        Zero values in the input image are ignored in the median.

    Args:
        image : 2D np.array
        shape : tuple of length 2 giving the shape of the filtering window.

    Returns:
        2D np.array of same shape as input image
    """
    return ndi.generic_filter(image, LowLevelCallable(maskedmed.ctypes), size=shape)




def clipped_mean_image(image_filenames) :
    """ Return a clipped mean of input images after rescaling each image

    Args:
        image_filenames : list of preprocessed image path

    Returns:
        mimage : mean image (2D np.array)
        ivar   : ivar of median
    """

    log.debug("first median")
    images=[]
    ivars=[]
    for filename in image_filenames  :
        h=pyfits.open(filename)
        images.append(h[0].data)
        ivars.append(h["IVAR"].data)

    mimage=np.median(images,axis=0)
    log.debug("compute a scale per image")
    smimage2=np.sum(mimage**2)
    for i in range(len(images)) :
        a=np.sum(images[i]*mimage)/smimage2
        log.debug("scale %d = %f", i,a)
        if a<=0 :
            raise ValueError("scale = %f for image %s"%(a,image_filenames[i]))
        images[i] /= a
        ivars[i] *= a**2

    shape=images[0].shape
    mimage=np.median(images,axis=0)

    log.info("compute mask ...")
    ares=np.abs(images-mimage)
    nsig=4.
    mask=(ares<nsig*1.4826*np.median(ares,axis=0))
    # average (not median)
    log.info("compute average ...")
    mimage=np.sum(images*mask,axis=0)/np.sum(mask,axis=0)
    mimage=mimage.reshape(shape)

    ivar=np.sum(ivars,axis=0)
    return mimage,ivar

@numba.jit
def dilate_mask(mask,d0=1,d1=1) :
    """ Increases the size of a masked area
    """
    omask=mask.copy()
    n0=mask.shape[0]
    n1=mask.shape[1]
    for i0 in range(n0) :
        for i1 in range(n1) :
            if mask[i0,i1] :
                v=mask[i0,i1]
                for j0 in range(max(0,i0-d0),min(n0,i0+d0+1)) :
                    for j1 in range(max(0,i1-d1),min(n1,i1+d1+1)) :
                        omask[j0,j1]=v
    return omask


def gaussian_smoothing_1d_per_axis(image,ivar,sigma,npass=2,dxdy=0.,dydx=0.) :
    """Computes a smooth model of the input image using two
    1D convolution with a Gaussian kernel of parameter sigma.
    Can do several passes.

    Args:
        image : 2D array input image
        sigma : float number (>0)
        npass : integer number (>=1)

    Returns:
        model : 2D array image of same shape as image
    """

    log=get_logger()
    hw=int(3*sigma)
    tmp = image.copy()
    tmpivar = ivar.copy()
    model = np.ones(tmp.shape).astype(float)

    # single Gaussian profile
    u=(np.arange(2*hw+1)-hw)
    prof=np.exp(-u**2/sigma**2/2.)
    prof/=np.sum(prof)

    # two kernels along two axes
    #
    kernels=[]

    # axis 0
    if dxdy==0 :
        kernel=np.zeros((2*hw+1,3))
        kernel[:,1]=prof
        kernels.append(kernel)
    else :
        x=u*dxdy
        i=(x+0.5*(x>0)-0.5*(x<0)).astype(int)
        j=np.arange(2*hw+1)
        hwb=max(1,np.max(np.abs(i)))
        kernel=np.zeros((2*hw+1,2*hwb+1))
        kernel[j,i+hwb]=prof
        kernels.append(kernel)

    # axis 1
    if dydx==0 :
        kernel=np.zeros((3,2*hw+1))
        kernel[1,:]=prof
        kernels.append(kernel)
    else :
        y=u*dydx
        j=(y+0.5*(y>0)-0.5*(y<0)).astype(int)
        i=np.arange(2*hw+1)
        hwb=max(1,np.max(np.abs(j)))
        kernel=np.zeros((2*hwb+1,2*hw+1))
        kernel[j+hwb,i]=prof
        kernels.append(kernel)

    for p in range(npass) : # possibly do several passes
        for a in range(2) : # convolve in 1d on each axis
            #log.debug("p=%d a=%d", p,a)
            res=convolve2d(tmp,kernels[a],weight=tmpivar)
            model *= res
            tmpivar *= res**2 # ?
            #tmpivar *= tmp**2 # ?
            tmp /= (res+(res==0))


    if 0 : # add 2D smoothing (does not help)
        x=np.tile((np.arange(2*hw+1)-hw)/sigma,(2*hw+1,1))
        r2=x**2+x.T**2
        kernel2d=np.exp(-r2/2.)
        kernel2d/=np.sum(kernel2d)
        res = convolve2d(tmp,kernel2d,weight=tmpivar)
        model *= res

    return model

def filtering(flat,model,width,edge_width,gradmask,reverse_order=False) :

    log = get_logger()

    if not reverse_order :
        shape_1 = [width,1]
        shape_2 = [1,width]
    else :
        shape_1 = [1,width]
        shape_2 = [width,1]

    log.info("step 1")
    model *= scipy.signal.medfilt2d(flat,shape_1)
    flat  =  (ivar>0)*(model>minflat)*image/(model*(model>minflat)+(model<=minflat))
    flat  += (model<=minflat)|(ivar<=0)
    log.info("step 2")
    model *= scipy.signal.medfilt2d(flat,shape_2)
    flat  =  (ivar>0)*(model>minflat)*image/(model*(model>minflat)+(model<=minflat))
    flat  += (model<=minflat)|(ivar<=0)

    if edge_width>1 :
        if not reverse_order :
            shape_1 = [edge_width,1]
            shape_2 = [1,edge_width]
        else :
            shape_1 = [1,edge_width]
            shape_2 = [edge_width,1]
        log.info("step 3")
        tmp = scipy.signal.medfilt2d(flat,shape_1)
        model[gradmask] *= tmp[gradmask]
        flat  =  (ivar>0)*(model>minflat)*image/(model*(model>minflat)+(model<=minflat))
        flat  += (model<=minflat)|(ivar<=0)
        log.info("step 4")
        tmp = scipy.signal.medfilt2d(flat,shape_2)
        model[gradmask] *= tmp[gradmask]
        flat  =  (ivar>0)*(model>minflat)*image/(model*(model>minflat)+(model<=minflat))
        flat  += (model<=minflat)|(ivar<=0)
    return flat,model

def amplifier_matching(image) :

    log = get_logger()

    n0 = image.shape[0]
    n1 = image.shape[1]
    c0 = n0//2
    c1 = n1//2
    w=10
    lw=200

    vband_a = np.median(image[c0-lw:c0,c1-w:c1],axis=1)
    vband_b = np.median(image[c0-lw:c0,c1::c1+w],axis=1)
    rab = np.sqrt( np.median(vband_a/vband_b) / np.median(vband_b/vband_a) ) # to debias
    log.debug("a/b=%f", rab)
    image[:c0,:c1] /= np.sqrt(rab)
    image[:c0,c1:] *= np.sqrt(rab)

    vband_c = np.median(image[c0:c0+lw,c1-w:c1],axis=1)
    vband_d = np.median(image[c0:c0+lw,c1::c1+w],axis=1)
    rcd = np.sqrt( np.median(vband_c/vband_d) / np.median(vband_d/vband_c) ) # to debias
    log.debug("c/d=%f", rcd)
    image[c0:,:c1] /= np.sqrt(rcd)
    image[c0:,c1:] *= np.sqrt(rcd)

    hband_ab = np.median(image[c0-w:c0,lw:-lw],axis=0)
    hband_cd = np.median(image[c0:c0+w,lw:-lw],axis=0)
    rac = np.sqrt( np.median(hband_ab/hband_cd) / np.median(hband_cd/hband_ab) ) # to debias
    log.debug("a/c = b/d = %f", rac)
    image[:c0] /= np.sqrt(rac)
    image[c0:] *= np.sqrt(rac)



parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description="""Computes a pixel level flat field image from a set of preprocessed images obtained with the flatfield slit.
An average image is computed if several preprocessed images are given.
The method consists in iteratively dividing the image by a smoothed version of the same image, flat(n+1) = flat(n)/smoothing(flat(n)).
The smoothing consists in 1D median filtering along the wavelength dispersion axis or the fiber axis alternatively, with masking.
"""
)

parser.add_argument('-i','--images', type = str, nargs='*', default = None, required = True,
                    help = 'path to input preprocessed image fits files, or a single median image')
parser.add_argument('-o','--outfile', type = str, default = None, required = True,
                    help = 'output flatfield image filename')
parser.add_argument('--out-mean', type = str, default = None, required = False,
                    help = 'save median image')
parser.add_argument('-d','--debug', action = 'store_true', help="write a lot of debugging images")
#parser.add_argument('--minflux',type = float, default=0.05,help="minimum flux fraction of median")
#parser.add_argument('--maxerr',type = float, default=0.02,help="maximum error")
parser.add_argument('--min-flat-for-mask',type=float, default=0.99,help="threshold for masking pixels in flat")
parser.add_argument('--max-flat-for-mask',type=float, default=1.02,help="threshold for masking pixels in flat")


args = parser.parse_args()
log = get_logger()

if len(args.images) == 1 :
    log.info("read a single image")
    h=pyfits.open(args.images[0])
    image=h[0].data
    ivar=h["IVAR"].data
else :
    log.info("compute a clipped mean of the input images")
    image,ivar=clipped_mean_image(args.images)

if args.out_mean is not None :
    log.info("writing median image %s ..."%args.out_mean)
    h=pyfits.HDUList([pyfits.PrimaryHDU(image),pyfits.ImageHDU(ivar,name="IVAR")])
    h.writeto(args.out_mean,overwrite=True)

if np.any(np.isnan(image)) :
    log.error("median image has nan values")
    sys.exit(12)

# hardcoded values for the pixel flat field
#############################################################
# minimal normalized model flux value to be used for flat fielding
# (this is just to avoid numerical issues)
minflat=0.0001

# maximum flat value (saturating the flat at this value)
maxflat=1.2

# width of the median filtering window along both axis
# in the CCD image area where the illumination gradient is not too large.
# it has to be large in order to avoid filtering out genuine CCD
# features, but it cannot be too large in order to filter out
# the variations of the illumination
median_filter_width=201 # CCD pixels (has to be an odd number)

# width of the median filtering window along both axis
# in the CCD image area where the illumination gradient is large
median_filter_width_in_mask=21 # CCD pixels (has to be an odd number)

# parameters defining the initial mask
maxerr_for_initial_mask=0.015 # fractional error
minfluxfrac_for_initial_mask=0.1 # fraction of median flux
ccd_edge_margin_for_initial_mask=20 # CCD pixels, on edges of CCD
illuminated_region_margin_for_initial_mask=5 # CCDpixels, masked area increase


minimum_flux_for_second_mask=10. # (in unit of input image values)

# parameters used for the first model that tries to describe
# the illumination pattern as a single spectrum,
# with a fitted transform from wavelength to CCD row coordinate
# as a function of CCD column coordinate, and with varying
# amplitude as a function of CCD column coordinate
#
# this model is not good enough for flat fielding
# but it's a starting point for the iterative median filtering
#
# vertical band width used to estimate the spectrum
# (fiber spectral traces are quasi vertical at the center of the CCD)
pixel_band_width_for_initial_spectral_model=1000 # pixels

# the spectrum has to be smoothed to remove most of the Poisson noise
# we use a fast fft convolution with a Gaussian for this
sigma_of_gaussian_smoothing_for_initial_spectral_model=4 # pixels

# a first stage consist in doing a cross-correlation of the spectral
# model with each column of the CCD to get a first estimate of the
# coordinate transformation to apply
max_pixel_row_offset_for_initial_spectral_model_correlation=90 # pixels

# CCD columns with low correlation coeff (=low flux in fact) are
# discarded from the polynomial fit of the transfo.
minimum_correlation_for_initial_spectral_model=0.1
# Degrees of polynomial of Y_central(X_ccd,Y_ccd)
x_degree_of_coordinate_transform_for_initial_spectral_model=4
y_degree_of_coordinate_transform_for_initial_spectral_model=0

# The ratio of the initial model image to a smoothed version of the
# same image is used to identify is smoothed further and
sigma_of_gaussian_smoothing_for_large_gradient_regions=100. # pixels
flat_error_threshold_for_large_gradient_regions=0.1
mask_margin_for_large_gradient_regions=30 # pixels



# Parameters for mask during fit (to avoid haloes around CCD defects)
min_flat_for_fit_mask = args.min_flat_for_mask
max_flat_for_fit_mask = args.max_flat_for_mask
margin_for_fit_mask = 3 # increase mask by 3 pix in each direction
max_fraction_of_masked_pixels = 0.05
#############################################################
nstep=12 # this is for log messages
step=0 # this is for log messages

step+=1; log.info("step {}/{} match amps".format(step,nstep))
amplifier_matching(image) # routine to match amplifier to avoid residual feature at the center of CCD

minflux_for_initial_mask=minfluxfrac_for_initial_mask*np.median(image)
err   = np.sqrt(1./(ivar+(ivar==0)))/(image*(image>0)+(image<=0))

# define mask0
# this is the primary mask, flat in the masked
# pixel will be set to 1.
# it is based on the illumination pattern of
# the flat slit that does not cover the whole CCD
# parameters that matter for this are
#   maxerr_for_initial_mask
#   minfluxfrac_for_initial_mask
#   ccd_edge_margin_for_initial_mask
#   illuminated_region_margin_for_initial_mask
###################################################
y=[]
xmin=[]
xmax=[]


for j in range(err.shape[0]) :
    ii=np.where((err[j,ccd_edge_margin_for_initial_mask:-ccd_edge_margin_for_initial_mask]<maxerr_for_initial_mask)&(image[j,ccd_edge_margin_for_initial_mask:-ccd_edge_margin_for_initial_mask]>minflux_for_initial_mask))[0]+ccd_edge_margin_for_initial_mask
    y.append(j)
    if len(ii)<3 :
        xmin.append(4000)
        xmax.append(0)
    else :
        xmin.append(ii[0])
        xmax.append(ii[-1])

x=[]
ymin=[]
ymax=[]
for i in range(err.shape[1]) :
    jj=np.where((err[:,i]<maxerr_for_initial_mask)&(image[:,i]>minfluxfrac_for_initial_mask))[0]
    x.append(i)
    if len(jj)<3 :
        ymin.append(4000)
        ymax.append(0)
    else :
        ymin.append(jj[0])
        ymax.append(jj[-1])
x=np.array(x)
ymin=np.array(ymin)
ymax=np.array(ymax)
y=np.array(y)
xmin=np.array(xmin)
xmax=np.array(xmax)

xmin += illuminated_region_margin_for_initial_mask
xmax -= illuminated_region_margin_for_initial_mask
ymin += illuminated_region_margin_for_initial_mask
ymax -= illuminated_region_margin_for_initial_mask

mask0=np.zeros(err.shape,dtype=int)
for j in range(mask0.shape[0]) :
    mask0[j,:xmin[j]+1]=1
    mask0[j,xmax[j]:]=1
for i in range(mask0.shape[1]) :
    mask0[:ymin[i]+1,i]=1
    mask0[ymax[i]:,i]=1

if args.debug :
    pyfits.writeto("debug-mask-0.fits",mask0,overwrite=True)
    log.info("wrote debug-mask-0.fits")
    pyfits.writeto("debug-image-0.fits",image,overwrite=True)
    log.info("wrote debug-image-0.fits")
    pyfits.writeto("debug-err.fits",err,overwrite=True)
    log.info("wrote debug-err.fits")

original_image = image.copy()
mask=(original_image<=minimum_flux_for_second_mask)

# here we compute a first model of the image
# the illumination pattern is modeled as a single spectrum,
# with a fitted transform from wavelength to CCD row coordinate
# as a function of CCD column coordinate, and with varying
# amplitude as a function of CCD column coordinate
step+=1; log.info("step {}/{} first model".format(step,nstep))

# compute the median spectrum in a band
medspec=np.zeros(image.shape[0])
n0=image.shape[0]
n1=image.shape[1]
b1=n1//2-pixel_band_width_for_initial_spectral_model//2
e1=n1//2+pixel_band_width_for_initial_spectral_model//2
band=image[:,b1:e1].copy()

if np.any(np.isnan(band)) :
    log.error("band image has nan values")
    sys.exit(12)

for loop in range(3) :
    medspec=np.nanmedian(band,axis=1)
    medspec[medspec<1]=1
    for i in range(band.shape[1]) :
        norm=np.median(band[:,i]/medspec)
        log.debug("%d %f", i, norm)
        if not np.isnan(norm) :
            band[:,i] *= norm

medspec=np.median(band,axis=1)

# smooth the spectrum with a gaussian kernel
y=np.arange(n0)
k=np.exp(-(y-n0/2.)**2/2./sigma_of_gaussian_smoothing_for_initial_spectral_model**2)
k/=np.sum(k)
medspec=scipy.signal.fftconvolve(medspec,k,'same')

model = np.zeros(image.shape)

# first estimate of the coordinate tranform with cross-correlation for each CCD column
nc = max_pixel_row_offset_for_initial_spectral_model_correlation
norm = np.mean(medspec[nc:-nc]**2)
offsets=np.zeros(n1)
corrvals=np.zeros(n1)
for i in range(n1):
    corr=np.zeros(nc)
    dy=np.arange(-2,nc-2).astype(int)
    for c in range(nc) :
        corr[c]=np.mean(medspec[nc:-nc]*image[nc+dy[c]:n0-nc+dy[c],i])/norm
    c=np.argmax(corr)
    offsets[i]  = dy[c]
    corrvals[i] = corr[c]
    if i%100 == 0 :
        log.debug("%d %f %f", i,offsets[i],corrvals[i])

# refit the best offset as a polynomial of X to remove statistical jitter
x=np.arange(n1)
ok=(corrvals>minimum_correlation_for_initial_spectral_model)
pol=np.poly1d(np.polyfit(x[ok],offsets[ok],x_degree_of_coordinate_transform_for_initial_spectral_model))

# refine this first guess with a linear fit for each CCD column
# potentially including a dilation term or higher order terms as a function of Y (CCD row number)
y=np.arange(n0)
dy=(y-n0//2)/float(n0//2)
for loop in range(1) :

    if loop==0 :
        npar=4
    else :
        npar=4+y_degree_of_coordinate_transform_for_initial_spectral_model
    A=np.zeros((npar,npar))
    B=np.zeros(npar)
    H=np.zeros((npar,n0))
    param=np.zeros((n1,npar))


    for i in range(n1):

        if not ok[i] :
            image[:,i] = 1
            ivar[:,i]  = 0
            continue


        if loop==0 :
            spec=np.interp(y,y+pol(x[i]),medspec)
        else :
            spec=model[:,i]
        dspecdy=np.interp(y,y+1,spec)-spec
        # refit offset and scale.
        # model = a * spec + b * dspecdy + c * dy * dspecdy + ...
        # the model absorbs some chromatic variation of the spectrum
        H[0] = spec
        H[1] = spec*dy
        H[2] = spec*dy**2
        for p in range(npar-3) :
            H[p+3] = dy**p*dspecdy
        sqrtw = np.sqrt(ivar[:,i]) ### high weight for faint values
        sqwH = sqrtw*H
        A = sqwH.dot(sqwH.T)
        B = sqwH.dot(sqrtw*image[:,i])
        Ai = np.linalg.inv(A)
        param[i] = Ai.dot(B)
        if i%100 == 0 : log.debug("%d %d %f", i,loop,param[i])

    # fit smooth params
    dx=(np.arange(n1)-n1//2)/(n1/2.)
    ii=np.where(ok)[0]
    for p in range(npar) :
        param_pol=np.poly1d(np.polyfit(dx[ii],param[ii,p],6))
        param[:,p]=param_pol(dx)

    for i in range(n1):
        if not ok[i] :
            continue


        if loop==0 :
            spec=np.interp(y,y+pol(x[i]),medspec)
        else :
            spec=model[:,i]

        # use exact transfo instead of derivatives
        # spec' = a*spec + pol(y)*dspecdy
        #       = a*spec(y')
        # y' = y + pol(dy)/a
        a  = param[i,0]
        yp = y.astype(float)
        for p in range(npar-1) :
            yp += param[i,p-1]/a
        specp = a*np.interp(yp,y,spec)
        model[:,i] = specp

for i in range(n1):
    if not ok[i] : continue
    spec = model[:,i]
    image[:,i] /= spec
    ivar[:,i] *= spec**2

if args.debug :
    pyfits.writeto("debug-flat-{}.fits".format(step),image,overwrite=True)
    log.info("wrote debug-flat-{}.fits".format(step))
    pyfits.writeto("debug-model-{}.fits".format(step),model,overwrite=True)
    log.info("wrote debug-model-{}.fits".format(step))

# we use the above model to identify regions of the image
# with large gradient to adapt the median filtering width.
# this is done by comparing the model with a smoothed version
step+=1; log.info("step {}/{} compute a gradient mask".format(step,nstep))
sig=sigma_of_gaussian_smoothing_for_large_gradient_regions
hw=int(3*sig)
x=np.tile(np.linspace(-hw,hw,2*hw+1),((2*hw+1),1))
k=np.exp(-(x**2+x.T**2)/2/sig**2)
k/=np.sum(k)
smodel=convolve2d(model,k)
flat=model/smodel
gradmask=(np.abs(flat-1)>flat_error_threshold_for_large_gradient_regions)
gradmask = dilate_mask(gradmask,mask_margin_for_large_gradient_regions,mask_margin_for_large_gradient_regions)
if args.debug :
    pyfits.writeto("debug-gradmask.fits",gradmask.astype(int),overwrite=True)
    log.info("wrote debug-gradmask.fits")

flat=image.copy()
model=np.ones(flat.shape)


# starting the filtering loop
# 1D smoothing along vertical axis
# then 1D smoothing along horizontal axis
# (this is not the same thing as a 2D filtering!)
step+=1; log.info("step {}/{} filtering".format(step,nstep))
flat,model = filtering(flat,model,median_filter_width,median_filter_width_in_mask,gradmask,False)

if args.debug :
    tmp=flat.copy()
    tmp=flat*(mask0==0)+(mask0!=0)
    pyfits.writeto("debug-flat-{}.fits".format(step),tmp,overwrite=True)
    log.info("wrote debug-flat-{}.fits".format(step))
    pyfits.writeto("debug-flat-nomask-{}.fits".format(step),flat,overwrite=True)
    log.info("wrote debug-flat-nomask-{}.fits".format(step))

# we need to also do the filtering in the reversed order of axis
# to cure a bias trail arising when a dark spot is located
# in a region with an illumination gradient
flat2=image.copy()
model2=np.ones(flat.shape)

step+=1; log.info("step {}/{} reverse order filtering".format(step,nstep))
flat2,model2 = filtering(flat2,model2,median_filter_width,0,gradmask,reverse_order=True)
if args.debug :
    pyfits.writeto("debug-flat2-reversed.fits",flat2,overwrite=True)
    log.info("wrote debug-flat2-reversed.fits")

# iterative mask of pixels and filtering
# also includes a gaussian filtering along with the median filtering
nloop=3
for loop in range(nloop) :

    step+=1; log.info("step {}/{} masking pixels".format(step,nstep))

    # auto-adjust the mask threshold
    for nsig in [3.,3.5,4.,5.,10.,20.] :
        mask = (flat<(min_flat_for_fit_mask-nsig*err))|(flat>(max_flat_for_fit_mask+nsig*err))|(original_image<=minimum_flux_for_second_mask)
        mask = dilate_mask(mask,margin_for_fit_mask,margin_for_fit_mask)
        frac = np.sum((mask>0)&(mask0==0))/float(np.sum(mask0==0)) #float(mask.size)
        if frac<max_fraction_of_masked_pixels :
            break
    log.info("Used nsig = {}, frac = {:4.3f}".format(nsig,frac))

    if args.debug :
        pyfits.writeto("debug-mask-{}.fits".format(step),mask.astype(int),overwrite=True)
        log.info("wrote debug-mask-{}.fits".format(step))

    if loop == 0 :
        log.debug("reset after mask is computed")
        flat = image*(mask==0)+model2*(mask!=0)
        model=np.ones(flat.shape)
        if args.debug :
            pyfits.writeto("debug-reset-flat.fits",flat,overwrite=True)
            log.info("wrote debug-reset-flat.fits")

    if loop > 0 :
        step+=1; log.info("step {}/{} gaussian smoothing".format(step,nstep))
        w = (mask==0).astype(float)
        model *= gaussian_smoothing_1d_per_axis(flat,w,50.,npass=2)
        flat  =  (ivar>0)*(model>minflat)*image/(model*(model>minflat)+(model<=minflat))
        flat  += ((model<=minflat)|(ivar<=0))
        if args.debug:
            tmp=flat.copy()
            tmp=flat*(mask0==0)+(mask0!=0)
            pyfits.writeto("debug-flat-{}.fits".format(step),tmp,overwrite=True)
            log.info("wrote debug-flat-{}.fits".format(step))
            pyfits.writeto("debug-flat-nomask-{}.fits".format(step),flat,overwrite=True)
            log.info("wrote debug-flat-nomask-{}.fits".format(step))

    if loop < nloop-1 :
        step+=1; log.info("step {}/{} filtering".format(step,nstep))
        if loop == 0 :
            flat,model=filtering(flat,model,median_filter_width,median_filter_width_in_mask,gradmask,False)
        else :
            model *= maskedmedian(flat*(mask==0),[median_filter_width,1])
            flat  =  (ivar>0)*(model>minflat)*image/(model*(model>minflat)+(model<=minflat))
            flat  += ((model<=minflat)|(ivar<=0))
            model *= maskedmedian(flat*(mask==0),[1,median_filter_width])
            flat  =  (ivar>0)*(model>minflat)*image/(model*(model>minflat)+(model<=minflat))
            flat  += ((model<=minflat)|(ivar<=0))

    if args.debug :
        tmp=flat.copy()
        tmp=flat*(mask0==0)+(mask0!=0)
        pyfits.writeto("debug-flat-{}.fits".format(step),tmp,overwrite=True)
        log.info("wrote debug-flat-{}.fits".format(step))
        pyfits.writeto("debug-flat-nomask-{}.fits".format(step),flat,overwrite=True)
        log.info("wrote debug-flat-nomask-{}.fits".format(step))
        pyfits.writeto("debug-model-{}.fits".format(step),model,overwrite=True)
        log.info("wrote debug-model-{}.fits".format(step))


fivar=ivar*model**2
flat[flat>maxflat]=maxflat
flat[flat<0]=0.
flat=flat*(mask0==0)+(mask0!=0)
fivar*=(mask0==0)

h=pyfits.HDUList([pyfits.PrimaryHDU(flat.astype('float32')),pyfits.ImageHDU(fivar.astype('float32'),name="IVAR")])
h[0].header["BUNIT"]=("","adimensional quantify to divide to flat field a CCD frame")
h[0].header["EXTNAME"]="PIXFLAT"
h[1].header["BUNIT"]=("","adimensional quantify, inverse variance")
h.writeto(args.outfile,overwrite="True")
log.info("wrote {}".format(args.outfile))
