#!/usr/bin/env python
# -*- coding: utf-8 -*-

#from __future__ import print_function
#from __future__ import division
#from __future__ import absolute_import


# read a yoda file containing theory prediction for the full run 2 m4l analysis
# output the yoda theory file, and dummy REF data, for rivet/contur


import rivet, yoda, sys, os
import numpy as np
from optparse import OptionParser

verbose=False
pdf_vars=False

parser = OptionParser(usage=__doc__)
parser.add_option("-i", "--input", dest="INPUT",
                  default="EW_SM+QCD_SM-8TeV.yoda",
                  help="input file")

# Command line parsing
opts, files = parser.parse_args()

# Open the input file
aos = yoda.read(opts.INPUT)

thyObjects = []

# output directory for theory yoda
theoryrivet_out = "./Theory-8-scale-pdf-variation"

print("Opened {} with {} YODA objects".format(opts.INPUT, len(aos)))

analyses = {} 

for path, ao in aos.items():

    # skip RAW histos
    if rivet.isRawPath(ao.path()):
        continue

    if ao.type() == "Histo1D":
        ref = ao.mkScatter()
    else:
        ref = ao

    if ref.type()=="Scatter2D":
       
        if ref.name().endswith(']'):
            continue
        
        hparts = ref.path().strip("/").split("/")
        analysisName=hparts[0]
        print("Found {}, {}.".format(analysisName,ref.type()))
        ref.setPath("/THY/"+analysisName+"/"+ref.name())
        ref.setTitle("PowhegBoxZpWp")
 
        noms  = np.array([ b.y()   for b in ref.points() ])
        scaleup = np.array(noms)
        scaledown = np.array(noms)
        
        RAWpdfMean = 0*np.array(noms)
        RAWpdfErr  = 0*np.array(noms) 
        # scale variations (from [1] to [6])
        for ivar in range(1,7):
            histvar = aos["/{}/{}[{}]".format(analysisName,ref.name(),ivar)]
            # histvar = aos["/{}"+analysisName+"/"+"%s[%d]" % (ref.name(), ivar)]
            if verbose: print(histvar)
            histvar = histvar.mkScatter()
            nomsvar  = np.array([ b.y()   for b in histvar.points() ])
            scaleup = np.maximum(scaleup, nomsvar)   
            scaledown = np.minimum(scaledown, nomsvar)   
            if verbose: print("scaleup", scaleup)
            if verbose: print("scaledown", scaledown)
        
        scaleup = scaleup - noms # the error should be the difference wrt nominal
        scaledown = scaledown - noms

        if pdf_vars:
            # pdf variations (from [7] to [106]) # arXiv:1510.03865 eq. 21 & 22
            for ivar in range(7,107):
                 #histvar = aos["/"+analysisName+"/"+"%s[%d]" % (ref.name(), ivar)]
                histvar = aos["/{}/{}[{}]".format(analysisName,ref.name(),ivar)]
                histvar = histvar.mkScatter()
                nomspdf  = np.array([ b.y()   for b in histvar.points() ])
                if verbose: print("nomspdf is", nomspdf)
                RAWpdfMean = RAWpdfMean + nomspdf
                if verbose: print("RAWpdfMean = ", RAWpdfMean)

            if verbose: print("ivar is", ivar)
        
            pdfMean = (1.0/((ivar + 1.0) - 7.0)) * RAWpdfMean # arXiv:1510.03865 eq. 22
            if verbose: print("pdfMean is", pdfMean)

            for ivar in range(7,107):
                histvar = aos["/"+analysisName+"/"+"%s[%d]" % (ref.name(), ivar)]
                histvar = histvar.mkScatter()
                nomspdf  = np.array([ b.y()   for b in histvar.points() ])
                #print("nomspdf is", nomspdf)
                RAWpdfErr = RAWpdfErr + (nomspdf - pdfMean)**2
                #print("RAWpdfErr = ", RAWpdfErr)

            pdfErr = np.sqrt( (1.0/ (ivar - 7.0)) * RAWpdfErr ) # arXiv:1510.03865 eq. 21 
            if verbose: print("pdfErr = ", pdfErr)
        else:
            pdfErr = RAWpdfErr
            
        ErrorBreakdown={} # create a holder for the error breakdown
        if verbose: print("Looping over {} points".format(ref.numPoints())) 
        if verbose: print(len(scaleup),len(pdfErr))
        ipoint=0
        for point in ref.points():
            print(ipoint,point)
            
            if not ipoint in ErrorBreakdown.keys(): # make a new entry for the error breakdown for this point
                ErrorBreakdown[ipoint]={} # make a dummy entry for this point
            ErrorBreakdown[ipoint]["stat"]={"up":point.yErrs()[1],"dn":point.yErrs()[0]} # fill this with the errors
            ErrorBreakdown[ipoint]["scale"]={"up":scaleup[ipoint],"dn":scaledown[ipoint]} # fill this with the errors
            if pdf_vars:
                ErrorBreakdown[ipoint]["pdf"]={"up":pdfErr[ipoint],"dn":-1*pdfErr[ipoint]} # fill this with the errors


            if verbose: 
                print("Stat error: {}".format(point.yErrs()))
                print(scaledown[ipoint],scaleup[ipoint])
                print(-1*pdfErr[ipoint],pdfErr[ipoint])
            point.setErrMinus(2,scaledown[ipoint], "scale")
            point.setErrPlus(2,scaleup[ipoint], "scale")
            point.setYErrs(np.sqrt(scaledown[ipoint]**2+pdfErr[ipoint]**2+point.yErrs()[0]**2),np.sqrt(scaleup[ipoint]**2+pdfErr[ipoint]**2+point.yErrs()[1]**2))

            point.setErrMinus(2,-1.0*pdfErr[ipoint], "pdf")
            point.setErrPlus(2,pdfErr[ipoint], "pdf")
            #print("error map is ",point.errMap())
            ipoint+=1
            if verbose:
                print("Total error: {}".format(point.yErrs()))

        print(ref.hasValidErrorBreakdown())
        ref.setAnnotation("ErrorBreakdown",ErrorBreakdown)
    
     
        if not analysisName in analyses:
            analyses[analysisName] = []
        analyses[analysisName].append(ref)



for ana in analyses:
    theoryRef=os.path.join(theoryrivet_out,ana+"-Theory.yoda")
    print("Writing THY data to {}".format(theoryRef))    
    yoda.write(analyses[ana], theoryRef)
