Algorithm---LCA(倍增算法)

来源:互联网 发布:中南大学网络考试答案 编辑:程序博客网 时间:2024/06/05 05:04

基本思想:(参考:from lanshui_Yang)

deep[i] 表示 i节点的深度, fa[i,j]表示 i 的 2^j (即2的j次方) 倍祖先,那么fa[i , 0]即为节点i 的父亲,然后就有一个递推式子:

fa[i,j]= fa [ fa [i,j-1] , j-1 ] 

可以这样理解:

设tmp = fa [i, j - 1] ,tmp2 = fa [tmp, j - 1 ] ,即tmp 是i 的第2 ^ (j - 1) 倍祖先,tmp2 是tmp 的第2 ^ (j - 1) 倍祖先 , 所以tmp2 是i 的第 2 ^ (j - 1) + 2 ^ (j - 1) =  2^ j 倍祖先,注意:这里的“倍”可不能理解为倍数的意思,而是距离节点i有多远的意思,节点i的第2 ^ j 倍祖先表示的节点u满足deep[ u ] - deep[ i ] = 2 ^ j
这样子一个O(NlogN)的预处理求出每个节点的 2^k 的祖先  
然后对于每一个询问的点对a, b的最近公共祖先就是: 

先判断是否 d[x]< d[y] ,如果是的话就交换一下(保证 x 的深度大于 y 的深度), 然后把 x 调到与 y 同深度, 同深度以后再把a, b 同时往上调,调到有一个最小的 j 满足fa [x,j] != fa [y,j] (x,y是在不断更新的), 最后再把(x,y)往上调(x=p[x,0], y=p[y,0])  ,一个一个向上调直到x = y, 这时 x或y 就是他们的最近公共祖先。

 Ps:如果还是不明白,就手动模拟一棵节点数为9的树(如下图所示),很快就会理解的。还有我不得不感叹一句 :二进制真的很神奇!!                  

#include<iostream>#include<cstring>#include<algorithm>#include<string>#include<cmath>#include<vector>#include<cstdio>#define mem(a , b) memset(a , b , sizeof(a))using namespace std ;inline void RD(int &a){    a = 0 ;    char t ;    do    {        t = getchar() ;    }    while (t < '0' || t > '9') ;    a = t - '0' ;    while ((t = getchar()) >= '0' && t <= '9')    {        a = a * 10 + t - '0' ;    }}inline void OT(int a){    if(a >= 10)    {        OT(a / 10) ;    }    putchar(a % 10 + '0') ;}const int MAXN = 10005 ;const int M = 30 ;vector<int> G[MAXN] ;bool vis[MAXN] ;int deep[MAXN] ;int fa[MAXN][M] ;int n ;int root ;void chu(){    mem(vis , 0) ;    mem(deep , 0) ;    mem(fa , 0) ;    int i ;    for(i = 0 ; i <= n ; i ++)        G[i].clear() ;}void dfs(int u){    vis[u] = true ;    int i ;    for(i = 0 ; i < G[u].size() ; i ++)    {        int v = G[u][i] ;        if(!vis[v])        {            deep[v] = deep[u] + 1 ;            dfs(v) ;        }    }}void bz()  // 倍增祖先{    int i , j ;    for(j = 1 ; j < M ; j ++)    {        for(i = 1 ; i <= n ; i ++)        {            fa[i][j] = fa[ fa[i][j - 1] ][j - 1] ;        }    }}void swap(int &x , int &y){    int tmp = x ;    x = y ;    y = tmp ;}int LCA(int u , int v){    if(deep[u] < deep[v]) swap(u , v) ;    int d = deep[u] - deep[v] ;    int i ;    for(i = 0 ; i < M ; i ++)    {        if( (1 << i) & d )  // 注意此处,动手模拟一下,就会明白的        {            u = fa[u][i] ;        }    }    if(u == v) return u ;    for(i = M - 1 ; i >= 0 ; i --)    {        if(fa[u][i] != fa[v][i])        {            u = fa[u][i] ;            v = fa[v][i] ;        }    }    u = fa[u][0] ;    return u ;}void init(){    scanf("%d" , &n) ;    chu() ;    int i ;    for(i = 0 ; i < n - 1 ; i ++)    {        int a , b ;        scanf("%d%d" , &a , &b) ;        G[a].push_back(b) ;        fa[b][0] = a ;        if(fa[a][0] == 0)        {            root = a ;        }    }    deep[root] = 1 ;    dfs(root) ;    bz() ;    int u , v ;    scanf("%d%d" , &u , &v) ;    printf("%d\n", LCA(u , v)) ;}int main(){    int T ;    scanf("%d" , &T) ;    while (T --)    {        init() ;    }    return 0 ;}


0 0