树链剖分 ZJOI2008 树的统计 Count

来源:互联网 发布:js时间戳获取月份 编辑:程序博客网 时间:2024/05/17 02:41

(同步个人博客http://sxysxy.org/blogs/3 到csdn)

题目大意:

有一棵有N个节点,N-1条边的树,每个节点都有一个权值w。现在要求支持以下操作: CHANGE u t 把节点u的权值修改为t,QMAX u v 询问从节点u到v路径上权值最大的节点,QSUM询问从节点u到节点v上所有节点的权值之和。(N <= 3W ,修改/询问次数 <= 20W)

LCT并不会写orz,于是树链剖分,这也是我第一道用树链剖分AC的题,纪念之。

(备注,本题可以在http://syzoj.com/problem/47 , http://cojs.tk/cogs/problem/problem.php?pid=1688, 以及 http://www.lydsy.com/JudgeOnline/problem.php?id=1036 上提交)

  • 首先这个树链剖分的大体思路是这样:树上两个节点之间的路径是唯一的(不绕圈,每条边只经过一次),也就是说对于同一对点u,v他们之间路途覆盖的范围是不变的,现在要求维护这个范围上的信息,于是第一时间就想到线段树。

  • 然而线段树维护的区间总是在某种意义上连续的。我们就需要给点(这道题中维护的是点,树链剖分也可以维护边的信息)进行编号。树链剖分中一个重要的步骤就是给节点(边)编号。剖分后的树链(顺便一说:树链即树上的路径)有着特殊的性质,能够使得我们能在较少时间内完成树上的查询维护操作。下边对此进行一些讲解。

废话完毕

  • 我们定义 size[x]表示以x为根节点的子树的节点数,deep[x]表示x在的深度(这里设树根的深度为1),parent[x]为x的父节点,son[x]表示x节点的“重儿子“(重儿子指的是x的子节点中size最大的节点),top[x]为x所在的链的顶端的节点。 另外定义两个辅助的数组,intree[x]表示树中节点x在线段树(用来维护树上信息的线段树)中的节点号,outree是intree的反函数,outree[x]表示线段树中节点x对应要维护的树中的节点号。上面这些信息可以通过对树进行两次dfs得到。(也能bfs,过程见后面给出的代码)

  • 剖分时要将树链剖分成轻链重链,重链是由重边组成的链,节点x与son[x]的连边是一条重边,x与他的其余子节点的连边为轻边。轻链就是轻边组成的链。

下面是一个剖分的树的图例:
红色的边的重边,绿色的是轻边
剖分后的树具有以下性质:

  • 1.对于树中两个节点u, v,若u是v的子节点(或者说deep[u] > deep[v]),且u, v之间的连边是轻边,那么必定有size[u]*2 < size[v]
  • 2.从根节点到树中任意一点的路径上,存在不超过logN条轻链和不超过logN条重链。(N是节点总数)
  • 3.树中重链上的节点(或者边,这要看具体问题)在线段树重的intree值是连续的。也就是说重链上的节点在线段树中的编号形成了一段连续的区间,我们是可以直接对这个区间进行查询的。
  • 4.同一条链上的节点中,deep小的一定是deep大的的节点的父节点,这就意味着同一条链上deep小的节点比deep大的节点在线段树中的intree值小(这是废话,但是在查询是这一点很有用,线段树接受的查询的区间[l,r]满足l<=r,我们需要保证这一点,而保证这一点的方法就是通过deep值判断。。。)。

修改x节点的权值就直接修改线段树里面intree[x]的值,能学到树剖想必线段树这方面不需要过多赘述。

  • 查询的时候,对于节点u,v,如果u,v在同一条链上(即top[u] == top[v]),直接向线段树查询就可以。
  • 当top[u] != top[v]时,记t1 = top[u], t2 = top[v], 我们钦定deep[t1] >= deep[t2](若不满足则交换u,v,交换t1, t2满足),。此时查询t1到u(t1是u所在链的顶端,即deep[t1] <= deep[u],即线段树中intree[t1] <= intree[u],满足线段树查询的要求(见性质4),可以进行查询。)查询完毕及时更新答案,然后令u = parent[t1], t1 = top[u],t1与u之间形成一段新的链(即查询的时候一段一段地查),对这段链再查询。。。重复以上操作直到top[u] == top[v],这时候就可以直接向线段树查询了,这是最后一次更新答案,完成查询。

(PS:对于需要维护树上边的信息的树链剖分,可以实行”边化点”,让每条边两端较深的一点代替这条边。之后就和维护树上点的一样啦。

下面给出本题我的AC代码:

#include <cstdio>#include <cstdlib>#include <cstdarg>#include <cstring>#include <string>#include <vector>#include <queue>#include <list>#include <algorithm>using namespace std;#define MAXN (30010)#define BETTER_CODE  __attribute__((optimize("O3")))vector<int> G[MAXN];int top[MAXN], son[MAXN], parent[MAXN], value[MAXN], size[MAXN], deep[MAXN];int intree[MAXN], outree[MAXN];int num;BETTER_CODEvoid dfs1(int cur, int fa, int dep){    parent[cur] = fa;    deep[cur] = dep;    size[cur] = 1;    son[cur] = 0;    for(int i = 0; i < G[cur].size(); i++)    {        int nx = G[cur][i];        if(nx != fa)        {            dfs1(nx, cur, dep+1);            size[cur] += size[nx];            if(size[son[cur]] < size[nx])                son[cur] = nx;        }    }}BETTER_CODEvoid dfs2(int cur, int tp){    top[cur] = tp;    intree[cur] = ++num;    outree[intree[cur]] = cur;    if(!son[cur])return;    dfs2(son[cur], tp);    for(int i = 0; i < G[cur].size(); i++)    {        int nx = G[cur][i];        if(nx == parent[cur] || nx == son[cur])continue;        dfs2(nx, nx);    }}class segtree{public:    struct node    {        int l, r;        int maxi;        int sum;    }ns[MAXN<<2];    #define mid(a,b) ((a+b)>>1)    #define ls(x) (x<<1)    #define rs(x) ((x<<1)|1)    BETTER_CODE    void build(int c, int l, int r)    {        ns[c].l = l;        ns[c].r = r;        if(l == r)        {            ns[c].maxi = value[outree[l]];            ns[c].sum = value[outree[l]];            return;        }        int m = mid(l, r);        build(ls(c), l, m);        build(rs(c), m+1, r);        ns[c].maxi = max(ns[ls(c)].maxi, ns[rs(c)].maxi);        ns[c].sum = ns[ls(c)].sum + ns[rs(c)].sum;    }    BETTER_CODE    void update(int c, int v)    {        if(ns[c].l == ns[c].r)        {            ns[c].maxi = value[outree[ns[c].l]];                ns[c].sum = value[outree[ns[c].l]];            return;        }else if(v <= ns[ls(c)].r)            update(ls(c), v);        else            update(rs(c), v);        ns[c].maxi = max(ns[ls(c)].maxi, ns[rs(c)].maxi);        ns[c].sum = ns[ls(c)].sum + ns[rs(c)].sum;    }    BETTER_CODE    int askmax(int c, int l, int r)    {        int t = ls(c);        if(l == ns[c].l && r == ns[c].r)            return ns[c].maxi;        else if(r <= ns[t].r)            return askmax(t, l, r);        else if(l >= ns[t|1].l)            return askmax(t|1, l, r);        else if(l <= ns[t].r && r >= ns[t|1].l)            return max(askmax(t, l, ns[t].r), askmax(t|1, ns[t|1].l, r));    }    BETTER_CODE    int asksum(int c, int l, int r)    {        int t = ls(c);        if(l == ns[c].l && r == ns[c].r)            return ns[c].sum;        else if(r <= ns[t].r)            return asksum(t, l, r);        else if(l >= ns[t|1].l)            return asksum(t|1, l, r);        else if(l <= ns[t].r && r >= ns[t|1].l)            return asksum(t, l, ns[t].r)+asksum(t|1, ns[t|1].l, r);    }};segtree ST;BETTER_CODEint querymax(int u, int v){    int t1 = top[u];    int t2 = top[v];    int ans = -0x2333333;    while(t1 != t2)    {        //假设t1比t2深,这里如果发现deep[t1]<deep[t2]则交换        if(deep[t1] < deep[t2])        {            swap(t1, t2);            swap(u, v);        }        ans = max(ans, ST.askmax(1, intree[t1], intree[u]));        u = parent[t1];        t1 = top[u];    }    //then t1 == t2    if(deep[u] > deep[v])        ans = max(ans, ST.askmax(1, intree[v], intree[u]));    else        ans = max(ans, ST.askmax(1, intree[u], intree[v]));    return ans;}BETTER_CODEint querysum(int u, int v){    int t1 = top[u];    int t2 = top[v];    int ans = 0;    while(t1 != t2)    {        //假设t1比t2深,这里如果发现deep[t1]<deep[t2]则交换        if(deep[t1] < deep[t2])        {            swap(t1, t2);            swap(u, v);        }        ans += ST.asksum(1, intree[t1], intree[u]);        u = parent[t1];        t1 = top[u];    }    //then t1 == t2    if(deep[u] > deep[v])        ans += ST.asksum(1, intree[v], intree[u]);    else        ans += ST.asksum(1, intree[u], intree[v]);    return ans;}void change(int t, int v){    value[t] = v;    ST.update(1, intree[t]);}BETTER_CODEint main(){    int n;    scanf("%d", &n);    for(int i = 1; i < n; i++)    {        int a, b;        scanf("%d %d", &a, &b);        G[a].push_back(b);        G[b].push_back(a);    }    for(int i = 1; i <= n; i++)        scanf("%d", value + i);    num = 0;    dfs1(1, 0, 1);    dfs2(1, 1);    ST.build(1, 1, num);    char buf[233];    int q;    scanf("%d", &q);    while(q--)    {        int x, y;        scanf("%s", buf);        scanf("%d %d", &x, &y);        if(buf[0] == 'Q')        {            if(buf[1] == 'M')                printf("%d\n", querymax(x, y));            else                printf("%d\n", querysum(x, y));        }else if(buf[0] == 'C')            change(x, y);    }    return 0;}
0 0