概率+树规 熟练剖分

来源:互联网 发布:java局域网访问数据库 编辑:程序博客网 时间:2024/05/16 01:06

这里写图片描述
这里写图片描述

根节点不一定是1,但是是一个确定的点,看谁不是儿子就行了。。
这道题我们考虑从儿子推到根。设f[i][j]表示以i为根的子树中,最长轻链长度为j的概率。
因为每一个son被选为重儿子的概率相同,且重儿子对父亲贡献和轻儿子不同,所以要每一个点为重儿子,之后挨个枚举每个儿子。这个效率是N^2,然后要枚举链的长度,如果枚举到size[root],相当于N^3,废掉了。。但只要枚举到size[son]+1就好了。更大不会再有意义,概率一定为零。
在计算答案时要再次用到f数组,所以要先存一下,再转过去。
对于每一次枚举重儿子时枚举所有儿子时的计算方法。g[i][j]表示i节点子树中最长链长为0~j的概率之和。
f[i][j]=f[son][j]*g[i][j]+g[son][j]*f[i][j]-f[i][j]*f[son][j];(对于重儿子)
考虑一次向答案上添加一个子节点。那么父节点最长为j的链出现地方有两种可能。1,出现在之前已经添加到答案中的子节点里(f[i][j]),那么这时新添加的子节点的长度只要在0~j中哪一个都无所谓了。同理,最长j出现在新添加的节点中,那之前添加的多长也就无所谓了。。。
那么对于轻儿子呢?不同于重儿子,轻儿子与父亲的连边会对答案造成贡献,所以只要把son的j改成j-1即可。
最后统计答案:根节点子树中最长轻链长为j的概率*j,加个和就好了。
对于除法,乘逆元即可。
注:最好别看我代码,巨丑,还难理解。。。我把f和g数组合二为一了。。。。

#pragma GCC optimize("O3")#include<cstdio>#include<cstdlib>#include<cstring>#include<iostream>#include<algorithm>#define N 3005#define mod 1000000007#define ll long longusing namespace std;int read(){    int sum=0,f=1;char x=getchar();    while(x<'0'||x>'9'){if(x=='-')f=-1;x=getchar();}    while(x>='0'&&x<='9'){sum=(sum<<1)+(sum<<3)+x-'0';x=getchar();}    return sum*f;}struct road{int v,next;}lu[N*2];int n,e,adj[N];ll out[N],sz[N],g[N],h[N],f[N][N];bool v[N];inline ll cheng(ll x,int m){    ll ans=1;    while(m)    {        if(m&1)ans=ans*x%mod;        x=x*x%mod;        m/=2;    }    return ans;}inline void dfs(int x){    sz[x]=1;    for(int i=adj[x];i;i=lu[i].next)    {        dfs(lu[i].v);        sz[x]+=sz[lu[i].v];    }    ll fm=cheng(out[x],mod-2);    for(int i=adj[x];i;i=lu[i].next)    {        int zz=lu[i].v;        for(int j=0;j<=n;j++)g[j]=1;        for(int j=adj[x];j;j=lu[j].next)        {            int to=lu[j].v;            for(int k=0;k<=sz[to]+1;k++)            {                ll s=g[k],sum=f[to][k];                if(k)s-=g[k-1],sum-=f[to][k-1];                if(s<0)s+=mod;if(sum<0)sum+=mod;                if(to==zz)                    h[k]=((sum*g[k]+s*f[to][k]-sum*s)%mod+mod)%mod;                 else if(k)                {                    sum=f[to][k-1];if(k!=1)sum-=f[to][k-2];                    if(sum<0)sum==mod;                    h[k]=((sum*g[k]+s*f[to][k-1]-sum*s)%mod+mod)%mod;                }            }            g[0]=h[0];h[0]=0;            for(int k=1;k<=sz[to]+1;k++)g[k]=(g[k-1]+h[k])%mod,h[k]=0;        }        for(int j=sz[x];j>=1;j--)g[j]=(g[j]-g[j-1]+mod)%mod;        for(int j=0;j<=sz[x];j++)f[x][j]=(f[x][j]+g[j]*fm%mod)%mod;    }    if(!adj[x])f[x][0]=1;    for(int i=1;i<=n;i++)f[x][i]=(f[x][i]+f[x][i-1])%mod;}int main(){    //freopen("tree.in","r",stdin);    //freopen("tree.out","w",stdout);    n=read();    for(int i=1;i<=n;i++)    {        int k=read();out[i]=k;        for(int j=1;j<=k;j++)        {            int x=read();v[x]=1;            lu[++e]=(road){x,adj[i]};adj[i]=e;        }    }    int root;    for(int i=1;i<=n;i++)if(!v[i])root=i;    dfs(root);    ll ans=0;    for(ll i=1;i<=n;i++)        ans=(ans+i*(f[root][i]-f[root][i-1]+mod)%mod)%mod;    cout<<ans;}