poj 1741 树分治入门

来源:互联网 发布:linux cpu调为高性能 编辑:程序博客网 时间:2024/06/01 16:59

题意:统计距离<=k的点对,n<=10000.

这道题用的树的点分治方法。点分治基于树的重心。

树的重心,定义是删除某个点后得到的最大(节点数)子树的节点数最小。性质是,可以证明删除掉重心后,每个子树的大小<=n/2。这个性质保证了基于重心的分治算法深度不会超过logn。

这题递归求解子树后,将所有子树节点d[i]距离排序,用两个指针扫过去就可以得到对于每一个d[i]满足的d[j]+d[i]<=k的j的个数。要注意的是减去重复的个数(这个比较麻烦)。


#include <iostream>#include <cstring>#include <cstdio>#include <algorithm>using namespace std;typedef long long ll;const int inf=0x3f3f3f3f;const int maxn=1e4+10;const int maxm=maxn<<1;int n,k;int tot,first[maxn],nxt[maxm], to[maxm], cost[maxm];int vis[maxn],siz[maxn],maxs[maxn]={inf}, dep[maxn];int a[maxn], in[maxn];void addedge(int u, int v, int w){    nxt[tot]=first[u];    to[tot]=v;    cost[tot]=w;    first[u]=tot++;}void init(){    memset(first ,-1, sizeof(first));    tot=0;    memset(vis, 0, sizeof(vis));    for(int i=1; i<n; i++){        int u,v,w;        scanf("%d%d%d", &u, &v, &w);        addedge(u,v,w);        addedge(v,u,w);    }}void getsize(int u, int pre) //记录节点数{    siz[u]=1;    maxs[u]=0;    for(int i=first[u]; i!=-1; i=nxt[i]){        int v=to[i];        if(v==pre || vis[v]) continue;        getsize(v, u);        siz[u]+=siz[v];        maxs[u]=max(maxs[u], siz[v]);    }}void getroot(int u, int pre, int num, int &rt)//获取重心{    maxs[u]=max(maxs[u], num-siz[u]);    if(maxs[u]<maxs[rt])        rt=u;    for(int i=first[u]; i!=-1; i=nxt[i]){        int v=to[i];        if(v==pre||vis[v]) continue;        getroot(v, u, num, rt);    }}void getdep(int u, int pre, int d, int &cnt){    a[cnt++]=d;    dep[u]=d;    for(int i=first[u]; i!=-1; i=nxt[i]){        int v=to[i], w=cost[i];        if(v==pre||vis[v]) continue;        getdep(v, u, d+w, cnt);    }}ll count(int l, int r) //对于排好序的数组,用两个指针扫描得到计数。{    ll ret=0;    int ptr=r;    for(int i=l; i<=r && a[i]<=k && ptr>=l; i++){        while(a[ptr]+a[i]>k && ptr>=l) ptr--;        ret+=ptr-l+1;    }    return ret;}ll solve(int u){    ll ret=0;    int rt=0;    getsize(u, -1);    getroot(u, -1, siz[u], rt);    vis[rt]=1;       for(int i=first[rt]; i!=-1; i=nxt[i]){        int v=to[i], w=cost[i];        if(vis[v])continue;        ret+=solve(v);//递归求解子树    }    ll tmp=0;    int cnt=0;    for(int i=first[rt]; i!=-1; i=nxt[i]){        int v=to[i], w=cost[i];        if(vis[v]) continue;        int p=cnt;        getdep(v, rt, w, cnt);        sort(a+p, a+cnt);        tmp-=count(p, cnt-1);    }    a[cnt++]=0;    sort(a, a+cnt);    tmp+=count(0, cnt-1);    ret+=(tmp-1)/2;    vis[rt]=0;      return ret;}int main(){    while(cin>>n>>k &&n+k){        init();        ll ans=solve(1);        cout<<ans<<endl;    }    return 0;}


0 0