【BZOJ】1036 [ZJOI2008]树的统计Count 树链剖分

来源:互联网 发布:河海大学单片机实验 编辑:程序博客网 时间:2024/06/07 01:50

题目传送门

这题还是树链剖分,并没有什么较大的改动,只是由于先前的手贱,现在打程序谨慎了很多,这一题一次就AC了。

这题和上一题相比,只是多了一个树上两点的路径和查询,这一点在洛谷的树链剖分模板就体现了。

给定两个节点,每次选取深度较大的一个节点沿着该节点所在的重链向上爬,并用线段树统计答案。当两个节点已经在同一重链上时,再用线段树统计一次两点间的答案就行了。

这就是统计两个节点间的所有节点的键值和的办法,最大值的统计和这个差不多。

附上AC代码:

#include <cstdio>#include <cctype>#include <string>#include <cstring>#include <algorithm>#define N 30010#define lt (k<<1)#define rt (k<<1|1)#define mid ((l+r)>>1)using namespace std;struct side{int to,nt;}s[N*2];struct tree{long long w,mx;}t[N*4];int n,m,a[N],x,y,size[N],dis[N],f[N],hs[N],h[N],num,id,top[N],wz[N];void read(int& a){static char c=getchar();a=0;int f=1;while (!isdigit(c)) {if (c=='-') f=-1;c=getchar();}while (isdigit(c)) a=a*10+c-'0',c=getchar();a*=f;return;}void add(int x,int y){s[num]=(side){y,h[x]},h[x]=num++;s[num]=(side){x,h[y]},h[y]=num++;}void so(int x){size[x]=1,dis[x]=dis[f[x]]+1;for (int i=h[x]; ~i; i=s[i].nt)if (s[i].to!=f[x]){f[s[i].to]=x,so(s[i].to),size[x]+=size[s[i].to];if (size[s[i].to]>size[hs[x]]) hs[x]=s[i].to;}return;}void so(int x,int fa){top[x]=fa,wz[x]=++id;if (!hs[x]) return;so(hs[x],fa);for (int i=h[x]; ~i; i=s[i].nt)if (s[i].to!=f[x]&&s[i].to!=hs[x])so(s[i].to,s[i].to);return;}void change(int k,int l,int r,int ql,int qr,long long w){if (l>qr||r<ql) return;if (l>=ql&&r<=qr){t[k].mx=t[k].w=w;return;}change(lt,l,mid,ql,qr,w),change(rt,mid+1,r,ql,qr,w);t[k].mx=max(t[lt].mx,t[rt].mx),t[k].w=t[lt].w+t[rt].w;return;}long long query_max(int k,int l,int r,int ql,int qr){if (l>qr||r<ql) return -2e9;if (l>=ql&&r<=qr) return t[k].mx;return max(query_max(lt,l,mid,ql,qr),query_max(rt,mid+1,r,ql,qr));}long long query_sum(int k,int l,int r,int ql,int qr){if (l>qr||r<ql) return 0;if (l>=ql&&r<=qr) return t[k].w;return query_sum(lt,l,mid,ql,qr)+query_sum(rt,mid+1,r,ql,qr);}long long find_max(int x,int y){long long sum=-2e9;while (top[x]!=top[y]){if (dis[top[x]]<dis[top[y]]) swap(x,y);sum=max(query_max(1,1,n,wz[top[x]],wz[x]),sum),x=f[top[x]];}if (dis[x]>dis[y]) swap(x,y);return max(query_max(1,1,n,wz[x],wz[y]),sum);}long long find_sum(int x,int y){long long sum=0;while (top[x]!=top[y]){if (dis[top[x]]<dis[top[y]]) swap(x,y);sum+=query_sum(1,1,n,wz[top[x]],wz[x]),x=f[top[x]];}if (dis[x]>dis[y]) swap(x,y);return sum+query_sum(1,1,n,wz[x],wz[y]);}int main(void){read(n),memset(h,-1,sizeof h);for (int i=1; i<n; ++i) read(x),read(y),add(x,y);for (int i=1; i<=n; ++i) read(a[i]);so(1),so(1,1);for (int i=1; i<=n; ++i) change(1,1,n,wz[i],wz[i],a[i]);read(m);while (m--){static char c=getchar();string s="";while (!isalpha(c)) c=getchar();while (isalpha(c)) s+=c,c=getchar();read(x),read(y);if (s=="CHANGE") change(1,1,n,wz[x],wz[x],y);if (s=="QMAX") printf("%lld\n",find_max(x,y));if (s=="QSUM") printf("%lld\n",find_sum(x,y));}return 0;}

0 0