您的位置:首页 > 其它

Gym 100341C AVL TREE(NTT快速数论变换)

2015-09-02 03:18 323 查看
题意:给定n和h,求节点数为n(n<2^18)且高度为h(h<15)的平衡二叉树有多少种(mod 3*2^18+1)。

思路:直接dp的话复杂度为O(n*h*n),其实仔细观察它的状态转移方程,可以发现它其实就是一个卷积,也就是说对于某个高度的多项式f[h]来说,它可以由f[h-1]和f[h-2]的多项式乘法得到,那么我们就可以用NTT来加速计算过程。



自己太弱做了一晚上并没有学透NTT的原理....只学会了套模板.....模板转自http://blog.csdn.net/ACdreamers/article/details/39026505

里面有详细解释,还有另一篇写的不错的博客 http://blog.miskcoo.com/2015/04/polynomial-multiplication-and-fast-fourier-transform#i-15

#include<cstdio>
#include<cstring>
#include<cmath>
#include<cstdlib>
#include<iostream>
#include<algorithm>
#include<vector>
#include<map>
#include<queue>
#include<stack>
#include<string>
#include<map>
#include<set>
#include<ctime>
#define eps 1e-6
#define LL long long
#define pii (pair<int, int>)
//#pragma comment(linker, "/STACK:1024000000,1024000000")
using namespace std;

const int N = 65536;
const int P = 786433;
const int G = 10;
const int NUM = 18;

LL wn[NUM];

LL quick_mod(LL a, LL b, LL m)
{
LL ans = 1;
a %= m;
while(b)
{
if(b & 1)
{
ans = ans * a % m;
b--;
}
b >>= 1;
a = a * a % m;
}
return ans;
}

void GetWn()
{
for(int i=0; i<NUM; i++)
{
int t = 1 << i;
wn[i] = quick_mod(G, (P - 1) / t, P);
}
}

void Rader(LL a[], int len)
{
int j = len >> 1;
for(int i=1; i<len-1; i++)
{
if(i < j) swap(a[i], a[j]);
int k = len >> 1;
while(j >= k)
{
j -= k;
k >>= 1;
}
if(j < k) j += k;
}
}

void NTT(LL a[], int len, int on)
{
Rader(a, len);
int id = 0;
for(int h = 2; h <= len; h <<= 1)
{
id++;
for(int j = 0; j < len; j += h)
{
LL w = 1;
for(int k = j; k < j + h / 2; k++)
{
LL u = a[k] % P;
LL t = w * a[k + h / 2] % P;
a[k] = (u + t) % P;
a[k + h / 2] = (u - t + P) % P;
w = w * wn[id] % P;
}
}
}
if(on == -1)
{
for(int i = 1; i < len / 2; i++)
swap(a[i], a[len - i]);
LL Inv = quick_mod(len, P - 2, P);
for(int i = 0; i < len; i++)
a[i] = a[i] * Inv % P;
}
}
LL f[20][70000], tmp[70000];
int n, h;
void solve(LL a[], LL b[], LL c[], int len, int hentai) {
for(int i = 0; i < len; i++) tmp[i] = a[i]*b[i];
NTT(tmp, len, -1);
if(hentai==1) for(int i = 1; i < len; i++) c[i] = (c[i]+tmp[i-1]*hentai) % P;
else for(int i = 1; i < len; i++) c[i] = c[i]+tmp[i-1]*hentai;
}

int main() {
freopen("avl.in", "r", stdin);
freopen("avl.out", "w", stdout);
cin >> n >> h;
if(!h) {
if(n==1) puts("1");
else puts("0");
return 0;
}
GetWn();
f[0][1] = 1; f[1][2] = 2; f[1][3] = 1;
for(int i = 2; i <= h; i++) {
int len = 1<<(i+1);
NTT(f[i-2], len, 1);
NTT(f[i-1], len, 1);
solve(f[i-2], f[i-1], f[i], len, 2);
solve(f[i-1], f[i-1], f[i], len, 1);
if(i!=h) NTT(f[i-2], len, -1);
if(i!=h) NTT(f[i-1], len, -1);
}
cout << f[h]
<< endl;
return 0;
}

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  GYM NTT