主席树讲解

来源:互联网 发布:铁岭网络招聘 编辑:程序博客网 时间:2024/06/05 03:05



——————————————————————————————————————————————————————————————————————————————————————————————

以下转自http://prominences.weebly.com/1/post/2013/02/1.html

可持久化线段树,也叫作函数式线段树,也就是主席树,(。。。因为先驱就是fotile主席。。Orz。。。)
网上的教程很少啊,有的教程写得特别简单,4行中文,然后就是一篇代码~~
这里,我将从查找区间第k小值(不带修改)题的可持久化线段树做法中,讲一讲主席树。
/*只是略懂,若有错误,还请多多包涵!*/
可持久化数据结构(Persistent data structure)就是利用函数式编程的思想使其支持询问历史版本、同时充分利用它们之间的共同数据来减少时间和空间消耗。/*找不到比较科学的定义,就拿这个凑凑数吧~~~*/
这个数据结构很坑啊,我研究了一整天才差不多理解了一些(。。太笨了。。。)。所以,要理解好每一个域或变量的意义。
开讲!
一些数据结构,比如线段树或平衡树,他们一般是要么维护每个元素在原序列中的排列顺序,要么是维护每个元素的大小顺序,若是像二者兼得。。(反正我是觉得很。。)那么,这道题就想想主席树吧~/*还可以用划分树做*/
开讲!~好像说过一边了
既然叫函数式线段树,那么就应该有跟普通线段树相同的地方。一颗线段树,只能维护一段区间里的元素。但是,每个询问的区间都不一样,若是对每段区间都单独建立的线段树,那~萎定了~。因此,就要想,如何在少建,或建得快的情况下,能利用一些方法,得出某个区间里的情况。
比如一棵线段树,记为tree[i][j],表示区间[i,j]的线段树。那么,要得到它的情况,可以利用另外两棵树,tree[1][i-1]tree[1][j],得出来。也就是说,可以由建树的一系列历史版本推出。
那么,怎么创建这些树呢?
首先,离散化数据。因为如果数据太大的话,线段树会爆~~
在所有树中,是按照当前区间元素的离散值(也就是用大小排序)储存的,在每个节点,存的是这个区间每个元素出现的次数之和(data域)。出现的次数,也就是存了多少数进来(建树时,是一个数一个数地存进来的)。
先建议棵线段树,所有的节点data域为0。再一个节点一个节点地添加。把每个数按照自己的离散值,放到树中合适的位置,然后data+1,回溯的时候也要+1。当然,不能放到那棵空树中,要重新建树。第i棵树存的是区间(原序列)[1,i]。但是,如果是这样,那么会MLE+TLE。因此,要充分利用历史版本。用两个指针,分指当前空树和前一棵树。因为每棵树的结构是一样的,只是里面的data域不同,但是两棵相邻的树,只有一数只差,因此,如果元素要进左子树的话,右子树就会跟上个树这个区间的右子树是完全一样的,因此,可以直接将本树本节点的右子树指针接到上棵树当前节点的右儿子,这样即省时间,又省空间。
每添加一个节点(也就是新建一棵树)的复杂度是O(logn),因此,这一步的复杂度是O(nlogn)
建完之后,要怎么查找呢?
跟一般的,在整棵树中找第k个数是一样的。如果一个节点的左权值(左子树上点的数量之和)大于k,那么就到左子树查找,否则到右子树查找。其实主席树是一样的。对于任意两棵树(分别存区间[1,i]和区间[1,j] i<j),在同一节点上(两节点所表示的区间相同),data域之差表示的是,原序列区间[i,j]在当前节点所表示的区间里,出现多少次(有多少数的大小是在这个区间里的)。同理,对于同一节点,如果在两棵树中,它们的左权值之差大于等于k,那么要求的数就在左孩子,否则在右孩子。当定位到叶子节点时,就可以输出了。

———————————————————————————————————————————————————————————————————————————————————————————


鄙人的一些理解:所谓主席树呢,就是对原来的数列[1..n]的每一个前缀[1..i](1≤i≤n)建立一棵线段树,线段树的每一个节点存某个前缀[1..i]中属于区间[L..R]的数一共有多少个(比如根节点是[1..n],一共i个数,sum[root] = i;根节点的左儿子是[1..(L+R)/2],若不大于(L+R)/2的数有x个,那么sum[root.left] = x)。若要查找[i..j]中第k大数时,设某结点x,那么x.sum[j] - x.sum[i - 1]就是[i..j]中在结点x内的数字总数。而对每一个前缀都建一棵树,会MLE,观察到每个[1..i]和[1..i-1]只有一条路是不一样的,那么其他的结点只要用回前一棵树的结点即可,时空复杂度为O(nlogn)。

 

代码(最原始的树所有结点的值都为0,就算建好一棵树了……):

复制代码
 1 #include <cstdio> 2 #include <algorithm> 3 using namespace std; 4  5 const int MAXN = 100010; 6  7 struct Node { 8     int L, R, sum; 9 };10 Node T[MAXN * 20];11 int T_cnt;12 13 void insert(int &num, int &x, int L, int R) {14     T[T_cnt++] = T[x]; x = T_cnt - 1;15     ++T[x].sum;16     if(L == R) return ;17     int mid = (L + R) >> 1;18     if(num <= mid) insert(num, T[x].L, L, mid);19     else insert(num, T[x].R, mid + 1, R);20 }21 22 int query(int i, int j, int k, int L, int R) {23     if(L == R) return L;24     int t = T[T[j].L].sum - T[T[i].L].sum;25     int mid = (R + L) >> 1;26     if(k <= t) return query(T[i].L, T[j].L, k, L, mid);27     else return query(T[i].R, T[j].R, k - t, mid + 1, R);28 }29 30 struct A {31     int x, idx;32     bool operator < (const A &rhs) const {33         return x < rhs.x;34     }35 };36 37 A a[MAXN];38 int rank[MAXN], root[MAXN];39 int n, m;40 41 int main() {42     T[0].L = T[0].R = T[0].sum = 0;43     root[0] = 0;44     while(scanf("%d%d", &n, &m) != EOF) {45         for(int i = 1; i <= n; ++i) {46             scanf("%d", &a[i].x);47             a[i].idx = i;48         }49         sort(a + 1, a + n + 1);50         for(int i = 1; i <= n; ++i) rank[a[i].idx] = i;51         T_cnt = 1;52         for(int i = 1; i <= n; ++i) {53             root[i] = root[i - 1];54             insert(rank[i], root[i], 1, n);55         }56         while(m--) {57             int i, j, k;58             scanf("%d%d%d", &i, &j, &k);59             printf("%d\n", a[query(root[i - 1], root[j], k, 1, n)].x);60         }61     }62 }
复制代码

/* ***********************************************Author        :kuangbinCreated Time  :2013-9-5 23:54:37File Name     :F:\2013ACM练习\专题学习\主席树\SPOJ_DQUERY.cpp************************************************ */#include <stdio.h>#include <string.h>#include <iostream>#include <algorithm>#include <vector>#include <queue>#include <set>#include <map>#include <string>#include <math.h>#include <stdlib.h>#include <time.h>using namespace std;/* * 给出一个序列,查询区间内有多少个不相同的数 */const int MAXN = 30010;const int M = MAXN * 100;int n,q,tot;int a[MAXN];int T[M],lson[M],rson[M],c[M];int build(int l,int r){    int root = tot++;    c[root] = 0;    if(l != r)    {        int mid = (l+r)>>1;        lson[root] = build(l,mid);        rson[root] = build(mid+1,r);    }    return root;}int update(int root,int pos,int val){    int newroot = tot++, tmp = newroot;    c[newroot] = c[root] + val;    int l = 1, r = n;    while(l < r)    {        int mid = (l+r)>>1;        if(pos <= mid)        {            lson[newroot] = tot++; rson[newroot] = rson[root];            newroot = lson[newroot]; root = lson[root];            r = mid;        }        else        {            rson[newroot] = tot++; lson[newroot] = lson[root];            newroot = rson[newroot]; root = rson[root];            l = mid+1;        }        c[newroot] = c[root] + val;    }    return tmp;}int query(int root,int pos){    int ret = 0;    int l = 1, r = n;    while(pos < r)    {        int mid = (l+r)>>1;        if(pos <= mid)        {            r = mid;            root = lson[root];        }        else        {            ret += c[lson[root]];            root = rson[root];            l = mid+1;        }    }    return ret + c[root];}int main(){    //freopen("in.txt","r",stdin);    //freopen("out.txt","w",stdout);    while(scanf("%d",&n) == 1)    {        tot = 0;        for(int i = 1;i <= n;i++)            scanf("%d",&a[i]);        T[n+1] = build(1,n);        map<int,int>mp;        for(int i = n;i>= 1;i--)        {            if(mp.find(a[i]) == mp.end())            {                T[i] = update(T[i+1],i,1);            }            else            {                int tmp = update(T[i+1],mp[a[i]],-1);                T[i] = update(tmp,i,1);            }            mp[a[i]] = i;        }        scanf("%d",&q);        while(q--)        {            int l,r;            scanf("%d%d",&l,&r);            printf("%d\n",query(T[l],r));        }    }    return 0;}



原创粉丝点击