树状数组点更新,区间更新理解

来源:互联网 发布:汇天下p2p源码 编辑:程序博客网 时间:2024/06/06 05:23

对于一个数列A1A2A3…An,要求支持两种操作:
1.查询[x,y]区间的区间和
2.把[x,y]区间每个元素加val
事实上线段树也可以解决这样的问题,用上一点lazy的思想,每次只更新小区间的区间和,查询的时候加上祖先节点的影响就可以了。
我们用树状数组也可以解决这样的问题,并且效率会更高一些,复杂度都是nlog(n),但树状数组的常数更小,空间占用也更少。

首先回忆一下树状数组维护前缀和的过程,首先对于任意一个整数x,比如11,表达成二进制为1011,而1011 = 1000 + 10 + 1,这样的划分不超过log(11)次,考虑让一个数组C[]维护一小段区间和,让C[11]维护以A[11]作为结尾,长度为1的区间(这个时候只有A[11]一个值),让C[11-1]维护以A[11-1]作为结尾,长度为10(2进制)的区间和,同理让C[11-1-2]维护以A[11-1-2]作为结尾,长度为1000(2进制)的区间和,这样我们的小区间就覆盖了整个[1,11]的区间,前缀和就可以累加得到了,基于这样的思想我们设定了lowbit函数,他是这样的:

int lowbit(int x) {    return x & -x;}

这个函数利用位运算技巧返回了整数x的最低位的1和后续的0组成的整数,比如10 = 1010(2进制),那么lowbit(10) = 2 = 10(2进制),根据上面的讨论,这个函数实际上用来分解整数达到划分区间的目的。
因此我们用C[i]维护A[i-lowbit(i)+1]A[i-lowbit(i)+2]…A[i]的区间和
假设C数组已经构造好了,那么我们查询从[1,x]区间的前缀和函数是这样的:

int query(int x) {    int res = 0;    while(x > 0)     {        C[x] += lowbit(x); x -= lowbit(x);    }    return res;}

联系上面11的例子,11 = 1000 + 10 + 1,所以C[11]只用管A[11]就可以了,剩余的10个不管,C[10]只用管A[9],A[10]就可以了,C[8]管剩下所有的。
当某个A[i]加上一个值val,如何更新呢

void update(int x, int val){    while(x <= n)    {        C[x] += val; x += lowbit(x);    }}

首先可以明确的是,x + lowbit(x)必定会使x的二进制数发生进位,而且是最小的进位,事实上很容易发现进位之后的数一定会管到原来的A[i],画出树状图就可以发现
C数组也可以递推求得,就不介绍了,代码如下,很好懂

memset(c, 0, sizeof(c));for(int i = 1; i <= n; i++){     c[i] += a[i];     int father = i + lowbit(i);     if(father <= n) c[father] += c[i];}

对一段区间每个值加上一个val如何处理呢?
假设存在一个数组add[],add[x] = val 表示把[x,n]这个区间每个元素+val.
这样当我们把[x,y]区间每个元素都+val的时候,把问题转化为把[x,n]的区间每个元素+val,把[y+1,n]每个元素-val
这样当我们查询[1,x]区间的区间和时,实际的
sum[x] = A[1] + A[2] +… + A[x] + add[1] * (x+1-1) + add[2] * (x+1-2) + … + add[x] * (x+1-x) = (A[1] + A[2] + … + A[x]) + (x+1)(add[1] + add[2] + …+ add[x] ) - (1 * add[1] + 2*add[2] + …+x *add[x])

做到这里问题就转化为了求前缀和了,第一项是A[x]的前缀和直接维护即可,第二项和第三项分别用树状数组维护即可。代码很容易看懂。结合题目poj3468,以下是ac代码

#include <cstdio>#include <cstring>#include <iostream>#include <algorithm>using namespace std;const int maxn = 100005;typedef long long LL;int n,q,a[maxn];LL sum[maxn];LL c1[maxn],c2[maxn]; //c1维护add[i]的前缀和,c2维护add[i]*i的前缀和int lowbit(int x) { return x & (-x); }void update1(int x,int val){    while(x <= n)    {        c1[x] += val; x += lowbit(x);    }}LL query1(int x){    LL res = 0;    while(x > 0)    {        res += c1[x]; x -= lowbit(x);    }    return res;}void update2(int x,int val){    while(x <= n)    {        c2[x] += val; x += lowbit(x);    }}LL query2(int x){    LL res = 0;    while(x > 0)    {        res += c2[x]; x -= lowbit(x);    }    return res;}int main(){    scanf("%d%d",&n,&q);    memset(a,0,sizeof(a));    memset(c1,0,sizeof(c1));    memset(c2,0,sizeof(c2));    memset(sum, 0, sizeof(sum));    for(int i = 1; i <= n; i++) scanf("%d" ,&a[i]);    for(int i = 1; i <= n; i++) sum[i] = sum[i-1] + a[i];    for(int i = 1; i <= q; i++)    {        char cmd[2];        scanf("%s",cmd);        if(cmd[0] == 'Q')        {            int left,right;            scanf("%d%d" ,&left,&right);            LL x = sum[right] + (right + 1) * query1(right) - query2(right);            LL y = sum[left-1] + left * query1(left-1) - query2(left-1);            printf("%I64d\n" ,x-y);        }        else if(cmd[0] == 'C')        {            int left,right,val;            scanf("%d%d%d" ,&left,&right,&val);            update1(left,val); //转化为点更新            update1(right+1,-val);            update2(left,val*left);            update2(right+1,-val*(right+1));        }    }    return 0;}
0 0
原创粉丝点击