from __future__ import print_function
import Cmm
import numpy
import os
import Manyo as mm
import Manyo.Utsusemi as mu
import math

###############################################################################
# Efficiency Correction Util
###############################################################################

def PutIntegralIntensityArray(DAT,energyMin,energyMax):
    retS=[]
    retM=[]
    EC = DAT.Put(0).Put(0)
    ecXn = numpy.array(EC.PutXList())
    ecX  = (ecXn[1:]+ecXn[:-1])*0.5
    minReginArray  = ecX>energyMin
    maxReginArray  = ecX<energyMax
    regionArray    = numpy.logical_and(minReginArray,maxReginArray)

    ecaSize = DAT.PutTableSize()
    for i in range(ecaSize):
        ECA = DAT.Put(i)
        retAS=[]
        retAM=[]
        ecSize = ECA.PutTableSize()
        for j in range(ecSize):
            EC = ECA.Put(j)
            mask = EC.PutHeaderPointer().PutInt4("MASKED")
            #mask==1 masked / mask ==0 not masked

            ecY  = numpy.array(EC.PutYList())
            ecE  = numpy.array(EC.PutEList())
            errMaskArray   = ecE>=0
            conditionArray = numpy.logical_and(regionArray,errMaskArray)
            ecY            = ecY[conditionArray]

            sumVal = numpy.sum(ecY)

            retAS.append(sumVal)
            retAM.append(mask)
        retS.append(retAS)
        retM.append(retAM)
    return retS,retM
#------------------------------------------------------------------------------
def PutIntegralIntensityArray2(DAT,energyMin,energyMax):
    retS=[]
    retM=[]
    EC = DAT.Put(0).Put(0)
    ecXn = numpy.array(EC.PutXList())
    ecX  = (ecXn[1:]+ecXn[:-1])*0.5
    minReginArray  = ecX>energyMin
    maxReginArray  = ecX<energyMax
    regionArray    = numpy.logical_and(minReginArray,maxReginArray)

    ecaSize = DAT.PutTableSize()
    for i in range(ecaSize):
        ECA = DAT.Put(i)
        retAS=[]
        retAM=[]
        ecSize = ECA.PutTableSize()
        for j in range(ecSize):
            EC = ECA.Put(j)
            mask = EC.PutHeaderPointer().PutInt4("MASKED")
            #mask==1 masked / mask ==0 not masked

            ecY  = numpy.array(EC.PutYList())
            ecE  = numpy.array(EC.PutEList())
            errMaskArray   = ecE>=0
            conditionArray = numpy.logical_and(regionArray,errMaskArray)
            ecY            = ecY[conditionArray]
            if mask==0:
                sumVal = numpy.sum(ecY)
            else:
                sumVal = 0
            retAS.append(sumVal)
            retAM.append(mask)
        retS.append(retAS)
        retM.append(retAM)
    return retS,retM
#------------------------------------------------------------------------------
def PutAverageIntegralIntensity(integralArray,maskArray):
    integralArray = numpy.array(integralArray)
    maskArray = numpy.array(maskArray)

    cond = (maskArray == 0)
    integralArray = integralArray[cond]

    #average = numpy.average(integralArray)
    sumval  = numpy.sum(integralArray)
    sizeval = numpy.size(integralArray)

    #print "average:",average
    print("sumval:",sumval)
    print("sizeval:",sizeval)
    print("average:",sumval/sizeval)

    return sumval/sizeval
#------------------------------------------------------------------------------
def PutEbinWidth(DAT):
    EC  = DAT.Put(0).Put(0)
    ECX = EC.PutXList()
    d   = ECX[1]-ECX[0]
    print("Ebin Width :",d)
    return d
#------------------------------------------------------------------------------
def WriteMaskFileFromAverageIntensity(output,integralArray,maskArray,averageVal,thresh=0.5):
    f = open(output,"w")

    #fPath=os.path.join(os.environ["UTSUSEMI_USR_DIR"],"ana","xml",output)
    fPath=os.path.join(mu.UtsusemiEnvGetBaseDir(),mu.UtsusemiEnvGetInstCode(),"ana","xml",output)
    g = open(fPath,"w")

    integralArray = numpy.array(integralArray)
    maskArray     = numpy.array(maskArray)
    condMask      = (maskArray == 0)
    condHalfVal   = (integralArray >= averageVal*thresh)

    cond = numpy.logical_and(condHalfVal,condMask)
    sep = " "
    for i,line in enumerate(cond):
        if numpy.sum(line)==0:
            text = str(i)+"\n"
        else:
            vals = ["%d.%d"%(i,j) for j,flag in enumerate(line) if not flag]
            text = sep.join(vals)+"\n"
        f.write(text)
        g.write(text)
    f.close()
    g.close()
    print(fPath)
#------------------------------------------------------------------------------
def CreateMaskFileFromAverageIntensity(DAT,energyMin,energyMax,output,thresh=0.5):
    integs,masks = PutIntegralIntensityArray(DAT,energyMin,energyMax)
    average      = PutAverageIntegralIntensity(integs,masks)
    WriteMaskFileFromAverageIntensity(output,integs,masks,average,thresh)
#------------------------------------------------------------------------------
def PutCrossSectionScaleFactor(DAT, energyMin, energyMax,
                               container1, sample_filename1,
                               nist_filename):
    integs,masks = PutIntegralIntensityArray(DAT,energyMin,energyMax)
    average      = PutAverageIntegralIntensity(integs,masks)
    c1 = GetCrossSection(container1, sample_filename1, nist_filename)
    d  = PutEbinWidth(DAT)
    return average*d/c1
#------------------------------------------------------------------------------
def DoCrossSectionScaleFactorCorrection(ECM,scalefactor,sampleMass=1.0,unitMass=1.0,mode=1):
    AvogadroC =6.022141*1e23
    #d = PutEbinWidth(ECM)
    d = 1.0  # changed by M. Matsuura 2015.4.31
    ecaSize = ECM.PutTableSize()
    ecm = mm.ElementContainerMatrix(ECM.PutHeader())
    for i in range(ecaSize):
        ECA = ECM.Put(i)
        eca = mm.ElementContainerArray(ECA.PutHeader())
        ecSize = ECA.PutTableSize()
        for j in range(ecSize):
            EC = ECA.Put(j)
            r = d / scalefactor / (sampleMass/unitMass) / AvogadroC
            if mode==1:
                ec = EC.Mul(r)
            else:
                if r!=0:
                    ec = EC.Mul(1.0/r)
                else:
                    ec = EC.Mul(0.0)
            ech=ec.PutHeaderPointer()

            ecY  = numpy.array(ec.PutYList())
            ecE  = numpy.array(ec.PutEList())
            errMaskArray   = ecE>=0
            ecY            = ecY[errMaskArray]
            sumVal = numpy.sum(ecY)

            ech.OverWrite("TotalCounts",sumVal)
            eca.Add(ec)
        ecm.Add(eca)
    return ecm
#------------------------------------------------------------------------------
def PutAverage(DAT, energyMin, energyMax):
    integs,masks = PutIntegralIntensityArray(DAT,energyMin,energyMax)
    average      = PutAverageIntegralIntensity(integs,masks)
    return average
#------------------------------------------------------------------------------
def WriteEfficiencyFile(output,integralArray,maskArray,averageVal):
    f = open(output,"w")

    #fPath=os.path.join(os.environ["HOME"],"ana","Efficiency",output)
    fPath=os.path.join(mu.UtsusemiEnvGetUserDir(),"ana","DNA","Efficiency",output)
    g = open(fPath,"w")

    integralArray = numpy.array(integralArray)
    psdSize, pixelSize = numpy.shape(integralArray)

    sep=","
    vals = integralArray/averageVal
    for i in range(psdSize):
        val = vals[i]
        valStr = [ "%.4e"%v for v in val]
        text = sep.join(valStr) +"\n"
        f.write(text)
        g.write(text)

    g.close()
    f.close()
#------------------------------------------------------------------------------
def CreateEfficiencyFile(DAT,energyMin,energyMax,output):
    integs,masks = PutIntegralIntensityArray(DAT,energyMin,energyMax)
    average      = PutAverageIntegralIntensity(integs,masks)
    WriteEfficiencyFile(output,integs,masks,average)
#------------------------------------------------------------------------------
def ReadEfficiencyFile(filename):
    f=open(filename)
    sep=","

    v=[]
    for line in f:
        ll = line.split(sep)
        p=[float(val) for val in ll]

        v.append(p)
    return v
#------------------------------------------------------------------------------
def ExecEfficiencyCorrection(ECM,d,mode=1):

    ecaSize = ECM.PutTableSize()
    ecm = mm.ElementContainerMatrix(ECM.PutHeader())
    for i in range(ecaSize):
        ECA = ECM.Put(i)
        eca = mm.ElementContainerArray(ECA.PutHeader())
        ecSize = ECA.PutTableSize()
        for j in range(ecSize):
            EC = ECA.Put(j)
            if mode==1:
                ec = EC.Mul(d[i][j])
            else:
                r=d[i][j]
                if r!=0:
                    ec = EC.Mul(1.0/d[i][j])
                else:
                    ec = EC.Mul(0.0)
            ech=ec.PutHeaderPointer()

            ecY  = numpy.array(ec.PutYList())
            ecE  = numpy.array(ec.PutEList())
            errMaskArray   = ecE>=0
            ecY            = ecY[errMaskArray]
            sumVal = numpy.sum(ecY)

            ech.OverWrite("TotalCounts",sumVal)
            eca.Add(ec)
        ecm.Add(eca)
    return ecm
#------------------------------------------------------------------------------
def DoEfficiencyCorrection(DAT,filename,mode=1):
    d = ReadEfficiencyFile(filename)
    DAT2 = ExecEfficiencyCorrection(DAT,d,mode)
    return DAT2

###############################################################################
# Cross Section Utils (from sample data file)
###############################################################################

def _DataExtract(filename,container):

    f = open(filename, 'r')
    dataIn = f.read()
    f.close()

    if container == "cell":
        ini = dataIn.find("// wall")
        oth = dataIn.find("// body")
    else:
        ini = dataIn.find("// body")
        oth = dataIn.find("// wall")

    if ini < oth:
        dataOut=dataIn[ini:oth]
    else:
        dataOut=dataIn[ini:]

    return dataOut

#------------------------------------------------------------------------------

def _GetData(data,tag,ind=0):

    h_i = data.find(tag,ind)

    if h_i<0:
        return -1
    else:

        h_c = data.rfind("//",0,h_i)
        h_n = data.rfind("\n",0,h_i)
        if h_c > h_n:
            return _GetData(data,tag,h_i+1)


        if data[h_i-1]==" " or data[h_i-1]=="\n" or h_i==0:

            h_e = data.find("\n",h_i)
            if  "//" in data[h_i:h_e]:
                h_e = data[h_i:h_e].find("//")+h_i

            if h_i > h_e:
                return -1
            else:
                h = float(data[h_i:h_e].split("=")[-1])
                return h
        else:
            return _GetData(data,tag,h_i+1)
#------------------------------------------------------------------------------

def _GetShape(data,ind=0):

    form_i= data.find("form",ind)

    form_c = data.rfind("//",0,form_i)
    form_n = data.rfind("\n",0,form_i)
    if form_c > form_n:
        return _GetShape(data,form_i+1)

    form_e= data.find("\n",form_i)

    if "cylinder" in data[form_i:form_e]:
        return "cylinder"
    elif "column"  in data[form_i:form_e]:
        return "column"
    else:
        return ""

#------------------------------------------------------------------------------

def _GetAtomicInfo(data):

    # Get Atomic Info Index
    atom_ds = []

    tmp_i=-1
    while 1:
        tmp_i = data.find("\n",tmp_i+1)

        tmp_c=data.find("//",  tmp_i+1)
        tmp_i=data.find("atom",tmp_i+1)
        if tmp_i < 0:
            break
 
        if tmp_c > 0 and tmp_c < tmp_i:
            pass
        else:
            tmp_e = data.find("\n",tmp_i+1)
            if tmp_c > 0 and tmp_c < tmp_e:
                tmp_e = tmp_c 

            atom_ds.append(data[tmp_i:tmp_e])

    # Get Atomic Info Data
    atoms = []
    ratio_all=0.0
    for atom_d in atom_ds:

        a1 = atom_d.find("=")+1
        a2 = atom_d.find(",")
        name = atom_d[a1:a2].strip()

        a3 = atom_d.find("=",a2)+1
        a4 = -1
        try:
            ratio = float(atom_d[a3:a4].strip())
            atoms.append([name,ratio])
            ratio_all += ratio
        except:
            pass

        if "mol" in atom_d:
            mode = "mol"
        else:
            mode = "wt"

    # re - normalize of atomic ratio
    for i in range(len(atoms)):
        atoms[i][1] = atoms[i][1]/ratio_all*100

    return atoms,mode

#------------------------------------------------------------------------------

def _ReadNISTData(atoms,nist_filename):

    g = open(nist_filename)
    data2 = g.read()
    g.close()

    for i in range(len(atoms)):
        b1=data2.find(atoms[i][0])
        if b1<0:
            atoms[i] += [0.0,0.0,0.0,0.0,0.0,0.0]
            print("\n" + atoms[i][0] + " cross section data is not found.")
        else:
            b2=data2.find("\n",b1)
            tmp = data2[b1:b2].split()
            tmp = [ float(a) for a in tmp[1:7]]
            atoms[i] += tmp

    return atoms

#------------------------------------------------------------------------------

def _GetTotalCrossSection(atoms,mode,M_tot):
    AvogadroC =6.022141*1e23

    print("\natoms:")
    if mode == "mol":
        mean_atomic_M = 0.0
        for i in range(len(atoms)):
            ratio     = atoms[i][1]/100.0
            atomic_M  = atoms[i][2]
            mean_atomic_M += atomic_M * ratio

        N_tot = M_tot / mean_atomic_M * AvogadroC # numer of atoms

        for i in range(len(atoms)):
            ratio    = atoms[i][1]/100.0
            CSinc    = atoms[i][5]
            atoms[i].append(ratio * N_tot * CSinc / (4.0*math.pi))
            print(atoms[i][0],":",mode,"%",atoms[i][1],":",CSinc,"[barn]")
    else: # mode == "wt"
        for i in range(len(atoms)):
            ratio    = atoms[i][1]/100.0
            M_each   = M_tot*ratio
            atomic_M = atoms[i][2]
            if atomic_M==0:
                N_each=0
            else:
                N_each   = M_each / atomic_M *AvogadroC  # numer of atoms
            CSinc    = atoms[i][5]
            atoms[i].append(N_each * CSinc / (4.0*math.pi))
            print(atoms[i][0],":",mode,"%",atoms[i][1],":",CSinc,"[barn]")
    sumCS=0.0
    for i in range(len(atoms)):
        sumCS += atoms[i][8]

    return sumCS

#------------------------------------------------------------------------------

def _GetTotalUnitCell(atoms,mode,M_tot):
    AvogadroC =6.022141*1e23

    print("\natoms:")
    if mode == "mol":
        mean_atomic_M = 0.0
        for i in range(len(atoms)):
            ratio     = atoms[i][1]/100.0
            atomic_M  = atoms[i][2]
            mean_atomic_M += atomic_M * ratio

        N_tot = M_tot / mean_atomic_M * AvogadroC # numer of atoms
    else: # mode == "wt"
        N_tot=0
        for i in range(len(atoms)):
            ratio    = atoms[i][1]/100.0
            M_each   = M_tot*ratio
            atomic_M = atoms[i][2]
            if atomic_M==0:
                N_each=0
            else:
                N_each   = M_each / atomic_M *AvogadroC  # numer of atoms
            N_tot += N_each

    return N_tot

#------------------------------------------------------------------------------

def GetCrossSection(container, sample_filename, nist_filename):
    data = _DataExtract(sample_filename, container)

    wpv = _GetData(data,"weight_per_volume") * 1000.0 # g/m^3

    shape = _GetShape(data)
    if shape=="cylinder":
        r_l = _GetData(data,"r_large") # mm
        r_s = _GetData(data,"r_small") # mm
    elif shape=="column":
        r_l = _GetData(data,"r_large") # mm
        r_s = 0.0                     # mm

    h  = _GetData(data,"height")       # mm
    bs = _GetData(data,"beam_size")    # mm
    vu = _GetData(data,"volume")       # mm^3

    if vu < 0:
        if bs < 0:
            volume = math.pi * (r_l * r_l - r_s * r_s) * h * 1e-9  # m^3
            print("\nsample volume = pi * (r_l * r_l - r_s * r_s) * h [mm ^ 3]")
            print("          r_l =",r_l)
            print("          r_s =",r_s)
            print("            h =",h)
        else:
            volume = math.pi * (r_l * r_l - r_s * r_s) * bs * 1e-9 # m^3
            print("\nsample volume = pi * (r_l * r_l - r_s * r_s) * beam_size [mm ^ 3]")
            print("          r_l =",r_l)
            print("          r_s =",r_s)
            print("    beam_size =",bs)

    else:
            volume = vu * 1e-9                                     # m^3
            print("\nsample volume = volume [mm ^ 3]")
            print("       volume =",vu)

    m_tot      = wpv *volume                                       # g

    atoms,mode = _GetAtomicInfo(data)

    atoms      = _ReadNISTData(atoms,nist_filename)

    cs         = _GetTotalCrossSection(atoms,mode,m_tot)

    return cs

#------------------------------------------------------------------------------

def GetNumOfUnitCell(container, sample_filename, nist_filename):
    data = _DataExtract(sample_filename, container)

    wpv = _GetData(data,"weight_per_volume") * 1000.0 # g/m^3

    shape = _GetShape(data)
    if shape=="cylinder":
        r_l = _GetData(data,"r_large") # mm
        r_s = _GetData(data,"r_small") # mm
    elif shape=="column":
        r_l = _GetData(data,"r_large") # mm
        r_s = 0.0                     # mm

    h  = _GetData(data,"height")       # mm
    bs = _GetData(data,"beam_size")    # mm
    vu = _GetData(data,"volume")       # mm^3

    if vu < 0:
        if bs < 0:
            volume = math.pi * (r_l * r_l - r_s * r_s) * h * 1e-9  # m^3
            print("\nsample volume = pi * (r_l * r_l - r_s * r_s) * h [mm ^ 3]")
            print("          r_l =",r_l)
            print("          r_s =",r_s)
            print("            h =",h)
        else:
            volume = math.pi * (r_l * r_l - r_s * r_s) * bs * 1e-9 # m^3
            print("\nsample volume = pi * (r_l * r_l - r_s * r_s) * beam_size [mm ^ 3]")
            print("          r_l =",r_l)
            print("          r_s =",r_s)
            print("    beam_size =",bs)

    else:
            volume = vu * 1e-9                                     # m^3
            print("\nsample volume = volume [mm ^ 3]")
            print("       volume =",vu)

    m_tot      = wpv *volume                                       # g

    atoms,mode = _GetAtomicInfo(data)

    atoms      = _ReadNISTData(atoms,nist_filename)

    N_tot = _GetTotalUnitCell(atoms,mode,m_tot)
    return N_tot

#------------------------------------------------------------------------------
# added by M. Matsuura 2015.6.8
def DeleteMask_DirectBeam(input, output):
    f = open(input, 'r')
    dat = f.readlines()
    f.close

    i=0
    for psd in dat:
        if i==167:
            print(i)
            psdmask = dat[i].split(" ")
            j=0
            psd167_1st=0
            psd167_2nd=0
            for pixel in psdmask:
                [psdno,pxlno] = psdmask[j].split(".")
                if int(pxlno)!=j:
                    psd167_2nd=int(pxlno)
                    break
                psd167_1st=j
                j+=1
            print(psd167_1st, psd167_2nd)
        if i==168:
            print(i)
            psdmask = dat[i].split(" ")
            j=0
            psd168_1st=0
            psd168_2nd=0
            for pixel in psdmask:
                [psdno,pxlno] = psdmask[j].split(".")
                if int(pxlno)!=j:
                    psd168_2nd=int(pxlno)
                    break
                psd168_1st=j
                j+=1
            print(psd168_1st, psd168_2nd)
        if i==169:
            print(i)
            psdmask = dat[i].split(" ")
            j=0
            psd169_1st=0
            psd169_2nd=0
            for pixel in psdmask:
                [psdno,pxlno] = psdmask[j].split(".")
                if int(pxlno)!=j:
                    psd169_2nd=int(pxlno)
                    break
                psd169_1st=j
                j+=1
            print(psd169_1st, psd169_2nd)
        i+=1
    # make new mask data for psd167
    psd167txt=""
    i=0
    while i < psd167_1st+1: 
        psd167txt+="167."+str(i)+" "
        i+=1
    i=psd167_2nd
    while i <119:
        psd167txt+="167."+str(i)+" "
        i+=1
    psd167txt+="167.119\n"
    dat[167]=psd167txt
    
    # make new mask data for psd168
    psd168txt=""
    i=0
    
    while i < psd168_1st+1: 
        psd168txt+="168."+str(i)+" "
        i+=1
    
    i=psd168_2nd
    while i <119:
        psd168txt+="168."+str(i)+" "
        i+=1
    
    psd168txt+="168.119\n"
    dat[168]=psd168txt
    
    # make new mask data for psd169
    psd169txt=""
    i=0
    
    while i < psd169_1st+1: 
        psd169txt+="169."+str(i)+" "
        i+=1
    
    i=psd169_2nd
    while i <119:
        psd169txt+="169."+str(i)+" "
        i+=1
    
    psd169txt+="169.119\n"
    dat[169]=psd169txt
    
    g = open(output, 'w')
    i=0
    for psd in dat:
        g.writelines(dat[i])
        print(dat[i])
        i+=1
    g.close
#------------------------------------------------------------------------------
def DeleteMask_Spurious(input, output):
    f = open(input, 'r')
    dat = f.readlines()
    f.close

    i=0
    for psd in dat:
        if i==49:
            print(i)
            psdmask = dat[i].split(" ")
            j=0
            psd49=0
            for pixel in psdmask:
                [psdno,pxlno] = psdmask[j].split(".")
                if int(pxlno)!=j:
                    psd49=j-1
                    break
                j+=1
            print(psd49)
        if i==50:
            print(i)
            psdmask = dat[i].split(" ")
            j=0
            psd50=0
            for pixel in psdmask:
                [psdno,pxlno] = psdmask[j].split(".")
                if int(pxlno)!=j:
                    psd50=j-1
                    break
                j+=1
            print(psd50)
        if i==51:
            print(i)
            psdmask = dat[i].split(" ")
            j=0
            psd51=0
            for pixel in psdmask:
                [psdno,pxlno] = psdmask[j].split(".")
                if int(pxlno)!=j:
                    psd51=j-1
                    break
                j+=1
            print(psd51)
        if i==61:
            print(i)
            psdmask = dat[i].split(" ")
            j=0
            psd61=0
            for pixel in psdmask:
                [psdno,pxlno] = psdmask[j].split(".")
                if int(pxlno)!=j:
                    psd61=j-1
                    break
                j+=1
            print(psd61)
        i+=1

    # make new mask 
    # psd49
    psd49txt=""
    i=0
    while i < psd49+1: 
        psd49txt+="49."+str(i)+" "
        i+=1
    i=44
    while i <119:
        psd49txt+="49."+str(i)+" "
        i+=1
    
    psd49txt+="49.119\n"
    dat[49]=psd49txt
    
    # psd50
    psd50txt=""
    i=0
    while i < psd50+1: 
        psd50txt+="50."+str(i)+" "
        i+=1
    i=44
    while i <119:
        psd50txt+="50."+str(i)+" "
        i+=1
    psd50txt+="50.119\n"
    dat[50]=psd50txt
    
    # psd51
    psd51txt=""
    i=0
    while i < psd51+1: 
        psd51txt+="51."+str(i)+" "
        i+=1
    i=44
    while i <119:
        psd51txt+="51."+str(i)+" "
        i+=1
    psd51txt+="51.119\n"
    dat[51]=psd51txt
    
    # psd61
    psd61txt=""
    i=0
    while i < psd61+1: 
        psd61txt+="61."+str(i)+" "
        i+=1
    i=44
    while i <119:
        psd61txt+="61."+str(i)+" "
        i+=1
    psd61txt+="61.119\n"
    dat[61]=psd61txt
    
    g = open(output, 'w')
    i=0
    for psd in dat:
        g.writelines(dat[i])
        print(dat[i])
        i+=1
    g.close
