// C++
#include <iostream>
#include <string>
#include <utility>
#include <cassert>

// ROOT
#include "TMath.h"
#include "TH1F.h"
#include "TH2F.h"
#include "TH1D.h"
#include "TH2D.h"
#include "TH3F.h"
#include "TFile.h"
#include "TTree.h"
#include "TCanvas.h"

// MINOS 
#include "Conventions/Detector.h"
#include "MCReweight/MCEventInfo.h"
#include "MCReweight/Zbeam.h"
#include "MCReweight/BeamSys.h"
#include "MCReweight/Zfluk.h"
#include "MCReweight/MCReweight.h"
#include "MCReweight/NuParent.h"
#include "MCReweight/NeugenWeightCalculator.h"

// Local
#include "FitPTRW.h"

ClassImp(FitPTRW)
using namespace std;

//---------------------------------------------------------------------------------------------------
FitPTRW::FitPTRW()
   :fZbeam(),
    fZfluk()
{  
  fZbeam.SetReweightConfig("PiMinus_CedarDaikon");
  iterations=0;
  fUseNeugen=false;
  mcr = &MCReweight::Instance();

  NeugenWeightCalculator *n=new NeugenWeightCalculator();
  mcr->AddWeightCalculator(n); 
  rwtconfig = new Registry();

  fUseNueMRCC=false;

  fUseSmooth=false;
  fSmoothWidth=0.0;

  fMRCCStart.clear();
  fMRCCEnd.clear();
  fMRCCWeight.clear();
}

//---------------------------------------------------------------------------------------------------
FitPTRW::~FitPTRW()
{
}

//---------------------------------------------------------------------------------------------------
void FitPTRW::FillDataHist(const string &filename, const string &histname, const FitBeam::FitBeam_t beam_type)
{
   TFile *datafile = new TFile(filename.c_str(), "READ");
   if(!datafile || !datafile -> IsOpen())
   {
      cerr << "Could not find file " << filename << endl;
      return;
   }
   
   cout << "Adding data for " << FitBeam::NeutrinoTypeAsString(beam_type) <<" "<<FitBeam::BeamTypeAsString(beam_type) << endl;
   TH1D *h = dynamic_cast<TH1D*>(datafile->Get(histname.c_str()));

   fDataHist[beam_type] = dynamic_cast<TH1D*>(h -> Clone());
   fDataHist[beam_type] -> SetDirectory(0);
   fDataHist[beam_type] -> SetName(Form("reco_enu_data_%s_%s",FitBeam::NeutrinoTypeAsString(beam_type).c_str(),
					FitBeam::BeamTypeAsString(beam_type).c_str()));
   fDataHist[beam_type] -> Sumw2();
   datafile -> Close();
}

//---------------------------------------------------------------------------------------------------
void FitPTRW::FillMCVectors(const string &filename, const string &panname, FitBeam::FitBeam_t beam_type, const bool rwinuk)
{
   TFile *mcfile = new TFile(filename.c_str(),"READ");   
   if(!mcfile || !mcfile -> IsOpen())
   {
      cerr << "Could not find file " << filename << endl;
      return;
   }

   TTree *mctree = dynamic_cast<TTree *> (mcfile->Get(panname.c_str()));
   if(!mctree)
   {
      cerr << "Could not find " << panname << " tree in " << filename << endl;
      return;
   }

   if(!fDataHist[beam_type])
     {
       cerr << "Please fill the data histogram first, since MC histogram bins are determined using data hist" << endl;
       return;
     }
   // check if beam_type already exits
   map<FitBeam::FitBeam_t, vector<FitData> >::iterator beam_it = fMCData.find(beam_type);
   map<FitBeam::FitBeam_t, TH1D *>::iterator hit = fMCHist.find(beam_type);

   if(beam_it == fMCData.end())
   {
     cout << "Adding MC for " << FitBeam::NeutrinoTypeAsString(beam_type)<< " "
	  << FitBeam::BeamTypeAsString(beam_type) << endl;
     
     const pair<FitBeam::FitBeam_t, vector<FitData> > p(beam_type, vector<FitData>());
     beam_it = fMCData.insert(p).first;
     
     TH1D *h=(TH1D*)fDataHist[beam_type]->Clone(
		      Form("start_%s_%s",FitBeam::NeutrinoTypeAsString(beam_type).c_str(),
    			   FitBeam::BeamTypeAsString(beam_type).c_str()));

     h -> Reset();
     h -> Sumw2();
     h -> SetDirectory(0);
     const pair<FitBeam::FitBeam_t, TH1D *> ph(beam_type, h);
     hit = fMCHist.insert(ph).first;
     
     assert(beam_it != fMCData.end() && "Failed to insert an value into map");
   }
   else
   {
     cerr << "This beam type was already filled: " << FitBeam::BeamTypeAsString(beam_type) << endl;
     return;
   }
   
   hit ->second -> Reset();
   // initialize init and best chi2 to 0
   initchi2[beam_type]=0.;
   initndf[beam_type]=0;
   
   bestchi2[beam_type]=0.;
   bestndf[beam_type]=0;
  
   //loop over each tree and fill histograms with new weights based on the values in par;
   //compute chi2 and ndf for each, add up chi2 and ndf, which is what needs to be minimized.

   int pass;
   int ntrack;
   int is_fid;
   //int trk_fit_pass;
   //int trkendu;
   //int trkendv;
   
   float reco_enu;
   float true_enu;
   float reco_emu;
   float reco_eshw;
   float true_eshw;
   float tpx;
   float tpy;
   float tpz;

   float ppdxdz;
   float ppdydz;   
   float pppz;
   //int ntype;
   int tptype;
   int ptype;
   
   //   float dave_cc_pid;
   float reco_y;
   float nu_px;
   float nu_py;
   float nu_pz;
   float tar_e;
   float tar_px;
   float tar_py;
   float tar_pz;
   float true_y;
   float true_x;
   float true_q2;
   float true_w2;
   int inu;
   int process;
   int initial_state;
   int nucleus;
   int resnum;
   int had_fs;
   int cc_nc;
   int is_cev;
   Double_t modReweight;
    
   //for gnumi v18 fluxes we need this to find muon parent
   NuParent* parent=0;
   
   //mycuts
   //int emu_meth;
   //float trkeqp;
   //float trkqp;
   //float evtsigfull,trksigfull,evtsigpart,trksigpart;
   //int duvvtx;
    
   cout<<"Filling MC histograms and vectors!"<<endl;
   
   if(rwinuk) cout<<" reweighting for inuke "<<endl;
  
   if(panname == "pan"){ 
     //set branch addresses to new tree
     mctree->SetBranchAddress("pass",&pass);
     mctree->SetBranchAddress("ntrack",&ntrack);
     mctree->SetBranchAddress("is_fid",&is_fid);
     //mctree->SetBranchAddress("trk_fit_pass",&trk_fit_pass);
     //mctree->SetBranchAddress("duvvtx",&duvvtx);
     //mctree->SetBranchAddress("trkendu",&trkendu);
     // mctree->SetBranchAddress("trkendv",&trkendv);
     //mctree->SetBranchAddress("trkqp",&trkqp);
   
     mctree->SetBranchAddress("reco_enu",&reco_enu);
     mctree->SetBranchAddress("true_enu",&true_enu);
     mctree->SetBranchAddress("reco_emu",&reco_emu);
     mctree->SetBranchAddress("true_eshw",&true_eshw);
     mctree->SetBranchAddress("reco_y",&reco_y);
     mctree->SetBranchAddress("true_x",&true_x);
     mctree->SetBranchAddress("true_y",&true_y);
     mctree->SetBranchAddress("true_q2",&true_q2);
     mctree->SetBranchAddress("true_w2",&true_w2);
     mctree->SetBranchAddress("process",&process);
     mctree->SetBranchAddress("initial_state",&initial_state);
     mctree->SetBranchAddress("nucleus",&nucleus);
     mctree->SetBranchAddress("is_cev",&is_cev);
     mctree->SetBranchAddress("cc_nc",&cc_nc);
     //mctree->SetBranchAddress("evtsigfull",&evtsigfull);
     //mctree->SetBranchAddress("evtsigpart",&evtsigpart);
     //mctree->SetBranchAddress("trksigfull",&trksigfull);
     //mctree->SetBranchAddress("trksigpart",&trksigpart);
     mctree->SetBranchAddress("ntrack",&ntrack);
     //mctree->SetBranchAddress("emu_meth",&emu_meth);
     //mctree->SetBranchAddress("trkeqp",&trkeqp);
     //mctree->SetBranchAddress("trkqp",&trkqp);
   
     mctree->SetBranchAddress("reco_eshw",&reco_eshw);
     mctree->SetBranchAddress("tpx",&tpx);
     mctree->SetBranchAddress("tpy",&tpy);
     mctree->SetBranchAddress("tpz",&tpz);
     mctree->SetBranchAddress("tptype",&tptype);
     mctree->SetBranchAddress("ptype",&ptype);
     mctree->SetBranchAddress("pppz",&pppz);
     mctree->SetBranchAddress("ppdydz",&ppdydz);
     mctree->SetBranchAddress("ppdxdz",&ppdxdz);
   
     //   mctree->SetBranchAddress("ntype",&ntype);
     //mctree->SetBranchAddress("dave_cc_pid",&dave_cc_pid);
     mctree->SetBranchAddress("nu_px",&nu_px);
     mctree->SetBranchAddress("nu_py",&nu_py);
     mctree->SetBranchAddress("nu_pz",&nu_pz);
     mctree->SetBranchAddress("tar_px",&tar_px);
     mctree->SetBranchAddress("tar_py",&tar_py);
     mctree->SetBranchAddress("tar_pz",&tar_pz);
     mctree->SetBranchAddress("tar_e",&tar_e);
     mctree->SetBranchAddress("inu",&inu);
  
     mctree->SetBranchAddress("resnum",&resnum);
     mctree->SetBranchAddress("resnum",&had_fs);
     
     //mctree->SetBranchAddress("duvvtx",&duvvtx);      
   }

   if(panname == "ana_nue"){
    mctree->SetMakeClass(1);
    mctree->SetBranchAddress("srtrack.phCCGeV", &reco_emu);
    mctree->SetBranchAddress("srshower.phCCGeV", &reco_eshw);
                                                                              
    mctree->SetBranchAddress("mctrue.nuEnergy", &true_enu);
    mctree->SetBranchAddress("mctrue.showerEnergy",&true_eshw);
    mctree->SetBranchAddress("mctrue.nuFlavor",&inu);
                                                                              
    mctree->SetBranchAddress("fluxinfo.tpx", &tpx);
    mctree->SetBranchAddress("fluxinfo.tpy", &tpy);
    mctree->SetBranchAddress("fluxinfo.tpz", &tpz);
    mctree->SetBranchAddress("fluxinfo.tptype", &tptype);
                                                                              
    mctree->SetBranchAddress("mctrue.interactionType", &cc_nc);
    mctree->SetBranchAddress("xsecweights.xsecweight", &modReweight);
  }


   int z=0;
  //loop over events in tree
   while(mctree->GetEntry(z)>0)
   {
      if(z % 100000==0){
	 cout<<"on entry ND "<<z<<endl;
      }
       ++z;
       if (CutEvent()) continue;

       if (panname == "ana_nue"){
	 // In the Nue Ntuples the default value is -9999
          if(reco_emu < 0) reco_emu = 0;
          if(reco_eshw < 0) reco_eshw = 0;
       }
 
      const double inuke_reco_eshw = InukeParam(reco_eshw,true_eshw,rwinuk);

      FitData fit_data;

      reco_enu = reco_emu + inuke_reco_eshw;
      
      if (panname == "ana_nue")
	{
	  pppz=tpz;
	  if (tpz!=0.)
	    {
	      ppdxdz=tpx/tpz;
	      ppdydz=tpy/tpz;
	    }
	  else 
	    {
	      ppdxdz=0.;
	      ppdydz=0.;
	    }
	  ptype=tptype;
	}
      
      fit_data.RecoEnergy(reco_enu);
      fit_data.TrueEnergy(true_enu);
      fit_data.RecoTrkEnergy(reco_emu);
      fit_data.RecoShwEnergy(inuke_reco_eshw);
      //      fit_data.TargetPt(sqrt(ppdxdz*ppdxdz+ppdydz*ppdydz)*pppz);
      fit_data.TargetPt(sqrt(tpx*tpx+tpy*tpy));
      fit_data.TargetPz(tpz);
      fit_data.NuType(inu);
      fit_data.ParentType(tptype);
      fit_data.InteractionType(cc_nc);
      fit_data.InteractionProcess(process);
      fit_data.IsContained(is_cev);

      double ge = 1.0;
      if(fUseNeugen&&panname == "pan"){
        MCEventInfo ei;
        ei.UseStoredXSec(true);
        ei.nuE=true_enu;
        ei.nuPx=nu_px;
        ei.nuPy=nu_py;
        ei.nuPz=nu_pz;
        ei.tarE=tar_e;
        ei.tarPx=tar_px;
        ei.tarPy=tar_py;
        ei.tarPz=tar_pz;
        ei.y=true_y;
        ei.x=true_x;
        ei.q2=true_q2;
        ei.w2=true_w2;
        ei.iaction=cc_nc;
        ei.inu=inu;
        ei.iresonance=process;
        ei.initial_state=initial_state;
        ei.nucleus=nucleus;
        ei.had_fs=had_fs;
        NuParent *np=0;
        if(ei.iresonance!=1005){
  	  ge = mcr->ComputeWeight(&ei,np,rwtconfig);
        }
      }
      
      fit_data.GeneratorError(ge-1.0);
      
      double gw = 1.0;
      if(panname == "ana_nue"){
         gw = modReweight;
	 if (cc_nc==0&&fUseNueMRCC)
	   {
	     double mrcc = 1.0;
	     double recoE = reco_enu;
	     for(unsigned int i = 0; i < fMRCCWeight.size(); i++){
	       if(recoE >= fMRCCStart[i] && recoE < fMRCCEnd[i] && fMRCCWeight[i] > 0.05)
                 mrcc = fMRCCWeight[i];
	     }
	     gw *= mrcc;
	   }
      }
      
      fit_data.GeneratorWeight(gw);
      
      (beam_it -> second).push_back(fit_data);
      
      const double dpot  = GetDataPOTS(beam_type);
      const double mcpot = GetMCPOTS(beam_type);

      if(!fUseSmooth) (hit -> second )->Fill(reco_enu,dpot/mcpot*gw);   
      else FillSmooth(hit->second,reco_enu,reco_enu*fSmoothWidth,dpot/mcpot*gw);   
   }
   
   mcfile -> Close();
   if (parent) delete parent;
}

//---------------------------------------------------------------------------------------------------
void FitPTRW::FillRWHist()
{
  for(map<FitBeam::FitBeam_t, vector<FitData> >::const_iterator bit = fMCData.begin(); 
       bit != fMCData.end(); ++bit)
   {
     const FitBeam::FitBeam_t beam_type = bit -> first;
      
      // find MC reweight histogram, create a new one if it does not exist
      map<FitBeam::FitBeam_t, TH1D *>::iterator hit = fReWeightHist.find(beam_type);
      if(hit == fReWeightHist.end())
      {	 
	TH1D *h=(TH1D*)fMCHist[beam_type]->Clone(
		       Form("reweight_%s_%s",FitBeam::NeutrinoTypeAsString(beam_type).c_str(),
			    FitBeam::BeamTypeAsString(beam_type).c_str()));

	h -> Reset();
        h -> Sumw2();
	const pair<FitBeam::FitBeam_t, TH1D *> p(beam_type, h);
	fReWeightHist.insert(p);
	hit=fReWeightHist.find(beam_type);
	assert(hit != fReWeightHist.end() && "Failed to insert histogram");
      }
      hit ->second -> Reset();

      const int Ibeam = FitBeam::AsZbeamCode(beam_type);
      const int Idet  = 1;          //Only using near det. 
     
      const double dpot  = GetDataPOTS(beam_type);
      const double mcpot = GetMCPOTS(beam_type);
    
     
      for(vector<FitData>::const_iterator dit = (bit->second).begin(); 
	  dit != (bit->second).end(); ++dit)
      {
	 const double true_enu  = dit -> TrueEnergy();
	 double reco_enu  = dit -> RecoEnergy();
	 const double reco_emu  = dit -> RecoTrkEnergy();
	 const double reco_eshw = dit -> RecoShwEnergy();
	 const int ntype = dit -> NuType();
	 const int tptype = dit -> ParentType();
	 const double pt = dit -> TargetPt();
	 const double pz = dit -> TargetPz();
	 const double genweight = dit -> GeneratorWeight();
	 const double generror = dit -> GeneratorError();
	 const int cc_nc = dit -> InteractionType();
	 const int process = dit -> InteractionProcess();
	 const int is_cev = dit -> IsContained();
	 
	 double beamsysw = 1.;
	 double misshwcal = reco_eshw;

	 if (reco_eshw>0.) misshwcal+=GetShwOffsetPar();

	 //check that shower energy is not negative
	 if (misshwcal<0.) misshwcal=0.;

	 misshwcal *= (1.0 - GetEshwPar());

	 double mismucal = reco_emu;
	 if(is_cev==1) {mismucal *= (1.0 - GetEmuRangePar());}
	 else if(is_cev==0) {mismucal *= (1.0 - GetEmuCrvPar());}

	 Zbeam::ZbeamData_t zdata;
	 zdata.ntype    = ntype;
	 zdata.true_enu = true_enu;
	 zdata.detector = Detector::kNear;
	 zdata.beam     = BeamType::FromZarko(Ibeam);

	 beamsysw *= fZbeam.GetWeight(zdata,BeamSys::kHorn1Offset   ,zbmParVec[0]);
	 beamsysw *= fZbeam.GetWeight(zdata,BeamSys::kBaffleScraping,zbmParVec[1]);
	 beamsysw *= fZbeam.GetWeight(zdata,BeamSys::kPOT           ,zbmParVec[2]);
	 beamsysw *= fZbeam.GetWeight(zdata,BeamSys::kHornIMiscal   ,zbmParVec[3]);
	 beamsysw *= fZbeam.GetWeight(zdata,BeamSys::kHornIDist     ,zbmParVec[4]);

	 //target z 
	 if (beam_type == FitBeam::kNuMuLE010z170i || beam_type == FitBeam::kAntiNuMuLE010z170i) {
	   beamsysw *= fZbeam.GetWeight(zdata,BeamSys::kTargetZ     ,zbmParVec[5]);
	 } else if (beam_type == FitBeam::kNuMuLE010z185i || beam_type == FitBeam::kAntiNuMuLE010z185i) {
	   beamsysw *= fZbeam.GetWeight(zdata,BeamSys::kTargetZ     ,zbmParVec[6]);
	 } else if (beam_type == FitBeam::kNuMuLE010z200i || beam_type == FitBeam::kAntiNuMuLE010z200i) {
	   beamsysw *= fZbeam.GetWeight(zdata,BeamSys::kTargetZ     ,zbmParVec[7]);
	 } else if (beam_type == FitBeam::kNuMuLE100z200i || beam_type == FitBeam::kAntiNuMuLE100z200i) {
	   beamsysw *= fZbeam.GetWeight(zdata,BeamSys::kTargetZ     ,zbmParVec[8]);
	 } else if (beam_type == FitBeam::kNuMuLE150z200i || beam_type == FitBeam::kAntiNuMuLE150z200i) {
	   beamsysw *= fZbeam.GetWeight(zdata,BeamSys::kTargetZ     ,zbmParVec[9]);
	 } else if (beam_type == FitBeam::kNuMuLE250z200i || beam_type == FitBeam::kAntiNuMuLE250z200i) {
	   beamsysw *= fZbeam.GetWeight(zdata,BeamSys::kTargetZ     ,zbmParVec[10]);
	 } else if (beam_type == FitBeam::kNuMuLE010z185iN || beam_type == FitBeam::kAntiNuMuLE010z185iN) {
	   beamsysw *= fZbeam.GetWeight(zdata,BeamSys::kTargetZ     ,zbmParVec[11]);
	 } else if (beam_type == FitBeam::kNuMuLE250z200iN || beam_type == FitBeam::kAntiNuMuLE250z200iN) {
	   beamsysw *= fZbeam.GetWeight(zdata,BeamSys::kTargetZ     ,zbmParVec[12]);
	 }

	 //beam width correction
	 if (beam_type == FitBeam::kNuMuLE250z200i ||
	     beam_type == FitBeam::kNuMuLE100z200i)
	   beamsysw *= fZbeam.GetWeight(zdata,BeamSys::kBeamWidth,0.);
	 
	 //now change pt and pz
	 double func=1.;
	 func = fZfluk.GetWeight(tptype,pt,pz);
	 func*=(beamsysw*genweight);
	 if(cc_nc==0)
	 {
	    func*=(1-GetNCPar(beam_type));
	 } else {
	   func*=(1+GetXSPar(process)*generror);
	 }
	 reco_enu = (misshwcal+mismucal)*(1-GetNuEMiscalPar(ntype));
	 
	 // float numubar xsec
	 if ((ntype==55||ntype==-14)&&true_enu<GetNumubarXsecPar(1))  
	   {
	     double a=GetNumubarXsecPar(0);
	     double xp=GetNumubarXsecPar(1);
	     func*=a+2.*(1.-a)*true_enu/xp+(a-1.)*true_enu*true_enu/(xp*xp);
	   }
	 // fill reweighted histogram
	 if(!fUseSmooth) (hit -> second )->Fill(reco_enu,func*dpot/mcpot);   
	 else FillSmooth(hit->second,reco_enu,reco_enu*fSmoothWidth,func*dpot/mcpot);   
	 
      }

   }
  
}

//---------------------------------------------------------------------------------------------------
double FitPTRW::InukeParam(double reco_eshw, double true_eshw, bool doit)
{
  if(doit){
    return (reco_eshw*
	    (1 + true_eshw*(-2.60867e-01 + true_eshw*5.71546e-02 -
			    true_eshw*true_eshw*5.79411e-03)*
	     TMath::Exp(-true_eshw*4.06206e-01)));
  }
  return reco_eshw;
}

//---------------------------------------------------------------------------------------------------
bool FitPTRW::CutEvent()
{
  return false;
}


//---------------------------------------------------------------------------------------------------


double FitPTRW::GetArea(const double &center,
			     const double &low_edge,const double &high_edge,
			     const double &width)
{
  double ledge = center - width;
  double redge = center + width;
  
  if(ledge < low_edge) ledge = low_edge;
  if(redge > high_edge) redge = high_edge;

  const double x1 = (ledge - center)/width;
  const double x2 = (redge - center)/width;
  
  double I1 = x1 - 2.0*x1*x1*x1/3.0 + x1*x1*x1*x1*x1/5.0;
  double I2 = x2 - 2.0*x2*x2*x2/3.0 + x2*x2*x2*x2*x2/5.0;
  
  double overlap = 15.0*(I2 - I1)/16.0;
  
  return overlap;
  
  
}



//---------------------------------------------------------------------------------------------------


void FitPTRW::FillSmooth(TH1D* &h,const double &center,const double &width,
			 const double &weight)
{

  int bin=h->GetXaxis()->FindBin(center);

  //   skip overflow and underflow
  if(bin<1 || bin>h->GetNbinsX()) return;

  // normalize to total area in histogram
  double norm=1;
  norm=GetArea(center,
		    h->GetXaxis()->GetBinLowEdge(1),
		    h->GetXaxis()->GetBinUpEdge(h->GetNbinsX()),
		    width);

  double low_edge=h->GetXaxis()->GetBinLowEdge(bin);
  double high_edge=h->GetXaxis()->GetBinUpEdge(bin);

  double area=GetArea(center,low_edge,high_edge,width);
  area/=norm;
  h->Fill(center,weight*area);
  
  double ledge=center-width;
  double redge=center+width;

  // check for lower bin
  if( low_edge > ledge && bin >1 ) {
    
    area=GetArea(center,h->GetXaxis()->GetBinLowEdge(bin-1),
		   h->GetXaxis()->GetBinUpEdge(bin-1),width);
      h->Fill(h->GetBinCenter(bin-1),weight*area);
  }
  
  // check for higher bin
  if( high_edge < redge && bin<h->GetNbinsX()  ) {

    area=GetArea(center,h->GetXaxis()->GetBinLowEdge(bin+1),
		      h->GetXaxis()->GetBinUpEdge(bin+1),width);
    h->Fill(h->GetBinCenter(bin+1),weight*area);
  }

}

