#!/usr/bin/python
# -*- coding: utf-8 -*-

import matplotlib
matplotlib.use('Agg')

import pylab
import numpy
try:
    import numpy.ma as ma
except:
    import numpy.core.ma as ma
try:
    from scipy import signal
except:
    pass
from matplotlib.colors import LogNorm
import Manyo.Utsusemi as mu
from vis.M2Plot import D2Matrix

class CuiD2Chart(object):
    def __init__(self,map=None,zMin=0,zMax=1, Titles=["MainTitle","","X","Y"] ):
        """
        """
        import Manyo
        if isinstance(map, Manyo.ElementContainerArray):
            d2 = D2Matrix()
            map = d2.ReadMatrix(map)
        elif isinstance(map, Manyo.ElementContainerMatrix):
            d2 = D2Matrix()
            map = d2.ReadMatrix(map.Put(0))
        
        if type(map)==list or type(map)==tuple:
            self.arData = map[0]
            self.erData = map[1]
            self.XX = map[2]
            self.YY = map[3]
        else:
            print( type(map) )
            raise UserWarning("CuiD2Chart >> invalit 1st argument ");
        
        self.Titles = Titles
        
        #self.MASKVALUE=100000000.0
        self.MASKVALUE=1.0E+15
        try:
            import Manyo.MLF
            self.MASKVALUE=Manyo.MLF.MLF_MASKVALUE
        except:
            pass
        self.ZZ = ma.masked_where(self.arData >= self.MASKVALUE-1.0, self.arData)
        
        self.xrange = (self.XX.min(),self.XX.max())
        self.yrange = (self.YY.min(),self.YY.max())
        self.zrange = (zMin, zMax)

        self.isLog = False
        self.isSmooth = False
        self.isSmooth_win = 1.0

        self._initial()

    def _initial(self):
        """
        """
        sModeDic = {}
        sModeDic["MainPosi"] = [0.13, 0.105, 0.75, 0.75]
        sModeDic["ColorBarPosi"] = [0.90, 0.15, 0.05 ,0.62]
        sModeDic["ColorBarOri"] = 'vertical'
        sModeDic["TickLabelSize"] = 10
        sModeDic["LabelSize"] = 12
        
        self.target = sModeDic

        # Original ColorMap
        self.cdict = {'blue': ((0.0, 0.0, 0.0),
                          (0.16, 1, 1),
                          (0.377, 1, 1),
                          (0.67, 0, 0),
                          (1, 0,0)),
                 'green': ((0.0, 0, 0),
                           (0.174, 0, 0),
                           (0.410, 1, 1),
                           (0.66, 1, 1),
                           (0.915, 0, 0),
                           (1, 0, 0)),
                 'red': ((0.0, 0, 0),
                         (0.387, 0, 0),
                         (0.680, 1, 1),
                         (0.896, 1, 1),
                         (1, 0.5, 0.5))
                 }

    def SetTitles(self, mainT, subT="" ):
        self.Titles[0] = str(mainT)
        self.Titles[1] = str(subT)

    def SetAxisLabel(self, xlabel, ylabel ):
        self.Titles[2] = str(xlabel)
        self.Titles[3] = str(ylabel)
    
    def SetSmooth(self, flag, window=1.0 ):
        if isinstance( flag, bool ):
            self.isSmooth = flag
            self.isSmooth_win = float(window)
        else:
            raise UserWarning( "Invalid Arguments" )

    def SetLog(self, flag):
        if isinstance( flag, bool ):
            self.isLog = flag
        else:
            raise UserWarning( "Invalid Arguments" )

    def MakePlot(self, zrange=[], pngFile=""):
        # Prepare Plotter
        fig = pylab.figure(figsize=(6.0,6.4))
        
        # Show Main Title
        fig.text(0.1, 0.97, self.Titles[0], fontsize = 10)
        
        # Show Sub Title
        colms = self.Titles[1].split(';')
        colm_offset = float( int(900/len(colms)) )/1000.0
        xxt = 0.1
        for colm in colms:
            yyt = 0.945
            lines = colm.split('\n')
            for line in lines:
                fig.text( xxt, yyt, line, fontsize=9 )
                yyt = yyt-0.025
            xxt = xxt + colm_offset
        
        # SQEマップ用のサブプロットを準備
        ax = fig.add_subplot(111)
        ax.set_position(self.target["MainPosi"])

        # カラーバー用の座標軸を準備
        axc = fig.add_axes(self.target["ColorBarPosi"])

        # Smooth
        self.DoSmooth()
        
        # マスクデータの色作成(白色)
        #print "PlotD2Chart >> self.colormap=",self.colormap
        colormap_default = matplotlib.colors.LinearSegmentedColormap('default_utsusemi',self.cdict,256)
        cmap=colormap_default
        cmap.set_bad('w', 1.0)

        if len(zrange)>=2:
            self.zrange = (zrange[0],zrange[1])
        else:
            self.zrange = ( numpy.min(self.Zm), numpy.max(self.Zm) )

        if self.isLog:
            pc = ax.pcolormesh(self.XX, self.YY, self.Zm,  cmap=cmap, \
                                    norm=LogNorm(vmin=self.zrange[0], vmax=self.zrange[1]))
        else:
            # リニアモード
            pc = ax.pcolormesh(self.XX, self.YY, self.Zm, cmap=cmap,\
                               norm=pylab.Normalize(vmin=self.zrange[0], vmax=self.zrange[1]))


        # カラーマップと関連付けて、カラーバーを表示
        fig.colorbar(pc, cax=axc, orientation=self.target["ColorBarOri"])
        # 座標範囲を設定
        ax.set_xlim(self.xrange[0], self.xrange[1])
        ax.set_ylim(self.yrange[0], self.yrange[1])
        # カラーバーの座標ラベルの文字を小さくする
        xlabels = pylab.getp(axc, 'xticklabels')
        pylab.setp(xlabels, size=6)
        ylabels = pylab.getp(axc, 'yticklabels')
        pylab.setp(ylabels, size=6)

        # カラーマップの座標ラベルの文字を小さくする
        pylab.setp(ax.get_xticklabels(), size=self.target["TickLabelSize"])
        pylab.setp(ax.get_yticklabels(), size=self.target["TickLabelSize"])

        # X軸のラベルを表示
        ax.set_xlabel(self.Titles[2], fontsize = self.target["LabelSize"])
        # Y軸のラベルを表示
        ax.yaxis.tick_left()
        ax.set_ylabel(self.Titles[3], fontsize = self.target["LabelSize"])


        self.fig = fig

        if pngFile!="":
            self.SaveFile( pngFile )

    def SaveFile(self, filepath="test.pnt"):
        """
        """
        self.fig.savefig( filepath )

    def DoSmooth(self):
        """
        Smoothing ##[inamura 170306]
        @param  None
        @retval 無し
        """
        #if self.smoothFlag and self.Z!=None:
        if self.isSmooth and isinstance(self.ZZ, numpy.ndarray):
            smooth_window=self.isSmooth_win
            sw=int( abs( smooth_window ) )
            sw_org = int(abs( smooth_window ) )
            sw = int(sw_org*2)
            numZx=int(self.ZZ.size/self.ZZ[0].size)
            numZy=int(self.ZZ[0].size)
            Zm = numpy.zeros( (int(numZx+sw*2),int(numZy+sw*2)),float )

            for i in range(numZx):
                for j in range(numZy):
                    Zm[sw+i][sw+j]=self.ZZ[i][j]

            mX,mY = numpy.where(self.ZZ>=(self.MASKVALUE-1.0))
            for (i,j) in zip( mX, mY ):
                st_k = i-sw
                en_k = i+sw
                st_l = j-sw
                en_l = j+sw
                if st_k<0: st_k=0
                if en_k>=numZx: en_k=numZx-1
                if st_l<0: st_l=0
                if en_l>=numZy: en_l=numZy-1

                ssum = 0.0
                isum = 0
                for k in range(st_k,en_k+1):
                    for l in range(st_l,en_l+1):
                        if self.ZZ[k][l]< self.MASKVALUE:
                            isum+=1
                            ssum+=self.ZZ[k][l]

                if isum!=0:
                    Zm[sw+i][sw+j]=ssum/float(isum)

            for i in range(numZx):
                for j in range(sw):
                    Zm[sw+i][j]       = (Zm[sw+i][sw]+Zm[sw+i][sw+1])/2.0
                    Zm[sw+i][j+numZy] = (Zm[sw+i][sw+numZy-2]+Zm[sw+i][sw+numZy-1])/2.0

            for i in range(sw):
                for j in range(numZy+sw*2):
                    Zm[i][j]       = (Zm[sw][j]+Zm[sw+1][j])/2.0
                    Zm[numZx+i][j] = (Zm[numZx-2][j]+Zm[numZx-1][j])/2.0

            g = self.gauss_kern( sw_org, None)
            tmpZm = signal.convolve( Zm, g, mode='valid')
            outZm = numpy.zeros( (numZx,numZy),float )
            for i in range(numZx):
                for j in range(numZy):
                    outZm[i][j]=tmpZm[sw_org+i][sw_org+j]

            if outZm.size==self.ZZ.size:
                self.Zm = ma.masked_where(self.ZZ >= self.MASKVALUE-1.0, outZm)
            else:
                print "self.Z.size/self.Z[0].size=",numZx
                print "self.Z[0].size = ",numZy
                self.Zm = ma.masked_where(self.ZZ >= self.MASKVALUE-1.0, self.ZZ)

        else:
            self.Zm = ma.masked_where(self.ZZ >= self.MASKVALUE-1.0, self.ZZ)

    def gauss_kern(self, size, sizey=None):
        size = int(size)
        if not sizey:
            sizey = size
        else:
            sizey = int(sizey)
        x, y = numpy.mgrid[-size:size+1, -sizey:sizey+1]
        g = numpy.exp(-(x**2/float(size) + y**2/float(sizey) ) )
        return g / g.sum()


if __name__ == '__main__':
    import DR
    dat = DR.GetDataOfMonochroEi2(373,45.56,0.4,-4.0,44.0,"mask.txt",0.0)

    import ana.Reduction.CuiCtrlVisContQ as CCV
    cui=CCV.CuiCtrlVisContQ("Parameters.xml",type="OutToD4Mat")
    cui.SetECM(dat)
    #cui.SetPhi(phi)
    cui.DoSlice()

    subt = "AAA;bbb;ccc;DDD"
    titles=["Test title",subt,"Momentum Transfer (1/A)", "Energy (meV)" ]
    cc = CuiD2Chart( cui.map, 0.0,50.0, titles )
    cc.MakePlot()
    cc.SaveFile("test.png")

    del dat
    del cui
    del cc

    
