/* 
$Id$

Copyright (C) 2007 by The Regents of the University of California

Redistribution of this file is permitted under
the terms of the *GNU* Public License (*GPL*) 	

Date: 04/08/2007
Author: Yiming Lu <yimingl@ics.uci.edu>
*/

#include <sys/types.h>
#include <time.h>
#include <iostream>
#include <fstream>
#include <iomanip>
#include "filtertree.h"
#include "../util/ed.h"
using std::setw;

FilterTree::FilterTree(unsigned q, 
					   const vector<string> *data, 
					   EDFunction f):
                       leftSign('#'),
                       rightSign('$'), 
					   q(q), 
					   function(f),
					   data(data) 					    					 
{
    this->headBuckets = new Buckets();
    //for analysis
    totalTime = 0;
    candidateNo = 0;
    shortCandidateNo = 0;
    patternNo = 0;
    resultNo = 0;
    queryNo = 0;
    shortQueryNo = 0;
}

FilterTree::FilterTree(const vector<string> *data,
                       const string& indexFile):
                       leftSign('#'),
                       rightSign('$'),  					  
					   data(data)
{
    this->headBuckets = new Buckets();			
    loadIndex(indexFile);
    //for analysis
    totalTime = 0;   	      
    candidateNo = 0;
    shortCandidateNo = 0;
    patternNo = 0;
    resultNo = 0;
    queryNo = 0;
    shortQueryNo = 0;
}

FilterTree::~FilterTree()
{
    for(unsigned i = 0; i < data->size(); i++)
    {
        delete countTables[i];
    }
    delete[] countTables;
    delete headBuckets;		
}

void FilterTree::build()
{
    buildBuckets();		
    this->countTables = new CountTable*[data->size()];
    for(unsigned i = 0; i < data->size(); i++)
    {
        countTables[i] = new CountTable();
    }			
}

void FilterTree::saveIndex(const string &indexFile) const
{
    //open Index file
    ofstream ofs ( indexFile.c_str() );
    //output q value
    ofs << q;
    ofs << " ";
    //output function value
    ofs << (unsigned)function;
    ofs << " ";
    //output headBuckets
    ofs << (*headBuckets);
    ofs.close();
}

void FilterTree::loadIndex(const string &indexFile)
{
    //load the value of q & function
    ifstream ifs ( indexFile.c_str(),  ifstream::in );
    ifs >> q;
    unsigned func;
    ifs >> func;
    function = (EDFunction)func;
    //initiate the countTable
    this->countTables = new CountTable*[data->size()];
    for(unsigned i = 0; i < data->size(); i++)
    {
        countTables[i] = new CountTable();
    }
    //load the headBuckets
    ifs >> (*headBuckets);
    ifs.close();	
}

void FilterTree::search(const string &query,
                        const unsigned editdist,
                        vector<unsigned> &results)
{
    string temp;
    double startTime, endTime;
    double diffTime;
    
    startTime = clock();
    getApproximateMatch(query, patternNo++, editdist, results);
    endTime = clock();
    diffTime = static_cast<double>(endTime - startTime)/CLOCKS_PER_SEC;
    totalTime += diffTime;
}

void FilterTree::buildBuckets()
{
    unsigned i,j;
    unsigned idSize = data->size();
    
    for ( i=0; i< idSize; i++ ) 
    {
        vector<string> grams;
        str2grams( (*data)[i], grams, q, leftSign, rightSign );
        if (grams.size() > 0) 
        {
            for(j=0; j< grams.size(); j++)
            {				
		    /*insert gram in the grams[j] corresponding 
		     *to the string data[i] into hash table
		     */
		    this->headBuckets->addToHashTable(grams[j].c_str(), 
				                              i, 
				                              j, 
				                              (*data)[i].size());
            }
        }
    }	
}


void FilterTree::getApproximateMatch(const string &targetWord, 
									 unsigned patternId, 
									 unsigned edThreshold, 
									 vector<unsigned>& results)
{
	unsigned commonGramsNo;
    unsigned j;
    LengthBucket *lengthBucket;
    vector<unsigned>* positions = new vector<unsigned>();
    vector<LengthBucket *> *lengthBuckets = new vector<LengthBucket *>();    
	
    if(this-> function == standardFunction)
    {
        commonGramsNo = edThreshold * q - q + 1;
    }
    else
    {
	    commonGramsNo = edThreshold * (q+1) - q +1;
    }
    vector<string> grams;
    str2grams(targetWord, grams, q, leftSign, rightSign);                
    
    for( j = 0; j < grams.size(); j++ ) 
    {
        lengthBucket = headBuckets->getMatchedBuckets(grams[j].c_str());
        if( (lengthBucket != NULL)
            && (lengthBucket->isEmpty() == false) ) 
        {
            lengthBuckets->push_back(lengthBucket);
            positions->push_back(j);
        }
    }
	//If the query is not short query, then call mergeSort
    if ( targetWord.length() > commonGramsNo  )
    {
	    mergeSort(targetWord,
	              patternId,
	              positions,
	              edThreshold,
	              lengthBuckets,
	              results);
	    queryNo++;
    }
    //If the query is short query, then call processShortStrings
    else
    {
        processShortStrings(targetWord,
                            patternId,
                            positions,
                            edThreshold,
                            lengthBuckets,
                            results);
        shortQueryNo++;
    }
    
    delete positions;
    delete lengthBuckets;	
}

void FilterTree::processShortStrings(const string &targetWord,
                                     unsigned patternId,
                                     const vector<unsigned> *positions,
                                     unsigned edThreshold,
                                     const vector<LengthBucket *> *lengthBuckets,
                                     vector<unsigned>& results)
{
    unsigned i;
    unsigned distance;
    unsigned sourceLength;
    unsigned targetLength;
    unsigned sourcePosition;  // the position corresponding to "sourceId"
    unsigned targetPosition;  // the position corresponding to "targetWord"
    unsigned commonGramsNo;
    const string *sourceWord;
    signed differentLength;
    
    
    if(this-> function == standardFunction)
    {
        commonGramsNo = edThreshold * q - q + 1;
    }
    else
    {
	    commonGramsNo = edThreshold * (q+1) - q +1;
    } 
    
    targetLength = targetWord.length();
    for ( i = 0; i < data->size(); i++ )
    {
        sourceWord = &((*data)[i]);
        sourceLength = sourceWord->length();
        differentLength = (signed)sourceLength - (signed)targetLength;
        
        if ( (sourceLength <= commonGramsNo) &&
             ((differentLength) <= (signed)edThreshold) &&
             ((- differentLength) <= (signed)edThreshold) )
        {
            shortCandidateNo ++ ;
            if(this-> function == standardFunction)
            {
                distance = ed( (*sourceWord), targetWord );
            }
            else
            {
                distance = edSwap( (*sourceWord), targetWord );
            }
            
            if ( distance <= edThreshold )
            {
                resultNo ++ ;
                results.push_back( i );
            }
        }  // end if
    }  // end for 	
    
    unsigned count = 0;
    unsigned sourceId; // ID of strings in the dataset (dictionary)	
    unsigned j,k;
    LengthBucket *lengthBucket;
    
    for (j=0; j< lengthBuckets->size(); j++) 
    {
        //traverse in gram level
        lengthBucket = lengthBuckets->at(j);
        targetPosition = positions->at(j);
        
        signed lowBoundLen   =(signed)targetLength
                             -(signed)edThreshold;
        signed upperBoundLen =(signed)targetLength
                             +(signed)edThreshold;
        signed lowBoundPos   =(signed)targetPosition
                             -(signed)edThreshold;
        signed upperBoundPos =(signed)targetPosition
                             +(signed)edThreshold;
        unsigned numOfGroups = lengthBucket->getSize();
        for(i = 0; i < numOfGroups; i++)
        {
            //traverse in length level
            sourceLength = (lengthBucket->getGroups())[i]->getLength();
            if( (signed)sourceLength > (signed)commonGramsNo 
            && (signed)sourceLength >= lowBoundLen 
            && (signed)sourceLength <= upperBoundLen )
            {
                StringPosition** strPos = 
                (lengthBucket->getGroups())[i]->getPositions();
                unsigned numOfPositions = 
                (lengthBucket->getGroups())[i]->getSize();
                for(unsigned index = 0; index < numOfPositions; index++)
                {
                    //traverse in position level
                    sourcePosition = strPos[index]->getPosition();
                    if( (signed)sourcePosition >= lowBoundPos 
                    && (signed)sourcePosition <= upperBoundPos )
                    {
                        unsigned numOfStrs = strPos[index]->getSize();
                        for( k = 0; k < numOfStrs; k++)
                        {
                            sourceId = (strPos[index]->getStrIDs())[k];
                            sourceWord = &(data->at(sourceId));
                            count = (countTables[sourceId])->getCount(patternId);
                            if ( (count != (countTables[sourceId])->getMaxCount())
                            && ((sourceLength > commonGramsNo)||
                            (targetLength > commonGramsNo)) )
                            {
                                count =
                                (countTables[sourceId])->addCount(patternId);
                                if(((signed)count
                                >= ((signed)sourceLength - (signed)commonGramsNo))
                                &&
                                ((signed)count
                                >= ((signed)targetLength - (signed)commonGramsNo))) 
                                {
                                    shortCandidateNo ++ ;
                                    if(this-> function == standardFunction)
                                        distance = ed((*sourceWord), targetWord);
                                    else
                                        distance = edSwap((*sourceWord), targetWord);
                                    //If it is already candidate, then set the count to MAXCOUNT
                                    (countTables[sourceId])->resetCount(patternId);
                                    if (distance <= (unsigned)edThreshold) 
                                    {
                                        resultNo ++ ;
                                        results.push_back(sourceId);
                                    }
                                }
                            }
                        }
                    }
                    //Skip the rest nodes in position level
                    else if ((signed)sourcePosition > upperBoundPos)
                    {
                        break;
                    }
                }
            }
            //Skip the rest nodes in length level
            else if((signed)sourceLength > upperBoundLen)
            {
                break;
            }
        }
    }	
 }
 
void FilterTree::mergeSort(const string & targetWord,
                           unsigned patternId,
                           const vector<unsigned> *positions,
                           unsigned edThreshold,
                           const vector<LengthBucket *> *lengthBuckets,
                           vector<unsigned>& results)
{
    unsigned count = 0;
    unsigned distance;      
    unsigned sourceId;        // ID of strings in the dataset (dictionary)
    unsigned sourcePosition;  // the position corresponding to "sourceId"
    unsigned targetPosition;  // the position corresponding to "targetPattern"
    unsigned targetLength;    // the length of the pattern
    unsigned sourceLength;    // the length of a string in the dataset
    const string   *sourceWord;	  // a string in the dataset (dictionary)	
    unsigned commonGramsNo;
    
    if(this-> function == standardFunction)
        commonGramsNo = edThreshold * q - q + 1;
    else
        commonGramsNo = edThreshold * (q+1) - q +1; 
    
    unsigned i,j,k;
    LengthBucket *lengthBucket;
    targetLength = targetWord.length();  	
	
    for (j=0; j< lengthBuckets->size(); j++) {
        //traverse in gram level
        lengthBucket = lengthBuckets->at(j);
        targetPosition = positions->at(j);
       
	    signed lowBoundLen = (signed)targetLength - (signed)edThreshold;
	    signed upperBoundLen = (signed)targetLength + (signed)edThreshold;
	    
	    signed lowBoundPos = (signed)targetPosition - (signed)edThreshold;
	    signed upperBoundPos = (signed)targetPosition + (signed)edThreshold;
	    
	    unsigned numOfGroups = lengthBucket->getSize();
	    for(i = 0; i < numOfGroups; i++)
	    {
	        //traverse in length level
	        sourceLength = (lengthBucket->getGroups())[i]->getLength();
	        if( (signed)sourceLength >= lowBoundLen 
	        && (signed)sourceLength <= upperBoundLen )
	        {
	            StringPosition** strPos =
	            (lengthBucket->getGroups())[i]->getPositions();
	            unsigned numOfPositions =
	            (lengthBucket->getGroups())[i]->getSize();
	            for(unsigned index = 0; index < numOfPositions; index++)
	            {
	                //traverse in position level
	                sourcePosition = strPos[index]->getPosition();
	                if((signed)sourcePosition >= lowBoundPos 
	                && (signed)sourcePosition <= upperBoundPos )
	                {
	                    unsigned numOfStrs = strPos[index]->getSize();
	                    for( k = 0; k < numOfStrs; k++)
	                    {
	                        sourceId = (strPos[index]->getStrIDs())[k];
	                        sourceWord = &(data->at(sourceId));
	                        count = (countTables[sourceId])->getCount(patternId);
	                        
	                        if ( (count != (countTables[sourceId])->getMaxCount())
	                        && ((sourceLength > commonGramsNo) ||
	                        (targetLength > commonGramsNo)))
	                        {
	                            count = (countTables[sourceId])->addCount(patternId);
	                            if (((signed)count
	                            >= ((signed)sourceLength - (signed)commonGramsNo))
	                            &&
	                            ((signed)count
	                            >= ((signed)targetLength - (signed)commonGramsNo))) 
	                            {
	                                candidateNo ++ ;
	                                if(this-> function == standardFunction)
	                                    distance = ed((*sourceWord), targetWord);
	                                else
	                                    distance = edSwap((*sourceWord), targetWord);
	                                //If it is already candidate, then set the count to MAXCOUNT
	                                (countTables[sourceId])->resetCount(patternId);
	                                
	                                if (distance <= (unsigned)edThreshold) {
	                                	resultNo ++ ;
	                                	results.push_back(sourceId);
	                                }
	                            }
	                        }
	                    }
	                }
	                //Skip the rest nodes in position level
	                else if ((signed)sourcePosition > upperBoundPos)
	                {
	                    break;
	                }
	            }
	        }
	        //Skip the rest nodes in length level
	        else if((signed)sourceLength > upperBoundLen)
	        {
	            break;
	        }
	    }
    }	
}
 
void FilterTree::reportAnalyses()
{
    cout<< "\n==============================================================" << endl;
    cout<< "= Number of strings in workload = " << patternNo  << endl;
    cout<< "= Total CPU time for approximate string processing = " << totalTime << " s" << endl;
    cout<< "= Total Number of candidates for long strings: " << (unsigned)candidateNo << endl;
    cout<< "= Total Number of candidates for short strings: " << (unsigned)shortCandidateNo << endl;
    cout<< "= Total Number of long queries: " << (unsigned)queryNo << endl;
    cout<< "= Total Number of short queries: " << (unsigned)shortQueryNo << endl;
    cout<< "= Total Number of results: " << (double)resultNo << endl;
    cout<< "= Total Memory: " << headBuckets->getMemory() <<endl;
    cout<< "==============================================================" << endl;
}
