树状数组的学习小结

来源:互联网 发布:js数组添加数组concat 编辑:程序博客网 时间:2024/07/17 22:18

树状数组,又称二进制索引树,英文名Binary Indexed Tree。

一、树状数组的用途

主要用来求解数列的前缀和,a[0]+a[1]+...+a[n]。

由此引申出三类比较常见问题:

1、单点更新,区间求值。(HDU1166)

2、区间更新,单点求值。(HDU1556)

3、求逆序对。(HDU2838)

 

二、树状数组的表示

1、公式表示

设A[]为一个已知的数列。C[]为树状数组。则会有

C[i]=A[j]+...+A[i];j=i&(-i)=i&(i^(i-1))。

2、图形表示

(注:1、最下面的一行表示数组A,上面的二进制表示的部分是C;

2、图片来源于http://hi.baidu.com/rain_bow_joy/blog/item/569ec380c39730d2bc3e1eae.html)

 

从以上可以发现:

1、树状数组C是表示普通数组A的一部分的和。

2、小标为奇数时,C[i]只能管辖一个A[i]。

3、C[i]的最后一个数一定是A[i]。

 

三、树状数组的关键代码

1、

[cpp] view plaincopy
  1. int lowBit(int x)  
  2. {  
  3.     return x&(-x);  
  4. }  

这段代码可以简单的理解为是树状数组向前或向后衍生是用的。

向后主要是为了找到目前节点的父节点,比如要将C[4]+1,那么4+(4&(-4))=8,C[8]+1,8+(8&(-8))=16,

C[16]+1。

向前主要是为了求前缀和,比如要求A[1]+...+A[12]。那么,C[12]=A[9]+...+A[12];然后12-12&(-12)=8,

C[8]=A[1]+...+A[8]。

 

2、

[cpp] view plaincopy
  1. void modify(int pos,int num)  //pos为数组下标位置,num为要增加的值   
  2. {  
  3.     while(pos<=n)   //n为数组的长度   
  4.     {  
  5.         c[pos]+=num;  
  6.         pos+=lowBit(pos);  
  7.     }  
  8. }  

这段代码是用来更新树状数组的,包括区间更新、单点更新。

就是想刚才所说的,一点更新了,要不断将父节点也更新。

 

3、

[cpp] view plaincopy
  1. int getResult(int pos)  //求A[1]+...+A[pos]   
  2. {  
  3.     int sum=0;   
  4.     while(pos>0)  
  5.     {  
  6.         sum+=c[pos];  
  7.         pos-=lowBit(pos);  
  8.     }  
  9.       
  10.     return sum;   
  11. }                  

这段代码用来求解前缀和的。

就像刚才说的,求解A[1]+...+A[12],也就是C[12]+C[8]搞定。

 

四、树状数组的优点

1、原本的长度为n的数列求和时间复杂度为O(n),更改的时间复杂度为O(1)。

树状数组将其优化为O(logn)。在n较大时,效率更高。

2、树状数组编码简单。

 

五、注意

1、树状数组的下标要从1开始。

2、在学习的过程中遇到这么个问题。不知道为什么pos+pos&(-pos)就到了pos的父节点,也不知道

为什么pos-pos&(-pos)就得到了下一个无联系的节点,从而可以得到前缀和。

我只能说:我不懂如何证明,这是数学问题了,树状数组的发明者应该就是发现了这点才搞出树状

数组的吧。初学者不妨抛开这点,专注于事实,将上面的图形自己计算画一遍,非常有利于理解。

 

六、符代码:

HDU1166

单点更新,区间求值

[cpp] view plaincopy
  1. #include<iostream>  
  2. using namespace std;  
  3.   
  4. const int maxn=50001;  
  5.   
  6. int a[maxn];  
  7. int c[maxn];  
  8. int n;   
  9.   
  10. int lowBit(int t)  
  11. {  
  12.     return t&(-t);  
  13. }  
  14.   
  15. void modify(int t,int num)  
  16. {  
  17.     while(t<=n)  
  18.     {  
  19.         c[t]+=num;  
  20.         t+=lowBit(t);  
  21.     }  
  22. }  
  23.   
  24. int getResult(int t)  
  25. {  
  26.     int num=0;   
  27.     while(t>0)  
  28.     {  
  29.         num+=c[t];  
  30.         t-=lowBit(t);  
  31.     }  
  32.       
  33.     return num;   
  34. }  
  35.   
  36. void init()  
  37. {  
  38.     for(int i=1;i<=n;i++)  
  39.     {  
  40.         scanf("%d",&a[i]);  
  41.           
  42.         modify(i,a[i]);   
  43.     }  
  44. }  
  45.   
  46.   
  47. int main()  
  48. {  
  49.     int cas,Case=1;  
  50.       
  51.     scanf("%d",&cas);   
  52.     while(cas--)  
  53.     {   
  54.         memset(c,0,sizeof(c));   
  55.         printf("Case %d:\n",Case++);  
  56.            
  57.         scanf("%d",&n);  
  58.           
  59.         init();  
  60.           
  61.         char ch[15];  
  62.         int a,b;     
  63.         while(scanf("%s",&ch),strcmp(ch,"End"))  
  64.         {   
  65.             scanf("%d%d",&a,&b);  
  66.                
  67.             switch(ch[0])  
  68.             {  
  69.                 case 'Q':  
  70.                     printf("%d\n",getResult(b)-getResult(a-1));  
  71.                     break;   
  72.                 case 'A':   
  73.                     modify(a,b);  
  74.                     break;  
  75.                 case 'S':  
  76.                     modify(a,-b);  
  77.                     break;  
  78.             }  
  79.          }  
  80.      }  
  81.        
  82.      system("pause");  
  83.      return 0;  
  84. }   

 

HDU1556

区间更新,单点求值

[cpp] view plaincopy
  1. #include<iostream>  
  2. #include<cstring>  
  3. using namespace std;  
  4.   
  5. const int maxn=100001;  
  6.   
  7. int c[maxn];  
  8. int n;  
  9.   
  10. int lowbit(int t)  
  11. {  
  12.     return t&(-t);  
  13. }  
  14.   
  15. void insert(int t,int d)  
  16. {  
  17.     while(t<=n)  
  18.     {  
  19.         c[t]+=d;  
  20.         t+=lowbit(t);  
  21.     }  
  22. }  
  23.   
  24. int getSum(int t)  
  25. {  
  26.     int sum=0;  
  27.     while(t>0)  
  28.     {  
  29.         sum+=c[t];  
  30.         t-=lowbit(t);  
  31.     }  
  32.       
  33.     return sum;  
  34. }  
  35.   
  36. int main()  
  37. {  
  38.     while(cin>>n,n)  
  39.     {  
  40.         int a,b;  
  41.         memset(c,0,sizeof(c));  
  42.           
  43.         for(int i=1;i<=n;i++)  
  44.         {  
  45.             scanf("%d%d",&a,&b);  
  46.               
  47.             insert(a,1);  
  48.             insert(b+1,-1);  
  49.         }  
  50.           
  51.        for(int j=1;j<n;j++)  
  52.        {  
  53.             printf("%d ",getSum(j));  
  54.        }  
  55.        printf("%d\n",getSum(n));  
  56.     }  
  57.       
  58.     system("pause");  
  59.     return 0;  
  60. }  

 

HDU2838

求逆序对

[cpp] view plaincopy
  1. #include<iostream>  
  2. #include<cstring>  
  3. using namespace std;  
  4.   
  5. const int maxn=100001;  
  6.    
  7. struct node  
  8. {  
  9.     int cnt;  
  10.     __int64 sum;  
  11. }tree[maxn];           
  12.           
  13. int n;  
  14.   
  15. int lowBit(int x)  
  16. {  
  17.     return x&(-x);  
  18. }  
  19.   
  20. void modify(int x,int y,int t)  
  21. {  
  22.     while(x<=n)  
  23.     {  
  24.         tree[x].sum+=y;  
  25.         tree[x].cnt+=t;  //tree[].cnt来保存是否出现过a   
  26.         x+=lowBit(x);  
  27.     }  
  28. }  
  29.   
  30. __int64 query_cnt(int x)   //比x小的数的个数   
  31. {  
  32.     __int64 sum=0;  
  33.     while(x>0)  
  34.     {  
  35.         sum+=tree[x].cnt;  
  36.         x-=lowBit(x);  
  37.     }  
  38.       
  39.     return sum;  
  40. }  
  41.   
  42. __int64 query_sum(int x)  //比x小的所有数之和   
  43. {  
  44.     __int64 sum=0;  
  45.     while(x>0)  
  46.     {  
  47.         sum+=tree[x].sum;  
  48.         x-=lowBit(x);  
  49.     }  
  50.       
  51.     return sum;  
  52. }  
  53.   
  54. int main()  
  55. {  
  56.     while(~scanf("%d",&n))  
  57.     {  
  58.         int a;  
  59.         __int64 ans=0;   
  60.         memset(tree,0,sizeof(tree));  
  61.            
  62.         for(int i=1;i<=n;i++)  
  63.         {  
  64.             scanf("%d",&a);  
  65.               
  66.             modify(a,a,1);  //以a为下标更新数组   
  67.               
  68.             __int64 k1=i-query_cnt(a);   //k1为前i个数比a大的数的个数   
  69.             if(k1!=0)  
  70.             {  
  71.                 __int64 k2=query_sum(n)-query_sum(a); //目前所有数的和-目前所有比a小的数的和,为比a大的数的和     
  72.                 ans+=k1*a+k2;   //调换a所需的时间   
  73.             }  
  74.         }  
  75.           
  76.         printf("%I64d\n",ans);   
  77.     }  
  78.       
  79.     system("pause");  
  80.     return 0;  
  81. }   

 

七、二维树状数组

C[x][y]=sum(A[i][j])。其中,x-lowBit(x)+1<=i<=x,y-lowBit(y)+1<=j<=y。

例题:HDU1892

二维树状数组一般就是对矩阵的操作,更新、求值。。。

代码:

[cpp] view plaincopy
  1. #include<iostream>  
  2. #include<cstring>  
  3. using namespace std;  
  4.   
  5. const int maxn=1005;  
  6.   
  7. int c[maxn][maxn];  
  8.   
  9. int lowBit(int x)  
  10. {  
  11.     return x&(-x);  
  12. }  
  13.   
  14. void modify(int x,int y,int val)  
  15. {  
  16.     for(int i=x;i<maxn;i+=lowBit(i))  
  17.     {  
  18.         for(int j=y;j<maxn;j+=lowBit(j))  
  19.         {  
  20.             c[i][j]+=val;  
  21.         }  
  22.     }  
  23. }  
  24.   
  25. int getResult(int x,int y)  
  26. {  
  27.     int sum=0;  
  28.     for(int i=x;i>0;i-=lowBit(i))  
  29.     {  
  30.         for(int j=y;j>0;j-=lowBit(j))  
  31.         {  
  32.             sum+=c[i][j];  
  33.         }  
  34.     }  
  35.       
  36.     return sum;  
  37. }  
  38.   
  39. int getVal(int x,int y)  
  40. {  
  41.     return getResult(x,y)-getResult(x-1,y)-getResult(x,y-1)+getResult(x-1,y-1);  
  42. }  
  43.   
  44. void init()  
  45. {  
  46.     memset(c,0,sizeof(c));  
  47.       
  48.     for(int i=1;i<maxn;i++)  
  49.     {  
  50.         for(int j=1;j<maxn;j++)  
  51.         {  
  52.             modify(i,j,1);  
  53.         }  
  54.     }  
  55. }  
  56.       
  57. int main()  
  58. {  
  59.     int cas,cas1=1,query;  
  60.       
  61.     scanf("%d",&cas);  
  62.     while(cas--)  
  63.     {  
  64.         init();  
  65.           
  66.         scanf("%d",&query);  
  67.   
  68.         printf("Case %d:\n",cas1++);  
  69.         for(int i=1;i<=query;i++)  
  70.         {  
  71.             char ch;  
  72.             int x1,y1,x2,y2,n;  
  73.               
  74.             getchar();  
  75.             scanf("%c",&ch);  
  76.               
  77.             switch(ch)  
  78.             {  
  79.                 case 'S':  
  80.                     {  
  81.                     scanf("%d%d%d%d",&x1,&y1,&x2,&y2);  
  82.                     int x11=min(x1,x2);  
  83.                     int x22=max(x1,x2);  
  84.                     int y11=min(y1,y2);  
  85.                     int y22=max(y1,y2);  
  86.                     printf("%d\n",getResult(x22+1,y22+1)-getResult(x11,y22+1)-getResult(x22+1,y11)+getResult(x11,y11));  
  87.                     break;  
  88.                     }  
  89.                 case 'A':  
  90.                     {  
  91.                     scanf("%d%d%d",&x1,&y1,&n);  
  92.                     modify(x1+1,y1+1,n);  
  93.                     break;  
  94.                     }  
  95.                 case 'M':  
  96.                     {  
  97.                     scanf("%d%d%d%d%d",&x1,&y1,&x2,&y2,&n);  
  98.                     int v=getVal(x1+1,y1+1);  
  99.                     int Min=min(n,v);  
  100.                     modify(x1+1,y1+1,-Min);  
  101.                     modify(x2+1,y2+1,Min);  
  102.                     break;  
  103.                     }  
  104.                 case 'D':  
  105.                     {  
  106.                     scanf("%d%d%d",&x1,&y1,&n);  
  107.                     int v=getVal(x1+1,y1+1);  
  108.                     int Min=min(v,n);  
  109.                     modify(x1+1,y1+1,-Min);  
  110.                     break;  
  111.                     }  
  112.             }  
  113.         }  
  114.     }  
  115.       
  116.     system("pause");  
  117.     return 0;  
  118. }