【bzoj1036】【树链剖分】 [ZJOI2008]树的统计Count

来源:互联网 发布:java图片base64编码 编辑:程序博客网 时间:2024/04/20 10:59

Description

  一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。我们将以下面的形式来要求你对这棵树完成
一些操作: I. CHANGE u t : 把结点u的权值改为t II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值 I
II. QSUM u v: 询问从点u到点v的路径上的节点的权值和 注意:从点u到点v的路径上的节点包括u和v本身

Input

  输入的第一行为一个整数n,表示节点的个数。接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有
一条边相连。接下来n行,每行一个整数,第i行的整数wi表示节点i的权值。接下来1行,为一个整数q,表示操作
的总数。接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。
对于100%的数据,保证1<=n<=30000,0<=q<=200000;中途操作中保证每个节点的权值w在-30000到30000之间。

Output

  对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。

Sample Input

4
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4

Sample Output

4
1
2
2
10
6
5
6
5
16

思路

这题一看题干就知是裸的树链剖分
树链剖分就是把树拆成一系列链,然后用数据结构对链进行维护。

通常的剖分方法是轻重链剖分,所谓轻重链就是对于节点u的所有子结点v,size[v]最大的v与u的边是重边,其它边是轻边,其中size[v]是以v为根的子树的节点个数,全部由重边组成的路径是重路径,根据论文上的证明,任意一点到根的路径上存在不超过logn条轻边和logn条重路径。

这样我们考虑用数据结构来维护重路径上的查询,轻边直接查询。

通常用来维护的数据结构是线段树,splay较少见。

具体步骤

预处理

1:第一遍dfs
求出树每个结点的深度deep[x],其为根的子树大小size[x]
以及祖先的信息fa[x][i]表示x往上距离为2^i的祖先

2:第二遍dfs

根节点为起点,向下拓展构建重链

选择最大的一个子树的根继承当前重链

其余节点,都以该节点为起点向下重新拉一条重链

ž给每个结点分配一个位置编号,每条重链就相当于一段区间,用数据结构去维护。

把所有的重链首尾相接,放到同一个数据结构上,然后维护这一个整体即可

修改操作

ž1、单独修改一个点的权值

根据其编号直接在数据结构中修改就行了。

2、修改点u和点v的路径上的权值

(1)若u和v在同一条重链上

直接用数据结构修改pos[u]至pos[v]间的值。

(2)若u和v不在同一条重链上

一边进行修改,一边将u和v往同一条重链上靠,然后就变成了情况(1)。

查询操作

查询操作的分析过程同修改过程
题目不同,选用不同的数据结构来维护值,通常有线段树和splay

代码一

%:pragma GCC optimize("O2")#include <bits/stdc++.h>#define N 30005#define INF 0x7fffffff#define ls (rt<<1)#define rs (rt<<1|1)#define mid ((tr[rt].l+tr[rt].r)/2)using namespace std;inline int read(){    int ret=0,f=1;char c=getchar();    for(;!isdigit(c);c=getchar())if(c=='-')f=-1;    for(;isdigit(c);c=getchar())ret=ret*10+c-'0';    return ret*f;}int n,q,sz=0,v[N],dep[N],siz[N],he[N];int fa[N],pos[N],bl[N],pp=0;struct pppp{int to,nxt;}a[N<<1];struct ppp{int l,r,ma,sum;}tr[N<<2];inline void IN(int u,int v){    a[++pp].to=v;a[pp].nxt=he[u];he[u]=pp;    a[++pp].to=u;a[pp].nxt=he[v];he[v]=pp;}void dfs1(int x){    siz[x]=1;    for(int i=he[x];~i;i=a[i].nxt){        if(a[i].to==fa[x])continue;        dep[a[i].to]=dep[x]+1;        fa[a[i].to]=x;        dfs1(a[i].to);        siz[x]+=siz[a[i].to];    }}void dfs2(int x,int chain){    int k=0;++sz;    pos[x]=sz;bl[x]=chain;    for(int i=he[x];~i;i=a[i].nxt){        if(dep[a[i].to]>dep[x]&&siz[a[i].to]>siz[k])            k=a[i].to;    }    if(k==0)return ;    dfs2(k,chain);    for(int i=he[x];~i;i=a[i].nxt){        if(dep[a[i].to]>dep[x]&&k!=a[i].to)            dfs2(a[i].to,a[i].to);    }}void build(int rt,int l,int r){    tr[rt].l=l;tr[rt].r=r;    if(l==r)return ;    build(ls,l,mid);    build(rs,mid+1,r);}void change(int rt,int x,int y){    if(tr[rt].l==tr[rt].r){tr[rt].sum=tr[rt].ma=y;return ;}    if(x<=mid)change(ls,x,y);    else change(rs,x,y);    tr[rt].sum=tr[ls].sum+tr[rs].sum;    tr[rt].ma=max(tr[ls].ma,tr[rs].ma);}int querysum(int rt,int x,int y){    if(tr[rt].l==x&&tr[rt].r==y)return tr[rt].sum;    if(y<=mid)return querysum(ls,x,y);    else if(x>mid)return querysum(rs,x,y);    else return querysum(ls,x,mid)+querysum(rs,mid+1,y);}int queryma(int rt,int x,int y){    if(tr[rt].l==x&&tr[rt].r==y)return tr[rt].ma;    if(y<=mid)return queryma(ls,x,y);    else if(x>mid)return queryma(rs,x,y);    else return max(queryma(ls,x,mid),queryma(rs,mid+1,y));}inline int solvesum(int x,int y){    int ans=0;    while(bl[x]!=bl[y]){        if(dep[bl[x]]<dep[bl[y]])swap(x,y);        ans+=querysum(1,pos[bl[x]],pos[x]);        x=fa[bl[x]];    }    if(pos[x]>pos[y])swap(x,y);    ans+=querysum(1,pos[x],pos[y]);    return ans;}inline int solvema(int x,int y){    int ma=-INF;    while(bl[x]!=bl[y]){        if(dep[bl[x]]<dep[bl[y]])swap(x,y);        ma=max(ma,queryma(1,pos[bl[x]],pos[x]));        x=fa[bl[x]];    }    if(pos[x]>pos[y])swap(x,y);    ma=max(ma,queryma(1,pos[x],pos[y]));    return ma;}int main(){    memset(he,-1,sizeof(he));    n=read();int x,y,q;    for(int i=1;i<n;++i){        x=read();y=read();        IN(x,y);    }    for(int i=1;i<=n;++i)scanf("%d",&v[i]);    dfs1(1);dfs2(1,1);    build(1,1,n);    for(int i=1;i<=n;++i)change(1,pos[i],v[i]);    scanf("%d",&q);char ch[10];    for(int i=1;i<=q;++i){        scanf("%s%d%d",ch,&x,&y);        if(ch[0]=='C'){v[x]=y;change(1,pos[x],y);}        else{            if(ch[1]=='M')printf("%d\n",solvema(x,y));            else printf("%d\n",solvesum(x,y));        }    }    return 0;}

代码模板二

link cut tree也是十分优美的
动态树可以维护一个动态的森林,支持树的合并(两棵合并成一棵),分离(把某个点和它父亲点分开),动态LCA,树上的点权和边权维护、查询(单点或者树上的一条路径),换根。
感谢hzwer学长的代码

%:pragma GCC optimize("O2")#include <bits/stdc++.h>#define inf 1000000000#define ll long longusing namespace std;int read(){    int x=0,f=1;char ch=getchar();    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}    while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}    return x*f;}int n,m,top;int fa[30005],c[30005][2],u[30005],v[30005],q[30005];ll w[30005],sum[30005],mx[30005];bool rev[30005];bool isroot(int x){    return c[fa[x]][0]!=x&&c[fa[x]][1]!=x;}void update(int x){    int l=c[x][0],r=c[x][1];    sum[x]=sum[l]+sum[r]+w[x];    mx[x]=max(w[x],max(mx[l],mx[r]));}void pushdown(int x){    int l=c[x][0],r=c[x][1];    if(rev[x]){        rev[x]^=1;rev[l]^=1;rev[r]^=1;        swap(c[x][0],c[x][1]);    }}void rotate(int x){    int y=fa[x],z=fa[y],l,r;    l=(c[y][1]==x);r=l^1;    if(!isroot(y))c[z][c[z][1]==y]=x;    fa[c[x][r]]=y;fa[y]=x;fa[x]=z;    c[y][l]=c[x][r];c[x][r]=y;    update(y);update(x);}void splay(int x){    q[++top]=x;    for(int i=x;!isroot(i);i=fa[i])        q[++top]=fa[i];    while(top)pushdown(q[top--]);    while(!isroot(x)){        int y=fa[x],z=fa[y];        if(!isroot(y)){            if(c[y][0]==x^c[z][0]==y)rotate(x);            else rotate(y);        }        rotate(x);    }}void access(int x){    for(int t=0;x;t=x,x=fa[x])        splay(x),c[x][1]=t,update(x);}void makeroot(int x){    access(x);splay(x);rev[x]^=1;}void link(int x,int y){    makeroot(x);fa[x]=y;}void split(int x,int y){    makeroot(x);access(y);splay(y);}int main(){    n=read();mx[0]=-inf;    for(int i=1;i<n;i++)u[i]=read(),v[i]=read();    for(int i=1;i<=n;i++){        w[i]=read();        sum[i]=mx[i]=w[i];    }    for(int i=1;i<n;i++)link(u[i],v[i]);    m=read();    char ch[10];    int u,v;    while(m--){        scanf("%s",ch);        u=read();v=read();        if(ch[1]=='H'){            splay(u);            w[u]=v;            update(u);        }        if(ch[1]=='M')            split(u,v),printf("%lld\n",mx[v]);        if(ch[1]=='S')            split(u,v),printf("%lld\n",sum[v]);    }    return 0;}