您的位置:首页 > 其它

hdu 5730(分治FFT)

2016-07-23 19:50 337 查看
dp[i] = sigma(a[j] * dp[i-j]) (j < i);

这里

#include <bits/stdc++.h>

#pragma comment(linker, "/STACK:102400000,102400000")
using namespace std;

#define LL long long
#define pii pair<int,int>
#define MP make_pair
#define ls i << 1
#define rs ls | 1
#define md (ll + rr >> 1)
#define lson ll, md, ls
#define rson md + 1, rr, rs
#define Pi acos(-1.0)
#define mod 1000000007
#define eps 1e-12
#define inf 0x3f3f3f3f
#define N 200010
#define M 1200020

struct Complex{
double r, i;
Complex(double r = 0, double i = 0) : r(r), i(i) {}
Complex operator + (const Complex &b) const {
return Complex(r + b.r, i + b.i);
}
Complex operator - (const Complex &b) const {
return Complex(r - b.r, i - b.i);
}
Complex operator * (const Complex &b) const {
return Complex(r * b.r - i * b.i, r * b.i + b.r * i);
}
};
void change(Complex *y, int len){
for(int i = 1, j = len >> 1; i < len - 1; ++i){
if(i < j) swap(y[i], y[j]);
int k = len >> 1;
while(j >= k)
j -= k, k >>= 1;
if(j < k) j += k;
}
}
void FFT(Complex *y, int len, int on){
change(y, len);
for(int h = 2; h <= len; h <<= 1){
Complex wn = Complex(cos(-on*2*Pi/h), sin(-on*2*Pi/h));
for(int j = 0; j < len; j += h){
Complex w = Complex(1, 0);
for(int k = j; k < j + h / 2; ++k){
Complex u = y[k];
Complex t = w * y[k+h/2];
y[k] = u + t;
y[k+h/2] = u - t;
w = w * wn;
}
}
}
if(on == -1){
for(int i = 0; i < len; ++i)
y[i].r /= len;
}
}
int dp
, a
;
Complex x[N<<2], y[N<<2];
void solve(int L, int R){
if(L == R) return ;
int mid = (L + R) >> 1;
solve(L, mid);
int len = 1, len1 = R - L + 1;
while(len <= len1) len <<= 1;
for(int i = 0; i < len1; ++i) x[i] = Complex(a[i], 0);
for(int i = len1; i <= len; ++i) x[i] = Complex(0, 0);
for(int i = L; i <= mid; ++i)
y[i-L] = Complex(dp[i], 0);
for(int i = mid - L + 1; i <= len; ++i) y[i] = Complex(0, 0);
FFT(x, len, 1);
FFT(y, len, 1);
for(int i = 0; i < len; ++i) x[i] = x[i] * y[i];
FFT(x, len, -1);
for(int i = mid + 1; i <= R; ++i){
dp[i] += x[i-L].r + 0.5;
dp[i] %= 313;
}
solve(mid + 1, R);
}
int main(){
int n;
while(scanf("%d", &n) != EOF && n){
for(int i = 1; i <= n; ++i)
scanf("%d", &a[i]), a[i] %= 313;
for(int i = 1; i <= n; ++i) dp[i] = 0;
dp[0] = 1;
solve(0, n);
printf("%d\n", dp
);
}
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: