【算法】树上启发式合并算法

来源:互联网 发布:沪江背单词软件 编辑:程序博客网 时间:2024/05/17 07:45

        树上启发式合并算法是启发式合并算法在树上的应用。下面我直接通过一个例子来讲解这个算法。

        例:给定一棵有根树,树的结点编号为1~n,根结点为结点1。结点i有颜色col[i],其中1≤col[i]≤n。要求回答m个询问,每个询问回答颜色c在子树u中出现多少次。

        显然要将查询离线处理,即对子树u的查询都“挂”到结点u上。我们用cnt[c]表示颜色c出现的次数,那么一种容易想到的暴力做法如下:

        0.cnt[]数组初始化为0;

        1.从结点1开始dfs整棵树。dfs至结点u时,按如下步骤处理挂在子树u内的结点上的查询:

        (1)dfs结点u的各个儿子;

        (2)遍历一遍子树u,在遍历的同时更新cnt[]数组;

        (3)给出挂在结点u上的查询的答案;

        (4)遍历一遍子树u,在遍历的同时更新cnt[]数组(把对cnt[]数组的贡献抹去)。

        显然这个暴力做法的时间复杂度为O(n²)。容易想到的一个优化是在步骤1(0)中u的最后一个儿子对cnt[]数组的贡献可以保留着,这样步骤1(1)就不需要遍历以该儿子为根的子树。这个优化有多大效果呢?事实上,如果我们适当改变dfs顺序,使得被dfs的最后一个儿子对应的子树是诸位儿子中最大的,那么这个优化可以把时间复杂度降至O(nlgn)。

        带这种优化的暴力做法就是树上启发式合并。在遍历时,用bool变量keep来表示子树u对cnt[]数组的贡献是否要保留。算法步骤如下:

        0.cnt[]数组初始化为0;

        1.从结点1开始dfs整棵树。dfs至结点u时,按如下步骤处理挂在子树u内的结点上的查询:

        (1)找到结点u的对应最大子树的儿子bc;

        (2)dfs结点u的各个儿子,其中结点bc是u的最后一个被遍历到的儿子,递归处理时若该子结点非bc则keep赋值为0,否则赋值为1;

        (3)遍历一遍子树u,但不遍历子树bc,在遍历的同时更新cnt[]数组;

        (4)给出挂在结点u上的查询的答案;

        (5)若keep为0,遍历一遍子树u,在遍历的同时更新cnt[]数组(把对cnt[]数组的贡献抹去)。

        下面简单地计算算法的时间复杂度:算法的耗时来自于对查询的回答与各结点对cnt[]数组的操作。由于每次查询都是O(1)的,所以查询的总复杂度为O(m),可以忽略。考虑结点u对cnt[]数组的操作次数,设u的祖先从近到远依次为w_1,w_2,...,w_t。处理子树u时,结点u第一次对cnt[]数组进行操作,而结点u下一次对cnt[]数组进行操作,发生在u下一次不在某个祖先的儿子子树中最大的那棵中时,我们来说明这样的事只能发生O(lgn)次。若u不在w_k的儿子子树中最大的那棵中,则有size(w_k-1)*2≤size(w_k),由于size(w_t)=n,所以这样的事(u不在w_k的儿子子树中最大的那棵中)只能发生O(lgn)次。综上,各节点对cnt[]数组的操作次数为O(nlgn)。

        下面给出树上启发式合并的关键代码:

void dfs1(int u,int fu)          //sz[u]为子树u的大小,st[u]与ft[u]分别为子树u的dfs开始时间与结束时间,ver[time]为time时刻dfs的结点编号{sz[u]=1;st[u]=++t_c;ver[t_c]=u;for (int i=0;i<G[u].size();i++){int v=G[u][i];if (v==fu) continue;dfs1(v,u);sz[u]+=sz[v];}ft[u]=t_c;}void dfs2(int u,int fu,bool keep){int bc=0;for (int i=0;i<G[u].size();i++){int v=G[u][i];if (v==fu) continue;if (!bc||sz[bc]<sz[v]) bc=v;}for (int i=0;i<G[u].size();i++){int v=G[u][i];if (v==fu||v==bc) continue;dfs2(v,u,0);}if (bc) dfs2(bc,u,1);for (int i=0;i<G[u].size();i++){int v=G[u][i];if (v==fu||v==bc) continue;for (int j=st[v];j<=ft[v];j++) in(ver[j]);}in(u);ans[u]=getans();if (!keep) for (int j=st[u];j<=ft[u];j++) out(ver[j]);}

        习题与解答(待更新)

codeforce741D

参考资料

0 1