树链剖分步骤

来源:互联网 发布:js打开新的页面 编辑:程序博客网 时间:2024/05/30 20:07

刚学了树链剖分,做了几题熟悉了一下,在此总结一下树剖的步骤吧

准备

deep[]:记录该节点的深度siz[]:记录以该节点为根的子树的节点数fa[]:记录该节点的父亲是谁son[]:记录该节点的重儿子是谁top[]:记录该节点所在的重链的根节点是谁w[]:记录该节点投影到数轴后的位置(即DFS序,线段树要用)

具体步骤

  1. 读入,用空间池储存边(注意是双向的);
  2. 第一次DFS:
    记录dep[],siz[],fa[],son[](找siz最大的儿子做重儿子)
  3. 第二次DFS:
    记录w[];
    顺着son[]处理重链记录top[](包括自己)
    之后处理轻链
    【伪代码】
    设现处理x节点,其重父亲为tp
    dfs2(x,tp)
    time++;
    w[x]=time;
    top[x]=tp;
    if(有重儿子){
    dfs2(重儿子,tp);
    枚举其他儿子v->dfs2(v,v);
    }

  4. 建线段树

  5. 修改:从x到y的节点加d
    while(top[x]!=top[y]){        if(dep[top[x]]<dep[top[y]])  swap(x,y);        insert(1,1,n,w[top[x]],w[x],d);        x=fa[top[x]];    }    if(dep[x] > dep[y]) swap(x,y);    insert(1,1,n,w[x],w[y],d);

6.查询(和线段树一样就不说了)

贴个代码吧,hdu3966,比较经典的树剖题

#include <cstdio>#include <cstdlib>#include <cstring>#include <cmath>#include <iostream>#include <string>#include <algorithm>using namespace std;#define zero(a) memset(a,0,sizeof(a))#define minus(a) memset(a,-1,sizeof(a))const int MAX_N = 50100;struct edge{    int idx;    int next;}e[MAX_N * 3];int siz[MAX_N];int dep[MAX_N];int fa[MAX_N];int son[MAX_N];int top[MAX_N];int w[MAX_N];int a[MAX_N];int h[MAX_N];int n,m,q,ep,time;int cnt[MAX_N * 4];inline void add(int x,int y){    ep++;    e[ep].idx = y;    e[ep].next = h[x];    h[x] = ep;    return;}void dfs1(int x,int fat,int deep){    dep[x] = deep;    siz[x] = 1;    fa[x] = fat;    for(int i=h[x];i!=-1;i=e[i].next){        int idx = e[i].idx;        if(idx == fat)        continue;        dfs1(idx,x,deep+1);        siz[x]+=siz[idx];        if(son[x]==-1 || siz[idx]>siz[son[x]])        son[x]=idx;    }    return;}void dfs2(int x,int tp){    time++;    w[x]=time;    top[x] = tp;    if(son[x]!=-1){        dfs2(son[x],tp);   //ÖØÁ´         for(int i=h[x];i!=-1;i=e[i].next){            int v = e[i].idx;            if(v!=son[x] && v!=fa[x]) dfs2(v,v); //ÇáÁ´         }    }    return;}void insert(int rt,int l,int r,int a,int b,int val){    if(b<l||a>r)    return;    if(a<=l&&b>=r){        cnt[rt]+=val;        return;    }    int lson = rt*2;    int rson = rt*2+1;    int mid = (l+r)/2;    insert(lson,l,mid,a,b,val);    insert(rson,mid+1,r,a,b,val);    return;}inline void revise(int x,int y,int d){    while(top[x]!=top[y]){        if(dep[top[x]]<dep[top[y]]) swap(x,y);        insert(1,1,n,w[top[x]],w[x],d);        x=fa[top[x]];    }    if(dep[x] > dep[y]) swap(x,y);    insert(1,1,n,w[x],w[y],d);    return;}int query(int rt,int l,int r,int idx,int sum){    if(l==r)    return sum+cnt[rt];    int lson = rt*2;    int rson = rt*2+1;    int mid = (l+r)/2;    if(mid>=idx)    return query(lson,l,mid,idx,sum+cnt[rt]);    else    return query(rson,mid+1,r,idx,sum+cnt[rt]);}inline void ask(int x){    //for(int i=1;i<=n;i++)    //cout<<i<<" "<<cnt[i]<<endl;    int add = query(1,1,n,w[x],0);    printf("%d\n",a[x]+add);    return;}inline void init(){    zero(siz);    zero(dep);    zero(fa);    zero(top);    zero(w);    zero(e);    zero(cnt);    zero(a);    minus(son);    minus(h);    ep=0;    time=0;    return;}inline void read(){    for(int i=1;i<=n;i++)    scanf("%d",&a[i]);    for(int i=1;i<=m;i++){        int x,y;        scanf("%d %d",&x,&y);        add(x,y);        add(y,x);    }    return;}inline void build(){    dfs1(1,1,1);    dfs2(1,1);    return;}inline void solve(){    char c[5];    int x,y,d;    for(int i=1;i<=q;i++){        scanf("%s",&c);        if(c[0]=='Q'){            scanf("%d",&x);            ask(x);            continue;        }        scanf("%d %d %d",&x,&y,&d);        if(c[0]=='D')        d=-d;        revise(x,y,d);            }    return;}int main(){    while(scanf("%d %d %d",&n,&m,&q)!=EOF){        init();        read();        build();        solve();    }    return 0;}
1 0
原创粉丝点击