codeforce 855C. Helga Hufflepuff's Cup 树形DP

来源:互联网 发布:益盟软件怎么样 编辑:程序博客网 时间:2024/05/20 00:35

题目链接 :http://codeforces.com/problemset/problem/855/C
题意 : 给你一颗树,这颗树有n个节点,每个节点有m中取值,给定一种最高等级取值 k, 值为k的节点数最多为x个(x<=10) 。 最高等级节点周围的节点的值都要小于k。 问你取值方案数。

解题思路, 思路其实很简单,这就是一个简单的树形dp求方案数的变形。每个节点可以描述为,节点的编号,这个节点的取值。 但这样不好转移,因为不知道他的子节点的状态和它本身的关系。 所以我们再加一维。
dp[i][j][k] 以节点i为根的子树,有j个节点为最高等级节点,当前节点取值为k的方案数。
但这样的话 k最高为1e9 太大了。 但可以很轻易的发现,k的取值其实可以分为三类,小于,等于,大于最高等级。
这样整个数组只要开 1e5 * 11 * 3就够了。
状态转移的话。 有点复杂,讲(wo)不(bu)太(hui)清
大致思想就是利用乘法原理。

思路很简单,但是我写了好久都没a。。。。。
看了题解才知道可以用一个中间数组储存转移后的状态,学到了。。。。。。。

复杂度:

一共n个节点,每个节点访问一次,每次状态转移复杂度小于50. 反正,总复杂度还是O(能过)级别的。
代码跑了700ms

#include<iostream>#include<cstdio>#include<cstring>#include<string>#include<algorithm>using namespace std;const int MAX=1e5+10;const int mod=1e9+7;long long dp[MAX][12][3];long long tmp[12][3];long long sz[MAX];long long n,m,k,x;class node{public:    int u,v,next;};node nodes[MAX<<1];int head[MAX];int tot;void add(int u,int v){    nodes[tot].u=u;    nodes[tot].v=v;    nodes[tot].next=head[u];    head[u]=tot;    tot++;}void init(){    memset(head,-1,sizeof head);    tot=0;}void dfs(int u,int per){    dp[u][0][0]=k-1;    dp[u][0][2]=m-k;    dp[u][1][1]=1;    sz[u]=1;    for(int i=head[u];i!=-1;i=nodes[i].next)    {        int v=nodes[i].v;        if(v==per)            continue;        dfs(v,u);        memset(tmp,0,sizeof tmp);        for(int j=0;j<=10;j++)        {            for(int k=0;k<=10;k++)            {                if(j+k>x)                    continue;                tmp[j+k][0]=(tmp[j+k][0]%mod+dp[u][j][0]%mod*(dp[v][k][0] + dp[v][k][1] + dp[v][k][2])%mod)%mod;                tmp[j+k][1] = (tmp[j+k][1]%mod + dp[u][j][1]%mod*(dp[v][k][0])%mod)%mod;                tmp[j+k][2] = (tmp[j+k][2]%mod + dp[u][j][2]%mod*(dp[v][k][0] + dp[v][k][2])%mod)%mod;            }        }        memcpy(dp[u],tmp,sizeof tmp);    }}int main(){    scanf("%lld %lld",&n,&m);    init();    for(int i=1;i<n;i++)    {        int u,v;        scanf("%d %d",&u,&v);        add(u,v);        add(v,u);    }    scanf("%lld %lld",&k,&x);    dfs(1,-1);    int ans=0;    for(int i=0;i<=x;i++)        for(int j=0;j<=2;j++)            ans=(ans+dp[1][i][j])%mod;    cout<<ans<<endl;}
阅读全文
0 0
原创粉丝点击