51 nod 苹果曼和树

来源:互联网 发布:logitech g502 mac 编辑:程序博客网 时间:2024/05/29 04:21

1500 苹果曼和树

题目来源: CodeForces

基准时间限制:1 秒 空间限制:131072 KB 分值: 80 难度:5级算法题

 收藏

 关注

 

苹果曼有一棵n个点的树。有一些(至少一个)结点被标记为黑色,有一些结点被标记为白色。

现在考虑一个包含k(0 ≤ k < n)条树边的集合。如果苹果曼删除这些边,那么会将这个树分成(k+1)个部分。每个部分还是一棵树。

现在苹果曼想知道有多少种边的集合,可以使得删除之后每一个部分恰好包含一个黑色结点。答案对1000000007 取余即可。

 

Input

单组测试数据。

第一行有一个整数n (2 ≤ n ≤ 10^5),表示树中结点的数目。

第二行有n-1个整数p[0],p[1],...,p[n-2] (0 ≤ p[i] ≤ i)。表示p[i](i+1)之间有一条边。结点从0开始编号。

第三行给出每个结点的颜色,包含n个整数x[0],x[1],...,x[n-1] (x[i]0或者1)。如果x[i]1,那么第i个点就是黑色的,否则是白色的。

Output

输出答案占一行。

Input示例

3

0 0

0 1 1

Output示例

2

System Message (题目提供者)

C++的运行时限为:1000 ms,空间限制为:131072 KB 


题解:本道题是树形dp一类题目(方案树)的通用解题方法,就是考虑每个节点与其儿子节点的合法关系更具乘法原理统计方案。对于题目的条件,可以简化问题:将给定的树分为k个联通块,每个联通块有且仅有一个黑点。抓住关键条件:有且仅有一个黑点,

那么对于每个节点,就用dp【i】【0】表示以i为根节点的子树中删边后i所在的联通块没有黑点的方案数,dp【i】【1】表示以i为根节点的子树中删完边后i所在的联通块有一个黑点的方案数。

分情况讨论:

i为黑点:对于儿子节点son,若son所在的联通块有一个黑点,那就断开i与son的边,方案数为dp【i】【1】*dp【son】【1】;若son所在的联通块没有黑点,那就保留i与son的边,方案数为dp【i】【1】*dp【son】【0】(dp的巧妙在于没有讨论删边,直接讨论点与点的情况)。所以i为黑点的dp【i】【1】点总方案数为:dp【i】【1】=(dp【i】【1】*(dp【son】【1】+dp【son】【0】));那么dp【i】【0】喃,很显然,因为i为黑点,dp【i】【0】点值永远为0不必讨论。

i为白点:对于儿子son,若son所在的联通块没有黑点可以保留也可以删边,对于son所在的联通块有一个黑点,则删边。所以dp【i】【0】=dp【i】【0】*(dp【son】【1】+dp【son】【0】);因为每个联通块至少要有一个黑点。所以dp【i】【1】=(dp【i】【1】*(dp【son】【1】+dp【son】【0】)+dp【i】【0】*dp【son】【1】)

总结:做树形dp的题一定要弄清楚关系,儿子与父亲的关系,有时更行关系并不仅限于父亲和儿子,还有可能是儿子与祖先。对于题目中的一些操作(此题中的删边),当很难枚举和表示时,通常可以通过表示其他状态来间接表示这些操作,毕竟操作就是改变状态。树形dp中的方案数问题通常都要用到乘法原理。

#include <iostream>#include <cstring>#include <cstdio>#include <algorithm>#define N 300000#define mod 1000000007using namespace std;int n;int last[N<<1],to[N<<1],head[N],cnt=0;int v[N];long long dp[N][3];void ins(int u,int v){last[++cnt]=head[u];head[u]=cnt;to[cnt]=v;}void dfs(int x,int fa){dp[x][v[x]]=(long long)1;for(int i=head[x];~i;i=last[i]){if(to[i]==fa) continue;dfs(to[i],x);dp[x][1]=(dp[x][1]*(dp[to[i]][1]+dp[to[i]][0])%mod+dp[x][0]*dp[to[i]][1])%mod;dp[x][0]=(dp[x][0]*(dp[to[i]][1]+dp[to[i]][0]))%mod;}}int main(){//freopen("in.txt","r",stdin);memset(head,-1,sizeof(head));int tmp;scanf("%d",&n);for(int i=0;i<n-1;i++){scanf("%d",&tmp);ins(i+1,tmp);ins(tmp,i+1);}for(int i=0;i<n;i++) scanf("%d",&v[i]);dfs(0,-1);printf("%lld",dp[0][1]);return 0;}


原创粉丝点击