/*
  $Id: simmetric.cc 4099 2008-11-07 21:50:29Z abehm $

  Copyright (C) 2007 by The Regents of the University of California

  Redistribution of this file is permitted under the terms of the
  BSD license

  Date: 04/15/2008
  Author: Rares Vernica <rares (at) ics.uci.edu>, Alexander Behm
*/

#include "simmetric.h"
#include "util/misc.h"

#include <cmath>
#include <cstring>
#include <iostream>

using namespace std;

// ------------------------------ SimMetric       ------------------------------

bool SimMetric::operator()(const string &s1, const string &s2, float threshold) 
  const 
{
  return operator()(s1, s2) >= threshold;
}

// ------------------------------ SimMetricEd     ------------------------------

float SimMetricEd::operator()(const string &s1, const string &s2) 
  const 
{  
  uint i, iCrt, iPre, j;
  uint
    n = s1.length(), 
    m = s2.length();

  if (n == 0)
    return m;
  if (m == 0)
    return n;

  uint d[2][m + 1];

  for (j = 0; j <= m; j++)
    d[0][j] = j;

  iCrt = 1;
  iPre = 0;
  for (i = 1; i <= n; i++) {
    d[iCrt][0] = i;
    for (j = 1; j <= m; j++)
      d[iCrt][j] = min(min(d[iPre][j] + 1, 
                           d[iCrt][j - 1] + 1), 
                       d[iPre][j - 1] + (s1[i - 1] == s2[j - 1] ? 0 : 1));
    iPre = !iPre;
    iCrt = !iCrt;
  }
  
  return d[iPre][m];
}

bool SimMetricEd::operator()(
  const string &s1, const string &s2, float threshold) 
  const 
{
  uint T = static_cast<uint>(threshold);

  uint i, j, ii, jj;
  uint
    n = s1.length(), 
    m = s2.length();

  if (n == 0)
    return m <= T;
  if (m == 0)
    return n <= T;
  if ((n > m && n - m > T) ||  
      (m > n &&  m - n > T))
    return false;

  uint d[n + 1][m + 1], dmin, dmax = T + 1;

  for (i = 0; i <= n; i++)
    d[i][0] = i;
  for (j = 1; j <= m; j++)
    d[0][j] = j;

  for (ii = 1; ii <= n; ii++) {
    dmin = dmax;
    for (j = 1; j <= min(ii, m); j++) {
      i = ii - j + 1;
      d[i][j] = min(min(d[i - 1][j] + 1,
                        d[i][j - 1] + 1),
                    d[i - 1][j - 1] + (s1[i - 1] == s2[j - 1] ? 0 : 1));
      dmin = min(dmin, min(d[i][j], d[i - 1][j]));
    }
    if (dmin > T)
      return false;
  }
  
  for (jj = 2; jj <= m; jj++) {
    dmin = dmax;
    for (j = jj; j <= min(n + jj - 1, m); j++) {
      i = n - (j - jj);
      d[i][j] = min(min(d[i - 1][j] + 1,
                        d[i][j - 1] + 1),
                    d[i - 1][j - 1] + (s1[i - 1] == s2[j - 1] ? 0 : 1));
      dmin = min(dmin, min(d[i][j], d[i - 1][j]));
    }
    if (dmin > T)
      return false;
  }

  return d[n][m] <= T;
}

uint SimMetricEd::getMergeThreshold(
  const string& query, 
  const vector<uint>& queryGramCodes, 
  const float simThreshold) 
  const 
{
  uint edThreshold = (uint)simThreshold;
  uint q = gramGen.getGramLength();
  uint numGrams = queryGramCodes.size();
  return numGrams - (q * edThreshold); 
}

void SimMetricEd::getFilterBounds(
  const string& query,
  const float simThreshold,
  const FilterType filterType,
  uint& lbound,
  uint& ubound) 
  const 
{
  uint edThreshold = (uint)simThreshold;
  switch(filterType) {

  case FT_LENGTH: {
    lbound = (query.length() - edThreshold <= 1) ? 0 : 
      query.length() - edThreshold - 1;
    ubound = query.length() + edThreshold - 1;
  } break;

  case FT_CHECKSUM: { 
    uint sum = checksum(query);
    lbound = sum - (edThreshold * CHECKSUM_ASCII_MAX);
    if((signed)lbound < 0) lbound = 0;
    ubound = sum + (edThreshold * CHECKSUM_ASCII_MAX);
  } break;

  default: {
    lbound = 0;
    ubound = 0;
  } break;

  }
}
  
float SimMetricEd::getSimMin(
  uint noGramsQuery, 
  uint noGramsData, 
  uint noGramsCommon) 
  const 
{
  cerr << "SimMetricEd::getSimMin Not Implemented" << endl; 
  exit(1); 
}

float SimMetricEd::getSimMax(
  uint lenQuery, 
  uint noGramsQuery, 
  uint noGramsData, 
  uint noGramsCommon) 
  const 
{
  return static_cast<float>(noGramsQuery - noGramsCommon) / 
    gramGen.getGramLength();
}

uint SimMetricEd::getNoGramsMin(
  uint lenQuery, 
  uint noGramsMin, 
  uint noGramsQuery, 
  float sim)
  const 
{
  float th = noGramsQuery - sim * gramGen.getGramLength();
  return th > 1 ? static_cast<uint>(floor(th)) : 1;
}

// ------------------------------ SimMetricEdNorm ------------------------------

float SimMetricEdNorm::operator()(const string &s1, const string &s2) 
  const 
{
  return 1 - static_cast<float>(SimMetricEd::operator()(s1, s2)) / 
    max(s1.length(), s2.length());
}

bool SimMetricEdNorm::operator()(
  const string &s1, const string &s2, float threshold) 
  const 
{
  // return SimMetric::operator()(s1, s2, threshold);
  return SimMetricEd::operator()(s1, s2, (1 - threshold) * 
                                 max(s1.length(), s2.length()));
}

uint SimMetricEdNorm::getMergeThreshold(
  const string& query, 
  const vector<uint>& queryGramCodes,
  const float simThreshold)
  const 
{
  cerr << "SimMetricEdNorm::getMergeThreshold Not Implemented" << endl;
  exit(1);
}

void SimMetricEdNorm::getFilterBounds(
  const string& query,
  const float simThreshold,
  const FilterType filterType,
  uint& lbound,
  uint& ubound) 
  const 
{
  cerr << "SimMetricEdNorm::getFilterBounds Not Implemented" << endl;
  exit(1);
} 

float SimMetricEdNorm::getSimMin(
  uint noGramsQuery, 
  uint noGramsData, 
  uint noGramsCommon) 
  const 
{
  return 0;
}

float SimMetricEdNorm::getSimMax(
  uint lenQuery, 
  uint noGramsQuery, 
  uint noGramsData, 
  uint noGramsCommon) 
  const 
{
  return 1 - static_cast<float>(noGramsQuery - noGramsCommon) / 
    (gramGen.getGramLength() * lenQuery);
}

uint SimMetricEdNorm::getNoGramsMin(
  uint lenQuery, 
  uint noGramsMin, 
  uint noGramsQuery, 
  float sim)
  const 
{
  float th = noGramsQuery - (1 - sim) * lenQuery * gramGen.getGramLength(); 
  return th > 1 ? static_cast<uint>(floor(th)) : 1;
}

// ------------------------------ SimMetricEdSwap ------------------------------

float SimMetricEdSwap::operator()(const string &s1, const string &s2) 
  const 
{
  uint i, iCrt, iPre, j;
  uint
    n = s1.length(), 
    m = s2.length();
  uint d[2][m + 1];

  for (j = 0; j <= m; j++)
    d[0][j] = j;

  iCrt = 1;
  iPre = 0;
  for (i = 1; i <= n; i++) {
    d[iCrt][0] = i;
    for (j = 1; j <= m; j++)
      d[iCrt][j] = min(min(d[iPre][j] + 1,
                           d[iCrt][j - 1] + 1),
                       d[iPre][j - 1] + ((s1[i - 1] == s2[j - 1] ||
                                          (i > 1 &&
                                           j > 1 &&
                                           s1[i - 1] == s2[j - 2] &&
                                           s1[i - 2] == s2[j - 1])) ? 0 : 1));
    iPre = !iPre;
    iCrt = !iCrt;
  }
  
  return d[iPre][m];
}

bool SimMetricEdSwap::operator()(
  const string &s1, const string &s2, float threshold) 
  const 
{
  return operator()(s1, s2) <= threshold;
}

uint SimMetricEdSwap::getMergeThreshold(
  const string& query, 
  const vector<uint>& queryGramCodes,
  const float simThreshold)
  const 
{
  cerr << "SimMetricEdSwap::getMergeThreshold Not Implemented" << endl;
  exit(1);
}

void SimMetricEdSwap::getFilterBounds(
  const string& query,
  const float simThreshold,
  const FilterType filterType,
  uint& lbound,
  uint& ubound) 
  const 
{
  cerr << "SimMetricEdSwap::getFilterBounds Not Implemented" << endl;
  exit(1);
} 

float SimMetricEdSwap::getSimMin(
  uint noGramsQuery, 
  uint noGramsData, 
  uint noGramsCommon) 
  const 
{
  cerr << "SimMetricEdSwap::getSimMin Not Implemented" << endl;
  exit(1);
}

float SimMetricEdSwap::getSimMax(
  uint lenQuery, 
  uint noGramsQuery, 
  uint noGramsData, 
  uint noGramsCommon) 
  const 
{
  cerr << "SimMetricEdSwap::getSimMax Not Implemented" << endl;
  exit(1);
}

uint SimMetricEdSwap::getNoGramsMin(
  uint lenQuery, 
  uint noGramsMin, 
  uint noGramsQuery, 
  float sim)
  const 
{
  cerr << "SimMetricEdSwap::noGramsMin Not Implemented" << endl;
  exit(1);
}

// ------------------------------ SimMetircGram   ------------------------------

// ------------------------------ SimMetricJacc   ------------------------------

float SimMetricJacc::operator()(const string &s1, const string &s2) 
  const 
{
  uint
    n = s1.length(), 
    m = s2.length();

  if (n == 0 || m == 0)
    return 0;

  set<uint> s1Gram, s2Gram, sUni;
  gramGen.decompose(s1, s1Gram);
  gramGen.decompose(s2, s2Gram);

  set_union(s1Gram.begin(), s1Gram.end(),
            s2Gram.begin(), s2Gram.end(), 
            inserter(sUni, sUni.begin()));
  
  uint interSize  = s1Gram.size() + s2Gram.size() - sUni.size();
  
  float d = static_cast<float>(interSize) / sUni.size();
  
  return d;
}
 
float SimMetricJacc::operator()(
  uint noGramsData, 
  uint noGramsQuery, 
  uint noGramsCommon) 
  const 
{
  return static_cast<float>(noGramsCommon) / 
    (noGramsQuery + noGramsData - noGramsCommon);
}
     
void SimMetricJacc::getFilterBounds(
  const string& query,
  const float simThreshold,
  const FilterType filterType,
  uint& lbound,
  uint& ubound) 
  const 
{
  uint numGrams = gramGen.getNumGrams(query);
  uint gramLength = gramGen.getGramLength();
  switch(filterType) {
    
  case FT_LENGTH: {
    lbound = (uint)floor((float)numGrams*simThreshold);
    ubound = (uint)ceil(((float)numGrams/simThreshold)); 
    if(gramGen.prePost) {
      lbound = lbound - gramLength + 1;
      ubound = ubound - gramLength + 1;
    }
    else {
      lbound = lbound + gramLength - 1;
      ubound = ubound + gramLength - 1;
    }
    if((signed)lbound < 0) lbound = 0;
  } break;
      
  case FT_CHECKSUM: {
    uint queryChecksum = checksum(query);
    uint minGrams = (uint)floor((float)numGrams*simThreshold);
    uint maxGrams = (uint)ceil(((float)numGrams/simThreshold)); 
    lbound = queryChecksum - ((numGrams - minGrams)*CHECKSUM_ASCII_MAX);
    ubound = queryChecksum + ((maxGrams - numGrams)*CHECKSUM_ASCII_MAX);
  } break;
      
  default: {
    lbound = 0;
    ubound = 0;
    cout << "WARNING: unknown filter passed to distancemeasure." << endl;
  } break;
      
  }   
}

uint SimMetricJacc::getMergeThreshold(
  const string& query, 
  const vector<uint>& queryGramCodes,
  const float simThreshold) 
  const 
{
  uint numGrams = queryGramCodes.size();
  return (uint)floor(simThreshold*(numGrams));  
}
 
uint SimMetricJacc::getNoGramsMin(
  uint lenQuery, 
  uint noGramsMin, 
  uint noGramsQuery, 
  float sim)
  const 
{
  float th = max(sim * noGramsQuery,
                 (noGramsQuery + noGramsMin) / (1 + 1 / sim));
  return th > 1 ? static_cast<uint>(ceil(th)) : 1;
}

// ------------------------------ SimMetricCos    ------------------------------

float SimMetricCos::operator()(const string &s1, const string &s2) 
  const 
{
  uint
    n = s1.length(), 
    m = s2.length();

  if (n == 0 || m == 0)
    return 0;

  set<uint> s1Gram, s2Gram, sInt;
  gramGen.decompose(s1, s1Gram);
  gramGen.decompose(s2, s2Gram);

  set_intersection(s1Gram.begin(), s1Gram.end(),
                   s2Gram.begin(), s2Gram.end(), 
                   inserter(sInt, sInt.begin()));
  
  float d =  static_cast<float>(sInt.size()) / 
    sqrt(s1Gram.size() * s2Gram.size());
  
  return d;
}
      
float SimMetricCos::operator()(
  uint noGramsData, 
  uint noGramsQuery, 
  uint noGramsCommon) 
  const 
{
  return noGramsCommon / sqrt(noGramsQuery * noGramsData);
}
     
uint SimMetricCos::getMergeThreshold(
  const string& query, 
  const vector<uint>& queryGramCodes,
  const float simThreshold)
  const 
{

  // uint gramLength = gramGen.getGramLength();
  uint numGrams = queryGramCodes.size();
  return (uint)floor(simThreshold*simThreshold*numGrams);
}  

void SimMetricCos::getFilterBounds(
  const string& query,
  const float simThreshold,
  const FilterType filterType,
  uint& lbound,
  uint& ubound) 
  const 
{
    
  uint numGrams = gramGen.getNumGrams(query);
  uint gramLength = gramGen.getGramLength();
    
  switch(filterType) {
      
  case FT_LENGTH: {
    lbound = (uint)floor((float)numGrams*simThreshold*simThreshold - gramLength + 1);
    ubound = (uint)ceil(((float)numGrams/(simThreshold*simThreshold)) - gramLength + 1);
    if ((signed)lbound < 0) lbound = 0;
  } break;
      
  case FT_CHECKSUM: {
    uint queryChecksum = checksum(query);
    uint minGrams = (uint)floor((float)numGrams*simThreshold*simThreshold);
    uint maxGrams = (uint)ceil((float)numGrams/(simThreshold*simThreshold));
    lbound = queryChecksum - ((numGrams - minGrams)*CHECKSUM_ASCII_MAX*gramLength);
    ubound = queryChecksum + ((maxGrams - numGrams)*CHECKSUM_ASCII_MAX*gramLength);
  } break;
      
  default: {
    lbound = 0;
    ubound = 0;
    cout << "WARNING: unknown filter passed to distancemeasure." << endl;
  } break;
      
  }
}

uint SimMetricCos::getNoGramsMin(
  uint lenQuery, 
  uint noGramsMin, 
  uint noGramsQuery, 
  float sim)
  const 
{
  float th = sim * sqrt(noGramsQuery * noGramsMin);
  return th > 1 ? static_cast<uint>(ceil(th)) : 1;
}

// ------------------------------ SimMetricDice   ------------------------------

float SimMetricDice::operator()(const string &s1, const string &s2) 
  const 
{
  uint
    n = s1.length(), 
    m = s2.length();

  if (n == 0 || m == 0)
    return 0;

  set<uint> s1Gram, s2Gram, sInt;
  gramGen.decompose(s1, s1Gram);
  gramGen.decompose(s2, s2Gram);

  set_intersection(s1Gram.begin(), s1Gram.end(),
                   s2Gram.begin(), s2Gram.end(), 
                   inserter(sInt, sInt.begin()));
  
  float d =  2. * sInt.size() / (s1Gram.size() + s2Gram.size());
  
  return d;
}
      
float SimMetricDice::operator()(
  uint noGramsData, 
  uint noGramsQuery, 
  uint noGramsCommon) 
  const 
{
  cerr << "SimMetricDice::operator(noGramsData, noGramQuery, noGramsCommon) "
       << "Not Implemented" << endl;
  exit(1);
}
     
uint SimMetricDice::getMergeThreshold(
  const string& query, 
  const vector<uint>& queryGramCodes,
  const float simThreshold)
  const 
{
  cerr << "SimMetricEdNorm::getMergeThreshold Not Implemented" << endl;
  exit(1);
} 

void SimMetricDice::getFilterBounds(
  const string& query,
  const float simThreshold,
  const FilterType filterType,
  uint& lbound,
  uint& ubound) 
  const 
{
  uint numGrams = gramGen.getNumGrams(query);
  uint gramLength = gramGen.getGramLength();
  switch(filterType) {
    
  case FT_LENGTH: {
    lbound = (uint)floor( ((float)numGrams*simThreshold) / (2.0f - simThreshold) );
    ubound = (uint)ceil( ((2.0f - simThreshold) * numGrams) / simThreshold);
    if(gramGen.prePost) {
      lbound = lbound - gramLength + 1;
      ubound = ubound - gramLength + 1;
    }
    else {
      lbound = lbound + gramLength - 1;
      ubound = ubound + gramLength - 1;
    }
    if((signed)lbound < 0) lbound = 0;
  } break;
    
  case FT_CHECKSUM: {
    uint queryChecksum = checksum(query);
    uint minGrams = (uint)floor( ((float)numGrams*simThreshold) / (2.0f - simThreshold) );
    uint maxGrams = (uint)ceil( ((2.0f - simThreshold) * numGrams) / simThreshold);
    lbound = queryChecksum - ((numGrams - minGrams)*CHECKSUM_ASCII_MAX*gramLength);
    ubound = queryChecksum + ((maxGrams - numGrams)*CHECKSUM_ASCII_MAX*gramLength);
  } break;
      
  default: {
    lbound = 0;
    ubound = 0;
    cout << "WARNING: unknown filter passed to distancemeasure." << endl;
  } break;
      
  }   
}

uint SimMetricDice::getNoGramsMin(
  uint lenQuery, 
  uint noGramsMin, 
  uint noGramsQuery, 
  float sim)
  const 
{
  cerr << "SimMetricDice::noGramsMin Not Implemented" << endl;
  exit(1);
}

// ------------------------------ SimGramCount    ------------------------------

float SimMetricGramCount::operator()(
  const string &s1, 
  const string &s2) 
  const 
{
  uint
    n = s1.length(), 
    m = s2.length();

  if (n == 0 || m == 0)
    return 0;
  
  set<uint> s1Gram, s2Gram, sInt;
  gramGen.decompose(s1, s1Gram);
  gramGen.decompose(s2, s2Gram);

  set_intersection(s1Gram.begin(), s1Gram.end(),
                   s2Gram.begin(), s2Gram.end(), 
                   inserter(sInt, sInt.begin()));
  
  return sInt.size();
}
      
float SimMetricGramCount::operator()(
  uint noGramsData, 
  uint noGramsQuery, 
  uint noGramsCommon) 
  const 
{
  return noGramsCommon;
}
     
uint SimMetricGramCount::getMergeThreshold(
  const string& query, 
  const vector<uint>& queryGramCodes,
  const float simThreshold) 
  const 
{
  cerr << "SimMetricGramCount::getMergeThreshold Not Implemented" << endl;
  exit(1);
}

void SimMetricGramCount::getFilterBounds(
  const string& query,
  const float simThreshold,
  const FilterType filterType,
  uint& lbound,
  uint& ubound) 
  const 
{
  cerr << "SimMetricGramCount::getFilterBounds Not Implemented" << endl;
  exit(1);
}

uint SimMetricGramCount::getNoGramsMin(
  uint lenQuery, 
  uint noGramsMin, 
  uint noGramsQuery, 
  float sim)
  const 
{
  return static_cast<uint>(sim);
}
