ZOJ 3231 Apple Transportation 树形dp

来源:互联网 发布:爱奇艺 网络环境危险 编辑:程序博客网 时间:2024/06/08 07:46

题目链接:

http://acm.hust.edu.cn/vjudge/problem/viewProblem.action?id=14656

题意:

给一棵树,问树有点权值和边权值。问保证点权值方差最小的前提下,怎样花费最小代价改变点权值。已知从一个点移动一个权值到相邻点需要代价为两点的连接边权值

思路:

树形dp第一题。

第一个难点是抛弃方差最小的限制条件,发现最后生成树的点权值是sum/n的下界或上界,并且取上界的点个数一定。

dp[i][j]表示i为根的子树有j个取上界的点,深搜处理子节点后dp即可。

注意的是每个子节点的值都要加上,即分组背包。

然而WA了三天。刚开始是用容器模拟链表做,后面改成前向星还是WA。原来关键是gmin和gmax的使用问题。因为通常都是define,会出现e = gmin(1,100)+50的判断,实际为e = 1<100?1:(100+50)

哭死在键盘前。

源码:

#include <cstdio>

#include <cstring>

#include <cmath>

#include <string>

#include <algorithm>

#include <iostream>

#include <queue>

#include <vector>

#include <cstdlib>

using namespace std;

#define gmax(a,b) a>b?a:b

#define gmin(a,b) a<b?a:b

typedef long long LL;

#define MAXN 110

#define INF 100000000000LL

//vector<int>dd[MAXN];///son of each node

int data[MAXN],tdata[MAXN],num[MAXN];///apple numbers, number of nodes, sum of numbers

//int road[MAXN][MAXN];///distance between two nodes

LL dp[MAXN][MAXN];

LL tdp[MAXN];

int n,tot,ave,sum,lv;

int up[MAXN];

int head[MAXN];

struct D

{

    int u,v,w;

    int next;

    D(){}

    D(int __u,int __v,int __w){u = __u;v = __v;w = __w;}

}jia[MAXN*2];

int zheng(int a)

{

    return a > 0 ? a : -a;

}

void add_edge(int u,int v,int w)

{

    jia[tot] = D(u,v,w);

    jia[tot].next = head[u];

    head[u] = tot++;

}

void update(int u,int fa)

{

//    for(int i=0; i<n; i++)

//        tdp[i] = INF;

    fill(tdp, tdp+n+1, INF);

    tdp[0] = 0;

    int cap = min(lv, num[u]);

    for(int i=head[u]; i!=-1; i=jia[i].next){

        int tmark = jia[i].v;

        int tw = jia[i].w;

        if(tmark == fa)

            continue;

        else{

            for(int j=cap; j>=0; j--){///?

                LL res = INF;

                for(int k=0; k <= num[tmark] && k<=j; k++)

                    if(tdp[j-k] != INF && dp[tmark][k] != INF)

                        res = min(res, tdp[j-k] + dp[tmark][k]);

                tdp[j] = res;

            }

        }

    }

//    printf("u = %d\n",u);

//    printf("tdp\n");

//    for(int i=0; i<=cap; i++)

//        printf("%lld ",tdp[i]);

//    printf("\n");

    for(int i=0; i<=cap; i++){

        LL temp = up[u] * zheng(ave * num[u] + i - data[u]);///second

//        printf("temp = %lld ",temp);

        if(i == 0){

            dp[u][i] = tdp[i] + temp;

//            printf("tdp[i] = %lld ,dp[u][i] = %lld\n",tdp[i],dp[u][i]);

        }

        else{

            dp[u][i] = min(tdp[i-1],tdp[i]) + temp;

//            printf("tdp[i-1] = %lld ,tdp[i] = %lld, dp[u][i] = %lld\n",tdp[i-1],tdp[i],dp[u][i]);

        }

    }

//    printf("u = %d,tdp\n",u);

//    for(int i=0; i<=cap; i++)

//        printf("%lld ",dp[u][i]);

//    printf("\n");

}

void dfs2(int mark, int fa, int w)

{

    bool leaf = true;

    up[mark] = w;

    num[mark] = 1;

    data[mark] = tdata[mark];

    for(int i = head[mark]; i!=-1; i = jia[i].next){

        int v = jia[i].v;

        int tw = jia[i].w;

        if(v == fa)

            continue;

        else{

//            printf("v = %d,fa = %d\n",v,fa);

            leaf = false;

            dfs2(v,mark,tw);

            data[mark] += data[v];

            num[mark] += num[v];

        }

    }

    if(leaf){

        LL temp;

        temp = up[mark] * zheng(ave * num[mark] - data[mark]);

        dp[mark][0] = temp;

        temp = up[mark] * zheng(ave * num[mark] + 1 - data[mark]);

        dp[mark][1] = temp;

    }

    else{

            update(mark, fa);

        }

//    printf("mark = %d\n",mark);

//    int cap = gmin(num[mark],lv);

//    for(int i=0; i<=cap; i++)

//        printf("%lld ",dp[mark][i]);

//    printf("\n");

 

}

int main()

{

//    freopen("data.txt","r",stdin);

//    freopen("data2.txt","w",stdout);

    while(scanf("%d",&n) != EOF){

        sum = 0;

        tot = 0;

        memset(head, -1, sizeof(head));

        for(int i=0; i<n; i++){

            scanf("%d",&tdata[i]);

            sum += tdata[i];

        }

        int u,v,t;

        for(int i=1; i<n; i++){

            scanf("%d%d%d",&u,&v,&t);

            add_edge(u,v,t);

            add_edge(v,u,t);

        }

        ave = sum / n;

        lv = sum % n;

        dfs2(0,-1,0);

        printf("%lld\n",dp[0][lv]);

    }

    return 0;

}

/*

 

4

3 99 41 12

0 1 14

1 2 23

2 3 50

*/

 

0 0
原创粉丝点击