HDU-6057 Kanade's convolution(多项式FWT)

来源:互联网 发布:sql统计总金额 编辑:程序博客网 时间:2024/06/14 08:37

HDU-6057
官方题解

解释一下官方题解:
因为y中为0的位置,x中若为1,肯定a,b该位置上都为1,若为0, a和b该位置都为0,即y中为0的位置,确定x后,a和b的该位置上的值都是确定的, 而y中为1的位置,x中肯定为1,a,b有 a1b0,a0b1两种选择。于是最后有2bit(y)种选择方法。
重写式子的前三步很容易理解
最后一步可以证明
若x xor y =k 则 x and y =y <==> bit[x]-bit[y]=bit[k]
从左向右 显然
从右向左 k中为0的位置,x和y一定相同。 而k中为1的位置,x和y的位置肯定不同,这些位置有bit[k]个, 只有当这些位置 x均为1,y均为0时 可以使条件成立 ,此时x and y = y。

如下代码实际执行了一个按位中1个数分类,然后根据位数之差FWT的过程
因为FWT可以线性相加,c[i]数组其实同时执行了b和a位数差为i的所有的FWT,它们的和累加在了一起。
注释里是错误写法,它没有把之前对应数组的所有FWT后的值累加到一个数组中

#include <bits/stdc++.h>using namespace std;const int MAXN=1<<20;const int MOD=998244353;const int inv2=(MOD+1)>>1;void fwt(int a[],int len,int mode){    if(mode)        for(int d=1;d<len;d<<=1)            for(int m=d<<1,i=0;i<len;i+=m)                for(int j=0;j<d;j++)                {                    int x=a[i+j],y=a[i+j+d];                    a[i+j]=(x+y)%MOD,a[i+j+d]=(x-y)%MOD;                }    else        for(int d=1;d<len;d<<=1)            for(int m=d<<1,i=0;i<len;i+=m)                for(int j=0;j<d;j++)                {                    int x=a[i+j],y=a[i+j+d];                    a[i+j]=1ll*(x+y)*inv2%MOD,a[i+j+d]=1ll*(x-y)*inv2%MOD;                }}int m,len,bit[MAXN];void init(){    len=1<<m;    for(int i=0;i<len;i++)        bit[i]=bit[i>>1]+(i&1);}int a[22][MAXN],b[22][MAXN],c[22][MAXN];int main(){    scanf("%d",&m);    init();    int ta,tb;    for(int i=0;i<len;i++)    {        scanf("%d",&ta);        a[bit[i]][i]=1ll*ta*(1<<bit[i])%MOD;    }    for(int i=0;i<len;i++)    {        scanf("%d",&ta);        b[bit[i]][i]=ta;    }    for(int i=0;i<=m;i++)    {        fwt(a[i],len,1);        fwt(b[i],len,1);    }    for(int i=0;i<=m;i++)        for(int j=i;j<=m;j++)            for(int k=0;k<len;k++)                c[j-i][k]=(c[j-i][k]+1ll*b[j][k]*a[i][k]%MOD)%MOD;//    for(int k=0;k<len;k++)//        for(int i=bit[k];i<=m;i++)//        {//            c[bit[k]][k]=c[bit[k]][k]+1ll*b[i][k]*a[i-bit[k]][k]%MOD;//        }    for(int i=0;i<=m;i++)    {        fwt(c[i],len,0);        for(int j=0;j<len;j++)            if(c[i][j]<0)                c[i][j]+=MOD;    }    long long ans=0,tmp=1;    for(int i=0;i<len;i++)    {        ans+=c[bit[i]][i]*tmp%MOD;        tmp=tmp*1526%MOD;    }    ans%=MOD;    printf("%lld\n",ans);    return 0;}
原创粉丝点击