您的位置:首页 > 其它

CodeChef - COUNTARI Arithmetic Progressions FFT 分块

2017-09-30 18:10 465 查看
这题因为ijk大小关系的限制,所以不能像三个傻瓜那题一样直接FFT,排序后排出情况。

所以一开始想到的是对每个位置都做一次FFT,即枚举Aj,用Ai和Ak做FFT,但这复杂度明显是不行的O(N∗30000∗log30000)

然后看了题解才知道还有分块这种方法。。

具体的分法就不说了,网上有一大堆,最后的复杂度就是O(N2K+k∗30000∗16)

k取到30就能过了,不过为什么看着感觉k取到1000才能过呢。。

#include <iostream>
#include <algorithm>
#include <cstring>
#include <cstdio>
#include <cmath>

using namespace std;
const int MA
c3ed
XN = 262144;
const int LIM = 61000;
typedef long long LL;
const int INF = 0x3f3f3f3f;
const double pi=acos(-1.0);
int num[MAXN],n,block,size;
LL pre[LIM],in[LIM],nex[LIM];
LL ans = 0;

struct cp
{
double x,y;
cp() {}
cp(double x,double y):x(x),y(y) {}
inline double real() { return x; }
inline cp operator * (const cp& r) const { return cp(x*r.x-y*r.y,x*r.y+y*r.x); }
inline cp operator - (const cp& r) const { return cp(x-r.x,y-r.y); }
inline cp operator + (const cp& r) const { return cp(x+r.x,y+r.y); }
};

cp a[MAXN],b[MAXN];
LL r[MAXN],res[MAXN];
LL ax[MAXN],bx[MAXN];

void fft_init(int nm,int k)
{
for (int i=0;i<nm;i++) r[i] = (r[i>>1]>>1) | ((i&1) << (k-1));
}

void fft(cp ax[],int nm,int op)
{
for (int i=0;i<nm;i++) if (i<r[i]) swap(ax[i],ax[r[i]]);
for (int h=2,m=1;h<=nm;h<<=1,m<<=1)
{
cp wn = cp(cos(op*2*pi/h),sin(op*2*pi/h));
for (int i=0;i<nm;i+=h)
{
cp w(1,0);
for (int j=i;j<i+m;++j,w=w*wn)
{
cp t=w*ax[j+m];
ax[j+m] = ax[j]-t;
ax[j] = ax[j]+t;
}
}
}
if (op==-1) for (int i=0;i<nm;i++) ax[i].x /= nm;
}

void trans(LL ax[],LL bx[],int n,int m)
{
int nm=1,k=0;
while (nm < 2*n || nm<2*m) nm<<=1,k++;

for (int i=0;i<n;i++) a[i] = cp(ax[i],0);
for (int i=0;i<m;i++) b[i] = cp(bx[i],0);
for (int i=n;i<nm;i++) a[i] = cp(0,0);
for (int i=m;i<nm;i++) b[i] = cp(0,0);

fft_init(nm,k);
fft(a,nm,1);fft(b,nm,1);
for (int i=0;i<nm;i++) a[i] = a[i]*b[i];
fft(a,nm,-1);
nm = n+m-1;
for (int i=0;i<nm;i++) res[i] = (LL)(a[i].real()+0.5);
}

int main()
{
while (scanf("%d",&n)!=EOF)
{
memset(in,0,sizeof in);
memset(pre,0,sizeof pre);
memset(nex,0,sizeof nex);
for (int i=1;i<=n;i++)
{
scanf("%d",&num[i]);
nex[num[i]] ++;
}
block = 30;
size = (n+block-1)/block;
for (int b=1;b<=block;b++)
{
int s=(b-1)*size +1,e=min(b*size,n);
for (int i=s;i<=e;i++) nex[num[i]]--;
trans(pre,nex,30001,30001);
for (int i=s;i<=e;i++)
{
for (int j=i+1;j<=e;j++)
{
if (2*num[i] - num[j]>=1 )
{
ans += in[ 2*num[i] - num[j] ];//3 in
ans += pre[ 2*num[i] - num[j] ];//2 in  1 prev
}
if ( 2*num[j]-num[i]>=1 )
ans += nex[ 2*num[j]-num[i] ];
}
ans += res[ 2*num[i] ];
in [ num[i] ] ++;
}
for (int i=s;i<=e;i++) pre[num[i]]++,in[num[i]]--;
}
printf("%lld\n",ans);
}
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: