【HDU1402】 【FFT求大数乘法】

来源:互联网 发布:淘宝店加盟代理 编辑:程序博客网 时间:2024/05/21 11:01

传送门:HDU 1402 A * B Problem Plus

描述:

A * B Problem Plus

Time Limit: 2000/1000 MS (Java/Others)    Memory Limit: 65536/32768 K (Java/Others)
Total Submission(s): 18015    Accepted Submission(s): 3981


Problem Description
Calculate A * B.
 

Input
Each line will contain two integers A and B. Process to end of file.

Note: the length of each integer will not exceed 50000.
 

Output
For each case, output A * B in one line.
 

Sample Input
1210002
 

Sample Output
22000
 

Author
DOOM III
 

Recommend
DOOM III


分析:

这题的数据量是5w, 也就是传统意义上的n^2算法是不可取的。这里就用到了FFT加快计算


FFT一般的作用就是使得多项式乘法的复杂度降到nlogn。利用FFT可以快速求出循环卷积。

那么卷积又是什么样一个东西。

----------------------------------------以下内容转自http://blog.sina.com.cn/s/blog_6733026501019ubf.html--------------------

信号处理中的一个重要运算是卷积.初学卷积的时候,往往是在连续的情形,
  两个函数f(x),g(x)的卷积,是∫f(u)g(x-u)du
  当然,证明卷积的一些性质并不困难,比如交换,结合等等,但是对于卷积运算的来处,初学者就不甚了了。
  
  其实,从离散的情形看卷积,或许更加清楚,
  对于两个序列f[n],g[n],一般可以将其卷积定义为s[x]= ∑f[k]g[x-k]
  
  卷积的一个典型例子,其实就是初中就学过的多项式相乘的运算,
  比如(x*x+3*x+2)(2*x+5)
  一般计算顺序是这样,
  (x*x+3*x+2)(2*x+5)
  = (x*x+3*x+2)*2*x+(x*x+3*x+2)*5
  = 2*x*x*x+3*2*x*x+2*2*x+ 5*x*x+3*5*x+10
  然后合并同类项的系数,
  2 x*x*x
  3*2+1*5 x*x
  2*2+3*5 x
  2*5
  ----------
  2*x*x*x+11*x*x+19*x+10
  
  实际上,从线性代数可以知道,多项式构成一个向量空间,其基底可选为
  {1,x,x*x,x*x*x,...}
  如此,则任何多项式均可与无穷维空间中的一个坐标向量相对应,
  如,(x*x+3*x+2)对应于
  (1 3 2),
  (2*x+5)对应于
  (2,5).
  
  线性空间中没有定义两个向量间的卷积运算,而只有加法,数乘两种运算,而实际上,多项式的乘法,就无法在线性空间中说明.可见线性空间的理论多么局限了.
  但如果按照我们上面对向量卷积的定义来处理坐标向量,
  (1 3 2)*(2 5)
  则有
  2 3 1
  _ _ 2 5
  --------
      2
  
  
  2 3 1
  _ 2 5
  -----
    6+5=11
  
  2 3 1
  2 5
  -----
  4+15 =19
  
  
  _ 2 3 1
  2 5
  -------
    10
  
   或者说,
  (1 3 2)*(2 5)=(2 11 19 10)
  
  回到多项式的表示上来,
  (x*x+3*x+2)(2*x+5)= 2*x*x*x+11*x*x+19*x+10
  
  似乎很神奇,结果跟我们用传统办法得到的是完全一样的.
  换句话,多项式相乘,相当于系数向量的卷积.
  
  其实,琢磨一下,道理也很简单,
  卷积运算实际上是分别求 x*x*x ,x*x,x,1的系数,也就是说,他把加法和求和杂合在一起做了。(传统的办法是先做乘法,然后在合并同类项的时候才作加法)
  以x*x的系数为例,得到x*x,或者是用x*x乘5,或者是用3x乘2x,也就是
  2 3 1
  _ 2 5
  -----
   6+5=11
  其实,这正是向量的内积.如此则,卷积运算,可以看作是一串内积运算.既然是一串内积运算,则我们可以试图用矩阵表示上述过程。
  
  [ 2 3 1 0 0 0]
  [ 0 2 3 1 0 0]==A
  [ 0 0 2 3 1 0]
  [ 0 0 0 2 3 1]
  
  [0 0 2 5 0 0]' == x
  
  b= Ax=[ 2 11 19 10]'
  
  采用行的观点看Ax,则b的每行都是一个内积。
  A的每一行都是序列[2 3 1]的一个移动位置。
  
  ---------
  
  显然,在这个特定的背景下,我们知道,卷积满足交换,结合等定律,因为,众所周知的,多项式的乘法满足交换律,结合律.在一般情形下,其实也成立.
  
  在这里,我们发现多项式,除了构成特定的线性空间外,基与基之间还存在某种特殊的联系,正是这种联系,给予多项式空间以特殊的性质.
  
  在学向量的时候,一般都会举这个例子,甲有三个苹果,5个橘子,乙有5个苹果,三个橘子,则共有几个苹果,橘子。老师反复告诫,橘子就是橘子,苹果就是苹果,可不能混在一起。所以有(3,5)+(5,3)=(8,8).是的,橘子和苹果无论怎么加,都不会出什么问题的,但是,如果考虑橘子乘橘子,或者橘子乘苹果,这问题就不大容易说清了。
  
  又如复数,如果仅仅定义复数为数对(a,b),仅仅在线性空间的层面看待C2,那就未免太简单了。实际上,只要加上一条(a,b)*(c,d)=(ac-bd,ad+bc)
  则情况马上改观,复变函数的内容多么丰富多彩,是众所周知的。
  
  另外,回想信号处理里面的一条基本定理,频率域的乘积,相当于时域或空域信号的卷积.恰好跟这里的情形完全对等.这后面存在什么样的隐态联系,需要继续参详.
  
  从这里看,高等的卷积运算其实不过是一种初等的运算的抽象而已.中学学过的数学里面,其实还蕴涵着许多高深的内容(比如交换代数)。温故而知新,斯言不谬.
  
  其实这道理一点也不复杂,人类繁衍了多少万年了,但过去n多年,人们只知道男女媾精,乃能繁衍后代。精子,卵子的发现,生殖机制的研究,也就是最近多少年的事情。
  
  孔子说,道在人伦日用中,看来我们应该多用审视的眼光看待周围,乃至自身,才能知其然,而知其所以然。


----------------------------------------------------------完毕------------------------------


然后我们就知道卷积大概的作用了。

那么FFT本来是信号里面的东西,而我没学过信号。 所以看的也不怎么懂。

大概就是对离散的信号,先将其转变为一些正弦函数,然后这些正弦函数叠加能构成这个离散信号,但是这些正弦函数易于处理。处理完之后就可以再转变回来。

两个过程叫做DFT和IDFT。


如果令上面的x=10

那么就可以把两个大整数相乘看做是多项式乘法。

最后求出各系数后再进位即可


代码一:

[cpp] view plain copy
  1. #include <iostream>  
  2. #include <cstdio>  
  3. #include <algorithm>  
  4. #include <cstring>  
  5. #include <cmath>  
  6. #include <map>  
  7. #include <queue>  
  8. #include <set>  
  9. #include <vector>  
  10. using namespace std;  
  11. #define L(x) (1 << (x))  
  12. const double PI = acos(-1.0);  
  13. const int Maxn = 133015;  
  14. double ax[Maxn], ay[Maxn], bx[Maxn], by[Maxn];  
  15. char sa[Maxn/2],sb[Maxn/2];  
  16. int sum[Maxn];  
  17. int x1[Maxn],x2[Maxn];  
  18.   
  19. int revv(int x, int bits){  
  20.   int ret = 0;  
  21.   for (int i = 0; i < bits; i++){  
  22.     ret <<= 1;  
  23.     ret |= x & 1;  
  24.     x >>= 1;  
  25.   }  
  26.   return ret;  
  27. }  
  28.   
  29. void fft(double * a, double * b, int n, bool rev){  
  30.   int bits = 0;  
  31.   while (1 << bits < n) ++bits;  
  32.   for (int i = 0; i < n; i++){  
  33.     int j = revv(i, bits);  
  34.     if (i < j)  
  35.       swap(a[i], a[j]), swap(b[i], b[j]);  
  36.   }  
  37.   for (int len = 2; len <= n; len <<= 1){  
  38.     int half = len >> 1;  
  39.     double wmx = cos(2 * PI / len), wmy = sin(2 * PI / len);  
  40.     if (rev) wmy = -wmy;  
  41.     for (int i = 0; i < n; i += len){  
  42.       double wx = 1, wy = 0;  
  43.       for (int j = 0; j < half; j++){  
  44.         double cx = a[i + j], cy = b[i + j];  
  45.         double dx = a[i + j + half], dy = b[i + j + half];  
  46.         double ex = dx * wx - dy * wy, ey = dx * wy + dy * wx;  
  47.         a[i + j] = cx + ex, b[i + j] = cy + ey;  
  48.         a[i + j + half] = cx - ex, b[i + j + half] = cy - ey;  
  49.         double wnx = wx * wmx - wy * wmy, wny = wx * wmy + wy * wmx;  
  50.         wx = wnx, wy = wny;  
  51.       }  
  52.     }  
  53.   }  
  54.     if (rev)  
  55.     {  
  56.         for (int i = 0; i < n; i++)  
  57.             a[i] /= n, b[i] /= n;  
  58.     }  
  59. }  
  60.   
  61. int solve(int a[],int na,int b[],int nb,int ans[]){  
  62.   int len = max(na, nb), ln;  
  63.   for(ln=0; L(ln)<len; ++ln);  
  64.   len=L(++ln);  
  65.   for(int i = 0; i < len ; ++i){  
  66.     if (i >= na) ax[i] = 0, ay[i] =0;  
  67.     else ax[i] = a[i], ay[i] = 0;  
  68.   }  
  69.   fft(ax, ay, len, 0);  
  70.   for (int i = 0; i < len; ++i){  
  71.     if (i >= nb) bx[i] = 0, by[i] = 0;  
  72.     else bx[i] = b[i], by[i] = 0;  
  73.   }  
  74.   fft(bx, by, len, 0);  
  75.   for (int i = 0; i < len; ++i){  
  76.     double cx = ax[i] * bx[i] - ay[i] * by[i];  
  77.     double cy = ax[i] * by[i] + ay[i] * bx[i];  
  78.     ax[i] = cx, ay[i] = cy;  
  79.   }  
  80.   fft(ax, ay, len, 1);  
  81.   for(int i = 0; i < len; ++i)  
  82.     ans[i] = (int)(ax[i] + 0.5);  
  83.   return len;  
  84. }  
  85.   
  86. int main(){  
  87.   int l1,l2,l;  
  88.   int i;  
  89.   while(gets(sa)){  
  90.     gets(sb);  
  91.     memset(sum, 0, sizeof(sum));  
  92.     l1 = strlen(sa);  
  93.     l2 = strlen(sb);  
  94.     for(i = 0; i < l1; i++)  
  95.       x1[i] = sa[l1-i-1]-'0';  
  96.     for(i = 0; i < l2; i++)  
  97.       x2[i] = sb[l2-i-1]-'0';  
  98.     l = solve(x1, l1, x2, l2, sum);  
  99.     for(i = 0; i<l || sum[i] >= 10; i++){ // 进位  
  100.       sum[i + 1] += sum[i] / 10;  
  101.       sum[i] %= 10;  
  102.     }  
  103.     l = i;  
  104.     while(sum[l] <= 0 && l>0)   l--; // 检索最高位  
  105.     for(i = l; i >= 0; i--)   putchar(sum[i] + '0'); // 倒序输出  
  106.     putchar('\n');  
  107.   }  
  108.   return 0;  
  109. }  

乘法其实就是做线性卷积。

用DFT得方法可以求循环卷积,但是当循环卷积长度LN+M-1,就可以做线性卷积了。

使用FFT将两个数列转换成傅里叶域,在这的乘积就是时域的卷积。

 

给几个学习的链接吧:

http://wenku.baidu.com/view/8bfb0bd476a20029bd642d85.html  (这主要看那个FFT的流程图

http://wlsyzx.yzu.edu.cn/kcwz/szxhcl/kechenneirong/jiaoan/jiaoan3.htm   这有DFT的原理。

(得好好研究研究_(:зゝ∠)_)

代码二:

(bin神模板)

[cpp] view plain copy
  1. #include <stdio.h>  
  2. #include <string.h>  
  3. #include <iostream>  
  4. #include <algorithm>  
  5. #include <math.h>  
  6. using namespace std;  
  7.   
  8. const double PI = acos(-1.0);  
  9. //复数结构体  
  10. struct complex{  
  11.   double r,i;  
  12.   complex(double _r = 0.0,double _i = 0.0){  
  13.     r = _r; i = _i;  
  14.   }  
  15.   complex operator +(const complex &b){  
  16.     return complex(r+b.r,i+b.i);  
  17.   }  
  18.   complex operator -(const complex &b){  
  19.     return complex(r-b.r,i-b.i);  
  20.   }  
  21.   complex operator *(const complex &b){  
  22.     return complex(r*b.r-i*b.i,r*b.i+i*b.r);  
  23.   }  
  24. };  
  25. /* 
  26.  * 进行FFT和IFFT前的反转变换。 
  27.  * 位置i和 (i二进制反转后位置)互换 
  28.  * len必须去2的幂 
  29.  */  
  30. void change(complex y[],int len){  
  31.   int i,j,k;  
  32.   for(i = 1, j = len/2;i < len-1; i++){  
  33.     if(i < j)swap(y[i],y[j]);  
  34.     //交换互为小标反转的元素,i<j保证交换一次  
  35.     //i做正常的+1,j左反转类型的+1,始终保持i和j是反转的  
  36.     k = len/2;  
  37.     while( j >= k){  
  38.       j -= k;  
  39.       k /= 2;  
  40.     }  
  41.      if(j < k) j += k;  
  42.   }  
  43. }  
  44. /* 
  45.  * 做FFT 
  46.  * len必须为2^k形式, 
  47.  * on==1时是DFT,on==-1时是IDFT 
  48.  */  
  49. void fft(complex y[],int len,int on){  
  50.   change(y,len);  
  51.   for(int h = 2; h <= len; h <<= 1){  
  52.     complex wn(cos(-on*2*PI/h),sin(-on*2*PI/h));  
  53.     for(int j = 0;j < len;j+=h){  
  54.       complex w(1,0);  
  55.       for(int k = j;k < j+h/2;k++){  
  56.         complex u = y[k];  
  57.         complex t = w*y[k+h/2];  
  58.         y[k] = u+t;  
  59.         y[k+h/2] = u-t;  
  60.         w = w*wn;  
  61.       }  
  62.     }  
  63.   }  
  64.   if(on == -1)  
  65.   for(int i = 0;i < len;i++)  
  66.     y[i].r /= len;  
  67. }  
  68.   
  69. const int maxn = 200010;  
  70. complex x1[maxn],x2[maxn];  
  71. char str1[maxn/2],str2[maxn/2];  
  72. int sum[maxn];  
  73. int main(){  
  74.   while(scanf("%s%s",str1,str2)==2){  
  75.     int len1 = strlen(str1);  
  76.     int len2 = strlen(str2);  
  77.     int len = 1;  
  78.     while(len < len1*2 || len < len2*2)len<<=1;  
  79.     for(int i = 0;i < len1;i++)  
  80.       x1[i] = complex(str1[len1-1-i]-'0',0);  
  81.     for(int i = len1;i < len;i++)  
  82.       x1[i] = complex(0,0);  
  83.     for(int i = 0;i < len2;i++)  
  84.       x2[i] = complex(str2[len2-1-i]-'0',0);  
  85.     for(int i = len2;i < len;i++)  
  86.       x2[i] = complex(0,0);  
  87.     //求DFT  
  88.     fft(x1,len,1);  
  89.     fft(x2,len,1);  
  90.     for(int i = 0;i < len;i++)  
  91.       x1[i] = x1[i]*x2[i];  
  92.     fft(x1,len,-1);  
  93.     for(int i = 0;i < len;i++)  
  94.       sum[i] = (int)(x1[i].r+0.5);  
  95.     for(int i = 0;i < len;i++){   // 进位  
  96.       sum[i+1]+=sum[i]/10;  
  97.       sum[i]%=10;  
  98.     }  
  99.     len = len1+len2-1;  
  100.     while(sum[len] <= 0 && len > 0)len--;// 检索最高位  
  101.     for(int i = len;i >= 0;i--)  // 倒序输出  
  102.       printf("%c",sum[i]+'0');  
  103.     printf("\n");  
  104.   }  
  105.   return 0;  
  106. }