树状数组及其应用

来源:互联网 发布:seo推广培训 编辑:程序博客网 时间:2024/05/25 16:37

树状数组是一个查询和修改复杂度都为log(n)的数据结构。主要用于查询任意两位置之间的所有元素之和,每次可以修改某一处元素的值。

以下是树状数组的存储方式(图片来源于互联网)


可以看出:

C[1]=A[1]

C[2]=A[1]+A[2]

C[3]=A[3]

C[4]=A[1]+A[2]+A[3]+A[4]

C[5]=A[5]

C[6]=A[5]+A[6]

C[7]=A[7]

C[8]=A[1]+A[2]+A[3]+A[4]+A[5]+A[6]+A[7]+A[8]

……

严格的定义为:C[n] = A[n-2^k+1] + …… + A[n],其中,k为n的二进制表示中最低位的1。即:C[n]为A[n]在内的前2^k个数之和,C[n]的辖域为2 ^ k。

则查询前x个数之和可以如下进行:首先确定C[x]的辖域为low_bit(x),则累加完这些元素之和C[x]后,继续求解1~C[x]-low_bit(x)这个区间的元素之和,x = 0时停止迭代。函数如下:

int getsum(int idx){int sum = 0;for(int i = idx; i > 0; i -= lowbit(i))sum += c[i];return sum;}

其中low_bit求解如下:

int lowbit(int x){return x & (-x);}

结合负数的补码表示方式:取反加一,若最低位为第i位,则取反后第i位为0,第0~i - 1位均为1,加一后使得第i位为1,则再进行&操作,得到的就是2 ^ i。

修改时,对某个元素进行修改,则要沿着该元素所在路径一直向父节点修改,复杂度为O(logn)。关键是父节点的确定。其实树状数组可以理解为减少了一半节点的线段树,如下图所示(图片来源于http://www.cppblog.com/Ylemzy/articles/98322.html):


其中的空白节点,即是树状数组相对线段树节省的节点,空白节点可以理解为该节点的兄弟节点,这两个子树拥有同样数目的元素(辖域相同),则父节点的辖域为左右子树的元素个数之和,父节点的下标为c[x] + low_bit(x)。则修改的代码如下:

void update(int idx, int delta){for(int i = idx; i <= n; i += low_bit(i))c[i] += delta;}

例题:hdu1166 http://acm.hdu.edu.cn/showproblem.php?pid=1166 代码如下:

#include <cstdio>#include <cstring>using namespace std;#define N 50005int n, a[N], c[N] ,sum[N];char opt[10];int low_bit(int x){return x & (-x); //负数取反加1,最低位的1处被set}int get_sum(int k){int res = 0;for(int i = k; i > 0; i -= low_bit(i))res += c[i];return res;}void update(int idx, int delta){for(int i = idx; i <= n; i += low_bit(i))c[i] += delta;}int main(){int tc, ca = 0;scanf("%d", &tc);while(tc --){scanf("%d", &n);sum[0] = 0;for(int i = 1; i <= n; ++i){scanf("%d", &a[i]);sum[i] = sum[i - 1] + a[i];   //也可直接在此处一个一个update,不过需要将c清零,且复杂度为nlogn}for(int i = 1; i <= n; ++i)c[i] = sum[i] - sum[i - low_bit(i)];getchar();printf("Case %d:\n", ++ca);while(scanf("%s", opt) == 1 && strcmp(opt, "End")){            int op1, op2;            scanf("%d %d", &op1, &op2);            if(strcmp(opt, "Query") == 0){                printf("%d\n", get_sum(op2) - get_sum(op1 - 1));            }            else{                int delta = strcmp(opt, "Add") == 0 ? op2 : -op2;                update(op1, delta);            }}}return 0;}

除此之外,当元素值在某一范围之内时,可以用来求第k小/大的数。类似哈希的思想,a[i]中存储的是值为a[i]的元素个数,则sum(x)为不大于x的元素个数。

方法①:常规方法,使用二分查找第一个sum不小于k的sum(x),即sum(x - 1) <k, sum(x) >= k,但速度上比方法②略慢一些,因为每次二分时都调用了get_sum函数;

方法②:二进制增量法不断set x的每一位,使其向目标值ans靠近,从最高位开始,试探该位置一后不大于x的元素个数是否小于k,若小于k则该位可以置一,并记录当前范围内元素的个数(此处的处理节约了get_sum的计算开销);注意c[x]代表的是处于区间[x - low_bit(x) + 1, x]的元素个数。代码如下:

int getkth(int a, int k){    int ans = 0, cnt = -getsum(a);    for(int i = 20; i >= 0; i --){        ans += 1 << i;        if(ans >= MAXN || cnt + c[ans] >= k)            ans -= 1 << i;        else            cnt += c[ans]; //将新扩展的区间中元素个数累加到当前总个数中    }    return ans + 1;}

上述代码中,在每轮的循环中,由于变量i是递减的,若点ans可行,则c[ans]的辖域为(1 << i),代表的是处于区间[ans - 1 << i + 1, ans]之间的元素个数,而ans' = ans - 1 << i,即是上一轮循环确定的范围,cnt为[1, ans']区间的元素个数。

例题:hdu2852 http://acm.hdu.edu.cn/showproblem.php?pid=2852 代码如下:

#include <cstdio>#include <cstring>using namespace std;#define MAXN 100000int c[MAXN];int lowbit(int x){    return x & (-x);}void add(int idx, int delta){    for(int i = idx; i < MAXN; i += lowbit(i))        c[i] += delta;}int getsum(int idx){    int sum = 0;    for(int i = idx; i > 0; i -= lowbit(i))        sum += c[i];    return sum;}int getkth(int a, int k){    int ans = 0, cnt = -getsum(a);    for(int i = 20; i >= 0; i --){        ans += 1 << i;        if(ans >= MAXN || cnt + c[ans] >= k)            ans -= 1 << i;        else            cnt += c[ans];    }    return ans + 1;}int main(){    int m;    while(scanf("%d", &m) == 1){        int ops;        memset(c, 0, sizeof(c));        for(int i = 0; i < m; i ++){            scanf("%d", &ops);            if(ops != 2){                int val;                scanf("%d", &val);                if(ops == 0)                    add(val, 1);                else{                    if(getsum(val - 1) != getsum(val))                        add(val, -1);                    else                        printf("No Elment!\n");                }            }            else{                int a, k;                scanf("%d %d", &a, &k);                int ans = getkth(a, k);                if(ans == MAXN)                    printf("Not Find!\n");                else                    printf("%d\n", ans);            }        }    }}
0 0
原创粉丝点击