树链剖分简介(BZOJ1036)(洛谷2590)

来源:互联网 发布:js regexp 编辑:程序博客网 时间:2024/06/06 19:47

建议学树剖的同学先学会DFS序,并有一定的数据结构基础

定义

树链剖分,计算机术语,指一种对树进行划分的算法,它先通过轻重边剖分将树分为多条链,保证每个点属于且只属于一条链,然后再通过数据结构(树状数组、SBT、SPLAY、线段树等)来维护每一条链。

那它有什么用呢?

它可以解决一些树上路径的一些问题。

一般的树剖是轻重链剖分。当然还有更高级的长链剖分等。这里先讲轻重链剖分。

先定义几个东西

sz[x]: 以x为根的子树的节点个数
to[x]: x中节点最多的儿子编号(重儿子)
重边: 连接x与to[x]的边
轻边: 除了重边以外的边
重链: 由重边构成的链
轻链: 除了重边以外的链
tp[x]: x所在的重链的起始点

剖分完后有一些性质:

轻边(U,V),size(V)<=size(U)/2。 ž从根到某一点的路径上,不超过O(logN)条轻边,不超过O(logN)条重路径。

容易发现树剖是启发式的。

实现

轻重链剖分

实现轻重链剖分我们可以用两遍dfs实现

第一遍DFS:求出每个节点的父亲、深度、sz[]、to[]等。这个比较简单。

void dfs1(int x,int depth){    sz[x]=1,dep[x]=depth;    for (int i=h[x];i;i=ed[i].next)        if (ed[i].to!=fa[x]){            int v=ed[i].to;            fa[v]=x,dfs1(v,depth+1),sz[x]+=sz[v];            if (sz[v]>sz[to[x]]) to[x]=v;        }}

第二遍DFS:优先遍历重儿子。对于每一个重儿子,继承当前节点的重链。对于其它儿子,重新拉一条以自己为端点的重链。

void dfs2(int x){    id[x]=++nd,in[nd]=x;//id[x]存的是每个节点对应的区间位置,in[nd]则相反    for (int i=h[x];i;i=ed[i].next)        if (ed[i].to==to[x])//优先遍历重儿子            tp[ed[i].to]=tp[x],dfs2(ed[i].to);//继承当前重链    for (int i=h[x];i;i=ed[i].next)//再遍历其他儿子        if (ed[i].to!=fa[x]&&ed[i].to!=to[x])            tp[ed[i].to]=ed[i].to,dfs2(ed[i].to);//重新拉一条重链}

仔细体会一下就可以发现,轻重链剖分其实就是优先遍历重儿子的一个DFS序。每一条重链/每一个子树都是连续的一段区间。

路径操作

对于x到y的一条路径,可以把它划分为若干条重链(当然最后一条可能不完整)。

假设tp[x]>=tp[y]。
当tp[x]!=tp[y]时,对x到tp[x]这一段进行操作,然后把x赋成fa[tp[x]],重复上述操作。
当tp[x]=tp[y]时,对x到y这一段进行操作,结束操作。

下面给出伪代码:

各种类型 函数名称(int x,int y){    while (tp[x]!=tp[y]){        if (dep[tp[x]]<dep[tp[y]]) swap(x,y);        各种操作(从id[tp[x]]到id[x])        x=fa[tp[x]];    }    if (dep[x]<dep[y]) swap(x,y);    各种操作(从id[y]到id[x])}

例子(模板)

以BZOJ1036(洛谷P2590)为例:
这道题的操作有单点修改和询问路径最大值及路径点权和,所以我们可以用线段树维护。

代码:

#include<cctype>#include<cstdio>#include<cstring>#include<algorithm>#define N 30005using namespace std;struct tree{ int l,r,sum,mx; }t[N*4];struct edge{ int next,to; }ed[N*2];int n,m,k,nd,a[N],fa[N],dep[N],tp[N],to[N],h[N],sz[N],id[N],in[N];inline int _read(){    int num=0,f=1; char ch=getchar();    while (!isdigit(ch)) { if (ch=='-') f=-1; ch=getchar(); }    while (isdigit(ch)) num=(num<<3)+(num<<1)+(ch^48),ch=getchar();    return num*f;}void addedge(int x,int y){    ed[++k].next=h[x],ed[k].to=y,h[x]=k;}void dfs1(int x,int depth){//第一遍DFS    sz[x]=1,dep[x]=depth;    for (int i=h[x];i;i=ed[i].next)        if (ed[i].to!=fa[x]){            int v=ed[i].to;            fa[v]=x,dfs1(v,depth+1),sz[x]+=sz[v];            if (sz[v]>sz[to[x]]) to[x]=v;        }}void dfs2(int x){//第二遍DFS    id[x]=++nd,in[nd]=x;    for (int i=h[x];i;i=ed[i].next)        if (ed[i].to==to[x])            tp[ed[i].to]=tp[x],dfs2(ed[i].to);    for (int i=h[x];i;i=ed[i].next)        if (ed[i].to!=fa[x]&&ed[i].to!=to[x])            tp[ed[i].to]=ed[i].to,dfs2(ed[i].to);}void build(int l,int r,int x){//线段树建树    t[x].l=l,t[x].r=r;    if (l==r) { t[x].mx=t[x].sum=a[in[l]]; return; }    int mid=l+r>>1; build(l,mid,x*2),build(mid+1,r,x*2+1);    t[x].sum=t[x*2].sum+t[x*2+1].sum;    t[x].mx=max(t[x*2].mx,t[x*2+1].mx);}void mdfy(int x,int p,int w){//修改    if (t[x].l==t[x].r&&t[x].l==p){        t[x].mx=t[x].sum=w; return;    }    int mid=t[x].l+t[x].r>>1;    if (p<=mid) mdfy(x*2,p,w);     else mdfy(x*2+1,p,w);    t[x].sum=t[x*2].sum+t[x*2+1].sum;    t[x].mx=max(t[x*2].mx,t[x*2+1].mx);}tree find(int x,int l,int r){//区间查询    tree ans; ans.mx=-30001,ans.sum=0;    if (t[x].l>r||t[x].r<l) return ans;    if (t[x].l>=l&&t[x].r<=r) return t[x];    tree p=find(x*2,l,r),q=find(x*2+1,l,r);    ans.mx=max(p.mx,q.mx),ans.sum=p.sum+q.sum;    return ans;}tree srch(int x,int y){//路径查询    tree ans; ans.mx=-30001,ans.sum=0;    while (tp[x]!=tp[y]){        if (dep[tp[x]]<dep[tp[y]]) swap(x,y);        tree ret=find(1,id[tp[x]],id[x]);        ans.mx=max(ans.mx,ret.mx);        ans.sum+=ret.sum,x=fa[tp[x]];    }    if (dep[x]<dep[y]) swap(x,y);    tree ret=find(1,id[y],id[x]);    ans.mx=max(ans.mx,ret.mx);    ans.sum+=ret.sum; return ans;}int main(){    n=_read();    for (int i=1;i<n;i++){        int u=_read(),v=_read();        addedge(u,v),addedge(v,u);    }    for (int i=1;i<=n;i++) a[i]=_read();    dfs1(1,1),tp[1]=1,dfs2(1),build(1,n,1);    m=_read(); char s[10]; int x,y;    while (m--){        scanf("%s%d%d",s,&x,&y);        switch (s[1]){            case 'H': mdfy(1,id[x],y); break;            case 'M': printf("%d\n",srch(x,y).mx); break;            case 'S': printf("%d\n",srch(x,y).sum); break;        }    }    return 0;}
原创粉丝点击