hdu 5905 Black White Tree 树dp

来源:互联网 发布:淘宝网羊绒衫 编辑:程序博客网 时间:2024/04/28 01:58

我们可以思考对一个子树如果他节点数固定那么黑点数一定是连续变化的,那我们可以很容易想到找到每个子树大小的黑点最大值和最小值也就是dp[i]表示子树大小为i的黑点最大值,dp1[i]为最小值,剩下的就是dp。

#include<cstdio>#include<cstring>#include<iostream>#include<algorithm>#include<vector>using namespace std;const int maxn=2004;int dp[maxn][maxn],dp1[maxn][maxn];int ma[maxn],mi[maxn];int size[maxn];vector<int>g[maxn];int a[maxn];int be[maxn][maxn],en[maxn][maxn];void dfs(int u,int pa){    size[u]=1;    for(int v:g[u]){        if(v!=pa){            dfs(v,u);            size[u]+=size[v];        }    }}int cmp(int a,int b){    return size[a]<size[b];}void dfs1(int u,int pa){    for(int i=1;i<=size[u];i++){        dp[u][i]=-1000000000;        dp1[u][i]=1000000000;    }    if(a[u]){        dp[u][1]=1;        dp1[u][1]=1;    }    else{        dp[u][1]=0;        dp1[u][1]=0;    }    sort(g[u].begin(),g[u].end(),cmp);    for(int v:g[u]){        if(v!=pa){            dfs1(v,u);        }    }    int all=1;    for(int v:g[u]){        if(v!=pa){            for(int j=all;j>0;j--){                for(int i=1;i<=size[v];i++){                    dp[u][i+j]=max(dp[u][i+j],dp[u][j]+dp[v][i]);                    dp1[u][i+j]=min(dp1[u][i+j],dp1[u][j]+dp1[v][i]);                }            }            all+=size[v];        }    }    for(int i=1;i<=size[u];i++){        be[i][dp1[u][i]]++;        be[i][dp[u][i]+1]--;    }}char c[maxn];int main(){    int t,n;    cin>>t;    while(t--){        scanf("%d",&n);        for(int i=1;i<=n;i++) g[i].clear();        for(int i=0;i<=n;i++){            for(int j=0;j<=i;j++){                be[i][j]=0;            }        }        be[0][0]=1;        scanf("%s",c);        for(int i=1;i<=n;i++) a[i]=c[i-1]-'0';        for(int i=1;i<n;i++){            int u,v;            scanf("%d%d",&u,&v);            g[u].push_back(v);            g[v].push_back(u);        }        dfs(1,-1);        dfs1(1,-1);        long long ans=0;        for(int i=0;i<=n;i++){            int all=0;            for(int j=0;j<=i;j++){                all+=be[i][j];                if(all){                    ans+=(i-j+1)*(j+1);                }            }        }        printf("%lld\n",ans);    }}


0 0
原创粉丝点击