您的位置:首页 > 其它

FFT 模板

2016-05-17 21:53 253 查看

FFT(Fast Fourier Transformation/快速傅立叶变换),确切地说应该称之为 FDFT(Fast Discrete Fourier Transformation/快速离散傅立叶变换),因为 FFT 是为 DFT 问题而设计的一种快速算法。在深入讨论之前,有必要特别指出这一点。

DFT 问题:

给定一个复数域上的 $n-1$ 次多项式 $A(x)$ 的系数表示(cofficient representation)$(a_0, a_1,\dots, a_{n-1})$,求 $A(x)$ 的某个点-值表示(point-value representation):

\[((x_0, y_0), (x_1, y_1), (x_2, y_2), \dots, (x_{n-1}, y_{n-1}))\]

 

注意:数学上(或者信号处理上)所谓的离散傅里叶变换的定义跟这里的定义是不同的,我会补充这部分内容。 

 

FFT 的递归实现(《算法导论》)

 

#include <bits/stdc++.h>
#define rep(i, l, r) for(int i=l; i<r; i++)
using namespace std;
const double PI(acos(-1));

struct Complex{
double r, i;
Complex(double r, double i):r(r), i(i){}
Complex(int n):r(cos(2*PI/n)), i(sin(2*PI/n)){}    //!!error-prone
Complex():r(0), i(0){}    //default constructor
Complex &operator*=(const Complex &a){
double R=r*a.r-i*a.i, I=r*a.i+a.r*i;
r=R, i=I;
return *this;
}
Complex operator+(const Complex a){
return Complex(r+a.r, i+a.i);
}
Complex operator-(const Complex a){
return Complex(r-a.r, i-a.i);
}
Complex operator*(const Complex a){
return Complex(r*a.r-i*a.i, r*a.i+a.r*i);
}
void out(){
cout<<r<<' '<<i<<endl;
}
};

const int N(1<<17);
int ans
;
Complex a
, b
;
char s
, t
;

void bit_revrese_swap(Complex *a, int n){
for(int i=1, j=n>>1, k; i<n-1; i++){
if(i < j) swap(a[i],a[j]);
//tricky
for(k=n>>1; j>=k; j-=k, k>>=1);    //inspect the highest "1"
j+=k;
}
}

void FFT(Complex* a, int n, int t){
bit_revrese_swap(a, n);
for(int i=2; i<=n; i<<=1){
Complex wi(i*t);
for(int j=0; j<n; j+=i){
Complex w(1, 0);
for(int k=j, h=i>>1; k<j+h; k++){
Complex t=w*a[k+h], u=a[k];
a[k]=u+t;
a[k+h]=u-t;
w*=wi;
}
}
}
if(t==-1) rep(i, 0, n) a[i].r/=n;    //!!error-prone
}

int trans(int x){
int i=0;
for(; x>1<<i; i++);
return 1<<i;
}

int main(){
for(; ~scanf("%s%s", s, t); ){
int n=strlen(s), m=strlen(t), l=trans(n+m-1);
rep(i, 0, n) a[i]=Complex(s[n-1-i]-'0', 0);
rep(i, n, l) a[i]=Complex(0, 0);
rep(i, 0, m) b[i]=Complex(t[m-1-i]-'0', 0);
rep(i, m, l) b[i]=Complex(0, 0);

FFT(a, l, 1), FFT(b, l, 1);
rep(i, 0, l) a[i]*=b[i];
FFT(a, l, -1);
rep(i, 0, l) ans[i]=(int)(a[i].r+0.5); ans[l]=0;    //error-prone
rep(i, 0, l) ans[i+1]+=ans[i]/10, ans[i]%=10;
int c=l;
for(;c && !ans[c]; --c);
for(; ~c; putchar(ans[c--]+'0'));    //error-prone
puts("");
}
return 0;
}

 

 

 

Comment:

1. 此代码是为 HDU1402 写的。代码中,凡注释error-prone处,都应特别小心。我犯的最傻逼的错误是第9行,应当是2*PI,我写成PI了。

2. 这个FFT的递归实现完全是参照《算法导论》(3rd. Ed. Chapter 30, Polynomials and the FFT) 的,但这个实现常数大,空间复杂度高,TLE了

3. 迭代实现后面会补上

4. FFT的数值稳定性(精度)问题,还有待考虑。(UPD)多次做多项式乘法时,精度损失较快,这时将double 换成long double可缓解精度损失。

 

FFT的非递归实现(算法导论)

Version I: 手写Complex类

#include <bits/stdc++.h>
#define rep(i, l, r) for(int i=l; i<r; i++)
using namespace std;
const double PI(acos(-1));

struct Complex{
double r, i;
Complex(double r, double i):r(r), i(i){}
Complex(int n):r(cos(2*PI/n)), i(sin(2*PI/n)){}    //!!error-prone
Complex():r(0), i(0){}    //default constructor
Complex &operator *=(const Complex &a){
double R=r*a.r-i*a.i, I=r*a.i+a.r*i;
r=R, i=I;
return *this;
}
Complex operator+(const Complex a){
return Complex(r+a.r, i+a.i);
}
Complex operator-(const Complex a){
return Complex(r-a.r, i-a.i);
}
Complex operator*(const Complex a){
return Complex(r*a.r-i*a.i, r*a.i+a.r*i);
}
void out(){
cout<<r<<' '<<i<<endl;
}
};

const int N(1<<17);
int ans
;
Complex a
, b
;
char s
, t
;

void bit_revrese_swap(Complex *a, int n){
for(int i=1, j=n>>1, k; i<n-1; i++){
if(i < j) swap(a[i],a[j]);
//tricky
for(k=n>>1; j>=k; j-=k, k>>=1);    //inspect the highest "1"
j+=k;
}
}

void FFT(Complex* a, int n, int t){
bit_revrese_swap(a, n);
for(int i=2; i<=n; i<<=1){
Complex wi(i*t);
for(int j=0; j<n; j+=i){
Complex w(1, 0);
for(int k=j, h=i>>1; k<j+h; k++){
Complex t=w*a[k+h], u=a[k];
a[k]=u+t;
a[k+h]=u-t;
w*=wi;
}
}
}
if(t==-1) rep(i, 0, n) a[i].r/=n;    //!!error-prone
}

int trans(int x){
int i=0;
for(; x>1<<i; i++);
return 1<<i;
}

int main(){
for(; ~scanf("%s%s", s, t); ){
int n=strlen(s), m=strlen(t), l=trans(n+m-1);
rep(i, 0, n) a[i]=Complex(s[n-1-i]-'0', 0);
rep(i, n, l) a[i]=Complex(0, 0);
rep(i, 0, m) b[i]=Complex(t[m-1-i]-'0', 0);
rep(i, m, l) b[i]=Complex(0, 0);

FFT(a, l, 1), FFT(b, l, 1);
rep(i, 0, l) a[i]*=b[i];
FFT(a, l, -1);
rep(i, 0, l) ans[i]=(int)(a[i].r+0.5); ans[l]=0;    //error-prone
rep(i, 0, l) ans[i+1]+=ans[i]/10, ans[i]%=10;
int Complex=l;
for(;Complex && !ans[Complex]; --Complex);
for(; ~Complex; putchar(ans[Complex--]+'0'));    //error-prone
puts("");
}
return 0;
}

Commet:

1. bit_reverse_swap()函数是对算法导论上的bit_reverse_copy()的改进,将下标互为bit-reverse的两元素互换位置,就免去了copy所需的空间。

2.bit_revrese_copy()不太好懂,需要一点解释:

void bit_revrese_swap(Complex *a, int n){
for(int i=1, j=n>>1, k; i<n-1; i++){
if(i < j) swap(a[i],a[j]);
//tricky
for(k=n>>1; j>=k; j-=k, k>>=1);    //inspect the highest "1"
j+=k;
}
}

将$i$的bit-reverse记作$rev(i)$。

(i). 由于$rev(0)=1, rev(n-1)=n-1$($n$是$2$的幂),所以第2行的主循环可从$i$从$1$循环到$n-2$。同时$j$从$rev(1)= \frac{n}{2}$,“循环”到$rev(n-2)$。

(ii). 第3行的判断 if(i < j) 避免了重复交换

(iii).第5行的循环的作用就是将$j$从$rev(i)$变成$rev(i+1)$:

首先应当注意到,$i$的最低位恰是$rev(i)$的最高位。若$rev(i)$的最高位是$0$那么$rev(i+1)$就是$rev(i)+ \frac{n}{2}$,否则,$i$加上$1$后,最低位将变成$0$,并且向高一位进$1$。相应的,$rev(i+1)$的最高位应置$0$(即代码中的 j-=k),并且向低一位"进“$1$(对应代码中的 k>>=1)。这样从高位往低位检查,遇到$1$(对应代码中的条件j>=k)就进位,遇到$0$就退出循环。
3. 我写代码时把第58行的==写成了=,结果DEBUG 一个多小时。。。

Version II: 用C++标准库中的complex<double>类,代码短一些,但也会慢一些:

#include <bits/stdc++.h>
#define rep(i, l, r) for(int i=l; i<r; i++)
using namespace std;
const double PI(acos(-1));
typedef complex<double> C;

const int N(1<<17);
int ans
;
C a
, b
;
char s
, t
;

void bit_revrese_swap(C *a, int n){
for(int i=1, j=n>>1, k; i<n-1; i++){
if(i < j) swap(a[i],a[j]);
//tricky
for(k=n>>1; j>=k; j-=k, k>>=1);    //inspect the highest "1"
j+=k;
}
}

void FFT(C* a, int n, int t){
bit_revrese_swap(a, n);
for(int i=2; i<=n; i<<=1){
C wi(cos(t*2*PI/i), sin(t*2*PI/i));
for(int j=0; j<n; j+=i){
C w(1);
for(int k=j, h=i>>1; k<j+h; k++){
C t=w*a[k+h], u=a[k];
a[k]=u+t;
a[k+h]=u-t;
w*=wi;
}
}
}
if(t==-1) rep(i, 0, n) a[i]/=n;    //!!error-prone: typo ==/=
}

int trans(int x){
int i=0;
for(; x>1<<i; i++);
return 1<<i;
}

int main(){
for(; ~scanf("%s%s", s, t); ){
int n=strlen(s), m=strlen(t), l=trans(n+m-1);
rep(i, 0, n) a[i]=C(s[n-1-i]-'0');
rep(i, n, l) a[i]=C(0);
rep(i, 0, m) b[i]=C(t[m-1-i]-'0');
rep(i, m, l) b[i]=C(0);

FFT(a, l, 1), FFT(b, l, 1);
rep(i, 0, l) a[i]*=b[i];
FFT(a, l, -1);
rep(i, 0, l) ans[i]=(int)(a[i].real()+0.5); ans[l]=0;    //error-prone
rep(i, 0, l) ans[i+1]+=ans[i]/10, ans[i]%=10;
int p=l;
for(;p && !ans

; --p); for(; ~p; putchar(ans[p--]+'0')); //error-prone puts(""); } return 0; }

[p] 

 

 

 



内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: