#include "filter_mask.h"
#include "utils.h" // for TinyVector comparison

void FilterGenMask::init(){
  min.set_description("lower threshold");
  append_arg(min,"min");
  max.set_description("upper threshold");
  append_arg(max,"max");
}

bool FilterGenMask::process(Data<float,4>& data, Protocol& prot) const {

/*
TinyVector<int,4> datashape=data.shape();
  TinyVector<int,4> maskshape=data.shape();
  maskshape(timeDim)=1;

  Data<float,4> mask(maskshape); mask=1.0;
  TinyVector<int,4> maskindex;
  TinyVector<int,4> dataindex;
  for(int i=0; i<mask.size(); i++) {
    maskindex=mask.create_index(i);
    dataindex=maskindex;
    for(int irep=0; irep<datashape(timeDim); irep++) { // Condition must be met for all time steps
      dataindex(timeDim)=irep;
      float val=data(dataindex);
      if(val<min || val>max) mask(maskindex)=0.0;
    }
  }

  data.reference(mask);
*/

  data=where(Array<float,4>(data)>=min && Array<float,4>(data)<=max, float(1.0), float(0.0));
  return true;
}


///////////////////////////////////////////////////////////////////////////


bool FilterAutoMask::process(Data<float,4>& data, Protocol& prot) const {
  Log<Filter> odinlog(c_label(),"process");

  int nslots=100;
  
  float step=secureDivision(max(data),nslots);

  ODINLOG(odinlog,normalDebug) << "nslots/step=" << nslots << "/" << step << STD_endl; 

  Data<float,1> hist(nslots); hist=0.0;

  // create histogram
  for(int i=0; i<data.size(); i++) {
    float val=data(data.create_index(i));
    int slot=int(secureDivision(val,step));
    ODINLOG(odinlog,normalDebug) << "val/slot=" << val << "/" << slot << STD_endl; 
    if(slot>=0 && slot<nslots) hist(slot)++;
  }
//  hist.write_asc_file("hist.asc");

  // find first minimum
  float thresh=0.0;
  for(int islot=0; islot<(nslots-1); islot++) {
    if(hist(islot+1)>hist(islot)) {
      thresh=islot*step;
      break;
    }
  }
  ODINLOG(odinlog,normalDebug) << "thresh=" << thresh << STD_endl; 

  data=where(Array<float,4>(data)>thresh, float(1.0), float(0.0));

  return true;
}


///////////////////////////////////////////////////////////////////////////


typedef STD_list<TinyVector<int,4> > QuantilIndexList;
typedef STD_map<float, QuantilIndexList > QuantilIndexMap;


void FilterQuantilMask::init(){
  fraction.set_minmaxval(0.0,1.0).set_description("quantil");
  append_arg(fraction,"fraction");
}

bool FilterQuantilMask::process(Data<float,4>& data, Protocol& prot) const {
  Log<Filter> odinlog(c_label(),"process");

  int ntotal=data.size();

  float frac=fraction;
  check_range<float>(frac,0.0,1.0);
  int nmask=int((1.0-frac)*ntotal+0.5);

  ODINLOG(odinlog,normalDebug) << "ntotal/nmask=" << ntotal << "/" << nmask << STD_endl; 

  Data<float,4> mask(data.shape()); mask=0.0;

  QuantilIndexMap indexmap;

  for(int i=0; i<ntotal; i++) {
    TinyVector<int,4> index=data.create_index(i);
    indexmap[data(index)].push_back(index);
  }
  int nmap=indexmap.size();

  ODINLOG(odinlog,normalDebug) << "ntotal/nmap=" << ntotal << "/" << nmap << STD_endl; 

  QuantilIndexMap::const_iterator mapiter=indexmap.end();
  int j=0;
  while(j<nmask) {
    if(mapiter==indexmap.begin()) break;
    --mapiter;
    const QuantilIndexList& indexlist=mapiter->second;
    ODINLOG(odinlog,normalDebug) << "indexmap(" << mapiter->first << ")=" << indexlist.size() << STD_endl; 
    for(QuantilIndexList::const_iterator listiter=indexlist.begin(); listiter!=indexlist.end(); ++listiter) {
      mask(*listiter)=1.0;
      j++;
    }
  }

  data.reference(mask);

  return true;
}


///////////////////////////////////////////////////////////////////////////

void FilterUseMask::init() {

  fname.set_description("filename");
  append_arg(fname,"fname");
}


bool FilterUseMask::process(Data<float,4>& data, Protocol& prot) const {
  Log<Filter> odinlog(c_label(),"process");

  // Load external file
  Data<float,4> maskdata;
  if(maskdata.autoread(fname)<0) return false;
  TinyVector<int,4> maskshape=maskdata.shape();
  TinyVector<int,4> datashape=data.shape();

  maskshape(timeDim)=datashape(timeDim)=1;
  if(maskshape!=datashape) {
    ODINLOG(odinlog,errorLog) << "shape mismatch: " << maskshape << "!=" << datashape << STD_endl;
    return false;
  }

  fvector vals;
  for(int i=0; i<data.size(); i++) {
    TinyVector<int,4> index=data.create_index(i);
    float val=data(index);
    index(timeDim)=0;
    if(maskdata(index)) vals.push_back(val);
  }

  data.resize(1,vals.size(),1,1);
  data(0,Range::all(),0,0)=Data<float,1>(vals);

  return true;
}
