[bzoj 2588] Spoj 10628. Count on a tree:函数式线段树

来源:互联网 发布:ie11 找不到js 编辑:程序博客网 时间:2024/06/10 05:34

本题求一棵树上两点之间简单路径上第k小的点权,强制在线。函数式线段树(主席树)是一种可持久化数据结构。第i棵线段树保存前缀[1..i]的信息,在树上,前缀可扩展为结点i到根的简单路径。具体地,第i棵线段树中的位置x为a,意味着[1..i]中有a个元素具有性质x——按权值建线段树。这样建出来的线段树有两个性质:
1. 形态相同。
2. 可加减,比如第i棵线段树-第j棵线段树=[j+1..i]。
因此,我们可以在第(i-1)棵线段树的基础上建第i棵线段树。离散化后,两者至多有lg n个结点不同;那些相同的结点,复用即可。自始至终,我们并未修改结点,只是新建和复用——函数式的思想。

新建一棵树的代码如下,其中MAXD=lg MAXN向上取整:

int ptr = 1, root[MAXN+1];struct Node {    int v, lc, rc;} T[MAXN*(MAXD+1)];void build(int& y, int x, int a, int l, int r){    T[y = ptr++] = T[x];    ++T[y].v;    if (l == r)        return;    int m = (l+r)/2;    if (a <= m)        build(T[y].lc, T[x].lc, a, l, m);    else        build(T[y].rc, T[x].rc, a, m+1, r);}

查询,像普通线段树一样二分。其中lca(u, v)返回u和v的最近公共祖先,anc[a][0]是a的父亲,j是u、v间简单路径上权在[l, r]范围内的点数:

int query(int u, int v, int k){    int l = 1, r = top, a = lca(u, v), x = root[u], y = root[v], z = root[a], t = root[anc[a][0]];    while (l < r) {        int m = (l+r)/2, j = T[T[x].lc].v + T[T[y].lc].v - T[T[z].lc].v - T[T[t].lc].v;        if (j >= k) {            x = T[x].lc;            y = T[y].lc;            z = T[z].lc;            t = T[t].lc;            r = m;        } else {            x = T[x].rc;            y = T[y].rc;            z = T[z].rc;            t = T[t].rc;            l = m+1;            k -= j;        }    }    return l;}

所以还需要找lca,这里采用倍增算法。完整代码如下:

#include <cstdio>#include <algorithm>using namespace std;const int MAXN = 100000, MAXD = 17;int N, M, e_ptr = 1, top = 1, maxd = 1;int w[MAXN+1], h1[MAXN+1], h[MAXN+1], fst[MAXN+1];struct Edge {    int v, next;} E[MAXN*2];inline void add_edge(int u, int v){    E[e_ptr] = (Edge){v, fst[u]}; fst[u] = e_ptr++;    E[e_ptr] = (Edge){u, fst[v]}; fst[v] = e_ptr++;}inline int read(){    int x = 0;    char ch = getchar();    while (ch<'0' || ch>'9')        ch = getchar();    while (ch>='0' && ch<='9') {        x = x*10 + ch - '0';        ch = getchar();    }    return x;}namespace work {    int ptr = 1, anc[MAXN+1][MAXD+1], root[MAXN+1], depth[MAXN+1];    struct Node {        int v, lc, rc;    } T[MAXN*(MAXD+1)];    void build(int& y, int x, int a, int l, int r)    {        T[y = ptr++] = T[x];        ++T[y].v;        if (l == r)            return;        int m = (l+r)/2;        if (a <= m)            build(T[y].lc, T[x].lc, a, l, m);        else            build(T[y].rc, T[x].rc, a, m+1, r);    }    void dfs(int u, int fa)    {        for (int i = 1; i <= maxd; ++i)            anc[u][i] = anc[anc[u][i-1]][i-1];        build(root[u], root[fa], lower_bound(h+1, h+top+1, w[u])-h, 1, top);        for (int i = fst[u]; i; i = E[i].next) {            int v = E[i].v;            if (v != fa) {                anc[v][0] = u;                depth[v] = depth[u]+1;                dfs(v, u);            }        }    }    inline void swim(int& x, int h)    {        for (int i = 0; h; ++i, h >>= 1)            if (h & 1)                x = anc[x][i];    }    int lca(int u, int v)    {        if (depth[u] > depth[v])            swap(u, v);        swim(v, depth[v]-depth[u]);        for (int i = maxd; i >= 0; --i)            if (anc[u][i] != anc[v][i]) {                u = anc[u][i];                v = anc[v][i];            }        if (u != v)            u = anc[u][0];        return u;    }    int query(int u, int v, int k)    {        int l = 1, r = top, a = lca(u, v), x = root[u], y = root[v], z = root[a], t = root[anc[a][0]];        while (l < r) {            int m = (l+r)/2, j = T[T[x].lc].v + T[T[y].lc].v - T[T[z].lc].v - T[T[t].lc].v;            if (j >= k) {                x = T[x].lc;                y = T[y].lc;                z = T[z].lc;                t = T[t].lc;                r = m;            } else {                x = T[x].rc;                y = T[y].rc;                z = T[z].rc;                t = T[t].rc;                l = m+1;                k -= j;            }        }        return l;    }}int main(){    N = read();    M = read();    while ((1<<maxd) < N)        ++maxd;    for (int i = 1; i <= N; ++i)        h1[i] = w[i] = read();    sort(h1+1, h1+N+1);    h[1] = h1[1];    for (int i = 2; i <= N; ++i)        if (h1[i] != h1[i-1])            h[++top] = h1[i];    for (int i = 1; i < N; ++i) {        int u = read(), v = read();        add_edge(u, v);    }    work::dfs(1, 0);    int lastans = 0;    for (int i = 1; i <= M; ++i) {        int u = read(), v = read(), k = read();        u ^= lastans;        printf("%d", lastans = h[work::query(u, v, k)]);        if (i != M)            putchar('\n');    }    return 0;}

最后不要换行,否则PE。

1 0