HDU 5293 Tree chain problem(数链问题)【LCA+树形dp+dfs序+树状数组】

来源:互联网 发布:家用网络监控 编辑:程序博客网 时间:2024/06/02 02:30

话说这题很骚

原题链接如下:
http://acm.hdu.edu.cn/showproblem.php?pid=5293

本题大意如下:

Tree chain problem

Problem Description

有一棵树,树的节点编号为1,2,…,n。
树上有m条路,每条路有1个权值,现在要从这些路径中选一些,选出的路径不能有公共点。
求这些路径的权值和最大是多少

Input

第一行,一个整数T,表示有T组测试数据(T<=10)
对于每一组数据,第一行两个整数n和m,分别表示树上的节点数和路径数(1<=ai,bi<=n)
接下来n-1行,每行两个整数ai,bi,表示在结点ai和bi之间有一条边(1<=ai,bi<=n)
接下来m行,每行三个整数u,v和val,表示在结点u和v之间有一条路径,权值为val(1<=u,v<=n,1 <=val<=1000)

Output

一个整数,表示所选路径的最大权值和

Sample Input

1
7 3
1 2
1 3
2 4
2 5
3 6
3 7
2 3 4
4 5 3
6 7 3

Sample Output

6

Hint

话说这题真的骚到爆炸。。。
在做这题之前dfs序是什么我还真不知道
而在前不久学的我认为很可能被我遗弃很长时间的树状数组竟然也被用上了0.0

不过在做这一题之前,我们先来A掉简化版(博主在内部考试碰到的题,AC以后被老师怂恿去做这个进阶题。。)
简化版问题和原题没有什么大差别,只是每一条路径上是没有权值的,最后输出能用上的路径条数最大值
(你可以把它当做每条边上的权值均为1)
虽说这两题看起来差不多,但是难度差距极大,方法也不一样。。
对于这一个简化题,首先我们从这棵树是一条链的情况开始考虑
那么,你是不是发现了什么,没错,这样再次简化了问题以后这题就成为了一道极为经典的简单贪心题。
(USACO 2014 January Silver 录制比赛)
做这种题的时候我们一般都是根据右端点从小到大排序的
那么对于这一题树状的结构我们该怎么排序呢?
首先为了解题的简单,我们要保证已经使用的路径不会被撤销,也就是说要保证这种情况已经是最优的
那么类比链上的解法很容易想到LCA,用LCA的深度从大到小排序可以保证最优,因为两条路径若是冲突,选用LCA深度较大的显然可以使其他的路径有更大的选择空间
排序完毕以后就要开始贪心了,我们可以知道当以某一点x为lca的一条路径(即链)被使用过以后,以点x为根的子树中的点不会也不能被使用到,于是我们可以把这一棵子树上的所有点都mark掉,而之后判断某一条链能不能使用只需要判断这一条链的两个顶点有没有被mark过就行了(因为若一个点被mark过,他的所有子节点一定都被mark过,而父亲结点却不一定);
本简化版题目代码如下:
(PS:这不是 Tree chain problem 的代码)

#include<stdio.h>#include<vector>#include<algorithm>#include<iostream>using namespace std;#define M 100005vector<int>edge[M];//边vector<int>BIGedge[M];//链int dis[M],fa[M],tfa[M],mark[M];//dis[i]代表结点i的深度//fa[i]代表结点i的以返回的最大父亲,这是用来求LCA的//tfa[i]代表结点i的直接父亲struct node{    int st,ed,lca;    bool operator <(const node &A)const{        return dis[lca]>dis[A.lca];//排序。。。。    }}s[M];int getfa(int x){    if(x==fa[x])return x;    return fa[x]=getfa(fa[x]);//路径压缩}void dfs(int x,int pre,int d){    dis[x]=d;//标记深度    for(int i=0;i<edge[x].size();i++){        int y=edge[x][i];        if(y==pre)continue;        dfs(y,x,d+1);        fa[y]=x;        tfa[y]=x;    }for(int i=0;i<BIGedge[x].size();i++){        int id=BIGedge[x][i];        if(x==s[id].st&&dis[s[id].ed])s[id].lca=getfa(s[id].ed);        if(x==s[id].ed&&dis[s[id].st])s[id].lca=getfa(s[id].st);        //若x为这条边的某一个顶点a且另一个顶点b已经被遍历过,则LCA为b的已返回的最大父亲        //也就是说,遍历a的时候是从哪一点往深处dfs到达b点的,LCA就是那一个点        //这是一个求LCA比较好的办法    }}void Mark(int x){    mark[x]=1;    for(int i=0;i<edge[x].size();i++){        int y=edge[x][i];        if(dis[y]>dis[x])Mark(y);//不要往父亲结点上走。。。    }//mark}int main(){    int n,m;    scanf("%d %d",&n,&m);    for(int i=1;i<n;i++){        int a,b;        scanf("%d %d",&a,&b);        edge[a].push_back(b);        edge[b].push_back(a);//造树    }for(int i=1;i<=m;i++){        scanf("%d %d",&s[i].st,&s[i].ed);        BIGedge[s[i].st].push_back(i);        BIGedge[s[i].ed].push_back(i);//加入链    }for(int i=1;i<=n;i++)fa[i]=i;    dfs(1,0,1);//求LCA    sort(s+1,s+m+1);    int ans=0;    for(int i=1;i<=m;i++){        int a=s[i].st,b=s[i].ed;        if(mark[a]||mark[b])continue;//不要这条链        ans++;//要这条链        Mark(s[i].lca);    }printf("%d\n",ans);    return 0;}

AC了简化版以后我们来思考一下原题,原题显然是不能和简化版一样用贪心的,但是我们想到简化版的题其实还有一种树形dp的做法(这里就不敲了,其实和原题代码差不多)
我们想,这题是不是也可以用树形dp写呢?
若dp[i]代表在根结点为i的子树上的最大权值和,
sum[i]代表i的所有子结点的dp值的和,即sum[i]=Σdp[q];(q为i的子结点)
那么dp[i]的状态转移公式,有两种可能,第一种:
结点i上不出现链,那么 dp[i] = ∑dp[q] (q为i的子节点)=sum[i];
第二种:
结点i上出现链,如果选择加入这条链,那么
dp[i]
=val(链的权值) + ∑dp[p] (p为链上结点的子结点,)
=val + ∑sum[k] - ∑dp[k] (k为链上的结点)
于是乎——
dp[i]
=max(∑sum[k],val + ∑sum[k] - ∑dp[k]);
这个时候我们的工作重心就转移到了提取链上的sum值上了
很显然,直接暴力找链上的点的复杂度是O(n^2),会超时。。。
于是我们就要想办法使用数据结构去维护这个值。。
这个过程没有什么捷径了。。
把自己学过的数据结构在脑子里全部过一遍,先排除掉一些
最后实在不知道可不可行就都敲一遍吧
就算最后敲不出来也不亏。。。
这个题目用树状数组和线段树都是可以的
这时候问题就来了,
树状数组虽然是树状的,但存储的仍旧是线性的信息
于是我们要把这棵树变成线性
于是就用到了dfs序
任意一颗子树在dfs序中都是连续的,于是很容易更新和查询
然后一堆错误,
一顿调试
一通对拍
最终AC。。。。
过程很艰难。。。。。。。
虽然我的代码是O(nlogn)的,但是由于系数比较大,跑了1500多毫秒。。
幸好这题时限比较大有3000ms
这题似乎也可以用数链剖分写,但是我不会

AC代码如下:

#include<stdio.h>#include<string.h>#include<vector>#include<algorithm>#include<iostream>using namespace std;#define M 100005vector<int>edge[M];//边vector<int>BIGedge[M];//链vector<int>Chain[M];//Chian[x]表示以点x为LCA的链的集合int dis[M],fa[M];int sum[2*M],dp[2*M];int l[M],r[M];//dfs序int c1[2*M],c2[2*M];//树状数组//c1存储sum值,c2存储dp值(其实完全可以一起存)struct node{    int st,ed,lca,val;}s[M];void Rd(int &res){    char c;res=0;    while(c=getchar(),!isdigit(c));    do{        res=(res<<3)+(res<<1)+(c^48);    }while(c=getchar(),isdigit(c));}int getfa(int x){    if(x==fa[x])return x;    return fa[x]=getfa(fa[x]);}int tot=0;void dfs(int x,int pre,int d){    l[x]=++tot;//造dfs序    dis[x]=d;    for(int i=0;i<edge[x].size();i++){        int y=edge[x][i];        if(y==pre)continue;        dfs(y,x,d+1);        fa[y]=x;    }for(int i=0;i<BIGedge[x].size();i++){//求LCA,在简化版题目的代码中已解释        int id=BIGedge[x][i];        if(x==s[id].st&&dis[s[id].ed])s[id].lca=getfa(s[id].ed);        if(x==s[id].ed&&dis[s[id].st])s[id].lca=getfa(s[id].st);    }r[x]=tot;}void init(int n){//孩子,有多组数据,别忘了清空数组,之前一个名为zqh_wz的编程王者就挂在这上面了    tot=0for(int i=1;i<=n;i++){edge[i].clear();BIGedge[i].clear();Chain[i].clear();}    memset(dp,0,sizeof(dp));    memset(sum,0,sizeof(sum));    memset(c1,0,sizeof(c1));    memset(c2,0,sizeof(c2));    memset(l,0,sizeof(l));    memset(r,0,sizeof(r));}int lowbit(int x){//树状数组不要告诉我你不会      return x&-x;}  void add(int i,int k,int *c,int n){//*c传入数组的地址    while(i<=n){        c[i]+=k;        i+=lowbit(i);      }  }//更新sum,dp值int getsum(int i,int *c){      int res=0;      while(i){         res+=c[i];          i-=lowbit(i);      }return res;}//导出sum,dp值void Treedp(int x,int pre,int n){//最终环节,树形dp    for(int i=0;i<edge[x].size();i++){        int y=edge[x][i];        if(y==pre)continue;        Treedp(y,x,n);        sum[x]+=dp[y];//sum[x]=Σdp[y](y为x的子结点),用树形dp值,更新树形sum值    }dp[x]=sum[x];    for(int i=0;i<Chain[x].size();i++){          int a=s[Chain[x][i]].st,b=s[Chain[x][i]].ed;//找链        int tmp=getsum(l[a],c1)+getsum(l[b],c1)-getsum(l[a],c2)-getsum(l[b],c2)+sum[x];//取出线性sum值和dp值          dp[x]=max(dp[x],tmp+s[Chain[x][i]].val); //用来更新树形的dp值    }add(l[x],sum[x],c1,n*2); //一大波更新还有30秒到达战场(更新线性dp、sum值)    add(r[x],-sum[x],c1,n*2);      add(l[x],dp[x],c2,n*2);      add(r[x],-dp[x],c2,n*2);  }int main(){    int Cas;    scanf("%d",&Cas);    while(Cas--){        int n,m;        scanf("%d %d",&n,&m);        init(n);        for(int i=1;i<n;i++){            int a,b;            Rd(a);Rd(b);            edge[a].push_back(b);            edge[b].push_back(a);        }for(int i=1;i<=m;i++){            Rd(s[i].st);Rd(s[i].ed);Rd(s[i].val);            BIGedge[s[i].st].push_back(i);            BIGedge[s[i].ed].push_back(i);        }for(int i=1;i<=n;i++)fa[i]=i;        dfs(1,0,1);//求lCA,dfs序        for(int i=1;i<=m;i++){            Chain[s[i].lca].push_back(i);//加链        }Treedp(1,0,n);//solve环节        printf("%d\n",dp[1]);输出根节点的dp值即可    }return 0;}
2 0
原创粉丝点击