经典字典树求异或最大值

来源:互联网 发布:备案域名转让 编辑:程序博客网 时间:2024/05/18 18:16

给出n个数,求在这n个数中取两个数使得他们异或最大

解题思路

非常经典利用字典树求异或的最大值。 
读懂这个题目之后,对于初学者我们可能会思考下面的问题: 
1. 两层循环o(n^2)肯定能求出了,但是肯定会超时 
2. 这道题怎么做呢? 
3. 思考一阵(在不知道用什么算法的情况下),会不会用把每个数转换为二进制,然后按位进行比对 
4. 就算转换为二进制之后,怎么进行按个比对呢?

上面的几个问题就是我,看到这个题目的时候思考的几个问题,我们知道字典是拥有公共前缀的,对应于一个十进制数的二进制来讲,每个二进制数能转换到字典树里面。 
因为题目中对数据的要求不超过10^9那么转换为二进制也就是不会超过2^32这个数,那么我们对每一个数的保存时通过把这个数变为32位数比如1保存的就是0(31)1。建立之后我们在通过一次循环找这个数的二进制的一根分支上有多少不同的,然后取最大的。不理解没有关系,我们用题目中的第一个样例来举例: 
样例是 

1 2 3 4 
那么我们先要把1这个转换为字典中的节点,一位要转换为32位的,所以前面的三十位都应该是0,最后一位是1那么转换为图就是下面这个图: 
建立第一个节点,然后我们就要在这个字典树中建立第二个节点也就是2,2对应的32位二进制位0(30)10,对应如下图 
建立第二个节点, 
一次的建立第三个节点和第四个节点,第三个节点如下: 
第三个节点
第四个节点: 
第四个节点, 
把整个字典树完全做好之后就变为下面的图: 
总图

由图我们可以知道这棵树一共有有32层,每一个节点最多拥有两个节点分别是0和1。那么我们怎么用代码把这个字典树建立起来呢? 
因为一个节点最多有两个节点,所以我们就用一个二维数组son[MAXN][2]来保存每个节点的值,son[i][alp]这个数组的意义表示编号为i的节点里面保存的是alp(0或者1)是孩子节点(存就保存的是下一个节点的编号)

void insert(long long a)    {        int x=0,alp;//x代表每个节点的标号        for(int i=31;i>=0;i--)        {            alp = (a>>i)&1;            if(!son[x][alp]) son[x][alp] = ++cnt;//如果这个节点没有孩子节点就创建它            x = son[x][alp];//将指针后移            //printf("[x=%d,alp=%d]\n",x,alp);        }    }
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

这样一颗完美的字典树就建立好了!,那么怎么查询出异或的最大值呢? 
查询的话就相对来说比较简单了! 
循环输入的n个数,每次找他的反方向,比如最开始的1,0(31)1这个32位二进制数我们按照常理来说应该是先找1因为1和0(31)1的第一位是不同的,但是这个节点不存在就没有办法访问到,我们只能在存在的分支上找和输出的数不同的数,我们还是按照第一个输入的这个十进制1来说明,看下图: 
查询过程
对应的查询代码如下:

long long find(int a)    {        int x=0,alp;        long long ret=0;        for(int i=31;i>=0;i--)        {            alp = !((a>>i)&1);  //取反查找            ret<<=1;     //因为是按照位的,所以是*2            if(son[x][alp]) x = son[x][alp],ret++;//如果和原来的那一为相反的存在的话,返回值就加上,并且在这个支路走            else x = son[x][!alp];  //按照相同的顺序找            //printf("{x=%d,alp=%d,ret=%d}\n",x,alp,ret);        }        //printf("\n------\n");        return ret;    }
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

AC代码

#include<cstdio>#include<cstring>#include<algorithm>#include<cmath>using namespace std;const int MAXN = 100005;long long a[MAXN];struct Trie{    int son[MAXN*13][3],cnt;//这里必须是最大值*13+    void init()    {        memset(son,0,sizeof son);        cnt = 0;    }    void insert(long long a)    {        int x=0,alp;//x代表每个节点的标号        for(int i=31;i>=0;i--)        {            alp = (a>>i)&1;            if(!son[x][alp]) son[x][alp] = ++cnt;//如果这个节点没有孩子节点就创建它            x = son[x][alp];//将指针后移            //printf("[x=%d,alp=%d]\n",x,alp);        }    }    long long find(int a)    {        int x=0,alp;        long long ret=0;        for(int i=31;i>=0;i--)        {            alp = !((a>>i)&1);  //取反查找            ret<<=1;     //因为是按照位的,所以是*2            if(son[x][alp]) x = son[x][alp],ret++;//如果和原来的那一为相反的存在的话,返回值就加上,并且在这个支路走            else x = son[x][!alp];  //按照相同的顺序找            //printf("{x=%d,alp=%d,ret=%d}\n",x,alp,ret);        }        //printf("\n------\n");        return ret;    }}trie;int main(){    int t,n;    scanf("%d",&t);    while(t--)    {        trie.init();        scanf("%d",&n);        for(int i=1;i<=n;i++)scanf("%lld",&a[i]),trie.insert(a[i]);        long long  ans = -1;        for(int i=1;i<=n;i++) ans = max(ans,trie.find(a[i]));        printf("%lld\n",ans);    }    return 0;}#include<iostream>using namespace std;const int N = 1e6 + 10;int a[N];struct TrieNode{    int count;    TrieNode* next[2];    TrieNode(){        count = 0;        next[0] = NULL;        next[1] = NULL;    }};void Insert(TrieNode* root, int value){    TrieNode* p = root;    for(int i=31; i>=0; i--) {        int temp = (value >> i) & 1;        if (p->next[temp] == NULL){            p->next[temp] = new TrieNode();        }        p = p->next[temp];        p->count++;    }}long long  Find(TrieNode* root, int value, int m) {    TrieNode *p = root;    long long result = 0;    for(int i=31; i>=0; i--) {        int val = (value>>i) & 1;        int mval = (m>>i) & 1;        if(mval == 1) {            p = p->next[val^1];        } else {            if (p->next[val^1] != NULL) {                result += p->next[val^1]->count;            }            p = p->next[val];        }        if (p == NULL)            break;    }    return result;}int main(){    int m,n;    long long result = 0;    cin>>n>>m;    TrieNode* root = new TrieNode();    for(int i=0; i<n; i++) {        cin>>a[i];        Insert(root, a[i]);    }    for(int i=0; i<n; i++) {        result += Find(root, a[i], m);    }    cout<<(result>>1)<<endl;    return 0;}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59

如果不理解的话,可以把我注释的程序代码,取消注释就能直观的反映变化