树链剖分学习笔记

来源:互联网 发布:linux vi 怎么查找 编辑:程序博客网 时间:2024/06/07 23:03

树链剖分名副其实,就是一棵树剖成一堆链,比如像这样:
这里写图片描述
好吧实际上是这样:
这里写图片描述
可以看到这棵树被我们分成了好多好多链,每条红色的线就是一条链
具体怎么剖呢?
首先,找出所有节点中到根节点最远的一个点,把它到根节点的路径变成一条链,然后你可以想象这条路径已经从树中删除,然后递归处理剩下的子树……
这样一棵树就变成了很多很多很萌很萌的链~(萌个卵),把每条链中的节点依次放进线段树里,同一条链上的节点紧挨着,然后就可以开始玩这棵树啦~

求两个点的距离怎么办啊?
普通青年:找LCA
文艺青年:上链剖
二逼青年:Floyd
这里写图片描述
我想找两个蓝色节点的距离,怎么办呢?
首先我们可以确定它们在同一条链上,所以在线段树里它们是连续的一段,所以我们直接用线段树来找区间和就可以啦(注意由于线段树里记录的是每个节点,所以要把边权放到点上变成点权)

这里写图片描述
那么如果是这两个绿色节点呢?
对于每个节点,我们先记录下他的链顶节点,即这条链中深度最浅的那个节点,我们把它叫做一个节点的top,对于这两个绿色节点,我们就找他们两个的top,top更深的节点,我们记作节点x,节点x就变成x的top的父亲,如下图
这里写图片描述
节点x是左边那个绿色节点,它的top是黄色节点,它的top的父亲是橙色节点,然后我们就要计算节点x到top的距离,加入总答案,这个操作由于是在同一条链上,所以只需要线段树就能解决。不要忘记加上黄色节点到橙色节点的边的权值,计入总答案。
这时候我们可以递归找橙色节点和绿色节点的距离,重复之前的步骤,直至找到答案。
要注意的就是变化的节点是top较深的节点,否则会出事……

那么求两个点路径上的最大值怎么办呢?
和上面一样就好了……

修改一条边的值呢?
使用线段树单点修改操作

修改一条路径上所有边的权值?
使用线段树区间修改操作

然后是喜闻乐见的代码,这里用的是bzoj 1036 树的统计
题目链接:http://www.lydsy.com/JudgeOnline/problem.php?id=1036
题目大意:一棵树,三种操作:1.修改一个点的点权 2.寻找两点路径中最大的点权 3.寻找两点路径中的点权和

边与加边操作:

struct Link //记录边的结构体{    int s,t,next;}l[200000];int g[200000];  //边表中最开始的那条边
void Add_Link(int s,int t)  //建边操作{    l[++cnt].s = s;    l[cnt].t = t;    l[cnt].next = g[s];    g[s] = cnt;    return ;}

线段树:

struct tr   //线段树结构体{    int l,r,mid,val;    int sum,maxn;}a[1000000];void build(int x,int l,int r)   //建树操作{    a[x].l = l;    a[x].r = r;    int mid = (l+r)/2;    a[x].mid = mid;    if(l != r)    {        build(x*2,l,mid);        build(x*2+1,mid+1,r);    }}void rejs(int x)    //重计算线段树中一个节点的各个值{    if(a[x].l != a[x].r)    {        a[x].sum = a[x*2].sum + a[x*2+1].sum;        a[x].maxn = max(a[x*2].maxn,a[x*2+1].maxn);    }    else    {        a[x].sum = a[x].val;        a[x].maxn = a[x].val;    }}void change(int x,int y,int val)    //改变一个点的点权{    if(a[x].l != a[x].r)    {        if(y <= a[x].mid)            change(x*2,y,val);        else            change(x*2+1,y,val);        return ;    }    a[x].val = val;    while(x != 0)    {        rejs(x);        x /= 2;    }}int findmax(int x,int l,int r)  //区间最大值{    if(a[x].l == l && a[x].r == r)        return a[x].maxn;    else if(r <= a[x].mid)        return findmax(x*2,l,r);    else if(l > a[x].mid)        return findmax(x*2+1,l,r);    else        return max(findmax(x*2,l,a[x].mid),findmax(x*2+1,a[x].mid+1,r));}int findsum(int x,int l,int r)  //区间之和{    if(a[x].l == l && a[x].r == r)        return a[x].sum;    else if(r <= a[x].mid)        return findsum(x*2,l,r);    else if(l > a[x].mid)        return findsum(x*2+1,l,r);    else        return findsum(x*2,l,a[x].mid)+findsum(x*2+1,a[x].mid+1,r);}

链剖:
上面只讲了思想,这里讲具体怎么剖

struct point{    int siz,dep,son,fa;    int top,num;}p[200000];

这里记录的几个值:
siz 以这棵树为根的子树的大小
dep 该节点的深度
son 剖成链后这个节点所在链中的下一个节点
fa 这个节点的父亲节点
top 该链的定点
num 这个点在线段树中的位置

首先我们需要一个dfs来大致剖出这棵树:

void dfs1(int x){    p[x].siz = 1;    p[x].son = 0;    int w = g[x];    while(w)    {        int k = l[w].t;        if(p[x].fa != k)        {            p[k].fa = x;            p[k].dep = p[x].dep + 1;            dfs1(k);            if(p[k].siz > p[p[x].son].siz)                p[x].son = k;            p[x].siz += p[k].siz;        }        w = l[w].next;    }    return ;}

这里我并没有严格的采用剖最长链的做法,而是选择贪心的剖最大的子树,应该是会导致程序运行慢一些

然后我们需要另一个dfs来找top,并把点放进线段树里:

void dfs2(int x,int topx){    p[x].top = topx;    p[x].num = ++cnt;    int k = p[x].son;    if(!k)        return ;    dfs2(k,topx);    int w = g[x];    while(w)    {        k = l[w].t;        if(k != p[x].fa && k != p[x].son)            dfs2(k,k);        w = l[w].next;    }}

可以看到在这个dfs里我们才真正的把它剖开,把每个点放进应该放的位置,并找出top

最后是找一条路径点权的最大值和总和:

int treemax(int x,int y){    if(p[x].top == p[y].top)    {        if(p[x].num > p[y].num)            swap(x,y);        return findmax(1,p[x].num,p[y].num);    }    if(p[p[x].top].dep < p[p[y].top].dep)        swap(x,y);    int ans1 = findmax(1,p[p[x].top].num,p[x].num);    int ans2 = treemax(p[p[x].top].fa,y);    return max(ans1,ans2);}int treesum(int x,int y){    if(p[x].top == p[y].top)    {        if(p[x].num > p[y].num)            swap(x,y);        return findsum(1,p[x].num,p[y].num);    }    if(p[p[x].top].dep < p[p[y].top].dep)        swap(x,y);    int ans1 = findsum(1,p[p[x].top].num,p[x].num);    int ans2 = treesum(p[p[x].top].fa,y);    return ans1+ans2;}

终于弄完了好累……

0 0
原创粉丝点击