poj 3782 LCA+树链剖分

来源:互联网 发布:淘宝上买东西怎样付款 编辑:程序博客网 时间:2024/05/02 04:26

题意很简单,我们很容易求得两个点的LCA,至于求完之后呢,我们用树链剖分来维护这个值,线段树里的元素有元素最大值,元素最小值,正向答案(后面的减前面的最大值)以及逆向答案,注意,树链剖分转移时一定要记录两个区间段之间最大最小值之差。具体的方案可以参考代码。

#include<cstdio>#include<cstring>#include<algorithm>#include<iostream>#define maxn 100005using namespace std;int top[maxn],fa[maxn],size[maxn],son[maxn],id[maxn],dep[maxn];int head[maxn],a[maxn],num,tot,b[maxn];void init(void){    memset(head,-1,sizeof(head));    num=0;    tot=0;}struct ppi{    int to;    int next;}pp1[maxn<<1];struct pi{    int le;    int ri;    int sum1,sum2;    int max;    int min;}pp[maxn<<2];void add(int u,int v){    pp1[tot].to=v;    pp1[tot].next=head[u];    head[u]=tot++;    pp1[tot].to=u;    pp1[tot].next=head[v];    head[v]=tot++;}void build(int tot,int l,int r){    pp[tot].le=l;    pp[tot].ri=r;    pp[tot].sum1=0;    pp[tot].sum2=0;    if(l==r){        pp[tot].max=b[l];        pp[tot].min=b[l];        return ;    }    build(2*tot,l,(l+r)/2);    build(2*tot+1,(l+r)/2+1,r);    pp[tot].sum1=max(pp[2*tot].sum1,pp[2*tot+1].sum1);    pp[tot].sum1=max(pp[tot].sum1,pp[2*tot+1].max-pp[2*tot].min);    pp[tot].sum2=max(pp[2*tot].sum2,pp[2*tot+1].sum2);    pp[tot].sum2=max(pp[tot].sum2,pp[2*tot].max-pp[2*tot+1].min);    pp[tot].max=max(pp[2*tot].max,pp[2*tot+1].max);    pp[tot].min=min(pp[2*tot].min,pp[2*tot+1].min);}void dfs1(int u,int pa,int d){    dep[u]=d;    fa[u]=pa;    size[u]=1;    son[u]=0;    int k,v;    k=head[u];    while(k!=-1){        v=pp1[k].to;        if(v!=pa){            dfs1(v,u,d+1);            size[u]+=size[v];            if(size[son[u]]<size[v]) son[u]=v;        }        k=pp1[k].next;    }}void dfs2(int u,int pa,int tp){    top[u]=tp;    id[u]=++num;    if(son[u]) dfs2(son[u],u,tp);    int k,v;    k=head[u];    while(k!=-1){        v=pp1[k].to;        if(v!=pa&&v!=son[u]){            dfs2(v,u,v);        }        k=pp1[k].next;    }}int query1(int tot,int l,int r){    if(pp[tot].le>=l&&pp[tot].ri<=r){        return pp[tot].max;    }    int s=0;    int mid=(pp[tot].le+pp[tot].ri)/2;    if(l<=mid) s=max(s,query1(2*tot,l,r));    if(r>mid)   s=max(s,query1(2*tot+1,l,r));    return s;}int query2(int tot,int l,int r){    if(pp[tot].le>=l&&pp[tot].ri<=r){        return pp[tot].min;    }    int s=1000000000;    int mid=(pp[tot].le+pp[tot].ri)/2;    if(l<=mid) s=min(s,query2(2*tot,l,r));    if(r>mid)   s=min(s,query2(2*tot+1,l,r));    return s;}int query(int tot,int l,int r,int p){    if(pp[tot].le>=l&&pp[tot].ri<=r){        if(p==0) return pp[tot].sum1;        return pp[tot].sum2;    }    int s=0;    int mid=(pp[tot].le+pp[tot].ri)/2;    if(l<=mid) s=max(s,query(2*tot,l,r,p));    if(r>mid) s=max(s,query(2*tot+1,l,r,p));    if(l<=mid&&r>mid){        if(p==0){            s=max(s,query1(2*tot+1,l,r)-query2(2*tot,l,r));        }        else{            s=max(s,query1(2*tot,l,r)-query2(2*tot+1,l,r));        }    }    return s;}int get(int u,int v,int p){    int to1,s=0;    to1=top[u];    int m1=1000000000,m2=0;    while(dep[to1]>dep[v]){        s=max(s,query(1,id[to1],id[u],p));        if(p==0){            s=max(s,m2-query2(1,id[to1],id[u]));            m2=max(m2,query1(1,id[to1],id[u]));        }        else{            s=max(s,query1(1,id[to1],id[u])-m1);            m1=min(m1,query2(1,id[to1],id[u]));        }        u=fa[to1];        to1=top[u];    }    to1=v;    s=max(s,query(1,id[to1],id[u],p));    if(p==0){        s=max(s,m2-query2(1,id[to1],id[u]));        m2=max(m2,query1(1,id[to1],id[u]));    }    else{        s=max(s,query1(1,id[to1],id[u])-m1);        m1=min(m1,query2(1,id[to1],id[u]));    }    u=fa[to1];    to1=top[u];    return s;}int get1(int u,int v){    int s=0;    int to1=top[u];    while(dep[to1]>dep[v]){        s=max(s,query1(1,id[to1],id[u]));        u=fa[to1];        to1=top[u];    }    to1=v;    s=max(s,query1(1,id[to1],id[u]));    return s;}int get2(int u,int v){    int s=1000000000;    int to1=top[u];    while(dep[to1]>dep[v]){        s=min(s,query2(1,id[to1],id[u]));        u=fa[to1];        to1=top[u];    }    to1=v;    s=min(s,query2(1,id[to1],id[u]));    return s;}int deap[maxn],vis[maxn],dis[maxn],kk[22][maxn];int maxlog;struct pppi{    int to;    int cost;}pp2;void init(int v,int p,int d){    vis[v]=1;    deap[v]=d;    int i;    kk[0][v]=p;    for(i=1;i<maxlog;i++)    {        if(kk[i-1][v]<0)            kk[i][v]=-1;        else        {            kk[i][v]=kk[i-1][kk[i-1][v]];        }    }    for(i=head[v];i!=-1;i=pp1[i].next)    {        if(!vis[pp1[i].to]){            init(pp1[i].to,v,d+1);        }    }    return ;}int find(int a,int b){    if(deap[a]>deap[b])        swap(a,b);    int i,f;    f=deap[b]-deap[a];    for(i=0;i<maxlog;i++)    {        if((f>>i)&1)            b=kk[i][b];    }    if(b==a)        return a;    for(i=maxlog-1;i>=0;i--)    {        if(kk[i][a]!=kk[i][b])        {            a=kk[i][a];            b=kk[i][b];        }    }    return kk[0][a];}int main(){    int i,k,n,m,p;    while(cin>>n){        init();        maxlog=20;        for(i=1;i<=n;i++) scanf("%d",&a[i]);        for(i=1;i<n;i++){            scanf("%d%d",&p,&k);            add(p,k);        }        dfs1(1,1,1);        dfs2(1,1,1);        for(i=1;i<=n;i++){            b[id[i]]=a[i];        }        build(1,1,n);        memset(vis,0,sizeof(vis));        init(1,-1,0);        scanf("%d",&m);        for(i=0;i<m;i++){            int a,b;            scanf("%d%d",&a,&b);            p=find(a,b);            int s=0;            s=max(s,get(a,p,1));            s=max(s,get(b,p,0));            s=max(s,get1(b,p)-get2(a,p));            printf("%d\n",s);        }    }}


0 0
原创粉丝点击