#!/usr/bin/env python

"""
Compare a new bias to a default bias
"""

import os, sys
import numpy as np
import fitsio
from desiutil.log import get_logger
import desispec.io
from desispec.calibfinder import CalibFinder
from desispec.ccdcalib import compare_bias

#-------------------------------------------------------------------------
import argparse

p = argparse.ArgumentParser(
        description="Compare bias file to default bias for raw data")
p.add_argument('-n', '--night', type=int, help='YEARMMDD night')
p.add_argument('-e', '--expid', type=int, help='Exposure ID')
p.add_argument('-b', '--bias', type=str, help='bias file to compare')
p.add_argument('--debug', action='store_true', help='start ipython at end')
args = p.parse_args()

camera = fitsio.read_header(args.bias)['CAMERA'].upper().strip()

rawfile = desispec.io.findfile('raw', args.night, args.expid)
with fitsio.FITS(rawfile) as fx:
    rawhdr = fx['SPEC'].read_header()
    camhdr = fx[camera].read_header()

cf = CalibFinder([rawhdr, camhdr])
defaultbias = cf.findfile('BIAS')

mdiff1, mdiff2 = compare_bias(rawfile, args.bias, defaultbias)

mean1 = np.mean(mdiff1)
maxabs1 = np.max(np.abs(mdiff1))
std1 = np.std(mdiff1)
mean2 = np.mean(mdiff2)
maxabs2 = np.max(np.abs(mdiff2))
std2 = np.std(mdiff2)

import matplotlib.pyplot as plt

if camera.startswith('B'):
    imgopts  = dict(vmin=-4, vmax=4, aspect='auto')
    histopts = dict(bins=50, range=(-2,2))
else:
    imgopts  = dict(vmin=-0.25, vmax=0.25, aspect='auto')
    histopts = dict(bins=50, range=(-0.25,0.25))

plt.figure()
plt.subplot(221)
plt.imshow(mdiff1, **imgopts)
plt.title(f'{args.night}/{args.expid} {camera}')
plt.ylabel('Diff w/ nightly bias')
ax = plt.subplot(222)
plt.hist(mdiff1.ravel(), **histopts)
plt.title('median(image-bias) in patches')
plt.text(0.03, 0.9, f'maxabs {maxabs1:.2f}', transform=ax.transAxes, fontsize=10)
plt.text(0.03, 0.8, f'stddev {std1:.2f}', transform=ax.transAxes, fontsize=10)


plt.subplot(223)
plt.imshow(mdiff2, **imgopts)
plt.ylabel('Diff w/ default bias')
ax = plt.subplot(224)
plt.hist(mdiff2.ravel(), **histopts)
plt.text(0.03, 0.9, f'maxabs {maxabs2:.2f}', transform=ax.transAxes)
plt.text(0.03, 0.8, f'stddev {std2:.2f}', transform=ax.transAxes)

plt.show()

if args.debug:
    import IPython; IPython.embed()



