FFT

来源:互联网 发布:java优缺点 编辑:程序博客网 时间:2024/06/06 07:05

本博客只是部分理解,非教学对学习有一定帮助作用,但是建议配合书来看
没想到还没学会反演就先弄了FFT
首先FFT的作用是快速多项式乘法
具体思想就是我们发现对于一个多项式当我们把它看做函数的时候,n+1个点必能确定一个n次的多项式,具体证明用矩阵乘法,蒟蒻并不太懂。
那么我们发现如果两个多项式(函数)相乘同一个x对应的y值一定是原来两个y的积
那么我们不妨先将多项式用点表示然后乘起来再变回系数表达。
这个过程中我们发现

w2=1

在复数域中有n个解,且这些解将复数平面平分。之后我们根据欧拉老先生的公式
e^(iu)=cosu+i*sinu
之后我们变变变就变了
我也不太懂啊
我们发现这么变就是徒增复杂度
但是我们发现其实这个过程是可以二分的
具体原因是消去引理和拆分引理
这个证明的过程不妨看算导

看一个代码
这是一道模板题。

给你两个多项式,请输出乘起来后的多项式。

输入格式
第一行两个整数 nn 和 mm,分别表示两个多项式的次数。

第二行 n+1n+1 个整数,分别表示第一个多项式的 00 到 nn 次项前的系数。

第三行 m+1m+1 个整数,分别表示第一个多项式的 00 到 mm 次项前的系数。

输出格式
一行 n+m+1n+m+1 个整数,分别表示乘起来后的多项式的 00 到 n+mn+m 次项前的系数。

样例一
input

1 2
1 2
1 2 1

output

1 4 5 2

explanation

(1+2x)⋅(1+2x+x2)=1+4x+5x2+2x3(1+2x)⋅(1+2x+x2)=1+4x+5x2+2x3。

#include<bits/stdc++.h>#define pi acos(-1)#define N 300020using namespace std;typedef complex<double>E;int n,m;E a[N],b[N];void fft(E *x,int n,int type){    if(n==1)return ; //如果已经削成一次的就直接返回     E l[n>>1],r[n>>1];//将原来的n次(注意一定是2的某次幂) 分为两个n/2次    //fft的核心思想就是消去引理和拆分引理     for(int i=0;i<n;i+=2)        l[i>>1]=x[i],r[i>>1]=x[i+1];        //将原来的值分别赋给新的多项式     fft(l,n>>1,type);fft(r,n>>1,type);    //分别进行fft     E wn(cos(2*pi/n),sin(type*2*pi/n)),w(1,0),t;    for(int i=0;i<n>>1;i++,w*=wn)        t=w*r[i],x[i]=l[i]+t,x[i+(n>>1)]=l[i]-t;        //fft的系数变点 }int main(){    scanf("%d%d",&n,&m);    for(int i=0,x;i<=n;i++)scanf("%d",&x),a[i]=x;    for(int i=0,x;i<=m;i++)scanf("%d",&x),b[i]=x;    //新的多项式需要为N+M次的多项式     m=n+m;for(n=1;n<=m;n<<=1);//需要倍增到的最大值     fft(a,n,1);//对a,b分别进行倍增的fft     fft(b,n,1);    for(int i=0;i<=n;i++)a[i]=a[i]*b[i];     //点直接相乘     fft(a,n,-1);    //反向fft,即点变系数     for(int i=0;i<=m;i++)        printf("%d ",(int)(a[i].real()/n+0.5));}

刚才的非递归写法没保存就算了
找到了
我觉得zcy写的很好看我写的就很丑了

#include<bits/stdc++.h>using namespace std;inline int read(){    char ch=getchar();    int num=0,f=1;    while(ch<'0'||ch>'9') {        if(ch=='-' ) f=-1;        ch=getchar();    }    while(ch<='9'&&ch>='0') {        num=(num<<1)+(num<<2)+ch-'0';        ch=getchar();    }    return num*f;}#define N 3000200#define FFT#ifdef FFTtypedef complex<double> COM;int n,m,l;COM a[N],b[N];int r[N];#define pi acos(-1)void dft(COM *f,int kind)   //kind=1时为DFT,=-1时为IDFT{    for(int i=0;i<n;i++)if(i<r[i])swap(f[i],f[r[i]]);    for(int i=1;i<n;i<<=1)          //枚举当前做DFT的子序列长为2i    {        COM x,y,wn(cos(pi/i),kind*sin(pi/i));    //单位根ωi        for(int j=0;j<n;j+=i<<1)        {            COM w(1,0);            for(int k=0;k<i;k++)            {                x=f[j+k],y=w*f[j+i+k]   ;                f[j+k]=x+y,f[j+i+k]=x-y;                w=w*wn;            }        }    }    //if(kind==-1)for(int i=0;i<n;i++)f[i].real()/=n;}#endifint main(){    cin>>n>>m;    for(int i=0,x;i<=n;i++) a[i]=read();    for(int i=0,x;i<=m;i++) b[i]=read();    m=n+m; for(n=1;n<=m;n<<=1)l++;    for(int i=0;i<=n;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));    dft(a,1);    dft(b,1);    for(int i= 0;i<=n;i++)a[i]=a[i]*b[i];    dft(a,-1);    for(int i=0;i<=m;i++)        printf("%d ",(int)(a[i].real()/n+0.5));}

直接看NTT吧

#include<bits/stdc++.h>using namespace std;inline int read(){    char ch=getchar();    int num=0,f=1;    while(ch<'0'||ch>'9') {        if(ch=='-' ) f=-1;        ch=getchar();    }    while(ch<='9'&&ch>='0') {        num=(num<<1)+(num<<2)+ch-'0';        ch=getchar();    }    return num*f;}#define N 3000200#define FFT#ifdef FFT#define MOD 998244353int n,m,l;int a[N],b[N];int r[N];inline int QPow(int d,int z){  int ans=1;  for(;z;z>>=1,d=1ll*d*d%MOD)    if(z&1)ans=1ll*ans*d%MOD;  return ans;}void dft(int *f,int kind){    for(int i=0;i<n;i++)if(i<r[i])swap(f[i],f[r[i]]);    for(int i=1;i<n;i<<=1)    {        int x,y,gn=QPow(3,(MOD-1)/(i<<1));    //单位根ωi        for(int j=0;j<n;j+=i<<1)        {            int g=1;            for(int k=0;k<i;++k,g=1ll*g*gn%MOD)            {                x=f[j+k],y=1ll*g*f[j+i+k]%MOD ;                f[j+k]=(x+y)%MOD,f[j+i+k]=(x-y+MOD)%MOD;            }        }    }    if(kind==1) return ;    reverse(f+1,f+n);    int y=QPow(n,MOD-2);    for(int i=0;i<n;i++) f[i]=1ll*f[i]*y%MOD;}#endifint main(){    cin>>n>>m;//  cout<<n<<m<<endl;    for(int i=0;i<=n;++i) a[i]=read();    for(int i=0;i<=m;++i) b[i]=read();    m=n+m; for(n=1;n<=m;n<<=1)l++;    for(int i=0;i<n;++i) r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));    dft(a,1);    dft(b,1);    for(int i=0;i<n;++i)a[i]=1ll*a[i]*b[i]%MOD;    dft(a,-1);    for(int i=0;i<=m;++i)    printf("%d ",a[i]);    return 0;}
原创粉丝点击