Trie树实现三----双数组trie树

来源:互联网 发布:溜冰鞋淘宝网 编辑:程序博客网 时间:2024/05/21 10:11

双数组trie结合了数组形式的转换表和链表形式的转换表的优点,不仅查询速度快,而且使用内存紧凑。

下面代码只是提供了Trie树的简单实现,没有考虑动态扩展base和check。具体实现过程的详解请查考An Efficient Implementation of Trie Structures

附上代码:

#include <iostream>#include <vector>#include <string>#include <algorithm>#include <cstdio>#include <cstdlib>#include <cstring>#include <assert.h>using namespace std;class DoubleArrayTrie{    public:typedef string::size_type size_type;static const int ROOT = 1;vector<int> base;vector<int> check;string tail;// public member functionDoubleArrayTrie():base(10024, 0), check(10024, 0), tail("##"){    base[1] = 1;};int mapTo(char ch);char mapTo(int d);pair<string, string> extractCommonPrefix(const char* const & str1, const char* const & str2);pair<vector<int>, vector<int> > extractCollision(const int& from1, const int& from2);void resolveCollision(const int& from, const vector<int>& list, const int& app);int minQ(string commonPrefix);void relabel(int beg, int end);void relabel(int beg);bool insert(const char*& key);bool remove(const char*& key);bool retrival(const char*& key);void display();};const int DoubleArrayTrie::ROOT;void DoubleArrayTrie::display(){    /* only used to test sth */    for(int i = 1; i <= 30; ++i)cout << base[i] << " ";    cout << endl;    for(int i = 1; i <= 30; ++i)cout << check[i] << " ";    cout << endl;    cout << "tail = " << tail << endl;}int DoubleArrayTrie::mapTo(char ch){    return (ch=='#')?1:(ch-'a'+2);}char DoubleArrayTrie::mapTo(int d){    return (d==1)?'#':('a'+d-2);}pair<string, string> DoubleArrayTrie::extractCommonPrefix(const char* const& str1, const char* const& str2){    string first, second;    int i = 0;    while( str1[i] == str2[i] )    {first.append(1, str1[i]);++i;    }    if( str1[i] == '\0' )second.append(1, '#');    else second.append(1, str1[i]);    if( str2[i] == '\0' )second.append(1, '#');    else second.append(1, str2[i]);    return pair<string, string>(first, second);}pair<vector<int>, vector<int> > DoubleArrayTrie::extractCollision(const int& from1, const int& from2){    vector<int> list1, list2;    for(int i = 0; i < check.size(); ++i)    {if( check[i] == from1 )    list1.push_back( i );else if( check[i] == from2 )    list2.push_back( i );    }    return pair<vector<int>, vector<int> >(list1, list2);}int DoubleArrayTrie::minQ(string commonPrefix){    int q = 1;    while( true )    {int i = 0;while( i < commonPrefix.size() ){    int w = mapTo( commonPrefix.at(i) );    if( check[q+w] )break;    ++i;}if( i == commonPrefix.size() )    return q;++q;    }}void DoubleArrayTrie::resolveCollision(const int& from, const vector<int>& list, const int& app){    string chs;    for(int i = 0; i < list.size(); ++i)chs.append( 1, mapTo(list[i]-base[from]) );    if( app != -1 )chs.append(1, mapTo(app-base[from]));    int q = minQ( chs );    int temp_base = base[from];    base[from] = q;    for(int i = 0; i < list.size(); ++i)    {int oldTo = list[i];int newTo = base[from] + (list[i]-temp_base);base[newTo] = base[oldTo];check[newTo] = check[oldTo];if( base[oldTo] > 0 ){    replace(check.begin(), check.end(), oldTo, newTo);}base[oldTo] = 0;check[oldTo] = 0;    }}void DoubleArrayTrie::relabel(int beg, int end){    for(int i = beg; i <= end; ++i)tail[i] = '?';}void DoubleArrayTrie::relabel(int beg){    int i = beg;    while( tail.at(i) != '#' )    {tail[i] = '?';++i;    }    tail[i] = '?';}bool DoubleArrayTrie::insert(const char*& key){    int i = 0;    int from = ROOT;    while( key[i] )    {int to = base[from] + mapTo( key[i] );if( !check[to] ){    // key+i+1 can be distinguished from other keys    check[to] = from;    base[to] = -tail.size();    tail.append( key+i+1 ).append(1, '#');    return true;}else if(check[to] == from && base[to] < 0){    int beg = -base[to];    size_type end = tail.find('#', beg);    assert( end != string::npos );    string suffix = tail.substr(beg, end-beg);    pair<string, string> result = extractCommonPrefix(key+i+1, suffix.c_str());    string commonPrefix = result.first, twoChars = result.second;    int q = minQ( commonPrefix );    int j = 0;    while( j < commonPrefix.size() )    {base[to] = q;from = to;to = base[from] + mapTo( commonPrefix.at(j) );check[to] = from;++j;    }    q = minQ( twoChars );    base[to] = q;    from = to;    // insert key+i+1 + commonPrefix.size() + 1    to = base[from] + mapTo( twoChars.at(0) );    check[to] = from;    if( twoChars.at(0) == '#' )    {base[to] = -1;    }    else     {base[to] = -tail.size();tail.append(key+i+1+commonPrefix.size()+1).append(1, '#');    }    // insert suffix    to = base[from] + mapTo( twoChars.at(1) );    check[to] = from;    if( twoChars.at(1) == '#' )    {base[to] = -1;// relabe as ? relabel(beg, end);    }    else     {base[to] = -beg - commonPrefix.size() - 1;// relabe as ? relabel(beg, beg+commonPrefix.size());    }    return true;}else if( check[to] != from ){    int iFrom = check[to];    pair<vector<int>, vector<int> > cols = extractCollision(iFrom, from);    vector<int> firstCols = cols.first, secondCols = cols.second;    if( firstCols.size() < secondCols.size() + 1 )resolveCollision(iFrom, firstCols, -1);    else     {resolveCollision(from, secondCols, to);// change 'to' to another positionto = base[from] + mapTo(key[i]);    }        assert( !check[to] );    check[to] = from;    base[to] = -tail.size();    tail.append( key+i+1 ).append(1, '#');    return true;}else if( check[to] == from )    from = to;++i;    }    return true;}bool DoubleArrayTrie::remove(const char*& key){    if( !retrival(key) )return false;    int i = 0;    int from = ROOT, to;    while( key[i] )    {to = base[from] + mapTo(key[i]);from = to;++i;if( base[to] < 0 ){    // relabel from 'to' to '#'    relabel(-base[to]);    base[to] = 0;    check[to] = 0;    return true;}    }    to = base[from] + mapTo('#');    if( base[to] < 0 )    {base[to] = 0;check[to] = 0;return true;    }    return false;}bool DoubleArrayTrie::retrival(const char*& key){    int i = 0;    int from = ROOT, to;    while( key[i] )    {to = base[from] + mapTo(key[i]);++i;if( check[to] != from ){    return false;}from = to;if( base[from] < 0 ){    int beg = -base[from];    while(key[i] && tail[beg] != '#')    {if( key[i] != tail[beg] )    break;++i, ++beg;    }    if( !key[i] && tail[beg] == '#' )return true;    else return false;}    }    to = base[from] + mapTo('#');    if( check[to] == from && base[to] == -1 )return true;    return false;}void Test(){    DoubleArrayTrie dat;    const char* keys[] = {"bachelor", "jar", "badge", "baby", "int", "integer", "flow", "float"};    int n = 8;    for(int i = 0; i < n; ++i)    {dat.insert( keys[i] );    }    for(int i = 0; i < n; ++i)    {if( dat.retrival(keys[i]) )    cout << "find " << keys[i] << endl;else cout << "fail to find " << keys[i] << endl;    }    cout << "------------------------------" << endl;    for(int i = 0; i < n; ++i)    {dat.remove( keys[i] );for(int j = 0; j < n; ++j){    if( dat.retrival(keys[j]) )cout << "find " << keys[j] << endl;    else cout << "fail to find " << keys[j] << endl;}dat.display();cout << "------------------------------" << endl;    }}int main(){    Test();    return 0;}