[JZOJ5058]采蘑菇

来源:互联网 发布:php 记录日志 编辑:程序博客网 时间:2024/04/28 09:54

题目大意

给定一棵n个节点的树,每个点有一个颜色种类ci
对于每一个点x,你需要统计从x出发的所有路径的颜色种类数之和。

1n3×105,0cin


题目分治

首先这题虚树肯定可以做,这里不讲。
考虑使用点分治,先不考虑有多种颜色。假设我只想统计出现过某一种颜色的路径总数。
对于分治重心c,在分治过程中做到点x
 如果xc的路径上已经有了这一种颜色,那么x的答案显然就要加上当前分治层的点数减去x所在子树的点数。
 如果xc的路径上没有这种颜色,我们就要通过预处理求出所有到c路径(不包括c)包含该种颜色的路径条数统计出来,减去x所在子树中满足同样条件的路径数,加在x的答案上面。
现在考虑将其扩展到多种颜色上。
对于分治重心c,在分治过程中做到点x
 设其到c路径上颜色种类数为cnt,分治层点数减去x所在子树的点数是siz,答案就要加上cnt×siz
 对于那些没有出现过的颜色种类,我们考虑先预处理不包含颜色i的到c路径(不包含c)条数fi,那么答案就要加上每种没有出现过的颜色的f值,算的时候由于我们是深搜实现,一开始用sum记录f值和,每新出现一种颜色就减掉对应的f,消失一种颜色就加上,这样就可以快速求出要加的值了。当然,我们还要减去和x同一棵子树的路径,这个在实现的时候,我们再深搜计算一个子树之前先深搜一边把这个子树的路径从f中删掉就好了,深搜计算完之后再深搜加回去就好了。
时间复杂度O(nlogn)


代码实现

#include <iostream>#include <cstdio>#include <cctype>using namespace std;typedef long long LL;int read(){    int x=0,f=1;    char ch=getchar();    while (!isdigit(ch)) f=ch=='-'?-1:f,ch=getchar();    while (isdigit(ch)) x=x*10+ch-'0',ch=getchar();    return x*f;}int buf[30];void write(LL x){    if (x<0) putchar('-'),x=-x;    for (;x;x/=10) buf[++buf[0]]=x%10;    if (!buf[0]) buf[++buf[0]]=0;    for (;buf[0];putchar('0'+buf[buf[0]--]));}const int N=300050;const int E=N<<1;int last[N],fa[N],size[N],col[N],f[N],ext[N],que[N];int nxt[E],tov[E];bool vis[N];LL ans[N];int n,tot,head,tail,cur;LL sum;void insert(int x,int y){tov[++tot]=y,nxt[tot]=last[x],last[x]=tot;}int core(int og){    int i,x,y,ret,rets=n,tmp;    for (head=0,fa[que[tail=1]=og]=0;head<tail;)        for (size[x=que[++head]]=1,i=last[x];i;i=nxt[i])            if ((y=tov[i])!=fa[x]&&!vis[y])                fa[que[++tail]=y]=x;    for (head=tail;head>1;--head) size[fa[que[head]]]+=size[que[head]];    for (head=1;head<=tail;++head)    {        for (tmp=size[og]-size[x=que[head]],i=last[x];i;i=nxt[i])            if ((y=tov[i])!=fa[x]&&!vis[y]) tmp=max(tmp,size[y]);        if (tmp<rets) ret=x,rets=tmp;    }    return ret;}void dfs(int x){    size[x]=1;    for (int i=last[x],y;i;i=nxt[i])        if ((y=tov[i])!=fa[x]&&!vis[y])            fa[y]=x,dfs(y),size[x]+=size[y];}void count(int x,int *f,int sig){    if (!ext[col[x]]++) f[col[x]]+=sig*size[x],sum+=sig*size[x],++cur;    for (int i=last[x],y;i;i=nxt[i])        if ((y=tov[i])!=fa[x]&&!vis[y]) count(y,f,sig);    if (!--ext[col[x]]) --cur;}void calc(int x,int siz,int c){    if (!ext[col[x]]++) sum-=f[col[x]],++cur;    ans[x]+=1ll*siz*cur+sum,ans[c]+=cur;    for (int i=last[x],y;i;i=nxt[i])        if ((y=tov[i])!=fa[x]&&!vis[y]) calc(y,siz,c);    if (!--ext[col[x]]) sum+=f[col[x]],--cur;}void solve(int x){    int c=core(x);    ++ans[c],size[c]=1;    for (int i=last[c],y;i;i=nxt[i])        if (!vis[y=tov[i]]) fa[y]=c,dfs(y),size[c]+=size[y];    for (int i=last[c],y;i;i=nxt[i])        if (!vis[y=tov[i]]) count(y,f,1);    for (int i=last[c],y;i;i=nxt[i])        if (!vis[y=tov[i]]) count(y,f,-1),++ext[col[c]],sum-=f[col[c]],cur=1,calc(y,size[c]-size[y],c),--ext[col[c]],sum+=f[col[c]],cur=0,count(y,f,1);    for (int i=last[c],y;i;i=nxt[i])        if (!vis[y=tov[i]]) count(y,f,-1);    vis[c]=1;    for (int i=last[c],y;i;i=nxt[i])        if (!vis[y=tov[i]]) solve(y);}int main(){    freopen("mushroom.in","r",stdin),freopen("mushroom.out","w",stdout);    n=read();    for (int i=1;i<=n;++i) col[i]=read();    for (int i=1,x,y;i<n;++i) x=read(),y=read(),insert(x,y),insert(y,x);    solve(1);    for (int i=1;i<=n;++i) write(ans[i]),putchar('\n');    fclose(stdin),fclose(stdout);    return 0;}
0 0