#include "AverageElementContainerMatrix.hh"
//////////////////////////////////////////////////////////
AverageElementContainerMatrix::
AverageElementContainerMatrix(){
    Initialize(0);
}
//////////////////////////////////////////////////////////
AverageElementContainerMatrix::
AverageElementContainerMatrix( Int4 debug){
    if (debug==1){
        Initialize(1);
    }else{
        Initialize(0);
    }

}
//////////////////////////////////////////////////////////
AverageElementContainerMatrix::
AverageElementContainerMatrix(ElementContainerMatrix *ecm){
    Initialize(0);
    SetTarget(ecm);
}
//////////////////////////////////////////////////////////
AverageElementContainerMatrix::
AverageElementContainerMatrix( ElementContainerMatrix *ecm, std::vector<UInt4> psd_v, std::vector<UInt4> pixel_v ){
    Initialize(0);
    SetTarget(ecm);
    SetPoints( psd_v, pixel_v );
}
//////////////////////////////////////////////////////////
AverageElementContainerMatrix::
AverageElementContainerMatrix( ElementContainerMatrix *ecm, std::vector<UInt4> psd_v, std::vector<UInt4> pixel_v, Int4 debug){
    if (debug==1){
        Initialize(1);
    }else{
        Initialize(0);
    }

    SetTarget(ecm);
    SetPoints( psd_v, pixel_v );

}
//////////////////////////////////////////////////////////
AverageElementContainerMatrix::
~AverageElementContainerMatrix(){
}
//////////////////////////////////////////////////////////
void AverageElementContainerMatrix::
Initialize(Int4 debug){
    psd_vec.clear();
    pixel_vec.clear();
    DebugFlag = debug;

    _NumOfMulTh = MlfGetNumOfMulTh();
}
//////////////////////////////////////////////////////////
void AverageElementContainerMatrix::
SetPoints(std::vector<UInt4> psd_v, std::vector<UInt4> pixel_v){
    if (psd_v.size()!=pixel_v.size()){
        std::cerr << "AverageElementContainerMatrix::SetPoints >> input std::vector size are invalid." << std::endl;
        return;
    }
    psd_vec.clear();
    pixel_vec.clear();

    for (UInt4 i=0; i<(psd_v.size()); i++){
        psd_vec.push_back(psd_v[i]);
        pixel_vec.push_back(pixel_v[i]);
    }
}

//////////////////////////////////////////////////////////
ElementContainer AverageElementContainerMatrix::
GetAverage(){
    ElementContainer ret;
    if (_Calculate( &ret, true ) ){
    }else{
        std::cerr << "AverageElementContainerMatrix::GetAverage Failed. " << std::endl;
    }
    return ret;
}

//////////////////////////////////////////////////////////
ElementContainer AverageElementContainerMatrix::
GetSum(){
    ElementContainer ret;
    if (_Calculate( &ret, false ) ){
    }else{
        std::cerr << "AverageElementContainerMatrix::GetSum Failed. " << std::endl;
    }
    return ret;
}

//////////////////////////////////////////////////////////
bool AverageElementContainerMatrix::
_Calculate( ElementContainer* _ec, bool isAve ){

    ElementContainerMatrix *ecm = Put();

    if ((psd_vec.size()==0)||(pixel_vec.size()==0)){
        //throw "ArgumentsError\n";
        for (UInt4 i=0;i<(ecm->PutSize());i++){
            for (UInt4 j=0;j<(ecm->PutPointer(i)->PutSize());j++){
                psd_vec.push_back(i);
                pixel_vec.push_back(j);
            }
        }
    }else{
        for (UInt4 i=0; i<psd_vec.size(); i++){
            if (psd_vec[i]>=(ecm->PutSize())){
                std::cerr << "AverageElementContainerMatrix::_Calculate > given psd_vec is invalid. "<< psd_vec[i] << std::endl;
                return false;
            }else if (pixel_vec[i]>=(ecm->PutPointer(psd_vec[i])->PutSize())){
                std::cerr << "AverageElementContainerMatrix::_Calculate > given pixel_vec is invalid. " << std::endl;
                return false;
            }
        }
    }

    ElementContainer* ec0 = ecm->PutPointer(psd_vec[0])->PutPointer(pixel_vec[0]);
    HeaderBase hh = ec0->PutHeader();
    std::string xkey = ec0->PutXKey();
    std::string ykey = ec0->PutYKey();
    std::string ekey = ec0->PutEKey();
    std::string xunit = ec0->PutUnit( xkey );
    std::string yunit = ec0->PutUnit( ykey );
    std::string eunit = ec0->PutUnit( ekey );

    std::vector<Double> x_vec = ecm->PutPointer(psd_vec[0])->PutPointer(pixel_vec[0])->PutX();

    UInt4 v_size = (UInt4)(x_vec.size())-1;
    UInt4 num_of_pixel = (UInt4)(psd_vec.size());

    std::vector<Double> y_vec( v_size, 0.0 );
    std::vector<Double> e_vec( v_size, 0.0 );
    std::vector<UInt4> c_vec( v_size, 0 );
    std::vector<bool> m_vec( v_size, false );

#ifdef MULTH
    omp_set_num_threads(_NumOfMulTh);
    std::vector< std::vector<Double>* >* YY_vec = new std::vector< std::vector<Double>* >( _NumOfMulTh, NULL );
    std::vector< std::vector<Double>* >* EE_vec = new std::vector< std::vector<Double>* >( _NumOfMulTh, NULL );
    std::vector< std::vector< UInt4 >* >* CC_vec = new std::vector< std::vector< UInt4 >* > ( _NumOfMulTh, NULL );
    std::vector< std::vector< bool >* >* MM_vec = new std::vector< std::vector< bool >* > ( _NumOfMulTh, NULL );
    for (UInt4 i=0;i<_NumOfMulTh;i++){
        YY_vec->at(i) = new std::vector<Double> (v_size, 0.0);
        EE_vec->at(i) = new std::vector<Double> (v_size, 0.0);
        CC_vec->at(i) = new std::vector<UInt4> (v_size, 0 );
        MM_vec->at(i) = new std::vector<bool> (v_size, false );
    }
#pragma omp parallel for
#if (_OPENMP >= 200805)  // OpenMP 3.0 and later
    for (UInt4 i=0;i<num_of_pixel;i++){
#else
    for (Int4 i=0;i<(Int4)num_of_pixel;i++){
#endif
        UInt4 ThNum = omp_get_thread_num();
        ElementContainer* ec = ecm->PutPointer(psd_vec[i])->PutPointer(pixel_vec[i]);
        HeaderBase* h_ec = ec->PutHeaderPointer();
        if ((h_ec->CheckKey("MASKED")==1)&&(h_ec->PutInt4("MASKED")==1)){
            continue;
        }else{
            std::vector<Double> intensity_vec = ec->PutY();
            std::vector<Double> error_vec = ec->PutE();

            if (isAve){
                for (UInt4 j=0; j<(v_size); j++){
                    if (error_vec[j]<0.0){
                    }else{
                        YY_vec->at(ThNum)->at(j) += intensity_vec[j];
                        EE_vec->at(ThNum)->at(j) += error_vec[j]*error_vec[j];
                        CC_vec->at(ThNum)->at(j) += 1;
                    }
                }
            }else{
                for (UInt4 j=0; j<(v_size); j++){
                    if (error_vec[j]<0.0){
                        YY_vec->at(ThNum)->at(j) += intensity_vec[j];
                        EE_vec->at(ThNum)->at(j) += error_vec[j]*error_vec[j];
                        CC_vec->at(ThNum)->at(j) += 1;
                        MM_vec->at(ThNum)->at(j) = true;
                    }else{
                        YY_vec->at(ThNum)->at(j) += intensity_vec[j];
                        EE_vec->at(ThNum)->at(j) += error_vec[j]*error_vec[j];
                        CC_vec->at(ThNum)->at(j) += 1;
                    }
                }
            }

        }
    }

    for (UInt4 i=0;i<_NumOfMulTh;i++){
        for (UInt4 j=0;j<v_size;j++){
            y_vec[j] += YY_vec->at(i)->at(j);
            e_vec[j] += EE_vec->at(i)->at(j);
            c_vec[j] += CC_vec->at(i)->at(j);
            if (MM_vec->at(i)->at(j))
                m_vec[j] = true;
        }
    }

    for (UInt4 i=0;i<_NumOfMulTh;i++){
        if (YY_vec->at(i)!=NULL) delete YY_vec->at(i);
        if (EE_vec->at(i)!=NULL) delete EE_vec->at(i);
        if (CC_vec->at(i)!=NULL) delete CC_vec->at(i);
        if (MM_vec->at(i)!=NULL) delete MM_vec->at(i);
    }
    delete YY_vec;
    delete EE_vec;
    delete CC_vec;
    delete MM_vec;

    //std::cout << "num_of_counts/num_of_pixels = " << num_of_counts << "/" << num_of_pixel << std::endl;
#else
    HeaderBase *h_ec;
    for (UInt4 i=0; i<(num_of_pixel); i++){
        ElementContainer* ec = ecm->PutPointer(psd_vec[i])->PutPointer(pixel_vec[i]);
        h_ec = ec->PutHeaderPointer();
        if ((h_ec->CheckKey("MASKED")==1)&&(h_ec->PutInt4("MASKED")==1)){
            continue;
        }else{
            std::vector<Double> intensity_vec = ec->PutY();
            std::vector<Double> error_vec = ec->PutE();
            //std::cout << "v_size=" << v_size << ", Ysize=" << intensity_vec.size() << std::endl;
            if (isAve){
                for (UInt4 j=0; j<(v_size); j++){
                    if (error_vec[j]<0.0){
                    }else{
                        y_vec[j] += intensity_vec[j];
                        e_vec[j] += error_vec[j]*error_vec[j];
                        c_vec[j] += 1;
                    }
                }
            }else{
                for (UInt4 j=0; j<(v_size); j++){
                    if (error_vec[j]<0.0){
                        y_vec[j] += intensity_vec[j];
                        e_vec[j] += error_vec[j]*error_vec[j];
                        c_vec[j] += 1;
                        m_vec[j] = true;
                    }else{
                        y_vec[j] += intensity_vec[j];
                        e_vec[j] += error_vec[j]*error_vec[j];
                        c_vec[j] += 1;
                    }
                }
            }
        }
    }
#endif

    //if (DebugFlag==1) std::cout << "num_of_counts = " << num_of_counts << std::endl;
    Int4 masked_flag = 0;
    UInt4 sum_of_counts = 0;

    if (isAve){
        for (UInt4 i=0; i<v_size; i++){
            if (c_vec[i]!=0){
                y_vec[i] = y_vec[i]/double( c_vec[i] );
                e_vec[i] = sqrt(e_vec[i])/double( c_vec[i] );
            }
            sum_of_counts += c_vec[i];
        }
        if (sum_of_counts==0) masked_flag = 1;
    }else{
        UInt4 num_of_masked = 0;
        for (UInt4 i=0; i<v_size; i++){
            if (m_vec[i]){
                e_vec[i] = -1.0*sqrt(e_vec[i]);
                num_of_masked++;
            }else{
                e_vec[i] = sqrt(e_vec[i]);
                sum_of_counts += c_vec[i];
            }
        }
        if ((sum_of_counts==0)||(num_of_masked==v_size)) masked_flag = 1;
    }
    /**
    if (num_of_counts!=0){
      for (UInt4 i=0; i<(v_size); i++){
        y_vec[i] = y_vec[i]/double(num_of_counts);
        e_vec[i] = sqrt(e_vec[i])/double(num_of_counts);
      }
    }else{
      masked_flag = 1;
    }
    **/

    if (hh.CheckKey("MASKED")==0){
        hh.Add("MASKED",masked_flag);
    }else{
        hh.OverWrite("MASKED",masked_flag);
    }

    _ec->InputHeader(hh);
    _ec->PutHeaderPointer()->InputString(ecm->PutHeaderPointer()->DumpToString());
    _ec->Add(xkey,x_vec,xunit);
    _ec->Add(ykey,y_vec,yunit);
    _ec->Add(ekey,e_vec,eunit);
    _ec->SetKeys(xkey,ykey,ekey);

    return true;
}
