对主席树的一点理解 -- 例题POJ 2104

来源:互联网 发布:淘宝会员中心在哪里 编辑:程序博客网 时间:2024/05/22 06:10

断断续续看了许久的主席树,简单记录一下。

什么样的题目用主席树呢,比如POJ 2104 求区间第K大的数是谁?

当时做这个题时,感觉分块+二分可以搞,就写了好久,改了好久,始终TLE,还是学学主席树把= =。

先吐槽一下线段树:

线段树竟然是被一个黄嘉泰的大佬因不会划分树来代替的,,,,,因缩写是HJT取名为主席树= =!orz

主席树大体思路:

我们怎样求区间第K大数呢:

假如我们能够利用前缀和思想,把每个数都以它为根建立一棵线段树,上面能够统计第几大的数出现的次数的话,我们就可以根据像二叉平衡树那样查找第K大的数。


假如我们建立一个线段树,区间划分和正常的线段树一样,只不过区间代表的是第几大数,不在是区间上的点。

先声明一下线段树的结点:

struct Node{    int l,r,sum;}p[maxn*40];
// L代表的是“第几大的左端点”,R代表的是“第几大的右端点”,sum 代表是1~i个数在[L,R]这个第几大区间上出现的次数。

那么我们就可以n 次更新(修改)线段树 ,根据这个数在所有数中第几大来确定往左树走还是往右树上走,直到走到叶子结点(L == R)为止。

n 次更新线段树,因为每次都是往左或者往右 这是一个log级别的次数,因为我们只需要nlogn 个空间就可以完成一个有n 个版本的线段树。

假设root[i]表示以第i 个数为根的 根节点编号。

那么我们的update函数就可以这样写:

先建立一个新结点,和root[i-1]相等,这样就可以让新结点的左右子树和root[i-1]的一样了,相当于是引用,部分引用,部分修改。然后这样修改新结点的sum变量就可以了。 这样一直走一直走,根据第i 个数的第几大来确定往左还是往右走。

int update(int l,int r,int c,int k){    int nc = ++cnt;    p[nc] = p[c];    p[nc].sum++;    int mid = l+r>>1;    if (l == r) return nc;    if (mid >= k) p[nc].l = update(l,mid,p[c].l,k);    else p[nc].r = update(mid+1,r,p[c].r,k);    return nc;}

query查询函数:

比如说我们要查询[x,y]这个区间上第k 大的数:

刚开始我们肯定从根结点(1~n)开始。

我们先求出[x,y]区间上,第1~mid大的数有几个(假设有sum 个),如果sum >= k,那么这个第k大数肯定在根节点的左子树,否则如果sum < k,第k 大数肯定在右子树上。那么问题就是如何求一个区间上的1~mid 大的呢?

根据我们建树的性质,我们让第y 个版本的线段树目前结点(根结点)的左儿子的sum 减去 第x-1个版本的线段树目前结点(根节点)的左儿子的sum。这个差就是[x,y]区间上1~mid大数的个数。这样我们找到最后找到一个叶子结点L==R,那么这个L 就是原数组离散化后的下标。

int query(int l,int r,int x,int y,int k){    if (l == r) return l;    int mid = l + r >> 1;    int sum = p[p[y].l ].sum - p[p[x].l ].sum;    if (sum >= k) return query(l,mid,p[x].l,p[y].l,k);    else return query(mid+1,r,p[x].r,p[y].r,k-sum);}

好了,这样区间第K大数就解决了,其实理解了,感觉很巧妙的= =。

常用的离散化方法:

    sort(v.begin(),v.end());    v.erase(unique(v.begin(),v.end()),v.end());

参考代码:

#include <cstdio>#include <cstring>#include <algorithm>#include <vector>using namespace std;const int maxn = 1e5 + 10;int root[maxn];vector<int>v;struct Node{    int l,r,sum;}p[maxn*40];int cnt = 0;int build(int l,int r){    int rt = ++cnt;    p[rt].sum = 0;    p[rt].l = p[rt].r = 0;    if (l == r) return rt;    int mid = l+r>>1;    p[rt].l = build(l,mid);    p[rt].r = build(mid+1,r);    return rt;}int a[maxn];int getid(int x){    return lower_bound(v.begin(),v.end(),x) - v.begin() + 1;}int update(int l,int r,int c,int k){    int nc = ++cnt;    p[nc] = p[c];    p[nc].sum++;    int mid = l+r>>1;    if (l == r) return nc;    if (mid >= k) p[nc].l = update(l,mid,p[c].l,k);    else p[nc].r = update(mid+1,r,p[c].r,k);    return nc;}int query(int l,int r,int x,int y,int k){    if (l == r) return l;    int mid = l + r >> 1;    int sum = p[p[y].l ].sum - p[p[x].l ].sum;    if (sum >= k) return query(l,mid,p[x].l,p[y].l,k);    else return query(mid+1,r,p[x].r,p[y].r,k-sum);}int main(){    int n, q;    scanf("%d %d",&n, &q);    root[0] = build(1,n);    for (int i = 1; i <= n; ++i){        int x;        scanf("%d",&x);        a[i] = x;        v.push_back(x);    }    sort(v.begin(),v.end());    v.erase(unique(v.begin(),v.end()),v.end());    for (int i = 1; i <= n; ++i){        root[i] = update(1,n,root[i-1],getid(a[i]));    }    while(q--){        int x,y,k;        scanf("%d %d %d",&x, &y, &k);        int ans = query(1,n,root[x-1], root[y], k);        printf("%d\n",v[ans-1]);    }    return 0;}





1 0
原创粉丝点击