您的位置:首页 > 其它

【动态规划20】bzoj4818[sdoi2017]序列计数(dp+矩阵快速幂)

2017-06-26 22:47 525 查看

题目描述

Alice想要得到一个长度为n的序列,序列中的数都是不超过m的正整数,而且这n个数的和是p的倍数。Alice还希望

,这n个数中,至少有一个数是质数。Alice想知道,有多少个序列满足她的要求。

输入输出格式

一行三个数,n,m,p。

1<=n<=10^9,1<=m<=2×10^7,1<=p<=100

一行一个数,满足Alice的要求的序列数量,答案对20170408取模。

题目显然容斥原理,用所有方案-不含质数的方案为所求。

首先,有一个非常菜的dp方程。

f[i][j]表示前i个数%p==j的序列个数

f[i][j]+=f[i−1][(j−k)%p](1<=k<=m)

那么就是一个显然的矩阵乘法。

首先f[i][]只与f[i-1][]相关,那么我们可以将原来的i维消掉。

那么之后的转移就是将所有f[(j−k)%p](1<=k<=m)转移到f[j]

那矩阵就很好构建了,从1到m枚举k,矩阵的第j行的(j-k)%p列的值就加一,很好理解。(对于质数在枚举过程中特殊判断即可)

但是若是每一行都从1到m枚举k,时间复杂度上是不允许的。

但是很显然,我们可以意识到矩阵的第j+1行的(j+1-k)%p列,实际上就是j行的前一列(直接写出来有点像废话..),所以直接暴力从矩阵的上一行转移,这样子最后就是一个p*p的矩阵。

直接快速幂搞就完事了。

#include<bits/stdc++.h>
#define fer(i,j,n) for(int i=j;i<=n;i++)
#define far(i,j,n) for(int i=j;i>=n;i--)
#define ll long long
const int maxn=20000010;
const int INF=1e9+7;
const int mod=20170408;
using namespace std;
/*----------------------------------------------------------------------------*/
inline ll read()
{
char ls;ll x=0,sng=1;
for(;ls<'0'||ls>'9';ls=getchar())if(ls=='-')sng=-1;
for(;ls>='0'&&ls<='9';ls=getchar())x=x*10+ls-'0';
return x*sng;
}
/*----------------------------------------------------------------------------*/
int n,m,p,cnt;
int prime[maxn],f[110];
bool flag[maxn];
struct kaga
{
ll v[110][110];
kaga friend operator *(kaga a,kaga b)
{
kaga c;
fer(i,0,p-1)
fer(j,0,p-1)
{
c.v[i][j]=0;
fer(k,0,p-1)
c.v[i][j]+=a.v[i][k]*b.v[k][j]%mod;
}
return c;
}
kaga friend operator ^(kaga a,ll k)
{
kaga c;
fer(i,0,p-1)
fer(j,0,p-1)
if(i==j)c.v[i][j]=1;
else c.v[i][j]=0;
for(;k;k>>=1,a=a*a)
if(k&1)c=c*a;
return c;
}
void friend print(kaga a)
{
fer(i,0,p-1)
{
fer(j,0,p-1)
cout<<a.v[i][j]<<" ";
cout<<endl;
}
}
}a;
void Prime(int n)
{
memset(flag,0,sizeof(flag));
flag[1]=1;
cnt=0;
fer(i,2,n)
{
if(!flag[i])prime[++cnt]=i;
for(int j=1;j<=cnt&&i*prime[j]<=n;j++)
{
flag[i*prime[j]]=1;
if(!(i%prime[j]))break;
}
}
}
int solve1()
{
fer(i,1,m)f[i%p]++;
fer(i,1,m)a.v[(-i%p+p)%p][0]++;
fer(i,1,p-1)
fer(j,0,p-1)
a.v[j][i]=a.v[(j-1+p)%p][i-1];
a=a^(n-1);
int ans=0;
fer(i,0,p-1)ans=(ans+(ll)f[i]*a.v[i][0]%mod)%mod;
return ans;
}
int solve2()
{
memset(f,0,sizeof(f));
fer(i,1,m)if(flag[i])f[i%p]++;
memset(a.v,0,sizeof(a.v));
fer(i,1,m)if(flag[i])a.v[(-i%p+p)%p][0]++;
fer(i,1,p-1)
fer(j,0,p-1)
a.v[j][i]=a.v[(j-1+p)%p][i-1];
a=a^(n-1);
int ans=0;
fer(i,0,p-1)ans=(ans+(ll)f[i]*a.v[i][0]%mod)%mod;
return ans;
}
int main()
{
n=read();m=read();p=read();
Prime(m);
cout<<(solve1()-solve2()+mod)%mod;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: