树状数组解析与例题

来源:互联网 发布:查看linux系统编码 编辑:程序博客网 时间:2024/06/06 11:48
   树状数组是对一个数组改变某个元素和求和比较实用的数据结构。两中操作都是O(logn)。 

传统数组(共n个元素)的元素修改和连续元素求和的复杂度分别为O(1)和O(n)。树状数组通过将线性结构转换成伪树状结构(线性结构只能逐个扫描元素,而树状结构可以实现跳跃式扫描),使得修改和求和复杂度均为O(lgn),大大提高了整体效率。

给定序列(数列)A,我们设一个数组C满足

C[i] = A[i–2^k+ 1] + … + A[i]

其中,k为i在二进制下末尾0的个数,i从1开始算!

则我们称C为树状数组。

下面的问题是,给定i,如何求2^k?

答案很简单:2^k=i&(i^(i-1)) ,也就是i&(-i)    为什么呢?? 请看下面:

整数运算 x&(-x),当x为0时结果为0;x为奇数时,结果为1;x为偶数时,结果为x中2的最大次方的因子。
       因为:x &(-x) 就是整数x与其相反数(负号取反)的按位与:1&1=1,0&1 =0, 0&0 =1。具体分析如下:
       □ 当x为0时,x&(-x) 即 0 & 0,结果为0;
       □ 当x不为0时,x和-x必有一个为正。不失一般性,设x为正。
       ●当x为奇数时,最后一个比特为1,取反加1没有进位,故x和-x除最后一位外前面的位正好相反,按位与结果为0。最后一位都为1,故结果为       1。
       ●当x为偶数,且为2的m次方(m>0)时,x的二进制表示中只有一位是1(从右往左的第m+1位),其右边有m位0,左边也都是0(个数由表示   x的字        节数决定),故x取反加1后,从右到左第有m个0,第m+1位及其左边全是1。这样,x& (-x) 得到的就是x。 
       ●当x为偶数,却不为2的m次方的形式时,可以写作x= y * (2^k)。其中,y的最低位为1。实际上就是把x用一个奇数左移k位来表示。这时,x的   二进制         表示最右边有k个0,从右往左第k+1位为1。当对x取反时,最右边的k位0变成1,第k+1位变为0;再加1,最右边的k位就又变成了0,第   k+1位因为进         位的关系变成了1。左边的位因为没有进位,正好和x原来对应的位上的值相反。二者按位与,得到:第k+1位上为1,左边右边都为   0。结果为2^k,即         x中包含的2的最大次方的因子。
        总结一下:x&(-x),当x为0时结果为0;x为奇数时,结果为1;x为偶数时,结果为x中2的最大次方的因子。 比如x=32,其中2的最大次方因子  为                 2^5,故x&(-x)结果为32;当x=28,其中2的最大次方因子为4,故x & (-x)结果为4。当x=24,其中2的最大次方因子为8,故 x&(-x)结果为  8。

下面进行解释:

以i=6为例(注意:a_x表示数字a是x进制表示形式):

(i)_10 = (0110)_2

(i-1)_10=(0101)_2

i xor (i-1) =(0011)_2

i and (i xor (i-1))  =(0010)_2

2^k = 2

C[6] = C[6-2+1]+…+A[6]=A[5]+A[6]

数组C的具体含义如下图所示:

当我们修改A[i]的值时,可以从C[i]往根节点一路上溯,调整这条路上的所有C[]即可,这个操作的复杂度在最坏情况下就是树的高度即O(logn)。另外,对于求数列的前n项和,只需找到n以前的所有最大子树,把其根节点的C加起来即可。不难发现,这些子树的数目是n在二进制时1的个数,或者说是把n展开成2的幂方和时的项数,因此,求和操作的复杂度也是O(logn)。

树状数组能快速求任意区间的和:A[i] + A[i+1] + … + A[j],设sum(k) = A[1]+A[2]+…+A[k],则A[i] + A[i+1] + … + A[j] = sum(j)-sum(i-1)。


下面是别人总结的题目和代码(借鉴一下):

hdu 1541 Stars

题意:略。

思路:

树状数组经典入门题。

ans[i]: the amount of stars of the level i;

sum[i]: 横坐标为x的点,满足的the amount of the stars;

注意的地方:

(1)题目所给的点已经排好序了。

(2)由于x可能取0,而lowbit(0)=0,故add(0,1)会死循环。这就是为什么我一开始TLE的原因。所以将所有的 x++.

[cpp] view plaincopyprint?
  1. const int MAX1 = 15555, MAX2 = 32222;  
  2. int ans[MAX1], sum[MAX2], n,x,y;  
  3.   
  4. int lowbit(int x){  
  5.     return x & (-x);  
  6. }  
  7.   
  8. int getsum(int pos){  
  9.     int ret = 0;  
  10.     while(pos > 0){  
  11.         ret += sum[pos];  
  12.         pos -= lowbit(pos);  
  13.     }  
  14.     return ret;  
  15. }  
  16.   
  17. void add(int pos, int num){  
  18.     while(pos < MAX2){  
  19.         sum[pos] += num;  
  20.         pos += lowbit(pos);  
  21.     }  
  22. }  
  23.   
  24. int main()  
  25. {  
  26.     while(scanf("%d", &n) != EOF){  
  27.         memset(sum, 0, sizeof(sum));  
  28.         memset(ans, 0, sizeof(ans));  
  29.         FOR(i,1,n){  
  30.             scanf("%d%d", &x, &y);  
  31.             x++;   //注意加1,不然会在add(0,1)处死循环  
  32.             ans[getsum(x)]++;  
  33.             add(x, 1);  
  34.         }  
  35.         FOR(i,0,n-1)  
  36.             printf("%dn", ans[i]);  
  37.     }  
  38.     return 0;  
  39. }  



poj 2182 Lost Cows 
题意:有一个序列a:1,2,…,N(2 <= N <= 8,000). 现该序列为乱序,已知第i个数前面的有a[i]个小于它的数。求出该序列的排列方式。
思路:由后向前推。易知最后一个数的真实值为a[N]+1。将a[N]+1在序列中删去,更新a[i],那么第N-1个数的真实值为a[N-1]+1。由此类推。
由于数据范围较小,用两层for循环的简单方法就可以解决。
这里给出树状数组的解法:


[cpp] view plaincopyprint?
  1. const int MAX = 8010;  
  2. int a[MAX], n, cnt[MAX], ans[MAX];  
  3.   
  4. int lowbit(int x){  
  5.     return x & (-x);  
  6. }  
  7. void add(int pos, int val){  
  8.     while(pos <= n){  
  9.         cnt[pos] += val;  
  10.         pos += lowbit(pos);  
  11.     }  
  12. }  
  13. int sum(int pos){  
  14.     int res = 0;  
  15.     while(pos > 0){  
  16.         res += cnt[pos];  
  17.         pos -= lowbit(pos);  
  18.     }  
  19.     return res;  
  20. }  
  21.   
  22. //二分找到第一个等于x的位置  
  23. int binary_search(int x){  
  24.     int low = 1, high = n, mid;  
  25.     while(low <= high){  
  26.         mid = (low + high) >> 1;  
  27.         int k = sum(mid);  
  28.         if(k >= x) high = mid-1;  
  29.         else low = mid+1;  
  30.     }  
  31.     while(sum(mid) < x) mid++;  
  32.     return mid;  
  33. }  
  34.   
  35. int main()  
  36. {  
  37.     while(scanf("%d", &n) != EOF){  
  38.         a[1] = 0;  
  39.         FOR(i,2,n) scanf("%d", &a[i]);  
  40.         memset(cnt, 0, sizeof(cnt));  
  41.         FOR(i,1,n) add(i,1);  
  42.         for(int i = n; i >= 1; i--){  
  43.             ans[i] = binary_search(a[i]+1);  
  44.             add(ans[i], -1);  
  45.         }  
  46.         FOR(i,1,n) printf("%d\n", ans[i]);  
  47.     }  
  48.     return 0;  
  49. }  



 poj 2481 Cows

题意:两个区间:[Si, Ei] and [Sj, Ej].(0 <= S < E <= 105). 若 Si <= Sj and Ej <= Ei and Ei – Si > Ej – Sj, 则第i个区间覆盖第j个区间。给定N个区间(1 <= N <= 10^5),分别求出对于第i个区间,共有多少个区间能将它覆盖。

思路:初看好像挺复杂的。其实可以把区间[S, E]看成点(S, E),这样题目就转化为hdu 1541 Stars。只是这里是求该点左上方的点的个数。

虽然如此,我还是WA了不少,有一些细节没注意到。给点排序时是先按y由大到小排序,再按x由小到大排序。而不能先按x排序。比如n=3, [1,5], [1,4], [3,5]的例子。另外还要注意对相同点的处理。


[cpp] view plaincopyprint?
  1. const int MAX = 100010;  
  2.   
  3. struct Node{  
  4.     int x, y, id, ans;  
  5. }seq[MAX];  
  6. int sum[MAX], n;  
  7.   
  8. int cmp1(const void *n1, const void *n2){  
  9.     int res = ((Node*)n2)->y - ((Node*)n1)->y;  
  10.     if(res == 0) return ((Node*)n1)->x - ((Node*)n2)->x;  
  11.     else return res;  
  12. }  
  13. int cmp2(const void *n1, const void *n2){  
  14.     return ((Node*)n1)->id - ((Node*)n2)->id;  
  15. }  
  16.   
  17. int lowbit(int x){  
  18.     return x & (-x);  
  19. }  
  20. void add(int pos, int val){  
  21.     while(pos < MAX){         //我这里总是习惯性的写成n,浪费了很多时间。其实是横坐标x的最大范围。  
  22.         sum[pos]+=val;  
  23.         pos+=lowbit(pos);  
  24.     }  
  25. }  
  26. int getsum(int pos){  
  27.     int res = 0;  
  28.     while(pos>0){  
  29.         res+=sum[pos];  
  30.         pos-=lowbit(pos);  
  31.     }  
  32.     return res;  
  33. }  
  34.   
  35. int main()  
  36. {  
  37.     while(scanf("%d", &n) && n){  
  38.         FOR(i,1,n){  
  39.             scanf("%d%d", &seq[i].x, &seq[i].y);  
  40.             seq[i].x++, seq[i].y++;  
  41.             seq[i].id = i;  
  42.         }  
  43.         qsort(seq+1, n, sizeof(Node), cmp1);  
  44.         memset(sum, 0, sizeof(sum));  
  45.         seq[1].ans = 0;  
  46.         add(seq[1].x, 1);  
  47.         int fa = 1;  
  48.         FOR(i,2,n){  
  49.             if(seq[i].x == seq[fa].x && seq[i].y == seq[fa].y){  
  50.                 seq[i].ans = seq[fa].ans;  
  51.             }else{  
  52.                 fa = i;  
  53.                 seq[i].ans = getsum(seq[i].x);  
  54.             }  
  55.   
  56.             add(seq[i].x, 1);  
  57.         }  
  58.         qsort(seq+1, n, sizeof(Node), cmp2);  
  59.         printf("%d", seq[1].ans);  
  60.         FOR(i,2,n) printf(" %d", seq[i].ans);  
  61.         printf("\n");  
  62.     }  
  63.     return 0;  
  64. }  


poj 2155 Matrix
二维树状数组经典题
题意:给一个N*N的矩阵,里面的值不是0,就是1。初始时每一个格子的值为0。
现对该矩阵有两种操作:(共T次)
1.C x1 y1 x2 y2:将左上角为(x1, y1),右下角为(x2, y2)这个范围的子矩阵里的值全部取反。
2.Q x y:查询矩阵中第i行,第j列的值。
(2 <= N <= 1000, 1 <= T <= 50000)
思路:参见国家集训队论文:武森《浅谈信息学竞赛中的“0”和“1”》
1. 根据这个题目中介绍的这个矩阵中的数的特点不是 1 就是 0,这样我们只需记录每个格子改变过几次,即可判断这个格子的数字。
2. 先考虑一维的情况:
若要修改[x,y]区间的值,其实可以先只修改 x 和 y+1 这两个点的值(将这两个点的值加1)。查询k点的值时,其修改次数即为 sum(cnt[1] + … + cnt[k])。
3. 二维的情况:
道理同一维。要修改范围[x1, y1, x2, y2],只需修改这四个点:(x1,y1), (x1,y2+1), (x2+1,y1), (x2+1,y2+1)。查询点(x,y)的值时,其修改次数为 sum(cnt[1, 1, x, y])。
4. 而区间求和,便可用树状数组来实现。

[cpp] view plaincopyprint?
  1. const int MAX = 1010;  
  2. int n, cnt[MAX][MAX];  
  3.   
  4. int lowbit(int x){  
  5.     return x & (-x);  
  6. }  
  7. void add(int x, int y, int val){  
  8.     for(int i = x; i <= n; i += lowbit(i))  
  9.         for(int j = y; j <= n; j += lowbit(j))  
  10.             cnt[i][j] += val;  
  11. }  
  12. int sum(int x, int y){  
  13.     int res = 0;  
  14.     for(int i = x; i > 0; i -= lowbit(i))  
  15.         for(int j = y; j > 0; j -= lowbit(j))  
  16.             res += cnt[i][j];  
  17.     return res;  
  18. }  
  19.   
  20. int main()  
  21. {  
  22.     int t, m, x1, y1, x2, y2;  
  23.     char op[10];  
  24.     cin >> t;  
  25.     while(t--){  
  26.         scanf("%d%d", &n, &m);  
  27.         memset(cnt, 0, sizeof(cnt));  
  28.         while(m--){  
  29.             scanf("%s%d%d", op, &x1, &y1);  
  30.             if(op[0] == 'C'){  
  31.                 scanf("%d%d", &x2, &y2);  
  32.                 add(x1, y1, 1);  
  33.                 add(x1, y2+1, 1);  
  34.                 add(x2+1, y1, 1);  
  35.                 add(x2+1, y2+1, 1);  
  36.             }else{  
  37.                 printf("%d\n", sum(x1,y1) % 2);  
  38.             }  
  39.         }  
  40.         printf("\n");  
  41.     }  

0 0