#include "MessageService/MsgService.h"
#include "MCReweight/Zfluk.h"

#include "ZflukWeight.h"
#include "Parsers.h"
#include "FitData.h"
#include "FitPar.h"
#include "GenericFactory.h"

#include "TFile.h"
#include "TDirectory.h"
#include "TGraph.h"
#include "TH1D.h"
#include "TH2F.h"

#include <boost/algorithm/string/split.hpp>
#include <boost/algorithm/string/classification.hpp>
#include <boost/algorithm/string/predicate.hpp>
#include <boost/algorithm/string.hpp>

REGISTER_OBJECT(SkzpWeight,ZflukWeight)

  using namespace boost;
using namespace Parsers;
CVSID("");


ZflukWeight::ZflukWeight()
  :fName("ZflukWeight"),
   fDataFile(""),
   fConfig("PiMinus_CedarDaikon"),
   fUseNa49(true),
   fNa49DataFile("data/na49.root"),
   fPiPlusPtPenalty(15.),
   fKPlusPtPenalty(15.),
   fXfSlices(4),
   fPiBinWidth(15.),
   fKBinWidth(15.),
   fFlukaError(0.2),
   fNa49Error(0.05)
{
  
  fZfluk=new Zfluk();
  
  // default vales for pt slices
  fPtSlices.push_back(0.);
  fPtSlices.push_back(0.1);
  fPtSlices.push_back(0.3);
  fPtSlices.push_back(0.5);
  fPtSlices.push_back(1.0);
  
}

ZflukWeight::~ZflukWeight()
{
  delete fZfluk;
}


bool ZflukWeight::Init() 
{
  TDirectory *dir=gDirectory;
  if(fUseNa49) {
    TFile* f=new TFile(fNa49DataFile.c_str());
    if (f) {
      TGraph *pina49=dynamic_cast<TGraph*> (f->Get("pina49"));
      for (int i=0;i<pina49->GetN();i++) {
        fNa49Pz.push_back(pina49->GetX()[i]);
        fNa49PiRatio.push_back(pina49->GetY()[i]);
      }
    }
    else {
      MSG("ZflukWeight",Msg::kWarning)<<"Cannot find Na49 datafile at "<<fNa49DataFile
				    <<" for ZflukWeight function "<<fName<<endl;
      return false;
    }
  }
  dir->cd();
  return true;
  
}
    
  

bool ZflukWeight::SetProperty(std::string prop) 
{
  
  string property=prop.substr(0,prop.find("="));
  string val=prop.substr(prop.find("=")+1);
  
  
  if(property=="Parameters") {
    vector<std::string> params;

    parse_val(val,params);
    std::vector<std::string>::iterator iter=params.begin();
    std::vector<std::string>::iterator iter_end=params.end();
    for( ; iter!=iter_end ; ++iter) {
      // remove the trailing '
      //cout<<*iter<<endl;
      std::vector<string> vals;
      split(vals, *iter, is_any_of("() "), token_compress_on);
      
      FitPar par(vals);
      fParameters.push_back(par);
    }
  }
  else if(property=="UseNa49")         parse_val(val,fUseNa49);
  else if(property=="Na49DataFile")    parse_val(val,fNa49DataFile);
  else if(property=="PiPlusPtPenalty") parse_val(val,fPiPlusPtPenalty);
  else if(property=="KPlusPtPenalty")  parse_val(val,fKPlusPtPenalty);
  else if(property=="XfSlices")        parse_val(val,fXfSlices);
  else if(property=="PiBinWidth")      parse_val(val,fPiBinWidth);
  else if(property=="KBinWidth")       parse_val(val,fKBinWidth);
  else if(property=="PtSlices")        parse_val(val,fPtSlices);
  else if(property=="FlukaError")      parse_val(val,fFlukaError);
  else if(property=="Na49Error")       parse_val(val,fNa49Error);
  else {
    MSG("ZflukWeight",Msg::kWarning)<<"Cannot set property "<<property
				    <<" for ZflukWeight function "<<fName<<endl;
    return false;
  } 
   
  return true;
  
}


double ZflukWeight::GetWeight(FitData &data)
{
  
  double weight=fZfluk->GetWeight(data.ParentType(),data.TargetPt(),data.TargetPz());
  MSG("ZflukWeight",Msg::kVerbose)
    <<"Fluk Weight="<<weight<<endl;  
  return weight;
}



double ZflukWeight::GetPT() const
{
  double penalty=0;

  // term for pi+
  double ptshift=fZfluk->GetPTshift(8);
  if(fPiPlusPtPenalty!=0.) penalty+=ptshift*ptshift/
			     (fPiPlusPtPenalty*fPiPlusPtPenalty);
  // term for k+
  ptshift=fZfluk->GetPTshift(11);
  if(fKPlusPtPenalty!=0.) penalty+=ptshift*ptshift/
			    (fKPlusPtPenalty*fKPlusPtPenalty);

  int n_pt_slices=(int)fPtSlices.size()-1;

  // fractional changes in number of pi and K
  for (int i=0;i<fXfSlices;i++) 
    for (int j=0;j<n_pt_slices;j++) {
      
      // pi+/pi- ratio from fluka
      if(!fUseNa49) {
	double npp1=fZfluk->GetNFrac(8,double(i)*fPiBinWidth,
					  double(i+1)*fPiBinWidth,
					  fPtSlices[j],fPtSlices[j+1]);
	double npm1=fZfluk->GetNFrac(9,double(i)*fPiBinWidth,
					  double(i+1)*fPiBinWidth,
					  fPtSlices[j],fPtSlices[j+1]);
	penalty+=(npp1/npm1-1.)*(npp1/npm1-1.)/(fFlukaError*fFlukaError)/
	  (double(fXfSlices)*double(n_pt_slices)); 
      }
      
      
      double npp2=fZfluk->GetNFrac(8,double(i)*fKBinWidth,
					double(i+1)*fKBinWidth,
                                        fPtSlices[j],fPtSlices[j+1]);
      double npm2=fZfluk->GetNFrac(9,double(i)*fKBinWidth,
					double(i+1)*fKBinWidth,
                                        fPtSlices[j],fPtSlices[j+1]);
      double nkp=fZfluk->GetNFrac(11,double(i)*fKBinWidth,
				       double(i+1)*fKBinWidth,
                                       fPtSlices[j],fPtSlices[j+1]);
      double nkm=fZfluk->GetNFrac(12,double(i)*fKBinWidth,
				       double(i+1)*fKBinWidth,
                                       fPtSlices[j],fPtSlices[j+1]);

      // K+/K- ratio
      penalty+=(nkp/nkm-1.)*(nkp/nkm-1.)/(fFlukaError*fFlukaError)/
	(double(fXfSlices)*double(n_pt_slices));  

      // K+/pi+ ratio
      penalty+=(nkp/npp2-1.)*(nkp/npp2-1.)/(fFlukaError*fFlukaError)/
	(double(fXfSlices)*double(n_pt_slices));   
      
      // K-/pi-
      penalty+=(nkm/npm2-1.)*(nkm/npm2-1.)/(fFlukaError*fFlukaError)/
	(double(fXfSlices)*double(n_pt_slices));  

    }


  if(fUseNa49) {
    // pi+/pi- ratio penalty term (NA49)
    TH2F* piplus_2d=fZfluk->GetReweightedPTXF(8);
    TH1D* piratio=piplus_2d->ProjectionX("piratio",1,piplus_2d->GetNbinsY());

    TH2F* piminus_2d=fZfluk->GetReweightedPTXF(9);
    TH1D* piminus=piminus_2d->ProjectionX("piminus",1,piminus_2d->GetNbinsY());

    piratio->Divide(piminus);

    for (unsigned int i=0;i<fNa49Pz.size();i++) {
      int bin = piratio->FindBin(fNa49Pz[i]);
      double rwrat=piratio->GetBinContent(bin);
      penalty+= (rwrat-fNa49PiRatio[i])*(rwrat-fNa49PiRatio[i])/
	(fNa49Error*fNa49PiRatio[i]*fNa49Error*fNa49PiRatio[i]);
    }
  }
    
  return penalty;
}

  
void ZflukWeight::Write() 
{
  ////////////////////////////////
  // write out histograms for
  // pi+,pi-,K0L,K+,K-
  ///////////////////////////////
  for(int i=8; i<=12; ++i) {

    (fZfluk->GetPTXF(i))->Write();
    (fZfluk->GetReweightedPTXF(i)->Write());
    (fZfluk->GetWeightHistogram(i))->Write();
  }


  ///////////////////////////
  // Get Ratios 
  ///////////////////////////

  // From fluka05
  TH1D* piplus_f05= dynamic_cast<TH1D*>
    ( (fZfluk->GetPTXF(8))->ProjectionX()->Clone("piplus_f05"));
  TH1D* piminus_f05= dynamic_cast<TH1D*>
    ( (fZfluk->GetPTXF(9))->ProjectionX()->Clone("piminus_f05"));
  TH1D* Kplus_f05= dynamic_cast<TH1D*>
    ( (fZfluk->GetPTXF(11))->ProjectionX()->Clone("Kplus_f05"));
  TH1D* Kminus_f05= dynamic_cast<TH1D*>
    ( (fZfluk->GetPTXF(12))->ProjectionX()->Clone("Kminus_f05"));
  

  // From the fit
  TH1D* piplus_fit= dynamic_cast<TH1D*>
    ( (fZfluk->GetReweightedPTXF(8))->ProjectionX()->Clone("piplus_fit"));
  TH1D* piminus_fit= dynamic_cast<TH1D*>
    ( (fZfluk->GetReweightedPTXF(9))->ProjectionX()->Clone("piminus_fit"));
  TH1D* Kplus_fit= dynamic_cast<TH1D*>
    ( (fZfluk->GetReweightedPTXF(11))->ProjectionX()->Clone("Kplus_fit"));
  TH1D* Kminus_fit= dynamic_cast<TH1D*>
    ( (fZfluk->GetReweightedPTXF(12))->ProjectionX()->Clone("Kminus_fit"));

  // pi+/pi-
  TH1D* piplus_piminus_f05=dynamic_cast<TH1D*>
    (piplus_f05->Clone("piplus_piminus_f05")); 
  piplus_piminus_f05->Divide(piminus_f05); 
  piplus_piminus_f05->Write();

  TH1D* piplus_piminus_fit=dynamic_cast<TH1D*>
    (piplus_fit->Clone("piplus_piminus_fit")); 
  piplus_piminus_fit->Divide(piminus_fit); 
  piplus_piminus_fit->Write();

  // K+/K-
  TH1D* Kplus_Kminus_f05=dynamic_cast<TH1D*>
    (Kplus_f05->Clone("Kplus_Kminus_f05")); 
  Kplus_Kminus_f05->Divide(Kminus_f05); 
  Kplus_Kminus_f05->Write();

  TH1D* Kplus_Kminus_fit=dynamic_cast<TH1D*>
    (Kplus_fit->Clone("Kplus_Kminus_fit")); 
  Kplus_Kminus_fit->Divide(Kminus_fit); 
  Kplus_Kminus_fit->Write();

  // K+/pi+
  TH1D* Kplus_piplus_f05=dynamic_cast<TH1D*>
    (Kplus_f05->Clone("Kplus_piplus_f05")); 
  Kplus_piplus_f05->Divide(piplus_f05); 
  Kplus_piplus_f05->Write();

  TH1D* Kplus_piplus_fit=dynamic_cast<TH1D*>
    (Kplus_fit->Clone("Kplus_piplus_fit")); 
  Kplus_piplus_fit->Divide(piplus_fit); 
  Kplus_piplus_fit->Write();

  // K-/pi-
  TH1D* Kminus_piminus_f05=dynamic_cast<TH1D*>
    (Kminus_f05->Clone("Kminus_piminus_f05")); 
  Kminus_piminus_f05->Divide(piminus_f05); 
  Kminus_piminus_f05->Write();

  TH1D* Kminus_piminus_fit=dynamic_cast<TH1D*>
    (Kminus_fit->Clone("Kminus_piminus_fit")); 
  Kminus_piminus_fit->Divide(piminus_fit); 
  Kminus_piminus_fit->Write();
  
  
}
