poj2057--The Lost House(树状dp,求期望)

来源:互联网 发布:三轴点胶机编程教学 编辑:程序博客网 时间:2024/06/05 02:43

题目链接:点击打开链接

题目大意:蜗牛把壳落在了一个树梢上,壳在每一个树梢上的概率是相同的。现在他从树根开始爬,在树杈中可能会有毛毛虫,告诉它壳是否在这个树枝上。每个树枝的长度为1,问最终能找到壳需要爬行的距离期望值最小是多少。

求期望值 = ∑到第i个树梢的距离*在第i个树梢上的概率(i为叶子节点) = 到所有叶子节点的和/叶子节点数。也就是说要求一个序列,按这个序列到达每一个节点的和是最小的。

现在需要判断的就是怎么找到这一个序列,也就是在分叉中如何判断先走哪个叉。

假设根为s,有两个叉为a和b,那么如果壳在s上,有两种可能:

先走a,后走b。

期望值 = 在a上找到壳的期望*壳在a的概率 + (由a返回s的步数+在b上找到壳的期望)*壳在b的概率

先走b,后走a。

期望值 = 在b上找到壳的期望*壳在b的概率 + (由b返回s的步数+在a上找到壳的期望)*壳在a的概率

这两种方式的区别在于 第一个 有一个k1 =(有a返回s的步数*壳在b的概率)。第二个有k2=(在b返回s的步数*壳在a的概率)。如果要求期望值最小。那么就是找到k1和k2中小的一种方式,这样就得到两个叉之间的排序关系,也就可以找出任意几个叉访问的序列,就可以得出最终的序列。

三个数组:

re存储由子节点返回父节点的步数(受到re[子节点]和是否有毛毛虫的控制)

num:(如果壳在这个子树上的)存储该节点需要的平均期望

p:存储壳在该子树上的概率,为计算方便统一乘以总节点数。

re[s] = ∑(re[i])+2 或者 re[s] = 2;

num[s] = 按顺序得到期望值/叶子节点数

p[s] = ∑(p[i])


#include <cstdio>#include <cstring>#include <algorithm>using namespace std ;#define eqs 1e-8struct node{    int u , v ;    int next ;}edge[1010] , temp[1010] ;int head[1010] , cnt ;double re[1010] , num[1010] , p[1010] ;int k[1010] ;void add(int u,int v) {    edge[cnt].u = u ; edge[cnt].v = v ;    edge[cnt].next = head[u] ; head[u] = cnt++ ;}int cmp(node a,node b) {    return re[ b.v ]*p[ a.v ] - re[ a.v ]*p[ b.v ] > eqs  ;}void dfs(int u) {    re[u] = num[u] = p[u] = 0 ;    if( head[u] == -1 ) {        re[u] = 2 ;        num[u] = 1 ;        p[u] = 1 ;        return ;    }    int i , j = 0 , v , s ;    for(i = head[u] ; i != -1 ; i = edge[i].next) {        dfs(edge[i].v) ;    }    for(i = head[u] ; i != -1 ; i = edge[i].next) {        temp[j++] = edge[i] ;        p[u] += p[edge[i].v] ;        re[u] += re[ edge[i].v ] ;    }    if( k[u] ) re[u] = 0 ;    re[u] += 2 ;    sort(temp,temp+j,cmp) ;    for(i = 0 , s = 0 , num[u] = p[u] ; i < j ; i++) {        num[u] += (s+num[ temp[i].v ])*p[ temp[i].v ] ;        s += re[ temp[i].v ] ;    }    num[u] /= p[u] ;    return ;}int main() {    int n , i , u , v ;    char s[10] ;    while( scanf("%d", &n) && n ) {        memset(head,-1,sizeof(head)) ;        memset(k,0,sizeof(k)) ;        cnt = 0 ;        for(i = 1 ; i <= n ; i++) {            scanf("%d %s", &u, s) ;            if( s[0] == 'Y' ) k[i] = 1 ;            if(u == -1) u = 0 ;            add(u,i) ;        }        dfs(1) ;        printf("%.4f\n", num[1]-1.0) ;    }    return 0 ;}




0 0
原创粉丝点击