浅谈树的点分治

来源:互联网 发布:mac u盘装系统win7 编辑:程序博客网 时间:2024/06/06 02:54

…我竟然又没保存就关了…这篇博文写了两边w…


首先树的点分的大概意思如下

对于一些树上问题,我们在处理的时候,如果是遍历每个点来进行统计等操作,会发现复杂度可能会比较高,这个时候我们就可以在树上运用分治的思想,将树上问题分解为小的问题,再分解为更小的问题,递归到底层之后再递归回去,可以极大地对算法进行加速(大概是可以优化一个n到logn)

具体操作我们需要依赖寻求树的重心,如果已经明白了如何去求一棵子树的重心的朋友可以跳过这一段,每次在递归处理子树的时候,我们总是先找出其重心,在进行统计,因为树的重心可以近似地把树分的比较均匀(重心的定义是将该点割掉后,所生成的最大联通块最小,我们可以通过dfs来获得这个点)


具体操作如下:
第一次dfs计算出子树的大小

void dfs_size(int u,int fa){    maxn[u]=-0x3f3f3f3f;    size[u]=1;    for(register int i=head[u];i;i=line[i].nxt){        int v=line[i].to;        if(!vis[v]&&v!=fa){            dfs_size(v,u);            size[u]+=size[v];            maxn[u]=max(maxn[u],size[v]);        }    }}

第二次dfs计算出重心:

void dfs_root(int u,int fa,int r){    maxn[u]=max(maxn[u],size[r]-size[u]);    if(maxn[u]<minn) minn=maxn[u],root=u;    for(register int i=head[u];i;i=line[i].to){        int v=line[i].to;        if(!vis[v]&&v!=fa){            dfs_root(v,u,r);        }    }}

其中maxn数组存储的是每个点割掉之后形成的最大的联通块的结点数
minn记录的是最小结点数

我们来看一道例题
POJ 1741,相信很多人都做过


Tree

Give a tree with n vertices,each edge has a length(positive integer less than 1001).
Define dist(u,v)=The min distance between node u and v.
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k.
Write a program that will count how many pairs which are valid for a given tree.


Input

The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l.
The last test case is followed by two zeros.


Output

For each test case output the answer on a single line.


Sample Input

5 4
1 2 3
1 3 1
1 4 2
3 5 1
0 0


Sample Output

8


如果直接暴力统计的话,复杂度不能接受,这个时候我们就需要点分治来进行处理
算法流程大概是这样的

int work(int u){    solve(u);    delete(u);//vis[v]=true;    for(u's son){        if(!vis[u's son]){            work(u's son);        }    }}

首先,我们可以把该树以重心进行划分,然后递归地处理,具体方式是,处理一棵子树中所有点到重心的距离,假设我们已经求到了dis数组表示一个点到当前子树重心的距离,其中dis数组是一个无序数组,那么我们可以用如下的经典算法进行统计路径长度小于等于k的点对

int ans(int u,int d){    int temp=0;    cnt=0;    dfs_dis(u,d,0);    sort(dis+1,dis+cnt+1);    int i=1,j=cnt;    while(i<j){        while(dis[i]+dis[j]>k&&i<j) j--;        temp+=j-i;        i++;    }    return temp;}

正确性是显然的,我们可以这样理解

这里写图片描述
其中的dfs_dis函数即是计算每个点到重心的距离

void dfs_dis(int u,int d,int fa){    dis[++cnt]=d;    for(register int i=head[u];i;i=line[i].nxt){        int v=line[i].to;        if(v!=fa&&!vis[v]){            dfs_dis(v,d+line[i].w,u);        }    }}

接下来只剩一个主的dfs函数,算法流程大概是这样

void dfs(int u){    dfs_size(u,0);    dfs_root(u,u,0);    tot+=ans(root,0);    vis[u]=true;    for(register int i=head[from];i;i=line[i].nxt){        int v=line[i].to;        if(!vis[v]){            tot-=ans(v,line[i].w);            dfs(v);         }    }}

我们会发现我们统计了所有点到重心的距离,然后进行了操作,然而有一些操作是违规的,因为我们规定每一条路径必须经过重心(没有经过重心的可以递归处理,在下一层累加)
这里写图片描述
我们称蓝色的为一条合法路径,而红色的为一条非法路径,我们在计算的时候其实是会把所有非法路径也计算进去(不仅非法,而且错误,因为之前的dis统计的是各点到重心的距离),现在在这句话中我们可以减去所有的非法路径

tot-=ans(v,line[i].w);

这样计算出来的相当于还是以原来的重心为为重心所计算出来的每个点离重心的值
也就是计算出以当前根的某一没有被打过标记的子点为起点,往下把起所有的数据进行处理,然后计算出在该子树的点对数即可,这样可以删掉所有的非法路径,原理是上面的一行不起眼的代码

vis[u]=true;

这样一来可以将dfs_ans的活动范围固定在以这个子点为根的子树中

这里写图片描述
至此,所有的算法流程已经讲解完毕,请大家结合代码再进行深入的理解(或许已经懂了233)

POJ 1741 完整代码:

#include<iostream>#include<cstring>#include<cstdio>#include<algorithm>#define MAXN 10000+100 #define MAXM 50000+100using namespace std;struct Line{    int from,to,nxt,w;}line[MAXM];int head[MAXN];int n,k,ans,root,tail,cnt,minn=0x7ffffff,tot;bool vis[MAXN];int maxn[MAXN],size[MAXN],dis[MAXN];void add_line(int from,int to,int w){    tail++;    line[tail].from=from;    line[tail].to=to;    line[tail].w=w;    line[tail].nxt=head[from];    head[from]=tail;}void dfs_dis(int u,int d,int fa){    dis[++cnt]=d;    for(register int i=head[u];i;i=line[i].nxt){        int v=line[i].to;        if(v!=fa&&!vis[v])        dfs_dis(v,d+line[i].w,u);    }}int ass(int u,int d){    int temp=0;    cnt=0;    dfs_dis(u,d,0);    sort(dis+1,dis+cnt+1);    int i=1,j=cnt;    while(i<j){        while(dis[i]+dis[j]>k&&i<j) j--;        temp+=j-i;        i++;    }    return temp;}void dfs_root(int r,int u,int fa){//POJ 174    maxn[u]=max(maxn[u],size[r]-size[u]);//上端子树     if(maxn[u]<minn) minn=maxn[u],root=u;    for(register int i=head[u];i;i=line[i].nxt){        int v=line[i].to;        if(v!=fa&&!vis[v]) dfs_root(r,v,u);     }}void dfs_size(int u,int fa){//我要动态差错了233    size[u]=1;    maxn[u]=-0x3f3f3f3f;    for(register int i=head[u];i;i=line[i].nxt){        int v=line[i].to;        if(!vis[v]&&v!=fa){            dfs_size(v,u);            size[u]+=size[v];            maxn[u]=max(maxn[u],size[v]);         }     }}void dfs(int u){    dfs_size(u,0);//以该点为根节点dfs     minn=0x7fffffff;    dfs_root(u,u,0);    tot+=ass(root,0);    vis[root]=true;     for(register int i=head[root];i;i=line[i].nxt){        int v=line[i].to;                 if(!vis[v]){            tot-=ass(v,line[i].w);            dfs(v);        }                       }}void init(){     memset(vis,false,sizeof(vis));    memset(head,0,sizeof(head));    tail=0;    tot=0;}int main(){    int i,j,u,v,w;    while(scanf("%d%d",&n,&k)==2){        if(n==0)            return 0;        init();        for(i=1;i<=n-1;i++){            scanf("%d%d%d",&u,&v,&w);            add_line(u,v,w);            add_line(v,u,w);        }        dfs(1);        printf("%d\n",tot);    }}

这里写图片描述

原创粉丝点击