from __future__ import print_function
import Manyo as mm
import Manyo.LevmarTools as ml
import os
import shutil
import numpy as np
import uGao.MPlot as mp
import utsusemi.DNA.ana.Reduction.FunctionBase as FB
import utsusemi.DNA.ana.Reduction.BaseCommandsDNA as BaseComDNA
import math

################################################################################################################

def PrintFileNametype():
    """
    @retval None
    """
    message = \
     "\nTEMP1.DAT    :EXPERIMENTAL DATA FROM SAMPLE\n"+\
     "TEMP2.DAT    :RESOLUTION FUNCTION\n"+\
     "TEMP3.DAT    :initial S(q,e)(Lorentzian)\n"+\
     "TEMP4.DAT    :initial S(q,e) times resolution\n"+\
     "TEMP5.DAT    :present S(q,e)\n"+\
     "TEMP6.DAT    :present S(q,e) times resolution\n"+\
     "outputSqe.txt:final calc result of S(q,e)\n\n"+\
     "outputSqe.txt contents bellow:\n"+\
     "e[j], s[j], th1[j], thcal[j], sect*s[j], sect*th1[j], sect*thcal[j]\n"+\
     "1  :  e          :  observation data of energy\n"+\
     "2  :  s          :  observation data of scattering intensity\n"+\
     "3  :  th1        :  S(q,e)\n"+\
     "4  :  thcal      :  S(q,e) times resolution\n"+\
     "5  :  sect*s     :  sect times observation data of scattering intensity\n"+\
     "6  :  sect*th1   :  sect times S(q,e)\n"+\
     "7  :  sect*thcal :  sect times S(q,e) times resolution\n\n"+\
     "normalization factor of spectrum:\n"+\
     "sect = integration s[e] from e[0] to e[ndat-1]\n"

    print(message)

    return

################################################################################################################

def LoadTextData(filename="TEMP6.DAT",SaveDataPath = "/usr/local/mlf/DNA/tmp/Mem/"):
    """
    @param filename      (string)
    @param SaveDataPath  (string)
    @retval ec (ElementContainer)
    """

    #SaveDataPath = "/usr/local/mlf/DNA/tmp/Mem/"
    #filename = "TEMP6.DAT"

    # TEMP1.DAT    :EXPERIMENTAL DATA FROM SAMPLE
    # TEMP2.DAT    :RESOLUTION FUNCTION
    # TEMP3.DAT    :initial S(q,e)(Lorentzian)
    # TEMP4.DAT    :initial S(q,e) times resolution
    # TEMP5.DAT    :present S(q,e)
    # TEMP6.DAT    :present S(q,e) times resolution
    # outputSqe.txt:final calc result of S(q,e)
    # e[j], s[j], th1[j], thcal[j], sect*s[j], sect*th1[j], sect*thcal[j]
    # e     :  observation data of energy
    # s     :  observation data of scattering intensity
    # th1   :  S(q,e)
    # thcal :  S(q,e) times resolution
    # normalization factor of spectrum
    # sect = integration s[e] from e[0] to e[ndat-1]
    ec = mm.ElementContainer()
    if(os.path.exists(SaveDataPath+filename)):
        f = open(SaveDataPath+filename)
        lines = f.readlines()
        f.close()
        line = []
        xline = []
        yline = []
        eline = []
     
        for i in range(len(lines)):
            line.append( lines[i].split() )
        if (len(line[0]) >= 3)and type(line[2][0]):
            xline.append(float(line[0][0])*1.5 - float(line[1][0])*0.5)
            for j in range(len(lines)):
                xline.append(float(line[j][0])*2-xline[j])
                if (line[j][1]=="nan") or (line[j][1]==" ") or (line[j][1]==""):
                    yline.append(0)
                else:
                    yline.append(float(line[j][1]))
                eline.append(float(line[j][2]))
        else:
            xline.append(float(line[0][0])*1.5 - float(line[1][0])*0.5)
            for j in range(len(lines)):
                xline.append(float(line[j][0])*2-xline[j])
                if (line[j][1]=="nan") or (line[j][1]=="") or (line[j][1]==""):
                    yline.append(0)
                else:
                    yline.append(float(line[j][1]))
                eline.append(0)

        ec.Add("xval",xline)
        ec.Add("yval",yline)
        ec.Add("eval",eline)
        ec.SetKeys("xval","yval","eval")
    else:
        print("file is not found\n")
    return ec
################################################################################################################

def LoadFinalData(filename="outputSqe.txt",SaveDataPath = "/usr/local/mlf/DNA/tmp/Mem/",linenumber=5):
    """
    @param filename      (string)
    @param SaveDataPath  (string)
    @param linenumber       (int)
    @retval ec (ElementContainer)
    """

    #SaveDataPath = "/usr/local/mlf/DNA/tmp/Mem/"
    #filename = "TEMP6.DAT"

    # TEMP1.DAT    :EXPERIMENTAL DATA FROM SAMPLE
    # TEMP2.DAT    :RESOLUTION FUNCTION
    # TEMP3.DAT    :initial S(q,e)(Lorentzian)
    # TEMP4.DAT    :initial S(q,e) times resolution
    # TEMP5.DAT    :present S(q,e)
    # TEMP6.DAT    :present S(q,e) times resolution
    # outputSqe.txt:final calc result of S(q,e)
    # e[j], s[j], th1[j], thcal[j], sect*s[j], sect*th1[j], sect*thcal[j]
    # e     :  observation data of energy
    # s     :  observation data of scattering intensity
    # th1   :  S(q,e)
    # thcal :  S(q,e) times resolution
    # normalization factor of spectrum
    # sect = integration s[e] from e[0] to e[ndat-1]
    ec = mm.ElementContainer()
    if(os.path.exists(SaveDataPath+filename)):
        f = open(SaveDataPath+filename)
        lines = f.readlines()
        f.close()
        line = []
        xline = []
        yline = []
        eline = []
        num_line = []
        if ((linenumber<1) or (linenumber>7)):
            print("cannot accept input-linenumber\n")
            print("please input 2 -7 \n")
        else:
            for i in range(3,len(lines)):
                line.append( lines[i].split() )
            xline.append(float(line[0][0])*1.5 - float(line[1][0])*0.5)
            for j in range(len(lines)-3):
                xline.append(float(line[j][0])*2-xline[j])
                if line[j][linenumber-1]=="nan":
                    yline.append(0)
                else:
                    yline.append(float(line[j][linenumber-1]))
                #eline.append(0)
            eline=[math.sqrt(c) for c in yline]

            ec.Add("xval",xline)
            ec.Add("yval",yline)
            ec.Add("eval",eline)
            ec.SetKeys("xval","yval","eval")
    else:
        print("file is not found\n")
    return ec


################################################################################################################

def OutputECText(ec,tag):
    """
    @param ec      (ElementContainer)
    @param tag      (string)
    @retval None
    """
    mh = ec.PutHeader()
    if mh.CheckKey("RUNNUMBER"):
            runno = mh.PutString("RUNNUMBER")
    else:
            runno = "tmp"

    if tag == "_":
        filenameECbin = "/usr/local/mlf/DNA/tmp/EC_"+runno+".txt"
    else:
        filenameECbin = "/usr/local/mlf/DNA/tmp/EC_"+runno+tag+".txt"

    flug = os.path.exists(filenameECbin)
    if flug:
           os.rename(filenameECbin,filenameECbin+"_bk1" )

    f = open(filenameECbin, "w")

    xbin = ec.PutX()
    ybin = ec.PutY()
    ebin = ec.PutE()

 
    for i in range(len(ybin)):
        if ((xbin[i]+xbin[i+1])*0.5<1e-10)and((xbin[i]+xbin[i+1])*0.5>-1e-10):
            f.write(str(0.0)+" "+str(ybin[i])+" "+str(ebin[i])+"\n")
        else:
            f.write(str((xbin[i]+xbin[i+1])*0.5)+" "+str(ybin[i])+" "+str(ebin[i])+"\n")
    f.close()

    print("")
    print("output >> ",filenameECbin)
    
    return

################################################################################################################

def EditInitLorentzParam(NumOfLor = 2, omega = "1.0,2.0", ene0 = "1.0,2.0", val = "1.0,2.0"):
    """
    @param NumOfLor      (int)
    @param omega      (string)
    @param ene0       (string)
    @param val        (string)
    @retval None
    """

    #NumOfLor = 2
    #omega = "1.0,2.0" 
    #ene0 = "1.0,2.0"
    #val = "1.0,2.0"

    omegaArray = omega.split(",")
    ene0Array = ene0.split(",")
    valArray = val.split(",")

    if ((len(omegaArray)!=NumOfLor) or (len(ene0Array)!=NumOfLor) or (len(valArray)!=NumOfLor)):
       print ("number of parameter are incorrect.")
    else:
        initParamFilePath="/usr/local/mlf/DNA/tmp/Mem/"
        filename = "initParam.txt"
        f = open(initParamFilePath+filename, 'w')
        for i in range(NumOfLor):
            f.write(omegaArray[i]+" "+ene0Array[i]+" "+valArray[i]+"\n")
        f.close()
    print ("Input number of Lorentz function       is "+str(NumOfLor)+".")
    print ("Input number of omega parameter        is "+str(len(omegaArray))+". : ["+omega+"]")
    print ("Input number of enegy-center parameter is "+str(len(ene0Array))+". : ["+ene0+"]")
    print ("Input number of relative intensity     is "+str(len(valArray))+". : ["+val+"]")

    return

################################################################################################################

def ReadECAndWriteText(ec,filename= "sTest"):
    """
    @param ec     (ElementContaienr)
    @param filename         (string)
    @retval None
    """
    #filename = "sTest"
    #ec = Cmm.ReadEC("EC_611.bst")

    mh = ec.PutHeader()
    if mh.CheckKey("RUNNUMBER"):
         runno = mh.PutString("RUNNUMBER")
    else:
         runno = "tmp"

    filenameECbin = "/usr/local/mlf/DNA/tmp/EC_"+runno+".txt"
    copypath = "/usr/local/mlf/DNA/tmp/MemData/"+filename

    OutputECText(ec,"_")

    shutil.copy(filenameECbin,copypath)
    print("copy   >>  "+ copypath)

    return

################################################################################################################

def MemExecute(DataFileName = "sDET05",ResFileName = "sV_5",memData = 1,NumOfLor = 2,columnParam = "2,2,3,4,1,2,3",Itteration = 10,MemThreshold = 1.0,savepath = "/usr/local/mlf/DNA/tmp/Mem/"):
    """
    @param DataFileName     (string)
    @param ResFileName      (string)
    @param memData          (int)
    @param NumOfLor         (int)
    @param columnParam      (string)
    @param Itteration       (int)
    @param MemThreshold     (int)
    @param savepath         (string)
    @retval None
    """
    DataFilePath = "/usr/local/mlf/DNA/tmp/MemData/"

    ## function parameter
    #DataFileName = "sDET05"
    #ResFileName = "sV_5"
    #memData = 1
    #NumOfLor = 2
    #columnParam = "2,2,3,4,1,2,3"
    #Itteration = 10
    #MemThreshold = 1.0
    #savepath = "/usr/local/mlf/DNA/tmp/Mem/"

    # for Utsusemi
    if columnParam=="-1":
        columnParamVal = "0,1,2,3,0,1,2"
    else:
        columnParamVal = columnParam
    print(columnParamVal)
    columnParamArray = columnParamVal.split(",")

    cmlineDataParam = int(columnParamArray[0])
    KcorEinParam = int(columnParamArray[1])
    KcorSinParam = int(columnParamArray[2])
    KcorRinParam = int(columnParamArray[3])
    cmlineResParam = int(columnParamArray[4])
    LcorEInParam = int(columnParamArray[5])
    LcorNInParam = int(columnParamArray[6])
    
    DataFile = DataFilePath + DataFileName
    ResFile = DataFilePath + ResFileName

    Mem = ml.AdvInsMEM()
    Mem.Begin(DataFile,ResFile,memData,NumOfLor,cmlineDataParam,KcorEinParam,KcorSinParam,KcorRinParam,cmlineResParam,LcorEInParam,LcorNInParam,Itteration,MemThreshold,savepath)

    #int AdvInsMEM::Begin(char *DataFileName,char *ResFileName, int memData,int NumOfLor, int cmlineDataParam, int KcorEinParam,int KcorSinParam, int KcorRinParam,int cmlineResParam,int LcorEInParam, int LcorNInParam, int Itteration,double MemThreshold,char *savepath)

    #    * @param char  DataFileName   :: Experimental data File Name (ascii file)
    #    * @param char  ResFileName    :: Resolution data File Name (ascii file)
    #    * @param int   memData        :: MEM -- Initial Data
    #    *              1 = Conbination of Lorentzian functions / 2 = Observed data
    #    * @param int   NumOfLor       :: Number of Lorenzian function (fitting functions)
    #    * @param int   
    #    *              line definition parameter
    #    * @param int   cmlineDataParam :: Number of comment Line of Experimental-data-File
    #    * @param int   KcorEinParam    :: Column number of Energy Axis
    #    * @param int   KcorSinParam    :: Column number of Scattering Intensity
    #    * @param int   KcorRinParam    :: Column number of Statistic Error of Event
    #    * @param int   cmlineResParam  :: Number of comment Line of Resolution-data-File
    #    * @param int   LcorEInParam    :: Column number of Energy Axis
    #    * @param int   LcorNInParam    :: Column number of Scattering Intensity
    #    * @param int   Itteration      :: Maximum number of itteration
    #    * @param double MemThreshold   :: MEM threshold value
    #    * @param char  savepath        :: temporary save file path

    return

################################################################################################################

def MemExecuteManual():
    """
    @retval None
    """
    Mem = ml.AdvInsMEM()
    Mem.Manual()

    return

################################################################################################################

def MemDeconvolution(ecData, ecRes, memData = 1,NumOfLor = 2,Itteration = 10,MemThreshold = 1.0,savepath = "/usr/local/mlf/DNA/tmp/Mem/"):
    """
    @param ecData           (ElementContainer)
    @param ecRes            (ElementContainer)
    @param memData          (int)
    @param NumOfLor         (int)
    @param Itteration       (int)
    @param MemThreshold     (int)
    @param savepath         (string)
    @retval EC              (ElementContainer)
    """
    columnParam = "0,1,2,3,0,1,2"
    ReadECAndWriteText(ecData,"sDET")
    ReadECAndWriteText(ecRes,"sRES")
    MemExecute("sDET","sRES",memData,NumOfLor,columnParam,Itteration,MemThreshold,savepath)
    filename="TEMP5.DAT_"+str(Itteration)
    EC = LoadTextData(filename,savepath)

    print("output ElementContainer("+filename+")\n")
    print("if you want to output other file, please exutute LoadTextData(filename,savepath) command\n")
    print("file name list is bellow\n")

    PrintFileNametype()

    return EC

################################################################################################################

def GslFFTConvolution(ecData, ecRes):
    """
    @param ecData           (ElementContainer)
    @param ecRes            (ElementContainer)
    @retval ec              (ElementContainer)
    """
    src = mm.ElementContainerArray()

    src.Add(ecRes)
    src.Add(ecData)

    preDriver = ml.AdvConvDeconvPre(src, ml.EQUAL_SPACING)
    preDriver.execute()
    tmp = preDriver.getResult()

    print("**** convolution ****")
    convDriver=ml.AdvConvDeconv(tmp, ml.CONVOLUTION_BY_FFT)
    convDriver.execute()
    result_con = convDriver.getResult()
    ec = result_con.Put(1)

    return ec

################################################################################################################

def GslFFTDeconvolution(ecData, ecRes):
    """
    @param ecData           (ElementContainer)
    @param ecRes            (ElementContainer)
    @retval ec              (ElementContainer)
    """
    src = mm.ElementContainerArray()

    src.Add(ecRes)
    src.Add(ecData)

    preDriver = ml.AdvConvDeconvPre(src, ml.EQUAL_SPACING)
    preDriver.execute()
    tmp = preDriver.getResult()

    print("**** deconvolution ****")
    deconvDriver=ml.AdvConvDeconv(tmp, ml.DECONVOLUTION_BY_FFT)
    deconvDriver.execute()
    result_decon = deconvDriver.getResult()
    ec = result_decon.Put(1)

    return ec

################################################################################################################

def NumpyFFT(ec0):
    """
    @param ec0            (ElementContainer)
    @retval ec            (ElementContainer)
    """

    xbin = ec0.PutX()
    ybin = ec0.PutY()
    ebin = ec0.PutE()

    xbin_new = xbin
    ybin_fft_list = np.fft.fft(ybin)
    ebin_fft_list = np.fft.fft(ebin)

    y_fft_Amplitude = [float(np.sqrt(c.real ** 2 + c.imag ** 2)) for c in ybin_fft_list]
    e_fft_Amplitude = [float(np.sqrt(c.real ** 2 + c.imag ** 2)) for c in ebin_fft_list]


    ec=mm.ElementContainer()

    ec.Add("x",xbin_new)
    ec.Add("y",y_fft_Amplitude)
    ec.Add("e",e_fft_Amplitude)

    ec.SetKeys("x","y","e")

    return ec


################################################################################################################

def NumpyIFFT(ec0):
    """
    @param ec0            (ElementContainer)
    @retval ec            (ElementContainer)
    """

    xbin = ec0.PutX()
    ybin = ec0.PutY()
    ebin = ec0.PutE()

    xbin_new = xbin
    ybin_fft_list = np.fft.ifft(ybin)
    ebin_fft_list = np.fft.ifft(ebin)

    y_fft_Amplitude = [float(np.sqrt(c.real ** 2 + c.imag ** 2)) for c in ybin_fft_list]
    e_fft_Amplitude = [float(np.sqrt(c.real ** 2 + c.imag ** 2)) for c in ebin_fft_list]

    ec=mm.ElementContainer()

    ec.Add("x",xbin_new)
    ec.Add("y",y_fft_Amplitude)
    ec.Add("e",e_fft_Amplitude)

    ec.SetKeys("x","y","e")

    return ec

################################################################################################################

def NumpyFFTConvolution(ecDet,ecRes):
    """
    @param ecDet            (ElementContainer)
    @param ecRes            (ElementContainer)
    @retval ec            (ElementContainer)
    """
    # Rebinning
    #if (ecRes.PutX().size()>ecDet.PutX().size()):
    #    ecDetRB=ecDet.ReBin(ecRes.PutX())
    #    ecResRB=ecRes
    #if (ecRes.PutX().size()<=ecDet.PutX().size()):
    #    ecResRB=ecRes.ReBin(ecDet.PutX())
    #    ecDetRB=ecDet
    ecResRB=ecRes.ReBin(ecDet.PutX())
    ecDetRB=ecDet

    #f_res => F_res FFT
    xbinRes = ecResRB.PutX()
    ybinRes = ecResRB.PutY()
    ebinRes = ecResRB.PutE()
    ybinRes_fft_list = np.fft.fft(ybinRes)
    ebinRes_fft_list = np.fft.fft(ebinRes)

    #f_det => F_det FFT
    xbinDet = ecDetRB.PutX()
    ybinDet = ecDetRB.PutY()
    ebinDet = ecDetRB.PutE()
    ybinDet_fft_list = np.fft.fft(ybinDet)
    ebinDet_fft_list = np.fft.fft(ebinDet)

    y_fft_Amplitude = []
    e_fft_Amplitude = []
    fft_a = []
    fft_b = []
    fft_c = []

    #Convolution dividing section, F_det/F_res
    for i in range(ybinDet.size()):
        fft_a.append(ybinDet_fft_list[i]*ybinRes_fft_list[i])
        fft_b.append(ebinDet_fft_list[i]*ybinRes_fft_list[i])
        fft_c.append(ebinRes_fft_list[i]*ybinDet_fft_list[i])
        y_fft_Amplitude.append(fft_a[i])
        e_fft_Amplitude.append(fft_b[i]+fft_c[i])

    #Inverse FFT, F_det*F_res => f_convoluted_value
    y_ifft_Amplitude  = np.fft.ifft(y_fft_Amplitude)
    e_ifft_Amplitude  = np.fft.ifft(e_fft_Amplitude)

    #calculation of absolute value of f_expected(complex)
    y_Amplitude = [float(np.sqrt(c.real ** 2 + c.imag ** 2)) for c in y_ifft_Amplitude]
    e_Amplitude = [float(np.sqrt(c.real ** 2 + c.imag ** 2)) for c in e_ifft_Amplitude]
    
    ec_Amplitude=mm.ElementContainer()
    ec_Amplitude.Add("x",xbinDet)
    ec_Amplitude.Add("y",y_Amplitude)
    ec_Amplitude.Add("e",e_Amplitude)
    ec_Amplitude.SetKeys("x","y","e")

    # shift the half phase
    ec=mm.ElementContainer()
    ec = FB.HalfRotationX(ec_Amplitude)

    return ec

################################################################################################################

def NumpyFFTDeconvolution(ecDet,ecRes):
    """
    @param ecDet            (ElementContainer)
    @param ecRes            (ElementContainer)
    @retval ec            (ElementContainer)
    """
    # Rebinning
    #if (ecRes.PutX().size()>ecDet.PutX().size()):
    #    ecDetRB=ecDet.ReBin(ecRes.PutX())
    #    ecResRB=ecRes
    #if (ecRes.PutX().size()<=ecDet.PutX().size()):
    #    ecResRB=ecRes.ReBin(ecDet.PutX())
    #    ecDetRB=ecDet
    ecResRB=ecRes.ReBin(ecDet.PutX())
    ecDetRB=ecDet

    #f_res => F_res FFT
    xbinRes = ecResRB.PutX()
    ybinRes = ecResRB.PutY()
    ebinRes = ecResRB.PutE()
    ybinRes_fft_list = np.fft.fft(ybinRes)
    ebinRes_fft_list = np.fft.fft(ebinRes)

    #f_det => F_det FFT
    xbinDet = ecDetRB.PutX()
    ybinDet = ecDetRB.PutY()
    ebinDet = ecDetRB.PutE()
    ybinDet_fft_list = np.fft.fft(ybinDet)
    ebinDet_fft_list = np.fft.fft(ebinDet)

    y_fft_Amplitude = []
    e_fft_Amplitude = []
    fft_a = []
    fft_b = []
    fft_c = []

    #deConvolution dividing section, F_det/F_res
    for i in range(ybinDet.size()):
        if(np.abs(ybinRes[i])!=0):
                fft_a.append(ybinDet_fft_list[i]*np.conjugate(ybinRes_fft_list[i])/(np.abs(ybinRes_fft_list[i])**2))
                fft_c.append(ebinRes_fft_list[i]*np.conjugate(ybinRes_fft_list[i])/(np.abs(ybinRes_fft_list[i])**2))
        elif(np.abs(ybinRes[i])==0):
                fft_a.append(0)
                fft_c.append(0)
        if(np.abs(ybinDet[i])!=0):
                fft_b.append(ebinDet[i]*np.conjugate(ybinDet_fft_list[i])/(np.abs(ybinDet_fft_list[i])**2))
        elif(np.abs(ybinDet[i])==0):
                fft_b.append(0)
        y_fft_Amplitude.append(fft_a[i])
        e_fft_Amplitude.append(fft_a[i]*(fft_b[i]-fft_c[i]))

    #Inverse FFT, F_det/F_res => f_expected
    y_ifft_Amplitude  = np.fft.ifft(y_fft_Amplitude)
    e_ifft_Amplitude  = np.fft.ifft(e_fft_Amplitude)

    #calculation of absolute value of f_expected(complex)
    y_Amplitude = [float(np.sqrt(c.real ** 2 + c.imag ** 2)) for c in y_ifft_Amplitude]
    e_Amplitude = [float(np.sqrt(c.real ** 2 + c.imag ** 2)) for c in e_ifft_Amplitude]
    
    ec_Amplitude=mm.ElementContainer()
    ec_Amplitude.Add("x",xbinDet)
    ec_Amplitude.Add("y",y_Amplitude)
    ec_Amplitude.Add("e",e_Amplitude)
    ec_Amplitude.SetKeys("x","y","e")

    # shift the half phase
    ec=mm.ElementContainer()
    ec = FB.HalfRotationX(ec_Amplitude)
    return ec

################################################################################################################

def OutputPartOfEC(ec,initial,final):
    """
    @param  ec         (ElementContaienr)
    @param  initial    (double)
    @param  final      (double)
    """
    return BaseComDNA.OutputPartOfEC(ec,initial,final)

################################################################################################################
