HDU6035 Colorful Tree

来源:互联网 发布:淘宝导航条怎么装修 编辑:程序博客网 时间:2024/06/03 17:39

计算所有路径上不同颜色数的和

假设每种颜色对每条路径都有贡献,再减去对于每种颜色路径没有贡献的情况。(一开始觉得答案很大自信取模。。。)

O(nlogn):树形DP+线段树合并

#include<stdio.h>#include<math.h>#include<string.h>#include<stdlib.h>#include<string>#include<iostream>#include<algorithm>using namespace std;typedef long long ll;const int MAXN=(int)2e5+10;const int MOD=(int)1e9+7;struct node{    int to,nxt;}ed[MAXN<<1];int head[MAXN],cnt;int root[MAXN];int ls[MAXN*20],rs[MAXN*20],sum[MAXN*20],tol;void insert(int &rt,int l,int r,int x,int y){    if(rt==0)rt=++tol;    if(l==r){        sum[rt]=y;        return ;    }    int mid=l+r>>1;    if(x<=mid)insert(ls[rt],l,mid,x,y);    else insert(rs[rt],mid+1,r,x,y);    sum[rt]=sum[ls[rt]]+sum[rs[rt]];}int query(int rt,int l,int r,int k){    if(rt==0)return 0;    if(l==r)return sum[rt];    int mid=l+r>>1;    if(k<=mid)return query(ls[rt],l,mid,k);    else return query(rs[rt],mid+1,r,k);}int merge(int u,int v){    if(u==0||v==0)return u|v;    ls[u]=merge(ls[u],ls[v]);    rs[u]=merge(rs[u],rs[v]);    sum[u]+=sum[v];    return u;}void addedge(int u,int v){    ed[cnt].to=v;    ed[cnt].nxt=head[u];    head[u]=cnt++;}int col[MAXN],sz[MAXN],vis[MAXN];int n;ll de,ans;void dfs(int u,int pre){    sz[u]=1;    int step=0;    for(int i=head[u];i!=-1;i=ed[i].nxt){        int v=ed[i].to;        if(v!=pre){            dfs(v,u);            sz[u]+=sz[v];            ll qr=query(root[v],1,n,col[u]);            ll szv=sz[v]-qr;            de=(de+1LL*szv*(szv-1)/2);            step+=qr;            }    }        insert(root[u],1,n,col[u],sz[u]-step);    for(int i=head[u];i!=-1;i=ed[i].nxt){        int v=ed[i].to;        if(v!=pre){            merge(root[u],root[v]);        }    }}int main(){    int ca=1;    while(~scanf("%d",&n)){        de=0;        tol=0;        cnt=0;        for(int i=1;i<=n;i++)root[i]=0;        memset(ls,0,sizeof(ls));        memset(rs,0,sizeof(rs));        memset(vis,0,sizeof(vis));        for(int i=1;i<=n;i++)head[i]=-1;        for(int i=1;i<=n;i++)scanf("%d",&col[i]),vis[col[i]]=1;        for(int i=1;i<n;i++){            int u,v;            scanf("%d%d",&u,&v);            addedge(u,v);            addedge(v,u);        }        dfs(1,0);        for(int i=1;i<=n;i++){            if(vis[i]&&i!=col[1]){                ll qr=query(root[1],1,n,i);                ll szv=n-qr;                de=(de+1LL*szv*(szv-1)/2);            }        }        int colnum=0;        for(int i=1;i<=n;i++)colnum+=vis[i];        ans=1LL*n*(n-1)/2*colnum;        ans=((ans-de));        printf("Case #%d: %lld\n",ca++,ans);    }    return 0;}
O(n):

#include<stdio.h>#include<math.h>#include<string.h>#include<stdlib.h>#include<string>#include<iostream>#include<algorithm>using namespace std;typedef long long ll;const int MAXN=(int)2e5+10;const int MOD=(int)1e9+7;struct node{    int to,nxt;}ed[MAXN<<1];int head[MAXN],cnt;int root[MAXN];int sum[MAXN];void addedge(int u,int v){    ed[cnt].to=v;    ed[cnt].nxt=head[u];    head[u]=cnt++;}int col[MAXN],sz[MAXN],vis[MAXN];int n;ll de,ans;void dfs(int u,int pre){    sz[u]=1;    int step=0;    int s1=sum[col[u]];    for(int i=head[u];i!=-1;i=ed[i].nxt){        int v=ed[i].to;        if(v!=pre){            dfs(v,u);            sz[u]+=sz[v];            int qr=sum[col[u]]-s1;            int szv=sz[v]-qr;            s1=sum[col[u]];            de=(de+1LL*szv*(szv-1)/2);            step+=qr;}    }        sum[col[u]]+=sz[u]-step;}int main(){    int ca=1;    while(~scanf("%d",&n)){        de=0;        cnt=0;        for(int i=1;i<=n;i++)vis[i]=0;        for(int i=1;i<=n;i++)head[i]=-1;        for(int i=1;i<=n;i++)scanf("%d",&col[i]),vis[col[i]]=1,sum[col[i]]=0;        for(int i=1;i<n;i++){            int u,v;            scanf("%d%d",&u,&v);            addedge(u,v);            addedge(v,u);        }        dfs(1,0);        for(int i=1;i<=n;i++){            if(vis[i]&&i!=col[1]){                int qr=sum[i];                int szv=n-qr;                de=(de+1LL*szv*(szv-1)/2);            }        }        int colnum=0;        for(int i=1;i<=n;i++)colnum+=vis[i];        ans=1LL*n*(n-1)/2*colnum;        ans=((ans-de));        printf("Case #%d: %lld\n",ca++,ans);    }    return 0;}



原创粉丝点击