UOJ 34 多项式乘法

来源:互联网 发布:腾讯云mysql 编辑:程序博客网 时间:2024/06/05 15:07

快速傅里叶变换

关于FFT网上的教材不多,而且大多与算法问题关系不大。强烈推荐一个。这个讲得真的很不错:从多项式乘法到快速傅里叶变换

本弱数学知识不够多,复数、单位根之类的知识都是下午临时补的。。。从下午开始看FFT,看到晚上,总算大概是把递归版FFT的思路看懂了吧。(迭代版的还没看懂。。。有空慢慢钻研)UPD(2017.4.3):已经成功实现了迭代版,详见这份代码的下一份代码……

注意到IDFT的时候需要把所有单位根取倒数,那么有一个复数倒数式子:
设复数Z=a+b i
那么1/Z=(a-bi) / [(a+bi)(a-bi)] =(a-bi) / (a²+b²) =a / (a²+b² ) - b i / (a²+b²)
然后我们发现单位根的a²+b²=sin²+cos²=1,所以只要变化连接a和bi的符号即可


递归版

代码打了注释,以供日后观赏学习

#include<cstdio>#include<cmath>#define N 1000000using namespace std;struct complex{    double r, i;    complex(double a=0, double b=0):r(a),i(b){}    complex operator + (complex a){return complex(r+a.r,i+a.i);}        complex operator - (complex a){return complex(r-a.r,i-a.i);}        complex operator * (complex a){return complex(a.r*r-a.i*i,a.r*i+r*a.i);}    complex operator ^ (int f){return complex(r,i*f);}//方便起见,我用异或表示是否取倒数 }a[N], b[N], w[N], temp[N];const double eps = 1e-2;int L;void FFT(int n, complex buffer[], int beg, int step, int f){    if(n==1)return;    int m=n>>1;    FFT(m,buffer,beg,step<<1,f);    FFT(m,buffer,beg+step,step<<1,f);    /*下面是最麻烦(复杂)的一部分*/    for(int i = 0; i < m; i++)    {        int pos=2*step*i;        //由递归的下一层为上一层贡献,由式子知上下层是两倍的关系         //[beg+2*i*step]和 [beg+(2*i+1)*step]贡献,因为递归下去之后step变大,两个贡献是交错的         temp[i]=buffer[beg+pos]+(w[i*step]^f)*buffer[beg+pos+step];//根据式子         temp[i+m]=buffer[beg+pos]-(w[i*step]^f)*buffer[beg+pos+step];//根据式子    }    for(int i = 0; i < n; i++)        buffer[beg+i*step]=temp[i];}void init_w(int n){    double pi=acos(-1.0);    for(int i = 0; i < n; i++)        w[i]=complex(cos(2*i*pi/n),sin(2*i*pi/n));//计算每一个单位根啦 }int main(){    int n, m;    scanf("%d%d",&n,&m);    for(int i = 0; i <= n; i++)        scanf("%lf",&a[i].r);    for(int i = 0; i <= m; i++)        scanf("%lf",&b[i].r);    for(L=1; L<=n+m; L<<=1);//补成2的幂,方便操作     init_w(L);//预处理单位根     FFT(L,a,0,1,1);    FFT(L,b,0,1,1);    for(int i = 0; i <= L; i++)        a[i]=a[i]*b[i];    FFT(L,a,0,1,-1);    for(int i = 0, ii = (m+n); i<= ii; i++)        printf("%d ",(int)((a[i].r+eps)/L));//小心精度误差,需要+eps    return 0;}

迭代版

做法思想都一样,都是奇偶分类然后大力化简……
和递归版不同的地方在于,递归版的分类是手动分组,通过每次设定step和beg来做。迭代版在刚开始就先把组分好了,然后就只要从下往上迭代直接做。

补充一下刚开始没理解好的东西。

在迭代的过程中,设当前迭代块的大小为i,则深度为n/i,记nn次单位根是wknk=0,1,...,n1
此时数组里的区间[l,l+i1]的第k(k=0,1,...,i1)个位置的含义是,wkn/i 在系数是al,al+1,...,al+i1的情况下的求值答案(当然这里的a是已经分组过的)。

最终迭代块为n,深度为1时就是答案。

两边用推出来的式子合并即可。

#include<cstdio> #include<cmath>#include<algorithm>#define N 400005 using namespace std;namespace Endless{    typedef double db;    struct complex    {        db r, i;        complex operator + (const complex &that) const {return (complex){r+that.r, i+that.i};}        complex operator - (const complex &that) const {return (complex){r-that.r, i-that.i};}        complex operator * (const complex &that) const {return (complex){r*that.r-i*that.i,r*that.i+i*that.r};}    }a[N], b[N], w[N], re_w[N];    const db pi = acos(-1.0);    int la, lb, n;    void init_w()    {        for(int i = 0; i < n; i++)        {            w[i] = (complex){cos(2*pi*i/n), sin(2*pi*i/n)};            re_w[i] = (complex){cos(2*pi*i/n), -sin(2*pi*i/n)};        }    }    void FFT(complex *a, complex *w)    {        for(int i = 0, j = 0; i < n; i++)        {            if(i > j) swap(a[i], a[j]);            for(int l = n >> 1; (j ^= l) < l; l >>= 1);        }        for(int i = 2; i <= n; i <<= 1)        {            int m = i >> 1;            for(int j = 0; j < n; j += i)            {                for(int k = 0; k < m; k++)                {                    complex tmp = w[n/i*k] * a[j+k+m];                    a[j+k+m] = a[j+k] - tmp;                     a[j+k] = a[j+k] + tmp;                }            }        }    }    void main()    {        scanf("%d%d",&la,&lb);         for(int i = 0; i <= la; i++) scanf("%lf",&a[i].r), a[i].i = 0;        for(int i = 0; i <= lb; i++) scanf("%lf",&b[i].r), b[i].i = 0;        for(n = 1; n <= (la+lb); n <<= 1); init_w();        FFT(a, w); FFT(b, w);        for(int i = 0; i < n; i++) a[i] = a[i] * b[i];        FFT(a, re_w);        for(int i = 0; i <= la+lb; i++) printf("%.0lf ",(a[i].r+1e-3)/n);    }}int main(){    Endless::main();}
1 0