hdu 4661 Message Passing - 树形dp

来源:互联网 发布:照片文字编辑软件 编辑:程序博客网 时间:2024/05/16 11:19

题目链接: http://acm.hdu.edu.cn/showproblem.php?pid=4661

题意:

有n个人,标号从1到n,这n个人之间构成一棵树,每个人有一个唯一的消息,每次一个人可以告诉与他相邻的一个人他知道的所有消息,现在要求有多少种方式,使得用最少的交换次数让每个人都知道所有的消息。

解题思路:

首先可以确定的是最少传递次数就是边的数量*2,方法就是先讲所有信息都汇集到一个点,
 然后再从这个点传到其他所有点。


 现在就是枚举中心节点,计算中心节点的拓扑排序数x[i],那么
 最后答案就是: ans = Sum(x[i]*x[i]).


一个点的拓扑排序数是指,从这个点出发遍历其他所有点的不同方法数。
用dp[u]记录以u为根的子树的拓扑排序数,num[u]记录u为根的节点数
两个子树合并后的拓扑排序数为dp[u] = dp[v1]*dp[v2]*C(num[v1]+num[v2],num[v1]).


但是上面的方法只能求出一个点为根的拓扑排序数,要求是要所有的点为根的拓扑排序数的平方和,
所以还要能够从一个点为根的情况递推到其他点为根拓扑排序数。


假设当前点为u, 父节点为fa,fa的子树中除了u这颗其他子树为t,那么现在已知fa的拓扑排序数,如何求u的拓扑排序数。
dp[fa]=dp[t]*dp[u]*C(n-1,num[u]);


把u看做根时,newdp[u] = dp[u]*dp[t]*C(n-1,num[u]-1);
两式化简 newdp[u] = dp[fa]*num[u]/(n-num[u]);
然后就是,再dfs一遍,根据这个公式推出所有点为根的拓扑排序数。


#pragma comment(linker,"/stack:100000000,100000000")#include <cstdio>#include <cstring>#include <iostream>#include <queue>#include <vector>#include <map>#include <set>#include <cmath>#include <algorithm>#include <functional>#include <cmath>#include <bitset>using namespace std;typedef long long LL;const int maxn = 1000010;const LL M = 1e9 + 7;struct Edge{    int to,next;}e[maxn<<1];int head[maxn],cnt,n;LL ans,dp[maxn],num[maxn];LL fac[maxn];LL ext_gcd(LL a,LL b,LL &x, LL&y){    if(b == 0){        x=1,y=0;        return a;    }    LL d = ext_gcd(b,a%b,x,y);    LL t = x;    x = y;    y = t - a/b*y;    return d;}LL pow_mod(LL a,LL b){    LL res = 1;    while(b){        if(b&1) {            res = res*a;            if(res > M) res %= M;        }        b >>= 1;        a = a*a;        if(a>M) a%= M;    }    return res;}LL niyuan(LL t){//    LL x,y;//    ext_gcd(t,M,x,y);//    if(x > M) x %= M;//    if(x < 0) x = (x%M + M)%M;//    return x;    return pow_mod(t,M-2);}void init(){    cnt=0;    memset(head,-1,sizeof(head));}void add(int u,int v){    e[cnt].to = v;    e[cnt].next = head[u];    head[u] = cnt++;}// C(n,m)LL C(int n,int m){    return fac[n]*niyuan(fac[m]*fac[n-m]%M)%M;}void dfs1(int u,int fa){    num[u]=1, dp[u] = 1;    for(int i=head[u]; i!=-1; i=e[i].next){        int v = e[i].to;        if(v == fa) continue;        dfs1(v,u);        num[u] += num[v];        dp[u] = dp[u]*dp[v]%M*C(num[u]-1,num[v])%M;    }}void dfs(int u,int fa){    if( u!=1 ){   // u==1 已经算出来了        dp[u] = dp[fa]*num[u]%M*niyuan(n - num[u])%M;    }    ans = (ans + dp[u]*dp[u]%M)%M;    for(int i=head[u]; i!=-1; i=e[i].next){        int v = e[i].to;        if(v == fa) continue;        dfs(v,u);    }}int main(){    int T,u,v;    cin >> T;    fac[0] = 1;    for(int i=1;i<=1000000;i++) {        fac[i] = fac[i-1]*i%M;    }    while(T--){        init();        cin >> n;        for(int i=0;i<n-1;i++){            cin >> u >> v;            add(u,v);            add(v,u);        }        dfs1(1,-1);        ans = 0;        dfs(1,-1);        cout << ans << endl;    }    return 0;}


0 0