【树链剖分】bzoj4034: [HAOI2015]树上操作

来源:互联网 发布:国债收益率升高知乎 编辑:程序博客网 时间:2024/06/07 22:45

~biu~

Description

submit
有一棵点数为 N 的树,以点 1 为根,且树点有边权。然后有 M 个
操作,分为三种:
操作 1 :把某个节点 x 的点权增加 a 。
操作 2 :把某个节点 x 为根的子树中所有点的点权都增加 a 。
操作 3 :询问某个节点 x 到根的路径中所有点的点权和。

Input

第一行包含两个整数 N, M 。表示点数和操作数。接下来一行 N 个整数,表示树中节点的初始权值。接下来 N-1
行每行三个正整数 fr, to , 表示该树中存在一条边 (fr, to) 。再接下来 M 行,每行分别表示一次操作。其中
第一个数表示该操作的种类( 1-3 ) ,之后接这个操作的参数( x 或者 x a ) 。

Output

对于每个询问操作,输出该询问的答案。答案之间用换行隔开。

Sample Input

5 5
1 2 3 4 5
1 2
1 4
2 3
2 5
3 3
1 2 1
3 5
2 1 2
3 3

Sample Output

6
9
13

HINT

对于 100% 的数据, N,M<=100000 ,且所有输入数据的绝对值都不会超过 10^6 。

思路

一个裸的树链剖分
首先连边,然后跑两边dfs
第一遍求每个节点的父节点和子树大小
第二遍重新编号每个节点并求出重链和每个点子树中的最大值
然后建一个空树,主要是为了找到结构体中每个区间的左右端点
再把每个重新编号的元素添加进树
最后根据每个要求求解即可

代码

#include <bits/stdc++.h>#define ls (rt<<1)#define rs (rt<<1|1)#define mid ((tr[rt].l+tr[rt].r)>>1)#define N 100001#define ll long longusing namespace std;inline int read(){    int ret=0,f=1;char c=getchar();    for(;!isdigit(c);c=getchar())if(c=='-')f=-1;    for(;isdigit(c);c=getchar())ret=ret*10+c-'0';    return ret*f;}int n,m,v[N],pos[N],top[N],bl[N],pp=0,he[N];int siz[N],ma[N],cnt=0,fa[N];struct pppp{int l,r;ll sum,tag;}tr[N<<2];struct derpp{int to,nxt;}a[N<<2];inline void add(int x,int y){    a[++pp]=(derpp){y,he[x]};he[x]=pp;    a[++pp]=(derpp){x,he[y]};he[y]=pp;}void build(int rt,int l,int r){    tr[rt].l=l;tr[rt].r=r;    if(l==r)return ;    build(ls,l,mid);    build(rs,mid+1,r);}void dfs(int x){    siz[x]=1;    for(int i=he[x];~i;i=a[i].nxt){        int v=a[i].to;        if(v!=fa[x]){            fa[v]=x;            dfs(v);            siz[x]+=siz[v];            ma[x]=max(ma[x],ma[v]);        }    }}void dfs2(int x,int father){    bl[x]=father;pos[x]=ma[x]=++cnt;    int k=0,v;    for(int i=he[x];~i;i=a[i].nxt){        v=a[i].to;        if(v!=fa[x]&&siz[v]>siz[k])k=v;    }    if(k){        dfs2(k,father);        ma[x]=max(ma[x],ma[k]);    }    for(int i=he[x];~i;i=a[i].nxt){        v=a[i].to;        if(v!=fa[x]&&v!=k){            dfs2(v,v);            ma[x]=max(ma[x],ma[v]);        }    }}void pushdown(int rt){    if(tr[rt].l==tr[rt].r)return ;    tr[ls].tag+=tr[rt].tag;tr[rs].tag+=tr[rt].tag;    tr[ls].sum+=tr[rt].tag*(mid-tr[rt].l+1);    tr[rs].sum+=tr[rt].tag*(tr[rt].r-mid);    tr[rt].tag=0;}void add(int rt,int x,int y,ll val){    if(tr[rt].tag)pushdown(rt);    if(tr[rt].l==x&&tr[rt].r==y){tr[rt].tag+=val;tr[rt].sum+=(tr[rt].r-tr[rt].l+1)*val;return ;}    if(x<=mid)add(ls,x,min(mid,y),val);    if(y>=mid+1)add(rs,max(x,mid+1),y,val);    tr[rt].sum=tr[ls].sum+tr[rs].sum;}ll query(int rt,int x,int y){    if(tr[rt].tag)pushdown(rt);    if(tr[rt].l==x&&tr[rt].r==y)return tr[rt].sum;    ll ans=0;    if(x<=mid)ans+=query(ls,x,min(mid,y));    if(y>=mid+1)ans+=query(rs,max(mid+1,x),y);    return ans;}ll query(int x){    ll ans=0;    while(bl[x]!=1){        ans+=query(1,pos[bl[x]],pos[x]);        x=fa[bl[x]];    }    ans+=query(1,1,pos[x]);    return ans;}int main(){int x,y,opt;    memset(he,-1,sizeof(he));    n=read();m=read();    for(int i=1;i<=n;++i)v[i]=read();    for(int i=1;i<n;++i){        x=read();y=read();        add(x,y);    }    dfs(1);dfs2(1,1);    build(1,1,n);    for(int i=1;i<=n;++i)add(1,pos[i],pos[i],v[i]);    while(m--){        opt=read();x=read();        if(opt==1){            y=read();            add(1,pos[x],pos[x],y);        }        else if(opt==2){            y=read();            add(1,pos[x],ma[x],y);        }        else{            printf("%lld\n",query(x));        }    }    return 0;}