线段树常用模板(转)

来源:互联网 发布:大数据架构师年龄要求 编辑:程序博客网 时间:2024/05/21 10:17

此模板来自博主:不忘初心0924

线段树,顾名思义是在树上的线段,通过建树来维护你需要的操作,基本的操作有:区间求和,区间求最值,区间异或(这个实际上和区间更新差不多,就是加上值这个操作换成了异或),区间覆盖,扫描线求面积,线段树求区间连续字段。

下面从最基础的区间求最值开始:

给你一个长度为n的序列,例如长度为5的序列{9,1,8,5,6},最大值和最小值,当然朴素算法在数据量小的时候是可以的,但是在n的数量级特别大的时候就显得略加笨重。由此想出一种办法,将数组里的数构成一棵数。如下图:

 

这棵树中叶子保存的是每个下标的数值,两个叶子的父亲,保存的是叶子中最小的那一个,这样一直到根节点得出数列中最小值。最大值也是一样。说起来很简单,接着就是用代码实现,线段树是一个二叉树。数据量大的时候也可能是一个完全二叉树,用一个sum数组来存储每个节点的数值,然后递归建树。

复制代码
void build(int i,int l,int r){    if(l==r)    {        scanf("%d",&sum[i]);        return ;    }    int m=(l+r)/2;    build(i*2,l,m);    build(i*2+1,m+1,r);    pushup(i);//收集子节点的结果}         Pushup()函数是将当前节点向下更新         void pushup(int i){         sum[i]=min(sum[i*2],sum[i*2+1]);}
复制代码

当维护用途不同的时候push函数的用法是不一样的。下面每种用途的线段树push函数的写法都有讲解。

然后是更新操作:

复制代码
void update(int id ,int val,int i,int l,int r){         if(l==r)         {                  sum[i]=val;//这里的操作的修改id点的值                  return;         }         int m=(l+r)/2;         if(id<=m) update(id,val,i*2,l,m);         else update(id,val,i*2+1,m+1,r);         pushup(i);}
复制代码

修改,查询操作都是从根节点开始遍历,然后当你遍历到的当前区间,在需要区间之内的时候,就可以进行你需要的操作了。

查询操作:

查询的时候,一个区间的最值要不就在左区间,要不就在右区间,要不然就在左加右区间(虽然很像废话,但是就是这样的)

复制代码
int query (int rt,int L,int R,int l,int r){    if(L<=l&&r<=R)        return sum[rt];    int m=(r+l)>>1;    int ret=0;    if(L<=m)        ret=min(ret,query(rt*2,L,R,l,m)    if(R>m)        ret=min(ret,query(rt*2+1,L,R,m+1,r);    return ret;}
复制代码

区间求和(单点更新,区间更新):        

单点更新:

复制代码
int sum[N*4];void pushup(int i){         sum[i]=sum[i*2]+sum[i*2+1];}void build(int i,int l,int r){    if(l==r)    {        scanf("%d",&sum[i]);        return ;    }     int m=(l+r)/2;    build(i*2,l,m);    build(i*2+1,m+1,r);    pushup(i);//收集子节点的结果}/*在当前区间[l, r]内查询区间[ql, qr]间的目标值  且能执行这个函数的前提是:[l,r]与[ql,qr]的交集非空  其实本函数返回的结果也是 它们交集的目标值  */int query(int ql,int qr,int i,int l,int r){         if(ql<=l&&r<=qr) return sum[i];                  int m=(l+r)/2;         int cur=0;         if(ql<=m) cur+=query(ql,qr,i*2,l,m);         if(m<qr) cur+=query(ql,qr,i*2+1,m+1,r);         return cur;}/*update这个函数就有点定制的意味了  本题是单点更新,所以是在区间[l,r]内使得第id数的值+val  如果是区间更新,可以update的参数需要将id变为ql和qr  */void update(int id ,int val,int i,int l,int r){         if(l==r)         {                  sum[i]+=val;                  return;         }         int m=(l+r)/2;         if(id<=m) update(id,val,i*2,l,m);         else update(id,val,i*2+1,m+1,r);         pushup(i);}
复制代码

区间更新:

这里要引进一个概念叫:延迟更新,就是当你想要更新某一个区间的时候,先更新代表这个区间的树干,然后再去更新枝叶的值,另外区间更新的值记录到addv数组中,更新的时候单点的只需要就上对应addv数组中节点修改的值就行了。

每当有add加到i节点上,直接更新i节点的sum.

也就是说如果要查询区间[1,n]的sum值,直接sum[1]即可,不用再去考虑1的addv[1]值.

复制代码
const int MAXN=100000+100;typedef long long LL;#define lson i*2,l,m#define rson i*2+1,m+1,rLL sum[MAXN*4];LL addv[MAXN*4];void PushDown(int i,int num)//这就是延迟操作,更新当前结点的叶子{    if(addv[i])    {        sum[i*2] +=addv[i]*(num-(num/2));//每个点的需要更新的值乘以的个数        sum[i*2+1] +=addv[i]*(num/2);//同上        addv[i*2] +=addv[i];//这个区间需要更新的个数        addv[i*2+1]+=addv[i];        addv[i]=0;    }}void PushUp(int i){    sum[i]=sum[i*2]+sum[i*2+1];}void build(int i,int l,int r){    addv[i]=0;//将延迟操作更改的值需要记录到addv数组中,现在将它初始化    if(l==r)    {        scanf("%I64d",&sum[i]);        return ;    }    int m=(l+r)/2;    build(lson);    build(rson);    PushUp(i);}void update(int ql,int qr,int add,int i,int l,int r){    if(ql<=l&&r<=qr)    {        addv[i]+=add;        sum[i] += (LL)add*(r-l+1);        return ;    }    PushDown(i,r-l+1);//向下更新枝叶的值    int m=(l+r)/2;    if(ql<=m) update(ql,qr,add,lson);    if(m<qr) update(ql,qr,add,rson);    PushUp(i);}LL query(int ql,int qr,int i,int l,int r){    if(ql<=l&&r<=qr)    {        return sum[i];    }    PushDown(i,r-l+1);    int m=(l+r)/2;    LL res=0;    if(ql<=m) res+=query(ql,qr,lson);    if(m<qr) res+=query(ql,qr,rson);    return res;}
复制代码

 

区间扫描线:

扫描线是用来计算平面图形的面积的,既然是扫描线那么肯定要有一条线,用结构体表示出一条线,它的数据有线段的左右端点,高的坐标(也就是y轴坐标),再加一个标记,是图形的上边界,还是下边界。

计算图形面积的时候,将图形分解成一个个小矩形,然后通过扫描线进行扫描,每一个小矩形的面积就是宽度乘以扫描线的长度,宽度就是高度差,所谓扫描线长度实际就是用线段树维护的区间长度,那么线段树怎么维护区间长度呐?引进一个cover数组记录哪一段区间是覆盖的,初始为0,0是没覆盖,1是覆盖,具体操作如下

例题地址:HDU 3265 Posters(线段树:扫描线)

http://acm.hdu.edu.cn/showproblem.PHP?pid=3265

复制代码
#include<cstring>#include<algorithm>using namespace std;const int MAXN=55555;#define lson i*2,l,m#define rson i*2+1,m+1,rint cnt[MAXN*4],sum[MAXN*4];struct node{    int l,r,h,d;    node(){}    node(int a,int b,int c,int d):l(a),r(b),h(c),d(d){}    bool operator < (const node & b)const    {        if (h == b.h) return d > b.d;//这句话不写也AC,但是还是写上保险,对于本题来说写不写没区别        return h<b.h;    }}nodes[MAXN*8];void PushUp(int i,int l,int r){    if(cnt[i])//如果当前区间有覆盖的位置        sum[i]=r-l+1;    else if(l==r)        sum[i]=0;    else        sum[i]=sum[i*2]+sum[i*2+1];}void build(int i,int l,int r){    cnt[i]=0;//每扫描一次这个数组就要清零一次    sum[i]=0;    if(l==r)        return ;    int m=(l+r)>>1;    build(lson);    build(rson);    //PushUp(i,l,r);}void update(int ql,int qr,int v,int i,int l,int r){    if(ql<=l&&r<=qr)    {        cnt[i]+=v;        PushUp(i,l,r);        return ;    }    int m=(l+r)>>1;//这里一定小心,如果是m=(l+r)/2,会无限递归,栈溢出,如ql=qr=-1且l=-1,r=0的时候    if(ql<=m) update(ql,qr,v,lson);    if(m<qr) update(ql,qr,v,rson);    PushUp(i,l,r);}int main(){    int t;    while(scanf("%d",&t)==1&&t)    {        int m=0;        int lbd=50000,rbd=0;        for(int i=1;i<=t;i++)        {            int x1,y1,x2,y2,x3,y3,x4,y4;            scanf("%d%d%d%d%d%d%d%d",&x1,&y1,&x2,&y2,&x3,&y3,&x4,&y4);            lbd=min(lbd,x1);            rbd=max(rbd,x2);//找出扫描线的最大区间            nodes[++m]= node(x1,x3,y1,1);            nodes[++m]= node(x1,x3,y2,-1);            nodes[++m]= node(x4,x2,y1,1);            nodes[++m]= node(x4,x2,y2,-1);            nodes[++m]= node(x3,x4,y1,1);            nodes[++m]= node(x3,x4,y3,-1);            nodes[++m]= node(x3,x4,y4,1);            nodes[++m]= node(x3,x4,y2,-1);        }        sort(nodes+1,nodes+m+1);//这个题而言的        build(1,lbd,rbd-1);        long long ans=0;        for(int i=1;i<m;i++)        {            int ql=nodes[i].l;            int qr=nodes[i].r-1;            if(ql<=qr)update(ql,qr,nodes[i].d,1,lbd,rbd-1);            ans+= (long long)sum[1]*(nodes[i+1].h-nodes[i].h);//这个是进行离散化后的高度差        }        printf("%I64d\n",ans);    }}
复制代码

 

区间覆盖:

         实际上在上面也说了引进cover数组表示i段区间是不是被覆盖。

POJ 2777 Count Color(线段树:区间覆盖)

例题地址:

http://poj.org/problem?id=2777

题意:

        题意:有一个长板子,多次操作,有两种操作,第一种是给从a到b那段染一种颜色c,另一种是询问a到b有多少种不同的颜色。

操作时候引进cover数组这里表示区间i覆盖了(染过色了),i=0表示没染过色,可以有i-n种颜色,查询的时候需要用一个visit数组记录此区间用过几种颜色,然后在进行判断。累加

复制代码
#include <iostream>#include<cstdio>#include<algorithm>#include<cstring>#include<cmath>using namespace std;#define lson i*2,l,m#define rson i*2+1,m+1,rconst int MAXN = 100000 + 100;bool vis[35];int cnt;//用来计数颜色struct IntervalTree{    int color[MAXN * 4];     void build(int i, int l, int r)    {        color[i] = 1;        if(l == r) return ;        int m = (l + r) / 2;        build(lson);        build(rson);    }     void PushDown(int i)    {        if(color[i] > 0)//当前区间染过色,那么左右区间都是和父节点颜色相同的            color[i * 2] = color[i * 2 + 1] = color[i];    }     void PushUp(int i)    {        if(color[i * 2] == -1 || color[i * 2 + 1] == -1)//左右区间的颜色相同的时候才能更新到父节点            color[i] = -1;        else if(color[i * 2] == color[i * 2 + 1])            color[i] = color[i * 2];        else            color[i] = -1;    }     void update(int ql, int qr, int v, int i, int l, int r)    {        if(ql <= l && r <= qr)        {            color[i] = v;            return ;        }        PushDown(i);        int m = (l + r) / 2;        if(ql <= m) update(ql, qr, v, lson);        if(m < qr) update(ql, qr, v, rson);        PushUp(i);    }     void query(int ql, int qr, int i, int l, int r)    {        if(color[i] > 0)        {            if(vis[color[i]] == false)                cnt++;            vis[color[i]] = true;            return ;        }        //PushDown(i);        int m = (l + r) / 2;        if(ql <= m) query(ql, qr, lson);        if(m < qr) query(ql, qr, rson);    }};IntervalTree T;int main(){    int n, t, q;    while(scanf("%d%d%d", &n, &t, &q) == 3)    {        T.build(1, 1, n);        while(q--)        {            char str[10];            scanf("%s", str);            if(str[0] == 'C')            {                int x, y, z;                scanf("%d%d%d", &x, &y, &z);                T.update(x, y, z, 1, 1, n);            }            else if(str[0] == 'P')            {                int x, y;                scanf("%d%d", &x, &y);                memset(vis, 0, sizeof(vis));                cnt = 0;                T.query(x, y, 1, 1, n);                printf("%d\n", cnt);            }        }    }    return 0;}
复制代码

区间维护最大连续字段:

HDU1540 Tunnel Warfare(线段树:维护最大连续子串)

例题地址

http://acm.hdu.edu.cn/showproblem.PHP?pid=1540

题意:

       3种操作:

1.D x: 该操作就是单点更新

2.Q x: 该操作可以分解为查区间[1,x]的最大连续0后缀长L和区间[x,n]的最大连续0前缀长R,则R+L-1即为所求。

3.R : 该操作其实就是update,不过需要一个stack来保存以前D过的点.

最大连续字段操作的时候,线段树需要维护区间最大前缀,最大后缀,和此区间是否有值。

详细注释看代码

复制代码
#include<iostream>#include<cstdio>#include<algorithm>#include<cstring>#include<cmath>#include<stack>using namespace std;const int MAXN=50000+1000;#define lson i*2,l,m#define rson i*2+1,m+1,r#define root 1,1,nint cover[MAXN*4],pre[MAXN*4],suf[MAXN*4];                               //用来存放最大前缀,最大后缀void PushUp(int i,int l,int r){    int m=(l+r)/2;    //cover    if(cover[i*2]==-1 || cover[i*2+1]==-1)//只有左右区间都存在的时候父节点才能更新为存在        cover[i]=-1;    else if(cover[i*2] != cover[i*2+1])        cover[i]=-1;    else        cover[i]=cover[i*2];     //pre    pre[i]=pre[i*2];    if(pre[i]== m-l+1)pre[i] +=pre[i*2+1];//如果做区间内的最大前缀是最区间的值,那么最大前缀在父节点表示的区间内     //suf    suf[i]=suf[i*2+1];//同上    if(suf[i]==r-m) suf[i]+=suf[i*2];}void PushDown(int i,int l,int r){    int m=(l+r)/2;    if(cover[i]!=-1)    {        cover[i*2]=cover[i*2+1]=cover[i];        suf[i*2]=pre[i*2]= (cover[i]?0:m-l+1);//此区间覆盖了,那么才能是右区间的长度        suf[i*2+1]=pre[i*2+1]= (cover[i]?0:r-m);//同上    }}void build(int i,int l,int r){    if(l==r)    {        cover[i]=0;        suf[i]=pre[i]=1;//前缀和初始化为1个单位长度        return ;    }    int m=(l+r)/2;    build(lson);    build(rson);    PushUp(i,l,r);}void update(int p,int v,int i,int l,int r){    if(l==r)    {        cover[i]=v;        suf[i]=pre[i]= (v?0:1);//既然到叶子上了当然是1或0了        return ;    }    PushDown(i,l,r);    int m=(l+r)/2;    if(p<=m) update(p,v,lson);    else update(p,v,rson);    PushUp(i,l,r);}int query_pre(int ql,int qr,int i,int l,int r)//查找[ql,qr]与[l,r]的公共部分的最大前缀连0的长度{    if(ql<=l && r<=qr)        return pre[i];    PushDown(i,l,r);    int m=(l+r)/2;    if(qr<=m) return query_pre(ql,qr,lson);    if(m<ql) return query_pre(ql,qr,rson);    int L = query_pre(ql,qr,lson);if(L == m+1-max(l,ql) )//这里表示如果做区间的前缀和刚好等于左区间长度那么肯定在右区间还能找到更大的 L +=query_pre(ql,qr,rson);    return L;}int query_suf(int ql,int qr,int i,int l,int r)//查找[ql,qr]与[l,r]的公共部分的最大后缀连续0的长度{    if(ql<=l && r<=qr)        return suf[i];    PushDown(i,l,r);    int m=(l+r)/2;    if(qr<=m) return query_suf(ql,qr,lson);    if(m<ql) return query_suf(ql,qr,rson);    int R = query_suf(ql,qr,rson);    if(R == min(r,qr)-m ) R +=query_suf(ql,qr,lson);    return R;} int main(){    int n,m;    while(scanf("%d%d",&n,&m)==2)    {        build(root);        stack<int> sq;        while(m--)        {            char str[10];            int x;            scanf("%s",str);            if(str[0]=='D')            {                scanf("%d",&x);                sq.push(x);                update(x,1,root);            }            else if(str[0]=='Q')            {                scanf("%d",&x);                int L=query_suf(1,x,root);                int R=query_pre(x,n,root);                if(L==0)                    printf("0\n");                else                    printf("%d\n",L+R-1);            }            else if(str[0]=='R')            {                if(!sq.empty())                {                    int x= sq.top();sq.pop();                    update(x,0,root);                }            }        }    }    return 0;}

0 0