/* 
 * File:   main.cpp
 * Author: frioult
 *
 * Created on 18 mai 2010, 22:38
 */

/// calcule le score du classifieur avec les CS
/// prend en entr�e le paste du v�rif et des CSi
#include "mvminer.h"

#define MAXCLASSE 1024 // nb maximum de classe

float *scoretempo;  // score tempo avant vote
t_support bienclasse = 0; // nb d'exemples bien class�s
t_support malclasse = 0; // nb d'exemples mal class�s
t_support *classement; // indique le nombre d'instances de chaque classe bien class�es

/// aire d'un trapeze
float trapeze(int X1, int X2, int Y1, int Y2){
  float heightAVG = 0.5 * (Y1 + Y2);
  float ret = heightAVG * ABS(X1 - X2);
  //  cout << X1 << ' ' << X2  << ' ' << Y1 << ' ' << Y2 << ' ' << ret << endl;
  return ret;
}


/// classe des instances
class Instance{
public:
  /// sa classe
  int classe;
  /// son score pour positif
  float score;

  Instance(int _classe, float *scoreTempo): classe(_classe){
    if (scoreTempo[0] + scoreTempo[1] == 0.0)
      score = 1;
    else
      score = scoreTempo[0] / (scoreTempo[0] + scoreTempo[1]);
  }

  friend ostream &operator<<(ostream &os, const Instance &rhs){
    os << rhs.classe << '-' << rhs.score;
    return os;
  }

};

// -------------------------------------------------------------------
struct ltInstance
{
  bool operator()(const Instance &c1, const Instance &c2)
  {
    return c1.score < c2.score;
  }
};

// -------------------------------------------------------------------
struct gtInstance
{
  bool operator()(const Instance &c1, const Instance &c2)
  {
    return c1.score > c2.score;
  }
};

//--------------------------------------------------
/// classe des matrices de confusion
class c_confusion{
public:
  typedef multiset<Instance, gtInstance> ListInstance ;
  typedef ListInstance::iterator ListInstanceIt;
  ListInstance listInstance;

  /// indique le nombre d'instances de chaque classes bien class�es
  t_support *tab;
  /// nombre de classes
  t_item nclasse;
  /// pr�cision, rappel, fscore, calcul�s par moyenne
  float prec, rapp, fsco;

  t_support rappelDen, rappelNum, precisionDen, precisionNum;

  /// constructeur
  c_confusion(t_item _nclasse) :
    tab(static_cast<t_support *>(calloc(_nclasse * _nclasse, sizeof(t_support)))),
    nclasse(_nclasse),
    prec(0.0),
    rapp(0.0),
    fsco(0.0),
    rappelDen(0), rappelNum(0), precisionDen(0), precisionNum(0)
  {}

  /// calcul static du fscore
  static float calculefs(float precision, float rappel){
    if (precision + rappel)
      return 2 * precision * rappel / (precision + rappel);
    else
      return 0.0;
  }

  /// destructeur
  ~c_confusion(void){ free(tab); }

  /// ajoute une instance de classe classe au classement choisi
  void ecris(t_item classe, t_item classement, float *scoreTempo){
    tab[nclasse * classe + classement] ++;
    Instance instance(classe, scoreTempo);
    listInstance.insert(instance);
 }

  /// indique le nombre d'instances de la classe
  t_support taille(t_item classe) const {
    t_support ret = 0;
    for(t_item j = 0; j < nclasse; j ++)
      ret += tab[nclasse * classe + j];
    return ret;
  }

  /// indique le nombre d'instances class�es dans la classe
  t_support choix(t_item classe) const {
    t_support ret = 0;
    for(t_item i = 0; i < nclasse; i ++)
      ret += tab[nclasse * i + classe];
    return ret;
  }

  /// indique la pr�cision pour la classe
  float precision(t_item classe) {
    t_support ret = choix(classe);
    precisionDen += ret;
    precisionNum += tab[nclasse * classe + classe];
    if (ret)
      return 100 * static_cast<float>(tab[nclasse * classe + classe]) / ret;
    else
      return 0.0;
  }

  /// indique le rappel pour la classe
  float rappel(t_item classe) {
    t_support ret = taille(classe);
    rappelDen += ret;
    rappelNum += tab[nclasse * classe + classe];
    if (ret)
      return 100 *  static_cast<float>(tab[nclasse * classe + classe]) / ret;
    else
      return 0.0;
  }

  void calcule(void){
    for(t_item i = 0; i < nclasse; i ++){
      float precisioni = precision(i), rappeli = rappel(i), fscorei;
      fscorei = calculefs(precisioni, rappeli);
      prec += precisioni;
      rapp += rappeli;
      fsco += fscorei;
    }
    prec /= nclasse;
    rapp /= nclasse;
    //fsco /= nclasse;
    fsco = calculefs(prec, rapp);
  }

  /// retourne l'aire sous la courbe ROC
  void ROCpoints(void){
    int TP = 0;
    int FP = 0;
    int TPprev = 0;
    int FPprev = 0;
    int P = taille(0);
    int N = taille(1);
    float fprev = -1.0;
    float aire = 0.0;
    int a = 0;

    cout << "<roc>" << endl;
    for (ListInstanceIt instance = listInstance.begin(); instance != listInstance.end(); instance ++){
      if (instance->score != fprev){
	aire += trapeze(FP, FPprev, TP, TPprev);
	a ++;
	fprev = instance->score;
	FPprev = FP;
	TPprev = TP;
	cout << "<point x=\"" << static_cast<float>(FP) / N << "\" y=\"" <<  static_cast<float>(TP) / P << "\"/>" << endl;
      }
      if (instance->classe == 0)
	TP ++;
      else
	FP ++;
    }
    cout << "<point x=\"" << static_cast<float>(FP) / N << "\" y=\"" <<  static_cast<float>(TP) / P << "\"/>" << endl;
    aire += trapeze(N, FPprev, P, TPprev);
    a ++;
    cout << "<auc>" << 100.0 * aire / P / N << "</auc>" << endl;
    cout << "</roc>" << endl;
  }

  friend ostream &operator<<(ostream &os, c_confusion &rhs){
    //os << "#class\tpreci\trap\tfscore\tconfusion" << endl;
    for(t_item i = 0; i < rhs.nclasse; i ++){
      float precisioni = rhs.precision(i), rappeli = rhs.rappel(i), fscorei;
      fscorei = c_confusion::calculefs(precisioni, rappeli);
//      rhs.prec += precisioni;
//      rhs.rapp += rappeli;
//      rhs.fsco += fscorei;
      os << "<class index=\"" << static_cast<int>(i + 1) << "\">\n";
      os << "<precision>" << precisioni << "</precision>" << endl;
      os << "<recall>" << rappeli << "</recall>" << endl;
      os << "<fscore>" << fscorei << "</fscore>" << endl;
      for(t_item j = 0; j < rhs.nclasse; j ++)
	os << "<decision class=\"" << static_cast<int>(i + 1) << "\">" << static_cast<int>(rhs.tab[rhs.nclasse * i + j]) << "</decision>" << endl;
      os << "</class>" << endl;
    }
//    rhs.prec /= rhs.nclasse;
//    rhs.rapp /= rhs.nclasse;
//    rhs.fsco /= rhs.nclasse;
    os << "<average>" << endl;
      os << "<precision>" << rhs.prec << "</precision>" << endl;
      os << "<recall>" << rhs.rapp << "</recall>" << endl;
      os << "<fscore>" << rhs.fsco << "</fscore>" << endl;
      os << "</average>" << endl;
    return os;
  }
};

c_confusion *confusion = NULL;

//--------------------------------------------------
int vote(t_item taille){
  // calcule le maximum des scoretempo[] et retourne la classe correspondante
  int ret = 0;
  bool egalite = true;
  for(t_item i = 0; i < taille; i++){
    //cout << i << ':' << scoretempo[i] << ' ';
    if(scoretempo[ret] != scoretempo[i]) egalite = false;
    if(scoretempo[ret] < scoretempo[i]) ret = i;
  }
  //if(egalite) return -1;
  return ret;
}

//--------------------------------------------------
int chercheclasse(char *ligne){
  // retourne la valeur de la classe dans la ligne
  char *s = ligne;
  do ligne++; while(*ligne != ' ');
  *ligne = 0;
  int ret = (t_item)atoi(s);
  *ligne = ' ';
  return ret;
}

//--------------------------------------------------
void lectureTrans(char *ligne,t_support nligne){
  // indique si l'exemple est bien class�
  char *s = NULL;
  bool encore = true;
  int cl = chercheclasse(ligne) - 1;
  //cout << "classe " << cl << " " << ligne << ' ' << endl;

  t_item ind = 0; // indice des scores
  while(encore){
    switch(*ligne){
    case ':' :
      s = ligne + 1;
      break;
    case 0:encore = false;
    case '\t':
      if(s){  // si ce n'est pas le premier '\t'
	*ligne = 0;
	scoretempo[ind] = atof(s);
	ind++;
      }
      break;

    case ' ':case '.':case '+':case '-':case 'e':case 'E':
    case '0':case '1':case '2':case '3':case '4':
    case '5':case '6':case '7':case '8':case '9':
      break;
    default:
      cerr << "score : error line " << nligne << " unknown char " << ligne << '.' << endl;
      exit(-1);
      break;
    }
    ligne++;
  }
  int v = vote(ind);
  //cout << "nclasse = " << (int)ind << endl;
  if(!confusion)
    confusion = new c_confusion(ind);
  //  cout << "max = " << v << endl;
  confusion->ecris(cl, v, scoretempo);
  if( v == cl){
    //        cout << "bien classe " << endl;
    bienclasse ++;
    classement[cl] ++;
  }
}

//----------------------------------------------------------------------------
float fichTrans(istream *fich){
  char *buff;
  t_support nligne = 0;   // nb d'exemples
  t_support nllue = 1;    // n� ligne lue

  while(!fich->eof()){
    buff = fgetline(fich);
    if(*buff){
      if(*buff == '#'){
	//cout << buff << endl;
      }else{
	nligne++;
	//cout << "lecture " << buff << endl;
	lectureTrans(buff,nligne);
      }
    }
    nllue ++;
    delete buff;
  }
  confusion->calcule();
  confusion->ROCpoints();
  return static_cast<float>(100 * bienclasse) / nligne;
}

/////////////////////////////////// MAIN //////////////////////////////
int main(int argc, char **argv){
  cout << "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\" ?>" << endl;
  cout << "<root>" << endl;

  scoretempo = (float *)calloc(MAXCLASSE, sizeof(float));
  classement = (t_support *)calloc(MAXCLASSE, sizeof(float));
  float score = 0.0;
  if(argc == 1){
    //cout << "# score de classification depuis l'entree standard " << endl;
    score = fichTrans(&cin);
  }else{
    ifstream *ref = new ifstream(argv[1]);
    if(!*ref){
      cerr << "error opening the input file : " << argv[1] << endl;
      exit(-1);
    }
    //cout << "# score de classification depuis " << argv[1] << endl;
    score = fichTrans(ref);
    ref->close();
  }
  cout << "<score>" << score << "</score>" << endl;
  cout << *confusion;
  
  cout << "</root>" << endl;
  return (EXIT_SUCCESS);
}

