#!/usr/bin/env python

import os,sys
import numpy as np
import argparse

from desispec.io import read_image,write_image,read_xytraceset
from desispec.calibfinder import findcalibfile

parser = argparse.ArgumentParser(description="copy part of an image and paste it into another")
parser.add_argument("--from-image", type=str, required=True, help="input preproc image from which part of the image is copied")
parser.add_argument("--into-image", type=str, required=True, help="input preproc image into which part of the other image is paste")
parser.add_argument("-o","--outfile", type=str, required=True, help="output preproc image")
parser.add_argument("--wrange", type=str, required=True, help="wmin1:wmax1,wmin2:wmax2 ... wavelength range")
parser.add_argument("--trace",type=str, default=None, required=False, help="PSF or xytraceset file defining the fiber trace coordinates (default uses calib finder)")
parser.add_argument("--scale",type=float, default=None, required=False, help="apply scale to signal of pasted region (and ivar accordingly if no --scale-ivar")
parser.add_argument("--scale-ivar",type=float, default=None, required=False, help="apply scale to ivar of pasted region")
parser.add_argument("--replace",action="store_true",help="replace instead of coadding")

args = parser.parse_args()

from_image = read_image(args.from_image)
into_image = read_image(args.into_image)

if args.trace is None :
    args.trace = findcalibfile([from_image.meta],"PSF")
    print("will use {}".format(args.trace))

tset = read_xytraceset(args.trace)

scale=args.scale
scale_ivar=args.scale_ivar
if scale is None :
    scale = 1.
if scale_ivar is None :
    scale_ivar = 1./scale**2

n0=from_image.pix.shape[0]
n1=from_image.pix.shape[1]

for tmp in args.wrange.split(",") :
    vals=tmp.split(":")
    if len(vals)!=2 :
        print("ERROR: cannot decode '{}' from '{}' as wmin:wmax".format(tmp,args.wrange))
        sys.exit(12)
    try :
        wmin=float(vals[0])
        wmax=float(vals[1])
    except ValueError :
        print("ERROR: cannot decode '{}' from '{}' as wmin:wmax".format(tmp,args.wrange))
        sys.exit(12)

    print("copy/paste wavelength region {}:{}".format(wmin,wmax))
    xmin = np.zeros(tset.nspec)
    ymin = np.zeros(tset.nspec)
    xmax = np.zeros(tset.nspec)
    ymax = np.zeros(tset.nspec)
    for i in range(tset.nspec) :
        xmin[i] = tset.x_vs_wave(fiber=i,wavelength=wmin)
        ymin[i] = tset.y_vs_wave(fiber=i,wavelength=wmin)
        xmax[i] = tset.x_vs_wave(fiber=i,wavelength=wmax)
        ymax[i] = tset.y_vs_wave(fiber=i,wavelength=wmax)
    x=np.arange(n1)
    ymin=np.interp(x,xmin,ymin).astype(int) # interpolate across fibers , and convert to int = pixel index
    ymax=np.interp(x,xmax,ymax).astype(int) # interpolate across fibers , and convert to int = pixel index
    yy=np.tile(np.arange(n0),(n1,1)).T
    mask=(yy>=ymin[None,:])&(yy<ymax[None,:])

    if args.replace :
        into_image.pix[mask]  = scale*from_image.pix[mask]
        into_image.ivar[mask] = scale_ivar*from_image.ivar[mask]
        into_image.mask[mask] = from_image.mask[mask]
    else :
        into_image.pix[mask]  += scale*from_image.pix[mask]
        into_image.ivar[mask] *= (from_image.ivar[mask]>0)
        into_image.mask[mask] |= from_image.mask[mask]
        mask2=mask&(into_image.ivar>0)&(from_image.ivar>0)
        into_image.ivar[mask2] = 1/(1./into_image.ivar[mask2]+1./from_image.ivar[mask2]/scale_ivar)


write_image(args.outfile,into_image)
print("wrote {}".format(args.outfile))
