poj1741(点分治)

来源:互联网 发布:apple pencil配套软件 编辑:程序博客网 时间:2024/06/05 17:17

点分治这个算法听了一年了,最近没有比赛就学了一下。
题面
题意:给你一棵树,问距离小于等于L 的点对有多少个。

点分治定义了一个叫重心的点。即是割去它后最大的连通块最小的点。然后有证明最大的连通块大小不大于总点数的一半。

对于这题,先找到重心x,然后要处理的便是经过点x的路径。可以用dfs求出所有点到x的距离,然后排序,单调寻找点对。

若有两个点来自同一个连通块,它们分别到x的距离的和小于等于L,我们也把它们算了进去,故还要减去这一部分。
然后递归处理每个连通块即可。

#include <iostream>#include <fstream>#include <algorithm>#include <cmath>#include <ctime>#include <cstdio>#include <cstdlib>#include <cstring>using namespace std;#define mmst(a, b) memset(a, b, sizeof(a))#define mmcp(a, b) memcpy(a, b, sizeof(b))typedef long long LL;const int N=20005,oo=1e9+7;void read(int &hy){    hy=0;    char cc=getchar();    while(cc<'0'||cc>'9')    cc=getchar();    while(cc>='0'&&cc<='9')    {        hy=(hy<<3)+(hy<<1)+cc-'0';        cc=getchar();    }}int n,L;int to[N],nex[N],val[N],head[N],cnt;int siz[N],d[N],vis[N],pre[N],root,sum;int dep[N],a[N];int ans;void add(int u,int v,int w){    to[++cnt]=v;    val[cnt]=w;    nex[cnt]=head[u];    head[u]=cnt;}void dfsRoot(int x,int fa){    pre[x]=fa;    siz[x]=1;    d[x]=0;    for(int h=head[x];h;h=nex[h])    if(!vis[to[h]]&&to[h]!=fa)    {        dfsRoot(to[h],x);        siz[x]+=siz[to[h]];        d[x]=max(d[x],siz[to[h]]);    }    d[x]=max(d[x],sum-siz[x]);    if(d[x]<d[root])    root=x;}void dfsDeep(int x,int fa){    a[++a[0]]=dep[x];    for(int h=head[x];h;h=nex[h])    if(!vis[to[h]]&&to[h]!=fa)    {        dep[to[h]]=dep[x]+val[h];        dfsDeep(to[h],x);    }}int cal(int x,int now){    dep[x]=now;    a[0]=0;    dfsDeep(x,0);    sort(a+1,a+1+a[0]);    int l=1,r=a[0],ans=0;    while(l<r)    {        if(a[l]+a[r]<=L)        ans+=r-l,l++;        else r--;    }    return ans;}void dfsSol(int x){    siz[pre[x]]=sum-siz[x];    vis[x]=1;    ans+=cal(x,0);    for(int h=head[x];h;h=nex[h])    if(!vis[to[h]])    {        ans-=cal(to[h],val[h]);        sum=siz[to[h]];        root=0;        dfsRoot(to[h],0);        dfsSol(root);    }}int main(){    while(1)    {        cin>>n>>L;        if(n==0)        break;        cnt=0;        mmst(vis,0);        mmst(head,0);        ans=0;        for(int i=1;i<=n-1;i++)        {            int u,v,w;            read(u);            read(v);            read(w);            add(u,v,w);            add(v,u,w);        }        sum=n;        root=0;        d[0]=oo;        dfsRoot(1,0);        dfsSol(root);        printf("%d\n",ans);    }}

这里写图片描述