hdu4616 树形dp,分治思想

来源:互联网 发布:热血传奇装备数据库 编辑:程序博客网 时间:2024/05/21 15:06

刚做这题的时候感觉很麻烦,而别人却说是经典题,当时我一点儿没看出来经典。。。知道今天看了漆子超的论文,里面第一个例题就和这个差不多。

先是分治思想,然后对于每个根节点,大概想像成向下拉了许多条链出来就行了,有个小技巧可以将时间复杂度降低。

#include<iostream>#include<cstdio>#include<string>#include<cstring>#include<algorithm>#include<cmath>using namespace std;typedef long long ll;const int maxn=50005;const int inf=0x3fffffff;ll val[maxn];int trap[maxn];ll ans,n,cc;ll dp[maxn][5][2];struct edge{    int to,next;}ee[maxn*2];int e[maxn],ecnt;void addedge(int u,int v){    ee[ecnt].to=v;ee[ecnt].next=e[u];e[u]=ecnt++;    ee[ecnt].to=u;ee[ecnt].next=e[v];e[v]=ecnt++;}void init(){    int i,j,k;    for(i=1;i<=n;++i)    {        for(j=0;j<=cc;++j)        {            for(k=0;k<2;++k)            {                dp[i][j][k]=-inf;            }        }    }}void dfs(int f,int u){    int v,i,j,k;    dp[u][trap[u]][trap[u]]=val[u];    for(i=e[u];i!=-1;i=ee[i].next)    {        v=ee[i].to;        if(v==f)            continue;        dfs(u,v);        for(j=0;j<=cc;++j)        {            for(k=0;j+k<=cc;++k)            {                if(j+k<cc)                {                    ans=max(ans,dp[u][j][0]+dp[v][k][0]);                }                if(j+k<=cc)                {                    ans=max(ans,dp[u][j][1]+dp[v][k][1]);                }                if(j!=cc)                {                    ans=max(ans,dp[u][j][0]+dp[v][k][1]);                }                if(k!=cc)                {                    ans=max(ans,dp[u][j][1]+dp[v][k][0]);                }            }        }        for(j=0;j<cc;++j)        {            dp[u][j+trap[u]][0]=max(dp[u][j+trap[u]][0],dp[v][j][0]+val[u]);            dp[u][j+trap[u]][1]=max(dp[u][j+trap[u]][1],dp[v][j][1]+val[u]);        }        if(trap[u]==0)        {            dp[u][cc][1]=max(dp[u][cc][1],dp[v][cc][1]+val[u]);        }    }}int main(){    int t,i,j,u,v;    scanf("%d",&t);    while(t--)    {        scanf("%d%d",&n,&cc);        for(i=1;i<=n;++i)        {            scanf("%d%d",&val[i],&trap[i]);        }        memset(e,-1,sizeof(e));ecnt=0;        for(i=1;i<n;++i)        {            scanf("%d%d",&u,&v);            u++;v++;            addedge(u,v);        }        init();        ans=-inf;        dfs(-1,1);        printf("%d\n",ans);    }    return 0;}


 

原创粉丝点击