//_____________________________________________________________________________
///
/// MatrixCalculator - performs matrix calculations for CandFitTrackSA 
/// track fitter. Calculates the multiple Coulomb scattering covariance 
/// matrix and the propagator matrix; performs a fit iteration by
/// solving the matrix equation that minimizes chi-square function.
///
/// \author Sergei Avvakumov avva@fnal.gov
///

#include "TMatrixD.h"
#include "TDecompChol.h"

#include "MessageService/MsgService.h"
#include "MessageService/MsgFormat.h"

#include "CandFitTrackSA/ConstFT.h"
#include "CandFitTrackSA/DataFT.h"
#include "CandFitTrackSA/FitResult.h"
#include "CandFitTrackSA/MatrixCalculator.h"

#include <cmath>

using namespace ConstFT;

CVSID("$Id: MatrixCalculator.cxx,v 1.8 2007/02/04 06:10:47 rhatcher Exp $");

///
/// default constructor
///
MatrixCalculator::MatrixCalculator() : 
        fFitCovM(NTrackParams, NTrackParams), 
        fFitErrM(NTrackParams, NTrackParams),
        fTrackIn(NTrackParams), fTrackOut(NTrackParams)
{}


///
/// constructor
///
MatrixCalculator::MatrixCalculator(const AlgConfig& /*ac*/, const TrackContext& /*tc*/) :
        fFitCovM(NTrackParams, NTrackParams), 
        fFitErrM(NTrackParams, NTrackParams),
        fTrackIn(NTrackParams), fTrackOut(NTrackParams)
{}


///
/// destructor
///
MatrixCalculator::~MatrixCalculator()
{}

///
/// get fit results
///
FitResult MatrixCalculator::GetFitResult() const
{
    return FitResult(   fFitErrM, fTrackOut, fChi2, 
                        fDChi2, fNPlanesUsed, fV.GetNcols() );
}


///
/// The main method called by AlgFitTrackSA - based on the
/// track data from DataFT object, calculate covariance and
/// propagator matrices and solve the matrix equation; return
/// error status - 0 is succes, otherwise failure.  
/// Equations in matrix form:
///
///      Cm - column of measured coordinates
///
///      C  - column of track coordinates (from the swimmer)
///
///      p  - track parameters - u, du/dz, v, dv/dz, q/p
///
///      A  - track propagation matrix
///
///      V  - covariance matrix
///
///      W  - weight matrix - inversed covariance matrix
///
///      C = A * p
///
///      minimizing  chi2 =  (C - Cm)T  * W * (C - Cm)
///
///      solution:   p =  inv(AT * W * A)   * AT  * W * Cm
///
///      fit error matrix:  E = inv(AT * W * A)
///
Int_t MatrixCalculator::Solve(const DataFT& data) 
{
    fTrackIn = data.GetTrack();
    
    fNPlanesUsed = data.GetNPlanesUsed();
        
    MakePropagatorMatrix(data);
    
    TMatrixD At(TMatrixD::kTransposed, fA);

    MakeCovarianceMatrix(data);

    //fW.ResizeTo(fV.GetNrows(), fV.GetNcols());

#if ROOT_VERSION_CODE >= ROOT_VERSION(4,0,0)
    // using Cholesky decomposition here - works on symmetric
    // matrices TMatrixDSym only and should be faster than
    // regular inversion
    // TDecompChol chol(fV); 
    // chol.Invert(fW);
    
    // fW = TMatrixD(TMatrixD::kInverted, fV);
    
    TMatrixD W(TMatrixD::kInverted, fV);
#else
    // fW = TMatrixD(TMatrixD::kInvertedPosDef, fV);
    
    TMatrixD W(TMatrixD::kInverted, fV);
#endif
        
    fFitCovM = TMatrixD( TMatrixD(At,TMatrixD::kMult,W), TMatrixD::kMult, fA);

#if ROOT_VERSION_CODE >= ROOT_VERSION(4,0,0)
    fFitErrM = TMatrixD(TMatrixD::kInverted, fFitCovM);
#else    
    fFitErrM = TMatrixD(TMatrixD::kInvertedPosDef, fFitCovM);
#endif

    MsgStream *mftsa = &MSGSTREAM("FitTrackSA", Msg::kVerbose);
    (*mftsa) << "Solution Error Matrix:\n";
    MsgFormat efmt("%10.2e");
    for (Int_t i = 0; i<fFitErrM.GetNrows(); i++) {
        for (Int_t j = 0; j<fFitErrM.GetNcols(); j++) {
            (*mftsa) << efmt(fFitErrM(i,j));
        }
        (*mftsa) << "\n";
    }

    data.FillVectorC(fC);

    TMatrixD solution( fFitErrM, TMatrixD::kMult, 
        TMatrixD(At, TMatrixD::kMult, TMatrixD(W,TMatrixD::kMult,fC)) );

    fTrackOut = TMatrixDColumn(solution, 0);

    TMatrixD  vRes(1,1);
    data.FillVectorRes(vRes);
    TMatrixD chi2(TMatrixD(TMatrixD::kTransposed,vRes), 
                    TMatrixD::kMult, TMatrixD(W, TMatrixD::kMult, vRes));
    fChi2 = chi2(0,0);

    TVectorD vres(fTrackOut);
    for (Int_t i = 0; i<NTrackParams; i++) {
        vres(i) -= fTrackIn(i);
    }
    TVectorD vresT(vres);
    vresT *= fFitCovM;
    
    fDChi2 = vresT*vres;
    
    return 0;
}

/// 
/// calculate the propagator matrix
///
Int_t MatrixCalculator::MakePropagatorMatrix(const DataFT& data) 
{
    Int_t nuhits, nvhits;
    nuhits = data.GetNUHitsUsed();
    nvhits = data.GetNVHitsUsed();

    fA.ResizeTo(nuhits+nvhits, NTrackParams);

    Int_t nplanes = data.GetNPlanesUsed();
    Int_t ihits_u = 0;
    for (Int_t i=0; i<nplanes; i++) {
        if ( data.UHitUse(i) ) {
            fA(ihits_u,kU) = 1.;
            fA(ihits_u,kdUdZ) = data.GetZ(i) - data.GetZ(0);
            fA(ihits_u,kV) = 0.;
            fA(ihits_u,kdVdZ) = 0.;
            fA(ihits_u,kQoverP) = (data.GetUf(i) - fTrackIn(kU) -
               fTrackIn(kdUdZ)*(data.GetZ(i)-data.GetZ(0)))/fTrackIn(kQoverP);
            ihits_u++;
        }
    }

    Int_t ihits_v = 0;
    for (Int_t i = 0; i<nplanes; i++) {
        if (data.VHitUse(i) ) {
            fA(ihits_v+nuhits,kV) = 1.;
            fA(ihits_v+nuhits,kdVdZ) = data.GetZ(i) - data.GetZ(0);
            fA(ihits_v+nuhits,kU) = 0.;
            fA(ihits_v+nuhits,kdUdZ) = 0.;
            fA(ihits_v+nuhits,kQoverP) = (data.GetVf(i) - fTrackIn(kV) -
               fTrackIn(kdVdZ)*(data.GetZ(i)-data.GetZ(0)))/fTrackIn(kQoverP);
            ihits_v++;
        }
    }
    
    return 0;
}

///
/// calculate the covariance matrix
///
Int_t MatrixCalculator::MakeCovarianceMatrix(const DataFT& data) 
{
    Int_t nplanes;
    nplanes = data.GetNPlanesUsed();

    Int_t nuhits, nvhits, nhits;
    nuhits = data.GetNUHitsUsed();
    nvhits = data.GetNVHitsUsed();
    nhits = nuhits+nvhits;
    fV.ResizeTo(nhits, nhits);
    fV.Zero();

    // Calculate MCS scattering angles for each plane
    TVectorD theta_i(nplanes);
    for (Int_t i = 0; i<nplanes; i++) {
        theta_i(i) = data.ThetaMCS(i);
    }

    // Calculate MCS covariance matrix
    Double_t diag, non_diag;
    Int_t index1, index2;
    // U part
    index1 = 0;
    for (Int_t n = 0; n<nplanes; n++) {
        if ( data.UHitUse(n) ) {
            // calculate diagonal elements
            diag = 0.0;
            for (Int_t i = 0; i<=n; i++) {
                diag += DiagonalElement(i, n, theta_i, data);
            }
            fV(index1,index1) = diag;

            // non-diagonal elements
            index2 = index1+1;
            for (Int_t k = n+1; k<nplanes; k++) {
                if ( data.UHitUse(k) ) {
                    non_diag = 0.0;
                    for (Int_t i = 0; i<=n; i++) {
                        non_diag += NonDiagonalElement(i, k, n, theta_i, data);
                    }
                    fV(index1,index2) = non_diag;
                    fV(index2,index1) = non_diag;
                    index2++;
                }
            }
            index1++;
        }
    }
    // V part
    index1 = 0;
    for (Int_t n = 0; n<nplanes; n++) {
        if ( data.VHitUse(n) ) {
            // calculate diagonal elements
            diag = 0.0;
            for (Int_t i = 0; i<=n; i++) {
                diag += DiagonalElement(i, n, theta_i, data);
            }
            fV(index1+nuhits,index1+nuhits) = diag;

            // non-diagonal elements
            index2 = index1+1;
            for (Int_t k = n+1; k<nplanes; k++) {
                if ( data.VHitUse(k) ) {
                    non_diag = 0.0;
                    for (Int_t i = 0; i<=n; i++) {
                        non_diag += NonDiagonalElement(i, k, n, theta_i, data);
                    }
                    fV(index1+nuhits,index2+nuhits) = non_diag;
                    fV(index2+nuhits,index1+nuhits) = non_diag;
                    index2++;
                }
            }
            index1++;
        }
    }
    
    // Add resolutions
    Int_t i_hit = 0;
    for (Int_t i = 0; i<nplanes; i++) {
        if ( data.UHitUse(i) ) {
            fV(i_hit,i_hit) += pow(data.GetSigmaU(i),2);
            i_hit++;
        }
    }
    for (Int_t i = 0; i<nplanes; i++) {
        if ( data.VHitUse(i) ) {
            fV(i_hit,i_hit) += pow(data.GetSigmaV(i),2);
            i_hit++;
        }
    }

//     MsgStream *mftsa = &MSGSTREAM("FitTrackSA", Msg::kVerbose);
//     (*mftsa) << "Covariance Matrix: \n";
//     MsgFormat ffmt("%10.3e");
// 
//     for (Int_t i=0; i<nhits; i++) {
//         for (Int_t j=0; j<nhits; j++) {
//             (*mftsa) <<  ffmt(fV(i,j));
//         }
//         (*mftsa) << "\n";
//     }

    return 0;
}

///
/// calcuate diagonal element of the covariance matrix
///
Double_t MatrixCalculator::DiagonalElement(Int_t i, Int_t n,
                            const TVectorD& theta_i, const DataFT& data) const
{
    return  pow(theta_i(i),2) *
        ( pow(data.GetdZSteel(i)/data.GetCos(i),2)/3.+
          data.GetdZSteel(i)/data.GetCos(i)*data.T(i,n) +
          pow(data.T(i,n),2)                              );
}

///
/// calcuate non-diagonal element of the covariance matrix
///
Double_t MatrixCalculator::NonDiagonalElement(Int_t i, Int_t k, Int_t n,
                             const TVectorD& theta_i, const DataFT& data) const
{
    return pow(theta_i(i),2) *
        (  pow(data.GetdZSteel(i)/data.GetCos(i),2)/3.+
           data.GetdZSteel(i)/data.GetCos(i)*data.T(i,n) +
           pow(data.T(i,n),2) +
           TMath::Abs(data.GetZ(k)-data.GetZ(n))*
           (data.GetdZSteel(i)/data.GetCos(i)/2.+data.T(i,n))  );
}
