#!/usr/bin/python3
# -*- coding: utf-8 -*-

import Cmm
import Manyo
import Manyo.MLF as mm

def RadialCollimatorCorrect(dat,vdat,integRange="All",pixelRange="All"):
    """
    Data Correction for radial collimator on SIK
    
    @param dat[Def:DAT] (ElementContainerMatrix) Target data
    @param vdat[Def:DAT] (ElementContainerMatrix) Vanadium data
    @param integRange (string) Integrate range "All" or "<start>:<end>"
    @param pixelRange (string) Pixel range to be used "All" or "<start>:<end>"
    @retval None
    """
    if dat.PutSize()!=vdat.PutSize():
        raise UserWarning("ERROR : Data sizes are different between dat and vdat.")

    v_integRange = integRange.split(":")
    isAll = True
    x_range = []
    if len(v_integRange)==2:
        try:
            x_range = [float(v_integRange[0]), float(v_integRange[1])]
            isAll = False
        except:
            raise UserWarning("ERROR : invalid integRange argument")
    elif integRange.upper()!="ALL":
        raise UserWarning("ERROR : invalid integRange argument")

    start_pix = 0
    end_pix = dat(0).PutSize()-1
    
    if pixelRange.upper()!="ALL":
        v_pixel = pixelRange.split(":")
        if len(v_pixel)!=2:
            raise UserWarning("ERROR : invalid pixelRange argument")
        start_pix = int(v_pixel[0])
        end_pix = int(v_pixel[1])
        if start_pix>end_pix:
            tmp = start_pix
            start_pix = end_pix
            end_pix = tmp
        if end_pix>=dat(0).PutSize():
            end_pix = dat(0).PutSize()-1
    

    sum_vals = 0.0
    list_vals = []
    
    ignore_num =0
    for i in range(vdat.PutSize()):
        num_of_ec = 0
        det_v = Manyo.MakeUInt4Vector()
        pix_v = Manyo.MakeUInt4Vector()
        for j in range(start_pix,end_pix+1):
            if vdat(i,j).PutHeaderPointer().PutInt4("MASKED")==0:
                det_v.append(i)
                pix_v.append(j)
                num_of_ec += 1
        tt=mm.AverageElementContainerMatrix(vdat,det_v,pix_v)
        ec = tt.GetSum()
        if isAll:
            val = ec.Sum()
        else:
            p = ec.Sum(x_range[0],x_range[1])
            val = p.first
        if val==0.0 or num_of_ec==0:
            ignore_num +=1
        else:
            val /= float(num_of_ec)
            sum_vals += val
        list_vals.append(val)

    ave_vals = sum_vals/float(vdat.PutSize()-ignore_num)
    for i,val in enumerate(list_vals):
        if val!=0:
            dat(i).MulMySelf(1.0/val*ave_vals)

