hdu 4679 Terrorist’s destroy (树形dp)

来源:互联网 发布:华为手机连不上4g网络 编辑:程序博客网 时间:2024/05/18 00:46

hdu 4679 Terrorist’s destroy

4小时55分敲完代码,0调试,跑出样例直接交了,结果爆栈,扩栈交了一发,居然A了,我和我的小伙伴们都惊呆了。。。

题意:给出有n个节点的一棵树,树上的边有权值。我们切断一条边,将整棵树分成两颗,计算一个值,这个值的计算方法,v = b * max ( d1 , d2 ) ;其中,b为所切的边的权值,d1 ,d2 为切断后形成的两颗树的树上的最长路。对于每一条边,会计算出一个v值,问切那条边时,v值最小,如果有多个最小的v值,输出边的id最小的那条。

解题思路:树形dp(对于整棵树而言,根节点为1号节点)。。对于每一个节点,要记录很多东西,l1 , l2 , l3 分别表示该节点下,最长的,次长的,第三长链的长度,f1 , f2 , f3分别表示最长的,次长的,第三长的链来自于哪个儿子节点。t1 , t2分别表示该节点下的儿子中的最长路和次长路的长度,g1 , g2分别表示最长路,次长路的儿子是哪个儿子。dp[u]表示u节点下的这颗子树的最长路。怎么更新这些信息呢?对于l1 , l2 , l3 , f1 , f2 , f3 枚举每个儿子,递推上来吧,还是很好做的,t1 , t2 就枚举 dp[v] (v表示u的儿子节点)得出来。dp[u]就是max ( t1 , t2 , l1 + l2  ) 。这就可以推出这棵树上每个节点的信息了。然后我么就要算切断某一条边的时候的值了。对于根节点下的边,我们枚举一下,答案还是很容易得出的,d1就是dp[v], d2的话,看枚举的这个v是不是g1,能取t1就取 t1 , 否则就取t2,但这样d2并不一定是最优的,还要从最长,次长,第三长里面取出能去的最长的两条去拼一下,得出一个最大值就是d2了。对于每个节点v,算一次值,更新下答案。然后就要处理以1的儿子为树根了,这是就要把父亲这个节点转变成儿子了,也就是我们要处理的下一个树根多了一个特殊的儿子,就是原来的它的父亲,我们对这个父亲,转变成儿子,要传两个信息下去,它能传下去的最长的l(就是链)的值,和他传下去的最长的t(就是将它变成儿子后,能得到的新的最长路值),而这两个信息,又要根据之前预处理的那些东西进行计算了,链的值就是从能取的最长的链里取一个最长的+1,最长路的值,就是根据能取的儿子的最长路里的值,以及能去的最长的两条边拼出的最大值(这就是为什么要处理三条最长的链了,因为最长的链有可能来自于我要转换成根的那个儿子,而这是不能取的)。转变了树的结构后,我们对于新的根节点,又可以计算它下面的每条边能计算的v值,一直递归,计算出每条边的v值就好了。

挫代码一份,不忍直视:

#pragma comment(linker, "/STACK:1024000000,1024000000")#include<stdio.h>#include<string.h>#include<algorithm>#define ll __int64using namespace std ;const int maxn = 111111 ;struct Node {    int l1 , l2 , l3 , f1 , f2 , f3 ;    int t1 , t2 , g1 , g2 ;    void init () {        l1 = l2 = l3 = f1 = f2 = f3 = t1 = t2 = g1 = g2 = 0 ;    }} p[maxn] ;struct Edge {    int t , next , v , id ;} edge[maxn<<1] ;int head[maxn] , tot , fuck ;ll dp[maxn] , ans ;void new_edge ( int a , int b , int c ) {    edge[tot].t = b ;    edge[tot].v = c ;    edge[tot].next = head[a] ;    edge[tot].id =     head[a] = tot ++ ;    edge[tot].t = a ;    edge[tot].v = c ;    edge[tot].next = head[b] ;    head[b] = tot ++ ;}void init ( int n ) {    int i ;    tot = 0 ;    for ( i = 0 ; i <= n ; i ++ ) head[i] = -1 ;    for ( i = 0 ; i <= n ; i ++ ) {        dp[i] = 0 ;        p[i].init () ;    }    ans = (ll) 1111111 * 111111 ;}void cal ( int u , int fa ) {    int i , j , k ;    for ( i = head[u] ; i != -1 ; i = edge[i].next ) {        int v = edge[i].t ;        if ( v == fa ) continue ;        cal ( v , u ) ;        if ( p[v].l1 + 1 >= p[u].l1 ) {            p[u].l3 = p[u].l2 ;            p[u].f3 = p[u].f2 ;            p[u].l2 = p[u].l1 ;            p[u].f2 = p[u].f1 ;            p[u].l1 = p[v].l1 + 1 ;            p[u].f1 = v ;        }        else if ( p[v].l1 + 1 >= p[u].l2 ) {            p[u].l3 = p[u].l2 ;            p[u].f3 = p[u].f2 ;            p[u].l2 = p[v].l1 + 1 ;            p[u].f2 = v ;        }        else if ( p[v].l1 + 1 >= p[u].l3 ) {            p[u].l3 = p[v].l1 + 1 ;            p[u].f3 = v ;        }        if ( dp[v] >= p[u].t1 ) {            p[u].t2 = p[u].t1 ;            p[u].g2 = p[u].g1 ;            p[u].t1 = dp[v] ;            p[u].g1 = v ;        }        else if ( dp[v] >= p[u].t2 ) {            p[u].t2 = dp[v] ;            p[u].g2 = v ;        }    }    dp[u] = max ( p[u].t1 , p[u].l1 + p[u].l2 ) ;}void dfs ( int u , int fa , int fv , int ft ) {    int i ;    ll k ;    for ( i = head[u] ; i != -1 ; i = edge[i].next ) {        int v = edge[i].t ;        if ( v == fa ) continue ;        int d1 = dp[v] ;        int ak[10] , T = 0 ;        if ( p[u].f1 != v ) ak[++T] = p[u].l1 ;        if ( p[u].f2 != v ) ak[++T] = p[u].l2 ;        if ( p[u].f3 != v ) ak[++T] = p[u].l3 ;        ak[++T] = fv ;        sort ( ak + 1 , ak + T + 1 ) ;        int d2 = ak[T] + ak[T-1] ;        d2 = max ( d2 , ft ) ;        if ( v != p[u].g1 ) d2 = max ( d2 , p[u].t1 ) ;        else d2 = max ( d2 , p[u].t2 ) ;        k = (ll) edge[i].v * ( max ( d1 , d2 ) ) ;        if ( k <= ans ) {            if ( k < ans ) ans = k , fuck = i / 2 + 1 ;            else if ( i / 2 + 1 < fuck ) fuck = i / 2 + 1 ;        }        int fvv = fv + 1 , ftt = ft ;        T = 0 ;        if ( p[u].f1 != v ) ak[++T] = p[u].l1 ;        if ( p[u].f2 != v ) ak[++T] = p[u].l2 ;        if ( p[u].f3 != v ) ak[++T] = p[u].l3 ;        sort ( ak + 1 , ak + T + 1 ) ;        fvv = max ( fvv , ak[T] + 1 ) ;        if ( v != p[u].g1 ) ftt = max ( ftt , p[u].t1 ) ;        else ftt = max ( ftt , p[u].t2 ) ;        ak[++T] = fv ;        sort ( ak + 1 , ak + T + 1 ) ;        ftt = max ( ftt , ak[T] + ak[T-1] ) ;        dfs ( v , u , fvv , ftt ) ;    }}int main () {    int cas , ca = 0 ;    int i , j , n , m , a , b , c ;    ll k ;    scanf ( "%d" , &cas ) ;    while ( cas -- ) {        scanf ( "%d" , &n ) ;        init ( n ) ;        for ( i = 1 ; i < n ; i ++ ) {            scanf ( "%d%d%d" , &a , &b , &c ) ;            new_edge ( a , b , c ) ;        }        cal ( 1 , 0 ) ;        for ( i = head[1] ; i != -1 ; i = edge[i].next ) {            int v = edge[i].t ;            int d1 = dp[v] , d2 = p[1].t1 ;            if ( v == p[1].g1 ) d2 = p[1].t2 ;            if ( v == p[1].f1 ) d2 = max ( d2 , p[1].l2 + p[1].l3 ) ;            else if ( v == p[1].f2 ) d2 = max ( d2 , p[1].l1 + p[1].l3 ) ;            else d2 = max ( d2 , p[1].l1 + p[1].l2 ) ;            k = (ll) edge[i].v * ( max ( d1 , d2 ) ) ;            if ( k <= ans ) {                if ( k < ans ) ans = k , fuck = i / 2 + 1 ;                else if ( i / 2 + 1 < fuck ) fuck = i / 2 + 1 ;            }            int fv , ft ;            fv = p[1].l1 + 1 ;            if ( v == p[1].f1 ) fv = p[1].l2 + 1 ;            ft = p[1].t1 ;            if ( v == p[1].g1 ) ft = p[1].t2 ;            int ak[10] , T = 0 ;            if ( p[1].f1 != v ) ak[++T] = p[1].l1 ;            if ( p[1].f2 != v ) ak[++T] = p[1].l2 ;            if ( p[1].f3 != v ) ak[++T] = p[1].l3 ;            sort ( ak + 1 , ak + T + 1 ) ;            ft = max ( ft , ak[T] + ak[T-1] ) ;            dfs ( v , 1 , fv , ft ) ;        }        printf ( "Case #%d: %d\n" , ++ ca , fuck ) ;    }}