楼教主男人八题 POJ 1741(树分治(我自然是看题解搞懂的))

来源:互联网 发布:阿里云域名优惠 2017 编辑:程序博客网 时间:2024/05/22 18:55
题意就是求树上距离小于等于K的点对有多少个
n2的算法肯定不行,因为1W个点
这就需要分治。可以看09年漆子超的论文 http://wenku.baidu.com/view/e087065f804d2b160b4ec0b5.html###
本题用到的是关于点的分治。
一个重要的问题是,为了防止退化,所以每次都要找到树的重心然后分治下去,所谓重心,就是删掉此结点后,剩下的结点最多的树结点个数最小
每次分治,我们首先算出重心,为了计算重心,需要进行两次dfs,第一次把以每个结点为根的子树大小求出来,第二次是从这些结点中找重心
找到重心后,需要统计所有结点到重心的距离,看其中有多少对小于等于K,这里采用的方法就是把所有的距离存在一个数组里,进行快速排序,这是nlogn的,然后用一个经典的相向搜索O(n)时间内解决。但是这些求出来满足小于等于K的里面只有那些路径经过重心的点对才是有效的,也就是说在同一颗子树上的肯定不算数的,所以对每颗子树,把子树内部的满足条件的点对减去。

最后的复杂度是n logn logn    其中每次快排是nlogn 而递归的深度为logn

我的理解是:比如现在有一棵树,从点1开始遍历整棵树,算出每个节点为根所在的子树中点的数量

然后再遍历树,找出重心,重心的定义如上,也可以看论文(有证明),然后算出从重心到每个节点的距离(不计算已经当过根的节点),快排之后用相向搜索找两段dis相加小于等于k的个数(具体见代码)

不过这样找的会有重复,所以每次我们找的符合题目的对数都是经过重心root的对数,所以要减去不经过重心的个数

原理还是可以理解的,代码比较复杂,确实挺难写,仍需努力,早日成为真男人QAQ

#include <map>#include <set>#include <stack>#include <queue>#include <cmath>#include <string>#include <vector>#include <cstdio>#include <cctype>#include <cstring>#include <sstream>#include <cstdlib>#include <iostream>#include <algorithm>using namespace std;#define   MAX       10005#define   MAXN      2000005#define   lson      l,m,rt<<1#define   rson      m+1,r,rt<<1|1#define   lrt       rt<<1#define   rrt       rt<<1|1#define   mid       int m=(r+l)>>1#define   LL        long long#define   ull       unsigned long long#define   mem0(x)   memset(x,0,sizeof(x))#define   mem1(x)   memset(x,-1,sizeof(x))#define   meminf(x) memset(x,INF,sizeof(x))#define   lowbit(x) (x&-x)const LL     mod   = 1000000;const int    prime = 999983;const int    INF   = 0x3f3f3f3f;const int    INFF  = 1e9;const double pi    = 3.141592653589793;const double inf   = 1e18;const double eps   = 1e-10;struct Edge{    int v,cost,next;}edge[MAX*2];int head[MAX];int maxn[MAX];int siz[MAX];int vis[MAX];int dis[MAX];int tot;int root;int mi;int n,k;int ans;int num;void add_edge(int a,int b,int c){    edge[tot]=(Edge){b,c,head[a]};    head[a]=tot++;}void dfssize(int u,int fa){    siz[u]=1;    maxn[u]=0;    for(int i=head[u];i!=-1;i=edge[i].next){        int v=edge[i].v;        if(v!=fa&&!vis[v]){            dfssize(v,u);            siz[u]+=siz[v];            maxn[u]=max(maxn[u],siz[v]);        }    }}void dfsroot(int r,int u,int fa){    if(siz[r]-siz[u]>maxn[u]) maxn[u]=siz[r]-siz[u];    if(maxn[u]<mi){        mi=maxn[u];        root=u;    }    for(int i=head[u];i!=-1;i=edge[i].next){        int v=edge[i].v;        if(v!=fa&&!vis[v]) dfsroot(r,v,u);    }}void dfsdis(int u,int fa,int d){    dis[num++]=d;    for(int i=head[u];i!=-1;i=edge[i].next){        int v=edge[i].v;        if(v!=fa&&!vis[v]) dfsdis(v,u,d+edge[i].cost);    }}int calc(int u,int d){    int ret=0;    num=0;    dfsdis(u,-1,d);    sort(dis,dis+num);    int i=0,j=num-1;    while(i<j){        while(dis[i]+dis[j]>k&&i<j) j--;        ret+=j-i;        i++;    }    return ret;}void dfs(int u){    mi=n;    dfssize(u,-1);    dfsroot(u,u,-1);    ans+=calc(root,0);    vis[root]=1;    for(int i=head[root];i!=-1;i=edge[i].next){        int v=edge[i].v;        if(!vis[v]){            ans-=calc(v,edge[i].cost);//不经过重心的对数,从子节点v开始计算,距离是cost            dfs(v);        }    }}int main(){    while(scanf("%d%d",&n,&k)){        if(!n&&!k) break;        mem1(head);        mem0(vis);        ans=0;        tot=0;        for(int i=1;i<n;i++){            int a,b,c;            scanf("%d%d%d",&a,&b,&c);            add_edge(a,b,c);            add_edge(b,a,c);        }        dfs(1);        printf("%d\n",ans);    }    return 0;}


0 0
原创粉丝点击