poj 2104 归并树(线段树)

来源:互联网 发布:centos 查看时区 编辑:程序博客网 时间:2024/05/21 17:28
 * 题意:给定一个序列key[1..n]和m个询问{s,t,rank}(1 <= n <= 100 000, 1 <= m <= 5 000),对于每个询问输出区间[s,t]中第rank小的值

分析:由于2761和这题差不多,且数据量是这题的10倍,所以我一开始就把2761的SBT代码交上去,结果竟然是TLE,估计是栽在了"Case Time Limit: 2000MS"上面了。最终还是用了别人的思路,由此接触到一种很巧妙的结构:归并树

归并树可以用简单的一句话概括:利用类似线段树的树型结构记录合并排序的过程。

回顾一下如何利用归并树解决这道题:

1,建立归并树后我们得到了序列key[]的非降序排列,由于此时key[]内元素的rank是非递减的,因此key[]中属于指定区间[s,t]内的元素的rank也是非递减的,所以我们可以用二分法枚举key[]中的元素并求得它在[s,t]中的 rank值,直到该rank值和询问中的rank值相等;
2,那对于key[]中的某个元素val,如何求得它在指定区间[s,t]中的rank?这就要利用到刚建好的归并树:我们可以利用类似线段树的 query[s,t]操作找到所有属于[s,t]的子区间,然后累加val分别在这些子区间内的rank,得到的就是val在区间[s,t]中的 rank,注意到这和合并排序的合并过程一致;
3,由于属于子区间的元素的排序结果已经记录下来,所以val在子区间内的rank可以通过二分法得到。

上面三步经过了三次二分操作(query也是种二分),于是每次询问的复杂度是O(log n * log n * log n)

PS:写二分查找时要注意细节。。
 */
//============================================================================

#include <iostream>#include <algorithm>using namespace std;#define MAX 100005int seg[17 + 5][MAX];struct NODE{    int l, r;    int mid()    {        return (l + r) >> 1;    }};NODE tree[MAX * 3];int str[MAX + 1];int s, t;void build(int v, int l, int r, int deep){    tree[v].l = l;    tree[v].r = r;    if (l == r)    {        seg[deep][l] = str[l];        return;    }    int mid = tree[v].mid();    build(v * 2, l, mid, deep + 1);    build(v * 2 + 1, mid + 1, r, deep + 1);    int i = l, j = mid + 1;    int k = l;    while (i <= mid && j <= r)    {        if (seg[deep + 1][i] < seg[deep + 1][j])        {            seg[deep][k++] = seg[deep + 1][i++];        }        else        {            seg[deep][k++] = seg[deep + 1][j++];        }    }    if (i == mid + 1)        while (j <= r)            seg[deep][k++] = seg[deep + 1][j++];    else        while (i <= mid)            seg[deep][k++] = seg[deep + 1][i++];}int find(int v, int deep, int val){    if (s <= tree[v].l && t >= tree[v].r)    {        return lower_bound(&seg[deep][tree[v].l], &seg[deep][tree[v].r] + 1,                val) - &seg[deep][tree[v].l];    }    int res = 0;    if (s <= tree[v * 2].r)        res += find(v * 2, deep + 1, val);    if (t >= tree[v * 2 + 1].l)        res += find(v * 2 + 1, deep + 1, val);    return res;}int main(){    freopen("in", "r", stdin);    int n, m, q;    int i, l, r, pos, k;    while (scanf("%d %d", &n, &q) != EOF)    {        for (i = 1; i <= n; i++)        {            scanf("%d", str + i);        }        build(1, 1, n, 1);        while (q--)        {            scanf("%d %d %d", &s, &t, &k);            k--;            l = 1, r = n;            while (l < r)            {                m = (l + r + 1) >> 1;                pos = find(1, 1, seg[1][m]);                if (pos <= k)                    l = m;                else                    r = m - 1;            }            printf("%d/n", seg[1][l]);        }    }    return 0;}


原创粉丝点击