[CC]Prime Distance On Tree

来源:互联网 发布:2015网络流行语俏皮话 编辑:程序博客网 时间:2024/05/05 01:51

题意简述

n点的树,边权都为1,任选两点求其距离为素数的概率。

数据范围

1n5×104

思路

点分治+FFT。
只需要求出所有长度为素数的路径条数。
点分治统计答案的时候,我们构建一个生成函数,表示每种距离的条数。
将其平方就是答案,这一步我们可以利用FFT高效实现。
但是多统计了在同一子树的方案,用同样的方法计算减去。
时间复杂度O(nlog2n)
大常数需要跑1.4s。

#include<cstdio>#include<cstring>#include<cmath>#include<complex>using namespace std;#define MAXN 50010#define pi acos(-1)#define MAXL 131080typedef complex<double> C;struct edge{    int s,t,next;}e[MAXN<<1];int head[MAXN],cnt;void addedge(int s,int t){    e[cnt].s=s;e[cnt].t=t;e[cnt].next=head[s];head[s]=cnt++;    e[cnt].s=t;e[cnt].t=s;e[cnt].next=head[t];head[t]=cnt++;}int n,all,rt,u,v,f[MAXN],size[MAXN],num[MAXN];bool vis[MAXN];int len,tmp,ti;int r[MAXL];C a[MAXL];int p_cnt;int prime[MAXN];bool not_prime[MAXN];long long ans;void fft(C *a,int f){    for (int i=0;i<len;i++)        if (i<r[i])            swap(a[i],a[r[i]]);    for (int i=1;i<len;i<<=1)    {        C wn(cos(pi/i),f*sin(pi/i));        for (int j=0;j<len;j+=(i<<1))        {            C w=1;            for (int k=0;k<i;k++,w*=wn)            {                C x=a[j+k],y=w*a[i+j+k];                a[j+k]=x+y,a[i+j+k]=x-y;            }        }    }}void get_rt(int node,int lastfa){    f[node]=0;    size[node]=1;    for (int i=head[node];i!=-1;i=e[i].next)        if (e[i].t!=lastfa&&!vis[e[i].t])        {            get_rt(e[i].t,node);            size[node]+=size[e[i].t];            f[node]=max(f[node],size[e[i].t]);        }    f[node]=max(f[node],all-size[node]);    if (f[node]<f[rt])        rt=node;}void get_dis(int node,int lastfa,int sum,int f){    num[sum]+=f;    for (int i=head[node];i!=-1;i=e[i].next)        if (e[i].t!=lastfa&&!vis[e[i].t])            get_dis(e[i].t,node,sum+1,f);}void solve(int node){    vis[node]=1;    for (int i=head[node];i!=-1;i=e[i].next)        if (!vis[e[i].t])        {            get_dis(e[i].t,node,1,1);            for (tmp=size[e[i].t]<<1,len=1,ti=0;len<=tmp;len<<=1)                ti++;            for (int j=0;j<len;j++)                r[j]=(r[j>>1]>>1)|((j&1)<<(ti-1));            for (int j=0;j<=size[e[i].t];j++)                a[j]=num[j];            for (int j=size[e[i].t]+1;j<len;j++)                a[j]=0;            fft(a,1);            for (int j=0;j<len;j++)                a[j]=a[j]*a[j];            fft(a,-1);            for (int j=0;j<len;j++)                a[j]/=len;            for (int j=1;j<=p_cnt&&prime[j]<len;j++)                ans-=(long long)(a[prime[j]].real()+0.5);            get_dis(e[i].t,node,1,-1);        }    get_dis(node,node,0,1);    for (tmp=size[node]<<1,len=1,ti=0;len<=tmp;len<<=1)        ti++;    for (int i=0;i<len;i++)        r[i]=(r[i>>1]>>1)|((i&1)<<(ti-1));    for (int i=0;i<=size[node];i++)        a[i]=num[i];    for (int i=size[node]+1;i<len;i++)        a[i]=0;    fft(a,1);    for (int i=0;i<len;i++)        a[i]=a[i]*a[i];    fft(a,-1);    for (int i=0;i<len;i++)        a[i]/=len;    for (int i=1;i<=p_cnt&&prime[i]<len;i++)        ans+=(long long)(a[prime[i]].real()+0.5);    get_dis(node,node,0,-1);    for (int i=head[node];i!=-1;i=e[i].next)        if (!vis[e[i].t])        {            all=size[e[i].t];            rt=0;            get_rt(e[i].t,e[i].t);            get_rt(rt,rt);            solve(rt);        }}void sieve(int n){    for (int i=2;i<=n;i++)    {        if (!not_prime[i])            prime[++p_cnt]=i;        for (int j=1;j<=p_cnt&&prime[j]*i<=n;j++)        {            not_prime[i*prime[j]]=1;            if (i%prime[j]==0)                break;        }    }}int main(){    scanf("%d",&n);    sieve(n);    memset(head,0xff,sizeof(head));    cnt=0;    for (int i=1;i<n;i++)    {        scanf("%d%d",&u,&v);        addedge(u,v);    }    rt=0;    f[rt]=n+1;    all=n;    get_rt(1,1);    get_rt(rt,rt);    solve(rt);    printf("%.8lf",1.0*ans/(1.0*n*(n-1)));    return 0;}
0 0