BZOJ 4033 [HAOI 2015] 树DP 解题报告

来源:互联网 发布:淘宝退的运费险在哪里 编辑:程序博客网 时间:2024/05/21 21:36

4033: [HAOI2015]树上染色

Description

有一棵点数为N的树,树边有边权。给你一个在0~N之内的正整数K,你要在这棵树中选择K个点,将其染成黑色,并将其他的N-K个点染成白色。将所有点染色后,你会获得黑点两两之间的距离加上白点两两之间距离的和的收益。问收益最大值是多少。

Input

第一行两个整数N,K。接
下来N-1行每行三个正整数fr,to,dis,表示该树中存在一条长度为dis的边(fr,to)。
输入保证所有点之间是联通的。
N<=2000,0<=K<=N

Output

输出一个正整数,表示收益的最大值。

Sample Input

5 2
1 2 3
1 5 1
2 3 1
2 4 2

Sample Output

17

【解题报告】
dalao的题解
http://ydc.blog.uoj.ac/blog/336
摘录一部分:
首先我们要知道的是,这么一份伪代码的复杂度是O(n2)O(n2)的

void Tree_Dp(int p){    size[p]=1;    each(x,son[p])    {        Tree_Dp(x);        for(int i=0;i<=size[p];++i)            for(int j=0;j<=size[x];++j)                update dp[p][i+j];        size[p]+=size[x];    }}

复杂度的证明?考虑那个二重循环,可以看做分别枚举两棵子树的每个点。你会发现,点对(u,v)(u,v),只会在Tree_Dp(lca(u,v)lca(u,v))处被考虑到,所以复杂度是O(n2)O(n2)
换句话说,ii所考虑到的点的DFS序一定小于jj所考虑到的点
于是我们就可以dpi,jdpi,j表示子树ii选了jj个黑点,然后转移暴力枚举子树里选了几个黑点,复杂度还是O(n2)O(n2)的,关键是怎么转移
其实说白了还是设状态的问题,如果设他是子树ii选了jj个黑点,子树内的同色点对距离和的话,是不好转移的对吧……所以我们要把他们往子树外连的也考虑进来
这么讲吧,假设我们把任意一对同色点之间的路径给标一下,那么dpi,jdpi,j记录的就是,子树ii选了jj个黑点,子树内的所有被标过的路径权值和,这样就能转移了
简单地说,就是改变一下状态定义,就能让他不仅考虑子树ii内部的和,把子树外的延伸到子树内的那些边也考虑了

代码如下:

/**************************************************************    Problem: 4033    User: onepointo    Language: C++    Result: Accepted    Time:960 ms    Memory:64284 kb****************************************************************/#include<cstdio>#include<cstring>#include<algorithm>using namespace std;#define max(a,b) (a<b)?b:a#define inf 0x3f3f3f3f#define N 2010#define LL long longLL n,m;LL cnt,head[N],size[N];struct Edge{LL to,nxt,w;}e[N<<1];LL dp[N][N];void adde(LL u,LL v,LL w){    e[++cnt].to=v;e[cnt].w=w;    e[cnt].nxt=head[u];head[u]=cnt;    e[++cnt].to=u;e[cnt].w=w;    e[cnt].nxt=head[v];head[v]=cnt;}void dfs(LL u,LL fa){    LL tmp[N];size[u]=1;    fill(dp[u]+2,dp[u]+n+1,-inf);       for(int i=head[u];~i;i=e[i].nxt)    {        LL v=e[i].to;        if(v==fa) continue;        dfs(v,u);        fill(tmp,tmp+size[u]+size[v]+1,-inf);        for(LL j=0;j<=size[u];++j)        for(LL k=0;k<=size[v];++k)        {            tmp[j+k]=max(tmp[j+k],dp[u][j]+dp[v][k]+e[i].w*(k*(m-k)+(size[v]-k)*(n-m-(size[v]-k))));            }        size[u]+=size[v];        for(LL j=0;j<=size[u];++j) dp[u][j]=tmp[j];    }}int main(){    cnt=-1;    memset(head,-1,sizeof(head));    scanf("%lld%lld",&n,&m);    for(int i=1;i<n;++i)    {        LL u,v,w;scanf("%lld%lld%lld",&u,&v,&w);        adde(u,v,w);    }    dfs(1,1);    printf("%lld\n",dp[1][m]);    return 0;}