#include "InterSplineFunc.h"
#include "matrix.h"



InterSplineFunc::InterSplineFunc()
    :fInitialized(false)
     ,fUniform(false)
     ,fHaveMin(false)
     ,fHaveMax(false)
{
}

InterSplineFunc::InterSplineFunc(vector<double>& vx, vector<double>& vy, 
                                 SplineType s /*= linear_runout*/,
                                 bool uniform /*= false*/)
    :fInitialized(false)
     ,fUniform(uniform)
     ,fHaveMin(false)
     ,fHaveMax(false)
{
    if (uniform) this->InitUniform(vx,vy,s);
    else this->InitNonUniform(vx,vy,s);
}


InterSplineFunc::~InterSplineFunc()
{
}

void InterSplineFunc::Init(vector<double>& vx, vector<double>& vy)
{
    this->InitNonUniform(vx,vy,linear_runout);
}

void InterSplineFunc::InitUniform(vector<double>& vx, vector<double>& vy, 
                                  InterSplineFunc::SplineType s)
{
    vector<double> avec, bvec, cvec, rvec;
    vector<double>::iterator vit = vy.begin()+2;
    double step = vx[1] - vx[0];
    double coef = 6.0/(step*step);

    // fill tri-diag matrix and RHS.
    while (vit != vy.end()) {
        avec.push_back(1.0);
        bvec.push_back(4.0);
        cvec.push_back(1.0);
        rvec.push_back(((*(vit-2)) - 2.0*(*(vit-1)) + (*vit))*coef);
        ++vit;
    }
    avec.pop_back();            // remove extra
    cvec.pop_back();            // entries

    fRunout = s;
    switch (fRunout) {
    case zero_runout:
        // fall through
    case static_runout:
        // no-op
        break;
    case linear_runout:
        // first and last stays 1.0
        break;
    case parabolic_runout:
        *bvec.begin() = 5.0;
        *(bvec.end()-1) = 5.0;
        break;
    case cubic_runout:
        *bvec.begin() = 6.0;
        *cvec.begin() = 0.0;
        *(bvec.end()-1) = 6.0;
        *(avec.end()-1) = 0.0;
        break;
    default:
        cerr << "InterSplineFunc: unkown runout";
        abort();
    }
    
    vector<double> mvec;
    mvec.push_back(0.0);        // place holder
    if (!tridiag_solve(avec,bvec,cvec,rvec,mvec)) abort();
    
    switch (fRunout) {
    case zero_runout:
        // fall through
    case static_runout:
        // no-op
        break;
    case linear_runout:
        mvec[0] = 0.0;
        mvec.push_back(0.0);
        break;
    case parabolic_runout:
        mvec[0] = mvec[1];
        mvec.push_back(mvec.back());
        break;
    case cubic_runout: {
        mvec[0] = 2.0*mvec[1] - mvec[2];
        int msiz = mvec.size();
        double tmp = 2.0*mvec[msiz-1] - mvec[msiz-2];
        mvec.push_back(tmp);
        break;
    }
    default:
        abort();
    }

    vector<double>::iterator M = mvec.begin(), x = vx.begin(), y = vy.begin();
    while (M != mvec.end()-1) { // fill N-1 parameters
        X.push_back(*x);
        Y.push_back(*y);
        A.push_back((*(M+1)-(*M))/(6.0*step));
        B.push_back(*M/2.0);
        C.push_back((*(y+1) - *y)/step - (*(M+1) + 2.0*(*M))*step/6.0);
        D.push_back(*y);
        ++M; ++x; ++y;
    }
    X.push_back(*x);
    Y.push_back(*y);

    fInitialized = true;
}
void InterSplineFunc::InitNonUniform(vector<double>& vx, vector<double>& vy, 
                                     InterSplineFunc::SplineType s)
{
    int siz = vx.size();

    vector<double> avec, bvec, cvec, rvec;

    double a1=0, cnm2=0;
    for (int ind = 0; ind < siz-2; ++ind) {
        double hip1 = vx[ind+1] - vx[ind+0];
        double hip2 = vx[ind+2] - vx[ind+1];
        if (ind == 0)
            a1=hip1;
        else
            avec.push_back(hip1);
        bvec.push_back(2.0*(hip1+hip2));
        if (ind == siz-3) 
            cnm2 = hip2;
        else
            cvec.push_back(hip2);
        rvec.push_back(6.0*(vy[ind+0]/hip1 -
                            vy[ind+1]*(hip1+hip2)/(hip1*hip2) +
                            vy[ind+2]/hip2));
    }
    
    fRunout = s;
    switch (fRunout) {
    case zero_runout:
        // fall through
    case static_runout:
        // no-op
        break;
    case linear_runout:
        // first and last stays same
        break;
    case parabolic_runout:
        bvec[0] += a1;
        bvec[siz-1] += cnm2;
        break;
    case cubic_runout:
        bvec[0] += 2.0*a1;
        bvec.back() += 2*cnm2;
        cvec[0] -= a1;
        avec.back() -= cnm2;
        break;
    default:
        abort();
    }
    
    vector<double> mvec;
    mvec.push_back(0.0);        // place holder
    if (!tridiag_solve(avec,bvec,cvec,rvec,mvec)) abort();
    
    switch (fRunout) {
    case zero_runout:
        // fall through
    case static_runout:
        // no-op
        break;
    case linear_runout:
        mvec[0] = 0.0;
        mvec.push_back(0.0);
        break;
    case parabolic_runout:
        mvec[0] = mvec[1];
        mvec.push_back(mvec.back());
        break;
    case cubic_runout: {
        mvec[0] = 2.0*mvec[1] - mvec[2];
        int msiz = mvec.size();
        double tmp = 2.0*mvec[msiz-1] - mvec[msiz-2];
        mvec.push_back(tmp);
        break;
    }
    default:
        abort();
    }

    vector<double>::iterator M = mvec.begin(), x = vx.begin(), y = vy.begin();
    while (M != mvec.end()-1) { // fill N-1 parameters
        double step = *(x+1) - *x;
        X.push_back(*x);
        Y.push_back(*y);
        A.push_back((*(M+1)-(*M))/(6.0*step));
        B.push_back(*M/2.0);
        C.push_back((*(y+1) - *y)/step - (*(M+1) + 2.0*(*M))*step/6.0);
        D.push_back(*y);
        ++M; ++x; ++y;
    }
    X.push_back(*x);
    Y.push_back(*y);
    fInitialized = true;
}

double InterSplineFunc::operator()(double x)
{
    if (!fInitialized) { 
        cerr << "InterSplineFunc: not initialized\n";
        return 0;
    }

    if (fRunout == zero_runout) {
        if (x <= X[0])     return 0;
        if (x >= X.back()) return 0;
    }
    if (fRunout == static_runout) {
        if (x <= X[0])     return Y[0];
        if (x >= X.back()) return Y.back();
    }

    int ind = 0;
    if (fUniform) {
        ind = (int)((x-X[0])/(X[1]-X[0]));
        if (ind < 0) ind = 0;
        int size = A.size();
        if (ind > size-1) ind = size-1;
    }
    else {
        vector<double>::iterator vit;
        for (vit = X.begin(); vit != X.end()-1 && x > *vit; ++vit, ++ind) ;
        if (ind) --ind;
    }
    double dx = x-X[ind];
    double retval = dx*(dx*(dx*A[ind] + B[ind]) + C[ind]) + D[ind];
    if (fHaveMin && retval < fMin) retval = fMin;
    if (fHaveMax && retval > fMax) retval = fMax;
    return retval;
}

