您的位置:首页 > 其它

【bzoj3992】 SDOI2015—序列统计

2017-02-13 23:06 357 查看
http://www.lydsy.com/JudgeOnline/problem.php?id=3992 (题目链接)

题意

  集合${S}$中有若干个不超过${m}$的非负整数,问由这些数组成一个长度${n}$的序列,使序列中的数的乘积对${m}$取模正好等于${x}$,问存在多少方案。

Solution

  好神的题。算法还是要多复习,我连${NTT}$都忘记怎么写了T_T

  这还是我的第一发原根→_→。

  一个数如果有原根,那么它会有很多原根,所以如果对时间没有特殊限制,我们枚举${rt=2~~to~~inf}$,然后判断是否存在${t<m-1}$使${rt^t=1}$。虽然我并不知道为什么可以那样check。。

  我们可以很简单的列出dp方程${f_{i,j}}$表示,已经放到了第${i}$个数,它们的乘积是${j}$的方案数。转移也就很显然了:$${f[i][j]=\sum_{k=1}^{m-1}f_{i-1,j*inv[k]}}$$

  复杂度${O(nm^2)}$,于是我们就可以获得10分的高分,是不是很良心啊。

  考虑这个东西怎么优化,我们把每一个${j}$都写成${m}$的原根的几次方,然后乘就变成加辣,然后我们就可以卷积辣。

  然后你发现${n}$有${10^9}$,我们快速幂一波,然后就AC辣。

细节

  一开始没想清没注意到还是循环卷积卧槽T_T

代码

// bzoj3992
#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<ctime>
#define LL long long
#define inf (1ll<<30)
#define MOD 1004535809
#define Pi acos(-1.0)
#define free(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout);
using namespace std;

const int maxn=20010;
int f[maxn],g[maxn],rev[maxn],vis[maxn];
int n,m,rt,S,X,N,L;

int power(int a,int b,int c) {
int res=1;
while (b) {
if (b&1) res=(LL)res*a%c;
b>>=1;a=(LL)a*a%c;
}
return res;
}
void root(int p) {
if (p==2) {rt=1;return;}
for (rt=2;;rt++) {
int flag=1;
for (int i=2;i*i<p;i++)
if (power(rt,(p-1)/i,p)==1) {flag=0;break;}
if (flag) break;
}
}
namespace NTT {
LL A[maxn],B[maxn];
void NTT(LL *a,int f) {
for (int i=0;i<N;i++) if (i<rev[i]) swap(a[i],a[rev[i]]);
for (int i=1;i<N;i<<=1) {
LL gn=power(3,(MOD-1)/(i<<1),MOD);
for (int p=i<<1,j=0;j<N;j+=p) {
LL g=1;
for (int k=0;k<i;k++,(g*=gn)%=MOD) {
LL x=a[k+j],y=g*a[k+j+i]%MOD;
a[k+j]=(x+y)%MOD,a[k+j+i]=(x-y+MOD)%MOD;
}
}
}
if (f==-1) reverse(a+1,a+N);
}
void Init(int *a,int *b) {
for (int i=0;i<N;i++) A[i]=a[i],B[i]=b[i];
NTT(A,1);NTT(B,1);
for (int i=0;i<N;i++) (A[i]*=B[i])%=MOD;
NTT(A,-1);
LL ev=power(N,MOD-2,MOD);
for (int i=0;i<N;i++) (A[i]*=ev)%=MOD;
for (int i=0;i<m-1;i++) a[i]=(A[i]+A[i+m-1])%MOD;
}
}
using namespace NTT;

int main() {
scanf("%d%d%d%d",&n,&m,&X,&S);
root(m);
for (int x,i=1;i<=S;i++) scanf("%d",&x),vis[x]=1;
for (int p=1,i=0;i<m-1;i++,(p*=rt)%=m) if (vis[p]) f[i]=1;
for (N=1,L=-1;N<(m-1)*2;N<<=1) L++;
for (int i=0;i<N;i++) rev[i]=(rev[i>>1]>>1) | ((i&1)<<L);
g[0]=1;
while (n) {
if (n&1) Init(g,f);
n>>=1;Init(f,f);
}
for (int i=0,p=1;i<m-1;i++,(p*=rt)%=m)
if (p==X) {printf("%d",g[i]);break;}
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: