#include "I0LambdaCorrection.hh"

////////////////////////////////////////
I0LambdaCorrection::
I0LambdaCorrection():
    calcErrorPropagation(false),
    MessageTag ("I0(lambda) correction>> "){
}

////////////////////////////////////////
I0LambdaCorrection::
I0LambdaCorrection(ElementContainerMatrix* ecm):
    calcErrorPropagation(false),
    MessageTag ("I0(lambda) correction>> "){
    SetTarget(ecm);
}

////////////////////////////////////////
void I0LambdaCorrection::
SetI0Lambda(ElementContainer ec){
    i0Lambda = ec;
}

//////////////////////////////////////////////////////////
bool I0LambdaCorrection::
SetPolarizationTable( std::string _pathToData ){
    std::string foundPath=FindParamFilePath( _pathToData );
    if (foundPath==""){
        UtsusemiError( MessageTag+"SetPolarizationTable >> No such file ("+_pathToData+")" );
        return false;
    }
    std::ifstream ifs( foundPath.c_str() );
    if (ifs.fail()){
        UtsusemiError( MessageTag+"SetPolarizationTable >> Failed to open file ("+_pathToData+")" );
        return false;
    }
    StringTools st;
    std::string aline;
    _pol_lam.clear();
    _pol_val.clear();
    while(getline(ifs,aline)){
        if (aline.substr(0,1)!="#"){
            std::vector<std::string> conts = st.SplitString( aline, "," );
            if (conts.size()>2){
                _pol_lam.push_back( st.StringToDouble( conts[0] ) );
                _pol_val.push_back( st.StringToDouble( conts[1] ) );
            }
        }
    }
    ifs.close();
    return true;
}

//////////////////////////////////////////////////////////
bool I0LambdaCorrection::
_ExecPolarizationCorrection( ElementContainerMatrix* _ecm, std::vector<Double> _lambs,  std::vector<Double> _vals ){

    if ((_lambs.empty())||(_vals.empty())) {
        UtsusemiError( MessageTag+"_ExecPolarizationCorrection >> Polarization data is not loaded." );
        return false;
    }

    std::vector<Double>* lamb_v = _ecm->PutPointer(0)->PutPointer(0)->PutP("Lamb");
    std::vector<Double> Pn_v;
    for (UInt4 i=0; i<(lamb_v->size()); i++){
        Double lamb = lamb_v->at(i);
        Double Pn = 0.0;
        for (UInt4 j=0; j<(_lambs.size()-1); j++){
            if ((lamb>=_lambs[j])&&(lamb<=_lambs[j+1])){
                Pn = _vals[j] + (_vals[j+1]-_vals[j])/(_lambs[j+1]-_lambs[j])*(lamb - _lambs[j]);
                break;
            }
        }
        Pn_v.push_back(Pn);
    }

    std::string Ykey = _ecm->PutPointer(0)->PutPointer(0)->PutYKey();
    std::string Ekey = _ecm->PutPointer(0)->PutPointer(0)->PutEKey();
    for (UInt4 i=0; i<_ecm->PutSize(); i++){
        ElementContainerArray* eca = _ecm->PutPointer(i);
        for (UInt4 j=0; j<(eca->PutSize()); j++){
            ElementContainer* ec = eca->PutPointer(j);
            std::vector<Double>* ii = ec->PutP( Ykey );
            std::vector<Double>* ee = ec->PutP( Ekey );
            if ((ii->size())!=Pn_v.size()){
                UtsusemiError( MessageTag+"_ExecPolarizationCorrection >> Intensity size is not match with Polarization points" );
                return false;
            }
            for (UInt4 k=0; k<(ii->size()); k++){
                ii->at(k) = (ii->at(k))/Pn_v[k];
                ee->at(k) = (ee->at(k))/Pn_v[k];
            }
        }
    }
    return true;
}

////////////////////////////////////////
bool I0LambdaCorrection::
Execute(ElementContainer ec, std::string _pathToData){
    SetI0Lambda(ec);
    if (Execute()){
        if (_pathToData==""){
            return true;
        }else{
            if (SetPolarizationTable( _pathToData )){
                if (_ExecPolarizationCorrection( Put(), _pol_lam, _pol_val )) return true;
            }
        }
    }
    return false;
}

////////////////////////////////////////
bool I0LambdaCorrection::
Execute(){
    ElementContainerMatrix *ecm = Put();

    bool ret = false;

    HeaderBase* h_ecm = ecm->PutHeaderPointer();
    std::vector<std::string> process = h_ecm->PutStringVector("DATAPROCESSED");

    std::vector<Double> lamb_vec_corr0 = i0Lambda.Put("Lamb");
    std::vector<Double> int_vec_corr0 = i0Lambda.Put("Intensity");
    std::vector<Double> err_vec_corr0 = i0Lambda.Put("Error");
    std::vector<Double> lamb_vec_corr;
    std::vector<Double> int_vec_corr;
    std::vector<Double> err_vec_corr;
    UInt4 i = 0;
    while(i < lamb_vec_corr0.size()){
        if (err_vec_corr0[i] >= 0){
            lamb_vec_corr.push_back(lamb_vec_corr0[i]);
            int_vec_corr.push_back(int_vec_corr0[i]);
            err_vec_corr.push_back(err_vec_corr0[i]);
        }
        i++;
    }
    std::string XKey_ec = ecm->PutPointer(0)->PutPointer(0)->PutXKey();

    if (lamb_vec_corr.size() < 2){
        return false;
    }

#ifdef MULTH
    omp_set_num_threads( MULTH );
#endif

    std::cout << MessageTag << "ECM size: " << ecm->PutSize() << std::endl;
    std::cout << MessageTag << "Processing ";
    for (UInt4 i=0; i< ecm->PutSize(); ++i){
        std::cout << ".";
        ElementContainerArray *eca = ecm->PutPointer(i);
#pragma omp parallel for
#if (_OPENMP >= 200805)  // OpenMP 3.0 and later
        for (UInt4 j=0; j<eca->PutSize(); ++j){
#else
        for (Int4 j=0; j<eca->PutSize(); ++j){
#endif
            ElementContainer *ec = eca->PutPointer(j);

            std::vector<Double> lamb_vec = ec->Put("Lamb");
            std::vector<Double> int_vec = ec->Put("Intensity");
            std::vector<Double> err_vec = ec->Put("Error");
            UInt4 lambda_size = lamb_vec.size();
            std::vector<Double> new_int_vec(lambda_size, 0.0);
            std::vector<Double> new_err_vec(lambda_size, 0.0);
            for (UInt4 k=0; k < lambda_size; ++k){
                if (int_vec[k] == 0.0){
                    new_err_vec[k] = err_vec[k];
                } else {
                    Double factor(1.), error(0.);
                    std::vector<Double>::iterator it_lamb1 = lower_bound(lamb_vec_corr.begin(), lamb_vec_corr.end(), lamb_vec[k]);
                    std::vector<Double>::iterator it_int1 = int_vec_corr.begin();
                    std::vector<Double>::iterator it_err1 = err_vec_corr.begin();
                    advance(it_int1, distance(lamb_vec_corr.begin(), it_lamb1));
                    advance(it_err1, distance(lamb_vec_corr.begin(), it_lamb1));
                    if (*it_lamb1 == lamb_vec[k] ){
                        factor = *it_int1;
                        if (calcErrorPropagation){
                            error = *it_err1;
                        }
                    } else {
                        Double lamb0 = *(it_lamb1-1);
                        Double int0 = *(it_int1-1);
                        Double err0 = *(it_err1-1);
                        factor = (*it_int1 - int0)/(*it_lamb1 - lamb0) * (lamb_vec[k] - lamb0) + int0;
                        if (calcErrorPropagation){
                            Double f = (lamb_vec[k] - lamb0)/(*it_lamb1 - lamb0);
                            error = sqrt(pow(f* (*it_err1), 2.)+pow( (1.-f)*err0, 2.));
                        }
                    }
                    new_int_vec[k] = int_vec[k]/factor;
                    Double flag = 1.0;
                    if (err_vec[k] < 0.0) flag = -1.0;
                    new_err_vec[k] = flag * sqrt( pow(err_vec[k]/factor, 2.0) +
                                                  pow(new_int_vec[k]*error/factor, 2.0) );
                }
            }
            ec->Replace("Intensity", new_int_vec);
            ec->Replace("Error", new_err_vec);
            ec->SetUnit("Intensity", "counts");
            ec->SetUnit("Error", "counts");
            ec->SetKeys(XKey_ec, "Intensity", "Error");
            /*
            HeaderBase * h_ec = ec->PutHeaderPointer();
            h_ec->Erase("TotalCounts");
            h_ec->Add("TotalCounts", ec->Sum("Intensity"));
            */
        }
    }
    std::cout << std::endl;

    process.push_back("I0LambdaCorrection");
    h_ecm->OverWrite("DATAPROCESSED",process);

    return true;
}

