#include <cassert>
#include <cmath>
#include <cstdlib> 
#include <ctime>

#include <gsl/gsl_bspline.h>
#include <gsl/gsl_multifit.h>
#include <gsl/gsl_rng.h>
#include <gsl/gsl_randist.h>
#include <gsl/gsl_statistics.h>

#include "Header.hh"
#include "HeaderBase.hh"
#include "ElementContainer.hh"

#include "Domain.hh"
#include "ParamSet.hh"

//#include "BSpline.hh"
//#include "MethodType.hh"
#include "MethodFactory.hh"



Double gaussian(const Double x, const Double a, const Double c, const Double w) {
    return a*exp(-1.0*pow((x-c)/w, 2.0));
}

Double dist(const Double x, const UInt4 N, const Double a[], const Double c[], const Double w[]) {

    Double s=0.0;
    for (UInt4 i=0; i<N; ++i) {
        s += gaussian(x, a[i], c[i], w[i]);
    }
    return s;
}

ElementContainer init_src(const Double xmin, const Double xmax, const UInt4 nDiv) {

    const Double a[]={ 100000.0, 30000.0, 5000.0 };
    const Double c[]={      7.5,     3.0,   10.0 };
    const Double w[]={      2.5,     3.5,    4.0 };
    const UInt4 N_PEAK = sizeof(a)/sizeof(a[0]);

    Double delta=(xmax - xmin)/nDiv;
    vector<Double> x;
    for (UInt4 i=0; i<=nDiv; ++i) {
        x.push_back(xmin + delta*i);
    }
   
    gsl_rng *r;
    gsl_rng_env_setup();
    r = gsl_rng_alloc(gsl_rng_default);

    vector<Double> y, e;
    for (UInt4 i=0; i < x.size()-1; ++i) {
        Double xi = (x.at(i) + x.at(i+1))/2.0;
        Double yi, ei;

        yi = dist(xi, N_PEAK, a, c, w);
        ei = 0.1 * yi;
        yi += gsl_ran_gaussian(r, ei);

        y.push_back(yi);
        e.push_back(ei);
    }

    ElementContainer *ec = new ElementContainer();
    ec->AddToHeader("RUNNUMBWER", 1    );
    ec->AddToHeader("level",      0    );
    ec->AddToHeader("Inst.",      "DNA");
    ec->Add("TOF",       x, "sec"  );
    ec->Add("Intensity", y, "count");
    ec->Add("Error",     e, "coune");
    ec->SetKeys("TOF", "Intensity", "Error");

    return *ec;
};

void outputMethodType(MethodType methodType) {
    std::cout << "value=" <<  methodType.value;
    std::cout << ", name=" <<  methodType.name;
    std::cout << ", symbol=" <<  methodType.symbol;
    std::cout << endl;
}

int main(int argc, char *argv[]) {

    Domain domain;
    ParamSet param;

    outputMethodType(BSPLINE);
    outputMethodType(MOVING_AVERAGE);
    outputMethodType(LEVMAR);
    outputMethodType(NEW_LEVMAR);

    //------------            source data            ------------//
    //--------         element conatiner        --------//
    Double xMin, xMax;
    UInt4  nDiv;

    std::cout << "xmin = "; std::cin >> xMin;
    std::cout << "xmax = "; std::cin >> xMax;
    std::cout << "number of division = "; std::cin >> nDiv;

    ElementContainer src=init_src(xMin, xMax, nDiv);
    src.Dump();

    //--------        domain        --------//
    domain.setSource(src);
    domain.setRange(xMin, xMax);
    std::cout << "domain ["       << domain.getLowerBound()   << ", " << domain.getUpperBound()   << ")" << endl;
    std::cout << "domain index [" << domain.getLowerBoundID() << ", " << domain.getUpperBoundID() << "]" << endl;

    //--------        parameter set for B-spline        --------//
    param.add(BSpline::ORDER,                    3U   );
    param.add(BSpline::AUTOMATIC_KNOTS,          false);
    param.add(BSpline::USE_UNIFORM_BREAK_POINTS, true );
    param.add(BSpline::NUMBER_OF_BREAK_POINTS,   10U  );
    param.dump();

    //------------            initalise for B-spline            ------------//
    //


    clock_t start_time, stop_time, total_time;
    total_time=0;
    MethodFactory *factory = MethodFactory::getInstance();
    Method *method=factory->createMethod(BSPLINE);
    std::cout << "method = " << method->getMethodName() << endl;
    method->setDefaultParam(src).dump();
    method->checkParam(src, domain, param); 

    UInt4 N;
    std::cout << "number of iterations for test = "; std::cin >> N;

    printf("%10s %10s %10s %10s\n", "start", " stop",  "diff", " totak");
    for (UInt4 i=0; i<N; ++i) {
        start_time=clock();
        method->toInnerForm(src, domain, param); 
        method->fit();
        method->eval();
        stop_time=clock();

        total_time += (stop_time-start_time);
        printf("%10lu %10lu %10lu %10lu\n", start_time, stop_time, stop_time-start_time, total_time);
    }

    std::cout << "average time:   " << total_time/N   << endl;
    std::cout << "clocks per sec: " << CLOCKS_PER_SEC << endl;

    /*
    BSpline *conqMethod = dynamic_cast<BSpline*>(method);
    vector< vector<Double> > inc_dec_table=conqMethod->getTrend();

    list< vector<Double> > peakList;
    Double xi, h, w, xl, xr;
    for (UInt4 i=0; i<inc_dec_table[0].size()-1; ++i) {
        if (inc_dec_table[1][i] > 0.0 && inc_dec_table[1][i+1] < 0.0 &&
            inc_dec_table[2][i] < 0.0 && inc_dec_table[2][i+1] < 0.0 ) {

            vector<Double> *v=new vector<Double>();
            xi = src.Put(src.PutXKey(), i+1);
            h=(inc_dec_table[0][i] + inc_dec_table[0][i+1])/2.0;
            w=0.0;
            v->push_back(xi);
            v->push_back(h);
            v->push_back(w);
            for (int j=i; j>0; --j) {
                if (inc_dec_table[0][j-1] < h/2.0 && h/2.0 < inc_dec_table[0][j]) {
                    xl=src.Put(src.PutXKey(), j);
                    break;
                }
            }
            for (int j=i+1; j<inc_dec_table[0].size(); ++j) {
                if (inc_dec_table[0][j] > h/2.0 && h/2.0 > inc_dec_table[0][j+1]) {
                    xr=src.Put(src.PutXKey(), j);
                    break;
                }
            }

            peakList.push_back(*v);

            std::cout << "found peak: x=" << xi << ", h=" << h << "(xl, xr)=(" << xl << ", " << xr << ")" << endl;
            std::cout << "found peak: xm=" << (xl+xr)/2.0 << " w=" << (xr-xl)/2.0 << endl;


        } else if (0.0 < inc_dec_table[1][i] && 0.0 < inc_dec_table[1][i+1] &&
                   inc_dec_table[2][i] < 0.0 && 0.0 < inc_dec_table[2][i+1] ) {  // sholder

            vector<Double> *v=new vector<Double>();
            xi = src.Put(src.PutXKey(), i+1);
            h=(inc_dec_table[0][i] + inc_dec_table[0][i+1])/2.0;
            w=0.0;
            v->push_back(xi);
            v->push_back(h);
            v->push_back(w);
            for (int j=i; j>0; --j) {
                if (inc_dec_table[0][j-1] < h/2.0 && h/2.0 < inc_dec_table[0][j]) {
                    xl=src.Put(src.PutXKey(), j);
                    break;
                }
            }
            peakList.push_back(*v);
            std::cout << "found inc. sholder: x=" << xi << ", h=" << h << " xl=" << xl << " w=" << xi-xl << endl;

        } else if (0.0 > inc_dec_table[1][i] && 0.0 > inc_dec_table[1][i+1] &&
                   inc_dec_table[2][i] > 0.0 && 0.0 > inc_dec_table[2][i+1] ) {  // sholder

            vector<Double> *v=new vector<Double>();
            xi = src.Put(src.PutXKey(), i+1);
            h=(inc_dec_table[0][i] + inc_dec_table[0][i+1])/2.0;
            w=0.0;
            v->push_back(xi);
            v->push_back(h);
            v->push_back(w);
            peakList.push_back(*v);

            for (int j=i+1; j < inc_dec_table[0].size()-1; ++j) {
                if (inc_dec_table[0][j] > h/2.0 && h/2.0 > inc_dec_table[0][j+1]) {
                    xr=src.Put(src.PutXKey(), j+1);
                    break;
                }
            }

            std::cout << "found dec. sholder: x=" << xi << ", h=" << h << " xr=" << xr << " w=" << xr-xi << endl;
        }
    }
    */

    ElementContainer *result=new ElementContainer(src.PutHeader());
    //method.toElementContainer(src, *result);
    method->toElementContainer(src, *result);

    //result->dump();
    //ParamSet newParam=method.getFittedParam();
    //ParamSet newParam=method->getFittedParam();
    //newParam.dump();

    //std::cerr << "analyticallly differentiable: " <<  method.differentiable() << endl;
    std::cerr << "analyticallly differentiable: " <<  method->differentiable() << endl;

    return EXIT_SUCCESS;
}
