您的位置:首页 > 产品设计 > UI/UE

[Bluestein's Algorithm][DFT] CodeChef REALSET

2018-03-01 21:42 405 查看

SolutionSolution

由F(b)≠0,F(a∗b)=0F(b)≠0,F(a∗b)=0可以得到F(a)F(a)至少有一项为00。

考虑aa向量的DFTDFT。

Bluestein's AlgorithmBluestein's Algorithm:Xk===∑i=0n−1xiωik∑i=0n−1xiω−(k−i)2+i2+k22w−k22∑i=0n−1xiω−i2ω−(k−i)22Xk=∑i=0n−1xiωik=∑i=0n−1xiω−(k−i)2+i2+k22=w−k22∑i=0n−1xiω−i2ω−(k−i)22这个ωω怎么找呢。根据上面的式子如果使用NTTNTT,只要找到一个质数PP,满足2n∣(P−1)2n∣(P−1),则ω=gP−12nω=gP−12n。

然而这个PP并不能FNTFNT,所以我们要找两个FNTFNT模数P0,P1P0,P1,使得P0×P1≥nP2P0×P1≥nP2。最后使用CRTCRT合并到模P0×P1P0×P1意义下,再模PP就好了。

Bluestein's AlgorithmBluestein's Algorithm实际上就是用卷积来实现DFTDFT,时间复杂度O(nlogn)O(nlog⁡n)。

#include <bits/stdc++.h>
#define show(x) cerr << #x << " = " << x << endl
using namespace std;
typedef long long ll;
typedef pair<int, int> Pairs;
typedef complex<double> com;

const int N = 404040;
const double PI = acos(-1);

inline char get(void) {
static char buf[100000], *S = buf, *T = buf;
if (S == T) {
T = (S = buf) + fread(buf, 1, 100000, stdin);
if (S == T) return EOF;
}
return *S++;
}
template<typename T>
inline void read(T &x) {
static char c; x = 0; int sgn = 0;
for (c = get(); c < '0' || c > '9'; c = get()) if (c == '-') sgn = 1;
for (; c >= '0' && c <= '9'; c = get()) x = x * 10 + c - '0';
if (sgn) x = -x;
}

inline int pwr(int a, int b, int P) {
int c = 1;
while (b) {
if (b & 1) c = (ll)c * a % P;
b >>= 1; a = (ll)a * a % P;
}
return c;
}
inline ll mul(ll a, ll b, ll P) {
ll c = 0;
while (b) {
if (b & 1) c = (c + a) % P;
b >>= 1; a = (a + a) % P;
}
return c;
}
inline int inv(int x, int MOD) {
return pwr(x, MOD - 2, MOD);
}

int rev
;
int num;

struct FFTtool {
int MOD, G;
int ww
, iw
;

inline void pre(int n) {
num = n;
int g = pwr(G, (MOD - 1) / n, MOD);
ww[0] = iw[0] = 1;
for (int i = 1; i < n; i++)
iw[n - i] = ww[i] = (ll)ww[i - 1] * g % MOD;
}
inline void fft(int *a, int n, int f) {
int x, y;
int *w = (f == 1) ? ww : iw;
for (int i = 0; i < n; i++)
if (rev[i] > i)
swap(a[i], a[rev[i]]);
for (int i = 1; i < n; i <<= 1)
for (int j = 0; j < n; j += (i << 1))
for (int k = 0; k < i; k++) {
x = a[j + k];
y = (ll)a[j + k + i] * w[num / (i << 1) * k] % MOD;
a[j + k] = (x + y) % MOD;
//a[j + k + i] = 992018338;
a[j + k + i] = (x - y + MOD) % MOD;
}
}
inline void conv(int *a, int *b, int *c, int m) {
static int res
;
fft(a, m, 1); fft(b, m, 1);
for (int i = 0; i < m; i++)
res[i] = (ll)a[i] * b[i] % MOD;
fft(res, m, -1);
int invn = inv(m, MOD);
for (int i = 0; i < m; i++)
c[i] = (ll)res[i] * invn % MOD;
}
} S[2];

int A[2]
, B[2]
;
inline void dft(int *a, int n, int P, int g) {
static int A
, B
;
int m, L, w = pwr(g, (P - 1) / n / 2, P);
for (m = 1, L = 0; m < n; m <<= 1) ++L;
m <<= 1;
for (int i = 0; i < m; i++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << L);
for (int i = 0; i < m; i++)
A[i] = B[i] = 0;
for (int i = 0; i < n; i++) {
A[i] = (ll)(a[i] + P) * pwr(w, (ll)i * i % (2 * n), P) % P;
B[i] = pwr(w, 2 * n - (ll)i * i % (2 * n), P);
B[m - i] = B[i];
}
for (int d = 0; d < 2; d++) {
for (int i = 0; i < m; i++) {
::A[d][i] = A[i];
::B[d][i] = B[i];
}
S[d].conv(::A[d], ::B[d], ::A[d], m);
}
int p0 = S[0].MOD, p1 = S[1].MOD;
int t0 = inv(p1, p0), t1 = inv(p0, p1);
ll M = (ll)p0 * p1;
for (int i = 0; i < m; i++) {
int a0 = ::A[0][i], a1 = ::A[1][i];
a[i] = (mul((ll)a0 * t0, p1, M) + mul((ll)a1 * t1, p0, M)) % M % P;
}
}

inline int getG(int P) {
static vector<int> fac;
int phi = P - 1, m = sqrt(phi) + 1;
fac.clear();
for (int i = 2; i <= m; i++)
if (phi % i == 0) {
fac.push_back(i);
fac.push_back(phi / i);
}
for (int g = 2; ; g++) {
int ord = 1;
for (int d: fac)
if (pwr(g, d, P) == 1) { ord = 0; break;}
if (ord) return g;
}
}
inline int isPrime(int P) {
int m = sqrt(P) + 1;
for (int i = 2; i <= m; i++)
if (P % i == 0) return 0;
return 1;
}

int test, n, m, L, x;
int a
;

int main(void) {
freopen("1.in", "r", stdin);
freopen("1.out", "w", stdout);
read(test);
S[0].MOD = 998244353; S[0].G = 3;
S[1].MOD = 1005060097; S[1].G = 5;
S[0].pre(1 << 18);
S[1].pre(1 << 18);
while (test--) {
read(n);
for (int i = 0; i < n; i++) read(a[i]);
int P = 10000000 / (2 * n) * (2 * n) + 1;
while (!isPrime(P)) P += 2 * n;
int G = getG(P);
dft(a, n, P, G);
for (int i = 0; i < n; i++)
if (a[i] == 0) {
printf("YES\n"); break;
} else if (i == n - 1) {
printf("NO\n");
}
}
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: