luogu3384 树链剖分

来源:互联网 发布:mysql导入excel文件 编辑:程序博客网 时间:2024/05/21 15:06

题目

  https://www.luogu.org/problem/show?pid=3384

题解

  辣鸡题目毁我青春!

  以前写线段树,指针从来不赋初值,偏偏这道题乖张,本来应该1A的题目我提交了29遍!我的AC率,我的时间!

  呃,,吐槽完了。。还是自己习惯不好

  轻重边剖分之后每个点的tid其实就是它的dfs序。。tid数组顺着输出其实就是树的先序遍历(允许我乱用概念吧),那么以一个节点为根节点的子树就是刚进dfs时候的时间戳到退出这层dfs时的时间戳之内的节点。

代码

//树链剖分+dfs序#include <cstdio>#include <algorithm>#define maxn 200010using namespace std;int N, M, R, P, l[maxn], r[maxn], tim, fa[maxn], size[maxn], son[maxn], tid[maxn],head[maxn], next[maxn], to[maxn], val[maxn], w[maxn], tot, top[maxn], deep[maxn],debug;struct segtree{int l, r, sum, d;segtree *lch, *rch;segtree(){l=r=d=sum=0;lch=rch=0;}}*root;void adde(int a, int b){to[++tot]=b;next[tot]=head[a];head[a]=tot;}void pushdown(segtree *p){p->sum=(p->sum+(p->r-p->l+1)*p->d)%P;if(p->lch)p->lch->d+=p->d,p->rch->d+=p->d;p->d=0;}void update(segtree *p){if(p->lch==0)return;pushdown(p->lch),pushdown(p->rch);p->sum=(p->lch->sum+p->rch->sum)%P;}void segadd(segtree *p, int l, int r, int d){pushdown(p);int mid=(p->l+p->r)>>1;if(l<=p->l and r>=p->r){p->d+=d;return;}if(l<=mid)segadd(p->lch,l,r,d);if(r>mid)segadd(p->rch,l,r,d);update(p);}int segsum(segtree *p, int l, int r){pushdown(p);int mid=(p->l+p->r)>>1, ans=0;if(l<=p->l and r>=p->r){return p->sum;}if(l<=mid)ans=(ans+segsum(p->lch,l,r))%P;if(r>mid)ans=(ans+segsum(p->rch,l,r))%P;return ans;}void build(segtree *p, int l, int r){p->l=l,p->r=r;if(l==r){p->sum=w[l];return;}int mid=(l+r)>>1;build(p->lch=new segtree,l,mid);build(p->rch=new segtree,mid+1,r);update(p);}void dfs1(int pos){int p;size[pos]=1;for(p=head[pos];p;p=next[p]){if(to[p]==fa[pos])continue;fa[to[p]]=pos;deep[to[p]]=deep[pos]+1;dfs1(to[p]);if(size[to[p]]>size[son[pos]])son[pos]=to[p];size[pos]+=size[to[p]];}}void dfs2(int pos, int tp){int p;top[pos]=tp;tid[pos]=++tim;l[pos]=tim;if(son[pos])dfs2(son[pos],tp);for(p=head[pos];p;p=next[p]){if(to[p]==fa[pos] or to[p]==son[pos])continue;dfs2(to[p],to[p]);}r[pos]=tim;}void init(){int i, a, b;scanf("%d%d%d%d",&N,&M,&R,&P);for(i=1;i<=N;i++)scanf("%d",val+i);for(i=1;i<N;i++)scanf("%d%d",&a,&b),adde(a,b),adde(b,a);dfs1(R);dfs2(R,R);for(i=1;i<=N;i++)w[tid[i]]=val[i];build(root=new segtree,1,tim);}void add(int a, int b, int d){int ta=top[a], tb=top[b];while(ta!=tb){if(deep[ta]<deep[tb])swap(a,b),swap(ta,tb);segadd(root,tid[ta],tid[a],d);a=fa[ta];ta=top[a];}if(deep[a]>deep[b])swap(a,b);segadd(root,tid[a],tid[b],d);}int sum(int a, int b){int ta=top[a], tb=top[b], ans=0;while(ta!=tb){if(deep[ta]<deep[tb])swap(a,b),swap(ta,tb);ans+=segsum(root,tid[ta],tid[a]);ans%=P;a=fa[ta];ta=top[a];}if(deep[a]>deep[b])swap(a,b);ans+=segsum(root,tid[a],tid[b]);ans%=P;return ans;}int main(){int i, x, y, z, type;segtree *p=new segtree;init();for(i=1;i<=M;i++){ scanf("%d",&type); if(type==1)scanf("%d%d%d",&x,&y,&z),add(x,y,z); if(type==2)scanf("%d%d",&x,&y),printf("%d\n",sum(x,y)); if(type==3)scanf("%d%d",&x,&z),segadd(root,l[x],r[x],z); if(type==4)scanf("%d",&x),printf("%d\n",segsum(root,l[x],r[x]));}return 0;}


0 0
原创粉丝点击