您的位置:首页 > 产品设计 > UI/UE

POJ2778 DNA Sequence (AC自动机+矩阵快速幂)

2017-12-17 01:38 399 查看

POJ2778 DNA Sequence

原题地址

http://poj.org/problem?id=2778

题意:

给出有m种有疾病的DNA序列,问有多少种长度为n的DNA序列不包含任何一种有疾病的DNA序列。(仅含A,T,C,G四个字符)

数据范围

0 <= m <= 10,1 <= n <=2000000000,给出的疾病串的长度<=10

题解:

首先要知道的预备知识:

(有向/无向)图中从u点到v点长为n的路径数

=原图的邻接矩阵自乘n次后 mat[u][v]的值

具体证明可以看这里

大概就是一个乘法原理+加法原理。

于是对于这道题,

首先我们要把疾病节点去掉,

这当然要考虑到一个点它的fail是疾病节点,那么疾病的标记是要下传的。

我们的补全AC自动机显然是个DAG,

对于剩下的图,我们搞出它的邻接矩阵,

矩阵快速幂n次,

答案就是从root开始走n步到各个点的方案数之和。

ans=∑mat[0][v]

代码:

#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<queue>
#define LL long long
using namespace std;
const int N=110;
const int mod=100000;
queue<int> Q;
int ch
[4],fail
,isword
,num[130],n,m,tail,root,sz;
char s
;
struct Mat
{
long long a

;
void init(){memset(a,0,sizeof(a));}
}ret,base;
Mat operator*(const Mat &A,const Mat &B)
{
Mat C; C.init();
for(int i=0;i<=sz;i++)
for(int j=0;j<=sz;j++)
for(int k=0;k<=sz;k++)
C.a[i][j]=(C.a[i][j]+1LL*A.a[i][k]*B.a[k][j])%mod;
return C;
}
void init()
{
memset(fail,0,sizeof(fail));
memset(isword,0,sizeof(isword));
memset(ch,0,sizeof(ch));
tail=0; root=0; ret.init(); base.init();
while(!Q.empty()) Q.pop();
}
void insert()
{
int len=strlen(s); int tmp=root;
for(int i=0;i<len;i++)
{
int c=num[s[i]];
if(!ch[tmp][c]){ch[tmp][c]=++tail;}
tmp=ch[tmp][c];
}
isword[tmp]=1;
}
void getfail()
{
for(int i=0;i<4;i++)
if(ch[root][i]) {fail[ch[root][i]]=root; Q.push(ch[root][i]);}
while(!Q.empty())
{
int top=Q.front(); Q.pop();
for(int i=0;i<4;i++)
{
if(!ch[top][i]) {ch[top][i]=ch[fail[top]][i]; continue;}
int u=ch[top][i];
fail[u]=ch[fail[top]][i];
if(isword[fail[u]]) isword[u]=1;
Q.push(u);
}
}
}
int main()
{
num['A']=0; num['G']=1; num['C']=2; num['T']=3;
while(~scanf("%d%d",&m,&n))
{
init();
for(int i=1;i<=m;i++) {scanf("%s",s); insert();}
getfail();
for(int i=0;i<=tail;i++)
{
if(isword[i]) continue;
for(int c=0;c<4;c++)
{
if(isword[ch[i][c]]) continue;
base.a[i][ch[i][c]]++;
}
}
for(int i=0;i<=tail;i++) ret.a[i][i]=1; sz=tail;
for(int j=n;j;j>>=1)
{
if(j&1) ret=ret*base;
base=base*base;
}
int ans=0;
for(int i=0;i<=tail;i++) ans=(ans+ret.a[0][i])%mod;
printf("%d\n",ans);
}
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: