#include <math.h>
#include <sstream>
#include <iostream>
#include <TH1F.h>
#include <TGraphAsymmErrors.h>
#include "efficiencies.h"

// This should really be a class not a set of functions

template <typename T1, typename T2>
std::pair<double, std::pair<double, double> > GetEfficiency(T1 passed, T2 total){
  double theEff = double(passed)/double(total);
  double theEffError = sqrt(theEff*(1.+theEff)/double(total));
  std::pair<double, double> theEffErrors = std::make_pair<double, double>(theEffError, theEffError);  
  if((theEff + theEffError) >= 1. || (theEff - theEffError <= 0.)){
  //if(theEff == 1. || theEff == 0.){
    //do one bin histo && extract bayes eff
    TH1F* theNumHisto = new TH1F("theNumHisto","theNumHisto",1,0,1);
    theNumHisto->SetBinContent(1,passed);
    theNumHisto->Sumw2();
    TH1F* theDenHisto = new TH1F("theDenHisto","",1,0,1);
    theDenHisto->SetBinContent(1,total);
    theDenHisto->Sumw2();
    TGraphAsymmErrors* bayesEff = new TGraphAsymmErrors();
    bayesEff->BayesDivide(theNumHisto,theDenHisto,"b");
    double effErrorHigh = bayesEff->GetErrorYhigh(0);
    double effErrorLow = bayesEff->GetErrorYlow(0);
    theEffErrors = std::make_pair<double, double>(effErrorLow, effErrorHigh); 
 
    delete theNumHisto;
    delete theDenHisto;
    delete bayesEff;
  }
  std::pair<double,std::pair<double,double> > theEffAndError = std::make_pair<double,std::pair<double,double> >(theEff,theEffErrors);
  return theEffAndError;
}

std::pair<double, double> GetEfficiencyProd(std::pair<double,double> eff1, std::pair<double,double> eff2){
  double eff = eff1.first * eff2.first;
  double effError = sqrt((eff2.first*eff2.first * eff1.second*eff1.second) +
                         (eff1.first*eff1.first*eff2.second*eff2.second));
  return std::make_pair<double,double>(eff,effError);
}

std::pair<double,double> GetEfficiencyRatio(std::pair<double,double> eff1, std::pair<double,double> eff2){
  double effRatio = eff1.first/eff2.first;
  double sigma = (1./eff2.first)*sqrt(eff1.second*eff1.second + effRatio*effRatio*eff2.second*eff2.second);
  return std::make_pair<double,double>(effRatio, sigma);
}

template <typename T1, typename T2>
std::string PrintEfficiency(T1 passed, T2 total){
  std::pair<double, std::pair<double, double> > theEff = GetEfficiency(passed, total);
  
  std::stringstream theEffValue;
  theEffValue.setf(std::ios::fixed);
  theEffValue.setf(std::ios::showpoint);
  theEffValue.precision(2);
  if(theEff.first < 0.01) theEffValue << std::scientific << theEff.first*100. << " +- "  << std::max(theEff.second.first, theEff.second.second)*100.;
  else theEffValue << theEff.first*100. << " +- "  << std::max(theEff.second.first, theEff.second.second)*100.;
  //std::cout << "Calculating eff for passed/total: " << passed 
  //          << "\t" << total << "\t Eff: " << theEff.first 
  //	    << "\t Error: " << theEff.second.first << "\t" << theEff.second.second << std::endl;
  return theEffValue.str();
}

std::string PrintWeightedEvent(double theEventNumber){
  std::stringstream theOutString;
  theOutString.setf(std::ios::fixed);
  theOutString.setf(std::ios::showpoint);
  theOutString.precision(1);
  if(theEventNumber <=1.) theOutString << std::scientific << theEventNumber;
  else{
    theOutString.precision(2);
    theOutString << theEventNumber;
  }
  return theOutString.str();
}

std::pair<double, double> GetSurvivingEvents(std::pair<double, double> nEvents, std::pair<double, std::pair<double, double> >efficiency){
  std::pair<double, double> survivingPair;
  double eff = efficiency.first;
  std::pair<double, double> effError = efficiency.second;
  double nSurviving = eff*nEvents.first;
  double errorFirstTerm = nEvents.first*effError.first;
  if(effError.first==0.) errorFirstTerm = nEvents.first*effError.second;
  double errorSecondTerm = eff*nEvents.second;
  double nSurvivingError = sqrt((errorFirstTerm*errorFirstTerm)+(errorSecondTerm*errorSecondTerm));
  survivingPair = std::make_pair<double, double>(nSurviving, nSurvivingError);  
  return survivingPair;
}

std::pair<double, double> GetSurvivingEvents(double nEvents, std::pair<double, double> efficiency){
  //std::cout << "GetSurviving\t" << nEvents << "\t" << efficiency.first << std::endl;
  std::pair<double, double> survivingPair;
  double eff = efficiency.first;
  double effError = efficiency.second;
  double nSurviving = eff*nEvents;
  double errorFirstTerm = nEvents*effError;
  double errorSecondTerm = eff*sqrt(nEvents);
  double nSurvivingError = sqrt((errorFirstTerm*errorFirstTerm)+(errorSecondTerm*errorSecondTerm));
  survivingPair = std::make_pair<double, double>(nSurviving, nSurvivingError);  
  return survivingPair;
}

std::pair<double,double> GetRatio(double nA, double nB){
  //Get ratio of A/B where A + B = total events
  double total = nA + nB;
  double ratio = (total - nB) / nB;
  double errorB = sqrt(nB);
  double errorTotal = sqrt(total);
  double dRatioWRTnB = -1./nB - nA/(nB*nB);
  double dRationWRTtotal = 1./nB;
  double ratioError = sqrt(dRatioWRTnB*dRatioWRTnB*errorB*errorB + dRationWRTtotal*dRationWRTtotal*errorTotal*errorTotal);
  return std::make_pair<double,double>(ratio,ratioError);
}

// explicit instatiation of the templates
template std::pair<double, std::pair<double, double> > GetEfficiency(int, int);
template std::pair<double, std::pair<double, double> > GetEfficiency(unsigned int, unsigned int);
template std::pair<double, std::pair<double, double> > GetEfficiency(float, float);
template std::pair<double, std::pair<double, double> > GetEfficiency(double, double);
template std::pair<double, std::pair<double, double> > GetEfficiency(double, int);
template std::pair<double, std::pair<double, double> > GetEfficiency(double, long);
template std::pair<double, std::pair<double, double> > GetEfficiency(double, long long);
template std::pair<double, std::pair<double, double> > GetEfficiency(float, int);
template std::string PrintEfficiency(int, int);
template std::string PrintEfficiency(unsigned int, unsigned int);
template std::string PrintEfficiency(float, float);
template std::string PrintEfficiency(float, int);
template std::string PrintEfficiency(double, double);
template std::string PrintEfficiency(double, int);
template std::string PrintEfficiency(double, long);
template std::string PrintEfficiency(double, long long);