您的位置:首页 > 其它

HDU 3518 && HDU 4416【后缀自动机len的使用】

2015-08-28 10:43 295 查看
max:即代码中 len 变量,它表示该状态能够接受的最长的字符串长度。

min:表示该状态能够接受的最短的字符串长度。实际上等于该状态的 fail 指针指向的结点的 len + 1。

max-min+1:表示该状态能够接受的不同的字符串数。

那么在HDU 3518 中。

求的是,不相交且出现次数大于2的子串个数。

我们这么想,如果你把一个串直接塞进后缀自动机里面,比较难以处理(不是没有办法,只是我看不太懂那是怎么回事)不相交的问题。那么我们就想到将字符串拆成两部分,然后两部分分别hash,在一个串里找另一个串中是否存在一个hash值,这样就必然是不相交的,如果找到了相同的hash值,则是在两个字符串中各出现了一次,在原串中出现两次。但是这么做会TLE。

于是,就通过后缀自动机中的len来计数。

最上面讲过了
len[u]-len[fa[u]]+1
表示能够接受的不同串的个数,那么就是规定了一个上届和下届,如果我们把上届或者下届的意义改了,就可以计算我们需要的串的个数了。

在这题中,我们将串拆成两部分,第一部分塞进后缀自动机,用后一部分来匹配,匹配每一个节点能够匹配的最远距离mi[u]mi[u],那么讲状态u接受的字符串区间分成了两部分(len[fa[u]],mi[u]]和(mi[u],len[u]]( len[fa[u]],mi[u] ]和(mi[u],len[u]],很明显,前面是后半部分能够匹配的那部分串,后面是不能匹配的。那么我们需要的就是将每个节点的(len[fa[u]],mi[u]]( len[fa[u]],mi[u] ]部分加起来就行了。

同时还要注意的是,从u→fa[u]u\rightarrow fa[u]累加的情况。fail指针,指向了一个能够表示当前状态表示的所有字符串的最长公共后缀的结点。简单说,就是fail指针指向了一个v,从v到root表示的串是u到root表示的串中的一部分,然后要求的是最长。

那么也就是说,如果str在u这个状态匹配了mi[u]mi[u],那么str很可能在fa[u]也匹配了mi[u]mi[u](当然有些情况下,mi[u]mi[u]会超过len[fa[u]],可以直接取min),所以必须要累加过去。

[code]//      whn6325689
//      Mr.Phoebe
//      http://blog.csdn.net/u013007900 #include <algorithm>
#include <iostream>
#include <iomanip>
#include <cstring>
#include <climits>
#include <complex>
#include <fstream>
#include <cassert>
#include <cstdio>
#include <bitset>
#include <vector>
#include <deque>
#include <queue>
#include <stack>
#include <ctime>
#include <set>
#include <map>
#include <cmath>
#include <functional>
#include <numeric>
#pragma comment(linker, "/STACK:1024000000,1024000000")

using namespace std;

#define eps 1e-9
#define PI acos(-1.0)
#define INF 0x3f3f3f3f
#define LLINF 1LL<<62
#define speed std::ios::sync_with_stdio(false);

typedef long long ll;
typedef unsigned long long ull;
typedef long double ld;
typedef pair<ll, ll> pll;
typedef complex<ld> point;
typedef pair<int, int> pii;
typedef pair<pii, int> piii;
typedef vector<int> vi;

#define CLR(x,y) memset(x,y,sizeof(x))
#define CPY(x,y) memcpy(x,y,sizeof(x))
#define clr(a,x,size) memset(a,x,sizeof(a[0])*(size))
#define cpy(a,x,size) memcpy(a,x,sizeof(a[0])*(size))

#define mp(x,y) make_pair(x,y)
#define pb(x) push_back(x)
#define lowbit(x) (x&(-x))

#define MID(x,y) (x+((y-x)>>1))
#define ls (idx<<1)
#define rs (idx<<1|1)
#define lson ls,l,mid
#define rson rs,mid+1,r
#define root 1,1,n

template<class T>
inline bool read(T &n)
{
    T x = 0, tmp = 1;
    char c = getchar();
    while((c < '0' || c > '9') && c != '-' && c != EOF) c = getchar();
    if(c == EOF) return false;
    if(c == '-') c = getchar(), tmp = -1;
    while(c >= '0' && c <= '9') x *= 10, x += (c - '0'),c = getchar();
    n = x*tmp;
    return true;
}
template <class T>
inline void write(T n)
{
    if(n < 0)
    {
        putchar('-');
        n = -n;
    }
    int len = 0,data[20];
    while(n)
    {
        data[len++] = n%10;
        n /= 10;
    }
    if(!len) data[len++] = 0;
    while(len--) putchar(data[len]+48);
}
//-----------------------------------

const int MAXN=2010;
const int MAC=26;

struct SAM
{
    int len[MAXN],next[MAXN][MAC],fa[MAXN],L,last;
    int sum[MAXN],topo[MAXN];
    int mi[MAXN];
    SAM(){init();}
    void init()
    {
        L=0;last=newnode(0,-1);
    }
    int newnode(int l,int pre)
    {
        fa[L]=pre;
        mi[L]=0;
        for(int i=0;i<MAC;i++)  next[L][i]=-1;
        len[L]=l;
        return L++;
    }
    void add(int x)
    {
        int pre=last,now=newnode(len[pre]+1,-1);
        last=now;
        while(~pre && next[pre][x]==-1)
        {
            next[pre][x]=now;
            pre=fa[pre];
        }
        if(pre==-1) fa[now]=0;
        else
        {
            int bro=next[pre][x];
            if(len[bro]==len[pre]+1)    fa[now]=bro;
            else
            {
                int fail=newnode(len[pre]+1,fa[bro]);
                memcpy(next[fail],next[bro],sizeof next[bro]);
                fa[bro]=fa[now]=fail;
                while(~pre && next[pre][x]==bro)
                {
                    next[pre][x]=fail;pre=fa[pre];
                }
            }
        }
    }

    void toposort()
    {
        CLR(sum,0);
        for(int i=0;i<L;i++)    sum[len[i]]++;
        for(int i=1;i<L;i++)    sum[i]+=sum[i-1];
        for(int i=0;i<L;i++)    topo[--sum[len[i]]]=i;
    }

    void query(char *S)
    {
        int u=0,x,cnt=0;
        for(char *sp=S;*sp;sp++)
        {
            x=*sp-'a';
            if(~next[u][x])
            {
                u=next[u][x];
                mi[u]=max(mi[u],++cnt);
            }
            else
            {
                while(~u && next[u][x]==-1)
                    u=fa[u];
                if(u==-1)
                    cnt=0,u=0;
                else
                {
                    cnt=len[u]+1;
                    u=next[u][x];
                    mi[u]=max(mi[u],cnt);
                }
            }
        }
    }

    int build()
    {
        int ans=0;
        toposort();
        for(int i=L-1;i>=0;i--)
        {
            int u=topo[i];
            if(~fa[u])
                mi[fa[u]]=max(mi[fa[u]],mi[u]);
        }
        for(int i=0;i<L;i++)
        {
            mi[i]=min(mi[i],len[i]);
            if(~fa[i] && mi[i]>len[fa[i]])
                ans+=(mi[i]-len[fa[i]]);
        }
        return ans;
    }
}T;

char str[MAXN];

int main()
{
    while(scanf("%s",str)!=EOF && str[0]!='#')
    {
        int le=strlen(str);
        T.init();
        for(int i=1;i<le;i++)
        {
            T.add(str[i-1]-'a');
            T.query(str+i);
        }
        printf("%d\n",T.build());
    }
    return 0;
}


在HDU 4416中

求的是A这个字符串,有多少各子串在BiB_i中没有出现过。

上一题求的是出现过的,这一题求的是没有出现过的。

同样的分析,最终将区间分成两部分(len[fa[u]],mi[u]]和(mi[u],len[u]]( len[fa[u]],mi[u] ]和(mi[u],len[u]],很明显,前面是出现过的,后面是没出现过的,于是我们需要将后面那部分加起来就行了。

[code]//      whn6325689
//      Mr.Phoebe
//      http://blog.csdn.net/u013007900 #include <algorithm>
#include <iostream>
#include <iomanip>
#include <cstring>
#include <climits>
#include <complex>
#include <fstream>
#include <cassert>
#include <cstdio>
#include <bitset>
#include <vector>
#include <deque>
#include <queue>
#include <stack>
#include <ctime>
#include <set>
#include <map>
#include <cmath>
#include <functional>
#include <numeric>
#pragma comment(linker, "/STACK:1024000000,1024000000")

using namespace std;

#define eps 1e-9
#define PI acos(-1.0)
#define INF 0x3f3f3f3f
#define LLINF 1LL<<62
#define speed std::ios::sync_with_stdio(false);

typedef long long ll;
typedef unsigned long long ull;
typedef long double ld;
typedef pair<ll, ll> pll;
typedef complex<ld> point;
typedef pair<int, int> pii;
typedef pair<pii, int> piii;
typedef vector<int> vi;

#define CLR(x,y) memset(x,y,sizeof(x))
#define CPY(x,y) memcpy(x,y,sizeof(x))
#define clr(a,x,size) memset(a,x,sizeof(a[0])*(size))
#define cpy(a,x,size) memcpy(a,x,sizeof(a[0])*(size))

#define mp(x,y) make_pair(x,y)
#define pb(x) push_back(x)
#define lowbit(x) (x&(-x))

#define MID(x,y) (x+((y-x)>>1))
#define ls (idx<<1)
#define rs (idx<<1|1)
#define lson ls,l,mid
#define rson rs,mid+1,r
#define root 1,1,n

template<class T>
inline bool read(T &n)
{
    T x = 0, tmp = 1;
    char c = getchar();
    while((c < '0' || c > '9') && c != '-' && c != EOF) c = getchar();
    if(c == EOF) return false;
    if(c == '-') c = getchar(), tmp = -1;
    while(c >= '0' && c <= '9') x *= 10, x += (c - '0'),c = getchar();
    n = x*tmp;
    return true;
}
template <class T>
inline void write(T n)
{
    if(n < 0)
    {
        putchar('-');
        n = -n;
    }
    int len = 0,data[20];
    while(n)
    {
        data[len++] = n%10;
        n /= 10;
    }
    if(!len) data[len++] = 0;
    while(len--) putchar(data[len]+48);
}
//-----------------------------------

const int MAXN=(100010<<1);
const int MAC=26;

struct SAM
{
    int len[MAXN],next[MAXN][MAC],fa[MAXN],L,last;
    int mi[MAXN];
    SAM()
    {
        init();
    }
    void init()
    {
        L=last=0;
        newnode(0,-1);
    }
    int newnode(int l,int pre)
    {
        fa[L]=pre;
        for(int i=0; i<MAC; i++)    next[L][i]=-1;
        len[L]=l;
        mi[L]=0;
        return L++;
    }
    void build(const char* p)
    {
        int le=strlen(p);
        for(int i=0; i<le; i++)
            add(p[i]-'a');
        toposort();
    }
    void add(int x)
    {
        int pre=last,now=newnode(len[pre]+1,-1);
        last=now;
        while(~pre && next[pre][x]==-1)
        {
            next[pre][x]=now;
            pre=fa[pre];
        }
        if(pre==-1) fa[now]=0;
        else
        {
            int bro=next[pre][x];
            if(len[bro]==len[pre]+1)    fa[now]=bro;
            else
            {
                int fail=newnode(len[pre]+1,fa[bro]);
                for(int i=0; i<MAC; i++)    next[fail][i]=next[bro][i];
                fa[bro]=fa[now]=fail;
                while(~pre && next[pre][x]==bro)
                {
                    next[pre][x]=fail;
                    pre=fa[pre];
                }
            }
        }
    }
    int sum[MAXN],topo[MAXN];
    void toposort()
    {
        CLR(sum,0);
        for(int i=0; i<L; i++)  sum[len[i]]++;
        for(int i=1; i<L; i++)  sum[i]+=sum[i-1];
        for(int i=0; i<L; i++)  topo[--sum[len[i]]]=i;
    }

    void query(const char* S)
    {
        int u=0,x,cnt=0;
        for(const char* sp=S; *sp; sp++)
        {
            x=*sp-'a';
            if(~next[u][x])
            {
                u=next[u][x];
                mi[u]=max(mi[u],++cnt);
            }
            else
            {
                while(~u && next[u][x]==-1)
                    u=fa[u];
                if(u==-1)
                    cnt=0,u=0;
                else
                {
                    cnt=len[u]+1;
                    u=next[u][x];
                    mi[u]=max(mi[u],cnt);
                }
            }
        }
    }
} T;

char str[MAXN];
int q,k;

int main()
{
    int t,cas=1;
    scanf("%d",&t);
    while(t--)
    {
        scanf("%d",&q);
        T.init();
        scanf("%s",str);
        T.build(str);
        while(q--)
        {
            scanf("%s",str);
            T.query(str);
        }
        for(int i=T.L-1; i>=1; i--)
        {
            int u=T.topo[i];
            if(T.mi[u]>0)
                T.mi[T.fa[u]]=max(T.mi[T.fa[u]],T.mi[u]);
            else
                T.mi[u]=T.len[T.fa[u]];
        }
        ll ans=0;
        for(int i=1; i<T.L; i++)
            if(T.len[i]>T.mi[i])
                ans+=T.len[i]-T.mi[i];
        printf("Case %d: %lld\n",cas++,ans);
    }
    return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: