Snow的追寻

来源:互联网 发布:网络掘金1000例 刘铭火 编辑:程序博客网 时间:2024/05/09 14:05

Description

给出一棵有根树,1为根。
给出q次询问,每次询问x,y表示除x,y为根的子树外,剩下的树的直径的长度。
n,q<=10^5

Solution

既然和子树有关,那么我们就维护树的dfs序。
然后每个区间维护直径的长度。用线段树,同51nod1766树上的最远点对.
那么不能用x,y为根的子树就是不能用某两个区间。这样就把原序列分成了最多三个区间,合并起来就好了。

Code

#include<cmath>#include<cstdio>#include<cstring>#include<algorithm>#define fo(i,a,b) for(int i=a;i<=b;i++)#define fd(i,a,b) for(int i=a;i>=b;i--)#define rep(i,a) for(int i=last[a];i;i=next[i])#define N 100005using namespace std;struct note{int a,b;}tr[N*5],p[2],tmp;bool cmp(note x,note y) {return x.a<y.a;}int n,m,x,y,l,tot,top,c[N],d[N],dfn[N],q[N*2],fir[N],size[N],f[N*2][18];int t[N*2],next[N*2],last[N],mi[18];void add(int x,int y) {    t[++l]=y;next[l]=last[x];last[x]=l;}void dfs(int x,int y) {    dfn[x]=++tot;c[tot]=x;size[x]=1;d[x]=d[y]+1;q[++top]=x;fir[x]=top;    rep(i,x) if (t[i]!=y) dfs(t[i],x),size[x]+=size[t[i]],q[++top]=x;}int lca(int x,int y) {    x=fir[x];y=fir[y];    if (x>y) swap(x,y);    int z=log2(y-x+1);    if (d[q[f[x][z]]]<d[q[f[y-mi[z]+1][z]]]) return q[f[x][z]];    else return q[f[y-mi[z]+1][z]];}int len(int x,int y) {    int z=lca(x,y);    return d[x]+d[y]-2*d[z];}note merge(note y,note z) {    note x;int mx=0,l;    if (!(y.a+y.b)) return z;    l=len(y.a,y.b);if (l>mx) mx=l,x.a=y.a,x.b=y.b;    l=len(z.a,z.b);if (l>mx) mx=l,x.a=z.a,x.b=z.b;    l=len(y.a,z.a);if (l>mx) mx=l,x.a=y.a,x.b=z.a;    l=len(y.a,z.b);if (l>mx) mx=l,x.a=y.a,x.b=z.b;    l=len(y.b,z.a);if (l>mx) mx=l,x.a=y.b,x.b=z.a;    l=len(y.b,z.b);if (l>mx) mx=l,x.a=y.b,x.b=z.b;    return x;}void build(int v,int l,int r) {    if (l==r) {tr[v].a=tr[v].b=c[l];return;}    int m=(l+r)/2;    build(v*2,l,m);build(v*2+1,m+1,r);    tr[v]=merge(tr[v*2],tr[v*2+1]);}note find(int v,int l,int r,int x,int y) {    if (l==x&&r==y) return tr[v];    int m=(l+r)/2;    if (y<=m) return find(v*2,l,m,x,y);    else if (x>m) return find(v*2+1,m+1,r,x,y);    else return merge(find(v*2,l,m,x,m),find(v*2+1,m+1,r,m+1,y));}int main() {    freopen("snow.in","r",stdin);    freopen("snow.out","w",stdout);    scanf("%d%d",&n,&m);    fo(i,1,n-1) scanf("%d%d",&x,&y),    add(x,y),add(y,x);dfs(1,0);mi[0]=1;         fo(i,1,top) f[i][0]=i;    fo(i,1,log2(top)) mi[i]=mi[i-1]*2;    fo(j,1,log2(top))        fo(i,1,top-mi[j]+1)            if (d[q[f[i][j-1]]]<d[q[f[i+mi[j-1]][j-1]]]) f[i][j]=f[i][j-1];            else f[i][j]=f[i+mi[j-1]][j-1];    build(1,1,n);    for(;m;m--) {        scanf("%d%d",&x,&y);        if (x==1||y==1) {printf("0\n");continue;}        p[0].a=dfn[x];p[0].b=dfn[x]+size[x]-1;        p[1].a=dfn[y];p[1].b=dfn[y]+size[y]-1;        sort(p,p+2,cmp);tmp.a=tmp.b=0;        if (p[0].a>1) tmp=merge(tmp,find(1,1,n,1,p[0].a-1));        if (p[0].b+1<=p[1].a-1) tmp=merge(tmp,find(1,1,n,p[0].b+1,p[1].a-1));        int ri=max(p[0].b,p[1].b);        if (ri<n) tmp=merge(tmp,find(1,1,n,ri+1,n));        printf("%d\n",len(tmp.a,tmp.b));    }}
0 0
原创粉丝点击