[二分 树形DP] CEOI2017. Mousetrap

来源:互联网 发布:访问服务器8080端口 编辑:程序博客网 时间:2024/06/13 00:51

第一次做这种DP题

t 为根,那么老鼠的决策肯定是先往上走一段(或不走),再往子树中走。

如果老鼠往子树中走,我们肯定是等它走到某个位置不能走了,然后把他当前在的位置当根节点这段路上的其他支路都封死,这样是最优的。

那么我们可以树形DP出老鼠走到以这个点为根的子树的时候,需要的最少步数。

wi 表示走到这个节点为根的子树时候的最小步数( wi 要算上从这个点到根路径上的支路的数量)

那么 wi=secmaxusoni{wu} 其中 secmax{} 是次大值,因为可以把通往 wi 最大的儿子的路封上,老鼠就会走向次大儿子。

那么可以二分答案,然后从 m 点开始往上扫,这就相当于老鼠先往上走一段路,如果一个儿子的 wi 大于 mid ,那么这个点就要封掉,因为老鼠只要往这个节点走,你的步数就会大于 mid ,如果在老鼠走到这个点之前,你不能把这条路上的所有不合法的支路都封掉,那么答案也会大于 mid

这样复杂度就是 O(nlogn)

#include <cstdio>#include <iostream>#include <algorithm>using namespace std;const int N=1000010;int n,m,t,cnt,G[N],w[N],d[N];struct edge{    int t,nx;}E[N<<1];inline char nc(){    static char buf[100000],*p1=buf,*p2=buf;    return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++;}inline void rea(int &x){    char c=nc(); x=0;    for(;c>'9'||c<'0';c=nc());for(;c>='0'&&c<='9';x=x*10+c-'0',c=nc());}inline void addedge(int x,int y){    E[++cnt].t=y; E[cnt].nx=G[x]; G[x]=cnt; d[x]++;    E[++cnt].t=x; E[cnt].nx=G[y]; G[y]=cnt; d[y]++;}int fa[N],dpt[N];int Q[N],qt;void dfs(int x,int f){    fa[x]=f; dpt[x]=dpt[f]+1;    for(int i=G[x];i;i=E[i].nx)        if(E[i].t!=f) dfs(E[i].t,x);}void calc(int x,int y){    for(int i=G[x];i;i=E[i].nx)        if(E[i].t!=fa[x]) calc(E[i].t,y+d[x]-1);    if(d[x]==1) w[x]=y;    else if(d[x]==2) w[x]=y+1;    else{        int imax=0,sec=0;        for(int i=G[x];i;i=E[i].nx)            if(E[i].t!=fa[x])                if(w[E[i].t]>imax)                    sec=imax,imax=w[E[i].t];                else                    sec=max(sec,w[E[i].t]);        w[x]=sec;    }}int vis[N];inline bool check(int X){    int lst=0;    for(int k=1;k<=qt;k++){        int x=Q[k],cur=0;        for(int i=G[x];i;i=E[i].nx)            if(E[i].t!=fa[x] && w[E[i].t]+lst>X && !vis[E[i].t])                cur++;        lst+=cur;        if(lst>k || lst>X) return 0;    }    return 1;}int main(){    rea(n); rea(t); rea(m);    for(int i=1,x,y;i<n;i++)        rea(x),rea(y),addedge(x,y);    dfs(t,0);     for(int u=m;u!=t;u=fa[u]) Q[++qt]=u,vis[u]=1;    for(int k=qt,cur=0;k;k--){        int x=Q[k]; cur+=d[x]-2;        for(int i=G[x];i;i=E[i].nx)            if(E[i].t!=fa[x] && !vis[E[i].t]) calc(E[i].t,cur+(x==m));    }    int L=0,R=n,mid,ans;    while(L<=R)         check(mid=L+R>>1)?R=(ans=mid)-1:L=mid+1;    printf("%d\n",ans);    return 0;}