hdu5909Tree Cutting

来源:互联网 发布:java 线程池 如何使用 编辑:程序博客网 时间:2024/05/18 02:45

链接:http://acm.hdu.edu.cn/showproblem.php?pid=5909

题意:给定一棵无根树,统计所有子树的异或和的个数。

分析:求出所有子树的异或和,题解的两种方法我都写了一下。第一种是FWT加速卷积O(n*m*logn)。第二种是树分治,因为是无根树,我们可以每次用树dp确定过重心的方案数,然后每次删掉重心就是树分治啦O(n*mlgon)。

fwt代码:

#include<map>#include<set>#include<stack>#include<cmath>#include<queue>#include<bitset>#include<math.h>#include<vector>#include<string>#include<stdio.h>#include<cstring>#include<iostream>#include<algorithm>#pragma comment(linker, "/STACK:102400000,102400000")using namespace std;typedef double db;typedef long long ll;typedef unsigned int uint;typedef unsigned long long ull;const db eps=1e-5;const int N=1e3+30;const int M=1e5+10;const ll MOD=1000000007;const int mod=1000000007;const int MAX=1000000010;const double pi=acos(-1.0);int n,m,w[N],tot,u[N],v[2*N],pre[2*N];ll d[N],ans[N],dp[N][N];void add(int a,int b) {    v[tot]=b;pre[tot]=u[a];u[a]=tot++;    v[tot]=a;pre[tot]=u[b];u[b]=tot++;}void fwt(ll a[],int n,ll inv) {    int i,j,k,h;    for (h=1;h<n;h<<=1)        for (k=h<<1,i=0;i<n;i+=k)            for (j=0;j<h;j++) {                ll x=a[i+j],y=a[i+j+h];                a[i+j]=(x+y)*inv%MOD;                a[i+j+h]=(x-y)*inv%MOD;            }}void dfs(int a,int b) {    int i,j;    dp[a][w[a]]=1;    for (i=u[a];~i;i=pre[i])    if (v[i]!=b) {        dfs(v[i],a);        for (j=0;j<m;j++) d[j]=dp[a][j];        fwt(dp[a],m,1ll);fwt(dp[v[i]],m,1ll);        for (j=0;j<m;j++) dp[a][j]=dp[a][j]*dp[v[i]][j]%MOD;        fwt(dp[a],m,(MOD+1)/2);///fwt(dp[v[i]],m,(MOD+1)/2);        for (j=0;j<m;j++) (dp[a][j]+=d[j])%=MOD;    }    for (j=0;j<m;j++) (ans[j]+=dp[a][j])%=MOD;}int main(){    int a,b,i,T;    scanf("%d", &T);    while (T--) {        scanf("%d%d", &n, &m);        for (i=1;i<=n;i++) scanf("%d", &w[i]);        memset(ans,0,sizeof(ans));        tot=0;memset(u,-1,sizeof(u));        for (i=1;i<n;i++) scanf("%d%d", &a, &b),add(a,b);        memset(dp,0,sizeof(dp));        dfs(1,0);        for (i=0;i<m-1;i++) printf("%lld ", ans[i]);        printf("%lld\n", ans[m-1]);    }    return 0;}

树分治代码:

#include<map>#include<set>#include<stack>#include<cmath>#include<queue>#include<bitset>#include<math.h>#include<vector>#include<string>#include<stdio.h>#include<cstring>#include<iostream>#include<algorithm>#pragma comment(linker, "/STACK:102400000,102400000")using namespace std;typedef double db;typedef long long ll;typedef unsigned int uint;typedef unsigned long long ull;const db eps=1e-5;const int N=1e3+30;const int M=1e5+10;const ll MOD=1000000007;const int mod=1000000007;const int MAX=1000000010;const double pi=acos(-1.0);int n,m,w[N],tot,u[N],v[2*N],pre[2*N];ll ans[N],dp[N][N];void add(int a,int b) {    v[tot]=b;pre[tot]=u[a];u[a]=tot++;    v[tot]=a;pre[tot]=u[b];u[b]=tot++;}int sum,root,f[N],vis[N],siz[N];void get_root(int a,int b) {    f[a]=0;siz[a]=1;    for (int i=u[a];~i;i=pre[i])    if (v[i]!=b&&!vis[v[i]]) {        get_root(v[i],a);        siz[a]+=siz[v[i]];        f[a]=max(f[a],siz[v[i]]);    }    f[a]=max(f[a],sum-siz[a]);    if (f[a]<f[root]) root=a;}void get_ans(int a,int b) {    int i,j;    for (i=u[a];~i;i=pre[i])    if (v[i]!=b&&!vis[v[i]]) {        for (j=0;j<m;j++) dp[v[i]][j]=0;        for (j=0;j<m;j++) (dp[v[i]][j^w[v[i]]]+=dp[a][j])%=MOD;        get_ans(v[i],a);        for (j=0;j<m;j++) (dp[a][j]+=dp[v[i]][j])%=MOD;    }}void dfs_div(int a,int b) {    for (int i=0;i<m;i++) dp[a][i]=0;    dp[a][w[a]]=1;vis[a]=1;    get_ans(a,b);    for (int i=0;i<m;i++)    (ans[i]+=dp[a][i])%=MOD;    for (int i=u[a];~i;i=pre[i])    if (v[i]!=b&&!vis[v[i]]) {        root=0;sum=siz[v[i]];        get_root(v[i],0);        dfs_div(root,0);    }}int main(){    int a,b,i,T;    scanf("%d", &T);    while (T--) {        scanf("%d%d", &n, &m);        for (i=1;i<=n;i++) scanf("%d", &w[i]);        tot=0;memset(u,-1,sizeof(u));        for (i=1;i<n;i++) scanf("%d%d", &a, &b),add(a,b);        memset(vis,0,sizeof(vis));        memset(ans,0,sizeof(ans));        root=0;f[0]=sum=n;        get_root(1,0);        dfs_div(root,0);        for (i=0;i<m-1;i++) printf("%lld ", ans[i]);        printf("%lld\n", ans[m-1]);    }    return 0;}


1 0
原创粉丝点击