hdu4747 mex 线段树

来源:互联网 发布:jsp中嵌入java代码 编辑:程序博客网 时间:2024/06/06 00:09

题意:给一个序列不超过200000个元素,定义mex(i,j)是区间[i,j]之间所没有的最小非负整数。求sum(mex[i,j])对于所有1<=i<=j<=n;

解法:线段树。先求出mex(1,1),mex(1,2),mex(1,3)...mex(1,n) 而且这必然是递增的。

   然后 sum[i=1,1<=j<=n]就算出来了,然后去掉arr[1],这时候会影响到的是下一个arr[1]出现前mex值大于arr[1]的那些位置,而且由于mex具有单调性,如果有必然是连续的一个区间,所以区间修改即可。修改完,讲arr[1]置0,求和就是sum[i=2,2<=j<=n],以此不断往后计算。就得到了sum(mex[1<=i<=n,i<=j<=n]);


代码:

/******************************************************* @author:xiefubao*******************************************************/#pragma comment(linker, "/STACK:102400000,102400000")#include <iostream>#include <cstring>#include <cstdlib>#include <cstdio>#include <queue>#include <vector>#include <algorithm>#include <cmath>#include <map>#include <set>#include <stack>#include <string.h>//freopen ("in.txt" , "r" , stdin);using namespace std;#define eps 1e-8#define zero(_) (abs(_)<=eps)const double pi=acos(-1.0);typedef long long LL;const int Max=200010;const LL INF=0x3FFFFFFF;int arr[Max];int mex[Max];map<int,int> maps;struct node{    LL sum;//表示和    int ma;//表示区间最大值    int lazy;//表示区间值一样    int l,r;    node* pl,*pr;} tree[3*Max];int tot=0;int next[Max];void build(node* p,int left,int right){    p->l=left;    p->r=right;    if(left==right)    {        p->ma=mex[left];        p->sum=mex[left];        p->lazy=0;        return ;    }    int middle=(left+right)/2;    tot++;    p->pl=tree+tot;    build(p->pl,left,middle);    tot++;    p->pr=tree+tot;    build(p->pr,middle+1,right);    p->sum=p->pl->sum+p->pr->sum;    p->ma=max(p->pl->ma,p->pr->ma);}void update(node* p,int value,int left,int right);void down(node* p){    if(p->l==p->r)        return ;    p->lazy=0;    int middle=(p->l+p->r)/2;    update(p->pl,p->ma,p->l,middle);    update(p->pr,p->ma,middle+1,p->r);}int findpos(node* p,int value){    if(p->l==p->r)    {        if(p->ma>=value)            return p->l;    }    if(p->lazy)        down(p);    if(p->pl->ma>value)        return findpos(p->pl,value);    return findpos(p->pr,value);}void update(node* p,int value,int left,int right){    if(p->l==left&&p->r==right)    {        p->lazy=1;        p->ma=value;        p->sum=(p->r-p->l+1)*value;        return;    }    LL ans=0;    int middle=(p->l+p->r)/2;    if(p->lazy)    {        down(p);    }    if(left>middle)    {        update(p->pr,value,left,right);    }    else if(right<=middle)        update(p->pl,value,left,right);    else    {        update(p->pl,value,left,middle);        update(p->pr,value,middle+1,right);    }    p->sum=p->pl->sum+p->pr->sum;    p->ma=max(p->pl->ma,p->pr->ma);}int n;int main(){    while(cin>>n&&n)    {        tot=0;        memset(tree,0,sizeof tree);        memset(next,0,sizeof next);        maps.clear();        for(int i=1; i<=n; i++)            scanf("%d",arr+i);        int p=0;        for(int i=1; i<=n; i++)        {            next[maps[arr[i]]]=i;            maps[arr[i]]=i;            while(maps.find(p)!=maps.end())p++;            mex[i]=p;        }        build(tree,1,n);        LL ans=0;        for(int i=1; i<=n; i++)        {            ans+=tree->sum;            if(next[i]==0) next[i]=n+1;            if(tree->ma>arr[i])            {                int pos=findpos(tree,arr[i]);                if(pos<=next[i]-1)                    update(tree,arr[i],pos,next[i]-1);            }            update(tree,0,i,i);        }        cout<<ans<<"\n";    }    return 0;}

0 0