树链剖分 模板 洛谷P3384

来源:互联网 发布:51单片机有趣的小制作 编辑:程序博客网 时间:2024/05/22 14:17

题目要求:

已知一棵包含N个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:

操作1: 格式: 1 x y z 表示将树从x到y结点最短路径上所有节点的值都加上z

操作2: 格式: 2 x y 表示求树从x到y结点最短路径上所有节点的值之和

操作3: 格式: 3 x z 表示将以x为根节点的子树内所有节点值都加上z

操作4: 格式: 4 x 表示求以x为根节点的子树内所有节点值之和

代码

#include<bits/stdc++.h>using namespace std;#define int long longint n,root,num_e,num_s,mod,m,opt,u,v;inline void _read(int &x){    x=0;char c=getchar();    while(c<'0'||c>'9') c=getchar();    while(c>='0'&&c<='9'){        x=(x<<1)+(x<<3)+c-'0';        c=getchar();    }}inline void read(int &x){    _read(x);}inline void read(int &x,int &y){    _read(x);    _read(y);}inline void read(int &x,int &y,int &z){    _read(x);    _read(y);    _read(z);}inline void read(int &x,int &y,int &z,int &o){    _read(x);    _read(y);    _read(z);    _read(o);}#define maxn 1000005struct edge{    int to,nex;}e[maxn];int head[maxn],mp[maxn],a[maxn];void add(int x,int y){    e[++num_e].to=y;e[num_e].nex=head[x];head[x]=num_e;    e[++num_e].to=x;e[num_e].nex=head[y];head[y]=num_e;}struct tre{    int siz,dep,fa,top,dis,wson,s,e;}t[maxn];struct point{    int l,r,lazy,sum;}p[maxn*6];void dfs1(int,int);void dfs2(int,int);void build_tree(int,int,int);void push_down(int);void up_date(int,int,int,int);int ask(int,int,int);void solve1();void solve2();void solve3();void solve4();#undef intint main(){    read(n,m,root,mod);    for(int i=1;i<=n;i++) read(a[i]);    u,v;    for(int i=1;i<n;i++) read(u,v),add(u,v);//cout<<"   !!";    dfs1(root,1);//for(int i=1;i<=n;i++) cout<<t[i].wson<<' '; cout<<endl;    dfs2(root,root);    build_tree(1,num_s,1);    for(int i=1;i<=m;i++){        read(opt);        if(opt==1) solve1();        else if(opt==2) solve2();        else if(opt==3) solve3();        else solve4();    }    return 0;}#define int long longvoid dfs1(int x,int depth){    t[x].siz=1;    t[x].dep=depth;    for(int i=head[x];i;i=e[i].nex){        int y=e[i].to;        if(t[x].fa!=y){            t[y].fa=x;            dfs1(y,depth+1);            t[x].siz+=t[y].siz;            if(t[t[x].wson].siz<t[y].siz) t[x].wson=y;        }    }}void dfs2(int x,int fir){    t[x].top=fir;//cout<<"  **"<<x<<' '<<t[x].wson;    t[x].s=++num_s;    mp[num_s]=x;    if(t[x].wson){        dfs2(t[x].wson,fir);        for(int i=head[x];i;i=e[i].nex){            int y=e[i].to;            if(y!=t[x].fa&&y!=t[x].wson){                dfs2(y,y);            }        }    }    t[x].e=num_s;}/***********************************愉快的树剖到此结束。。5 5 2 247 3 7 8 01 21 53 14 13 4 23 2 24 51 5 1 32 1 3痛苦的线段树已经开始。。。************************************/void build_tree(int l,int r,int num){    p[num].l=l,p[num].r=r;p[num].lazy;    if(l==r){        p[num].sum=a[mp[l]];        return;    }    int m=(l+r)>>1;    build_tree(l,m,num<<1);    build_tree(m+1,r,num<<1|1);    p[num].sum=p[num<<1].sum+p[num<<1|1].sum;}inline void push_down(int x){    if(!p[x].lazy) return;    p[x<<1].sum+=(p[x<<1].r-p[x<<1].l+1)*p[x].lazy; p[x<<1].lazy+=p[x].lazy;    p[x<<1|1].sum+=(p[x<<1|1].r-p[x<<1|1].l+1)*p[x].lazy; p[x<<1|1].lazy+=p[x].lazy;    p[x].lazy=0;}void up_date(int l,int r,int x,int num){    if(l<=p[num].l&&r>=p[num].r){        p[num].lazy+=x;        p[num].sum+=x*(p[num].r-p[num].l+1);        return;    }    push_down(num);    int m=(p[num].l+p[num].r)>>1;    if(l<=m) up_date(l,r,x,num<<1);    if(r>m) up_date(l,r,x,num<<1|1);    p[num].sum=p[num<<1].sum+p[num<<1|1].sum;}int ask(int l,int r,int num){    if(l<=p[num].l&&r>=p[num].r) return p[num].sum;    int m=(p[num].l+p[num].r)>>1;int summ=0;    push_down(num);    if(l<=m) summ+=ask(l,r,num<<1);    if(r>m) summ+=ask(l,r,num<<1|1);    return summ;}void solve1(){    int x,y,w;    read(x,y,w);    int f1=t[x].top,f2=t[y].top;    while(f1!=f2){        if(t[f1].dep<t[f2].dep){            swap(f1,f2);            swap(x,y);        }        up_date(t[f1].s,t[x].s,w,1);        x=t[f1].fa;        f1=t[x].top;    }    if(t[x].dep<t[y].dep) up_date(t[x].s,t[y].s,w,1);    else up_date(t[y].s,t[x].s,w,1);}void solve2(){    int x,y;    read(x,y);    int f1=t[x].top,f2=t[y].top,summ=0;    while(f1!=f2){        if(t[f1].dep<t[f2].dep){            swap(f1,f2);            swap(x,y);        }        summ+=ask(t[f1].s,t[x].s,1);        summ%=mod;        x=t[f1].fa;        f1=t[x].top;    }    if(t[x].dep<t[y].dep) summ+=ask(t[x].s,t[y].s,1);    else summ+=ask(t[y].s,t[x].s,1);    printf("%lld\n",summ%mod);}void solve3(){    int x,w;    read(x,w);//cout<<"    !! "<<t[x].s<<' '<<t[x].e<<endl;    up_date(t[x].s,t[x].e,w,1);}void solve4(){    int x;    read(x);    printf("%lld\n",ask(t[x].s,t[x].e,1)%mod);}#undef int