【JZOJ 4587】 Snow的追寻

来源:互联网 发布:教务网络管理系统郑铁 编辑:程序博客网 时间:2024/05/20 14:44

Description

这里写图片描述
这里写图片描述

Analysis

此题本人跑得最快,rank1了233
求出树的欧拉序。顺便用序弄出rmq快速求lca。
我们知道,一个节点在序中表示的是一段区间。而题目询问的是一些树的直径。
我们可以用线段树维护区间表示的节点内的直径,可是怎么合并?

直径性质

两棵树用一条边合并,新树直径两端一定是原本两棵树直径四个端点中的两个。

具体证明可以看crazy的课件。
所以,按上述方法合并区间,对于两个不能走的子树,就不合并之。
然后这样合并完就是询问的直径啦。

Code

#include<cstdio>#include<cmath>#include<cstring>#include<algorithm>#define fo(i,a,b) for(int i=a;i<=b;i++)#define efo(i,v) for(int i=last[v];i;i=next[i])using namespace std;const int N=100010,M=N*2;int n,m,tot,to[M],next[M],last[N],dep[N],fir[N],las[N],rmq[M][20];struct segment{    int u,v,l;}tr[M*4],ans;struct node{    int d,v;}a[M];void link(int u,int v){    to[++tot]=v,next[tot]=last[u],last[u]=tot;}int lca(int u,int v){    int l=fir[u],r=fir[v];    if(l>r) swap(l,r);    int len=log2(r-l+1);    int x=rmq[l][len],y=rmq[r-(1<<len)+1][len];    return a[x].d<a[y].d?a[x].v:a[y].v;}void dfs(int v,int from,int d){    a[++m].d=d,a[m].v=v,dep[v]=d;    efo(i,v)    {        int u=to[i];        if(u==from) continue;        dfs(u,v,d+1);        a[++m].d=d,a[m].v=v;    }}int dis(int u,int v){    return dep[u]+dep[v]-2*dep[lca(u,v)];}void merge(segment &c,segment a,segment b){    c=a.l>b.l?a:b;    int t=dis(a.u,b.u);    if(t>c.l) c.l=t,c.u=a.u,c.v=b.u;    t=dis(a.u,b.v);    if(t>c.l) c.l=t,c.u=a.u,c.v=b.v;    t=dis(a.v,b.u);    if(t>c.l) c.l=t,c.u=a.v,c.v=b.u;    t=dis(a.v,b.v);    if(t>c.l) c.l=t,c.u=a.v,c.v=b.v;}void build(int v,int l,int r){    if(l==r)    {        tr[v].u=tr[v].v=a[l].v,tr[v].l=0;        return;    }    int mid=(l+r)>>1;    build(v+v,l,mid);    build(v+v+1,mid+1,r);    merge(tr[v],tr[v+v],tr[v+v+1]);}void find(int v,int l,int r,int x,int y){    if(x>y) return;    if(l==x && r==y)    {        merge(ans,ans,tr[v]);        return;    }    int mid=(l+r)>>1;    if(y<=mid) find(v+v,l,mid,x,y);    else    if(x>mid) find(v+v+1,mid+1,r,x,y);    else    find(v+v,l,mid,x,mid),find(v+v+1,mid+1,r,mid+1,y);}int main(){    freopen("snow.in","r",stdin);    freopen("snow.out","w",stdout);    int _,u,v;    scanf("%d %d",&n,&_);    fo(i,1,n-1)    {        scanf("%d %d",&u,&v);        link(u,v),link(v,u);    }    dfs(1,0,0);    fo(i,1,m)    {        if(!fir[a[i].v]) fir[a[i].v]=i;        las[a[i].v]=i;        rmq[i][0]=i;    }    fo(j,1,int(log2(m)))        fo(i,1,m-(1<<j)+1)        {            int x=rmq[i][j-1],y=rmq[i+(1<<(j-1))][j-1];            rmq[i][j]=a[x].d<a[y].d?x:y;        }    build(1,1,m);    while(_--)    {        scanf("%d %d",&u,&v);        ans.u=ans.v=ans.l=0;        int x1=fir[u],y1=las[u],x2=fir[v],y2=las[v];        find(1,1,m,1,min(x1,x2)-1);        find(1,1,m,max(y1,y2)+1,m);        if(x2>y1) find(1,1,m,y1+1,x2-1);        if(x1>y2) find(1,1,m,y2+1,x1-1);        printf("%d\n",ans.l);    }    return 0;}
0 0