#include "TransmittanceCorrection.hh"

////////////////////////////////////////
TransmittanceCorrection::
TransmittanceCorrection():
    params(0),
    errors(0),
    funcType(TRANSMIT_TYPE_EXP),
    calcErrorPropagation(false),
    useAngleDependence(false),
    MessageTag ("Transmittance correction>> "){
}

////////////////////////////////////////
TransmittanceCorrection::
TransmittanceCorrection(ElementContainerMatrix* ecm):
    params(0),
    errors(0),
    funcType(TRANSMIT_TYPE_EXP),
    calcErrorPropagation(false),
    useAngleDependence(false),
    MessageTag ("Transmittance correction>> "){
    SetTarget(ecm);
}

////////////////////////////////////////
void TransmittanceCorrection::
SetParameters(std::vector<Double> params, std::vector<Double> errors, std::string funcType, bool _useAngDep){
    this->params.clear();
    this->errors.clear();
    copy( params.begin(), params.end(), back_inserter(this->params) );
    copy( errors.begin(), errors.end(), back_inserter(this->errors) );
    this->funcType = funcType;
    this->useAngleDependence = _useAngDep;
    return;
}

////////////////////////////////////////
bool TransmittanceCorrection::
Execute(){
    if (funcType == TRANSMIT_TYPE_POLYAPPROX)
        return _ExecutePolyApproximation();
    else
        return _Execute();
}

////////////////////////////////////////
bool TransmittanceCorrection::
_Execute(){
    ElementContainerMatrix *ecm = Put();

    bool ret = false;

    if (params.size() == 0){
        return false;
    }
    HeaderBase* h_ecm = ecm->PutHeaderPointer();
    std::vector<std::string> process = h_ecm->PutStringVector("DATAPROCESSED");
    std::string XKEY_EC = ecm->PutPointer(0)->PutPointer(0)->PutXKey();

#ifdef MULTH
    omp_set_num_threads( MULTH );
#endif

    std::cout << MessageTag << "ECM size: " << ecm->PutSize() << std::endl;
    std::cout << MessageTag << "Function type: " << funcType << std::endl;
    std::cout << MessageTag << "Processing ";
    for (UInt4 i=0; i< ecm->PutSize(); ++i){
        std::cout << ".";
        ElementContainerArray *eca = ecm->PutPointer(i);
#pragma omp parallel for
#if (_OPENMP >= 200805)  // OpenMP 3.0 and later
        for (UInt4 j=0; j<eca->PutSize(); ++j){
#else
        for (Int4 j=0; j<eca->PutSize(); ++j){
#endif
            ElementContainer *ec = eca->PutPointer(j);

            std::vector<Double> lamb_vec = ec->Put("Lamb");
            std::vector<Double> int_vec = ec->Put("Intensity");
            std::vector<Double> err_vec = ec->Put("Error");
            UInt4 lambda_size = lamb_vec.size();
            std::vector<Double> new_int_vec(lambda_size, 0.0);
            std::vector<Double> new_err_vec(lambda_size, 0.0);
            std::vector<Double> pol_ang = ec->PutHeaderPointer()->PutDoubleVector(UTSUSEMI_KEY_HEAD_PIXELPOLARANGLES);
            for (UInt4 k=0; k < lambda_size; ++k){
                Double factor(1.), error(0.);
                if (funcType == TRANSMIT_TYPE_EXP){
                    Exponential(lamb_vec[k], &factor, &error);
                } else if (funcType == TRANSMIT_TYPE_POLY3) {
                    Polynomial3(lamb_vec[k], &factor, &error);
                }
                if (useAngleDependence){
                    Double a_lamb = -1.0 * log(factor);
                    Double f_tt = (1.0 / cos(pol_ang[0] / 180.0 * M_PI)) - 1.0;
                    factor = (1.0 / (a_lamb * f_tt)) * exp( -1.0 * a_lamb / cos(pol_ang[0] / 180.0 * M_PI)) * (exp(a_lamb * f_tt) - 1.0);
                }
                new_int_vec[k] = int_vec[k]/factor;
                Double err_abs = fabs(err_vec[k]);
                Double flag = 1.0;
                if (err_vec[k] < 0.0) flag = -1.0;
                new_err_vec[k] = flag * sqrt( pow(err_abs/factor, 2.0) +
                                              pow(new_int_vec[k]*error/factor/factor, 2.0) );
            }
            ec->Replace("Intensity", new_int_vec);
            ec->Replace("Error", new_err_vec);
            ec->SetUnit("Intensity", "counts");
            ec->SetUnit("Error", "counts");
            ec->SetKeys(XKEY_EC, "Intensity", "Error");
            /*
            HeaderBase * h_ec = ec->PutHeaderPointer();
            h_ec->Erase("TotalCounts");
            h_ec->Add("TotalCounts", ec->Sum("Intensity"));
            */
        }
    }
    std::cout << std::endl;

    process.push_back("TransmittanceCorrection");
    h_ecm->OverWrite("DATAPROCESSED",process);

    return true;
}

////////////////////////////////////////
bool TransmittanceCorrection::
_ExecutePolyApproximation(){
    UtsusemiMessage("SAS::TransmittanceCorrection::_ExecutePolyApploximation start");
    ElementContainerMatrix *ecm = Put();

    bool ret = false;

    if (params.size() == 0){
        return false;
    }
    HeaderBase* h_ecm = ecm->PutHeaderPointer();
    std::vector<std::string> process = h_ecm->PutStringVector("DATAPROCESSED");
    std::string XKEY_EC = ecm->PutPointer(0)->PutPointer(0)->PutXKey();

    Double thick;
    if (h_ecm->CheckKey("TransmittancePolyApproxThick") == 1){
        thick = h_ecm->PutDouble("TransmittancePolyApproxThick");
    }else{
        UtsusemiError("SAS::TransmittanceCorrection::ExecutePolyApproximation >> required argument of thickness");
        return ret;
    }

#ifdef MULTH
    omp_set_num_threads( MULTH );
#endif

    std::cout << MessageTag << "ECM size: " << ecm->PutSize() << std::endl;
    std::cout << MessageTag << "Function type: " << funcType << std::endl;
    std::cout << MessageTag << "Processing ";
    for (UInt4 i=0; i< ecm->PutSize(); ++i){
        std::cout << ".";
        ElementContainerArray *eca = ecm->PutPointer(i);
#pragma omp parallel for
#if (_OPENMP >= 200805)  // OpenMP 3.0 and later
        for (UInt4 j=0; j<eca->PutSize(); ++j){
#else
        for (Int4 j=0; j<eca->PutSize(); ++j){
#endif
            ElementContainer *ec = eca->PutPointer(j);

            std::vector<Double> *lamb_vec = ec->PutP("Lamb");
            std::vector<Double> *int_vec = ec->PutP("Intensity");
            std::vector<Double> *err_vec = ec->PutP("Error");
            std::vector<Double> pol_ang = ec->PutHeaderPointer()->PutDoubleVector(UTSUSEMI_KEY_HEAD_PIXELPOLARANGLES);
            UInt4 lambda_size = lamb_vec->size();
            std::vector<Double> new_int_vec(lambda_size, 0.0);
            std::vector<Double> new_err_vec(lambda_size, 0.0);
            Double thick_corr = (thick / 2.0) * (1.0 + (1.0 / cos(pol_ang[0] / 180.0 * M_PI)));
            for (UInt4 k=0; k < lambda_size; ++k){
                Double factor(1.), error(0.);
                PolynominalApproximation(lamb_vec->at(k), &factor, &error, thick_corr);
                new_int_vec[k] = int_vec->at(k) / factor;
                Double err_abs = fabs(err_vec->at(k));
                Double flag = 1.0;
                if (err_vec->at(k) < 0.0) flag = -1.0;
                new_err_vec[k] = flag * sqrt(pow(err_abs / factor, 2.0) +
                                             pow(new_int_vec[k] * error / factor / factor, 2.0) );
            }
            ec->Replace("Intensity", new_int_vec);
            ec->Replace("Error", new_err_vec);
            ec->SetUnit("Intensity", "counts");
            ec->SetUnit("Error", "counts");
            ec->SetKeys(XKEY_EC, "Intensity", "Error");
        }
    }
    std::cout << std::endl;

    process.push_back("TransmittanceCorrection");
    h_ecm->OverWrite("DATAPROCESSED",process);

    return true;
}

////////////////////////////////////////
bool TransmittanceCorrection::
DumpFactorIntoContainer(ElementContainerMatrix *ecm_in){
    UtsusemiMessage("SAS::TransmittanceCorrection::DumpFactorIntoContainer start");
    ElementContainerMatrix *ecm = Put();

    bool ret = false;

    if (params.size() == 0){
        return false;
    }
    HeaderBase* h_ecm = ecm->PutHeaderPointer();
    ecm_in->PutHeaderPointer()->InputString(h_ecm->DumpToString());
    std::string XKEY_EC = ecm->PutPointer(0)->PutPointer(0)->PutXKey();

    Double thick;
    if (h_ecm->CheckKey("TransmittancePolyApproxThick") == 1){
        thick = h_ecm->PutDouble("TransmittancePolyApproxThick");
    }else{
        UtsusemiError("SAS::TransmittanceCorrection::ExecutePolyApproximation >> required argument of thickness");
        return ret;
    }
    std::cout << "# [inamura 240211] thick = " << thick << std::endl;
#ifdef MULTH
    omp_set_num_threads( MULTH );
#endif

    std::cout << MessageTag << "ECM size: " << ecm->PutSize() << std::endl;
    std::cout << MessageTag << "Function type: " << funcType << std::endl;
    std::cout << MessageTag << "Processing ";
    ecm_in->Allocate(ecm->PutSize());
    for (UInt4 i=0; i< ecm->PutSize(); ++i){
        ElementContainerArray *eca = ecm->PutPointer(i);
        HeaderBase* h_eca = eca->PutHeaderPointer();
        ElementContainerArray *eca_in = new ElementContainerArray(*h_eca);
        eca_in->Allocate(eca->PutSize());
        std::cout << ".";

#pragma omp parallel for
#if (_OPENMP >= 200805)  // OpenMP 3.0 and later
        for (UInt4 j=0; j<eca->PutSize(); ++j){
#else
        for (Int4 j=0; j<eca->PutSize(); ++j){
#endif
            ElementContainer *ec = eca->PutPointer(j);

            std::vector<Double> *lamb_vec = ec->PutP("Lamb");
            std::vector<Double> *int_vec = ec->PutP("Intensity");
            std::vector<Double> *err_vec = ec->PutP("Error");
            std::vector<Double> pol_ang = ec->PutHeaderPointer()->PutDoubleVector(UTSUSEMI_KEY_HEAD_PIXELPOLARANGLES);
            UInt4 lambda_size = lamb_vec->size();
            std::vector<Double> new_int_vec(lambda_size, 0.0);
            std::vector<Double> new_err_vec(lambda_size, 0.0);
            Double thick_corr = (thick / 2.0) * (1.0 + (1.0 / cos(pol_ang[0] / 180.0 * M_PI)));
            for (UInt4 k=0; k < lambda_size; ++k){
                Double factor(1.), error(0.);
                PolynominalApproximation(lamb_vec->at(k), &factor, &error, thick_corr);
                new_int_vec[k] = factor;
                new_err_vec[k] = error;
            }
            ElementContainer *ec_in = new ElementContainer(ec->PutHeader());
            ec_in->Add("Lamb", *lamb_vec);
            ec_in->Add(XKEY_EC, ec->PutX());
            ec_in->Add(ec->PutYKey(), new_int_vec, "none");
            ec_in->Add(ec->PutEKey(), new_err_vec, "none");
            ec_in->SetKeys(XKEY_EC, ec->PutYKey(), ec->PutEKey());
            eca_in->Set(j, ec_in);
        }
        ecm_in->Set(i, eca_in);
    }
    std::cout << std::endl;

    return true;
}
////////////////////////////////////////
void TransmittanceCorrection::
Exponential(Double lamb, Double *factor, Double *error){
    *factor = params[0]*exp(-params[1]*lamb)+params[2];
    if (calcErrorPropagation){
        *error = sqrt( pow( exp(-params[1]*lamb)*errors[0], 2.) +
                        pow( -lamb*params[0]*exp(-params[1]*lamb)*errors[1], 2.) +
                        pow( errors[2], 2.) ) ;
    }
    return;
}

////////////////////////////////////////
void TransmittanceCorrection::
Polynomial3(Double lamb, Double *factor, Double *error){
    *factor = params[0]+params[1]*lamb+params[2]*lamb*lamb+params[3]*lamb*lamb*lamb;
    if (calcErrorPropagation){
        *error = sqrt( pow( errors[0], 2.) +
                        pow( lamb*errors[1], 2.) +
                        pow( lamb*lamb*errors[2], 2.) +
                        pow( lamb*lamb*lamb*errors[3], 2.));
    }
    return;
}

////////////////////////////////////////
void TransmittanceCorrection::
PolynominalApproximation(Double lamb, Double *factor, Double *error, Double thick){
    Double St = (params[0] + (params[1] * lamb) + (params[2] * pow(lamb, 2.)) + (params[3] * pow(lamb, 3.))) * thick;
    *factor = exp(-1. * St);
    if (calcErrorPropagation){
        Double ex2 = pow( (-thick * (*factor)), 2.);
        Double ee = ex2 * (pow(errors[0], 2.) + pow(error[1] * lamb, 2.) +
                           pow(errors[2] * lamb * lamb, 2.) +
                           pow(errors[3] * lamb * lamb * lamb, 2.));
        *error = sqrt(ee);
    }
    return;
}
const std::string TransmittanceCorrection::TRANSMIT_TYPE_EXP = "exp";
const std::string TransmittanceCorrection::TRANSMIT_TYPE_POLY3 = "poly3";
const std::string TransmittanceCorrection::TRANSMIT_TYPE_POLYAPPROX = "PolyApproximation";
