树链剖分详解及其模板题

来源:互联网 发布:网络视频广告的形成 编辑:程序博客网 时间:2024/06/03 20:25

转自foreverpiano的安心小窝~

1 前言

如果给你一棵树,求点u到点v路径上点的权值之和,你可能会说:倍增啊!

那如果出题人:我还要你支持修改某个点的权值!

或者再j一点:我还要你支持修改点u到点v路径上点的权值!

那就得用树链剖分了。

2 什么是树链剖分

上面那个问题,树上区间修改。 区间修改最常见做法就是线段树了。

那我们怎么用线段树维护一颗。。。普通的树呢?

那就给普通的树的每个节点标个号,然后放线段树里呗。区间维护。

但如果随便标号,那点u到点v路径不一定标号是连续的啊,你线段树维护个j啊
所以我们现在引入一个(堆)姿势:

重(zhong)儿子:siz[u]为v的子节点中siz值最大的,那么u就是v的重儿子。
轻儿子:v的其它子节点。
重边:点v与其重儿子的连边。
轻边:点v与其轻儿子的连边。
重链:由连续的重边连成的一条链。
轻链:由连续的轻边组成的一条链。

我们先来看张图:

这里写图片描述
图中标在边上了,但也不影响我们学习。。。
标号方法是:
跑dfs,先给当前节点标号,

再给重儿子标号(重儿子和当前节点在一个重链上),

然后对重儿子递归,

最后给剩下的别的儿子标号(别的儿子不和当前节点在一个重链上,所以新建
重链,把新建的重链的顶端节点设为那个”别的儿子“)、

递归(图中是先给重(zhong)边标号,再给剩下的边标号)。标号从小到大。

不难发现一条重链上的标号是连续的,比如点1到点14,点2到点12

这意味着在线段树中,它们是在一个连续的区间里的,而不是像随便标号时断断续续的。

这样就很好用线段树处理了。

如果两个节点不在一条重链上呢?

比如图中的点11和点10,我们要求它们之间的路径上的点权和

那我们就看,点11所在的重链是11->6->2,点10所在的重链是10(10所在的重链只有一个点,就是10)。所以我们就先求出重链11->6->2上的点权和、重链10上的点权和。这两条重链在线段树上都是一段连续的区间,可以直接log2n求出
这时候我们发现还有4->1的重链没有计算,就把它的点权和计算出来,三个重链的点权和加在一起就得到了答案。
所以我们要记录的是:
1. pos[i] 点i的标号
2. top[i] 点i所在重链的顶端节点
3. siz[i] 以点i为根的子树的大小
4. dep[i] 点i的深度
5. fa[i] 点i的父亲节点
6. son[]数组,用来保存重儿子

具体步骤:

我们先跑一边dfs,算出fa、size、dep
然后再跑一边dfs,根据size[i]找出点i的重儿子,然后算出pos、top。
搞完这些就很easy了,因为一段重链在线段树里是一段连续的区间(这是坠重要的)。

我们在查询/修改从点u到点v的路径时,
先找到所在重链的顶端节点(top)深度较深的
(因为这样能让u和v同步提升,防止一个提到根节点了,另一个没提,这时你就不知道提谁了)
注意不能按照u和v的深度来提!

比如top较深的点是u,然后就用线段树处理区间(因为top[u]的标号一定比u要小),再设置u为fa[top[u]],把u往上提,直至u和v在一条重链上(即top[u]==top[v])

这时候可能u和v之间还有一段距离,此时u和v已经在一条重链上,直接处理它们之间的区间就行了。

然后复杂度就是:这里写图片描述
同时这个复杂度也是一般的树链剖分的复杂度。因为重链个数不会超过个,线段树复杂度是的。网上有证明,我就不做过多赘述”.



洛谷 P3384 【模板】树链剖分

此代码是用线段树维护的,这道题主要是维护树中的点权。

#include<bits/stdc++.h>#define ls (rt<<1)#define rs (rt<<1|1)#define mid (tr[rt].l+(tr[rt].r-tr[rt].l)/2)#define maxn 600000#define ll long longusing  namespace std;//1. pos[i] 点i的标号//2. top[i] 点i所在重链的顶端节点//3. siz[i] 以点i为根的子树的大小//4. dep[i] 点i的深度//5. fa[i] 点i的父亲节点//son[]数组,用来保存重儿子struct TREE{ll l,r,sum,tag;}tr[maxn<<2];ll n,m,rr,mod;int u,v,pos[maxn],sz[maxn],top[maxn],son[maxn],fa[maxn],ww[maxn],w[maxn],id[maxn],deep[maxn],cnt=0;vector<int> g[maxn];//动态数组 template<typename tp>void read(tp & dig){//读入优化     char c=getchar();dig=0;    while(!isdigit(c))c=getchar();    while(isdigit(c))dig=dig*10+c-'0',c=getchar();}inline void pushup(int rt){tr[rt].sum=tr[ls].sum+tr[rs].sum;}inline void build(int l,int r,int rt){//建立一棵线段树     tr[rt].l=l,tr[rt].r=r;    if(l==r){tr[rt].sum=w[id[l]];return ;}    int midd=l+(r-l)/2;    build(l,midd,ls),build(midd+1,r,rs);    pushup(rt);}inline void pushdown(int rt){//下传标记     if(tr[rt].tag){        tr[ls].tag=(tr[ls].tag+tr[rt].tag)%mod;        tr[rs].tag=(tr[rs].tag+tr[rt].tag)%mod;        tr[ls].sum=(tr[ls].sum+(tr[ls].r-tr[ls].l+1)*tr[rt].tag)%mod;        tr[rs].sum=(tr[rs].sum+(tr[rs].r-tr[rs].l+1)*tr[rt].tag)%mod;        tr[rt].tag=0;    }}inline void update(int l,int r,int c,int rt){//区间、单点修改     if(l<=tr[rt].l&&tr[rt].r<=r){           tr[rt].sum=(tr[rt].sum+(c*(tr[rt].r-tr[rt].l+1)%mod))%mod;        tr[rt].tag+=c%mod;        return ;    }    pushdown(rt);    if(l<=mid) update(l,r,c,ls);//为啥是这样写?     if(r>mid) update(l,r,c,rs);    pushup(rt);}inline ll query(int l,int r,int rt){//区间查询     if(l<=tr[rt].l&&tr[rt].r<=r) return tr[rt].sum;    pushdown(rt);    int ans=0;    if(l<=mid) ans+=query(l,r,ls),ans%=mod;//if(r>mid) ans+=query(l,r,rs),ans%=mod;    return ans%mod;}inline void dfs1(int x,int fat,int dep) {    //这里应该就是树链剖分的第一个操作,把树上每个节点的size,fa等统统搞出来    deep[x]=dep,fa[x]=fat,sz[x]=1;//fat表示x的爸爸     for(int i=0;i<g[x].size();i++){        int v=g[x][i];        if(v!=fat){            dfs1(v,x,dep+1);            sz[x]+=sz[v];            if(son[x]==-1||sz[son[x]]<sz[v]) son[x]=v;//son[x]记录x的儿子中的重点         }    }}inline void dfs2(int x,int tp){//这个操作应该就是处理轻链和重链     top[x]=tp;pos[x]=++cnt;id[pos[x]]=x;//pos记录的是遍历到该点时的时间,id记录的是当时间为pos[x]时的编号是xif(son[x]==-1) return ;//到了叶子节点就退出。     dfs2(son[x],tp);//继续递归,并且是处理的重链和重点     for(int i=0;i<g[x].size();i++){        int v=g[x][i];        if(v!=fa[x]&&v!=son[x]) dfs2(v,v);//处理轻链     }//tp记录的是重点的祖先节点编号 }inline ll add(int t1,int t2,int c,int ok){//ok为一个标记,表示是进行哪一个操作     ll u=t1,v=t2,ans=0;    while(top[u]!=top[v]){//在不同链上的情况         if(deep[top[u]]>deep[top[v]]) swap(u,v);         if(!ok)update(pos[top[v]],pos[v],c,1);//这里是更新值,例如本题中的从a到b都加上某一个值         else ans+=query(pos[top[v]],pos[v],1),ans%=mod;        v=fa[top[v]];//继续操作下去,直到这两个点爬到他们的公共节点     }    if(deep[u]>deep[v]) swap(u,v);//处理在同一条链上的情况     if(!ok) update(pos[u],pos[v],c,1);    else ans+=query(pos[u],pos[v],1);    return ans%=mod;}//这个add估计是把轻重链加到线段树中 int main(){    memset(son,-1,sizeof(son));    read(n),read(m),read(rr),read(mod);//rr为根节点     for(int i=1;i<=n;i++) read(w[i]);    for(int i=1;i<n;i++)        read(u),read(v),g[u].push_back(v),g[v].push_back(u);    dfs1(rr,-1,1),    dfs2(rr,rr);//important    build(1,n,1);    for(int i=1;i<=m;i++){        int xx,t1,t2,t3;        cin>>xx;        if(xx==1) read(t1),read(t2),read(t3),add(t1,t2,t3,0);        if(xx==3) read(t1),read(t2),update(pos[t1],pos[t1]+sz[t1]-1,t2,1);        if(xx==2) read(t1),read(t2),printf("%lld\n",add(t1,t2,0,1)%mod);        if(xx==4) read(t2),printf("%lld\n",query(pos[t2],pos[t2]+sz[t2]-1,1)%mod);    }    return 0;}/*5 5 2 247 3 7 8 0 1 21 53 14 13 4 23 2 24 51 5 1 32 1 3*/ 

以上代码摘自http://www.cnblogs.com/foreverpiano/p/7142189.html

本题的另一种写法:(本校的一位大佬用树状数组来维护的,在洛谷上跑的飞快)

#include<cstdio>#define re register intchar ss[1<<17],*A=ss,*B=ss;inline char gc(){if(A==B){B=(A=ss)+fread(ss,1,1<<17,stdin);if(A==B)return EOF;}return*A++;}template<class T>inline void sdf(T&x){    char c;re y=1;while(c=gc(),c<48||57<c)if(c=='-')y=-1;x=c^48;    while(c=gc(),47<c&&c<58)x=(x<<1)+(x<<3)+(c^48);x*=y;}char sr[1<<20],z[20];int C=-1,Z;template<class T>inline void wer(T x){    re y=0;if(x<0)y=1,x=-x;    while(z[++Z]=x%10+'0',x/=10);if(y)z[++Z]='-';    while(sr[++C]=z[Z],--Z);sr[++C]='\n';}const int N=5e5+5;typedef int array[N];typedef long long ll;struct edges{int nx,to;}e[N<<1];int n,m,s,mod,tot,tmp;array a,b,depth,fa,son,size,top,id,real,fi;ll c1[N],c2[N];void dfs1(re u){    depth[u]=depth[fa[u]]+(size[u]=1);    for(re i=fi[u],v;i;i=e[i].nx)        if((v=e[i].to)!=fa[u]){            fa[v]=u;dfs1(v);size[u]+=size[v];            if(size[v]>size[son[u]])son[u]=v;        }}void dfs2(re u){    if(son[fa[u]]==u)top[u]=top[fa[u]];    else top[u]=u;    real[id[u]=++tmp]=u;    if(son[u])dfs2(son[u]);    for(re v,i=fi[u];i;i=e[i].nx)        if((v=e[i].to)!=fa[u]&&v!=son[u])            dfs2(v);}inline void insert(re x,ll w){for(re i=x;i<=n;i+=i&(-i))c1[i]+=w,c2[i]+=(ll)x*w;}inline ll sigma(re x){ll sum=0;for(re i=x;i;i-=i&(-i))sum+=(ll)(x+1)*c1[i]-c2[i];return sum;}inline void swap(re&x,re&y){re t=x;x=y;y=t;}inline void change(re u,re v,re w){    while(top[u]!=top[v]){        if(depth[top[u]]<depth[top[v]])swap(u,v);        insert(id[top[u]],w);insert(id[u]+1,-w);        u=fa[top[u]];    }    if(depth[u]>depth[v])swap(u,v);    insert(id[u],w);insert(id[v]+1,-w);}inline ll sum(re u,re v){    ll sum=0;    while(top[u]!=top[v]){        if(depth[top[u]]<depth[top[v]])swap(u,v);        sum+=sigma(id[u])-sigma(id[top[u]]-1);        u=fa[top[u]];    }    if(depth[u]>depth[v])swap(u,v);    sum+=sigma(id[v])-sigma(id[u]-1);    return sum%mod;}inline void add(re u,re v){e[++tot]=(edges){fi[u],v};fi[u]=tot;}int main(){    sdf(n);sdf(m);sdf(s);sdf(mod);    for(re i=1;i<=n;++i)sdf(a[i]);    for(re u,v,i=1;i<n;++i)sdf(u),sdf(v),add(u,v),add(v,u);    dfs1(s);dfs2(s);    for(re i=1;i<=n;i++)b[i]=a[real[i]];    for(re i=1;i<=n;i++)insert(i,b[i]-b[i-1]);    re op,x,y,z;    while(m--){        sdf(op);sdf(x);        if(op==1)sdf(y),sdf(z),change(x,y,z);        else if(op==2)sdf(y),wer(sum(x,y));        else if(op==3)sdf(y),insert(id[x],y),insert(id[x]+size[x]-1+1,-y);        else wer((sigma(id[x]+size[x]-1)-sigma(id[x]-1))%mod);    }    fwrite(sr,1,C+1,stdout);return 0;}

另外几篇博客:
树链剖分
树链剖分基础模板
树链剖分模板题



poj 3237 tree

【题目大意】
第一个操作:将第i条边的权值变为v。
第二个操作:将点a到点b路径上的边权变为其相反数。
第三个操作:查询点a到点b路径上的边权最大值。
【题目解析】
这道题是一道裸的熟练剖分的题,和上面那一道题不同的是,这道题修改和查询的值为边的权值。所以要处理一些细节,将在下面代码中解释。

/*ls:线段树组中p节点左儿子的编号,相当于p<<1(在代码中有很多地方要用到,懒的打)rs:右儿子mid:当前在线段树中的区间的中点deep:每个节点在树中的深度sz:以该点i为根的树的大小fa:点i的父亲节点road:第i条边所指向的点pos:第二个dfs中当前节点遍历到的时间a:父亲节点到当前儿子的边权值son_cost:当前节点与重儿子连的边的权值top:点i所在重链能到的最顶上的节点son:记录节点i的重儿子tree:线段树组mx:当前区间的最大边权值mi:当前区间的最小边权值fg:相当于线段树中的tag*/#include<iostream>#include<cstdio>#include<algorithm>#include<cmath>#include<cstring>#define ls p<<1#define rs (p<<1)+1#define mid (tree[p].l+(tree[p].r-tree[p].l)/2)using namespace std;struct Node{    int l,r,p,sum;}tree[500005];struct arr{    int nd,nx,co,id;}edge[100005];//这个地方既可以用邻接链表储存,也可以用动态数组来存 int head[100005],deep[100005],pos[100005],son_cost[100005];int top[100005],fa[100005],sz[100005],son[100005],road[100005],a[100005];int t,n,cnt,tot,ans;int mx[500005],mi[500005],fg[500005];inline void add(int u,int v,int w,int id){ edge[++cnt].nd=v; edge[cnt].co=w; edge[cnt].id=id;edge[cnt].nx=head[u]; head[u]=cnt; }inline int read(){//读入优化    int x=0,w=1;char ch=0;    while(ch!='-'&&(ch<'0'||ch>'9')) ch=getchar();    if(ch=='-') w=-1,ch=getchar();    while(ch>='0'&&ch<='9') x=x*10+ch-48,ch=getchar();    return x*w;}void dfs1(int u,int father,int depth){//建立起一棵树,并且查找出重儿子,节点深度,节点的大小    deep[u]=depth;fa[u]=father;sz[u]=1;    for(int i=head[u];i;i=edge[i].nx){        int v=edge[i].nd;        if(v!=father){            road[edge[i].id]=v;            dfs1(v,u,depth+1);            sz[u]+=sz[v];            if(son[u]==-1||sz[son[u]]<sz[v]) {                son[u]=v;                son_cost[u]=edge[i].co;            }//第二个if是为了记录重儿子是哪一个 ,并且记录它与重儿子的边权值        }    }}void dfs2(int u,int tp,int cost){//这个操作是将重轻链的最顶端的点求出来     top[u]=tp;pos[u]=++tot;a[tot]=cost;//a的作用是把每个边权值记录下来,方便放入线段树中    if(son[u]==-1) return;    dfs2(son[u],tp,son_cost[u]);    for(int i=head[u];i;i=edge[i].nx){        int v=edge[i].nd;        if(v!=fa[u]&&v!=son[u]) dfs2(v,v,edge[i].co);    }} inline void pushup(int p){  mx[p]=max(mx[ls],mx[rs]); mi[p]=min(mi[ls],mi[rs]);  }//线段树中的上传操作inline void push_down(int p){//下传标记    if(fg[p]){        fg[ls]^=1; fg[rs]^=1;        int tmp=mx[ls]; mx[ls]=-mi[ls]; mi[ls]=-tmp;        tmp=mx[rs]; mx[rs]= -mi[rs]; mi[rs]= -tmp;         fg[p]=0;    }//用mi数组的原因是,此题在做取相反数操作的时候,会变成负数,方便取最大值}void build(int l,int r,int p){//建树    tree[p].l=l; tree[p].r=r;    if(l==r) {        mx[p]=a[l];        mi[p]=a[l];//边权        return;    }    int midd=(l+r)>>1;    build(l,midd,ls);    build(midd+1,r,rs);    pushup(p);}void change(int p,int l,int data){//把某条边的权值改掉    if (tree[p].l==tree[p].r){        mi[p]=mx[p]=data;    } else {        push_down(p);        if (l<=mid) change(ls,l,data);        else change(rs,l,data);        pushup(p);    }}void nega(int p,int l,int r){//取相反数操作    if(l<=tree[p].l&&tree[p].r<=r){        fg[p]^=1;         int tmp=mx[p];        mx[p]=-mi[p];        mi[p]=-tmp;        return;    }    push_down(p);    if(l<=mid) nega(ls,l,r);    if(r>mid) nega(rs,l,r);    pushup(p);}void NEGATE(int x,int y,int n){//树链剖分中的必要操作    while(top[x]!=top[y]){        if(deep[top[x]]<deep[top[y]]) swap(x,y);        nega(1,pos[top[x]],pos[x]);        x=fa[top[x]];    }    if(x==y) return;    if(deep[x]>deep[y]) swap(x,y);    nega(1,pos[son[x]],pos[y]);}int query(int p,int l,int r){    if(l<=tree[p].l&&tree[p].r<=r) return mx[p];    push_down(p);    int ans = -(1<<30);    if(l<=mid) ans=max(ans,query(ls,l,r));    if(r>mid) ans=max(ans,query(rs,l,r));    pushup(p);    return ans;}int QUERY(int x,int y,int n){//树链剖分的必要操作,和上面的NEGATE差不多,只是适应不同需要    if(x==y) return 0;    int ans=-(1<<30);    while(top[x]!=top[y]){        if(deep[top[x]]<deep[top[y]]) swap(x,y);        ans=max(ans,query(1,pos[top[x]],pos[x]));        x=fa[top[x]];    }    if(x==y) return ans;    if(deep[x]>deep[y]) swap(x,y);    ans=max(ans,query(1,pos[son[x]],pos[y]));    return ans;} void init(){//初始化    tot=cnt=0;    memset(head,0,sizeof(head));    memset(son,-1,sizeof(head));    memset(fg,0,sizeof(fg));}int main(){    t=read();    for(int k=1;k<=t;k++){        n=read();init();        for(int i=1;i<=n-1;i++){            int u=read(),v=read(),w=read();             add(u,v,w,i);add(v,u,w,i);        }        dfs1(1,0,0);         dfs2(1,1,0);         build(2,n,1);        char ch[10];        while(~scanf("%s",ch)){            if(*ch=='D') break;            int u=read(),v=read();            if(*ch=='Q')printf("%d\n",QUERY(u,v,n));            else if(*ch=='C') change(1,pos[road[u]],v);            else NEGATE(u,v,n);        }    }}

以上代码摘自G_lory(部分修改过)

原创粉丝点击