树链剖分学习专题

来源:互联网 发布:电动汽车数据 编辑:程序博客网 时间:2024/05/01 20:47

学习树链剖分,具体的可以参考蒋一瑶的论文
解释的非常清楚,非常通俗易懂
就是第一次dfs找到每个节点往下的节点最多的儿子
第二次dfs就是把重链连起来,重新标号,先给重链标号,然后给轻链标号
然后给重新标号的节点,用线段树或者其他数据结构来维护
如果是点的权值,就直接存标号上
如果是边,就把边权存在高度比较大的那个节点标号上
然后就是操作
一般有3种
点修改,查询链
链修改,查询链

每次查询都是比较它们重链上的祖先是否相同,然后一段一段的比较,最后移动到同一条重链,感觉思想很简单,做个模版

点修改查询链:

struct Edge{    int u,v,next,c;}edge[MAX*2];int head[MAX];int tot;//size是子树节点个数,son记录重链是哪个子节点int top[MAX],son[MAX],size[MAX],dep[MAX];//top记录重链上的祖先int tid[MAX],fa[MAX];//tid为先重链后其他边的新标号int id[MAX];//新标号的点原来的标号int label;int num[MAX];int sum[MAX<<2];void init(){    mem(head,-1);    mem(son,-1);    label=0;    tot=0;}void add_edge(int a,int b,int c){    edge[tot]=(Edge){a,b,head[a],c};    head[a]=tot++;}//找重边void dfs1(int u,int f,int d){    dep[u]=d;    fa[u]=f;    size[u]=1;    for(int i=head[u];i!=-1;i=edge[i].next){        int v=edge[i].v;        if(v==f) continue;        dfs1(v,u,d+1);        size[u]+=size[v];        if(son[u]==-1||size[v]>size[son[u]]) son[u]=v;    }}//连接重链void dfs2(int u,int ance){    top[u]=ance;    tid[u]=++label;    id[tid[u]]=u;    if(son[u]==-1) return;    dfs2(son[u],ance);    for(int i=head[u];i!=-1;i=edge[i].next){        int v=edge[i].v;        if(v==fa[u]||v==son[u]) continue;        dfs2(v,v);    }}void pushup(int rt){    sum[rt]=sum[lrt]+sum[rrt];}void build(int l,int r,int rt){    if(l==r){        sum[rt]=num[id[l]];        return;    }    middle;    build(lson);    build(rson);    pushup(rt);}void update(int l,int r,int rt,int pos,int d){    if(l==r){        sum[rt]=d;        return;    }    middle;    if(pos<=m) update(lson,pos,d);    else update(rson,pos,d);    pushup(rt);}int query(int l,int r,int rt,int L,int R){    if(L<=l&&r<=R) return sum[rt];    middle;    int ans=0;    if(L<=m) ans+=query(lson,L,R);    if(R>m) ans+=query(rson,L,R);    return ans;}int n;int que(int x,int y){    int ans=0;    while(top[x]!=top[y]){        if(dep[top[x]]<dep[top[y]]) swap(x,y);        ans+=query(1,n,1,tid[top[x]],tid[x]);        x=fa[top[x]];    }    if(x==y) return ans;    if(dep[x]<dep[y]) swap(x,y);    ans+=query(1,n,1,tid[son[y]],tid[x]);    return ans;}int main(){    int q,s;    cin>>n>>q>>s;    init();    for(int i=1;i<n;i++){        int a,b,c;        scanf("%d%d%d",&a,&b,&c);        add_edge(a,b,c);        add_edge(b,a,c);    }    dfs1(1,-1,0);    dfs2(1,1);    num[1]=0;    for(int i=0;i<tot;i+=2){        int x=edge[i].u;        int y=edge[i].v;        if(dep[x]<dep[y]) swap(x,y);        num[x]=edge[i].c;    }    build(1,n,1);    while(q--){        int op;        scanf("%d",&op);        if(op==0){            int a;            scanf("%d",&a);            printf("%d\n",que(a,s));            s=a;        }        else{            int a,b;            scanf("%d%d",&a,&b);            a--;            int x=edge[2*a].u;            int y=edge[2*a].v;            if(x==fa[y]) swap(x,y);            update(1,n,1,tid[x],b);        }    }    return 0;}    

链修改查询链

struct Edge{    int u,v,next,c;}edge[MAX*2];int head[MAX];int tot;//size是子树节点个数,son记录重链是哪个子节点int top[MAX],son[MAX],size[MAX],dep[MAX];//top记录重链上的祖先int tid[MAX],fa[MAX];//tid为先重链后其他边的新标号int id[MAX];//新标号的点原来的标号int col[MAX];int label;int num[MAX];int maxv[MAX<<2];int minv[MAX<<2];void init(){    mem(head,-1);    mem(son,-1);    label=0;    tot=0;}void add_edge(int a,int b,int c){    edge[tot]=(Edge){a,b,head[a],c};    head[a]=tot++;}//找重边void dfs1(int u,int f,int d){    dep[u]=d;    fa[u]=f;    size[u]=1;    for(int i=head[u];i!=-1;i=edge[i].next){        int v=edge[i].v;        if(v==f) continue;        dfs1(v,u,d+1);        size[u]+=size[v];        if(son[u]==-1||size[v]>size[son[u]]) son[u]=v;    }}//连接重链void dfs2(int u,int ance){    top[u]=ance;    tid[u]=++label;    id[tid[u]]=u;    if(son[u]==-1) return;    dfs2(son[u],ance);    for(int i=head[u];i!=-1;i=edge[i].next){        int v=edge[i].v;        if(v==fa[u]||v==son[u]) continue;        dfs2(v,v);    }}void pushup(int rt){    maxv[rt]=max(maxv[lrt],maxv[rrt]);    minv[rt]=min(minv[lrt],minv[rrt]);}void pushdown(int rt){    if(col[rt]){        col[lrt]^=1;        col[rrt]^=1;        col[rt]=0;        int a=maxv[lrt];        int b=minv[lrt];        maxv[lrt]=-b;        minv[lrt]=-a;        a=maxv[rrt];        b=minv[rrt];        maxv[rrt]=-b;        minv[rrt]=-a;    }}void build(int l,int r,int rt){    col[rt]=0;    if(l==r){        maxv[rt]=minv[rt]=num[id[l]];        return;    }    middle;    build(lson);    build(rson);    pushup(rt);}void update1(int l,int r,int rt,int pos,int d){    if(l==r){        maxv[rt]=minv[rt]=d;        return;    }    middle;    pushdown(rt);    if(pos<=m) update1(lson,pos,d);    else update1(rson,pos,d);    pushup(rt);}void update2(int l,int r,int rt,int L,int R){    if(L<=l&&r<=R){        col[rt]^=1;        int a=maxv[rt];        int b=minv[rt];        maxv[rt]=-b;        minv[rt]=-a;        return;    }    middle;    pushdown(rt);    if(L<=m) update2(lson,L,R);    if(R>m) update2(rson,L,R);    pushup(rt);}int query(int l,int r,int rt,int L,int R){    if(L<=l&&r<=R) return maxv[rt];    middle;    int ans=-INF;    pushdown(rt);    if(L<=m) ans=max(ans,query(lson,L,R));    if(R>m) ans=max(ans,query(rson,L,R));    pushup(rt);    return ans;}int n;void change(int x,int y){    while(top[x]!=top[y]){        if(dep[top[x]]<dep[top[y]]) swap(x,y);        update2(1,n,1,tid[top[x]],tid[x]);        x=fa[top[x]];    }    if(x==y) return;    if(dep[x]<dep[y]) swap(x,y);    update2(1,n,1,tid[son[y]],tid[x]);}int que(int x,int y){    int ans=-INF;    while(top[x]!=top[y]){        if(dep[top[x]]<dep[top[y]]) swap(x,y);        ans=max(ans,query(1,n,1,tid[top[x]],tid[x]));        x=fa[top[x]];    }    if(x==y) return ans;    if(dep[x]<dep[y]) swap(x,y);    ans=max(ans,query(1,n,1,tid[son[y]],tid[x]));    return ans;}char s[10];int main(){    int t;    cin>>t;    while(t--){        cin>>n;        init();        for(int i=1;i<n;i++){            int a,b,c;            scanf("%d%d%d",&a,&b,&c);            add_edge(a,b,c);            add_edge(b,a,c);        }        dfs1(1,-1,0);        dfs2(1,1);        num[1]=0;        for(int i=0;i<tot;i+=2){            int x=edge[i].u;            int y=edge[i].v;            if(dep[x]<dep[y]) swap(x,y);            num[x]=edge[i].c;        }        build(1,n,1);        while(scanf("%s",s)){            if(s[0]=='D') break;            int a,b;            scanf("%d%d",&a,&b);            if(s[0]=='N'){                change(a,b);            }            else if(s[0]=='C'){                a--;                int x=edge[2*a].u;                int y=edge[2*a].v;                if(x==fa[y]) swap(x,y);                update1(1,n,1,tid[x],b);            }            else printf("%d\n",que(a,b));        }    }    return 0;}       
0 0
原创粉丝点击