hdu 4918 Query on the subtree (动态点分治+动态开点+线段树)

来源:互联网 发布:java特殊字符处理 编辑:程序博客网 时间:2024/06/05 06:13

题目描述

传送门

题目大意:一棵n个节点的树,每个节点有一个权值val
操作1:修改点x的权值
操作2:查询与x的距离小于等于d的节点的权值和。

题解

如果修改的话应该有很多种做法的。
首先建立重心树,对于每个点维护两棵权值线段树,一棵表示u(作为重心)的子树中到u距离为x的点的权值和,一棵表示到u的父重心距离为x的点的权值和。
那么每次查询的时候就是u的子树中距离为[0,d]的权值和+u在重心树中所有祖先子树中的答案。
第一部分直接从线段树中查,第二部分设u到祖先的距离为D,先得到祖先重心子树中所有[0,d-D]的权值和,再减去u所属的祖先的子重心子树中的答案,就是我们维护的第二棵线段树中的信息。
对于修改,每次只会影响logn个点的信息。
所以时间复杂度是O(nlog2n)

代码

#include<iostream>#include<cstdio>#include<cstring>#include<algorithm>#include<cmath>#define N 200003#define inf 1000000000using namespace std;int n,m,tot,nxt[N],point[N],v[N],deep[N],fa[N][20],mi[20],belong[N];int rt[N],rtc[N],root,size[N],f[N],vis[N],sum,sz,ans,val[N];struct data{    int ls,rs,sum;}tr[N*60],c[N*60];void add(int x,int y){    tot++; nxt[tot]=point[x]; point[x]=tot; v[tot]=y;    tot++; nxt[tot]=point[y]; point[y]=tot; v[tot]=x;}void dfs(int x,int father){    deep[x]=deep[father]+1;    for (int i=1;i<=19;i++) {        if (deep[x]-mi[i]<0) break;        fa[x][i]=fa[fa[x][i-1]][i-1];    }    for (int i=point[x];i;i=nxt[i]){        if (v[i]==father) continue;        fa[v[i]][0]=x;        dfs(v[i],x);    }}int lca(int x,int y){    if (deep[x]<deep[y]) swap(x,y);    int k=deep[x]-deep[y];    for (int i=0;i<=19;i++)     if ((k>>i)&1) x=fa[x][i];    if (x==y) return x;    for (int i=19;i>=0;i--)     if (fa[x][i]!=fa[y][i])       x=fa[x][i],y=fa[y][i];    return fa[x][0];}void getroot(int x,int father){    f[x]=0; size[x]=1;    for (int i=point[x];i;i=nxt[i]){        if (v[i]==father||vis[v[i]]) continue;        getroot(v[i],x);        size[x]+=size[v[i]];        f[x]=max(f[x],size[v[i]]);    }    f[x]=max(f[x],sum-size[x]);    if (f[x]<f[root]) root=x;}int dis(int x,int y){    return deep[x]+deep[y]-2*deep[lca(x,y)];}void divi(int x,int father){    belong[x]=father; vis[x]=1;    for (int i=point[x];i;i=nxt[i]){        if (vis[v[i]]) continue;        root=0; sum=size[v[i]];        getroot(v[i],x);        divi(root,x);    }}void update(int now){    int l=tr[now].ls; int r=tr[now].rs;    tr[now].sum=tr[l].sum+tr[r].sum;}void insert(int &i,int l,int r,int x,int val){    if (!i) i=++sz,tr[i].ls=tr[i].rs=tr[i].sum=0;    if (l==r) {        tr[i].sum+=val;        return;    }     int mid=(l+r)/2;    if (x<=mid) insert(tr[i].ls,l,mid,x,val);    else insert(tr[i].rs,mid+1,r,x,val);    update(i);}int qjsum(int i,int l,int r,int ll,int rr){    if (ll>rr) return 0;    if (ll<=l&&r<=rr) return tr[i].sum;    int mid=(l+r)/2; int ans=0;    if (ll<=mid) ans+=qjsum(tr[i].ls,l,mid,ll,rr);    if (rr>mid) ans+=qjsum(tr[i].rs,mid+1,r,ll,rr);    return ans;}void change(int u,int v,int val){    int D=dis(u,v);     insert(rt[u],0,n,D,val);    if (!belong[u]) return;    int f=belong[u]; D=dis(f,v);     insert(rtc[u],0,n,D,val);    change(f,v,val);}void calc(int u,int son,int v,int d){    if (!u) return;    if (u==son) ans+=qjsum(rt[u],0,n,0,d);    else {        int D=dis(u,v);        ans+=qjsum(rt[u],0,n,0,d-D);        ans-=qjsum(rtc[son],0,n,0,d-D);     }    calc(belong[u],u,v,d);}int main(){    freopen("a.in","r",stdin);    freopen("my.out","w",stdout);    mi[0]=1;    for (int i=1;i<=19;i++) mi[i]=mi[i-1]*2;    while (scanf("%d%d",&n,&m)!=EOF) {        tot=0;        memset(point,0,sizeof(point)); sz=0;        memset(vis,0,sizeof(vis));        memset(fa,0,sizeof(fa));        memset(deep,0,sizeof(deep));        memset(rt,0,sizeof(rt));        memset(rtc,0,sizeof(rtc));        for (int i=1;i<=n;i++) scanf("%d",&val[i]);        for (int i=1;i<n;i++) {            int x,y; scanf("%d%d",&x,&y);            add(x,y);        }        dfs(1,0);        sum=n; f[0]=inf; root=0;        getroot(1,0); divi(root,0);        for (int i=1;i<=n;i++) change(i,i,val[i]);        for (int i=1;i<=m;i++) {            char s[10]; int x,v1; ans=0;            scanf("%s%d%d",s+1,&x,&v1);            if (s[1]=='!') {                change(x,x,v1-val[x]);                val[x]=v1;            }            if (s[1]=='?') calc(x,x,x,v1),printf("%d\n",ans);        }    }}
0 0
原创粉丝点击