ext/hash_map:进一步提高字符串为键的哈希表的性能

来源:互联网 发布:好东东网络课 编辑:程序博客网 时间:2024/06/02 01:36

http://hi.baidu.com/ah__fu/item/c8a1d4e17e79e5f52b09a458

当我们在ext/hash_map使用string或const char*为键的时候,通常需要使用一个HASH函数将字符串转换为一个32位的整型值,然后再与一个大质数取模,最终将节点分布到不同的桶里面去。
     可是,当我们作为键的字符串很长时,每次进行插入、查找和删除操作的时候,字符串都要调用一次HASH函数。于是可以想到,如果我们预先计算出字符串的HASH值,每次都通过HASH值来查找,就省去了使用HASH函数的开销。
     首先,在设计的时候,字符串的哈希表应看成 string -> struct ,就是通过字符串去索引一个节点。而字符串本身需要存储空间,所以可以把字符串需要的空间存储在节点上。可以这样设计这个节点:
      struct HashNode
      {
          char Key[100];     //索引这个节点的字符串,唯一的
          size_t HashCode;      //通过上面的字符串和哈希函数计算出来的hash值
          short KeyLength;      //除了缓存HASH值,把字符串的长度也缓存了
          //其他的信息放在后面
      };

      在想好如何存储后,下面就需要考虑如何用节点中的HASH值代替字符串HASH函数计算出来的HASH值。其实,ext/hash_map有五个模板参数,分别是:
       hash_map<键的类型,值的类型,HASH函数,比较函数,分配器>
       我们只需要写一个函数对象来产生HASH值就行了:
struct NodeHasher
{
      size_t operator()(const HashNode* node) const
      {
          return node->HashCode;      //直接返回计算好的HASH值就行,避免使用HASH函数
      }
};
       有了HASH函数,我们还必须写一个节点比较的函数对象,用于判断两个节点是否是相等:
struct NodeCompare
{
      bool operator()(const HashNode* lsh, const HashNode* rsh) const
      {
          return lsh->HashCode == rsh->HashCode &&  //为什么不只比较HASH值就可以了呢?
                    lsh->KeyLength == rsh->KeyLength &&  //因为两个不同的字符串可能产生同样的HASH值
                    0==strcmp(lsh->Key, rsh->Key);               //所以在发生HASH冲突的时候一定要比较内容
      }
};

       OK, 下面就可以这样建立哈希表了:
           hash_map<HashNode*, HashNode*, NodeHasher, NodeCompare> hash;
       要注意使用预先计算HASH值的方法,一定要在添加、查找和删除前预先计算好节点的HASH值,计算字符串的HASH值可以这样计算:
        size_t hash_code = __gnu_cxx::__stl_hash_string("string");

        预先计算HASH值真的能提高HASH表的性能吗?请看下面的测试代码:

#include <stdio.h>
#include <ext/hash_map>
using namespace __gnu_cxx;
#include <utility>
#include <functional>
using namespace std;
#include <string.h>
#include <assert.h>

#define MAX_USERNAME_LEN 40
#define P(format, ...) printf("%s %s %d " format "\n", __FILE__, __FUNCTION__, __LINE__, ##__VA_ARGS__)

//调用INTEL CPU指令RDTSC来获得时间计数,便于得到代码段的性能指标
unsigned long long rdtsc()
{
#ifdef _MSC_VER /* msvc compiler */
      __asm _emit 0x0F
      __asm _emit 0x31
#else /* gcc compiler */
      unsigned long long temp;
      unsigned int low, high;
      __asm__ __volatile__("rdtsc" : "=a" (low), "=d" (high));
      temp = high;
      temp <<= 32;
      temp += low;
      return temp;
#endif
}

struct UserInfo
{
      char UserName[MAX_USERNAME_LEN];
      short Length;
      unsigned int HashCode;
};

#define MAX_USERS 600000
UserInfo* pUsers = NULL;
int UserCount = 0;

void make_hash(UserInfo* Users, int UserCount);
void make_char_hash(UserInfo* Users, int UserCount);
void make_CachedHash(UserInfo* Users, int UserCount);

void test()
{
      //读入用户
      FILE* fp = fopen("users.txt", "r");
      if (NULL==fp)
      {
          P("open file error");
          return;
      }
      pUsers = new UserInfo[MAX_USERS];
      char temp_username[MAX_USERNAME_LEN];
      while (NULL!=fgets(temp_username, sizeof(temp_username), fp))
      {
          strncpy(pUsers[UserCount].UserName, temp_username, sizeof(pUsers[0].UserName)-1);
          pUsers[UserCount].Length = strlen(pUsers[UserCount].UserName);
          pUsers[UserCount].HashCode = __gnu_cxx::__stl_hash_string(pUsers[UserCount].UserName);
          UserCount++;
      }
      fclose(fp);
      fp = NULL;
      P("user count=%d", UserCount);
      //
      make_hash(pUsers, UserCount);
      make_char_hash(pUsers, UserCount);
      make_CachedHash(pUsers, UserCount);
      delete[] pUsers;
      pUsers = NULL;
}

int main()
{
      test();
      return 1;
}

//=============================================================================

#include <string>
using namespace std;

struct str_hash
{
      size_t operator()(const string& str) const
      {
          return __stl_hash_string(str.c_str());
      }
};

typedef hash_map<string, UserInfo*, str_hash> StringHash;
typedef StringHash::iterator StringHashIterator;

void test_string_find(StringHash& hash, UserInfo* Users, int UserCount)
{
      int i;
      StringHashIterator it;
      unsigned long long start, end;
      start = rdtsc();
      for (i=0; i<UserCount; i++)
      {
          it = hash.find(Users[i].UserName);
          if (it == hash.end())
          {
              P("not found %s", Users[i].UserName);
          }
      }
      end = rdtsc();
      end -= start;
      P("string find \t= %llu", end);
}

void test_string_erase(StringHash& hash, UserInfo* Users, int UserCount)
{
      int i;
      unsigned long long start, end;
      start = rdtsc();
      for (i=0; i<UserCount; i++)
      {
          hash.erase(Users[i].UserName);
      }
      end = rdtsc();
      end -= start;
      P("string erase \t= %llu", end);
      assert(hash.size()==0);
}

void make_hash(UserInfo* Users, int UserCount)
{
      hash_map<string, UserInfo*, str_hash> hh;
      int i;
      unsigned long long start, end;
      start = rdtsc();
      for (i=0; i<UserCount; i++)
      {
          hh.insert(make_pair(Users[i].UserName, Users+i));
      }
      end = rdtsc();
      end -= start;
      P("string spend\t= %llu", end);
      test_string_find(hh, Users, UserCount);
      test_string_erase(hh, Users, UserCount);
}
//=============================================================================
namespace std
{
      template <>
      struct equal_to<const char*> : public binary_function<const char*, const char*, bool>
      {
          bool operator()(const char* str1, const char* str2) const
          {
              return 0==strcmp(str1, str2);
          }
      };
};

struct char_hash
{
      size_t operator()(const char* str) const
      {
          return __stl_hash_string(str);
      }
};

typedef hash_map<const char*, UserInfo*, hash<const char*> > CharHash;
typedef CharHash::iterator CharHashIterator;

void test_char_find(CharHash& hash, UserInfo* Users, int UserCount)
{
      int i;
      CharHashIterator it;
      unsigned long long start, end;
      start = rdtsc();
      for (i=0; i<UserCount; i++)
      {
          it = hash.find(Users[i].UserName);
          if (it == hash.end())
          {
              P("not found %s", Users[i].UserName);
          }
      }
      end = rdtsc();
      end -= start;
      P("char hash find \t= %llu", end);
}

void test_char_erase(CharHash& hash, UserInfo* Users, int UserCount)
{
      int i;
      unsigned long long start, end;
      start = rdtsc();
      for (i=0; i<UserCount; i++)
      {
          hash.erase(Users[i].UserName);
      }
      end = rdtsc();
      end -= start;
      P("char erase \t= %llu", end);
      assert(hash.size()==0);
}

void make_char_hash(UserInfo* Users, int UserCount)
{
      //hash_map<const char*, UserInfo*, char_hash, std::equal_to<const char*> > hh;
      CharHash hh;
      int i;
      unsigned long long start, end;
      start = rdtsc();
      for (i=0; i<UserCount; i++)
      {
          hh.insert(make_pair(Users[i].UserName, Users+i));
      }
      end = rdtsc();
      end -= start;
      P("char spend\t= %llu", end);
      test_char_find(hh, Users, UserCount);
      test_char_erase(hh, Users, UserCount);
}

//=============================================================================
//预先缓存HASH值
struct UserInfoHasher
{
      size_t operator()(const UserInfo* node) const
      {
          return node->HashCode;
      }
};

struct UserInfoCompare
{
      bool operator()(const UserInfo* lsh, const UserInfo* rsh) const
      {
          return lsh->HashCode == rsh->HashCode &&
                  lsh->Length == rsh->Length &&
                  strcmp(lsh->UserName, rsh->UserName)==0;
      }
};

typedef hash_map<UserInfo*, UserInfo*, UserInfoHasher, UserInfoCompare> CachedHash;
typedef hash_map<UserInfo*, UserInfo*, UserInfoHasher, UserInfoCompare>::iterator CachedHashIterator;

void test_cached_find(CachedHash& hash, UserInfo* Users, int UserCount)
{
      int i;
      CachedHashIterator it;
      unsigned long long start, end;
      start = rdtsc();
      for (i=0; i<UserCount; i++)
      {
          it = hash.find(Users+i);
          if (it == hash.end())
          {
              P("not found %s", Users[i].UserName);
          }
      }
      end = rdtsc();
      end -= start;
      P("cached find \t= %llu", end);
}

void test_cached_erase(CachedHash& hash, UserInfo* Users, int UserCount)
{
      int i;
      unsigned long long start, end;
      start = rdtsc();
      for (i=0; i<UserCount; i++)
      {
          hash.erase(Users+i);
      }
      end = rdtsc();
      end -= start;
      P("cached erase \t= %llu", end);
      assert(hash.size()==0);
}

void make_CachedHash(UserInfo* Users, int UserCount)
{
      CachedHash hh;
      int i;
      unsigned long long start, end;
      start = rdtsc();
      for (i=0; i<UserCount; i++)
      {
          hh.insert(make_pair(Users+i, Users+i));
      }
      end = rdtsc();
      end -= start;
      P("cached spend\t= %llu", end);
      test_cached_find(hh, Users, UserCount);
      test_cached_erase(hh, Users, UserCount);
}

/*
g++ -o string_hash.o -c string_hash.cpp -g -Wall
g++ -o string_hash.exe string_hash.o
*/                         

      下面是几种HASH的性能对比:数值是rdtsc指令得到的时钟数,可以明显地发现预先计算HASH值后,并能有提升,而且字符串越长,提升的效果越明显。
                                  方式            插入            查找            删除                            string            2635731072            1263602970            1847298339                            const char*            2198134872            807559812            1156917744                            cache_hash            2015849286            740948535            1109622195