主席树(可持久化线段树)入门专题

来源:互联网 发布:家庭电路设计软件 编辑:程序博客网 时间:2024/05/18 02:30

1.poj 2104 查询区间第k小。

主席树其实相当于建立了n棵线段树,第i棵线段树是根据区间【1,i】按值建立的。对于每一棵线段树我们记录它对应的区间每个数出现的次数,所以首先要对所有的数离散化。

先考虑最简单的情况,只查询【1,n】的第k小,对于【1,n】我们按值建立一棵线段树,对于a[i]我们在位置a[i]上加1。查询第k小那么先看左子区间出现了多少个数cnt,假设左区间出现的数cnt>=k,那么直接递归到左区间查询(因为是按值建立的,左区间的数肯定小于右区间),否则递归到右区间查询第k-cnt小(左区间已经有了最小的cnt个数了)


对于任意区间查询【l,r】,我们只需要比较第l-1棵线段树和第r棵线段树,【l,r】之间的数就是第r棵线段树相比于第l-1棵多出来的数。只需要对比两颗树同一个节点,对比到哪个数为止第r棵比第l-1棵刚好多出k个数。(先比较左区间cntr-cntl,cntr-cntl>=k,则递归到左区间,否则递归查询右区间k-cntr-cntl)


主席树就相当于n棵线段树,但是对比建立在【1,i】的线段树和【1,i+1】的线段树,只多出了一个值,也就是相当于单点更新他们之间只有logn个节点是不同的,所以可以将【1,i+1】的一些节点指针指向前一棵的共同部分。每次新增的空间只需要logn。


代码是根据kuangbin模板抄的。。。。

#include <iostream>#include <cstdio>#include <cstring>#include <cmath>#include <cctype>#include <string>#include <vector>#include <map>#include <set>#include <vector>#include <queue>#include <stack>#include <algorithm>using namespace std;const int maxn=1e5+10;const int M=maxn*30;int n,q,m,tot;int a[maxn], t[maxn];int T[maxn], lson[M], rson[M], c[M];void init_hash(){    for(int i=1; i<=n; i++)        t[i]=a[i];    sort(t+1, t+n+1);    m=unique(t+1, t+n+1)-t-1;}int hash(int x){    return lower_bound(t+1, t+1+m, x)-t;}int build(int l, int r){    int rt=tot++;    c[rt]=0;    if(l!=r){        int mid=(l+r)>>1;        lson[rt]=build(l,mid);        rson[rt]=build(mid+1, r);    }    return rt;}int update(int rt, int pos, int val){    int newrt=tot++,tmp=newrt;    c[newrt]=c[rt]+val;    int l=1, r=m;    while(l<r){        int mid=(l+r)>>1;        if(pos<=mid){            lson[newrt]=tot++; rson[newrt]=rson[rt];            newrt=lson[newrt]; rt=lson[rt];            r=mid;        }        else{            rson[newrt]=tot++; lson[newrt]=lson[rt];            newrt=rson[newrt]; rt=rson[rt];            l=mid+1;        }        c[newrt]=c[rt]+val;    }    return tmp;}int query(int lrt, int rrt, int k){    int l=1, r=m;    while(l<r){        int mid=(l+r)>>1;        if(c[lson[lrt]]-c[lson[rrt]]>=k){            r=mid;            lrt=lson[lrt];            rrt=lson[rrt];        }        else{            l=mid+1;            k-=c[lson[lrt]]-c[lson[rrt]];            lrt=rson[lrt];            rrt=rson[rrt];        }    }    return l;}int main(){    while(cin>>n>>q){                for(int i=1; i<=n; i++)            scanf("%d", a+i);        init_hash();        tot=0;        T[n+1]=build(1,m);        for(int i=n; i; i--){            int pos=hash(a[i]);            T[i]=update(T[i+1], pos, 1);        }        while(q--){            int l,r,k;            scanf("%d%d%d", &l, &r, &k);            printf("%d\n", t[query(T[l], T[r+1], k)]);        }    }    return 0;}



2.hdu 4417 区间查询<=H的数有多少个

查询【l,r】区间只需要将第r棵线段树【0,H】区间的总数减去第l-1棵的就行了。

#include <iostream>#include <cstdio>#include <cstring>#include <cmath>#include <cctype>#include <string>#include <vector>#include <map>#include <set>#include <vector>#include <queue>#include <stack>#include <algorithm>using namespace std;const int maxn=1e5+10;const int maxm=maxn*30;int n,m, N;int a[maxn],b[2*maxn];int T[maxn],tot;int lson[maxm],rson[maxm], cnt[maxm];int build(int l, int r){    int rt=tot++;    cnt[rt]=0;    if(l==r) return rt;    int mid=(l+r)>>1;    lson[rt]=build(l, mid);    rson[rt]=build(mid+1, r);    return rt;}int update(int rt, int pos, int v){    int newrt=tot++, ret=newrt;    int l=1, r=N;    cnt[newrt]=cnt[rt]+v;    while(l<r){        int mid=(l+r)>>1;        if(pos<=mid){            lson[newrt]=tot++; rson[newrt]=rson[rt];            newrt=lson[newrt]; rt=lson[rt];            r=mid;        }        else{            lson[newrt]=lson[rt]; rson[newrt]=tot++;            newrt=rson[newrt]; rt=rson[rt];            l=mid+1;        }        cnt[newrt]=cnt[rt]+v;    }    return ret;}int query(int lrt, int rrt, int pos){    int ret=0;    int l=1,r=N;    while(l<r){        int mid=(l+r)>>1;        if(pos<=mid){            lrt=lson[lrt]; rrt=lson[rrt];            r=mid;        }        else{            ret+=cnt[lson[rrt]]-cnt[lson[lrt]];            lrt=rson[lrt]; rrt=rson[rrt];            l=mid+1;        }    }    ret+=cnt[rrt]-cnt[lrt];    return ret;}int l[maxn], r[maxn], h[maxn];int main(){    int t;    cin>>t;    for(int tt=1; tt<=t; tt++){        cin>>n>>m;        for(int i=1; i<=n; i++){            scanf("%d", a+i);            b[i]=a[i];        }        for(int i=0; i<m; i++){            scanf("%d%d%d", l+i, r+i, h+i);            b[n+1+i]=h[i];        }        sort(b+1, b+1+n+m);        N=unique(b+1, b+1+n+m)-b-1;        tot=0;        T[0]=build(1,N);        for(int i=1; i<=n; i++){            int v=lower_bound(b+1, b+1+N, a[i])-b;            T[i]=update(T[i-1], v, 1);        }        printf("Case %d:\n", tt);        for(int i=0; i<m; i++){            int v=lower_bound(b+1, b+1+N, h[i])-b;            printf("%d\n", query(T[l[i]], T[r[i]+1], v));        }    }    return 0;}



3.hdu 4348 可持久化线段树,区间更新,不下放的懒惰标记(空间优化)

主席树其实就是可持久化线段树。可持久化就是每次修改操作,尽量用新节点表示而不是直接修改原来的点,这样所有的历史版本都得以保留。

主要麻烦的就是区间更新。区间更新对于完全覆盖的区间要用lazy标记。但是每次lazy下放的时候两个子区间都发生修改需要创造两个新的节点,这样到最后下放到最后一层相当于消耗了O(n)个新节点,空间会爆。

这道题用的空间优化就是不下放标记。标记就打在那个区间节点上,而查询的时候,往下递归时遇到标记就累加,最后把标记的影响加到总答案里。这样就不需要创造那么多新节点了。

#include <iostream>#include <cstdio>#include <cstring>#include <cmath>#include <cctype>#include <string>#include <vector>#include <map>#include <set>#include <vector>#include <queue>#include <stack>#include <algorithm>using namespace std;const int maxn=1e5+1000;const int maxm=30*maxn;typedef long long LL;int n, m;int a[maxn];int T[maxn], tot=0;int lson[maxm],rson[maxm], lazy[maxm];LL sum[maxm];void push_up(int rt, int l, int r){    sum[rt]=sum[lson[rt]]+sum[rson[rt]]+(LL)lazy[rt]*(r-l+1);}int build(int l, int r){    int rt=tot++;    lazy[rt]=0;    if(l==r){        sum[rt]=a[l];        return rt;    }    int mid=(l+r)>>1;    lson[rt]=build(l,mid);    rson[rt]=build(mid+1, r);    push_up(rt, l, r);    return rt;}int update(int rt, int l, int r, int ll, int rr, int v){    int newrt=tot++;    lazy[newrt]=lazy[rt];    if(ll<=l && r<=rr){        lson[newrt]=lson[rt], rson[newrt]=rson[rt];        lazy[newrt]=lazy[rt]+v;        sum[newrt]=sum[rt]+(LL)v*(r-l+1);        return newrt;    }    int mid=(l+r)>>1;    if(rr<=mid){        rson[newrt]=rson[rt];        lson[newrt]=update(lson[rt], l, mid, ll, rr, v);    }    else if(ll>mid){        lson[newrt]=lson[rt];        rson[newrt]=update(rson[rt], mid+1, r,ll, rr, v);    }    else{        lson[newrt]=update(lson[rt], l, mid, ll, mid, v);        rson[newrt]=update(rson[rt], mid+1, r, mid+1, rr, v);    }    push_up(newrt, l, r);    return newrt;}LL query(int rt, int l, int r, int ll, int rr, int  la){    if(ll<=l && r<=rr){        return sum[rt]+(LL)la*(r-l+1);    }    la+=lazy[rt];    int mid=(l+r)>>1;    if(rr<=mid)        return query(lson[rt], l, mid, ll, rr, la);    else if(ll>mid)        return query(rson[rt], mid+1, r, ll, rr, la);    else        return query(lson[rt], l, mid, ll, mid, la)+query(rson[rt], mid+1, r, mid+1, rr, la);}int main(){    while(cin>>n>>m){        for(int i=1; i<=n; i++)            scanf("%d", a+i);        tot=0;        T[0]=build(1,n);        int tag=0;        while(m--){            char s[5];            int l,r,d;            scanf("%s%d", s, &l);            if(s[0]=='B'){                tag=l;            }            else{                scanf("%d", &r);                if(s[0]=='Q'){                    LL res=query(T[tag], 1, n, l, r, 0);                    printf("%I64d\n", res);                }                else{                    scanf("%d", &d);                    if(s[0]=='C'){                        tag++;                        T[tag]=update(T[tag-1], 1, n, l, r, d);                    }                    else{                        LL res=query(T[d], 1, n, l, r, 0);                        printf("%I64d\n", res);                    }                }            }        }    }    return 0;}



4 0
原创粉丝点击