【BZOJ1036】[ZJOI2008]树的统计Count 树链剖分

来源:互联网 发布:数控g92螺纹编程实例 编辑:程序博客网 时间:2024/04/26 05:28

此题为树链剖分模板题,可用线段树维护区间最值和区间和。

树链剖分部分请参看【算法杂谈_02】树链剖分

[ZJOI2008]树的统计Count 树链剖分 C++代码:

#include <cstdio>#include <cstring>#include <iostream>#include <algorithm>using namespace std;#define lson (pos<<1)#define rson (pos<<1|1)#define INF 0x7fffffff#define N 300010int n,m,a[N];int cnt,head[N],next[2*N],to[2*N];int now,size[N],deep[N],fa[N],son[N],p_id[N],id_p[N],top[N];int maxt[4*N],sum[4*N];void add(int x,int y){to[++cnt]=y;next[cnt]=head[x];head[x]=cnt;}void dfs(int x,int d){size[x]=1;deep[x]=d;for(int y,i=head[x];i;i=next[i])if((y=to[i])!=fa[x]){fa[y]=x;dfs(y,d+1);if(size[y]>size[son[x]])son[x]=y;size[x]+=size[y];}}void create(int x,int d){p_id[x]=++now;id_p[now]=x;top[x]=d;if(son[x])create(son[x],d);for(int y,i=head[x];i;i=next[i])if((y=to[i])!=fa[x]&&y!=son[x])create(y,y);}void build(int pos,int l,int r){if(l==r){ sum[pos]=maxt[pos]=a[id_p[l]]; return; }int mid=(l+r)>>1;build(lson,l,mid);build(rson,mid+1,r);sum[pos]=sum[lson]+sum[rson];maxt[pos]=max(maxt[lson],maxt[rson]);}void fix(int pos,int l,int r,int x,int y){if(l==x&&r==x){ sum[pos]=maxt[pos]=y; return; }int mid=(l+r)>>1;if(x<=mid)fix(lson,l,mid,x,y);else fix(rson,mid+1,r,x,y);sum[pos]=sum[lson]+sum[rson];maxt[pos]=max(maxt[lson],maxt[rson]);}int query_max(int pos,int l,int r,int x,int y){if(x<=l&&r<=y) return maxt[pos];int mid=(l+r)>>1;if(y<=mid)return query_max(lson,l,mid,x,y);if(x>mid)return query_max(rson,mid+1,r,x,y);return max(query_max(lson,l,mid,x,y),query_max(rson,mid+1,r,x,y));}int query_sum(int pos,int l,int r,int x,int y){if(x<=l&&r<=y) return sum[pos];int mid=(l+r)>>1;if(y<=mid)return query_sum(lson,l,mid,x,y);else if(x>mid)return query_sum(rson,mid+1,r,x,y);return query_sum(lson,l,mid,x,y)+query_sum(rson,mid+1,r,x,y);}int findmax(int x,int y){int f1=top[x],f2=top[y],re=-INF;while(f1!=f2){if(deep[f1]<deep[f2])swap(x,y),swap(f1,f2);re=max(re,query_max(1,1,n,p_id[f1],p_id[x]));x=fa[f1];f1=top[x];}if(deep[x]>deep[y])swap(x,y);return max(re,query_max(1,1,n,p_id[x],p_id[y]));}int findsum(int x,int y){int f1=top[x],f2=top[y],re=0;while(f1!=f2){if(deep[f1]<deep[f2])swap(x,y),swap(f1,f2);re+=query_sum(1,1,n,p_id[f1],p_id[x]);x=fa[f1];f1=top[x];}if(deep[x]>deep[y])swap(x,y);return re+query_sum(1,1,n,p_id[x],p_id[y]);}int getopt(){char ch[10];scanf("%s",ch);if(ch[0]=='C') return 1;if(ch[1]=='M') return 2;return 3;}int main(){cin>>n;for(int x,y,i=1;i<n;i++)scanf("%d%d",&x,&y),add(x,y),add(y,x);dfs(1,1);create(1,1);for(int i=1;i<=n;i++)scanf("%d",&a[i]);build(1,1,n);cin>>m;for(int opt,x,y,i=1;i<=m;i++){opt=getopt();scanf("%d%d",&x,&y);if(opt==1)fix(1,1,n,p_id[x],y);else if(opt==2)printf("%d\n",findmax(x,y));elseprintf("%d\n",findsum(x,y));}return 0;}


0 0
原创粉丝点击