树链剖分学习小记

来源:互联网 发布:ios版nba2k17捏脸数据 编辑:程序博客网 时间:2024/06/16 14:16

树链剖分学习小记

平常在一棵树上,从点u到点v询问一些最大值,求和之类的,都是先打个lca然后在类暴力一下,小题可以对,但遇到大题就挂了。然后就去看了一下树链剖分%%%
其实也不难。。。
树链剖分其实就是用数据结构去维护上面的点或链,降低一下复杂度。不过如果随意的去维护,会搞得很乱,时间消耗可能比暴搜还慢。
现在有一种剖分的方法:轻重链剖分(启发式剖分)。
从一个节点走下去,儿子多的为重儿子,链接他的为重边。其余的为轻儿子,链接他的事轻边。重链连成的链叫重链。
——>定义size[i]为以i为根节点的子树的节点数量,fa[i]为i节点的父节点,top[i]为i所在重链的第一个节点,w[i]为父边在线段树中的序号(如果要处理的是点,就是当前这儿点在线段树中的序号),son[i]为节点i所在的重链的儿子节点(除了叶节点,其他都有son),deep[i]为节点i在树中的深度是多少。用两个dfs把这些预处理好。
然后就是查询(修改直接在线段树上做就行了)。设要询问节点x到节点y,f1=top[x],f2=top[y]。首先从deep大的(就是深度更深的)开始做,处理出top到当前这个点的值,假设deep[x]>deep[y],先处理w[f1]到w[x]这段区间,因为这段区间是沿着重链的,在线段树上是连续的,可以直接处理。然后当前这个节点再更新为top的父亲,如上是x=fa[f1],f1=top[x]。一直这样做知道f1=f2为止,此时都在同一条重链上,也可以直接处理。由于是分成一条一条的重链,所以时效很快。

入门题:【ZJOI2008】树的统计

Description

    一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。    我们将以下面的形式来要求你对这棵树完成一些操作:    I.    CHANGE u t : 把结点u的权值改为t    II.   QMAX u v: 询问从点u到点v的路径上的节点的最大权值    III.  QSUM u v: 询问从点u到点v的路径上的节点的权值和    注意:从点u到点v的路径上的节点包括u和v本身

Solution

    很裸的树链剖分

Code

#include<iostream>#include<cstring>#include<cstdio>#include<algorithm>#include<cmath>#define fo(i,a,b) for(i=a;i<=b;i++)using namespace std;const int maxn=100000;int i,j,k,l,t,n,m,ans,size[maxn],deep[maxn],top[maxn],fa[maxn],son[maxn],w[maxn];int first[maxn],last[maxn],next[maxn],num,a[maxn],tot,q,p,fw[maxn];char s[10];struct node{    int l,r,sum,da;}tree[maxn*5];void add(int x,int y){    last[++num]=y;    next[num]=first[x];    first[x]=num;    last[++num]=x;    next[num]=first[y];    first[y]=num;}void dfs1(int x,int y){    int i,j=0,k=0;    size[x]=1;    for(i=first[x];i;i=next[i]){        if(last[i]!=y){            fa[last[i]]=x;            deep[last[i]]=deep[x]+1;            dfs1(last[i],x);            size[x]+=size[last[i]];             if(k<size[last[i]]){                k=size[last[i]];                son[x]=last[i];            }        }    }}void dfs2(int x,int y){    top[x]=y;    w[x]=++tot;    fw[tot]=a[x];    if(!son[x]) return;    dfs2(son[x],y);    for(int i=first[x];i;i=next[i]){        if(last[i]!=fa[x]&&last[i]!=son[x]){            dfs2(last[i],last[i]);        }    }}void build(int x,int l,int r){    int i,j,mid;    if(l==r){        tree[x].da=tree[x].sum=fw[l];    }    else{        mid=(l+r)/2;        build(x*2,l,mid);        build(x*2+1,mid+1,r);        tree[x].sum=tree[x*2].sum+tree[x*2+1].sum;        tree[x].da=max(tree[x*2].da,tree[x*2+1].da);    }}void change(int x,int l,int r,int y,int z){    int i,mid;    if(l==r){        tree[x].da=tree[x].sum=z;    }    else{        mid=(l+r)/2;        if(y<=mid) change(x*2,l,mid,y,z);        else change(x*2+1,mid+1,r,y,z);        tree[x].da=max(tree[x*2].da,tree[x*2+1].da);        tree[x].sum=tree[x*2].sum+tree[x*2+1].sum;    }}int query_max(int x,int l,int r,int y,int z){    int i,mid;    if(l==y&&r==z){        return tree[x].da;    }    else{        mid=(l+r)/2;        if(y>mid) return query_max(x*2+1,mid+1,r,y,z);        else if(z<=mid)return query_max(x*2,l,mid,y,z);        else{            return max(query_max(x*2+1,mid+1,r,mid+1,z),query_max(x*2,l,mid,y,mid));        }     }}int query_sum(int x,int l,int r,int y,int z){    int i,mid;    if(l==y&&r==z){        return tree[x].sum;    }    else{        mid=(l+r)/2;        if(y>mid) return query_sum(x*2+1,mid+1,r,y,z);        else if(z<=mid)return query_sum(x*2,l,mid,y,z);        else{            return query_sum(x*2+1,mid+1,r,mid+1,z)+            query_sum(x*2,l,mid,y,mid);        }     }}int find_max(int x,int y){    int f1=top[x],f2=top[y],o=-10000000;    while(f1!=f2){        if(deep[f1]<deep[f2]){            swap(f1,f2);swap(x,y);        }        o=max(o,query_max(1,1,n,w[f1],w[x]));        x=fa[f1];f1=top[x];    }    if(deep[x]>deep[y]) swap(x,y);    return max(o,query_max(1,1,n,w[x],w[y]));}int find_sum(int x,int y){    int f1=top[x],f2=top[y],o=0;    while(f1!=f2){        if(deep[f1]<deep[f2]){            swap(f1,f2);swap(x,y);        }        o+=query_sum(1,1,n,w[f1],w[x]);        x=fa[f1];f1=top[x];    }    if(deep[x]>deep[y]) swap(x,y);    return o+query_sum(1,1,n,w[x],w[y]);}int main(){    scanf("%d",&n);    fo(i,1,n-1){        scanf("%d%d",&k,&l);        add(k,l);    }    fo(i,1,n) scanf("%d",&a[i]);    deep[1]=1;    dfs1(1,0);    dfs2(1,1);    build(1,1,n);    scanf("%d",&q);    fo(p,1,q){        scanf("%s%d%d",s+1,&k,&l);        if(s[1]=='C'){            change(1,1,n,w[k],l);        }        else if(s[2]=='M'){            printf("%d\n",find_max(k,l));        }        else if(s[2]=='S'){            printf("%d\n",find_sum(k,l));        }    }}
1 0
原创粉丝点击