您的位置:首页 > 其它

BZOJ2553: [BeiJing2011]禁忌 AC自动机 期望DP 矩阵

2016-12-27 14:24 344 查看
题目大意:给出n个禁忌串,定义任意一个字符串的禁忌伤害是不相交的禁忌子串的最大个数,求长度为len的随机串禁忌伤害期望。

所有的串只包含前alphabet个拉丁字母。

N ≤ 5,len ≤109,1 ≤ alphabet ≤ 26。

对所有禁忌串建AC自动机,问题转化成:每次在AC自动机上走一步,走到危险节点(即有危险标记的节点或者能沿fail走到危险标记的节点)就自动回到根,求危险节点的期望经过次数。

考虑期望的定义:sigma(权值 * 概率)。权值就是1,因此将概率累加即可。可以先不考虑危险节点,则其为一个普通的图,用矩阵乘法即可求出走n步后从i点走到j点的概率。考虑危险节点,每一次转移,都会有一些节点以一定概率转移到危险节点,走到这些点的概率*转移概率就是这次转移对答案的贡献,可以随着矩阵的递推而计算每一次的贡献,因此新建一个代表所有危险点的节点,能走到危险点的点同时向根和这个点连边,同时这个点自身连一条1边,代表着累加上一次的结果。自乘len次后f[0][danger]就表示了从0号节点走len步,danger节点的期望经过次数。

据说有精度问题,要long double才能过。

#include<cstdio>
#include<queue>
using std::queue;
typedef long double cnt_t;
int n,len,alpha;
struct node
{
node *s[26],*fail;
bool da;
node():s(),fail(),da(){}
inline void* operator new(size_t);
}a[100];
inline void* node::operator new(size_t)
{
static size_t t=0xffffffffu;
return a+ ++t;
}
struct trie
{
node *rt;
trie():rt(new node){}
inline void insert(const char *s)
{
node *now=rt;
while(*s)
{
node*& x=now->s[*s-'a'];
if(!x) x=new node;
now=x;
++s;
}
now->da=1;
}
};
size_t maxn;
struct martix
{
cnt_t p[80][80];
cnt_t* operator [](size_t x){return p[x];}
martix():p(){}
martix(size_t):p()
{
for(size_t i=0;i<=maxn;++i)
p[i][i]=1.0;
}
martix operator *(martix &b) const
{
martix res;
for(size_t i=0;i<=maxn;++i)
for(size_t j=0;j<=maxn;++j)
for(size_t k=0;k<=maxn;++k)
res[i][j]+=p[i][k]*b[k][j];
return res;
}
}ans;
struct ac_auto:public trie
{
ac_auto():trie(){}
inline size_t operator() ()
{
static queue<node*> q;
size_t res=1;
rt->fail=rt;
for(int i=0;i<alpha;++i)
{
node*& x=rt->s[i];
if(!x) x=rt;
else x->fail=rt,q.push(x);
}
while(!q.empty())
{
node *now=q.front();q.pop();
now->da|=now->fail->da;
for(int i=0;i<alpha;++i)
{
node*& x=now->s[i];
if(!x) {x=now->fail->s[i];continue;}
x->fail=now->fail->s[i];
q.push(x);
}
++res;
}
return res;
}
}ac;
template<typename T,typename Int>
T power(T x,Int a)
{
T ans(1);
while(a)
{
if(a&1) ans=ans*x;
x=x*x;
a>>=1;
}
return ans;
}
char s[20];
int main()
{
scanf("%d%d%d",&n,&len,&alpha);
while(n--) scanf("%s",s),ac.insert(s);
maxn=ac();
ans[maxn][maxn]=1.0;
cnt_t base=1.0/alpha;
for(size_t i=0;i<maxn;++i)
for(int k=0;k<alpha;++k)
{
size_t j=a[i].s[k]-a;
if(a[j].da) ans[i][0]+=base,ans[i][maxn]+=base;
else ans[i][j]+=base;
}
ans=power(ans,len);
double res=ans[0][maxn];
printf("%.10lf\n",res);
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: