FFT之大数乘法

来源:互联网 发布:手机淘宝5.8.0 编辑:程序博客网 时间:2024/05/21 15:41
 1 #include <iostream> 2 #include <stdio.h> 3 #include <cmath> 4 #include <algorithm> 5 #include <cstring> 6 #include <vector> 7 using namespace std; 8 #define N 50500*2 9 const double PI = acos(-1.0);10 struct Vir11 {12     double re, im;13     Vir(double _re = 0., double _im = 0.) :re(_re), im(_im){}14     Vir operator*(Vir r) { return Vir(re*r.re - im*r.im, re*r.im + im*r.re); }15     Vir operator+(Vir r) { return Vir(re + r.re, im + r.im); }16     Vir operator-(Vir r) { return Vir(re - r.re, im - r.im); }17 };18 void bit_rev(Vir *a, int loglen, int len)19 {20     for (int i = 0; i < len; ++i)21     {22         int t = i, p = 0;23         for (int j = 0; j < loglen; ++j)24         {25             p <<= 1;26             p = p | (t & 1);27             t >>= 1;28         }29         if (p < i)30         {31             Vir temp = a[p];32             a[p] = a[i];33             a[i] = temp;34         }35     }36 }37 void FFT(Vir *a, int loglen, int len, int on)38 {39     bit_rev(a, loglen, len);40 41     for (int s = 1, m = 2; s <= loglen; ++s, m <<= 1)42     {43         Vir wn = Vir(cos(2 * PI*on / m), sin(2 * PI*on / m));44         for (int i = 0; i < len; i += m)45         {46             Vir w = Vir(1.0, 0);47             for (int j = 0; j < m / 2; ++j)48             {49                 Vir u = a[i + j];50                 Vir v = w*a[i + j + m / 2];51                 a[i + j] = u + v;52                 a[i + j + m / 2] = u - v;53                 w = w*wn;54             }55         }56     }57     if (on == -1)58     {59         for (int i = 0; i < len; ++i) a[i].re /= len, a[i].im /= len;60     }61 }62 char a[N * 2], b[N * 2];63 Vir pa[N * 2], pb[N * 2];64 int ans[N * 2];65 int main()66 {67     while (scanf("%s%s", a, b) != EOF)68     {69         int lena = strlen(a);70         int lenb = strlen(b);71         int n = 1, loglen = 0;72         while (n < lena + lenb) n <<= 1, loglen++;73         for (int i = 0, j = lena - 1; i < n; ++i, --j)74             pa[i] = Vir(j >= 0 ? a[j] - '0' : 0., 0.);75         for (int i = 0, j = lenb - 1; i < n; ++i, --j)76             pb[i] = Vir(j >= 0 ? b[j] - '0' : 0., 0.);77         for (int i = 0; i <= n; ++i) ans[i] = 0;78 79         FFT(pa, loglen, n, 1);80         FFT(pb, loglen, n, 1);81         for (int i = 0; i < n; ++i)82             pa[i] = pa[i] * pb[i];83         FFT(pa, loglen, n, -1);84 85         for (int i = 0; i < n; ++i) ans[i] = pa[i].re + 0.5;86         for (int i = 0; i<n; ++i) ans[i + 1] += ans[i] / 10, ans[i] %= 10;87 88         int pos = lena + lenb - 1;89         for (; pos>0 && ans[pos] <= 0; --pos);90         for (; pos >= 0; --pos) printf("%d", ans[pos]);91         puts("");92     }93     return 0;94 }

 

0 0
原创粉丝点击