#!/usr/bin/python3
# -*- coding: utf-8 -*-

from __future__ import print_function
import os
from xml.dom.minidom import parse

import Manyo as mm
import Manyo.Utsusemi as mu
import utsusemi.ana.Reduction.BaseCommands as BC
import utsusemi.ana.Reduction.UtilsOnChoppers as UTL
import numpy as np


class CorrectABC(object):
    def __init__(self):
        self.p1 = 0.0
        self.p2 = 0.0
        self.p3 = 0.0  # ideal position

        self.Ain = 0.0
        self.Bin = 0.0
        self.Cin = 0.0  # original A,B,C

        self.p1in = 0.0
        self.p2in = 0.0
        self.p3in = 0.0  # actual results

    def SetParam(self, ABC=[0.0, 0.0, 0.0], Pin=[0.0, 0.0, 0.0], Pidel=[0.0, 0.0, 0.0]):
        self.p1 = float(Pidel[0])
        self.p2 = float(Pidel[1])
        self.p3 = float(Pidel[2])  # ideal position

        self.Ain = float(ABC[0])
        self.Bin = float(ABC[1])
        self.Cin = float(ABC[2])  # original A,B,C

        self.p1in = float(Pin[0])
        self.p2in = float(Pin[1])
        self.p3in = float(Pin[2])  # actual results

    def _Calc_R(self, x):
        return self.Ain / (x + self.Cin) - self.Bin

    def DoIt(self):
        R1 = self._Calc_R(self.p1in)
        R2 = self._Calc_R(self.p2in)
        R3 = self._Calc_R(self.p3in)
        print("p1,p2,p3(idleal)=", self.p1, self.p2, self.p3)
        print("R1,R2,R3=", R1, R2, R3)
        RR = (R2 - R3) / (R1 - R2)
        PP = (self.p2 - self.p3) / (self.p1 - self.p2)
        print("RR,PP=", RR, PP)
        newB = (R1 * RR - R3 * PP) / (PP - RR)

        newA = (self.p1 - self.p2) * (R1 + newB) * (R2 + newB) / (R2 - R1)

        newC = newA / (R1 + newB) - self.p1

        return [newA, newB, newC]


class AutoCorrectPsdParams(object):
    def __init__(self, isUpConvex=True):
        self.data = None
        self.wiringFile_in = ""
        self.wiringFile_out = ""
        self.peakRange = []
        self.TOFrange = []
        self.isUpConvex = isUpConvex
        inst_dir = mu.UtsusemiEnvGetInstDir()
        if inst_dir != "":
            self.PathToXml = os.path.join(inst_dir, "ana", "xml")
        else:
            self.PathToXml = os.path.join(
                mu.UtsusemiEnvGetUserDir(), "ana", "xml")
        self.CABC = CorrectABC()
        self.IdealPosiDic = {}
        self.IgnorePsd = []
        self.CommentHead = "AutoCorrectPsdParam >>> "

    def SetTarget(self, dat=None, wiringFile_in="", wiringFile_out="", peakRange=[], TOFmin=0, TOFmax=40000):
        print("SetTarget ( data, wFile_in, wFile_out, peakRange, TOFmin, TOFmax )")
        print("data : ElementContainerMatrix")
        print("wFile_in : initial wiring file")
        print("wFile_out : output wiring file")
        print(
            "peakRange : [ [xmin_peak1, xmax_peak1], [xmin_peak2, xmax_peak2], [xmin_peak3, xmax_peak3] ]")
        print("TOFmin, TOFmax : TOF integration region. defaults: 0, 40000")
        # if type(dat)!=type(mm.ElementContainerMatrix()):
        if not isinstance(dat, mm.ElementContainerMatrix):
            print("data must be ElementContainerMatrix!")
            return
        if isinstance(wiringFile_in, str) and isinstance(wiringFile_out, str):
            fullpath_wFile_in = os.path.join(self.PathToXml, wiringFile_in)
            fullpath_wFile_out = os.path.join(self.PathToXml, wiringFile_out)
            if not os.path.exists(fullpath_wFile_in):
                print("Cannot find wiring file : ",
                      wiringFile_in, " in ", self.PathToXml)
                return
            if os.path.exists(fullpath_wFile_out):
                print("Already there is wiring file : ",
                      wiringFile_out, " in ", self.PathToXml)
                print("Overwrite it.")
        else:
            print("wirigin files are needed !")
            return

        if not isinstance(peakRange, list):
            print("peakRange must be list!")
            return
        elif len(peakRange) != 3:
            print("peakRange must have 3 lists!")
            return

        self.data = dat
        print("TYPE(dat)=", type(self.data))
        self.wiringFile_in = fullpath_wFile_in
        self.wiringFile_out = fullpath_wFile_out
        self.peakRange = peakRange
        self.TOFrange = [TOFmin, TOFmax]

        self.EditDI = mu.WiringInfoEditorNeunet(self.wiringFile_in)

    def SetTargetSimple(self, dat=None, wiringFile_in="", wiringFile_out="", peakRange=[], TOFmin=0, TOFmax=40000):
        print("SetTarget ( data, wFile_in, wFile_out, peakRange, TOFmin, TOFmax )")
        print("data : ElementContainerMatrix")
        print("wFile_in : initial wiring file")
        print("wFile_out : output wiring file")
        print(
            "peakRange : [ [xmin_peak1, xmax_peak1], [xmin_peak2, xmax_peak2], [xmin_peak3, xmax_peak3] ]")
        print("TOFmin, TOFmax : TOF integration region. defaults: 0, 40000")
        # if type(dat)!=type(mm.ElementContainerMatrix()):
        if not isinstance(dat, mm.ElementContainerMatrix):
            print("data must be ElementContainerMatrix!")
            return
        if isinstance(wiringFile_in, str) and isinstance(wiringFile_out, str):
            fullpath_wFile_in = wiringFile_in
            fullpath_wFile_out = wiringFile_out
            if not os.path.exists(fullpath_wFile_in):
                print("Cannot find wiring file : ",
                      wiringFile_in, " in ", self.PathToXml)
                return
            if os.path.exists(fullpath_wFile_out):
                print("Already there is wiring file : ",
                      wiringFile_out, " in ", self.PathToXml)
                print("Overwrite it.")
        else:
            print("wirigin files are needed !")
            return

        if not isinstance(peakRange, list):
            print("peakRange must be list!")
            return
        elif len(peakRange) != 3:
            print("peakRange must have 3 lists!")
            return

        self.data = dat
        print("TYPE(dat)=", type(self.data))
        self.wiringFile_in = fullpath_wFile_in
        self.wiringFile_out = fullpath_wFile_out
        self.peakRange = peakRange
        self.TOFrange = [TOFmin, TOFmax]

        self.EditDI = mu.WiringInfoEditorNeunet(self.wiringFile_in)

    def ReadIdealPositions(self, XMLfile):
        filename = os.path.join(self.PathToXml, XMLfile)
        # Positions XML File
        if os.path.exists(filename):
            xmlfile = parse(filename)
        else:
            print(self.CommentHead + "Can't read ", filename)
            return

        # XML
        # root
        root = xmlfile.getElementsByTagName("psdCorrect")[0]
        idealPosi_ele_list = root.getElementsByTagName("idealPosi")
        for idealPosi_ele in idealPosi_ele_list:
            det_id = int(idealPosi_ele.getAttribute("detId"))
            con_str = idealPosi_ele.firstChild.data
            con_lst = con_str.split(',')
            self.IdealPosiDic[str(det_id)] = [float(
                con_lst[0]), float(con_lst[1]), float(con_lst[2])]

    def SetIgnorePsd(self, psd_list=[]):
        if not isinstance(psd_list, list):
            print(self.CommentHead + " psd_list must be list.")
            return

        for psd in psd_list:
            self.IgnorePsd.append(int(psd))

    def BeginAll(self):
        num_of_psd = self.data.PutTableSize()
        # ECA = mm.ElementContainerArray()
        for i in range(num_of_psd):
            detId = self.data(i).PutHeader().PutInt4("PSDID")
            print("detID = ", detId)
            if detId in self.IgnorePsd:
                continue

            v = self.EditDI.PutPsdParams(detId)
            (PA, PB, PC, LLD1, LLD2) = (v[0], v[1], v[2], v[3], v[4])
            if (PA, PB, PC, LLD1, LLD2) != (-1, -1, -1, -1, -1):
                EC = BC.SumOfTOF(self.data, detId,
                                 self.TOFrange[0], self.TOFrange[1])
                try:
                    (p0, p1, p2) = self._PeaksFit(EC, self.peakRange)

                except:
                    print("detID(%d):No Fitting" % (detId))
                    pass
                [IdealP0, IdealP1, IdealP2] = self.IdealPosiDic[str(detId)]
                self.CABC.SetParam([PA, PB, PC], [p0, p1, p2], [
                                   IdealP0, IdealP1, IdealP2])
                [newA, newB, newC] = self.CABC.DoIt()
                self.EditDI.SetPsdParams(detId, newA, newB, newC, LLD1, LLD2)
                print("detID(%d):[%10.5f, %10.5f, %10.5f] => [%10.5f, %10.5f, %10.5f]" % (
                    detId, PA, PB, PC, newA, newB, newC))

        self.EditDI.Write(self.wiringFile_out)

    def BeginOne(self, detId, peakRange=[], ABC=[], fixList=[-1, -1, -1]):

        v = self.EditDI.PutPsdParams(detId)
        (PA, PB, PC, LLD1, LLD2) = (v[0], v[1], v[2], v[3], v[4])
        PA = ABC[0]
        PB = ABC[1]
        PC = ABC[2]
        if (PA, PB, PC, LLD1, LLD2) == (-1, -1, -1, -1, -1):
            print(self.CommentHead + " Invalid det-ID for PSD params.")
            return

        SIH = mm.SearchInHeader(self.data)
        SIH.SearchArray("PSDID", detId)
        ret_vec = SIH.PutResultIndex(0)
        if ret_vec.size() == 0:
            print(self.CommentHead + " Invalid det-ID for histogram.")
            return

        if not isinstance(peakRange, list):
            print("peakRange must be list!")
            return
        elif len(peakRange) != 3:
            print("peakRange must have 3 lists!")
            return

        EC = BC.SumOfTOF(self.data, detId, self.TOFrange[0], self.TOFrange[1])
        (p0, p1, p2) = self._PeaksFit(EC, peakRange, fixList)
        [IdealP0, IdealP1, IdealP2] = self.IdealPosiDic[str(detId)]
        self.CABC.SetParam([PA, PB, PC], [p0, p1, p2],
                           [IdealP0, IdealP1, IdealP2])
        [newA, newB, newC] = self.CABC.DoIt()
        self.EditDI.SetPsdParams(detId, newA, newB, newC, LLD1, LLD2)
        print("detID(%d):[%10.5f, %10.5f, %10.5f] => [%10.5f, %10.5f, %10.5f]" % (
            detId, PA, PB, PC, newA, newB, newC))

    def _FitOnePeak(self, EC, a1_in, a2_in, a3_in, a4_in, a5_in):
        fit = UTL.FitOnChoppers()
        pixel_vec = EC.PutX()
        xx = []
        for i in range(pixel_vec.size() - 1):
            xx.append((pixel_vec[i] + pixel_vec[(i + 1)]) / 2.0)

        xx = np.array(xx)
        yy = np.array(EC.PutY())

        a1 = fit.Parameter(a1_in)
        a2 = fit.Parameter(a2_in)
        a3 = fit.Parameter(a3_in)
        a4 = fit.Parameter(a4_in)
        a5 = fit.Parameter(a5_in)

#        def fit_func_peaks(x): return a1()+a2()*x+a3()*exp( -0.50*((x-a4())/a5())**2 )
        def fit_func_peaks(x):
            return a1() + (a2() * x) - a3() * np.exp(-0.50 * ((x - a4()) / a5())**2)

        ret = fit.fit(fit_func_peaks, [a1, a2, a3, a4, a5], yy, xx)

        return (a1(), a2(), a3(), a4(), a5())

    def _PeaksFit(self, EC, peakRangeList, fixList=[-1, -1, -1]):
        xx = EC.PutX()
        yy = EC.PutY()
        ee = EC.PutE()
        peak_posi_list = []
        for i in range(3):
            if fixList[i] != -1:
                peak_posi_list.append(fixList[i])
                continue
            EC_fit = mm.ElementContainer()
            x_fit = mm.MakeDoubleVector()
            y_fit = mm.MakeDoubleVector()
            e_fit = mm.MakeDoubleVector()

            P_range = peakRangeList[i]
            # max=0.0
            max_int = yy[P_range[0]]
            max_p = 0
            for p in range(int(P_range[0]), int(P_range[1]) + 1):
                x_fit.append(xx[p])
                y_fit.append(yy[p])
                e_fit.append(ee[p])
                if self.isUpConvex:
                    if yy[p] > max_int:
                        max_int = yy[p]
                        max_p = p
                else:
                    if yy[p] < max_int:
                        max_int = yy[p]
                        max_p = p

            x_fit.append(xx[int(P_range[1]) + 1])
            EC_fit.Add("x", x_fit)
            EC_fit.Add("y", y_fit)
            EC_fit.Add("e", e_fit)
            EC_fit.SetKeys("x", "y", "e")

            first_yy = yy[P_range[0]]
            last_yy = yy[P_range[1]]
            base_ini = min(first_yy, last_yy)
            slope_ini = float(last_yy - first_yy) / float(P_range[1] - P_range[0])
            # max_ini = max - (yy[P_range[0]]+yy[P_range[1]])/2.0
            # max_ini =  (yy[P_range[0]]+yy[P_range[1]])/2.0 - max
            if self.isUpConvex:
                max_ini = max_int - (yy[P_range[0]] + yy[P_range[1]]) / 2.0
            else:
                max_ini = (yy[P_range[0]] + yy[P_range[1]]) / 2.0 - max_int

            width_ini = float(P_range[1] - P_range[0]) / 20.0

            print("base_ini,slope_ini,max_ini,max_p,width_ini",
                  (base_ini, slope_ini, max_ini, max_p, width_ini))
            ret = self._FitOnePeak(
                EC_fit, base_ini, slope_ini, max_ini, max_p, width_ini)

            peak_posi_list.append(ret[3])

        return (peak_posi_list[0], peak_posi_list[1], peak_posi_list[2])
