bzoj1912 [Apio2010]patrol 巡逻(树的直径[变式])

来源:互联网 发布:祖马龙 知乎 编辑:程序博客网 时间:2024/05/17 19:21

Description
这里写图片描述

Input
第一行包含两个整数 n, K(1 ≤ K ≤ 2)。接下来 n – 1行,每行两个整数 a, b, 表示村庄a与b之间有一条道路(1 ≤ a, b ≤ n)。

Output
输出一个整数,表示新建了K 条道路后能达到的最小巡逻距离。

Sample Input
8 1
1 2
3 1
3 4
5 3
7 5
8 5
5 6

Sample Output
11

HINT
10%的数据中,n ≤ 1000, K = 1;
30%的数据中,K = 1;
80%的数据中,每个村庄相邻的村庄数不超过 25;
90%的数据中,每个村庄相邻的村庄数不超过 150;
100%的数据中,3 ≤ n ≤ 100,000, 1 ≤ K ≤ 2。

分析:
先考虑K==1的情况
加了一条边之后,图就变成了一棵环套树
显然我们如果要让巡逻距离尽可能短,那么就要使环上的边尽量多,
(因为不在环上的边都要走两遍
怎么让环大呢,
就是让添加边变成环之前的那条链尽可能长
这就是树上最长链

树的直径

求法:
两遍dfs,
第一遍dfs找到一个最远点,再从最远点dfs,最后得出的最长dis就是树的直径

那么K==1的时候就是求一个直径dis
ans=2*(n-1-dis)+dis+1

K==2时,就是在K==1求出解得基础上,求一个次长直径

注意

如果我们什么处理都没有,直接求一个次长链(次短路方法),
可能会和最长链重合,那么最长链上的一部分就会走两遍
所以我们在求出最长链之后,把最长链上的边权赋为-1,
这样再跑一个裸的直径就好了
(这样就可以保证可以在新求出的直径中尽量少重合原先的直径)

tip

代码中我用了两种求直径的方法
注意dp返回值是当前点的f值
然而答案要单独统计

这里写代码片#include<cstdio>#include<cstring>#include<iostream>using namespace std;const int N=100010;struct node{    int x,y,nxt,v;};node way[N<<1];int st[N],tot=-1,n,K;int len1,len2,pre[N],ansx,ans;int f[N],g[N];void add(int u,int w,int z){    tot++;    way[tot].x=u;way[tot].y=w;way[tot].v=z;way[tot].nxt=st[u];st[u]=tot;    tot++;    way[tot].x=w;way[tot].y=u;way[tot].v=z;way[tot].nxt=st[w];st[w]=tot;}void dfs(int now,int fa,int dis){    if (dis>ans)    {        ans=dis;        ansx=now;    }    for (int i=st[now];i!=-1;i=way[i].nxt)        if (way[i].y!=fa)        {            pre[way[i].y]=i;            dfs(way[i].y,now,dis+way[i].v);        }}void change(int s,int t){    for (int i=t;i!=s;i=way[pre[i]].x)    {        way[pre[i]].v=-1;        way[pre[i]^1].v=-1;    }}int dp(int now,int fa){    f[now]=0;g[now]=0;    for (int i=st[now];i!=-1;i=way[i].nxt)        if (way[i].y!=fa)        {            int len=dp(way[i].y,now)+way[i].v;            if (len>f[now])            {                g[now]=f[now];                f[now]=len;            }            else if (len>g[now]) g[now]=len;        }    len2=max(len2,f[now]+g[now]);   //统计答案     return f[now];  //dp返回的是最长链 }int main(){    memset(st,-1,sizeof(st));    scanf("%d%d",&n,&K);    for (int i=1;i<n;i++)    {        int u,w;        scanf("%d%d",&u,&w);        add(u,w,1);    }    memset(pre,-1,sizeof(pre));    ans=ansx=0;    dfs(1,0,0);    int p=ansx;    ans=ansx=0;    memset(pre,-1,sizeof(pre));    dfs(p,0,0);    if (K==1)    {        printf("%d",2*(n-1-ans)+ans+1);        return 0;    }    len1=ans;    change(p,ansx);    dp(1,0);    printf("%d\n",2*(n-1-len1-len2)+len1+len2+2);    return 0;}