#!python

"""
Hartmann doors data analysis script
"""

import os
import argparse
import numpy as np
import matplotlib.pyplot as plt
import fitsio

from desiutil.log import get_logger
from desispec.io import read_xytraceset
from desispec.calibfinder import sp2sm
from desispec.focus import piston_and_tilt_to_gauge_offsets,test_gauge_offsets,RPIXSCALE

parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description="Hartmann doors analysis",
epilog='''
Hartmann doors data analysis sequence:
1) take 2 series of exposures with arc lamps on the white spots,
   one with the left hartmann door closed, and the other one with
   the right hartmann door closed.
2) preprocess the exposures (see desi_preproc )
3) for all exposures, fit trace shifts using the preproc images and either the default or the most recent
   nightly psf (use desi_compute_trace_shifts --arc-lamps --psf psf-yyy.fits -i preproc-xxx.fits -o psf-xxx.fits)
4) run this desi_hartmann_analysis script, using as input the shifted psf(s) from the previous step,
   specifying the set of exposures with the left door closed and the one with the right door closed

This script will determine an average offset in Y_ccd (in pixel units) that
can be converted into a camera focus offset.

A positive delta Y means an negative value of the 'defocus'
as defined in https://desi.lbl.gov/DocDB/cgi-bin/private/ShowDocument?docid=4585 .

In order to restore a good focus, the cryostat has to moved in order to increase
the "absolute value* of the gauge readings.

The average ratio of pixel_offsets/defocus = 20.240, 20.322, 20.389 pixel/mm for the BLUE, RED and NIR camera respectively.
'''
)

parser.add_argument('--left-closed-psf', type = str, required=True, nargs="*",
                    help = 'path to psf with trace coordinates for arc lamp obs with closed left hartmann door')
parser.add_argument('--right-closed-psf', type = str, required=True, nargs="*",
                    help = 'path to psf with trace coordinates for arc lamp obs with closed right hartmann door')
parser.add_argument('--plot', action = 'store_true')
parser.add_argument('--camera', type = str, required = True, help="b r or z")

args        = parser.parse_args()

log = get_logger()

camera=None
x_vals=[]
y_vals=[]
dy_vals=[]
nmeas=len(args.left_closed_psf)
for i in range(nmeas) :

    if not os.path.isfile(args.left_closed_psf[i]) :
        log.warning("missing "+args.left_closed_psf[i])
        continue
    if not os.path.isfile(args.right_closed_psf[i]) :
        log.warning("missing "+args.right_closed_psf[i])
        continue

    head=fitsio.read_header(args.left_closed_psf[i],"PSF")
    if camera is None :
        camera=head["CAMERA"].strip().lower()
    else :
        assert(camera == head["CAMERA"].strip().lower())
    head=fitsio.read_header(args.right_closed_psf[i],"PSF")
    assert(camera == head["CAMERA"].strip().lower())

    left  = read_xytraceset(args.left_closed_psf[i])
    right = read_xytraceset(args.right_closed_psf[i])

    wave=np.linspace(left.wavemin+200,left.wavemax-200,10)
    fibers=np.arange(0,51)*10
    fibers[-1]=499
    log.debug(f"use wavelength = {wave}")
    log.debug(f"use fibers           = {fibers}")

    x=np.zeros((fibers.size,wave.size))
    y=np.zeros((fibers.size,wave.size))
    dx=np.zeros((fibers.size,wave.size))
    dy=np.zeros((fibers.size,wave.size))
    for f,fiber in enumerate(fibers) :
        xleft  = left.x_vs_wave(fiber,wave)
        xright = right.x_vs_wave(fiber,wave)
        x[f]=(xleft+xright)/2
        dx[f]=(xleft-xright)
        yleft  = left.y_vs_wave(fiber,wave)
        yright = right.y_vs_wave(fiber,wave)
        y[f]=(yleft+yright)/2
        dy[f]=(yleft-yright)

    x_vals.append(x.ravel())
    y_vals.append(y.ravel())
    dy_vals.append(dy.ravel())

    meandy=np.mean(dy)
    rmsdy=np.std(dy)
    log.info("LEFT = {} RIGHT= {} dy = {:.3f} +- {:.3f} pixels".format(args.left_closed_psf[i],args.right_closed_psf[i],meandy,rmsdy))



dy_vals = np.hstack(dy_vals)
nmeas=len(dy_vals)

meandy=np.median(dy_vals)
if nmeas>=2 :
    errdy=np.sqrt(np.pi/2./(nmeas-1.))*np.std(dy_vals)
else :
    errdy=0.

camera=str(camera).replace("'","").strip(" ")

if camera[0] == "b" :
    focus_pixels2mm = -1/20.240 # mm/pixel
elif camera[0] == "r" :
    focus_pixels2mm = -1/20.322 # mm/pixel
elif camera[0] == "z" :
    focus_pixels2mm = -1/20.389 # mm/pixel
else :
    log.error("camera name '{}' does not start with b,r or z: I don't know what to do".format(camera))
    sys.exit(12)

defocus=focus_pixels2mm*meandy
err=errdy*np.abs(focus_pixels2mm)

camera=str(camera).replace("'","").strip(" ")
spectro=int(camera[1])
sm=sp2sm(spectro)
log.info("SM{}-{} LEFT-RIGHT(closed) DELTA = {:+.3f} +- {:.4f} pix (N={})".format(sm,camera,meandy,errdy,nmeas))
log.info("SM{}-{} DEFOCUS = {:+.3f} +- {:.4f} mm (N={}) (the correction to apply is of opposite sign)".format(sm,camera,defocus,err,nmeas))

# fit for best focus plane
x = np.hstack(x_vals)
y = np.hstack(y_vals)
focus = - focus_pixels2mm*np.hstack(dy_vals)
rx = x/RPIXSCALE -1
ry = y/RPIXSCALE -1
h=np.vstack([np.ones(rx.size),rx,ry])
a=h.dot(h.T)
b=h.dot(focus)
ai=np.linalg.inv(a)
focus_plane_coefficients=ai.dot(b)
log.debug(focus_plane_coefficients)
best_focus_gauge_offsets = piston_and_tilt_to_gauge_offsets(args.camera,focus_plane_coefficients*1000.) # conversion from mm to microns
names=["TOP","LEFT","RIGHT"]
best_focus_gauges = np.array([best_focus_gauge_offsets[k] for k in names])
log.info("best focus gauges offsets to add ({},{},{}) = {:5.3f} {:5.3f} {:5.3f} mm".format(names[0],names[1],names[2],best_focus_gauges[0],best_focus_gauges[1],best_focus_gauges[2]))

focus_fit = focus_plane_coefficients.T.dot(h)

if args.plot :
    plt.figure("focus")
    a1= plt.subplot(121)
    for fiber in range(fibers.size) :
        a1.plot(x,focus,"o",label="meas")
        a1.plot(x,focus_fit,".",label="best fit plane")
    plt.grid()
    a1.set_xlabel("x ccd (fiber direction)")
    a1.set_ylabel("focus correction to apply (mm)")
    a2= plt.subplot(122)
    for fiber in range(fibers.size) :
        a2.plot(y,focus,"o",label="meas")
        a2.plot(y,focus_fit,".",label="best fit plane")
    plt.grid()
    a2.set_xlabel("y ccd (wavelength direction)")
    #a2.set_ylabel("focus correction to apply (mm)")

    plt.show()
