hihocoder 1479 三等分 树型dp

来源:互联网 发布:python筛选excel数据 编辑:程序博客网 时间:2024/06/10 21:44

描述

小Hi最近参加了一场比赛,这场比赛中小Hi被要求将一棵树拆成3份,使得每一份中所有节点的权值和相等。

比赛结束后,小Hi发现虽然大家得到的树几乎一模一样,但是每个人的方法都有所不同。于是小Hi希望知道,对于一棵给定的有根树,在选取其中2个非根节点并将它们与它们的父亲节点分开后,所形成的三棵子树的节点权值之和能够两两相等的方案有多少种。

两种方案被看做不同的方案,当且仅当形成方案的2个节点不完全相同。

输入

每个输入文件包含多组输入,在输入的第一行为一个整数T,表示数据的组数。

每组输入的第一行为一个整数N,表示给出的这棵树的节点数。

接下来N行,依次描述结点1~N,其中第i行为两个整数Vi和Pi,分别描述这个节点的权值和其父亲节点的编号。

父亲节点编号为0的节点为这棵树的根节点。

对于30%的数据,满足3<=N<=100

对于100%的数据,满足3<=N<=100000, |Vi|<=100, T<=10

输出

对于每组输入,输出一行Ans,表示方案的数量。

样例输入
231 01 11 241 01 11 21 3
样例输出
10

统计所形成的三棵子树的节点权值之和能够两两相等的方案,等价于在这树上取两个不同且非根结点,形成三棵子树后子树节点权值之和两两相等。

 

树型dp求解,每个节点维护res(以这个节点为根的子树节点权值和),cnt(以这个节点为根的子树权值等于sum/3的节点个数)。

 

一开始能够想到的一个A节点res=sum/3,那么在这棵子树外再找一个res=sum/3的B节点进行组合不就得出一种方案了吗。可是这里面是分两类的

1.      B不是A的祖先,那么后来枚举B的时候A又被算了一次。记为2*s1

2.      B是A的祖先,其实这种情况是错误的,因为A、B分别取出后,A子树res=sum/3,B子树res=0(因为A子树本来就是B的一部分啊),这种方案是错误的要除去,且记为p

 

还有一种情况是一个节点res=sum*2/3,那么这个节点与其子树内不包括它自己,任意一个res=sum/3的节点相组合就是一种方案,记为s2。

 

因为不能选root,所以cnt对res[root]=sum/3情况不予考虑

2*s1+p=cnt[root]^2 - (res[x]=sum/3&&x!=root)

P= (cnt[x]>0&&res[x]=sum/3&&x!=root)

S2= (res[x]=sum*2/3&&x!=root)

最后ans=s1+s2

Dp时维护好数据最后求解即可

 

 

#include <bits/stdc++.h>using namespace std;typedef long long ll;const int maxn=1e5+8;vector<int> g[maxn];int root;ll sum;ll res[maxn];ll cnt[maxn];ll a[maxn];ll s2,s3,sig;void dfs(int x,int p){    res[x]=a[x];    cnt[x]=0;    if(g[x].size()<=1){        if(res[x]==sum)cnt[x]=1;        sig+=cnt[x];        //cout<<"xx="<<x<<" "<<sum<<" "<<res[x]<<" "<<cnt[x]<<endl;        return ;    }    for(int i=0;i<g[x].size();i++){        int u=g[x][i];        if(u==p)continue;        dfs(u,x);        res[x]+=res[u];        cnt[x]+=cnt[u];    }    if(res[x]==sum&&x!=root)cnt[x]+=1;    if(res[x]==sum&&x!=root)sig+=cnt[x];    if(cnt[x]>0&&res[x]==sum&&x!=root)s3+=(cnt[x]-1);    if(res[x]==2*sum&&x!=root)s2+=(res[x]==sum?cnt[x]-1:cnt[x]);}int main(){    int T;    scanf("%d",&T);    while(T--){        int n;        scanf("%d",&n);        for(int i=0;i<n+7;i++)g[i].clear();        sum=0;        for(int i=1;i<=n;i++){            int v,p;            scanf("%d%d",&v,&p);            a[i]=v;            g[i].push_back(p);            g[p].push_back(i);            if(p==0)root=i;            sum+=v;        }        if(sum%3){printf("0\n");continue;}        sum/=3;        s2=s3=sig=0;        dfs(root,0);//        for(int i=1;i<=n;i++){//            printf("id==%d res==%I64d cnt==%I64d\n",i,res[i],cnt[i]);//        }//        cout<<" s2=="<<s2<<" s3=="<<s3<<" sig=="<<sig<<endl;        ll ans=((cnt[root]*cnt[root]-sig)-s3)/2+s2;        cout<<ans<<endl;    }    return 0;}


0 0
原创粉丝点击