[POJ1987]Distance Statistics(点分治)

来源:互联网 发布:淘宝代刷远程单安全吗 编辑:程序博客网 时间:2024/05/16 01:22

【题目链接】http://poj.org/problem?id=1987
【题目大意】给定一棵树,每条边有权值,求距离<=k的点对数
【解题思路】树上点分治基础题,更新答案时用容斥原理把子树信息一起处理
【呆马】

#include<cstdio>#include<algorithm>#include<cmath>#include<cstdlib>#include<iostream>const int N=40001;using namespace std;struct st{int to,next,v;} e[N<<1];int n,x,y,z,k,cnt,m,num,root,t,p,i,ans,fi[N],f[N],siz[N],b[N],d[N],s[N];bool vis[N];char ch;void add(int x,int y,int z){    e[++cnt].to=y; e[cnt].next=fi[x]; e[cnt].v=z; fi[x]=cnt;    e[++cnt].to=x; e[cnt].next=fi[y]; e[cnt].v=z; fi[y]=cnt;}void getroot(int x,int fa){    f[x]=0;    siz[x]=1;    for (int i=fi[x];i;i=e[i].next)        if (e[i].to!=fa && !vis[e[i].to])        {            int y=e[i].to;            getroot(y,x);            f[x]=max(f[x],siz[y]);            siz[x]+=siz[y];         }    f[x]=max(f[x],num-siz[x]);    if (f[x]<f[root]) root=x;}void go(int x,int fa){    ans++;    for (int i=fi[x];i;i=e[i].next)        if (e[i].to!=fa && !vis[e[i].to])        {            d[++t]=s[e[i].to]=s[x]+e[i].v;            if (d[t]>k)            {                t--;                return;            }            go(e[i].to,x);        }}void calc(int x){    m=0;    for (int i=fi[x];i;i=e[i].next)        if (!vis[e[i].to] && e[i].v<=k)        {            d[t=1]=s[e[i].to]=e[i].v;            go(e[i].to,x);            sort(d+1,d+t+1);            p=t;            for (int i=1;i<=t;i++)            {                b[++m]=d[i];                for (;p>=0 && d[i]+d[p]>k;p--);                if (p>=i) ans-=p-i+1;            }        }    sort(b+1,b+m+1);    p=m;    for (int i=1;i<=m;i++)    {        for (;p>=0 && b[i]+b[p]>k;p--);        if (b[i]>k || !p) break;        if (p>=i) ans+=p-i+1;    }}void part(int x){    vis[x]=1;    calc(x);    for (int i=fi[x];i;i=e[i].next)        if (!vis[e[i].to])        {            root=0;            num=siz[e[i].to];            getroot(e[i].to,0);            part(root);        }}int main(){        scanf("%d%d\n",&n,&x);        for (i=1;i<n;i++)        {            scanf("%d%d%d %c\n",&x,&y,&z,&ch);            add(x,y,z);        }        scanf("%d",&k);        f[0]=1e9;        num=n;        getroot(1,0);        part(root);        printf("%d",ans);}
0 0