FFT学习笔记(DFT,IDFT)

来源:互联网 发布:数据库系统分为 编辑:程序博客网 时间:2024/06/07 06:23

昨天参悟了一天FFT,总算是理解了,今天的莫比乌斯反演也不太懂,干脆弃疗,决定来认真水一发博客。

什么是FFT?

FFT(Fast Fourier Transformation),即为快速傅氏变换,是离散傅氏变换(DFT)的快速算法,它是根据离散傅氏变换的奇、偶、虚、实等特性,对离散傅立叶变换的算法进行改进获得的。

FFT的作用?

主要用于加速多项式乘法(形如an x^n + a(n - 1) x^(n - 1) + …… + a1 x + a0),同时可以优化很多与多项式乘法相近的内容,比如高精度乘法(令x为10)。

先明确几个概念:



复数:

由两个部分组成,实数部分,虚数部分,形如 :a,ib(a为实数部分) 其中i^2 = -1,显然i不是一个实数。

复数的运算法则:

加法:实数部分相加,虚数部分相加

减法:实数部分相减,虚数部分相减

乘法:

我们来举一个例子:

(a,ib)* (c,id)

= ac + iad + ibc + i^2bd

= (ac - bd) + i(ad + bc)

=(ac - bd,i(ad + bc))

(i ^ 2 = -1)


我们考虑用坐标系来表示一下复数,


可以理解一下,对后文的一些讲解会有所帮助。(注意y轴的默认单位长度为i)

由坐标轴可以得出,复数(a,ib)的模长为sqrt(a^2 + b^2)

同理我么可以得出复数的乘法运算的直观体现,模长相乘,幅角相加。(自己可以带入两个(1,i1)计算来很好的理解)




多项式的系数表示与点值表示。

我们知道一个最高次项为n的多项式,有n + 1个系数,x^n……x^0的对应的系数。

如果我们将这n+1个系数构成一个n+1维的向量,显然可以唯一的确定出一个多项式。

那么这个向量就是系数表达式。


如果我们带入n个数字,求算出n个对应的值,那么这些值就构成了点值表达式。


我们同样可以认为这个点值表达式可以唯一确定出一个多项式。

证明如下:(转自Menci,鸣谢作者,链接详见左侧友情链接)

证明:假设命题不成立,存在两个不同的 n - 1n1 次多项式 A(x)A(x)B(x)B(x),满足对于任何 i \in [0,\ n - 1]i[0, n1],有 A(x_i) = B(x_i)A(xi)=B(xi)

令 C(x) = A(x) - B(x)C(x)=A(x)B(x),则 C(x)C(x) 也是一个 n - 1n1 次多项式。对于任何 i \in [0,\ n - 1]i[0, n1],有 C(x_i) = 0C(xi)=0

即 C(x)C(x) 有 nn 个根,这与代数基本定理(一个 n - 1n1 次多项式在复数域上有且仅有 n - 1n1 个根)相矛盾,故 C(x)C(x) 并不是一个 n - 1n1 次多项式,原命题成立,证毕。

插值:已知点值表达,求系数表达式


单位根:我们上文提及虚数可以在坐标系内表示。我们可以在坐标系内做半径为1的圆,作为单位元,如果把单位圆分成n分,那么最靠近x轴正半轴的一份的考上的边即为wn = w0,即为单位根,剩下的依次为w1,w2,w3……wn-1;

其中单位根的幅角2π/n ,由欧拉公式可以得出cos2k2n2π+isin2k2n2π=coskn2π+isinkn2π

我们在求解点值表达式时,通常带入单位根,举个例子

如果有n项,那么我们可以分别带入wn^0,wn^2,wn^(n-1),这样子便于计算,此结论是前人证明,在此不详细叙述。

同时我们随手得出几个小的结论。

没有什么比画图更能说明这个两个结论了。(结合上文提及的复数乘法)

折半定理:


\omega_{2n} ^ {2k} = \omega_n ^ k
ω2n2k=ωnk

ωnk+2n=ωnk



好,我们开始步入正题,使用FFT进行多项式乘法在nlogn的时间内进行运算。

我们简述一下FFT的流程,先将这两个多项式转换为点值表达式,然后在线性将两个多项式每一位相乘,然后将得出的新的点值表达式转换会系数表达式输出即可。

我们闲来考虑第一部分,将系数表达式转化为点值表达式。


我们先明确一下,我们一下所指的所有多项式,最高次项均为2^k - 1。如果不足,请默认在高位补零。

通常将系数表达式转化为点值表达式有两种方法,递归与迭代,递归由于传参可能涉及到数组,所以通常效率稍微差些,而迭代版则不存在这个问题,但是递归更便于理解,所以我们从递归说起。

我们定义一个函数DFT(vector<复数>) vector内存着每一位的系数,可以将系数表达式转化为点值表达式

首先我们先明确边界,如果只有一个数,那么系数表达式就是点值表达式,直接返回即可。

如果没有到边界,我们进行如下操作。

先将每个元素分别按照下标的奇偶处理,分别递归操作。

这时候我们举个栗子来观察一下

a3 (w4) ^ 3 + a2(w4)^2 + a1(w4)^1 + a0

按奇偶分成两个递归

a3(w2) ^ 1 + a1  a2(w2) ^ 1 + a0

我们可以根据之前的折半定理 

将递归出的左式转化为,a3(w4) ^ 2 + a1(w4),我么发现这个式子×w4即为上文的4项式子的奇数部分的值。

我们递归出的右式,通过折半定理,可以转化成 a2(w4)^2 + a0 即为4项式的偶数项的和。


我们上文讲的式对于带入一个值求出的点值表达式,而我们的DFT是返回带入n个单位根,每个答案分别存在vector中一位。

所以我们要宏观的再来考虑一下,我们用f(i)表示带入i,当前多项式得到的结果。

我们来宏观的考虑一下。

f(w0) ,f(w1),f(w2),f(w3)  对应的多项式 a3(wx)^3 + a2(wx) ^ 2 + a1(wx) ^ 1 + a0


f(w1),f(w0)  对应的多项式 a3(w x/2)^1 + a1    f(w1),f(w0) 对应的多项式 a2(w x/2)^1 + a0


根据我们上文的推理,我们可以得出

f(w0)  a3(w0)^3 + a2(w0) ^ 2 + a1(w0) ^ 1 + a0

=w0(a3(w0) ^ 1 + a1) +   a2(w0) ^ 1 + a0

= 左f(w0) * w0 + 右f(w0)


同理我们只需要将w1……wn-1同样处理,只是每次不是在 左×w0而是(w1……wn-1)即可。

这样子我认为已经很详细的写出了如何在递归中求出点值表达式了。

我们给出代码:

vector<pot> DFT(vector<pot> a){if (a.size() == 1) return a; vector<pot> a1,a0,y1,y0,ans;for (int i = 0;i < a.size();i++){if (i & 1)a1.push_back(a[i]);elsea0.push_back(a[i]);}y0 = DFT(a0,pd);y1 = DFT(a1,pd);pot wn;wn = pot(cos(2.0f * PI /(double) a.size()),sin(2.0f * PI / (double) a.size()));pot w = pot(1.0,0.0);ans.resize(a.size());for (int i = 0;i < a.size() / 2;i++){ans[i] = y0[i] + y1[i] * w;ans[i + (a.size() >> 1)] = y0[i] - y1[i] * w;w = w * wn;}return ans;}

现在我们只需要用上述代码将两个多项式的点值表达式,然后将两个多项式的点值表达式相乘,最后在利用下面的IDFT转化回系数表达式,就可以很轻易的求出多项式乘法了。

后来有些小伙伴私信我问上述代码的第二个for循环内的计算对称部分的为什么在y1前加了负号,这个问题之前忘记书写了,这里解释一下。

我们考虑ak((wn) ^ p)^k 对应的在右侧部分的为ak((wn) ^ (p + n/2))^k

然后我们上文有提及两个很基本的小性质,其中第二个可以把 (wn) ^ (p + n/2) 转换为-wn^p,而当我们位于右侧部分,也就是所谓的奇数部分,最外侧的^k不为偶数,算出的值为负,所以需要在右侧加上-号

有人问wn为什么那么算,就是在复数坐标轴上很简单的几何意义,可以自己画一下。


递归如果已经理解,那么迭代就非常容易理解了,在此给出代码,基本思路跟递归是相同的,只不过我们通过一个for循环来枚举长度而已,但注意此时我们发现迭代中缺少了递归中奇偶分类,但是非常幸运,我们是可以预先推算吃迭代的处理顺序,从而提前处理好奇偶数的位置关系 ,这里给出基于二分的nlogn处理方式,这里非常显然,不做任何讲解。然而有一种更加优美的写法,通过二进制的奇妙操作,在常数较短的时间内进行处理奇数偶数。

暴力nlogn写法:

void pre(int l, int r){if (l < r){int mid = (l + r) >> 1;static pot dl[500010], dr[500010];for (int i = l; i < r; i += 2){dl[(i - l) >> 1] = dla[i];dr[(i - l) >> 1] = dla[i + 1];}memcpy(dla + l, dl, (mid - l + 1) * sizeof(dla[0]));memcpy(dla + mid + 1, dr, (r - mid) * sizeof(dla[0]));for (int i = l; i < r; i += 2){dl[(i - l) >> 1] = dlb[i];dr[(i - l) >> 1] = dlb[i + 1];}memcpy(dlb + l, dl, (mid - l + 1) * sizeof(dlb[0]));memcpy(dlb + mid + 1, dr, (r - mid) * sizeof(dlb[0]));pre(l, mid);pre(mid + 1, r);}}

二进制翻转方法

我们可以用一个

000 001 010 011 100 101 110 111 0   1   2   3   4   5   6   7 0   2   4   6 - 1   3   5   7 0   4 - 2   6 - 1   5 - 3   7 0 - 4 - 2 - 6 - 1 - 5 - 3 - 7000 100 010 110 001 101 011 111(本段演示及相关二进制转化代码均转自Menci,鸣谢作者,链接详见友情链接)
int k = 0;while ((1 << k) < n) k++;for (int i = 0; i < n; i++) {    int t = 0;    for (int j = 0; j < k; j++) if (i & (1 << j)) t |= (1 << (k - j - 1));    if (i < t) std::swap(a[i], a[t]);}

void DFT()
{for (int mi = 2;mi <= n;mi <<= 1){pot wn;wn = pot(cos(PI * 2 / (double) mi),-sin(PI * 2 / (double) mi));for (int j = 0;j < n;j += mi){int midn = j + (mi >> 1);pot w = pot(1.0f,0);for (int k = j;k < midn;k++){pot tp = dla[k];dla[k] = dla[k] + w * dla[k + (mi >> 1)];dla[k + (mi >> 1)] = tp - w * dla[k + (mi >> 1)];tp = dlb[k];dlb[k] = dlb[k] + w * dlb[k + (mi >> 1)];dlb[k + (mi >> 1)] = tp - w * dlb[k + (mi >> 1)];w = w * wn;}}}}


我们来考虑一下,如何将点值表达式转化为系数表达式。

我们把从系数表达式求成点值表达式的过程抽象为矩阵乘法


A矩阵                                                 (wn的幂<0时)D矩阵 否则为V矩阵,即D是V的逆矩阵F矩阵


以下为具体的矩阵推导过程,(鸣谢xys在此的帮助,其github详见友情链接)

F=V×A(显而易见)

E=D×V(由一些奇妙的定理可得,E为长度为n的单位矩阵,即对角线为n,其余区域为0)`

I=1/nE (I为单位长度是1的单位矩阵)

 =1/nD × V

1/n D = V^(-1) 与F=V×A连理得

1/n DF = A


我们回头来看,这tmd不就是DFT的逆过程么,只需要在前面加一个1/n即可。

所以我们只需要在DFT内加一个小的改动,并将结果进行一个小处理即可。

具体改动是将

wn = pot(cos(2.0f * PI /(double) a.size()),sin(2.0f * PI / (double) a.size()));

在sin前面加上一个负号,并且在操作完全结束后,将ans数组/n即可。

至于为什么这么做,因为实际上,我们在操作中不会单独搞出一个D矩阵,而是继续将就着用V矩阵来节约代码量,所以我们需要将wn进行修改。而ans/n是因为我们

1/n DF = A 前有一个n/1,最后上传一下二合一的DFT与IDFT,传参为0或任意非零值。

迭代版本:

void DFT(int pd){for (int mi = 2;mi <= n;mi <<= 1){pot wn;if (!pd)wn = pot(cos(PI * 2 / (double) mi),sin(PI * 2 / (double) mi));elsewn = pot(cos(PI * 2 / (double) mi),-sin(PI * 2 / (double) mi));for (int j = 0;j < n;j += mi){int midn = j + (mi >> 1);pot w = pot(1.0f,0);for (int k = j;k < midn;k++){pot tp = dla[k];dla[k] = dla[k] + w * dla[k + (mi >> 1)];dla[k + (mi >> 1)] = tp - w * dla[k + (mi >> 1)];tp = dlb[k];dlb[k] = dlb[k] + w * dlb[k + (mi >> 1)];dlb[k + (mi >> 1)] = tp - w * dlb[k + (mi >> 1)];w = w * wn;}}}}


vector<pot> DFT(vector<pot> a,int pd){//printf("DFT : %d\n", a.size());if (a.size() == 1) return a; vector<pot> a1,a0,y1,y0,ans;for (int i = 0;i < a.size();i++){if (i & 1)a1.push_back(a[i]);elsea0.push_back(a[i]);}y0 = DFT(a0,pd);y1 = DFT(a1,pd);pot wn;if (pd == 0) wn = pot(cos(2.0f * PI /(double) a.size()),-sin(2.0f * PI / (double) a.size()));elsewn = pot(cos(2.0f * PI /(double) a.size()),sin(2.0f * PI / (double) a.size()));pot w = pot(1.0,0.0);ans.resize(a.size());for (int i = 0;i < a.size() / 2;i++){ans[i] = y0[i] + y1[i] * w;ans[i + (a.size() >> 1)] = y0[i] - y1[i] * w;w = w * wn;}return ans;}

最后强调一个细节,使用该方法来对齐2的n次方。

int l = la + lb + 2;int k = 0;while (l > 0){k++;l >>= 1;}n = 1 << k;


第一次写技术含量这么高的博客,如果有什么不周到的,大家可以提出,或私信我,有人问上面出现的wnw0的关系,我这默认wn== w0,酱紫。

看完全文的小伙伴辛苦了,有兴趣可以点点友链之类的。=ω=

0 0