您的位置:首页 > 运维架构

【XSY2166】Hope 分治 FFT

2017-11-02 20:44 246 查看

题目描述

  对于一个1到n的排列a1,a2,a3,…,an,我们定义这个排列的P值和Q值:

  对于每个ai,如果存在一个最小的j使得i<j且ai<aj,那么将ai和aj连一条无向边。于是就得到一幅图。计算这幅图每个联通块的大小,将它们相乘,得到P。记Q=Pk。

  对于1到n的所有排列,我们想知道它们的Q值之和。由于答案可能很大,请将答案对998244353取模。

  n,k≤100000

题解

  考虑从小到大插入这n个数。

  设fi为所有1~i的排列的Q值之和。

  考虑i的位置,当i在第j个位置的时候,前面j个点是联通的,后面i−j个点与前面j个点不连通。

fififi=∑j=1i(i−1j−1)(j−1)!jkfi−j=∑j=1i(i−1)!jkfi−j(i−j)!=(i−1)!∑j=1ijkfi−j(i−j)!

  用分治FFT加速。

  时间复杂度:O(nlogk+nlog2n)

代码

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<ctime>
#include<utility>
#include<cmath>
#include<functional>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
void sort(int &a,int &b)
{
if(a>b)
swap(a,b);
}
void open(const char *s)
{
#ifndef ONLINE_JUDGE
char str[100];
sprintf(str,"%s.in",s);
freopen(str,"r",stdin);
sprintf(str,"%s.out",s);
freopen(str,"w",stdout);
#endif
}
const ll p=998244353;
ll fp(ll a,ll b)
{
ll s=1;
while(b)
{
if(b&1)
s=s*a%p;
a=a*a%p;
b>>=1;
}
return s;
}
namespace ntt
{
const ll g=3;
ll w1[270010];
ll w2[270010];
int rev[270010];
int n;
void init(int m)
{
n=1;
while(n<m)
n<<=1;
int i;
for(i=2;i<=n;i<<=1)
{
w1[i]=fp(g,(p-1)/i);
w2[i]=fp(w1[i],p-2);
}
rev[0]=0;
for(i=1;i<n;i++)
rev[i]=(rev[i>>1]>>1)|(i&1?n>>1:0);
}
void ntt(ll *a,int t)
{
ll u,v,w,wn;
int i,j,k;
for(i=0;i<n;i++)
if(rev[i]<i)
swap(a[i],a[rev[i]]);
for(i=2;i<=n;i<<=1)
{
wn=(t==1?w1[i]:w2[i]);
for(j=0;j<n;j+=i)
{
w=1;
for(k=j;k<j+i/2;k++)
{
u=a[k];
v=a[k+i/2]*w%p;
a[k]=(u+v)%p;
a[k+i/2]=(u-v)%p;
w=w*wn%p;
}
}
}
if(t==-1)
{
ll inv=fp(n,p-2);
for(i=0;i<n;i++)
a[i]=a[i]*inv%p;
}
}
ll x[270010];
ll y[270010];
void copy_clear(ll *a,ll *b,int m)
{
int i;
for(i=0;i<m;i++)
a[i]=b[i];
for(i=m;i<n;i++)
a[i]=0;
}
void copy(ll *a,ll *b,int m)
{
int i;
for(i=0;i<m;i++)
a[i]=b[i];
}
void inverse(ll *a,ll *b,int m)
{
if(m==1)
{
b[0]=fp(a[0],p-2);
return;
}
inverse(a,b,m>>1);
init(2*m);
copy_clear(x,a,m);
copy_clear(y,b,m>>1);
ntt(x,1);
ntt(y,1);
int i;
for(i=0;i<n;i++)
x[i]=(2*y[i]%p-x[i]*y[i]%p*y[i]%p+p)%p;
ntt(x,-1);
copy(b,x,m);
}
};
ll fac[300010];
ll ifac[300010];
ll inv[300010];
ll f[300010];
ll a[300010];
ll b[300010];
ll ex[300010];
int n,k;
void solve(int l,int r)
{
if(l==r)
{
f[l]=f[l]*fac[l-1]%p;
return;
}
int mid=(l+r)>>1;
solve(l,mid);
ntt::init(r-l+1);
int i;
for(i=l;i<=mid;i++)
a[i-l]=f[i]*ifac[i];
for(i=l;i<=r;i++)
b[i-l]=ex[i-l];
for(i=mid-l+1;i<ntt::n;i++)
a[i]=0;
for(i=r-l+1;i<ntt::n;i++)
b[i]=0;
ntt::ntt(a,1);
ntt::ntt(b,1);
for(i=0;i<ntt::n;i++)
a[i]=a[i]*b[i]%p;
ntt::ntt(a,-1);
for(i=mid+1;i<=r;i++)
f[i]+=a[i-l];
solve(mid+1,r);
}
int main()
{
open("xsy2166");
scanf("%d%d",&n,&k);
int i;
fac[0]=fac[1]=ifac[0]=ifac[1]=inv[0]=inv[1]=1;
for(i=2;i<=n;i++)
{
inv[i]=-(p/i)*inv[p%i]%p;
fac[i]=fac[i-1]*i%p;
ifac[i]=ifac[i-1]*inv[i]%p;
}
for(i=1;i<=n;i++)
{
ex[i]=fp(i,k);
f[i]=ex[i];
}
solve(1,n);
ll ans=(f
+p)%p;
printf("%lld\n",ans);
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: