您的位置:首页 > 其它

HDU 4029 Distinct Sub-matrix [后缀数组]

2012-08-31 00:14 316 查看
  看了大牛的代码,在斌牛的指导下,终于AC了这题。

  题目很短,就是问一个由大写字母组成的矩阵中有多少个不同的子矩阵。

  从1到m枚举宽度,对每个宽度进行HASH,hash[i][j]表示s[i][j...j+width-1]这个串的hash值,然后再将hash值按照hash[i][0],hash[i+1][0]..hash[n-1][0],#,hash[i][1]...hash[n-1][1],这样竖着的顺序连接起来。并在每一列的串之间用一个符号隔开,这样形成了一个串,再求这个串的不重复子串有多少个,最后将所有宽度的不重复子串和加起来就可以了。这个应该比较容易理解,当枚举宽度为width时,h[i1][j]..h[i2][j]构成的串实际上就是高从i1~i2,宽从j~j+width-1这样一个矩阵。

  至于怎样求不重复子串有多少个,显然有后缀数组可以解决,子串的个数减去height[i]的和就可以了。几乎没怎么写过后缀数组,用的是罗赛骞的代码,对它的代码不熟悉,一开始总RE,后来看了它的代码才知道原来他da和calheight传的不是一个n,看来还是要写一份自己的模版比较靠谱。。

  编码中还是有几个问题要解决。一个是hash的问题,第一遍枚举width=1的时候,hash[i][j]就是该个字符,第二遍枚举width=2,直接在第一次hash[i][j]的值上进行操作hash[i][j]=hash[i][j]*BASE+mat[i][j+1],之后一直扩展这个hash值就可以了。我这里用了双重hash,用两个unsigned int记录hash值,乘不一样的BASE,最后根据这两个值是否都相等判断两个字符串是否相同。两个不同的字符串这两个hash值都相同的概率是极小的,可以忽略不计。。后来试了用一个unsigned int也可以,看人品了,有的BASE不行,有的BASE会WA,最后测了419这个神奇的数字是可以的。。

  还有就是对加的m-w个相隔符要用0~m-w来标记,这样这些以相隔符开头的串就排在了前面,然后从m-w+1统计到len-1就可以了。

#include <stdio.h>
#include <string.h>
#include <algorithm>
#define MAXN 130
#define MAXL 130*130
typedef unsigned long long ULL;
typedef unsigned int UINT;
const ULL BASE1=419,BASE2=131;
struct HASH{
UINT h1,h2;
HASH(){}
HASH(UINT _h1,UINT _h2):h1(_h1),h2(_h2){}
bool operator ==(const HASH& hh)const{return h1==hh.h1&&h2==hh.h2;}
bool operator <(const HASH& hh)const{return h1<hh.h1||h1==hh.h1&&h2<hh.h2;}
}h[MAXN][MAXN],st[MAXL];
char mz[MAXN][MAXN];
int cas,n,m,len;
int wa[MAXL],wb[MAXL],wv[MAXL],ws[MAXL],r[MAXL],ord[MAXL],sa[MAXL];
int cmp(int *r,int a,int b,int l)
{return r[a]==r[b]&&r[a+l]==r[b+l];}
void da(int *r,int *sa,int n,int m)
{
int i,j,p,*x=wa,*y=wb,*t;
for(i=0;i<m;i++) ws[i]=0;
for(i=0;i<n;i++) ws[x[i]=r[i]]++;
for(i=1;i<m;i++) ws[i]+=ws[i-1];
for(i=n-1;i>=0;i--) sa[--ws[x[i]]]=i;
for(j=1,p=1;p<n;j*=2,m=p)
{
for(p=0,i=n-j;i<n;i++) y[p++]=i;
for(i=0;i<n;i++) if(sa[i]>=j) y[p++]=sa[i]-j;
for(i=0;i<n;i++) wv[i]=x[y[i]];
for(i=0;i<m;i++) ws[i]=0;
for(i=0;i<n;i++) ws[wv[i]]++;
for(i=1;i<m;i++) ws[i]+=ws[i-1];
for(i=n-1;i>=0;i--) sa[--ws[wv[i]]]=y[i];
for(t=x,x=y,y=t,p=1,x[sa[0]]=0,i=1;i<n;i++)
x[sa[i]]=cmp(y,sa[i-1],sa[i],j)?p-1:p++;
}
return;
}
int rank[MAXL],height[MAXL];
void calheight(int *r,int *sa,int n)
{
int i,j,k=0;
for(i=0;i<=n;i++) rank[sa[i]]=i;
for(i=0;i<n;height[rank[i++]]=k)
for(k?k--:0,j=sa[rank[i]-1];r[i+k]==r[j+k];k++);
return;
}
bool cmpp(int a,int b){
return st[a]<st[b];
}
int main(){
//freopen("test.in","r",stdin);
scanf("%d",&cas);
for(int ca=1;ca<=cas;ca++){
scanf("%d%d",&n,&m);
for(int i=0;i<n;i++)scanf("%s",mz[i]);
for(int i=0;i<n;i++)for(int j=0;j<m;j++)h[i][j].h1=h[i][j].h2=0ULL;
unsigned long long ans=0;
for(int w=1;w<=m;w++){
//memset(rank,-1,sizeof rank);
for(int i=0;i<n;i++){
for(int j=w-1;j<m;j++){
h[i][j+1-w].h1=h[i][j+1-w].h1*BASE1+mz[i][j];
h[i][j+1-w].h2=h[i][j+1-w].h2*BASE2+mz[i][j];
}
}
len=0;
for(int j=0;j+w<=m;j++){
for(int i=0;i<n;i++)st[len++]=h[i][j];
st[len++]=HASH(0,0);
}
for(int i=0;i<len;i++)ord[i]=i;
std::sort(ord,ord+len,cmpp);
r[ord[0]]=0;
for(int i=1;i<len;i++){
if(st[ord[i]]==st[ord[i-1]]&&st[ord[i]].h1)r[ord[i]]=r[ord[i-1]];
else r[ord[i]]=r[ord[i-1]]+1;
}
da(r,sa,len,r[ord[len-1]]+1);
calheight(r,sa,len-1);
ULL tmp=(1+n)*n/2*(m-w+1);
for(int i=m-w+1;i<len;i++)tmp-=height[i];
ans+=tmp;
}
printf("Case #%d: %I64u\n",ca,ans);
}
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: