#!/usr/bin/python

from math import *
import numpy as np
import Manyo as m
import NewLevmar as Adv
#import Adv
import time

####    linear combination of gaussian    ####
#def dist(x, n, a, c, w):
def dist(x, param):
    sum=param[0]
    for i in range(len(param[1])):
        a, c, w=param[1][i]
        sum = sum + a*exp(-1.0*((x-c)/w)**2)
    return sum

####    sign function    ####
def sign(r):
    if r > 0.0:
        s = 1.0
    elif r == 0.0:
        s = 0.0
    else:
        s = -1.0
    return s

####    initialize an ElementContainer    ####
#def initElementContainer(xmin, xmax, nDiv):
def initElementContainer(xmin, xmax, nDiv, param):

    bin = m.MakeDoubleVector(nDiv+1)
    delta = (xmax - xmin)/nDiv
    for i in range(nDiv+1):
        bin[i] = xmin + delta*i

    xc = m.MakeDoubleVector(nDiv)
    for i in range(nDiv):
        xc[i] = (bin[i] + bin[i+1])/2.0

    #nPeak=1
    #a=m.MakeDoubleVector(nPeak)
    #c=m.MakeDoubleVector(nPeak)
    #w=m.MakeDoubleVector(nPeak)
    #a[0]=100000.0
    #c[0]=7.5
    #w[0]=2.5
    y = m.MakeDoubleVector(nDiv)
    e = m.MakeDoubleVector(nDiv)
    for i in range(nDiv):
        #p = dist(xc[i], nPeak, a, c, w)
        p = dist(xc[i], param)
        q = sign(np.random.rand(1)-0.5)*(p/10.0)*sqrt(-1.0*log(np.random.rand(1)))
        y[i] = p + q
        e[i] = p/10.0

    ec = m.ElementContainer()
    ec.AddToHeader("run number", 1)
    ec.AddToHeader("level",      1)
    ec.AddToHeader("Inst.",      "manyo")
    ec.Add("TOF",       bin, "sec." )
    ec.Add("Intensity", y,   "count")
    ec.Add("Error",     e,   "count")
    ec.SetKeys("TOF", "Intensity", "Error")

    return ec

def outputElementContainer(ec):
    bin = ec.Put(ec.PutXKey())
    y   = ec.Put(ec.PutYKey())
    e   = ec.Put(ec.PutEKey())

    print str("No.").rjust(4), str("----------- bin -----------").rjust(27), str("xc").rjust(10), str("Intensity").rjust(23), str("error").rjust(23)
    for i in range(ec.PutSize(ec.PutYKey())):
         print str(i).rjust(4), "[", str(bin[i]).rjust(10), ",", str(bin[i+1]).rjust(10), ")", str((bin[i]+bin[i+1])/2.0).rjust(10), str(y[i]).rjust(23), str(e[i]).rjust(23)

####    plot    ####
def knownTerminal(type):
    retval=False
    if type=="emf":
        retval=True
    elif type=="jpeg":
        retval=True
    elif type=="png":
        retval=True
    elif type=="postscript":
        retval=True
    elif type=="x11":
        retval=True
    else:
        if type == "":
            print "knownTerminal: empty string interpreted x11\n"
        else:
            print "knownTerminal: unknown terminal type: ", type
    print "knownTerminal: ", retval
    return retval

def createTerminalCmd(terminalType):
    cmd=""
    if terminalType == "emf":
        cmd="set terminal emf"
    elif terminalType == "jpeg":
        cmd="set terminal jpeg"
    elif terminalType == "png":
        cmd="set terminal png"
    elif terminalType == "postscript":
        cmd="set terminal postscript eps"
    elif terminalType == "x11":
        cmd=""
    else:
        if terminalType != "":
            print "createTerminalCmd: unknown terminal type: ", type
    print "createTerminalCmd: ", cmd
    return cmd


def createOutputCmd(terminalType, fileNameStem):
    cmd=""
    if terminalType == "emf":
        cmd="set output '"+fileNameStem+".emf"  + "'"
    elif terminalType == "jpeg":
        cmd="set output '"+fileNameStem+".jpeg" + "'"
    elif terminalType == "png":
        cmd="set output '"+fileNameStem+".png"  + "'"
    elif terminalType == "postscript":
        cmd="set output '"+fileNameStem+".eps"  + "'"
    elif terminalType == "x11":
        cmd=""
    else:
        if terminalType != "":
            print "createOutoutCmd: unknown terminal type: ", type
    print "createOutputCmd: ", cmd
    return cmd

def plotPeakData(src, smoothed, peakData, terminal):
    srcBin=src.Put(src.PutXKey())

    g=m.GnuplotInterface()

    if knownTerminal(terminal):
        g.e(createTerminalCmd(terminal))
        g.e(createOutputCmd(terminal, "peakData"))

    g.e("set xlabel '" + src.PutXKey() + " / " + src.PutUnit(src.PutXKey()) + "'")
    g.e("set xrange[" + str(srcBin.front()) + ":" +str(srcBin.back())+"]")

    g.e("set ylabel '" + src.PutYKey() + " / " + src.PutUnit(src.PutYKey()) + "'")
    g.e("set autoscale y")

    v=peakData.toVector()
    for i in range(peakData.size()):
        peak=peakData.getPeak(i)
        v=peak.toFullVector()
        h=v[0] # height
        c=v[1] # position
        l=v[3] # left  pos at HWHM
        u=v[4] # right pos at HWHM
        #print "peak id = ", i, p
        g.e("set arrow "+str(3*i+1)+" from "+str(c)+", 0 to "+str(c)+", 150000 nohead\n")  # guide line for a peak or a sholder
        g.e("set label "+str(4*i+1)+" '"+str(c)+"' at "+str(c)+",10000" )                  # the position of a peak or a sholder
        g.e("set label "+str(4*i+2)+" '"+str(h)+"' at "+str(c)+","+str(h)+" front" )       # the height   of a peak or a sholder
        if c-l > 0.0:
            g.e("set arrow "+str(3*i+2)+" from "+str(c)+","+str(h/2.0)+" to "+str(l)+","+str(h/2.0)) # guide line for the lower bound at hwhm
            g.e("set label "+str(4*i+3)+" '"+str(c-l)+"' at "+str((c+l)/2.0)+","+str(h/2.0)+" center front" )  # the width
        if u-c > 0.0:
            g.e("set arrow "+str(3*i+3)+" from "+str(c)+","+str(h/2.0)+" to "+str(u)+","+str(h/2.0)) # guide line  for the upper bound at hwhm
            g.e("set label "+str(4*i+4)+" '"+str(u-c)+"' at "+str((c+u)/2.0)+","+str(h/2.0)+" center front" )  # the width

    while 1:
        g.e("plot '-' using 1:2 with lines linewidth 1 title 'source', '-' using 1:3 with lines linewidth 1 title 'smoothed'")

        for i in range(src.PutSize(src.PutYKey())):
            g.e( str( (srcBin[i]+srcBin[i+1])/2.0 ) + " " + str( src.Put(src.PutYKey(), i) )+ " " + str( smoothed.Put(smoothed.PutYKey(), i) ) )
        g.e("e")

        if terminal=="x11" or terminal=="":
            c=raw_input("hit any key except return for next\n")
            if c != "":
                break
        else:
            g.e("replot")
            break
    #g.e("pause -1 'hit any key'")

    del g
    del srcBin


def plotNorm(method, terminal):
    ic = method.getConvergenceHistoryForInt4(  Adv.NewLevmar.ITERATION_COUNT  )
    rf = method.getConvergenceHistoryForDouble(Adv.NewLevmar.R_FACTOR         )
    rn = method.getConvergenceHistoryForDouble(Adv.NewLevmar.RESIDUAL_ERR_NORM)
    gn = method.getConvergenceHistoryForDouble(Adv.NewLevmar.GRADIENT_NORM    )
    pn = method.getConvergenceHistoryForDouble(Adv.NewLevmar.PARAM_DIFF_NORM  )
    waitTime  = method.getConvergenceHistoryForInt4(  Adv.NewLevmar.ITERATION_TIME ).back()

    g=m.GnuplotInterface()

    #if knownTerminal(terminal):
    #    g.e(createTerminalCmd(terminal))
    #    g.e(createOutputCmd(terminal, "norm"))

    g.e("set xlabel 'iteration'")
    if ic.size() == 1:
        cmd="set xrange [-0.5:0.5]"
    else:
        cmd="set xrange ["+repr(ic.front())+":"+repr(ic.back())+"]"
    g.e(cmd)

    g.e("set ylabel 'norm'")
    g.e("set autoscale y")
    g.e("set logscale y 10.0")

    cmd =  "plot"
    cmd = cmd + "  '-' using 1:2 with linespoints title '" + Adv.NewLevmar.R_FACTOR          + "'"
    #cmd = cmd + ", '-' using 1:3 with linespoints title '" + Adv.NewLevmar.RESIDUAL_ERR_NORM + "'"
    #cmd = cmd + ", '-' using 1:4 with linespoints title '" + Adv.NewLevmar.GRADIENT_NORM     + "'"
    #cmd = cmd + ", '-' using 1:5 with linespoints title '" + Adv.NewLevmar.PARAM_DIFF_NORM   + "'"
    g.e(cmd)

    for i in range(ic.size()):
        data =              str(ic[i]/1.0).rjust(5)
        data = data + " " + str(rf[i]    ).rjust(23)
        data = data + " " + str(rn[i]    ).rjust(23)
        data = data + " " + str(gn[i]    ).rjust(23)
        data = data + " " + str(pn[i]    ).rjust(23)
        g.e(data)
        #print data
    g.e("e")

    g.e("pause " + str(waitTime/1000000))

    del g
    del ic
    del rf
    del rn
    del gn
    del pn
    del waitTime


def plot(src, smoothed, fittied, terminal):
    srcBin      = src.Put(src.PutXKey())
    srcY        = src.Put(src.PutYKey())
    smoothedBin = smoothed.Put(smoothed.PutXKey())
    smoothedY   = smoothed.Put(smoothed.PutYKey())
    fittiedBin  = fittied.Put(fittied.PutXKey())
    fittiedY    = fittied.Put(fittied.PutYKey())

    g=m.GnuplotInterface()

    #if knownTerminal(terminal):
    #    g.e(createTerminalCmd(terminal))
    #    g.e(createOutputCmd(terminal, "result"))

    g.e("set xlabel '" + src.PutXKey() + " / " + src.PutUnit(src.PutXKey()) + "'")
    g.e("set xrange[" + str(srcBin.front()) + ":" +str(srcBin.back())+"]")

    g.e("set ylabel '" + src.PutYKey() + " / " + src.PutUnit(src.PutYKey()) + "'")
    g.e("set autoscale y")

    while 1:
        cmd="plot '-' using 1:2 with lines linewidth 1 title 'original', '-' using 1:3 with lines linewidth 2 title 'smoothed', '-' using 1:4 with lines linewidth 2 title 'fittied'"
        g.e(cmd)

        for i in range(srcY.size()):
            data=str( (srcBin[i]+srcBin[i+1])/2.0 ) + " " + str( srcY[i] )+ " " + str( smoothedY[i] ) + " " + str( fittiedY[i] )
            g.e(data)
        g.e("e")

        g.e("replot")
        c=raw_input("hit any key except return\n")
        if c != "":
            break

    del g
    del srcBin
    del srcY
    del fittiedBin
    del fittiedY

####    main routine    ####
xmin = 0.0
xmax = 15.0
nDiv = 1500
param1=[1000.0,     [[100000.0, 5.0, 1.5], [100000.0, 10.0, 1.5], [35000.0, 2.8, 1.0] ]]

src = initElementContainer(xmin, xmax, nDiv, param1)
#outputElementContainer(src)

peakSearch = Adv.PeakSearch(src, Adv.BSPLINE)
peakSearch.setParam(Adv.BSpline.NUMBER_OF_BREAK_POINTS, 16)
if peakSearch.checkParam():
    peakSearch.execute()
    smoothed = peakSearch.getResult()
    peakData = peakSearch.getPeaks()
    plotPeakData(src, smoothed, peakData, "postscript")
    plotPeakData(src, smoothed, peakData, "")

domain = Adv.Domain()
domain.setSource(src)
domain.setRange(xmin, xmax)
print "[", str(domain.getLowerBound()).rjust(10), ",", str(domain.getUpperBound()).rjust(10), "]"
print "[", str(domain.getLowerBoundID()).rjust(10), ",", str(domain.getUpperBoundID()).rjust(10), "]"
print
print "type of the domain", domain.getType()
#domain.setType(Domain.OO)
#print "type of the domain", domain.getType()

levmar=Adv.NewLevmar()

paramSet = levmar.setDefaultParam(src);
paramSet.dump()

# initial values of fitting parameters
param=peakData.toVector()

# box constrain for fitting parameters
lb=m.DoubleVector()
for i in range(param.size()):
    lb.push_back(param[i]*0.9)
ub=m.DoubleVector()
for i in range(param.size()):
    ub.push_back(param[i]*1.1)

paramSet.add(Adv.NewLevmar.PARAMETER_VALUES, param)
paramSet.add(Adv.NewLevmar.LOWER_BOUNDS,     lb)
paramSet.add(Adv.NewLevmar.UPPER_BOUNDS,     ub)

#combinations of fitting function
funcStr=""
for i in range(peakData.size()):
    funcStr=funcStr+" g"
   
parser=Adv.FuncParser(funcStr) # a string expression for list of function names or symbols
#funcList=parser.parse()
paramSet.add(Adv.NewLevmar.FUNCTIONS, parser.parse())

paramSet.replace(Adv.NewLevmar.CONSTRAIN,          Adv.NewLevmar.NO_CONSTRAIN)  # default Adv.NewLevmar.BOX
paramSet.replace(Adv.NewLevmar.USE_NUMERICAL_DIFF, False)                             # default True
#paramSet.replace(Adv.NewLevmar.DIFF_METHOD,        Adv.NewLevmar.CENTRAL)       # default FOWARD
#paramSet.replace(Adv.NewLevmar.USE_DATA_WEIGHTS,   False)                             # default True
paramSet.replace(Adv.NewLevmar.MAX_ITERATIONS,     1000)                              # default 1000
paramSet.replace(Adv.NewLevmar.OUTPUT_INTERVAL,    50)                               # default 50

paramSet.replace(Adv.NewLevmar.SCALING_FACTOR,     Adv.NewLevmar.DEFAULT_SCALING_FACTOR)     # default 0.001
paramSet.replace(Adv.NewLevmar.TOLERANCE,          Adv.NewLevmar.DEFAULT_TOLERANCE)          # default 1.0e-17
paramSet.replace(Adv.NewLevmar.GRADIENT_TOLERANCE, Adv.NewLevmar.DEFAULT_GRADIENT_TOLERANCE) # default 1.0e-17
paramSet.replace(Adv.NewLevmar.RELATIVE_TOLERANCE, Adv.NewLevmar.DEFAULT_RELATIVE_TOLERANCE) # default 1.0e-17
paramSet.dump()

print "parameter check: ", levmar.checkParam(src, domain, paramSet)  # parameter check, true if parametes are consistent

if  levmar.checkParam(src, domain, paramSet):
    levmar.toInnerForm(src, domain, paramSet)  # translate inner form
    levmar.fit()                               # fitting start

    if levmar.isMultiThreaded():               # if levmar is multi-threaded method, wait while fitting
        ct=-1
        while levmar.isFitting():              # loop until stopping function fit
            time.sleep(1)                      #     
            s = levmar.getHistorySize()
            if s > ct:
                #    print "ct=", ct, " hestory size=", s
                plotNorm(levmar, "")
                ct = s
        plotNorm(levmar, "")

    levmar.eval()                              # evaluate values of functions useing fitted parameter
    fitted=m.ElementContainer()
    levmar.toElementContainer(src, fitted)     # copy fitted results to a element container
    plot(src, smoothed , fitted, "")
    #outputElementContainer(fitted)

