[Usaco2011open][BZOJ2444]焊接

来源:互联网 发布:网络假新闻案例 编辑:程序博客网 时间:2024/04/27 14:27

【题目描述】

Description


奶牛们正在玩电线!他们学会了焊接:把两条电线连接起来,将某条的端点焊接到
另一条的中间某个位置(注意:不能够将两条电线的端点焊接起来,即中间某个位置
不包括端点)。当然,中间的同一个位置可以焊接多条电线。(并且焊接点必须为整数点,
这个好像英文题面没说,我是这么理解的)

奶牛们准备建造一个神奇的结构。它是一个N(1 <= N <= 50,000)个节点N-1条边的图,
并且任意两个节点连通。每条边通过两个整数A,B来表示(1 <= A <=N; 1 <= B <= N; A != B)。

奶牛们要从当地的店里买电线,然而,越长的电线就越贵,具体地:一条长度为L的电线
的售价为L*L,并且,电线是不允许连接或者裁断的。

给出奶牛准备建造的结构,请帮助奶牛们找出最小的花费。

Input


* 第一行: 一个整数 N

* 第二到N行: 每行两个整数A,B描述一条边

Output


* 第一行:一个整数表示最小的花费,注意这个整数可能超过32位二进制数。

Sample Input

6
1 2
1 3
1 4
1 5
1 6



Sample Output

7

OUTPUT DETAILS:

由于每个节点都和1号节点相连,因此,我们只要购买1条长度为2的电线和3条长度为1的电线即可。
总的花费为2 * 2 +1 * 1 + 1 * 1 + 1 * 1 = 7。


【解题思路】

英文题解(最下有)太长看不下去就去请教了一下大神,然后优化了一下O(N²)的算法就水过了。

状态f[i][j]表示,以i为根的子树,唯一一根接到i并且会在后面继续往上接的电线的当前长度为j的情况下的最小代价

转移有三种(蓝色为当前加入边)

1.蓝色电线截断在i

2.蓝色电线取代之前的电线接到子树外,红色电线截断在i

3.蓝色电线截断在i,并与红色电线合并成一条电线,另外绿色电线不一定如图连到子树外


【优化】

用队列维护f(建议用vector),维护信息j和f[i][j],对于每个深度只需要维护一个队列,每当一个孩子子树信息更新了队列信息后,删去队列内的两种无效点,一是对于相同的j只留下一个最优点,二是对于jx>jy并且fx[i][j]>fy[i][j],把x删除。PS:正解是找到某些性质然后用凸包优化树上dp,最下有官方题解

【呆马】

#include<cstdio>#include<algorithm>#include<cmath>#include<cstdlib>#include<iostream>#include<vector>#define ll long longusing namespace std;const int N=50001;struct st{int l; ll co;st(){}st(int x,ll y){l=x; co=y;}ll cal(){return co+l*l;}bool operator <(const st x) const{return l<x.l || (l==x.l && co<x.co);}};vector <st> f[N],g;vector <int> l[N];int n,i,x,y;bool vis[N];void dp(int x,int d){vis[x]=1;f[d].clear();bool leaf=1;int m1=l[x].size(),m2,m3;for (int k=0;k<m1;k++)if (!vis[l[x][k]]){leaf=0;dp(l[x][k],d+1);m3=f[d+1].size();for (int i=0;i<m3;i++) f[d+1][i].l++;if (!f[d].size()){f[d]=f[d+1]; continue;}g.clear();m2=f[d].size();for (int i=0;i<m2;i++)for (int j=0;j<m3;j++)if (!f[d][i].l) g.push_back(st(0,f[d][i].co+f[d+1][j].cal()));else{g.push_back(st(f[d][i].l,f[d][i].co+f[d+1][j].cal()));g.push_back(st(f[d+1][j].l,f[d+1][j].co+f[d][i].cal()));g.push_back(st(0,f[d][i].cal()+f[d+1][j].cal()+((f[d][i].l*f[d+1][j].l)<<1)));}sort(g.begin(),g.end());int num=0;m2=g.size();for (int i=1;i<m2;i++)if (g[i].cal()<g[num].cal()) g[++num]=g[i];g.resize(num+1);f[d]=g;}if (leaf) f[d].push_back(st(0,0));}int main(){scanf("%d\n",&n);for (i=1;i<n;i++){scanf("%d%d\n",&x,&y);l[x].push_back(y);l[y].push_back(x);}dp(1,0);printf("%lld",f[0][0].cal());}

【官方题解】

       The starting point here is a dynamic programming algorithm. Arbitrarily root the tree and consider "cutting off" a particular subtree in a soldering. This leaves one (or none if a wire was cut off at its endpoints) "cut wire" which extends out of the subtree to the parent and a set of wires that are wholly within the subtree. Now, note that all that is relevant is the length of the "cut wire" within the subtree and the total cost of all the other wires. This is because the cut wire is the only wire whose cost depends on the rest of the soldering.

       This gives a relatively simple dynamic programming solution: for each vertex (defining a subtree) store, for each possible cut wire length, the minimum cost of the other wires; if there is no cut wire this can be taken as a wire of length 0. We will compute these from the bottom up. To compute these values, note that if there is a cut wire it must extend down to one of the children; the cost for a cut wire going through a particular child is the cost for the cut wire through the child's subtree plus the minimum cost soldering covering each of the other subtrees. If there is no cut wire, then the edge going to the parent must be soldered onto the middle of another wire; then one can just check all pairs of lengths and distinct children to find two "cut wires" for two children to merge into into one wire. Now, note that the maximum length cut wire for each subtree is the number of nodes it contains, so the number of pairs of lengths for any two distinct children is at most the number of pairs of nodes in the two children; summing over all children this is the number of nodes whose lowest common ancestor is the root of the subtree. Then the total work done over the whole algorithm is only the total number of pairs of nodes, or O(N2).
       Now, at this point it will be convenient to assume, in the discussion of the algorithm, that each vertex has at most two children. In fact, this is not a problem: a vertex V with more children can be "split up" by giving it a direct edge to one of its children and attaching the remainder to a new vertex V' with an edge to it from V of length 0 (the length does not break the algorithm although all edges in the problem were of length 1), then iterating this until no vertex has more than two children.
To further reduce the runtime, one must note the convexity properties of the squaring of the length. If one looks at a length/cost pair (l, c) for a subtree, it corresponds to the function (L+l)2+c where L is the length of the cut wire outside the subtree. But one only cares about those functions that are the minimum for some value of L: since (L+l)2 + c = L2 + 2L * l + (l2+ c), this is the lower envelope of these functions, equivalent to a convex hull. All pairs not in the envelope can be deleted. One can then binary search the convex hull to find the optimal pairing with any particular length of the wire outside the subtree. Then to find the optimal pair of lengths in the two children to merge into one wire, one can simply take all the lengths in the smaller subtree and binary search the convex hull in the larger subtree to find the best thing within that subtree to pair it with. Finally, to efficiently find the convex hulls for all subtrees, one can represent the convex hulls with binary search trees (std::set does fine here) and to get the possibilities from either child, one can offset the values in the larger child subtree's convex hull (by storing offset values that are added to all the pairs in the hull, since both length and cost change as you merge subtrees) and then insert each pair (offset) from the smaller subtree into it. The total number of operations on the binary search trees is then at most the sum of the sizes of the smaller child subtree from each node (in fact it can be smaller as the convex hull can have fewer elements than the size of the subtree). This can be shown to be O(N log N): one can consider the number of times each position gets merged into a larger group, and note that it is always less than log N since with each merge only the values in the smaller half are incremented. Each tree operation is O(log N), so the overall runtime is O(N log2 N).

Below is Neal Wu's N2 implementation:

#include <cstdio>#include <vector>#include <algorithm>using namespace std;FILE *in = fopen ("solder.in", "r"), *out = fopen ("solder.out", "w");const int MAXN = 50005;const long long LLINF = 1LL << 60;int N, down [MAXN];long long *dp [MAXN], mindp [MAXN];vector <int> adj [MAXN];void init_dfs (int num, int par){    down [num] = 1;    int par_ind = -1;    for (int i = 0; i < (int) adj [num].size (); i++)    {        int child = adj [num][i];        if (child == par)        {            par_ind = i;            continue;        }        init_dfs (child, num);        down [num] = max (down [num], down [child] + 1);    }    if (par_ind != -1)        adj [num].erase (adj [num].begin () + par_ind);}void solve_dfs (int num){    for (int i = 0; i < (int) adj [num].size (); i++)        solve_dfs (adj [num][i]);    long long dp1 = 1;    if (adj [num].size () > 1)    {        for (int i = 0; i < (int) adj [num].size (); i++)            dp1 += mindp [adj [num][i]];        long long best_two = LLINF;        for (int i = 0; i < (int) adj [num].size (); i++)            for (int j = i + 1; j < (int) adj [num].size (); j++)            {                int child1 = adj [num][i], child2 = adj [num][j];                for (int a = 1; a <= down [child1]; a++)                    for (int b = 1; b <= down [child2]; b++)                        best_two = min (best_two, dp [child1][a] + dp[child2][b] + 2LL * a * b - mindp [child1] - mindp [child2]);            }        dp1 += best_two;    }    dp [num] = new long long [down [num] + 1];    dp [num][1] = dp1;    if (adj [num].size () == 1)    {        dp [num][1] = LLINF;        dp [num][0] = mindp [adj [num][0]];    }    else        dp [num][0] = dp [num][1] - 1;    for (int k = 1; k < down [num]; k++)    {        long long sum = 0, best_link = LLINF;        for (int i = 0; i < (int) adj [num].size (); i++)        {            int child = adj [num][i];            sum += mindp [child];            if (k <= down [child])                best_link = min (best_link, dp [child][k] - mindp [child]);        }        dp [num][k + 1] = sum + best_link + 2 * k + 1;    }    mindp [num] = LLINF;    for (int k = 1; k <= down [num]; k++)        mindp [num] = min (mindp [num], dp [num][k]);    for (int i = 0; i < (int) adj [num].size (); i++)        delete dp [adj [num][i]];}int main (){    fscanf (in, "%d", &N);    for (int i = 1, a, b; i < N; i++)    {        fscanf (in, "%d %d", &a, &b); a--; b--;        adj [a].push_back (b);        adj [b].push_back (a);    }    init_dfs (0, -1);    solve_dfs (0);    fprintf (out, "%lld\n", dp [0][0]);    return 0;}

And below is Michael Cohen's impressive full implementation:

#include <fstream>#include <vector>#include <set>#define endl '\n'using namespace std;struct poss {long long depth;long long cost;long long takeover;bool tcheck;};bool operator<(poss a, poss b) {if (a.tcheck || b.tcheck) return (a.takeover < b.takeover);if (a.depth > b.depth) return true;if (a.depth < b.depth) return false;return (a.cost < b.cost);}int N;vector<int> edges[50000];bool visited[50000];long long depth[50000];long long offset[50000];set<poss>* best[50000];void recurse(int node) {visited[node] = true;long long bestPair = -1;long long allSoFar = 0;for (int i = 0; i < edges[node].size(); i++) {if (visited[edges[node][i]]) continue;depth[edges[node][i]] = depth[node]+1;recurse(edges[node][i]);long long tadd;{poss when = { 0, 0, -depth[node], true };set<poss>::iterator which = best[edges[node][i]]->upper_bound(when);which--;tadd =(depth[node]-which->depth)*(depth[node]-which->depth)+which->cost+offset[edges[node][i]];}if (bestPair != -1) bestPair += tadd;if (best[node] == NULL) {best[node] = best[edges[node][i]];offset[node] = offset[edges[node][i]];}else {set<poss>* s = best[node], * t = best[edges[node][i]];long long os = offset[node]+tadd, ot = offset[edges[node][i]]+allSoFar;if (s->size() < t->size()) {set<poss>* tem = s;s = t;t = tem;int to = os;os = ot;ot = to;}for (set<poss>::iterator it = t->begin(); it != t->end(); it++) {poss when = { 0, 0, it->depth-2*depth[node], true };set<poss>::iterator which = s->upper_bound(when);which--;long long thisPair =(it->depth+which->depth-2*depth[node])*(it->depth+which->depth-2*depth[node])+it->cost+which->cost+offset[node]+offset[edges[node][i]];if (bestPair == -1 || thisPair < bestPair) bestPair = thisPair;}for (set<poss>::iterator it = t->begin(); it != t->end(); it++) {poss p = *it;p.cost += ot-os;set<poss>::iterator where = s->insert(p).first;bool killed = false;while (where != s->begin()) {set<poss>::iterator prev = where;prev--;if (prev->depth == where->depth) {s->erase(where);killed = true;break;}p.takeover =(where->cost-prev->cost+where->depth*where->depth-prev->depth*prev->depth)/(2*prev->depth-2*where->depth);while ((2*prev->depth-2*where->depth)*p.takeover <where->cost-prev->cost+where->depth*where->depth-prev->depth*prev->depth)p.takeover++;s->erase(where);where = s->insert(p).first;if (where->takeover <= prev->takeover) s->erase(prev);else break;}if (killed) continue;if (where == s->begin()) {p.takeover = -1000000000;s->erase(where);where = s->insert(p).first;}set<poss>::iterator next = where;next++;while (next != s->end()) {if (next->depth == where->depth) {s->erase(next);next = where;next++;continue;}poss n = *next;n.takeover =(next->cost-where->cost+next->depth*next->depth-where->depth*where->depth)/(2*where->depth-2*next->depth);while ((2*where->depth-2*next->depth)*n.takeover <next->cost-where->cost+next->depth*next->depth-where->depth*where->depth)n.takeover++;if (n.takeover <= where->takeover) {s->erase(where);break;}s->erase(next);next = s->insert(n).first;set<poss>::iterator nnext = next;nnext++;if (nnext != s->end() && nnext->takeover <=next->takeover) {s->erase(next);next = nnext;}else break;}}best[node] = s;offset[node] = os;delete t;}allSoFar += tadd;}if (best[node] == NULL) {best[node] = new set<poss>();poss p = { depth[node], 0, -1000000000, false };best[node]->insert(p);}else if (bestPair != -1) {poss p = { depth[node], bestPair-offset[node], 0, false };while (!best[node]->empty()) {p.takeover =(p.cost-best[node]->rbegin()->cost+p.depth*p.depth-best[node]->rbegin()->depth*best[node]->rbegin()->depth)/(2*best[node]->rbegin()->depth-2*p.depth);while ((2*best[node]->rbegin()->depth-2*p.depth)*p.takeover <p.cost-best[node]->rbegin()->cost+p.depth*p.depth-best[node]->rbegin()->depth*best[node]->rbegin()->depth)p.takeover++;if (p.takeover > best[node]->rbegin()->takeover) break;best[node]->erase(*(best[node]->rbegin()));}if (best[node]->empty()) p.takeover = -1000000000;best[node]->insert(p);}}int main(){ifstream inp("solder.in");ofstream outp("solder.out");inp >> N;for (int i = 0; i < N-1; i++) {int a, b;inp >> a >> b;a--, b--;edges[a].push_back(b);edges[b].push_back(a);}recurse(0);if (edges[0].size() == 1) {poss when = { 0, 0, 0, true };set<poss>::iterator which = best[0]->upper_bound(when);which--;outp << which->depth*which->depth+which->cost+offset[0]<< endl;}else {poss p = *(best[0]->rbegin());outp << p.cost+offset[0] << endl;}return 0;}




0 0