区间第K值——主席树详解

来源:互联网 发布:万税 知乎 编辑:程序博客网 时间:2024/05/21 10:42

序:这是一篇迟到的题解,机房的小伙伴们系统地学主席树应该是七月份的时候,然而我没赶上趟,当时压根没看懂主席树是什么东东。 昨天晚上决定重新来过,于是请教了一位大神1113(这是他的博客,不过好像因为手机验证的原因很久没有更新了),他告诉我了主席树的始末,然后我就秒懂了,原来并没有想象中的那么复杂,相信看完了这篇题解,你也会这么觉得的。下面开始正文:

Description

    给定一个长度为n的序列和m个询问,对于询问,输出区间[L,R]K大的数。

Solution

    对于一个问题,我们一般采取的方法就是化繁为简,这样更容易分析题目的特点,然后逐个击破。对于问题中的m个询问,我们可以先看看如何解决1个询问。

    对于一个询问(L,R,K),刚开始最容易想到的是二分答案,但是check的时候又不得不O(n)地扫一遍区间,总的复杂度就变成O(nlogn),这显然不是很优。我们完全可以O(n)地求出第k大的值,我们可以先用计数的方法记下区间中每种数的个数(在这之前要先把整个数组离散化),然后从大到小维护一个cnt的后缀和,表示大于或等于当前数字的数的个数,当这个后缀和大于或等于k时,我们也就找到了这个区间内第K大的数了。省去离散的部分,代码如下:

void solve(){    //离散(未写出)    //计数,用cnt计数     for(int i=L;i<=R;i++)cnt[A[i]]++;    //维护后缀和,len是unique之后数组长度     for(int i=len;i>=1;i--){        cnt[i]+=cnt[i+1];        if(cnt[i]>=K){            //找到第K大数             Pt(tmp[i]);            putchar('\n');            break;        }    }}

    上面方法的复杂度是O(n)的,假如用来解决m个询问的话总复杂度是O(nm),这样只能过70%的数据,那么能不能再优化一下呢?这里就要用到主席树了。

    对于上面一次询问O(n)的算法, 我们在两个地方用到了O(n),一个是在计数的时候,一个是在维护后缀和的时候。这里先讲怎么优化维护后缀和的算法,也就是假设我们已经计好区间内的数了,我们怎么更快地找到第K大数呢,这里只要维护一个线段树就行了,我们让线段树以离散后的权值为下标,然后在节点存下当前区间的数的个数sum,只要写一个find函数来从右往左找到第K个数即可。每次先询问右区间,假如右区间的sumK,那么就继续在右区间从右往左找第K个数;假如右区间的sum<K,那么就要在左区间从右往左找第Ksum个数,直到L=R。下面给出代码:

struct Segment_Tree{    struct node{        int L,R,sum;//sum表示[L,R]中元素总数    }tree[N<<2];    int find(int k,int p){        if(tree[p].L==tree[p].R)return tree[p].L;        if(tree[p<<1|1].sum>=k)//右区间元素个数大于或等于k,就往右区间找            return find(k,p<<1|1);        return find(k-tree[p<<1|1].sum,p<<1);//往左区间找    }}T;

    虽然说查询只有O(logn),但是我们建树要O(nlogn)啊!?不是的,我们并不是对于每次询问都建一个线段树,而是在询问前就已经建好了,这就是主席树。下面讲讲如何建树:我们最初要先build一棵空线段树,并且要动态开节点,而不是像之前的线段树一样用p2p2+1表示左右儿子。然后我们从左往右遍历整个数组,每遍历一个数就在之前的树的基础之上再建一棵树,然而对于一个数,他从根节点down到叶子节点最多只会经过logn个节点,所以其他节点的信息都不会改变,我们只要再多开logn个新节点,再加上原来的节点就可以表示出当前的线段树了,这就是动态开节点的妙处。我们用root[i]表示第i棵线段树的根节点编号,第i线段树就表示已经插入了[1,i]区间内节点的线段树,这就相当于是一棵前缀线段树,节点内存了[1,i]区间内各种数字的个数。查询[L,R]区间时,只要把第R棵线段树的sum减去第L1棵线段树的sum,就可以得到[L,R]区间的sum了。时间和空间复杂度都是O(nlogn)的,于是问题就迎刃而解了。

Code

#include<stdio.h>#include<algorithm>#include<iostream>#define N 30005#define M 30005using namespace std;template <class T>inline void Rd(T &res){    char c;res=0;int k=1;    while(c=getchar(),c<48&&c!='-');    if(c=='-'){k=-1;c='0';}    do{        res=(res<<3)+(res<<1)+(c^48);    }while(c=getchar(),c>=48);    res*=k;}template <class T>inline void Pt(T res){    if(res<0){        putchar('-');        res=-res;    }    if(res>=10)Pt(res/10);    putchar(res%10+48);}struct opr{    int L,R,k;}Q[M];int n,m,len;int A[N],tmp[N];struct Segment_Tree{    struct node{        //L是左区间编号,R是右区间编号,sum是当前区间的元素个数        int L,R,sum;            }tree[N*15];    int tot,root[N];//tot用来动态开节点,root存第i棵前缀线段树的节点编号    void build(int L,int R,int &p){        //[L,R]表示当前区间,p是当前节点编号        p=++tot;        tree[p].sum=0;        if(L==R)return;        int mid=(L+R)>>1;        build(L,mid,tree[p].L);        build(mid+1,R,tree[p].R);    }    void insert(int t,int L,int R,int x,int &p){        //t是原来节点编号,p是当前节点编号,[L,R]表示当前区间,x是要加的位置        p=++tot;        tree[p]=tree[t];        tree[p].sum++;        if(L==R)return;        int mid=(L+R)>>1;        if(x<=mid)insert(tree[p].L,L,mid,x,tree[p].L);        else insert(tree[p].R,mid+1,R,x,tree[p].R);    }    int find(int t,int L,int R,int k,int p){        //t表示第L-1棵线段树的节点,p表示第R棵线段树的节点,[L,R]表示当前区间        if(L==R)return L;        int mid=(L+R)>>1;        int cnt=tree[tree[p].R].sum-tree[tree[t].R].sum;        if(cnt>=k)return find(tree[t].R,mid+1,R,k,tree[p].R);        return find(tree[t].L,L,mid,k-cnt,tree[p].L);    }    Segment_Tree(){tot=0;}}T;void solve(){    //建树    T.build(1,len,T.root[0]);    for(int i=1;i<=n;i++)    T.insert(T.root[i-1],1,len,A[i],T.root[i]);    //查询    for(int i=1;i<=m;i++){        int L=Q[i].L,R=Q[i].R,k=Q[i].k;        Pt(tmp[T.find(T.root[L-1],1,len,k,T.root[R])]);        putchar('\n');    }}int main(){    Rd(n);Rd(m);    for(int i=1;i<=n;i++){        Rd(A[i]);        tmp[i]=A[i];    }    for(int i=1;i<=m;i++){        Rd(Q[i].L);Rd(Q[i].R);Rd(Q[i].k);    }    //离散    sort(tmp+1,tmp+n+1);    len=unique(tmp+1,tmp+n+1)-tmp-1;    for(int i=1;i<=n;i++)    A[i]=lower_bound(tmp+1,tmp+len+1,A[i])-tmp;    solve();    return 0;}