您的位置:首页 > 其它

FFT之大数乘法

2016-01-06 15:34 337 查看
1 #include <iostream>
2 #include <stdio.h>
3 #include <cmath>
4 #include <algorithm>
5 #include <cstring>
6 #include <vector>
7 using namespace std;
8 #define N 50500*2
9 const double PI = acos(-1.0);
10 struct Vir
11 {
12     double re, im;
13     Vir(double _re = 0., double _im = 0.) :re(_re), im(_im){}
14     Vir operator*(Vir r) { return Vir(re*r.re - im*r.im, re*r.im + im*r.re); }
15     Vir operator+(Vir r) { return Vir(re + r.re, im + r.im); }
16     Vir operator-(Vir r) { return Vir(re - r.re, im - r.im); }
17 };
18 void bit_rev(Vir *a, int loglen, int len)
19 {
20     for (int i = 0; i < len; ++i)
21     {
22         int t = i, p = 0;
23         for (int j = 0; j < loglen; ++j)
24         {
25             p <<= 1;
26             p = p | (t & 1);
27             t >>= 1;
28         }
29         if (p < i)
30         {
31             Vir temp = a[p];
32             a[p] = a[i];
33             a[i] = temp;
34         }
35     }
36 }
37 void FFT(Vir *a, int loglen, int len, int on)
38 {
39     bit_rev(a, loglen, len);
40
41     for (int s = 1, m = 2; s <= loglen; ++s, m <<= 1)
42     {
43         Vir wn = Vir(cos(2 * PI*on / m), sin(2 * PI*on / m));
44         for (int i = 0; i < len; i += m)
45         {
46             Vir w = Vir(1.0, 0);
47             for (int j = 0; j < m / 2; ++j)
48             {
49                 Vir u = a[i + j];
50                 Vir v = w*a[i + j + m / 2];
51                 a[i + j] = u + v;
52                 a[i + j + m / 2] = u - v;
53                 w = w*wn;
54             }
55         }
56     }
57     if (on == -1)
58     {
59         for (int i = 0; i < len; ++i) a[i].re /= len, a[i].im /= len;
60     }
61 }
62 char a[N * 2], b[N * 2];
63 Vir pa[N * 2], pb[N * 2];
64 int ans[N * 2];
65 int main()
66 {
67     while (scanf("%s%s", a, b) != EOF)
68     {
69         int lena = strlen(a);
70         int lenb = strlen(b);
71         int n = 1, loglen = 0;
72         while (n < lena + lenb) n <<= 1, loglen++;
73         for (int i = 0, j = lena - 1; i < n; ++i, --j)
74             pa[i] = Vir(j >= 0 ? a[j] - '0' : 0., 0.);
75         for (int i = 0, j = lenb - 1; i < n; ++i, --j)
76             pb[i] = Vir(j >= 0 ? b[j] - '0' : 0., 0.);
77         for (int i = 0; i <= n; ++i) ans[i] = 0;
78
79         FFT(pa, loglen, n, 1);
80         FFT(pb, loglen, n, 1);
81         for (int i = 0; i < n; ++i)
82             pa[i] = pa[i] * pb[i];
83         FFT(pa, loglen, n, -1);
84
85         for (int i = 0; i < n; ++i) ans[i] = pa[i].re + 0.5;
86         for (int i = 0; i<n; ++i) ans[i + 1] += ans[i] / 10, ans[i] %= 10;
87
88         int pos = lena + lenb - 1;
89         for (; pos>0 && ans[pos] <= 0; --pos);
90         for (; pos >= 0; --pos) printf("%d", ans[pos]);
91         puts("");
92     }
93     return 0;
94 }


 
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: