[BZOJ1500][NOI2005]维修数列(splay)

来源:互联网 发布:pdf图片提取软件 编辑:程序博客网 时间:2024/05/05 19:09

题目描述

传送门

题解

splay模板题(笑
这题是我今年1月左右看到的。记得当时只调过了样例然后交上去过了一个点就用了我一天的时间。果然我还是太傻逼。
填上了好大一个坑。
前5个都是splay的基础操作,就是第6个比较麻烦。因为要维护区间最大连续子序列和,可以维护3个值:Max子树最大连续子序列和,maxl子树自左端起最大连续子序列和,maxr子树自右端起最大连续子序列和。注意这里的Max是不允许有空串的,因为这个即为最后的答案,但是maxl和maxr都是允许有空串的,这样的话update的时候比较好搞。

坑点:
①覆盖的值可能为0,所以覆盖的标记刚开始不能为0.
②由于Max是不允许有空串的,而maxl和maxr是允许有空串的,所以刚开始要将Max[0]赋为-inf,否则的话如果某个点没有儿子的话求Max的时候会出错。
③打翻转标记的含义为“对当前点做过但是对当前点的儿子没有做过”,也就是说打了标记的这个点,它的左右两个儿子已经交换了,但是它的儿子的儿子没有交换。这个和平常写splay的时候有区别。原因在于,在更新x的时候x的两个儿子的maxl和maxr对x是有影响的,所以不能在没有pushdown的情况下直接更新。不要忘了翻转的时候maxl和maxr也是要交换的。
④内存池的使用:每一次删除一段区间,实际上是将一棵树丢到内存池里,每次从内存池中弹出一个点时,要将它的两个儿子的子树重新丢到内存池里。

代码

#include<iostream>#include<cstring>#include<cstdio>using namespace std;#define N 500005#define inf 1000000000int n,m,pos,tot,c,root;int l,r,q[5000005];int a[N],b[N],f[N],ch[N][2],size[N],key[N],same[N],rev[N],sum[N],Max[N],maxl[N],maxr[N];char opt[20];void clear(int x){    f[x]=ch[x][0]=ch[x][1]=size[x]=key[x]=rev[x]=sum[x]=maxl[x]=maxr[x]=0;    same[x]=inf;    Max[x]=-inf;}void delp(int x){    q[++r]=x;}int newp(){    ++l;    if (ch[q[l]][0]) delp(ch[q[l]][0]);    if (ch[q[l]][1]) delp(ch[q[l]][1]);    clear(q[l]);    return q[l];}int get(int x){    return ch[f[x]][1]==x;}void update(int x){    size[x]=size[ch[x][0]]+size[ch[x][1]]+1;    sum[x]=key[x]+sum[ch[x][0]]+sum[ch[x][1]];    maxl[x]=max(maxl[ch[x][0]],sum[ch[x][0]]+key[x]+maxl[ch[x][1]]);    maxr[x]=max(maxr[ch[x][1]],sum[ch[x][1]]+key[x]+maxr[ch[x][0]]);    Max[x]=max(Max[ch[x][0]],Max[ch[x][1]]);    Max[x]=max(Max[x],maxr[ch[x][0]]+key[x]+maxl[ch[x][1]]);}void pushdown(int x){    if (x)    {        if (same[x]!=inf)        {            if (ch[x][0])            {                key[ch[x][0]]=same[x];                sum[ch[x][0]]=same[x]*size[ch[x][0]];                if (same[x]>=0)                    Max[ch[x][0]]=maxl[ch[x][0]]=maxr[ch[x][0]]=sum[ch[x][0]];                else                    Max[ch[x][0]]=same[x],maxl[ch[x][0]]=maxr[ch[x][0]]=0;                same[ch[x][0]]=same[x];            }            if (ch[x][1])            {                key[ch[x][1]]=same[x];                sum[ch[x][1]]=same[x]*size[ch[x][1]];                if (same[x]>=0)                    Max[ch[x][1]]=maxl[ch[x][1]]=maxr[ch[x][1]]=sum[ch[x][1]];                else                    Max[ch[x][1]]=same[x],maxl[ch[x][1]]=maxr[ch[x][1]]=0;                same[ch[x][1]]=same[x];            }            same[x]=inf;        }        if (rev[x])        {            if (ch[x][0])            {                int now=ch[x][0];                rev[now]^=1;                swap(ch[now][0],ch[now][1]);                swap(maxl[now],maxr[now]);            }            if (ch[x][1])            {                int now=ch[x][1];                rev[now]^=1;                swap(ch[now][0],ch[now][1]);                swap(maxl[now],maxr[now]);            }            rev[x]=0;        }    }}int build(int l,int r,int fa,int *a){    if (l>r) return 0;    int mid=(l+r)>>1,now=newp();    f[now]=fa;key[now]=a[mid];Max[now]=a[mid];    int lch=build(l,mid-1,now,a);    int rch=build(mid+1,r,now,a);    ch[now][0]=lch,ch[now][1]=rch;    update(now);    return now;}void rotate(int x){    pushdown(f[x]);    pushdown(x);    int old=f[x],oldf=f[old],wh=get(x);    ch[old][wh]=ch[x][wh^1];    if (ch[old][wh]) f[ch[old][wh]]=old;    ch[x][wh^1]=old;    f[old]=x;    if (oldf) ch[oldf][ch[oldf][1]==old]=x;    f[x]=oldf;    update(old);    update(x);}void splay(int x,int tar){    for (int fa;(fa=f[x])!=tar;rotate(x))        if (f[fa]!=tar)            rotate( (get(x)==get(fa))?fa:x );    if (!tar) root=x;}int find(int x){    int now=root;    while (1)    {        pushdown(now);        if (x<=size[ch[now][0]]) now=ch[now][0];        else        {            x-=size[ch[now][0]];            if (x==1) return now;            --x;            now=ch[now][1];        }    }}int main(){    freopen("input.in","r",stdin);    freopen("my.out","w",stdout);    scanf("%d%d",&n,&m);    a[1]=-inf;a[n+2]=inf;Max[0]=-inf;    for (int i=1;i<=n;++i) scanf("%d",&a[i+1]);    for (int i=1;i<=500000;++i) delp(i);    root=build(1,n+2,0,a);    for (int i=1;i<=m;++i)    {        scanf("%s",opt);        switch(opt[0])        {            case 'I':                {                    scanf("%d%d",&pos,&tot);                    for (int i=1;i<=tot;++i) scanf("%d",&b[i]);                    int now=build(1,tot,0,b);                    int aa=find(pos+1);                    int bb=find(pos+2);                    splay(aa,0);                    splay(bb,aa);                    ch[ch[root][1]][0]=now;                    f[now]=ch[root][1];                    update(ch[root][1]);                    update(root);                    break;                }            case 'D':                {                    scanf("%d%d",&pos,&tot);                    int aa=find(pos);                    int bb=find(pos+tot+1);                    splay(aa,0);                    splay(bb,aa);                    delp(ch[ch[root][1]][0]);                    ch[ch[root][1]][0]=0;                    update(ch[root][1]);                    update(root);                    break;                }            case 'R':                {                    scanf("%d%d",&pos,&tot);                    if (tot==1) continue;                    int aa=find(pos);                    int bb=find(pos+tot+1);                    splay(aa,0);                    splay(bb,aa);                    int now=ch[ch[root][1]][0];                    swap(ch[now][0],ch[now][1]);                    swap(maxl[now],maxr[now]);                    rev[now]^=1;                    break;                }            case 'G':                {                    scanf("%d%d",&pos,&tot);                    int aa=find(pos);                    int bb=find(pos+tot+1);                    splay(aa,0);                    splay(bb,aa);                    int ans=sum[ch[ch[root][1]][0]];                    printf("%d\n",ans);                    break;                }            case 'M':                {                    if (opt[2]=='K')                    {                        scanf("%d%d%d",&pos,&tot,&c);                        int aa=find(pos);                        int bb=find(pos+tot+1);                        splay(aa,0);                        splay(bb,aa);                        int now=ch[ch[root][1]][0];                        key[now]=c;                        sum[now]=c*size[now];                        if (c>=0)                            Max[now]=maxl[now]=maxr[now]=sum[now];                        else                            Max[now]=c,maxl[now]=maxr[now]=0;                        same[now]=c;                        update(ch[root][1]);                        update(root);                    }                    else                    {                        int aa=find(1);                        int bb=find(size[root]);                        splay(aa,0);                        splay(bb,aa);                        int ans=Max[ch[ch[root][1]][0]];                        printf("%d\n",ans);                    }                    break;                }        }    }}
0 0
原创粉丝点击