POJ 2763 树链剖分+线段树维护区间和

来源:互联网 发布:手机如何注册淘宝会员 编辑:程序博客网 时间:2024/06/18 14:08

这题是我第一次自己查错并且A掉的树剖模板题。

思路很简单,树剖后扔到线段树里就行了。

唯一困难点的是

1.题目给的是边权值,要转为点值,那么直接把值赋给连接边的两点中深度较大的一点

2.修改时,是给编号修改,为了方便,我直接开的几个数组来装边上的信息

//查了半天错,结果是线段树打错了个字母。。。

AC代码如下

#include <cstdio>#include <algorithm>#include <cstring>#include <cstdlib>#include <iostream>#include <cassert>#include <sstream>#include <numeric>#include <climits>#include <string>#include <cctype>#include <ctime>#include <iomanip>#include <cmath>#include <vector>#include <queue>#include <list>#include <map>#include <set>#pragma comment(linker, "/STACK:1024000000,1024000000")#define INF 200100using namespace std;struct data{int to,len,next;}bian[INF*2];struct node{int l,r,sum;}tr[INF*4];int n,q,s;int cnt=0,siz[INF],son[INF],fa[INF],id[INF],top[INF],fanid[INF],deep[INF];int size,first[INF],a[INF];bool vis[INF];int from[INF],too[INF],len[INF],pos,val;int test,to;void add(int x,int y,int z){size++;bian[size].next=first[x];first[x]=size;bian[size].to=y;bian[size].len=z;}void init(){cnt=size=0;memset(first,0,sizeof(first));memset(siz,0,sizeof(siz));memset(fa,0,sizeof(fa));memset(deep,0,sizeof(deep));memset(top,0,sizeof(top));memset(id,0,sizeof(id));memset(son,0,sizeof(son));memset(vis,0,sizeof(vis));}void dfs1(int u,int last){vis[u]=1;siz[u]=1;fa[u]=last;deep[u]=deep[last]+1;for(int i=first[u];i;i=bian[i].next){int to=bian[i].to;if(to==fa[u]||vis[to])continue;a[to]=bian[i].len;dfs1(to,u);siz[u]+=siz[to];if(son[u]==0)son[u]=to;else if(siz[son[u]]<siz[to])son[u]=to;}}void dfs2(int u,int topp){vis[u]=1;id[u]=++cnt;top[u]=topp;fanid[cnt]=u;if(son[u])dfs2(son[u],topp);for(int i=first[u];i;i=bian[i].next){int to=bian[i].to;if(to==fa[u]||to==son[u]||vis[to])continue;dfs2(to,to);}}void update(int k){tr[k].sum=tr[k<<1].sum+tr[k<<1|1].sum;}void build(int k,int s,int t){tr[k].l=s;tr[k].r=t;if(s==t){tr[k].sum=a[fanid[s]];return;}int mid=s+t>>1;build(k<<1,s,mid);build(k<<1|1,mid+1,t);update(k);}void modify(int k,int pos,int val){if(tr[k].l==tr[k].r){tr[k].sum=val;return;}int mid=tr[k].l+tr[k].r>>1;if(pos<=mid)modify(k<<1,pos,val);else modify(k<<1|1,pos,val);update(k);}int query(int k,int s,int t){if(s<=tr[k].l&&tr[k].r<=t)return tr[k].sum;int mid=tr[k].l+tr[k].r>>1;int res=0;if(t<=mid)res+=query(k<<1,s,t);else if(s>mid)res+=query(k<<1|1,s,t);else res+=query(k<<1,s,mid)+query(k<<1|1,mid+1,t);return res;}int solvequery(int x,int y){int sum=0;while(top[x]!=top[y]){if(deep[top[x]]<deep[top[y]])swap(x,y);sum+=query(1,id[top[x]],id[x]);x=fa[top[x]];}if(x==y)return sum;if(deep[x]>deep[y])swap(x,y);sum+=query(1,id[x]+1,id[y]);return sum;}int main(){//freopen("in.in","r",stdin);while(scanf("%d%d%d",&n,&q,&s)!=EOF){for(int i=1;i<n;i++){scanf("%d%d%d",&from[i],&too[i],&len[i]);add(from[i],too[i],len[i]);add(too[i],from[i],len[i]);}dfs1(1,0);memset(vis,0,sizeof(vis));dfs2(1,1);for(int i=1;i<n;i++){if(deep[from[i]]<deep[too[i]])                swap(from[i],too[i]);            a[from[i]]=len[i];}build(1,1,n);while(q--){scanf("%d",&test);if(test){scanf("%d%d",&pos,&val);modify(1,id[from[pos]],val);}else{scanf("%d",&to);cout<<solvequery(s,to)<<endl;s=to;}}}}


0 0
原创粉丝点击