树状数组入门[例题详解]

来源:互联网 发布:德保罗大学 知乎 编辑:程序博客网 时间:2024/06/07 02:40

一周时间学习了树状数组入门的十道题,特写此贴总结下树状数组

树状数组的定义

树状数组(Binary Indexed Tree,BIT) 二叉索引树
主要支持两种操作

  • Add(x,d)操作:让Ax增加d
  • Query(L,R):计算Al+(Al+1)+…+Ar.

且两种操作的时间都是log(n)

这里写图片描述
[ 1 ] 树状数组的本职是单点修改+区间查询 维护前缀和 每次修改向上传数据 然后查询区间的时候也是从下往上加值

比如
C[1]掌管A[1]
C[2]掌管A[1]和A[2]
C[3]掌管A[3]
C[4]掌管A[1]A[2]A[3]A[4]
….
C[8]掌管A[1…..8]

所以每次修改A[x]都要向上更新所有掌管A[x]的C[i],树状数组主要是维护C[]数组
而每次向上要找到下一个掌管他的C[]只需要加上lowbit(i)

同理每次求[a,b]区间的和时,= [1,b] - [1,a-1]
求[1,b]区间和的时候只需加上所有独立的C[i]。(比如加上C[4]就不需要加C[2],因为C[4]中包括C[2]。)i向下找到下一个独立的C[i],向下找的方法就是减去lowbit(i)。累加就是区间和了。

int lowbit(int x){return x & -x;}
void add(int x,ll k,int d){    for(int i = x ;i<=maxn;i+=lowbit(i)) c[d][i] += k;}
ll sum(int x,int d){    ll ans = 0;    for(int i = x ;i;i-=lowbit(i)) ans += c[d][i];    return ans;}

A - Color the ball

HDU - 1556

这题就是N个气球每次对[a,b]区间的气球涂色,每次询问第i个气球被涂了多少次色。

显然朴素的树状数组实现单点修改+区间求和

对于这题区间修改+单点查询,思想延续树状数组。加上差分数组的思想很容易知道我们只需要用树状数组维护一个差分数组,每次修改区间[a,b],我们只需要
add(C[a],1) add(C[b+1],-1)

查询单点值根据差分数组的定义,就是求A[1,x]的和

代码

/**    树状数组区间修改+单点查询(差分数组思想+树状数组)*/#include<iostream>#include<cstdio>#include<cstring>#include<algorithm>using namespace std;const int maxn = 1e5+10;int c[maxn];int lowbit(int x){return x&-x;}void add(int x,int k){    for(int i = x;i<maxn;i+=lowbit(i)) c[i] += k;}int sum(int x) {    int sum = 0;    for(int i=x;i;i-=lowbit(i)) sum += c[i];    return sum;}int main(){    int n;    while(~scanf("%d",&n),n)    {        memset(c,0,sizeof(c));        int a,b;        for(int i=0;i<n;i++) {            scanf("%d%d",&a,&b);            add(a,1);add(b+1,-1);        }        for(int i=1;i<n;i++) {            printf("%d ",sum(i));        }        printf("%d\n",sum(n));    }    return 0;}

B - Matrix

POJ - 2155

二维树状数组的区间修改+单点查询

二维树状数组和一维类似这里给出二维树状数组的更新查询代码

int lowbit(int x) { return x & -x;}void add(int x,int y,int k){    for(int i=x;i<=n;i+=lowbit(i))        for(int j=y;j<=n;j+=lowbit(j))            c[i][j] += k;}int sum(int x,int y){    int ans = 0;    for(int i=x;i;i-=lowbit(i))        for(int j=y;j;j-=lowbit(j))            ans += c[i][j];    return ans;}

区间更新,利用差分数组的思想

首先考虑一维的情况,我要一段区间取反,假设是 [l, r]。那么我只需要C[l]+1,C[r+1]+1,假设查询 k 的时候,只需要查询前 k 的和 mod 2 的结果即可。

这种方法可以推广到二维的情况,利用容斥原理。假设修改的子矩阵左上角和右下角分别为 x1 y1 x2 y2,首先C[x1][y1]+1, C[x2][y2]+1,不过这时要 C[x1][y2+1]+1, C[x2+1][y1]+1。

这里写图片描述

这里红色的就是上面需要标记的点,标记后表示以这个点为右下角的子矩阵内的点全部取反一次。这里绿色代表取反一次,棕色是取反了两次,蓝色是取反了四次。最后实际进行取反操作的就是要求取反的子矩阵内的点(即绿色的区域)。

注意这里的前缀和用二位树状数组维护,进行单点更新,单点查询

#include<iostream>#include<cstdio>#include<cstring>#include<algorithm>using namespace std;const int maxn = 1010;int c[maxn][maxn];int n,k;int lowbit(int x) { return x & -x;}void add(int x,int y,int k){    for(int i=x;i<=n;i+=lowbit(i))        for(int j=y;j<=n;j+=lowbit(j))            c[i][j] += k;}int sum(int x,int y){    int ans = 0;    for(int i=x;i;i-=lowbit(i))        for(int j=y;j;j-=lowbit(j))            ans += c[i][j];    return ans;}int main(){    bool first = true;    char op;    int x1,y1,x2,y2,caset;    scanf("%d",&caset);    while(caset--)    {        if(first) first = false;else printf("\n");        scanf("%d%d",&n,&k);        memset(c,0,sizeof(c));        while(k--)        {            scanf("%*c%c",&op);            if(op == 'C'){                scanf("%d%d%d%d",&x1,&y1,&x2,&y2);                add(x1,y1,1);add(x1,y2+1,1);                add(x2+1,y1,1);add(x2+1,y2+1,1);   ///类似于容斥原理(根据add函数对于右下角矩形区域的更新)            }            else {                scanf("%d%d",&x1,&x2);                printf("%d\n",sum(x1,x2)%2);            }        }    }    return 0;}

C - Ultra-QuickSort

POJ - 2299

求冒泡排序逆序对的个数
详细题解参考这篇 逆序对

#include<iostream>#include<cstdio>#include<cstring>#include<algorithm>using namespace std;typedef long long ll;typedef struct node{    int x,tag;}point;const int maxn = 5e5+10;point num[maxn];int c[maxn];int position[maxn];int lowbit(int x) { return x & -x;}void add(int pos,int x) {    for(int i=pos;i<maxn;i += lowbit(i)) c[i] += x;}ll sum(int pos){    ll ans = 0;    for(int i=pos;i;i -= lowbit(i)) ans += c[i];    return ans;}bool cmp(const point &a,const point &b){    return a.x < b.x;}void init(){    memset(c,0,sizeof(c));}int main(){    int n;    while(~scanf("%d",&n),n)    {        init();        for(int i=1;i<=n;i++) scanf("%d",&num[i].x),num[i].tag = i;        sort(num+1,num+1+n,cmp);        for(int i=1;i<=n;i++) position[num[i].tag] = i;             ///表示这个数放在第几个位置        ll ans = 0;        for(int i=1;i<=n;i++) {            add(position[i],1);            ans += i - sum(position[i]);        }        printf("%lld\n",ans);    }}

D - Japan

POJ - 3067

题意:日本岛东海岸与西海岸分别有N和M个城市,现在修高速公路连接东西海岸的城市,求交点个数。

题解:记每条告诉公路为(x,y), 即东岸的第x个城市与西岸的第y个城市修一条路。当两条路有交点时,满足(x1-x2)*(y1-y2) < 0。所以,将每条路按x从小到达排序,若x相同,按y从小到大排序。 然后按排序后的公路用树状数组在线更新,求y的逆序数之 和 即为交点个数。

#include<iostream>#include<cstdio>#include<cstring>#include<algorithm>using namespace std;typedef long long ll;const int maxn = 5e5+10;typedef struct node{    int x,y;}point;point num[maxn];int c[maxn];bool cmp(const point &a,const point &b){    if(a.x == b.x) return a.y < b.y;    return a.x < b.x;}int lowbit(int x) { return x & -x;}void add(int x,int k){    for(int i=x;i<maxn;i+=lowbit(i)) c[i]+=k;}int sum(int x){    int ans = 0;    for(int i=x;i;i-=lowbit(i)) ans += c[i];    return ans;}int main(){    int caset,n,m,k,t=1;    scanf("%d",&caset);    while(caset--)    {        memset(c,0,sizeof(c));        scanf("%d%d%d",&n,&m,&k);        for(int i=0;i<k;i++){            scanf("%d%d",&num[i].x,&num[i].y);        }        sort(num,num+k,cmp);        ll ans = 0;        for(int i=1;i<=k;i++)        {            add(num[i-1].y,1);            ans += i-sum(num[i-1].y);        }        printf("Test case %d: %lld\n",t++,ans);    }    return 0;}

E - Stars

POJ - 2352

题目大意:
在坐标上有n个星星,如果某个星星坐标为(x, y), 它的左下位置为:(x0,y0),x0<=x 且y0<=y。如果左下位置有a个星星,就表示这个星星属于level x
按照y递增,如果y相同则x递增的顺序给出n个星星,求出所有level水平的数量。

分析与总结:
因为输入是按照按照y递增,如果y相同则x递增的顺序给出的, 所以,对于第i颗星星,它的level就是之前出现过的星星中,横坐标x小于等于i星横坐标的那些星星的总数量(前面的y一定比后面的y小)。
所以,需要找到一种数据结构来记录所有星星的x值,方便的求出所有值为0~x的星星总数量。
树状数组和线段树都是很适合处理这种问题。

#include<iostream>#include<cstdio>#include<cstring>#include<algorithm>using namespace std;const int maxn = 32100;typedef struct node{    int x,y;}point;point star[15010];int c[maxn],num[maxn];int n;int lowbit(int x) {return x & -x;}bool cmp(const point &a,const point &b){    if(a.x == b.x) return a.y < b.y;    return a.x < b.x;}void add(int x,int k){    for(int i = x;i<maxn;i+=lowbit(i))            c[i] += k;}int sum(int x){    int ans = 0;    for(int i=x;i;i-=lowbit(i))            ans += c[i];    return ans;}int main(){    while(~scanf("%d",&n))    {        memset(c,0,sizeof(c));        memset(num,0,sizeof(num));        for(int i=1,u,v;i<=n;i++)        {            scanf("%d%d",&u,&v);            num[sum(u+1)]++;            add(u+1,1);        }        for(int i=0;i<n;i++) printf("%d\n",num[i]);    }    return 0;}

F - Mobile phones

POJ - 1195

二维树状数组单点更新+区间求和的裸题

#include<iostream>#include<cstdio>#include<cstring>#include<algorithm>using namespace std;const int maxn = 1025;typedef long long ll;int c[maxn][maxn];int n;int lowbit(int x) { return x & -x;}void add(int x,int y,int k){    for(int i=x;i<=n;i+=lowbit(i))        for(int j=y;j<=n;j+=lowbit(j))            c[i][j] += k;}ll sum(int x,int y){    ll ans = 0;    for(int i = x;i;i-=lowbit(i))        for(int j=y;j;j-=lowbit(j))            ans += c[i][j];    return ans;}int main(){    int op,x,y,k,x1,y1;    while(~scanf("%d",&op))    {        if(op == 0) {            scanf("%d",&n);            memset(c,0,sizeof(c));        }        else if(op == 1){            scanf("%d%d%d",&x,&y,&k);            add(x+1,y+1,k);        }        else if(op == 2){            scanf("%d%d%d%d",&x,&y,&x1,&y1);            printf("%lld\n",sum(x1+1,y1+1) - sum(x1+1,y) - sum(x,y1+1) + sum(x,y));        }        else break;    }    return 0;}

G - Cows

POJ - 2481

给出n头牛,每头牛都喜欢喜欢三叶草???然后每一头牛都有喜欢的三叶草的范围[S,E],现在给出一个定义如果Si <= Sj &&Ei >= Ej && Ei - Si > Ej - Sj 则认为第i头牛要比第j头牛重。然后问你对于每一头牛有多少头牛比他重,

这样说起来可能很抽象,于是我们换一个说法,就是有n个区间,每个区间都有两个端点坐标,如果另一个区间真包含于该区间,则认为另一个区间要重于该区间,于是问你对于每一个区间有多少个区间真包含于该区间。

跟上上一道Star类似

我们可以先按一个端点进行排序,保证在一维的映射中解决。
我们先把最重的牛放上去。怎么样的牛被认为是最重的呢,首先左端点最小,右端点最大

根据上面最重的定义我们先对左端点从小到大进行进行排序,如果左端点相同,按照右端点大的优先排序。这样就保证了最重的定义。树状数组维护右端点的是否存在。然后对于每一头牛,要统计有多少头牛比他重的方法就是看有多少头牛的右端点大于该牛的右端点。因为我们已经保证了左端点的递增性,所以只需要判断右端点即可。

同理我们也可以按照右端点进行排序,要使得最重的牛优先,于是右端点按照从大到小的方式排序,保证重的优先,如果右端点相同,左端点从小到大的方式。树状数组维护左端点的是否存在。然后对于牛x要统计有多少头牛重于该牛,只需要统计在牛x之前比牛x的左端点小的个数即可。

以上两种方法都可以,注意集合的真真子集,左右端点都相同的时候需要特殊判断一下。

#include<iostream>#include<cstdio>#include<cstring>#include<algorithm>using namespace std;const int maxn = 1e5+10;typedef struct node{    int x,y,tag;}point;point num[maxn];int c[maxn],store[maxn];bool cmp(const point &a,const point &b){    if(a.x == b.x) return a.y > b.y;    return a.x < b.x;}int lowbit(int x) { return x & -x;}void add(int x,int k){    for(int i=x;i<maxn;i+=lowbit(i)) c[i] += k;}int sum(int x){    int ans = 0;    for(int i=x;i;i-=lowbit(i)) ans += c[i];    return ans;}int main(){    int n;    while(~scanf("%d",&n),n)    {        memset(c,0,sizeof(c));        memset(store,0,sizeof(store));        for(int i=1;i<=n;i++) scanf("%d%d",&num[i].x,&num[i].y),num[i].tag = i,num[i].x++,num[i].y++;        sort(num+1,num+1+n,cmp);        for(int i=1;i<=n;i++){            if(num[i].x == num[i-1].x && num[i].y == num[i-1].y) store[num[i].tag] = store[num[i-1].tag];            else {                //printf("%d %d\n",num[i].tag,sum(num[i].y));                int temp;                if(num[i].y == 1 ) temp = 0;                else temp = sum(num[i].y-1);                store[num[i].tag] = i - 1 - temp;            }            add(num[i].y,1);        }        for(int i=1;i<n;i++) printf("%d ",store[i]);printf("%d\n",store[n]);    }    return 0;}

H - Apple Tree

POJ - 3321

一颗完全二叉树,结点从1到n,有两个操作,1.对一个结点做异或(如果这个结点有苹果则摘了,没有则种上去),2.计算一个结点的的子树下有多少个苹果

思路,我们先一遍DFS序,记录第一次访问到该结点的序号,和最后一次访问到该结点的序号。
操作一:用树状数组维护第i个结点的值。
操作二:显然要计算结点x下的子树的苹果个数,就是sum(DFS序最后一次访问-DFS序第一次访问);

#include<iostream>#include<cstdio>#include<algorithm>#include<cstring>#include<vector>using namespace std;const int maxn = 100005;int cntedge,cntdfs,n;int c[maxn],head[maxn],ri[maxn],le[maxn];bool vis[maxn],s[maxn];typedef struct node{    int v,next;}Edge;Edge edge[maxn+maxn];int lowbit(int x){return x & -x;}void add(int x,int k){    for(int i=x;i<=n;i+=lowbit(i)) c[i]+=k;}int sum(int x){    int ans = 0;    for(int i=x;i;i-=lowbit(i)) ans += c[i];    return ans;}void addedge(int u,int v){    edge[cntedge].v = v;    edge[cntedge].next = head[u];    head[u] = cntedge++;}void dfs(int x){    vis[x] = 1;    le[x] = cntdfs;    for(int i=head[x];i!=-1;i=edge[i].next) if(!vis[edge[i].v])    {        cntdfs++;        dfs(edge[i].v);    }    ri[x] = cntdfs;}int main(){    char op[2];    scanf("%d",&n);    {        cntedge = 0;cntdfs = 1;        memset(c,0,sizeof(c));        memset(s,0,sizeof(s));        memset(head,-1,sizeof(head));        for(int i=1,u,v;i<n;i++) {            scanf("%d%d",&u,&v);            addedge(u,v);        }        dfs(1);        for(int i=1;i<=n;i++) {            s[i] = 1;            add(le[i],1);        }        int m;        scanf("%d",&m);        for(int i=0,x;i<m;i++){            scanf("%s%d",&op,&x);            if(op[0] == 'Q') printf("%d\n",sum(ri[x]) - sum(le[x]-1));            else {                if(s[x]) add(le[x],-1);                else add(le[x],1);                s[x] = !s[x];            }        }    }    return 0;}

I - MooFest

POJ - 1990

这道题做完想起来还是有点吃力

首先说下题意:就是有 n 头 牛(QWQ 为什么又是牛),然后这n头牛被放置在一个水平面上。对于每一头牛给出他的听力限度和他所在的位置,每两头牛进行交流的开销是dist(i,j) * max{val[i],val[j]} = i牛和j牛之间的距离 * i牛和j牛的最大听力限度。输出每两头牛交流的开销和。n头牛的话总共有(n-1) * n/2种交流配对。就是这n *(n-1)/2种开销的和。

如果枚举所有的组合O(n^2)的时间复制度,肯定超时。
于是怎么来考虑这个问题呢,求和的话树状数组十分适合这种情况。
根据前几题的经验。显然我们要先保证数据的优先性。
对于max{val[i],val[j]}我们可以先从小到大排序来处理,用树状数组c[]维护第i个是否出现过,于是我们放第x个的时候,第x个的val[x]是最大的,然后已经在树状数组中出现的就是比val[x]小的,这个时候需要用val[x] * (所有已经出现的牛到牛x的距离和),这个地方如果用枚举的话显然也会超时,我们看到求距离和的时候显然也是个求和问题。再开一个树状数组维护下距离和,维护的距离的位置到源点的距离。于是对于牛x,要计算所有以他为最大值的开销的时候,= val[x] * (所有已经出现的牛到牛x的距离和),有了前面记录距离和的树状数组,就可以很高效的得到 所有已经出现的牛到牛x的距离和
牛x之前有坐标小于牛x的也有坐标大于牛x的

对于坐标小于牛x的我们有 val[x] * (坐标小于牛x的个数(sum(x,0) * 牛x距离源点的距离 - 所有坐标小于牛x的牛到源点的距离和(sum(x,1 ))。

对于坐标大于牛x的我们有 val[x] * (所有坐标大于牛x的牛到源点的距离和(sum(maxn,1) -sum(x,1)) - 坐标大于牛x的个数(sum(maxn,0) - sum(x,0)) * 牛距离源点的个数)

#include<iostream>#include<cstdio>#include<algorithm>#include<cstring>using namespace std;const int maxn = 20010;typedef long long ll;typedef struct node{    int val,pos;    bool operator < (const struct node &a) const{        return val < a.val;    }}point;point points[maxn];ll c[2][maxn];int lowbit(int x){return x & -x;}void add(int x,ll k,int d){    for(int i = x ;i<=maxn;i+=lowbit(i)) c[d][i] += k;}ll sum(int x,int d){    ll ans = 0;    for(int i = x ;i;i-=lowbit(i)) ans += c[d][i];    return ans;}int main(){    int n;    while(~scanf("%d",&n))    {        memset(c,0,sizeof(c));        for(int i=0;i<n;i++) scanf("%d%d",&points[i].val,&points[i].pos);        sort(points,points+n);        ll ans = 0;        for(int i=0;i<n;i++)        {            ll cntl = sum(points[i].pos,0);            ll cntr = i - cntl;            ll dl = sum(points[i].pos,1),dr = sum(maxn-1,1) - dl;            ans += points[i].val*(cntl*points[i].pos- dl + dr - cntr*points[i].pos);            add(points[i].pos,1,0);            add(points[i].pos,points[i].pos,1);        }        printf("%lld\n",ans);    }    return 0;}

J - BST

POJ - 2309

树状数组原理题

答案为 x - lowbit (x) + 1 和 x + lowbit(x) - 1

#include<iostream>#include<cstdio>#include<cstring>#include<algorithm>using namespace std;typedef long long ll;int lowbit(int x){return x & -x;}void get(int x,int &l,ll &r){    l =  x - lowbit(x) + 1;    r = (ll)x + (ll)lowbit(x) - 1;}int main(){    int x,n,l;    ll r;    scanf("%d",&n);    while(n--)    {        scanf("%d",&x);        get(x,l,r);        printf("%d %lld\n",l,r);    }    return 0;}
原创粉丝点击