hdu 4035 经典概率dp求期望

来源:互联网 发布:datagridview清除数据 编辑:程序博客网 时间:2024/06/05 00:38

求期望要用到全期望公式来来分类讨论:

k[i]:表示死掉回到1的概率

e[i]:表示成功逃走的概率

那么我们设定随机变量X:在节点i处开始,逃走所走的边数

那么E[i]就是从节点i开始,要逃走的边数的期望

如果i是叶子节点:

E[i] = k[i]*E[1] + e[i]*0 + (1-k[i]-e[i])*(E(parent(i))+1);       (1)

如果i不是叶子节点:

与i相连的节点的总数为m,j是i的孩子节点

E[i] = k[i]*E[1] + e[i]*0 + (1-k[i]-e[i])/m*(E(parent(i))+1) + (1-k[i]-e[i])/m*sum(E[j]+1);      (2)

为了简化计算过程,我们设Ai = k[i] , Bi = (1-k[i]-e[i])/m , Ci = Bi*sum(E[j]+1)+Bi;

那么E[i] = Ai*E[1] + Bi*E[p]+ Ci; (3)

导出E[j] = Aj*E[1] + Bj*E[i] + Cj;     (4)

进而得到sum(E[j]) = sum ( Aj*E[1] + Bj*E(i) + Cj )     (5)

把(5)代入到(3)中得到

E[i] = Ai*E[1] + Bi*E[p] + Bi*sum(Aj*E[1] + Bj*E[i] + Cj + 1 ) + Bi    (6)

 =>(1-(1-ki-ei)/m*SUM(Bj))*E(i)=(ki+(1-ki-ei)/m*SUM(Aj))*E(1)+(1-ki-ei)/m *E(father)+(1-ki-ei+(1-ki-ei)/m*SUM(cj));
所以与上述2式对比得到: 

Ai=(ki+(1-ki-ei)/m*SUM(Aj))       / (1-(1-ki-ei)/m*SUM(Bj))

Bi=(1-ki-ei)/m                   / (1-(1-ki-ei)/m*SUM(Bj))

Ci=(1-ki-ei+(1-ki-ei)/m*SUM(cj)) / (1-(1-ki-ei)/m*SUM(Bj))

所以Ai,Bi,Ci只与i的孩子Aj,Bj,Cj和本身ki,ei有关

于是可以从叶子开始逆推得到A1,B1,C1 

在叶子节点: 

Ai=ki; 

Bi=(1-ki-ei); 

Ci=(1-ki-ei); 

而E(1)=A1*E(1)+B1*0+C1;

=>E(1)=C1/(1-A1);

当A趋近于1时,那么无解,精度卡到1e-9才能过

#include <iostream>#include <cstring>#include <cstdio>#include <algorithm>#include <cmath>#define MAX 10007#define eps 1e-9using namespace std;int t,n,x,y;double k[MAX],e[MAX],A,B,C;struct Edge{    int v,next;}edge[MAX<<1];int head[MAX];int cc;void add ( int u , int v ){    edge[cc].v = v;    edge[cc].next = head[u];    head[u] = cc++;}void dfs ( int u , int p ){    double a,b,c,t;    a = b = c = 0.0;    int m = 0;    for ( int i = head[u] ; ~i ; i = edge[i].next )    {        int v = edge[i].v;        if ( v == p ) continue;        dfs ( v , u );        a += A;        b += B;        c += C;        m++;    }    if ( p != -1 ) m++;    t = (1-k[u]-e[u])/m;    A = (k[u]+t*a)/(1-t*b);    B = t/(1-t*b);    C = (1-k[u]-e[u]+t*c)/(1-t*b);}int main ( ){    int Case = 1;    scanf ( "%d" , &t );    while ( t-- )    {        cc = 0;        memset ( head , -1 , sizeof ( head ) );        scanf ( "%d" , &n );        for ( int i = 1 ; i < n ; i++ )        {            scanf ( "%d%d" , &x , &y );            add ( x , y );            add ( y , x );        }        for ( int i = 1 ; i <= n ; i++ )        {            scanf ( "%lf%lf" , &k[i] , &e[i] );            k[i] /= 100.0;            e[i] /= 100.0;        }        dfs ( 1 , -1 );        printf ( "Case %d: " , Case++ );        if ( fabs( A-1 )  < eps )            puts ( "impossible");        else printf ( "%.6lf\n" , C/(1-A) );    }}


0 0
原创粉丝点击