您的位置:首页 > 其它

【bzoj4566】[Haoi2016]找相同字符

2017-12-03 23:40 302 查看

4566: [Haoi2016]找相同字符

Time Limit: 20 Sec  Memory Limit: 256 MB
Submit: 640  Solved: 350

[Submit][Status][Discuss]

Description

给定两个字符串,求出在两个字符串中各取出一个子串使得这两个子串相同的方案数。两个方案不同当且仅当这两
个子串中有一个位置不同。

Input

两行,两个字符串s1,s2,长度分别为n1,n2。1 <=n1, n2<= 200000,字符串中只有小写字母

Output

输出一个整数表示答案

Sample Input

aabb

bbaa

Sample Output

10

HINT

Source



[Submit][Status][Discuss]

后缀自动机模板系列。。

建两个后缀自动机,然后同时在两个自动机上dfs,答案即为,对应节点的endpos的大小的乘积的和

说起来非常简单,但是实现起来本垃圾就有问题了

endpos的大小要怎么统计。。。。。?????

然后被学长的模板误导了一波,后来才知道,只要先建好自动机,然后在parent树上拓扑拓扑就完了

代码:
#include<cstdio>
#include<cmath>
#include<queue>
#include<stack>
#include<vector>
#include<algorithm>
#include<cstring>
using namespace std;

typedef long long LL;

const int maxn = 1200010;
const int maxs = 200100;

char s[maxn];
int n;
int fa[maxn],ch[maxn][30],Max[maxn],du[maxn],Q[maxn],rt[3],tot,last;
bool vis[maxn];
LL cnt[maxn],ans;

inline LL getint()
{
LL ret = 0,f = 1;
char c = getchar();
while (c < '0' || c > '9')
{
if (c == '-') f = -1;
c = getchar();
}
while (c >= '0' && c <= '9')
ret = ret * 10 + c - '0',c = getchar();
return ret * f;
}

inline void insert(int x,int d)
{
int v = last;
Max[++tot] = Max[v] + 1; last = tot; fa[last] = rt[d]; cnt[last] = 1;
while (v && !ch[v][x]) ch[v][x] = last , v = fa[v];
if (!v) {du[rt[d]]++; return;}
int p = ch[v][x];
if (Max[p] != Max[v] + 1)
{
int np = ++tot;
Max[np] = Max[v] + 1; fa[np] = fa[p]; fa[p] = np; fa[last] = np; du[np] = 2;
while (v && ch[v][x] == p) ch[v][x] = np , v = fa[v];
for (int i = 1; i <= 26; i++) ch[np][i] = ch[p][i];
}
else fa[last] = p , du[p]++;
}

inline void top()
{
int head = 0,tail = 0;
for (int i = 1; i <= tot; i++)
if (!du[i]) Q[++tail] = i;
while (head < tail)
{
int u = Q[++head];
cnt[fa[u]] += cnt[u];
--du[fa[u]];
if (!du[fa[u]])
Q[++tail] = fa[u];
}
int test;
test = 1;
}

inline void dfs(int u,int v)
{
ans += cnt[u] * cnt[v];
for (int i = 1; i <= 26; i++)
{
if (!ch[u][i] || !ch[v][i]) continue;
dfs(ch[u][i],ch[v][i]);
}
}

int main()
{
rt[1] = last = ++tot;
scanf("%s",s + 1); n = strlen(s + 1);
for (int i = 1; i <= n; i++)
insert(s[i] - 'a' + 1,1);

rt[2] = last = ++tot;
scanf("%s",s + 1); n = strlen(s + 1);
for (int i = 1; i <= n; i++) insert(s[i] - 'a' + 1,2);

top();
cnt[0] = cnt[rt[1]] = cnt[rt[2]] = 0;
dfs(rt[1],rt[2]);
printf("%lld",ans);
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: