HDU 5909 Tree Cutting (点分治+树形DP|FWT+树形DP)

来源:互联网 发布:java统一协议才能注册 编辑:程序博客网 时间:2024/06/06 02:36

题目描述

传送门

题目大意:给出一棵树,求异或和为[0..m-1]的非空连通子图的个数。

题解1

FWT+树形DP
f[i][j]表示以i为根异或和为j的连通子树的个数(注意必须是i的子树中)
f[x][j^k]=f[x][j^k]+f[x][j]f[son][k] 这个转移方程的瓶颈在于f[x][j]f[son][k],转移是O(m2)
可以发现转移实际上就是异或卷积,可以用FWT优化。
FWT这东西第一次接触,感觉原理就算懂了也会忘,所以直接背的板子。。。
时间复杂度O(nmlogm)

代码1

#include<iostream>#include<cstdio>#include<algorithm>#include<cstring>#include<cmath>#define N 2003#define p 1000000007#define LL long long using namespace std;const LL ret=(p+1)/2;LL dp[N][N],ans[N],tmp[N];int tot,n,m,nxt[N],v[N],point[N],val[N];void add(int x,int y){    tot++; nxt[tot]=point[x]; point[x]=tot; v[tot]=y;    tot++; nxt[tot]=point[y]; point[y]=tot; v[tot]=x;}void FWT(LL *a,int n){    for (int i=1;i<n;i<<=1)     for (int p1=i<<1,j=0;j<n;j+=p1)      for (int k=0;k<i;k++) {        LL x=a[j+k]; LL y=a[j+k+i];        a[j+k]=(x+y)%p;         a[j+k+i]=(x-y+p)%p;      }}void UFWT(LL *a,int n){    for (int i=1;i<n;i<<=1)     for (int p1=i<<1,j=0;j<n;j+=p1)      for (int k=0;k<i;k++){        LL x=a[j+k]; LL y=a[j+k+i];        a[j+k]=(x+y)%p*ret%p;        a[j+k+i]=((x-y)*ret%p+p)%p;      }}void solve(LL *a,LL *b,int n){    FWT(a,n); FWT(b,n);    for (int i=0;i<n;i++) a[i]=a[i]*b[i]%p;    UFWT(a,n);}void dfs(int x,int fa){    dp[x][val[x]]=1;    for (int i=point[x];i;i=nxt[i]) {        if (v[i]==fa) continue;        dfs(v[i],x);        for (int j=0;j<m;j++) tmp[j]=dp[x][j];        solve(dp[x],dp[v[i]],m);        for (int j=0;j<m;j++)          dp[x][j]=(tmp[j]+dp[x][j])%p;    }    for (int i=0;i<m;i++)     ans[i]=(ans[i]+dp[x][i])%p;}int main(){    freopen("a.in","r",stdin);//  freopen("my.out","w",stdout);    int T; scanf("%d",&T);    while (T--) {        tot=0;        memset(point,0,sizeof(point));        memset(dp,0,sizeof(dp));        memset(ans,0,sizeof(ans));        scanf("%d%d",&n,&m);        for (int i=1;i<=n;i++) scanf("%d",&val[i]);        for (int i=1;i<n;i++) {            int x,y; scanf("%d%d",&x,&y);            add(x,y);        }        dfs(1,0);        for (int i=0;i<m-1;i++) printf("%I64d ",ans[i]);        printf("%I64d\n",ans[m-1]);    }}

题解2

点分治+树形DP
对于每个节点,点分治到他的时候直接做树形依赖即可。
f[son][j^val[son]]=f[x][j],带入计算,最后用计算后f[x]数组更新总答案
这样每个点都之后被计算logn次,时间复杂度O(nmlogn)

代码2

#include<iostream>#include<cstdio>#include<algorithm>#include<cstring>#include<cmath>#define N 2003#define p 1000000007#define LL long long using namespace std;LL dp[N][N],ans[N];int tot,n,m,nxt[N],v[N],point[N],val[N],f[N],size[N],root,vis[N],sum;void add(int x,int y){    tot++; nxt[tot]=point[x]; point[x]=tot; v[tot]=y;    tot++; nxt[tot]=point[y]; point[y]=tot; v[tot]=x;}void getroot(int x,int fa){    f[x]=0; size[x]=1;    for (int i=point[x];i;i=nxt[i]) {        if (v[i]==fa||vis[v[i]]) continue;        getroot(v[i],x);        size[x]+=size[v[i]];        f[x]=max(f[x],size[v[i]]);    }    f[x]=max(f[x],sum-size[x]);    if (f[x]<f[root]) root=x;}void calc(int x,int fa){    for (int i=point[x];i;i=nxt[i]) {        if (v[i]==fa||vis[v[i]]) continue;        for (int j=0;j<m;j++) dp[v[i]][j^val[v[i]]]=dp[x][j];        calc(v[i],x);        for (int j=0;j<m;j++)         dp[x][j]=(dp[x][j]+dp[v[i]][j])%p;        for (int j=0;j<m;j++) dp[v[i]][j]=0;    }}void solve(int x){    vis[x]=1;    dp[x][val[x]]=1; calc(x,0);     for (int i=0;i<m;i++) ans[i]=(ans[i]+dp[x][i])%p;    for (int i=0;i<m;i++) dp[x][i]=0;    for (int i=point[x];i;i=nxt[i]) {        if (vis[v[i]]) continue;        sum=size[v[i]]; root=0;        getroot(v[i],x);        solve(root);    }}int main(){    freopen("a.in","r",stdin);//  freopen("my.out","w",stdout);    int T; scanf("%d",&T);    while (T--) {        tot=0;        memset(point,0,sizeof(point));        memset(dp,0,sizeof(dp));        memset(ans,0,sizeof(ans));        memset(vis,0,sizeof(vis));        scanf("%d%d",&n,&m);        for (int i=1;i<=n;i++) scanf("%d",&val[i]);        for (int i=1;i<n;i++) {            int x,y; scanf("%d%d",&x,&y);            add(x,y);        }        f[0]=p; root=0; sum=n;        getroot(1,0);         solve(root);        for (int i=0;i<m-1;i++) printf("%I64d ",ans[i]);        printf("%I64d\n",ans[m-1]);    }}
原创粉丝点击