POJ 1330 Nearest Common Ancestors(求LCA的三种方法)

来源:互联网 发布:安安电子狗软件 编辑:程序博客网 时间:2024/03/29 20:02

1.离线Tarjan

       设我们求点对(u,v)的最近公共祖先。

       利用在DFS过程中,从点u第一次到点v的过程中,必定是从u开始,经过u和v的最近公共祖先的S,然后到达v的,这时候u和v都是在S为根结点的子树里的。

       如果我们在访问v的时候,u已经被访问过了,这时候,如果我们知道u在以哪个结点S为根结点的子树里,我们就知道它们的公共祖先了,即S。借助并查集,在DFS过程中,我们每到达一个节点x,便创建一棵以x为根结点的子树(即在并查集中令fa[x]=x),将不断将它的子结点连同子结点的子结点……合并到这棵子树下(即借助并查集的并操作令fa[son of x]=x),注意必须在子结点结束DFS过程后才合并,因为我们查询的点对(u,v)也有可能在以结点sone of x为根的子树下。则我们访问到u或者v的时候,判断v或者u是否已经被访问过,如果被访问过,那么v或者u所在的子树的根结点S即是它们的公共祖先,查找v或u所在子树的根节点借助并查集的查操作(即可即find_fa(v或者u))。

#include <iostream>#include <cstring>#include <cstdio>using namespace std;const int N=10005;int q1,q2,res,fa[N],deg[N];int find_fa(int u){if(fa[u]==u) return u;return fa[u]=find_fa(fa[u]);}struct Edge{int to;Edge *next;}memo[N*2],*cur,*adj[N];void addEdge(Edge *head[],int u,int v){cur->to=v;cur->next=head[u];head[u]=cur++;}void tarjan(int u){fa[u]=u;for(Edge *it=adj[u];it;it=it->next){int to=it->to;tarjan(to);fa[to]=u;}if(u==q1||u==q2){if(u!=q1) swap(q1,q2);if(fa[q2]) res=find_fa(fa[q2]);}}void init(int n){for(int i=0;i<=n;i++){adj[i]=NULL;deg[i]=0;fa[i]=0;}cur=memo;}int main(){int t;scanf("%d",&t);while(t--){int n,u,v;scanf("%d",&n);init(n);for(int i=0;i<n-1;i++){scanf("%d%d",&u,&v);addEdge(adj,u,v);deg[v]++;}scanf("%d%d",&q1,&q2);for(int i=1;i<=n;i++) if(deg[i]==0) {tarjan(i);break;}printf("%d\n",res);}return 0;}

2.倍增法

       倍增法先利用一次dfs处理出每个结点的深度及其每个结点i的2^j层的祖先是谁——也可以表述成从结点i向上跳2*j层(即dp[i][j]

       对于询问u,v的公共祖先,假设u的深度比v小,先将u跳到与v同一层的某个u的祖先u1,这时候,如果u1与v是同一个点,那么它们的公共祖先就是这u1,即也是v。

       否则,这时候,我们就要在树上向上跳若干层,使得u和v第一次交汇,也即得到u的某个祖先u2和v的某个祖先v2,满足u2=v2,那么那个交汇的点就是它们的最近公共祖先。如果我们跳到这个交汇点的下一层的结点,设为u3和v3,即满足dp[u3][0]=u2。注意到,从从u跳到u3的过程中,始终满足dp[u’][i]!=dp[v’][i]。

       设u到u3的层数为x,那么这个x是可以表达成一个二进制数的,即x=2^k1+2^k2+..,即我们要从u达到u3要跳2^k1,2^k2…层——不就是我们处理出来的dp函数?那么,我们寻找满足dp[u’][k]!=dp[v’][k]的最大k值,然后跳到那一层(即令u’’=dp[u’][k],v’’=dp[v’][k]),然后再寻找满足dp[u’’][k]!=dp[v’’][k]的最大k值,再跳到那一层。。。这样到最后我们就到了u3层。如果还不理解可以画个图帮助理解下。

#include <iostream>#include <cstdio>#include <cstring>using namespace std;const int N=10005;const int Log=20;int dp[N][Log],depth[N],deg[N];struct Edge{int to;Edge *next;}memo[N*2],*cur,*head[N];void addEdge(int u,int v){cur->to=v;cur->next=head[u];head[u]=cur++;}void dfs(int u){depth[u]=depth[dp[u][0]]+1;for(int i=1;i<Log;i++) dp[u][i]=dp[dp[u][i-1]][i-1];for(Edge *it=head[u];it;it=it->next){dfs(it->to);}}int lca(int u,int v){if(depth[u]<depth[v]) swap(u,v);for(int st=1<<(Log-1),i=Log-1;i>=0;i--,st>>=1){if(st<=depth[u]-depth[v]){u=dp[u][i];}}if(u==v) return u;for(int i=Log-1;i>=0;i--){if(dp[v][i]!=dp[u][i]){v=dp[v][i];u=dp[u][i];}}return dp[u][0];}void init(int n){for(int i=0;i<=n;i++){dp[i][0]=0;head[i]=NULL;deg[i]=0;}cur=memo;}int main(){int t;scanf("%d",&t);while(t--){int n,u,v;scanf("%d",&n);init(n);for(int i=0;i<n-1;i++) {scanf("%d%d",&u,&v);addEdge(u,v);deg[v]++;dp[v][0]=u;}for(int i=1;i<=n;i++) if(deg[i]==0) {dfs(i);break;}scanf("%d%d",&u,&v);printf("%d\n",lca(u,v));}return 0;}

3. 转化成RMQ

       RMQ:区间最小值询问问题。

       RMQ(A,i,j):对于线性序列A中,询问区间[i,j]上的最小值。

       ST(Sparse Table)算法是一个非常有名的在线处理RMQ问题的算法,它可以在O(logN)时间内进行预处理,然后在O(1)的时候内回答每个查询。

      首先是预处理,用动态规划DP解决。设A[i]是要求区间最值的数列,dp[i,j]表示从第i个数起连续2^j个数中的最大值。例如数列3 2 4 5 6 8 1 2 9 7,dp[1,0]表示从第1个数起,长度为2^0=1的最大值,其实就是3这个数。dp[1,2]=5,dp[1,3]=8,dp[2,0]=2,dp[2,1]=4……从这里可以看出dp[i,0]其实就等于A[i]。这样,DP的状态,初值都已经有了,剩下的就是状态转移方程。我们把dp[i,j]平均分成两段(因为dp[i,j]一定是偶数个数字),从i+2^(j-1)-1为一段,从i+2^(j-1)到i+2^j-1为一段(长度都为2^(j-1))。用上例说明,当i=1,j=3时就是3 2 4 5 和 6 8 1 2 这两段。F[i,j]就是这两段的最大值中的最大值。于是我们得到了动态规划方程dp[i,j]=max(dp[i,j-1],dp[i+2^(j-1),2^(j-1)。

       然后是查询。取k=[log2(j-i+1)],则有:RMQ(A,i,j)=min(dp[i,k],dp[j-2^k+1,k])。举例说明,要求区间[2,8]的最大值,就要把它分成[2,5]和[5,8]两个区间,因为这两个区间的最大值我们可以直接由dp[2,2]和dp[5,2]得到。

       对有根树T进行DFS,将遍历到的结点按照顺序记下,我们将得到一个长度为2N – 1的序列,称之为T的欧拉序列F。

       每个结点都在欧拉序列中出现,我们记录结点u在欧拉序列中第一次出现的位置为pos(u)。

       下图是一个例子:

       根据DFS的性质,对于两结点u、v,从pos(u)遍历到pos(v)的过程中经过LCA(u,v)有且仅有一次,且深度是深度序列B[pos(u)…pos(v)]中最小的。

       即LCA(T, u, v) =RMQ(B, pos(u), pos(v))

       有两种写法。

       第一 种:在rmq_depth里记录在如上图所示的深度序列,在rmq_hash里记录每个深度序列里每个位置对应的点,dp[i][j]表示从i开始的长度为2^j的序列中深度最小的值的位置。比如查询(u,v)的最近公共祖先,即是rmq_hash[min(dp[u][k],dp[v-(1<<k)+1][k])];

#include <iostream>#include <cstdio>#include <cstring>#include <cmath>using namespace std;const int N=10005;const int Log=30;int deg[N],dp[N*2][Log];int dfn;//dfs过程中用到的时间戳int rmq_pos[N];//表示节点u第一次出现的位置int rmq_depth[N*2];//表示路径上的每个点的深度int rmq_hash[N*2];//表示路径上的每个深度代表的点struct Edge{int to;Edge *next;}memo[N*2],*cur,*adj[N];void addEdge(Edge *head[],int u,int v){cur->to=v;cur->next=head[u];head[u]=cur++;}void dfs(int u,int d){rmq_pos[u]=dfn;rmq_depth[dfn]=d;rmq_hash[dfn++]=u;for(Edge *it=adj[u];it;it=it->next){int v=it->to;dfs(v,d+1);rmq_depth[dfn]=d;rmq_hash[dfn++]=u;}}void solve(int n){for(int i=1;i<=n;i++) dp[i][0]=i;for(int j=1;(1<<j)<=n;j++) for(int i=1;i+(1<<j)-1<=n;i++){int tmp1=i,tmp2=i+(1<<(j-1));if(rmq_depth[dp[tmp1][j-1]]<rmq_depth[dp[tmp2][j-1]])dp[i][j]=dp[tmp1][j-1];else dp[i][j]=dp[tmp2][j-1];}}int rmq(int u,int v){u=rmq_pos[u],v=rmq_pos[v];if(u>v) swap(u,v);int k=(int)(log(v*1.0-u+1)/log(2.0));int tmp1=u,tmp2=v-(1<<k)+1;return rmq_hash[min(dp[tmp1][k],dp[tmp2][k])];}void init(int n){dfn=1;for(int i=0;i<=n;i++){adj[i]=NULL;deg[i]=0;}cur=memo;}int main(){int t;scanf("%d",&t);while(t--){int n,u,v;scanf("%d",&n);init(n);for(int i=0;i<n-1;i++){scanf("%d%d",&u,&v);addEdge(adj,u,v);deg[v]++;}for(int i=1;i<=n;i++) if(deg[i]==0){dfs(i,0);break;}solve(2*n-1);scanf("%d%d",&u,&v);printf("%d\n",rmq(u,v));}return 0;}
       第二种:对第一种写法做了些改进,即如果在一棵子树下,深度值较大的同时编号也较大,那么在对dp数组进行处理时就可以不用借助上面的rmq_depth了。

#include <iostream>#include <cstdio>#include <cstring>#include <cmath>using namespace std;const int N=10005;const int Log=30;int deg[N],dp[N*2][Log];int dfn;int rmq_low[N*2];int rmq_pos[N*2];int rmq_hash[N*2];struct Edge{int to;Edge *next;}memo[N*2],*cur,*adj[N];void addEdge(Edge *head[],int u,int v){cur->to=v;cur->next=head[u];head[u]=cur++;}void dfs(int u){int tmp=dfn;rmq_low[dfn]=dfn;rmq_pos[u]=dfn;rmq_hash[dfn++]=u;for(Edge *it=adj[u];it;it=it->next){int v=it->to;dfs(v);rmq_low[dfn]=tmp;rmq_hash[dfn++]=u;}}void solve(int n){for(int i=1;i<=n;i++) dp[i][0]=rmq_low[i];for(int j=1;(1<<j)<=n;j++)for(int i=1;i+(1<<j)-1<=n;i++)dp[i][j]=min(dp[i][j-1],dp[i+(1<<(j-1))][j-1]);}int rmq(int u,int v){u=rmq_pos[u],v=rmq_pos[v];if(u>v) swap(u,v);int k=(int)(log(v*1.0-u+1)/log(2.0));int tmp1=u,tmp2=v-(1<<k)+1;return rmq_hash[min(dp[tmp1][k],dp[tmp2][k])];}void init(int n){dfn=1;for(int i=0;i<=n;i++){deg[i]=0;adj[i]=NULL;}cur=memo;}int main(){int t;scanf("%d",&t);while(t--){int n,u,v;scanf("%d",&n);init(n);for(int i=0;i<n-1;i++){scanf("%d%d",&u,&v);addEdge(adj,u,v);deg[v]++;}for(int i=1;i<=n;i++) if(deg[i]==0){dfs(i);break;}solve(2*n-1);scanf("%d%d",&u,&v);printf("%d\n",rmq(u,v));}return 0;}


原创粉丝点击