#include "AdvModelDistribution.hh"
#include "AdvModelParam.hh"
#include "AdvModelParamSet.hh"

#include <gsl/gsl_sf_gamma.h>

AdvModelDistribution::AdvModelDistribution(AdvModel* mdl)
{
    model = mdl;

    bimodal = false;
    ndiv  = 1;
    func1 = NULL;
    func2 = NULL;
}

AdvModelDistribution::~AdvModelDistribution()
{
    if (func1) delete func1;
    if (func2) delete func2;
}

void AdvModelDistribution::SetAttr( mxml_node_t* node )
{
    mxml_node_t* n1;
    mxml_node_t* n2;
    const char* cp;

    if (node==NULL) return;

    cp = mxmlElementGetAttr(node, "distribution");

    if (std::strlen(cp)>=2 && tolower(cp[0])=='b' && tolower(cp[1])=='i' )
    {
        /* bi-modal distribution */
        n1 = mxmlFindElement(node, node, "Distribution", NULL, NULL, MXML_DESCEND);
        n2 = mxmlFindElement(n1,   node, "Distribution", NULL, NULL, MXML_DESCEND);
        if (n1==NULL) {
          std::printf("Error: Not found 1st \"Distribution\" element for bi-modal distribution\n");
          return;
        }
        if (n2==NULL) {
          std::printf("Error: Not found 2nd \"Distribution\" element for bi-modal distribution\n");
          return;
        }
        bimodal = true;
        if (func1==NULL) func1 = new AdvModelDist(model);
        if (func2==NULL) func2 = new AdvModelDist(model);
        func1->SetAttr(n1);
        func2->SetAttr(n2);
        cp = mxmlElementGetAttr(node, "ndiv");
        ndiv = atoi(cp);
        if (ndiv==1) {
          std::printf("Error: ndiv=1 is not adequate for bi-modal distribution\n");
        }
    } else {
        bimodal = false;
        if (func1==NULL) func1 = new AdvModelDist(model);
        func1->SetAttr(node);
    }
}

void AdvModelDistribution::GetAttr( mxml_node_t* node )
{
    char s[250];
    mxml_node_t* n1;
    mxml_node_t* n2;

    if (bimodal) {
        mxmlElementSetAttr(node, "distribution", "bimodal"); 
        std::snprintf(s,sizeof(s),"%d",ndiv);
        mxmlElementSetAttr(node, "ndiv", s); 
        n1 = mxmlNewElement(node, "Distribution");
        func1->GetAttr(n1);
        n2 = mxmlNewElement(node, "Distribution");
        func2->GetAttr(n2);
    } else {
        func1->GetAttr(node);
    }
}

std::vector< std::pair<Double, Double> > AdvModelDistribution::GetWeight()
{
    std::vector< std::pair<Double, Double> >  v;
    v.clear();

    if (func1==NULL) return v;

    if (bimodal) {
      Int4 n = ndiv;
      if (n<1) n = 1;
      
      std::pair<Double,Double> range1 = func1->GetDefaultRange(1./n);
      std::pair<Double,Double> range2 = func2->GetDefaultRange(1./n);

      Double l1, l2, u1, u2, low, up;
      l1 = range1.first;
      l2 = range2.first;
      low = l1 < l2 ? l1 : l2 ;
      u1 = range1.second;
      u2 = range2.second;
      up  = u1 > u2 ? u1 : u2 ;

      Double x0, dx;
      if (n>1) {
        x0 = low;
        dx = (up-low)/(n-1);
      } else {
        x0 = (low+up)*0.5;
        dx = 0.;
      }

      v.resize(n);

      for (Int4 i = 0; i < n; i++) {
        Double x = x0 + dx * i;
        v[i].first  = x;
        v[i].second = func1->ValueAtWeight(x)
                    + func2->ValueAtWeight(x);    
      }
    } else {
      Int4 n = ndiv;
      if (n<1) n = 1;
      
      std::pair<Double,Double> range1 = func1->GetDefaultRange(1./n);

      Double low, up;
      low = range1.first;
      up  = range1.second;

      Double x0, dx;
      if (n>1) {
        x0 = low;
        dx = (up-low)/(n-1);
      } else {
        x0 = (low+up)*0.5;
        dx = 0.;
      }

      v.resize(n);

      for (Int4 i = 0; i < n; i++) {
        Double x = x0 + dx * i;
        v[i].first  = x;
        v[i].second = func1->ValueAt(x);
      }
    }

    return v;
}

AdvModelDist::AdvModelDist(AdvModel* mdl)
{
    model = mdl;

    dtype = NONE;
    
    AdvModelParamSet* params = mdl->GetParamSet();

    value = params->NewParam();

    sigma = NULL;
    upper = NULL;
    lower = NULL;
    weight= NULL;
}

AdvModelDist::~AdvModelDist()
{
    AdvModelParamSet* params = model->GetParamSet();

    if (value) params->DeleteParam(value);
    if (sigma) params->DeleteParam(sigma);
    if (upper) params->DeleteParam(upper);
    if (lower) params->DeleteParam(lower);
    if (weight)params->DeleteParam(weight);
}

Double AdvModelDist::ValueAt(Double x)
{
    Double p = 0.0;

    Double x0 = value->GetVal();
    Double s  = sigma->GetVal();
    Double c,d,e,ls,phase,q,s1;

    switch (dtype)
    {
      case NONE:
        p = 1.0;
        break;
      case GATE:
        if (x0-s <= x && x <= x0+s) {
          p = 1.0 / (2.0*s);
        } else {
          p = 0.0;
        }
        break;
      case LORENTZ:
        c = s/PI;
        d = (x-x0)*(x-x0) + s*s;
        p = c / d;
        break;
      case GAUSSIAN:
        c = sqrt(2.*PI) * s;
        d = (x-x0)*(x-x0)/(2.0*s*s);
        p = exp(-d) / c;
        break;
      case COSINE:
        phase = (x-x0)/s;
        if ( -PI <= phase && phase <= PI ) {
          p = (1.+cos(phase))/(2.*PI*s);
        } else {
          p = 0.0;
        }
        break;
      case LOG_NORMAL:
        ls = log(s);
        c = sqrt(2.*PI)*x*ls;
        d = log(x)-log(x0);
        e = 0.5*d*d/(ls*ls);
        p = exp(-e) / c;
        break;
      case  SCHULZ:
        q = x/x0;
        s1 = s + 1.0;
        c = pow(q,s);
        d = pow(s1,s1);
        e = gsl_sf_gamma(s1);
        p = c * d * exp(-s1*q)/(x0*e);
        break;
    }

    return p;
}

Double AdvModelDist::ValueAtWeight(Double x)
{
    if (weight==NULL) {
      return ValueAt(x);
    }

    return weight->GetVal() * ValueAt(x);
}

std::pair<Double,Double>  AdvModelDist::GetDefaultRange (Double threshold) {
    std::pair<Double,Double> p;

    Double x0 = value->GetVal();
    Double s  = sigma->GetVal();
    Double thr = threshold;

    if (thr <= 0.0) thr = 0.01; 

    switch (dtype)
    {
      case NONE:
        p.first  = x0;
        p.second = x0;
        break;
      case GATE:
        p.first  = x0-s;
        p.second = x0+s;
        break;
      default:
        Double a = ValueAt(x0);
        if (a < thr) 
        {
          p.first  = x0;
          p.second = x0;
          break;
        }
 
        Double x1 = x0;
        Double x2 = x0-s;
        if (x2<0.0) x2 = 0.5*x0;
        Double b;
        for (;;) {
          b = ValueAt(x2);
          if (b < thr) break;
          x2 *= 0.9;
        }

        Double x3 = ( (a-thr)*x2 + (thr-b)*x1 ) / ( a-b ) ;

        while ( abs(x1-x2)>1.0e-5 ) {
          Double c = ValueAt(x3);
          if (c > thr) {
            x1 = x3;
            a  = c;
          } else {
            x2 = x3;
            b  = c;
          }
          x3 = ( (a-thr)*x2 + (thr-b)*x1 ) / ( a-b ) ;
        }
        p.first = x3;

        x1 = x0;
        x2 = x0+s;
        for (;;) {
          b = ValueAt(x2);
          if (b < thr) break;
          x2 *= 1.1;
        }

        x3 = ( (a-thr)*x2 + (thr-b)*x1 ) / ( a-b ) ;

        while ( abs(x2-x1)>1.0e-5 ) {
          Double c = ValueAt(x3);
          if (c > thr) {
            x1 = x3;
            a  = c;
          } else {
            x2 = x3;
            b  = c;
          }
          x3 = ( (a-thr)*x2 + (thr-b)*x1 ) / ( a-b ) ;
        }
        p.second = x3;

        break;
    }

    return p;
}

void AdvModelDist::SetAttr( mxml_node_t *node )
{
    const char* cp;
    Int4  n;
    cp = mxmlElementGetAttr(node, "distribution");

    n = 0;
    if (cp) n = std::strlen(cp);

    dtype = NONE;

    if ( n > 2 ) {

      if ( tolower(cp[0])=='n' && tolower(cp[1])=='o' && tolower(cp[2])=='n' )
      { dtype = NONE; }
 
      if ( tolower(cp[0])=='g' && tolower(cp[1])=='a' && tolower(cp[2])=='t' )
      { dtype = GATE; }

      if ( tolower(cp[0])=='l' && tolower(cp[1])=='o' && tolower(cp[2])=='r' )
      { dtype = LORENTZ; }

      if ( tolower(cp[0])=='g' && tolower(cp[1])=='a' && tolower(cp[2])=='u' )
      { dtype = GAUSSIAN; }

      if ( tolower(cp[0])=='c' && tolower(cp[1])=='o' && tolower(cp[2])=='s' )
      { dtype = COSINE; }

      if ( tolower(cp[0])=='l' && tolower(cp[1])=='o' && tolower(cp[2])=='g' )
      { dtype = LOG_NORMAL; }

      if ( tolower(cp[0])=='s' && tolower(cp[1])=='c' && tolower(cp[2])=='h' )
      { dtype = SCHULZ; }

    }

    if (!value) value = model->GetParamSet()->NewParam();

    if (dtype==NONE) {
      value->SetXML( node );
      return;
    }

    mxml_node_t *node2;

    node2 = mxmlFindElement(node, node, "param1", NULL, NULL, MXML_DESCEND);
    if (node2) {
      value->SetXML( node2 );
    } else {
      std::printf("param1 is not given for distributable.\n");
    }

    if (!sigma) sigma = model->GetParamSet()->NewParam();
    node2 = mxmlFindElement(node, node, "param2", NULL, NULL, MXML_DESCEND);
    if (node2) {
      sigma->SetXML( node2 );
    } else {
      std::printf("param2 is not given for distributable.\n");
    }
}

void AdvModelDist::GetAttr( mxml_node_t *node )
{
    char  s[50];

    if (dtype==NONE) {
      value->GetXML( node );
      return;
    }

    switch (dtype) {
      case GATE:
        std::strcpy (s, "Gate"); break;
      case LORENTZ:
        std::strcpy (s, "Lorenz"); break;
      case GAUSSIAN:
        std::strcpy (s, "Gaussian"); break;
      case COSINE:
        std::strcpy (s, "Cosine"); break;
      case LOG_NORMAL:
        std::strcpy (s, "Log_normal"); break;
      case SCHULZ:
        std::strcpy (s, "Schulz"); break;
      default:
        std::strcpy (s, "unknown"); break;
    }

    mxmlElementSetAttr(node , "distribution", s);

    mxml_node_t *node2;

    node2 = mxmlFindElement(node, node, "param1", NULL, NULL, MXML_DESCEND);
    if (!node2)  node2 = mxmlNewElement(node, "param1");
    value->GetXML( node2 );

    node2 = mxmlFindElement(node, node, "param2", NULL, NULL, MXML_DESCEND);
    if (!node2)  node2 = mxmlNewElement(node, "param2");
    sigma->GetXML( node2 );

    return;
}
