#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
uGao pacakge
Interface for the DetectorViewer

---- Original version
Emaki package
Detector view frame
Author: ITO, Takayoshi (CROSS)
Last update: 11th July, 2013
"""

__author__ = "ITO, Takayoshi (CROSS)"
__version__ = "23.08.00"
__date__ = "9th Aug. 2023"

import time
import sys
import os
# First, and before importing any Enthought packages, set the ETS_TOOLKIT
# environment variable to qt4, to tell Traits that we will use Qt.
# Described in https://docs.enthought.com/mayavi/mayavi/auto/example_qt_embedding.html

# The ETS_TOOLKIT option for using Qt in TraitsUI was changed from "qt4" to "qt" in TraitsUI 7.0.
# Prior to TraitsUI 7.0, the only way to specify the ETS_TOOLKIT option was to use "qt4".
# However, starting with TraitsUI 7.0, you can also use "qt" to specify the ETS_TOOLKIT option,
# which allows you to use Qt5 with TraitsUI.
os.environ['ETS_TOOLKIT'] = 'qt4'

import numpy as np
from traits.api import HasTraits, Instance, on_trait_change
from traitsui.api import View, Item, Group
from mayavi.core.ui.api import MlabSceneModel, SceneEditor
from mayavi.core.ui.mayavi_scene import MayaviScene
from tvtk.pyface.api import Scene
#import Manyo
import Manyo.MLF as mm
import multiprocessing

from uGao.Calculator import PutTwoTheta, RotateWXYZ

# 処理時間を標準出力するデコレータ
def measure_time(func):
    def _measure_time(*argv, **keywords):
        ini = time.time()
        result = func(*argv, **keywords)
        fin = time.time()
        print("Processing time for {0}: {1:.4f} sec.".format(func.__name__, fin-ini))
        return result
    return _measure_time

#
# DetectorViewInterface
#
class DetectorViewInterface(HasTraits):
    """
    Interface for the DetectorViewer

    bank_info: [bankの名前, bank ID, Detector IDのlist]のlist
    bank_config: bank_infoからの座標等
    """
    scene = Instance(MlabSceneModel, ())
    view = View(Item('scene', editor=SceneEditor(scene_class=MayaviScene), resizable=True, show_label=False), resizable=True)
    #view = View(Item('scene', editor=SceneEditor(scene_class=Scene), resizable=True, show_label=False), resizable=True)
    ########################################
    def __init__(self, detectorInfoFilepath, wiringInfoFilepath=None):
        HasTraits.__init__(self)
        pool = multiprocessing.Pool()
        #self.mx = Manyo.MiniXmlReader()
        self.mx = mm.BoostXmlParser()
        #self.hasInfo = [self.mx.readXmlFile(detectorInfoFilepath, "detector")+1,
        #                self.mx.readXmlFile(wiringInfoFilepath, "wiring")+1 ]
        self.hasInfo = [+1, +1]
        if not self.mx.Load("detector", detectorInfoFilepath):
            self.hasInfo[0] = 0
        if not self.mx.Load("wiring", wiringInfoFilepath):
            self.hasInfo[1] = 0

        self.inst_code = self.mx.putTextContent("detector", "detectorInfo", "inst")
        self.bank_info = self.PutBankInfo() # [bankの名前, bank ID, Detector IDのlist]のlist
        self.bank_modules = [] # glyphのlist。
        #self.bank_names = []
        self.bank_configurations = []
        for info in self.bank_info:
            #self.bank_names.append(info[0]) # info[0]はbankの名前
            bank_config = self.PutBankConfiguration(info) # bank_config:
            if len(bank_config.keys()) == 0:
                continue
            self.bank_configurations.append(bank_config)
            ini = time.time()
            coordinates = bank_config["coordinates"].copy()
            coordinates -= bank_config["origin"].reshape(3,-1) # bank_originを原点へ移動する操作を各点に施す。

            ## bankの法線方向をZ軸に重ねる回転。ps_newは(n, 3)
            if bank_config["n_vector"][0] == 0. and bank_config["n_vector"][1] == 0.:
                ps_new = coordinates.T
            else:
                args = [ [p, bank_config["theta"], bank_config["r_vector"]] for p in coordinates.T ]
                ps_new = np.array( pool.map(RotateWXYZ, args) )

            args = [ [p, bank_config["phi"], [0., 0., 1.]] for p in ps_new ]
            ps_new = np.array( pool.map(RotateWXYZ, args) )
            xs, ys, zs = ps_new.T
            ## 以下二行で、検出器方向のピクセル間隔と、検出器の幅を1にする。後でこの逆操作をする。
            ## 描画できる四角形は正方形であるための措置。
            xs /= bank_config["interval"] # X座標をピクセルの検出器方向で割り、ピクセルの間隔を1にする。
            ys /= bank_config["width"] # y座標を検出器の幅で割り、1にする。

            self.bank_modules.append(self.scene.mlab.points3d(xs,ys,zs,bank_config["scalars"],
                                           mode='2dsquare',
                                           scale_factor=1.0,
                                           scale_mode='none',))
            self.bank_modules[-1].glyph.glyph_source.glyph_source.set(filled=True)
            self.bank_modules[-1].actor.actor.rotate_z(-bank_config["phi"]) # bankを元の向きに。まずZ軸回転。
            self.bank_modules[-1].actor.actor.rotate_wxyz(-bank_config["theta"],
                                                          *bank_config["r_vector"]) # bankを元の向きに。次にR軸回転。
            self.bank_modules[-1].actor.actor.scale = [bank_config["interval"],
                                                       bank_config["width"], 1.] # bankのアスペクト比を元に戻す。
            self.bank_modules[-1].actor.actor.add_position(*bank_config["origin"]) # bankを元の位置に。
            self.bank_modules[-1].actor.mapper.use_lookup_table_scalar_range = False
            #self.bank_modules[-1].actor.mapper.scalar_range = [0., 180.0]
            self.bank_modules[-1].actor.mapper.scalar_range = [0., 10.0]
            fin = time.time()
            print("---- Processing time to make bank_module {0:.4f} sec.".format(fin-ini))
        print("End: bank_info loop")

    ########################################
    def SetPickerCallback(self, picker_callback):
        self.picker = self.scene.scene.mayavi_scene.on_mouse_pick(picker_callback, type="cell")
        self.picker.tolerance = 0.001

    ########################################
    def PutPickerTolerance(self):
        return self.picker.tolerance

    ########################################
    def SetPickerTolerance(self, tolerance):
        self.picker.tolerance = tolerance
        return

    ########################################
    @measure_time
    def PutBankConfiguration(self, info):
        """
        DetectorInfoから検出器バンクの座標等を求める。
        座標変換に必要な情報を求めている。
        """
        xs = []
        ys = []
        zs = []
        ss = []
        detIdList = []
        template='detectorInfo/positionInfo/position,detId=%d'
        ##[inamrua 161111]-->
        numPixel = 100
        if self.hasInfo[1]:
            for iDet, detId in enumerate(info[2]):
                # n_psdBinInfo = int(self.mx.putTextContent("wiring", "wiringInfo/psdBinInfo", "n")) # wiringInfoの書式変更に伴い変更
                n_psdBinInfo = self.mx.PutNumOfElements("wiring", "wiringInfo/psdBinInfo/positionBin")
                for i_positionBin in range(n_psdBinInfo):
                    content = self.mx.putTextContent("wiring",
                              "wiringInfo/psdBinInfo/positionBin,i=%d"%i_positionBin)
                    s0 = content.split(',')
                    if s0[0].upper() == 'ALL':
                        numPixel = int(self.mx.putTextContent("wiring",
                                  "wiringInfo/psdBinInfo/positionBin,i=0",
                                  "numPixel"))
                        break
                    else:
                        bank = []
                        try:
                            for s in s0:
                                s1 = s.split('-')
                                if len(s1) == 1:
                                    bank.append(int(s1[0]))
                                elif len(s1) == 2:
                                    detIds = [int(s2) for s2 in s1]
                                    detIds.sort()
                                    bank.extend(list(range(detIds[0], detIds[1]+1)))
                                else:
                                    raise UserWarning('Invalid "positionBIn" format.')
                        except:
                            raise UserWarning('Invalid "positionBin" format.')
                        if detId in bank:
                            numPixel = int(self.mx.putTextContent("wiring",
                                      "wiringInfo/psdBinInfo/positionBin,i=%d"%i_positionBin,
                                      "numPixel"))
                            break
                else:
                    continue

                break
            else:
                print("[!] Not found in banks. Set num. pixel to 100.")
        else:
            print("[!] No wiringInfo. Set num. pixel to 100.")
        ##<--[inamura 161111]

        for iDet, detId in enumerate(info[2]):
            content = self.mx.putTextContent("detector", template%detId)
            try:
                nums = [float(s) for s in content.split(',')]
                ## width: 検出器の幅。PSDなら直径。
                pox, poy, poz, ux, uy, uz, l, width = nums[:]
            except ValueError:
                print(nums)
                print("[!] No info.: {0}".format(detId))
                continue
            if iDet == 0:
                ## u_vector: 検出器のピクセルが並ぶ方向のベクトル。
                u_vector = np.array([ux, uy, uz])
                ##[inamura 161111]-->
                """
                ## このbankの（代表としてこの検出器の）ピクセル数を取り出す。
                if self.hasInfo[1]:
                    n_psdBinInf = int(self.mx.putTextContent("wiring", "psdBinInfo", "n"))
                    for i_positionBin in range(n_psdBinInfo):
                        content = self.mx.putTextContent("wiring",
                                  "psdBinInfo/positionBin,i=%d"%i_positionBin)
                        s0 = content.split(',')
                        if s0[0] == 'All':
                            numPixel = int(self.mx.putTextContent("wiring",
                                      "psdBinInfo/positionBin,i=0",
                                      "numPixel"))
                            break
                        else:
                            bank = []
                            try:
                                for s in s0:
                                    s1 = s.split('-')
                                    if len(s1) == 1:
                                        bank.append(int(s1[0]))
                                    elif len(s1) == 2:
                                        detIds = [int(s2) for s2 in s1]
                                        detIds.sort()
                                        bank.extend(range(detIds[0], detIds[1]+1))
                                    else:
                                        raise UserWarning, 'Invalid "positionBIn" format.'
                            except:
                                raise UserWarning, 'Invalid "positionBin" format.'
                            if detId in bank:
                                numPixel = int(self.mx.putTextContent("wiring",
                                          "psdBinInfo/positionBin,i=%d"%i_positionBin,
                                          "numPixel"))
                                break
                    else:
                        print "[!] No num. pixel info. Set num. pixel to 100."
                        numPixel = 100
                else:
                    print "[!] No wiringInfo. Set num. pixel to 100."
                    numPixel = 100
            """
            ##<--[inamura 161111]
            o = np.array([pox,poy,poz])
            u = np.array([ux,uy,uz])
            u_scalar = np.sqrt(ux**2+uy**2+uz**2)
            if u_scalar == 0.0:
                return dict()
            p_zero = o-l/u_scalar*u # p_zero: その検出器の第一ピクセルの座標。
            for j in range(numPixel):
                p = p_zero + float(j)/numPixel*u
                xs.append(p[0])
                ys.append(p[1])
                zs.append(p[2])
                #ss.append(np.sqrt(xs[-1]**2+ys[-1]**2+zs[-1]**2))
                #ss.append(self.PutTwoTheta(p))
                ss.append(0.0)
            detIdList.append(detId)

        else: # breakが無ければ実行。
            ## v_vector: 検出器の第一ピクセルが並ぶ方向のベクトル。
            v_vector = np.array([p_zero[0] - xs[0], p_zero[1] - ys[0], p_zero[2] - zs[0]])
            ## n_vector: u_vectorとn_vectorの外積。法線ベクトルとして利用。
            n_vector = self.PutCross(u_vector, v_vector)
            ## r_vector: n_vectorとZ軸の外積。n_vectorをZ軸に向けるための回転軸として利用。
            r_vector = self.PutCross(n_vector, [0., 0., 1.])
            ## th: n_vectorとZ軸のなす角。単位はdeg.。
            th = np.degrees(np.arccos( n_vector[2]/np.sqrt( np.sum(n_vector**2) ) ))

        xs, ys, zs, ss = list(map(np.array, [xs, ys, zs, ss])) # ここからxs, ys, zs, ssはnumpy.array
        coordinates = np.array([xs, ys, zs]) # この時点でcoordinatesのshapeは(3, n)
        bank_origin = np.array([xs[0], ys[0], zs[0]]) # そのbankの第一ピクセルの座標。shapeは(3,)
        ## bankの法線方向をZ軸に重ねる回転する操作をUに行い、回転後のUとX軸の角度を求める。
        if n_vector[0] == 0. and n_vector[1] == 0.:
            u_vector_new = u_vector.copy()
        else:
            u_vector_new = RotateWXYZ([u_vector, th, r_vector])
        phi = np.degrees(np.arccos( u_vector_new[0]/np.sqrt(np.sum(u_vector_new**2)) ))

        print(coordinates.T[0])
        bank_config = {"coordinates": coordinates,
                "scalars": ss,
                "width": width,
                "interval": u_scalar/numPixel,
                "u_scalar": u_scalar,
                "u_vector": u_vector,
                "v_vector": v_vector,
                "n_vector": n_vector,
                "r_vector": r_vector,
                "theta": th,
                "phi": phi,
                "origin": bank_origin,
                "det_id_list": detIdList,
                "n_pxl_per_det": numPixel,
                "n_det": len(ss)//numPixel,
                "num": len(ss)}
        print("---- {0}, bankId: {1}".format(info[0], info[1]))
        print("u: psd direction {0}".format(u_vector))
        print("v: psd alignment {0}".format(v_vector))
        print("n: normal direction {0}".format(n_vector))
        print("r: rotation axis {0}".format(r_vector))
        print("|U|: {0}".format(u_scalar))
        print("interval {0}".format(u_scalar//numPixel))
        print("width {0}".format(width))
        print("rotated U: {0}".format(u_vector_new))
        print("theta: {0}, phi: {1}".format(th, phi))
        print("Num. pxl: bank: {0}, detector: {1}, Num. det: {2}".format(len(ss), numPixel, len(ss)//numPixel))
        print("----")
        return bank_config

    ########################################
    @measure_time
    def PutCross(self, a, b):
        if type(a) == np.ndarray:
            a_np = a.copy()
        else:
            a_np = np.array(a)
        if type(b) == np.ndarray:
            b_np = b.copy()
        else:
            b_np = np.array(b)
        return np.cross(a_np, b_np)

    ########################################
    @measure_time
    def PutBankInfo(self):
        """
        Create and return a bank_info from a DetectorInfo
        """
        try:
            #numBank = int(self.mx.putTextContent('detector', 'bankInfo', 'n'))
            numBank = int(self.mx.putTextContent('detector', 'detectorInfo/bankInfo', 'n'))
        except TypeError:
            raise UserWarning('Invalid "BankInfo" format.')
        #template = 'bankInfo/bank,i=%d'
        template = 'detectorInfo/bankInfo/bank,i=%d'
        bank_info = []
        for i in range(numBank):
            content = self.mx.putTextContent('detector', template%i)
            name = self.mx.putTextContent('detector', template%i, 'name')
            bankId = self.mx.putTextContent('detector', template%i, 'bankId')
            print("{0}, {1}, {2}".format(name, bankId, content))
            content = content.strip()
            if content == '':
                continue
            s0 = [s.strip() for s in content.split(',') if s.strip() != "-1"]
            bank = []
            try:
                for s in s0:
                    s1 = s.split('-')
                    if len(s1) == 1:
                        bank.append(int(s1[0]))
                    elif len(s1) == 2:
                        detIds = [int(s2) for s2 in s1]
                        detIds.sort()
                        bank.extend(list(range(detIds[0], detIds[1]+1)))
                    else:
                        raise UserWarning('Invalid "BankInfo" format.')
            except:
                raise UserWarning('Invalid "BankInfo" format.')
            bank_info.append([name, int(bankId), bank])
        return bank_info

    ########################################
    @on_trait_change('scene.activated')
    def Preprocess(self):
        """
        Preprocess to instantiation the scene
        """
        self.cb = self.scene.mlab.colorbar(title="Intensity", orientation='vertical')
        self.cb.scalar_bar_representation.position = [0.01, 0.1]
        self.cb.scalar_bar_representation.position2 = [0.05, 0.4]
        # 試料位置のマーカー
        self.sample = self.scene.mlab.points3d([0],[0],[0],mode='axes',scale_factor=100.0)
        # 軸方向表示
        self.scene.scene_editor.show_axes = not self.scene.scene_editor.show_axes

    ########################################
    def SetScalarsToBankByBankId(self, bankId, data):
        for i, info in enumerate(self.bank_info):
            if bankId == info[1]:
                break
        else:
            print("[!] No matching bank ID. ID = {0}.".format(bankId))
            return
        self.SetScalarsToBank(i, data)

    ########################################
    def SetScalarsToBank(self, bankIndex, data):
        """
        @param bankIndex (int)
        @param data (numpy.array|list)
        """
        n_scalars = len(self.bank_modules[bankIndex].mlab_source.scalars)
        if len(data) == n_scalars:
            self.bank_modules[bankIndex].mlab_source.scalars = data
        elif len(data) < n_scalars:
            print("[!] Fill excess elements with 0.")
            if isinstance(data, list):
                data = np.array(data)
            data_copy = data.copy()
            data_copy.resize(n_scalars)
            self.bank_modules[bankIndex].mlab_source.scalars = data_copy
        else:
            print("Data length dose not match.")
        return

    ########################################
    def PutInstCode(self):
        return self.inst_code

    ########################################
    def RemoveSourceTest(self, bankIndex):
        print(dir(self.scene.scene.scene_editor))
        print(self.bank_modules)
        #self.bank_modules[0].stop()
        self.bank_modules[0].glyph.glyph_source.remove()
        self.bank_modules[0].remove()
        self.bank_modules.remove(self.bank_modules[0])
        print(self.bank_modules)
