////////////////////////////////////////////////////////////////////////
// $Id: AlgFitThruMuonList.cxx,v 1.22 2002/08/15 11:22:28 miyagawa Exp $
//
// AlgFitThruMuonList
//
// Begin_Html<img src="../../pedestrians.gif" align=center>
// <a href="../source_warning.html">Warning for beginners</a>.<br> 
//
// This is an Algorithm class to improve the straight-line fit of
// tracks in a CandThruMuonList. It rejects digits that lie too far
// away from the fitted tracks.
//
// Author:  P.S. Miyagawa 10/2000
//
// Also see <a href="../../root_crib/index.html">The ROOT Crib</a> and 
// <a href="../ALG_Classes.html"> Algorithm Classes</a> (part of
// <a href="../index.html">The MINOS Class User Guide</a>)End_Html
////////////////////////////////////////////////////////////////////////

#include <cassert>

#include "Algorithm/AlgConfig.h"
#include "Algorithm/AlgFactory.h"
#include "Algorithm/AlgHandle.h"
#include "Candidate/CandContext.h"
#include "MessageService/MsgService.h"
#include "Plex/PlexStripEndId.h"
#include "UgliGeometry/UgliGeomHandle.h"
#include "BubbleSpeak/AlgFitThruMuonList.h"
#include "BubbleSpeak/CandDigiPairHandle.h"
#include "BubbleSpeak/CandStraightClusterHandle.h"
#include "BubbleSpeak/CandThruMuon.h"
#include "BubbleSpeak/CandThruMuonHandle.h"
#include "BubbleSpeak/CandThruMuonListHandle.h"

ClassImp(AlgFitThruMuonList)

//......................................................................

CVSID("$Id: AlgFitThruMuonList.cxx,v 1.22 2002/08/15 11:22:28 miyagawa Exp $");

//......................................................................

AlgFitThruMuonList::AlgFitThruMuonList()
{
//
//  Purpose:    Default constructor.
//
//  Arguments:  n/a
//
//  Return:     n/a
//
}

//......................................................................

AlgFitThruMuonList::~AlgFitThruMuonList()
{
//
//  Purpose:    Default destructor.
//
//  Arguments:  n/a
//
//  Return:     n/a
//
}

//......................................................................

void AlgFitThruMuonList::RunAlg(AlgConfig &ac, CandHandle &ch,
                                            CandContext &cx)
{
//
//  Purpose:  Improve straight-line fit to a muon track by rejecting
//            digits whose strips do not lie within a specified number
//            of strip widths of the fitted track.
//
//  Arguments:
//    ac        in    AlgConfig containing muon fit parameters.
//    ch        in    Handle to the new CandThruMuonList to fill.
//    cx        in    CandContext containing the CandThruMuonList whose
//                    track fits are to be improved.
//
//  Return:   n/a
//

   MSG("BubAlg", Msg::kVerbose)
      << "Starting AlgFitThruMuonList::RunAlg()" << endl;

// Check for CandThruMuonListHandle input.
   assert(cx.GetDataIn());
   assert(cx.GetDataIn()->InheritsFrom("CandThruMuonListHandle"));
   const CandThruMuonListHandle *cmlh =
            dynamic_cast<const CandThruMuonListHandle*>(cx.GetDataIn());

// Save config parameters for fit rejection.
   Float_t fspan = ac.GetDouble("FitSpan");
   Int_t   fimax = ac.GetInt("FitIterMax");

// General setup for creating new CandThruMuons.
   MSG("BubAlg", Msg::kVerbose) << "Get AlgThruMuon instance from "
      << "AlgThruMuonFactory." << endl;
   AlgFactory &af = AlgFactory::GetInstance();
   AlgHandle ah = af.GetAlgHandle("AlgThruMuon", "default");

   MSG("BubAlg", Msg::kVerbose)
      << "Create CandContext instance." << endl;
   CandContext cxx(this, cx.GetMom());
   cxx.SetCandRecord(cx.GetCandRecord());

// Iterate over muons.
   TIter cmhItr(cmlh->GetDaughterIterator());
   while (CandThruMuonHandle *cmh =
             dynamic_cast<CandThruMuonHandle *>(cmhItr())) {
      TObjArray mpair(3);

// Fit cluster U.
      const CandStraightClusterHandle *cch = cmh->GetClusterU();
      TObjArray *hay;
      RunFindStraightTrackAlg(fspan, fimax, *cch, hay);
      if (!hay) continue;
      mpair.AddAt(hay, 0);

// Fit cluster V.
      cch = cmh->GetClusterV();
      RunFindStraightTrackAlg(fspan, fimax, *cch, hay);
      if (!hay) {
         delete mpair.RemoveAt(0);
         continue;
      }
      mpair.AddAt(hay, 1);

// Add veto shield/cosmic counter digits, if any.
      TObjArray *cosm = 0;
      TIter chhItr(cmh->GetDaughterIterator());
      while (CandDigiPairHandle *chh =
             dynamic_cast<CandDigiPairHandle *>(chhItr())) {

         // Veto shield
         if (chh->GetStripEndId().IsVetoShield()) {
            if (!cosm) cosm = new TObjArray();
            cosm->Add(chh);
            continue;
         }

         // Caldet cosmic counter
         PlaneView::PlaneView_t vw = chh->GetPlaneView();
         if ((vw == PlaneView::kA) || (vw == PlaneView::kB)) {
            if (!cosm) cosm = new TObjArray();
            cosm->Add(chh);
         }
      }
      if (cosm) mpair.AddAt(cosm, 2);
/*
      if (cmh->GetNCosmics()) {
         TObjArray *cosm = new TObjArray();
         TIter chhItr(cmh->GetDaughterIterator());
         while (CandDigiPairHandle *chh =
                   dynamic_cast<CandDigiPairHandle *>(chhItr())) {
            PlaneView::PlaneView_t vw = chh->GetPlaneView();
            if ((vw == PlaneView::kA) || (vw == PlaneView::kB))
               cosm->Add(chh);
         }
         mpair.AddAt(cosm, 2);
      }
*/

// Create CandThruMuon.
      cxx.SetDataIn(&mpair);
      CandThruMuonHandle cfh = CandThruMuon::MakeCandidate(ah, cxx);
      ch.AddDaughterLink(cfh);
      mpair.Delete();
   }
}

//......................................................................

void AlgFitThruMuonList::RunFindStraightTrackAlg(Float_t span,
   Int_t iterMax, const CandStraightClusterHandle &cch, TObjArray *&hay)
{
//
//  Purpose:  Find a straight track from a cluster by iterating until
//            all digits are within a specified number of strip widths
//            of the fitted track.
//
//  Arguments:
//    span      in    Number of strip widths for difference limit.
//    iterMax   in    Maximum number of iterations for fit.
//    cch       in    Handle to cluster from which straight track to be
//                    found.
//    hay       out   New array in which to store digits associated with
//                    straight track.
//
//  Returns:  n/a
//

// Set default return value.
   hay = 0;

// Set fit mode.
   int modeZT = cch.GetFitMode();

// Initialize stats.
   Float_t xsum, x2sum, ysum;
   Float_t xysum = cch.GetTZsum();
   Float_t wtsum = cch.GetWtSum();
   if (modeZT) {
      xsum  = cch.GetTsum();
      x2sum = cch.GetT2sum();
      ysum  = cch.GetZsum();
   }
   else {
      xsum  = cch.GetZsum();
      x2sum = cch.GetZ2sum();
      ysum  = cch.GetTsum();
   }

// Check for negative determinant.
   Float_t det = cch.GetFitDet();
   if (det < 0) {
      MSG("BubAlg", Msg::kWarning)
         << "Determinant of fit negative." << endl;
      return;
   }

// Initialize fit parameters.
   Float_t intercept = cch.GetFitInter();
   Float_t slope = cch.GetFitSlope();
   Float_t slope2_1 = slope * slope + 1;

// Find strip width using first digitized strip.
   Float_t cellwidth = 0;
   TIter chhItr(cch.GetDaughterIterator());
   while (const CandDigiPairHandle *chht =
             dynamic_cast<const CandDigiPairHandle *>(chhItr())) {
      UgliGeomHandle ugh(*chht->GetVldContext());
      UgliStripHandle ush = ugh.GetStripHandle(chht->GetStripEndId());
      if (ush.IsValid()) {
         cellwidth = ush.GetHalfWidth();
         cellwidth *= 2.0;
         break;
      }
   }

// Set difference check limit.
   Float_t d2check = span * cellwidth;
   d2check *= d2check;

// Create temporary clusters of good and bad digit pairs.
   TObjArray *fitCls = new TObjArray(40);
   TObjArray unfit;

// Iterate over cluster digit pairs.
   chhItr.Reset();
   while (CandDigiPairHandle *chh =
             dynamic_cast<CandDigiPairHandle *>(chhItr())) {

// Retrieve data.
      Float_t xdat, ydat;
      if (modeZT) {
         xdat = chh->GetTPos();
         ydat = chh->GetZPos();
      }
      else {
         xdat = chh->GetZPos();
         ydat = chh->GetTPos();
      }

// Check fit.
      Float_t d2diff = ydat - intercept - slope * xdat;
      d2diff *= d2diff;
      d2diff /= slope2_1;
      if (d2diff <= d2check) {

// Good fit, so add to fit cluster.
         fitCls->Add(chh);
      }

// Bad fit, so add to unfit array and update stats.
      else {
         unfit.Add(chh);
         Float_t wt = chh->GetCharge();
         xsum  -= wt * xdat;
         ysum  -= wt * ydat;
         xysum -= wt * xdat * ydat;
         x2sum -= wt * xdat * xdat;
         wtsum -= wt;
      }
   }

// Check that there are enough points included.
   if (fitCls->GetEntries() <= 1) {
      MSG("BubAlg", Msg::kWarning)
         << "Not enough points for fit." << endl;
      return;
   }

// Loop until no digits added or removed.
   Bool_t changed;
   Int_t  numiter = 0;
   do {
      changed = kFALSE;
      numiter++;

// Check for negative determinant.
      det = wtsum * x2sum - xsum * xsum;
      if (det < 0) {
         MSG("BubAlg", Msg::kWarning)
            << "Determinant of fit negative." << endl;
         return;
      }

// Calculate fit parameters.
      if (det != 0) {
         intercept = (-xysum * xsum + ysum * x2sum) / det;
         slope = (-xsum * ysum + wtsum * xysum) / det;
         slope2_1 = slope * slope + 1;
      }
      else {
         intercept = ysum / wtsum;
         slope = 0;
         slope2_1 = 1;
      }

// Iterate over digits in rejects to check for new additions.
      TIter unfhItr(&unfit);
      while (CandDigiPairHandle *unfh =
                dynamic_cast<CandDigiPairHandle *>(unfhItr())) {

// Retrieve data.
         Float_t xdat, ydat;
         if (modeZT) {
            xdat = unfh->GetTPos();
            ydat = unfh->GetZPos();
         }
         else {
            xdat = unfh->GetZPos();
            ydat = unfh->GetTPos();
         }

// Check fit.
         Float_t d2diff = ydat - intercept - slope * xdat;
         d2diff *= d2diff;
         d2diff /= slope2_1;
         if (d2diff <= d2check) {

// Good fit, so add to fit cluster and update stats.
            fitCls->Add(unfit.Remove(unfh));
            Float_t wt = unfh->GetCharge();
            xsum  += wt * xdat;
            ysum  += wt * ydat;
            xysum += wt * xdat * ydat;
            x2sum += wt * xdat * xdat;
            wtsum += wt;
            changed = kTRUE;
         }
      }

// Iterate over digits in fit cluster to check for new rejects.
      TIter fithItr(fitCls);
      while (CandDigiPairHandle *fith =
                dynamic_cast<CandDigiPairHandle *>(fithItr())) {

// Retrieve data.
         Float_t xdat, ydat;
         if (modeZT) {
            xdat = fith->GetTPos();
            ydat = fith->GetZPos();
         }
         else {
            xdat = fith->GetZPos();
            ydat = fith->GetTPos();
         }

// Check fit.
         Float_t d2diff = ydat - intercept - slope * xdat;
         d2diff *= d2diff;
         d2diff /= slope2_1;
         if (d2diff > d2check) {

// Bad fit, so remove from fit cluster and update stats.
            unfit.Add(fitCls->Remove(fith));
            Float_t wt = fith->GetCharge();
            xsum  -= wt * xdat;
            ysum  -= wt * ydat;
            xysum -= wt * xdat * ydat;
            x2sum -= wt * xdat * xdat;
            wtsum -= wt;
            changed = kTRUE;
         }
      }
   } while (changed && (numiter <= iterMax));

// Reject vertical clusters.
   TIter checkItr(fitCls->MakeIterator());
   CandDigiPairHandle *chh =
                       dynamic_cast<CandDigiPairHandle *>(checkItr());
   if(!chh) return;
   Int_t currpln = chh->GetPlane();
   while ( ( chh = dynamic_cast<CandDigiPairHandle *>(checkItr()) ) ) {
      if (chh->GetPlane() != currpln) {

// Found straight non-vertical track.
         hay = fitCls;
         return;
      }
   }
}

//......................................................................

void AlgFitThruMuonList::Trace(const char *c) const
{
//
//  Purpose:  Trace the AlgFitThruMuonList.
//
//  Arguments:
//    c          in    String tag for the trace.
//
//  Return:   n/a
//

  MSG("BubCand", Msg::kDebug)
    << "**********Begin AlgFitThruMuonList::Trace(\"" << c << "\")"
    << endl
    << "**********End AlgFitThruMuonList::Trace(\"" << c << "\")"
    << endl;
}
